MCPcopy
hub / github.com/yerfor/GeneFacePlusPlus / run_training_batch

Method run_training_batch

utils/commons/trainer.py:369–456  ·  view source on GitHub ↗
(self, batch_idx, batch)

Source from the content-addressed store, hash-verified

367 task_ref.on_train_end()
368
369 def run_training_batch(self, batch_idx, batch):
370 if batch is None:
371 return {}
372 all_progress_bar_metrics = []
373 all_log_metrics = []
374 task_ref = self.get_task_ref()
375 for opt_idx, optimizer in enumerate(self.optimizers):
376 if optimizer is None:
377 continue
378 # make sure only the gradients of the current optimizer's paramaters are calculated
379 # in the training step to prevent dangling gradients in multiple-optimizer setup.
380 if len(self.optimizers) > 1:
381 for k, param in task_ref.named_parameters():
382 param.requires_grad = False
383 for group in optimizer.param_groups:
384 for param in group['params']:
385 param.requires_grad = True
386
387 # forward pass
388 with Timer("forward_training_step", enable=self.debug):
389 with autocast(enabled=self.amp):
390 if self.on_gpu:
391 batch = move_to_cuda(copy.copy(batch), self.root_gpu)
392 args = [batch, batch_idx, opt_idx]
393 if self.use_ddp:
394 output = self.task(*args)
395 else:
396 output = task_ref.training_step(*args)
397 loss = output['loss']
398 if loss is None:
399 continue
400 progress_bar_metrics = output['progress_bar']
401 log_metrics = output['tb_log']
402 # accumulate loss
403 loss = loss / self.accumulate_grad_batches
404
405 # backward pass
406 with Timer("backward_training_step", enable=self.debug):
407 if loss.requires_grad:
408 if self.amp:
409 self.amp_scalar.scale(loss).backward()
410 else:
411 loss.backward()
412
413 # track progress bar metrics
414 all_log_metrics.append(log_metrics)
415 all_progress_bar_metrics.append(progress_bar_metrics)
416
417 if loss is None:
418 continue
419
420 # nan grads
421 with Timer("checkNan_training_step", enable=self.debug):
422 has_nan_grad = False
423 nan_params_names = []
424 if self.print_nan_grads:
425 for name, param in task_ref.named_parameters():
426 if (param.grad is not None) and torch.isnan(param.grad.float()).any():

Callers 1

trainMethod · 0.95

Calls 11

get_task_refMethod · 0.95
TimerClass · 0.90
move_to_cudaFunction · 0.90
training_stepMethod · 0.80
scaleMethod · 0.80
appendMethod · 0.80
backwardMethod · 0.45
updateMethod · 0.45
stepMethod · 0.45
on_after_optimizationMethod · 0.45

Tested by

no test coverage detected