MCPcopy Index your code
hub / github.com/apache/tvm / test_random_fill

Function test_random_fill

tests/python/contrib/test_random.py:107–163  ·  view source on GitHub ↗

Tests random_fill function

()

Source from the content-addressed store, hash-verified

105
106@pytest.mark.gpu
107def test_random_fill():
108 """Tests random_fill function"""
109
110 def test_local(dev, dtype):
111 if not tvm.get_global_func("tvm.contrib.random.random_fill", True):
112 print("skip because extern function is not available")
113 return
114 value = tvm.runtime.empty((512, 512), dtype, dev)
115 random_fill = tvm.get_global_func("tvm.contrib.random.random_fill")
116 random_fill(value)
117
118 assert np.count_nonzero(value.numpy()) == 512 * 512
119
120 # make sure arithmentic doesn't overflow too
121 np_values = value.numpy()
122 assert np.isfinite(np_values * np_values + np_values).any()
123
124 def test_rpc(dtype):
125 if not tvm.get_global_func("tvm.contrib.random.random_fill", True):
126 print("skip because extern function is not available")
127 return
128 if not tvm.testing.device_enabled("rpc") or not tvm.runtime.enabled("llvm"):
129 return
130
131 def check_remote(server):
132 remote = rpc.connect(server.host, server.port)
133 value = tvm.runtime.empty((512, 512), dtype, remote.cpu())
134 random_fill = remote.get_function("tvm.contrib.random.random_fill")
135 random_fill(value)
136
137 assert np.count_nonzero(value.numpy()) == 512 * 512
138
139 # make sure arithmentic doesn't overflow too
140 np_values = value.numpy()
141 assert np.isfinite(np_values * np_values + np_values).any()
142
143 check_remote(rpc.Server("127.0.0.1"))
144
145 # Packed sub-byte dtypes (e.g. int4) are intentionally unsupported by
146 # random_fill since #19714 and raise an error instead.
147 for dtype in [
148 "bool",
149 "int8",
150 "uint8",
151 "int16",
152 "uint16",
153 "int32",
154 "int32",
155 "int64",
156 "uint64",
157 "float16",
158 "float32",
159 "float64",
160 ]:
161 for _, dev in tvm.testing.enabled_targets():
162 test_local(dev, dtype)
163 test_rpc(dtype)
164

Callers 1

test_random.pyFile · 0.85

Calls 2

test_localFunction · 0.85
test_rpcFunction · 0.70

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…