From 8c9ac49d8d7c2333c8cdb03ffd00e558bd9e9f97 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Tue, 20 Jan 2026 18:13:13 -0500 Subject: [PATCH 1/2] Fix TRTLLM attention assert (#28) * fix * fix assert Signed-off-by: Lucas Wilkinson * Zero first element Signed-off-by: Matthew Bonanni --------- Signed-off-by: Lucas Wilkinson Signed-off-by: Matthew Bonanni Co-authored-by: Lucas Wilkinson Co-authored-by: Matthew Bonanni --- vllm/v1/attention/backends/flashinfer.py | 3 --- vllm/v1/worker/gpu_model_runner.py | 4 ++++ 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 7a0aff80ed19..a6e4776060f0 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -1059,9 +1059,6 @@ def build( ## DECODE PATHWAY if num_decodes > 0: if decode_use_trtllm: - assert num_decode_tokens % num_decodes == 0, ( - "TRTLLM decode requires uniform query lengths per request." - ) attn_metadata.decode = TRTLLMDecode( block_tables=block_table_tensor[:num_decodes], seq_lens=seq_lens[:num_decodes], diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 32a07d64ada3..23d5bac75d00 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4382,7 +4382,11 @@ def _dummy_run( self.seq_lens.copy_to_gpu() cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) + self.query_start_loc.np[0] = 0 self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens + # Note: pad query_start_loc to be non-decreasing, as kernels + # like FlashAttention requires that + self.query_start_loc.np[num_reqs + 1 :].fill(cum_num_tokens[-1]) self.query_start_loc.copy_to_gpu() pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL From 40b862bf68159a9bf833f8cf4fc12a604f3418e4 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Sun, 15 Mar 2026 16:23:12 -0700 Subject: [PATCH 2/2] [Attention] Add FA4 (flash_attn.cute) support for ViT encoder on Blackwell Enable Flash Attention 4 (CUTLASS-based) as a multimodal encoder attention backend on Blackwell (SM100+). This replaces the FA v2 `flash_fwd_kernel` with FA4's optimized kernel, reducing vision encoder attention time by 37.5% and improving end-to-end QPS by 26.3% on Qwen3-VL-235B (MLPerf Offline). Changes: - Add `FLASH_ATTN_CUTE` to AttentionBackendEnum - Add `fa4_utils.py` with `flash_attn.cute` wrapper (handles tuple return) - Add `_forward_fa4()` path in MMEncoderAttention - Add `vit_fa4_flash_attn_wrapper` custom op in vit_attn_wrappers.py - Auto-select FA4 on Blackwell in cuda.py, with explicit opt-in via `--mm-encoder-attn-backend FLASH_ATTN_CUTE` - Block FLASH_ATTN_CUTE in KV-cache attention config (ViT-only) - Guard flash_attn.ops import in rotary_embedding for namespace compat Requires: flash_attn.cute (pip install from Dao-AILab/flash-attention) + quack-kernels, torch-c-dlpack-ext Benchmark (Qwen3-VL-235B, 4x GB300, MLPerf Offline): FA2 baseline: 38.41 QPS FA4: 48.53 QPS (+26.3%) Co-Authored-By: Claude Opus 4.6 (1M context) --- vllm/config/attention.py | 11 ++- .../layers/attention/mm_encoder_attention.py | 43 +++++++++- .../layers/rotary_embedding/common.py | 5 +- vllm/platforms/cuda.py | 38 ++++++++- vllm/v1/attention/backends/fa4_utils.py | 82 +++++++++++++++++++ vllm/v1/attention/backends/registry.py | 1 + vllm/v1/attention/ops/vit_attn_wrappers.py | 68 +++++++++++++++ 7 files changed, 244 insertions(+), 4 deletions(-) create mode 100644 vllm/v1/attention/backends/fa4_utils.py diff --git a/vllm/config/attention.py b/vllm/config/attention.py index 293045787a1c..0de92462ee73 100644 --- a/vllm/config/attention.py +++ b/vllm/config/attention.py @@ -67,7 +67,16 @@ def compute_hash(self) -> str: def validate_backend_before(cls, value: Any) -> Any: """Enable parsing of the `backend` enum type from string.""" if isinstance(value, str): - return AttentionBackendEnum[value.upper()] + value = AttentionBackendEnum[value.upper()] + + if value == AttentionBackendEnum.FLASH_ATTN_CUTE: + raise ValueError( + "AttentionConfig.backend does not support FLASH_ATTN_CUTE " + "(FA4 / flash_attn.cute). This is a ViT/MM-encoder-only " + "attention tag. Use --mm-encoder-attn-backend / " + "MultiModalConfig.mm_encoder_attn_backend instead." + ) + return value def _set_from_env_if_set(self, field_name: str, env_var_name: str) -> None: diff --git a/vllm/model_executor/layers/attention/mm_encoder_attention.py b/vllm/model_executor/layers/attention/mm_encoder_attention.py index 44e990d29c16..aabf8c9b9427 100644 --- a/vllm/model_executor/layers/attention/mm_encoder_attention.py +++ b/vllm/model_executor/layers/attention/mm_encoder_attention.py @@ -11,6 +11,7 @@ from vllm.v1.attention.backends.fa_utils import get_flash_attn_version from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.ops.vit_attn_wrappers import ( + vit_fa4_flash_attn_wrapper, vit_flash_attn_wrapper, vit_torch_sdpa_wrapper, ) @@ -79,6 +80,10 @@ def __init__( AttentionBackendEnum.ROCM_AITER_FA, } + self.is_fa4_backend = ( + self.attn_backend == AttentionBackendEnum.FLASH_ATTN_CUTE + ) + self._fa_version = ( get_flash_attn_version() if self.is_flash_attn_backend else None ) @@ -182,6 +187,40 @@ def _forward_fa( output = output.reshape(bsz, q_len, -1) return output + def _forward_fa4( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, + ) -> torch.Tensor: + """FA4 (flash_attn.cute) attention for multimodal encoder.""" + assert (cu_seqlens is not None and max_seqlen is not None) or ( + cu_seqlens is None and max_seqlen is None + ), "cu_seqlens and max_seqlen should be both set or both None." + + bsz, q_len = query.size()[:2] + kv_len = key.size(1) + is_reshaped = query.dim() != 4 + + query, key, value = self.maybe_reshape_qkv_to_4d( + query, key, value, bsz, q_len, kv_len + ) + + output = vit_fa4_flash_attn_wrapper( + q=query, + k=key, + v=value, + batch_size=bsz, + scale=self.scale, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + if is_reshaped: + output = output.reshape(bsz, q_len, -1) + return output + def forward_native( self, query: torch.Tensor, @@ -200,7 +239,9 @@ def forward_cuda( cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention ) -> torch.Tensor: - if self.is_flash_attn_backend: + if self.is_fa4_backend: + return self._forward_fa4(query, key, value, cu_seqlens, max_seqlen) + elif self.is_flash_attn_backend: return self._forward_fa(query, key, value, cu_seqlens, max_seqlen) elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: return self._forward_sdpa(query, key, value, cu_seqlens) diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index 34de1da561f5..531fbfc4415b 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -135,7 +135,10 @@ def __init__( self.apply_rotary_emb_flash_attn = None if find_spec("flash_attn") is not None: - from flash_attn.ops.triton.rotary import apply_rotary + try: + from flash_attn.ops.triton.rotary import apply_rotary + except (ImportError, ModuleNotFoundError): + apply_rotary = None self.apply_rotary_emb_flash_attn = apply_rotary diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 47d634416ae5..e6dce98c27ea 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -361,6 +361,7 @@ def get_attn_backend_cls( def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: return [ AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.FLASH_ATTN_CUTE, AttentionBackendEnum.FLASH_ATTN, ] @@ -376,10 +377,45 @@ def get_vit_attn_backend( f"Backend {backend} is not supported for vit attention. " f"Supported backends are: {cls.get_supported_vit_attn_backends()}" ) + if backend == AttentionBackendEnum.FLASH_ATTN_CUTE: + cc = cls.get_device_capability() + if cc is None or cc.major < 10: + raise ValueError( + "FLASH_ATTN_CUTE (FA4) requires Blackwell (SM100+). " + f"Current device: SM{cc.major}{cc.minor}" if cc + else "No device found." + ) + from vllm.v1.attention.backends.fa4_utils import ( + is_flash_attn_cute_available, + ) + if not is_flash_attn_cute_available(): + raise ImportError( + "flash_attn.cute is not installed. " + "Install with: pip install " + "git+https://github.com/Dao-AILab/flash-attention.git" + "#subdirectory=flash_attn/cute" + ) logger.info_once(f"Using backend {backend} for vit attention") return backend - # Try FlashAttention first + # On Blackwell, try FA4 first + if (cc := cls.get_device_capability()) and cc.major >= 10: + try: + from vllm.v1.attention.backends.fa4_utils import ( + is_flash_attn_cute_available, + ) + if is_flash_attn_cute_available() and dtype in ( + torch.float16, torch.bfloat16 + ): + logger.info_once( + "Auto-selecting FLASH_ATTN_CUTE (FA4) for ViT on " + "Blackwell." + ) + return AttentionBackendEnum.FLASH_ATTN_CUTE + except ImportError: + pass + + # Try FlashAttention (FA2) if (cc := cls.get_device_capability()) and cc.major >= 8: try: backend_class = AttentionBackendEnum.FLASH_ATTN.get_class() diff --git a/vllm/v1/attention/backends/fa4_utils.py b/vllm/v1/attention/backends/fa4_utils.py new file mode 100644 index 000000000000..b1a6c1731c4a --- /dev/null +++ b/vllm/v1/attention/backends/fa4_utils.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Utilities for Flash Attention 4 (flash_attn.cute) on Blackwell.""" + +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +_FA4_AVAILABLE: bool | None = None +_FA4_FUNC = None + +# Head sizes optimized for FA4 on Blackwell +FA4_SUPPORTED_HEAD_SIZES = (64, 96, 128, 192) + + +def _import_fa4_fwd(): + """Try importing FA4. Prefer flash_attn_cute to avoid polluting the + flash_attn namespace which would break vllm's flash_attn.ops imports.""" + try: + from flash_attn_cute.interface import _flash_attn_fwd + return _flash_attn_fwd + except (ImportError, ModuleNotFoundError): + pass + try: + from flash_attn.cute.interface import _flash_attn_fwd + return _flash_attn_fwd + except (ImportError, ModuleNotFoundError): + pass + return None + + +def is_flash_attn_cute_available() -> bool: + global _FA4_AVAILABLE + if _FA4_AVAILABLE is not None: + return _FA4_AVAILABLE + _FA4_AVAILABLE = _import_fa4_fwd() is not None + return _FA4_AVAILABLE + + +def _get_fa4_func(): + global _FA4_FUNC + if _FA4_FUNC is None: + _FA4_FUNC = _import_fa4_fwd() + if _FA4_FUNC is None: + raise ImportError( + "flash_attn.cute is not available. " + "Install flash-attn-4 for Blackwell FA4 support." + ) + return _FA4_FUNC + + +def flash_attn_cute_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: float | None = None, + causal: bool = False, +) -> torch.Tensor: + """Wrapper around flash_attn.cute for varlen (variable-length) attention.""" + fa4_fwd = _get_fa4_func() + + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + result = fa4_fwd( + q, k, v, + softmax_scale=softmax_scale, + causal=causal, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + ) + # _flash_attn_fwd returns (output, softmax_lse); we only need output + if isinstance(result, tuple): + return result[0] + return result diff --git a/vllm/v1/attention/backends/registry.py b/vllm/v1/attention/backends/registry.py index bd45702fa587..af1e44bd15b1 100644 --- a/vllm/v1/attention/backends/registry.py +++ b/vllm/v1/attention/backends/registry.py @@ -58,6 +58,7 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): "vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse.ROCMAiterMLASparseBackend" ) TORCH_SDPA = "" # this tag is only used for ViT + FLASH_ATTN_CUTE = "" # FA4 via flash_attn.cute, ViT/MM encoder only FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" FLASHINFER_MLA = ( "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend" diff --git a/vllm/v1/attention/ops/vit_attn_wrappers.py b/vllm/v1/attention/ops/vit_attn_wrappers.py index f077a61c984f..9e4964091b88 100644 --- a/vllm/v1/attention/ops/vit_attn_wrappers.py +++ b/vllm/v1/attention/ops/vit_attn_wrappers.py @@ -183,3 +183,71 @@ def vit_torch_sdpa_wrapper( cu_seqlens: torch.Tensor | None = None, ) -> torch.Tensor: return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, scale, cu_seqlens) + + +# ---- FA4 (flash_attn.cute) wrappers ---- + +def fa4_flash_attn_maxseqlen_wrapper( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + batch_size: int, + scale: float | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, +) -> torch.Tensor: + from vllm.v1.attention.backends.fa4_utils import flash_attn_cute_varlen_func + + q_len = q.size(1) + if cu_seqlens is None: + cu_seqlens = torch.arange( + 0, (batch_size + 1) * q_len, step=q_len, + dtype=torch.int32, device=q.device, + ) + max_seqlen_val = q_len if max_seqlen is None else max_seqlen.item() + + q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + output = flash_attn_cute_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen_val, + max_seqlen_k=max_seqlen_val, + softmax_scale=scale, + causal=False, + ) + context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size) + return context_layer + + +def fa4_flash_attn_maxseqlen_wrapper_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + batch_size: int, + scale: float | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, +) -> torch.Tensor: + return torch.empty_like(q) + + +direct_register_custom_op( + op_name="fa4_flash_attn_maxseqlen_wrapper", + op_func=fa4_flash_attn_maxseqlen_wrapper, + fake_impl=fa4_flash_attn_maxseqlen_wrapper_fake, +) + + +def vit_fa4_flash_attn_wrapper( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + batch_size: int, + scale: float | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, +) -> torch.Tensor: + return torch.ops.vllm.fa4_flash_attn_maxseqlen_wrapper( + q, k, v, batch_size, scale, cu_seqlens, max_seqlen, + )