(specs_names)
| 2726 | |
| 2727 | |
| 2728 | def get_kernel_traits_code(specs_names): |
| 2729 | print_kernel_specs = [] |
| 2730 | |
| 2731 | for kspec, fname, lname, kname in specs_names: |
| 2732 | effective_sm, sm_name = get_effective_sm_and_name(kspec) |
| 2733 | if (effective_sm < 90): |
| 2734 | instruction_traits = sm_name.capitalize() + '_' + dtype2traits[ |
| 2735 | kspec.dtype] |
| 2736 | elif (effective_sm == 90): |
| 2737 | instruction_traits = sm_name.capitalize( |
| 2738 | ) + '_' + hopper_dtype2traits[kspec.dtype] |
| 2739 | instruction_traits_p, instruction_traits_o = get_hopper_instruction_traits( |
| 2740 | instruction_traits, kspec) |
| 2741 | |
| 2742 | if (effective_sm < 90): |
| 2743 | kernel_traits = 'Kernel_traits_' |
| 2744 | elif (effective_sm == 90): |
| 2745 | kernel_traits = 'FMHA_kernel_traits_hopper_' |
| 2746 | |
| 2747 | if kspec.interleaved: |
| 2748 | kernel_traits += 'interleaved_v2' |
| 2749 | elif kspec.cross_mha: |
| 2750 | kernel_traits += 'fmhca' |
| 2751 | else: |
| 2752 | kernel_traits += 'v{}'.format(kspec.version) |
| 2753 | |
| 2754 | # needed by warpspec kernels. |
| 2755 | fp8_kernel = kspec.dtype in ["e4m3", "e4m3_fp32"] |
| 2756 | kernel_traits_header = "fmha::ws::Kernel_traits_Hopper_qgmma_e4m3_fp32<" if fp8_kernel \ |
| 2757 | else f"fmha::ws::Kernel_traits<fmha::{instruction_traits}," |
| 2758 | |
| 2759 | flags = 0 |
| 2760 | if kspec.ldgsts_q: |
| 2761 | flags |= 1 |
| 2762 | if kspec.ldgsts_k: |
| 2763 | flags |= 2 |
| 2764 | if kspec.ldgsts_v: |
| 2765 | flags |= 4 |
| 2766 | if kspec.share_smem_k_v: |
| 2767 | flags |= 8 |
| 2768 | if kspec.has_scale_max: |
| 2769 | flags |= 16 |
| 2770 | if not kspec.head_interleaved: |
| 2771 | flags |= 32 |
| 2772 | if kspec.limit_qk_fragments: |
| 2773 | flags |= 128 |
| 2774 | if kspec.limit_qk_fragments: |
| 2775 | flags |= 256 |
| 2776 | if kspec.has_noloop: |
| 2777 | # NOTE do not use flags 512 = 0x200 as it is reserved; do not add to flags because it |
| 2778 | # will be selectively added to no-loop kernel trait upon generating .cu templates |
| 2779 | pass |
| 2780 | if kspec.enable_attn_logit_softcapping: |
| 2781 | flags |= 2048 |
| 2782 | if kspec.tiled: |
| 2783 | flags |= 4096 |
| 2784 | if kspec.is_mtp: |
| 2785 | flags |= 8192 |
no test coverage detected