MCPcopy Index your code
hub / github.com/apache/tvm / SGD

Class SGD

python/tvm/relax/training/optimizer.py:244–353  ·  view source on GitHub ↗

Implements stochastic gradient descent. The returned function of `get_function()` is equivalent to the following numpy code: .. code-block:: python def SGD(param_tuple, grad_tuple, state_tuple): num_steps = state_tuple[0] param_tuple_new, state_tuple_new =

Source from the content-addressed store, hash-verified

242
243
244class SGD(Optimizer):
245 """Implements stochastic gradient descent.
246
247 The returned function of `get_function()` is equivalent to the following numpy code:
248
249 .. code-block:: python
250
251 def SGD(param_tuple, grad_tuple, state_tuple):
252 num_steps = state_tuple[0]
253 param_tuple_new, state_tuple_new = [], []
254 state_tuple_new.append(num_steps + 1)
255 for i in range(len(param_tuple)):
256 param = param_tuple[i]
257 grad = grad_tuple[i]
258 param_tuple_new.append(param - lr * (grad + weight_decay * param))
259 return param_tuple_new, state_tuple_new
260
261 Parameters
262 ----------
263 lr : float
264 learning rate
265
266 weight_decay : float
267 weight decay (L2 penalty) (default: 0)
268 """
269
270 def __init__(self, lr: float, weight_decay: float = 0) -> None:
271 super().__init__("SGD")
272 self.lr = float(lr)
273 self.weight_decay = float(weight_decay)
274
275 def init(self, params: Var | list[Var]) -> "SGD":
276 """Set the parameters, determine the dtype, and construct the initial state for the
277 optimizer.
278
279 The state of SGD is `(num_steps,)`.
280
281 Parameters
282 ----------
283 params : Union[Var, List[Var]]
284 The parameter or the list of parameters to optimize.
285
286 Parameters should all be Vars of floating point Tensors, including float32, float64,
287 float16, etc. Currently, all parameters should have the same dtype, and that dtype
288 will be used as the dtype of the optimizer states.
289
290 Returns
291 -------
292 self : SGD
293 The SGD optimizer itself.
294 """
295 if not isinstance(params, list):
296 params = [params]
297 self._set_params_and_dtype(params)
298 self.state = (
299 # num_steps = 0
300 tvm.runtime.tensor(np.zeros((), "int64")),
301 )

Callers 8

test_execute_numericFunction · 0.90
test_load_export_paramsFunction · 0.90
test_setting_errorFunction · 0.90
test_optimizer_errorFunction · 0.90
test_sgd_simpleFunction · 0.90
test_sgd_complexFunction · 0.90
test_simpleFunction · 0.90
test_invalid_modFunction · 0.90

Calls

no outgoing calls

Tested by 8

test_execute_numericFunction · 0.72
test_load_export_paramsFunction · 0.72
test_setting_errorFunction · 0.72
test_optimizer_errorFunction · 0.72
test_sgd_simpleFunction · 0.72
test_sgd_complexFunction · 0.72
test_simpleFunction · 0.72
test_invalid_modFunction · 0.72