MCPcopy
hub / github.com/zai-org/CogView / train_step

Function train_step

pretrain_gpt2.py:406–448  ·  view source on GitHub ↗

Single training step.

(data_iterator, model, optimizer, lr_scheduler,
               args, timers, mems)

Source from the content-addressed store, hash-verified

404 print(" ")
405
406def train_step(data_iterator, model, optimizer, lr_scheduler,
407 args, timers, mems):
408 """Single training step."""
409 while True:
410 # Forward model for one step.
411 timers('forward').start()
412 lm_loss, mems, img_loss, txt_loss = forward_step(data_iterator, model, args, timers, mems)
413 timers('forward').stop()
414
415 if (img_loss + txt_loss).isnan().any() or (img_loss + txt_loss).isinf().any():
416 print('Skipping backward and optimizer step for nan or inf in forwarding!')
417 return (img_loss + txt_loss), 1, mems, img_loss, txt_loss
418
419 # Calculate gradients, reduce across processes, and clip.
420 timers('backward').start()
421 lm_loss_reduced = backward_step(optimizer, model, lm_loss, args, timers)
422 timers('backward').stop()
423
424 # Update parameters.
425 skipped_iter, complete = 0, False
426 timers('optimizer').start()
427 if args.deepspeed:
428 if model.is_gradient_accumulation_boundary():
429 model.step()
430 complete = True
431 if not (args.fp16 and optimizer.overflow):
432 lr_scheduler.step()
433 else:
434 skipped_iter = 1
435 else:
436 model.step()
437 else:
438 optimizer.step()
439 complete = True
440 # Update learning rate.
441 if not (args.fp16 and optimizer.overflow):
442 lr_scheduler.step()
443 else:
444 skipped_iter = 1
445 timers('optimizer').stop()
446 if complete:
447 break
448 return lm_loss_reduced, skipped_iter, mems, img_loss, txt_loss
449
450
451def report_iteration_metrics(summary_writer, optimizer, lr, loss, elapsed_time, step, total_step, args, img_loss, txt_loss):

Callers 1

trainFunction · 0.85

Calls 5

forward_stepFunction · 0.85
backward_stepFunction · 0.85
startMethod · 0.80
stopMethod · 0.80
stepMethod · 0.45

Tested by

no test coverage detected