Define subroutines when nn.Module.define_subroutine is True
()
| 60 | |
| 61 | |
| 62 | def test_define_subroutine(): |
| 63 | """Define subroutines when nn.Module.define_subroutine is True""" |
| 64 | |
| 65 | class Activation(nn.Module): |
| 66 | define_subroutine = True |
| 67 | |
| 68 | def forward(self, state: relax.Expr) -> relax.Var: |
| 69 | return relax.op.nn.relu(state) |
| 70 | |
| 71 | class Layer(nn.Module): |
| 72 | define_subroutine = True |
| 73 | |
| 74 | def __init__(self, in_features, out_features): |
| 75 | self.weights = nn.Parameter( |
| 76 | (in_features, out_features), dtype="float32", name="weights" |
| 77 | ) |
| 78 | self.activation = Activation() |
| 79 | |
| 80 | def forward(self, input: relax.Expr) -> relax.Var: |
| 81 | state = relax.op.matmul(input, self.weights) |
| 82 | return self.activation(state) |
| 83 | |
| 84 | @I.ir_module |
| 85 | class Expected: |
| 86 | @R.function |
| 87 | def main( |
| 88 | state: R.Tensor(("batch_size", 64), dtype="float32"), |
| 89 | weights: R.Tensor((64, 32), dtype="float32"), |
| 90 | ) -> R.Tensor(("batch_size", 32), dtype="float32"): |
| 91 | state = Expected.layer(state, weights) |
| 92 | return state |
| 93 | |
| 94 | @R.function(private=True) |
| 95 | def layer( |
| 96 | state: R.Tensor(("batch_size", 64), dtype="float32"), |
| 97 | weights: R.Tensor((64, 32), dtype="float32"), |
| 98 | ) -> R.Tensor(("batch_size", 32), dtype="float32"): |
| 99 | state = R.matmul(state, weights) |
| 100 | state = Expected.activation(state) |
| 101 | return state |
| 102 | |
| 103 | @R.function(private=True) |
| 104 | def activation(state: R.Tensor(("batch_size", 32), dtype="float32")) -> R.Tensor( |
| 105 | ("batch_size", 32), dtype="float32" |
| 106 | ): |
| 107 | state = R.nn.relu(state) |
| 108 | return state |
| 109 | |
| 110 | model = Layer(64, 32) |
| 111 | batch_size = tvm.tirx.Var("batch_size", "int64") |
| 112 | input = nn.Placeholder((batch_size, 64), dtype="float32", name="input") |
| 113 | |
| 114 | bb = relax.BlockBuilder() |
| 115 | with bb.function("main", params=[input, *model.parameters()]): |
| 116 | output = model(input) |
| 117 | bb.emit_func_output(output) |
| 118 | |
| 119 | tvm.ir.assert_structural_equal(Expected, bb.get()) |
nothing calls this directly
no test coverage detected
searching dependent graphs…