25 Commits

Author SHA1 Message Date
  Wentao Ye c9b968a349
Merge branch 'main' into wentao-small-refactor 17 hours ago
  yjc9696 855b101d75
[Frontend] add tools for dsv32 developer role (#30040) 17 hours ago
  Robert Shaw d0502b4928
[MoE][Refactor 1/N] Separate Online Quantization (#30627) 17 hours ago
  Max Hu 3f175f18a2
[Bugfix] Fix multimodal configuration for Qwen3VL MOE model (#30670) 18 hours ago
  Cyrus Leung ed586e7724
[Refactor] [3/N] Move tool parser tests and run on CPU (#30693) 18 hours ago
  Chauncey 2a1776b7ac
[Refactor] [2/N] Move tool parsers into the vLLM main directory (#30675) 19 hours ago
  Nicolò Lucchesi 185c22bf2f
[Misc][Hybrid allocator + kv connector] Optionally enable hybrid allocator + KV cache connector (#29805) 21 hours ago
  duke e4806d973a
[BugFix] Add embed_input_ids method to make QWenLMHeadModel a vllm model (#30674) 21 hours ago
  wang.yuqi 4429d934de
[Model] Automatic conversion of TokenClassification model (#30666) 1 day ago
  ゆり 33278073d6
typing: Add type hints to TurnMetrics class in context.py (#30552) 1 day ago
  汪志鹏 1adeb3b84c
[New Model] BAGEL support (AR only) (#28439) 1 day ago
  Kunshang Ji e3a1cd1c59
[XPU] fix Dockerfile.xpu, avoid wheel conflicts (#30662) 1 day ago
  Wentao Ye 3778673ea8
[Feat] Refactor for `parallel_config` in `FusedMoEModularKernel` (#30282) 1 day ago
  Seokhyun An b337647aa0
[Bugfix] Drop empty tool_calls lists to keep assistant replies in chat template (#30648) 1 day ago
  Jee Jee Li a524d1ba0a
[Bugfix] Fix deepseek_v32 tokenizer_mode (#30658) 1 day ago
  Shanshan Shen 87b4d1557d
[CustomOp][MM] Extract MMEncoderAttention as CustomOp and replace the backend of QwenVisionAttention with it. (#30125) 1 day ago
  Wenqi Glantz 84e23d103d
additional protection for CVE-2025-62164 (#30649) 1 day ago
  Shanshan Shen 738648fb81
[CustomOp] Support object-level enable for CustomOp (#30547) 1 day ago
  Boyuan Feng 917fdae5b2
[Log] Skip piecewise cudagraph warn when using full cudagraph (#30657) 1 day ago
  Robert Shaw e2ed238885
Revert "[Fix]Load kv-cache dtype from hf_quant_config.json automatically" (#30653) 1 day ago
  Or Ozeri 174e39ead7
CPU KV Offloading: Use more CUDA streams (#29013) 1 day ago
  RioS 9ccbf6b692
[responsesAPI]add extra body parameters (#30532) 1 day ago
  Chendi.Xue ae2e503dda
[NIXL][BUG FIX] Fix a bug for PD with host_buffer after merging 29665 (#30420) 1 day ago
  Tsukasa OI 9e33a1a75b
[Model][Quantization] Override HF defaults to GGUF ones (incl. Qwen3 MoE) (#30118) 1 day ago
  Vensen add4b0ca44
[Bugfix][benchmarks] Fix input token calculation for rerank benchmark metrics (#30596) 1 day ago
100 changed files with 3081 additions and 1331 deletions
Split View
  1. +5
    -15
      .buildkite/test-amd.yaml
  2. +5
    -12
      .buildkite/test-pipeline.yaml
  3. +3
    -1
      .buildkite/test_areas/misc.yaml
  4. +1
    -11
      .buildkite/test_areas/tool_use.yaml
  5. +3
    -0
      docker/Dockerfile.xpu
  6. +2
    -2
      docs/features/tool_calling.md
  7. +1
    -0
      docs/models/supported_models.md
  8. +27
    -0
      examples/offline_inference/vision_language.py
  9. +1
    -1
      tests/entrypoints/openai/test_serving_chat.py
  10. +342
    -0
      tests/entrypoints/openai/test_sparse_tensor_validation.py
  11. +1
    -1
      tests/entrypoints/openai/tool_parsers/test_gigachat3_tool_parser.py
  12. +1
    -1
      tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py
  13. +1
    -1
      tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py
  14. +1
    -1
      tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py
  15. +1
    -1
      tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py
  16. +1
    -1
      tests/entrypoints/openai/tool_parsers/test_olmo3_tool_parser.py
  17. +1
    -1
      tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py
  18. +1
    -1
      tests/entrypoints/openai/tool_parsers/utils.py
  19. +2
    -1
      tests/kernels/moe/modular_kernel_tools/common.py
  20. +14
    -0
      tests/kernels/moe/test_flashinfer.py
  21. +3
    -3
      tests/models/language/generation/test_mistral.py
  22. +31
    -0
      tests/models/language/pooling/test_token_classification.py
  23. +434
    -0
      tests/models/multimodal/generation/test_vit_backend_functionality.py
  24. +2
    -0
      tests/models/registry.py
  25. +134
    -0
      tests/multimodal/test_sparse_tensor_validation_unit.py
  26. +0
    -0
      tests/tool_parsers/__init__.py
  27. +2
    -2
      tests/tool_parsers/test_deepseekv31_tool_parser.py
  28. +1
    -1
      tests/tool_parsers/test_ernie45_moe_tool_parser.py
  29. +2
    -4
      tests/tool_parsers/test_glm4_moe_tool_parser.py
  30. +1
    -3
      tests/tool_parsers/test_jamba_tool_parser.py
  31. +1
    -3
      tests/tool_parsers/test_kimi_k2_tool_parser.py
  32. +1
    -3
      tests/tool_parsers/test_minimax_tool_parser.py
  33. +1
    -1
      tests/tool_parsers/test_mistral_tool_parser.py
  34. +1
    -1
      tests/tool_parsers/test_openai_tool_parser.py
  35. +4
    -6
      tests/tool_parsers/test_qwen3coder_tool_parser.py
  36. +1
    -3
      tests/tool_parsers/test_seed_oss_tool_parser.py
  37. +1
    -3
      tests/tool_parsers/test_xlam_tool_parser.py
  38. +1
    -1
      tests/tool_use/test_tool_choice_required.py
  39. +6
    -6
      tests/v1/kv_connector/unit/test_nixl_connector.py
  40. +10
    -12
      tests/v1/kv_offload/test_cpu_gpu.py
  41. +5
    -68
      vllm/attention/layer.py
  42. +284
    -0
      vllm/attention/layers/mm_encoder_attention.py
  43. +3
    -8
      vllm/attention/ops/vit_attn_wrappers.py
  44. +3
    -1
      vllm/benchmarks/serve.py
  45. +7
    -3
      vllm/config/compilation.py
  46. +1
    -0
      vllm/config/model.py
  47. +3
    -1
      vllm/config/scheduler.py
  48. +60
    -36
      vllm/config/vllm.py
  49. +56
    -39
      vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
  50. +1
    -1
      vllm/engine/arg_utils.py
  51. +20
    -6
      vllm/entrypoints/chat_utils.py
  52. +8
    -8
      vllm/entrypoints/context.py
  53. +1
    -1
      vllm/entrypoints/openai/api_server.py
  54. +1
    -1
      vllm/entrypoints/openai/cli_args.py
  55. +1
    -1
      vllm/entrypoints/openai/parser/responses_parser.py
  56. +9
    -0
      vllm/entrypoints/openai/protocol.py
  57. +2
    -2
      vllm/entrypoints/openai/serving_chat.py
  58. +2
    -2
      vllm/entrypoints/openai/serving_engine.py
  59. +23
    -140
      vllm/entrypoints/openai/tool_parsers/__init__.py
  60. +1
    -0
      vllm/entrypoints/pooling/score/protocol.py
  61. +3
    -1
      vllm/entrypoints/pooling/score/serving.py
  62. +14
    -11
      vllm/entrypoints/renderer.py
  63. +7
    -2
      vllm/model_executor/custom_op.py
  64. +0
    -2
      vllm/model_executor/layers/fused_moe/cutlass_moe.py
  65. +1
    -1
      vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
  66. +1
    -6
      vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
  67. +13
    -8
      vllm/model_executor/layers/fused_moe/modular_kernel.py
  68. +0
    -3
      vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
  69. +154
    -89
      vllm/model_executor/layers/quantization/fp8.py
  70. +1
    -6
      vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
  71. +12
    -0
      vllm/model_executor/models/adapters.py
  72. +584
    -0
      vllm/model_executor/models/bagel.py
  73. +46
    -83
      vllm/model_executor/models/dots_ocr.py
  74. +28
    -80
      vllm/model_executor/models/ernie45_vl.py
  75. +46
    -91
      vllm/model_executor/models/glm4_1v.py
  76. +27
    -80
      vllm/model_executor/models/keye.py
  77. +1
    -7
      vllm/model_executor/models/opencua.py
  78. +6
    -16
      vllm/model_executor/models/ovis2_5.py
  79. +30
    -75
      vllm/model_executor/models/paddleocr_vl.py
  80. +3
    -0
      vllm/model_executor/models/qwen.py
  81. +32
    -0
      vllm/model_executor/models/qwen2.py
  82. +1
    -0
      vllm/model_executor/models/qwen2_5_omni_thinker.py
  83. +48
    -76
      vllm/model_executor/models/qwen2_5_vl.py
  84. +47
    -96
      vllm/model_executor/models/qwen2_vl.py
  85. +12
    -8
      vllm/model_executor/models/qwen3_omni_moe_thinker.py
  86. +24
    -22
      vllm/model_executor/models/qwen3_vl.py
  87. +5
    -1
      vllm/model_executor/models/qwen3_vl_moe.py
  88. +1
    -0
      vllm/model_executor/models/registry.py
  89. +45
    -84
      vllm/model_executor/models/siglip2navit.py
  90. +8
    -5
      vllm/model_executor/models/vision.py
  91. +10
    -2
      vllm/multimodal/audio.py
  92. +10
    -2
      vllm/multimodal/image.py
  93. +36
    -18
      vllm/platforms/cuda.py
  94. +38
    -7
      vllm/platforms/interface.py
  95. +38
    -19
      vllm/platforms/rocm.py
  96. +27
    -1
      vllm/platforms/tpu.py
  97. +29
    -7
      vllm/platforms/xpu.py
  98. +0
    -0
      vllm/tokenizers/deepseek_v32.py
  99. +1
    -1
      vllm/tokenizers/registry.py
  100. +150
    -0
      vllm/tool_parsers/__init__.py

+ 5
- 15
.buildkite/test-amd.yaml View File

@@ -61,8 +61,8 @@ steps:
- pytest -v -s -m 'not cpu_test' multimodal
- pytest -v -s utils_

- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 15min
timeout_in_minutes: 20
- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 20min
timeout_in_minutes: 30
mirror_hardwares: [amdexperimental, amdproduction, amdtentative]
agent_pool: mi325_1
grade: Blocking
@@ -73,6 +73,7 @@ steps:
- tests/multimodal
- tests/standalone_tests/lazy_imports.py
- tests/tokenizers_
- tests/tool_parsers
- tests/transformers_utils
- tests/config
no_gpu: true
@@ -82,6 +83,7 @@ steps:
- pytest -v -s test_outputs.py
- pytest -v -s -m 'cpu_test' multimodal
- pytest -v -s tokenizers_
- pytest -v -s tool_parsers
- pytest -v -s transformers_utils
- pytest -v -s config

@@ -759,19 +761,7 @@ steps:
- vllm/
- tests/tool_use
commands:
- pytest -v -s -m 'not cpu_test' tool_use

- label: OpenAI-Compatible Tool Use (CPU) # 5 mins
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1
# grade: Blocking
timeout_in_minutes: 10
source_file_dependencies:
- vllm/
- tests/tool_use
no_gpu: true
commands:
- pytest -v -s -m 'cpu_test' tool_use
- pytest -v -s tool_use

##### models test #####



+ 5
- 12
.buildkite/test-pipeline.yaml View File

@@ -57,8 +57,8 @@ steps:
- pytest -v -s -m 'not cpu_test' multimodal
- pytest -v -s utils_

- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 15min
timeout_in_minutes: 20
- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 20min
timeout_in_minutes: 30
source_file_dependencies:
- vllm/
- tests/test_inputs.py
@@ -66,6 +66,7 @@ steps:
- tests/multimodal
- tests/standalone_tests/lazy_imports.py
- tests/tokenizers_
- tests/tool_parsers
- tests/transformers_utils
- tests/config
no_gpu: true
@@ -75,6 +76,7 @@ steps:
- pytest -v -s test_outputs.py
- pytest -v -s -m 'cpu_test' multimodal
- pytest -v -s tokenizers_
- pytest -v -s tool_parsers
- pytest -v -s transformers_utils
- pytest -v -s config

@@ -672,16 +674,7 @@ steps:
- vllm/
- tests/tool_use
commands:
- pytest -v -s -m 'not cpu_test' tool_use

- label: OpenAI-Compatible Tool Use (CPU) # 5 mins
timeout_in_minutes: 10
source_file_dependencies:
- vllm/
- tests/tool_use
no_gpu: true
commands:
- pytest -v -s -m 'cpu_test' tool_use
- pytest -v -s tool_use

##### models test #####



+ 3
- 1
.buildkite/test_areas/misc.yaml View File

@@ -115,7 +115,7 @@ steps:

- label: Async Engine, Inputs, Utils, Worker, Config (CPU)
depends_on: ~
timeout_in_minutes: 20
timeout_in_minutes: 30
source_file_dependencies:
- vllm/
- tests/test_inputs.py
@@ -123,6 +123,7 @@ steps:
- tests/multimodal
- tests/standalone_tests/lazy_imports.py
- tests/tokenizers_
- tests/tool_parsers
- tests/transformers_utils
- tests/config
no_gpu: true
@@ -132,6 +133,7 @@ steps:
- pytest -v -s test_outputs.py
- pytest -v -s -m 'cpu_test' multimodal
- pytest -v -s tokenizers_
- pytest -v -s tool_parsers
- pytest -v -s transformers_utils
- pytest -v -s config



+ 1
- 11
.buildkite/test_areas/tool_use.yaml View File

@@ -10,14 +10,4 @@ steps:
- vllm/
- tests/tool_use
commands:
- pytest -v -s -m 'not cpu_test' tool_use

- label: OpenAI-Compatible Tool Use (CPU)
depends_on: ~
timeout_in_minutes: 10
source_file_dependencies:
- vllm/
- tests/tool_use
no_gpu: true
commands:
- pytest -v -s -m 'cpu_test' tool_use
- pytest -v -s tool_use

+ 3
- 0
docker/Dockerfile.xpu View File

@@ -76,6 +76,9 @@ RUN python3 -m pip install -e tests/vllm_test_utils
ENV NIXL_VERSION=0.7.0
RUN python3 /workspace/vllm/tools/install_nixl_from_source_ubuntu.py

# PyJWT-2.7.0 will influence some wheel behaviors, remove its dist-info to avoid conflicts
RUN rm /usr/lib/python3/dist-packages/PyJWT-2.7.0.dist-info/ -rf

# remove torch bundled oneccl to avoid conflicts
RUN --mount=type=cache,target=/root/.cache/pip \
pip uninstall oneccl oneccl-devel -y


+ 2
- 2
docs/features/tool_calling.md View File

@@ -420,7 +420,7 @@ Flags: `--tool-call-parser pythonic --chat-template {see_above}`

## How to Write a Tool Parser Plugin

A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in [vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py](../../vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py).
A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in [vllm/tool_parsers/hermes_tool_parser.py](../../vllm/tool_parsers/hermes_tool_parser.py).

Here is a summary of a plugin file:

@@ -468,7 +468,7 @@ Here is a summary of a plugin file:
# register the tool parser to ToolParserManager
ToolParserManager.register_lazy_module(
name="example",
module_path="vllm.entrypoints.openai.tool_parsers.example",
module_path="vllm.tool_parsers.example",
class_name="ExampleToolParser",
)



+ 1
- 0
docs/models/supported_models.md View File

@@ -661,6 +661,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | | |
| `AudioFlamingo3ForConditionalGeneration` | AudioFlamingo3 | T + A<sup>+</sup> | `nvidia/audio-flamingo-3-hf`, `nvidia/music-flamingo-hf` | ✅︎ | ✅︎ |
| `AyaVisionForConditionalGeneration` | Aya Vision | T + I<sup>+</sup> | `CohereLabs/aya-vision-8b`, `CohereLabs/aya-vision-32b`, etc. | | ✅︎ |
| `BagelForConditionalGeneration` | BAGEL | T + I<sup>+</sup> | `ByteDance-Seed/BAGEL-7B-MoT` | ✅︎ | ✅︎ |
| `BeeForConditionalGeneration` | Bee-8B | T + I<sup>E+</sup> | `Open-Bee/Bee-8B-RL`, `Open-Bee/Bee-8B-SFT` | | ✅︎ |
| `Blip2ForConditionalGeneration` | BLIP-2 | T + I<sup>E</sup> | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | | ✅︎ |
| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ |


+ 27
- 0
examples/offline_inference/vision_language.py View File

@@ -118,6 +118,32 @@ def run_bee(questions: list[str], modality: str) -> ModelRequestData:
)


def run_bagel(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "ByteDance-Seed/BAGEL-7B-MoT"

engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=8192,
max_num_seqs=2,
limit_mm_per_prompt={modality: 1},
)

prompts = [
(
f"<|im_start|>user\n<|image_pad|>\n{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
for question in questions
]

return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)


# BLIP-2
def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
@@ -1832,6 +1858,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData:
model_example_map = {
"aria": run_aria,
"aya_vision": run_aya_vision,
"bagel": run_bagel,
"bee": run_bee,
"blip-2": run_blip2,
"chameleon": run_chameleon,


+ 1
- 1
tests/entrypoints/openai/test_serving_chat.py View File

@@ -19,9 +19,9 @@ from vllm.entrypoints.openai.protocol import (
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.tokenizers import get_tokenizer
from vllm.tool_parsers import ToolParserManager
from vllm.v1.engine.async_llm import AsyncLLM

from ...utils import RemoteOpenAIServer


+ 342
- 0
tests/entrypoints/openai/test_sparse_tensor_validation.py View File

@@ -0,0 +1,342 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Sparse tensor validation in embedding APIs.

Tests verify that malicious sparse tensors are rejected before they can trigger
out-of-bounds memory writes during to_dense() operations.
"""

import base64
import io

import pytest
import torch

from vllm.entrypoints.renderer import CompletionRenderer
from vllm.multimodal.audio import AudioEmbeddingMediaIO
from vllm.multimodal.image import ImageEmbeddingMediaIO


def _encode_tensor(tensor: torch.Tensor) -> bytes:
"""Helper to encode a tensor as base64 bytes."""
buffer = io.BytesIO()
torch.save(tensor, buffer)
buffer.seek(0)
return base64.b64encode(buffer.read())


def _create_malicious_sparse_tensor() -> torch.Tensor:
"""
Create a malicious sparse COO tensor with out-of-bounds indices.

This tensor has indices that point beyond the declared shape, which would
cause an out-of-bounds write when converted to dense format without
validation.
"""
# Create a 3x3 sparse tensor but with indices pointing to (10, 10)
indices = torch.tensor([[10], [10]]) # Out of bounds for 3x3 shape
values = torch.tensor([1.0])
shape = (3, 3)

# Create sparse tensor (this will be invalid)
sparse_tensor = torch.sparse_coo_tensor(indices, values, shape, dtype=torch.float32)
return sparse_tensor


def _create_valid_sparse_tensor() -> torch.Tensor:
"""Create a valid sparse COO tensor for baseline testing."""
indices = torch.tensor([[0, 1, 2], [0, 1, 2]])
values = torch.tensor([1.0, 2.0, 3.0])
shape = (3, 3)

sparse_tensor = torch.sparse_coo_tensor(indices, values, shape, dtype=torch.float32)
return sparse_tensor


def _create_valid_dense_tensor() -> torch.Tensor:
"""Create a valid dense tensor for baseline testing."""
return torch.randn(10, 768, dtype=torch.float32) # (seq_len, hidden_size)


class TestPromptEmbedsValidation:
"""Test sparse tensor validation in prompt embeddings (Completions API)."""

def test_valid_dense_tensor_accepted(self, model_config):
"""Baseline: Valid dense tensors should work normally."""
renderer = CompletionRenderer(model_config)

valid_tensor = _create_valid_dense_tensor()
encoded = _encode_tensor(valid_tensor)

# Should not raise any exception
result = renderer.load_prompt_embeds(encoded)
assert len(result) == 1
assert result[0]["prompt_embeds"].shape == valid_tensor.shape

def test_valid_sparse_tensor_accepted(self):
"""Baseline: Valid sparse tensors should load successfully."""
io_handler = ImageEmbeddingMediaIO()

valid_sparse = _create_valid_sparse_tensor()
encoded = _encode_tensor(valid_sparse)

# Should not raise any exception (sparse tensors remain sparse)
result = io_handler.load_base64("", encoded.decode("utf-8"))
assert result.shape == valid_sparse.shape

def test_malicious_sparse_tensor_rejected(self, model_config):
"""Security: Malicious sparse tensors should be rejected."""
renderer = CompletionRenderer(model_config)

malicious_tensor = _create_malicious_sparse_tensor()
encoded = _encode_tensor(malicious_tensor)

# Should raise RuntimeError due to invalid sparse tensor
with pytest.raises((RuntimeError, ValueError)) as exc_info:
renderer.load_prompt_embeds(encoded)

# Error should indicate sparse tensor validation failure
error_msg = str(exc_info.value).lower()
assert "sparse" in error_msg or "index" in error_msg or "bounds" in error_msg

def test_extremely_large_indices_rejected(self, model_config):
"""Security: Sparse tensors with extremely large indices should be rejected."""
renderer = CompletionRenderer(model_config)

# Create tensor with indices far beyond reasonable bounds
indices = torch.tensor([[999999], [999999]])
values = torch.tensor([1.0])
shape = (10, 10)

malicious_tensor = torch.sparse_coo_tensor(
indices, values, shape, dtype=torch.float32
)
encoded = _encode_tensor(malicious_tensor)

with pytest.raises((RuntimeError, ValueError)):
renderer.load_prompt_embeds(encoded)

def test_negative_indices_rejected(self, model_config):
"""Security: Sparse tensors with negative indices should be rejected."""
renderer = CompletionRenderer(model_config)

# Create tensor with negative indices
indices = torch.tensor([[-1], [-1]])
values = torch.tensor([1.0])
shape = (10, 10)

malicious_tensor = torch.sparse_coo_tensor(
indices, values, shape, dtype=torch.float32
)
encoded = _encode_tensor(malicious_tensor)

with pytest.raises((RuntimeError, ValueError)):
renderer.load_prompt_embeds(encoded)


class TestImageEmbedsValidation:
"""Test sparse tensor validation in image embeddings (Chat API)."""

def test_valid_dense_tensor_accepted(self):
"""Baseline: Valid dense tensors should work normally."""
io_handler = ImageEmbeddingMediaIO()

valid_tensor = _create_valid_dense_tensor()
encoded = _encode_tensor(valid_tensor)

# Should not raise any exception
result = io_handler.load_base64("", encoded.decode("utf-8"))
assert result.shape == valid_tensor.shape

def test_valid_sparse_tensor_accepted(self):
"""Baseline: Valid sparse tensors should load successfully."""
io_handler = AudioEmbeddingMediaIO()

valid_sparse = _create_valid_sparse_tensor()
encoded = _encode_tensor(valid_sparse)

# Should not raise any exception (sparse tensors remain sparse)
result = io_handler.load_base64("", encoded.decode("utf-8"))
assert result.shape == valid_sparse.shape

def test_malicious_sparse_tensor_rejected(self):
"""Security: Malicious sparse tensors should be rejected."""
io_handler = ImageEmbeddingMediaIO()

malicious_tensor = _create_malicious_sparse_tensor()
encoded = _encode_tensor(malicious_tensor)

# Should raise RuntimeError due to invalid sparse tensor
with pytest.raises((RuntimeError, ValueError)) as exc_info:
io_handler.load_base64("", encoded.decode("utf-8"))

error_msg = str(exc_info.value).lower()
assert "sparse" in error_msg or "index" in error_msg or "bounds" in error_msg

def test_load_bytes_validates(self):
"""Security: Validation should also work for load_bytes method."""
io_handler = ImageEmbeddingMediaIO()

malicious_tensor = _create_malicious_sparse_tensor()
buffer = io.BytesIO()
torch.save(malicious_tensor, buffer)
buffer.seek(0)

with pytest.raises((RuntimeError, ValueError)):
io_handler.load_bytes(buffer.read())


class TestAudioEmbedsValidation:
"""Test sparse tensor validation in audio embeddings (Chat API)."""

def test_valid_dense_tensor_accepted(self):
"""Baseline: Valid dense tensors should work normally."""
io_handler = AudioEmbeddingMediaIO()

valid_tensor = _create_valid_dense_tensor()
encoded = _encode_tensor(valid_tensor)

# Should not raise any exception
result = io_handler.load_base64("", encoded.decode("utf-8"))
assert result.shape == valid_tensor.shape

def test_valid_sparse_tensor_accepted(self):
"""Baseline: Valid sparse tensors should be converted successfully."""
io_handler = AudioEmbeddingMediaIO()

valid_sparse = _create_valid_sparse_tensor()
encoded = _encode_tensor(valid_sparse)

# Should not raise any exception
result = io_handler.load_base64("", encoded.decode("utf-8"))
assert result.is_sparse is False

def test_malicious_sparse_tensor_rejected(self):
"""Security: Malicious sparse tensors should be rejected."""
io_handler = AudioEmbeddingMediaIO()

malicious_tensor = _create_malicious_sparse_tensor()
encoded = _encode_tensor(malicious_tensor)

# Should raise RuntimeError due to invalid sparse tensor
with pytest.raises((RuntimeError, ValueError)) as exc_info:
io_handler.load_base64("", encoded.decode("utf-8"))

error_msg = str(exc_info.value).lower()
assert "sparse" in error_msg or "index" in error_msg or "bounds" in error_msg

def test_load_bytes_validates(self):
"""Security: Validation should also work for load_bytes method."""
io_handler = AudioEmbeddingMediaIO()

malicious_tensor = _create_malicious_sparse_tensor()
buffer = io.BytesIO()
torch.save(malicious_tensor, buffer)
buffer.seek(0)

with pytest.raises((RuntimeError, ValueError)):
io_handler.load_bytes(buffer.read())


class TestSparseTensorValidationIntegration:
"""
These tests verify the complete attack chain is blocked at all entry points.
"""

def test_attack_scenario_completions_api(self, model_config):
"""
Simulate a complete attack through the Completions API.

Attack scenario:
1. Attacker crafts malicious sparse tensor
2. Encodes it as base64
3. Sends to /v1/completions with prompt_embeds parameter
4. Server should reject before memory corruption occurs
"""
renderer = CompletionRenderer(model_config)

# Step 1-2: Attacker creates malicious payload
attack_payload = _encode_tensor(_create_malicious_sparse_tensor())

# Step 3-4: Server processes and should reject
with pytest.raises((RuntimeError, ValueError)):
renderer.load_prompt_embeds(attack_payload)

def test_attack_scenario_chat_api_image(self):
"""
Simulate attack through Chat API with image_embeds.

Verifies the image embeddings path is protected.
"""
io_handler = ImageEmbeddingMediaIO()
attack_payload = _encode_tensor(_create_malicious_sparse_tensor())

with pytest.raises((RuntimeError, ValueError)):
io_handler.load_base64("", attack_payload.decode("utf-8"))

def test_attack_scenario_chat_api_audio(self):
"""
Simulate attack through Chat API with audio_embeds.

Verifies the audio embeddings path is protected.
"""
io_handler = AudioEmbeddingMediaIO()
attack_payload = _encode_tensor(_create_malicious_sparse_tensor())

with pytest.raises((RuntimeError, ValueError)):
io_handler.load_base64("", attack_payload.decode("utf-8"))

def test_multiple_valid_embeddings_in_batch(self, model_config):
"""
Regression test: Multiple valid embeddings should still work.

Ensures the fix doesn't break legitimate batch processing.
"""
renderer = CompletionRenderer(model_config)

valid_tensors = [
_encode_tensor(_create_valid_dense_tensor()),
_encode_tensor(_create_valid_dense_tensor()),
_encode_tensor(_create_valid_dense_tensor()),
]

# Should process all without error
result = renderer.load_prompt_embeds(valid_tensors)
assert len(result) == 3

def test_mixed_valid_and_malicious_rejected(self, model_config):
"""
Security: Batch with one malicious tensor should be rejected.

Even if most tensors are valid, a single malicious one should
cause rejection of the entire batch.
"""
renderer = CompletionRenderer(model_config)

mixed_batch = [
_encode_tensor(_create_valid_dense_tensor()),
_encode_tensor(_create_malicious_sparse_tensor()), # Malicious
_encode_tensor(_create_valid_dense_tensor()),
]

# Should fail on the malicious tensor
with pytest.raises((RuntimeError, ValueError)):
renderer.load_prompt_embeds(mixed_batch)


# Pytest fixtures
@pytest.fixture
def model_config():
"""Mock ModelConfig for testing."""
from vllm.config import ModelConfig

return ModelConfig(
model="facebook/opt-125m",
tokenizer="facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float32",
seed=0,
enable_prompt_embeds=True, # Required for prompt embeds tests
)

+ 1
- 1
tests/entrypoints/openai/tool_parsers/test_gigachat3_tool_parser.py View File

@@ -10,8 +10,8 @@ from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser, ToolParserManager

SIMPLE_ARGS_DICT = {
"action": "create",


+ 1
- 1
tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py View File

@@ -6,8 +6,8 @@ import json
import pytest

from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser

from ....utils import RemoteOpenAIServer



+ 1
- 1
tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py View File

@@ -12,7 +12,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.tool_parsers import ToolParser, ToolParserManager


def make_tool_call(name, arguments):


+ 1
- 1
tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py View File

@@ -6,8 +6,8 @@ from unittest.mock import MagicMock, patch
import pytest

from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation
from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import Llama3JsonToolParser
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.llama_tool_parser import Llama3JsonToolParser


@pytest.fixture


+ 1
- 1
tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py View File

@@ -10,8 +10,8 @@ from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser, ToolParserManager

# Test cases similar to pythonic parser but with Llama4 specific format
SIMPLE_FUNCTION_OUTPUT = "[get_weather(city='LA', metric='C')]"


+ 1
- 1
tests/entrypoints/openai/tool_parsers/test_olmo3_tool_parser.py View File

@@ -10,8 +10,8 @@ from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser, ToolParserManager

# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"


+ 1
- 1
tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py View File

@@ -10,8 +10,8 @@ from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser, ToolParserManager

# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"


+ 1
- 1
tests/entrypoints/openai/tool_parsers/utils.py View File

@@ -10,8 +10,8 @@ from vllm.entrypoints.openai.protocol import (
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser


class StreamingToolReconstructor:


+ 2
- 1
tests/kernels/moe/modular_kernel_tools/common.py View File

@@ -594,7 +594,8 @@ def make_modular_kernel(
)

modular_kernel = mk.FusedMoEModularKernel(
prepare_finalize=prepare_finalize, fused_experts=fused_experts
prepare_finalize=prepare_finalize,
fused_experts=fused_experts,
)

return modular_kernel


+ 14
- 0
tests/kernels/moe/test_flashinfer.py View File

@@ -5,6 +5,7 @@ from dataclasses import dataclass
import pytest
import torch

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
@@ -107,6 +108,19 @@ class TestData:
layer.w2_input_scale = a2_scale
layer.w13_weight_scale = w13_weight_scale
layer.w2_weight_scale = w2_weight_scale
# Setup dummy config.
layer.moe_parallel_config = mk.FusedMoEParallelConfig(
tp_size=1,
pcp_size=1,
dp_size=1,
ep_size=1,
tp_rank=1,
pcp_rank=1,
dp_rank=1,
ep_rank=1,
use_ep=False,
all2all_backend="naive",
)

register_moe_scaling_factors(layer)



+ 3
- 3
tests/models/language/generation/test_mistral.py View File

@@ -5,12 +5,12 @@ import json

import pytest

from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
from vllm.sampling_params import SamplingParams
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.tool_parsers.mistral_tool_parser import (
MistralToolCall,
MistralToolParser,
)
from vllm.sampling_params import SamplingParams
from vllm.tokenizers.mistral import MistralTokenizer

from ...utils import check_logprobs_close



+ 31
- 0
tests/models/language/pooling/test_token_classification.py View File

@@ -68,3 +68,34 @@ def test_modernbert_models(
hf_output = torch.tensor(hf_output).cpu().float()
vllm_output = torch.tensor(vllm_output).cpu().float()
assert torch.allclose(hf_output, vllm_output, atol=1e-2)


@pytest.mark.parametrize("model", ["bd2lcco/Qwen3-0.6B-finetuned"])
@pytest.mark.parametrize("dtype", ["float"])
@torch.inference_mode
def test_auto_conversion(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
) -> None:
with vllm_runner(model, max_model_len=1024, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.token_classify(example_prompts)

with hf_runner(
model, dtype=dtype, auto_cls=AutoModelForTokenClassification
) as hf_model:
tokenizer = hf_model.tokenizer
hf_outputs = []
for prompt in example_prompts:
inputs = tokenizer([prompt], return_tensors="pt")
inputs = hf_model.wrap_device(inputs)
output = hf_model.model(**inputs)
hf_outputs.append(softmax(output.logits[0]))

# check logits difference
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = torch.tensor(hf_output).cpu().float()
vllm_output = torch.tensor(vllm_output).cpu().float()
assert torch.allclose(hf_output, vllm_output, atol=1e-2)

+ 434
- 0
tests/models/multimodal/generation/test_vit_backend_functionality.py View File

@@ -0,0 +1,434 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Consolidated test for ViT attention backend functionality across multiple models.

This test validates that each multimodal model can successfully generate outputs
using different ViT attention backends. Tests are parametrized by model and backend.
"""

from dataclasses import asdict
from typing import Any

import pytest
from transformers import AutoProcessor

from vllm import LLM, EngineArgs, SamplingParams
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.multimodal.utils import encode_image_base64
from vllm.multimodal.video import sample_frames_from_video
from vllm.platforms import current_platform

from ....utils import create_new_process_for_each_test
from ...utils import dummy_hf_overrides

# Dots.OCR prompt from official repository
# https://github.com/rednote-hilab/dots.ocr/blob/d72d1d8c5bdd0362eb264f714cdbd1e5daa7cdff/dots_ocr/utils/prompts.py#L3
# ruff: noqa: E501
DOTS_OCR_PROMPT = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.

1. Bbox format: [x1, y1, x2, y2]

2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].

3. Text Extraction & Formatting Rules:
- Picture: For the 'Picture' category, the text field should be omitted.
- Formula: Format its text as LaTeX.
- Table: Format its text as HTML.
- All Others (Text, Title, etc.): Format their text as Markdown.

4. Constraints:
- The output text must be the original text from the image, with no translation.
- All layout elements must be sorted according to human reading order.

5. Final Output: The entire output must be a single JSON object.
"""

VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>"


# Model configurations
MODEL_CONFIGS: dict[str, dict[str, Any]] = {
"dots_ocr": {
"model_name": "rednote-hilab/dots.ocr",
"interface": "llm_chat",
"max_model_len": 32768,
"max_num_seqs": 1,
"limit_mm_per_prompt": {"image": 1},
"sampling_params": {
"temperature": 0.1,
"max_tokens": 16384,
"top_p": 0.9,
"stop_token_ids": None,
},
"use_specific_image": "stop_sign",
"prompt_builder": "build_dots_ocr_prompt",
"output_validator": lambda x: len(x) > 10 and "stop" in x.lower(),
},
"ernie45_vl": {
"model_name": "baidu/ERNIE-4.5-VL-28B-A3B-PT",
"interface": "llm_generate",
"max_model_len": 16384,
"max_num_seqs": 2,
"sampling_params": {
"temperature": 0.0,
"max_tokens": 256,
"stop_token_ids": None,
},
"use_processor": True,
"question": "What is the content of each image?",
},
"glm4_1v": {
"model_name": "zai-org/GLM-4.1V-9B-Thinking",
"interface": "llm_generate",
"max_model_len": 32768,
"max_num_seqs": 2,
"sampling_params": {
"temperature": 0.0,
"max_tokens": 256,
"stop_token_ids": None,
},
"use_processor": True,
"question": "What is the content of each image?",
},
"keye_vl": {
"model_name": "Kwai-Keye/Keye-VL-8B-Preview",
"interface": "llm_generate",
"max_model_len": 8192,
"max_num_seqs": 5,
"sampling_params": {
"temperature": 0.0,
"max_tokens": 256,
"stop_token_ids": None,
},
"supported_backends": {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
},
"use_processor": True,
"question": "What is the content of each image?",
},
"ovis2_5": {
"model_name": "AIDC-AI/Ovis2.5-2B",
"interface": "llm_generate",
"max_model_len": 8192,
"max_num_seqs": 2,
"sampling_params": {
"temperature": 0.0,
"max_tokens": 256,
"stop_token_ids": None,
},
"prompt_builder": "build_ovis_prompt",
"question": "What is the content of each image?",
},
"qwen2_5_vl": {
"model_name": "Qwen/Qwen2.5-VL-3B-Instruct",
"interface": "vllm_runner",
"media_type": "video",
"max_model_len": 4000,
"max_num_seqs": 1,
"limit_mm_per_prompt": {"video": 1},
"sampling_params": {
"max_tokens": 128,
},
"runner_kwargs": {
"runner": "generate",
"dtype": "bfloat16",
},
"video_params": {
"num_frames": 16,
"pruning_rates": [0.0, 0.75],
},
},
"qwen2_5_omni": {
"model_name": "Qwen/Qwen2.5-Omni-3B",
"interface": "llm_generate",
"max_model_len": 32768,
"max_num_seqs": 2,
"limit_mm_per_prompt": {"image": 3, "video": 3, "audio": 3},
"sampling_params": {
"temperature": 0.6,
"top_p": 0.95,
"top_k": 20,
"max_tokens": 16384,
},
"use_processor": True,
"question": "What is the content of each image?",
},
"qwen3_omni": {
"model_name": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
"interface": "llm_generate",
"max_model_len": 32768,
"max_num_seqs": 2,
"limit_mm_per_prompt": {"image": 3, "video": 3, "audio": 3},
"sampling_params": {
"temperature": 0.6,
"top_p": 0.95,
"top_k": 20,
"max_tokens": 16384,
},
"use_processor": True,
"question": "What is the content of each image?",
},
}


# Prompt builder functions
def build_dots_ocr_prompt(images, config):
"""Build Dots.OCR specific prompt with OCR instructions."""
# Use only stop_sign image for Dots.OCR
image = images[0] # Already filtered to stop_sign

image_url = f"data:image/jpeg;base64,{encode_image_base64(image)}"

placeholders = [{"type": "image_url", "image_url": {"url": image_url}}]
messages = [
{
"role": "user",
"content": [
*placeholders,
{
"type": "text",
"text": f"<|img|><|imgpad|><|endofimg|>{DOTS_OCR_PROMPT}",
},
],
},
]

return messages


def build_processor_prompt(images, config):
"""Build prompt using AutoProcessor.apply_chat_template()."""
processor = AutoProcessor.from_pretrained(
config["model_name"], trust_remote_code=True
)

image_urls = [
f"data:image/jpeg;base64,{encode_image_base64(img)}" for img in images
]
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": config["question"]},
],
},
]

return processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)


def build_ovis_prompt(images, config):
"""Build Ovis2.5 specific prompt with custom format."""
image_urls = [
f"data:image/jpeg;base64,{encode_image_base64(img)}" for img in images
]

placeholders = "\n".join(
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
)

return (
f"<|im_start|>user\n\n{placeholders}\n{config['question']}<|im_end|>\n"
"<|im_start|>assistant\n"
)


def build_qwen2_5_video_prompt():
"""Build Qwen2.5-VL video prompt with EVS placeholder."""
return (
f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n{VIDEO_PLACEHOLDER}"
"Describe this video with a short sentence (no more than 20 words)"
"<|im_end|><|im_start|>assistant\n"
)


# Handler functions
def run_llm_generate_test(config, mm_encoder_attn_backend, image_assets):
"""Standard LLM.generate() interface handler."""
images = [asset.pil_image for asset in image_assets]

# Build prompt
if config.get("use_processor"):
prompt = build_processor_prompt(images, config)
else:
prompt_builder_name = config.get("prompt_builder", "build_ovis_prompt")
prompt_builder = globals()[prompt_builder_name]
prompt = prompt_builder(images, config)

# Determine limit_mm_per_prompt
limit_mm_per_prompt = config.get("limit_mm_per_prompt", {"image": len(images)})

# Create engine
engine_args = EngineArgs(
model=config["model_name"],
trust_remote_code=True,
max_model_len=config["max_model_len"],
max_num_seqs=config["max_num_seqs"],
limit_mm_per_prompt=limit_mm_per_prompt,
mm_encoder_attn_backend=mm_encoder_attn_backend,
hf_overrides=dummy_hf_overrides,
load_format="dummy",
)

engine_dict = asdict(engine_args) | {"seed": 42}
llm = LLM(**engine_dict)

# Generate
sampling_params = SamplingParams(**config["sampling_params"])
outputs = llm.generate(
{
"prompt": prompt,
"multi_modal_data": {"image": images},
},
sampling_params=sampling_params,
)

# Validate
for o in outputs:
generated_text = o.outputs[0].text
validator = config.get("output_validator", lambda x: len(x) > 10)
assert validator(generated_text), (
f"Validation failed for {config['model_name']}: {generated_text}"
)


def run_llm_chat_test(config, mm_encoder_attn_backend, image_assets):
"""LLM.chat() interface handler for Dots.OCR."""
# Filter to stop_sign image only
stop_sign_image = [
asset.pil_image for asset in image_assets if asset.name == "stop_sign"
][0]

# Build messages
messages = build_dots_ocr_prompt([stop_sign_image], config)

# Create engine
engine_args = EngineArgs(
model=config["model_name"],
trust_remote_code=True,
max_model_len=config["max_model_len"],
max_num_seqs=config["max_num_seqs"],
limit_mm_per_prompt=config["limit_mm_per_prompt"],
mm_encoder_attn_backend=mm_encoder_attn_backend,
hf_overrides=dummy_hf_overrides,
load_format="dummy",
)

engine_dict = asdict(engine_args) | {"seed": 42}
llm = LLM(**engine_dict)

# Generate using chat
sampling_params = SamplingParams(**config["sampling_params"])
outputs = llm.chat(messages=messages, sampling_params=sampling_params)

# Validate
for o in outputs:
generated_text = o.outputs[0].text
validator = config.get("output_validator", lambda x: len(x) > 10)
assert validator(generated_text), (
f"Validation failed for {config['model_name']}: {generated_text}"
)


def run_video_test(config, mm_encoder_attn_backend, video_assets, vllm_runner):
"""Video test with EVS (Efficient Video Sampling) handler."""
for pruning_rate in config["video_params"]["pruning_rates"]:
num_frames = config["video_params"]["num_frames"]

# Sample frames from video
sampled_vids = [
sample_frames_from_video(asset.np_ndarrays, num_frames)
for asset in video_assets
]

# Build prompt and prepare video
prompt = build_qwen2_5_video_prompt()
prompts = [prompt]
videos = [sampled_vids[0]]

# Run with vllm_runner context manager
with vllm_runner(
config["model_name"],
max_model_len=config["max_model_len"],
max_num_seqs=config["max_num_seqs"],
limit_mm_per_prompt=config["limit_mm_per_prompt"],
tensor_parallel_size=1,
video_pruning_rate=pruning_rate,
mm_encoder_attn_backend=mm_encoder_attn_backend,
hf_overrides=dummy_hf_overrides,
load_format="dummy",
**config["runner_kwargs"],
) as vllm_model:
outputs = vllm_model.generate_greedy(
prompts,
config["sampling_params"]["max_tokens"],
videos=videos,
)

# Validate output
assert len(outputs) == 1, f"Expected 1 output, got {len(outputs)}"
output_ids, output_text = outputs[0]
assert len(output_ids) > 0, "Generated no output IDs"
assert len(output_text) > 0, "Generated empty text"
assert isinstance(output_text, str), (
f"Output is not string: {type(output_text)}"
)


# Main test function
@pytest.mark.parametrize("model_key", list(MODEL_CONFIGS.keys()))
@pytest.mark.parametrize(
"mm_encoder_attn_backend",
[None] + current_platform.get_supported_vit_attn_backends(),
)
@create_new_process_for_each_test()
def test_vit_backend_functionality(
model_key: str,
mm_encoder_attn_backend: AttentionBackendEnum | None,
image_assets,
video_assets,
vllm_runner,
request,
):
"""Test ViT attention backend functionality for multimodal models.

This test validates that each model can successfully generate outputs
using different ViT attention backends. The test:
1. Filters unsupported backends per model
2. Applies appropriate GPU marks
3. Routes to the correct test handler based on interface
4. Validates output meets minimum requirements
"""
config = MODEL_CONFIGS[model_key]

# Step 1: Backend filtering
if (
"supported_backends" in config
and mm_encoder_attn_backend is not None
and mm_encoder_attn_backend not in config["supported_backends"]
):
pytest.skip(
f"{model_key} does not support {mm_encoder_attn_backend} backend now."
)

# Step 2: Apply GPU marks dynamically
if "gpu_marks" in config:
for mark in config["gpu_marks"]:
request.applymarker(mark)

# Step 3: Route to appropriate handler
if config.get("media_type") == "video":
run_video_test(config, mm_encoder_attn_backend, video_assets, vllm_runner)
elif config["interface"] == "llm_chat":
run_llm_chat_test(config, mm_encoder_attn_backend, image_assets)
elif config["interface"] == "llm_generate":
run_llm_generate_test(config, mm_encoder_attn_backend, image_assets)
else:
raise ValueError(f"Unknown interface: {config['interface']}")

+ 2
- 0
tests/models/registry.py View File

@@ -573,6 +573,7 @@ _AUTOMATIC_CONVERTED_MODELS = {
"Qwen3ForSequenceClassification": _HfExamplesInfo(
"tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
),
"Qwen3ForTokenClassification": _HfExamplesInfo("bd2lcco/Qwen3-0.6B-finetuned"),
}

_MULTIMODAL_EXAMPLE_MODELS = {
@@ -582,6 +583,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"nvidia/audio-flamingo-3-hf", min_transformers_version="5.0.0.dev"
),
"AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereLabs/aya-vision-8b"),
"BagelForConditionalGeneration": _HfExamplesInfo("ByteDance-Seed/BAGEL-7B-MoT"),
"BeeForConditionalGeneration": _HfExamplesInfo(
"Open-Bee/Bee-8B-RL",
trust_remote_code=True,


+ 134
- 0
tests/multimodal/test_sparse_tensor_validation_unit.py View File

@@ -0,0 +1,134 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for sparse tensor validation.

Simple, fast unit tests that can run without server fixtures.
Run with: pytest tests/multimodal/test_sparse_tensor_validation_unit.py -v
"""

import io

import pytest
import torch


class TestSparseTensorValidationContextManager:
"""Test that torch.sparse.check_sparse_tensor_invariants() works as expected."""

def test_valid_sparse_tensor_passes(self):
"""Valid sparse tensors should pass validation."""
indices = torch.tensor([[0, 1], [0, 1]])
values = torch.tensor([1.0, 2.0])
shape = (2, 2)

with torch.sparse.check_sparse_tensor_invariants():
tensor = torch.sparse_coo_tensor(indices, values, shape)
dense = tensor.to_dense()

assert dense.shape == shape

def test_out_of_bounds_indices_rejected(self):
"""Sparse tensors with out-of-bounds indices should be rejected."""
indices = torch.tensor([[5], [5]]) # Out of bounds for 2x2
values = torch.tensor([1.0])
shape = (2, 2)

with pytest.raises(RuntimeError) as exc_info: # noqa: SIM117
with torch.sparse.check_sparse_tensor_invariants():
tensor = torch.sparse_coo_tensor(indices, values, shape)
tensor.to_dense()

assert (
"index" in str(exc_info.value).lower()
or "bound" in str(exc_info.value).lower()
)

def test_negative_indices_rejected(self):
"""Sparse tensors with negative indices should be rejected."""
indices = torch.tensor([[-1], [0]])
values = torch.tensor([1.0])
shape = (2, 2)

with pytest.raises(RuntimeError): # noqa: SIM117
with torch.sparse.check_sparse_tensor_invariants():
tensor = torch.sparse_coo_tensor(indices, values, shape)
tensor.to_dense()

def test_without_context_manager_allows_invalid(self):
"""
WITHOUT validation, invalid tensors may not immediately error.

This demonstrates the vulnerability: PyTorch 2.8.0+ doesn't validate
by default, which can lead to memory corruption.
"""
indices = torch.tensor([[100], [100]]) # Way out of bounds
values = torch.tensor([1.0])
shape = (2, 2)

# Without validation context, this might create an invalid tensor
# (actual behavior depends on PyTorch version)
tensor = torch.sparse_coo_tensor(indices, values, shape)

# The tensor object is created, but it's invalid
assert tensor.is_sparse


class TestTorchLoadWithValidation:
"""Test torch.load() with sparse tensor validation."""

def test_load_valid_sparse_tensor_with_validation(self):
"""Valid sparse tensors should load successfully with validation."""
# Create and save a valid sparse tensor
indices = torch.tensor([[0, 1], [0, 1]])
values = torch.tensor([1.0, 2.0])
tensor = torch.sparse_coo_tensor(indices, values, (2, 2))

buffer = io.BytesIO()
torch.save(tensor, buffer)
buffer.seek(0)

# Load with validation
with torch.sparse.check_sparse_tensor_invariants():
loaded = torch.load(buffer, weights_only=True)
dense = loaded.to_dense()

assert dense.shape == (2, 2)

def test_load_invalid_sparse_tensor_rejected(self):
"""Invalid sparse tensors should be caught when loaded with validation."""
# Create an invalid sparse tensor (out of bounds)
indices = torch.tensor([[10], [10]])
values = torch.tensor([1.0])
tensor = torch.sparse_coo_tensor(indices, values, (2, 2))

buffer = io.BytesIO()
torch.save(tensor, buffer)
buffer.seek(0)

# Load with validation - should fail on to_dense()
with pytest.raises(RuntimeError): # noqa: SIM117
with torch.sparse.check_sparse_tensor_invariants():
loaded = torch.load(buffer, weights_only=True)
loaded.to_dense()

def test_load_dense_tensor_unaffected(self):
"""Dense tensors should work normally with the validation context."""
# Create and save a dense tensor
tensor = torch.randn(10, 20)

buffer = io.BytesIO()
torch.save(tensor, buffer)
buffer.seek(0)

# Load with validation (should have no effect on dense tensors)
with torch.sparse.check_sparse_tensor_invariants():
loaded = torch.load(buffer, weights_only=True)

assert loaded.shape == (10, 20)
assert not loaded.is_sparse


if __name__ == "__main__":
# Allow running directly for quick testing
pytest.main([__file__, "-v", "--tb=short"])

+ 0
- 0
tests/tool_parsers/__init__.py View File


tests/tool_use/test_deepseekv31_tool_parser.py → tests/tool_parsers/test_deepseekv31_tool_parser.py View File

@@ -3,10 +3,10 @@

import pytest

from vllm.entrypoints.openai.tool_parsers.deepseekv31_tool_parser import (
from vllm.tokenizers import get_tokenizer
from vllm.tool_parsers.deepseekv31_tool_parser import (
DeepSeekV31ToolParser,
)
from vllm.tokenizers import get_tokenizer

MODEL = "deepseek-ai/DeepSeek-V3.1"


tests/tool_use/test_ernie45_moe_tool_parser.py → tests/tool_parsers/test_ernie45_moe_tool_parser.py View File

@@ -13,9 +13,9 @@ from vllm.entrypoints.openai.protocol import (
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.ernie45_tool_parser import Ernie45ToolParser
from vllm.tokenizers import TokenizerLike, get_tokenizer
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally
from vllm.tool_parsers.ernie45_tool_parser import Ernie45ToolParser

# Use a common model that is likely to be available
MODEL = "baidu/ERNIE-4.5-21B-A3B-Thinking"

tests/tool_use/test_glm4_moe_tool_parser.py → tests/tool_parsers/test_glm4_moe_tool_parser.py View File

@@ -7,12 +7,10 @@ import json
import pytest

from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
from vllm.entrypoints.openai.tool_parsers.glm4_moe_tool_parser import (
from vllm.tokenizers import get_tokenizer
from vllm.tool_parsers.glm4_moe_tool_parser import (
Glm4MoeModelToolParser,
)
from vllm.tokenizers import get_tokenizer

pytestmark = pytest.mark.cpu_test

pytest.skip("skip glm4_moe parser test", allow_module_level=True)
# Use a common model that is likely to be available

tests/tool_use/test_jamba_tool_parser.py → tests/tool_parsers/test_jamba_tool_parser.py View File

@@ -9,11 +9,9 @@ import pytest
from partial_json_parser.core.options import Allow

from vllm.entrypoints.openai.protocol import DeltaMessage, FunctionCall, ToolCall
from vllm.entrypoints.openai.tool_parsers.jamba_tool_parser import JambaToolParser
from vllm.tokenizers import TokenizerLike, get_tokenizer
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally

pytestmark = pytest.mark.cpu_test
from vllm.tool_parsers.jamba_tool_parser import JambaToolParser

MODEL = "ai21labs/Jamba-tiny-dev"


tests/tool_use/test_kimi_k2_tool_parser.py → tests/tool_parsers/test_kimi_k2_tool_parser.py View File

@@ -7,10 +7,8 @@ import json
import pytest

from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
from vllm.entrypoints.openai.tool_parsers.kimi_k2_tool_parser import KimiK2ToolParser
from vllm.tokenizers import get_tokenizer

pytestmark = pytest.mark.cpu_test
from vllm.tool_parsers.kimi_k2_tool_parser import KimiK2ToolParser

# Use a common model that is likely to be available
MODEL = "moonshotai/Kimi-K2-Instruct"

tests/tool_use/test_minimax_tool_parser.py → tests/tool_parsers/test_minimax_tool_parser.py View File

@@ -12,10 +12,8 @@ from vllm.entrypoints.openai.protocol import (
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.minimax_tool_parser import MinimaxToolParser
from vllm.tokenizers import get_tokenizer

pytestmark = pytest.mark.cpu_test
from vllm.tool_parsers.minimax_tool_parser import MinimaxToolParser

# Use a common model that is likely to be available
MODEL = "MiniMaxAi/MiniMax-M1-40k"

tests/tool_use/test_mistral_tool_parser.py → tests/tool_parsers/test_mistral_tool_parser.py View File

@@ -12,10 +12,10 @@ from mistral_common.protocol.instruct.tool_calls import FunctionCall, ToolCall
from partial_json_parser.core.options import Allow

from vllm.entrypoints.openai.protocol import DeltaMessage, DeltaToolCall
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolParser
from vllm.tokenizers import TokenizerLike, get_tokenizer
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.tool_parsers.mistral_tool_parser import MistralToolParser


@pytest.fixture(scope="module")

tests/tool_use/test_openai_tool_parser.py → tests/tool_parsers/test_openai_tool_parser.py View File

@@ -15,8 +15,8 @@ from openai_harmony import (
)

from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
from vllm.entrypoints.openai.tool_parsers.openai_tool_parser import OpenAIToolParser
from vllm.tokenizers import get_tokenizer
from vllm.tool_parsers.openai_tool_parser import OpenAIToolParser

MODEL = "gpt2"


tests/tool_use/test_qwen3coder_tool_parser.py → tests/tool_parsers/test_qwen3coder_tool_parser.py View File

@@ -13,14 +13,12 @@ from vllm.entrypoints.openai.protocol import (
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import (
Qwen3CoderToolParser,
)
from vllm.entrypoints.openai.tool_parsers.qwen3xml_tool_parser import Qwen3XMLToolParser
from vllm.tokenizers import TokenizerLike, get_tokenizer
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally

pytestmark = pytest.mark.cpu_test
from vllm.tool_parsers.qwen3coder_tool_parser import (
Qwen3CoderToolParser,
)
from vllm.tool_parsers.qwen3xml_tool_parser import Qwen3XMLToolParser

MODEL = "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8"


tests/tool_use/test_seed_oss_tool_parser.py → tests/tool_parsers/test_seed_oss_tool_parser.py View File

@@ -14,11 +14,9 @@ from vllm.entrypoints.openai.protocol import (
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.seed_oss_tool_parser import SeedOssToolParser
from vllm.tokenizers import TokenizerLike, get_tokenizer
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally

pytestmark = pytest.mark.cpu_test
from vllm.tool_parsers.seed_oss_tool_parser import SeedOssToolParser

# Use a common model that is likely to be available
MODEL = "ByteDance-Seed/Seed-OSS-36B-Instruct"

tests/tool_use/test_xlam_tool_parser.py → tests/tool_parsers/test_xlam_tool_parser.py View File

@@ -12,11 +12,9 @@ from vllm.entrypoints.openai.protocol import (
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.xlam_tool_parser import xLAMToolParser
from vllm.tokenizers import TokenizerLike, get_tokenizer
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally

pytestmark = pytest.mark.cpu_test
from vllm.tool_parsers.xlam_tool_parser import xLAMToolParser

# Use a common model that is likely to be available
MODEL = "Salesforce/Llama-xLAM-2-8B-fc-r"

+ 1
- 1
tests/tool_use/test_tool_choice_required.py View File

@@ -12,7 +12,7 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionToolsParam,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.tool_parsers.utils import get_json_schema_from_tools
from vllm.tool_parsers.utils import get_json_schema_from_tools

pytestmark = pytest.mark.cpu_test



+ 6
- 6
tests/v1/kv_connector/unit/test_nixl_connector.py View File

@@ -461,7 +461,7 @@ class TestNixlHandshake:
metadata = NixlConnectorMetadata()
if num_xfers > 0:
num_xfers -= 1
metadata.add_new_req(
metadata.add_new_req_to_recv(
request_id=request_id,
local_block_ids=[num_xfers + 1, num_xfers + 2, num_xfers + 3],
kv_transfer_params={
@@ -532,7 +532,7 @@ class TestNixlHandshake:
vllm_config, connector.engine_id
)
metadata = NixlConnectorMetadata()
metadata.add_new_req(
metadata.add_new_req_to_recv(
request_id="id",
local_block_ids=[1, 2, 3],
kv_transfer_params={
@@ -588,7 +588,7 @@ class TestNixlHandshake:
metadata = NixlConnectorMetadata()
total_reqs = 5
for i in range(total_reqs):
metadata.add_new_req(
metadata.add_new_req_to_recv(
request_id=f"id_{i}",
local_block_ids=[1, 2, 3],
kv_transfer_params={
@@ -752,7 +752,7 @@ def test_kv_connector_stats(dist_init):
# Create transfer metadata
request_id = "test_req_for_stats"
metadata = NixlConnectorMetadata()
metadata.add_new_req(
metadata.add_new_req_to_recv(
request_id=request_id,
local_block_ids=[1, 2, 3],
kv_transfer_params={
@@ -1515,7 +1515,7 @@ def test_handshake_failure_returns_finished(dist_init):

request_id = "test_handshake_fail"
metadata = NixlConnectorMetadata()
metadata.add_new_req(
metadata.add_new_req_to_recv(
request_id=request_id,
local_block_ids=[1, 2, 3],
kv_transfer_params={
@@ -1565,7 +1565,7 @@ def test_transfer_setup_failure_returns_finished(dist_init):

request_id = "test_transfer_fail"
metadata = NixlConnectorMetadata()
metadata.add_new_req(
metadata.add_new_req_to_recv(
request_id=request_id,
local_block_ids=[7, 8, 9],
kv_transfer_params={


+ 10
- 12
tests/v1/kv_offload/test_cpu_gpu.py View File

@@ -9,7 +9,7 @@ import torch
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandler
from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandlers

BACKENDS_TO_TEST = [FlashAttentionBackend]

@@ -82,7 +82,7 @@ def test_transfer(

# create handler
cpu_block_size = gpu_blocks_per_cpu_block * gpu_block_size
handler = CpuGpuOffloadingHandler(
handlers = CpuGpuOffloadingHandlers(
attn_backends=attn_backends,
gpu_block_size=gpu_block_size,
cpu_block_size=cpu_block_size,
@@ -112,8 +112,7 @@ def test_transfer(

# set transfer direction
if gpu_to_cpu:
src_kv_caches = handler.gpu_tensors
dst_kv_caches = handler.cpu_tensors
handler = handlers.gpu_to_cpu_handler
src_spec_class = GPULoadStoreSpec
dst_spec_class = CPULoadStoreSpec
src_blocks = gpu_blocks
@@ -122,8 +121,7 @@ def test_transfer(
dst_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size
dst_size_in_gpu_blocks = num_cpu_blocks * gpu_blocks_per_cpu_block
else:
src_kv_caches = handler.cpu_tensors
dst_kv_caches = handler.gpu_tensors
handler = handlers.cpu_to_gpu_handler
src_spec_class = CPULoadStoreSpec
dst_spec_class = GPULoadStoreSpec
src_blocks = cpu_blocks
@@ -144,12 +142,12 @@ def test_transfer(
dst_spec = dst_spec_class(dst_blocks)

# clone src and dst tensors before transfer
orig_src_caches = [x.clone() for x in src_kv_caches]
orig_dst_caches = [x.clone() for x in dst_kv_caches]
orig_src_caches = [x.clone() for x in handler.src_tensors]
orig_dst_caches = [x.clone() for x in handler.dst_tensors]

# call transfer function
assert handler.transfer_async(1, (src_spec, dst_spec))
assert set(handler.transfer_events.keys()) == {1}
assert set({x[0] for x in handler._transfers}) == {1}

# wait for transfer to complete
end_time = time.time() + 10
@@ -161,15 +159,15 @@ def test_transfer(
time.sleep(0.1)

# verify src tensors did not change
for orig_tensor, tensor in zip(orig_src_caches, src_kv_caches):
for orig_tensor, tensor in zip(orig_src_caches, handler.src_tensors):
assert torch.equal(orig_tensor, tensor)

# verify dst tensors
for dst_block in range(dst_size_in_gpu_blocks):
src_block_candidate = dst_to_src.get(dst_block)
for src_cache, dst_cache, orig_dst_cache, kv_dim in zip(
src_kv_caches,
dst_kv_caches,
handler.src_tensors,
handler.dst_tensors,
orig_dst_caches,
handler.kv_dim_before_num_blocks,
):


+ 5
- 68
vllm/attention/layer.py View File

@@ -3,7 +3,6 @@
"""Attention layer."""

import functools
from collections.abc import Callable
from typing import cast

import torch
@@ -17,6 +16,7 @@ from vllm.attention.backends.abstract import (
MLAAttentionImpl,
)
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layers.mm_encoder_attention import maybe_get_vit_flash_attn_backend
from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
@@ -49,58 +49,9 @@ from vllm.v1.kv_cache_interface import (
SlidingWindowSpec,
)

if current_platform.is_rocm():
from vllm.platforms.rocm import on_gfx9
else:
on_gfx9 = lambda *args, **kwargs: False


FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__)


def maybe_get_vit_flash_attn_backend(
attn_backend: AttentionBackendEnum,
attn_backend_override: AttentionBackendEnum | None = None,
) -> tuple[AttentionBackendEnum, Callable | None]:
if current_platform.is_rocm():
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
attn_backend = AttentionBackendEnum.ROCM_AITER_FA
elif (
attn_backend_override is None
and on_gfx9()
and attn_backend == AttentionBackendEnum.FLASH_ATTN
):
pass
else:
return AttentionBackendEnum.TORCH_SDPA, None
elif current_platform.is_cuda():
pass
elif current_platform.is_xpu():
assert attn_backend == AttentionBackendEnum.FLASH_ATTN, (
"XPU platform only supports FLASH_ATTN as vision attention backend."
)
pass
else:
return AttentionBackendEnum.TORCH_SDPA, None

if attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
try:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None
else:
flash_attn_varlen_func = None

return attn_backend, flash_attn_varlen_func


def _init_kv_cache_quant(
layer: nn.Module,
quant_config: QuantizationConfig | None,
@@ -496,29 +447,15 @@ class MultiHeadAttention(nn.Module):
attn_backend_override = None
if multimodal_config is not None:
attn_backend_override = multimodal_config.mm_encoder_attn_backend
backend = get_vit_attn_backend(

self.attn_backend = get_vit_attn_backend(
head_size=head_size,
dtype=dtype,
attn_backend_override=attn_backend_override,
)

self.attn_backend = (
backend
if backend
in {
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.PALLAS,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
}
else AttentionBackendEnum.TORCH_SDPA
)

self.attn_backend, self._flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
self._flash_attn_varlen_func = maybe_get_vit_flash_attn_backend(
self.attn_backend,
)

self.is_flash_attn_backend = self.attn_backend in {


+ 284
- 0
vllm/attention/layers/mm_encoder_attention.py View File

@@ -0,0 +1,284 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Callable

import torch

from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_torch_sdpa_wrapper,
)
from vllm.config import MultiModalConfig
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.models.vision import get_vit_attn_backend

logger = init_logger(__name__)


def maybe_get_vit_flash_attn_backend(
attn_backend: AttentionBackendEnum | None,
) -> Callable | None:
# At this point,
# we already have the attn_backend,
# overriding logic is done in the platform-specific implementation.
# so we don't need to override backend here.
# Just return the attn_backend and flash_attn_varlen_func.

if attn_backend == AttentionBackendEnum.FLASH_ATTN:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
elif attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
flash_attn_varlen_func = None

# if attn_backend is TORCH_SDPA,
# it will reach here and the flash_attn_varlen_func will be None.
return flash_attn_varlen_func


@CustomOp.register("mm_encoder_attn")
class MMEncoderAttention(CustomOp):
"""Multi-headed attention without any cache, used for multimodal encoder."""

def __init__(
self,
num_heads: int,
head_size: int,
scale: float | None = None,
num_kv_heads: int | None = None,
prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
) -> None:
"""
Args:
num_heads: number of attention heads per partition.
head_size: hidden_size per attention head.
scale: scale factor.
num_kv_heads: number of kv heads.
prefix: This has no effect, it is only here to make it easier to
swap between Attention and MultiHeadAttention
multimodal_config: configs for multi-modal.
"""
super().__init__()

self.num_heads = num_heads
self.head_size = head_size
self.scale = scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.layer_name = prefix

assert self.num_heads % self.num_kv_heads == 0, (
f"num_heads ({self.num_heads}) is not "
f"divisible by num_kv_heads ({self.num_kv_heads})"
)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()

# Try to get vision attention backend from multimodal_config.
attn_backend_override = None
if multimodal_config is not None:
attn_backend_override = multimodal_config.mm_encoder_attn_backend

# Get device-specific vision attention backend.
self.attn_backend = get_vit_attn_backend(
head_size=head_size,
dtype=dtype,
attn_backend_override=attn_backend_override,
)

self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}

self.flash_attn_varlen_func = maybe_get_vit_flash_attn_backend(
self.attn_backend,
)

logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")

@classmethod
def enabled(cls) -> bool:
return True

def reshape_qkv_to_4d(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
bsz: int,
q_len: int,
kv_len: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Reshape query, key, value to 4D tensors:
(batch_size, seq_len, num_heads, head_size)
"""
query = query.view(bsz, q_len, self.num_heads, self.head_size)
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2)

return query, key, value

def reshape_qkv_to_3d(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
bsz: int,
q_len: int,
kv_len: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Reshape query, key, value to 3D tensors:
(batch_size * seq_len, num_heads, head_size)
"""
query = query.view(bsz * q_len, self.num_heads, self.head_size)
key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size)

if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=1)
value = torch.repeat_interleave(value, num_repeat, dim=1)

return query, key, value

def _forward_sdpa(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
# TODO(Isotr0py): Migrate MultiHeadAttention
assert cu_seqlens is not None

bsz, q_len = query.size()[:2]
kv_len = key.size(1)

query, key, value = self.reshape_qkv_to_4d(
query, key, value, bsz, q_len, kv_len
)

output = vit_torch_sdpa_wrapper(
q=query,
k=key,
v=value,
cu_seqlens=cu_seqlens,
)
return output

def _forward_fa(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
assert self.flash_attn_varlen_func is not None, (
"Flash attention function is not set."
)
# # TODO(Isotr0py): Migrate MultiHeadAttention
assert cu_seqlens is not None and max_seqlen is not None

bsz = query.shape[0]

output = vit_flash_attn_wrapper(
q=query,
k=key,
v=value,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
batch_size=bsz,
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
)
return output

def forward_native(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
return self._forward_sdpa(query, key, value, cu_seqlens)

def forward_cuda(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
if self.is_flash_attn_backend:
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
return self._forward_sdpa(query, key, value, cu_seqlens)
else:
raise ValueError(
f"Unsupported multi-modal encoder attention backend for CUDA: "
f"{self.attn_backend}."
)

def forward_cpu(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
return self._forward_sdpa(query, key, value, cu_seqlens)

def forward_xpu(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
assert self.is_flash_attn_backend, (
"XPU only supports FLASH_ATTN for vision attention."
)
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)

def forward_tpu(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
assert self.attn_backend == AttentionBackendEnum.PALLAS, (
f"MMEncoderAttention on TPU only supports PALLAS backend, "
f"but got {self.attn_backend}."
)
if cu_seqlens is None:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
from torch_xla.experimental.custom_kernel import flash_attention

out = flash_attention(query, key, value, sm_scale=self.scale)
out = out.transpose(1, 2)
return out
logger.warning_once(
"PALLAS backend with cu_seqlens is not supported for ViT yet. ",
"Falling back to SDPA implementation.",
)
return self._forward_sdpa(query, key, value, cu_seqlens)

+ 3
- 8
vllm/attention/ops/vit_attn_wrappers.py View File

@@ -44,9 +44,7 @@ def flash_attn_maxseqlen_wrapper(
dropout_p=0.0,
causal=False,
)
context_layer = einops.rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size)
return context_layer


@@ -59,8 +57,7 @@ def flash_attn_maxseqlen_wrapper_fake(
batch_size: int,
is_rocm_aiter: bool,
) -> torch.Tensor:
b, s, h, d = q.shape
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
return torch.empty_like(q)


direct_register_custom_op(
@@ -106,7 +103,6 @@ def torch_sdpa_wrapper(
output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
return context_layer


@@ -116,8 +112,7 @@ def torch_sdpa_wrapper_fake(
v: torch.Tensor,
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
b, s, h, d = q.shape
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
return torch.empty_like(q)


direct_register_custom_op(


+ 3
- 1
vllm/benchmarks/serve.py View File

@@ -235,7 +235,9 @@ async def get_request(


def calculate_metrics_for_embeddings(
outputs: list[RequestFuncOutput], dur_s: float, selected_percentiles: list[float]
outputs: list[RequestFuncOutput],
dur_s: float,
selected_percentiles: list[float],
) -> EmbedBenchmarkMetrics:
"""Calculate the metrics for the embedding requests.



+ 7
- 3
vllm/config/compilation.py View File

@@ -932,9 +932,13 @@ class CompilationConfig:
self.splitting_ops = list(self._attention_ops)
added_default_splitting_ops = True
elif len(self.splitting_ops) == 0:
logger.warning_once(
"Using piecewise compilation with empty splitting_ops"
)
if (
self.cudagraph_mode == CUDAGraphMode.PIECEWISE
or self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
):
logger.warning_once(
"Using piecewise compilation with empty splitting_ops"
)
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
logger.warning_once(
"Piecewise compilation with empty splitting_ops do not"


+ 1
- 0
vllm/config/model.py View File

@@ -1796,6 +1796,7 @@ _SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [
("ForTextEncoding", ("pooling", "embed")),
("EmbeddingModel", ("pooling", "embed")),
("ForSequenceClassification", ("pooling", "classify")),
("ForTokenClassification", ("pooling", "classify")),
("ForAudioClassification", ("pooling", "classify")),
("ForImageClassification", ("pooling", "classify")),
("ForVideoClassification", ("pooling", "classify")),


+ 3
- 1
vllm/config/scheduler.py View File

@@ -122,10 +122,12 @@ class SchedulerConfig:
the default scheduler. Can be a class directly or the path to a class of
form "mod.custom_class"."""

disable_hybrid_kv_cache_manager: bool = False
disable_hybrid_kv_cache_manager: bool | None = None
"""If set to True, KV cache manager will allocate the same size of KV cache
for all attention layers even if there are multiple type of attention layers
like full attention and sliding window attention.
If set to None, the default value will be determined based on the environment
and starting configuration.
"""

async_scheduling: bool = False


+ 60
- 36
vllm/config/vllm.py View File

@@ -887,17 +887,48 @@ class VllmConfig:
if not self.instance_id:
self.instance_id = random_uuid()[:5]

if not self.scheduler_config.disable_hybrid_kv_cache_manager:
# logger should only print warning message for hybrid models. As we
# can't know whether the model is hybrid or not now, so we don't log
# warning message here and will log it later.
if not current_platform.support_hybrid_kv_cache():
# Hybrid KV cache manager is not supported on non-GPU platforms.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
# Hybrid KV cache manager (HMA) runtime rules:
# - Explicit enable (--no-disable-kv-cache-manager): error if runtime
# disables it
# - No preference: auto-disable for unsupported features (e.g. kv connector)
# - Explicit disable (--disable-kv-cache-manager): always respect it
need_disable_hybrid_kv_cache_manager = False
# logger should only print warning message for hybrid models. As we
# can't know whether the model is hybrid or not now, so we don't log
# warning message here and will log it later.
if not current_platform.support_hybrid_kv_cache():
# Hybrid KV cache manager is not supported on non-GPU platforms.
need_disable_hybrid_kv_cache_manager = True
if self.kv_events_config is not None:
# Hybrid KV cache manager is not compatible with KV events.
need_disable_hybrid_kv_cache_manager = True
if (
self.model_config is not None
and self.model_config.attention_chunk_size is not None
):
if (
self.speculative_config is not None
and self.speculative_config.use_eagle()
):
# Hybrid KV cache manager is not yet supported with chunked
# local attention + eagle.
need_disable_hybrid_kv_cache_manager = True
elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE:
logger.warning(
"There is a latency regression when using chunked local"
" attention with the hybrid KV cache manager. Disabling"
" it, by default. To enable it, set the environment "
"VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1."
)
# Hybrid KV cache manager is not yet supported with chunked
# local attention.
need_disable_hybrid_kv_cache_manager = True

if self.scheduler_config.disable_hybrid_kv_cache_manager is None:
# Default to disable HMA, but only if the user didn't express a preference.
if self.kv_transfer_config is not None:
# NOTE(Kuntai): turn HMA off for connector for now.
# TODO(Kuntai): have a more elegent solution to check and
# turn off HMA for connector that does not support HMA.
# NOTE(Kuntai): turn HMA off for connector unless specifically enabled.
need_disable_hybrid_kv_cache_manager = True
logger.warning(
"Turning off hybrid kv cache manager because "
"`--kv-transfer-config` is set. This will reduce the "
@@ -905,33 +936,26 @@ class VllmConfig:
"or Mamba attention. If you are a developer of kv connector"
", please consider supporting hybrid kv cache manager for "
"your connector by making sure your connector is a subclass"
" of `SupportsHMA` defined in kv_connector/v1/base.py."
" of `SupportsHMA` defined in kv_connector/v1/base.py and"
" use --no-disable-hybrid-kv-cache-manager to start vLLM."
)
self.scheduler_config.disable_hybrid_kv_cache_manager = True
if self.kv_events_config is not None:
# Hybrid KV cache manager is not compatible with KV events.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
if (
self.model_config is not None
and self.model_config.attention_chunk_size is not None
):
if (
self.speculative_config is not None
and self.speculative_config.use_eagle()
):
# Hybrid KV cache manager is not yet supported with chunked
# local attention + eagle.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE:
logger.warning(
"There is a latency regression when using chunked local"
" attention with the hybrid KV cache manager. Disabling"
" it, by default. To enable it, set the environment "
"VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1."
)
# Hybrid KV cache manager is not yet supported with chunked
# local attention.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
self.scheduler_config.disable_hybrid_kv_cache_manager = (
need_disable_hybrid_kv_cache_manager
)
elif (
self.scheduler_config.disable_hybrid_kv_cache_manager is False
and need_disable_hybrid_kv_cache_manager
):
raise ValueError(
"Hybrid KV cache manager was explicitly enabled but is not "
"supported in this configuration. Consider omitting the "
"--no-disable-hybrid-kv-cache-manager flag to let vLLM decide"
" automatically."
)

if self.scheduler_config.disable_hybrid_kv_cache_manager is None:
# Default to enable HMA if not explicitly disabled by user or logic above.
self.scheduler_config.disable_hybrid_kv_cache_manager = False

if self.compilation_config.debug_dump_path:
self.compilation_config.debug_dump_path = (


+ 56
- 39
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py View File

@@ -202,17 +202,22 @@ def compute_nixl_compatibility_hash(
return compat_hash


@dataclass
class RemoteMeta:
block_ids: list[int]
host: str
port: int
engine_id: str
request_id: str


@dataclass
class ReqMeta:
local_block_ids: list[int]
# To be used when logical block size does not match the kernel block size
local_physical_block_ids: list[int]
remote_block_ids: list[int]
remote_host: str
remote_port: int
remote_engine_id: str
remote_request_id: str
tp_size: int
remote: RemoteMeta | None = None


class NixlConnectorMetadata(KVConnectorMetadata):
@@ -223,31 +228,43 @@ class NixlConnectorMetadata(KVConnectorMetadata):
self.reqs_in_batch: set[ReqId] = set()
self.reqs_not_processed: set[ReqId] = set()

def add_new_req(
def _add_new_req(
self,
request_id: ReqId,
local_block_ids: list[int],
kv_transfer_params: dict[str, Any],
load_remote_cache: bool = True,
save_to_host: bool = False,
):
# save and load are mutually exclusive
assert load_remote_cache ^ save_to_host
_req = ReqMeta(
) -> ReqMeta:
return ReqMeta(
local_block_ids=local_block_ids,
local_physical_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params["remote_block_ids"],
remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_request_id=kv_transfer_params["remote_request_id"],
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
# P workers don't need to receive tp_size from proxy here.
tp_size=kv_transfer_params.get("tp_size", 1),
)
if save_to_host:
self.reqs_to_save[request_id] = _req
if load_remote_cache:
self.reqs_to_recv[request_id] = _req

def add_new_req_to_save(
self,
request_id: ReqId,
local_block_ids: list[int],
kv_transfer_params: dict[str, Any],
):
self.reqs_to_save[request_id] = self._add_new_req(
local_block_ids, kv_transfer_params
)

def add_new_req_to_recv(
self,
request_id: ReqId,
local_block_ids: list[int],
kv_transfer_params: dict[str, Any],
):
req = self._add_new_req(local_block_ids, kv_transfer_params)
req.remote = RemoteMeta(
block_ids=kv_transfer_params["remote_block_ids"],
engine_id=kv_transfer_params["remote_engine_id"],
request_id=kv_transfer_params["remote_request_id"],
host=kv_transfer_params["remote_host"],
port=kv_transfer_params["remote_port"],
)
self.reqs_to_recv[request_id] = req


class NixlConnector(KVConnectorBase_V1):
@@ -666,22 +683,18 @@ class NixlConnectorScheduler:
# Loop through scheduled reqs and convert to ReqMeta.
for req_id, (req, block_ids) in self._reqs_need_recv.items():
assert req.kv_transfer_params is not None
meta.add_new_req(
meta.add_new_req_to_recv(
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
load_remote_cache=True,
save_to_host=False,
)

for req_id, (req, block_ids) in self._reqs_need_save.items():
assert req.kv_transfer_params is not None
meta.add_new_req(
meta.add_new_req_to_save(
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
load_remote_cache=False,
save_to_host=True,
)

meta.reqs_to_send = self._reqs_need_send
@@ -1124,10 +1137,11 @@ class NixlConnectorWorker:
# Do NIXL handshake in background and add to _ready_requests when done.
fut = self._handshake_futures.get(remote_engine_id)
if fut is None:
assert meta.remote is not None
fut = self._handshake_initiation_executor.submit(
self._nixl_handshake,
meta.remote_host,
meta.remote_port,
meta.remote.host,
meta.remote.port,
meta.tp_size,
remote_engine_id,
)
@@ -1774,6 +1788,7 @@ class NixlConnectorWorker:
# clean up metadata for completed requests
meta = self._recving_metadata.pop(req_id, None)
assert meta is not None, f"{req_id} not found in recving_metadata list"
assert meta.remote is not None
if self.use_host_buffer:
self.sync_recved_kv_to_device(req_id, meta)
if self.enable_permute_local_kv:
@@ -1781,7 +1796,7 @@ class NixlConnectorWorker:

# post processing for heteroblocksize
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
meta.remote_engine_id
meta.remote.engine_id
)
if (
not self.use_mla
@@ -1916,17 +1931,18 @@ class NixlConnectorWorker:
meta.local_physical_block_ids = self._logical_to_kernel_block_ids(
meta.local_block_ids
)
meta.remote_block_ids = self._logical_to_kernel_block_ids(
meta.remote_block_ids
assert meta.remote is not None
meta.remote.block_ids = self._logical_to_kernel_block_ids(
meta.remote.block_ids
)
remote_engine_id = meta.remote_engine_id
remote_engine_id = meta.remote.engine_id
logger.debug(
"start_load_kv for request %s from remote engine %s. "
"Num local_block_ids: %s. Num remote_block_ids: %s. ",
req_id,
remote_engine_id,
len(meta.local_physical_block_ids),
len(meta.remote_block_ids),
len(meta.remote.block_ids),
)
# always store metadata for failure recovery
self._recving_metadata[req_id] = meta
@@ -1965,17 +1981,18 @@ class NixlConnectorWorker:
self._reqs_to_send[req_id] = expiration_time

def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
assert meta.remote is not None
logger.debug(
"Remote agent %s available, calling _read_blocks for req %s",
meta.remote_engine_id,
meta.remote.engine_id,
req_id,
)
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote_engine_id,
remote_request_id=meta.remote_request_id,
dst_engine_id=meta.remote.engine_id,
remote_request_id=meta.remote.request_id,
local_block_ids=meta.local_physical_block_ids,
remote_block_ids=meta.remote_block_ids,
remote_block_ids=meta.remote.block_ids,
)

def _read_blocks(


+ 1
- 1
vllm/engine/arg_utils.py View File

@@ -491,7 +491,7 @@ class EngineArgs:
enable_chunked_prefill: bool | None = None
disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input

disable_hybrid_kv_cache_manager: bool = (
disable_hybrid_kv_cache_manager: bool | None = (
SchedulerConfig.disable_hybrid_kv_cache_manager
)



+ 20
- 6
vllm/entrypoints/chat_utils.py View File

@@ -24,6 +24,7 @@ from openai.types.chat import (
ChatCompletionContentPartInputAudioParam,
ChatCompletionContentPartRefusalParam,
ChatCompletionContentPartTextParam,
ChatCompletionFunctionToolParam,
ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam,
)
@@ -269,6 +270,9 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
reasoning: str | None
"""The reasoning content for interleaved thinking."""

tools: list[ChatCompletionFunctionToolParam] | None
"""The tools for developer role."""


ChatCompletionMessageParam: TypeAlias = (
OpenAIChatCompletionMessageParam
@@ -300,6 +304,9 @@ class ConversationMessage(TypedDict, total=False):
reasoning_content: str | None
"""Deprecated: The reasoning content for interleaved thinking."""

tools: list[ChatCompletionFunctionToolParam] | None
"""The tools for developer role."""


# Passed in by user
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]
@@ -1619,6 +1626,8 @@ def _parse_chat_message_content(
if "name" in message and isinstance(message["name"], str):
result_msg["name"] = message["name"]

if role == "developer":
result_msg["tools"] = message.get("tools", None)
return result


@@ -1629,12 +1638,17 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
# so, for messages that have tool_calls, parse the string (which we get
# from openAI format) to dict
for message in messages:
if (
message["role"] == "assistant"
and "tool_calls" in message
and isinstance(message["tool_calls"], list)
):
for item in message["tool_calls"]:
if message["role"] == "assistant" and "tool_calls" in message:
tool_calls = message.get("tool_calls")
if not isinstance(tool_calls, list):
continue

if len(tool_calls) == 0:
# Drop empty tool_calls to keep templates on the normal assistant path.
message.pop("tool_calls", None)
continue

for item in tool_calls:
# if arguments is None or empty string, set to {}
if content := item["function"].get("arguments"):
if not isinstance(content, (dict, list)):


+ 8
- 8
vllm/entrypoints/context.py View File

@@ -34,13 +34,13 @@ from vllm.entrypoints.openai.protocol import (
ResponseRawMessageAndToken,
ResponsesRequest,
)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser
from vllm.entrypoints.responses_utils import construct_tool_dicts
from vllm.entrypoints.tool import Tool
from vllm.entrypoints.tool_server import ToolServer
from vllm.outputs import RequestOutput
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.tokenizers.protocol import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ToolParser
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid

@@ -74,24 +74,24 @@ class TurnMetrics:

def __init__(
self,
input_tokens=0,
output_tokens=0,
cached_input_tokens=0,
tool_output_tokens=0,
):
input_tokens: int = 0,
output_tokens: int = 0,
cached_input_tokens: int = 0,
tool_output_tokens: int = 0,
) -> None:
self.input_tokens = input_tokens
self.output_tokens = output_tokens
self.cached_input_tokens = cached_input_tokens
self.tool_output_tokens = tool_output_tokens

def reset(self):
def reset(self) -> None:
"""Reset counters for a new turn."""
self.input_tokens = 0
self.output_tokens = 0
self.cached_input_tokens = 0
self.tool_output_tokens = 0

def copy(self):
def copy(self) -> "TurnMetrics":
"""Create a copy of this turn's token counts."""
return TurnMetrics(
self.input_tokens,


+ 1
- 1
vllm/entrypoints/openai/api_server.py View File

@@ -72,7 +72,6 @@ from vllm.entrypoints.openai.serving_transcription import (
OpenAIServingTranscription,
OpenAIServingTranslation,
)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.classify.serving import ServingClassification
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
@@ -95,6 +94,7 @@ from vllm.entrypoints.utils import (
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.tasks import POOLING_TASKS
from vllm.tool_parsers import ToolParserManager
from vllm.usage.usage_lib import UsageContext
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.gc_utils import freeze_gc_heap


+ 1
- 1
vllm/entrypoints/openai/cli_args.py View File

@@ -27,8 +27,8 @@ from vllm.entrypoints.constants import (
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT,
)
from vllm.entrypoints.openai.serving_models import LoRAModulePath
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.logger import init_logger
from vllm.tool_parsers import ToolParserManager
from vllm.utils.argparse_utils import FlexibleArgumentParser

logger = init_logger(__name__)


+ 1
- 1
vllm/entrypoints/openai/parser/responses_parser.py View File

@@ -12,10 +12,10 @@ from openai.types.responses.response_reasoning_item import (
)

from vllm.entrypoints.openai.protocol import ResponseInputOutputItem, ResponsesRequest
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser
from vllm.outputs import CompletionOutput
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.tokenizers.protocol import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ToolParser
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid



+ 9
- 0
vllm/entrypoints/openai/protocol.py View File

@@ -320,6 +320,7 @@ class ResponsesRequest(OpenAIBaseModel):
max_tool_calls: int | None = None
metadata: Metadata | None = None
model: str | None = None
logit_bias: dict[str, float] | None = None
parallel_tool_calls: bool | None = True
previous_response_id: str | None = None
prompt: ResponsePrompt | None = None
@@ -333,6 +334,7 @@ class ResponsesRequest(OpenAIBaseModel):
tools: list[Tool] = Field(default_factory=list)
top_logprobs: int | None = 0
top_p: float | None = None
top_k: int | None = None
truncation: Literal["auto", "disabled"] | None = "disabled"
user: str | None = None

@@ -387,6 +389,7 @@ class ResponsesRequest(OpenAIBaseModel):
_DEFAULT_SAMPLING_PARAMS = {
"temperature": 1.0,
"top_p": 1.0,
"top_k": 0,
}

def to_sampling_params(
@@ -408,6 +411,10 @@ class ResponsesRequest(OpenAIBaseModel):
top_p = default_sampling_params.get(
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
)
if (top_k := self.top_k) is None:
top_k = default_sampling_params.get(
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
)
stop_token_ids = default_sampling_params.get("stop_token_ids")

# Structured output
@@ -428,6 +435,7 @@ class ResponsesRequest(OpenAIBaseModel):
return SamplingParams.from_optional(
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_tokens=max_tokens,
logprobs=self.top_logprobs if self.is_include_output_logprobs() else None,
stop_token_ids=stop_token_ids,
@@ -435,6 +443,7 @@ class ResponsesRequest(OpenAIBaseModel):
RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY
),
structured_outputs=structured_outputs,
logit_bias=self.logit_bias,
)

def is_include_output_logprobs(self) -> bool:


+ 2
- 2
vllm/entrypoints/openai/serving_chat.py View File

@@ -57,8 +57,6 @@ from vllm.entrypoints.openai.serving_engine import (
clamp_prompt_logprobs,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import TokensPrompt
@@ -73,6 +71,8 @@ from vllm.tokenizers.mistral import (
truncate_tool_call_ids,
validate_request_params,
)
from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.mistral_tool_parser import MistralToolCall
from vllm.utils.collection_utils import as_list
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters



+ 2
- 2
vllm/entrypoints/openai/serving_engine.py View File

@@ -59,7 +59,6 @@ from vllm.entrypoints.openai.protocol import (
TranslationRequest,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
@@ -102,8 +101,9 @@ from vllm.pooling_params import PoolingParams
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.deepseekv32 import DeepseekV32Tokenizer
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.tool_parsers import ToolParser, ToolParserManager
from vllm.tracing import (
contains_trace_headers,
extract_trace_headers,


+ 23
- 140
vllm/entrypoints/openai/tool_parsers/__init__.py View File

@@ -1,150 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
ToolParserManager,
)
import warnings

__all__ = ["ToolParser", "ToolParserManager"]

def __getattr__(name: str):
if name == "ToolParser":
from vllm.tool_parsers import ToolParser

"""
Register a lazy module mapping.
warnings.warn(
"`vllm.entrypoints.openai.tool_parsers.ToolParser` has been moved to "
"`vllm.tool_parsers.ToolParser`. "
"The old name will be removed in v0.14.",
DeprecationWarning,
stacklevel=2,
)

Example:
ToolParserManager.register_lazy_module(
name="kimi_k2",
module_path="vllm.entrypoints.openai.tool_parsers.kimi_k2_parser",
class_name="KimiK2ToolParser",
)
"""
return ToolParser
if name == "ToolParserManager":
from vllm.tool_parsers import ToolParserManager

warnings.warn(
"`vllm.entrypoints.openai.tool_parsers.ToolParserManager` "
"has been moved to `vllm.tool_parsers.ToolParserManager`. "
"The old name will be removed in v0.14.",
DeprecationWarning,
stacklevel=2,
)

_TOOL_PARSERS_TO_REGISTER = {
"deepseek_v3": ( # name
"deepseekv3_tool_parser", # filename
"DeepSeekV3ToolParser", # class_name
),
"deepseek_v31": (
"deepseekv31_tool_parser",
"DeepSeekV31ToolParser",
),
"deepseek_v32": (
"deepseekv32_tool_parser",
"DeepSeekV32ToolParser",
),
"ernie45": (
"ernie45_tool_parser",
"Ernie45ToolParser",
),
"glm45": (
"glm4_moe_tool_parser",
"Glm4MoeModelToolParser",
),
"granite-20b-fc": (
"granite_20b_fc_tool_parser",
"Granite20bFCToolParser",
),
"granite": (
"granite_tool_parser",
"GraniteToolParser",
),
"hermes": (
"hermes_tool_parser",
"Hermes2ProToolParser",
),
"hunyuan_a13b": (
"hunyuan_a13b_tool_parser",
"HunyuanA13BToolParser",
),
"internlm": (
"internlm2_tool_parser",
"Internlm2ToolParser",
),
"jamba": (
"jamba_tool_parser",
"JambaToolParser",
),
"kimi_k2": (
"kimi_k2_tool_parser",
"KimiK2ToolParser",
),
"llama3_json": (
"llama_tool_parser",
"Llama3JsonToolParser",
),
"llama4_json": (
"llama_tool_parser",
"Llama3JsonToolParser",
),
"llama4_pythonic": (
"llama4_pythonic_tool_parser",
"Llama4PythonicToolParser",
),
"longcat": (
"longcat_tool_parser",
"LongcatFlashToolParser",
),
"minimax_m2": (
"minimax_m2_tool_parser",
"MinimaxM2ToolParser",
),
"minimax": (
"minimax_tool_parser",
"MinimaxToolParser",
),
"mistral": (
"mistral_tool_parser",
"MistralToolParser",
),
"olmo3": (
"olmo3_tool_parser",
"Olmo3PythonicToolParser",
),
"openai": (
"openai_tool_parser",
"OpenAIToolParser",
),
"phi4_mini_json": (
"phi4mini_tool_parser",
"Phi4MiniJsonToolParser",
),
"pythonic": (
"pythonic_tool_parser",
"PythonicToolParser",
),
"qwen3_coder": (
"qwen3coder_tool_parser",
"Qwen3CoderToolParser",
),
"qwen3_xml": (
"qwen3xml_tool_parser",
"Qwen3XMLToolParser",
),
"seed_oss": (
"seed_oss_tool_parser",
"SeedOssToolParser",
),
"step3": (
"step3_tool_parser",
"Step3ToolParser",
),
"xlam": (
"xlam_tool_parser",
"xLAMToolParser",
),
"gigachat3": (
"gigachat3_tool_parser",
"GigaChat3ToolParser",
),
}
return ToolParserManager


def register_lazy_tool_parsers():
for name, (file_name, class_name) in _TOOL_PARSERS_TO_REGISTER.items():
module_path = f"vllm.entrypoints.openai.tool_parsers.{file_name}"
ToolParserManager.register_lazy_module(name, module_path, class_name)


register_lazy_tool_parsers()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

+ 1
- 0
vllm/entrypoints/pooling/score/protocol.py View File

@@ -120,6 +120,7 @@ class RerankResult(BaseModel):


class RerankUsage(BaseModel):
prompt_tokens: int
total_tokens: int




+ 3
- 1
vllm/entrypoints/pooling/score/serving.py View File

@@ -502,5 +502,7 @@ class ServingScores(OpenAIServing):
id=request_id,
model=model_name,
results=results,
usage=RerankUsage(total_tokens=num_prompt_tokens),
usage=RerankUsage(
total_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens
),
)

+ 14
- 11
vllm/entrypoints/renderer.py View File

@@ -167,17 +167,20 @@ class BaseRenderer(ABC):
)

def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
tensor = torch.load(
io.BytesIO(pybase64.b64decode(embed, validate=True)),
weights_only=True,
map_location=torch.device("cpu"),
)
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
torch.float32,
torch.bfloat16,
torch.float16,
)
tensor = tensor.to_dense()
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with torch.sparse.check_sparse_tensor_invariants():
tensor = torch.load(
io.BytesIO(pybase64.b64decode(embed, validate=True)),
weights_only=True,
map_location=torch.device("cpu"),
)
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
torch.float32,
torch.bfloat16,
torch.float16,
)
tensor = tensor.to_dense()
if tensor.dim() > 2:
tensor = tensor.squeeze(0)
assert tensor.dim() == 2


+ 7
- 2
vllm/model_executor/custom_op.py View File

@@ -38,8 +38,9 @@ class CustomOp(nn.Module):
)
return super().__new__(op_cls_to_instantiate)

def __init__(self):
def __init__(self, enforce_enable: bool = False):
super().__init__()
self._enforce_enable = enforce_enable
self._forward_method = self.dispatch_forward()

def forward(self, *args, **kwargs):
@@ -84,7 +85,11 @@ class CustomOp(nn.Module):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
compilation_config = get_cached_compilation_config()
enabled = self.enabled()

# CustomOp object can be enforce enabled, e.g., enable device-specific
# kernels in ViT models when enabling graph mode. By default, it will
# follow the compilation_config to determine whether enable itself.
enabled = self._enforce_enable or self.enabled()
if enabled:
compilation_config.enabled_custom_ops.update([self.__class__.name])
else:


+ 0
- 2
vllm/model_executor/layers/fused_moe/cutlass_moe.py View File

@@ -460,7 +460,6 @@ def cutlass_moe_fp8(
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
parallel_config=None,
) -> torch.Tensor:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
@@ -538,7 +537,6 @@ def cutlass_moe_fp8(
c_strides2=c_strides2,
quant_config=quant_config,
),
parallel_config=parallel_config,
)

return fn(


+ 1
- 1
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py View File

@@ -293,7 +293,7 @@ def deep_gemm_moe_fp8(
expert_map: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
apply_router_weight_on_input=False,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer


+ 1
- 6
vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py View File

@@ -43,11 +43,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
prepare_finalize: FusedMoEPrepareAndFinalize,
shared_experts: torch.nn.Module | None,
) -> "FusedMoEModularMethod":
parallel_config = getattr(
getattr(moe_layer, "vllm_config", None),
"parallel_config",
None,
)
return FusedMoEModularMethod(
old_quant_method,
FusedMoEModularKernel(
@@ -55,7 +50,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
shared_experts,
getattr(moe_layer, "shared_experts_stream", None),
parallel_config=parallel_config,
moe_parallel_config=moe_layer.moe_parallel_config,
),
)



+ 13
- 8
vllm/model_executor/layers/fused_moe/modular_kernel.py View File

@@ -10,10 +10,12 @@ from typing import final
import torch

import vllm.envs as envs
from vllm.config import ParallelConfig, get_current_vllm_config
from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache,
count_expert_num_tokens,
@@ -681,7 +683,7 @@ class FusedMoEModularKernel(torch.nn.Module):
fused_experts: FusedMoEPermuteExpertsUnpermute,
shared_experts: torch.nn.Module | None = None,
shared_experts_stream: torch.cuda.Stream | None = None,
parallel_config: ParallelConfig | None = None,
moe_parallel_config: FusedMoEParallelConfig | None = None,
):
super().__init__()
self.prepare_finalize = prepare_finalize
@@ -689,12 +691,15 @@ class FusedMoEModularKernel(torch.nn.Module):
self.shared_experts = shared_experts
self.shared_experts_stream = shared_experts_stream

# cache whether this worker is using DP+EP
if parallel_config is None:
parallel_config = get_current_vllm_config().parallel_config
# prefer an explicit FusedMoEParallelConfig when available (from
# FusedMoE layers / tests).
# if not provided, assume this kernel is
# running in a non-DP+EP context
self.moe_parallel_config: FusedMoEParallelConfig | None = moe_parallel_config
self.is_dp_ep = (
parallel_config.data_parallel_size > 1
and parallel_config.enable_expert_parallel
moe_parallel_config is not None
and moe_parallel_config.dp_size > 1
and moe_parallel_config.use_ep
)

self._post_init_setup()


+ 0
- 3
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py View File

@@ -1266,9 +1266,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
parallel_config=getattr(
getattr(layer, "vllm_config", None), "parallel_config", None
),
)

else:


+ 154
- 89
vllm/model_executor/layers/quantization/fp8.py View File

@@ -332,7 +332,10 @@ class Fp8Config(QuantizationConfig):
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedFusedMoEMethod(layer.moe_config)
moe_quant_method = Fp8MoEMethod(self, layer)
if self.is_checkpoint_fp8_serialized:
moe_quant_method = Fp8MoEMethod(self, layer)
else:
moe_quant_method = Fp8OnlineMoEMethod(self, layer)
moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return moe_quant_method
elif isinstance(layer, Attention):
@@ -745,8 +748,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.orig_dtype = params_dtype
layer.weight_block_size = None

if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
assert self.quant_config.is_checkpoint_fp8_serialized
params_dtype = torch.float8_e4m3fn

if self.block_quant:
assert self.weight_block_size is not None
layer.weight_block_size = self.weight_block_size
@@ -773,41 +777,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
f"weight quantization block_k = {block_k}."
)

# if we are doing online quantization, patch the weight
# loaded to call `process_weights_after_loading` in a streaming fashion
# as soon as the last weight chunk is loaded
if not self.quant_config.is_checkpoint_fp8_serialized:
weight_loader = extra_weight_attrs["weight_loader"]
# create a new holder to prevent modifying behavior of any other
# objects which might depend on the old one
new_extra_weight_attrs = extra_weight_attrs

def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# load the current weight chunk
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]

# add a counter to track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0
layer._loaded_numel += loaded_weight.numel()

# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)

# Delete the bookkeeping
del layer._loaded_numel
# Prevent the usual `process_weights_after_loading` call
# from doing anything
layer._already_called_process_weights_after_loading = True

return res

new_extra_weight_attrs["weight_loader"] = patched_weight_loader
extra_weight_attrs = new_extra_weight_attrs

# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
@@ -875,21 +844,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if self.block_quant
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if self.quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)

# INPUT_SCALES
if self.quant_config.activation_scheme == "static":
if not self.quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)

w13_input_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
@@ -986,45 +945,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight_scale_inv = Parameter(
dg_w2_weight_scale_inv, requires_grad=False
)

# If checkpoint is fp16, quantize in place.
elif not self.quant_config.is_checkpoint_fp8_serialized:
fp8_dtype = current_platform.fp8_dtype()
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)

# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
replace_parameter(
layer,
"w13_weight_scale",
torch.ones(
layer.local_num_experts,
dtype=torch.float32,
device=w13_weight.device,
),
)
for expert in range(layer.local_num_experts):
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w2_weight", w2_weight)

if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight
)

replace_parameter(layer, "w13_weight", shuffled_w13)
replace_parameter(layer, "w2_weight", shuffled_w2)
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
else:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
@@ -1387,6 +1307,151 @@ class Fp8MoEMethod(FusedMoEMethodBase):
return result


class Fp8OnlineMoEMethod(Fp8MoEMethod):
"""MoE method for online FP8 quantization.
Supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.

Args:
quant_config: The quantization config.
"""

def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
super().__init__(quant_config, layer)
assert not quant_config.is_checkpoint_fp8_serialized
assert quant_config.activation_scheme == "dynamic"
assert quant_config.weight_block_size is None
assert self.flashinfer_moe_backend is None

def create_weights(
self,
layer: Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.intermediate_size_per_partition = intermediate_size_per_partition
layer.hidden_size = hidden_size
layer.num_experts = num_experts
layer.orig_dtype = params_dtype
layer.weight_block_size = None

# We are doing online quantization, patch the weight loaded
# to call `process_weights_after_loading` in a streaming fashion
# as soon as the last weight chunk is loaded.
weight_loader = extra_weight_attrs["weight_loader"]
# create a new holder to prevent modifying behavior of any other
# objects which might depend on the old one
new_extra_weight_attrs = extra_weight_attrs

def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# load the current weight chunk
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]

# add a counter to track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0
layer._loaded_numel += loaded_weight.numel()

# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)

# Delete the bookkeeping
del layer._loaded_numel
# Prevent the usual `process_weights_after_loading` call
# from doing anything
layer._already_called_process_weights_after_loading = True

return res

new_extra_weight_attrs["weight_loader"] = patched_weight_loader
extra_weight_attrs = new_extra_weight_attrs

# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)

w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)

# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)

layer.w13_input_scale = None
layer.w2_input_scale = None

self.rocm_aiter_moe_enabled = False

def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return

# Lazy import to avoid importing triton too early.
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()

# If checkpoint is fp16, quantize in place.
fp8_dtype = current_platform.fp8_dtype()
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)

for expert in range(layer.local_num_experts):
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w2_weight", w2_weight)

# Reshuffle weights for AITER if needed.
if self.rocm_aiter_moe_enabled:
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight
)
replace_parameter(layer, "w13_weight", shuffled_w13)
replace_parameter(layer, "w2_weight", shuffled_w2)

# Rushuffle weights for MARLIN if needed.
if self.use_marlin:
prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype
)


class Fp8KVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.


+ 1
- 6
vllm/model_executor/layers/quantization/utils/flashinfer_utils.py View File

@@ -247,11 +247,6 @@ def flashinfer_cutlass_moe_fp8(
assert quant_config is not None

# Construct modular kernel with block-scale support when requested.
parallel_config = getattr(
getattr(layer, "vllm_config", None),
"parallel_config",
None,
)
fused_experts = mk.FusedMoEModularKernel(
build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe=moe, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
@@ -262,7 +257,7 @@ def flashinfer_cutlass_moe_fp8(
out_dtype=hidden_states.dtype,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
),
parallel_config=parallel_config,
moe_parallel_config=layer.moe_parallel_config,
)

return fused_experts(


+ 12
- 0
vllm/model_executor/models/adapters.py View File

@@ -337,6 +337,18 @@ def as_seq_cls_model(cls: _T) -> _T:
tokens = getattr(text_config, "classifier_from_token", None)
method = getattr(text_config, "method", None)

def auto_set_score_bias(weights):
for name, weight in weights:
if name == "score.bias":
device = self.score.weight.device
dtype = self.score.weight.dtype
bias = weight.to(device).to(dtype)
self.score.bias = torch.nn.Parameter(bias)
self.score.skip_bias_add = False
else:
yield name, weight

weights = auto_set_score_bias(weights)
if tokens is None and method is None:
return super().load_weights(weights)
else:


+ 584
- 0
vllm/model_executor/models/bagel.py View File

@@ -0,0 +1,584 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
"""Inference-only BAGEL model compatible with HuggingFace weights.

BAGEL is a unified multimodal model for image understanding and generation.
For vLLM, we focus on the image understanding (vision-to-text) capabilities.
"""

from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, TypeAlias

import torch
import torch.nn as nn

from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptReplacement,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processors.bagel import BagelProcessor
from vllm.utils.tensor_schema import TensorSchema

from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
)
from .siglip import SiglipVisionModel
from .utils import (
AutoWeightsLoader,
WeightsMapper,
init_vllm_registered_model,
maybe_prefix,
)

logger = init_logger(__name__)


class BagelImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height of each image
- w: Width of each image
"""

type: Literal["pixel_values"]
pixel_values: torch.Tensor # Shape: (bn, 3, h, w)


BagelImageInputs: TypeAlias = BagelImagePixelInputs


class BagelVisionMLP(nn.Module):
"""MLP connector for vision features."""

def __init__(
self,
in_features: int,
hidden_features: int,
out_features: int,
act_layer: str = "gelu_pytorch_tanh",
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.fc1 = ColumnParallelLinear(
in_features,
hidden_features,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.act = get_act_fn(act_layer)
self.fc2 = RowParallelLinear(
hidden_features,
out_features,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.fc1(x)
x = self.act(x)
x, _ = self.fc2(x)
return x


class PositionEmbedding(nn.Module):
"""2D position embedding for vision tokens using sin-cos embeddings."""

def __init__(self, max_num_patch_per_side: int, hidden_size: int):
super().__init__()
self.max_num_patch_per_side = max_num_patch_per_side
self.hidden_size = hidden_size

# Create learnable 2D position embeddings (frozen sin-cos)
pos_embed = self._get_2d_sincos_pos_embed(hidden_size, max_num_patch_per_side)
self.register_buffer(
"pos_embed",
torch.from_numpy(pos_embed).float(),
persistent=False,
)

@staticmethod
def _get_2d_sincos_pos_embed(embed_dim: int, grid_size: int):
"""Generate 2D sin-cos position embeddings."""
import numpy as np

grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = PositionEmbedding._get_2d_sincos_pos_embed_from_grid(
embed_dim, grid
)
return pos_embed

@staticmethod
def _get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid):
"""Generate 2D sin-cos position embeddings from grid."""
import numpy as np

assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = PositionEmbedding._get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[0]
)
emb_w = PositionEmbedding._get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[1]
)
emb = np.concatenate([emb_h, emb_w], axis=1)
return emb

@staticmethod
def _get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos):
"""Generate 1D sin-cos position embeddings."""
import numpy as np

assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega

pos = pos.reshape(-1)
out = np.einsum("m,d->md", pos, omega)

emb_sin = np.sin(out)
emb_cos = np.cos(out)
emb = np.concatenate([emb_sin, emb_cos], axis=1)
return emb

def forward(self, position_ids: torch.Tensor) -> torch.Tensor:
"""
Args:
position_ids: Flattened position IDs, shape (N,) where each ID
corresponds to a position in the flattened grid
Returns:
Position embeddings of shape (N, hidden_size)
"""
# Ensure position_ids are on the same device as pos_embed
position_ids = position_ids.to(self.pos_embed.device)
return self.pos_embed[position_ids]


class BagelProcessingInfo(BaseProcessingInfo):
"""Processing information for BAGEL model."""

def get_hf_processor(self, **kwargs: object) -> BagelProcessor:
from vllm.transformers_utils.processor import cached_get_image_processor

image_processor = cached_get_image_processor(
self.ctx.model_config.model,
revision=self.ctx.model_config.revision,
trust_remote_code=self.ctx.model_config.trust_remote_code,
)

tokenizer = self.get_tokenizer()

return BagelProcessor(
image_processor=image_processor,
tokenizer=tokenizer,
**kwargs,
)

def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}

def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
hf_config = self.get_hf_config()
# Calculate max tokens per image
# For BAGEL: (vit_max_num_patch_per_side) ** 2
max_num_patches = hf_config.vit_max_num_patch_per_side**2
return {"image": max_num_patches}

def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self.get_hf_config()
vit_config = hf_config.vit_config
patch_size = vit_config.patch_size

# Calculate number of patches
num_patches_h = image_height // patch_size
num_patches_w = image_width // patch_size
return num_patches_h * num_patches_w


class BagelDummyInputsBuilder(BaseDummyInputsBuilder[BagelProcessingInfo]):
"""Build dummy inputs for BAGEL model profiling."""

def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
# Use a simple placeholder for each image
return "<|image_pad|>" * num_images

def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
hf_config = self.info.get_hf_config()
vit_config = hf_config.vit_config

# Use the configured image size
image_size = vit_config.image_size
image_overrides = mm_options.get("image") if mm_options else None

return {
"image": self._get_dummy_images(
width=image_size,
height=image_size,
num_images=num_images,
overrides=image_overrides,
),
}


class BagelMultiModalProcessor(BaseMultiModalProcessor[BagelProcessingInfo]):
"""Multimodal processor for BAGEL model."""

def _hf_processor_applies_updates(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> bool:
return False

def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptReplacement]:
"""Replace image placeholders with the correct number of tokens."""
hf_config = self.info.get_hf_config()

# Get the tokenizer to look up the image token ID
tokenizer = self.info.get_tokenizer()
image_token_id = tokenizer.get_vocab().get("<|image_pad|>")
if image_token_id is None:
raise ValueError(
"Image token '<|image_pad|>' not found in tokenizer vocabulary"
)

def get_replacement_bagel(item_idx: int):
# For BAGEL, calculate number of tokens based on max patch size
num_tokens = hf_config.vit_max_num_patch_per_side**2
# Use the image token ID from tokenizer
return [image_token_id] * num_tokens

return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=get_replacement_bagel,
)
]

def _get_mm_fields_config(
self,
hf_inputs: Any,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return {
"pixel_values": MultiModalFieldConfig.batched("image"),
}


@MULTIMODAL_REGISTRY.register_processor(
BagelMultiModalProcessor,
info=BagelProcessingInfo,
dummy_inputs=BagelDummyInputsBuilder,
)
class BagelForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP
):
"""
BAGEL: A unified multimodal model for image understanding and generation.

For vLLM, we focus on the image understanding (vision-to-text) capabilities.
The image generation part is not supported in vLLM.
"""

# Weight mapping from HF to vLLM
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"language_model.": "language_model.",
"vit_model.": "vit_model.",
"connector.": "connector.",
"vit_pos_embed.": "vit_pos_embed.",
}
)

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config

# Ensure we have a BagelConfig (check by name to handle trust_remote_code)
# When trust_remote_code=True, the config comes from transformers_modules
if type(config).__name__ != "BagelConfig":
raise ValueError(
f"Expected BagelConfig, got {type(config).__name__}. "
"Make sure the model config is properly loaded."
)

self.config = config
self.multimodal_config = multimodal_config

# Initialize language model (Qwen2)
# Pass the llm_config from BagelConfig to initialize Qwen2 properly
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.llm_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen2ForCausalLM"],
)

# Initialize vision model (SigLIP) if visual understanding is enabled
if config.visual_und:
# Fix vit_config: checkpoint has 26 layers (0-25) but config says 27
# Also disable head as it's not in checkpoint
vit_config = config.vit_config
if vit_config.num_hidden_layers == 27:
logger.warning(
"Overriding vit_config.num_hidden_layers from 27 to 26 "
"to match the Bagel model checkpoint."
)
vit_config.num_hidden_layers = 26
if not hasattr(vit_config, "vision_use_head"):
logger.warning(
"Setting vit_config.vision_use_head to False as it is not "
"present in the Bagel model checkpoint."
)
vit_config.vision_use_head = False

self.vit_model = SiglipVisionModel(
config=vit_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "vit_model"),
)

# Initialize connector (MLP)
vit_hidden_size = config.vit_config.hidden_size
llm_hidden_size = config.llm_config.hidden_size

self.connector = BagelVisionMLP(
in_features=vit_hidden_size,
hidden_features=llm_hidden_size,
out_features=llm_hidden_size,
act_layer=config.connector_act,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "connector"),
)

# Position embedding for vision tokens
self.vit_pos_embed = PositionEmbedding(
max_num_patch_per_side=config.vit_max_num_patch_per_side,
hidden_size=llm_hidden_size,
)
else:
self.vit_model = None
self.connector = None
self.vit_pos_embed = None

self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)

def _parse_and_validate_image_input(
self, **kwargs: object
) -> BagelImageInputs | None:
pixel_values = kwargs.pop("pixel_values", None)

if pixel_values is None:
return None

return BagelImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,
)

def _process_image_input(
self, image_input: BagelImageInputs
) -> tuple[torch.Tensor, ...]:
"""Process image inputs through vision encoder and connector."""
pixel_values = image_input["pixel_values"]

# Handle potential extra batch dimension
# Expected shape: (batch_size * num_images, 3, H, W)
# But might receive: (batch_size, num_images, 3, H, W)
if pixel_values.ndim == 5:
# Flatten batch and num_images dimensions
batch_size, num_images, channels, height, width = pixel_values.shape
pixel_values = pixel_values.reshape(
batch_size * num_images, channels, height, width
)

# Get vision features from SigLIP
# pixel_values shape: (batch_size * num_images, 3, H, W)
vision_features = self.vit_model(pixel_values)

# Pass through connector
vision_embeds = self.connector(vision_features)

# Add position embeddings
batch_size, num_patches, hidden_size = vision_embeds.shape
patch_size = self.config.vit_config.patch_size
image_size = self.config.vit_config.image_size

# Calculate grid dimensions
num_patches_per_side = image_size // patch_size

# Create flattened position IDs (0 to num_patches-1)
# For BAGEL, we use extrapolate mode by default
h_coords = torch.arange(num_patches_per_side, device=vision_embeds.device)
w_coords = torch.arange(num_patches_per_side, device=vision_embeds.device)
position_ids = (
h_coords[:, None] * self.config.vit_max_num_patch_per_side + w_coords
).flatten()
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1).flatten()

# Add position embeddings
pos_embeds = self.vit_pos_embed(position_ids)
pos_embeds = pos_embeds.reshape(batch_size, num_patches, hidden_size)
# Ensure pos_embeds are on the same device as vision_embeds
pos_embeds = pos_embeds.to(vision_embeds.device)
vision_embeds = vision_embeds + pos_embeds

# Split by image
return tuple(vision_embeds)

def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
"""Get multimodal embeddings from input."""
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []

return self._process_image_input(image_input)

def get_language_model(self) -> nn.Module:
return self.language_model

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
"""Run forward pass for BAGEL.

Args:
input_ids: Flattened (concatenated) input_ids corresponding to a batch.
positions: Flattened (concatenated) position ids corresponding to a batch.
intermediate_tensors: Intermediate tensors from prior forward pass.
inputs_embeds: Optional tensor of input embeddings.
"""
if intermediate_tensors is not None:
inputs_embeds = None

hidden_states = self.language_model.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states

def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Load weights from checkpoint."""
skip_prefixes = []
# Skip vit_pos_embed.pos_embed as it's handled by PositionEmbedding module
skip_prefixes.append("vit_pos_embed.pos_embed")

# If visual understanding is disabled, skip vision-related weights
if self.vit_model is None:
skip_prefixes.extend(["vit_model.", "connector.", "vit_pos_embed"])

# Skip generation-related weights since we only support text2text and image2text
# Filter out all image generation components:
# - 'moe_gen': MoE generation weights
# - 'latent_pos_embed': Latent position embeddings for VAE
# - 'llm2vae', 'vae2llm': LLM-VAE projections
# - 'time_embedder': Timestep embeddings for diffusion
# - VAE encoder/decoder: Use specific prefixes to avoid matching vision encoder
generation_keywords = [
"moe_gen",
"latent_pos_embed",
"llm2vae",
"vae2llm",
"time_embedder",
]
vae_prefixes = [
"decoder.",
"encoder.",
] # VAE encoder/decoder, not vision encoder
filtered_weights = []
for name, tensor in weights:
# Skip generation-related keywords
if any(skip in name for skip in generation_keywords):
continue
if any(name.startswith(prefix) for prefix in vae_prefixes):
continue

if "patch_embedding.weight" in name and tensor.ndim == 2:
out_channels = tensor.shape[0]
in_features = tensor.shape[1]
patch_size = self.config.vit_config.patch_size
in_channels = self.config.vit_config.num_channels
if in_features == in_channels * patch_size * patch_size:
tensor = tensor.reshape(
out_channels, patch_size, patch_size, in_channels
)
tensor = tensor.permute(0, 3, 1, 2).contiguous()

filtered_weights.append((name, tensor))

loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(filtered_weights, mapper=self.hf_to_vllm_mapper)

+ 46
- 83
vllm/model_executor/models/dots_ocr.py View File

@@ -5,15 +5,14 @@ from typing import Annotated, Literal, TypeAlias

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm
from transformers.models.qwen2_vl import Qwen2VLProcessor

from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
maybe_get_vit_flash_attn_backend,
from vllm.attention.layers.mm_encoder_attention import (
MMEncoderAttention,
)
from vllm.config import VllmConfig
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import utils as dist_utils
from vllm.distributed.parallel_state import (
@@ -254,11 +253,15 @@ class DotsVisionAttention(nn.Module):
bias: bool = True,
*,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)

self.embed_dim = dim
self.tp_size = (
@@ -287,31 +290,13 @@ class DotsVisionAttention(nn.Module):
prefix=f"{prefix}.proj",
disable_tp=use_data_parallel,
)
# Select attention backend
self.attn_backend = get_vit_attn_backend(
self.hidden_size_per_attention_head,
torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)

self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
self.attn = MMEncoderAttention(
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Unsupported vision attention backend: {self.attn_backend}"
)
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}

def forward(
self,
@@ -319,7 +304,7 @@ class DotsVisionAttention(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor | None = None,
*,
max_seqlen: int | None = None,
max_seqlen: torch.Tensor | None = None,
) -> torch.Tensor:
# [S, C] -> [S, B=1, C]
x = hidden_states.unsqueeze(1)
@@ -336,41 +321,13 @@ class DotsVisionAttention(nn.Module):
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
q, k = torch.chunk(qk_rotated, 2, dim=0)

if self.is_flash_attn_backend:
q_ = q.reshape(bs * q.shape[1], q.shape[2], q.shape[3])
k_ = k.reshape(bs * k.shape[1], k.shape[2], k.shape[3])
v_ = v.reshape(bs * v.shape[1], v.shape[2], v.shape[3])
output = self.flash_attn_varlen_func(
q_,
k_,
v_,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False,
)
context_layer = output.view(
bs,
-1,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
outputs = []
for i in range(1, len(cu_seqlens)):
s = int(cu_seqlens[i - 1])
e = int(cu_seqlens[i])
q_i = q[:, s:e].permute(0, 2, 1, 3)
k_i = k[:, s:e].permute(0, 2, 1, 3)
v_i = v[:, s:e].permute(0, 2, 1, 3)
out_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
out_i = out_i.permute(0, 2, 1, 3)
outputs.append(out_i)
context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
else:
raise RuntimeError("Unsupported attention backend")
context_layer = self.attn(
query=q,
key=k,
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)

# [B,S,H,D] -> [S,B,H*D] -> [S, C]
context_layer = context_layer.permute(1, 0, 2, 3).contiguous()
@@ -385,14 +342,19 @@ class DotsSwiGLUFFN(nn.Module):
config,
*,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
hidden_features = config.intermediate_size
in_features = config.embed_dim
bias = config.use_bias

use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
# Referenced aimv2.py AIMv2SwiGLUFFN
self.fc13 = MergedColumnParallelLinear(
in_features,
@@ -498,9 +460,8 @@ class DotsVisionBlock(nn.Module):
config,
*,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()

@@ -510,16 +471,15 @@ class DotsVisionBlock(nn.Module):
num_heads=config.num_attention_heads,
bias=config.use_bias,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
self.mlp = DotsSwiGLUFFN(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)
self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)

@@ -546,12 +506,11 @@ class DotsVisionTransformer(nn.Module):
self,
config: DotsVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
self.config = config
@@ -561,6 +520,11 @@ class DotsVisionTransformer(nn.Module):

head_dim = config.embed_dim // config.num_attention_heads
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
@@ -578,9 +542,8 @@ class DotsVisionTransformer(nn.Module):
DotsVisionBlock(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{i}",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
for i in range(num_layers)
]
@@ -592,6 +555,11 @@ class DotsVisionTransformer(nn.Module):
else:
self.post_trunk_norm = None

use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.merger = PatchMerger(
dim=config.hidden_size,
context_dim=config.embed_dim,
@@ -647,7 +615,7 @@ class DotsVisionTransformer(nn.Module):
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen

def forward(
@@ -733,17 +701,12 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
self.config.vision_config = vision_config
else:
vision_config = self.config.vision_config
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)

self.vision_tower = DotsVisionTransformer(
vision_config,
quant_config=self.quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_tower"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.language_model: Qwen2ForCausalLM = init_vllm_registered_model(
vllm_config=vllm_config,


+ 28
- 80
vllm/model_executor/models/ernie45_vl.py View File

@@ -37,10 +37,10 @@ from einops import rearrange, repeat
from transformers import BatchFeature

from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
maybe_get_vit_flash_attn_backend,
from vllm.attention.layers.mm_encoder_attention import (
MMEncoderAttention,
)
from vllm.config import VllmConfig
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
@@ -163,8 +163,8 @@ class Ernie4_5_VisionAttention(nn.Module):
num_heads: int,
projection_size: int,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
# Per attention head and per partition values.
@@ -193,33 +193,13 @@ class Ernie4_5_VisionAttention(nn.Module):
prefix=f"{prefix}.proj",
)

# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(
self.attn = MMEncoderAttention(
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)

self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)

if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Ernie45-VL does not support {self.attn_backend} backend now."
)
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}

def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
@@ -253,14 +233,13 @@ class Ernie4_5_VisionAttention(nn.Module):
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)

# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
q, k, v = self.split_qkv(x)
batch_size = q.shape[1]

q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
if rotary_pos_emb is not None:
@@ -268,43 +247,14 @@ class Ernie4_5_VisionAttention(nn.Module):
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
q, k = torch.chunk(qk_rotated, 2, dim=0)

if self.is_flash_attn_backend:
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

output = self.flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False,
)

context_layer = rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
outputs = []

lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
q_chunks = torch.split(q, lens, dim=1)
k_chunks = torch.split(k, lens, dim=1)
v_chunks = torch.split(v, lens, dim=1)
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
q_i, k_i, v_i = (
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
output = self.attn(
query=q,
key=k,
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
context_layer = rearrange(output, "b s h d -> s b (h d)").contiguous()

output, _ = self.proj(context_layer)
return output
@@ -350,8 +300,8 @@ class Ernie4_5_VisionBlock(nn.Module):
act_layer: type[nn.Module] = QuickGELU,
norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()

@@ -366,8 +316,8 @@ class Ernie4_5_VisionBlock(nn.Module):
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
attn_backend_override=attn_backend_override,
)

self.mlp = Ernie4_5_VisionMLP(
@@ -383,7 +333,7 @@ class Ernie4_5_VisionBlock(nn.Module):
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
@@ -441,8 +391,8 @@ class Ernie4_5_VisionTransformer(nn.Module):
vision_config,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
patch_size = vision_config.patch_size
@@ -477,8 +427,8 @@ class Ernie4_5_VisionTransformer(nn.Module):
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}",
attn_backend_override=attn_backend_override,
)
for layer_idx in range(depth)
]
@@ -489,6 +439,9 @@ class Ernie4_5_VisionTransformer(nn.Module):
)
self.ln = nn.LayerNorm(hidden_size, eps=1e-6)

attn_backend_override = (
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
@@ -535,13 +488,13 @@ class Ernie4_5_VisionTransformer(nn.Module):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb

def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> torch.Tensor | None:
max_seqlen = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen

def forward(
@@ -1304,17 +1257,12 @@ class Ernie4_5_VLMoeForConditionalGeneration(
self.config = config
self.multimodal_config = multimodal_config

attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.vision_model = Ernie4_5_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"),
attn_backend_override=attn_backend_override,
)

self.language_model = Ernie4_5_VLMoeForCausalLM(


+ 46
- 91
vllm/model_executor/models/glm4_1v.py View File

@@ -47,8 +47,10 @@ from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor
from transformers.video_utils import VideoMetadata

from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.config import VllmConfig
from vllm.attention.layers.mm_encoder_attention import (
MMEncoderAttention,
)
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
from vllm.distributed import utils as dist_utils
@@ -191,10 +193,15 @@ class Glm4vVisionMLP(nn.Module):
hidden_features: int,
bias: bool = False,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features,
output_sizes=[hidden_features] * 2,
@@ -248,12 +255,16 @@ class Glm4vVisionAttention(nn.Module):
num_heads: int,
projection_size: int,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
# Per attention head and per partition values.
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.tp_size = (
1 if use_data_parallel else get_tensor_model_parallel_world_size()
)
@@ -287,34 +298,12 @@ class Glm4vVisionAttention(nn.Module):
disable_tp=use_data_parallel,
)

# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(
self.attn = MMEncoderAttention(
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)

self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
multimodal_config=multimodal_config,
)

if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"GLM-4V does not support {self.attn_backend} backend now."
)

self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}

def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
@@ -338,14 +327,13 @@ class Glm4vVisionAttention(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)

# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
q, k, v = self.split_qkv(x)
batch_size = q.shape[1]

q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
@@ -356,43 +344,14 @@ class Glm4vVisionAttention(nn.Module):
)
q, k = torch.chunk(qk_rotated, 2, dim=0)

if self.is_flash_attn_backend:
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

output = self.flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False,
)

context_layer = rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
outputs = []

lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
q_chunks = torch.split(q, lens, dim=1)
k_chunks = torch.split(k, lens, dim=1)
v_chunks = torch.split(v, lens, dim=1)
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
q_i, k_i, v_i = (
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
context_layer = self.attn(
query=q,
key=k,
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()

output, _ = self.proj(context_layer)
return output
@@ -406,9 +365,8 @@ class Glm4vVisionBlock(nn.Module):
mlp_hidden_dim: int,
norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
if norm_layer is None:
@@ -420,17 +378,16 @@ class Glm4vVisionBlock(nn.Module):
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.mlp = Glm4vVisionMLP(
dim,
mlp_hidden_dim,
bias=False,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)

def forward(
@@ -489,11 +446,16 @@ class Glm4vPatchMerger(nn.Module):
d_model: int,
context_dim: int,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
bias: bool = False,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.hidden_size = d_model
self.proj = ColumnParallelLinear(
self.hidden_size,
@@ -649,19 +611,19 @@ class Glm4vVisionTransformer(nn.Module):
vision_config: Glm4vVisionConfig,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()

assert multimodal_config is not None, "multimodal_config must be provided"

patch_size = vision_config.patch_size
temporal_patch_size = vision_config.temporal_patch_size
in_channels = vision_config.in_channels
depth = vision_config.depth
self.hidden_size = vision_config.hidden_size
self.num_heads = vision_config.num_heads
self.use_data_parallel = use_data_parallel

self.patch_size = vision_config.patch_size
self.spatial_merge_size = vision_config.spatial_merge_size
@@ -690,9 +652,8 @@ class Glm4vVisionTransformer(nn.Module):
mlp_hidden_dim=vision_config.out_hidden_size,
norm_layer=norm_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
)
for layer_idx in range(depth)
]
@@ -701,9 +662,9 @@ class Glm4vVisionTransformer(nn.Module):
d_model=vision_config.out_hidden_size,
context_dim=vision_config.intermediate_size,
quant_config=quant_config,
multimodal_config=multimodal_config,
bias=False,
prefix=f"{prefix}.merger",
use_data_parallel=self.use_data_parallel,
)
self.embeddings = Glm4vVisionEmbeddings(vision_config)

@@ -723,7 +684,7 @@ class Glm4vVisionTransformer(nn.Module):
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
attn_backend_override=multimodal_config.mm_encoder_attn_backend,
)

@property
@@ -775,13 +736,13 @@ class Glm4vVisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
) -> int | None:
) -> torch.Tensor | None:
max_seqlen = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen

def forward(
@@ -1465,18 +1426,12 @@ class Glm4vForConditionalGeneration(
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"

attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Glm4vVisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
)

if config.model_type == "glm4v":


+ 27
- 80
vllm/model_executor/models/keye.py View File

@@ -9,7 +9,6 @@ from typing import Annotated, Any, Literal, TypeAlias, TypeVar
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import PretrainedConfig
from transformers.activations import GELUActivation
@@ -17,11 +16,10 @@ from transformers.feature_extraction_utils import BatchFeature
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.utils import torch_int

from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
maybe_get_vit_flash_attn_backend,
from vllm.attention.layers.mm_encoder_attention import (
MMEncoderAttention,
)
from vllm.config import VllmConfig
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
@@ -80,7 +78,6 @@ from .utils import (
is_pp_missing_parameter,
maybe_prefix,
)
from .vision import get_vit_attn_backend

logger = init_logger(__name__)

@@ -369,8 +366,8 @@ class KeyeSiglipAttention(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
@@ -408,34 +405,14 @@ class KeyeSiglipAttention(nn.Module):
prefix=f"{prefix}.out_proj",
)

# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(
self.attn = MMEncoderAttention(
num_heads=self.num_heads,
head_size=self.head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
num_kv_heads=self.num_kv_heads,
prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
)

self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
)

if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Keye-VL does not support {self.attn_backend} backend now."
)

self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}

def forward(
self,
hidden_states: torch.Tensor,
@@ -450,8 +427,7 @@ class KeyeSiglipAttention(nn.Module):
dim=-1,
)

max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
batch_size = q.shape[0]
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()

if rope_emb is None:
q = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
@@ -482,38 +458,14 @@ class KeyeSiglipAttention(nn.Module):
self.head_dim,
)

if self.is_flash_attn_backend:
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

output = self.flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
causal=False,
softmax_scale=self.scale,
)
context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = (
rearrange(x, "b s h d -> b h s d") for x in (q_i, k_i, v_i)
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]

context_layer = rearrange(context_layer, "b s h d -> b s (h d)").contiguous()
context_layer = self.attn(
query=q,
key=k,
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
context_layer = rearrange(context_layer, "b s h d -> b s (h d)")

output, _ = self.out_proj(context_layer)
return output
@@ -547,8 +499,8 @@ class KeyeSiglipEncoderLayer(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.embed_dim = config.hidden_size
@@ -556,8 +508,8 @@ class KeyeSiglipEncoderLayer(nn.Module):
self.self_attn = KeyeSiglipAttention(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn",
attn_backend_override=attn_backend_override,
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(
@@ -601,8 +553,8 @@ class KeyeSiglipEncoder(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
@@ -614,8 +566,8 @@ class KeyeSiglipEncoder(nn.Module):
KeyeSiglipEncoderLayer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}",
attn_backend_override=attn_backend_override,
)
for layer_idx in range(config.num_hidden_layers)
]
@@ -696,8 +648,8 @@ class KeyeSiglipVisionTransformer(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
@@ -707,8 +659,8 @@ class KeyeSiglipVisionTransformer(nn.Module):
self.encoder = KeyeSiglipEncoder(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.encoder",
attn_backend_override=attn_backend_override,
)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

@@ -779,16 +731,16 @@ class KeyeSiglipVisionModel(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()

self.vision_model = KeyeSiglipVisionTransformer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.vision_model",
attn_backend_override=attn_backend_override,
)
self.quant_config = quant_config

@@ -1329,16 +1281,11 @@ class BaseKeyeModule(nn.Module):
self.config = config
self.multimodal_config = multimodal_config

attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = KeyeSiglipVisionModel(
config.vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
attn_backend_override=attn_backend_override,
)

self.mlp_AR = self._build_projector(


+ 1
- 7
vllm/model_executor/models/opencua.py View File

@@ -240,18 +240,12 @@ class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
)

if multimodal_config.get_limit_per_prompt("image"):
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = OpenCUAVisionTransformer(
vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self.quant_config,
multimodal_config=self.multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
)
else:
self.visual = None


+ 6
- 16
vllm/model_executor/models/ovis2_5.py View File

@@ -10,8 +10,7 @@ import torch
import torch.nn as nn
from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig

from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -104,18 +103,16 @@ class VisualTokenizer(torch.nn.Module):
config: PretrainedConfig,
visual_vocab_size: int,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
self.vit = self._init_backbone(
config=config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.vit",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
# reserved tokens for INDICATOR_IDS
head_dim = visual_vocab_size - len(INDICATOR_IDS)
@@ -133,18 +130,16 @@ class VisualTokenizer(torch.nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
):
model_type = config.model_type
if model_type == "siglip2_navit":
return Siglip2NavitModel(
config=config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=prefix,
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
raise ValueError(f"Unsupported visual tokenizer model_type: {model_type}")

@@ -468,17 +463,12 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
prefix=maybe_prefix(prefix, "llm"),
)

attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual_tokenizer = VisualTokenizer(
config=config.vit_config,
visual_vocab_size=config.visual_vocab_size,
multimodal_config=multimodal_config,
quant_config=quant_config,
prefix=f"{prefix}.visual_tokenizer",
attn_backend_override=attn_backend_override,
)

self.vte = VisualEmbedding(config.visual_vocab_size, config.hidden_size)


+ 30
- 75
vllm/model_executor/models/paddleocr_vl.py View File

@@ -22,7 +22,6 @@ from typing import Annotated, Literal
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from transformers import BatchFeature, PretrainedConfig
from transformers.activations import GELUActivation
@@ -32,13 +31,10 @@ from transformers.modeling_outputs import (
from transformers.utils import torch_int

from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
maybe_get_vit_flash_attn_backend,
from vllm.attention.layers.mm_encoder_attention import (
MMEncoderAttention,
)
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
)
from vllm.config import VllmConfig
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
@@ -578,9 +574,8 @@ class SiglipAttention(nn.Module):
num_heads: int,
projection_size: int,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()

@@ -608,18 +603,12 @@ class SiglipAttention(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)

self.attn_backend = attn_backend
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
self.attn = MMEncoderAttention(
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}

def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
seq_len, bs, _ = qkv.shape
@@ -665,44 +654,16 @@ class SiglipAttention(nn.Module):
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
q, k = torch.chunk(qk_rotated, 2, dim=0)

if self.is_flash_attn_backend:
if max_seqlen is None:
raise ValueError("Flash attention backend requires max_seqlen.")
context_layer = vit_flash_attn_wrapper(
q,
k,
v,
cu_seqlens,
max_seqlen,
batch_size,
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = (
rearrange(tensor, "b s h d -> b h s d")
for tensor in (q_i, k_i, v_i)
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
else:
raise RuntimeError(
f"PaddleOCR-VL does not support {self.attn_backend} backend now."
)
context_layer = self.attn(
query=q,
key=k,
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
context_layer = rearrange(context_layer, "b s h d -> b s (h d)")

output, _ = self.out_proj(context_layer)
output = rearrange(output, "s b d -> b s d")
return output


@@ -774,10 +735,8 @@ class SiglipEncoderLayer(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
*,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.embed_dim = config.hidden_size
@@ -787,9 +746,8 @@ class SiglipEncoderLayer(nn.Module):
num_heads=config.num_attention_heads,
projection_size=config.hidden_size,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn",
attn_backend=attn_backend,
attn_backend_override=attn_backend_override,
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(
@@ -832,14 +790,18 @@ class SiglipEncoder(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
embed_dim = config.hidden_size
num_heads = config.num_attention_heads
head_dim = embed_dim // num_heads

attn_backend_override = (
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
@@ -858,9 +820,8 @@ class SiglipEncoder(nn.Module):
SiglipEncoderLayer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}",
attn_backend=self.attn_backend,
attn_backend_override=attn_backend_override,
)
for layer_idx in range(config.num_hidden_layers)
]
@@ -941,8 +902,8 @@ class SiglipVisionTransformer(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
@@ -952,8 +913,8 @@ class SiglipVisionTransformer(nn.Module):
self.encoder = SiglipEncoder(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.encoder",
attn_backend_override=attn_backend_override,
)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

@@ -991,16 +952,16 @@ class SiglipVisionModel(nn.Module):
self,
config,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()

self.vision_model = SiglipVisionTransformer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.vision_model",
attn_backend_override=attn_backend_override,
)
self.quant_config = quant_config

@@ -1119,17 +1080,11 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
self.config = config
self.multimodal_config = multimodal_config

attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)

self.visual = SiglipVisionModel(
config=config.vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
attn_backend_override=attn_backend_override,
)
self.mlp_AR = Projector(config, config.vision_config)



+ 3
- 0
vllm/model_executor/models/qwen.py View File

@@ -281,6 +281,9 @@ class QWenBaseModel(nn.Module):
self.transformer.make_empty_intermediate_tensors
)

def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.wte(input_ids)

def compute_logits(
self,
hidden_states: torch.Tensor,


+ 32
- 0
vllm/model_executor/models/qwen2.py View File

@@ -122,6 +122,8 @@ class Qwen2Attention(nn.Module):
prefix: str = "",
attn_type: str = AttentionType.DECODER,
dual_chunk_attention_config: dict[str, Any] | None = None,
qk_norm: bool = False,
rms_norm_eps: float = 1e-6,
) -> None:
super().__init__()
self.hidden_size = hidden_size
@@ -144,6 +146,7 @@ class Qwen2Attention(nn.Module):
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.dual_chunk_attention_config = dual_chunk_attention_config
self.qk_norm = qk_norm

self.qkv_proj = QKVParallelLinear(
hidden_size,
@@ -162,6 +165,11 @@ class Qwen2Attention(nn.Module):
prefix=f"{prefix}.o_proj",
)

# QK Normalization support (used in BAGEL and some other models)
if self.qk_norm:
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)

self.rotary_emb = get_rope(
self.head_dim,
max_position=max_position,
@@ -197,6 +205,23 @@ class Qwen2Attention(nn.Module):
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

# Apply QK normalization if enabled (before RoPE)
if self.qk_norm:
# Reshape to apply per-head normalization
# q shape: (total_tokens, q_size) -> (total_tokens, num_heads, head_dim)
total_tokens = q.shape[0]
q = q.view(total_tokens, self.num_heads, self.head_dim)
k = k.view(total_tokens, self.num_kv_heads, self.head_dim)

# Apply normalization
q = self.q_norm(q)
k = self.k_norm(k)

# Reshape back
q = q.view(total_tokens, self.q_size)
k = k.view(total_tokens, self.kv_size)

q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
@@ -227,6 +252,9 @@ class Qwen2DecoderLayer(nn.Module):
else:
attn_type = AttentionType.ENCODER_ONLY

# Check if QK normalization is enabled (used in BAGEL and some other models)
qk_norm = getattr(config, "qk_norm", False)

self.self_attn = Qwen2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
@@ -238,6 +266,8 @@ class Qwen2DecoderLayer(nn.Module):
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
dual_chunk_attention_config=dual_chunk_attention_config,
qk_norm=qk_norm,
rms_norm_eps=config.rms_norm_eps,
)
self.mlp = Qwen2MLP(
hidden_size=self.hidden_size,
@@ -480,6 +510,8 @@ class Qwen2Model(nn.Module):
continue
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)


+ 1
- 0
vllm/model_executor/models/qwen2_5_omni_thinker.py View File

@@ -845,6 +845,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
multimodal_config=multimodal_config,
)
else:
self.visual = None


+ 48
- 76
vllm/model_executor/models/qwen2_5_vl.py View File

@@ -42,13 +42,9 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
)

from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_torch_sdpa_wrapper,
)
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.config import MultiModalConfig, VllmConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.forward_context import set_forward_context
@@ -267,10 +263,15 @@ class Qwen2_5_VisionMLP(nn.Module):
bias: bool = False,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features,
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
@@ -304,13 +305,16 @@ class Qwen2_5_VisionAttention(nn.Module):
num_heads: int,
projection_size: int,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
# Per attention head and per partition values.
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.tp_size = (
1
if use_data_parallel
@@ -342,18 +346,12 @@ class Qwen2_5_VisionAttention(nn.Module):
prefix=f"{prefix}.proj",
disable_tp=use_data_parallel,
)
self.attn_backend = attn_backend
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
)

self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
self.attn = MMEncoderAttention(
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
multimodal_config=multimodal_config,
)

def forward(
self,
@@ -394,32 +392,17 @@ class Qwen2_5_VisionAttention(nn.Module):
else:
q, k, v = qkv.unbind(dim=2)

if self.is_flash_attn_backend:
context_layer = vit_flash_attn_wrapper(
q,
k,
v,
cu_seqlens,
max_seqlen,
batch_size,
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
from vllm.platforms import current_platform

# Never remove the next contiguous logic
# Without it, hallucinations occur with the backend
if current_platform.is_rocm():
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
context_layer = vit_torch_sdpa_wrapper(
q,
k,
v,
cu_seqlens,
)
context_layer = self.attn(
query=q,
key=k,
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)

context_layer = einops.rearrange(
context_layer, "b s h d -> s b (h d)", b=batch_size
).contiguous()

output, _ = self.proj(context_layer)
return output
@@ -443,10 +426,8 @@ class Qwen2_5_VisionBlock(nn.Module):
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
if norm_layer is None:
@@ -458,10 +439,8 @@ class Qwen2_5_VisionBlock(nn.Module):
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend=attn_backend,
attn_backend_override=attn_backend_override,
)
self.mlp = Qwen2_5_VisionMLP(
dim,
@@ -469,8 +448,8 @@ class Qwen2_5_VisionBlock(nn.Module):
act_fn=act_fn,
bias=True,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)

def forward(
@@ -542,10 +521,15 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
norm_layer: Callable[[int], nn.Module] | None = None,
spatial_merge_size: int = 2,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.hidden_size = context_dim * (spatial_merge_size**2)
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
@@ -586,9 +570,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
vision_config: Qwen2_5_VLVisionConfig,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()

@@ -598,7 +581,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
depth = vision_config.depth
self.hidden_size = vision_config.hidden_size
self.num_heads = vision_config.num_heads
self.use_data_parallel = use_data_parallel
self.out_hidden_size = vision_config.out_hidden_size

# args for get_window_index_thw
@@ -629,19 +611,17 @@ class Qwen2_5_VisionTransformer(nn.Module):
rope_parameters={"partial_rotary_factor": 0.5},
)

attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)

self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
)

if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
@@ -661,10 +641,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
act_fn=get_act_and_mul_fn(vision_config.hidden_act),
norm_layer=norm_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
attn_backend=self.attn_backend,
attn_backend_override=attn_backend_override,
)
for layer_idx in range(depth)
]
@@ -677,8 +655,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
norm_layer=norm_layer,
spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel,
)

@property
@@ -1200,18 +1178,12 @@ class Qwen2_5_VLForConditionalGeneration(
if multimodal_config.get_limit_per_prompt(
"image"
) or multimodal_config.get_limit_per_prompt("video"):
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Qwen2_5_VisionTransformer(
vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
multimodal_config=multimodal_config,
)
else:
self.visual = None


+ 47
- 96
vllm/model_executor/models/qwen2_vl.py View File

@@ -33,7 +33,6 @@ from typing import Annotated, Any, Literal, TypeAlias
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import BatchFeature
from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
@@ -45,10 +44,8 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor

from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
maybe_get_vit_flash_attn_backend,
)
from vllm.config import VllmConfig
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
from vllm.distributed import utils as dist_utils
@@ -251,10 +248,15 @@ class Qwen2VisionMLP(nn.Module):
hidden_features: int,
act_layer: type[nn.Module] = QuickGELU,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.fc1 = ColumnParallelLinear(
in_features,
hidden_features,
@@ -295,12 +297,16 @@ class Qwen2VisionAttention(nn.Module):
num_heads: int,
projection_size: int,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
# Per attention head and per partition values.
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.tp_size = (
1
if use_data_parallel
@@ -329,34 +335,12 @@ class Qwen2VisionAttention(nn.Module):
disable_tp=use_data_parallel,
)

# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(
self.attn = MMEncoderAttention(
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)

self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
multimodal_config=multimodal_config,
)

if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Qwen2-VL does not support {self.attn_backend} backend now."
)

self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}

def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
@@ -398,7 +382,6 @@ class Qwen2VisionAttention(nn.Module):

# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
q, k, v = self.split_qkv(x)
batch_size = q.shape[1]

q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v))

@@ -409,49 +392,15 @@ class Qwen2VisionAttention(nn.Module):
)
q, k = torch.chunk(qk_rotated, 2, dim=0)

if self.is_flash_attn_backend:
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

output = self.flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False,
)
context_layer = self.attn(
query=q,
key=k,
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)

context_layer = rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
from vllm.platforms import current_platform

if current_platform.is_rocm():
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
outputs = []

lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
q_chunks = torch.split(q, lens, dim=1)
k_chunks = torch.split(k, lens, dim=1)
v_chunks = torch.split(v, lens, dim=1)
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
q_i, k_i, v_i = (
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()

output, _ = self.proj(context_layer)
return output
@@ -466,9 +415,8 @@ class Qwen2VisionBlock(nn.Module):
act_layer: type[nn.Module] = QuickGELU,
norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
if norm_layer is None:
@@ -482,17 +430,16 @@ class Qwen2VisionBlock(nn.Module):
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.mlp = Qwen2VisionMLP(
dim,
mlp_hidden_dim,
act_layer=act_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)

def forward(
@@ -552,10 +499,15 @@ class Qwen2VisionPatchMerger(nn.Module):
norm_layer: Callable[[int], nn.Module] | None = None,
spatial_merge_size: int = 2,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.hidden_size = context_dim * (spatial_merge_size**2)
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
@@ -599,9 +551,8 @@ class Qwen2VisionTransformer(nn.Module):
vision_config: Qwen2VLVisionConfig,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()

@@ -615,7 +566,11 @@ class Qwen2VisionTransformer(nn.Module):
num_heads = vision_config.num_heads
mlp_ratio = vision_config.mlp_ratio

self.use_data_parallel = use_data_parallel
self.use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.out_hidden_size = vision_config.hidden_size

self.spatial_merge_size = spatial_merge_size
@@ -647,8 +602,7 @@ class Qwen2VisionTransformer(nn.Module):
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
multimodal_config=multimodal_config,
)
for layer_idx in range(depth)
]
@@ -659,7 +613,10 @@ class Qwen2VisionTransformer(nn.Module):
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel,
multimodal_config=multimodal_config,
)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
@@ -720,7 +677,7 @@ class Qwen2VisionTransformer(nn.Module):
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen

def forward(
@@ -1324,18 +1281,12 @@ class Qwen2VLForConditionalGeneration(
if multimodal_config.get_limit_per_prompt(
"image"
) or multimodal_config.get_limit_per_prompt("video"):
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Qwen2VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
)
else:
self.visual = None


+ 12
- 8
vllm/model_executor/models/qwen3_omni_moe_thinker.py View File

@@ -48,7 +48,7 @@ from transformers.models.whisper import WhisperFeatureExtractor

from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.config import MultiModalConfig, VllmConfig
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
@@ -192,6 +192,7 @@ class Qwen3_VisionBlock(nn.Module):
mlp_hidden_dim: int,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Callable[[int], nn.Module] | None = None,
multimodal_config: MultiModalConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
@@ -205,6 +206,7 @@ class Qwen3_VisionBlock(nn.Module):
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)
self.mlp = Qwen3_VisionMLP(
@@ -299,8 +301,8 @@ class Qwen3Omni_VisionTransformer(nn.Module):
vision_config,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
self.hidden_size = vision_config.hidden_size
@@ -347,6 +349,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}",
)
for layer_idx in range(vision_config.depth)
@@ -376,6 +379,12 @@ class Qwen3Omni_VisionTransformer(nn.Module):
]
)

attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)

self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
@@ -1188,17 +1197,12 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(

self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config)

attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Qwen3Omni_VisionTransformer(
vision_config=thinker_config.vision_config,
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
attn_backend_override=attn_backend_override,
multimodal_config=multimodal_config,
)
self.quant_config = quant_config



+ 24
- 22
vllm/model_executor/models/qwen3_vl.py View File

@@ -50,7 +50,7 @@ from transformers.video_utils import VideoMetadata

from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
@@ -169,10 +169,15 @@ class Qwen3_VisionMLP(nn.Module):
bias: bool = False,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.linear_fc1 = ColumnParallelLinear(
in_features,
hidden_features,
@@ -206,10 +211,9 @@ class Qwen3_VisionBlock(nn.Module):
mlp_hidden_dim: int,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Callable[[int], nn.Module] | None = None,
multimodal_config: MultiModalConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
) -> None:
super().__init__()
if norm_layer is None:
@@ -221,9 +225,8 @@ class Qwen3_VisionBlock(nn.Module):
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend=attn_backend,
)
self.mlp = Qwen3_VisionMLP(
dim,
@@ -231,8 +234,8 @@ class Qwen3_VisionBlock(nn.Module):
act_fn=act_fn,
bias=True,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)

def forward(
@@ -264,10 +267,15 @@ class Qwen3_VisionPatchMerger(nn.Module):
spatial_merge_size: int = 2,
use_postshuffle_norm: bool = False,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.hidden_size = context_dim * (spatial_merge_size**2)

self.use_postshuffle_norm = use_postshuffle_norm
@@ -313,9 +321,8 @@ class Qwen3_VisionTransformer(nn.Module):
vision_config: Qwen3VLVisionConfig,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
self.hidden_size = vision_config.hidden_size
@@ -326,7 +333,6 @@ class Qwen3_VisionTransformer(nn.Module):
self.spatial_merge_unit = self.spatial_merge_size**2
self.temporal_patch_size = vision_config.temporal_patch_size
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
self.use_data_parallel = use_data_parallel
self.num_grid_per_side = int(self.num_position_embeddings**0.5)

# NOTE: This is used for creating empty tensor for all_gather for
@@ -359,8 +365,8 @@ class Qwen3_VisionTransformer(nn.Module):
norm_layer=norm_layer,
spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel,
)

self.deepstack_merger_list = nn.ModuleList(
@@ -372,13 +378,16 @@ class Qwen3_VisionTransformer(nn.Module):
use_postshuffle_norm=True,
norm_layer=norm_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.deepstack_merger_list.{layer_idx}",
use_data_parallel=use_data_parallel,
)
for layer_idx in range(len(self.deepstack_visual_indexes))
]
)

attn_backend_override = (
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
@@ -402,9 +411,8 @@ class Qwen3_VisionTransformer(nn.Module):
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
attn_backend=self.attn_backend,
)
for layer_idx in range(vision_config.depth)
]
@@ -1277,18 +1285,12 @@ class Qwen3VLForConditionalGeneration(
) and not multimodal_config.get_limit_per_prompt("video"):
self.visual = None
else:
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
)

self.language_model = Qwen3LLMForCausalLM(


+ 5
- 1
vllm/model_executor/models/qwen3_vl_moe.py View File

@@ -419,6 +419,10 @@ class Qwen3VLMoeForConditionalGeneration(
self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.video_pruning_rate = multimodal_config.video_pruning_rate
self.is_multimodal_pruning_enabled = (
multimodal_config.is_multimodal_pruning_enabled()
)

if not multimodal_config.get_limit_per_prompt(
"image"
@@ -429,8 +433,8 @@ class Qwen3VLMoeForConditionalGeneration(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)

self.language_model = Qwen3MoeLLMForCausalLM(


+ 1
- 0
vllm/model_executor/models/registry.py View File

@@ -272,6 +272,7 @@ _MULTIMODAL_MODELS = {
"aya_vision",
"AyaVisionForConditionalGeneration",
),
"BagelForConditionalGeneration": ("bagel", "BagelForConditionalGeneration"),
"BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
"ChameleonForConditionalGeneration": (


+ 45
- 84
vllm/model_executor/models/siglip2navit.py View File

@@ -13,7 +13,8 @@ from transformers import Siglip2VisionConfig
from transformers.configuration_utils import PretrainedConfig

from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import MultiModalConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
@@ -28,8 +29,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import current_platform

from .vision import get_vit_attn_backend


class VisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
@@ -190,7 +189,7 @@ def apply_rotary_pos_emb(
) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
if is_flash_attn_backend and not current_platform.is_xpu():
if is_flash_attn_backend and current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb

apply_rotary_emb_func = apply_rotary_emb
@@ -208,6 +207,7 @@ class Siglip2Attention(nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
@@ -227,20 +227,25 @@ class Siglip2Attention(nn.Module):
self.dropout = config.attention_dropout
self.is_causal = False

# TODO(Isotr0py): Enable data parallel after we support
# disabling TP on parallel linear layer
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
total_num_heads=self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
disable_tp=use_data_parallel,
)
self.out_proj = RowParallelLinear(
input_size=self.embed_dim,
output_size=self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
disable_tp=use_data_parallel,
)

self.tp_size = (
@@ -249,31 +254,13 @@ class Siglip2Attention(nn.Module):
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.use_rope = config.use_rope

# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(
self.attn = MMEncoderAttention(
num_heads=self.num_heads_per_partition,
head_size=self.head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)

self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
)

if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
self.attn_backend = AttentionBackendEnum.TORCH_SDPA
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}

def forward(
self,
hidden_states: torch.Tensor,
@@ -298,46 +285,23 @@ class Siglip2Attention(nn.Module):
keys.unsqueeze(0),
cos,
sin,
self.is_flash_attn_backend,
self.attn.is_flash_attn_backend,
)
queries = queries.squeeze(0)
keys = keys.squeeze(0)

max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
if self.is_flash_attn_backend:
attn_output = self.flash_attn_varlen_func(
queries,
keys,
values,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
).reshape(seq_length, -1)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
batch_size = cu_seqlens.shape[0] - 1
outputs = []
cu = cu_seqlens.tolist()
for i in range(batch_size):
start_idx = cu[i]
end_idx = cu[i + 1]

# Each sequence is processed independently.
q_i = queries[start_idx:end_idx].unsqueeze(0)
k_i = keys[start_idx:end_idx].unsqueeze(0)
v_i = values[start_idx:end_idx].unsqueeze(0)

# (1, seq_len, num_heads, head_dim) ->
# (1, num_heads, seq_len, head_dim)
q_i, k_i, v_i = [x.transpose(1, 2) for x in (q_i, k_i, v_i)]

output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
# (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim)
output_i = output_i.transpose(1, 2).reshape(end_idx - start_idx, -1)
outputs.append(output_i)

attn_output = torch.cat(outputs, dim=0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
attn_output = self.attn(
query=queries.unsqueeze(0),
key=keys.unsqueeze(0),
value=values.unsqueeze(0),
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
attn_output = attn_output.reshape(
seq_length, self.num_heads_per_partition * self.head_dim
)

attn_output, _ = self.out_proj(attn_output)
return attn_output

@@ -347,25 +311,30 @@ class Siglip2MLP(nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.config = config
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.activation_fn = get_act_fn(config.hidden_act)
# TODO(Isotr0py): Enable data parallel after we support
# disabling TP on parallel linear layer
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
disable_tp=use_data_parallel,
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
disable_tp=use_data_parallel,
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -380,9 +349,8 @@ class Siglip2EncoderLayer(nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.embed_dim = config.hidden_size
@@ -390,16 +358,15 @@ class Siglip2EncoderLayer(nn.Module):
self.self_attn = Siglip2Attention(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Siglip2MLP(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)

def forward(
@@ -444,9 +411,8 @@ class Siglip2Encoder(nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
@@ -455,9 +421,8 @@ class Siglip2Encoder(nn.Module):
Siglip2EncoderLayer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{idx}",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
for idx in range(config.num_hidden_layers)
]
@@ -630,9 +595,8 @@ class Siglip2VisionTransformer(nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
@@ -642,9 +606,8 @@ class Siglip2VisionTransformer(nn.Module):
self.encoder = Siglip2Encoder(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.encoder",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

@@ -671,18 +634,16 @@ class Siglip2NavitModel(torch.nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()

self.vision_model = Siglip2VisionTransformer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.vision_model",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)

def forward(


+ 8
- 5
vllm/model_executor/models/vision.py View File

@@ -88,14 +88,17 @@ def get_vit_attn_backend(
"""
Get the available attention backend for Vision Transformer.
"""
if attn_backend_override is not None:
return attn_backend_override
attn_backend = attn_backend_override

selected_backend = get_current_vllm_config().attention_config.backend
if selected_backend is not None:
return selected_backend
if attn_backend is None:
attn_backend = selected_backend

return current_platform.get_vit_attn_backend(head_size, dtype)
return current_platform.get_vit_attn_backend(
head_size,
dtype,
backend=attn_backend,
)


def should_torch_compile_mm_vit(vllm_config: VllmConfig) -> bool:


+ 10
- 2
vllm/multimodal/audio.py View File

@@ -127,13 +127,21 @@ class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]):

def load_bytes(self, data: bytes) -> torch.Tensor:
buffer = BytesIO(data)
return torch.load(buffer, weights_only=True)
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with torch.sparse.check_sparse_tensor_invariants():
tensor = torch.load(buffer, weights_only=True)
return tensor.to_dense()

def load_base64(self, media_type: str, data: str) -> torch.Tensor:
return self.load_bytes(pybase64.b64decode(data, validate=True))

def load_file(self, filepath: Path) -> torch.Tensor:
return torch.load(filepath, weights_only=True)
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with torch.sparse.check_sparse_tensor_invariants():
tensor = torch.load(filepath, weights_only=True)
return tensor.to_dense()

def encode_base64(self, media: torch.Tensor) -> str:
return tensor2base64(media)

+ 10
- 2
vllm/multimodal/image.py View File

@@ -122,13 +122,21 @@ class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):

def load_bytes(self, data: bytes) -> torch.Tensor:
buffer = BytesIO(data)
return torch.load(buffer, weights_only=True)
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with torch.sparse.check_sparse_tensor_invariants():
tensor = torch.load(buffer, weights_only=True)
return tensor.to_dense()

def load_base64(self, media_type: str, data: str) -> torch.Tensor:
return self.load_bytes(pybase64.b64decode(data, validate=True))

def load_file(self, filepath: Path) -> torch.Tensor:
return torch.load(filepath, weights_only=True)
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with torch.sparse.check_sparse_tensor_invariants():
tensor = torch.load(filepath, weights_only=True)
return tensor.to_dense()

def encode_base64(self, media: torch.Tensor) -> str:
return pybase64.b64encode(media.numpy()).decode("utf-8")

+ 36
- 18
vllm/platforms/cuda.py View File

@@ -7,7 +7,7 @@ pynvml. However, it should not initialize cuda context.
import os
from collections.abc import Callable
from functools import cache, wraps
from typing import TYPE_CHECKING, TypeVar
from typing import TYPE_CHECKING, Optional, TypeVar

import torch
from typing_extensions import ParamSpec
@@ -255,23 +255,6 @@ class CudaPlatformBase(Platform):
torch.cuda.reset_peak_memory_stats(device)
return torch.cuda.max_memory_allocated(device)

@classmethod
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
# Try FlashAttention first
if (cc := cls.get_device_capability()) and cc.major >= 8:
try:
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
if backend_class.supports_head_size(
head_size
) and backend_class.supports_dtype(dtype):
return AttentionBackendEnum.FLASH_ATTN
except ImportError:
pass

return AttentionBackendEnum.TORCH_SDPA

@classmethod
def get_valid_backends(
cls,
@@ -418,6 +401,41 @@ class CudaPlatformBase(Platform):

return selected_backend.get_path()

@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.FLASH_ATTN,
]

@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention. "
f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
)
logger.info_once(f"Using backend {backend} for vit attention")
return backend

# Try FlashAttention first
if (cc := cls.get_device_capability()) and cc.major >= 8:
try:
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
if backend_class.supports_head_size(
head_size
) and backend_class.supports_dtype(dtype):
return AttentionBackendEnum.FLASH_ATTN
except ImportError:
pass

return AttentionBackendEnum.TORCH_SDPA

@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"


+ 38
- 7
vllm/platforms/interface.py View File

@@ -7,7 +7,7 @@ import platform
import random
import sys
from datetime import timedelta
from typing import TYPE_CHECKING, Any, NamedTuple
from typing import TYPE_CHECKING, Any, NamedTuple, Optional

import numpy as np
import torch
@@ -222,12 +222,6 @@ class Platform:
with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401

@classmethod
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
return AttentionBackendEnum.TORCH_SDPA

@classmethod
def get_attn_backend_cls(
cls,
@@ -245,6 +239,43 @@ class Platform:
"""Get the attention backend class of a device."""
return ""

@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.TORCH_SDPA,
]

@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
"""
Get the vision attention backend class of a device.

NOTE: ViT Attention should be checked and override in the platform-specific
implementation. we should not override this in any other places, like
the model_executor/models/<model_name>.py.

We check if the backend is None or not:
1. If not, check if the backend is supported by the platform.
2. If None, continue to the default selection logic.
"""
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention"
f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
)
logger.info_once(f"Using backend {backend} for vit attention")
return backend

logger.info_once(
f"Using default backend {AttentionBackendEnum.TORCH_SDPA} for vit attention"
)
return AttentionBackendEnum.TORCH_SDPA

@classmethod
def get_device_capability(
cls,


+ 38
- 19
vllm/platforms/rocm.py View File

@@ -3,7 +3,7 @@

import os
from functools import cache, lru_cache, wraps
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

import torch

@@ -187,24 +187,6 @@ class RocmPlatform(Platform):
if not on_gfx9():
supported_quantization += ["bitsandbytes"]

@classmethod
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> AttentionBackendEnum:
from importlib.util import find_spec

from vllm._aiter_ops import rocm_aiter_ops

if rocm_aiter_ops.is_mha_enabled():
# Note: AITER FA is only supported for Qwen-VL models.
# TODO: Add support for other VL models in their model class.
return AttentionBackendEnum.ROCM_AITER_FA

if on_gfx9() and find_spec("flash_attn") is not None:
return AttentionBackendEnum.FLASH_ATTN

return AttentionBackendEnum.TORCH_SDPA

@classmethod
def get_attn_backend_cls(
cls,
@@ -322,6 +304,43 @@ class RocmPlatform(Platform):
"ROCm. Note that V0 attention backends have been removed."
)

@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.TORCH_SDPA,
]

@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention. "
f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
)
logger.info_once(f"Using backend {backend} for vit attention")
return backend

from importlib.util import find_spec

from vllm._aiter_ops import rocm_aiter_ops

if rocm_aiter_ops.is_mha_enabled():
# Note: AITER FA is only supported for Qwen-VL models.
# TODO: Add support for other VL models in their model class.
return AttentionBackendEnum.ROCM_AITER_FA

if on_gfx9() and find_spec("flash_attn") is not None:
return AttentionBackendEnum.FLASH_ATTN

return AttentionBackendEnum.TORCH_SDPA

@classmethod
def set_device(cls, device: torch.device) -> None:
"""


+ 27
- 1
vllm/platforms/tpu.py View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import contextlib
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Optional, cast

import torch
from tpu_info import device
@@ -75,6 +75,32 @@ class TpuPlatform(Platform):
logger.info("Using Pallas V1 backend.")
return AttentionBackendEnum.PALLAS.get_path()

@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.PALLAS,
]

@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention"
f"Supported backends are: {cls.get_supported_vit_attn_backends()}."
)
logger.info_once(f"Using backend {backend} for vit attention.")
return backend

logger.info_once(
f"Using default backend {AttentionBackendEnum.PALLAS} for vit attention."
)
return AttentionBackendEnum.PALLAS

@classmethod
def set_device(cls, device: torch.device) -> None:
"""


+ 29
- 7
vllm/platforms/xpu.py View File

@@ -3,7 +3,7 @@

import contextlib
import os
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

import torch

@@ -77,6 +77,34 @@ class XPUPlatform(Platform):
logger.info("Using Flash Attention backend.")
return AttentionBackendEnum.FLASH_ATTN.get_path()

@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
# XPU only supports FLASH_ATTN for vision attention.
return [
AttentionBackendEnum.FLASH_ATTN,
]

@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention. "
f"Supported backends are: "
f"{cls.get_supported_vit_attn_backends()}."
)
logger.info_once(f"Using backend {backend} for vit attention")
return backend

logger.info_once(
f"Using backend {AttentionBackendEnum.FLASH_ATTN} for vit attention"
)
return AttentionBackendEnum.FLASH_ATTN

@classmethod
def set_device(cls, device: torch.device) -> None:
"""
@@ -110,12 +138,6 @@ class XPUPlatform(Platform):
device_props = torch.xpu.get_device_properties(device_id)
return device_props.total_memory

@classmethod
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
return AttentionBackendEnum.FLASH_ATTN

@classmethod
def inference_mode(cls):
return torch.no_grad()


vllm/tokenizers/deepseekv32.py → vllm/tokenizers/deepseek_v32.py View File


+ 1
- 1
vllm/tokenizers/registry.py View File

@@ -30,7 +30,7 @@ logger = init_logger(__name__)


_VLLM_TOKENIZERS = {
"deepseekv32": ("deepseekv32", "DeepseekV32Tokenizer"),
"deepseek_v32": ("deepseek_v32", "DeepseekV32Tokenizer"),
"hf": ("hf", "CachedHfTokenizer"),
"mistral": ("mistral", "MistralTokenizer"),
}


+ 150
- 0
vllm/tool_parsers/__init__.py View File

@@ -0,0 +1,150 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
ToolParserManager,
)

__all__ = ["ToolParser", "ToolParserManager"]


"""
Register a lazy module mapping.

Example:
ToolParserManager.register_lazy_module(
name="kimi_k2",
module_path="vllm.tool_parsers.kimi_k2_parser",
class_name="KimiK2ToolParser",
)
"""


_TOOL_PARSERS_TO_REGISTER = {
"deepseek_v3": ( # name
"deepseekv3_tool_parser", # filename
"DeepSeekV3ToolParser", # class_name
),
"deepseek_v31": (
"deepseekv31_tool_parser",
"DeepSeekV31ToolParser",
),
"deepseek_v32": (
"deepseekv32_tool_parser",
"DeepSeekV32ToolParser",
),
"ernie45": (
"ernie45_tool_parser",
"Ernie45ToolParser",
),
"glm45": (
"glm4_moe_tool_parser",
"Glm4MoeModelToolParser",
),
"granite-20b-fc": (
"granite_20b_fc_tool_parser",
"Granite20bFCToolParser",
),
"granite": (
"granite_tool_parser",
"GraniteToolParser",
),
"hermes": (
"hermes_tool_parser",
"Hermes2ProToolParser",
),
"hunyuan_a13b": (
"hunyuan_a13b_tool_parser",
"HunyuanA13BToolParser",
),
"internlm": (
"internlm2_tool_parser",
"Internlm2ToolParser",
),
"jamba": (
"jamba_tool_parser",
"JambaToolParser",
),
"kimi_k2": (
"kimi_k2_tool_parser",
"KimiK2ToolParser",
),
"llama3_json": (
"llama_tool_parser",
"Llama3JsonToolParser",
),
"llama4_json": (
"llama_tool_parser",
"Llama3JsonToolParser",
),
"llama4_pythonic": (
"llama4_pythonic_tool_parser",
"Llama4PythonicToolParser",
),
"longcat": (
"longcat_tool_parser",
"LongcatFlashToolParser",
),
"minimax_m2": (
"minimax_m2_tool_parser",
"MinimaxM2ToolParser",
),
"minimax": (
"minimax_tool_parser",
"MinimaxToolParser",
),
"mistral": (
"mistral_tool_parser",
"MistralToolParser",
),
"olmo3": (
"olmo3_tool_parser",
"Olmo3PythonicToolParser",
),
"openai": (
"openai_tool_parser",
"OpenAIToolParser",
),
"phi4_mini_json": (
"phi4mini_tool_parser",
"Phi4MiniJsonToolParser",
),
"pythonic": (
"pythonic_tool_parser",
"PythonicToolParser",
),
"qwen3_coder": (
"qwen3coder_tool_parser",
"Qwen3CoderToolParser",
),
"qwen3_xml": (
"qwen3xml_tool_parser",
"Qwen3XMLToolParser",
),
"seed_oss": (
"seed_oss_tool_parser",
"SeedOssToolParser",
),
"step3": (
"step3_tool_parser",
"Step3ToolParser",
),
"xlam": (
"xlam_tool_parser",
"xLAMToolParser",
),
"gigachat3": (
"gigachat3_tool_parser",
"GigaChat3ToolParser",
),
}


def register_lazy_tool_parsers():
for name, (file_name, class_name) in _TOOL_PARSERS_TO_REGISTER.items():
module_path = f"vllm.tool_parsers.{file_name}"
ToolParserManager.register_lazy_module(name, module_path, class_name)


register_lazy_tool_parsers()

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save
Baidu
map