MCPcopy
hub / github.com/microsoft/Magma / cluster_traces

Method cluster_traces

data/utils/visual_trace.py:70–98  ·  view source on GitHub ↗
(self, traces, n_clusters=3)

Source from the content-addressed store, hash-verified

68
69 @classmethod
70 def cluster_traces(self, traces, n_clusters=3):
71 try:
72 traces_for_clustering = traces[0].transpose(0, 1)
73 # pred_tracks_4_clustering = pred_tracks_4_clustering - pred_tracks_4_clustering[:, :1]
74 traces_for_clustering = traces_for_clustering.flatten(1)
75 kmeans = faiss.Kmeans(
76 traces_for_clustering.shape[1],
77 min(n_clusters, traces_for_clustering.shape[0]),
78 niter=50,
79 verbose=False,
80 min_points_per_centroid=1,
81 max_points_per_centroid=10000000,
82 )
83 kmeans.train(traces_for_clustering.cpu().numpy())
84 distances, cluster_ids_x_np = kmeans.index.search(traces_for_clustering.cpu().numpy(), 1)
85 cluster_ids_x = torch.from_numpy(cluster_ids_x_np).to(traces_for_clustering.device)
86 except:
87 print("kmeans failed")
88 return None
89 # sample 20% of ids or at lest 1 and at most 2 ids from each cluster
90 sampled_ids = cluster_ids_x.new_zeros(cluster_ids_x.size(0)).to(traces.device)
91 for cluster_id in range(min(n_clusters, traces_for_clustering.shape[0])):
92 cluster_idx = (cluster_ids_x == cluster_id).nonzero().squeeze(1)
93 num_pts_to_sample = max(1, min(1, int(0.2*cluster_idx.size(0))))
94 if num_pts_to_sample > 0:
95 # TODO: random sample is a bit dummy, need a better sampling algo here
96 sampled_idx = torch.randperm(cluster_idx.size(0))[:num_pts_to_sample]
97 sampled_ids[cluster_idx[sampled_idx]] = 1
98 return sampled_ids
99
100 @classmethod
101 def cluster_traces_kmeans(self, traces, n_clusters=3, positive=False):

Callers

nothing calls this directly

Calls 1

flattenMethod · 0.80

Tested by

no test coverage detected