datagenerator.text_datasets

Attributes

dataset

Classes

BaseTextDataset

TextTriggerDataset

A PyTorch Dataset for text data with trigger words and feature masks, designed for explainable AI (XAI) tasks.

Functions

_generate_default_metrics(→ Callable)

Module Contents

class datagenerator.text_datasets.BaseTextDataset

Bases: torch.utils.data.Dataset

class datagenerator.text_datasets.TextTriggerDataset(index: Tuple[int, int] | None = None, tokenizer: Any | None = None, max_sequence_length: int = 4096, seed: int = 42, baselines: int | str = 220, skip_tokens: List[str] = [], model_name: str = 'XAIUnits/TriggerLLM_v2')

Bases: BaseTextDataset

A PyTorch Dataset for text data with trigger words and feature masks, designed for explainable AI (XAI) tasks.

This dataset loads text data, tokenizes it, identifies trigger words, and generates feature masks highlighting these words. It’s specifically tailored for analyzing the impact of trigger words on model predictions.

index

A tuple specifying the start and end indices for data subset selection. Defaults to None, using the entire dataset.

Type:

tuple, optional

tokenizer

The tokenizer to use for text processing. If None, it’s loaded based on the specified model_name.

Type:

transformers.PreTrainedTokenizer, optional

max_sequence_length

The maximum sequence length for input text. Longer sequences are truncated. Defaults to 4096.

Type:

int, optional

seed

Random seed for shuffling the data. Use -1 for no shuffling. Defaults to 42.

Type:

int, optional

baselines

Baseline token ID or string for attribution methods. Defaults to 220 (space token for Llama models).

Type:

int or str, optional

skip_tokens

List of tokens to skip during attribution. Defaults to an empty list.

Type:

list, optional

model_name

The name of the model to use for loading the tokenizer. Defaults to “XAIUnits/TriggerLLM_v2”.

Type:

str, optional

model_name = 'XAIUnits/TriggerLLM_v2'
target
__getitem__(idx: int) Tuple[Any, Ellipsis]
__len__() int
generate_model() Tuple[Any, Any]
property collate_fn: Callable
property default_metric: Callable
datagenerator.text_datasets._generate_default_metrics(region: str, agg_list: str, metric_ratio_mapping: Callable, out_processing: Callable) Callable
datagenerator.text_datasets.dataset