MCPcopy
hub / github.com/idiap/fast-transformers / cluster

Function cluster

fast_transformers/clustering/hamming/__init__.py:19–114  ·  view source on GitHub ↗

Cluster hashes using a few iterations of K-Means with hamming distance. All the tensors default initialized to None are optional buffers to avoid memory allocations. distances and bitcounts are only used by the CUDA version of this call. clusters will be ignored if centroids is provided

(
    hashes,
    lengths,
    groups=None,
    counts=None,
    centroids=None,
    distances=None,
    bitcounts=None,
    clusters=30,
    iterations=10,
    bits=32
)

Source from the content-addressed store, hash-verified

17
18
19def cluster(
20 hashes,
21 lengths,
22 groups=None,
23 counts=None,
24 centroids=None,
25 distances=None,
26 bitcounts=None,
27 clusters=30,
28 iterations=10,
29 bits=32
30):
31 """Cluster hashes using a few iterations of K-Means with hamming distance.
32
33 All the tensors default initialized to None are optional buffers to avoid
34 memory allocations. distances and bitcounts are only used by the CUDA
35 version of this call. clusters will be ignored if centroids is provided.
36
37 Arguments
38 ---------
39 hashes: A long tensor of shape (N, H, L) containing a hashcode for each
40 query.
41 lengths: An int tensor of shape (N,) containing the sequence length for
42 each sequence in hashes.
43 groups: An int tensor buffer of shape (N, H, L) contaning the cluster
44 in which the corresponding hash belongs to.
45 counts: An int tensor buffer of shape (N, H, K) containing the number
46 of elements in each cluster.
47 centroids: A long tensor buffer of shape (N, H, K) containing the
48 centroid for each cluster.
49 distances: An int tensor of shape (N, H, L) containing the distance to
50 the closest centroid for each hash.
51 bitcounts: An int tensor of shape (N, H, K, bits) containing the number
52 of elements that have 1 for a given bit.
53 clusters: The number of clusters to use for each sequence. It is
54 ignored if centroids is not None.
55 iterations: How many k-means iterations to perform.
56 bits: How many of the least-significant bits in hashes to consider.
57
58 Returns
59 -------
60 groups and counts as defined above.
61 """
62 device = hashes.device
63 N, H, L = hashes.shape
64
65 # Unfortunately cpu and gpu have different APIs so the entire call must be
66 # surrounded by an if-then-else
67 if device.type == "cpu":
68 if groups is None:
69 groups = torch.empty((N, H, L), dtype=torch.int32)
70 if centroids is None:
71 centroids = torch.empty((N, H, clusters), dtype=torch.int64)
72 centroids = hashes[:, :, np.random.choice(L, size=[clusters], replace=False)]
73 K = centroids.shape[2]
74 if counts is None:
75 counts = torch.empty((N, H, K), dtype=torch.int32)
76

Callers 15

_create_query_groupsMethod · 0.90
_create_query_groupsMethod · 0.90
_create_query_groupsMethod · 0.90
test_clusteringMethod · 0.90
time_clusteringFunction · 0.90
cluster_queriesFunction · 0.90
cluster_queriesFunction · 0.90

Calls

no outgoing calls

Tested by 15

test_clusteringMethod · 0.72
cluster_queriesFunction · 0.72
cluster_queriesFunction · 0.72
cluster_queriesFunction · 0.72
cluster_queriesFunction · 0.72
cluster_queriesFunction · 0.72
cluster_queriesFunction · 0.72