Legalize high-level operator calls in Relax functions to call_tir with corresponding low-level TIR PrimFuncs. For each high-level operator, we register the way of legalizing it as a function, which takes a context BlockBuilder and the Call being legalized as input, and returns the l
(
customize_legalize_map: dict[str, LegalizeFunc] | None = None,
skip_ops: list[str] | None = None,
enable_warning: bool = False,
)
| 1066 | |
| 1067 | |
| 1068 | def LegalizeOps( |
| 1069 | customize_legalize_map: dict[str, LegalizeFunc] | None = None, |
| 1070 | skip_ops: list[str] | None = None, |
| 1071 | enable_warning: bool = False, |
| 1072 | ): |
| 1073 | """Legalize high-level operator calls in Relax functions to call_tir |
| 1074 | with corresponding low-level TIR PrimFuncs. |
| 1075 | |
| 1076 | For each high-level operator, we register the way of legalizing it as a |
| 1077 | function, which takes a context BlockBuilder and the Call being legalized |
| 1078 | as input, and returns the legalized call. Here the input BlockBuilder is |
| 1079 | mainly used for adding the PrimFunc created by call_te into the context |
| 1080 | IRModule. |
| 1081 | |
| 1082 | The legalization function for each operator is registered as an attribute (with |
| 1083 | attribute key `FLegalize`) of the operator. |
| 1084 | |
| 1085 | This pass provides customizability for users to use their own legalization |
| 1086 | function for operators. The pass takes an optional customized map, |
| 1087 | with the key to be the operator name (`str`) and value to be the function |
| 1088 | (`LegalizeFunc`). The default legalization function will be overridden by the customized |
| 1089 | one. |
| 1090 | |
| 1091 | Parameters |
| 1092 | ---------- |
| 1093 | customize_legalize_map : Optional[Dict[str, LegalizeFunc]] |
| 1094 | The customized operator legalization function map. The customized function will override |
| 1095 | the default one. |
| 1096 | |
| 1097 | skip_ops : Optional,List[str]] |
| 1098 | List of ops that need to be skipped from legalization |
| 1099 | |
| 1100 | enable_warning : bool |
| 1101 | A boolean value indicating if to print warnings for CallNode whose op's |
| 1102 | legalization function is not registered. By default we don't print |
| 1103 | warnings. |
| 1104 | |
| 1105 | Returns |
| 1106 | ------- |
| 1107 | ret : tvm.transform.Pass |
| 1108 | The registered pass |
| 1109 | |
| 1110 | Examples |
| 1111 | -------- |
| 1112 | The following code shows how to use this pass: |
| 1113 | |
| 1114 | .. code-block:: python |
| 1115 | |
| 1116 | # Define the pass input IRModule |
| 1117 | @tvm.script.ir_module |
| 1118 | class Module: |
| 1119 | @R.function |
| 1120 | def main( |
| 1121 | x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") |
| 1122 | ) -> R.Tensor((2, 3), "float32"): |
| 1123 | z: R.Tensor((2, 3), "float32") = R.add(x, y) |
| 1124 | r: R.Tensor((2, 3), "float32") = R.multiply(y, z) |
| 1125 | return r |
no outgoing calls
searching dependent graphs…