(idtype)
| 52 | @unittest.skipIf(not F.gpu_ctx(), reason="only necessary with GPU") |
| 53 | @parametrize_idtype |
| 54 | def test_gpu_cache(idtype): |
| 55 | g = generate_graph(idtype) |
| 56 | cache = dgl.cuda.GPUCache(5, D, idtype) |
| 57 | h = g.ndata["h"] |
| 58 | |
| 59 | t = 5 |
| 60 | keys = F.arange(0, t, dtype=idtype) |
| 61 | values, m_idx, m_keys = cache.query(keys) |
| 62 | m_values = h[F.tensor(m_keys, F.int64)] |
| 63 | values[F.tensor(m_idx, F.int64)] = m_values |
| 64 | cache.replace(m_keys, m_values) |
| 65 | |
| 66 | keys = F.arange(3, 8, dtype=idtype) |
| 67 | values, m_idx, m_keys = cache.query(keys) |
| 68 | assert m_keys.shape[0] == 3 and m_idx.shape[0] == 3 |
| 69 | m_values = h[F.tensor(m_keys, F.int64)] |
| 70 | values[F.tensor(m_idx, F.int64)] = m_values |
| 71 | assert (values != h[F.tensor(keys, F.int64)]).sum().item() == 0 |
| 72 | cache.replace(m_keys, m_values) |
| 73 | |
| 74 | |
| 75 | if __name__ == "__main__": |
no test coverage detected