Main function demonstrating cross-function calls.
(self, x: torch.Tensor, w: torch.Tensor)
| 45 | |
| 46 | @I.pyfunc |
| 47 | def main(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: |
| 48 | """Main function demonstrating cross-function calls.""" |
| 49 | n = x.shape[0] |
| 50 | |
| 51 | # Call TIR function |
| 52 | lv = self.call_tir(self.matmul, [x, w], out_sinfo=R.Tensor((n, 20), "float32")) |
| 53 | |
| 54 | # Apply ReLU |
| 55 | lv1 = F.relu(lv) |
| 56 | |
| 57 | # Call packed function (will be added dynamically) |
| 58 | lv2 = self.call_dps_packed("my_softmax", [lv1, 1], out_sinfo=R.Tensor((n, 20), "float32")) |
| 59 | |
| 60 | # Call Python function |
| 61 | lv3 = self.my_identity_func(lv2) |
| 62 | |
| 63 | return lv3 |
| 64 | |
| 65 | @T.prim_func(s_tir=True) |
| 66 | def matmul( |
no test coverage detected