Create a constant value. Parameters ---------- value: bool | int | float | numpy.ndarray | tvm.runtime.Tensor The constant value. dtype: Optional[str] The data type of the resulting constant. Note ---- When dtype is None, we use the following rule:
(
value: bool | int | float | _np.ndarray | tvm.runtime.Tensor, dtype: str | None = None
)
| 1156 | |
| 1157 | |
| 1158 | def const( |
| 1159 | value: bool | int | float | _np.ndarray | tvm.runtime.Tensor, dtype: str | None = None |
| 1160 | ) -> Constant: |
| 1161 | """Create a constant value. |
| 1162 | |
| 1163 | Parameters |
| 1164 | ---------- |
| 1165 | value: bool | int | float | numpy.ndarray | tvm.runtime.Tensor |
| 1166 | The constant value. |
| 1167 | |
| 1168 | dtype: Optional[str] |
| 1169 | The data type of the resulting constant. |
| 1170 | |
| 1171 | Note |
| 1172 | ---- |
| 1173 | When dtype is None, we use the following rule: |
| 1174 | |
| 1175 | - int maps to "int32" |
| 1176 | - float maps to "float32" |
| 1177 | - bool maps to "bool" |
| 1178 | - other using the same default rule as numpy. |
| 1179 | """ |
| 1180 | # Needed for bf16 and fp8 support (does not come with numpy) |
| 1181 | import ml_dtypes # pylint: disable=unused-import,import-outside-toplevel |
| 1182 | |
| 1183 | if isinstance(value, Number | (bool | list)): |
| 1184 | value = _np.array(value, dtype=dtype) |
| 1185 | |
| 1186 | if not dtype: |
| 1187 | # when dtype is None: int maps to "int32", float maps to "float32" |
| 1188 | dtype = { # type: ignore |
| 1189 | _np.dtype("int64"): _np.int32, # type: ignore |
| 1190 | _np.dtype("float64"): _np.float32, # type: ignore |
| 1191 | }.get( |
| 1192 | value.dtype, |
| 1193 | None, # type: ignore |
| 1194 | ) |
| 1195 | |
| 1196 | if isinstance(value, _np.ndarray | _np.generic): |
| 1197 | if dtype is not None: |
| 1198 | value = value.astype(dtype) |
| 1199 | value = tvm.runtime.tensor(value) |
| 1200 | |
| 1201 | if not isinstance(value, tvm.runtime.Tensor): |
| 1202 | raise ValueError("value has to be scalar or Tensor") |
| 1203 | |
| 1204 | return Constant(value) |
| 1205 | |
| 1206 | |
| 1207 | @tvm_ffi.register_object("relax.TEPlaceholderOp") |
searching dependent graphs…