(
shape: tuple[int, ...], mapper_in: dict[int, int], one_dims: list[int]
)
| 574 | |
| 575 | |
| 576 | def _convert_to_shape( |
| 577 | shape: tuple[int, ...], mapper_in: dict[int, int], one_dims: list[int] |
| 578 | ): |
| 579 | # Used to map the input dimensions onto output dimensions. The mapper tracks |
| 580 | # which input dimensions are mapped onto what output dimensions. |
| 581 | # i.e. {0: 0, 1:0, 2: 1} means that dimensions 0 and 1 are mapped into 0 and 2 into 1 |
| 582 | output_shape: list[list[int]] = [[]] * ( |
| 583 | len(set(mapper_in.values())) + len(one_dims) |
| 584 | ) |
| 585 | output_shape = list(map(lambda x: x.copy(), output_shape)) |
| 586 | for i in one_dims: |
| 587 | output_shape[i] = [1] |
| 588 | |
| 589 | for k, v in mapper_in.items(): |
| 590 | output_shape[v].append(shape[k]) |
| 591 | |
| 592 | return tuple(map(lambda x: reduce(mul, x), output_shape)) |
no test coverage detected
searching dependent graphs…