(optimizer)
| 103 | |
| 104 | |
| 105 | def offload_fsdp_optimizer(optimizer): |
| 106 | for param_group in optimizer.param_groups: |
| 107 | for param in param_group['params']: |
| 108 | state = optimizer.state[param] |
| 109 | for key, value in state.items(): |
| 110 | if isinstance(value, torch.Tensor): |
| 111 | state[key] = value.to("cpu", non_blocking=True) |
| 112 | torch.cuda.empty_cache() |
| 113 | |
| 114 | |
| 115 | def load_fsdp_optimizer(optimizer, device_id): |
no test coverage detected