()
| 29 | |
| 30 | |
| 31 | def test_simple(): |
| 32 | # fmt: off |
| 33 | @I.ir_module |
| 34 | class Backbone: |
| 35 | I.module_attrs({"param_num": 1, "state_num": 0}) |
| 36 | @R.function |
| 37 | def backbone(x: R.Tensor((2, 2), "float64"), y: R.Tensor((2, 2), "float64")): |
| 38 | with R.dataflow(): |
| 39 | x1 = x + y |
| 40 | R.output(x1) |
| 41 | return x1 |
| 42 | |
| 43 | @I.ir_module |
| 44 | class Expected: |
| 45 | I.module_attrs({"input_num": 1, "param_num": 1, "state_num": 0}) |
| 46 | @R.function |
| 47 | def backbone(x: R.Tensor((2, 2), dtype="float64"), y: R.Tensor((2, 2), dtype="float64")) -> R.Tensor((2, 2), dtype="float64"): |
| 48 | with R.dataflow(): |
| 49 | x1: R.Tensor((2, 2), dtype="float64") = R.add(x, y) |
| 50 | R.output(x1) |
| 51 | return x1 |
| 52 | |
| 53 | @R.function |
| 54 | def backbone_loss(x: R.Tensor((2, 2), dtype="float64"), y: R.Tensor((2, 2), dtype="float64"), targets: R.Tensor((2, 2), dtype="float64")) -> R.Tensor((), dtype="float64"): |
| 55 | with R.dataflow(): |
| 56 | x1: R.Tensor((2, 2), dtype="float64") = R.add(x, y) |
| 57 | lv: R.Tensor((2, 2), dtype="float64") = R.subtract(x1, targets) |
| 58 | lv1: R.Tensor((2, 2), dtype="float64") = R.multiply(lv, lv) |
| 59 | gv: R.Tensor((), dtype="float64") = R.sum(lv1, axis=None, keepdims=False) |
| 60 | R.output(gv) |
| 61 | return gv |
| 62 | |
| 63 | @R.function |
| 64 | def backbone_loss_adjoint(x: R.Tensor((2, 2), dtype="float64"), y: R.Tensor((2, 2), dtype="float64"), targets: R.Tensor((2, 2), dtype="float64")) -> R.Tuple(R.Tensor((), dtype="float64"), R.Tuple(R.Tensor((2, 2), dtype="float64"))): |
| 65 | with R.dataflow(): |
| 66 | x1: R.Tensor((2, 2), dtype="float64") = R.add(x, y) |
| 67 | lv: R.Tensor((2, 2), dtype="float64") = R.subtract(x1, targets) |
| 68 | lv1: R.Tensor((2, 2), dtype="float64") = R.multiply(lv, lv) |
| 69 | gv: R.Tensor((), dtype="float64") = R.sum(lv1, axis=None, keepdims=False) |
| 70 | gv_adjoint: R.Tensor((), dtype="float64") = R.ones(R.shape([]), dtype="float64") |
| 71 | lv1_adjoint: R.Tensor((2, 2), dtype="float64") = R.broadcast_to(gv_adjoint, R.shape([2, 2])) |
| 72 | lv_adjoint: R.Tensor((2, 2), dtype="float64") = R.multiply(lv1_adjoint, lv) |
| 73 | lv_1: R.Tensor((2, 2), dtype="float64") = R.multiply(lv1_adjoint, lv) |
| 74 | lv_adjoint1: R.Tensor((2, 2), dtype="float64") = R.add(lv_adjoint, lv_1) |
| 75 | x1_adjoint: R.Tensor((2, 2), dtype="float64") = lv_adjoint1 |
| 76 | y_adjoint: R.Tensor((2, 2), dtype="float64") = x1_adjoint |
| 77 | y_adjoint_out: R.Tensor((2, 2), dtype="float64") = y_adjoint |
| 78 | R.output(gv, y_adjoint_out) |
| 79 | return (gv, (y_adjoint_out,)) |
| 80 | |
| 81 | @R.function |
| 82 | def optimizer(params: R.Tuple(R.Tensor((2, 2), dtype="float64")), gradients: R.Tuple(R.Tensor((2, 2), dtype="float64")), optim_states: R.Tuple(R.Tensor((), dtype="int64"))) -> R.Tuple(R.Tuple(R.Tensor((2, 2), dtype="float64")), R.Tuple(R.Tensor((), dtype="int64"))): |
| 83 | with R.dataflow(): |
| 84 | num_steps: R.Tensor((), dtype="int64") = optim_states[0] |
| 85 | num_steps_new: R.Tensor((), dtype="int64") = R.add(num_steps, R.const(1, "int64")) |
| 86 | y: R.Tensor((2, 2), dtype="float64") = params[0] |
| 87 | y_grad: R.Tensor((2, 2), dtype="float64") = gradients[0] |
| 88 | lv: R.Tensor((2, 2), dtype="float64") = R.multiply(R.const(0.10000000000000001, "float64"), y_grad) |
nothing calls this directly
no test coverage detected
searching dependent graphs…