MCPcopy Index your code
hub / github.com/apache/tvm / test_sgd_complex

Function test_sgd_complex

tests/python/relax/test_training_optimizer.py:96–134  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

94
95
96def test_sgd_complex():
97 x = relax.Var("x", R.Tensor((3, 3), "float32"))
98 y = relax.Var("y", R.Tensor((3,), "float32"))
99 sgd = SGD(0.01, 0.02).init([x, y]).get_function()
100
101 @R.function
102 def sgd_expected(
103 params: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")),
104 gradients: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")),
105 optim_states: R.Tuple(R.Tensor((), "int64")),
106 ) -> R.Tuple(
107 R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")),
108 R.Tuple(R.Tensor((), "int64")),
109 ):
110 R.func_attr({"global_symbol": "SGD"})
111 with R.dataflow():
112 num_steps: R.Tensor((), "int64") = optim_states[0]
113 num_steps_new: R.Tensor((), "int64") = R.add(num_steps, R.const(1, "int64"))
114 x: R.Tensor((3, 3), "float32") = params[0]
115 x_grad: R.Tensor((3, 3), "float32") = gradients[0]
116 lv: R.Tensor((3, 3), "float32") = R.multiply(R.const(0.02, "float32"), x)
117 x_grad_new: R.Tensor((3, 3), "float32") = R.add(lv, x_grad)
118 lv1: R.Tensor((3, 3), "float32") = R.multiply(R.const(0.01, "float32"), x_grad_new)
119 x_new: R.Tensor((3, 3), "float32") = R.subtract(x, lv1)
120 y: R.Tensor((3,), "float32") = params[1]
121 y_grad: R.Tensor((3,), "float32") = gradients[1]
122 lv2: R.Tensor((3,), "float32") = R.multiply(R.const(0.02, "float32"), y)
123 y_grad_new: R.Tensor((3,), "float32") = R.add(lv2, y_grad)
124 lv3: R.Tensor((3,), "float32") = R.multiply(R.const(0.01, "float32"), y_grad_new)
125 y_new: R.Tensor((3,), "float32") = R.subtract(y, lv3)
126 params_new: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")) = (
127 x_new,
128 y_new,
129 )
130 optim_states_new: R.Tuple(R.Tensor((), "int64")) = (num_steps_new,)
131 R.output(params_new, optim_states_new)
132 return (params_new, optim_states_new)
133
134 assert_structural_equal(sgd, sgd_expected)
135
136
137def test_momentum_sgd_simple():

Callers

nothing calls this directly

Calls 5

SGDClass · 0.90
assert_structural_equalFunction · 0.90
TensorMethod · 0.80
get_functionMethod · 0.45
initMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…