MCPcopy Index your code
hub / github.com/apache/tvm / test_conv2d_cuda_graph

Function test_conv2d_cuda_graph

tests/python/relax/test_codegen_cutlass.py:1741–1796  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

1739
1740
1741def test_conv2d_cuda_graph():
1742 @tvm.script.ir_module
1743 class Conv2d:
1744 @R.function
1745 def main(
1746 data: R.Tensor((16, 32, 32, 16), "float16"),
1747 weight1: R.Tensor((16, 3, 3, 16), "float16"),
1748 weight2: R.Tensor((16, 3, 3, 16), "float16"),
1749 weight3: R.Tensor((16, 3, 3, 16), "float16"),
1750 gamma: R.Tensor((16,), "float16"),
1751 beta: R.Tensor((16,), "float16"),
1752 ):
1753 R.func_attr({"num_input": 1})
1754 with R.dataflow():
1755 conv1 = R.nn.relu(
1756 R.nn.conv2d(
1757 data, weight1, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI"
1758 )
1759 )
1760 conv2 = R.nn.relu(
1761 R.nn.conv2d(
1762 conv1, weight2, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI"
1763 )
1764 )
1765 ln = R.nn.layer_norm(conv2, gamma, beta, axes=[-1])
1766 conv3 = R.nn.relu(
1767 R.nn.conv2d(
1768 ln, weight3, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI"
1769 )
1770 )
1771 R.output(conv3)
1772
1773 return conv3
1774
1775 low, high = -1, 1
1776 data_shape = (16, 32, 32, 16)
1777 weight_shape = (16, 3, 3, 16)
1778 dtype = "float16"
1779 data = np.random.randint(low, high, size=data_shape).astype(dtype)
1780 weight1 = np.random.randint(low, high, size=weight_shape).astype(dtype)
1781 weight2 = np.random.randint(low, high, size=weight_shape).astype(dtype)
1782 weight3 = np.random.randint(low, high, size=weight_shape).astype(dtype)
1783 gamma = np.random.randint(low, high, size=(weight_shape[0],)).astype(dtype)
1784 beta = np.random.randint(low, high, size=(weight_shape[0],)).astype(dtype)
1785 inputs = [data, weight1, weight2, weight3, gamma, beta]
1786
1787 mod = partition_for_cutlass(Conv2d)
1788 mod = relax.transform.RunCodegen({"cutlass": {"sm": 80, "find_first_valid": True}})(mod)
1789 mod = relax.pipeline.get_pipeline()(mod) # pylint: disable=no-value-for-parameter
1790
1791 with tvm.target.Target("cuda"):
1792 mod = tvm.s_tir.transform.DefaultGPUSchedule()(mod)
1793
1794 out = build_and_run(mod, inputs, "cuda", cuda_graph=True)
1795 ref = build_and_run(Conv2d, inputs, "llvm", legalize=True)
1796 tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
1797
1798

Callers

nothing calls this directly

Calls 3

partition_for_cutlassFunction · 0.90
build_and_runFunction · 0.70
astypeMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…