Queries the cache. Then inserts the keys that are not found by reading them by calling `reader_fn(missing_keys)`, which are then inserted into the cache using the selected caching policy algorithm to remove the old entries if it is full. Parameters ----------
(self, keys, reader_fn, offset=0)
| 104 | return values, missing_index, missing_keys, missing_offsets |
| 105 | |
| 106 | def query_and_replace(self, keys, reader_fn, offset=0): |
| 107 | """Queries the cache. Then inserts the keys that are not found by |
| 108 | reading them by calling `reader_fn(missing_keys)`, which are then |
| 109 | inserted into the cache using the selected caching policy algorithm |
| 110 | to remove the old entries if it is full. |
| 111 | |
| 112 | Parameters |
| 113 | ---------- |
| 114 | keys : Tensor |
| 115 | The keys to query the cache with. |
| 116 | reader_fn : reader_fn(keys: torch.Tensor) -> torch.Tensor |
| 117 | A function that will take a missing keys tensor and will return |
| 118 | their values. |
| 119 | offset : int |
| 120 | The offset to be added to the keys. Default is 0. |
| 121 | |
| 122 | Returns |
| 123 | ------- |
| 124 | Tensor |
| 125 | A tensor containing values corresponding to the keys. Should equal |
| 126 | `reader_fn(keys)`, computed in a faster way. |
| 127 | """ |
| 128 | self.total_queries += keys.shape[0] |
| 129 | ( |
| 130 | positions, |
| 131 | index, |
| 132 | pointers, |
| 133 | missing_keys, |
| 134 | found_offsets, |
| 135 | missing_offsets, |
| 136 | ) = self._policy.query_and_replace(keys, offset) |
| 137 | found_cnt = keys.size(0) - missing_keys.size(0) |
| 138 | found_positions = positions[:found_cnt] |
| 139 | values = self._cache.query(found_positions, index, keys.shape[0]) |
| 140 | found_pointers = pointers[:found_cnt] |
| 141 | self._policy.reading_completed(found_pointers, found_offsets) |
| 142 | self.total_miss += missing_keys.shape[0] |
| 143 | missing_index = index[found_cnt:] |
| 144 | missing_values = reader_fn(missing_keys) |
| 145 | values[missing_index] = missing_values |
| 146 | missing_positions = positions[found_cnt:] |
| 147 | self._cache.replace(missing_positions, missing_values) |
| 148 | missing_pointers = pointers[found_cnt:] |
| 149 | self._policy.writing_completed(missing_pointers, missing_offsets) |
| 150 | return values |
| 151 | |
| 152 | def replace(self, keys, values, offsets=None, offset=0): |
| 153 | """Inserts key-value pairs into the cache using the selected caching |