()
| 36 | |
| 37 | |
| 38 | def main(): |
| 39 | # initialize distributed setting |
| 40 | parser = colossalai.legacy.get_default_parser() |
| 41 | parser.add_argument( |
| 42 | "--optimizer", choices=["lars", "lamb"], help="Choose your large-batch optimizer", required=True |
| 43 | ) |
| 44 | args = parser.parse_args() |
| 45 | |
| 46 | # launch from torch |
| 47 | colossalai.legacy.launch_from_torch(config=args.config) |
| 48 | |
| 49 | # get logger |
| 50 | logger = get_dist_logger() |
| 51 | logger.info("initialized distributed environment", ranks=[0]) |
| 52 | |
| 53 | # create synthetic dataloaders |
| 54 | train_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE) |
| 55 | test_dataloader = DummyDataloader(length=5, batch_size=gpc.config.BATCH_SIZE) |
| 56 | |
| 57 | # build model |
| 58 | model = resnet18(num_classes=gpc.config.NUM_CLASSES) |
| 59 | |
| 60 | # create loss function |
| 61 | criterion = nn.CrossEntropyLoss() |
| 62 | |
| 63 | # create optimizer |
| 64 | if args.optimizer == "lars": |
| 65 | optim_cls = Lars |
| 66 | elif args.optimizer == "lamb": |
| 67 | optim_cls = Lamb |
| 68 | optimizer = optim_cls(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) |
| 69 | |
| 70 | # create lr scheduler |
| 71 | lr_scheduler = CosineAnnealingWarmupLR( |
| 72 | optimizer=optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=gpc.config.WARMUP_EPOCHS |
| 73 | ) |
| 74 | |
| 75 | # initialize |
| 76 | engine, train_dataloader, test_dataloader, _ = colossalai.legacy.initialize( |
| 77 | model=model, |
| 78 | optimizer=optimizer, |
| 79 | criterion=criterion, |
| 80 | train_dataloader=train_dataloader, |
| 81 | test_dataloader=test_dataloader, |
| 82 | ) |
| 83 | |
| 84 | logger.info("Engine is built", ranks=[0]) |
| 85 | |
| 86 | for epoch in range(gpc.config.NUM_EPOCHS): |
| 87 | # training |
| 88 | engine.train() |
| 89 | data_iter = iter(train_dataloader) |
| 90 | |
| 91 | if gpc.get_global_rank() == 0: |
| 92 | description = "Epoch {} / {}".format(epoch, gpc.config.NUM_EPOCHS) |
| 93 | progress = tqdm(range(len(train_dataloader)), desc=description) |
| 94 | else: |
| 95 | progress = range(len(train_dataloader)) |
no test coverage detected
searching dependent graphs…