(module)
| 69 | |
| 70 | |
| 71 | def offload_fsdp_grad(module): |
| 72 | for _, param in module.named_parameters(): |
| 73 | if param.grad is not None: |
| 74 | param.grad = param.grad.to("cpu", non_blocking=True) |
| 75 | torch.cuda.empty_cache() |
| 76 | |
| 77 | |
| 78 | def load_fsdp_grad(module, device_id): |
no test coverage detected