datagenerator.text_datasets
Attributes
Classes
A PyTorch Dataset for text data with trigger words and feature masks, designed for explainable AI (XAI) tasks. |
Functions
|
Module Contents
- class datagenerator.text_datasets.BaseTextDataset
Bases:
torch.utils.data.Dataset
- class datagenerator.text_datasets.TextTriggerDataset(index: Tuple[int, int] | None = None, tokenizer: Any | None = None, max_sequence_length: int = 4096, seed: int = 42, baselines: int | str = 220, skip_tokens: List[str] = [], model_name: str = 'XAIUnits/TriggerLLM_v2')
Bases:
BaseTextDatasetA PyTorch Dataset for text data with trigger words and feature masks, designed for explainable AI (XAI) tasks.
This dataset loads text data, tokenizes it, identifies trigger words, and generates feature masks highlighting these words. It’s specifically tailored for analyzing the impact of trigger words on model predictions.
- index
A tuple specifying the start and end indices for data subset selection. Defaults to None, using the entire dataset.
- Type:
tuple, optional
- tokenizer
The tokenizer to use for text processing. If None, it’s loaded based on the specified model_name.
- Type:
transformers.PreTrainedTokenizer, optional
- max_sequence_length
The maximum sequence length for input text. Longer sequences are truncated. Defaults to 4096.
- Type:
int, optional
- seed
Random seed for shuffling the data. Use -1 for no shuffling. Defaults to 42.
- Type:
int, optional
- baselines
Baseline token ID or string for attribution methods. Defaults to 220 (space token for Llama models).
- Type:
int or str, optional
- skip_tokens
List of tokens to skip during attribution. Defaults to an empty list.
- Type:
list, optional
- model_name
The name of the model to use for loading the tokenizer. Defaults to “XAIUnits/TriggerLLM_v2”.
- Type:
str, optional
- model_name = 'XAIUnits/TriggerLLM_v2'
- target
- __getitem__(idx: int) Tuple[Any, Ellipsis]
- __len__() int
- generate_model() Tuple[Any, Any]
- property collate_fn: Callable
- property default_metric: Callable
- datagenerator.text_datasets._generate_default_metrics(region: str, agg_list: str, metric_ratio_mapping: Callable, out_processing: Callable) Callable
- datagenerator.text_datasets.dataset