Quickstart - Library Walkthrough

Main Features

  • Preset neural network models that each have a defined type of behaviour, such as conflicting features.

  • A dataset with the corresponding type of behaviour.

  • Pipelines for evaluating the performance of explanation methods, across multiple datasets and neural networks.

  • The pipelines can support custom explanation methods, evaluation metrics, models, and datasets

  • Trainable models are also supported along with a lightweight trainer helper.

Datasets And Neural Network Models

Each Dataset has a corresponding preset neural network model, i.e. the weights have been explicitly defined rather than learned through training.

List of behaviour types:

  • Continuous Features

  • Synthetic Cancellation

  • Pertinent Negative

  • Interacting Features

  • Shattered Gradients

  • Uncertainty Model

  • Boolean Formulas

There is also the Dynamic Neural Network model that is trainable.

# import all models and datasets
from xaiunits.model import *
from xaiunits.datagenerator import *

from numpy import set_printoptions
set_printoptions(linewidth=10000)

# suppose we want to experiment on a dataset with continuous features
cont_dataset = WeightedFeaturesDataset(n_features=10, n_samples=500)

# Examining one datapoint from the dataset
x, y_true, context = cont_dataset[0]
print("x:", x)
print("y_true:", y_true)
print("context:", context) # context is a dict that (for most datasets) contains "ground_truth_attribute"
x: tensor([-1.1258, -1.1524, -0.2506, -0.4339,  0.8487,  0.6920, -0.3160, -2.1152,
         0.3223, -1.2633])
y_true: tensor(-0.8077)
context: {'ground_truth_attribute': tensor([ 1.0742, -0.7191, -0.0815,  0.1232,  0.2753,  0.5861, -0.2325, -1.2830,
        -0.1876, -0.3630])}
# Each dataset has an associated model type
cont_model = ContinuousFeaturesNN(n_features=cont_dataset.n_features, weights = cont_dataset.weights)
y_pred = cont_model(x)
print("y_pred:", y_pred)

# It is possible to get this model directly from the data generator, so that the model is always consistent with the data
cont_model = cont_dataset.generate_model()
y_pred = cont_model(x)
print("y_pred:", y_pred)
y_pred: tensor([-0.8077], grad_fn=<SqueezeBackward4>)
y_pred: tensor([-0.8077], grad_fn=<SqueezeBackward4>)

We can also run our own trained model. To do so, we can utilise the AutoTrainer we have implemented that builds upon the lightning package.

# prepare data for training
from torch.utils.data import DataLoader

train_data, val_data, test_data = cont_dataset.split([0.7, 0.15, 0.15])

train_loader = DataLoader([data[:2] for data in train_data])
val_loader = DataLoader([data[:2] for data in val_data])
test_loader = DataLoader([data[:2] for data in test_data])
# define model architecture
n_features = 10
model_arch = [{"type": "Linear", "in_features": n_features, "out_features": 32},
    {"type": "ReLU"},
    {"type": "Linear", "in_features": 32, "out_features": 8},
    {"type": "ReLU"},
    {"type": "Linear", "in_features": 8, "out_features": 8},
    {"type": "ReLU"},
    {"type": "Linear", "in_features": 8, "out_features": 1},
]
trained_model = DynamicNN(model_arch)


loss = torch.nn.functional.mse_loss
optim = torch.optim.Adam
# define the trainer
from xaiunits.trainer.trainer import AutoTrainer
import lightning as L
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

lightning_linear_model = AutoTrainer(trained_model, loss, optim)
trainer = L.Trainer(
    min_epochs=5,
    max_epochs=50,
    callbacks=[EarlyStopping(monitor="val_loss", mode="min", verbose=True)],
    enable_progress_bar=False # Lightning progress bar displays poorly in jupyter
)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
# train model
trainer.fit(
    model=lightning_linear_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)

# test results after training
trainer.test(lightning_linear_model, dataloaders=test_loader)

Methods

We want to apply attribution methods to our model. Any of the existing methods in Captum are supported as well as custom attribution methods.

from captum.attr import InputXGradient, IntegratedGradients, Lime
from xaiunits.methods.methods_wrapper import wrap_method

# List out the evaluation methods we want to use
xmethods = [
        InputXGradient,
        IntegratedGradients,
        Lime,
    ]

# If we want to pass non-default parameters to the attribution method, we use wrap_method to pre-load these parameters
wrapped_method = wrap_method(IntegratedGradients, other_inputs={"n_steps": 25})
xmethods.append(wrapped_method)

Metrics

Our goal is to evaluate how well each attribution method performs on different types of model and data. The metric (e.g. infidelity) provides a performance score.

We support metrics from Captum and custom metrics.

from captum.metrics import sensitivity_max, infidelity
from xaiunits.metrics import wrap_metric

# We have wrap_metric to pre-load parameters into the metric, similar to wrap_method
metrics = [
    wrap_metric(sensitivity_max),
    wrap_metric(infidelity, perturb_func=cont_dataset.perturb_function(), normalize=True),
]
# we can define a custom perturb function, or each dataset comes with a standard perturb function

# Another common metric is the RMSE between the attributions and the ground truth context (when ground truth is available)
# Note the usage of out_processing in the wrap_metric below, which allows us to pre-load arbitrary processing
# In this case we want to convert the torch MSE to RMSE
rmse_metric = wrap_metric(
    torch.nn.functional.mse_loss, 
    out_processing=lambda x: torch.sqrt(torch.sum(x, dim=1)),
    )
metrics.append(rmse_metric)

Pipelines

There are the 2 kinds of pipeline we have created:

  • Pipeline

  • ExperimentPipeline

Pipeline (Standard)

  • Allows running experiments on any number of models and datasets (assuming that they are all compatible).

  • Runs experiments over multiple seeds for the different explanation methods (useful when there is non-determinism)

  • Any number of explanation methods and evaluation metrics are supported

  • Aggregates the results into a DataFrame

from xaiunits.pipeline import Pipeline

models = [trained_model, cont_model]
datasets = [cont_dataset]

# Instatiate the pipeline with a list of models, datasets, xmethods, metrics, and seeds.
# All combinations will be evaluated
pipeline = Pipeline(models, datasets, xmethods, metrics, method_seeds=[10], name="test")
results = pipeline.run() # apply the explanation methods and evaluate them


# Accessing Pipeline Results
# You can directly access a dataframe of all the results
df = results.data
# However we generally suggest using the print_stats method, which has a lot of options for unpivoting the table
df_by_method = results.print_stats(stat_funcs=['mean'], index=["trial_group_name", 'method'], column_index=['model'])
                                                           
                                                             mean             \
model                                        ContinuousFeaturesNN              
metric                                                  attr_time infidelity   
trial_group_name method                                                        
test             InputXGradient                             0.009  2.811e-04   
                 IntegratedGradients                        0.325  2.811e-04   
                 Lime                                       5.617  2.867e-04   
                 wrapper_IntegratedGradients                0.170  2.811e-04   

                                                                         \
model                                                                     
metric                                         mse_loss sensitivity_max   
trial_group_name method                                                   
test             InputXGradient               0.000e+00           0.017   
                 IntegratedGradients          4.289e-08           0.017   
                 Lime                         2.306e-01           0.200   
                 wrapper_IntegratedGradients  4.289e-08           0.017   

                                                                            \
model                                        DynamicNN                       
metric                                       attr_time infidelity mse_loss   
trial_group_name method                                                      
test             InputXGradient                  0.019  2.868e-04    0.281   
                 IntegratedGradients             0.402  2.858e-04    0.251   
                 Lime                            7.647  2.924e-04    0.245   
                 wrapper_IntegratedGradients     0.167  2.858e-04    0.251   

                                                              
model                                                         
metric                                       sensitivity_max  
trial_group_name method                                       
test             InputXGradient                        0.046  
                 IntegratedGradients                   0.020  
                 Lime                                  0.213  
                 wrapper_IntegratedGradients           0.020  

Target Class

In our example the model is a regression model, so the target class is not important. But if we have a classification model then the target class is important for most explanation methods.

The pipeline has a parameter default_target which should take one of four possible values:

  • “y_labels” which will use the true y labels as the target class

  • “predicted_class” uses the model prediction as the target, i.e. y=model(feature_inputs)

  • an integer, for a single target class which will be used for all datapoints

  • a tuple or tensor matching the batch size

Accessing Examples from the pipeline

A useful parameter to understand why the scores are high or low is the pipeline parameter n_examples. This stores the n-best and n-worst peforming examples for each method/model/metric for further inspection.

# To demonstrate this, we set up a new pipeline using the n_examples parameter
pipeline = Pipeline(models, datasets, xmethods, metrics, method_seeds=[10], n_examples=1)
results = pipeline.run()
# The key for the examples is first "max" or "min" for the high / low scoring examples
# Then a tuple of (method, model, metric) to select the type of example, which returns a list of length n_examples
all_max_examples = results.examples["max"]
# print(all_max_examples)
example_list = all_max_examples[("IntegratedGradients", "ContinuousFeaturesNN", "mse_loss")]
max_example = example_list[-1]
print(max_example)

# the Example includes the the original feature_inputs, y_labels, and context
print("x:", max_example.feature_inputs)
print("y_true:", max_example.y_labels)
print("context:", max_example.context)

# and the Example includes the attributions and the metric score
print("attributions:", max_example.attribute)
print("metric_score:", max_example.score)
                                                           
Example(score=tensor(1.3339e-07, dtype=torch.float64), attribute=tensor([-2.0317,  0.3137, -0.2884, -0.5674, -0.5510, -0.5692, -0.5603, -1.1613,
         0.6530, -0.1734], dtype=torch.float64), feature_inputs=tensor([ 2.1293,  0.5027, -0.8871,  1.9974, -1.6984, -0.6720, -0.7617, -1.9145,
        -1.1218, -0.6036]), y_labels=tensor(-4.9359), target=None, context={'ground_truth_attribute': tensor([-2.0317,  0.3137, -0.2884, -0.5674, -0.5510, -0.5692, -0.5603, -1.1613,
         0.6530, -0.1734])}, example_type='max')
x: tensor([ 2.1293,  0.5027, -0.8871,  1.9974, -1.6984, -0.6720, -0.7617, -1.9145,
        -1.1218, -0.6036])
y_true: tensor(-4.9359)
context: {'ground_truth_attribute': tensor([-2.0317,  0.3137, -0.2884, -0.5674, -0.5510, -0.5692, -0.5603, -1.1613,
         0.6530, -0.1734])}
attributions: tensor([-2.0317,  0.3137, -0.2884, -0.5674, -0.5510, -0.5692, -0.5603, -1.1613,
         0.6530, -0.1734], dtype=torch.float64)
metric_score: tensor(1.3339e-07, dtype=torch.float64)

Experiment Pipeline

  • The ExperimentPipeline gives a systematic way for iterating over datasets with repeatable data seeds

  • Allows for trials run on different seeds for generating the data

  • Supports experiments that are wrapped with our Experiment class

# suppose we want to run experiments on these models
pert_neg_model1 = PertinentNegativesDataset().generate_model()
pert_neg_model2 = PertinentNegativesDataset(weight_range=(-10.0, 10.0)).generate_model()
pert_neg_model3 = PertinentNegativesDataset(pn_weight_factor=200).generate_model()

shatter_grad_model1 = ShatteredGradientsDataset().generate_model()
shatter_grad_model2 = ShatteredGradientsDataset(discontinuity_ratios=[1.0, 2.0, -7.0, 9.5, -2.0]).generate_model()
shatter_grad_model3 = ShatteredGradientsDataset(discontinuity_ratios=[60.45, -32.2, 23.1, 5.5, 12.0], bias=2.0).generate_model()
from captum.attr import DeepLift, ShapleyValueSampling, KernelShap, LRP

xmethods2 = [
    DeepLift,
    ShapleyValueSampling,
    KernelShap,
    LRP    
]
from xaiunits.pipeline import Experiment

# we need to first wrap them as an Experiment instance
# it is possible to just to give the class of the dataset and the data will be instantiated over different seeds during the experiment
pert_neg_experiment = Experiment(PertinentNegativesDataset, 
                                 [pert_neg_model1, pert_neg_model2, pert_neg_model3],
                                 xmethods2,
                                 None, # Using default metric for evaluation 
                                 seeds=[3, 4],
                                 method_seeds=[0, 11],
                      )

# Alternatively, an instantiated dataset can still be passed in
shattered_grad_experiment = Experiment(ShatteredGradientsDataset(discontinuity_ratios=[1.0, 2.0, -7.0, 9.5, -2.0]), 
                                       [shatter_grad_model1, shatter_grad_model2, shatter_grad_model3],
                                        xmethods2, 
                                        None,
                                        seeds=[3, 4],
                                        method_seeds=[0, 11],
                            )

# also can choose to pass in no model and allow the dataset to generate the corresponding model
interacion_feat_experiment = Experiment(InteractingFeatureDataset, 
                                        None, 
                                        xmethods2, 
                                        None,
                                        seeds=[3, 4],
                                        method_seeds=[0, 11],
                                        )

# customisation to how the data is generated is also possible
conflicting_experiment = Experiment(ConflictingDataset, 
                                    None,
                                    xmethods2,
                                    None, 
                                    seeds=[3, 4],
                                    method_seeds=[0, 11],
                                    data_params={"n_samples": 100, "n_features": 3, "cancellation_likelihood": 0.8},
                      )

experiments = [
    pert_neg_experiment,
    shattered_grad_experiment,
    interacion_feat_experiment,
    conflicting_experiment,
]
from xaiunits.pipeline import ExperimentPipeline

# instantiate the pipeline, run the attribution methods, then process and print the results
exp_pipeline = ExperimentPipeline(experiments)
exp_pipeline.run()
df = exp_pipeline.results.print_stats()
                                                                          mean  \
metric                                                               attr_time   
data                      model                 method                           
ConflictingDataset        ConflictingFeaturesNN DeepLift                 0.002   
                                                KernelShap               0.396   
                                                LRP                      0.002   
                                                ShapleyValueSampling     0.039   
InteractingFeatureDataset InteractingFeaturesNN DeepLift                 0.001   
                                                KernelShap               0.195   
                                                LRP                      0.001   
                                                ShapleyValueSampling     0.008   
PertinentNegativesDataset PertinentNN           DeepLift                 0.005   
                                                KernelShap               0.140   
                                                LRP                      0.003   
                                                ShapleyValueSampling     0.038   
ShatteredGradientsDataset ShatteredGradientsNN  DeepLift                 0.004   
                                                KernelShap               1.138   
                                                LRP                      0.003   
                                                ShapleyValueSampling     0.060   

                                                                                 \
metric                                                                 mse_loss   
data                      model                 method                            
ConflictingDataset        ConflictingFeaturesNN DeepLift              1.424e-01   
                                                KernelShap            4.536e-02   
                                                LRP                   1.424e-01   
                                                ShapleyValueSampling  3.076e-02   
InteractingFeatureDataset InteractingFeaturesNN DeepLift              0.000e+00   
                                                KernelShap            9.243e-02   
                                                LRP                   9.108e-16   
                                                ShapleyValueSampling  6.905e-02   
PertinentNegativesDataset PertinentNN           DeepLift              5.281e+01   
                                                KernelShap            5.281e+01   
                                                LRP                   6.219e+00   
                                                ShapleyValueSampling  5.281e+01   
ShatteredGradientsDataset ShatteredGradientsNN  DeepLift                    NaN   
                                                KernelShap                  NaN   
                                                LRP                         NaN   
                                                ShapleyValueSampling        NaN   

                                                                                      \
metric                                                               sensitivity_max   
data                      model                 method                                 
ConflictingDataset        ConflictingFeaturesNN DeepLift                         NaN   
                                                KernelShap                       NaN   
                                                LRP                              NaN   
                                                ShapleyValueSampling             NaN   
InteractingFeatureDataset InteractingFeaturesNN DeepLift                         NaN   
                                                KernelShap                       NaN   
                                                LRP                              NaN   
                                                ShapleyValueSampling             NaN   
PertinentNegativesDataset PertinentNN           DeepLift                         NaN   
                                                KernelShap                       NaN   
                                                LRP                              NaN   
                                                ShapleyValueSampling             NaN   
ShatteredGradientsDataset ShatteredGradientsNN  DeepLift                     114.367   
                                                KernelShap                     3.910   
                                                LRP                          114.367   
                                                ShapleyValueSampling           1.184   

                                                                            std  \
metric                                                                attr_time   
data                      model                 method                            
ConflictingDataset        ConflictingFeaturesNN DeepLift              1.672e-04   
                                                KernelShap            2.000e-02   
                                                LRP                   1.241e-04   
                                                ShapleyValueSampling  6.651e-03   
InteractingFeatureDataset InteractingFeaturesNN DeepLift              1.413e-04   
                                                KernelShap            1.027e-02   
                                                LRP                   1.705e-04   
                                                ShapleyValueSampling  9.660e-05   
PertinentNegativesDataset PertinentNN           DeepLift              1.013e-03   
                                                KernelShap            8.938e-03   
                                                LRP                   1.172e-04   
                                                ShapleyValueSampling  2.957e-03   
ShatteredGradientsDataset ShatteredGradientsNN  DeepLift              2.955e-04   
                                                KernelShap            3.311e-02   
                                                LRP                   2.350e-04   
                                                ShapleyValueSampling  1.232e-03   

                                                                                 \
metric                                                                 mse_loss   
data                      model                 method                            
ConflictingDataset        ConflictingFeaturesNN DeepLift              3.189e-02   
                                                KernelShap            1.072e-02   
                                                LRP                   3.189e-02   
                                                ShapleyValueSampling  8.854e-03   
InteractingFeatureDataset InteractingFeaturesNN DeepLift              0.000e+00   
                                                KernelShap            6.608e-02   
                                                LRP                   1.260e-17   
                                                ShapleyValueSampling  6.044e-02   
PertinentNegativesDataset PertinentNN           DeepLift              9.908e+00   
                                                KernelShap            9.908e+00   
                                                LRP                   2.821e-01   
                                                ShapleyValueSampling  9.908e+00   
ShatteredGradientsDataset ShatteredGradientsNN  DeepLift                    NaN   
                                                KernelShap                  NaN   
                                                LRP                         NaN   
                                                ShapleyValueSampling        NaN   

                                                                                      
metric                                                               sensitivity_max  
data                      model                 method                                
ConflictingDataset        ConflictingFeaturesNN DeepLift                         NaN  
                                                KernelShap                       NaN  
                                                LRP                              NaN  
                                                ShapleyValueSampling             NaN  
InteractingFeatureDataset InteractingFeaturesNN DeepLift                         NaN  
                                                KernelShap                       NaN  
                                                LRP                              NaN  
                                                ShapleyValueSampling             NaN  
PertinentNegativesDataset PertinentNN           DeepLift                         NaN  
                                                KernelShap                       NaN  
                                                LRP                              NaN  
                                                ShapleyValueSampling             NaN  
ShatteredGradientsDataset ShatteredGradientsNN  DeepLift                      12.123  
                                                KernelShap                     2.134  
                                                LRP                           12.123  
                                                ShapleyValueSampling           0.148