MCPcopy
hub / github.com/QData/TextAttack / cpu_

Method cpu_

textattack/attack.py:175–206  ·  view source on GitHub ↗

Move any `torch.nn.Module` models that are part of Attack to CPU.

(self)

Source from the content-addressed store, hash-verified

173 constraint.clear_cache()
174
175 def cpu_(self):
176 """Move any `torch.nn.Module` models that are part of Attack to CPU."""
177 visited = set()
178
179 def to_cpu(obj):
180 visited.add(id(obj))
181 if isinstance(obj, torch.nn.Module):
182 obj.cpu()
183 elif isinstance(
184 obj,
185 (
186 Attack,
187 GoalFunction,
188 Transformation,
189 SearchMethod,
190 Constraint,
191 PreTransformationConstraint,
192 ModelWrapper,
193 ),
194 ):
195 for key in obj.__dict__:
196 s_obj = obj.__dict__[key]
197 if id(s_obj) not in visited:
198 to_cpu(s_obj)
199 elif isinstance(obj, (list, tuple)):
200 for item in obj:
201 if id(item) not in visited and isinstance(
202 item, (Transformation, Constraint, PreTransformationConstraint)
203 ):
204 to_cpu(item)
205
206 to_cpu(self)
207
208 def cuda_(self):
209 """Move any `torch.nn.Module` models that are part of Attack to GPU."""

Callers 1

_attack_parallelMethod · 0.80

Calls

no outgoing calls

Tested by

no test coverage detected