Customization Tutorial

This notebook will serve as a tutorial on how users can use custom dataset, method or metric with our package. We will go through examples for each as well as other potential customizable parameters useful.

1 Custom Datasets

Users are able to pass in their own custom datasets into xaiunits’ the pipeline.

Here we will show simple example of how a user can do so, and later we will show more complex variations.

1.1 Simple Example

In this example we will

  1. Use sk_learn’s function to download cali data (which omits categorical data) and create a custom torch dataset

  2. Train a model using our AutoTraining (this step is optional)

  3. Simple Selection of XAI Method and Metric

  4. Instantiate Pipeline Class and run attribute

  5. Print Results

from sklearn.datasets import fetch_california_housing
import pickle

import torch
from torch.utils.data import Dataset, DataLoader
from xaiunits.model import DynamicNN
from xaiunits.trainer.trainer import AutoTrainer
from xaiunits.metrics import perturb_standard_normal, wrap_metric
from xaiunits.pipeline import Pipeline

from lightning.pytorch.callbacks.early_stopping import EarlyStopping
import lightning as L

from captum.attr import *
from captum.metrics import sensitivity_max, infidelity
import pandas as pd
#1. Download and Create California Dataset

class CaliDataset(Dataset):
    def __init__(self):
        sk_cali = fetch_california_housing(data_home="data/cali")
        self.feature_input = torch.tensor(sk_cali.data, dtype=float)
        self.labels = torch.tensor(sk_cali.target, dtype=float)

    def __len__(self):
        return self.feature_input.shape[0]

    def __getitem__(self, idx):
        return self.feature_input[idx], self.labels[idx]

data = CaliDataset()
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    data, [0.7, 0.2, 0.1]
)
train_data = DataLoader(train_dataset, batch_size=64)
val_data = DataLoader(val_dataset, batch_size=64)
# 2. Train Model

hdim = 100
linear_model_config = [
    {
        "type": "Linear",
        "in_features": data[:][0].shape[1],
        "out_features": hdim,
        "dtype": float,
    },
    {"type": "ReLU"},
    {"type": "Linear", "in_features": hdim, "out_features": hdim, "dtype": float},
    {"type": "ReLU"},
    {"type": "Linear", "in_features": hdim, "out_features": hdim, "dtype": float},
    {"type": "ReLU"},
    {"type": "Linear", "in_features": hdim, "out_features": 1, "dtype": float},
]
model = DynamicNN(linear_model_config)

try:
    with open("data/model.pkl", "rb") as file:
        state_dict = pickle.load(file)
    model.load_state_dict(state_dict)
except:
    # define auto trainer
    loss = torch.nn.functional.mse_loss
    optim = torch.optim.Adam
    lightning_linear_model = AutoTrainer(model, loss, optim)
    trainer = L.Trainer(
        min_epochs=20,
        max_epochs=50,
        callbacks=[EarlyStopping(monitor="val_loss", mode="min", verbose=True)],
    )

    # test results before training
    trainer.test(lightning_linear_model, dataloaders=test_dataset)

    # train model
    trainer.fit(
        model=lightning_linear_model,
        train_dataloaders=train_data,
        val_dataloaders=val_data,
    )
    # test results after training
    trainer.test(lightning_linear_model, dataloaders=test_dataset)

    with open("./data/model.pkl", "wb") as file:
        pickle.dump(lightning_linear_model.model.state_dict(), file)
    
    model = lightning_linear_model.model
#3. Select XAI Methods and Metrics
xmethods = [
        InputXGradient,
        IntegratedGradients,        
        DeepLift
    ]

metrics = [
    wrap_metric(sensitivity_max),
    wrap_metric(infidelity, perturb_func=perturb_standard_normal),
]
#4. Instantiate Pipeline and Run Pipeline
pipeline = Pipeline(model, test_dataset, xmethods, metrics, method_seeds=[10])
pipeline.run() # apply the explanation methods and evaluate them
#5. Display Results
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)

df = pipeline.results.print_stats(["infidelity", "sensitivity_max"], index=["method"], stat_funcs=["mean"])
                          mean                
metric              infidelity sensitivity_max
method                                        
DeepLift             2.423e-08       8.959e-04
InputXGradient       2.557e-09       1.360e-03
IntegratedGradients  2.060e-08       1.467e-03

1.2 Extension Example (+ Perturb Tutorial)

In this example we will define a new dataset for experiments by creating a subclass of BaseFeatureDataset

We will first create a new subclass of BaseFeaturedDataset (similar to WeightFeatureDataset but with Categorical Features). See comment for useful tips in the code.

We then define an instance of our Experiment class, and later pass in the Experiment to our ExperimentPipeline class. Note in this example we only pass in one Experiment, but users can pass in multiple Experiments to be executed.

In experiment class, makes testing for repetition much easier, further more given out Dataset is a subclass of BaseFeatureDataset we can also take advantage of these.

import torch

from xaiunits.datagenerator import BaseFeaturesDataset
from xaiunits.pipeline import Experiment, ExperimentPipeline, Pipeline
from xaiunits.methods import wrap_method
from xaiunits.metrics import wrap_metric, perturb_func_constructor

from captum.attr import *
from captum.metrics import infidelity, sensitivity_max

import pandas as pd
# Creat a new subclass of our BaseFeatureDataset, so that it is compatible with ExperimentPipeline Class

class ConCatFeatureDataset(BaseFeaturesDataset):
    def __init__(self, n_features=6, n_samples=100, seed=0, **other):
        assert n_features > 3
        super().__init__(n_features=n_features, n_samples=n_samples, seed=0, **other )

        # make last 3 categorical, 1 or 0 
        self.samples[:,[-3, -2, -1]] = (self.samples[:,[-3, -2, -1]]>0.0).float()
        self.weights = torch.rand(n_features)
        self.weighted_samples = self.samples * self.weights        
        self.labels = self.weighted_samples.sum(dim=1)
        self.features = "samples"
        self.ground_truth_attribute = "weighted_samples"
        self.subset_data = ["samples", "weighted_samples"]
        self.subset_attribute = ["weights" , "cat_features" ]
        self.cat_features = [ n_features-3, n_features-2, n_features-1 ] #Our package provides class method to generate a perturb function 
        # this attribute is needed to determine which features are categorical. 
    
    def generate_model(self):
        """
        Generates a neural network model using the defined features and weights.

        Returns:
            ContinuousFeaturesNN: A neural network model tailored to the dataset's features and weights.
        """
        from xaiunits.model.continuous import ContinuousFeaturesNN

        return ContinuousFeaturesNN(self.n_features, self.weights)
# Create an Experiment Class

xmethods = [        
        Lime,
        DeepLift
    ]

metrics = [
    wrap_metric(sensitivity_max),
    {"metric_fns": infidelity}, #Pipeline class will automatically add/override perturb function based on the dataset  
    # This is an alternate way to specify eval_metric that only works for Experiment class
]


pert_neg_experiment = Experiment( ConCatFeatureDataset, 
                                 None,
                                 xmethods,
                                 metrics, 
                                 seeds=[3, 4], #Seeds to be used to generate dataset
                                 method_seeds=[0, 11], # Seeds for each run of the XAI method
                      )

# Create Experimenter Pipeline 
exp_pipe = ExperimentPipeline(pert_neg_experiment)
exp_pipe.run() # apply the explanation methods and evaluate them
# Display Results
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)

df = exp_pipe.results.print_stats(["infidelity", "sensitivity_max"], 
                                  index=["method"])
               mean                        std                
metric   infidelity sensitivity_max infidelity sensitivity_max
method                                                        
DeepLift      0.124           0.027      0.005       5.379e-05
Lime          0.126           0.143      0.009       4.198e-03

Now we will pivot back to using standard Pipeline Class to show case how to use the specify perturb generator.

dataset = ConCatFeatureDataset()
model = dataset.generate_model()
xmethods = [        
        Lime,
        DeepLift
    ]

# Most simplest way to use perturb constructor
perturb_fns_1 = perturb_func_constructor(noise_scale=0.2, cat_resample_prob=0.2, cat_features= [3, 4, 5]) 

# This results in the same perturb function generated as above, but helps to illustrate how users
#may specify the integer range of the replacement values for the categorical features. 
# If categorical feature i can take values from 1 to 9, then replacement = {i:[0,1,...9]}, default is [0,1]
# we use uniform sampling to replace values fro categorical features. 
perturb_fns_2 = perturb_func_constructor(noise_scale=0.2, cat_resample_prob=0.2, cat_features= [3, 4, 5], 
                                       replacements={3: [0,1] , 4: [0,1] , 5: [0,1]})

# In case users want to sample the alternative from the same distribution as the data, users may also pass in
# dataset as a tensor for
perturb_fns_3 = perturb_func_constructor(noise_scale=0.2, cat_resample_prob=0.2, cat_features= [3, 4, 5], 
                                       replacements=dataset[:][0])


metrics = [
    # wrap_metric(sensitivity_max),
    wrap_metric(infidelity, perturb_func= perturb_fns_1, name= "infidelity_1"),
    wrap_metric(infidelity, perturb_func= perturb_fns_2, name= "infidelity_2"),
    wrap_metric(infidelity, perturb_func= perturb_fns_3, name= "infidelity_3"),
]


# Create Experimenter Pipeline 
pipeline = Pipeline(model, dataset, xmethods, metrics, method_seeds=[0])
pipeline.run() # apply the explanation methods and evaluate them
# Display Results
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)

df = pipeline.results.print_stats(["infidelity_1" , "infidelity_2" , "infidelity_3"], 
                                    index=["data", "method"],
                                    stat_funcs=["mean"])
                                      mean                          
metric                        infidelity_1 infidelity_2 infidelity_3
data                 method                                         
ConCatFeatureDataset DeepLift        0.166        0.142        0.145
                     Lime            0.151        0.149        0.174

2 Custom Methods

In this section we will create a custom dummy attribution method and will show case integrate this into the pipeline.

As a reminder users can directly pass in Captum XAI methods, and our package will execute the methods’ initialization and $attribute$ function with $model$ and $inputs$ as sole arguments respectively; default values are used for all other arguments.

$model$ refers to the Neural Network XAI methods will be run on, and $inputs$ represents the input tensors of the Neural Network. For ease of discussion, we will call arguments $model$ and $inputs$, primary arguments. Other arguments used for method initialization and $attribute$ function, are referred to as secondary arguments and can be static or created at runtime.

For users to pass in their custom XAI method, they first must ensure that the custom XAI Method adheres to the following:

  1. XAI method must be a class

  2. XAI method must have an initialization and $attribute$ function

  3. Respective primary arguments for initialization and $attribute$ function must be the first argument.

  4. Respective secondary arguments for initialization and $attribute$ function must have default values or specified via $input_fns_gen$ or $other_inputs$. See below for details

2.1 Simple Example

Here we create a simple XAI method (DummyAttributionMethod), that returns DeepLift Attribution if flag is set to True, and random noise when flag is set to False. Because we are creating a custom XAI method we also need to create custom $input_fns_gen$.

For this example we will treat $noise$ argument as non-static, and want $noise$ input to be different across the batch when we calculate attribution scores. Furthermore as there is no default values for $noise$ argument which is a secondary argument, this must be specified in either $input_fns_gen$ or $other_inputs$. We can treat flag as a static secondary argument.

We can use $input_fns_gen$ to pass in a function to generate the non-static secondary argument, and $other_inputs$ to pass in static secondary arguments.

import torch
from xaiunits.datagenerator import WeightedFeaturesDataset
from xaiunits.methods import wrap_method
from xaiunits.metrics import wrap_metric, perturb_standard_normal
from xaiunits.pipeline import Pipeline

from captum.attr import DeepLift
from captum.metrics import sensitivity_max, infidelity
import pandas as pd
# Defining Dummy XAI method and input_fns_gen

class DummyAttributionMethod():
    def __init__(self, model):
        self.actual_attribution = DeepLift(model)
        self.forward_func = model # For now this is needed,  
    def attribute(self, inputs, noise, flag=False):
        if flag == False:
            # In captum, inputs are often tuples. This is especially the case when calling infidelity or sensitivity. 
            # Here we provide some sample code to handle tuple inputs that is compatible with infidelity
            if type(inputs) == tuple:
                output = []
                for x in inputs:
                    if x.shape[0] == noise.shape[0]: # normal forward pass
                        output.append(noise) 
                    else: # Perturbed forward pass
                        output.append(torch.repeat_interleave(noise,  x.shape[0]//noise.shape[0], dim=0))

                return tuple(output)  
            else:
                return noise
        else:
            return self.actual_attribution.attribute(inputs)

def dummy_input_gen(feature_inputs, y_labels, target, context, model):
    return {
        "noise": torch.rand_like(feature_inputs)
    }
# Common arguments

dataset = WeightedFeaturesDataset()
model = dataset.generate_model()
metrics = [   
    wrap_metric(infidelity, perturb_func=perturb_standard_normal),
    # how to get root mean square
    wrap_metric(
        torch.nn.functional.mse_loss, 
        out_processing=lambda x: torch.sqrt(torch.sum(x, dim=1)),
    )
]
# In this example we use $other_inputs$ to override the default behavior of our DummyAttributeMethod class 
pipeline_true = [
        wrap_method(DummyAttributionMethod, dummy_input_gen, other_inputs={"flag": True}),
        DeepLift,
    ]

pipeline_true = Pipeline(model, dataset, pipeline_true, metrics, method_seeds=[10])
pipeline_true.run() # apply the explanation methods and evaluate them

pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
df = pipeline_true.results.print_stats([ "mse_loss", "infidelity"], stat_funcs=["mean"])
                                                   
                                                                                  mean  \
metric                                                                      infidelity   
data                    model                method                                      
WeightedFeaturesDataset ContinuousFeaturesNN DeepLift                        1.223e-15   
                                             wrapper_DummyAttributionMethod  1.644e-15   

                                                                                      
metric                                                                      mse_loss  
data                    model                method                                   
WeightedFeaturesDataset ContinuousFeaturesNN DeepLift                            0.0  
                                             wrapper_DummyAttributionMethod      0.0  

# Here we remove $other_inputs$ override thus attribution should be random noise. 
# Expectation is that MSE and Infidelity metrics are worse compared to DeepLift
xmethods_false = [
        wrap_method(DummyAttributionMethod, dummy_input_gen),
        DeepLift
    ]

pipeline_false = Pipeline(model, dataset, xmethods_false, metrics, method_seeds=[10])
pipeline_false.run() # apply the explanation methods and evaluate them

pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)

df = pipeline_false.results.print_stats(["mse_loss", "infidelity"], stat_funcs=["mean"])
                                                   
                                                                                  mean  \
metric                                                                      infidelity   
data                    model                method                                      
WeightedFeaturesDataset ContinuousFeaturesNN DeepLift                        1.223e-15   
                                             wrapper_DummyAttributionMethod  3.180e-02   

                                                                                      
metric                                                                      mse_loss  
data                    model                method                                   
WeightedFeaturesDataset ContinuousFeaturesNN DeepLift                          0.000  
                                             wrapper_DummyAttributionMethod    0.895