MCPcopy
hub / github.com/dmlc/dgl / test_feature_cache

Function test_feature_cache

tests/python/pytorch/graphbolt/impl/test_feature_cache.py:63–172  ·  view source on GitHub ↗
(offsets, dtype, feature_size, num_parts, policy, offset)

Source from the content-addressed store, hash-verified

61@pytest.mark.parametrize("policy", ["s3-fifo", "sieve", "lru", "clock"])
62@pytest.mark.parametrize("offset", [0, 1111111])
63def test_feature_cache(offsets, dtype, feature_size, num_parts, policy, offset):
64 cache_size = 32 * (
65 torch.get_num_threads() if num_parts is None else num_parts
66 )
67 a = torch.randint(0, 2, [1024, feature_size], dtype=dtype)
68 cache = gb.impl.CPUFeatureCache(
69 (cache_size,) + a.shape[1:], a.dtype, policy, num_parts
70 )
71 cache2 = gb.impl.CPUFeatureCache(
72 (cache_size,) + a.shape[1:], a.dtype, policy, num_parts
73 )
74 policy1 = gb.impl.CPUFeatureCache(
75 (cache_size,) + a.shape[1:], a.dtype, policy, num_parts
76 )._policy
77 policy2 = gb.impl.CPUFeatureCache(
78 (cache_size,) + a.shape[1:], a.dtype, policy, num_parts
79 )._policy
80 reader_fn = lambda keys: a[keys]
81
82 keys = torch.tensor([0, 1])
83 values, missing_index, missing_keys, missing_offsets = cache.query(
84 keys, offset
85 )
86 if not offsets:
87 missing_offsets = None
88 assert torch.equal(
89 missing_keys.flip([0]) if num_parts == 1 else missing_keys.sort()[0],
90 keys,
91 )
92
93 missing_values = a[missing_keys]
94 cache.replace(missing_keys, missing_values, missing_offsets, offset)
95 values[missing_index] = missing_values
96 assert torch.equal(values, a[keys])
97 assert torch.equal(
98 cache2.query_and_replace(keys, reader_fn, offset), a[keys]
99 )
100
101 _test_query_and_replace(policy1, policy2, keys, offset)
102
103 pin_memory = F._default_context_str == "gpu"
104
105 keys = torch.arange(1, 33, pin_memory=pin_memory)
106 values, missing_index, missing_keys, missing_offsets = cache.query(
107 keys, offset
108 )
109 if not offsets:
110 missing_offsets = None
111 assert torch.equal(
112 missing_keys.flip([0]) if num_parts == 1 else missing_keys.sort()[0],
113 torch.arange(2, 33),
114 )
115 assert not pin_memory or values.is_pinned()
116
117 missing_values = a[missing_keys]
118 cache.replace(missing_keys, missing_values, missing_offsets, offset)
119 values[missing_index] = missing_values
120 assert torch.equal(values, a[keys])

Callers

nothing calls this directly

Calls 9

queryMethod · 0.95
replaceMethod · 0.95
query_and_replaceMethod · 0.95
_test_query_and_replaceFunction · 0.85
index_selectMethod · 0.80
is_pinnedMethod · 0.45
replaceMethod · 0.45
toMethod · 0.45
ctxMethod · 0.45

Tested by

no test coverage detected