MCPcopy
hub / github.com/dmlc/dgl / GPUCache

Class GPUCache

python/dgl/cuda/gpu_cache.py:24–83  ·  view source on GitHub ↗

High-level wrapper for GPU embedding cache

Source from the content-addressed store, hash-verified

22
23
24class GPUCache(object):
25 """High-level wrapper for GPU embedding cache"""
26
27 def __init__(self, num_items, num_feats, idtype=F.int64):
28 assert idtype in [F.int32, F.int64]
29 self._cache = _CAPI_DGLGpuCacheCreate(
30 num_items, num_feats, 32 if idtype == F.int32 else 64
31 )
32 self.idtype = idtype
33 self.total_miss = 0
34 self.total_queries = 0
35
36 def query(self, keys):
37 """Queries the GPU cache.
38
39 Parameters
40 ----------
41 keys : Tensor
42 The keys to query the GPU cache with.
43
44 Returns
45 -------
46 tuple(Tensor, Tensor, Tensor)
47 A tuple containing (values, missing_indices, missing_keys) where
48 values[missing_indices] corresponds to cache misses that should be
49 filled by quering another source with missing_keys.
50 """
51 self.total_queries += keys.shape[0]
52 keys = F.astype(keys, self.idtype)
53 values, missing_index, missing_keys = _CAPI_DGLGpuCacheQuery(
54 self._cache, F.to_dgl_nd(keys)
55 )
56 self.total_miss += missing_keys.shape[0]
57 return (
58 F.from_dgl_nd(values),
59 F.from_dgl_nd(missing_index),
60 F.from_dgl_nd(missing_keys),
61 )
62
63 def replace(self, keys, values):
64 """Inserts key-value pairs into the GPU cache using the Least-Recently
65 Used (LRU) algorithm to remove old key-value pairs if it is full.
66
67 Parameters
68 ----------
69 keys: Tensor
70 The keys to insert to the GPU cache.
71 values: Tensor
72 The values to insert to the GPU cache.
73 """
74 keys = F.astype(keys, self.idtype)
75 values = F.astype(values, F.float32)
76 _CAPI_DGLGpuCacheReplace(
77 self._cache, F.to_dgl_nd(keys), F.to_dgl_nd(values)
78 )
79
80 @property
81 def miss_rate(self):

Callers 1

_init_gpu_cachesFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected