Defines the training process for a single epoch with gradient accumulation. Args: epoch (int): The current epoch number.
(
self,
model=None,
optim=None,
scheduler=None,
scaler=None,
dataloader_train=None,
dataloader_val=None,
epoch=None,
writer=None,
**kwargs,
)
| 343 | dist.barrier() |
| 344 | |
| 345 | def train_epoch( |
| 346 | self, |
| 347 | model=None, |
| 348 | optim=None, |
| 349 | scheduler=None, |
| 350 | scaler=None, |
| 351 | dataloader_train=None, |
| 352 | dataloader_val=None, |
| 353 | epoch=None, |
| 354 | writer=None, |
| 355 | **kwargs, |
| 356 | ): |
| 357 | """ |
| 358 | Defines the training process for a single epoch with gradient accumulation. |
| 359 | Args: |
| 360 | epoch (int): The current epoch number. |
| 361 | """ |
| 362 | if self.use_ddp or self.use_fsdp: |
| 363 | dist.barrier() |
| 364 | logging.info(f"Train epoch: {epoch}, rank: {self.rank}\n") |
| 365 | model.train() |
| 366 | |
| 367 | # Set the number of steps for gradient accumulation |
| 368 | accum_grad = self.accum_grad |
| 369 | # Initialize the gradient accumulation |
| 370 | optim.zero_grad() |
| 371 | speed_stats = {} |
| 372 | |
| 373 | iterator_stop = torch.tensor(0).to(self.device) |
| 374 | |
| 375 | dataloader_train.batch_sampler.set_epoch(epoch) |
| 376 | time_beg = time.perf_counter() |
| 377 | time5 = time_beg |
| 378 | for batch_idx, batch in enumerate(dataloader_train): |
| 379 | # if self.use_ddp or self.use_fsdp: |
| 380 | # dist.all_reduce(iterator_stop, dist.ReduceOp.SUM) |
| 381 | # if iterator_stop > 0: |
| 382 | # break |
| 383 | self.batch_total += 1 |
| 384 | self.step_in_epoch += 1 |
| 385 | time1 = time.perf_counter() |
| 386 | speed_stats["data_load"] = f"{time1-time_beg:0.3f}" |
| 387 | |
| 388 | batch = to_device(batch, self.device, non_blocking=True) |
| 389 | |
| 390 | my_context = nullcontext |
| 391 | if self.use_ddp or self.use_fsdp: |
| 392 | my_context = model.no_sync if batch_idx % accum_grad != 0 else my_context |
| 393 | with my_context(): |
| 394 | time2 = time.perf_counter() |
| 395 | with maybe_autocast(self.amp_enabled, dtype=self.amp_dtype): |
| 396 | retval = model(**batch) |
| 397 | |
| 398 | # if ( |
| 399 | # self.reset_gpu_cache |
| 400 | # and (torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024) > 70 |
| 401 | # ): |
| 402 | # torch.cuda.empty_cache() |
no test coverage detected