| 68 | |
| 69 | @classmethod |
| 70 | def cluster_traces(self, traces, n_clusters=3): |
| 71 | try: |
| 72 | traces_for_clustering = traces[0].transpose(0, 1) |
| 73 | # pred_tracks_4_clustering = pred_tracks_4_clustering - pred_tracks_4_clustering[:, :1] |
| 74 | traces_for_clustering = traces_for_clustering.flatten(1) |
| 75 | kmeans = faiss.Kmeans( |
| 76 | traces_for_clustering.shape[1], |
| 77 | min(n_clusters, traces_for_clustering.shape[0]), |
| 78 | niter=50, |
| 79 | verbose=False, |
| 80 | min_points_per_centroid=1, |
| 81 | max_points_per_centroid=10000000, |
| 82 | ) |
| 83 | kmeans.train(traces_for_clustering.cpu().numpy()) |
| 84 | distances, cluster_ids_x_np = kmeans.index.search(traces_for_clustering.cpu().numpy(), 1) |
| 85 | cluster_ids_x = torch.from_numpy(cluster_ids_x_np).to(traces_for_clustering.device) |
| 86 | except: |
| 87 | print("kmeans failed") |
| 88 | return None |
| 89 | # sample 20% of ids or at lest 1 and at most 2 ids from each cluster |
| 90 | sampled_ids = cluster_ids_x.new_zeros(cluster_ids_x.size(0)).to(traces.device) |
| 91 | for cluster_id in range(min(n_clusters, traces_for_clustering.shape[0])): |
| 92 | cluster_idx = (cluster_ids_x == cluster_id).nonzero().squeeze(1) |
| 93 | num_pts_to_sample = max(1, min(1, int(0.2*cluster_idx.size(0)))) |
| 94 | if num_pts_to_sample > 0: |
| 95 | # TODO: random sample is a bit dummy, need a better sampling algo here |
| 96 | sampled_idx = torch.randperm(cluster_idx.size(0))[:num_pts_to_sample] |
| 97 | sampled_ids[cluster_idx[sampled_idx]] = 1 |
| 98 | return sampled_ids |
| 99 | |
| 100 | @classmethod |
| 101 | def cluster_traces_kmeans(self, traces, n_clusters=3, positive=False): |