MCPcopy
hub / github.com/hpcaitech/Open-Sora / main

Function main

scripts/inference_vae.py:18–168  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

16
17
18def main():
19 torch.set_grad_enabled(False)
20 # ======================================================
21 # configs & runtime variables
22 # ======================================================
23 # == parse configs ==
24 cfg = parse_configs(training=False)
25
26 # == device and dtype ==
27 device = "cuda" if torch.cuda.is_available() else "cpu"
28 cfg_dtype = cfg.get("dtype", "fp32")
29 assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}"
30 dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
31 torch.backends.cuda.matmul.allow_tf32 = True
32 torch.backends.cudnn.allow_tf32 = True
33
34 # == init distributed env ==
35 if is_distributed():
36 colossalai.launch_from_torch({})
37 set_random_seed(seed=cfg.get("seed", 1024))
38
39 # == init logger ==
40 logger = create_logger()
41 logger.info("Inference configuration:\n %s", pformat(cfg.to_dict()))
42 verbose = cfg.get("verbose", 1)
43
44 # ======================================================
45 # build dataset and dataloader
46 # ======================================================
47 logger.info("Building reconstruction dataset...")
48 dataset = build_module(cfg.dataset, DATASETS)
49 batch_size = cfg.get("batch_size", 1)
50 dataloader, _ = prepare_dataloader(
51 dataset,
52 batch_size=batch_size,
53 num_workers=cfg.get("num_workers", 4),
54 shuffle=False,
55 drop_last=False,
56 pin_memory=True,
57 process_group=get_data_parallel_group(),
58 )
59 logger.info("Dataset %s contains %s videos.", cfg.dataset.data_path, len(dataset))
60 total_batch_size = batch_size * get_world_size()
61 logger.info("Total batch size: %s", total_batch_size)
62
63 total_steps = len(dataloader)
64 if cfg.get("num_samples", None) is not None:
65 total_steps = min(int(cfg.num_samples // cfg.batch_size), total_steps)
66 logger.info("limiting test dataset to %s", int(cfg.num_samples // cfg.batch_size) * cfg.batch_size)
67 dataiter = iter(dataloader)
68
69 # ======================================================
70 # build model & loss
71 # ======================================================
72 logger.info("Building models...")
73 model = build_module(cfg.model, MODELS).to(device, dtype).eval()
74 vae_loss_fn = VAELoss(
75 logvar_init=cfg.get("logvar_init", 0.0),

Callers 1

inference_vae.pyFile · 0.70

Calls 15

parse_configsFunction · 0.90
to_torch_dtypeFunction · 0.90
is_distributedFunction · 0.90
create_loggerFunction · 0.90
build_moduleFunction · 0.90
prepare_dataloaderFunction · 0.90
get_data_parallel_groupFunction · 0.90
get_world_sizeFunction · 0.90
VAELossClass · 0.90
is_main_processFunction · 0.90
save_sampleFunction · 0.90
tqdmFunction · 0.85

Tested by

no test coverage detected