MCPcopy
hub / github.com/karpathy/minGPT / batch_end_callback

Function batch_end_callback

projects/adder/adder.py:181–202  ·  view source on GitHub ↗
(trainer)

Source from the content-addressed store, hash-verified

179 # iteration callback
180 top_score = 0
181 def batch_end_callback(trainer):
182 global top_score
183
184 if trainer.iter_num % 10 == 0:
185 print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}")
186
187 if trainer.iter_num % 500 == 0:
188 # evaluate both the train and test score
189 train_max_batches = {1: None, 2: None, 3: 5}[config.data.ndigit] # if ndigit=2 we can afford the whole train set, ow no
190 model.eval()
191 with torch.no_grad():
192 train_score = eval_split(trainer, 'train', max_batches=train_max_batches)
193 test_score = eval_split(trainer, 'test', max_batches=None)
194 score = train_score + test_score
195 # save the model if this is the best score we've seen so far
196 if score > top_score:
197 top_score = score
198 print(f"saving model with new top score of {score}")
199 ckpt_path = os.path.join(config.system.work_dir, "model.pt")
200 torch.save(model.state_dict(), ckpt_path)
201 # revert model to training mode
202 model.train()
203
204 trainer.set_callback('on_batch_end', batch_end_callback)
205

Callers

nothing calls this directly

Calls 1

eval_splitFunction · 0.85

Tested by

no test coverage detected