Tests sort function using numpy
()
| 64 | |
| 65 | |
| 66 | def test_sort_np(): |
| 67 | """Tests sort function using numpy""" |
| 68 | dshape = (1, 2, 3, 4, 5, 6) |
| 69 | axis = 4 |
| 70 | reduced_shape = (1, 2, 3, 4, 6) |
| 71 | is_ascend = True |
| 72 | data = te.placeholder(dshape, name="data") |
| 73 | sort_num = te.placeholder(reduced_shape, name="sort_num", dtype="int32") |
| 74 | out = te.extern( |
| 75 | data.shape, |
| 76 | [data, sort_num], |
| 77 | lambda ins, outs: tvm.tirx.call_packed( |
| 78 | "tvm.contrib.sort.argsort_nms", ins[0], ins[1], outs[0], axis, is_ascend |
| 79 | ), |
| 80 | dtype="int32", |
| 81 | name="sort_tensor", |
| 82 | ) |
| 83 | |
| 84 | dev = tvm.cpu(0) |
| 85 | target = "llvm" |
| 86 | f = tvm.compile(te.create_prim_func([data, sort_num, out]), target=target) |
| 87 | |
| 88 | np_data = np.random.uniform(size=dshape) |
| 89 | np_out = np.argsort(np_data, axis=axis) |
| 90 | sort_num_input = np.full(reduced_shape, dshape[axis]) |
| 91 | a = tvm.runtime.tensor(np.array(np_data).astype(data.dtype), dev) |
| 92 | b = tvm.runtime.tensor(np.array(sort_num_input).astype(sort_num.dtype), dev) |
| 93 | c = tvm.runtime.tensor(np.zeros(a.shape, dtype=out.dtype), dev) |
| 94 | f(a, b, c) |
| 95 | tvm.testing.assert_allclose(c.numpy(), np_out, rtol=1e-5) |
| 96 | |
| 97 | |
| 98 | if __name__ == "__main__": |
no test coverage detected
searching dependent graphs…