(optimizer, device_id)
| 113 | |
| 114 | |
| 115 | def load_fsdp_optimizer(optimizer, device_id): |
| 116 | for param_group in optimizer.param_groups: |
| 117 | for param in param_group['params']: |
| 118 | state = optimizer.state[param] |
| 119 | for key, value in state.items(): |
| 120 | if isinstance(value, torch.Tensor): |
| 121 | state[key] = value.to(device_id, non_blocking=True) |
| 122 | torch.cuda.empty_cache() |
no test coverage detected