(
arg: T_Alignable, exclude, dims_map, common_coords
)
| 1134 | |
| 1135 | |
| 1136 | def _broadcast_helper( |
| 1137 | arg: T_Alignable, exclude, dims_map, common_coords |
| 1138 | ) -> T_Alignable: |
| 1139 | from xarray.core.dataarray import DataArray |
| 1140 | from xarray.core.dataset import Dataset |
| 1141 | |
| 1142 | def _set_dims(var): |
| 1143 | # Add excluded dims to a copy of dims_map |
| 1144 | var_dims_map = dims_map.copy() |
| 1145 | for dim in exclude: |
| 1146 | with suppress(ValueError): |
| 1147 | # ignore dim not in var.dims |
| 1148 | var_dims_map[dim] = var.shape[var.dims.index(dim)] |
| 1149 | |
| 1150 | return var.set_dims(var_dims_map) |
| 1151 | |
| 1152 | def _broadcast_array(array: T_DataArray) -> T_DataArray: |
| 1153 | data = _set_dims(array.variable) |
| 1154 | coords = dict(array.coords) |
| 1155 | coords.update(common_coords) |
| 1156 | return array.__class__( |
| 1157 | data, coords, data.dims, name=array.name, attrs=array.attrs |
| 1158 | ) |
| 1159 | |
| 1160 | def _broadcast_dataset(ds: T_Dataset) -> T_Dataset: |
| 1161 | data_vars = {k: _set_dims(ds.variables[k]) for k in ds.data_vars} |
| 1162 | coords = dict(ds.coords) |
| 1163 | coords.update(common_coords) |
| 1164 | return ds.__class__(data_vars, coords, ds.attrs) |
| 1165 | |
| 1166 | # remove casts once https://github.com/python/mypy/issues/12800 is resolved |
| 1167 | if isinstance(arg, DataArray): |
| 1168 | return _broadcast_array(arg) # type: ignore[return-value,unused-ignore] |
| 1169 | elif isinstance(arg, Dataset): |
| 1170 | return _broadcast_dataset(arg) # type: ignore[return-value,unused-ignore] |
| 1171 | else: |
| 1172 | raise ValueError("all input must be Dataset or DataArray objects") |
| 1173 | |
| 1174 | |
| 1175 | @overload |
no test coverage detected
searching dependent graphs…