MCPcopy Index your code
hub / github.com/apache/tvm / test_invalid_mod

Function test_invalid_mod

tests/python/relax/test_training_setup_trainer.py:183–229  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

181
182
183def test_invalid_mod():
184 @I.ir_module
185 class NoAttr:
186 @R.function
187 def backbone(
188 w0: R.Tensor((10, 5), "float32"),
189 b0: R.Tensor((5,), "float32"),
190 x: R.Tensor((1, 10), "float32"),
191 ):
192 with R.dataflow():
193 lv0 = R.matmul(x, w0)
194 gv = R.add(lv0, b0)
195 out = R.nn.relu(gv)
196 R.output(gv, out)
197 return gv, out
198
199 pred_sinfo = relax.TensorStructInfo((1, 5), "float32")
200 setup_trainer = SetupTrainer(
201 MSELoss(reduction="sum"),
202 SGD(0.001),
203 [pred_sinfo, pred_sinfo],
204 )
205
206 with pytest.raises((RuntimeError, ValueError)):
207 SetupTrainer(
208 MSELoss(reduction="sum"),
209 SGD(0.001),
210 [pred_sinfo, pred_sinfo],
211 )(NoAttr)
212
213 @I.ir_module
214 class WrongFuncName:
215 @R.function
216 def main(
217 w0: R.Tensor((10, 5), "float32"),
218 b0: R.Tensor((5,), "float32"),
219 x: R.Tensor((1, 10), "float32"),
220 ):
221 with R.dataflow():
222 lv0 = R.matmul(x, w0)
223 lv1 = R.add(lv0, b0)
224 out = R.nn.relu(lv1)
225 R.output(out)
226 return out
227
228 with pytest.raises(ValueError):
229 setup_trainer(WrongFuncName)
230
231
232if __name__ == "__main__":

Callers

nothing calls this directly

Calls 3

SetupTrainerClass · 0.90
MSELossClass · 0.90
SGDClass · 0.90

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…