()
| 16 | |
| 17 | |
| 18 | def main(): |
| 19 | torch.set_grad_enabled(False) |
| 20 | # ====================================================== |
| 21 | # configs & runtime variables |
| 22 | # ====================================================== |
| 23 | # == parse configs == |
| 24 | cfg = parse_configs(training=False) |
| 25 | |
| 26 | # == device and dtype == |
| 27 | device = "cuda" if torch.cuda.is_available() else "cpu" |
| 28 | cfg_dtype = cfg.get("dtype", "fp32") |
| 29 | assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}" |
| 30 | dtype = to_torch_dtype(cfg.get("dtype", "bf16")) |
| 31 | torch.backends.cuda.matmul.allow_tf32 = True |
| 32 | torch.backends.cudnn.allow_tf32 = True |
| 33 | |
| 34 | # == init distributed env == |
| 35 | if is_distributed(): |
| 36 | colossalai.launch_from_torch({}) |
| 37 | set_random_seed(seed=cfg.get("seed", 1024)) |
| 38 | |
| 39 | # == init logger == |
| 40 | logger = create_logger() |
| 41 | logger.info("Inference configuration:\n %s", pformat(cfg.to_dict())) |
| 42 | verbose = cfg.get("verbose", 1) |
| 43 | |
| 44 | # ====================================================== |
| 45 | # build dataset and dataloader |
| 46 | # ====================================================== |
| 47 | logger.info("Building reconstruction dataset...") |
| 48 | dataset = build_module(cfg.dataset, DATASETS) |
| 49 | batch_size = cfg.get("batch_size", 1) |
| 50 | dataloader, _ = prepare_dataloader( |
| 51 | dataset, |
| 52 | batch_size=batch_size, |
| 53 | num_workers=cfg.get("num_workers", 4), |
| 54 | shuffle=False, |
| 55 | drop_last=False, |
| 56 | pin_memory=True, |
| 57 | process_group=get_data_parallel_group(), |
| 58 | ) |
| 59 | logger.info("Dataset %s contains %s videos.", cfg.dataset.data_path, len(dataset)) |
| 60 | total_batch_size = batch_size * get_world_size() |
| 61 | logger.info("Total batch size: %s", total_batch_size) |
| 62 | |
| 63 | total_steps = len(dataloader) |
| 64 | if cfg.get("num_samples", None) is not None: |
| 65 | total_steps = min(int(cfg.num_samples // cfg.batch_size), total_steps) |
| 66 | logger.info("limiting test dataset to %s", int(cfg.num_samples // cfg.batch_size) * cfg.batch_size) |
| 67 | dataiter = iter(dataloader) |
| 68 | |
| 69 | # ====================================================== |
| 70 | # build model & loss |
| 71 | # ====================================================== |
| 72 | logger.info("Building models...") |
| 73 | model = build_module(cfg.model, MODELS).to(device, dtype).eval() |
| 74 | vae_loss_fn = VAELoss( |
| 75 | logvar_init=cfg.get("logvar_init", 0.0), |
no test coverage detected