(kernel_spec)
| 1803 | |
| 1804 | |
| 1805 | def encode_name(kernel_spec): |
| 1806 | effective_sm, sm_name = get_effective_sm_and_name(kernel_spec) |
| 1807 | # Is it a kernel for the interleaved NC/32HW32 INT8 layout? |
| 1808 | il_tag = '_il' if kernel_spec.interleaved else '' |
| 1809 | # Is it using the quantization scaling factor as an approximation of the max in softmax? |
| 1810 | scale_max_tag = '_scale_max' if kernel_spec.has_scale_max else '' |
| 1811 | # Deal with multi-CTA kernels for which the sequence length is seq_len per CTA * # of CTAs. |
| 1812 | seqlen = kernel_spec.seq_len * kernel_spec.ctas_per_head |
| 1813 | # The qkv layout. |
| 1814 | qkv_layout_tag = '' |
| 1815 | if kernel_spec.input_layout == InputLayout.PACKED_QKV: |
| 1816 | qkv_layout_tag = '_qkv' |
| 1817 | elif kernel_spec.input_layout == InputLayout.Q_PAGED_KV: |
| 1818 | qkv_layout_tag = '_q_paged_kv' |
| 1819 | elif kernel_spec.input_layout == InputLayout.SEPARATE_Q_K_V: |
| 1820 | qkv_layout_tag = '_q_k_v' |
| 1821 | else: |
| 1822 | qkv_layout_tag = '_q_kv' |
| 1823 | # for SM90 kernels, let's also differentiate ldgsts and tma kernels |
| 1824 | feature_tags = '' |
| 1825 | if (effective_sm == 90): |
| 1826 | # let's think about where to insert tma/ldgsts in the string before MR. [Timmy] |
| 1827 | if (kernel_spec.ldgsts_q == True): |
| 1828 | tma_or_ldgsts = '_ldgsts' |
| 1829 | else: |
| 1830 | tma_or_ldgsts = '_tma' |
| 1831 | if kernel_spec.warp_specialization: |
| 1832 | warp_specialization_tag = '_ws' |
| 1833 | # hopper warp-specialized kernels has specialized optimization for cases without alibi. |
| 1834 | if kernel_spec.alibi: |
| 1835 | feature_tags += '_alibi' |
| 1836 | if kernel_spec.return_softmax_stats: |
| 1837 | feature_tags += '_softmax' |
| 1838 | else: |
| 1839 | warp_specialization_tag = '' |
| 1840 | else: |
| 1841 | tma_or_ldgsts = '' |
| 1842 | warp_specialization_tag = '' |
| 1843 | |
| 1844 | if kernel_spec.enable_attn_logit_softcapping: |
| 1845 | feature_tags += '_softcapping' |
| 1846 | if kernel_spec.enable_skip_softmax: |
| 1847 | feature_tags += '_skipSoftmax' |
| 1848 | if kernel_spec.sage_block_sizes: |
| 1849 | feature_tags += f"_sage_{'_'.join(map(str, kernel_spec.sage_block_sizes))}" |
| 1850 | if kernel_spec.output_dtype: |
| 1851 | feature_tags += f"_output_{kernel_spec.output_dtype}" |
| 1852 | if kernel_spec.ctas_per_head > 1: |
| 1853 | fmt = 'fmha_v{version}{il_tag}_{dtype}_' + str( |
| 1854 | seqlen |
| 1855 | ) + '_{head_size}{attrib}{scale_max_tag}{tma_or_ldgsts}_sm{sm}' |
| 1856 | elif kernel_spec.flash_attention: |
| 1857 | fmt = 'fmha_v{version}{il_tag}_flash_attention_{dtype}_{loop_step}_{kv_loop_step}_S{qkv_layout_tag}_{head_size}{head_size_v_str}{attrib}{feature_tags}{scale_max_tag}{tma_or_ldgsts}{warp_specialization_tag}_sm{sm}' |
| 1858 | elif kernel_spec.cross_mha: |
| 1859 | fmt = 'fmha_mhca_{dtype}_{seq_len}_{head_size}{scale_max_tag}{tma_or_ldgsts}_sm{sm}' |
| 1860 | else: |
| 1861 | fmt = 'fmha_v{version}{il_tag}_{dtype}_{seq_len}_{head_size}{attrib}{scale_max_tag}{tma_or_ldgsts}_sm{sm}' |
| 1862 | head_size_v_str = "" if kernel_spec.head_size_v == 0 else f"x{kernel_spec.head_size_v}" |
no test coverage detected