MCPcopy Index your code
hub / github.com/apache/tvm / test_llvm_div

Function test_llvm_div

tests/python/codegen/test_target_codegen_llvm.py:562–698  ·  view source on GitHub ↗

Check that the semantics of div and mod is correct

()

Source from the content-addressed store, hash-verified

560
561@pytest.mark.skipif(not env.has_llvm(), reason="need llvm")
562def test_llvm_div():
563 """Check that the semantics of div and mod is correct"""
564
565 def check(start, end, dstart, dend, dtype, floor_div=False):
566 a_size = end - start + 1
567 b_size = dend - dstart + 1
568
569 div_fn = tvm.tirx.floordiv if floor_div else tvm.tirx.truncdiv
570 mod_fn = tvm.tirx.floormod if floor_div else tvm.tirx.truncmod
571
572 # Build clipping helpers — capture TIR const values from env
573 _start = tvm.tirx.const(start, dtype)
574 _end = tvm.tirx.const(end, dtype)
575 _dstart = tvm.tirx.const(dstart, dtype)
576 _dend = tvm.tirx.const(dend, dtype)
577
578 if start == end:
579 clipa = lambda x: _start
580 else:
581 clipa = lambda x: T.min(_end, T.max(_start, x))
582
583 if dstart == dend:
584 clipb = lambda x: _dstart
585 else:
586 clipb = lambda x: T.min(_dend, T.max(_dstart, x))
587
588 @I.ir_module(s_tir=True)
589 class Module:
590 @T.prim_func(s_tir=True)
591 def main(
592 A: T.Buffer((a_size,), dtype),
593 B: T.Buffer((b_size,), dtype),
594 D: T.Buffer((a_size, b_size), dtype),
595 M: T.Buffer((a_size, b_size), dtype),
596 ):
597 T.func_attr({"tirx.noalias": True})
598 for i, j in T.grid(a_size, b_size):
599 with T.sblock("D"):
600 v_i, v_j = T.axis.remap("SS", [i, j])
601 T.reads(A[v_i], B[v_j])
602 T.writes(D[v_i, v_j])
603 D[v_i, v_j] = div_fn(clipa(A[v_i]), clipb(B[v_j]))
604 with T.sblock("M"):
605 v_i, v_j = T.axis.remap("SS", [i, j])
606 T.reads(A[v_i], B[v_j])
607 T.writes(M[v_i, v_j])
608 M[v_i, v_j] = mod_fn(clipa(A[v_i]), clipb(B[v_j]))
609
610 f = tvm.compile(Module, target="llvm")
611
612 # Fill input arrays with values
613 A_arr = tvm.runtime.empty((a_size,), dtype)
614 B_arr = tvm.runtime.empty((b_size,), dtype)
615 A_arr.copyfrom(np.arange(start, end + 1, dtype=dtype))
616 B_np = np.arange(dstart, dend + 1, dtype=dtype)
617 # If the range of the divisor contains 0, replace it with 1 to avoid division by zero
618 if dend >= 0 and dstart <= 0:
619 B_np[-dstart] = 1

Callers

nothing calls this directly

Calls 1

checkFunction · 0.70

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…