The Engine class is responsible for managing the training and evaluation process of a neural network model. It handles the forward and backward passes, parameter updates, gradient handling, and mode switching between training and evaluation. Args: model (torch.nn.Module): T
| 17 | |
| 18 | |
| 19 | class Engine: |
| 20 | """ |
| 21 | The Engine class is responsible for managing the training and evaluation process of a neural network model. |
| 22 | It handles the forward and backward passes, parameter updates, gradient handling, and mode switching between |
| 23 | training and evaluation. |
| 24 | |
| 25 | Args: |
| 26 | model (torch.nn.Module): The neural network model to be trained or evaluated. |
| 27 | optimizer (BaseOptimizer): The optimizer used for updating the parameters of the model. |
| 28 | lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): The learning rate scheduler for the optimizer. |
| 29 | Default is None. |
| 30 | beta2_scheduler (internlm.solver.beta2_scheduler.Beta2Scheduler, optional): The beta2 scheduler for the |
| 31 | optimizer. Default is None. |
| 32 | criterion (torch.nn.modules.loss._Loss, optional): The loss function used for calculating the loss during |
| 33 | training. Default is None. |
| 34 | gradient_handlers (List[BaseGradientHandler], optional): A list of gradient handlers used in the backward pass. |
| 35 | Default is None. |
| 36 | clip_grad_norm (float, optional): The norm value for gradient clipping. Default is 0.0. |
| 37 | |
| 38 | Examples: |
| 39 | >>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training |
| 40 | >>> model = ... |
| 41 | >>> criterion = ... |
| 42 | >>> optimizer = ... |
| 43 | >>> train_dataloader = ... |
| 44 | >>> engine, _, _, _ = internlm.initialize_engine(model, optimizer, criterion) |
| 45 | >>> engine.train() |
| 46 | >>> for inputs, labels in train_dataloader |
| 47 | >>> # set gradients to zero |
| 48 | >>> engine.zero_grad() |
| 49 | >>> # run forward pass |
| 50 | >>> outputs = engine(inputs) |
| 51 | >>> # compute loss value and run backward pass |
| 52 | >>> loss = engine.criterion(outputs, labels) |
| 53 | >>> engine.backward(loss) |
| 54 | >>> # update parameters |
| 55 | >>> engine.step() |
| 56 | """ |
| 57 | |
| 58 | def __init__( |
| 59 | self, |
| 60 | model: Module, |
| 61 | optimizer: BaseOptimizer, |
| 62 | lr_scheduler: Optional[_LRScheduler] = None, |
| 63 | beta2_scheduler: Optional[Beta2Scheduler] = None, |
| 64 | criterion: Optional[_Loss] = None, |
| 65 | gradient_handlers: Optional[List[BaseGradientHandler]] = None, |
| 66 | clip_grad_norm: float = 0.0, |
| 67 | ): |
| 68 | self._model = model |
| 69 | self._optimizer = optimizer |
| 70 | self._lr_scheduler = lr_scheduler |
| 71 | self._beta2_scheduler = beta2_scheduler |
| 72 | self._criterion = criterion |
| 73 | self._clip_grad_norm = clip_grad_norm |
| 74 | |
| 75 | # state |
| 76 | self.training = True # default |
no outgoing calls
no test coverage detected