(module, device_id, load_grad=False)
| 93 | |
| 94 | |
| 95 | def load_fsdp_param_and_grad(module, device_id, load_grad=False): |
| 96 | for _, param in module.named_parameters(): |
| 97 | if hasattr(param, "_local_shard"): |
| 98 | param._local_shard = param._local_shard.to(device_id, non_blocking=True) |
| 99 | param.data = param.data.to(device_id, non_blocking=True) |
| 100 | if load_grad and param.grad is not None: |
| 101 | param.grad = param.grad.to(device_id, non_blocking=True) |
| 102 | torch.cuda.empty_cache() |
| 103 | |
| 104 | |
| 105 | def offload_fsdp_optimizer(optimizer): |
no test coverage detected