(kspec, kname, lname)
| 2012 | |
| 2013 | |
| 2014 | def get_kernel_code(kspec, kname, lname): |
| 2015 | min_cuda_version = 0 # no restriction |
| 2016 | |
| 2017 | # The architecture that determines the instruction. |
| 2018 | effective_sm, sm_name = get_effective_sm_and_name(kspec) |
| 2019 | |
| 2020 | if effective_sm >= 80: |
| 2021 | min_cuda_version = 11000 |
| 2022 | |
| 2023 | launcher_name = lname |
| 2024 | causal_kernel_name = kname.replace('__placeholder__', '_causal') |
| 2025 | custom_mask_kernel_name = kname.replace('__placeholder__', '_custom_mask') |
| 2026 | sliding_or_chunked_causal_kernel_name = kname.replace( |
| 2027 | '__placeholder__', '_sliding_or_chunked_causal') |
| 2028 | kernel_name = kname.replace('__placeholder__', '') |
| 2029 | |
| 2030 | # FIXME: use separate parameters when generating cubins for trtllm. |
| 2031 | if not kspec.cross_mha: |
| 2032 | params_type = 'bert::Fused_multihead_attention_params_v{}'.format( |
| 2033 | kspec.version) |
| 2034 | else: |
| 2035 | params_type = 'bert::Fused_multihead_attention_params_mhca' |
| 2036 | |
| 2037 | if (effective_sm < 90): |
| 2038 | instruction_traits = sm_name.capitalize() + '_' + dtype2traits[ |
| 2039 | kspec.dtype] |
| 2040 | elif (effective_sm == 90): |
| 2041 | instruction_traits = sm_name.capitalize() + '_' + hopper_dtype2traits[ |
| 2042 | kspec.dtype] |
| 2043 | # for hopper, we differentiate instruction_traits_o and instruction_traits_p |
| 2044 | instruction_traits_p, instruction_traits_o = get_hopper_instruction_traits( |
| 2045 | instruction_traits, kspec) |
| 2046 | #print(instruction_traits_p, instruction_traits_o) |
| 2047 | |
| 2048 | if (effective_sm < 90): |
| 2049 | if kspec.flash_attention: |
| 2050 | kernel_variant = 'flash_attention' |
| 2051 | else: |
| 2052 | kernel_variant = '1xN' if kspec.warps_m == 1 else '2x2' |
| 2053 | elif (effective_sm == 90): |
| 2054 | if kspec.warps_n > 1: |
| 2055 | # for hopper we slice the problem along the M dim. |
| 2056 | kernel_variant = '4xN' + '_hopper' |
| 2057 | else: |
| 2058 | kernel_variant = '4x1' + '_hopper' |
| 2059 | |
| 2060 | if (effective_sm < 90): |
| 2061 | kernel_traits = 'Kernel_traits_' |
| 2062 | elif (effective_sm == 90): |
| 2063 | kernel_traits = 'FMHA_kernel_traits_hopper_' |
| 2064 | |
| 2065 | if kspec.interleaved: |
| 2066 | kernel_traits += 'interleaved_v2' |
| 2067 | elif kspec.cross_mha: |
| 2068 | kernel_traits += 'fmhca' |
| 2069 | else: |
| 2070 | kernel_traits += 'v{}'.format(kspec.version) |
| 2071 |
no test coverage detected