()
| 1739 | |
| 1740 | |
| 1741 | def test_conv2d_cuda_graph(): |
| 1742 | @tvm.script.ir_module |
| 1743 | class Conv2d: |
| 1744 | @R.function |
| 1745 | def main( |
| 1746 | data: R.Tensor((16, 32, 32, 16), "float16"), |
| 1747 | weight1: R.Tensor((16, 3, 3, 16), "float16"), |
| 1748 | weight2: R.Tensor((16, 3, 3, 16), "float16"), |
| 1749 | weight3: R.Tensor((16, 3, 3, 16), "float16"), |
| 1750 | gamma: R.Tensor((16,), "float16"), |
| 1751 | beta: R.Tensor((16,), "float16"), |
| 1752 | ): |
| 1753 | R.func_attr({"num_input": 1}) |
| 1754 | with R.dataflow(): |
| 1755 | conv1 = R.nn.relu( |
| 1756 | R.nn.conv2d( |
| 1757 | data, weight1, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI" |
| 1758 | ) |
| 1759 | ) |
| 1760 | conv2 = R.nn.relu( |
| 1761 | R.nn.conv2d( |
| 1762 | conv1, weight2, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI" |
| 1763 | ) |
| 1764 | ) |
| 1765 | ln = R.nn.layer_norm(conv2, gamma, beta, axes=[-1]) |
| 1766 | conv3 = R.nn.relu( |
| 1767 | R.nn.conv2d( |
| 1768 | ln, weight3, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI" |
| 1769 | ) |
| 1770 | ) |
| 1771 | R.output(conv3) |
| 1772 | |
| 1773 | return conv3 |
| 1774 | |
| 1775 | low, high = -1, 1 |
| 1776 | data_shape = (16, 32, 32, 16) |
| 1777 | weight_shape = (16, 3, 3, 16) |
| 1778 | dtype = "float16" |
| 1779 | data = np.random.randint(low, high, size=data_shape).astype(dtype) |
| 1780 | weight1 = np.random.randint(low, high, size=weight_shape).astype(dtype) |
| 1781 | weight2 = np.random.randint(low, high, size=weight_shape).astype(dtype) |
| 1782 | weight3 = np.random.randint(low, high, size=weight_shape).astype(dtype) |
| 1783 | gamma = np.random.randint(low, high, size=(weight_shape[0],)).astype(dtype) |
| 1784 | beta = np.random.randint(low, high, size=(weight_shape[0],)).astype(dtype) |
| 1785 | inputs = [data, weight1, weight2, weight3, gamma, beta] |
| 1786 | |
| 1787 | mod = partition_for_cutlass(Conv2d) |
| 1788 | mod = relax.transform.RunCodegen({"cutlass": {"sm": 80, "find_first_valid": True}})(mod) |
| 1789 | mod = relax.pipeline.get_pipeline()(mod) # pylint: disable=no-value-for-parameter |
| 1790 | |
| 1791 | with tvm.target.Target("cuda"): |
| 1792 | mod = tvm.s_tir.transform.DefaultGPUSchedule()(mod) |
| 1793 | |
| 1794 | out = build_and_run(mod, inputs, "cuda", cuda_graph=True) |
| 1795 | ref = build_and_run(Conv2d, inputs, "llvm", legalize=True) |
| 1796 | tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) |
| 1797 | |
| 1798 |
nothing calls this directly
no test coverage detected
searching dependent graphs…