(target, dev)
| 108 | |
| 109 | @tvm.testing.parametrize_targets("llvm") |
| 110 | def test_load_export_params(target, dev): |
| 111 | backbone = _get_backbone() |
| 112 | pred_sinfo = relax.TensorStructInfo((1, 5), "float32") |
| 113 | |
| 114 | setup_trainer = SetupTrainer( |
| 115 | MSELoss(reduction="sum"), |
| 116 | SGD(0.01), |
| 117 | [pred_sinfo, pred_sinfo], |
| 118 | ) |
| 119 | |
| 120 | train_mod = setup_trainer(backbone) |
| 121 | ex = tvm.compile(train_mod, target) |
| 122 | vm = relax.VirtualMachine(ex, dev) |
| 123 | |
| 124 | trainer = Trainer(train_mod, vm, dev, False) |
| 125 | trainer.xaiver_uniform_init_params() |
| 126 | |
| 127 | dataset = _make_dataset() |
| 128 | for input, label in dataset: |
| 129 | trainer.update(input, label) |
| 130 | |
| 131 | param_dict = trainer.export_params() |
| 132 | assert "w0" in param_dict |
| 133 | assert "b0" in param_dict |
| 134 | |
| 135 | trainer1 = Trainer(train_mod, vm, dev, False) |
| 136 | trainer1.load_params(param_dict) |
| 137 | |
| 138 | x_sample = dataset[np.random.randint(len(dataset))][0] |
| 139 | tvm.testing.assert_allclose( |
| 140 | trainer.predict(x_sample).numpy(), trainer1.predict(x_sample).numpy() |
| 141 | ) |
| 142 | |
| 143 | |
| 144 | @tvm.testing.parametrize_targets("llvm") |
nothing calls this directly
no test coverage detected
searching dependent graphs…