(module, device_id)
| 76 | |
| 77 | |
| 78 | def load_fsdp_grad(module, device_id): |
| 79 | for _, param in module.named_parameters(): |
| 80 | if param.grad is not None: |
| 81 | param.grad = param.grad.to(device_id, non_blocking=True) |
| 82 | torch.cuda.empty_cache() |
| 83 | |
| 84 | |
| 85 | def offload_fsdp_param_and_grad(module, offload_grad=False): |
nothing calls this directly
no test coverage detected