(
q_shape, k_shape, v_shape, out_shape, dtype, bias_shape=None, scale=None
)
| 886 | |
| 887 | |
| 888 | def get_relax_attention_rewrite_module( |
| 889 | q_shape, k_shape, v_shape, out_shape, dtype, bias_shape=None, scale=None |
| 890 | ): |
| 891 | from tvm.script.ir_builder import IRBuilder |
| 892 | from tvm.script.ir_builder import relax as relax_builder |
| 893 | from tvm.script.ir_builder import tirx as T |
| 894 | |
| 895 | with IRBuilder() as builder: |
| 896 | with relax_builder.function(): |
| 897 | R.func_name("main") |
| 898 | q = R.arg("q", R.Tensor(q_shape, dtype)) |
| 899 | k = R.arg("k", R.Tensor(k_shape, dtype)) |
| 900 | v = R.arg("v", R.Tensor(v_shape, dtype)) |
| 901 | if bias_shape is not None: |
| 902 | bias = R.arg("bias", R.Tensor(bias_shape, dtype)) |
| 903 | with R.dataflow() as frame: |
| 904 | if len(q_shape) == 4: |
| 905 | q = R.emit(R.permute_dims(q, axes=[0, 2, 1, 3])) |
| 906 | q = R.emit(R.reshape(q, [q_shape[0] * q_shape[2], q_shape[1], q_shape[3]])) |
| 907 | |
| 908 | if len(k_shape) == 4: |
| 909 | k = R.emit(R.permute_dims(k, axes=[0, 2, 1, 3])) |
| 910 | k = R.emit(R.reshape(k, [k_shape[0] * k_shape[2], k_shape[1], k_shape[3]])) |
| 911 | |
| 912 | if len(v_shape) == 4: |
| 913 | v = R.emit(R.permute_dims(v, axes=[0, 2, 1, 3])) |
| 914 | v = R.emit(R.reshape(v, [v_shape[0] * v_shape[2], v_shape[1], v_shape[3]])) |
| 915 | |
| 916 | k = R.emit(R.permute_dims(k, axes=[0, 2, 1])) |
| 917 | qk = R.emit(R.matmul(q, k)) |
| 918 | qk_scaled = R.emit(R.multiply(qk, R.const(scale, "float32"))) |
| 919 | if bias_shape is not None: |
| 920 | if len(bias_shape) == 4: |
| 921 | bias = R.emit( |
| 922 | R.reshape(bias, [bias_shape[0] * bias_shape[1], *bias_shape[2:]]) |
| 923 | ) |
| 924 | qk_added = R.emit(R.add(qk_scaled, bias)) |
| 925 | softmax = R.emit(R.nn.softmax(qk_added, axis=-1)) |
| 926 | else: |
| 927 | softmax = R.emit(R.nn.softmax(qk_scaled, axis=-1)) |
| 928 | out = R.emit(R.matmul(softmax, v)) |
| 929 | |
| 930 | if len(out_shape) == 4: |
| 931 | out = R.emit( |
| 932 | R.reshape( |
| 933 | out, |
| 934 | [out_shape[0], out_shape[2], out_shape[1], out_shape[3]], |
| 935 | ) |
| 936 | ) |
| 937 | out = R.emit(R.permute_dims(out, axes=[0, 2, 1, 3])) |
| 938 | R.output(out) |
| 939 | |
| 940 | R.func_ret_value(frame.output_vars[0]) |
| 941 | |
| 942 | original_func = builder.get() |
| 943 | |
| 944 | if scale is not None: |
| 945 | scale = T.FloatImm("float32", scale) |
no test coverage detected
searching dependent graphs…