MCPcopy Index your code
hub / github.com/deepspeedai/DeepSpeedExamples / main

Function main

Megatron-LM/evaluate_gpt2.py:481–549  ·  view source on GitHub ↗

Main training program.

()

Source from the content-addressed store, hash-verified

479 return val_dataloader
480
481def main():
482 """Main training program."""
483
484 print('Evaluate GPT2 model')
485
486 # Disable CuDNN.
487 torch.backends.cudnn.enabled = False
488
489 # Timer.
490 timers = Timers()
491
492 # Arguments.
493 args = get_args()
494
495 # Pytorch distributed.
496 initialize_distributed(args)
497
498 # Random seeds for reproducability.
499 set_random_seed(args.seed)
500
501 # Data stuff.
502 eval_data = get_eval_data(args)
503
504 # Model, optimizer, and learning rate.
505 if args.eval_hf:
506 from pytorch_pretrained_bert import GPT2LMHeadModel
507 from pytorch_pretrained_bert import GPT2Model as HFGPT2Model
508 if args.num_layers == 24:
509 model_path = args.load
510 #model_path = '/home/universal-lm-data.cosmos549/repos/gpt2_mp/models/345M'
511 hfmodel = HFGPT2Model.from_pretrained(model_path, cache_dir='gpt2_weights', from_tf=True).cuda()
512 model = GPT2LMHeadModel(hfmodel.config)
513 model.transformer.load_state_dict(hfmodel.state_dict())
514 model.cuda()
515 else:
516 model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir='gpt2_weights').cuda()
517 else:
518 if args.load_openai:
519 from utils import move_weights
520 model_path = args.load
521 args.load = None
522 model = setup_model(args)
523 from pytorch_pretrained_bert import GPT2LMHeadModel
524 from pytorch_pretrained_bert import GPT2Model as HFGPT2Model
525
526 model_path = 'gpt2'
527 from_tf = False
528 print('loading openai weights')
529 model.cpu()
530 if args.num_layers == 24:
531 #model_path = '/home/universal-lm-data.cosmos549/repos/gpt2_mp/models/345M'
532 hfmodel = HFGPT2Model.from_pretrained(model_path, cache_dir='gpt2_weights', from_tf=True)
533 gpt2model = GPT2LMHeadModel(hfmodel.config)
534 gpt2model.transformer.load_state_dict(hfmodel.state_dict())
535 gpt2model
536 else:
537 gpt2model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir='gpt2_weights')
538 model2fill = model

Callers 1

evaluate_gpt2.pyFile · 0.70

Calls 11

TimersClass · 0.90
get_argsFunction · 0.90
move_weightsFunction · 0.90
get_eval_dataFunction · 0.85
initialize_distributedFunction · 0.70
set_random_seedFunction · 0.70
setup_modelFunction · 0.70
from_pretrainedMethod · 0.45
load_state_dictMethod · 0.45
state_dictMethod · 0.45

Tested by

no test coverage detected