Takes PyTorch tensors, calls TIR, returns PyTorch tensors.
(self, x, y)
| 96 | |
| 97 | @I.pyfunc |
| 98 | def forward(self, x, y): |
| 99 | """Takes PyTorch tensors, calls TIR, returns PyTorch tensors.""" |
| 100 | x_tvm = self._convert_pytorch_to_tvm(x) |
| 101 | y_tvm = self._convert_pytorch_to_tvm(y) |
| 102 | result = self.call_tir( |
| 103 | self.add_tir, [x_tvm, y_tvm], out_sinfo=R.Tensor((4,), "float32") |
| 104 | ) |
| 105 | return self._convert_tvm_to_pytorch(result) |
| 106 | |
| 107 | # TIR functions are JIT-compiled at instantiation |
| 108 | mod = MyFirstModule(device=tvm.cpu(0)) |
nothing calls this directly
no test coverage detected