(
sch: s_tir.Schedule,
len_bx: int,
len_tx: int,
len_vec: int,
)
| 154 | |
| 155 | |
| 156 | def _schedule( |
| 157 | sch: s_tir.Schedule, |
| 158 | len_bx: int, |
| 159 | len_tx: int, |
| 160 | len_vec: int, |
| 161 | ): |
| 162 | # pylint: disable=invalid-name |
| 163 | block = sch.get_sblock("B") |
| 164 | xo, xi, k = sch.get_loops(block) |
| 165 | bx, xo = sch.split(xo, factors=[len_bx, None]) |
| 166 | xi, tx, vec = sch.split(xi, factors=[None, len_tx, len_vec]) |
| 167 | sch.reorder(bx, xi, tx, xo, k, vec) |
| 168 | bx = sch.fuse(bx, xi) |
| 169 | sch.bind(bx, "blockIdx.x") |
| 170 | sch.bind(tx, "threadIdx.x") |
| 171 | ldg = sch.cache_read(block, 0, "local") |
| 172 | sch.compute_at(ldg, k, preserve_unit_loops=True) |
| 173 | sch.vectorize(sch.get_loops(ldg)[-1]) |
| 174 | sch.decompose_reduction(block, k) |
| 175 | # pylint: enable=invalid-name |
| 176 | |
| 177 | |
| 178 | def main(): # pylint: disable=too-many-locals |
no test coverage detected
searching dependent graphs…