MCPcopy Index your code
hub / github.com/apache/tvm / test_define_subroutine

Function test_define_subroutine

tests/python/relax/test_testing_nn.py:62–119  ·  view source on GitHub ↗

Define subroutines when nn.Module.define_subroutine is True

()

Source from the content-addressed store, hash-verified

60
61
62def test_define_subroutine():
63 """Define subroutines when nn.Module.define_subroutine is True"""
64
65 class Activation(nn.Module):
66 define_subroutine = True
67
68 def forward(self, state: relax.Expr) -> relax.Var:
69 return relax.op.nn.relu(state)
70
71 class Layer(nn.Module):
72 define_subroutine = True
73
74 def __init__(self, in_features, out_features):
75 self.weights = nn.Parameter(
76 (in_features, out_features), dtype="float32", name="weights"
77 )
78 self.activation = Activation()
79
80 def forward(self, input: relax.Expr) -> relax.Var:
81 state = relax.op.matmul(input, self.weights)
82 return self.activation(state)
83
84 @I.ir_module
85 class Expected:
86 @R.function
87 def main(
88 state: R.Tensor(("batch_size", 64), dtype="float32"),
89 weights: R.Tensor((64, 32), dtype="float32"),
90 ) -> R.Tensor(("batch_size", 32), dtype="float32"):
91 state = Expected.layer(state, weights)
92 return state
93
94 @R.function(private=True)
95 def layer(
96 state: R.Tensor(("batch_size", 64), dtype="float32"),
97 weights: R.Tensor((64, 32), dtype="float32"),
98 ) -> R.Tensor(("batch_size", 32), dtype="float32"):
99 state = R.matmul(state, weights)
100 state = Expected.activation(state)
101 return state
102
103 @R.function(private=True)
104 def activation(state: R.Tensor(("batch_size", 32), dtype="float32")) -> R.Tensor(
105 ("batch_size", 32), dtype="float32"
106 ):
107 state = R.nn.relu(state)
108 return state
109
110 model = Layer(64, 32)
111 batch_size = tvm.tirx.Var("batch_size", "int64")
112 input = nn.Placeholder((batch_size, 64), dtype="float32", name="input")
113
114 bb = relax.BlockBuilder()
115 with bb.function("main", params=[input, *model.parameters()]):
116 output = model(input)
117 bb.emit_func_output(output)
118
119 tvm.ir.assert_structural_equal(Expected, bb.get())

Callers

nothing calls this directly

Calls 5

functionMethod · 0.95
emit_func_outputMethod · 0.95
getMethod · 0.95
LayerClass · 0.70
parametersMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…