MCPcopy
hub / github.com/hpcaitech/Open-Sora / main

Function main

scripts/misc/extract_feat.py:16–170  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

14
15
16def main():
17 torch.set_grad_enabled(False)
18 # ======================================================
19 # 1. configs & runtime variables
20 # ======================================================
21 # == parse configs ==
22 cfg = parse_configs(training=False)
23
24 # == device and dtype ==
25 assert torch.cuda.is_available(), "Training currently requires at least one GPU."
26 cfg_dtype = cfg.get("dtype", "bf16")
27 assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
28 dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
29
30 # == colossalai init distributed training ==
31 device = "cuda" if torch.cuda.is_available() else "cpu"
32 cfg_dtype = cfg.get("dtype", "fp32")
33 assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}"
34 dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
35 torch.backends.cuda.matmul.allow_tf32 = True
36 torch.backends.cudnn.allow_tf32 = True
37
38 colossalai.launch_from_torch({})
39 set_data_parallel_group(dist.group.WORLD)
40
41 # == init logger, tensorboard & wandb ==
42 logger = create_logger()
43 logger.info("Configuration:\n %s", pformat(cfg.to_dict()))
44
45 # ======================================================
46 # 2. build dataset and dataloader
47 # ======================================================
48 logger.info("Building dataset...")
49 # == build dataset ==
50 dataset = build_module(cfg.dataset, DATASETS)
51 logger.info("Dataset contains %s samples.", len(dataset))
52
53 # == build dataloader ==
54 dataloader_args = dict(
55 dataset=dataset,
56 batch_size=cfg.get("batch_size", None),
57 num_workers=cfg.get("num_workers", 4),
58 seed=cfg.get("seed", 1024),
59 shuffle=True,
60 drop_last=True,
61 pin_memory=True,
62 process_group=get_data_parallel_group(),
63 )
64 dataloader, _ = prepare_dataloader(
65 bucket_config=cfg.get("bucket_config", None),
66 num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
67 **dataloader_args,
68 )
69 num_steps_per_epoch = len(dataloader)
70
71 # ======================================================
72 # 3. build model
73 # ======================================================

Callers 1

extract_feat.pyFile · 0.70

Calls 15

updateMethod · 0.95
parse_configsFunction · 0.90
to_torch_dtypeFunction · 0.90
set_data_parallel_groupFunction · 0.90
create_loggerFunction · 0.90
build_moduleFunction · 0.90
get_data_parallel_groupFunction · 0.90
prepare_dataloaderFunction · 0.90
get_model_numelFunction · 0.90
format_numel_strFunction · 0.90
save_training_configFunction · 0.90
FeatureSaverClass · 0.90

Tested by

no test coverage detected