MCPcopy Index your code
hub / github.com/apache/tvm / get_relax_attention_rewrite_module

Function get_relax_attention_rewrite_module

tests/python/relax/test_codegen_cutlass.py:888–1003  ·  view source on GitHub ↗
(
    q_shape, k_shape, v_shape, out_shape, dtype, bias_shape=None, scale=None
)

Source from the content-addressed store, hash-verified

886
887
888def get_relax_attention_rewrite_module(
889 q_shape, k_shape, v_shape, out_shape, dtype, bias_shape=None, scale=None
890):
891 from tvm.script.ir_builder import IRBuilder
892 from tvm.script.ir_builder import relax as relax_builder
893 from tvm.script.ir_builder import tirx as T
894
895 with IRBuilder() as builder:
896 with relax_builder.function():
897 R.func_name("main")
898 q = R.arg("q", R.Tensor(q_shape, dtype))
899 k = R.arg("k", R.Tensor(k_shape, dtype))
900 v = R.arg("v", R.Tensor(v_shape, dtype))
901 if bias_shape is not None:
902 bias = R.arg("bias", R.Tensor(bias_shape, dtype))
903 with R.dataflow() as frame:
904 if len(q_shape) == 4:
905 q = R.emit(R.permute_dims(q, axes=[0, 2, 1, 3]))
906 q = R.emit(R.reshape(q, [q_shape[0] * q_shape[2], q_shape[1], q_shape[3]]))
907
908 if len(k_shape) == 4:
909 k = R.emit(R.permute_dims(k, axes=[0, 2, 1, 3]))
910 k = R.emit(R.reshape(k, [k_shape[0] * k_shape[2], k_shape[1], k_shape[3]]))
911
912 if len(v_shape) == 4:
913 v = R.emit(R.permute_dims(v, axes=[0, 2, 1, 3]))
914 v = R.emit(R.reshape(v, [v_shape[0] * v_shape[2], v_shape[1], v_shape[3]]))
915
916 k = R.emit(R.permute_dims(k, axes=[0, 2, 1]))
917 qk = R.emit(R.matmul(q, k))
918 qk_scaled = R.emit(R.multiply(qk, R.const(scale, "float32")))
919 if bias_shape is not None:
920 if len(bias_shape) == 4:
921 bias = R.emit(
922 R.reshape(bias, [bias_shape[0] * bias_shape[1], *bias_shape[2:]])
923 )
924 qk_added = R.emit(R.add(qk_scaled, bias))
925 softmax = R.emit(R.nn.softmax(qk_added, axis=-1))
926 else:
927 softmax = R.emit(R.nn.softmax(qk_scaled, axis=-1))
928 out = R.emit(R.matmul(softmax, v))
929
930 if len(out_shape) == 4:
931 out = R.emit(
932 R.reshape(
933 out,
934 [out_shape[0], out_shape[2], out_shape[1], out_shape[3]],
935 )
936 )
937 out = R.emit(R.permute_dims(out, axes=[0, 2, 1, 3]))
938 R.output(out)
939
940 R.func_ret_value(frame.output_vars[0])
941
942 original_func = builder.get()
943
944 if scale is not None:
945 scale = T.FloatImm("float32", scale)

Callers 1

Calls 13

IRBuilderClass · 0.90
TensorMethod · 0.80
dataflowMethod · 0.80
outputMethod · 0.80
functionMethod · 0.45
emitMethod · 0.45
permute_dimsMethod · 0.45
reshapeMethod · 0.45
matmulMethod · 0.45
multiplyMethod · 0.45
addMethod · 0.45
softmaxMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…