()
| 39 | |
| 40 | |
| 41 | def main(): |
| 42 | torch.set_grad_enabled(False) |
| 43 | # ====================================================== |
| 44 | # configs & runtime variables |
| 45 | # ====================================================== |
| 46 | # == parse configs == |
| 47 | cfg = parse_configs(training=False) |
| 48 | |
| 49 | # == device and dtype == |
| 50 | device = "cuda" if torch.cuda.is_available() else "cpu" |
| 51 | cfg_dtype = cfg.get("dtype", "fp32") |
| 52 | assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}" |
| 53 | dtype = to_torch_dtype(cfg.get("dtype", "bf16")) |
| 54 | torch.backends.cuda.matmul.allow_tf32 = True |
| 55 | torch.backends.cudnn.allow_tf32 = True |
| 56 | |
| 57 | # == init distributed env == |
| 58 | if is_distributed(): |
| 59 | colossalai.launch_from_torch({}) |
| 60 | coordinator = DistCoordinator() |
| 61 | enable_sequence_parallelism = coordinator.world_size > 1 |
| 62 | if enable_sequence_parallelism: |
| 63 | set_sequence_parallel_group(dist.group.WORLD) |
| 64 | else: |
| 65 | coordinator = None |
| 66 | enable_sequence_parallelism = False |
| 67 | set_random_seed(seed=cfg.get("seed", 1024)) |
| 68 | |
| 69 | # == init logger == |
| 70 | logger = create_logger() |
| 71 | logger.info("Inference configuration:\n %s", pformat(cfg.to_dict())) |
| 72 | verbose = cfg.get("verbose", 1) |
| 73 | progress_wrap = tqdm if verbose == 1 else (lambda x: x) |
| 74 | |
| 75 | # ====================================================== |
| 76 | # build model & load weights |
| 77 | # ====================================================== |
| 78 | logger.info("Building models...") |
| 79 | # == build text-encoder and vae == |
| 80 | text_encoder = build_module(cfg.text_encoder, MODELS, device=device) |
| 81 | vae = build_module(cfg.vae, MODELS).to(device, dtype).eval() |
| 82 | |
| 83 | # == prepare video size == |
| 84 | image_size = cfg.get("image_size", None) |
| 85 | if image_size is None: |
| 86 | resolution = cfg.get("resolution", None) |
| 87 | aspect_ratio = cfg.get("aspect_ratio", None) |
| 88 | assert ( |
| 89 | resolution is not None and aspect_ratio is not None |
| 90 | ), "resolution and aspect_ratio must be provided if image_size is not provided" |
| 91 | image_size = get_image_size(resolution, aspect_ratio) |
| 92 | num_frames = get_num_frames(cfg.num_frames) |
| 93 | |
| 94 | # == build diffusion model == |
| 95 | input_size = (num_frames, *image_size) |
| 96 | latent_size = vae.get_latent_size(input_size) |
| 97 | model = ( |
| 98 | build_module( |
no test coverage detected