Compute mean average precision (mAP)
(dist_mat, query_ids, gallery_ids, query_cams, gallery_cams)
| 77 | |
| 78 | |
| 79 | def mean_ap(dist_mat, query_ids, gallery_ids, query_cams, gallery_cams): |
| 80 | """Compute mean average precision (mAP)""" |
| 81 | dist_mat = dist_mat.cpu().numpy() |
| 82 | m, n = dist_mat.shape |
| 83 | query_ids = np.asarray(query_ids) |
| 84 | gallery_ids = np.asarray(gallery_ids) |
| 85 | query_cams = np.asarray(query_cams) |
| 86 | gallery_cams = np.asarray(gallery_cams) |
| 87 | # Sort and find correct matches |
| 88 | indices = np.argsort(dist_mat, axis=1) |
| 89 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) |
| 90 | # Compute AP for each query |
| 91 | aps = [] |
| 92 | for i in range(m): |
| 93 | # Filter out the same id and same camera |
| 94 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | |
| 95 | (gallery_cams[indices[i]] != query_cams[i])) |
| 96 | y_true = matches[i, valid] |
| 97 | y_score = -dist_mat[i][indices[i]][valid] |
| 98 | if not np.any(y_true): continue |
| 99 | aps.append(average_precision_score(y_true, y_score)) |
| 100 | if len(aps) == 0: |
| 101 | raise RuntimeError("No valid query") |
| 102 | return np.mean(aps) |
| 103 | |
| 104 | |
| 105 | def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3): |