MCPcopy Index your code
hub / github.com/apache/tvm / test_load_export_params

Function test_load_export_params

tests/python/relax/test_training_trainer_numeric.py:110–141  ·  view source on GitHub ↗
(target, dev)

Source from the content-addressed store, hash-verified

108
109@tvm.testing.parametrize_targets("llvm")
110def test_load_export_params(target, dev):
111 backbone = _get_backbone()
112 pred_sinfo = relax.TensorStructInfo((1, 5), "float32")
113
114 setup_trainer = SetupTrainer(
115 MSELoss(reduction="sum"),
116 SGD(0.01),
117 [pred_sinfo, pred_sinfo],
118 )
119
120 train_mod = setup_trainer(backbone)
121 ex = tvm.compile(train_mod, target)
122 vm = relax.VirtualMachine(ex, dev)
123
124 trainer = Trainer(train_mod, vm, dev, False)
125 trainer.xaiver_uniform_init_params()
126
127 dataset = _make_dataset()
128 for input, label in dataset:
129 trainer.update(input, label)
130
131 param_dict = trainer.export_params()
132 assert "w0" in param_dict
133 assert "b0" in param_dict
134
135 trainer1 = Trainer(train_mod, vm, dev, False)
136 trainer1.load_params(param_dict)
137
138 x_sample = dataset[np.random.randint(len(dataset))][0]
139 tvm.testing.assert_allclose(
140 trainer.predict(x_sample).numpy(), trainer1.predict(x_sample).numpy()
141 )
142
143
144@tvm.testing.parametrize_targets("llvm")

Callers

nothing calls this directly

Calls 13

updateMethod · 0.95
export_paramsMethod · 0.95
load_paramsMethod · 0.95
predictMethod · 0.95
SetupTrainerClass · 0.90
MSELossClass · 0.90
SGDClass · 0.90
TrainerClass · 0.90
_get_backboneFunction · 0.85
_make_datasetFunction · 0.85
numpyMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…