MCPcopy
hub / github.com/apache/tvm / test_sort

Function test_sort

tests/python/contrib/test_sort.py:28–63  ·  view source on GitHub ↗

Tests sort function

()

Source from the content-addressed store, hash-verified

26
27
28def test_sort():
29 """Tests sort function"""
30 n = 2
31 l = 5
32 m = 3
33 data = te.placeholder((n, l, m), name="data")
34 sort_num = te.placeholder((n, m), name="sort_num", dtype="int32")
35 axis = 1
36 is_ascend = False
37 out = te.extern(
38 data.shape,
39 [data, sort_num],
40 lambda ins, outs: tvm.tirx.call_packed(
41 "tvm.contrib.sort.argsort_nms", ins[0], ins[1], outs[0], axis, is_ascend
42 ),
43 dtype="int32",
44 name="sort_tensor",
45 )
46 input_data = [
47 [[1, 2, 3], [2, 4.5, 3.5], [1.1, 0.5, 1], [3.2, -5, 0.5], [1.5, 0, 0]],
48 [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]],
49 ]
50 sort_num_input = [[1, 2, 3], [4, 5, 5]]
51 sorted_index = [
52 [[0, 1, 1], [1, 0, 0], [2, 2, 2], [3, 3, 3], [4, 4, 4]],
53 [[3, 4, 4], [2, 3, 3], [1, 2, 2], [0, 1, 1], [4, 0, 0]],
54 ]
55
56 dev = tvm.cpu(0)
57 target = "llvm"
58 f = tvm.compile(te.create_prim_func([data, sort_num, out]), target=target)
59 a = tvm.runtime.tensor(np.array(input_data).astype(data.dtype), dev)
60 b = tvm.runtime.tensor(np.array(sort_num_input).astype(sort_num.dtype), dev)
61 c = tvm.runtime.tensor(np.zeros(a.shape, dtype=out.dtype), dev)
62 f(a, b, c)
63 tvm.testing.assert_allclose(c.numpy(), np.array(sorted_index).astype(out.dtype), rtol=1e-5)
64
65
66def test_sort_np():

Callers 1

test_sort.pyFile · 0.70

Calls 8

placeholderMethod · 0.80
call_packedMethod · 0.80
numpyMethod · 0.80
fFunction · 0.50
cpuMethod · 0.45
compileMethod · 0.45
astypeMethod · 0.45
zerosMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…