()
| 43 | |
| 44 | |
| 45 | def main(): |
| 46 | # ====================================================== |
| 47 | # 1. configs & runtime variables |
| 48 | # ====================================================== |
| 49 | # == parse configs == |
| 50 | cfg = parse_configs(training=True) |
| 51 | record_time = cfg.get("record_time", False) |
| 52 | |
| 53 | # == device and dtype == |
| 54 | assert torch.cuda.is_available(), "Training currently requires at least one GPU." |
| 55 | cfg_dtype = cfg.get("dtype", "bf16") |
| 56 | assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}" |
| 57 | dtype = to_torch_dtype(cfg.get("dtype", "bf16")) |
| 58 | |
| 59 | # == colossalai init distributed training == |
| 60 | # NOTE: A very large timeout is set to avoid some processes exit early |
| 61 | dist.init_process_group(backend="nccl", timeout=timedelta(hours=24)) |
| 62 | torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) |
| 63 | set_seed(cfg.get("seed", 1024)) |
| 64 | coordinator = DistCoordinator() |
| 65 | device = get_current_device() |
| 66 | |
| 67 | # == init exp_dir == |
| 68 | exp_name, exp_dir = define_experiment_workspace(cfg) |
| 69 | coordinator.block_all() |
| 70 | if coordinator.is_master(): |
| 71 | os.makedirs(exp_dir, exist_ok=True) |
| 72 | save_training_config(cfg.to_dict(), exp_dir) |
| 73 | coordinator.block_all() |
| 74 | |
| 75 | # == init logger, tensorboard & wandb == |
| 76 | logger = create_logger(exp_dir) |
| 77 | logger.info("Experiment directory created at %s", exp_dir) |
| 78 | logger.info("Training configuration:\n %s", pformat(cfg.to_dict())) |
| 79 | if coordinator.is_master(): |
| 80 | tb_writer = create_tensorboard_writer(exp_dir) |
| 81 | if cfg.get("wandb", False): |
| 82 | wandb.init(project="Open-Sora", name=exp_name, config=cfg.to_dict(), dir=exp_dir) |
| 83 | |
| 84 | # == init ColossalAI booster == |
| 85 | plugin = create_colossalai_plugin( |
| 86 | plugin=cfg.get("plugin", "zero2"), |
| 87 | dtype=cfg_dtype, |
| 88 | grad_clip=cfg.get("grad_clip", 0), |
| 89 | sp_size=cfg.get("sp_size", 1), |
| 90 | reduce_bucket_size_in_m=cfg.get("reduce_bucket_size_in_m", 20), |
| 91 | ) |
| 92 | booster = Booster(plugin=plugin) |
| 93 | torch.set_num_threads(1) |
| 94 | |
| 95 | # == build text-encoder == |
| 96 | text_encoder = build_module(cfg.get("text_encoder", None), MODELS, device=device, dtype=dtype) |
| 97 | if text_encoder is not None: |
| 98 | text_encoder_output_dim = text_encoder.output_dim |
| 99 | text_encoder_model_max_length = text_encoder.model_max_length |
| 100 | cfg.dataset.tokenize_fn = text_encoder.tokenize_fn |
| 101 | else: |
| 102 | text_encoder_output_dim = cfg.get("text_encoder_output_dim", 4096) |
no test coverage detected