| 80 | |
| 81 | @classmethod |
| 82 | def cluster_traces(self, traces, n_clusters=3): |
| 83 | try: |
| 84 | traces_for_clustering = traces[0].transpose(0, 1) |
| 85 | # pred_tracks_4_clustering = pred_tracks_4_clustering - pred_tracks_4_clustering[:, :1] |
| 86 | traces_for_clustering = traces_for_clustering.flatten(1) |
| 87 | kmeans = faiss.Kmeans( |
| 88 | traces_for_clustering.shape[1], |
| 89 | min(n_clusters, traces_for_clustering.shape[0]), |
| 90 | niter=50, |
| 91 | verbose=False, |
| 92 | min_points_per_centroid=1, |
| 93 | max_points_per_centroid=10000000, |
| 94 | ) |
| 95 | kmeans.train(traces_for_clustering.cpu().numpy()) |
| 96 | distances, cluster_ids_x_np = kmeans.index.search(traces_for_clustering.cpu().numpy(), 1) |
| 97 | cluster_ids_x = torch.from_numpy(cluster_ids_x_np).to(traces_for_clustering.device) |
| 98 | except: |
| 99 | print("kmeans failed") |
| 100 | return None |
| 101 | # sample 20% of ids or at lest 1 and at most 2 ids from each cluster |
| 102 | sampled_ids = cluster_ids_x.new_zeros(cluster_ids_x.size(0)).to(traces.device) |
| 103 | for cluster_id in range(min(n_clusters, traces_for_clustering.shape[0])): |
| 104 | cluster_idx = (cluster_ids_x == cluster_id).nonzero().squeeze(1) |
| 105 | num_pts_to_sample = max(1, min(1, int(0.2*cluster_idx.size(0)))) |
| 106 | if num_pts_to_sample > 0: |
| 107 | # TODO: random sample is a bit dummy, need a better sampling algo here |
| 108 | sampled_idx = torch.randperm(cluster_idx.size(0))[:num_pts_to_sample] |
| 109 | sampled_ids[cluster_idx[sampled_idx]] = 1 |
| 110 | return sampled_ids |
| 111 | |
| 112 | def visualize(self, video, pred_tracks, pred_visibility, filename="visual_trace.mp4", mode="ranbow"): |
| 113 | if mode == "rainbow": |