()
| 6587 | |
| 6588 | |
| 6589 | def test_keep_params(): |
| 6590 | class Conv2D1(Module): |
| 6591 | def __init__(self): |
| 6592 | super().__init__() |
| 6593 | self.conv = torch.nn.Conv2d(3, 6, 7, bias=True) |
| 6594 | |
| 6595 | def forward(self, input): |
| 6596 | return self.conv(input) |
| 6597 | |
| 6598 | @tvm.script.ir_module |
| 6599 | class expected1: |
| 6600 | @R.function |
| 6601 | def main( |
| 6602 | input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), |
| 6603 | conv_weight: R.Tensor((6, 3, 7, 7), dtype="float32"), |
| 6604 | conv_bias: R.Tensor((6,), dtype="float32"), |
| 6605 | ) -> R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")): |
| 6606 | R.func_attr({"num_input": 1}) |
| 6607 | # block 0 |
| 6608 | with R.dataflow(): |
| 6609 | lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d( |
| 6610 | input_1, |
| 6611 | conv_weight, |
| 6612 | strides=[1, 1], |
| 6613 | padding=[0, 0, 0, 0], |
| 6614 | dilation=[1, 1], |
| 6615 | data_layout="NCHW", |
| 6616 | kernel_layout="OIHW", |
| 6617 | out_layout="NCHW", |
| 6618 | out_dtype="float32", |
| 6619 | ) |
| 6620 | lv2: R.Tensor((1, 6, 1, 1), dtype="float32") = R.reshape(conv_bias, [1, 6, 1, 1]) |
| 6621 | lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2) |
| 6622 | gv: R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")) = (lv3,) |
| 6623 | R.output(gv) |
| 6624 | return gv |
| 6625 | |
| 6626 | from tvm.relax.frontend import detach_params |
| 6627 | |
| 6628 | example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| 6629 | model = Conv2D1() |
| 6630 | |
| 6631 | exported_program = torch.export.export(model, example_args) |
| 6632 | mod = from_exported_program(exported_program, keep_params_as_input=True) |
| 6633 | mod, params = detach_params(mod) |
| 6634 | tvm.ir.assert_structural_equal(mod, expected1) |
| 6635 | func = mod["main"] |
| 6636 | params = params["main"] |
| 6637 | |
| 6638 | assert len(params) == len(func.params) - 1 |
| 6639 | for param_var, param_tensor in zip(func.params[1:], params): |
| 6640 | assert tuple(x.value for x in param_var.struct_info.shape.values) == param_tensor.shape |
| 6641 | assert param_var.struct_info.dtype == param_tensor.dtype |
| 6642 | |
| 6643 | tvm.testing.assert_allclose(params[0].numpy(), model.conv.weight.detach().detach().numpy()) |
| 6644 | tvm.testing.assert_allclose(params[1].numpy(), model.conv.bias.detach().detach().numpy()) |
| 6645 | |
| 6646 |
nothing calls this directly
no test coverage detected
searching dependent graphs…