MCPcopy Index your code
hub / github.com/dmlc/dgl / test_gpu_cache

Function test_gpu_cache

tests/python/common/cuda/test_gpu_cache.py:54–72  ·  view source on GitHub ↗
(idtype)

Source from the content-addressed store, hash-verified

52@unittest.skipIf(not F.gpu_ctx(), reason="only necessary with GPU")
53@parametrize_idtype
54def test_gpu_cache(idtype):
55 g = generate_graph(idtype)
56 cache = dgl.cuda.GPUCache(5, D, idtype)
57 h = g.ndata["h"]
58
59 t = 5
60 keys = F.arange(0, t, dtype=idtype)
61 values, m_idx, m_keys = cache.query(keys)
62 m_values = h[F.tensor(m_keys, F.int64)]
63 values[F.tensor(m_idx, F.int64)] = m_values
64 cache.replace(m_keys, m_values)
65
66 keys = F.arange(3, 8, dtype=idtype)
67 values, m_idx, m_keys = cache.query(keys)
68 assert m_keys.shape[0] == 3 and m_idx.shape[0] == 3
69 m_values = h[F.tensor(m_keys, F.int64)]
70 values[F.tensor(m_idx, F.int64)] = m_values
71 assert (values != h[F.tensor(keys, F.int64)]).sum().item() == 0
72 cache.replace(m_keys, m_values)
73
74
75if __name__ == "__main__":

Callers 1

test_gpu_cache.pyFile · 0.85

Calls 3

queryMethod · 0.95
replaceMethod · 0.95
generate_graphFunction · 0.70

Tested by

no test coverage detected