MCPcopy Index your code
hub / github.com/apache/tvm / test_keep_params

Function test_keep_params

tests/python/relax/test_frontend_from_exported_program.py:6589–6644  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

6587
6588
6589def test_keep_params():
6590 class Conv2D1(Module):
6591 def __init__(self):
6592 super().__init__()
6593 self.conv = torch.nn.Conv2d(3, 6, 7, bias=True)
6594
6595 def forward(self, input):
6596 return self.conv(input)
6597
6598 @tvm.script.ir_module
6599 class expected1:
6600 @R.function
6601 def main(
6602 input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
6603 conv_weight: R.Tensor((6, 3, 7, 7), dtype="float32"),
6604 conv_bias: R.Tensor((6,), dtype="float32"),
6605 ) -> R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")):
6606 R.func_attr({"num_input": 1})
6607 # block 0
6608 with R.dataflow():
6609 lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d(
6610 input_1,
6611 conv_weight,
6612 strides=[1, 1],
6613 padding=[0, 0, 0, 0],
6614 dilation=[1, 1],
6615 data_layout="NCHW",
6616 kernel_layout="OIHW",
6617 out_layout="NCHW",
6618 out_dtype="float32",
6619 )
6620 lv2: R.Tensor((1, 6, 1, 1), dtype="float32") = R.reshape(conv_bias, [1, 6, 1, 1])
6621 lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2)
6622 gv: R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")) = (lv3,)
6623 R.output(gv)
6624 return gv
6625
6626 from tvm.relax.frontend import detach_params
6627
6628 example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
6629 model = Conv2D1()
6630
6631 exported_program = torch.export.export(model, example_args)
6632 mod = from_exported_program(exported_program, keep_params_as_input=True)
6633 mod, params = detach_params(mod)
6634 tvm.ir.assert_structural_equal(mod, expected1)
6635 func = mod["main"]
6636 params = params["main"]
6637
6638 assert len(params) == len(func.params) - 1
6639 for param_var, param_tensor in zip(func.params[1:], params):
6640 assert tuple(x.value for x in param_var.struct_info.shape.values) == param_tensor.shape
6641 assert param_var.struct_info.dtype == param_tensor.dtype
6642
6643 tvm.testing.assert_allclose(params[0].numpy(), model.conv.weight.detach().detach().numpy())
6644 tvm.testing.assert_allclose(params[1].numpy(), model.conv.bias.detach().detach().numpy())
6645
6646

Callers

nothing calls this directly

Calls 5

from_exported_programFunction · 0.90
detach_paramsFunction · 0.90
tupleFunction · 0.85
numpyMethod · 0.80
Conv2D1Class · 0.70

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…