Gets predictions for a list of ``AttackedText`` objects. Gets prediction from cache if possible. If prediction is not in the cache, queries model and stores prediction in cache.
(self, attacked_text_list)
| 199 | return self._process_model_outputs(attacked_text_list, outputs) |
| 200 | |
| 201 | def _call_model(self, attacked_text_list): |
| 202 | """Gets predictions for a list of ``AttackedText`` objects. |
| 203 | |
| 204 | Gets prediction from cache if possible. If prediction is not in |
| 205 | the cache, queries model and stores prediction in cache. |
| 206 | """ |
| 207 | if not self.use_cache: |
| 208 | return self._call_model_uncached(attacked_text_list) |
| 209 | else: |
| 210 | uncached_list = [] |
| 211 | for text in attacked_text_list: |
| 212 | if text in self._call_model_cache: |
| 213 | # Re-write value in cache. This moves the key to the top of the |
| 214 | # LRU cache and prevents the unlikely event that the text |
| 215 | # is overwritten when we store the inputs from `uncached_list`. |
| 216 | self._call_model_cache[text] = self._call_model_cache[text] |
| 217 | else: |
| 218 | uncached_list.append(text) |
| 219 | uncached_list = [ |
| 220 | text |
| 221 | for text in attacked_text_list |
| 222 | if text not in self._call_model_cache |
| 223 | ] |
| 224 | outputs = self._call_model_uncached(uncached_list) |
| 225 | for text, output in zip(uncached_list, outputs): |
| 226 | self._call_model_cache[text] = output |
| 227 | all_outputs = [self._call_model_cache[text] for text in attacked_text_list] |
| 228 | return all_outputs |
| 229 | |
| 230 | def extra_repr_keys(self): |
| 231 | attrs = [] |
no test coverage detected