MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / get_kernel_traits_code

Function get_kernel_traits_code

cpp/kernels/fmha_v2/setup.py:2728–3077  ·  view source on GitHub ↗
(specs_names)

Source from the content-addressed store, hash-verified

2726
2727
2728def get_kernel_traits_code(specs_names):
2729 print_kernel_specs = []
2730
2731 for kspec, fname, lname, kname in specs_names:
2732 effective_sm, sm_name = get_effective_sm_and_name(kspec)
2733 if (effective_sm < 90):
2734 instruction_traits = sm_name.capitalize() + '_' + dtype2traits[
2735 kspec.dtype]
2736 elif (effective_sm == 90):
2737 instruction_traits = sm_name.capitalize(
2738 ) + '_' + hopper_dtype2traits[kspec.dtype]
2739 instruction_traits_p, instruction_traits_o = get_hopper_instruction_traits(
2740 instruction_traits, kspec)
2741
2742 if (effective_sm < 90):
2743 kernel_traits = 'Kernel_traits_'
2744 elif (effective_sm == 90):
2745 kernel_traits = 'FMHA_kernel_traits_hopper_'
2746
2747 if kspec.interleaved:
2748 kernel_traits += 'interleaved_v2'
2749 elif kspec.cross_mha:
2750 kernel_traits += 'fmhca'
2751 else:
2752 kernel_traits += 'v{}'.format(kspec.version)
2753
2754 # needed by warpspec kernels.
2755 fp8_kernel = kspec.dtype in ["e4m3", "e4m3_fp32"]
2756 kernel_traits_header = "fmha::ws::Kernel_traits_Hopper_qgmma_e4m3_fp32<" if fp8_kernel \
2757 else f"fmha::ws::Kernel_traits<fmha::{instruction_traits},"
2758
2759 flags = 0
2760 if kspec.ldgsts_q:
2761 flags |= 1
2762 if kspec.ldgsts_k:
2763 flags |= 2
2764 if kspec.ldgsts_v:
2765 flags |= 4
2766 if kspec.share_smem_k_v:
2767 flags |= 8
2768 if kspec.has_scale_max:
2769 flags |= 16
2770 if not kspec.head_interleaved:
2771 flags |= 32
2772 if kspec.limit_qk_fragments:
2773 flags |= 128
2774 if kspec.limit_qk_fragments:
2775 flags |= 256
2776 if kspec.has_noloop:
2777 # NOTE do not use flags 512 = 0x200 as it is reserved; do not add to flags because it
2778 # will be selectively added to no-loop kernel trait upon generating .cu templates
2779 pass
2780 if kspec.enable_attn_logit_softcapping:
2781 flags |= 2048
2782 if kspec.tiled:
2783 flags |= 4096
2784 if kspec.is_mtp:
2785 flags |= 8192

Callers 1

generate_filesFunction · 0.85

Calls 6

enable_mutexFunction · 0.85
selected_mask_typesFunction · 0.85
replaceMethod · 0.80
appendMethod · 0.45

Tested by

no test coverage detected