| 9 | """High-level wrapper for GPU embedding cache""" |
| 10 | |
| 11 | def __init__(self, cache_shape, dtype): |
| 12 | major, _ = torch.cuda.get_device_capability() |
| 13 | assert ( |
| 14 | major >= 7 |
| 15 | ), "GPUFeatureCache is supported only on CUDA compute capability >= 70 (Volta)." |
| 16 | self._cache = torch.ops.graphbolt.gpu_cache(cache_shape, dtype) |
| 17 | element_size = torch.tensor([], dtype=dtype).element_size() |
| 18 | self.max_size_in_bytes = reduce(mul, cache_shape) * element_size |
| 19 | self.total_miss = 0 |
| 20 | self.total_queries = 0 |
| 21 | |
| 22 | def query(self, keys, async_op=False): |
| 23 | """Queries the GPU cache. |