()
| 32 | |
| 33 | |
| 34 | def main(): |
| 35 | # ====================================================== |
| 36 | # 1. configs & runtime variables |
| 37 | # ====================================================== |
| 38 | # == parse configs == |
| 39 | cfg = parse_configs(training=True) |
| 40 | |
| 41 | # == device and dtype == |
| 42 | assert torch.cuda.is_available(), "Training currently requires at least one GPU." |
| 43 | cfg_dtype = cfg.get("dtype", "bf16") |
| 44 | assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}" |
| 45 | dtype = to_torch_dtype(cfg.get("dtype", "bf16")) |
| 46 | |
| 47 | # == colossalai init distributed training == |
| 48 | # NOTE: A very large timeout is set to avoid some processes exit early |
| 49 | dist.init_process_group(backend="nccl", timeout=timedelta(hours=24)) |
| 50 | torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) |
| 51 | set_seed(cfg.get("seed", 1024)) |
| 52 | coordinator = DistCoordinator() |
| 53 | device = get_current_device() |
| 54 | |
| 55 | # == init exp_dir == |
| 56 | exp_name, exp_dir = define_experiment_workspace(cfg) |
| 57 | coordinator.block_all() |
| 58 | if coordinator.is_master(): |
| 59 | os.makedirs(exp_dir, exist_ok=True) |
| 60 | save_training_config(cfg.to_dict(), exp_dir) |
| 61 | coordinator.block_all() |
| 62 | |
| 63 | # == init logger, tensorboard & wandb == |
| 64 | logger = create_logger(exp_dir) |
| 65 | logger.info("Experiment directory created at %s", exp_dir) |
| 66 | logger.info("Training configuration:\n %s", pformat(cfg.to_dict())) |
| 67 | if coordinator.is_master(): |
| 68 | tb_writer = create_tensorboard_writer(exp_dir) |
| 69 | if cfg.get("wandb", False): |
| 70 | wandb.init(project="minisora", name=exp_name, config=cfg.to_dict(), dir="./outputs/wandb") |
| 71 | |
| 72 | # == init ColossalAI booster == |
| 73 | plugin = create_colossalai_plugin( |
| 74 | plugin=cfg.get("plugin", "zero2"), |
| 75 | dtype=cfg_dtype, |
| 76 | grad_clip=cfg.get("grad_clip", 0), |
| 77 | sp_size=cfg.get("sp_size", 1), |
| 78 | ) |
| 79 | booster = Booster(plugin=plugin) |
| 80 | |
| 81 | # ====================================================== |
| 82 | # 2. build dataset and dataloader |
| 83 | # ====================================================== |
| 84 | logger.info("Building dataset...") |
| 85 | # == build dataset == |
| 86 | assert cfg.dataset.type == "VideoTextDataset", "Only support VideoTextDataset for vae training" |
| 87 | dataset = build_module(cfg.dataset, DATASETS) |
| 88 | logger.info("Dataset contains %s samples.", len(dataset)) |
| 89 | |
| 90 | # == build dataloader == |
| 91 | dataloader_args = dict( |
no test coverage detected