MCPcopy
hub / github.com/hpcaitech/Open-Sora / main

Function main

scripts/train.py:45–516  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

43
44
45def main():
46 # ======================================================
47 # 1. configs & runtime variables
48 # ======================================================
49 # == parse configs ==
50 cfg = parse_configs(training=True)
51 record_time = cfg.get("record_time", False)
52
53 # == device and dtype ==
54 assert torch.cuda.is_available(), "Training currently requires at least one GPU."
55 cfg_dtype = cfg.get("dtype", "bf16")
56 assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
57 dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
58
59 # == colossalai init distributed training ==
60 # NOTE: A very large timeout is set to avoid some processes exit early
61 dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
62 torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
63 set_seed(cfg.get("seed", 1024))
64 coordinator = DistCoordinator()
65 device = get_current_device()
66
67 # == init exp_dir ==
68 exp_name, exp_dir = define_experiment_workspace(cfg)
69 coordinator.block_all()
70 if coordinator.is_master():
71 os.makedirs(exp_dir, exist_ok=True)
72 save_training_config(cfg.to_dict(), exp_dir)
73 coordinator.block_all()
74
75 # == init logger, tensorboard & wandb ==
76 logger = create_logger(exp_dir)
77 logger.info("Experiment directory created at %s", exp_dir)
78 logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
79 if coordinator.is_master():
80 tb_writer = create_tensorboard_writer(exp_dir)
81 if cfg.get("wandb", False):
82 wandb.init(project="Open-Sora", name=exp_name, config=cfg.to_dict(), dir=exp_dir)
83
84 # == init ColossalAI booster ==
85 plugin = create_colossalai_plugin(
86 plugin=cfg.get("plugin", "zero2"),
87 dtype=cfg_dtype,
88 grad_clip=cfg.get("grad_clip", 0),
89 sp_size=cfg.get("sp_size", 1),
90 reduce_bucket_size_in_m=cfg.get("reduce_bucket_size_in_m", 20),
91 )
92 booster = Booster(plugin=plugin)
93 torch.set_num_threads(1)
94
95 # == build text-encoder ==
96 text_encoder = build_module(cfg.get("text_encoder", None), MODELS, device=device, dtype=dtype)
97 if text_encoder is not None:
98 text_encoder_output_dim = text_encoder.output_dim
99 text_encoder_model_max_length = text_encoder.model_max_length
100 cfg.dataset.tokenize_fn = text_encoder.tokenize_fn
101 else:
102 text_encoder_output_dim = cfg.get("text_encoder_output_dim", 4096)

Callers 1

train.pyFile · 0.70

Calls 15

get_masksMethod · 0.95
parse_configsFunction · 0.90
to_torch_dtypeFunction · 0.90
save_training_configFunction · 0.90
create_loggerFunction · 0.90
create_colossalai_pluginFunction · 0.90
build_moduleFunction · 0.90
get_data_parallel_groupFunction · 0.90
prepare_dataloaderFunction · 0.90
get_model_numelFunction · 0.90

Tested by

no test coverage detected