Move any `torch.nn.Module` models that are part of Attack to GPU.
(self)
| 206 | to_cpu(self) |
| 207 | |
| 208 | def cuda_(self): |
| 209 | """Move any `torch.nn.Module` models that are part of Attack to GPU.""" |
| 210 | visited = set() |
| 211 | |
| 212 | def to_cuda(obj): |
| 213 | visited.add(id(obj)) |
| 214 | if isinstance(obj, torch.nn.Module): |
| 215 | obj.to(textattack.shared.utils.device) |
| 216 | elif isinstance( |
| 217 | obj, |
| 218 | ( |
| 219 | Attack, |
| 220 | GoalFunction, |
| 221 | Transformation, |
| 222 | SearchMethod, |
| 223 | Constraint, |
| 224 | PreTransformationConstraint, |
| 225 | ModelWrapper, |
| 226 | ), |
| 227 | ): |
| 228 | for key in obj.__dict__: |
| 229 | s_obj = obj.__dict__[key] |
| 230 | if id(s_obj) not in visited: |
| 231 | to_cuda(s_obj) |
| 232 | elif isinstance(obj, (list, tuple)): |
| 233 | for item in obj: |
| 234 | if id(item) not in visited and isinstance( |
| 235 | item, (Transformation, Constraint, PreTransformationConstraint) |
| 236 | ): |
| 237 | to_cuda(item) |
| 238 | |
| 239 | to_cuda(self) |
| 240 | |
| 241 | def get_indices_to_order(self, current_text, **kwargs): |
| 242 | """Applies ``pre_transformation_constraints`` to ``text`` to get all |
no outgoing calls
no test coverage detected