Create a matmul op followd by dequantize operations.
(
x_shape,
y_shape,
in_dtype,
out_dtype,
transposed_y=False,
scale_const=1.0,
zero_point_const=0.0,
)
| 106 | |
| 107 | |
| 108 | def get_relax_matmul_dequantize_module( |
| 109 | x_shape, |
| 110 | y_shape, |
| 111 | in_dtype, |
| 112 | out_dtype, |
| 113 | transposed_y=False, |
| 114 | scale_const=1.0, |
| 115 | zero_point_const=0.0, |
| 116 | ): |
| 117 | """Create a matmul op followd by dequantize operations.""" |
| 118 | with IRBuilder() as builder: |
| 119 | with relax_builder.function(): |
| 120 | R.func_name("main") |
| 121 | x = R.arg("x", R.Tensor(x_shape, in_dtype)) |
| 122 | y = R.arg("y", R.Tensor(y_shape, in_dtype)) |
| 123 | |
| 124 | with R.dataflow() as frame: |
| 125 | if transposed_y: |
| 126 | axes = list(range(len(y_shape) - 2)) + [-1, -2] |
| 127 | y = R.emit(R.permute_dims(y, axes=axes)) |
| 128 | result = R.emit(R.matmul(x, y, out_dtype="float32")) |
| 129 | result = R.emit( |
| 130 | R.dequantize( |
| 131 | result, |
| 132 | scale=R.const(scale_const, "float16"), |
| 133 | zero_point=R.const(zero_point_const, "float16"), |
| 134 | axis=-1, |
| 135 | out_dtype=out_dtype, |
| 136 | ) |
| 137 | ) |
| 138 | R.output(result) |
| 139 | R.func_ret_value(frame.output_vars[0]) |
| 140 | |
| 141 | func = builder.get() |
| 142 | return tvm.IRModule({"main": func}) |
| 143 | |
| 144 | |
| 145 | def get_relax_matmul_multiply_module( |
no test coverage detected
searching dependent graphs…