()
| 46 | |
| 47 | |
| 48 | def main(): |
| 49 | # initialize |
| 50 | parse_args() |
| 51 | colossalai.legacy.launch_from_torch(config="./config.py", seed=1234, backend="nccl") |
| 52 | |
| 53 | logger = get_dist_logger() |
| 54 | |
| 55 | # build synthetic dataloader |
| 56 | BATCH_SIZE_PER_GPUS = gpc.config.GLOBAL_BATCH_SIZE // gpc.get_world_size(ParallelMode.DATA) |
| 57 | VOCAB_SIZE = 30528 |
| 58 | trainloader = DummyDataloader( |
| 59 | batch_size=BATCH_SIZE_PER_GPUS, vocab_size=VOCAB_SIZE, seq_length=gpc.config.SEQ_LENGTH |
| 60 | ) |
| 61 | validloader = DummyDataloader( |
| 62 | batch_size=BATCH_SIZE_PER_GPUS, vocab_size=VOCAB_SIZE, seq_length=gpc.config.SEQ_LENGTH |
| 63 | ) |
| 64 | |
| 65 | logger.info("Dataloaders are built", ranks=[0]) |
| 66 | |
| 67 | # build model |
| 68 | if hasattr(gpc.config, "fp16") and gpc.config.fp16.get("mode") == AMP_TYPE.NAIVE: |
| 69 | is_naive_fp16 = True |
| 70 | else: |
| 71 | is_naive_fp16 = False |
| 72 | |
| 73 | use_pipeline = is_using_pp() |
| 74 | kwargs = dict( |
| 75 | vocab_size=VOCAB_SIZE, |
| 76 | hidden_size=gpc.config.HIDDEN_SIZE, |
| 77 | max_sequence_length=gpc.config.SEQ_LENGTH, |
| 78 | num_attention_heads=gpc.config.NUM_ATTENTION_HEADS, |
| 79 | convert_fp16_to_fp32_in_softmax=True, |
| 80 | is_naive_fp16=is_naive_fp16, |
| 81 | add_binary_head=gpc.config.ADD_BINARY_HEAD, |
| 82 | ) |
| 83 | |
| 84 | if use_pipeline: |
| 85 | model = build_pipeline_bert(num_layers=gpc.config.DEPTH, num_chunks=1, **kwargs) |
| 86 | else: |
| 87 | model = BertForPretrain(num_layers=gpc.config.DEPTH, **kwargs) |
| 88 | |
| 89 | model = model.half() |
| 90 | model.reset_parameters() |
| 91 | logger.info(f"Model is built with softmax in fp32 = {is_naive_fp16}", ranks=[0]) |
| 92 | |
| 93 | total_numel = 0 |
| 94 | for p in model.parameters(): |
| 95 | total_numel += p.numel() |
| 96 | logger.info(f"This model has {total_numel} parameters") |
| 97 | |
| 98 | # build criterion |
| 99 | criterion = BertLoss() |
| 100 | logger.info("Criterion is built", ranks=[0]) |
| 101 | |
| 102 | # layernorm and bias has no weight decay |
| 103 | weight_decay_params = {"params": []} |
| 104 | no_weight_decay_params = {"params": [], "weight_decay": 0.0} |
| 105 | for module_ in model.modules(): |
no test coverage detected
searching dependent graphs…