MCPcopy Index your code
hub / github.com/apache/tvm / get_relax_matmul_dequantize_module

Function get_relax_matmul_dequantize_module

tests/python/relax/test_codegen_cublas.py:108–142  ·  view source on GitHub ↗

Create a matmul op followd by dequantize operations.

(
    x_shape,
    y_shape,
    in_dtype,
    out_dtype,
    transposed_y=False,
    scale_const=1.0,
    zero_point_const=0.0,
)

Source from the content-addressed store, hash-verified

106
107
108def get_relax_matmul_dequantize_module(
109 x_shape,
110 y_shape,
111 in_dtype,
112 out_dtype,
113 transposed_y=False,
114 scale_const=1.0,
115 zero_point_const=0.0,
116):
117 """Create a matmul op followd by dequantize operations."""
118 with IRBuilder() as builder:
119 with relax_builder.function():
120 R.func_name("main")
121 x = R.arg("x", R.Tensor(x_shape, in_dtype))
122 y = R.arg("y", R.Tensor(y_shape, in_dtype))
123
124 with R.dataflow() as frame:
125 if transposed_y:
126 axes = list(range(len(y_shape) - 2)) + [-1, -2]
127 y = R.emit(R.permute_dims(y, axes=axes))
128 result = R.emit(R.matmul(x, y, out_dtype="float32"))
129 result = R.emit(
130 R.dequantize(
131 result,
132 scale=R.const(scale_const, "float16"),
133 zero_point=R.const(zero_point_const, "float16"),
134 axis=-1,
135 out_dtype=out_dtype,
136 )
137 )
138 R.output(result)
139 R.func_ret_value(frame.output_vars[0])
140
141 func = builder.get()
142 return tvm.IRModule({"main": func})
143
144
145def get_relax_matmul_multiply_module(

Calls 10

IRBuilderClass · 0.90
TensorMethod · 0.80
dataflowMethod · 0.80
outputMethod · 0.80
functionMethod · 0.45
emitMethod · 0.45
permute_dimsMethod · 0.45
matmulMethod · 0.45
dequantizeMethod · 0.45
getMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…