High-level wrapper for GPU embedding cache
| 22 | |
| 23 | |
| 24 | class GPUCache(object): |
| 25 | """High-level wrapper for GPU embedding cache""" |
| 26 | |
| 27 | def __init__(self, num_items, num_feats, idtype=F.int64): |
| 28 | assert idtype in [F.int32, F.int64] |
| 29 | self._cache = _CAPI_DGLGpuCacheCreate( |
| 30 | num_items, num_feats, 32 if idtype == F.int32 else 64 |
| 31 | ) |
| 32 | self.idtype = idtype |
| 33 | self.total_miss = 0 |
| 34 | self.total_queries = 0 |
| 35 | |
| 36 | def query(self, keys): |
| 37 | """Queries the GPU cache. |
| 38 | |
| 39 | Parameters |
| 40 | ---------- |
| 41 | keys : Tensor |
| 42 | The keys to query the GPU cache with. |
| 43 | |
| 44 | Returns |
| 45 | ------- |
| 46 | tuple(Tensor, Tensor, Tensor) |
| 47 | A tuple containing (values, missing_indices, missing_keys) where |
| 48 | values[missing_indices] corresponds to cache misses that should be |
| 49 | filled by quering another source with missing_keys. |
| 50 | """ |
| 51 | self.total_queries += keys.shape[0] |
| 52 | keys = F.astype(keys, self.idtype) |
| 53 | values, missing_index, missing_keys = _CAPI_DGLGpuCacheQuery( |
| 54 | self._cache, F.to_dgl_nd(keys) |
| 55 | ) |
| 56 | self.total_miss += missing_keys.shape[0] |
| 57 | return ( |
| 58 | F.from_dgl_nd(values), |
| 59 | F.from_dgl_nd(missing_index), |
| 60 | F.from_dgl_nd(missing_keys), |
| 61 | ) |
| 62 | |
| 63 | def replace(self, keys, values): |
| 64 | """Inserts key-value pairs into the GPU cache using the Least-Recently |
| 65 | Used (LRU) algorithm to remove old key-value pairs if it is full. |
| 66 | |
| 67 | Parameters |
| 68 | ---------- |
| 69 | keys: Tensor |
| 70 | The keys to insert to the GPU cache. |
| 71 | values: Tensor |
| 72 | The values to insert to the GPU cache. |
| 73 | """ |
| 74 | keys = F.astype(keys, self.idtype) |
| 75 | values = F.astype(values, F.float32) |
| 76 | _CAPI_DGLGpuCacheReplace( |
| 77 | self._cache, F.to_dgl_nd(keys), F.to_dgl_nd(values) |
| 78 | ) |
| 79 | |
| 80 | @property |
| 81 | def miss_rate(self): |