A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch, and optionally resuming from a saved checkpoint. Attributes: max_epoch (int): Maximum number of epochs for training. model (torch.nn.Module): The model to be trained.
| 37 | |
| 38 | |
| 39 | class Trainer: |
| 40 | """ |
| 41 | A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch, |
| 42 | and optionally resuming from a saved checkpoint. |
| 43 | |
| 44 | Attributes: |
| 45 | max_epoch (int): Maximum number of epochs for training. |
| 46 | model (torch.nn.Module): The model to be trained. |
| 47 | optim (torch.optim.Optimizer): The optimizer to use for training. |
| 48 | scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler. |
| 49 | dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset. |
| 50 | dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset. |
| 51 | output_dir (str): Directory where model checkpoints will be saved. |
| 52 | resume (str, optional): Path to a checkpoint to resume training from. |
| 53 | """ |
| 54 | |
| 55 | def __init__( |
| 56 | self, |
| 57 | local_rank, |
| 58 | use_ddp: bool = False, |
| 59 | use_fsdp: bool = False, |
| 60 | use_fp16: bool = False, |
| 61 | use_bf16: bool = False, |
| 62 | output_dir: str = "./", |
| 63 | **kwargs, |
| 64 | ): |
| 65 | """ |
| 66 | Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings. |
| 67 | |
| 68 | Args: |
| 69 | model (torch.nn.Module): The model to be trained. |
| 70 | optim (torch.optim.Optimizer): The optimizer to use for training. |
| 71 | scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler. |
| 72 | dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset. |
| 73 | dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset. |
| 74 | **kwargs: Additional keyword arguments: |
| 75 | max_epoch (int): The maximum number of epochs for training. |
| 76 | output_dir (str): The directory where model checkpoints will be saved. Default is './'. |
| 77 | resume (str, optional): The file path to a checkpoint to resume training from. |
| 78 | """ |
| 79 | |
| 80 | self.output_dir = output_dir |
| 81 | if not os.path.exists(self.output_dir): |
| 82 | os.makedirs(self.output_dir, exist_ok=True) |
| 83 | self.resume = kwargs.get("resume", True) |
| 84 | self.start_epoch = 0 |
| 85 | self.max_epoch = kwargs.get("max_epoch", 100) |
| 86 | self.local_rank = local_rank |
| 87 | self.use_ddp = use_ddp |
| 88 | self.use_fsdp = use_fsdp |
| 89 | self.device = kwargs.get("device", "cuda") |
| 90 | # self.kwargs = kwargs |
| 91 | self.log_interval = kwargs.get("log_interval", 50) |
| 92 | self.batch_total = 0 |
| 93 | self.use_fp16 = use_fp16 |
| 94 | self.use_bf16 = use_bf16 |
| 95 | self.amp_enabled = use_fp16 or use_bf16 |
| 96 | self.amp_dtype = torch.bfloat16 if use_bf16 else (torch.float16 if use_fp16 else None) |
no outgoing calls
no test coverage detected
searching dependent graphs…