MCPcopy
hub / github.com/hpcaitech/Open-Sora / main

Function main

scripts/inference_i2v.py:42–386  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

40
41
42def main():
43 torch.set_grad_enabled(False)
44 # ======================================================
45 # configs & runtime variables
46 # ======================================================
47 # == parse configs ==
48 cfg = parse_configs(training=False)
49
50 # == device and dtype ==
51 device = "cuda" if torch.cuda.is_available() else "cpu"
52 cfg_dtype = cfg.get("dtype", "fp32")
53 assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}"
54 dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
55 torch.backends.cuda.matmul.allow_tf32 = True
56 torch.backends.cudnn.allow_tf32 = True
57
58 # == init distributed env ==
59 if is_distributed():
60 colossalai.launch_from_torch({})
61 coordinator = DistCoordinator()
62 enable_sequence_parallelism = coordinator.world_size > 1
63 if enable_sequence_parallelism:
64 set_sequence_parallel_group(dist.group.WORLD)
65 else:
66 coordinator = None
67 enable_sequence_parallelism = False
68 set_random_seed(seed=cfg.get("seed", 1024))
69
70 # == init logger ==
71 logger = create_logger()
72 logger.info("Inference configuration:\n %s", pformat(cfg.to_dict()))
73 verbose = cfg.get("verbose", 1)
74 progress_wrap = tqdm if verbose == 1 else (lambda x: x)
75
76 # ======================================================
77 # build model & load weights
78 # ======================================================
79 logger.info("Building models...")
80 # == build text-encoder and vae ==
81 text_encoder = build_module(cfg.text_encoder, MODELS, device=device)
82 vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()
83
84 # == prepare video size ==
85 image_size = cfg.get("image_size", None)
86 if image_size is None:
87 resolution = cfg.get("resolution", None)
88 aspect_ratio = cfg.get("aspect_ratio", None)
89 assert (
90 resolution is not None and aspect_ratio is not None
91 ), "resolution and aspect_ratio must be provided if image_size is not provided"
92 image_size = get_image_size(resolution, aspect_ratio)
93 num_frames = get_num_frames(cfg.num_frames)
94
95 # == build diffusion model ==
96 input_size = (num_frames, *image_size)
97 latent_size = vae.get_latent_size(input_size)
98 model = (
99 build_module(

Callers 1

inference_i2v.pyFile · 0.70

Calls 15

parse_configsFunction · 0.90
to_torch_dtypeFunction · 0.90
is_distributedFunction · 0.90
create_loggerFunction · 0.90
build_moduleFunction · 0.90
get_image_sizeFunction · 0.90
get_num_framesFunction · 0.90
load_promptsFunction · 0.90
collect_references_batchFunction · 0.90

Tested by

no test coverage detected