MCPcopy
hub / github.com/tinygrad/tinygrad / __init__

Method __init__

tinygrad/nn/optim.py:11–22  ·  view source on GitHub ↗
(self, params: list[Tensor], lr: float, device=None, fused=FUSE_OPTIM)

Source from the content-addressed store, hash-verified

9 Base class for all optimizers.
10 """
11 def __init__(self, params: list[Tensor], lr: float, device=None, fused=FUSE_OPTIM):
12 if lr < 0: raise ValueError(f"Invalid learning rate: {lr}")
13 self.params: list[Tensor] = dedup([x for x in params if x.is_param])
14 assert len(self.params) != 0, "optimizer must have at least one param"
15 self.buffers: list[Tensor] = dedup([x for x in params if not x.is_param]) # buffers are still realized
16 self.device = device or self.params[0].device
17 self.param_dtype = to_dtype(getenv("OPTIM_DTYPE", "float32"))
18 self.fused = fused
19 # store lr in at least float32 precision
20 self.lr = Tensor(lr if getenv("CONST_LR") else [lr], device=self.device,
21 dtype=least_upper_dtype(dtypes.default_float, dtypes.float32))
22 if self.fused: self.pos_params = list(itertools.accumulate(self.params, lambda x,y: x+y.numel(), initial=0))
23
24 def _new_optim_param(self) -> list[Tensor]:
25 if self.fused: return [Tensor.zeros(self.pos_params[-1], dtype=self.param_dtype, device=self.device)]

Callers

nothing calls this directly

Calls 6

dedupFunction · 0.90
to_dtypeFunction · 0.90
getenvFunction · 0.90
TensorClass · 0.90
least_upper_dtypeFunction · 0.90
numelMethod · 0.80

Tested by

no test coverage detected