(L, N, H, E,
n_batches, n_attentions,
k, n_buckets, n_iterations, verbose)
| 26 | |
| 27 | |
| 28 | def time_clustering(L, N, H, E, |
| 29 | n_batches, n_attentions, |
| 30 | k, n_buckets, n_iterations, verbose): |
| 31 | n_points = L * N * H |
| 32 | hashes = torch.zeros(n_points, dtype=torch.int64).cuda() |
| 33 | hashes = generate_hash(n_points, E, n_buckets, hashes).view(N, H, L) |
| 34 | |
| 35 | groups = torch.zeros((N, H, L), dtype=torch.int32).cuda() |
| 36 | counts = torch.zeros((N, H, k), dtype=torch.int32).cuda() |
| 37 | centroids = torch.zeros((N, H, k), dtype=torch.int64).cuda() |
| 38 | distances = torch.zeros((N, H, L), dtype=torch.int32).cuda() |
| 39 | cluster_bit_counts = torch.zeros((N, H, k, n_buckets), |
| 40 | dtype=torch.int32).cuda() |
| 41 | sequence_lengths = torch.ones((N,), dtype=torch.int32).cuda() * L |
| 42 | |
| 43 | start = time.time() |
| 44 | for batch_idx in range(int(n_batches)): |
| 45 | for attention_idx in range(int(n_attentions)): |
| 46 | #hashes = generate_hash(n_points, E, n_buckets, hashes).view(L, N, H) |
| 47 | cluster( |
| 48 | hashes, sequence_lengths, |
| 49 | groups=groups, counts=counts, centroids=centroids, |
| 50 | distances=distances, bitcounts=cluster_bit_counts, |
| 51 | iterations=n_iterations, |
| 52 | bits=n_buckets |
| 53 | ) |
| 54 | end = time.time() |
| 55 | duration = end - start |
| 56 | print("Time Elapsed: {}".format(duration)) |
| 57 | |
| 58 | |
| 59 | if __name__ == "__main__": |
no test coverage detected