| 18 | |
| 19 | |
| 20 | def copy_to_device(obj, device): |
| 21 | if torch.is_tensor(obj): |
| 22 | # Notice: |
| 23 | # When in no_grad context, requires_gard is False after movement |
| 24 | ret = obj.to(device).detach() |
| 25 | ret.requires_grad = obj.requires_grad |
| 26 | return ret |
| 27 | elif isinstance(obj, list): |
| 28 | return [copy_to_device(i, device) for i in obj] |
| 29 | elif isinstance(obj, tuple): |
| 30 | return tuple([copy_to_device(v, device) for v in obj]) |
| 31 | elif isinstance(obj, dict): |
| 32 | return {k: copy_to_device(v, device) for k, v in obj.items()} |
| 33 | else: |
| 34 | return obj |
| 35 | |
| 36 | |
| 37 | class CheckpointFunction(torch.autograd.Function): |