MCPcopy Index your code
hub / github.com/apache/tvm / test_rms_norm

Function test_rms_norm

tests/python/relax/test_codegen_cutlass.py:1666–1738  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

1664
1665
1666def test_rms_norm():
1667 @I.ir_module(s_tir=True)
1668 class Module:
1669 @T.prim_func(s_tir=True)
1670 def rms_norm(
1671 A: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"),
1672 B: T.Buffer((T.int64(4096),), "float16"),
1673 rms_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"),
1674 ):
1675 T.func_attr({"tirx.noalias": True})
1676 # with T.sblock("root"):
1677 Ared_temp = T.sblock_alloc_buffer((T.int64(1), T.int64(1)))
1678 for bsz, i, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
1679 with T.sblock("Ared_temp"):
1680 v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k])
1681 T.reads(A[v_bsz, v_i, v_k])
1682 T.writes(Ared_temp[v_bsz, v_i])
1683 with T.init():
1684 Ared_temp[v_bsz, v_i] = T.float32(0)
1685 Ared_temp[v_bsz, v_i] = Ared_temp[v_bsz, v_i] + T.Cast(
1686 "float32", A[v_bsz, v_i, v_k]
1687 ) * T.Cast("float32", A[v_bsz, v_i, v_k])
1688 for bsz, i, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
1689 with T.sblock("rms_norm"):
1690 v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k])
1691 T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i])
1692 T.writes(rms_norm[v_bsz, v_i, v_k])
1693 rms_norm[v_bsz, v_i, v_k] = T.Cast(
1694 "float16",
1695 T.Cast("float32", B[v_k])
1696 * (
1697 T.Cast("float32", A[v_bsz, v_i, v_k])
1698 / T.sqrt(
1699 Ared_temp[v_bsz, v_i] * T.float32(0.000244140625)
1700 + T.float32(9.9999999999999995e-07)
1701 )
1702 ),
1703 )
1704
1705 @R.function
1706 def main(
1707 input: R.Tensor((1, 1, 4096), dtype="float16"),
1708 weight: R.Tensor((4096,), dtype="float16"),
1709 ) -> R.Tensor((1, 1, 4096), dtype="float16"):
1710 cls = Module
1711 with R.dataflow():
1712 lv = R.call_tir(
1713 cls.rms_norm, (input, weight), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")
1714 )
1715 R.output(lv)
1716 return lv
1717
1718 data_shape = (1, 1, 4096)
1719 dtype = "float16"
1720 mod = partition_for_cutlass(Module)
1721
1722 # TODO(@tvm-team): This is temporary patch.Currently, the remaining packed function triggers error since it is not scheduled.
1723 # This is because RunCodegen does not support PrimFunc well yet.

Callers

nothing calls this directly

Calls 3

partition_for_cutlassFunction · 0.90
build_and_runFunction · 0.70
astypeMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…