MCPcopy Index your code
hub / github.com/modelscope/FunASR / Trainer

Class Trainer

funasr/train_utils/trainer.py:39–763  ·  view source on GitHub ↗

A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch, and optionally resuming from a saved checkpoint. Attributes: max_epoch (int): Maximum number of epochs for training. model (torch.nn.Module): The model to be trained.

Source from the content-addressed store, hash-verified

37
38
39class Trainer:
40 """
41 A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
42 and optionally resuming from a saved checkpoint.
43
44 Attributes:
45 max_epoch (int): Maximum number of epochs for training.
46 model (torch.nn.Module): The model to be trained.
47 optim (torch.optim.Optimizer): The optimizer to use for training.
48 scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
49 dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
50 dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
51 output_dir (str): Directory where model checkpoints will be saved.
52 resume (str, optional): Path to a checkpoint to resume training from.
53 """
54
55 def __init__(
56 self,
57 local_rank,
58 use_ddp: bool = False,
59 use_fsdp: bool = False,
60 use_fp16: bool = False,
61 use_bf16: bool = False,
62 output_dir: str = "./",
63 **kwargs,
64 ):
65 """
66 Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
67
68 Args:
69 model (torch.nn.Module): The model to be trained.
70 optim (torch.optim.Optimizer): The optimizer to use for training.
71 scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
72 dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
73 dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
74 **kwargs: Additional keyword arguments:
75 max_epoch (int): The maximum number of epochs for training.
76 output_dir (str): The directory where model checkpoints will be saved. Default is './'.
77 resume (str, optional): The file path to a checkpoint to resume training from.
78 """
79
80 self.output_dir = output_dir
81 if not os.path.exists(self.output_dir):
82 os.makedirs(self.output_dir, exist_ok=True)
83 self.resume = kwargs.get("resume", True)
84 self.start_epoch = 0
85 self.max_epoch = kwargs.get("max_epoch", 100)
86 self.local_rank = local_rank
87 self.use_ddp = use_ddp
88 self.use_fsdp = use_fsdp
89 self.device = kwargs.get("device", "cuda")
90 # self.kwargs = kwargs
91 self.log_interval = kwargs.get("log_interval", 50)
92 self.batch_total = 0
93 self.use_fp16 = use_fp16
94 self.use_bf16 = use_bf16
95 self.amp_enabled = use_fp16 or use_bf16
96 self.amp_dtype = torch.bfloat16 if use_bf16 else (torch.float16 if use_fp16 else None)

Callers 1

mainFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…