Tests normal function
()
| 80 | |
| 81 | |
| 82 | def test_normal(): |
| 83 | """Tests normal function""" |
| 84 | m = 10240 |
| 85 | n = 10240 |
| 86 | A = random.normal(3, 4, size=(m, n)) |
| 87 | |
| 88 | def verify(target="llvm"): |
| 89 | if not tvm.testing.device_enabled(target): |
| 90 | print(f"skip because {target} is not enabled...") |
| 91 | return |
| 92 | if not tvm.get_global_func("tvm.contrib.random.normal", True): |
| 93 | print("skip because extern function is not available") |
| 94 | return |
| 95 | dev = tvm.cpu(0) |
| 96 | f = tvm.compile(te.create_prim_func([A]), target=target) |
| 97 | a = tvm.runtime.tensor(np.zeros((m, n), dtype=A.dtype), dev) |
| 98 | f(a) |
| 99 | na = a.numpy() |
| 100 | assert abs(np.mean(na) - 3) < 1e-1 |
| 101 | assert abs(np.std(na) - 4) < 1e-2 |
| 102 | |
| 103 | verify() |
| 104 | |
| 105 | |
| 106 | @pytest.mark.gpu |
no test coverage detected
searching dependent graphs…