(start, end, dstart, dend, dtype, floor_div=False)
| 563 | """Check that the semantics of div and mod is correct""" |
| 564 | |
| 565 | def check(start, end, dstart, dend, dtype, floor_div=False): |
| 566 | a_size = end - start + 1 |
| 567 | b_size = dend - dstart + 1 |
| 568 | |
| 569 | div_fn = tvm.tirx.floordiv if floor_div else tvm.tirx.truncdiv |
| 570 | mod_fn = tvm.tirx.floormod if floor_div else tvm.tirx.truncmod |
| 571 | |
| 572 | # Build clipping helpers — capture TIR const values from env |
| 573 | _start = tvm.tirx.const(start, dtype) |
| 574 | _end = tvm.tirx.const(end, dtype) |
| 575 | _dstart = tvm.tirx.const(dstart, dtype) |
| 576 | _dend = tvm.tirx.const(dend, dtype) |
| 577 | |
| 578 | if start == end: |
| 579 | clipa = lambda x: _start |
| 580 | else: |
| 581 | clipa = lambda x: T.min(_end, T.max(_start, x)) |
| 582 | |
| 583 | if dstart == dend: |
| 584 | clipb = lambda x: _dstart |
| 585 | else: |
| 586 | clipb = lambda x: T.min(_dend, T.max(_dstart, x)) |
| 587 | |
| 588 | @I.ir_module(s_tir=True) |
| 589 | class Module: |
| 590 | @T.prim_func(s_tir=True) |
| 591 | def main( |
| 592 | A: T.Buffer((a_size,), dtype), |
| 593 | B: T.Buffer((b_size,), dtype), |
| 594 | D: T.Buffer((a_size, b_size), dtype), |
| 595 | M: T.Buffer((a_size, b_size), dtype), |
| 596 | ): |
| 597 | T.func_attr({"tirx.noalias": True}) |
| 598 | for i, j in T.grid(a_size, b_size): |
| 599 | with T.sblock("D"): |
| 600 | v_i, v_j = T.axis.remap("SS", [i, j]) |
| 601 | T.reads(A[v_i], B[v_j]) |
| 602 | T.writes(D[v_i, v_j]) |
| 603 | D[v_i, v_j] = div_fn(clipa(A[v_i]), clipb(B[v_j])) |
| 604 | with T.sblock("M"): |
| 605 | v_i, v_j = T.axis.remap("SS", [i, j]) |
| 606 | T.reads(A[v_i], B[v_j]) |
| 607 | T.writes(M[v_i, v_j]) |
| 608 | M[v_i, v_j] = mod_fn(clipa(A[v_i]), clipb(B[v_j])) |
| 609 | |
| 610 | f = tvm.compile(Module, target="llvm") |
| 611 | |
| 612 | # Fill input arrays with values |
| 613 | A_arr = tvm.runtime.empty((a_size,), dtype) |
| 614 | B_arr = tvm.runtime.empty((b_size,), dtype) |
| 615 | A_arr.copyfrom(np.arange(start, end + 1, dtype=dtype)) |
| 616 | B_np = np.arange(dstart, dend + 1, dtype=dtype) |
| 617 | # If the range of the divisor contains 0, replace it with 1 to avoid division by zero |
| 618 | if dend >= 0 and dstart <= 0: |
| 619 | B_np[-dstart] = 1 |
| 620 | B_arr.copyfrom(B_np) |
| 621 | D_arr = tvm.runtime.empty((a_size, b_size), dtype) |
| 622 | M_arr = tvm.runtime.empty((a_size, b_size), dtype) |
no test coverage detected
searching dependent graphs…