MCPcopy
hub / github.com/facebookresearch/MetaCLIP / backward

Function backward

src/training/train.py:79–109  ·  view source on GitHub ↗
(args, total_loss, scaler, optimizer, model)

Source from the content-addressed store, hash-verified

77
78
79def backward(args, total_loss, scaler, optimizer, model):
80 # total_loss.requires_grad = True
81 if torch.isfinite(total_loss).all():
82 if scaler is not None:
83 scaler.scale(total_loss).backward()
84 # if args.world_size == 1:
85 # from src.training.detect import detect_unused_parameters
86 # detect_unused_parameters(model)
87 if args.norm_gradient_clip is not None:
88 scaler.unscale_(optimizer)
89 torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0)
90 scaler.step(optimizer)
91 scaler.update()
92 else:
93 total_loss.backward()
94 # if args.world_size == 1:
95 # from src.training.detect import detect_unused_parameters
96 # detect_unused_parameters(model)
97 # detect_nan(model, optimizer)
98 if args.norm_gradient_clip is not None:
99 torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0)
100 optimizer.step()
101
102 # Note: we clamp to 4.6052 = ln(100), as in the original paper.
103 if hasattr(unwrap_model(model), "logit_scale"):
104 with torch.no_grad():
105 unwrap_model(model).logit_scale.clamp_(0, math.log(100))
106 else:
107 logging.warn(f"Loss is {total_loss}, skip back prop.")
108 import sys
109 sys.exit(1) # protect the checkpoint for debugging.
110
111
112def train_one_epoch_ex(args, model, data, start_step, total_steps, optimizer, scaler, scheduler, tb_writer=None):

Callers 2

train_altogetherFunction · 0.90
train_one_epoch_exFunction · 0.85

Calls 2

unwrap_modelFunction · 0.90
updateMethod · 0.80

Tested by

no test coverage detected