(offsets, dtype, feature_size, num_parts, policy, offset)
| 61 | @pytest.mark.parametrize("policy", ["s3-fifo", "sieve", "lru", "clock"]) |
| 62 | @pytest.mark.parametrize("offset", [0, 1111111]) |
| 63 | def test_feature_cache(offsets, dtype, feature_size, num_parts, policy, offset): |
| 64 | cache_size = 32 * ( |
| 65 | torch.get_num_threads() if num_parts is None else num_parts |
| 66 | ) |
| 67 | a = torch.randint(0, 2, [1024, feature_size], dtype=dtype) |
| 68 | cache = gb.impl.CPUFeatureCache( |
| 69 | (cache_size,) + a.shape[1:], a.dtype, policy, num_parts |
| 70 | ) |
| 71 | cache2 = gb.impl.CPUFeatureCache( |
| 72 | (cache_size,) + a.shape[1:], a.dtype, policy, num_parts |
| 73 | ) |
| 74 | policy1 = gb.impl.CPUFeatureCache( |
| 75 | (cache_size,) + a.shape[1:], a.dtype, policy, num_parts |
| 76 | )._policy |
| 77 | policy2 = gb.impl.CPUFeatureCache( |
| 78 | (cache_size,) + a.shape[1:], a.dtype, policy, num_parts |
| 79 | )._policy |
| 80 | reader_fn = lambda keys: a[keys] |
| 81 | |
| 82 | keys = torch.tensor([0, 1]) |
| 83 | values, missing_index, missing_keys, missing_offsets = cache.query( |
| 84 | keys, offset |
| 85 | ) |
| 86 | if not offsets: |
| 87 | missing_offsets = None |
| 88 | assert torch.equal( |
| 89 | missing_keys.flip([0]) if num_parts == 1 else missing_keys.sort()[0], |
| 90 | keys, |
| 91 | ) |
| 92 | |
| 93 | missing_values = a[missing_keys] |
| 94 | cache.replace(missing_keys, missing_values, missing_offsets, offset) |
| 95 | values[missing_index] = missing_values |
| 96 | assert torch.equal(values, a[keys]) |
| 97 | assert torch.equal( |
| 98 | cache2.query_and_replace(keys, reader_fn, offset), a[keys] |
| 99 | ) |
| 100 | |
| 101 | _test_query_and_replace(policy1, policy2, keys, offset) |
| 102 | |
| 103 | pin_memory = F._default_context_str == "gpu" |
| 104 | |
| 105 | keys = torch.arange(1, 33, pin_memory=pin_memory) |
| 106 | values, missing_index, missing_keys, missing_offsets = cache.query( |
| 107 | keys, offset |
| 108 | ) |
| 109 | if not offsets: |
| 110 | missing_offsets = None |
| 111 | assert torch.equal( |
| 112 | missing_keys.flip([0]) if num_parts == 1 else missing_keys.sort()[0], |
| 113 | torch.arange(2, 33), |
| 114 | ) |
| 115 | assert not pin_memory or values.is_pinned() |
| 116 | |
| 117 | missing_values = a[missing_keys] |
| 118 | cache.replace(missing_keys, missing_values, missing_offsets, offset) |
| 119 | values[missing_index] = missing_values |
| 120 | assert torch.equal(values, a[keys]) |
nothing calls this directly
no test coverage detected