MCPcopy Index your code
hub / github.com/apache/tvm / test_attention_rewrite_offload

Function test_attention_rewrite_offload

tests/python/relax/test_codegen_cutlass.py:1017–1042  ·  view source on GitHub ↗
(attention_rewrite_size)

Source from the content-addressed store, hash-verified

1015
1016
1017def test_attention_rewrite_offload(attention_rewrite_size):
1018 b, (s, s_kv), n, (h, h_v), bias_shape, scale = attention_rewrite_size
1019 q_shape = [b, s, n, h] if n != "none" else [b, s, h]
1020 k_shape = [b, s_kv, n, h] if n != "none" else [b, s_kv, h]
1021 v_shape = [b, s_kv, n, h_v] if n != "none" else [b, s_kv, h_v]
1022 out_shape = [b, s, n, h_v] if n != "none" else [b, s, h_v]
1023 bias_shape = [b, n, s, s_kv] if n != "none" else [b, s, s_kv]
1024 q, k, v, bias = get_numpy_attention_input(q_shape, k_shape, v_shape, bias_shape, "float32")
1025 original_mod, expected_mod = get_relax_attention_rewrite_module(
1026 q_shape, k_shape, v_shape, out_shape, "float32", bias_shape, scale
1027 )
1028 original_mod = partition_for_cutlass(original_mod, True)
1029 expected_mod = partition_for_cutlass(expected_mod, True)
1030 tvm.ir.assert_structural_equal(original_mod, expected_mod, True)
1031
1032 codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80, "find_first_valid": True}})
1033 original_mod = codegen_pass(original_mod)
1034 expected_mod = codegen_pass(expected_mod)
1035 if bias is None:
1036 original_out = build_and_run(original_mod, [q, k, v], "cuda")
1037 expected_out = build_and_run(expected_mod, [q, k, v], "cuda")
1038 tvm.testing.assert_allclose(original_out, expected_out, rtol=1e-5, atol=1e-5)
1039 else:
1040 original_out = build_and_run(original_mod, [q, k, v, bias], "cuda", legalize=False)
1041 expected_out = build_and_run(expected_mod, [q, k, v, bias], "cuda", legalize=False)
1042 tvm.testing.assert_allclose(original_out, expected_out, rtol=1e-5, atol=1e-5)
1043
1044
1045def test_conv2d_residual_broadcast():

Callers

nothing calls this directly

Calls 4

partition_for_cutlassFunction · 0.90
build_and_runFunction · 0.70

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…