()
| 231 | @pytest.mark.skipif(not env.has_rpc(), reason="need rpc") |
| 232 | @pytest.mark.skipif(not env.has_llvm(), reason="need llvm") |
| 233 | def test_rpc_remote_module(): |
| 234 | # graph |
| 235 | n = tvm.runtime.convert(102) |
| 236 | A = te.placeholder((n,), name="A") |
| 237 | B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B") |
| 238 | mod = tvm.ir.IRModule.from_expr(te.create_prim_func([A, B]).with_attr("global_symbol", "myadd")) |
| 239 | |
| 240 | server0 = rpc.Server(key="x0") |
| 241 | server1 = rpc.Server(key="x1") |
| 242 | |
| 243 | client = rpc.connect( |
| 244 | "127.0.0.1", |
| 245 | server0.port, |
| 246 | key="x0", |
| 247 | session_constructor_args=["rpc.Connect", "127.0.0.1", server1.port, "x1", False], |
| 248 | ) |
| 249 | |
| 250 | def check_remote(remote): |
| 251 | temp = utils.tempdir() |
| 252 | dev = remote.cpu(0) |
| 253 | f = tvm.compile(mod, "llvm") |
| 254 | path_dso = temp.relpath("dev_lib.so") |
| 255 | f.export_library(path_dso) |
| 256 | remote.upload(path_dso) |
| 257 | f1 = remote.load_module("dev_lib.so") |
| 258 | a = tvm.runtime.tensor(np.random.uniform(size=102).astype(A.dtype), dev) |
| 259 | b = tvm.runtime.tensor(np.zeros(102, dtype=A.dtype), dev) |
| 260 | time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10) |
| 261 | cost = time_f(a, b).mean |
| 262 | print(f"{cost:g} secs/op") |
| 263 | np.testing.assert_equal(b.numpy(), a.numpy() + 1) |
| 264 | |
| 265 | # Download the file from the remote |
| 266 | path_tar = temp.relpath("dev_lib.tar") |
| 267 | f.export_library(path_tar) |
| 268 | remote.upload(path_tar) |
| 269 | local_download_path = temp.relpath("dev_lib.download.so") |
| 270 | with open(local_download_path, "wb") as fo: |
| 271 | fo.write(remote.download_linked_module("dev_lib.tar")) |
| 272 | fupdated = tvm.runtime.load_module(local_download_path) |
| 273 | a = tvm.runtime.tensor(np.random.uniform(size=102).astype(A.dtype), tvm.cpu(0)) |
| 274 | b = tvm.runtime.tensor(np.zeros(102, dtype=A.dtype), tvm.cpu(0)) |
| 275 | fupdated(a, b) |
| 276 | np.testing.assert_equal(b.numpy(), a.numpy() + 1) |
| 277 | |
| 278 | def check_minrpc(): |
| 279 | if tvm.get_global_func("rpc.CreatePipeClient", allow_missing=True) is None: |
| 280 | return |
| 281 | # export to minrpc |
| 282 | temp = utils.tempdir() |
| 283 | # system lib prefix will trigger system lib build |
| 284 | f = tvm.compile(mod.with_attr("system_lib_prefix", ""), "llvm") |
| 285 | path_minrpc = temp.relpath("dev_lib.minrpc") |
| 286 | f.export_library(path_minrpc, fcompile=rpc.with_minrpc(cc.create_executable)) |
| 287 | |
| 288 | with pytest.raises(RuntimeError): |
| 289 | rpc.PopenSession("filenotexist") |
| 290 |
nothing calls this directly
no test coverage detected
searching dependent graphs…