(
variables: Iterable[Variable], exclude_dims: AbstractSet = frozenset()
)
| 619 | |
| 620 | |
| 621 | def unified_dim_sizes( |
| 622 | variables: Iterable[Variable], exclude_dims: AbstractSet = frozenset() |
| 623 | ) -> dict[Hashable, int]: |
| 624 | dim_sizes: dict[Hashable, int] = {} |
| 625 | |
| 626 | for var in variables: |
| 627 | if len(set(var.dims)) < len(var.dims): |
| 628 | raise ValueError( |
| 629 | "broadcasting cannot handle duplicate " |
| 630 | f"dimensions on a variable: {list(var.dims)}" |
| 631 | ) |
| 632 | for dim, size in zip(var.dims, var.shape, strict=True): |
| 633 | if dim not in exclude_dims: |
| 634 | if dim not in dim_sizes: |
| 635 | dim_sizes[dim] = size |
| 636 | elif dim_sizes[dim] != size: |
| 637 | raise ValueError( |
| 638 | "operands cannot be broadcast together " |
| 639 | "with mismatched lengths for dimension " |
| 640 | f"{dim}: {dim_sizes[dim]} vs {size}" |
| 641 | ) |
| 642 | return dim_sizes |
| 643 | |
| 644 | |
| 645 | SLICE_NONE = slice(None) |
no outgoing calls
searching dependent graphs…