MCPcopy Index your code
hub / github.com/hpcaitech/ColossalAI / main

Function main

examples/tutorial/hybrid_parallel/train.py:42–136  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

40
41
42def main():
43 # launch from torch
44 parser = colossalai.legacy.get_default_parser()
45 args = parser.parse_args()
46 colossalai.legacy.launch_from_torch(config=args.config)
47
48 # get logger
49 logger = get_dist_logger()
50 logger.info("initialized distributed environment", ranks=[0])
51
52 if hasattr(gpc.config, "LOG_PATH"):
53 if gpc.get_global_rank() == 0:
54 log_path = gpc.config.LOG_PATH
55 if not os.path.exists(log_path):
56 os.mkdir(log_path)
57 logger.log_to_file(log_path)
58
59 use_pipeline = is_using_pp()
60
61 # create model
62 model_kwargs = dict(
63 img_size=gpc.config.IMG_SIZE,
64 patch_size=gpc.config.PATCH_SIZE,
65 hidden_size=gpc.config.HIDDEN_SIZE,
66 depth=gpc.config.DEPTH,
67 num_heads=gpc.config.NUM_HEADS,
68 mlp_ratio=gpc.config.MLP_RATIO,
69 num_classes=10,
70 init_method="jax",
71 checkpoint=gpc.config.CHECKPOINT,
72 )
73
74 if use_pipeline:
75 pipelinable = PipelinableContext()
76 with pipelinable:
77 model = _create_vit_model(**model_kwargs)
78 pipelinable.to_layer_list()
79 pipelinable.policy = "uniform"
80 model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
81 else:
82 model = _create_vit_model(**model_kwargs)
83
84 # count number of parameters
85 total_numel = 0
86 for p in model.parameters():
87 total_numel += p.numel()
88 if not gpc.is_initialized(ParallelMode.PIPELINE):
89 pipeline_stage = 0
90 else:
91 pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
92 logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
93
94 # use synthetic dataset
95 # we train for 10 steps and eval for 5 steps per epoch
96 train_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE)
97 test_dataloader = DummyDataloader(length=5, batch_size=gpc.config.BATCH_SIZE)
98
99 # create loss function

Callers 1

train.pyFile · 0.70

Calls 15

to_layer_listMethod · 0.95
partitionMethod · 0.95
get_dist_loggerFunction · 0.90
is_using_ppFunction · 0.90
PipelinableContextClass · 0.90
CrossEntropyLossClass · 0.90
log_to_fileMethod · 0.80
get_local_rankMethod · 0.80
execute_scheduleMethod · 0.80
DummyDataloaderClass · 0.70
infoMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…