()
| 23 | |
| 24 | |
| 25 | def verify_torch_dlpack(): |
| 26 | a = np.random.randn(1337) |
| 27 | tvm_a = tvm.runtime.tensor(a) |
| 28 | np.testing.assert_equal(tvm.runtime.from_dlpack(tvm_a).numpy(), a) |
| 29 | |
| 30 | try: |
| 31 | import torch |
| 32 | import torch.utils.dlpack |
| 33 | |
| 34 | x = torch.rand(56, 56) |
| 35 | tvm_x = tvm.runtime.from_dlpack(torch.utils.dlpack.to_dlpack(x)) |
| 36 | np.testing.assert_equal(x.numpy(), tvm_x.numpy()) |
| 37 | y = tvm.runtime.from_dlpack(tvm_x) |
| 38 | np.testing.assert_equal(y.numpy(), tvm_x.numpy()) |
| 39 | np.testing.assert_equal(torch.utils.dlpack.from_dlpack(y).numpy(), tvm_x.numpy()) |
| 40 | |
| 41 | n = tvm.runtime.convert(137) |
| 42 | xx = torch.rand(137, 137) |
| 43 | yy = torch.rand(137, 137) |
| 44 | zz2 = torch.empty(137, 137) |
| 45 | zz = xx.mm(yy) |
| 46 | XX = te.placeholder((n, n), name="X") |
| 47 | YY = te.placeholder((n, n), name="Y") |
| 48 | |
| 49 | k = te.reduce_axis((0, n), name="k") |
| 50 | ZZ = te.compute((n, n), lambda i, j: te.sum(XX[i, k] * YY[k, j], axis=k)) |
| 51 | # No need to speficy target_host if it's llvm |
| 52 | # Otherwise you will need to specify the target and target_host |
| 53 | f = tvm.compile(te.create_prim_func([XX, YY, ZZ])) |
| 54 | |
| 55 | f_pytorch = to_pytorch_func(f) |
| 56 | zz2 = torch.empty(137, 137) |
| 57 | f_pytorch(xx, yy, zz2) |
| 58 | tvm.testing.assert_allclose(zz.numpy(), zz2.numpy(), rtol=1e-4, atol=1e-4) |
| 59 | |
| 60 | except ImportError: |
| 61 | pass |
| 62 | |
| 63 | |
| 64 | def test_torch_dlpack(): |
no test coverage detected
searching dependent graphs…