Tests random_fill function
()
| 105 | |
| 106 | @pytest.mark.gpu |
| 107 | def test_random_fill(): |
| 108 | """Tests random_fill function""" |
| 109 | |
| 110 | def test_local(dev, dtype): |
| 111 | if not tvm.get_global_func("tvm.contrib.random.random_fill", True): |
| 112 | print("skip because extern function is not available") |
| 113 | return |
| 114 | value = tvm.runtime.empty((512, 512), dtype, dev) |
| 115 | random_fill = tvm.get_global_func("tvm.contrib.random.random_fill") |
| 116 | random_fill(value) |
| 117 | |
| 118 | assert np.count_nonzero(value.numpy()) == 512 * 512 |
| 119 | |
| 120 | # make sure arithmentic doesn't overflow too |
| 121 | np_values = value.numpy() |
| 122 | assert np.isfinite(np_values * np_values + np_values).any() |
| 123 | |
| 124 | def test_rpc(dtype): |
| 125 | if not tvm.get_global_func("tvm.contrib.random.random_fill", True): |
| 126 | print("skip because extern function is not available") |
| 127 | return |
| 128 | if not tvm.testing.device_enabled("rpc") or not tvm.runtime.enabled("llvm"): |
| 129 | return |
| 130 | |
| 131 | def check_remote(server): |
| 132 | remote = rpc.connect(server.host, server.port) |
| 133 | value = tvm.runtime.empty((512, 512), dtype, remote.cpu()) |
| 134 | random_fill = remote.get_function("tvm.contrib.random.random_fill") |
| 135 | random_fill(value) |
| 136 | |
| 137 | assert np.count_nonzero(value.numpy()) == 512 * 512 |
| 138 | |
| 139 | # make sure arithmentic doesn't overflow too |
| 140 | np_values = value.numpy() |
| 141 | assert np.isfinite(np_values * np_values + np_values).any() |
| 142 | |
| 143 | check_remote(rpc.Server("127.0.0.1")) |
| 144 | |
| 145 | # Packed sub-byte dtypes (e.g. int4) are intentionally unsupported by |
| 146 | # random_fill since #19714 and raise an error instead. |
| 147 | for dtype in [ |
| 148 | "bool", |
| 149 | "int8", |
| 150 | "uint8", |
| 151 | "int16", |
| 152 | "uint16", |
| 153 | "int32", |
| 154 | "int32", |
| 155 | "int64", |
| 156 | "uint64", |
| 157 | "float16", |
| 158 | "float32", |
| 159 | "float64", |
| 160 | ]: |
| 161 | for _, dev in tvm.testing.enabled_targets(): |
| 162 | test_local(dev, dtype) |
| 163 | test_rpc(dtype) |
| 164 |
no test coverage detected
searching dependent graphs…