MCPcopy
hub / github.com/apache/tvm / _auto_broadcast

Function _auto_broadcast

python/tvm/tirx/script/parser/operation.py:55–98  ·  view source on GitHub ↗
(a, b, op)

Source from the content-addressed store, hash-verified

53 return dtype[0:index]
54
55 def _auto_broadcast(a, b, op):
56 if isinstance(a, int):
57 if hasattr(b, "dtype"):
58 if (
59 DataType(b.dtype).type_code == DataTypeCode.INT
60 or DataType(b.dtype).type_code == DataTypeCode.UINT
61 or DataType(b.dtype).type_code == DataTypeCode.BOOL
62 ):
63 a = IntImm(_get_type_str(b.dtype), a)
64 elif DataType(b.dtype).type_code == DataTypeCode.FLOAT:
65 a = FloatImm(_get_type_str(b.dtype), a)
66 elif isinstance(b, float):
67 a = FloatImm("float32", a)
68 else:
69 a = IntImm("int32", a)
70 elif isinstance(a, float):
71 if DataType(b.dtype).type_code == DataTypeCode.FLOAT:
72 a = FloatImm(_get_type_str(b.dtype), a)
73 else:
74 a = FloatImm("float32", a)
75
76 assert isinstance(a, tirx.PrimExpr), "Operand should be a PrimExpr."
77 if isinstance(b, int):
78 if (
79 DataType(a.dtype).type_code == DataTypeCode.INT
80 or DataType(a.dtype).type_code == DataTypeCode.UINT
81 or DataType(a.dtype).type_code == DataTypeCode.BOOL
82 ):
83 b = IntImm(_get_type_str(a.dtype), b)
84 elif DataType(a.dtype).type_code == DataTypeCode.FLOAT:
85 b = FloatImm(_get_type_str(a.dtype), b)
86 elif isinstance(b, float):
87 b = FloatImm(_get_type_str(a.dtype), b)
88
89 if DataType(a.dtype).lanes == DataType(b.dtype).lanes:
90 return op(a, b)
91 elif DataType(a.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes:
92 broadcast_a = tirx.Broadcast(a, DataType(b.dtype).lanes)
93 return op(broadcast_a, b)
94 elif DataType(b.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes:
95 broadcast_b = tirx.Broadcast(b, DataType(a.dtype).lanes)
96 return op(a, broadcast_b)
97 else:
98 raise TypeError("do not know how to deal with it.")
99
100 def _eq(a, b):
101 return _auto_broadcast(a, b, tirx.EQ)

Callers 6

_eqFunction · 0.85
_neFunction · 0.85
_ltFunction · 0.85
_leFunction · 0.85
_gtFunction · 0.85
_geFunction · 0.85

Calls 4

IntImmClass · 0.90
FloatImmClass · 0.90
DataTypeClass · 0.85
_get_type_strFunction · 0.85

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…