Wrapper function that receives datasets in args; converts to dataarrays when necessary; passes these to the user function `func` and checks returned objects for expected shapes/sizes/etc.
(
func: Callable,
args: list,
kwargs: dict,
arg_is_array: Iterable[bool],
expected: ExpectedDict,
expected_indexes: dict[Hashable, Index],
)
| 331 | """ |
| 332 | |
| 333 | def _wrapper( |
| 334 | func: Callable, |
| 335 | args: list, |
| 336 | kwargs: dict, |
| 337 | arg_is_array: Iterable[bool], |
| 338 | expected: ExpectedDict, |
| 339 | expected_indexes: dict[Hashable, Index], |
| 340 | ): |
| 341 | """ |
| 342 | Wrapper function that receives datasets in args; converts to dataarrays when necessary; |
| 343 | passes these to the user function `func` and checks returned objects for expected shapes/sizes/etc. |
| 344 | """ |
| 345 | |
| 346 | converted_args = [ |
| 347 | dataset_to_dataarray(arg) if is_array else arg |
| 348 | for is_array, arg in zip(arg_is_array, args, strict=True) |
| 349 | ] |
| 350 | |
| 351 | result = func(*converted_args, **kwargs) |
| 352 | |
| 353 | merged_coordinates = merge( |
| 354 | [arg.coords for arg in args if isinstance(arg, Dataset | DataArray)], |
| 355 | join="exact", |
| 356 | compat="override", |
| 357 | ).coords |
| 358 | |
| 359 | # check all dims are present |
| 360 | missing_dimensions = set(expected["shapes"]) - set(result.sizes) |
| 361 | if missing_dimensions: |
| 362 | raise ValueError( |
| 363 | f"Dimensions {missing_dimensions} missing on returned object." |
| 364 | ) |
| 365 | |
| 366 | # check that index lengths and values are as expected |
| 367 | for name, index in result._indexes.items(): |
| 368 | if ( |
| 369 | name in expected["shapes"] |
| 370 | and result.sizes[name] != expected["shapes"][name] |
| 371 | ): |
| 372 | raise ValueError( |
| 373 | f"Received dimension {name!r} of length {result.sizes[name]}. " |
| 374 | f"Expected length {expected['shapes'][name]}." |
| 375 | ) |
| 376 | |
| 377 | # ChainMap wants MutableMapping, but xindexes is Mapping |
| 378 | merged_indexes = collections.ChainMap( |
| 379 | expected_indexes, |
| 380 | merged_coordinates.xindexes, # type: ignore[arg-type] |
| 381 | ) |
| 382 | expected_index = merged_indexes.get(name, None) |
| 383 | if expected_index is not None and not index.equals(expected_index): |
| 384 | raise ValueError( |
| 385 | f"Expected index {name!r} to be {expected_index!r}. Received {index!r} instead." |
| 386 | ) |
| 387 | |
| 388 | # check that all expected variables were returned |
| 389 | check_result_variables(result, expected, "coords") |
| 390 | if isinstance(result, Dataset): |
nothing calls this directly
no test coverage detected
searching dependent graphs…