()
| 50 | |
| 51 | |
| 52 | def test_rpc(): |
| 53 | if not tvm.runtime.enabled("rpc"): |
| 54 | return |
| 55 | n = 1024 |
| 56 | dtype = "float32" |
| 57 | temp = utils.tempdir() |
| 58 | wasm_path = temp.relpath("relax.wasm") |
| 59 | target = tvm.target.Target( |
| 60 | "webgpu", host={"kind": "llvm", "mtriple": "wasm32-unknown-unknown-wasm"} |
| 61 | ) |
| 62 | |
| 63 | mod = get_model() |
| 64 | ex = relax.build(mod, target) |
| 65 | ex.export_library(wasm_path, fcompile=tvmjs.create_tvmjs_wasm) |
| 66 | wasm_binary = open(wasm_path, "rb").read() |
| 67 | |
| 68 | remote = rpc.connect( |
| 69 | proxy_host, |
| 70 | proxy_port, |
| 71 | key="wasm", |
| 72 | session_constructor_args=["rpc.WasmSession", wasm_binary], |
| 73 | ) |
| 74 | |
| 75 | def check(remote): |
| 76 | dev = remote.webgpu(0) |
| 77 | # invoke the function |
| 78 | vm = relax.VirtualMachine(remote.system_lib(), device=dev) |
| 79 | adata = np.random.uniform(size=n).astype(dtype) |
| 80 | bdata = np.random.uniform(size=n).astype(dtype) |
| 81 | a = tvm.runtime.tensor(adata, dev) |
| 82 | b = tvm.runtime.tensor(bdata, dev) |
| 83 | vm.set_input("main", a, b) |
| 84 | vm.invoke_stateful("main") |
| 85 | c = vm.get_outputs("main") |
| 86 | np.testing.assert_equal(c.numpy(), a.numpy() + b.numpy()) |
| 87 | |
| 88 | check(remote) |
| 89 | |
| 90 | |
| 91 | test_rpc() |
no test coverage detected
searching dependent graphs…