MCPcopy Index your code
hub / github.com/apache/tvm / test_execute_numeric

Function test_execute_numeric

tests/python/relax/test_training_trainer_numeric.py:81–106  ·  view source on GitHub ↗
(target, dev)

Source from the content-addressed store, hash-verified

79
80@tvm.testing.parametrize_targets("llvm")
81def test_execute_numeric(target, dev):
82 backbone = _get_backbone()
83 pred_sinfo = relax.TensorStructInfo((1, 5), "float32")
84
85 setup_trainer = SetupTrainer(
86 MSELoss(reduction="sum"),
87 SGD(0.01),
88 [pred_sinfo, pred_sinfo],
89 )
90
91 train_mod = setup_trainer(backbone)
92 ex = tvm.compile(train_mod, target)
93 vm = relax.VirtualMachine(ex, dev)
94
95 trainer = Trainer(train_mod, vm, dev, False)
96 trainer.zero_init_params()
97
98 dataset = _make_dataset()
99 for _ in range(2):
100 for input, label in dataset:
101 loss = trainer.update(input, label)
102 tvm.testing.assert_allclose(loss.numpy(), 3.1974423e-14)
103
104 result = trainer.predict(dataset[0][0])
105 result_expected = np.array([[0, 0, 0.9999998, 0, 0]], np.float32)
106 tvm.testing.assert_allclose(result.numpy(), result_expected)
107
108
109@tvm.testing.parametrize_targets("llvm")

Callers

nothing calls this directly

Calls 11

zero_init_paramsMethod · 0.95
updateMethod · 0.95
predictMethod · 0.95
SetupTrainerClass · 0.90
MSELossClass · 0.90
SGDClass · 0.90
TrainerClass · 0.90
_get_backboneFunction · 0.85
_make_datasetFunction · 0.85
numpyMethod · 0.80
compileMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…