MCPcopy
hub / github.com/modelscope/FunASR / train_epoch

Method train_epoch

funasr/train_utils/trainer.py:345–540  ·  view source on GitHub ↗

Defines the training process for a single epoch with gradient accumulation. Args: epoch (int): The current epoch number.

(
        self,
        model=None,
        optim=None,
        scheduler=None,
        scaler=None,
        dataloader_train=None,
        dataloader_val=None,
        epoch=None,
        writer=None,
        **kwargs,
    )

Source from the content-addressed store, hash-verified

343 dist.barrier()
344
345 def train_epoch(
346 self,
347 model=None,
348 optim=None,
349 scheduler=None,
350 scaler=None,
351 dataloader_train=None,
352 dataloader_val=None,
353 epoch=None,
354 writer=None,
355 **kwargs,
356 ):
357 """
358 Defines the training process for a single epoch with gradient accumulation.
359 Args:
360 epoch (int): The current epoch number.
361 """
362 if self.use_ddp or self.use_fsdp:
363 dist.barrier()
364 logging.info(f"Train epoch: {epoch}, rank: {self.rank}\n")
365 model.train()
366
367 # Set the number of steps for gradient accumulation
368 accum_grad = self.accum_grad
369 # Initialize the gradient accumulation
370 optim.zero_grad()
371 speed_stats = {}
372
373 iterator_stop = torch.tensor(0).to(self.device)
374
375 dataloader_train.batch_sampler.set_epoch(epoch)
376 time_beg = time.perf_counter()
377 time5 = time_beg
378 for batch_idx, batch in enumerate(dataloader_train):
379 # if self.use_ddp or self.use_fsdp:
380 # dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
381 # if iterator_stop > 0:
382 # break
383 self.batch_total += 1
384 self.step_in_epoch += 1
385 time1 = time.perf_counter()
386 speed_stats["data_load"] = f"{time1-time_beg:0.3f}"
387
388 batch = to_device(batch, self.device, non_blocking=True)
389
390 my_context = nullcontext
391 if self.use_ddp or self.use_fsdp:
392 my_context = model.no_sync if batch_idx % accum_grad != 0 else my_context
393 with my_context():
394 time2 = time.perf_counter()
395 with maybe_autocast(self.amp_enabled, dtype=self.amp_dtype):
396 retval = model(**batch)
397
398 # if (
399 # self.reset_gpu_cache
400 # and (torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024) > 70
401 # ):
402 # torch.cuda.empty_cache()

Callers 2

mainFunction · 0.95
mainFunction · 0.95

Calls 11

logMethod · 0.95
validate_epochMethod · 0.95
save_checkpointMethod · 0.95
to_deviceFunction · 0.90
parametersMethod · 0.80
maybe_autocastFunction · 0.70
trainMethod · 0.45
set_epochMethod · 0.45
backwardMethod · 0.45
stepMethod · 0.45
updateMethod · 0.45

Tested by

no test coverage detected