Skip to content

Upgrade cudnn_frontend to 1.24 and enable cuDNN SDPA for MHA/GQA#28849

Draft
tianleiwu wants to merge 2 commits into
mainfrom
tlwu/20260607/upgrade_cudnn_frontend
Draft

Upgrade cudnn_frontend to 1.24 and enable cuDNN SDPA for MHA/GQA#28849
tianleiwu wants to merge 2 commits into
mainfrom
tlwu/20260607/upgrade_cudnn_frontend

Conversation

@tianleiwu
Copy link
Copy Markdown
Contributor

Description

Upgrades the cudnn_frontend dependency from 1.12.0 → 1.24.0 and wires the updated cuDNN SDPA (scaled dot-product attention) kernels into the CUDA MultiHeadAttention and GroupQueryAttention operators. On SM≥90 (Hopper/Blackwell), cuDNN SDPA is auto-preferred for FP16/BF16 ahead of Flash Attention / cutlass FMHA, which significantly improves GQA prefill throughput.

Key Changes

Area Change
Dependency cmake/deps.txt: cudnn_frontend 1.12.0 → 1.24.0.
Build cmake/external/cudnn_frontend.cmake: mark cudnn_frontend headers as SYSTEM includes so v1.24's unused static helper does not trip -Werror=unused-function.
SDPA wrapper cudnn_fmha/cudnn_flash_attention.cc: migrate to the v1.24 API — set_generate_stats(false) (replaces deprecated set_is_inference), diagonal-band causal masking (set_diagonal_alignment + set_diagonal_band_right_bound/set_diagonal_band_left_bound), and synthesize the missing seq_len_q/seq_len_kv side that v1.24 now requires when a padding mask is used.
MHA multihead_attention.{cc,h}: enable cuDNN SDPA for FP16 and BF16; compute cuDNN eligibility before Flash and prefer it on SM≥90 unless the user pinned a kernel.
GQA group_query_attention.{cc,h}, group_query_attention_impl.cu, attention_data.h: add a cuDNN SDPA path (non-quantized FP16/BF16, no softcap/smooth-softmax/head-sink/local-window, BNSH KV cache), dispatched after XQA and before Flash/MEA/unfused.
Kernel selection attention_kernel_options.{cc,h}: track explicit sdpa_kernel selection and honor an explicit ORT_ENABLE_CUDNN_FLASH_ATTENTION=0 so it disables the SM≥90 auto path.

Kernel Priority

  • SM≥90, FP16/BF16: cuDNN SDPA is auto-preferred unless the user explicitly selects a kernel via the sdpa_kernel provider option or sets ORT_ENABLE_CUDNN_FLASH_ATTENTION=0.
  • GQA decode: XQA remains highest priority where eligible; cuDNN SDPA outranks Flash/MEA/unfused for the remaining eligible cases.
  • ORT_ENABLE_CUDNN_FLASH_ATTENTION=0 disables cuDNN entirely (including the auto path); =1 force-enables it; the sdpa_kernel provider option overrides env vars.

Benchmark Results

Measured with onnxruntime/test/python/transformers/benchmark_gqa.py on NVIDIA H200 (SM 9.0), CUDA 13.0 / cuDNN 9.19, Llama3-8B-shaped GQA (b1, 32 query heads, 8 KV heads, head size 128, FP16).

  • Baseline = Flash Attention (ORT_ENABLE_CUDNN_FLASH_ATTENTION=0)
  • This PR = cuDNN SDPA (default on SM≥90)

The prefill (prompt) phase shows the largest gains for the dense variants:

ORT-GQA-Dense — prompt latency (ms, lower is better)

seq_len Baseline (Flash) This PR (cuDNN) Speedup
64 0.082 0.054 1.52×
128 0.186 0.056 3.33×
256 0.184 0.065 2.83×
512 0.247 0.074 3.32×
1024 0.295 0.122 2.42×
2048 0.681 0.250 2.72×
4096 1.294 0.699 1.85×
8192 3.864 1.256 3.08×

ORT-GQA-Dense-PackedQKV — prompt latency (ms, lower is better)

seq_len Baseline (Flash) This PR (cuDNN) Speedup
64 0.197 0.053 3.69×
128 0.160 0.055 2.93×
256 0.213 0.060 3.53×
512 0.226 0.073 3.09×
1024 0.333 0.291 1.15×
2048 0.595 0.252 2.36×
4096 1.312 0.697 1.88×
8192 5.014 1.259 3.98×

Prefill is ~1.5×–4× faster across sequence lengths for both dense variants. Decode (token) latency is unchanged within run-to-run noise.

Testing

  • onnxruntime_provider_test --gtest_filter='GroupQueryAttentionTest.*:MultiHeadAttentionTest.*' — GQA 44/44, MHA 18/18 pass on H200.
  • Broader attention regression (AttentionTest.*:PackedMultiHeadAttentionTest.*:DecoderMaskedMultiHeadAttentionTest.*) — 143/143 pass.
  • Verified the new path on cuDNN 9.8 and 9.19 (decode s_q==1 causal-mask edge case handled for cuDNN ≤ 9.9).
  • Verified ORT_ENABLE_CUDNN_FLASH_ATTENTION=0 disables the SM≥90 auto path (0 cuDNN selections) while the default run selects cuDNN.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants