metrics.metrics_wrapper
Attributes
Functions
|
Validates the input arguments for the metric generator function. |
|
Default input generator. |
|
Validates the input wrappers. |
|
Wraps a metric function/callable to be used in a pipeline class. |
|
Simple perturbation function for Continuous Dataset. |
|
Flattens categorical feature argument. |
|
Reformats replacements arguments with defaults. |
|
Simple perturbabtion function generator compatible with captum's infidelity and sensitivity method. |
Module Contents
- metrics.metrics_wrapper._validate_metric_gen_arguments(metric: Callable, feature_input: torch.Tensor, y_labels: torch.Tensor | None, target: torch.Tensor | int | None, context: Dict | None, attribute: torch.Tensor, method_instance: Any, model: Any) None
Validates the input arguments for the metric generator function.
- Parameters:
metric (Callable) – The metric function.
feature_input (torch.Tensor) – Input features.
y_labels (torch.Tensor, optional) – Ground truth labels.
target (torch.Tensor, int, optional) – Target labels.
context (dict, optional) – Contextual information.
attribute (torch.Tensor) – Attribute of interest.
method_instance (Any) – Method instance.
model (Any) – Model instance.
- Raises:
AssertionError – If input arguments are not of the expected types with the expected behaviour.
- metrics.metrics_wrapper.default_metric_input_generator(metric: Callable, feature_input: torch.Tensor, y_labels: torch.Tensor | None, target: torch.Tensor | int | None, context: Dict | None, attribute: torch.Tensor, method_instance: Any, model: Any, **other: Any) Dict[str, Any]
Default input generator.
Input generator collates information from model, model output, dataset, attribute, method instance and others into single dictionary that will be unpack and used as arguments for the metric class method.
The default keys naming schema is that of captum’s attribute; create a custom input generator if you are using functions from other libraries.
This function only support captum metric’s as well a torch mse loss.
Some arguments passed in (see pipeline class) are not used by the default function (e.g. y_labels). These arguments are there as we anticipate that users who create and pass in their own input generator function may require these arguments.
- Parameters:
metric (Callable) – Metric to evaluate attribution score.
feature_inputs (torch.Tensor) – Input tensor.
y_labels (torch.Tensor) – True y label tensor.
target (torch.Tensor) – Target arguments to pass on to Captum attribute function.
context (dict) – Dict containing other relevant data (e.g. ground truth attribution).
attribute (torch.Tensor) – Attribution to be evaluated.
method_instance (Any) – Method used to obtain attribute.
other (Any) – Other keyword arguments to be passed into metric function.
- Returns:
Returns a dict with all argument required for Captum attribute function.
- Return type:
dict
- Raises:
TypeError – If metric given is not supported.
- metrics.metrics_wrapper._validate_wrapper_inputs(metric_fns: Callable, input_generator_fns: Callable, out_processing: Callable, other_args: Dict) None
Validates the input wrappers.
- Parameters:
metric_fns (Callable) – The metric function to evaluate.
input_generator_fns (Callable) – A function to generate input for the metric.
out_processing (Callable) – A function to post-process the metric evaluation scores.
- Raises:
AssertionError – If any of the assertions given fails.
- metrics.metrics_wrapper.wrap_metric(metric_fns: Callable, input_generator_fns: Callable = default_metric_input_generator, out_processing: Callable | None = None, name: str | None = None, pre_fix: str = '', **other_args: Any) Any
Wraps a metric function/callable to be used in a pipeline class.
Important: default behavior out_processing for mse_loss is MSE.
- Parameters:
metric_fns (Callable) – The metric function to evaluate.
input_generator_fns (Callable) – A function to generate input for the metric. Defaults to default_metric_input_generator.
out_processing (Callable, optional) – A function to post-process the metric evaluation scores. Defaults to None.
name (str, optional) – The name of the metric. Defaults to the name of the metric function. Defaults to None.
pre_fix (str) – A prefix to add to the name of the wrapped metric. Defaults to “”.
other_args (Any) – Any other keyword arguments to be passed to the input generator.
- Returns:
A class that wraps the metric function.
- Return type:
type
- metrics.metrics_wrapper.perturb_standard_normal(input, sd: float = 0.1) torch.Tensor
Simple perturbation function for Continuous Dataset.
Important to note that given the infidelity decorator is used, and Multiply by inputs set to true, when called function will return a tuple, perturbation and and perturbed inputs.
- Parameters:
input (torch.Tensor) – Input feature tensor which was used to calculate attribution score
sd (float) – A standard deviation of the Gaussian noise added to the continuous features.
- Returns:
Gaussian perturbed input.
- Return type:
(torch.Tensor)
- metrics.metrics_wrapper._flatten_cat_features(cat_features: List[int | Tuple[int]]) List[int]
Flattens categorical feature argument.
- Parameters:
cat_features (list[int | tuple]) – A list of int or tuple representing feature or one-hot encoding of features that are categorical.
- Returns:
Flattened list of categorical features.
- Return type:
list[int]
- Raises:
Exception – If invalid categorical feature input is provided.
AssertionError – If there are duplicate features in the flattened list.
- metrics.metrics_wrapper._reformat_replacements(replacements: Dict[int | Tuple[int], Any] | torch.Tensor, cat_features: List[int | Tuple[int]]) Dict[int | Tuple[int], torch.Tensor]
Reformats replacements arguments with defaults.
- Parameters:
cat_features (list[int | tuple]) – A list of int or tuple representing feature or one-hot encoding of features that are categorical.
replacements (dict | torch.Tensor) – Dictionary with tuple or int corresponding to cat features and list of values or torch.Tensor representing original dataset to be sampled from.
- Returns:
Dictionary containing the categorical features and their corresponding torch.Tensor replacements.
- Return type:
dict
- metrics.metrics_wrapper.perturb_func_constructor(noise_scale: float, cat_resample_prob: float, cat_features: List[int | Tuple[int]], replacements: Dict = {}, run_infidelity_decorator: bool = True, multipy_by_inputs: bool = False) Callable
Simple perturbabtion function generator compatible with captum’s infidelity and sensitivity method.
- Parameters:
noise_scale (float) – A standard deviation of the Gaussian noise added to the continuous features.
cat_resample_prob (float) – Probability of resampling a categorical feature.
cat_features (list[int | tuple]) – A list of int or tuple representing feature or one-hot encoding of features that are categorical.
replacements (dict | torch.Tensor) – Dictionary with tuple or int corresponding to cat features and list of values or torch.Tensor representing original dataset to be sampled from. Defaults to {}.
run_infidelity_decorator (bool) – Set to True if you want the returned fns to be compatible with infidelity. Set flag to False for sensitivity. Defaults to True.
multiply_by_inputs (bool) – Parameters for decorator. Defaults to False.
- Returns:
A perturbation function compatible with Captum.
- Return type:
perturb_func (function)
Examples
Given an expected input tensor of shape (N,M), N is batch size and M is number of features. If input features a, b, c are independent categorical features, then cat_features = [a, b, c]. If input features a, b, c are one hot encoding representations, then cat_features = [(a, b, c)]
- metrics.metrics_wrapper.data