MCPcopy
hub / github.com/QData/TextAttack / _call_model

Method _call_model

textattack/goal_functions/goal_function.py:201–228  ·  view source on GitHub ↗

Gets predictions for a list of ``AttackedText`` objects. Gets prediction from cache if possible. If prediction is not in the cache, queries model and stores prediction in cache.

(self, attacked_text_list)

Source from the content-addressed store, hash-verified

199 return self._process_model_outputs(attacked_text_list, outputs)
200
201 def _call_model(self, attacked_text_list):
202 """Gets predictions for a list of ``AttackedText`` objects.
203
204 Gets prediction from cache if possible. If prediction is not in
205 the cache, queries model and stores prediction in cache.
206 """
207 if not self.use_cache:
208 return self._call_model_uncached(attacked_text_list)
209 else:
210 uncached_list = []
211 for text in attacked_text_list:
212 if text in self._call_model_cache:
213 # Re-write value in cache. This moves the key to the top of the
214 # LRU cache and prevents the unlikely event that the text
215 # is overwritten when we store the inputs from `uncached_list`.
216 self._call_model_cache[text] = self._call_model_cache[text]
217 else:
218 uncached_list.append(text)
219 uncached_list = [
220 text
221 for text in attacked_text_list
222 if text not in self._call_model_cache
223 ]
224 outputs = self._call_model_uncached(uncached_list)
225 for text, output in zip(uncached_list, outputs):
226 self._call_model_cache[text] = output
227 all_outputs = [self._call_model_cache[text] for text in attacked_text_list]
228 return all_outputs
229
230 def extra_repr_keys(self):
231 attrs = []

Callers 2

get_outputMethod · 0.95
get_resultsMethod · 0.95

Calls 1

_call_model_uncachedMethod · 0.95

Tested by

no test coverage detected