MCPcopy Index your code
hub / github.com/apache/tvm / test_conv2d_offload

Function test_conv2d_offload

tests/python/relax/test_codegen_cudnn.py:179–210  ·  view source on GitHub ↗
(data_shape, weight_shape, dtype, with_bias, activation)

Source from the content-addressed store, hash-verified

177 ],
178)
179def test_conv2d_offload(data_shape, weight_shape, dtype, with_bias, activation):
180 input = np.random.randn(*data_shape).astype(dtype)
181 weight = np.random.randn(*weight_shape).astype(dtype)
182
183 if with_bias:
184 oc = weight_shape[0]
185 bias = np.random.randn(1, 1, 1, oc).astype(dtype)
186 args = (input, weight, bias)
187 else:
188 bias = None
189 args = (input, weight)
190
191 activation = _activation_table[activation]
192
193 mod = get_relax_conv2d_module(
194 data_shape,
195 weight_shape,
196 dtype,
197 with_bias=with_bias,
198 activation=activation,
199 )
200
201 out = get_result_with_relax_cudnn_offload(mod, args)
202 ref = build_and_run(mod, args, "llvm", legalize=True)
203 if dtype == "float16":
204 # FIXME(lei): currently raise into 3e-1 to prevent flaky test
205 # see https://github.com/apache/tvm/pull/18319
206 tvm.testing.assert_allclose(out, ref, rtol=3e-1, atol=3e-1)
207 else:
208 # Increased tolerance to 2.5e-2 to prevent flaky test due to numerical
209 # differences between cuDNN and LLVM implementations
210 tvm.testing.assert_allclose(out, ref, rtol=2.5e-2, atol=2.5e-2)
211
212
213@pytest.mark.skip(reason="flaky test")

Callers

nothing calls this directly

Calls 4

get_relax_conv2d_moduleFunction · 0.70
build_and_runFunction · 0.70
astypeMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…