MCPcopy
hub / github.com/hpcaitech/Open-Sora / main

Function main

scripts/train_vae.py:34–387  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

32
33
34def main():
35 # ======================================================
36 # 1. configs & runtime variables
37 # ======================================================
38 # == parse configs ==
39 cfg = parse_configs(training=True)
40
41 # == device and dtype ==
42 assert torch.cuda.is_available(), "Training currently requires at least one GPU."
43 cfg_dtype = cfg.get("dtype", "bf16")
44 assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
45 dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
46
47 # == colossalai init distributed training ==
48 # NOTE: A very large timeout is set to avoid some processes exit early
49 dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
50 torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
51 set_seed(cfg.get("seed", 1024))
52 coordinator = DistCoordinator()
53 device = get_current_device()
54
55 # == init exp_dir ==
56 exp_name, exp_dir = define_experiment_workspace(cfg)
57 coordinator.block_all()
58 if coordinator.is_master():
59 os.makedirs(exp_dir, exist_ok=True)
60 save_training_config(cfg.to_dict(), exp_dir)
61 coordinator.block_all()
62
63 # == init logger, tensorboard & wandb ==
64 logger = create_logger(exp_dir)
65 logger.info("Experiment directory created at %s", exp_dir)
66 logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
67 if coordinator.is_master():
68 tb_writer = create_tensorboard_writer(exp_dir)
69 if cfg.get("wandb", False):
70 wandb.init(project="minisora", name=exp_name, config=cfg.to_dict(), dir="./outputs/wandb")
71
72 # == init ColossalAI booster ==
73 plugin = create_colossalai_plugin(
74 plugin=cfg.get("plugin", "zero2"),
75 dtype=cfg_dtype,
76 grad_clip=cfg.get("grad_clip", 0),
77 sp_size=cfg.get("sp_size", 1),
78 )
79 booster = Booster(plugin=plugin)
80
81 # ======================================================
82 # 2. build dataset and dataloader
83 # ======================================================
84 logger.info("Building dataset...")
85 # == build dataset ==
86 assert cfg.dataset.type == "VideoTextDataset", "Only support VideoTextDataset for vae training"
87 dataset = build_module(cfg.dataset, DATASETS)
88 logger.info("Dataset contains %s samples.", len(dataset))
89
90 # == build dataloader ==
91 dataloader_args = dict(

Callers 1

train_vae.pyFile · 0.70

Calls 15

parse_configsFunction · 0.90
to_torch_dtypeFunction · 0.90
save_training_configFunction · 0.90
create_loggerFunction · 0.90
create_colossalai_pluginFunction · 0.90
build_moduleFunction · 0.90
get_data_parallel_groupFunction · 0.90
prepare_dataloaderFunction · 0.90
get_model_numelFunction · 0.90
format_numel_strFunction · 0.90

Tested by

no test coverage detected