Apply a dataset or datarray level function over GroupBy, Dataset, DataArray, Variable and/or ndarray objects.
(func, *args)
| 572 | |
| 573 | |
| 574 | def apply_groupby_func(func, *args): |
| 575 | """Apply a dataset or datarray level function over GroupBy, Dataset, |
| 576 | DataArray, Variable and/or ndarray objects. |
| 577 | """ |
| 578 | from xarray.core.groupby import GroupBy, peek_at |
| 579 | |
| 580 | groupbys = [arg for arg in args if isinstance(arg, GroupBy)] |
| 581 | assert groupbys, "must have at least one groupby to iterate over" |
| 582 | first_groupby = groupbys[0] |
| 583 | (grouper,) = first_groupby.groupers |
| 584 | if any(not grouper.group.equals(gb.groupers[0].group) for gb in groupbys[1:]): # type: ignore[union-attr] |
| 585 | raise ValueError( |
| 586 | "apply_ufunc can only perform operations over " |
| 587 | "multiple GroupBy objects at once if they are all " |
| 588 | "grouped the same way" |
| 589 | ) |
| 590 | |
| 591 | grouped_dim = grouper.name |
| 592 | unique_values = grouper.unique_coord.values |
| 593 | |
| 594 | iterators = [] |
| 595 | for arg in args: |
| 596 | iterator: Iterator[Any] |
| 597 | if isinstance(arg, GroupBy): |
| 598 | iterator = (value for _, value in arg) |
| 599 | elif hasattr(arg, "dims") and grouped_dim in arg.dims: |
| 600 | if isinstance(arg, Variable): |
| 601 | raise ValueError( |
| 602 | "groupby operations cannot be performed with " |
| 603 | "xarray.Variable objects that share a dimension with " |
| 604 | "the grouped dimension" |
| 605 | ) |
| 606 | iterator = _iter_over_selections(arg, grouped_dim, unique_values) |
| 607 | else: |
| 608 | iterator = itertools.repeat(arg) |
| 609 | iterators.append(iterator) |
| 610 | |
| 611 | applied: Iterator = itertools.starmap(func, zip(*iterators, strict=False)) |
| 612 | applied_example, applied = peek_at(applied) |
| 613 | combine = first_groupby._combine # type: ignore[attr-defined] |
| 614 | if isinstance(applied_example, tuple): |
| 615 | combined = tuple(combine(output) for output in zip(*applied, strict=True)) |
| 616 | else: |
| 617 | combined = combine(applied) |
| 618 | return combined |
| 619 | |
| 620 | |
| 621 | def unified_dim_sizes( |
no test coverage detected
searching dependent graphs…