()
| 56 | |
| 57 | |
| 58 | def test_sgd_simple(): |
| 59 | x = relax.Var("x", R.Tensor((3, 3), "float32")) |
| 60 | y = relax.Var("y", R.Tensor((3,), "float32")) |
| 61 | sgd = SGD(0.01).init([x, y]).get_function() |
| 62 | |
| 63 | @R.function |
| 64 | def sgd_expected( |
| 65 | params: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), |
| 66 | gradients: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), |
| 67 | optim_states: R.Tuple(R.Tensor((), "int64")), |
| 68 | ) -> R.Tuple( |
| 69 | R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), |
| 70 | R.Tuple(R.Tensor((), "int64")), |
| 71 | ): |
| 72 | R.func_attr({"global_symbol": "SGD"}) |
| 73 | # block 0 |
| 74 | with R.dataflow(): |
| 75 | num_steps: R.Tensor((), "int64") = optim_states[0] |
| 76 | num_steps_new: R.Tensor((), "int64") = R.add(num_steps, R.const(1, "int64")) |
| 77 | x: R.Tensor((3, 3), "float32") = params[0] |
| 78 | x_grad: R.Tensor((3, 3), "float32") = gradients[0] |
| 79 | lv: R.Tensor((3, 3), "float32") = R.multiply(R.const(0.01, "float32"), x_grad) |
| 80 | x_new: R.Tensor((3, 3), "float32") = R.subtract(x, lv) |
| 81 | y: R.Tensor((3,), "float32") = params[1] |
| 82 | y_grad: R.Tensor((3,), "float32") = gradients[1] |
| 83 | lv1: R.Tensor((3,), "float32") = R.multiply(R.const(0.01, "float32"), y_grad) |
| 84 | y_new: R.Tensor((3,), "float32") = R.subtract(y, lv1) |
| 85 | params_new: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")) = ( |
| 86 | x_new, |
| 87 | y_new, |
| 88 | ) |
| 89 | optim_states_new: R.Tuple(R.Tensor((), "int64")) = (num_steps_new,) |
| 90 | R.output(params_new, optim_states_new) |
| 91 | return (params_new, optim_states_new) |
| 92 | |
| 93 | assert_structural_equal(sgd, sgd_expected) |
| 94 | |
| 95 | |
| 96 | def test_sgd_complex(): |
nothing calls this directly
no test coverage detected
searching dependent graphs…