trainer
Submodules
Attributes
Classes
Class that loads up the configuration for performing training, validation, and testing |
Package Contents
- class trainer.AutoTrainer(model: Any, loss: torch.nn.Module, optimizer: Any, optimizer_params: Dict = {'lr': 0.0001}, scheduler: Any | None = None, scheduler_params: Dict | None = None, test_eval: Any | None = None)
Bases:
lightning.LightningModuleClass that loads up the configuration for performing training, validation, and testing torch.nn.Module models using functionality from PyTorch Lightning.
Requires lightning.Trainer to perform training and testing.
If more elaborate settings is needed, please add the necessary configurations through inheritence, or directly implement as a subclass of lightning.LightningModule.
- model
The model to be trained.
- Type:
torch.nn.Module
- loss
The loss function used for training.
- Type:
function
- optimizer
The optimizer used for training.
- Type:
type
- optimizer_params
Parameters to be passed to the optimizer.
- Type:
dict
- test_eval
Evaluation function for testing. Defaults to None.
- Type:
function
Initializes an AutoTrainer object.
- Parameters:
model (torch.nn.Module) – The model to be trained.
loss (function) – The loss function used for training.
optimizer (type) – The class of the optimizer used for training.
optimizer_params (dict) – Parameters to be passed to the optimizer. Defaults to {“lr”: 0.0001}.
test_eval (function, optional) – Evaluation function for testing. Defaults to None.
- model
- loss
- optimizer
- optimizer_params
- test_eval = None
- scheduler = None
- scheduler_params = None
- training_step(batch: Tuple, batch_idx: int) torch.Tensor
Performs a forward pass of the model on a batch of training data and returns corresponding loss.
The computed loss is logged as “train_loss”.
- Parameters:
batch (torch.Tensor) – The batch used for training.
batch_idx (int) – Id corresponding to the used batch.
- Returns:
training loss from forward pass of batch into model.
- Return type:
torch.Tensor
- validation_step(batch: Tuple, batch_idx: int) None
Performs a forward pass of the model on a batch of validation data.
The computed loss is logged as “val_loss”.
- Parameters:
batch (torch.Tensor) – The batch used for training.
batch_idx (int) – Id corresponding to the used batch.
- test_step(batch: Tuple, batch_idx: int) None
Performs a forward pass of the model on a batch of testing data.
The computed loss is logged as “test_loss”.
- Parameters:
batch (torch.Tensor) – The batch used for training.
batch_idx (int) – Id corresponding to the used batch.
- forward(inputs: torch.Tensor | Any) torch.Tensor | Any
Perform one forward pass of the model.
- Parameters:
inputs (torch.Tensor) – Input to the model.
- Returns:
Model output.
- Return type:
torch.Tensor
- configure_optimizers() Dict[str, Any]
Define the optimizer of the project.
- Returns:
The dictionary containing optimizer and scheduler used for training.
- Return type:
Dict
- trainer.data_gen