(
x: R.Tensor((128, 128), "float32"),
W1: R.Tensor((128, 128), "float32"),
W2: R.Tensor((128, 128), "float32"),
)
| 444 | class MLP: # pylint: disable=too-few-public-methods |
| 445 | @R.function |
| 446 | def main( |
| 447 | x: R.Tensor((128, 128), "float32"), |
| 448 | W1: R.Tensor((128, 128), "float32"), |
| 449 | W2: R.Tensor((128, 128), "float32"), |
| 450 | ) -> R.Tensor((128, 128), "float32"): |
| 451 | R.func_attr({"global_symbol": "main"}) |
| 452 | with R.dataflow(): |
| 453 | lv0: R.Tensor((128, 128), "float32") = R.matmul(x, W1) |
| 454 | lv1: R.Tensor((128, 128), "float32") = R.nn.gelu(lv0) |
| 455 | lv2: R.Tensor((128, 128), "float32") = R.matmul(lv1, W2) |
| 456 | R.output(lv2) |
| 457 | return lv2 |
| 458 | |
| 459 | @tvm.script.ir_module |
| 460 | class ShardedMLP: # pylint: disable=too-few-public-methods |
no test coverage detected