Implements stochastic gradient descent. The returned function of `get_function()` is equivalent to the following numpy code: .. code-block:: python def SGD(param_tuple, grad_tuple, state_tuple): num_steps = state_tuple[0] param_tuple_new, state_tuple_new =
| 242 | |
| 243 | |
| 244 | class SGD(Optimizer): |
| 245 | """Implements stochastic gradient descent. |
| 246 | |
| 247 | The returned function of `get_function()` is equivalent to the following numpy code: |
| 248 | |
| 249 | .. code-block:: python |
| 250 | |
| 251 | def SGD(param_tuple, grad_tuple, state_tuple): |
| 252 | num_steps = state_tuple[0] |
| 253 | param_tuple_new, state_tuple_new = [], [] |
| 254 | state_tuple_new.append(num_steps + 1) |
| 255 | for i in range(len(param_tuple)): |
| 256 | param = param_tuple[i] |
| 257 | grad = grad_tuple[i] |
| 258 | param_tuple_new.append(param - lr * (grad + weight_decay * param)) |
| 259 | return param_tuple_new, state_tuple_new |
| 260 | |
| 261 | Parameters |
| 262 | ---------- |
| 263 | lr : float |
| 264 | learning rate |
| 265 | |
| 266 | weight_decay : float |
| 267 | weight decay (L2 penalty) (default: 0) |
| 268 | """ |
| 269 | |
| 270 | def __init__(self, lr: float, weight_decay: float = 0) -> None: |
| 271 | super().__init__("SGD") |
| 272 | self.lr = float(lr) |
| 273 | self.weight_decay = float(weight_decay) |
| 274 | |
| 275 | def init(self, params: Var | list[Var]) -> "SGD": |
| 276 | """Set the parameters, determine the dtype, and construct the initial state for the |
| 277 | optimizer. |
| 278 | |
| 279 | The state of SGD is `(num_steps,)`. |
| 280 | |
| 281 | Parameters |
| 282 | ---------- |
| 283 | params : Union[Var, List[Var]] |
| 284 | The parameter or the list of parameters to optimize. |
| 285 | |
| 286 | Parameters should all be Vars of floating point Tensors, including float32, float64, |
| 287 | float16, etc. Currently, all parameters should have the same dtype, and that dtype |
| 288 | will be used as the dtype of the optimizer states. |
| 289 | |
| 290 | Returns |
| 291 | ------- |
| 292 | self : SGD |
| 293 | The SGD optimizer itself. |
| 294 | """ |
| 295 | if not isinstance(params, list): |
| 296 | params = [params] |
| 297 | self._set_params_and_dtype(params) |
| 298 | self.state = ( |
| 299 | # num_steps = 0 |
| 300 | tvm.runtime.tensor(np.zeros((), "int64")), |
| 301 | ) |
no outgoing calls