| 53 | return dtype[0:index] |
| 54 | |
| 55 | def _auto_broadcast(a, b, op): |
| 56 | if isinstance(a, int): |
| 57 | if hasattr(b, "dtype"): |
| 58 | if ( |
| 59 | DataType(b.dtype).type_code == DataTypeCode.INT |
| 60 | or DataType(b.dtype).type_code == DataTypeCode.UINT |
| 61 | or DataType(b.dtype).type_code == DataTypeCode.BOOL |
| 62 | ): |
| 63 | a = IntImm(_get_type_str(b.dtype), a) |
| 64 | elif DataType(b.dtype).type_code == DataTypeCode.FLOAT: |
| 65 | a = FloatImm(_get_type_str(b.dtype), a) |
| 66 | elif isinstance(b, float): |
| 67 | a = FloatImm("float32", a) |
| 68 | else: |
| 69 | a = IntImm("int32", a) |
| 70 | elif isinstance(a, float): |
| 71 | if DataType(b.dtype).type_code == DataTypeCode.FLOAT: |
| 72 | a = FloatImm(_get_type_str(b.dtype), a) |
| 73 | else: |
| 74 | a = FloatImm("float32", a) |
| 75 | |
| 76 | assert isinstance(a, tirx.PrimExpr), "Operand should be a PrimExpr." |
| 77 | if isinstance(b, int): |
| 78 | if ( |
| 79 | DataType(a.dtype).type_code == DataTypeCode.INT |
| 80 | or DataType(a.dtype).type_code == DataTypeCode.UINT |
| 81 | or DataType(a.dtype).type_code == DataTypeCode.BOOL |
| 82 | ): |
| 83 | b = IntImm(_get_type_str(a.dtype), b) |
| 84 | elif DataType(a.dtype).type_code == DataTypeCode.FLOAT: |
| 85 | b = FloatImm(_get_type_str(a.dtype), b) |
| 86 | elif isinstance(b, float): |
| 87 | b = FloatImm(_get_type_str(a.dtype), b) |
| 88 | |
| 89 | if DataType(a.dtype).lanes == DataType(b.dtype).lanes: |
| 90 | return op(a, b) |
| 91 | elif DataType(a.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes: |
| 92 | broadcast_a = tirx.Broadcast(a, DataType(b.dtype).lanes) |
| 93 | return op(broadcast_a, b) |
| 94 | elif DataType(b.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes: |
| 95 | broadcast_b = tirx.Broadcast(b, DataType(a.dtype).lanes) |
| 96 | return op(a, broadcast_b) |
| 97 | else: |
| 98 | raise TypeError("do not know how to deal with it.") |
| 99 | |
| 100 | def _eq(a, b): |
| 101 | return _auto_broadcast(a, b, tirx.EQ) |