(self, params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False, device=None, fused=FUSE_OPTIM)
| 151 | - Paper: https://arxiv.org/abs/1904.00962 |
| 152 | """ |
| 153 | def __init__(self, params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False, device=None, fused=FUSE_OPTIM): |
| 154 | if weight_decay < 0: raise ValueError(f"Invalid weight_decay value: {weight_decay}") |
| 155 | super().__init__(params, lr, device, fused) |
| 156 | self.b1, self.b2, self.eps, self.wd, self.adam = b1, b2, eps, weight_decay, adam |
| 157 | self.b1_t, self.b2_t = (Tensor.ones((1,), dtype=dtypes.float32, device=self.device).is_param_(False) for _ in [b1, b2]) |
| 158 | self.m = self._new_optim_param() |
| 159 | self.v = self._new_optim_param() |
| 160 | |
| 161 | def _step(self, params:list[Tensor], grads:list[Tensor]) -> tuple[list[Tensor], list[Tensor]]: |
| 162 | ret = [] |
nothing calls this directly
no test coverage detected