MCPcopy
hub / github.com/hpcaitech/Open-Sora / main

Function main

scripts/inference.py:41–379  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

39
40
41def main():
42 torch.set_grad_enabled(False)
43 # ======================================================
44 # configs & runtime variables
45 # ======================================================
46 # == parse configs ==
47 cfg = parse_configs(training=False)
48
49 # == device and dtype ==
50 device = "cuda" if torch.cuda.is_available() else "cpu"
51 cfg_dtype = cfg.get("dtype", "fp32")
52 assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}"
53 dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
54 torch.backends.cuda.matmul.allow_tf32 = True
55 torch.backends.cudnn.allow_tf32 = True
56
57 # == init distributed env ==
58 if is_distributed():
59 colossalai.launch_from_torch({})
60 coordinator = DistCoordinator()
61 enable_sequence_parallelism = coordinator.world_size > 1
62 if enable_sequence_parallelism:
63 set_sequence_parallel_group(dist.group.WORLD)
64 else:
65 coordinator = None
66 enable_sequence_parallelism = False
67 set_random_seed(seed=cfg.get("seed", 1024))
68
69 # == init logger ==
70 logger = create_logger()
71 logger.info("Inference configuration:\n %s", pformat(cfg.to_dict()))
72 verbose = cfg.get("verbose", 1)
73 progress_wrap = tqdm if verbose == 1 else (lambda x: x)
74
75 # ======================================================
76 # build model & load weights
77 # ======================================================
78 logger.info("Building models...")
79 # == build text-encoder and vae ==
80 text_encoder = build_module(cfg.text_encoder, MODELS, device=device)
81 vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()
82
83 # == prepare video size ==
84 image_size = cfg.get("image_size", None)
85 if image_size is None:
86 resolution = cfg.get("resolution", None)
87 aspect_ratio = cfg.get("aspect_ratio", None)
88 assert (
89 resolution is not None and aspect_ratio is not None
90 ), "resolution and aspect_ratio must be provided if image_size is not provided"
91 image_size = get_image_size(resolution, aspect_ratio)
92 num_frames = get_num_frames(cfg.num_frames)
93
94 # == build diffusion model ==
95 input_size = (num_frames, *image_size)
96 latent_size = vae.get_latent_size(input_size)
97 model = (
98 build_module(

Callers 1

inference.pyFile · 0.70

Calls 15

parse_configsFunction · 0.90
to_torch_dtypeFunction · 0.90
is_distributedFunction · 0.90
create_loggerFunction · 0.90
build_moduleFunction · 0.90
get_image_sizeFunction · 0.90
get_num_framesFunction · 0.90
load_promptsFunction · 0.90
collect_references_batchFunction · 0.90

Tested by

no test coverage detected