| 138 | |
| 139 | @R.function |
| 140 | def entry_b( |
| 141 | q: R.Tensor((32, 8, 16, 8), dtype="float16"), |
| 142 | k: R.Tensor((32, 8, 16, 8), dtype="float16"), |
| 143 | v: R.Tensor((32, 8, 16, 8), dtype="float16"), |
| 144 | ) -> R.Tensor((32, 8, 16, 8), dtype="float16"): |
| 145 | cls = Expected |
| 146 | with R.dataflow(): |
| 147 | workspace_main: R.Tensor((65536,), dtype="uint8") = R.builtin.alloc_tensor( |
| 148 | R.shape([65536]), R.dtype("uint8"), R.prim_value(0) |
| 149 | ) |
| 150 | gv: R.Tensor((32, 8, 16, 8), dtype="float16") = cls.fused_relax_nn_attention_cutlass1( |
| 151 | q, k, v, workspace_main |
| 152 | ) + R.const(1, dtype="float16") |
| 153 | R.output(gv) |
| 154 | return gv |
| 155 | |
| 156 | |
| 157 | def test_single_attention(): |