MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / encode_name

Function encode_name

cpp/kernels/fmha_v2/setup.py:1805–1884  ·  view source on GitHub ↗
(kernel_spec)

Source from the content-addressed store, hash-verified

1803
1804
1805def encode_name(kernel_spec):
1806 effective_sm, sm_name = get_effective_sm_and_name(kernel_spec)
1807 # Is it a kernel for the interleaved NC/32HW32 INT8 layout?
1808 il_tag = '_il' if kernel_spec.interleaved else ''
1809 # Is it using the quantization scaling factor as an approximation of the max in softmax?
1810 scale_max_tag = '_scale_max' if kernel_spec.has_scale_max else ''
1811 # Deal with multi-CTA kernels for which the sequence length is seq_len per CTA * # of CTAs.
1812 seqlen = kernel_spec.seq_len * kernel_spec.ctas_per_head
1813 # The qkv layout.
1814 qkv_layout_tag = ''
1815 if kernel_spec.input_layout == InputLayout.PACKED_QKV:
1816 qkv_layout_tag = '_qkv'
1817 elif kernel_spec.input_layout == InputLayout.Q_PAGED_KV:
1818 qkv_layout_tag = '_q_paged_kv'
1819 elif kernel_spec.input_layout == InputLayout.SEPARATE_Q_K_V:
1820 qkv_layout_tag = '_q_k_v'
1821 else:
1822 qkv_layout_tag = '_q_kv'
1823 # for SM90 kernels, let's also differentiate ldgsts and tma kernels
1824 feature_tags = ''
1825 if (effective_sm == 90):
1826 # let's think about where to insert tma/ldgsts in the string before MR. [Timmy]
1827 if (kernel_spec.ldgsts_q == True):
1828 tma_or_ldgsts = '_ldgsts'
1829 else:
1830 tma_or_ldgsts = '_tma'
1831 if kernel_spec.warp_specialization:
1832 warp_specialization_tag = '_ws'
1833 # hopper warp-specialized kernels has specialized optimization for cases without alibi.
1834 if kernel_spec.alibi:
1835 feature_tags += '_alibi'
1836 if kernel_spec.return_softmax_stats:
1837 feature_tags += '_softmax'
1838 else:
1839 warp_specialization_tag = ''
1840 else:
1841 tma_or_ldgsts = ''
1842 warp_specialization_tag = ''
1843
1844 if kernel_spec.enable_attn_logit_softcapping:
1845 feature_tags += '_softcapping'
1846 if kernel_spec.enable_skip_softmax:
1847 feature_tags += '_skipSoftmax'
1848 if kernel_spec.sage_block_sizes:
1849 feature_tags += f"_sage_{'_'.join(map(str, kernel_spec.sage_block_sizes))}"
1850 if kernel_spec.output_dtype:
1851 feature_tags += f"_output_{kernel_spec.output_dtype}"
1852 if kernel_spec.ctas_per_head > 1:
1853 fmt = 'fmha_v{version}{il_tag}_{dtype}_' + str(
1854 seqlen
1855 ) + '_{head_size}{attrib}{scale_max_tag}{tma_or_ldgsts}_sm{sm}'
1856 elif kernel_spec.flash_attention:
1857 fmt = 'fmha_v{version}{il_tag}_flash_attention_{dtype}_{loop_step}_{kv_loop_step}_S{qkv_layout_tag}_{head_size}{head_size_v_str}{attrib}{feature_tags}{scale_max_tag}{tma_or_ldgsts}{warp_specialization_tag}_sm{sm}'
1858 elif kernel_spec.cross_mha:
1859 fmt = 'fmha_mhca_{dtype}_{seq_len}_{head_size}{scale_max_tag}{tma_or_ldgsts}_sm{sm}'
1860 else:
1861 fmt = 'fmha_v{version}{il_tag}_{dtype}_{seq_len}_{head_size}{attrib}{scale_max_tag}{tma_or_ldgsts}_sm{sm}'
1862 head_size_v_str = "" if kernel_spec.head_size_v == 0 else f"x{kernel_spec.head_size_v}"

Callers 1

enumerate_kernelsFunction · 0.85

Calls 2

replaceMethod · 0.80

Tested by

no test coverage detected