MCPcopy Index your code
hub / github.com/apache/tvm / test_rpc

Function test_rpc

web/tests/python/relax_rpc_test.py:52–88  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

50
51
52def test_rpc():
53 if not tvm.runtime.enabled("rpc"):
54 return
55 n = 1024
56 dtype = "float32"
57 temp = utils.tempdir()
58 wasm_path = temp.relpath("relax.wasm")
59 target = tvm.target.Target(
60 "webgpu", host={"kind": "llvm", "mtriple": "wasm32-unknown-unknown-wasm"}
61 )
62
63 mod = get_model()
64 ex = relax.build(mod, target)
65 ex.export_library(wasm_path, fcompile=tvmjs.create_tvmjs_wasm)
66 wasm_binary = open(wasm_path, "rb").read()
67
68 remote = rpc.connect(
69 proxy_host,
70 proxy_port,
71 key="wasm",
72 session_constructor_args=["rpc.WasmSession", wasm_binary],
73 )
74
75 def check(remote):
76 dev = remote.webgpu(0)
77 # invoke the function
78 vm = relax.VirtualMachine(remote.system_lib(), device=dev)
79 adata = np.random.uniform(size=n).astype(dtype)
80 bdata = np.random.uniform(size=n).astype(dtype)
81 a = tvm.runtime.tensor(adata, dev)
82 b = tvm.runtime.tensor(bdata, dev)
83 vm.set_input("main", a, b)
84 vm.invoke_stateful("main")
85 c = vm.get_outputs("main")
86 np.testing.assert_equal(c.numpy(), a.numpy() + b.numpy())
87
88 check(remote)
89
90
91test_rpc()

Callers 1

relax_rpc_test.pyFile · 0.70

Calls 8

get_modelFunction · 0.85
relpathMethod · 0.80
checkFunction · 0.70
readMethod · 0.65
enabledMethod · 0.45
buildMethod · 0.45
export_libraryMethod · 0.45
connectMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…