AdamW with weight decay applied only to >=2D parameters (matrices), not to biases / norms / 1D params. Standard GPT recipe.
(
model: nn.Module,
lr: float,
weight_decay: float,
betas: tuple[float, float] = (0.9, 0.95),
)
| 15 | |
| 16 | |
| 17 | def configure_optimizer( |
| 18 | model: nn.Module, |
| 19 | lr: float, |
| 20 | weight_decay: float, |
| 21 | betas: tuple[float, float] = (0.9, 0.95), |
| 22 | ) -> torch.optim.AdamW: |
| 23 | """AdamW with weight decay applied only to >=2D parameters (matrices), not to |
| 24 | biases / norms / 1D params. Standard GPT recipe.""" |
| 25 | decay, no_decay = [], [] |
| 26 | for name, p in model.named_parameters(): |
| 27 | if not p.requires_grad: |
| 28 | continue |
| 29 | if p.dim() >= 2: |
| 30 | decay.append(p) |
| 31 | else: |
| 32 | no_decay.append(p) |
| 33 | groups = [ |
| 34 | {"params": decay, "weight_decay": weight_decay}, |
| 35 | {"params": no_decay, "weight_decay": 0.0}, |
| 36 | ] |
| 37 | return torch.optim.AdamW(groups, lr=lr, betas=betas) |
| 38 | |
| 39 | |
| 40 | def cosine_lr(step: int, *, warmup_steps: int, max_steps: int, lr: float, min_lr: float) -> float: |