(args, total_loss, scaler, optimizer, model)
| 77 | |
| 78 | |
| 79 | def backward(args, total_loss, scaler, optimizer, model): |
| 80 | # total_loss.requires_grad = True |
| 81 | if torch.isfinite(total_loss).all(): |
| 82 | if scaler is not None: |
| 83 | scaler.scale(total_loss).backward() |
| 84 | # if args.world_size == 1: |
| 85 | # from src.training.detect import detect_unused_parameters |
| 86 | # detect_unused_parameters(model) |
| 87 | if args.norm_gradient_clip is not None: |
| 88 | scaler.unscale_(optimizer) |
| 89 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0) |
| 90 | scaler.step(optimizer) |
| 91 | scaler.update() |
| 92 | else: |
| 93 | total_loss.backward() |
| 94 | # if args.world_size == 1: |
| 95 | # from src.training.detect import detect_unused_parameters |
| 96 | # detect_unused_parameters(model) |
| 97 | # detect_nan(model, optimizer) |
| 98 | if args.norm_gradient_clip is not None: |
| 99 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0) |
| 100 | optimizer.step() |
| 101 | |
| 102 | # Note: we clamp to 4.6052 = ln(100), as in the original paper. |
| 103 | if hasattr(unwrap_model(model), "logit_scale"): |
| 104 | with torch.no_grad(): |
| 105 | unwrap_model(model).logit_scale.clamp_(0, math.log(100)) |
| 106 | else: |
| 107 | logging.warn(f"Loss is {total_loss}, skip back prop.") |
| 108 | import sys |
| 109 | sys.exit(1) # protect the checkpoint for debugging. |
| 110 | |
| 111 | |
| 112 | def train_one_epoch_ex(args, model, data, start_step, total_steps, optimizer, scaler, scheduler, tb_writer=None): |
no test coverage detected