| 527 | io_deps: Mapping[str, BlockwiseDep] |
| 528 | |
| 529 | def __init__( |
| 530 | self, |
| 531 | output: str, |
| 532 | output_indices: Iterable[str], |
| 533 | task: Task, |
| 534 | indices: Iterable[tuple[str | TaskRef | BlockwiseDep, Iterable[str] | None]], |
| 535 | numblocks: Mapping[str, Sequence[int]], |
| 536 | concatenate: bool | None = None, |
| 537 | new_axes: Mapping[str, int] | None = None, |
| 538 | output_blocks: set[tuple[int, ...]] | None = None, |
| 539 | annotations: Mapping[str, Any] | None = None, |
| 540 | io_deps: Mapping[str, BlockwiseDep] | None = None, |
| 541 | ): |
| 542 | super().__init__(annotations=annotations) |
| 543 | self.output = output |
| 544 | self.output_indices = tuple(output_indices) |
| 545 | self.output_blocks = output_blocks |
| 546 | self.task = task |
| 547 | assert isinstance(task, Task) |
| 548 | |
| 549 | # Remove `BlockwiseDep` arguments from input indices |
| 550 | # and add them to `self.io_deps`. |
| 551 | # TODO: Remove `io_deps` and handle indexable objects |
| 552 | # in `self.indices` throughout `Blockwise`. |
| 553 | _tmp_indices = [] |
| 554 | numblocks = dict(numblocks) |
| 555 | io_deps = dict(io_deps or {}) |
| 556 | if indices: |
| 557 | for dep, ind in indices: |
| 558 | if ind is not None: |
| 559 | # FIXME: The Blockwise API is a little weird this way |
| 560 | assert not isinstance( |
| 561 | dep, TaskRef |
| 562 | ), "TaskRef objects are only allowed for broadcasted inputs with None as index." |
| 563 | if isinstance(dep, BlockwiseDep): |
| 564 | name = tokenize(dep) |
| 565 | io_deps[name] = dep |
| 566 | numblocks[name] = dep.numblocks |
| 567 | else: |
| 568 | name = dep # type: ignore[assignment] |
| 569 | _tmp_indices.append((name, tuple(ind) if ind is not None else ind)) |
| 570 | self.numblocks = numblocks |
| 571 | self.io_deps = io_deps or {} |
| 572 | self.indices = tuple(_tmp_indices) |
| 573 | |
| 574 | # optimize_blockwise won't merge where `concatenate` doesn't match, so |
| 575 | # enforce a canonical value if there are no axes for reduction. |
| 576 | output_indices_set = set(self.output_indices) |
| 577 | if concatenate is not None and all( |
| 578 | i in output_indices_set |
| 579 | for name, ind in self.indices |
| 580 | if ind is not None |
| 581 | for i in ind |
| 582 | ): |
| 583 | concatenate = None |
| 584 | self.concatenate = concatenate |
| 585 | self.new_axes = new_axes or {} |
| 586 | |