Upgrade cudnn_frontend to 1.24 and enable cuDNN SDPA for MHA/GQA#28849
Draft
tianleiwu wants to merge 2 commits into
Draft
Upgrade cudnn_frontend to 1.24 and enable cuDNN SDPA for MHA/GQA#28849tianleiwu wants to merge 2 commits into
tianleiwu wants to merge 2 commits into
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Upgrades the
cudnn_frontenddependency from 1.12.0 → 1.24.0 and wires the updated cuDNN SDPA (scaled dot-product attention) kernels into the CUDAMultiHeadAttentionandGroupQueryAttentionoperators. On SM≥90 (Hopper/Blackwell), cuDNN SDPA is auto-preferred for FP16/BF16 ahead of Flash Attention / cutlass FMHA, which significantly improves GQA prefill throughput.Key Changes
cmake/deps.txt:cudnn_frontend1.12.0 → 1.24.0.cmake/external/cudnn_frontend.cmake: mark cudnn_frontend headers asSYSTEMincludes so v1.24's unused static helper does not trip-Werror=unused-function.cudnn_fmha/cudnn_flash_attention.cc: migrate to the v1.24 API —set_generate_stats(false)(replaces deprecatedset_is_inference), diagonal-band causal masking (set_diagonal_alignment+set_diagonal_band_right_bound/set_diagonal_band_left_bound), and synthesize the missingseq_len_q/seq_len_kvside that v1.24 now requires when a padding mask is used.multihead_attention.{cc,h}: enable cuDNN SDPA for FP16 and BF16; compute cuDNN eligibility before Flash and prefer it on SM≥90 unless the user pinned a kernel.group_query_attention.{cc,h},group_query_attention_impl.cu,attention_data.h: add a cuDNN SDPA path (non-quantized FP16/BF16, no softcap/smooth-softmax/head-sink/local-window, BNSH KV cache), dispatched after XQA and before Flash/MEA/unfused.attention_kernel_options.{cc,h}: track explicitsdpa_kernelselection and honor an explicitORT_ENABLE_CUDNN_FLASH_ATTENTION=0so it disables the SM≥90 auto path.Kernel Priority
sdpa_kernelprovider option or setsORT_ENABLE_CUDNN_FLASH_ATTENTION=0.ORT_ENABLE_CUDNN_FLASH_ATTENTION=0disables cuDNN entirely (including the auto path);=1force-enables it; thesdpa_kernelprovider option overrides env vars.Benchmark Results
Measured with
onnxruntime/test/python/transformers/benchmark_gqa.pyon NVIDIA H200 (SM 9.0), CUDA 13.0 / cuDNN 9.19,Llama3-8B-shaped GQA (b1, 32 query heads, 8 KV heads, head size 128, FP16).ORT_ENABLE_CUDNN_FLASH_ATTENTION=0)The prefill (prompt) phase shows the largest gains for the dense variants:
ORT-GQA-Dense — prompt latency (ms, lower is better)
ORT-GQA-Dense-PackedQKV — prompt latency (ms, lower is better)
Prefill is ~1.5×–4× faster across sequence lengths for both dense variants. Decode (token) latency is unchanged within run-to-run noise.
Testing
onnxruntime_provider_test --gtest_filter='GroupQueryAttentionTest.*:MultiHeadAttentionTest.*'— GQA 44/44, MHA 18/18 pass on H200.AttentionTest.*:PackedMultiHeadAttentionTest.*:DecoderMaskedMultiHeadAttentionTest.*) — 143/143 pass.s_q==1causal-mask edge case handled for cuDNN ≤ 9.9).ORT_ENABLE_CUDNN_FLASH_ATTENTION=0disables the SM≥90 auto path (0 cuDNN selections) while the default run selects cuDNN.