Main training program.
()
| 479 | return val_dataloader |
| 480 | |
| 481 | def main(): |
| 482 | """Main training program.""" |
| 483 | |
| 484 | print('Evaluate GPT2 model') |
| 485 | |
| 486 | # Disable CuDNN. |
| 487 | torch.backends.cudnn.enabled = False |
| 488 | |
| 489 | # Timer. |
| 490 | timers = Timers() |
| 491 | |
| 492 | # Arguments. |
| 493 | args = get_args() |
| 494 | |
| 495 | # Pytorch distributed. |
| 496 | initialize_distributed(args) |
| 497 | |
| 498 | # Random seeds for reproducability. |
| 499 | set_random_seed(args.seed) |
| 500 | |
| 501 | # Data stuff. |
| 502 | eval_data = get_eval_data(args) |
| 503 | |
| 504 | # Model, optimizer, and learning rate. |
| 505 | if args.eval_hf: |
| 506 | from pytorch_pretrained_bert import GPT2LMHeadModel |
| 507 | from pytorch_pretrained_bert import GPT2Model as HFGPT2Model |
| 508 | if args.num_layers == 24: |
| 509 | model_path = args.load |
| 510 | #model_path = '/home/universal-lm-data.cosmos549/repos/gpt2_mp/models/345M' |
| 511 | hfmodel = HFGPT2Model.from_pretrained(model_path, cache_dir='gpt2_weights', from_tf=True).cuda() |
| 512 | model = GPT2LMHeadModel(hfmodel.config) |
| 513 | model.transformer.load_state_dict(hfmodel.state_dict()) |
| 514 | model.cuda() |
| 515 | else: |
| 516 | model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir='gpt2_weights').cuda() |
| 517 | else: |
| 518 | if args.load_openai: |
| 519 | from utils import move_weights |
| 520 | model_path = args.load |
| 521 | args.load = None |
| 522 | model = setup_model(args) |
| 523 | from pytorch_pretrained_bert import GPT2LMHeadModel |
| 524 | from pytorch_pretrained_bert import GPT2Model as HFGPT2Model |
| 525 | |
| 526 | model_path = 'gpt2' |
| 527 | from_tf = False |
| 528 | print('loading openai weights') |
| 529 | model.cpu() |
| 530 | if args.num_layers == 24: |
| 531 | #model_path = '/home/universal-lm-data.cosmos549/repos/gpt2_mp/models/345M' |
| 532 | hfmodel = HFGPT2Model.from_pretrained(model_path, cache_dir='gpt2_weights', from_tf=True) |
| 533 | gpt2model = GPT2LMHeadModel(hfmodel.config) |
| 534 | gpt2model.transformer.load_state_dict(hfmodel.state_dict()) |
| 535 | gpt2model |
| 536 | else: |
| 537 | gpt2model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir='gpt2_weights') |
| 538 | model2fill = model |
no test coverage detected