| 9 | Base class for all optimizers. |
| 10 | """ |
| 11 | def __init__(self, params: list[Tensor], lr: float, device=None, fused=FUSE_OPTIM): |
| 12 | if lr < 0: raise ValueError(f"Invalid learning rate: {lr}") |
| 13 | self.params: list[Tensor] = dedup([x for x in params if x.is_param]) |
| 14 | assert len(self.params) != 0, "optimizer must have at least one param" |
| 15 | self.buffers: list[Tensor] = dedup([x for x in params if not x.is_param]) # buffers are still realized |
| 16 | self.device = device or self.params[0].device |
| 17 | self.param_dtype = to_dtype(getenv("OPTIM_DTYPE", "float32")) |
| 18 | self.fused = fused |
| 19 | # store lr in at least float32 precision |
| 20 | self.lr = Tensor(lr if getenv("CONST_LR") else [lr], device=self.device, |
| 21 | dtype=least_upper_dtype(dtypes.default_float, dtypes.float32)) |
| 22 | if self.fused: self.pos_params = list(itertools.accumulate(self.params, lambda x,y: x+y.numel(), initial=0)) |
| 23 | |
| 24 | def _new_optim_param(self) -> list[Tensor]: |
| 25 | if self.fused: return [Tensor.zeros(self.pos_params[-1], dtype=self.param_dtype, device=self.device)] |