(self, batch_idx, batch)
| 367 | task_ref.on_train_end() |
| 368 | |
| 369 | def run_training_batch(self, batch_idx, batch): |
| 370 | if batch is None: |
| 371 | return {} |
| 372 | all_progress_bar_metrics = [] |
| 373 | all_log_metrics = [] |
| 374 | task_ref = self.get_task_ref() |
| 375 | for opt_idx, optimizer in enumerate(self.optimizers): |
| 376 | if optimizer is None: |
| 377 | continue |
| 378 | # make sure only the gradients of the current optimizer's paramaters are calculated |
| 379 | # in the training step to prevent dangling gradients in multiple-optimizer setup. |
| 380 | if len(self.optimizers) > 1: |
| 381 | for k, param in task_ref.named_parameters(): |
| 382 | param.requires_grad = False |
| 383 | for group in optimizer.param_groups: |
| 384 | for param in group['params']: |
| 385 | param.requires_grad = True |
| 386 | |
| 387 | # forward pass |
| 388 | with Timer("forward_training_step", enable=self.debug): |
| 389 | with autocast(enabled=self.amp): |
| 390 | if self.on_gpu: |
| 391 | batch = move_to_cuda(copy.copy(batch), self.root_gpu) |
| 392 | args = [batch, batch_idx, opt_idx] |
| 393 | if self.use_ddp: |
| 394 | output = self.task(*args) |
| 395 | else: |
| 396 | output = task_ref.training_step(*args) |
| 397 | loss = output['loss'] |
| 398 | if loss is None: |
| 399 | continue |
| 400 | progress_bar_metrics = output['progress_bar'] |
| 401 | log_metrics = output['tb_log'] |
| 402 | # accumulate loss |
| 403 | loss = loss / self.accumulate_grad_batches |
| 404 | |
| 405 | # backward pass |
| 406 | with Timer("backward_training_step", enable=self.debug): |
| 407 | if loss.requires_grad: |
| 408 | if self.amp: |
| 409 | self.amp_scalar.scale(loss).backward() |
| 410 | else: |
| 411 | loss.backward() |
| 412 | |
| 413 | # track progress bar metrics |
| 414 | all_log_metrics.append(log_metrics) |
| 415 | all_progress_bar_metrics.append(progress_bar_metrics) |
| 416 | |
| 417 | if loss is None: |
| 418 | continue |
| 419 | |
| 420 | # nan grads |
| 421 | with Timer("checkNan_training_step", enable=self.debug): |
| 422 | has_nan_grad = False |
| 423 | nan_params_names = [] |
| 424 | if self.print_nan_grads: |
| 425 | for name, param in task_ref.named_parameters(): |
| 426 | if (param.grad is not None) and torch.isnan(param.grad.float()).any(): |
no test coverage detected