()
| 14 | |
| 15 | |
| 16 | def main(): |
| 17 | torch.set_grad_enabled(False) |
| 18 | # ====================================================== |
| 19 | # 1. configs & runtime variables |
| 20 | # ====================================================== |
| 21 | # == parse configs == |
| 22 | cfg = parse_configs(training=False) |
| 23 | |
| 24 | # == device and dtype == |
| 25 | assert torch.cuda.is_available(), "Training currently requires at least one GPU." |
| 26 | cfg_dtype = cfg.get("dtype", "bf16") |
| 27 | assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}" |
| 28 | dtype = to_torch_dtype(cfg.get("dtype", "bf16")) |
| 29 | |
| 30 | # == colossalai init distributed training == |
| 31 | device = "cuda" if torch.cuda.is_available() else "cpu" |
| 32 | cfg_dtype = cfg.get("dtype", "fp32") |
| 33 | assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}" |
| 34 | dtype = to_torch_dtype(cfg.get("dtype", "bf16")) |
| 35 | torch.backends.cuda.matmul.allow_tf32 = True |
| 36 | torch.backends.cudnn.allow_tf32 = True |
| 37 | |
| 38 | colossalai.launch_from_torch({}) |
| 39 | set_data_parallel_group(dist.group.WORLD) |
| 40 | |
| 41 | # == init logger, tensorboard & wandb == |
| 42 | logger = create_logger() |
| 43 | logger.info("Configuration:\n %s", pformat(cfg.to_dict())) |
| 44 | |
| 45 | # ====================================================== |
| 46 | # 2. build dataset and dataloader |
| 47 | # ====================================================== |
| 48 | logger.info("Building dataset...") |
| 49 | # == build dataset == |
| 50 | dataset = build_module(cfg.dataset, DATASETS) |
| 51 | logger.info("Dataset contains %s samples.", len(dataset)) |
| 52 | |
| 53 | # == build dataloader == |
| 54 | dataloader_args = dict( |
| 55 | dataset=dataset, |
| 56 | batch_size=cfg.get("batch_size", None), |
| 57 | num_workers=cfg.get("num_workers", 4), |
| 58 | seed=cfg.get("seed", 1024), |
| 59 | shuffle=True, |
| 60 | drop_last=True, |
| 61 | pin_memory=True, |
| 62 | process_group=get_data_parallel_group(), |
| 63 | ) |
| 64 | dataloader, _ = prepare_dataloader( |
| 65 | bucket_config=cfg.get("bucket_config", None), |
| 66 | num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1), |
| 67 | **dataloader_args, |
| 68 | ) |
| 69 | num_steps_per_epoch = len(dataloader) |
| 70 | |
| 71 | # ====================================================== |
| 72 | # 3. build model |
| 73 | # ====================================================== |
no test coverage detected