Build the model.
(args)
| 49 | |
| 50 | |
| 51 | def get_model(args): |
| 52 | """Build the model.""" |
| 53 | |
| 54 | print_rank_0('building BERT model ...') |
| 55 | model = BertModel(args) |
| 56 | |
| 57 | if mpu.get_data_parallel_rank() == 0: |
| 58 | print(' > number of parameters on model parallel rank {}: {}'.format( |
| 59 | mpu.get_model_parallel_rank(), |
| 60 | sum([p.nelement() for p in model.parameters()])), flush=True) |
| 61 | |
| 62 | # GPU allocation. |
| 63 | model.cuda(torch.cuda.current_device()) |
| 64 | |
| 65 | # Fp16 conversion. |
| 66 | if args.fp16: |
| 67 | model = FP16_Module(model) |
| 68 | if args.fp32_embedding: |
| 69 | model.module.model.bert.embeddings.word_embeddings.float() |
| 70 | model.module.model.bert.embeddings.position_embeddings.float() |
| 71 | model.module.model.bert.embeddings.token_type_embeddings.float() |
| 72 | if args.fp32_tokentypes: |
| 73 | model.module.model.bert.embeddings.token_type_embeddings.float() |
| 74 | if args.fp32_layernorm: |
| 75 | for name, _module in model.named_modules(): |
| 76 | if 'LayerNorm' in name: |
| 77 | _module.float() |
| 78 | |
| 79 | # Wrap model for distributed training. |
| 80 | if USE_TORCH_DDP: |
| 81 | i = torch.cuda.current_device() |
| 82 | model = DDP(model, device_ids=[i], output_device=i, |
| 83 | process_group=mpu.get_data_parallel_group()) |
| 84 | else: |
| 85 | model = DDP(model) |
| 86 | |
| 87 | return model |
| 88 | |
| 89 | |
| 90 | def get_optimizer(model, args): |
no test coverage detected