()
| 94 | |
| 95 | |
| 96 | def test_sgd_complex(): |
| 97 | x = relax.Var("x", R.Tensor((3, 3), "float32")) |
| 98 | y = relax.Var("y", R.Tensor((3,), "float32")) |
| 99 | sgd = SGD(0.01, 0.02).init([x, y]).get_function() |
| 100 | |
| 101 | @R.function |
| 102 | def sgd_expected( |
| 103 | params: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), |
| 104 | gradients: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), |
| 105 | optim_states: R.Tuple(R.Tensor((), "int64")), |
| 106 | ) -> R.Tuple( |
| 107 | R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), |
| 108 | R.Tuple(R.Tensor((), "int64")), |
| 109 | ): |
| 110 | R.func_attr({"global_symbol": "SGD"}) |
| 111 | with R.dataflow(): |
| 112 | num_steps: R.Tensor((), "int64") = optim_states[0] |
| 113 | num_steps_new: R.Tensor((), "int64") = R.add(num_steps, R.const(1, "int64")) |
| 114 | x: R.Tensor((3, 3), "float32") = params[0] |
| 115 | x_grad: R.Tensor((3, 3), "float32") = gradients[0] |
| 116 | lv: R.Tensor((3, 3), "float32") = R.multiply(R.const(0.02, "float32"), x) |
| 117 | x_grad_new: R.Tensor((3, 3), "float32") = R.add(lv, x_grad) |
| 118 | lv1: R.Tensor((3, 3), "float32") = R.multiply(R.const(0.01, "float32"), x_grad_new) |
| 119 | x_new: R.Tensor((3, 3), "float32") = R.subtract(x, lv1) |
| 120 | y: R.Tensor((3,), "float32") = params[1] |
| 121 | y_grad: R.Tensor((3,), "float32") = gradients[1] |
| 122 | lv2: R.Tensor((3,), "float32") = R.multiply(R.const(0.02, "float32"), y) |
| 123 | y_grad_new: R.Tensor((3,), "float32") = R.add(lv2, y_grad) |
| 124 | lv3: R.Tensor((3,), "float32") = R.multiply(R.const(0.01, "float32"), y_grad_new) |
| 125 | y_new: R.Tensor((3,), "float32") = R.subtract(y, lv3) |
| 126 | params_new: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")) = ( |
| 127 | x_new, |
| 128 | y_new, |
| 129 | ) |
| 130 | optim_states_new: R.Tuple(R.Tensor((), "int64")) = (num_steps_new,) |
| 131 | R.output(params_new, optim_states_new) |
| 132 | return (params_new, optim_states_new) |
| 133 | |
| 134 | assert_structural_equal(sgd, sgd_expected) |
| 135 | |
| 136 | |
| 137 | def test_momentum_sgd_simple(): |
nothing calls this directly
no test coverage detected
searching dependent graphs…