MCPcopy
hub / github.com/hpcaitech/ColossalAI / main

Function main

examples/tutorial/sequence_parallel/train.py:48–230  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

46
47
48def main():
49 # initialize
50 parse_args()
51 colossalai.legacy.launch_from_torch(config="./config.py", seed=1234, backend="nccl")
52
53 logger = get_dist_logger()
54
55 # build synthetic dataloader
56 BATCH_SIZE_PER_GPUS = gpc.config.GLOBAL_BATCH_SIZE // gpc.get_world_size(ParallelMode.DATA)
57 VOCAB_SIZE = 30528
58 trainloader = DummyDataloader(
59 batch_size=BATCH_SIZE_PER_GPUS, vocab_size=VOCAB_SIZE, seq_length=gpc.config.SEQ_LENGTH
60 )
61 validloader = DummyDataloader(
62 batch_size=BATCH_SIZE_PER_GPUS, vocab_size=VOCAB_SIZE, seq_length=gpc.config.SEQ_LENGTH
63 )
64
65 logger.info("Dataloaders are built", ranks=[0])
66
67 # build model
68 if hasattr(gpc.config, "fp16") and gpc.config.fp16.get("mode") == AMP_TYPE.NAIVE:
69 is_naive_fp16 = True
70 else:
71 is_naive_fp16 = False
72
73 use_pipeline = is_using_pp()
74 kwargs = dict(
75 vocab_size=VOCAB_SIZE,
76 hidden_size=gpc.config.HIDDEN_SIZE,
77 max_sequence_length=gpc.config.SEQ_LENGTH,
78 num_attention_heads=gpc.config.NUM_ATTENTION_HEADS,
79 convert_fp16_to_fp32_in_softmax=True,
80 is_naive_fp16=is_naive_fp16,
81 add_binary_head=gpc.config.ADD_BINARY_HEAD,
82 )
83
84 if use_pipeline:
85 model = build_pipeline_bert(num_layers=gpc.config.DEPTH, num_chunks=1, **kwargs)
86 else:
87 model = BertForPretrain(num_layers=gpc.config.DEPTH, **kwargs)
88
89 model = model.half()
90 model.reset_parameters()
91 logger.info(f"Model is built with softmax in fp32 = {is_naive_fp16}", ranks=[0])
92
93 total_numel = 0
94 for p in model.parameters():
95 total_numel += p.numel()
96 logger.info(f"This model has {total_numel} parameters")
97
98 # build criterion
99 criterion = BertLoss()
100 logger.info("Criterion is built", ranks=[0])
101
102 # layernorm and bias has no weight decay
103 weight_decay_params = {"params": []}
104 no_weight_decay_params = {"params": [], "weight_decay": 0.0}
105 for module_ in model.modules():

Callers 1

train.pyFile · 0.70

Calls 15

reset_parametersMethod · 0.95
startMethod · 0.95
stepMethod · 0.95
stopMethod · 0.95
get_dist_loggerFunction · 0.90
DummyDataloaderClass · 0.90
is_using_ppFunction · 0.90
build_pipeline_bertFunction · 0.90
BertForPretrainClass · 0.90
BertLossClass · 0.90
FusedAdamClass · 0.90
AnnealingLRClass · 0.90

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…