MCPcopy
hub / github.com/apache/tvm / check

Function check

tests/python/codegen/test_target_codegen_llvm.py:565–659  ·  view source on GitHub ↗
(start, end, dstart, dend, dtype, floor_div=False)

Source from the content-addressed store, hash-verified

563 """Check that the semantics of div and mod is correct"""
564
565 def check(start, end, dstart, dend, dtype, floor_div=False):
566 a_size = end - start + 1
567 b_size = dend - dstart + 1
568
569 div_fn = tvm.tirx.floordiv if floor_div else tvm.tirx.truncdiv
570 mod_fn = tvm.tirx.floormod if floor_div else tvm.tirx.truncmod
571
572 # Build clipping helpers — capture TIR const values from env
573 _start = tvm.tirx.const(start, dtype)
574 _end = tvm.tirx.const(end, dtype)
575 _dstart = tvm.tirx.const(dstart, dtype)
576 _dend = tvm.tirx.const(dend, dtype)
577
578 if start == end:
579 clipa = lambda x: _start
580 else:
581 clipa = lambda x: T.min(_end, T.max(_start, x))
582
583 if dstart == dend:
584 clipb = lambda x: _dstart
585 else:
586 clipb = lambda x: T.min(_dend, T.max(_dstart, x))
587
588 @I.ir_module(s_tir=True)
589 class Module:
590 @T.prim_func(s_tir=True)
591 def main(
592 A: T.Buffer((a_size,), dtype),
593 B: T.Buffer((b_size,), dtype),
594 D: T.Buffer((a_size, b_size), dtype),
595 M: T.Buffer((a_size, b_size), dtype),
596 ):
597 T.func_attr({"tirx.noalias": True})
598 for i, j in T.grid(a_size, b_size):
599 with T.sblock("D"):
600 v_i, v_j = T.axis.remap("SS", [i, j])
601 T.reads(A[v_i], B[v_j])
602 T.writes(D[v_i, v_j])
603 D[v_i, v_j] = div_fn(clipa(A[v_i]), clipb(B[v_j]))
604 with T.sblock("M"):
605 v_i, v_j = T.axis.remap("SS", [i, j])
606 T.reads(A[v_i], B[v_j])
607 T.writes(M[v_i, v_j])
608 M[v_i, v_j] = mod_fn(clipa(A[v_i]), clipb(B[v_j]))
609
610 f = tvm.compile(Module, target="llvm")
611
612 # Fill input arrays with values
613 A_arr = tvm.runtime.empty((a_size,), dtype)
614 B_arr = tvm.runtime.empty((b_size,), dtype)
615 A_arr.copyfrom(np.arange(start, end + 1, dtype=dtype))
616 B_np = np.arange(dstart, dend + 1, dtype=dtype)
617 # If the range of the divisor contains 0, replace it with 1 to avoid division by zero
618 if dend >= 0 and dstart <= 0:
619 B_np[-dstart] = 1
620 B_arr.copyfrom(B_np)
621 D_arr = tvm.runtime.empty((a_size, b_size), dtype)
622 M_arr = tvm.runtime.empty((a_size, b_size), dtype)

Callers 1

test_llvm_divFunction · 0.70

Calls 9

_show_infoFunction · 0.85
minMethod · 0.80
maxMethod · 0.80
copyfromMethod · 0.80
numpyMethod · 0.80
fFunction · 0.50
compileMethod · 0.45
emptyMethod · 0.45
arangeMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…