|
|
|
@@ -0,0 +1,584 @@ |
|
|
|
# SPDX-License-Identifier: Apache-2.0 |
|
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
|
|
|
# Copyright 2025 Bytedance Ltd. and/or its affiliates. |
|
|
|
"""Inference-only BAGEL model compatible with HuggingFace weights. |
|
|
|
|
|
|
|
BAGEL is a unified multimodal model for image understanding and generation. |
|
|
|
For vLLM, we focus on the image understanding (vision-to-text) capabilities. |
|
|
|
""" |
|
|
|
|
|
|
|
from collections.abc import Iterable, Mapping, Sequence |
|
|
|
from typing import Any, Literal, TypeAlias |
|
|
|
|
|
|
|
import torch |
|
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
from vllm.config import VllmConfig |
|
|
|
from vllm.config.multimodal import BaseDummyOptions |
|
|
|
from vllm.logger import init_logger |
|
|
|
from vllm.model_executor.layers.activation import get_act_fn |
|
|
|
from vllm.model_executor.layers.linear import ( |
|
|
|
ColumnParallelLinear, |
|
|
|
RowParallelLinear, |
|
|
|
) |
|
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig |
|
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY |
|
|
|
from vllm.multimodal.inputs import ( |
|
|
|
MultiModalDataDict, |
|
|
|
MultiModalFieldConfig, |
|
|
|
MultiModalKwargsItems, |
|
|
|
) |
|
|
|
from vllm.multimodal.parse import MultiModalDataItems |
|
|
|
from vllm.multimodal.processing import ( |
|
|
|
BaseMultiModalProcessor, |
|
|
|
BaseProcessingInfo, |
|
|
|
PromptReplacement, |
|
|
|
) |
|
|
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder |
|
|
|
from vllm.sequence import IntermediateTensors |
|
|
|
from vllm.transformers_utils.processors.bagel import BagelProcessor |
|
|
|
from vllm.utils.tensor_schema import TensorSchema |
|
|
|
|
|
|
|
from .interfaces import ( |
|
|
|
MultiModalEmbeddings, |
|
|
|
SupportsLoRA, |
|
|
|
SupportsMultiModal, |
|
|
|
SupportsPP, |
|
|
|
) |
|
|
|
from .siglip import SiglipVisionModel |
|
|
|
from .utils import ( |
|
|
|
AutoWeightsLoader, |
|
|
|
WeightsMapper, |
|
|
|
init_vllm_registered_model, |
|
|
|
maybe_prefix, |
|
|
|
) |
|
|
|
|
|
|
|
logger = init_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
class BagelImagePixelInputs(TensorSchema): |
|
|
|
""" |
|
|
|
Dimensions: |
|
|
|
- bn: Batch size * number of images |
|
|
|
- c: Number of channels (3) |
|
|
|
- h: Height of each image |
|
|
|
- w: Width of each image |
|
|
|
""" |
|
|
|
|
|
|
|
type: Literal["pixel_values"] |
|
|
|
pixel_values: torch.Tensor # Shape: (bn, 3, h, w) |
|
|
|
|
|
|
|
|
|
|
|
BagelImageInputs: TypeAlias = BagelImagePixelInputs |
|
|
|
|
|
|
|
|
|
|
|
class BagelVisionMLP(nn.Module): |
|
|
|
"""MLP connector for vision features.""" |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
in_features: int, |
|
|
|
hidden_features: int, |
|
|
|
out_features: int, |
|
|
|
act_layer: str = "gelu_pytorch_tanh", |
|
|
|
quant_config: QuantizationConfig | None = None, |
|
|
|
prefix: str = "", |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.fc1 = ColumnParallelLinear( |
|
|
|
in_features, |
|
|
|
hidden_features, |
|
|
|
bias=True, |
|
|
|
quant_config=quant_config, |
|
|
|
prefix=f"{prefix}.fc1", |
|
|
|
) |
|
|
|
self.act = get_act_fn(act_layer) |
|
|
|
self.fc2 = RowParallelLinear( |
|
|
|
hidden_features, |
|
|
|
out_features, |
|
|
|
bias=True, |
|
|
|
quant_config=quant_config, |
|
|
|
prefix=f"{prefix}.fc2", |
|
|
|
) |
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
x, _ = self.fc1(x) |
|
|
|
x = self.act(x) |
|
|
|
x, _ = self.fc2(x) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class PositionEmbedding(nn.Module): |
|
|
|
"""2D position embedding for vision tokens using sin-cos embeddings.""" |
|
|
|
|
|
|
|
def __init__(self, max_num_patch_per_side: int, hidden_size: int): |
|
|
|
super().__init__() |
|
|
|
self.max_num_patch_per_side = max_num_patch_per_side |
|
|
|
self.hidden_size = hidden_size |
|
|
|
|
|
|
|
# Create learnable 2D position embeddings (frozen sin-cos) |
|
|
|
pos_embed = self._get_2d_sincos_pos_embed(hidden_size, max_num_patch_per_side) |
|
|
|
self.register_buffer( |
|
|
|
"pos_embed", |
|
|
|
torch.from_numpy(pos_embed).float(), |
|
|
|
persistent=False, |
|
|
|
) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _get_2d_sincos_pos_embed(embed_dim: int, grid_size: int): |
|
|
|
"""Generate 2D sin-cos position embeddings.""" |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
grid_h = np.arange(grid_size, dtype=np.float32) |
|
|
|
grid_w = np.arange(grid_size, dtype=np.float32) |
|
|
|
grid = np.meshgrid(grid_w, grid_h) # w goes first |
|
|
|
grid = np.stack(grid, axis=0) |
|
|
|
grid = grid.reshape([2, 1, grid_size, grid_size]) |
|
|
|
pos_embed = PositionEmbedding._get_2d_sincos_pos_embed_from_grid( |
|
|
|
embed_dim, grid |
|
|
|
) |
|
|
|
return pos_embed |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid): |
|
|
|
"""Generate 2D sin-cos position embeddings from grid.""" |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
assert embed_dim % 2 == 0 |
|
|
|
# use half of dimensions to encode grid_h |
|
|
|
emb_h = PositionEmbedding._get_1d_sincos_pos_embed_from_grid( |
|
|
|
embed_dim // 2, grid[0] |
|
|
|
) |
|
|
|
emb_w = PositionEmbedding._get_1d_sincos_pos_embed_from_grid( |
|
|
|
embed_dim // 2, grid[1] |
|
|
|
) |
|
|
|
emb = np.concatenate([emb_h, emb_w], axis=1) |
|
|
|
return emb |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos): |
|
|
|
"""Generate 1D sin-cos position embeddings.""" |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
assert embed_dim % 2 == 0 |
|
|
|
omega = np.arange(embed_dim // 2, dtype=np.float64) |
|
|
|
omega /= embed_dim / 2.0 |
|
|
|
omega = 1.0 / 10000**omega |
|
|
|
|
|
|
|
pos = pos.reshape(-1) |
|
|
|
out = np.einsum("m,d->md", pos, omega) |
|
|
|
|
|
|
|
emb_sin = np.sin(out) |
|
|
|
emb_cos = np.cos(out) |
|
|
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) |
|
|
|
return emb |
|
|
|
|
|
|
|
def forward(self, position_ids: torch.Tensor) -> torch.Tensor: |
|
|
|
""" |
|
|
|
Args: |
|
|
|
position_ids: Flattened position IDs, shape (N,) where each ID |
|
|
|
corresponds to a position in the flattened grid |
|
|
|
Returns: |
|
|
|
Position embeddings of shape (N, hidden_size) |
|
|
|
""" |
|
|
|
# Ensure position_ids are on the same device as pos_embed |
|
|
|
position_ids = position_ids.to(self.pos_embed.device) |
|
|
|
return self.pos_embed[position_ids] |
|
|
|
|
|
|
|
|
|
|
|
class BagelProcessingInfo(BaseProcessingInfo): |
|
|
|
"""Processing information for BAGEL model.""" |
|
|
|
|
|
|
|
def get_hf_processor(self, **kwargs: object) -> BagelProcessor: |
|
|
|
from vllm.transformers_utils.processor import cached_get_image_processor |
|
|
|
|
|
|
|
image_processor = cached_get_image_processor( |
|
|
|
self.ctx.model_config.model, |
|
|
|
revision=self.ctx.model_config.revision, |
|
|
|
trust_remote_code=self.ctx.model_config.trust_remote_code, |
|
|
|
) |
|
|
|
|
|
|
|
tokenizer = self.get_tokenizer() |
|
|
|
|
|
|
|
return BagelProcessor( |
|
|
|
image_processor=image_processor, |
|
|
|
tokenizer=tokenizer, |
|
|
|
**kwargs, |
|
|
|
) |
|
|
|
|
|
|
|
def get_supported_mm_limits(self) -> Mapping[str, int | None]: |
|
|
|
return {"image": None} |
|
|
|
|
|
|
|
def get_mm_max_tokens_per_item( |
|
|
|
self, |
|
|
|
seq_len: int, |
|
|
|
mm_counts: Mapping[str, int], |
|
|
|
) -> Mapping[str, int]: |
|
|
|
hf_config = self.get_hf_config() |
|
|
|
# Calculate max tokens per image |
|
|
|
# For BAGEL: (vit_max_num_patch_per_side) ** 2 |
|
|
|
max_num_patches = hf_config.vit_max_num_patch_per_side**2 |
|
|
|
return {"image": max_num_patches} |
|
|
|
|
|
|
|
def get_num_image_tokens( |
|
|
|
self, |
|
|
|
*, |
|
|
|
image_width: int, |
|
|
|
image_height: int, |
|
|
|
) -> int: |
|
|
|
hf_config = self.get_hf_config() |
|
|
|
vit_config = hf_config.vit_config |
|
|
|
patch_size = vit_config.patch_size |
|
|
|
|
|
|
|
# Calculate number of patches |
|
|
|
num_patches_h = image_height // patch_size |
|
|
|
num_patches_w = image_width // patch_size |
|
|
|
return num_patches_h * num_patches_w |
|
|
|
|
|
|
|
|
|
|
|
class BagelDummyInputsBuilder(BaseDummyInputsBuilder[BagelProcessingInfo]): |
|
|
|
"""Build dummy inputs for BAGEL model profiling.""" |
|
|
|
|
|
|
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: |
|
|
|
num_images = mm_counts.get("image", 0) |
|
|
|
# Use a simple placeholder for each image |
|
|
|
return "<|image_pad|>" * num_images |
|
|
|
|
|
|
|
def get_dummy_mm_data( |
|
|
|
self, |
|
|
|
seq_len: int, |
|
|
|
mm_counts: Mapping[str, int], |
|
|
|
mm_options: Mapping[str, BaseDummyOptions] | None = None, |
|
|
|
) -> MultiModalDataDict: |
|
|
|
num_images = mm_counts.get("image", 0) |
|
|
|
hf_config = self.info.get_hf_config() |
|
|
|
vit_config = hf_config.vit_config |
|
|
|
|
|
|
|
# Use the configured image size |
|
|
|
image_size = vit_config.image_size |
|
|
|
image_overrides = mm_options.get("image") if mm_options else None |
|
|
|
|
|
|
|
return { |
|
|
|
"image": self._get_dummy_images( |
|
|
|
width=image_size, |
|
|
|
height=image_size, |
|
|
|
num_images=num_images, |
|
|
|
overrides=image_overrides, |
|
|
|
), |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
class BagelMultiModalProcessor(BaseMultiModalProcessor[BagelProcessingInfo]): |
|
|
|
"""Multimodal processor for BAGEL model.""" |
|
|
|
|
|
|
|
def _hf_processor_applies_updates( |
|
|
|
self, |
|
|
|
prompt_text: str, |
|
|
|
mm_items: MultiModalDataItems, |
|
|
|
hf_processor_mm_kwargs: Mapping[str, object], |
|
|
|
tokenization_kwargs: Mapping[str, object], |
|
|
|
) -> bool: |
|
|
|
return False |
|
|
|
|
|
|
|
def _get_prompt_updates( |
|
|
|
self, |
|
|
|
mm_items: MultiModalDataItems, |
|
|
|
hf_processor_mm_kwargs: Mapping[str, Any], |
|
|
|
out_mm_kwargs: MultiModalKwargsItems, |
|
|
|
) -> Sequence[PromptReplacement]: |
|
|
|
"""Replace image placeholders with the correct number of tokens.""" |
|
|
|
hf_config = self.info.get_hf_config() |
|
|
|
|
|
|
|
# Get the tokenizer to look up the image token ID |
|
|
|
tokenizer = self.info.get_tokenizer() |
|
|
|
image_token_id = tokenizer.get_vocab().get("<|image_pad|>") |
|
|
|
if image_token_id is None: |
|
|
|
raise ValueError( |
|
|
|
"Image token '<|image_pad|>' not found in tokenizer vocabulary" |
|
|
|
) |
|
|
|
|
|
|
|
def get_replacement_bagel(item_idx: int): |
|
|
|
# For BAGEL, calculate number of tokens based on max patch size |
|
|
|
num_tokens = hf_config.vit_max_num_patch_per_side**2 |
|
|
|
# Use the image token ID from tokenizer |
|
|
|
return [image_token_id] * num_tokens |
|
|
|
|
|
|
|
return [ |
|
|
|
PromptReplacement( |
|
|
|
modality="image", |
|
|
|
target=[image_token_id], |
|
|
|
replacement=get_replacement_bagel, |
|
|
|
) |
|
|
|
] |
|
|
|
|
|
|
|
def _get_mm_fields_config( |
|
|
|
self, |
|
|
|
hf_inputs: Any, |
|
|
|
hf_processor_mm_kwargs: Mapping[str, object], |
|
|
|
) -> Mapping[str, MultiModalFieldConfig]: |
|
|
|
return { |
|
|
|
"pixel_values": MultiModalFieldConfig.batched("image"), |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_processor( |
|
|
|
BagelMultiModalProcessor, |
|
|
|
info=BagelProcessingInfo, |
|
|
|
dummy_inputs=BagelDummyInputsBuilder, |
|
|
|
) |
|
|
|
class BagelForConditionalGeneration( |
|
|
|
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP |
|
|
|
): |
|
|
|
""" |
|
|
|
BAGEL: A unified multimodal model for image understanding and generation. |
|
|
|
|
|
|
|
For vLLM, we focus on the image understanding (vision-to-text) capabilities. |
|
|
|
The image generation part is not supported in vLLM. |
|
|
|
""" |
|
|
|
|
|
|
|
# Weight mapping from HF to vLLM |
|
|
|
hf_to_vllm_mapper = WeightsMapper( |
|
|
|
orig_to_new_prefix={ |
|
|
|
"language_model.": "language_model.", |
|
|
|
"vit_model.": "vit_model.", |
|
|
|
"connector.": "connector.", |
|
|
|
"vit_pos_embed.": "vit_pos_embed.", |
|
|
|
} |
|
|
|
) |
|
|
|
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
config = vllm_config.model_config.hf_config |
|
|
|
quant_config = vllm_config.quant_config |
|
|
|
multimodal_config = vllm_config.model_config.multimodal_config |
|
|
|
|
|
|
|
# Ensure we have a BagelConfig (check by name to handle trust_remote_code) |
|
|
|
# When trust_remote_code=True, the config comes from transformers_modules |
|
|
|
if type(config).__name__ != "BagelConfig": |
|
|
|
raise ValueError( |
|
|
|
f"Expected BagelConfig, got {type(config).__name__}. " |
|
|
|
"Make sure the model config is properly loaded." |
|
|
|
) |
|
|
|
|
|
|
|
self.config = config |
|
|
|
self.multimodal_config = multimodal_config |
|
|
|
|
|
|
|
# Initialize language model (Qwen2) |
|
|
|
# Pass the llm_config from BagelConfig to initialize Qwen2 properly |
|
|
|
self.language_model = init_vllm_registered_model( |
|
|
|
vllm_config=vllm_config, |
|
|
|
hf_config=config.llm_config, |
|
|
|
prefix=maybe_prefix(prefix, "language_model"), |
|
|
|
architectures=["Qwen2ForCausalLM"], |
|
|
|
) |
|
|
|
|
|
|
|
# Initialize vision model (SigLIP) if visual understanding is enabled |
|
|
|
if config.visual_und: |
|
|
|
# Fix vit_config: checkpoint has 26 layers (0-25) but config says 27 |
|
|
|
# Also disable head as it's not in checkpoint |
|
|
|
vit_config = config.vit_config |
|
|
|
if vit_config.num_hidden_layers == 27: |
|
|
|
logger.warning( |
|
|
|
"Overriding vit_config.num_hidden_layers from 27 to 26 " |
|
|
|
"to match the Bagel model checkpoint." |
|
|
|
) |
|
|
|
vit_config.num_hidden_layers = 26 |
|
|
|
if not hasattr(vit_config, "vision_use_head"): |
|
|
|
logger.warning( |
|
|
|
"Setting vit_config.vision_use_head to False as it is not " |
|
|
|
"present in the Bagel model checkpoint." |
|
|
|
) |
|
|
|
vit_config.vision_use_head = False |
|
|
|
|
|
|
|
self.vit_model = SiglipVisionModel( |
|
|
|
config=vit_config, |
|
|
|
quant_config=quant_config, |
|
|
|
prefix=maybe_prefix(prefix, "vit_model"), |
|
|
|
) |
|
|
|
|
|
|
|
# Initialize connector (MLP) |
|
|
|
vit_hidden_size = config.vit_config.hidden_size |
|
|
|
llm_hidden_size = config.llm_config.hidden_size |
|
|
|
|
|
|
|
self.connector = BagelVisionMLP( |
|
|
|
in_features=vit_hidden_size, |
|
|
|
hidden_features=llm_hidden_size, |
|
|
|
out_features=llm_hidden_size, |
|
|
|
act_layer=config.connector_act, |
|
|
|
quant_config=quant_config, |
|
|
|
prefix=maybe_prefix(prefix, "connector"), |
|
|
|
) |
|
|
|
|
|
|
|
# Position embedding for vision tokens |
|
|
|
self.vit_pos_embed = PositionEmbedding( |
|
|
|
max_num_patch_per_side=config.vit_max_num_patch_per_side, |
|
|
|
hidden_size=llm_hidden_size, |
|
|
|
) |
|
|
|
else: |
|
|
|
self.vit_model = None |
|
|
|
self.connector = None |
|
|
|
self.vit_pos_embed = None |
|
|
|
|
|
|
|
self.make_empty_intermediate_tensors = ( |
|
|
|
self.language_model.make_empty_intermediate_tensors |
|
|
|
) |
|
|
|
|
|
|
|
def _parse_and_validate_image_input( |
|
|
|
self, **kwargs: object |
|
|
|
) -> BagelImageInputs | None: |
|
|
|
pixel_values = kwargs.pop("pixel_values", None) |
|
|
|
|
|
|
|
if pixel_values is None: |
|
|
|
return None |
|
|
|
|
|
|
|
return BagelImagePixelInputs( |
|
|
|
type="pixel_values", |
|
|
|
pixel_values=pixel_values, |
|
|
|
) |
|
|
|
|
|
|
|
def _process_image_input( |
|
|
|
self, image_input: BagelImageInputs |
|
|
|
) -> tuple[torch.Tensor, ...]: |
|
|
|
"""Process image inputs through vision encoder and connector.""" |
|
|
|
pixel_values = image_input["pixel_values"] |
|
|
|
|
|
|
|
# Handle potential extra batch dimension |
|
|
|
# Expected shape: (batch_size * num_images, 3, H, W) |
|
|
|
# But might receive: (batch_size, num_images, 3, H, W) |
|
|
|
if pixel_values.ndim == 5: |
|
|
|
# Flatten batch and num_images dimensions |
|
|
|
batch_size, num_images, channels, height, width = pixel_values.shape |
|
|
|
pixel_values = pixel_values.reshape( |
|
|
|
batch_size * num_images, channels, height, width |
|
|
|
) |
|
|
|
|
|
|
|
# Get vision features from SigLIP |
|
|
|
# pixel_values shape: (batch_size * num_images, 3, H, W) |
|
|
|
vision_features = self.vit_model(pixel_values) |
|
|
|
|
|
|
|
# Pass through connector |
|
|
|
vision_embeds = self.connector(vision_features) |
|
|
|
|
|
|
|
# Add position embeddings |
|
|
|
batch_size, num_patches, hidden_size = vision_embeds.shape |
|
|
|
patch_size = self.config.vit_config.patch_size |
|
|
|
image_size = self.config.vit_config.image_size |
|
|
|
|
|
|
|
# Calculate grid dimensions |
|
|
|
num_patches_per_side = image_size // patch_size |
|
|
|
|
|
|
|
# Create flattened position IDs (0 to num_patches-1) |
|
|
|
# For BAGEL, we use extrapolate mode by default |
|
|
|
h_coords = torch.arange(num_patches_per_side, device=vision_embeds.device) |
|
|
|
w_coords = torch.arange(num_patches_per_side, device=vision_embeds.device) |
|
|
|
position_ids = ( |
|
|
|
h_coords[:, None] * self.config.vit_max_num_patch_per_side + w_coords |
|
|
|
).flatten() |
|
|
|
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1).flatten() |
|
|
|
|
|
|
|
# Add position embeddings |
|
|
|
pos_embeds = self.vit_pos_embed(position_ids) |
|
|
|
pos_embeds = pos_embeds.reshape(batch_size, num_patches, hidden_size) |
|
|
|
# Ensure pos_embeds are on the same device as vision_embeds |
|
|
|
pos_embeds = pos_embeds.to(vision_embeds.device) |
|
|
|
vision_embeds = vision_embeds + pos_embeds |
|
|
|
|
|
|
|
# Split by image |
|
|
|
return tuple(vision_embeds) |
|
|
|
|
|
|
|
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: |
|
|
|
"""Get multimodal embeddings from input.""" |
|
|
|
image_input = self._parse_and_validate_image_input(**kwargs) |
|
|
|
if image_input is None: |
|
|
|
return [] |
|
|
|
|
|
|
|
return self._process_image_input(image_input) |
|
|
|
|
|
|
|
def get_language_model(self) -> nn.Module: |
|
|
|
return self.language_model |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
input_ids: torch.Tensor, |
|
|
|
positions: torch.Tensor, |
|
|
|
intermediate_tensors: IntermediateTensors | None = None, |
|
|
|
inputs_embeds: torch.Tensor | None = None, |
|
|
|
**kwargs: object, |
|
|
|
) -> torch.Tensor | IntermediateTensors: |
|
|
|
"""Run forward pass for BAGEL. |
|
|
|
|
|
|
|
Args: |
|
|
|
input_ids: Flattened (concatenated) input_ids corresponding to a batch. |
|
|
|
positions: Flattened (concatenated) position ids corresponding to a batch. |
|
|
|
intermediate_tensors: Intermediate tensors from prior forward pass. |
|
|
|
inputs_embeds: Optional tensor of input embeddings. |
|
|
|
""" |
|
|
|
if intermediate_tensors is not None: |
|
|
|
inputs_embeds = None |
|
|
|
|
|
|
|
hidden_states = self.language_model.model( |
|
|
|
input_ids=input_ids, |
|
|
|
positions=positions, |
|
|
|
intermediate_tensors=intermediate_tensors, |
|
|
|
inputs_embeds=inputs_embeds, |
|
|
|
) |
|
|
|
return hidden_states |
|
|
|
|
|
|
|
def compute_logits( |
|
|
|
self, |
|
|
|
hidden_states: torch.Tensor, |
|
|
|
) -> torch.Tensor | None: |
|
|
|
return self.language_model.compute_logits(hidden_states) |
|
|
|
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: |
|
|
|
"""Load weights from checkpoint.""" |
|
|
|
skip_prefixes = [] |
|
|
|
# Skip vit_pos_embed.pos_embed as it's handled by PositionEmbedding module |
|
|
|
skip_prefixes.append("vit_pos_embed.pos_embed") |
|
|
|
|
|
|
|
# If visual understanding is disabled, skip vision-related weights |
|
|
|
if self.vit_model is None: |
|
|
|
skip_prefixes.extend(["vit_model.", "connector.", "vit_pos_embed"]) |
|
|
|
|
|
|
|
# Skip generation-related weights since we only support text2text and image2text |
|
|
|
# Filter out all image generation components: |
|
|
|
# - 'moe_gen': MoE generation weights |
|
|
|
# - 'latent_pos_embed': Latent position embeddings for VAE |
|
|
|
# - 'llm2vae', 'vae2llm': LLM-VAE projections |
|
|
|
# - 'time_embedder': Timestep embeddings for diffusion |
|
|
|
# - VAE encoder/decoder: Use specific prefixes to avoid matching vision encoder |
|
|
|
generation_keywords = [ |
|
|
|
"moe_gen", |
|
|
|
"latent_pos_embed", |
|
|
|
"llm2vae", |
|
|
|
"vae2llm", |
|
|
|
"time_embedder", |
|
|
|
] |
|
|
|
vae_prefixes = [ |
|
|
|
"decoder.", |
|
|
|
"encoder.", |
|
|
|
] # VAE encoder/decoder, not vision encoder |
|
|
|
filtered_weights = [] |
|
|
|
for name, tensor in weights: |
|
|
|
# Skip generation-related keywords |
|
|
|
if any(skip in name for skip in generation_keywords): |
|
|
|
continue |
|
|
|
if any(name.startswith(prefix) for prefix in vae_prefixes): |
|
|
|
continue |
|
|
|
|
|
|
|
if "patch_embedding.weight" in name and tensor.ndim == 2: |
|
|
|
out_channels = tensor.shape[0] |
|
|
|
in_features = tensor.shape[1] |
|
|
|
patch_size = self.config.vit_config.patch_size |
|
|
|
in_channels = self.config.vit_config.num_channels |
|
|
|
if in_features == in_channels * patch_size * patch_size: |
|
|
|
tensor = tensor.reshape( |
|
|
|
out_channels, patch_size, patch_size, in_channels |
|
|
|
) |
|
|
|
tensor = tensor.permute(0, 3, 1, 2).contiguous() |
|
|
|
|
|
|
|
filtered_weights.append((name, tensor)) |
|
|
|
|
|
|
|
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) |
|
|
|
return loader.load_weights(filtered_weights, mapper=self.hf_to_vllm_mapper) |