MCPcopy Index your code
hub / github.com/apache/tvm / test_sgd_simple

Function test_sgd_simple

tests/python/relax/test_training_optimizer.py:58–93  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

56
57
58def test_sgd_simple():
59 x = relax.Var("x", R.Tensor((3, 3), "float32"))
60 y = relax.Var("y", R.Tensor((3,), "float32"))
61 sgd = SGD(0.01).init([x, y]).get_function()
62
63 @R.function
64 def sgd_expected(
65 params: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")),
66 gradients: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")),
67 optim_states: R.Tuple(R.Tensor((), "int64")),
68 ) -> R.Tuple(
69 R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")),
70 R.Tuple(R.Tensor((), "int64")),
71 ):
72 R.func_attr({"global_symbol": "SGD"})
73 # block 0
74 with R.dataflow():
75 num_steps: R.Tensor((), "int64") = optim_states[0]
76 num_steps_new: R.Tensor((), "int64") = R.add(num_steps, R.const(1, "int64"))
77 x: R.Tensor((3, 3), "float32") = params[0]
78 x_grad: R.Tensor((3, 3), "float32") = gradients[0]
79 lv: R.Tensor((3, 3), "float32") = R.multiply(R.const(0.01, "float32"), x_grad)
80 x_new: R.Tensor((3, 3), "float32") = R.subtract(x, lv)
81 y: R.Tensor((3,), "float32") = params[1]
82 y_grad: R.Tensor((3,), "float32") = gradients[1]
83 lv1: R.Tensor((3,), "float32") = R.multiply(R.const(0.01, "float32"), y_grad)
84 y_new: R.Tensor((3,), "float32") = R.subtract(y, lv1)
85 params_new: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")) = (
86 x_new,
87 y_new,
88 )
89 optim_states_new: R.Tuple(R.Tensor((), "int64")) = (num_steps_new,)
90 R.output(params_new, optim_states_new)
91 return (params_new, optim_states_new)
92
93 assert_structural_equal(sgd, sgd_expected)
94
95
96def test_sgd_complex():

Callers

nothing calls this directly

Calls 5

SGDClass · 0.90
assert_structural_equalFunction · 0.90
TensorMethod · 0.80
get_functionMethod · 0.45
initMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…