(model_outputs, ground_truth_detections)
| 50 | |
| 51 | |
| 52 | def test_association_costs(model_outputs, ground_truth_detections): |
| 53 | costs_gt = ground_truth_detections[0]["costs"] |
| 54 | peak_inds = predict_multianimal.find_local_peak_indices_maxpool_nms( |
| 55 | model_outputs[0], |
| 56 | RADIUS, |
| 57 | THRESHOLD, |
| 58 | ) |
| 59 | with tf.compat.v1.Session() as sess: |
| 60 | peak_inds = sess.run(peak_inds) |
| 61 | graph = [[i, j] for i in range(12) for j in range(i + 1, 12)] |
| 62 | preds = predict_multianimal.compute_peaks_and_costs( |
| 63 | *model_outputs, |
| 64 | peak_inds, |
| 65 | graph=graph, |
| 66 | paf_inds=np.arange(len(graph)), |
| 67 | n_id_channels=0, |
| 68 | stride=STRIDE, |
| 69 | )[0] |
| 70 | assert all(k in preds for k in ("coordinates", "confidence", "costs")) |
| 71 | costs_pred = preds["costs"] |
| 72 | assert len(costs_pred) == len(costs_gt) |
| 73 | eq = [ |
| 74 | np.array_equal(np.argmax(v["m1"], axis=0), np.argmax(costs_gt[k]["m1"], axis=0)) for k, v in costs_pred.items() |
| 75 | ] |
| 76 | assert sum(eq) == 60 # 6 arrays are unequal as cost computation was corrected |
| 77 | assert all(np.allclose(v["distance"], costs_gt[k]["distance"], atol=1.5) for k, v in costs_pred.items()) |
| 78 | |
| 79 | |
| 80 | def test_compute_peaks_and_costs_no_graph(model_outputs): |
nothing calls this directly
no test coverage detected