| 85 | |
| 86 | @I.ir_module |
| 87 | class MyFirstModule(BasePyModule): |
| 88 | @T.prim_func(s_tir=True) |
| 89 | def add_tir( |
| 90 | A: T.Buffer((4,), "float32"), |
| 91 | B: T.Buffer((4,), "float32"), |
| 92 | C: T.Buffer((4,), "float32"), |
| 93 | ): |
| 94 | for i in range(4): |
| 95 | C[i] = A[i] + B[i] |
| 96 | |
| 97 | @I.pyfunc |
| 98 | def forward(self, x, y): |
| 99 | """Takes PyTorch tensors, calls TIR, returns PyTorch tensors.""" |
| 100 | x_tvm = self._convert_pytorch_to_tvm(x) |
| 101 | y_tvm = self._convert_pytorch_to_tvm(y) |
| 102 | result = self.call_tir( |
| 103 | self.add_tir, [x_tvm, y_tvm], out_sinfo=R.Tensor((4,), "float32") |
| 104 | ) |
| 105 | return self._convert_tvm_to_pytorch(result) |
| 106 | |
| 107 | # TIR functions are JIT-compiled at instantiation |
| 108 | mod = MyFirstModule(device=tvm.cpu(0)) |
no outgoing calls
no test coverage detected
searching dependent graphs…