MCPcopy
hub / github.com/idiap/fast-transformers / time_clustering

Function time_clustering

tests/clustering/hamming/time_python_api_gpu.py:28–56  ·  view source on GitHub ↗
(L, N, H, E,
                    n_batches, n_attentions,
                    k, n_buckets, n_iterations, verbose)

Source from the content-addressed store, hash-verified

26
27
28def time_clustering(L, N, H, E,
29 n_batches, n_attentions,
30 k, n_buckets, n_iterations, verbose):
31 n_points = L * N * H
32 hashes = torch.zeros(n_points, dtype=torch.int64).cuda()
33 hashes = generate_hash(n_points, E, n_buckets, hashes).view(N, H, L)
34
35 groups = torch.zeros((N, H, L), dtype=torch.int32).cuda()
36 counts = torch.zeros((N, H, k), dtype=torch.int32).cuda()
37 centroids = torch.zeros((N, H, k), dtype=torch.int64).cuda()
38 distances = torch.zeros((N, H, L), dtype=torch.int32).cuda()
39 cluster_bit_counts = torch.zeros((N, H, k, n_buckets),
40 dtype=torch.int32).cuda()
41 sequence_lengths = torch.ones((N,), dtype=torch.int32).cuda() * L
42
43 start = time.time()
44 for batch_idx in range(int(n_batches)):
45 for attention_idx in range(int(n_attentions)):
46 #hashes = generate_hash(n_points, E, n_buckets, hashes).view(L, N, H)
47 cluster(
48 hashes, sequence_lengths,
49 groups=groups, counts=counts, centroids=centroids,
50 distances=distances, bitcounts=cluster_bit_counts,
51 iterations=n_iterations,
52 bits=n_buckets
53 )
54 end = time.time()
55 duration = end - start
56 print("Time Elapsed: {}".format(duration))
57
58
59if __name__ == "__main__":

Callers 1

Calls 2

clusterFunction · 0.90
generate_hashFunction · 0.70

Tested by

no test coverage detected