MCPcopy
hub / github.com/apache/tvm / test_sort_np

Function test_sort_np

tests/python/contrib/test_sort.py:66–95  ·  view source on GitHub ↗

Tests sort function using numpy

()

Source from the content-addressed store, hash-verified

64
65
66def test_sort_np():
67 """Tests sort function using numpy"""
68 dshape = (1, 2, 3, 4, 5, 6)
69 axis = 4
70 reduced_shape = (1, 2, 3, 4, 6)
71 is_ascend = True
72 data = te.placeholder(dshape, name="data")
73 sort_num = te.placeholder(reduced_shape, name="sort_num", dtype="int32")
74 out = te.extern(
75 data.shape,
76 [data, sort_num],
77 lambda ins, outs: tvm.tirx.call_packed(
78 "tvm.contrib.sort.argsort_nms", ins[0], ins[1], outs[0], axis, is_ascend
79 ),
80 dtype="int32",
81 name="sort_tensor",
82 )
83
84 dev = tvm.cpu(0)
85 target = "llvm"
86 f = tvm.compile(te.create_prim_func([data, sort_num, out]), target=target)
87
88 np_data = np.random.uniform(size=dshape)
89 np_out = np.argsort(np_data, axis=axis)
90 sort_num_input = np.full(reduced_shape, dshape[axis])
91 a = tvm.runtime.tensor(np.array(np_data).astype(data.dtype), dev)
92 b = tvm.runtime.tensor(np.array(sort_num_input).astype(sort_num.dtype), dev)
93 c = tvm.runtime.tensor(np.zeros(a.shape, dtype=out.dtype), dev)
94 f(a, b, c)
95 tvm.testing.assert_allclose(c.numpy(), np_out, rtol=1e-5)
96
97
98if __name__ == "__main__":

Callers 1

test_sort.pyFile · 0.85

Calls 10

placeholderMethod · 0.80
call_packedMethod · 0.80
uniformMethod · 0.80
numpyMethod · 0.80
fFunction · 0.50
cpuMethod · 0.45
compileMethod · 0.45
fullMethod · 0.45
astypeMethod · 0.45
zerosMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…