Memory Efficient AdamW optimizer that keeps parameters and gradients on GPU but optimizer states on CPU when enabled. When disabled, behaves exactly like standard AdamW.
| 5 | |
| 6 | |
| 7 | class MemoryEfficientAdamW(AdamW): |
| 8 | """ |
| 9 | Memory Efficient AdamW optimizer that keeps parameters and gradients on GPU |
| 10 | but optimizer states on CPU when enabled. |
| 11 | When disabled, behaves exactly like standard AdamW. |
| 12 | """ |
| 13 | |
| 14 | def __init__( |
| 15 | self, |
| 16 | params, |
| 17 | lr=1e-3, |
| 18 | betas=(0.9, 0.999), |
| 19 | eps=1e-8, |
| 20 | weight_decay=1e-2, |
| 21 | amsgrad=False, |
| 22 | pin_memory=True, |
| 23 | enabled=True, |
| 24 | ): |
| 25 | super(MemoryEfficientAdamW, self).__init__( |
| 26 | params, |
| 27 | lr=lr, |
| 28 | betas=betas, |
| 29 | eps=eps, |
| 30 | weight_decay=weight_decay, |
| 31 | amsgrad=amsgrad, |
| 32 | ) |
| 33 | self.pin_memory = pin_memory |
| 34 | self.enabled = enabled |
| 35 | |
| 36 | @torch.no_grad() |
| 37 | def step(self, closure=None): |
| 38 | """Performs a single optimization step.""" |
| 39 | if not self.enabled: |
| 40 | # Use the parent AdamW implementation when disabled |
| 41 | return super(MemoryEfficientAdamW, self).step(closure) |
| 42 | |
| 43 | loss = None |
| 44 | if closure is not None: |
| 45 | with torch.enable_grad(): |
| 46 | loss = closure() |
| 47 | |
| 48 | for group in self.param_groups: |
| 49 | params_with_grad = [] |
| 50 | grads = [] |
| 51 | exp_avgs = [] |
| 52 | exp_avg_sqs = [] |
| 53 | max_exp_avg_sqs = [] |
| 54 | state_steps = [] |
| 55 | beta1, beta2 = group["betas"] |
| 56 | |
| 57 | for p in group["params"]: |
| 58 | if p.grad is None: |
| 59 | continue |
| 60 | |
| 61 | params_with_grad.append(p) |
| 62 | grads.append(p.grad) |
| 63 | |
| 64 | # Initialize state if needed |