MCPcopy
hub / github.com/QData/TextAttack / cuda_

Method cuda_

textattack/attack.py:208–239  ·  view source on GitHub ↗

Move any `torch.nn.Module` models that are part of Attack to GPU.

(self)

Source from the content-addressed store, hash-verified

206 to_cpu(self)
207
208 def cuda_(self):
209 """Move any `torch.nn.Module` models that are part of Attack to GPU."""
210 visited = set()
211
212 def to_cuda(obj):
213 visited.add(id(obj))
214 if isinstance(obj, torch.nn.Module):
215 obj.to(textattack.shared.utils.device)
216 elif isinstance(
217 obj,
218 (
219 Attack,
220 GoalFunction,
221 Transformation,
222 SearchMethod,
223 Constraint,
224 PreTransformationConstraint,
225 ModelWrapper,
226 ),
227 ):
228 for key in obj.__dict__:
229 s_obj = obj.__dict__[key]
230 if id(s_obj) not in visited:
231 to_cuda(s_obj)
232 elif isinstance(obj, (list, tuple)):
233 for item in obj:
234 if id(item) not in visited and isinstance(
235 item, (Transformation, Constraint, PreTransformationConstraint)
236 ):
237 to_cuda(item)
238
239 to_cuda(self)
240
241 def get_indices_to_order(self, current_text, **kwargs):
242 """Applies ``pre_transformation_constraints`` to ``text`` to get all

Callers 2

_attackMethod · 0.80
attack_from_queueFunction · 0.80

Calls

no outgoing calls

Tested by

no test coverage detected