trainer.trainer

Attributes

data_gen

Classes

AutoTrainer

Class that loads up the configuration for performing training, validation, and testing

Module Contents

class trainer.trainer.AutoTrainer(model: Any, loss: torch.nn.Module, optimizer: Any, optimizer_params: Dict = {'lr': 0.0001}, scheduler: Any | None = None, scheduler_params: Dict | None = None, test_eval: Any | None = None)

Bases: lightning.LightningModule

Class that loads up the configuration for performing training, validation, and testing torch.nn.Module models using functionality from PyTorch Lightning.

Requires lightning.Trainer to perform training and testing.

If more elaborate settings is needed, please add the necessary configurations through inheritence, or directly implement as a subclass of lightning.LightningModule.

model

The model to be trained.

Type:

torch.nn.Module

loss

The loss function used for training.

Type:

function

optimizer

The optimizer used for training.

Type:

type

optimizer_params

Parameters to be passed to the optimizer.

Type:

dict

test_eval

Evaluation function for testing. Defaults to None.

Type:

function

Initializes an AutoTrainer object.

Parameters:
  • model (torch.nn.Module) – The model to be trained.

  • loss (function) – The loss function used for training.

  • optimizer (type) – The class of the optimizer used for training.

  • optimizer_params (dict) – Parameters to be passed to the optimizer. Defaults to {“lr”: 0.0001}.

  • test_eval (function, optional) – Evaluation function for testing. Defaults to None.

model
loss
optimizer
optimizer_params
test_eval = None
scheduler = None
scheduler_params = None
training_step(batch: Tuple, batch_idx: int) torch.Tensor

Performs a forward pass of the model on a batch of training data and returns corresponding loss.

The computed loss is logged as “train_loss”.

Parameters:
  • batch (torch.Tensor) – The batch used for training.

  • batch_idx (int) – Id corresponding to the used batch.

Returns:

training loss from forward pass of batch into model.

Return type:

torch.Tensor

validation_step(batch: Tuple, batch_idx: int) None

Performs a forward pass of the model on a batch of validation data.

The computed loss is logged as “val_loss”.

Parameters:
  • batch (torch.Tensor) – The batch used for training.

  • batch_idx (int) – Id corresponding to the used batch.

test_step(batch: Tuple, batch_idx: int) None

Performs a forward pass of the model on a batch of testing data.

The computed loss is logged as “test_loss”.

Parameters:
  • batch (torch.Tensor) – The batch used for training.

  • batch_idx (int) – Id corresponding to the used batch.

forward(inputs: torch.Tensor | Any) torch.Tensor | Any

Perform one forward pass of the model.

Parameters:

inputs (torch.Tensor) – Input to the model.

Returns:

Model output.

Return type:

torch.Tensor

configure_optimizers() Dict[str, Any]

Define the optimizer of the project.

Returns:

The dictionary containing optimizer and scheduler used for training.

Return type:

Dict

trainer.trainer.data_gen