(
len_xo: int,
len_k: int,
len_xi: int,
dtype: str,
)
| 136 | |
| 137 | |
| 138 | def _workload( |
| 139 | len_xo: int, |
| 140 | len_k: int, |
| 141 | len_xi: int, |
| 142 | dtype: str, |
| 143 | ): |
| 144 | # pylint: disable=invalid-name |
| 145 | A = te.placeholder((len_xo, len_k, len_xi), dtype=dtype, name="A") |
| 146 | k = te.reduce_axis((0, len_k), "k") |
| 147 | B = te.compute( |
| 148 | (len_xo, len_xi), |
| 149 | lambda i, j: te.sum(A[i, k, j], axis=k), |
| 150 | name="B", |
| 151 | ) |
| 152 | # pylint: enable=invalid-name |
| 153 | return te.create_prim_func([A, B]) |
| 154 | |
| 155 | |
| 156 | def _schedule( |
no test coverage detected
searching dependent graphs…