MCPcopy
hub / github.com/microsoft/Magma / cluster_traces

Method cluster_traces

data/utils/visual_tracker.py:82–110  ·  view source on GitHub ↗
(self, traces, n_clusters=3)

Source from the content-addressed store, hash-verified

80
81 @classmethod
82 def cluster_traces(self, traces, n_clusters=3):
83 try:
84 traces_for_clustering = traces[0].transpose(0, 1)
85 # pred_tracks_4_clustering = pred_tracks_4_clustering - pred_tracks_4_clustering[:, :1]
86 traces_for_clustering = traces_for_clustering.flatten(1)
87 kmeans = faiss.Kmeans(
88 traces_for_clustering.shape[1],
89 min(n_clusters, traces_for_clustering.shape[0]),
90 niter=50,
91 verbose=False,
92 min_points_per_centroid=1,
93 max_points_per_centroid=10000000,
94 )
95 kmeans.train(traces_for_clustering.cpu().numpy())
96 distances, cluster_ids_x_np = kmeans.index.search(traces_for_clustering.cpu().numpy(), 1)
97 cluster_ids_x = torch.from_numpy(cluster_ids_x_np).to(traces_for_clustering.device)
98 except:
99 print("kmeans failed")
100 return None
101 # sample 20% of ids or at lest 1 and at most 2 ids from each cluster
102 sampled_ids = cluster_ids_x.new_zeros(cluster_ids_x.size(0)).to(traces.device)
103 for cluster_id in range(min(n_clusters, traces_for_clustering.shape[0])):
104 cluster_idx = (cluster_ids_x == cluster_id).nonzero().squeeze(1)
105 num_pts_to_sample = max(1, min(1, int(0.2*cluster_idx.size(0))))
106 if num_pts_to_sample > 0:
107 # TODO: random sample is a bit dummy, need a better sampling algo here
108 sampled_idx = torch.randperm(cluster_idx.size(0))[:num_pts_to_sample]
109 sampled_ids[cluster_idx[sampled_idx]] = 1
110 return sampled_ids
111
112 def visualize(self, video, pred_tracks, pred_visibility, filename="visual_trace.mp4", mode="ranbow"):
113 if mode == "rainbow":

Callers

nothing calls this directly

Calls 1

flattenMethod · 0.80

Tested by

no test coverage detected