()
| 1664 | |
| 1665 | |
| 1666 | def test_rms_norm(): |
| 1667 | @I.ir_module(s_tir=True) |
| 1668 | class Module: |
| 1669 | @T.prim_func(s_tir=True) |
| 1670 | def rms_norm( |
| 1671 | A: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), |
| 1672 | B: T.Buffer((T.int64(4096),), "float16"), |
| 1673 | rms_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), |
| 1674 | ): |
| 1675 | T.func_attr({"tirx.noalias": True}) |
| 1676 | # with T.sblock("root"): |
| 1677 | Ared_temp = T.sblock_alloc_buffer((T.int64(1), T.int64(1))) |
| 1678 | for bsz, i, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)): |
| 1679 | with T.sblock("Ared_temp"): |
| 1680 | v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k]) |
| 1681 | T.reads(A[v_bsz, v_i, v_k]) |
| 1682 | T.writes(Ared_temp[v_bsz, v_i]) |
| 1683 | with T.init(): |
| 1684 | Ared_temp[v_bsz, v_i] = T.float32(0) |
| 1685 | Ared_temp[v_bsz, v_i] = Ared_temp[v_bsz, v_i] + T.Cast( |
| 1686 | "float32", A[v_bsz, v_i, v_k] |
| 1687 | ) * T.Cast("float32", A[v_bsz, v_i, v_k]) |
| 1688 | for bsz, i, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)): |
| 1689 | with T.sblock("rms_norm"): |
| 1690 | v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k]) |
| 1691 | T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i]) |
| 1692 | T.writes(rms_norm[v_bsz, v_i, v_k]) |
| 1693 | rms_norm[v_bsz, v_i, v_k] = T.Cast( |
| 1694 | "float16", |
| 1695 | T.Cast("float32", B[v_k]) |
| 1696 | * ( |
| 1697 | T.Cast("float32", A[v_bsz, v_i, v_k]) |
| 1698 | / T.sqrt( |
| 1699 | Ared_temp[v_bsz, v_i] * T.float32(0.000244140625) |
| 1700 | + T.float32(9.9999999999999995e-07) |
| 1701 | ) |
| 1702 | ), |
| 1703 | ) |
| 1704 | |
| 1705 | @R.function |
| 1706 | def main( |
| 1707 | input: R.Tensor((1, 1, 4096), dtype="float16"), |
| 1708 | weight: R.Tensor((4096,), dtype="float16"), |
| 1709 | ) -> R.Tensor((1, 1, 4096), dtype="float16"): |
| 1710 | cls = Module |
| 1711 | with R.dataflow(): |
| 1712 | lv = R.call_tir( |
| 1713 | cls.rms_norm, (input, weight), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16") |
| 1714 | ) |
| 1715 | R.output(lv) |
| 1716 | return lv |
| 1717 | |
| 1718 | data_shape = (1, 1, 4096) |
| 1719 | dtype = "float16" |
| 1720 | mod = partition_for_cutlass(Module) |
| 1721 | |
| 1722 | # TODO(@tvm-team): This is temporary patch.Currently, the remaining packed function triggers error since it is not scheduled. |
| 1723 | # This is because RunCodegen does not support PrimFunc well yet. |
nothing calls this directly
no test coverage detected
searching dependent graphs…