Move any `torch.nn.Module` models that are part of Attack to CPU.
(self)
| 173 | constraint.clear_cache() |
| 174 | |
| 175 | def cpu_(self): |
| 176 | """Move any `torch.nn.Module` models that are part of Attack to CPU.""" |
| 177 | visited = set() |
| 178 | |
| 179 | def to_cpu(obj): |
| 180 | visited.add(id(obj)) |
| 181 | if isinstance(obj, torch.nn.Module): |
| 182 | obj.cpu() |
| 183 | elif isinstance( |
| 184 | obj, |
| 185 | ( |
| 186 | Attack, |
| 187 | GoalFunction, |
| 188 | Transformation, |
| 189 | SearchMethod, |
| 190 | Constraint, |
| 191 | PreTransformationConstraint, |
| 192 | ModelWrapper, |
| 193 | ), |
| 194 | ): |
| 195 | for key in obj.__dict__: |
| 196 | s_obj = obj.__dict__[key] |
| 197 | if id(s_obj) not in visited: |
| 198 | to_cpu(s_obj) |
| 199 | elif isinstance(obj, (list, tuple)): |
| 200 | for item in obj: |
| 201 | if id(item) not in visited and isinstance( |
| 202 | item, (Transformation, Constraint, PreTransformationConstraint) |
| 203 | ): |
| 204 | to_cpu(item) |
| 205 | |
| 206 | to_cpu(self) |
| 207 | |
| 208 | def cuda_(self): |
| 209 | """Move any `torch.nn.Module` models that are part of Attack to GPU.""" |