Three argument where() with better dtype promotion rules.
(condition, x, y)
| 405 | |
| 406 | |
| 407 | def where(condition, x, y): |
| 408 | """Three argument where() with better dtype promotion rules.""" |
| 409 | xp = get_array_namespace(condition, x, y) |
| 410 | |
| 411 | dtype = xp.bool_ if hasattr(xp, "bool_") else xp.bool |
| 412 | if not is_duck_array(condition): |
| 413 | condition = asarray(condition, dtype=dtype, xp=xp) |
| 414 | else: |
| 415 | condition = astype(condition, dtype=dtype, xp=xp) |
| 416 | |
| 417 | promoted_x, promoted_y = as_shared_dtype([x, y], xp=xp) |
| 418 | |
| 419 | return xp.where(condition, promoted_x, promoted_y) |
| 420 | |
| 421 | |
| 422 | def where_method(data, cond, other=dtypes.NA): |
searching dependent graphs…