metrics ======= .. py:module:: metrics Submodules ---------- .. toctree:: :maxdepth: 1 /autoapi/metrics/metrics_wrapper/index Attributes ---------- .. autoapisummary:: metrics.data Functions --------- .. autoapisummary:: metrics._validate_metric_gen_arguments metrics.default_metric_input_generator metrics._validate_wrapper_inputs metrics.wrap_metric metrics.perturb_standard_normal metrics._flatten_cat_features metrics._reformat_replacements metrics.perturb_func_constructor Package Contents ---------------- .. py:function:: _validate_metric_gen_arguments(metric: Callable, feature_input: torch.Tensor, y_labels: Optional[torch.Tensor], target: Optional[Union[torch.Tensor, int]], context: Optional[Dict], attribute: torch.Tensor, method_instance: Any, model: Any) -> None Validates the input arguments for the metric generator function. :param metric: The metric function. :type metric: Callable :param feature_input: Input features. :type feature_input: torch.Tensor :param y_labels: Ground truth labels. :type y_labels: torch.Tensor, optional :param target: Target labels. :type target: torch.Tensor, int, optional :param context: Contextual information. :type context: dict, optional :param attribute: Attribute of interest. :type attribute: torch.Tensor :param method_instance: Method instance. :type method_instance: Any :param model: Model instance. :type model: Any :raises AssertionError: If input arguments are not of the expected types with the expected behaviour. .. py:function:: default_metric_input_generator(metric: Callable, feature_input: torch.Tensor, y_labels: Optional[torch.Tensor], target: Optional[Union[torch.Tensor, int]], context: Optional[Dict], attribute: torch.Tensor, method_instance: Any, model: Any, **other: Any) -> Dict[str, Any] Default input generator. Input generator collates information from model, model output, dataset, attribute, method instance and others into single dictionary that will be unpack and used as arguments for the metric class method. The default keys naming schema is that of captum's attribute; create a custom input generator if you are using functions from other libraries. This function only support captum metric's as well a torch mse loss. Some arguments passed in (see pipeline class) are not used by the default function (e.g. y_labels). These arguments are there as we anticipate that users who create and pass in their own input generator function may require these arguments. :param metric: Metric to evaluate attribution score. :type metric: Callable :param feature_inputs: Input tensor. :type feature_inputs: torch.Tensor :param y_labels: True y label tensor. :type y_labels: torch.Tensor :param target: Target arguments to pass on to Captum attribute function. :type target: torch.Tensor :param context: Dict containing other relevant data (e.g. ground truth attribution). :type context: dict :param attribute: Attribution to be evaluated. :type attribute: torch.Tensor :param method_instance: Method used to obtain attribute. :type method_instance: Any :param other: Other keyword arguments to be passed into metric function. :type other: Any :returns: Returns a dict with all argument required for Captum attribute function. :rtype: dict :raises TypeError: If metric given is not supported. .. py:function:: _validate_wrapper_inputs(metric_fns: Callable, input_generator_fns: Callable, out_processing: Callable, other_args: Dict) -> None Validates the input wrappers. :param metric_fns: The metric function to evaluate. :type metric_fns: Callable :param input_generator_fns: A function to generate input for the metric. :type input_generator_fns: Callable :param out_processing: A function to post-process the metric evaluation scores. :type out_processing: Callable :raises AssertionError: If any of the assertions given fails. .. py:function:: wrap_metric(metric_fns: Callable, input_generator_fns: Callable = default_metric_input_generator, out_processing: Optional[Callable] = None, name: Optional[str] = None, pre_fix: str = '', **other_args: Any) -> Any Wraps a metric function/callable to be used in a pipeline class. Important: default behavior out_processing for mse_loss is MSE. :param metric_fns: The metric function to evaluate. :type metric_fns: Callable :param input_generator_fns: A function to generate input for the metric. Defaults to default_metric_input_generator. :type input_generator_fns: Callable :param out_processing: A function to post-process the metric evaluation scores. Defaults to None. :type out_processing: Callable, optional :param name: The name of the metric. Defaults to the name of the metric function. Defaults to None. :type name: str, optional :param pre_fix: A prefix to add to the name of the wrapped metric. Defaults to "". :type pre_fix: str :param other_args: Any other keyword arguments to be passed to the input generator. :type other_args: Any :returns: A class that wraps the metric function. :rtype: type .. py:function:: perturb_standard_normal(input, sd: float = 0.1) -> torch.Tensor Simple perturbation function for Continuous Dataset. Important to note that given the infidelity decorator is used, and Multiply by inputs set to true, when called function will return a tuple, perturbation and and perturbed inputs. :param input: Input feature tensor which was used to calculate attribution score :type input: torch.Tensor :param sd: A standard deviation of the Gaussian noise added to the continuous features. :type sd: float :returns: Gaussian perturbed input. :rtype: (torch.Tensor) .. py:function:: _flatten_cat_features(cat_features: List[Union[int, Tuple[int]]]) -> List[int] Flattens categorical feature argument. :param cat_features: A list of int or tuple representing feature or one-hot encoding of features that are categorical. :type cat_features: list[int | tuple] :returns: Flattened list of categorical features. :rtype: list[int] :raises Exception: If invalid categorical feature input is provided. :raises AssertionError: If there are duplicate features in the flattened list. .. py:function:: _reformat_replacements(replacements: Union[Dict[Union[int, Tuple[int]], Any], torch.Tensor], cat_features: List[Union[int, Tuple[int]]]) -> Dict[Union[int, Tuple[int]], torch.Tensor] Reformats replacements arguments with defaults. :param cat_features: A list of int or tuple representing feature or one-hot encoding of features that are categorical. :type cat_features: list[int | tuple] :param replacements: Dictionary with tuple or int corresponding to cat features and list of values or torch.Tensor representing original dataset to be sampled from. :type replacements: dict | torch.Tensor :returns: Dictionary containing the categorical features and their corresponding torch.Tensor replacements. :rtype: dict .. py:function:: perturb_func_constructor(noise_scale: float, cat_resample_prob: float, cat_features: List[Union[int, Tuple[int]]], replacements: Dict = {}, run_infidelity_decorator: bool = True, multipy_by_inputs: bool = False) -> Callable Simple perturbabtion function generator compatible with captum's infidelity and sensitivity method. :param noise_scale: A standard deviation of the Gaussian noise added to the continuous features. :type noise_scale: float :param cat_resample_prob: Probability of resampling a categorical feature. :type cat_resample_prob: float :param cat_features: A list of int or tuple representing feature or one-hot encoding of features that are categorical. :type cat_features: list[int | tuple] :param replacements: Dictionary with tuple or int corresponding to cat features and list of values or torch.Tensor representing original dataset to be sampled from. Defaults to {}. :type replacements: dict | torch.Tensor :param run_infidelity_decorator: Set to True if you want the returned fns to be compatible with infidelity. Set flag to False for sensitivity. Defaults to True. :type run_infidelity_decorator: bool :param multiply_by_inputs: Parameters for decorator. Defaults to False. :type multiply_by_inputs: bool :returns: A perturbation function compatible with Captum. :rtype: perturb_func (function) .. rubric:: Examples Given an expected input tensor of shape (N,M), N is batch size and M is number of features. If input features a, b, c are independent categorical features, then cat_features = [a, b, c]. If input features a, b, c are one hot encoding representations, then cat_features = [(a, b, c)] .. py:data:: data