(descriptors, output, num_matched,
query_prefix=None, query_list=None,
db_prefix=None, db_list=None, db_model=None, db_descriptors=None)
| 68 | |
| 69 | |
| 70 | def main(descriptors, output, num_matched, |
| 71 | query_prefix=None, query_list=None, |
| 72 | db_prefix=None, db_list=None, db_model=None, db_descriptors=None): |
| 73 | logger.info('Extracting image pairs from a retrieval database.') |
| 74 | |
| 75 | # We handle multiple reference feature files. |
| 76 | # We only assume that names are unique among them and map names to files. |
| 77 | if db_descriptors is None: |
| 78 | db_descriptors = descriptors |
| 79 | if isinstance(db_descriptors, (Path, str)): |
| 80 | db_descriptors = [db_descriptors] |
| 81 | name2db = {n: i for i, p in enumerate(db_descriptors) |
| 82 | for n in list_h5_names(p)} |
| 83 | db_names_h5 = list(name2db.keys()) |
| 84 | query_names_h5 = list_h5_names(descriptors) |
| 85 | |
| 86 | if db_model: |
| 87 | images = read_images_binary(db_model / 'images.bin') |
| 88 | db_names = [i.name for i in images.values()] |
| 89 | else: |
| 90 | db_names = parse_names(db_prefix, db_list, db_names_h5) |
| 91 | if len(db_names) == 0: |
| 92 | raise ValueError('Could not find any database image.') |
| 93 | query_names = parse_names(query_prefix, query_list, query_names_h5) |
| 94 | |
| 95 | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| 96 | db_desc = get_descriptors(db_names, db_descriptors, name2db) |
| 97 | query_desc = get_descriptors(query_names, descriptors) |
| 98 | sim = torch.einsum('id,jd->ij', query_desc.to(device), db_desc.to(device)) |
| 99 | |
| 100 | # Avoid self-matching |
| 101 | self = np.array(query_names)[:, None] == np.array(db_names)[None] |
| 102 | pairs = pairs_from_score_matrix(sim, self, num_matched, min_score=0) |
| 103 | pairs = [(query_names[i], db_names[j]) for i, j in pairs] |
| 104 | |
| 105 | logger.info(f'Found {len(pairs)} pairs.') |
| 106 | with open(output, 'w') as f: |
| 107 | f.write('\n'.join(' '.join([i, j]) for i, j in pairs)) |
| 108 | |
| 109 | |
| 110 | if __name__ == "__main__": |
no test coverage detected