MCPcopy
hub / github.com/policy-gradient/GRPO-Zero / MemoryEfficientAdamW

Class MemoryEfficientAdamW

optimizer.py:7–170  ·  view source on GitHub ↗

Memory Efficient AdamW optimizer that keeps parameters and gradients on GPU but optimizer states on CPU when enabled. When disabled, behaves exactly like standard AdamW.

Source from the content-addressed store, hash-verified

5
6
7class MemoryEfficientAdamW(AdamW):
8 """
9 Memory Efficient AdamW optimizer that keeps parameters and gradients on GPU
10 but optimizer states on CPU when enabled.
11 When disabled, behaves exactly like standard AdamW.
12 """
13
14 def __init__(
15 self,
16 params,
17 lr=1e-3,
18 betas=(0.9, 0.999),
19 eps=1e-8,
20 weight_decay=1e-2,
21 amsgrad=False,
22 pin_memory=True,
23 enabled=True,
24 ):
25 super(MemoryEfficientAdamW, self).__init__(
26 params,
27 lr=lr,
28 betas=betas,
29 eps=eps,
30 weight_decay=weight_decay,
31 amsgrad=amsgrad,
32 )
33 self.pin_memory = pin_memory
34 self.enabled = enabled
35
36 @torch.no_grad()
37 def step(self, closure=None):
38 """Performs a single optimization step."""
39 if not self.enabled:
40 # Use the parent AdamW implementation when disabled
41 return super(MemoryEfficientAdamW, self).step(closure)
42
43 loss = None
44 if closure is not None:
45 with torch.enable_grad():
46 loss = closure()
47
48 for group in self.param_groups:
49 params_with_grad = []
50 grads = []
51 exp_avgs = []
52 exp_avg_sqs = []
53 max_exp_avg_sqs = []
54 state_steps = []
55 beta1, beta2 = group["betas"]
56
57 for p in group["params"]:
58 if p.grad is None:
59 continue
60
61 params_with_grad.append(p)
62 grads.append(p.grad)
63
64 # Initialize state if needed

Callers 1

mainFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected