MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / get_kernel_code

Function get_kernel_code

cpp/kernels/fmha_v2/setup.py:2014–2244  ·  view source on GitHub ↗
(kspec, kname, lname)

Source from the content-addressed store, hash-verified

2012
2013
2014def get_kernel_code(kspec, kname, lname):
2015 min_cuda_version = 0 # no restriction
2016
2017 # The architecture that determines the instruction.
2018 effective_sm, sm_name = get_effective_sm_and_name(kspec)
2019
2020 if effective_sm >= 80:
2021 min_cuda_version = 11000
2022
2023 launcher_name = lname
2024 causal_kernel_name = kname.replace('__placeholder__', '_causal')
2025 custom_mask_kernel_name = kname.replace('__placeholder__', '_custom_mask')
2026 sliding_or_chunked_causal_kernel_name = kname.replace(
2027 '__placeholder__', '_sliding_or_chunked_causal')
2028 kernel_name = kname.replace('__placeholder__', '')
2029
2030 # FIXME: use separate parameters when generating cubins for trtllm.
2031 if not kspec.cross_mha:
2032 params_type = 'bert::Fused_multihead_attention_params_v{}'.format(
2033 kspec.version)
2034 else:
2035 params_type = 'bert::Fused_multihead_attention_params_mhca'
2036
2037 if (effective_sm < 90):
2038 instruction_traits = sm_name.capitalize() + '_' + dtype2traits[
2039 kspec.dtype]
2040 elif (effective_sm == 90):
2041 instruction_traits = sm_name.capitalize() + '_' + hopper_dtype2traits[
2042 kspec.dtype]
2043 # for hopper, we differentiate instruction_traits_o and instruction_traits_p
2044 instruction_traits_p, instruction_traits_o = get_hopper_instruction_traits(
2045 instruction_traits, kspec)
2046 #print(instruction_traits_p, instruction_traits_o)
2047
2048 if (effective_sm < 90):
2049 if kspec.flash_attention:
2050 kernel_variant = 'flash_attention'
2051 else:
2052 kernel_variant = '1xN' if kspec.warps_m == 1 else '2x2'
2053 elif (effective_sm == 90):
2054 if kspec.warps_n > 1:
2055 # for hopper we slice the problem along the M dim.
2056 kernel_variant = '4xN' + '_hopper'
2057 else:
2058 kernel_variant = '4x1' + '_hopper'
2059
2060 if (effective_sm < 90):
2061 kernel_traits = 'Kernel_traits_'
2062 elif (effective_sm == 90):
2063 kernel_traits = 'FMHA_kernel_traits_hopper_'
2064
2065 if kspec.interleaved:
2066 kernel_traits += 'interleaved_v2'
2067 elif kspec.cross_mha:
2068 kernel_traits += 'fmhca'
2069 else:
2070 kernel_traits += 'v{}'.format(kspec.version)
2071

Callers 1

generate_filesFunction · 0.70

Calls 7

selected_mask_typesFunction · 0.85
enable_mutexFunction · 0.85
get_reg_countFunction · 0.85
enable_tma_storeFunction · 0.85
replaceMethod · 0.80

Tested by

no test coverage detected