()
| 139 | |
| 140 | @torch.no_grad() |
| 141 | def main(): |
| 142 | args = parse_args() |
| 143 | |
| 144 | meta_path = args.meta_path |
| 145 | if not os.path.exists(meta_path): |
| 146 | print(f"Meta file '{meta_path}' not found. Exit.") |
| 147 | exit() |
| 148 | |
| 149 | wo_ext, ext = os.path.splitext(meta_path) |
| 150 | out_path = f"{wo_ext}_flow{ext}" |
| 151 | if args.skip_if_existing and os.path.exists(out_path): |
| 152 | print(f"Output meta file '{out_path}' already exists. Exit.") |
| 153 | exit() |
| 154 | |
| 155 | torch.backends.cudnn.deterministic = True |
| 156 | torch.backends.cudnn.benchmark = False |
| 157 | dist.init_process_group(backend="nccl", timeout=timedelta(hours=24)) |
| 158 | torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) |
| 159 | |
| 160 | # build model |
| 161 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| 162 | model = UniMatch( |
| 163 | feature_channels=128, |
| 164 | num_scales=2, |
| 165 | upsample_factor=4, |
| 166 | num_head=1, |
| 167 | ffn_dim_expansion=4, |
| 168 | num_transformer_layers=6, |
| 169 | reg_refine=True, |
| 170 | task="flow", |
| 171 | ) |
| 172 | ckpt = torch.load("./pretrained_models/unimatch/gmflow-scale2-regrefine6-mixdata-train320x576-4e7b215d.pth") |
| 173 | model.load_state_dict(ckpt["model"]) |
| 174 | model = model.to(device) |
| 175 | |
| 176 | # build dataset |
| 177 | NUM_FRAMES = 10 |
| 178 | frames_inds = [15 * i for i in range(0, NUM_FRAMES)] |
| 179 | dataset = VideoTextDataset(meta_path=meta_path, frame_inds=frames_inds) |
| 180 | dataloader = DataLoader( |
| 181 | dataset, |
| 182 | batch_size=args.bs, |
| 183 | num_workers=args.num_workers, |
| 184 | sampler=DistributedSampler( |
| 185 | dataset, |
| 186 | num_replicas=dist.get_world_size(), |
| 187 | rank=dist.get_rank(), |
| 188 | shuffle=False, |
| 189 | drop_last=False, |
| 190 | ), |
| 191 | ) |
| 192 | |
| 193 | # compute optical flow scores |
| 194 | indices_list = [] |
| 195 | scores_list = [] |
| 196 | model.eval() |
| 197 | for batch in tqdm(dataloader, disable=dist.get_rank() != 0): |
| 198 | indices = batch["index"] |
no test coverage detected