(v: UOp)
| 131 | |
| 132 | def _minmax_reduce(is_max: bool, dt, *args: UOp) -> UOp: |
| 133 | def cast(v: UOp) -> UOp: return v.bitcast(dt) if dt == dtypes.float32 and v.dtype == dtypes.uint32 else v.cast(dt) |
| 134 | def minmax(a: UOp, b: UOp) -> UOp: |
| 135 | if dt in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64): return (a > b).where(a, b) if is_max else (a < b).where(a, b) |
| 136 | return a.maximum(b) if is_max else a.minimum(b) |
searching dependent graphs…