Tests sort function
()
| 26 | |
| 27 | |
| 28 | def test_sort(): |
| 29 | """Tests sort function""" |
| 30 | n = 2 |
| 31 | l = 5 |
| 32 | m = 3 |
| 33 | data = te.placeholder((n, l, m), name="data") |
| 34 | sort_num = te.placeholder((n, m), name="sort_num", dtype="int32") |
| 35 | axis = 1 |
| 36 | is_ascend = False |
| 37 | out = te.extern( |
| 38 | data.shape, |
| 39 | [data, sort_num], |
| 40 | lambda ins, outs: tvm.tirx.call_packed( |
| 41 | "tvm.contrib.sort.argsort_nms", ins[0], ins[1], outs[0], axis, is_ascend |
| 42 | ), |
| 43 | dtype="int32", |
| 44 | name="sort_tensor", |
| 45 | ) |
| 46 | input_data = [ |
| 47 | [[1, 2, 3], [2, 4.5, 3.5], [1.1, 0.5, 1], [3.2, -5, 0.5], [1.5, 0, 0]], |
| 48 | [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]], |
| 49 | ] |
| 50 | sort_num_input = [[1, 2, 3], [4, 5, 5]] |
| 51 | sorted_index = [ |
| 52 | [[0, 1, 1], [1, 0, 0], [2, 2, 2], [3, 3, 3], [4, 4, 4]], |
| 53 | [[3, 4, 4], [2, 3, 3], [1, 2, 2], [0, 1, 1], [4, 0, 0]], |
| 54 | ] |
| 55 | |
| 56 | dev = tvm.cpu(0) |
| 57 | target = "llvm" |
| 58 | f = tvm.compile(te.create_prim_func([data, sort_num, out]), target=target) |
| 59 | a = tvm.runtime.tensor(np.array(input_data).astype(data.dtype), dev) |
| 60 | b = tvm.runtime.tensor(np.array(sort_num_input).astype(sort_num.dtype), dev) |
| 61 | c = tvm.runtime.tensor(np.zeros(a.shape, dtype=out.dtype), dev) |
| 62 | f(a, b, c) |
| 63 | tvm.testing.assert_allclose(c.numpy(), np.array(sorted_index).astype(out.dtype), rtol=1e-5) |
| 64 | |
| 65 | |
| 66 | def test_sort_np(): |
no test coverage detected
searching dependent graphs…