MCPcopy Index your code
hub / github.com/hpcaitech/ColossalAI / main

Function main

examples/tutorial/large_batch_optimizer/train.py:38–100  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

36
37
38def main():
39 # initialize distributed setting
40 parser = colossalai.legacy.get_default_parser()
41 parser.add_argument(
42 "--optimizer", choices=["lars", "lamb"], help="Choose your large-batch optimizer", required=True
43 )
44 args = parser.parse_args()
45
46 # launch from torch
47 colossalai.legacy.launch_from_torch(config=args.config)
48
49 # get logger
50 logger = get_dist_logger()
51 logger.info("initialized distributed environment", ranks=[0])
52
53 # create synthetic dataloaders
54 train_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE)
55 test_dataloader = DummyDataloader(length=5, batch_size=gpc.config.BATCH_SIZE)
56
57 # build model
58 model = resnet18(num_classes=gpc.config.NUM_CLASSES)
59
60 # create loss function
61 criterion = nn.CrossEntropyLoss()
62
63 # create optimizer
64 if args.optimizer == "lars":
65 optim_cls = Lars
66 elif args.optimizer == "lamb":
67 optim_cls = Lamb
68 optimizer = optim_cls(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
69
70 # create lr scheduler
71 lr_scheduler = CosineAnnealingWarmupLR(
72 optimizer=optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=gpc.config.WARMUP_EPOCHS
73 )
74
75 # initialize
76 engine, train_dataloader, test_dataloader, _ = colossalai.legacy.initialize(
77 model=model,
78 optimizer=optimizer,
79 criterion=criterion,
80 train_dataloader=train_dataloader,
81 test_dataloader=test_dataloader,
82 )
83
84 logger.info("Engine is built", ranks=[0])
85
86 for epoch in range(gpc.config.NUM_EPOCHS):
87 # training
88 engine.train()
89 data_iter = iter(train_dataloader)
90
91 if gpc.get_global_rank() == 0:
92 description = "Epoch {} / {}".format(epoch, gpc.config.NUM_EPOCHS)
93 progress = tqdm(range(len(train_dataloader)), desc=description)
94 else:
95 progress = range(len(train_dataloader))

Callers 1

train.pyFile · 0.70

Calls 11

get_dist_loggerFunction · 0.90
execute_scheduleMethod · 0.80
DummyDataloaderClass · 0.70
infoMethod · 0.45
parametersMethod · 0.45
initializeMethod · 0.45
trainMethod · 0.45
get_global_rankMethod · 0.45
zero_gradMethod · 0.45
stepMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…