MCPcopy Index your code
hub / github.com/apache/tvm / test_rpc_remote_module

Function test_rpc_remote_module

tests/python/runtime/test_runtime_rpc.py:233–339  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

231@pytest.mark.skipif(not env.has_rpc(), reason="need rpc")
232@pytest.mark.skipif(not env.has_llvm(), reason="need llvm")
233def test_rpc_remote_module():
234 # graph
235 n = tvm.runtime.convert(102)
236 A = te.placeholder((n,), name="A")
237 B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B")
238 mod = tvm.ir.IRModule.from_expr(te.create_prim_func([A, B]).with_attr("global_symbol", "myadd"))
239
240 server0 = rpc.Server(key="x0")
241 server1 = rpc.Server(key="x1")
242
243 client = rpc.connect(
244 "127.0.0.1",
245 server0.port,
246 key="x0",
247 session_constructor_args=["rpc.Connect", "127.0.0.1", server1.port, "x1", False],
248 )
249
250 def check_remote(remote):
251 temp = utils.tempdir()
252 dev = remote.cpu(0)
253 f = tvm.compile(mod, "llvm")
254 path_dso = temp.relpath("dev_lib.so")
255 f.export_library(path_dso)
256 remote.upload(path_dso)
257 f1 = remote.load_module("dev_lib.so")
258 a = tvm.runtime.tensor(np.random.uniform(size=102).astype(A.dtype), dev)
259 b = tvm.runtime.tensor(np.zeros(102, dtype=A.dtype), dev)
260 time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10)
261 cost = time_f(a, b).mean
262 print(f"{cost:g} secs/op")
263 np.testing.assert_equal(b.numpy(), a.numpy() + 1)
264
265 # Download the file from the remote
266 path_tar = temp.relpath("dev_lib.tar")
267 f.export_library(path_tar)
268 remote.upload(path_tar)
269 local_download_path = temp.relpath("dev_lib.download.so")
270 with open(local_download_path, "wb") as fo:
271 fo.write(remote.download_linked_module("dev_lib.tar"))
272 fupdated = tvm.runtime.load_module(local_download_path)
273 a = tvm.runtime.tensor(np.random.uniform(size=102).astype(A.dtype), tvm.cpu(0))
274 b = tvm.runtime.tensor(np.zeros(102, dtype=A.dtype), tvm.cpu(0))
275 fupdated(a, b)
276 np.testing.assert_equal(b.numpy(), a.numpy() + 1)
277
278 def check_minrpc():
279 if tvm.get_global_func("rpc.CreatePipeClient", allow_missing=True) is None:
280 return
281 # export to minrpc
282 temp = utils.tempdir()
283 # system lib prefix will trigger system lib build
284 f = tvm.compile(mod.with_attr("system_lib_prefix", ""), "llvm")
285 path_minrpc = temp.relpath("dev_lib.minrpc")
286 f.export_library(path_minrpc, fcompile=rpc.with_minrpc(cc.create_executable))
287
288 with pytest.raises(RuntimeError):
289 rpc.PopenSession("filenotexist")
290

Callers

nothing calls this directly

Calls 8

check_minrpcFunction · 0.85
placeholderMethod · 0.80
from_exprMethod · 0.80
ServerMethod · 0.80
check_remoteFunction · 0.70
convertMethod · 0.45
with_attrMethod · 0.45
connectMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…