MCPcopy Index your code
hub / github.com/apache/tvm / verify_torch_dlpack

Function verify_torch_dlpack

tests/python/contrib/test_dlpack.py:25–61  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

23
24
25def verify_torch_dlpack():
26 a = np.random.randn(1337)
27 tvm_a = tvm.runtime.tensor(a)
28 np.testing.assert_equal(tvm.runtime.from_dlpack(tvm_a).numpy(), a)
29
30 try:
31 import torch
32 import torch.utils.dlpack
33
34 x = torch.rand(56, 56)
35 tvm_x = tvm.runtime.from_dlpack(torch.utils.dlpack.to_dlpack(x))
36 np.testing.assert_equal(x.numpy(), tvm_x.numpy())
37 y = tvm.runtime.from_dlpack(tvm_x)
38 np.testing.assert_equal(y.numpy(), tvm_x.numpy())
39 np.testing.assert_equal(torch.utils.dlpack.from_dlpack(y).numpy(), tvm_x.numpy())
40
41 n = tvm.runtime.convert(137)
42 xx = torch.rand(137, 137)
43 yy = torch.rand(137, 137)
44 zz2 = torch.empty(137, 137)
45 zz = xx.mm(yy)
46 XX = te.placeholder((n, n), name="X")
47 YY = te.placeholder((n, n), name="Y")
48
49 k = te.reduce_axis((0, n), name="k")
50 ZZ = te.compute((n, n), lambda i, j: te.sum(XX[i, k] * YY[k, j], axis=k))
51 # No need to speficy target_host if it's llvm
52 # Otherwise you will need to specify the target and target_host
53 f = tvm.compile(te.create_prim_func([XX, YY, ZZ]))
54
55 f_pytorch = to_pytorch_func(f)
56 zz2 = torch.empty(137, 137)
57 f_pytorch(xx, yy, zz2)
58 tvm.testing.assert_allclose(zz.numpy(), zz2.numpy(), rtol=1e-4, atol=1e-4)
59
60 except ImportError:
61 pass
62
63
64def test_torch_dlpack():

Callers 1

test_torch_dlpackFunction · 0.85

Calls 7

to_pytorch_funcFunction · 0.90
numpyMethod · 0.80
placeholderMethod · 0.80
convertMethod · 0.45
emptyMethod · 0.45
sumMethod · 0.45
compileMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…