(conf: Dict,
pairs_path: Path,
match_path: Path,
feature_path_q: Path,
feature_path_ref: Path,
overwrite: bool = False)
| 199 | |
| 200 | @torch.no_grad() |
| 201 | def match_from_paths(conf: Dict, |
| 202 | pairs_path: Path, |
| 203 | match_path: Path, |
| 204 | feature_path_q: Path, |
| 205 | feature_path_ref: Path, |
| 206 | overwrite: bool = False) -> Path: |
| 207 | logger.info('Matching local features with configuration:' |
| 208 | f'\n{pprint.pformat(conf)}') |
| 209 | |
| 210 | if not feature_path_q.exists(): |
| 211 | raise FileNotFoundError(f'Query feature file {feature_path_q}.') |
| 212 | if not feature_path_ref.exists(): |
| 213 | raise FileNotFoundError(f'Reference feature file {feature_path_ref}.') |
| 214 | match_path.parent.mkdir(exist_ok=True, parents=True) |
| 215 | |
| 216 | assert pairs_path.exists(), pairs_path |
| 217 | pairs = parse_retrieval(pairs_path) |
| 218 | pairs = [(q, r) for q, rs in pairs.items() for r in rs] |
| 219 | pairs = find_unique_new_pairs(pairs, None if overwrite else match_path) |
| 220 | if len(pairs) == 0: |
| 221 | logger.info('Skipping the matching.') |
| 222 | return |
| 223 | |
| 224 | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| 225 | Model = dynamic_load(matchers, conf['model']['name']) |
| 226 | model = Model(conf['model']).eval().to(device) |
| 227 | |
| 228 | dataset = FeaturePairsDataset(pairs, feature_path_q, feature_path_ref) |
| 229 | loader = torch.utils.data.DataLoader( |
| 230 | dataset, num_workers=5, batch_size=1, shuffle=False, pin_memory=True) |
| 231 | writer_queue = WorkQueue(partial(writer_fn, match_path=match_path), 5) |
| 232 | |
| 233 | for idx, data in enumerate(tqdm(loader, smoothing=.1)): |
| 234 | data = {k: v if k.startswith('image') |
| 235 | else v.to(device, non_blocking=True) for k, v in data.items()} |
| 236 | pred = model(data) |
| 237 | pair = names_to_pair(*pairs[idx]) |
| 238 | writer_queue.put((pair, pred)) |
| 239 | writer_queue.join() |
| 240 | logger.info('Finished exporting matches.') |
| 241 | |
| 242 | |
| 243 | if __name__ == '__main__': |
no test coverage detected