(args_others, mp_size=1)
| 22 | |
| 23 | |
| 24 | def get_gpt2_model(args_others, mp_size=1): |
| 25 | from megatron.model import GPT2Model |
| 26 | from megatron.initialize import initialize_megatron |
| 27 | |
| 28 | args_defaults = { |
| 29 | 'vocab_file': get_test_path('gpt2-vocab.json'), |
| 30 | 'merge_file': get_test_path('gpt2-merges.txt'), |
| 31 | 'tokenizer_type': 'GPT2BPETokenizer', |
| 32 | } |
| 33 | |
| 34 | args_defaults.update(args_others) |
| 35 | |
| 36 | # setting "make-vocab-size-divisible-by" to avoid word-embedding size change in resizing testing. |
| 37 | sys.argv.extend(['--model-parallel-size', str(mp_size), '--make-vocab-size-divisible-by', str(1)]) |
| 38 | |
| 39 | initialize_megatron(args_defaults=args_defaults, ignore_unknown_args=True) |
| 40 | model = GPT2Model(num_tokentypes=0, parallel_output=False) |
| 41 | model.to(get_accelerator().device_name()) |
| 42 | from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP |
| 43 | from megatron import mpu |
| 44 | i = get_accelerator().current_device_name() |
| 45 | model = torchDDP(model, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group()) |
| 46 | |
| 47 | return model |
| 48 | |
| 49 | |
| 50 | class MockGPT2ModelPipe(PipelineModule): |
searching dependent graphs…