Build the model.
(args)
| 39 | from utils import print_rank_0 |
| 40 | |
| 41 | def get_model(args): |
| 42 | """Build the model.""" |
| 43 | |
| 44 | print_rank_0('building GPT2 model ...') |
| 45 | model = GPT2Model(num_layers=args.num_layers, |
| 46 | vocab_size=args.vocab_size, |
| 47 | hidden_size=args.hidden_size, |
| 48 | num_attention_heads=args.num_attention_heads, |
| 49 | embedding_dropout_prob=args.hidden_dropout, |
| 50 | attention_dropout_prob=args.attention_dropout, |
| 51 | output_dropout_prob=args.hidden_dropout, |
| 52 | max_sequence_length=args.max_position_embeddings, |
| 53 | checkpoint_activations=args.checkpoint_activations, |
| 54 | checkpoint_num_layers=args.checkpoint_num_layers, |
| 55 | parallel_output=False) |
| 56 | |
| 57 | if mpu.get_data_parallel_rank() == 0: |
| 58 | print(' > number of parameters on model parallel rank {}: {}'.format( |
| 59 | mpu.get_model_parallel_rank(), |
| 60 | sum([p.nelement() for p in model.parameters()])), flush=True) |
| 61 | |
| 62 | # GPU allocation. |
| 63 | model.cuda(torch.cuda.current_device()) |
| 64 | |
| 65 | # Fp16 conversion. |
| 66 | if args.fp16: |
| 67 | model = FP16_Module(model) |
| 68 | |
| 69 | # Wrap model for distributed training. |
| 70 | model = DDP(model) |
| 71 | |
| 72 | return model |
| 73 | |
| 74 | def setup_model(args): |
| 75 | """Setup model and optimizer.""" |
no test coverage detected