54 Commits

Author SHA1 Message Date
  Ferdinand Mom 48c69f7f68
Merge branch 'main' into v5-test_tensor_parallel_moe 6 hours ago
  Cyril Vallez 4d6516e256
Simplify tie weights logic (#42895) 7 hours ago
  Wang, Yi 24b311eead
fix FastSpeech2ConformerTokenizer crash in tokenize (#42888) 7 hours ago
  r0 0f89661972
Added kernels from kernel hub for Bamba model (#41540) 7 hours ago
  Avihu Dekel 5d2f82b530
Fix GraniteMoeHybrid in transformers v5 (#42872) 7 hours ago
  Anton Vlasjuk 4e7cecb24d
[`Ernie 4.5 Moe`] Fix routing, weights, and update expectations (#42653) 8 hours ago
  Julien Denize 252afd8968
Fix convert_tekken_tokenizer (#42592) 8 hours ago
  Taisei Yamamoto 89998bddca
Stop collecting all model parameters to save models when using DeepSpeed and LoRA (#41416) 8 hours ago
  Cyril Vallez 8d526c238a
[modular] Fix a weird renaming edge-case (#42844) 8 hours ago
  Patrick von Platen 7960b5ea40
[Devstral] Make sure FP8 conversion works correctly (#42715) 8 hours ago
  Tom Aarsen 6c7c992faf
Add missing ModelOutput subclass return type hints (#41219) 8 hours ago
  Preetam Chhimpa 0f97c688d5
Fix BLT training_ci overfit test (#42685) 8 hours ago
  Abubakar Abid 7f52a2a4ea
Add `.on_push_begin()` callback to Trainer and implement for `TrackioCallback` (#42850) 14 hours ago
  Steven Liu 31de95ef71
[docs] optimizations quickstart (#42538) 20 hours ago
  Yoni Gozlan 23394cc491
Simplify using custom resolution for sam3 and sam3_video inference (#42787) 23 hours ago
  Danny Waser 06378d40e6
fix: Initialize ApertusMLP's xielu activation using `torch_dtype` (#42864) 1 day ago
  Yoni Gozlan fc50bdc685
Remove null values from fast image processors dict (#42780) 1 day ago
  Cyril Vallez c7aec088a6
Enforce call to `post_init` and fix all of them (#42873) 1 day ago
  Rémi Ouazan f3d5f2558b
[CB] Easy optimizations for continuous batching (#42839) 1 day ago
  AYou0207 298d08dc36
typo (#42863) 1 day ago
  Cyril Vallez a187b857a7
Remove tied weights from internal attribute if they are not tied (#42871) 1 day ago
  Mohamed Mekkouri 64c12fdf5f
[docs] Improve contribution guidelines for Quantization (#42870) 1 day ago
  Sai-Suraj-27 f0d9cd1ff6
Fixes 2 failing tests from AMD CI (#42777) 1 day ago
  jiqing-feng 66623a1fd6
Fix speccht5_tts pipeline (#42830) 1 day ago
  YangKai0616 e17b1b85e3
[Fix] Fix FA2 kernels ut (#42803) 1 day ago
  Cyril Vallez 40dc11cd3e
Fix Gemma (#42847) 4 days ago
  Cyril Vallez c247063001
Reapply modular examples (#42846) 4 days ago
  Yoni Gozlan a61aba59fb
Improve BatchFeature: stack list and lists of torch tensors (#42750) 4 days ago
  Cyril Vallez 5b710c7542
Do not rely on config for inferring model dtype (#42838) 4 days ago
  Anton Vlasjuk 33c948e494
[`T5Gemma2`] Fix bidirectional mask for encoder (#42820) 4 days ago
  Cyril Vallez e6b9d06147
[saving] Simplify general logic (#42766) 4 days ago
  Wu, Ke 65dc261512
Add inputs_to_logits_ratio to LasrCTCConfig (#42720) 4 days ago
  Cyril Vallez 64a7cc82a6
Simplify dtype instantiation (#42825) 4 days ago
  Raushan Turganbay 37426b27bf
Fix a typo in MoE models (#42835) 4 days ago
  Rémi Ouazan aa495f62de
Fixes for the failures of AMD CI (#42718) 4 days ago
  jiqing-feng c24b51dd78
Fix xpu output check for Ministral3 tests (#42761) 4 days ago
  ZX-ModelCloud b19844eef6
Compatible with GPTQModel FORAMT.LLM_AWQ (#42833) 4 days ago
  Marc Sun 780cc65907
Fix deepspeed sp loss due to missing labels (#42812) 4 days ago
  Lysandre Debut 3fbd59b6f1
Add requires_backends to the main init (#42799) 4 days ago
  YangKai0616 f80b0485fe
[XPU] Fix UT errors in the sam3 and lfm series model. (#42798) 4 days ago
  Steven Liu 6d00f6b0a5
[docs] Chat content patterns (#42748) 5 days ago
  Arthur 6217adc6c8
Default auto (#42805) 5 days ago
  Lysandre Debut 8a2a83d574
Automatic release script (#42808) 5 days ago
  Cyril Vallez 464dfa0446
Raise conversion errors after loading (#42807) 5 days ago
  Marc Sun 6a93635e1d
update deprecation msg for `warmup_ratio` (#42813) 5 days ago
  Mohamed Mekkouri 0c18820c51
[kernels] adding RMSNorm kernel for mps devices (#42058) 5 days ago
  Yoni Gozlan dfe6e4c0ef
Fix integration test in Owlv2 image processing tests (#42783) 5 days ago
  Mohamed Mekkouri de055d6db0
[kernels] Final kernel removal 🥳 (#41664) 5 days ago
  Merve Noyan eaa3d4dd35
Vision docs 📝 (#42096) 5 days ago
  zhang-prog 8c84144bfc
[Model] Add PaddleOCR-VL Model Support (#42178) 5 days ago
  Mohamed Mekkouri 78b2992920
[CI] fix wav2vec test (#42810) 5 days ago
  Marc Sun f8e8ddb087
fix awq (#42776) 5 days ago
  Rémi Ouazan f8e5ae6a50
Better continuous batching tests (#42699) 5 days ago
  Mohamed Mekkouri 86644be479
[Quantization] FBgemm FP8 for XPU (#42773) 5 days ago
100 changed files with 2854 additions and 1985 deletions
Split View
  1. +60
    -0
      .github/workflows/release.yml
  2. +1
    -2
      Makefile
  3. +6
    -0
      docs/source/en/_toctree.yml
  4. +198
    -0
      docs/source/en/chat_content_patterns.md
  5. +2
    -0
      docs/source/en/internal/import_utils.md
  6. +248
    -0
      docs/source/en/model_doc/paddleocr_vl.md
  7. +15
    -0
      docs/source/en/model_doc/sam3.md
  8. +15
    -0
      docs/source/en/model_doc/sam3_video.md
  9. +178
    -0
      docs/source/en/optimization_overview.md
  10. +95
    -16
      docs/source/en/quantization/contribute.md
  11. +48
    -21
      docs/source/en/tasks/image_text_to_text.md
  12. +334
    -19
      docs/source/en/tasks/mask_generation.md
  13. +1
    -1
      docs/source/en/tasks/semantic_segmentation.md
  14. +92
    -74
      docs/source/en/tasks/video_text_to_text.md
  15. +3
    -10
      examples/modular-transformers/configuration_duplicated_method.py
  16. +3
    -10
      examples/modular-transformers/configuration_my_new_model.py
  17. +3
    -10
      examples/modular-transformers/configuration_my_new_model2.py
  18. +15
    -19
      examples/modular-transformers/configuration_new_model.py
  19. +3
    -0
      examples/modular-transformers/modeling_add_function.py
  20. +56
    -164
      examples/modular-transformers/modeling_dummy_bert.py
  21. +12
    -42
      examples/modular-transformers/modeling_from_uppercase_model.py
  22. +5
    -2
      examples/modular-transformers/modeling_global_indexing.py
  23. +29
    -177
      examples/modular-transformers/modeling_multimodal2.py
  24. +6
    -2
      examples/modular-transformers/modeling_my_new_model2.py
  25. +2
    -37
      examples/modular-transformers/modeling_new_task_model.py
  26. +56
    -164
      examples/modular-transformers/modeling_roberta.py
  27. +44
    -13
      examples/modular-transformers/modeling_super.py
  28. +5
    -2
      examples/modular-transformers/modeling_switch_function.py
  29. +31
    -26
      examples/modular-transformers/modeling_test_detr.py
  30. +250
    -0
      examples/modular-transformers/modeling_test_suffix.py
  31. +8
    -1
      examples/modular-transformers/modular_multimodal2.py
  32. +1
    -2
      examples/modular-transformers/modular_new_model.py
  33. +12
    -0
      examples/modular-transformers/modular_test_suffix.py
  34. +12
    -2
      examples/pytorch/continuous_batching.py
  35. +2
    -30
      setup.py
  36. +2
    -0
      src/transformers/__init__.py
  37. +6
    -2
      src/transformers/conversion_mapping.py
  38. +62
    -57
      src/transformers/core_model_loading.py
  39. +1
    -1
      src/transformers/dependency_versions_table.py
  40. +38
    -9
      src/transformers/feature_extraction_utils.py
  41. +20
    -23
      src/transformers/generation/continuous_batching/continuous_api.py
  42. +15
    -4
      src/transformers/image_processing_utils_fast.py
  43. +26
    -0
      src/transformers/integrations/accelerate.py
  44. +50
    -106
      src/transformers/integrations/awq.py
  45. +88
    -38
      src/transformers/integrations/fbgemm_fp8.py
  46. +14
    -2
      src/transformers/integrations/hub_kernels.py
  47. +35
    -0
      src/transformers/integrations/integration_utils.py
  48. +12
    -0
      src/transformers/integrations/mistral.py
  49. +0
    -15
      src/transformers/kernels/falcon_mamba/__init__.py
  50. +0
    -529
      src/transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py
  51. +239
    -270
      src/transformers/modeling_utils.py
  52. +1
    -0
      src/transformers/models/__init__.py
  53. +1
    -0
      src/transformers/models/align/modeling_align.py
  54. +3
    -1
      src/transformers/models/apertus/modeling_apertus.py
  55. +4
    -1
      src/transformers/models/apertus/modular_apertus.py
  56. +1
    -1
      src/transformers/models/auto/auto_factory.py
  57. +2
    -0
      src/transformers/models/auto/configuration_auto.py
  58. +1
    -0
      src/transformers/models/auto/image_processing_auto.py
  59. +1
    -0
      src/transformers/models/auto/modeling_auto.py
  60. +1
    -0
      src/transformers/models/auto/processing_auto.py
  61. +1
    -0
      src/transformers/models/auto/tokenization_auto.py
  62. +15
    -16
      src/transformers/models/bamba/modeling_bamba.py
  63. +15
    -15
      src/transformers/models/bamba/modular_bamba.py
  64. +1
    -0
      src/transformers/models/bart/modeling_bart.py
  65. +0
    -1
      src/transformers/models/beit/image_processing_beit_fast.py
  66. +1
    -1
      src/transformers/models/beit/modeling_beit.py
  67. +2
    -2
      src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
  68. +1
    -0
      src/transformers/models/blenderbot/modeling_blenderbot.py
  69. +1
    -0
      src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
  70. +2
    -0
      src/transformers/models/blip/modeling_blip_text.py
  71. +1
    -1
      src/transformers/models/blip_2/modeling_blip_2.py
  72. +152
    -1
      src/transformers/models/blt/modeling_blt.py
  73. +158
    -2
      src/transformers/models/blt/modular_blt.py
  74. +0
    -2
      src/transformers/models/bridgetower/image_processing_bridgetower_fast.py
  75. +1
    -0
      src/transformers/models/bridgetower/modeling_bridgetower.py
  76. +1
    -0
      src/transformers/models/chameleon/modeling_chameleon.py
  77. +2
    -0
      src/transformers/models/clipseg/modeling_clipseg.py
  78. +0
    -1
      src/transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py
  79. +0
    -1
      src/transformers/models/convnext/image_processing_convnext_fast.py
  80. +4
    -4
      src/transformers/models/dac/modeling_dac.py
  81. +1
    -1
      src/transformers/models/data2vec/modeling_data2vec_vision.py
  82. +0
    -4
      src/transformers/models/decision_transformer/modeling_decision_transformer.py
  83. +1
    -1
      src/transformers/models/deepseek_v3/modeling_deepseek_v3.py
  84. +1
    -0
      src/transformers/models/deepseek_v3/modular_deepseek_v3.py
  85. +0
    -1
      src/transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py
  86. +2
    -2
      src/transformers/models/deepseek_vl/modeling_deepseek_vl.py
  87. +0
    -4
      src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py
  88. +2
    -2
      src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py
  89. +2
    -6
      src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py
  90. +0
    -1
      src/transformers/models/depth_pro/image_processing_depth_pro_fast.py
  91. +4
    -0
      src/transformers/models/dia/modeling_dia.py
  92. +4
    -0
      src/transformers/models/dia/modular_dia.py
  93. +0
    -1
      src/transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py
  94. +0
    -1
      src/transformers/models/donut/image_processing_donut_fast.py
  95. +1
    -1
      src/transformers/models/dots1/modeling_dots1.py
  96. +1
    -2
      src/transformers/models/dpt/image_processing_dpt_fast.py
  97. +1
    -2
      src/transformers/models/dpt/modular_dpt.py
  98. +1
    -1
      src/transformers/models/edgetam/modeling_edgetam.py
  99. +1
    -2
      src/transformers/models/efficientloftr/image_processing_efficientloftr_fast.py
  100. +0
    -1
      src/transformers/models/efficientnet/image_processing_efficientnet_fast.py

+ 60
- 0
.github/workflows/release.yml View File

@@ -0,0 +1,60 @@
name: Release
on:
push:
tags:
- v*
branches:
- 'v*-release'

jobs:
build_and_test:
name: build release
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: set up python
uses: actions/setup-python@v5
with:
python-version: "3.13"

- run: pip install setuptools
- run: pip install -e .
- run: make build-release

- run: pip uninstall -y transformers
- run: pip install dist/*.whl

- run: python -c "from transformers import *"

- run: pip install -e .[torch]
- run: python -c "from transformers import pipeline; classifier = pipeline('text-classification'); assert classifier('What a nice release')[0]['score'] > 0"

- name: Upload build artifacts
uses: actions/upload-artifact@v4
with:
name: python-dist
path: |
dist/**
build/**

upload_package:
needs: build_and_test
if: startsWith(github.ref, 'refs/tags/')
runs-on: ubuntu-latest
environment: pypi-release
permissions:
id-token: write

steps:
- uses: actions/checkout@v4

- name: Download build artifacts
uses: actions/download-artifact@v4
with:
name: python-dist
path: .

- name: Publish package distributions to TestPyPI
uses: pypa/gh-action-pypi-publish@release/v1


+ 1
- 2
Makefile View File

@@ -45,7 +45,7 @@ repo-consistency:
python utils/check_modular_conversion.py
python utils/check_dummies.py
python utils/check_repo.py
python utils/check_init_weights_data.py
python utils/check_modeling_structure.py
python utils/check_inits.py
python utils/check_pipeline_typing.py
python utils/check_config_docstrings.py
@@ -135,4 +135,3 @@ build-release:
rm -rf build
python setup.py bdist_wheel
python setup.py sdist
python utils/check_build.py

+ 6
- 0
docs/source/en/_toctree.yml View File

@@ -68,6 +68,8 @@
title: Perplexity of fixed-length models
title: Generate API
- sections:
- local: optimization_overview
title: Overview
- local: attention_interface
title: Attention backends
- local: continuous_batching
@@ -96,6 +98,8 @@
title: Chat basics
- local: chat_templating
title: Chat templates
- local: chat_content_patterns
title: Chat message patterns
- local: chat_templating_multimodal
title: Multimodal chat templates
- local: chat_extras
@@ -1119,6 +1123,8 @@
title: OWL-ViT
- local: model_doc/owlv2
title: OWLv2
- local: model_doc/paddleocr_vl
title: PaddleOCRVL
- local: model_doc/paligemma
title: PaliGemma
- local: model_doc/perceiver


+ 198
- 0
docs/source/en/chat_content_patterns.md View File

@@ -0,0 +1,198 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# Chat message patterns

Chat models expect conversations as a list of dictionaries. Each dictionary uses `role` and `content` keys. The `content` key holds the user message passed to the model. Large language models accept text and tools and multimodal models combine text with images, videos, and audio.

Transformers uses a unified format where each modality type is specified explicitly, making it straightforward to mix and match inputs in a single message.

This guide covers message formatting patterns for each modality, tools, batch inference, and multi-turn conversations.

## Text

Text is the most basic content type. It's the foundation for all other patterns. Pass your message to `"content"` as a string.

```py
message = [
{
"role": "user",
"content": "Explain the French Bread Law."
}
]
```

You could also use the explicit `"type": "text"` format to keep your code consistent when you add images, videos, or audio later.

```py
message = [
{
"role": "user",
"content": [{"type": "text", "text": "Explain the French Bread Law."}]
}
]
```

## Tools

[Tools](./chat_extras) are functions a chat model can call, like getting real-time weather data, instead of generating it on its own.

The `assistant` role handles the tool request. Set `"type": "function"` in the `"tool_calls"` key and provide your tool to the `"function"` key. Append the assistant's tool request to your message.

```py
weather = {"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}}
message.append(
{
"role": "assistant",
"tool_calls": [{"type": "function", "function": weather}]
}
)
```

The `tool` role handles the result. Append it in `"content"`. This value should always be a string.

```py
message.append({"role": "tool", "content": "22"})
```

## Multimodal

Multimodal models extend this format to handle images, videos, and audio. Each input specifies its `"type"` and provides the media with `"url"` or `"path"`.

### Image

Set `"type": "image"` and use `"url"` for links or `"path"` for local files.

```py
message = [
{
"role": "user",
"content": [
{"type": "image", "url": "https://assets.bonappetit.com/photos/57ad4ebc53e63daf11a4ddc7/master/w_1280,c_limit/kouign-amann.jpg"},
{"type": "text", "text": "What pastry is shown in the image?"}
]
}
]
```

### Video

Set `"type": "video"` and use `"url"` for links or `"path"` for local files.

```py
message = [
{
"role": "user",
"content": [
{"type": "video", "url": "https://static01.nyt.com/images/2019/10/01/dining/01Sourdough-GIF-1/01Sourdough-GIF-1-superJumbo.gif"},
{"type": "text", "text": "What is shown in this video?"}
]
}
]
```

### Audio

Set `"type": "audio"` and use `"url"` for links or `"path"` for local files.

```py
message = [
{
"role": "user",
"content": [
{"type": "audio", "url": "https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac"},
{"type": "text", "text": "Transcribe the speech."}
]
}
]
```

### Mixed multiple

The `content` list accepts any combination of types. The model processes all inputs together, enabling comparisons or cross-modal reasoning.

```py
message = [
{
"role": "user",
"content": [
{"type": "image", "url": "https://assets.bonappetit.com/photos/57ad4ebc53e63daf11a4ddc7/master/w_1280,c_limit/kouign-amann.jpg"},
{"type": "video", "url": "https://static01.nyt.com/images/2019/10/01/dining/01Sourdough-GIF-1/01Sourdough-GIF-1-superJumbo.gif"},
{"type": "text", "text": "What does the image and video share in common?"},
],
},
{
"role": "user",
"content": [
{"type": "image", "url": "https://assets.bonappetit.com/photos/57ad4ebc53e63daf11a4ddc7/master/w_1280,c_limit/kouign-amann.jpg"},
{"type": "image", "url": "https://assets.bonappetit.com/photos/57e191f49f19b4610e6b7693/master/w_1600%2Cc_limit/undefined"},
{"type": "text", "text": "What type of pastries are these?"},
],
}
]
```

## Batched

Batched inference processes multiple conversations in a single forward pass to improve throughput and efficiency. Wrap each conversation in its own list, then pass them together as a list of lists.

```py
messages = [
[
{"role": "user",
"content": [
{"type": "image", "url": "https://assets.bonappetit.com/photos/57ad4ebc53e63daf11a4ddc7/master/w_1280,c_limit/kouign-amann.jpg"},
{"type": "text", "text": "What type of pastry is this?"}
]
},
],
[
{"role": "user",
"content": [
{"type": "image", "url": "https://assets.bonappetit.com/photos/57e191f49f19b4610e6b7693/master/w_1600%2Cc_limit/undefined"},
{"type": "text", "text": "What type of pastry is this?"}
]
},
],
]
```

## Multi-turn

Conversations span multiple exchanges, alternating between `"user"` and `"assistant"` roles. Each turn adds a new message to the list, giving the model access to the full conversation history. This context helps the model generate more appropriate responses.

```py
message = [
{
"role": "user",
"content": [
{"type": "image", "url": "https://assets.bonappetit.com/photos/57ad4ebc53e63daf11a4ddc7/master/w_1280,c_limit/kouign-amann.jpg"},
{"type": "text", "text": "What pastry is shown in the image?"}
]
},
{
"role": "assistant",
"content": [{"type": "text", "text": "This is kouign amann, a laminated dough pastry (i.e., dough folded with layers of butter) that also incorporates sugar between layers so that during baking the sugar caramelizes."}]
},
{
"role": "user",
"content": [
{"type": "image", "url": "https://static01.nyt.com/images/2023/07/21/multimedia/21baguettesrex-hbkc/21baguettesrex-hbkc-videoSixteenByNineJumbo1600.jpg"},
{"type": "text", "text": "Compare it to this image now."}
]
}
]
```

+ 2
- 0
docs/source/en/internal/import_utils.md View File

@@ -97,3 +97,5 @@ You can specify the following operators: `==`, `>`, `>=`, `<`, `<=`, `!=`.
[[autodoc]] utils.import_utils.define_import_structure

[[autodoc]] utils.import_utils.requires

[[autodoc]] utils.import_utils.requires_backends

+ 248
- 0
docs/source/en/model_doc/paddleocr_vl.md View File

@@ -0,0 +1,248 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->
*This model was released on 2025.10.16 and added to Hugging Face Transformers on 2025.12.10*

# PaddleOCR-VL

<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>

## Overview

**Huggingface Hub**: [PaddleOCR-VL](https://huggingface.co/collections/PaddlePaddle/paddleocr-vl) | **Github Repo**: [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)

**Official Website**: [Baidu AI Studio](https://aistudio.baidu.com/paddleocr) | **arXiv**: [Technical Report](https://arxiv.org/pdf/2510.14528)

**PaddleOCR-VL** is a SOTA and resource-efficient model tailored for document parsing. Its core component is PaddleOCR-VL-0.9B, a compact yet powerful vision-language model (VLM) that integrates a NaViT-style dynamic resolution visual encoder with the ERNIE-4.5-0.3B language model to enable accurate element recognition. This innovative model efficiently supports 109 languages and excels in recognizing complex elements (e.g., text, tables, formulas, and charts), while maintaining minimal resource consumption. Through comprehensive evaluations on widely used public benchmarks and in-house benchmarks, PaddleOCR-VL achieves SOTA performance in both page-level document parsing and element-level recognition. It significantly outperforms existing solutions, exhibits strong competitiveness against top-tier VLMs, and delivers fast inference speeds. These strengths make it highly suitable for practical deployment in real-world scenarios.

<div align="center">
<img src="https://huggingface.co/datasets/PaddlePaddle/PaddleOCR-VL_demo/resolve/main/imgs/allmetric.png" width="800"/>
</div>

### **Core Features**

1. **Compact yet Powerful VLM Architecture:** We present a novel vision-language model that is specifically designed for resource-efficient inference, achieving outstanding performance in element recognition. By integrating a NaViT-style dynamic high-resolution visual encoder with the lightweight ERNIE-4.5-0.3B language model, we significantly enhance the model’s recognition capabilities and decoding efficiency. This integration maintains high accuracy while reducing computational demands, making it well-suited for efficient and practical document processing applications.

2. **SOTA Performance on Document Parsing:** PaddleOCR-VL achieves state-of-the-art performance in both page-level document parsing and element-level recognition. It significantly outperforms existing pipeline-based solutions and exhibiting strong competitiveness against leading vision-language models (VLMs) in document parsing. Moreover, it excels in recognizing complex document elements, such as text, tables, formulas, and charts, making it suitable for a wide range of challenging content types, including handwritten text and historical documents. This makes it highly versatile and suitable for a wide range of document types and scenarios.

3. **Multilingual Support:** PaddleOCR-VL Supports 109 languages, covering major global languages, including but not limited to Chinese, English, Japanese, Latin, and Korean, as well as languages with different scripts and structures, such as Russian (Cyrillic script), Arabic, Hindi (Devanagari script), and Thai. This broad language coverage substantially enhances the applicability of our system to multilingual and globalized document processing scenarios.

### **Model Architecture**

<div align="center">
<img src="https://huggingface.co/datasets/PaddlePaddle/PaddleOCR-VL_demo/resolve/main/imgs/paddleocrvl.png" width="800"/>
</div>

## Usage

### Usage tips

> [!IMPORTANT]
> We currently recommend using the [PaddleOCR official method for inference](https://www.paddleocr.ai/latest/en/version3.x/pipeline_usage/PaddleOCR-VL.html), as it is faster and supports page-level document parsing.
> The example code below only supports element-level recognition.

We have four types of element-level recognition:

- Text recognition, indicated by the prompt `OCR:`.
- Formula recognition, indicated by the prompt `Formula Recognition:`.
- Table recognition, indicated by the prompt `Table Recognition:`.
- Chart recognition, indicated by the prompt `Chart Recognition:`.

The following examples are all based on text recognition, with the prompt `OCR:`.

### Single input inference

The example below demonstrates how to generate text with PaddleOCRVL using [`Pipeline`] or the [`AutoModel`].

<hfoptions id="usage">
<hfoption id="Pipeline">

```py
from transformers import pipeline

pipe = pipeline("image-text-to-text", model="PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
messages = [
{
"role": "user",
"content": [
{"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"},
{"type": "text", "text": "OCR:"},
]
}
]
result = pipe(text=messages)
print(result[0]["generated_text"])
```

</hfoption>

<hfoption id="AutoModel">

```py
from transformers import AutoProcessor, AutoModelForImageTextToText

model = AutoModelForImageTextToText.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
processor = AutoProcessor.from_pretrained("PaddlePaddle/PaddleOCR-VL")
messages = [
{
"role": "user",
"content": [
{"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"},
{"type": "text", "text": "OCR:"},
]
}
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device)

outputs = model.generate(**inputs, max_new_tokens=100)
result = processor.decode(outputs[0][inputs["input_ids"].shape[-1]:-1])
print(result)
```

</hfoption>
</hfoptions>

### Batched inference

PaddleOCRVL also supports batched inference. We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Here is how you can do it with PaddleOCRVL using [`Pipeline`] or the [`AutoModel`]:

<hfoptions id="usage">
<hfoption id="Pipeline">

```py
from transformers import pipeline

pipe = pipeline("image-text-to-text", model="PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
messages = [
{
"role": "user",
"content": [
{"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"},
{"type": "text", "text": "OCR:"},
]
}
]
result = pipe(text=[messages, messages])
print(result[0][0]["generated_text"])
print(result[1][0]["generated_text"])
```

</hfoption>

<hfoption id="AutoModel">

```py
from transformers import AutoProcessor, AutoModelForImageTextToText

model = AutoModelForImageTextToText.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
processor = AutoProcessor.from_pretrained("PaddlePaddle/PaddleOCR-VL")
messages = [
{
"role": "user",
"content": [
{"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"},
{"type": "text", "text": "OCR:"},
]
}
]
batch_messages = [messages, messages]
inputs = processor.apply_chat_template(
batch_messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding=True,
padding_side='left',
).to(model.device)

generated_ids = model.generate(**inputs, max_new_tokens=100)
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
result = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(result)
```

</hfoption>
</hfoptions>

### Using Flash Attention 2

Flash Attention 2 is an even faster, optimized version of the previous optimization, please refer to the [FlashAttention](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention).

For example:

```shell
pip install flash-attn --no-build-isolation
```

```python
from transformers import AutoModelForImageTextToText
model = AutoModelForImageTextToText.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16", attn_implementation="flash_attention_2")
```

## PaddleOCRVLForConditionalGeneration

[[autodoc]] PaddleOCRVLForConditionalGeneration
- forward

## PaddleOCRVLConfig

[[autodoc]] PaddleOCRVLConfig

## PaddleOCRVisionConfig

[[autodoc]] PaddleOCRVisionConfig

## PaddleOCRTextConfig

[[autodoc]] PaddleOCRTextConfig

## PaddleOCRTextModel

[[autodoc]] PaddleOCRTextModel

## PaddleOCRVisionModel

[[autodoc]] PaddleOCRVisionModel

## PaddleOCRVLImageProcessor

[[autodoc]] PaddleOCRVLImageProcessor

## PaddleOCRVLImageProcessorFast

[[autodoc]] PaddleOCRVLImageProcessorFast

## PaddleOCRVLModel

[[autodoc]] PaddleOCRVLModel

## PaddleOCRVLProcessor

[[autodoc]] PaddleOCRVLProcessor

## PaddleOCRVisionTransformer

[[autodoc]] PaddleOCRVisionTransformer

+ 15
- 0
docs/source/en/model_doc/sam3.md View File

@@ -354,6 +354,21 @@ When running the same text prompt on multiple images, pre-compute text embedding
... print(f"Image {i+1}: {len(results['masks'])} '{text_prompt}' objects found")
```

### Custom Resolution Inference

<div class="warning">
⚠️ **Performance Note**: Custom resolutions may degrade accuracy. The model is meant to be used at 1008px resolution.
</div>

For faster inference or lower memory usage:

```python
>>> config = Sam3Config.from_pretrained("facebook/sam3")
>>> config.image_size = 560
>>> model = Sam3Model.from_pretrained("facebook/sam3", config=config).to(device)
>>> processor = Sam3Processor.from_pretrained("facebook/sam3", size={"height": 560, "width": 560})
```

### Prompt Label Conventions

SAM3 uses the following label conventions:


+ 15
- 0
docs/source/en/model_doc/sam3_video.md View File

@@ -188,6 +188,21 @@ For real-time applications, SAM3 Video supports processing video frames as they
>>> print(f"Masks are at original video resolution: {frame_0_outputs['masks'].shape}")
```

#### Custom Resolution Inference

<div class="warning">
⚠️ **Performance Note**: Custom resolutions may degrade accuracy. The model is meant to be used at 1008px resolution.
</div>

For faster inference or lower memory usage:

```python
>>> config = Sam3VideoConfig.from_pretrained("facebook/sam3")
>>> config.image_size = 560
>>> model = Sam3VideoModel.from_pretrained("facebook/sam3", config=config).to(device, dtype=torch.bfloat16)
>>> processor = Sam3VideoProcessor.from_pretrained("facebook/sam3", size={"height": 560, "width": 560})
```

## Sam3VideoConfig

[[autodoc]] Sam3VideoConfig


+ 178
- 0
docs/source/en/optimization_overview.md View File

@@ -0,0 +1,178 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# Overview

Transformers provides multiple inference optimization techniques to make models fast, affordable, and accessible. Options include alternative attention mechanisms for reduced memory traffic, code compilation for faster execution, and optimized kernels for throughput. Stack these techniques for maximum performance.

> [!NOTE]
> Memory and speed are closely related but not the same. Shrinking your memory footprint makes a model "faster" because there is less data to move around. Pure speed optimizations don't always reduce memory and sometimes increase usage. Choose the appropriate optimization based on your use case and hardware.

Use the table below to pick an optimization technique.

| Technique | Speed | Memory |
|---|:---:|:---:|
| [Compilation](#compilation) | ✅ | |
| [Attention backends](#attention-backends) | ✅ | ✅ |
| [Kernels](#kernels) | ✅ | ✅ |
| [Quantization](#quantization) | ✅ | ✅ |
| [Caching](#caching) | ✅ | ✅ |
| [Parallelism](#parallelism) | ✅ | |
| [Continuous batching](#continuous-batching) | ✅ | |

This guide gives you a quick start on Transformers optimizations.

## Compilation

[torch.compile](./perf_torch_compile) reduces Python overhead, fuses operations, and creates kernels tuned for your shapes and hardware. The first run warms it up and subsequent runs use the faster compiled path.

Pass a [fixed size cache](./kv_cache#fixed-size-cache) to [`~GenerationMixin.generate`] to trigger `torch.compile` automatically.

```py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", dtype=torch.float16, device_map="auto")
input = tokenizer("The French Bread Law states", return_tensors="pt").to(model.device)

output = model.generate(**input, do_sample=False, max_new_tokens=20, cache_implementation="static")
tokenizer.batch_decode(output, skip_special_tokens=True)[0]
```

> [!WARNING]
> Avoid calling `torch.compile(model)` outside of [`~GenerationMixin.generate`] to prevent the model from recompiling every step.

## Attention backends

Alternative [attention backends](./attention_interface) lower memory traffic. For example, FlashAttention tiles attention computations and avoids large intermediate tensors to reduce memory footprint.

Set `attn_implementation` in [`~PreTrainedModel.from_pretrained`] to load an optimized attention backend.

```py
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", attn_implementation="flash_attention_2")
```

## Kernels

Kernels fuse operations to boost throughput and reduce memory usage. The [Kernels](https://huggingface.co/docs/kernels/en/index) library loads optimized compute kernels from the [Hub](https://huggingface.co/kernels-community) in a flexible and version-safe way.

The example below loads an optimized FlashAttention-2 kernel without installing the package.

```py
import torch
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-0.6B", attn_implementation="kernels-community/flash-attn2"
)
```

## Quantization

[Quantization](./quantization/overview) shrinks the size of every parameter which lowers memory footprint and increases speed because you can do more operations.

Pass a quantization config to the `quantization_config` argument in [`~PreTrainedModel.from_pretrained`]. Each quantization backend has a different config with different arguments. The example below quantizes a model to 4-bits and configures the computation dtype with the [bitsandbytes](./quantization/bitsandbytes) backend.

```py
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)

model = AutoModelForCausalLM.from_pretrained(
"allenai/Olmo-3-7B-Think", quantization_config=bnb_config
)
```

## Caching

[Caching](./kv_cache) speeds up generation by reusing past keys and values instead of recomputing them for every token. To offset and reduce the memory cost of storing past keys and values, Transformers
supports offloading the cache to the CPU. Only the current layer remains on the GPU.

Use the `cache_implementation` argument in [`~GenerationMixin.generate`] to set a cache strategy.

```py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-0.6B", attn_implementation="kernels-community/flash-attn2"
)
inputs = tokenizer("The Le Décret Pain states that a baguette must,", return_tensors="pt")
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=50, cache_implementation="offloaded")
```

## Parallelism

[Parallelism](./perf_infer_gpu_multi) distributes a model across devices so models too big for one device run fast. This approach uses more memory due to sharding overhead and communication to sync results.

[Tensor parallelism](./perf_infer_gpu_multi) splits a model layer across devices. Set `tp_plan="auto"` in [`~PreTrainedModel.from_pretrained`] to enable it.

```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", tp_plan="auto")
print(model._tp_plan)
```

## Continuous batching

[Continuous batching](./continuous_batching) maximizes throughput by keeping the GPU busy with dynamic scheduling and chunked prefill. [Serving](./serving.md) applications use it to process multiple incoming requests concurrently.

Use [`~ContinuousMixin.generate_batch`] to enable continuous batching.

```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig

model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-0.6B",
attn_implementation="paged|sdpa",
device_map="cuda",
torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")

prompts = [
"The Le Décret Pain states that a baguette must",
"Explain gravity in one sentence.",
"Name the capital of France.",
]
inputs = [tokenizer.encode(p) for p in prompts]

generation_config = GenerationConfig(
max_new_tokens=32,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=False,
max_batch_tokens=512,
)

outputs = model.generate_batch(
inputs=inputs,
generation_config=generation_config,
)

for request_id, output in outputs.items():
text = tokenizer.decode(output.generated_tokens, skip_special_tokens=True)
print(f"[{request_id}] {text}")
```

+ 95
- 16
docs/source/en/quantization/contribute.md View File

@@ -16,7 +16,7 @@ rendered properly in your Markdown viewer.

# Contribute

Transformers supports many quantization methods such as QLoRA, GPTQ, LLM.int8, and AWQ. However, there are still many more quantization approaches that haven't been integrated yet. To make adding and using these quantization methods with Transformers easier, use the [`~quantizers.HfQuantizer`] class. [`~quantizers.HfQuantizer`] is designed to be an internal helper class for adding a quantization method instead of something applied to every PyTorch module.
Transformers supports many quantization methods such as QLoRA, GPTQ, LLM.int8, and AWQ. However, there are still many more quantization approaches that haven't been integrated yet. To make adding and using these quantization methods with Transformers easier, use the [`~quantizers.HfQuantizer`] class. [`~quantizers.HfQuantizer`] is designed to be an internal helper class for adding a quantization method instead of something applied to every PyTorch module.

This guide will show you how to integrate a new quantization method with [`~quantizers.HfQuantizer`].

@@ -28,16 +28,16 @@ Before integrating a new quantization method into Transformers, ensure the metho
- The method can run on commonly-used hardware (CPU, GPU, etc.).
- The method is wrapped in a [nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) ([`~bitsandbytes.nn.Linear8bitLt`], [`~bitsandbytes.nn.Linear4bit`]), and the quantized linear layer should have the following definition.

```py
class Linear4bit(nn.Module):
def __init__(self, ...):
...
def forward(self, x):
return my_4bit_kernel(x, self.weight, self.bias)
```
```py
class Linear4bit(nn.Module):
def __init__(self, ...):
...

This way, Transformers models are easily quantized by replacing instances of [nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) with a target class.
def forward(self, x):
return my_4bit_kernel(x, self.weight, self.bias)
```

This way, Transformers models are easily quantized by replacing instances of [nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) with a target class.

- The quantization method should be serializable. You can save the quantized weights locally or push them to the Hub.
- Make sure the package containing the quantization kernels/primitive is stable (no frequent breaking changes).
@@ -48,23 +48,23 @@ Some quantization methods may require "pre-quantizing" the model through data ca

0. The best starting point would be to have a look at another quantization method such as Finegrained Fp8. You will have to update or create three files in total: the [config file](https://github.com/huggingface/transformers/blob/main/src/transformers/utils/quantization_config.py), the [integration file](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/finegrained_fp8.py) and the [quantizer file](https://github.com/huggingface/transformers/blob/main/src/transformers/quantizers/quantizer_finegrained_fp8.py).

1. Create a new quantization config class inside [src/transformers/utils/quantization_config.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/utils/quantization_config.py). Add the new quantization config to the [_import_structure](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/__init__.py#L1088) inside Transformers' [src/transformers/__init__.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/__init__.py) file.
1. Create a new quantization config class inside [src/transformers/utils/quantization_config.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/utils/quantization_config.py). Add the new quantization config to the [\_import_structure](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/__init__.py#L1088) inside Transformers' [src/transformers/\_\_init\_\_.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/__init__.py) file.

2. Create a new file inside [src/transformers/quantizers/](https://github.com/huggingface/transformers/tree/abbffc4525566a48a9733639797c812301218b83/src/transformers/quantizers) named `quantizer_your_method.py`, and make it inherit from [`~quantizers.HfQuantizer]. Make sure to add the new quantizer and quantization config in the quantization auto-mapping in [src/transformers/quantizers/auto.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/quantizers/auto.py).

3. Define the following class attributes and property methods for your quantization method:

- `requires_calibration`: Whether the quantization method requires a data calibration process. If set to `True`, you can only support inference (with quantized weights) and not inference and quantization.
- `is_serializable`: A property method to determine whether the method is serializable or not.
- `is_trainable`: A property method to determine whether you can fine-tune models on top of the quantization method (with or without PEFT approaches).
- `requires_calibration`: Whether the quantization method requires a data calibration process. If set to `True`, you can only support inference (with quantized weights) and not inference and quantization.
- `is_serializable`: A property method to determine whether the method is serializable or not.
- `is_trainable`: A property method to determine whether you can fine-tune models on top of the quantization method (with or without PEFT approaches).

4. Write the `validate_environment` and `update_dtype` methods. These methods are called before creating the quantized model to ensure users use the right configuration. Refer to other quantizers for an example of it is implemented.

5. Write the `_process_model_before_weight_loading` method. In Transformers, the quantized models are initialized first on the `"meta"` device before loading the weights. This means the `_process_model_before_weight_loading` method takes care of manipulating the model skeleton to replace some modules ([nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html)) with the target modules (quantization modules).

You can define module replacement logic or any other utility method by creating a new file in [transformers/src/integrations/](https://github.com/huggingface/transformers/tree/abbffc4525566a48a9733639797c812301218b83/src/transformers/integrations) and exposing the relevant methods in that folder's `__init__.py` file.
You can define module replacement logic or any other utility method by creating a new file in [transformers/src/integrations/](https://github.com/huggingface/transformers/tree/abbffc4525566a48a9733639797c812301218b83/src/transformers/integrations) and exposing the relevant methods in that folder's `__init__.py` file.

6. Add the `get_quantize_ops` method to the quantizer class if the quantization supports quantizing on the fly. In transformers, we materialize each tensor and apply a sequence of different operations on it. In our case, the quantization operation happens at the end. You need to create a `XXXQuantize`, a subclass of `ConversionOps`, and add a `convert` method. In the `convert` method, you need to quantize the weights and return a dictionary of quantized params.
6. Add the `get_quantize_ops` method to the quantizer class if the quantization supports quantizing on the fly. In transformers, we materialize each tensor and apply a sequence of different operations on it. In our case, the quantization operation happens at the end. You need to create a `XXXQuantize`, a subclass of `ConversionOps`, and add a `convert` method. In the `convert` method, you need to quantize the weights and return a dictionary of quantized params.

7. Add the `get_weight_conversions` method to the quantizer class if the quantization supports loading pre-quantized weights. In transformers, we can collect multiple tensors and apply operations on them. This is particularly useful when we have tensors in the checkpoint that require to be regrouped to re-create the quantized tensors.

@@ -73,3 +73,82 @@ You can define module replacement logic or any other utility method by creating
9. Document everything! Make sure your quantization method is documented by adding a new file under `docs/source/en/quantization`.

10. You should add tests by adding the package in our nightly Dockerfile inside `docker/transformers-quantization-latest-gpu` and then adding a new test file in `tests/quantization/xxx`. Feel free to check out existing quantization methods to see how it is implemented.

## Files overview

| File | Purpose |
| -------------------------------------------- | ------------------------------------------------------------------------------------------------ |
| `utils/quantization_config.py` | Define `YourMethodConfig` inheriting from `QuantizationConfigMixin` |
| `quantizers/quantizer_your_method.py` | Implement `YourMethodHfQuantizer` inheriting from `HfQuantizer` |
| `integrations/your_method.py` | Implement `ConversionOps` subclasses and helper functions |
| `quantizers/auto.py` | Register quantizer and config in `AUTO_QUANTIZER_MAPPING` and `AUTO_QUANTIZATION_CONFIG_MAPPING` |
| `docs/source/en/quantization/your_method.md` | Document usage for users |
| `tests/quantization/your_method/` | Add integration tests |

## Understanding `get_quantize_ops` vs `get_weight_conversions`

These two methods handle different scenarios for loading weights. Understanding when to use each is essential.

### `get_quantize_ops` — Quantize on the fly

Use this when loading a **non-quantized checkpoint** (e.g., float16/bfloat16 weights) and quantizing during load.

```
Checkpoint: model.safetensors (float16 weights for example)
get_quantize_ops → YourQuantize.convert()
Result: Quantized weights in memory
```

The `convert` method receives one tensor at a time, quantizes it, and can return a dictionary of quantized params, for example:

```py
class YourQuantize(ConversionOps):
def convert(self, input_dict, model, full_layer_name, missing_keys, **kwargs):
# input_dict = {"layer.weight": <float16 tensor>}
value = list(input_dict.values())[0]
module, tensor_name = get_module_from_name(model, full_layer_name)

# Quantize and assign
quantized, scale, zero_point = your_quantize_fn(value)
return {full_layer_name: quantized, full_layer_name + ".scale": scale, full_layer_name + ".zero_point": zero_point}
```

### `get_weight_conversions` — Load pre-quantized checkpoints

Use this when loading a **pre-quantized checkpoint** where the quantized weights are saved as several separate components (such as data, scale, and zero point), and these need to be combined into one tensor during loading. Not all quantization methods require this reconstruction step: for example, some methods like FP8 simply load weights and scales as-is, without combining them. Others, such as torchao, do require reassembling the quantized tensor from its multiple saved components.

```
Checkpoint: model.safetensors (quantized components)
- layer._weight_qdata
- layer._weight_scale
- layer._weight_zero_point
get_weight_conversions → WeightConverter + YourDeserialize.convert()
Result: Reconstructed quantized tensor → layer.weight
```

The `WeightConverter` collects related tensors based on `source_patterns`, then passes them to your `convert` method:

```py
def get_weight_conversions(self):
if self.pre_quantized:
return [
WeightConverter(
source_patterns=["_weight_qdata", "_weight_scale", "_weight_zero_point"],
target_patterns="weight",
operations=[YourDeserialize(self)],
),
]
return []


class YourDeserialize(ConversionOps):
def convert(self, input_dict, model, full_layer_name, **kwargs):
# input_dict contains all collected tensors
# Reconstruct the quantized tensor from components
reconstructed_tensor = reconstruct_from_components(input_dict)
return {full_layer_name: reconstructed_tensor}
```

+ 48
- 21
docs/source/en/tasks/image_text_to_text.md View File

@@ -33,7 +33,8 @@ This guide focuses on inference with an instruction-tuned model.
Let's begin installing the dependencies.

```bash
pip install -q transformers accelerate flash_attn
pip install -q transformers accelerate
pip install flash-attn --no-build-isolation
```

Let's initialize the model and the processor.
@@ -45,12 +46,12 @@ import torch

device = Accelerator().device
model = AutoModelForImageTextToText.from_pretrained(
"HuggingFaceM4/idefics2-8b",
"Qwen/Qwen3-VL-4B-Instruct",
dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
).to(device)

processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b")
processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-4B-Instruct")
```

This model has a [chat template](./chat_templating) that helps user parse chat outputs. Moreover, the model can also accept multiple images as input in a single conversation or message. We will now prepare the inputs.
@@ -65,24 +66,29 @@ The image inputs look like the following.
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg" alt="A bee on a pink flower"/>
</div>

```python
from PIL import Image
import requests

img_urls =["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"]
images = [Image.open(requests.get(img_urls[0], stream=True).raw),
Image.open(requests.get(img_urls[1], stream=True).raw)]
Structure your conversation as shown below for a single prompt with image and text inputs.

```python
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.png"},
{"type": "text", "text": "What do we see in this image?"},
]
}
]
```

Below is an example of the chat template. We can feed conversation turns and the last message as an input by appending it at the end of the template.
Alternate between the `user` and `assistant` role to ground the model with prior context to generate better responses.

```python
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.png"},
{"type": "text", "text": "What do we see in this image?"},
]
},
@@ -95,7 +101,7 @@ messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
{"type": "text", "text": "And how about this image?"},
]
},
@@ -105,19 +111,20 @@ messages = [
We will now call the processors' [`~ProcessorMixin.apply_chat_template`] method to preprocess its output along with the image inputs.

```python
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=[images[0], images[1]], return_tensors="pt").to(device)
inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(device)
```

We can now pass the preprocessed inputs to the model.

```python
input_len = len(inputs.input_ids[0])

with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=500)
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
generated_ids = model.generate(**inputs, max_new_tokens=200)
generated_texts = processor.batch_decode(generated_ids[:, input_len:], skip_special_tokens=True)

print(generated_texts)
## ['User: What do we see in this image? \nAssistant: In this image we can see two cats on the nets. \nUser: And how about this image? \nAssistant: In this image we can see flowers, plants and insect.']
## ['In this image we can see flowers, plants and insect.']
```

## Pipeline
@@ -289,19 +296,38 @@ VLMs are often large and need to be optimized to fit on smaller hardware. Transf
First, install dependencies.

```bash
pip install -U quanto bitsandbytes
pip install -U optimum-quanto bitsandbytes
```

To quantize a model during loading, we need to first create [`QuantoConfig`]. Then load the model as usual, but pass `quantization_config` during model initialization.
To quantize a model during loading, we need to first create [`QuantoConfig`]. Then load the model as usual, but pass `quantization_config` during model initialization.

```python
from transformers import AutoModelForImageTextToText, QuantoConfig

model_id = "HuggingFaceM4/idefics2-8b"
model_id = "Qwen/Qwen3-VL-4B-Instruct"
quantization_config = QuantoConfig(weights="int8")
quantized_model = AutoModelForImageTextToText.from_pretrained(
model_id, device_map="auto", quantization_config=quantization_config
)

messages = [
{
"role": "user",
"content": [
{"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.png"},
{"type": "text", "text": "What do we see in this image?"},
]
},
]
inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device)
input_len = len(inputs.input_ids[0])

with torch.no_grad():
generated_ids = model.generate(**inputs, cache_implementation="static", max_new_tokens=100)
generated_texts = processor.batch_decode(generated_ids[:, input_len:], skip_special_tokens=True)

print(generated_texts[0])
## ['In this image, we see two tabby cats resting on a large, tangled pile of fishing nets. The nets are a mix of brown, orange, and red colors, with some blue and green ropes visible in the background. The cats appear relaxed and comfortable, nestled into the fibers of the nets. One cat is in the foreground, looking slightly to the side, while the other is positioned further back, looking directly at the camera. The scene suggests a coastal or fishing-related setting, possibly near']
```

And that's it, we can use the model the same way with no changes.
@@ -312,3 +338,4 @@ Here are some more resources for the image-text-to-text task.

- [Image-text-to-text task page](https://huggingface.co/tasks/image-text-to-text) covers model types, use cases, datasets, and more.
- [Vision Language Models Explained](https://huggingface.co/blog/vlms) is a blog post that covers everything about vision language models and supervised fine-tuning using [TRL](https://huggingface.co/docs/trl/en/index).
- [Learn how to fine-tune vision language models using TRL](https://huggingface.co/blog/trl-vlm-alignment)

+ 334
- 19
docs/source/en/tasks/mask_generation.md View File

@@ -24,8 +24,9 @@ Mask generation models are trained on large amounts of data and operate in two m
- Prompting mode: In this mode, the model takes in an image and a prompt, where a prompt can be a 2D point location (XY coordinates) in the image within an object or a bounding box surrounding an object. In prompting mode, the model only returns the mask over the object
that the prompt is pointing out.
- Segment Everything mode: In segment everything, given an image, the model generates every mask in the image. To do so, a grid of points is generated and overlaid on the image for inference.
- Video Inference: The model accepts a video, and a point or box prompt in a video frame, which is tracked throughout the video. You can get more information on how to do video inference by following [SAM 2 docs](../model_doc/sam2).

Mask generation task is supported by [Segment Anything Model (SAM)](model_doc/sam). It's a powerful model that consists of a Vision Transformer-based image encoder, a prompt encoder, and a two-way transformer mask decoder. Images and prompts are encoded, and the decoder takes these embeddings and generates valid masks.
Mask generation task is supported by [Segment Anything Model (SAM)](../model_doc/sam) and [Segment Anything Model 2 (SAM2)](../model_doc/sam2), while video inference is supported by [Segment Anything Model 2 (SAM2)](../model_doc/sam2). SAM is a powerful model that consists of a Vision Transformer-based image encoder, a prompt encoder, and a two-way transformer mask decoder. Images and prompts are encoded, and the decoder takes these embeddings and generates valid masks. Meanwhile, SAM 2 extends SAM by adding a memory module to track the masks.

<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/sam.png" alt="SAM Architecture"/>
@@ -53,7 +54,7 @@ The easiest way to infer mask generation models is to use the `mask-generation`
```python
>>> from transformers import pipeline

>>> checkpoint = "facebook/sam-vit-base"
>>> checkpoint = "facebook/sam2-hiera-base-plus"
>>> mask_generator = pipeline(model=checkpoint, task="mask-generation")
```

@@ -80,20 +81,12 @@ masks = mask_generator(image, points_per_batch=128, pred_iou_thresh=0.88)
The `masks` looks like the following:

```bash
{'masks': [array([[False, False, False, ..., True, True, True],
[False, False, False, ..., True, True, True],
[False, False, False, ..., True, True, True],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]]),
array([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
'scores': tensor([0.9972, 0.9917,
...,
}
{'masks': [tensor([[False, False, False, ..., True, True, True],
[False, False, False, ..., True, True, True],
[False, False, False, ..., True, True, True],
...,
[False, False, False, ..., False, False, False], ..
'scores': tensor([0.9874, 0.9793, 0.9780, 0.9776, ... 0.9016])}
```

We can visualize them like this:
@@ -134,7 +127,7 @@ processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

To do point prompting, pass the input point to the processor, then take the processor output
and pass it to the model for inference. To post-process the model output, pass the outputs and
`original_sizes` and `reshaped_input_sizes` we take from the processor's initial output. We need to pass these
`original_sizes` are taken from the processor's initial output. We need to pass these
since the processor resizes the image, and the output needs to be extrapolated.

```python
@@ -143,7 +136,7 @@ input_points = [[[2592, 1728]]] # point location of the bee
inputs = processor(image, input_points=input_points, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu())
```

We can visualize the three masks in the `masks` output.
@@ -199,7 +192,6 @@ with torch.no_grad():
mask = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)[0][0][0].numpy()
```

@@ -235,3 +227,326 @@ plt.show()
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/box_inference.png" alt="Visualized Inference"/>
</div>

## Fine-tuning for Mask Generation

We will fine-tune SAM2.1 on small part of MicroMat dataset for image matting. We need to install the [monai](https://github.com/Project-MONAI/MONAI) library to use DICE loss, and [trackio](https://huggingface.co/docs/trackio/index) for logging the masks during training.

```bash
pip install -q datasets monai trackio
``` 
We can now load our dataset and take a look.

```python
from datasets import load_dataset

dataset = load_dataset("merve/MicroMat-mini", split="train")
dataset
# Dataset({
# features: ['image', 'mask', 'prompt', 'image_id', 'object_id', 'sample_idx', 'granularity',
# 'image_path', 'mask_path', 'prompt_path'], num_rows: 94
#})
```
We need image, mask and prompt columns. We split for train and test.

```python
dataset = dataset.train_test_split(test_size=0.1)
train_ds = dataset["train"]
val_ds = dataset["test"]
```

Let's take a look at a sample.
```python
train_ds[0]
```
```
{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=2040x1356>,
'mask': <PIL.PngImagePlugin.PngImageFile image mode=L size=2040x1356>,
'prompt': '{"point": [[137, 1165, 1], [77, 1273, 0], [58, 1351, 0]], "bbox": [0, 701, 251, 1356]}',
'image_id': '0034',
'object_id': '34',
'sample_idx': 1,
'granularity': 'fine',
'image_path': '/content/MicroMat-mini/img/0034.png',
'mask_path': '/content/MicroMat-mini/mask/0034_34.png',
'prompt_path': '/content/MicroMat-mini/prompt/0034_34.json'}
```
Prompts are string of dictionaries, so you can get the bounding boxes as shown below.
```python
import json

json.loads(train_ds["prompt"][0])["bbox"]
# [0, 701, 251, 1356]
```

Visualize an example image, prompt and mask.

```python
import matplotlib.pyplot as plt
import numpy as np

def show_mask(mask, ax):
color = np.array([0.12, 0.56, 1.0, 0.6])
mask = np.array(mask)
h, w = mask.shape
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, 4)
ax.imshow(mask_image)
x0, y0, x1, y1 = eval(train_ds["prompt"][0])["bbox"]
ax.add_patch(
plt.Rectangle((x0, y0), x1 - x0, y1 - y0,
fill=False, edgecolor="lime", linewidth=2))

example = train_ds[0]
image = np.array(example["image"])
ground_truth_mask = np.array(example["mask"])

fig, ax = plt.subplots()
ax.imshow(image)
show_mask(ground_truth_mask, ax)
ax.set_title("Ground truth mask")
ax.set_axis_off()

plt.show()
```

Now we can define our dataset for loading the data. SAMDataset wraps our dataset and formats each sample the way the SAM processor expects. So instead of raw images and masks, you get processed images, bounding boxes, and ground-truth masks ready for training.

By default, processor resizes images, so on top of images and masks, it also returns original sizes. We also need to binarize the mask as it has values [0, 255].

```python
from torch.utils.data import Dataset
import torch

class SAMDataset(Dataset):
def __init__(self, dataset, processor):
self.dataset = dataset
self.processor = processor

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
item = self.dataset[idx]
image = item["image"]
prompt = eval(item["prompt"])["bbox"]
inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")
inputs["ground_truth_mask"] = (np.array(item["mask"]) > 0).astype(np.float32)
inputs["original_image_size"] = torch.tensor(image.size[::-1])


return inputs
```

We can initialize the processor and the dataset with it.

```python
from transformers import Sam2Processor

processor = Sam2Processor.from_pretrained("facebook/sam2.1-hiera-small")
train_dataset = SAMDataset(dataset=train_ds, processor=processor)
``` 

We need to define a data collator that will turn varying size of ground truth masks to batches of reshaped masks in same shape. We reshape them using nearest neighbor interpolation. We also make batched tensors for rest of the elements in the batch. If your masks are all of same size, feel free to skip this step.

```python
import torch.nn.functional as F

def collate_fn(batch, target_hw=(256, 256)):

pixel_values = torch.cat([item["pixel_values"] for item in batch], dim=0)
original_sizes = torch.stack([item["original_sizes"] for item in batch])
input_boxes = torch.cat([item["input_boxes"] for item in batch], dim=0)
ground_truth_masks = torch.cat([
F.interpolate(
torch.as_tensor(x["ground_truth_mask"]).unsqueeze(0).unsqueeze(0).float(),
size=(256, 256),
mode="nearest"
)
for x in batch
], dim=0).long()

return {
"pixel_values": pixel_values,
"original_sizes": original_sizes,
"input_boxes": input_boxes,
"ground_truth_mask": ground_truth_masks,
"original_image_size": torch.stack([item["original_image_size"] for item in batch]),
}

from torch.utils.data import DataLoader
train_dataloader = DataLoader(
train_dataset,
batch_size=4,
shuffle=True,
collate_fn=collate_fn,
)
```

Let's take a look at what the data loader yields.

```python
batch = next(iter(train_dataloader))
for k,v in batch.items():
print(k,v.shape)

# pixel_values torch.Size([4, 3, 1024, 1024])
# original_sizes torch.Size([4, 1, 2])
# input_boxes torch.Size([4, 1, 4])
# ground_truth_mask torch.Size([4, 1, 256, 256])
#original_image_size torch.Size([4, 2])
```
We will now load the model and freeze the vision and the prompt encoder to only train the mask decoder.

```python
from transformers import Sam2Model

model = Sam2Model.from_pretrained("facebook/sam2.1-hiera-small")

for name, param in model.named_parameters():
if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
param.requires_grad_(False)
``` 

We can now define the optimizer and the loss function.
```python
from torch.optim import Adam
import monai

optimizer = Adam(model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)
seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
```

Let's see how the model performs before training.

```python
import matplotlib.pyplot as plt

item = val_ds[1]
img = item["image"]
bbox = json.loads(item["prompt"])["bbox"]
inputs = processor(images=img, input_boxes=[[bbox]], return_tensors="pt").to(model.device)

with torch.no_grad():
outputs = model(**inputs)

masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])[0]
preds = masks.squeeze(0)
mask = (preds[0] > 0).cpu().numpy()

overlay = np.asarray(img, dtype=np.uint8).copy()
overlay[mask] = 0.55 * overlay[mask] + 0.45 * np.array([0, 255, 0], dtype=np.float32)

plt.imshow(overlay)
plt.axis("off")
plt.show()
```

![SAM2 result after training](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/sam2_before_training.png)

We need to log our predictions to trackio so we can monitor the model improvement in the middle of the training.

```python
from PIL import Image
import trackio
import json


@torch.no_grad()
def predict_fn(img, bbox):

inputs = processor(images=img, input_boxes=[[bbox]], return_tensors="pt").to(model.device)

with torch.no_grad():
outputs = model(**inputs)

masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])[0]
return masks

def log_eval_masks_trackio(dataset, indices, step, predict_fn, project=None, sample_cap=8):
logs = {"eval/step": int(step)}
for idx in indices[:sample_cap]:
item = dataset[idx]
img = item["image"]
bbox = json.loads(item["prompt"])["bbox"]
preds = predict_fn(img, bbox)
preds = preds.squeeze(0)
mask = (preds[0] > 0).cpu().numpy()

overlay = np.asarray(img, dtype=np.uint8).copy()
overlay[mask] = 0.55 * overlay[mask] + 0.45 * np.array([0, 255, 0], dtype=np.float32)
logs[f"{idx}/overlay"] = trackio.Image(overlay, caption="overlay")
trackio.log(logs)
```
We can now write our training loop and train!

Notice how we log our loss and evaluation masks with trackio.

```python
from tqdm import tqdm
from statistics import mean
import trackio
import torch

num_epochs = 30

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

model.train()
trackio.init(project="mask-eval")
for epoch in range(num_epochs):
epoch_losses = []
for batch in tqdm(train_dataloader):
outputs = model(pixel_values=batch["pixel_values"].to(device),
input_boxes=batch["input_boxes"].to(device),
multimask_output=False)

predicted_masks = outputs.pred_masks.squeeze(1)
ground_truth_masks = batch["ground_truth_mask"].float().to(device)
loss = seg_loss(predicted_masks, ground_truth_masks)

optimizer.zero_grad()
loss.backward()

optimizer.step()
epoch_losses.append(loss.item())
log_eval_masks_trackio(dataset=val_ds, indices=[0, 3, 6, 9], step=epoch, predict_fn=predict_fn, project="mask-eval")
print(f'Epoch: {epoch}')
print(f'Mean loss: {mean(epoch_losses)}')
trackio.log({"loss": mean(epoch_losses)})

trackio.finish()
```


Let's put the trained model to test.

```python
import matplotlib.pyplot as plt

item = val_ds[1]
img = item["image"]
bbox = json.loads(item["prompt"])["bbox"]

inputs = processor(images=img, input_boxes=[[bbox]], return_tensors="pt").to(model.device)

with torch.no_grad():
outputs = model(**inputs)

preds = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])[0]

preds = preds.squeeze(0)
mask = (preds[0] > 0).cpu().numpy()

overlay = np.asarray(img, dtype=np.uint8).copy()
overlay[mask] = 0.55 * overlay[mask] + 0.45 * np.array([0, 255, 0], dtype=np.float32)

plt.imshow(overlay)
plt.axis("off")
plt.show()
```
Great improvement after only training for 20 epochs on a small dataset!

![SAM2 result after training](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/sam2_after_training.png)

+ 1
- 1
docs/source/en/tasks/semantic_segmentation.md View File

@@ -219,7 +219,7 @@ Start by loading a smaller subset of the SceneParse150 dataset from the 🤗 Dat
```py
>>> from datasets import load_dataset

>>> ds = load_dataset("scene_parse_150", split="train[:50]")
>>> ds = load_dataset("merve/scene_parse_150", split="train[:50]")
```

Split the dataset's `train` split into a train and test set with the [`~datasets.Dataset.train_test_split`] method:


+ 92
- 74
docs/source/en/tasks/video_text_to_text.md View File

@@ -18,9 +18,13 @@ rendered properly in your Markdown viewer.

[[open-in-colab]]

Video-text-to-text models, also known as video language models or vision language models with video input, are language models that take a video input. These models can tackle various tasks, from video question answering to video captioning.
Video-text-to-text, also known as video language models are models that can process video and output text. These models can tackle various tasks, from video question answering to video captioning.

These models have nearly the same architecture as [image-text-to-text](../image_text_to_text) models except for some changes to accept video data, since video data is essentially image frames with temporal dependencies. Some image-text-to-text models take in multiple images, but this alone is inadequate for a model to accept videos. Moreover, video-text-to-text models are often trained with all vision modalities. Each example might have videos, multiple videos, images and multiple images. Some of these models can also take interleaved inputs. For example, you can refer to a specific video inside a string of text by adding a video token in text like "What is happening in this video? `<video>`".
These models have nearly the same architecture as [image-text-to-text](../image_text_to_text) models except for some changes to accept video data, since video data is essentially image frames with temporal dependencies. Some image-text-to-text models take in multiple images, but this alone is inadequate for a model to accept videos.

Moreover, video-text-to-text models are often trained with all vision modalities. Each example might have videos, multiple videos, images and multiple images. Some of these models can also take interleaved inputs. For example, you can refer to a specific video inside a string of text by adding a video token in text like "What is happening in this video? `<video>`".

Note that these models process videos with no audio. [Any-to-any](../any-to-any) models on the other hand can process videos with audio in them.

In this guide, we provide a brief overview of video LMs and show how to use them with Transformers for inference.

@@ -30,81 +34,27 @@ To begin with, there are multiple types of video LMs:
- chat fine-tuned models for conversation
- instruction fine-tuned models

This guide focuses on inference with an instruction-tuned model, [llava-hf/llava-interleave-qwen-7b-hf](https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf) which can take in interleaved data. Alternatively, you can try [llava-interleave-qwen-0.5b-hf](https://huggingface.co/llava-hf/llava-interleave-qwen-0.5b-hf) if your hardware doesn't allow running a 7B model.
This guide focuses on inference with an instruction-tuned model, [llava-hf/llava-onevision-qwen2-0.5b-ov-hf](https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf) which can take in interleaved data. Alternatively, you can try [llava-interleave-qwen-0.5b-hf](https://huggingface.co/llava-hf/llava-interleave-qwen-0.5b-hf) if your hardware doesn't allow running a 7B model.

Let's begin installing the dependencies.

```bash
pip install -q transformers accelerate flash_attn
pip install -q transformers accelerate flash_attn torchcodec
```

Let's initialize the model and the processor.

```python
from transformers import LlavaProcessor, LlavaForConditionalGeneration
from transformers import AutoProcessor, LlavaForConditionalGeneration
import torch
model_id = "llava-hf/llava-interleave-qwen-0.5b-hf"
model_id = "llava-hf/llava-onevision-qwen2-0.5b-ov-hf"

processor = LlavaProcessor.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id, device="cuda")

model = LlavaForConditionalGeneration.from_pretrained(model_id, device_map="auto", dtype=torch.float16)
```

Some models directly consume the `<video>` token, and others accept `<image>` tokens equal to the number of sampled frames. This model handles videos in the latter fashion. We will write a simple utility to handle image tokens, and another utility to get a video from a url and sample frames from it.

```python
import uuid
import requests
import cv2
from PIL import Image

def replace_video_with_images(text, frames):
return text.replace("<video>", "<image>" * frames)

def sample_frames(url, num_frames):

response = requests.get(url)
path_id = str(uuid.uuid4())

path = f"./{path_id}.mp4"

with open(path, "wb") as f:
f.write(response.content)

video = cv2.VideoCapture(path)
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
interval = total_frames // num_frames
frames = []
for i in range(total_frames):
ret, frame = video.read()
pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
if not ret:
continue
if i % interval == 0:
frames.append(pil_img)
video.release()
return frames[:num_frames]
```

Let's get our inputs. We will sample frames and concatenate them.

```python
video_1 = "https://huggingface.co/spaces/merve/llava-interleave/resolve/main/cats_1.mp4"
video_2 = "https://huggingface.co/spaces/merve/llava-interleave/resolve/main/cats_2.mp4"

video_1 = sample_frames(video_1, 6)
video_2 = sample_frames(video_2, 6)

videos = video_1 + video_2

videos

# [<PIL.Image.Image image mode=RGB size=1920x1080>,
# <PIL.Image.Image image mode=RGB size=1920x1080>,
# <PIL.Image.Image image mode=RGB size=1920x1080>, ...]
```

Both videos have cats.
We will infer with two videos, both have cats.

<div class="container">
<div class="video-container">
@@ -120,28 +70,96 @@ Both videos have cats.
</div>
</div>

Now we can preprocess the inputs.

This model has a prompt template that looks like following. First, we'll put all the sampled frames into one list. Since we have eight frames in each video, we will insert 12 `<image>` tokens to our prompt. Add `assistant` at the end of the prompt to trigger the model to give answers. Then we can preprocess.
Videos are series of image frames. Depending on the hardware limitations, downsampling is required. If the number of downsampled frames are too little, predictions will be low quality.


Video-text-to-text models have processors with video processor abstracted in them. You can pass video inference related arguments to [`~ProcessorMixin.apply_chat_template`] function.

> [!WARNING]
> You can learn more about video processors [here](../main_classes/video_processor).

We can define our chat history, passing in video with a URL like below.
```python
user_prompt = "Are these two cats in these two videos doing the same thing?"
toks = "<image>" * 12
prompt = "<|im_start|>user"+ toks + f"\n{user_prompt}<|im_end|><|im_start|>assistant"
inputs = processor(text=prompt, images=videos, return_tensors="pt").to(model.device, model.dtype)
messages = [
{
"role": "user",
"content": [
{"type": "video", "video": "https://huggingface.co/spaces/merve/llava-interleave/resolve/main/cats_1.mp4"},
{"type": "text", "text": "Describe what is happening in this video."},
],
}
]
```

We can now call [`~GenerationMixin.generate`] for inference. The model outputs the question in our input and answer, so we only take the text after the prompt and `assistant` part from the model output.
You can preprocess the videos by passing in messages, setting `do_sample_frames` to True and passing in `num_frames`. Here we sample 10 frames.

```python
output = model.generate(**inputs, max_new_tokens=100, do_sample=False)
print(processor.decode(output[0][2:], skip_special_tokens=True)[len(user_prompt)+10:])
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
num_frames=10,
do_sample_frames=True
)
inputs.to(model.device)
```
The inputs contain `input_ids` for tokenized text, `pixel_values_videos` for 10 frames and `attention_mask` for which tokens .

# The first cat is shown in a relaxed state, with its eyes closed and a content expression, while the second cat is shown in a more active state, with its mouth open wide, possibly in a yawn or a vocalization.
We can now infer with our preprocessed inputs and decode them.

```python
generated_ids = model.generate(**inputs, max_new_tokens=128)
input_length = len(inputs["input_ids"][0])
output_text = processor.batch_decode(
generated_ids[:, input_length:], skip_special_tokens=True, clean_up_tokenization_spaces=False
)
output_text = processor.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text[0])

#"The video features a fluffy, long-haired cat with a mix of brown and white fur, lying on a beige carpeted floor. The cat's eyes are wide open, and its whiskers are prominently visible. The cat appears to be in a relaxed state, with its head slightly"
```

You can also interleave multiple videos with text directly in chat template like below.

```python
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Here's a video."},
{"type": "video", "video": "https://huggingface.co/spaces/merve/llava-interleave/resolve/main/cats_1.mp4"},
{"type": "text", "text": "Here's another video."},
{"type": "video", "video": "https://huggingface.co/spaces/merve/llava-interleave/resolve/main/cats_2.mp4"},
{"type": "text", "text": "Describe similarities in these videos."},
],
}
]
```

And voila!
The inference remains the same as the previous example.

To learn more about chat templates and token streaming for video-text-to-text models, refer to the [image-text-to-text](../tasks/image_text_to_text) task guide because these models work similarly.
```python
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
num_frames=100,
do_sample_frames=True
)
inputs.to(model.device)

generated_ids = model.generate(**inputs, max_new_tokens=50)
input_length = len(inputs["input_ids"][0])
output_text = processor.batch_decode(
generated_ids[:, input_length:], skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)
#['Both videos feature a cat with a similar appearance, characterized by a fluffy white coat with black markings, a pink nose, and a pink tongue. The cat\'s eyes are wide open, and it appears to be in a state of alertness or excitement. ']
```

+ 3
- 10
examples/modular-transformers/configuration_duplicated_method.py View File

@@ -8,7 +8,7 @@
from typing import Optional

from ...configuration_utils import PreTrainedConfig
from ...modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params
from ...modeling_rope_utils import RopeParameters


class DuplicatedMethodConfig(PreTrainedConfig):
@@ -129,7 +129,7 @@ class DuplicatedMethodConfig(PreTrainedConfig):
eos_token_id: Optional[int] = 2,
pretraining_tp: Optional[int] = 1,
tie_word_embeddings: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
mlp_bias: Optional[bool] = False,
@@ -157,14 +157,7 @@ class DuplicatedMethodConfig(PreTrainedConfig):
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
# Try to set `rope_scaling` if available, otherwise use `rope_parameters`
rope_scaling = kwargs.pop("rope_scaling", None)
self.rope_parameters = rope_scaling or rope_parameters

# Validate the correctness of rotary position embeddings parameters
rope_theta = kwargs.get("rope_theta", 10000.0)
standardize_rope_params(self, rope_theta=rope_theta)
rope_config_validation(self)
self.rope_parameters = rope_parameters

super().__init__(
pad_token_id=pad_token_id,


+ 3
- 10
examples/modular-transformers/configuration_my_new_model.py View File

@@ -8,7 +8,7 @@
from typing import Optional

from ...configuration_utils import PreTrainedConfig
from ...modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params
from ...modeling_rope_utils import RopeParameters


class MyNewModelConfig(PreTrainedConfig):
@@ -165,7 +165,7 @@ class MyNewModelConfig(PreTrainedConfig):
eos_token_id: Optional[int] = 2,
pretraining_tp: Optional[int] = 1,
tie_word_embeddings: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
mlp_bias=True,
@@ -194,14 +194,7 @@ class MyNewModelConfig(PreTrainedConfig):
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
# Try to set `rope_scaling` if available, otherwise use `rope_parameters`
rope_scaling = kwargs.pop("rope_scaling", None)
self.rope_parameters = rope_scaling or rope_parameters

# Validate the correctness of rotary position embeddings parameters
rope_theta = kwargs.get("rope_theta", 10000.0)
standardize_rope_params(self, rope_theta=rope_theta)
rope_config_validation(self)
self.rope_parameters = rope_parameters

super().__init__(
pad_token_id=pad_token_id,


+ 3
- 10
examples/modular-transformers/configuration_my_new_model2.py View File

@@ -7,7 +7,7 @@
from typing import Optional

from ...configuration_utils import PreTrainedConfig
from ...modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params
from ...modeling_rope_utils import RopeParameters


class MyNewModel2Config(PreTrainedConfig):
@@ -68,7 +68,7 @@ class MyNewModel2Config(PreTrainedConfig):
eos_token_id: Optional[int] = 2,
pretraining_tp: Optional[int] = 1,
tie_word_embeddings: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
mlp_bias: Optional[bool] = False,
@@ -96,14 +96,7 @@ class MyNewModel2Config(PreTrainedConfig):
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
# Try to set `rope_scaling` if available, otherwise use `rope_parameters`
rope_scaling = kwargs.pop("rope_scaling", None)
self.rope_parameters = rope_scaling or rope_parameters

# Validate the correctness of rotary position embeddings parameters
rope_theta = kwargs.get("rope_theta", 10000.0)
standardize_rope_params(self, rope_theta=rope_theta)
rope_config_validation(self)
self.rope_parameters = rope_parameters

super().__init__(
pad_token_id=pad_token_id,


+ 15
- 19
examples/modular-transformers/configuration_new_model.py View File

@@ -6,7 +6,8 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Example where we only want to overwrite the defaults of an init

from ...configuration_utils import PreTrainedConfig, layer_type_validation

from ...configuration_utils import PreTrainedConfig


class NewModelConfig(PreTrainedConfig):
@@ -59,14 +60,14 @@ class NewModelConfig(PreTrainedConfig):
Beginning of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_parameters (`RopeParameters`, *optional*):
Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
with longer `max_position_embeddings`.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
layer_types (`list`, *optional*):
Attention pattern for each layer.
use_bidirectional_attention (`bool`, *optional*):
If True, the model will attend to all text tokens instead of using a causal mask.

@@ -116,20 +117,12 @@ class NewModelConfig(PreTrainedConfig):
eos_token_id=1,
bos_token_id=2,
tie_word_embeddings=True,
rope_theta=10000.0,
rope_parameters=None,
attention_bias=False,
attention_dropout=0.0,
use_bidirectional_attention=False,
layer_types=None,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
@@ -142,15 +135,18 @@ class NewModelConfig(PreTrainedConfig):
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.use_bidirectional_attention = use_bidirectional_attention
self.rope_parameters = rope_parameters

self.layer_types = layer_types
if self.layer_types is None:
self.layer_types = ["full_attention" for _ in range(self.num_hidden_layers)]
layer_type_validation(self.layer_types, self.num_hidden_layers)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

@property
def num_heads(self):


+ 3
- 0
examples/modular-transformers/modeling_add_function.py View File

@@ -10,6 +10,8 @@ from typing import Optional
import torch
from torch import nn

from ...integrations import use_kernel_func_from_hub


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
@@ -18,6 +20,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.



+ 56
- 164
examples/modular-transformers/modeling_dummy_bert.py View File

@@ -10,24 +10,20 @@ from typing import Optional, Union
import torch
from torch import nn

from ... import initialization as init
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...masking_utils import create_causal_mask
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
from ...masking_utils import create_bidirectional_mask, create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import TransformersKwargs, auto_docstring, is_torch_flex_attn_available
from ...pytorch_utils import apply_chunking_to_forward
from ...utils import TransformersKwargs, auto_docstring
from ...utils.generic import check_model_inputs
from .configuration_dummy_bert import DummyBertConfig


if is_torch_flex_attn_available():
from ...integrations.flex_attention import make_flex_block_causal_mask


class DummyBertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""

@@ -106,7 +102,7 @@ def eager_attention_forward(
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling

if attention_mask is not None and attention_mask.ndim == 4:
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask

@@ -148,7 +144,7 @@ class DummyBertSelfAttention(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Cache] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
@@ -160,14 +156,14 @@ class DummyBertSelfAttention(nn.Module):
key_layer = self.key(hidden_states).view(*hidden_shape).transpose(1, 2)
value_layer = self.value(hidden_states).view(*hidden_shape).transpose(1, 2)

if past_key_value is not None:
if past_key_values is not None:
# decoder-only dummy_bert can have a simple dynamic cache for example
current_past_key_value = past_key_value
if isinstance(past_key_value, EncoderDecoderCache):
current_past_key_value = past_key_value.self_attention_cache
current_past_key_values = past_key_values
if isinstance(past_key_values, EncoderDecoderCache):
current_past_key_values = past_key_values.self_attention_cache

# save all key/value_layer to cache to be re-used for fast auto-regressive generation
key_layer, value_layer = current_past_key_value.update(
key_layer, value_layer = current_past_key_values.update(
key_layer,
value_layer,
self.layer_idx,
@@ -221,7 +217,7 @@ class DummyBertCrossAttention(nn.Module):
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[EncoderDecoderCache] = None,
past_key_values: Optional[EncoderDecoderCache] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
# determine input shapes
@@ -234,22 +230,22 @@ class DummyBertCrossAttention(nn.Module):
# get query proj
query_layer = self.query(hidden_states).view(*q_input_shape).transpose(1, 2)

is_updated = past_key_value.is_updated.get(self.layer_idx) if past_key_value is not None else False
if past_key_value is not None and is_updated:
is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
if past_key_values is not None and is_updated:
# reuse k,v, cross_attentions
key_layer = past_key_value.cross_attention_cache.layers[self.layer_idx].keys
value_layer = past_key_value.cross_attention_cache.layers[self.layer_idx].values
key_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
value_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].values
else:
key_layer = self.key(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2)
value_layer = self.value(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2)

if past_key_value is not None:
if past_key_values is not None:
# save all states to the cache
key_layer, value_layer = past_key_value.cross_attention_cache.update(
key_layer, value_layer = past_key_values.cross_attention_cache.update(
key_layer, value_layer, self.layer_idx
)
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
past_key_value.is_updated[self.layer_idx] = True
past_key_values.is_updated[self.layer_idx] = True

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
@@ -290,25 +286,6 @@ class DummyBertAttention(nn.Module):
attention_class = DummyBertCrossAttention if is_cross_attention else DummyBertSelfAttention
self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx)
self.output = DummyBertSelfOutput(config)
self.pruned_heads = set()

def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)

# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)

def forward(
self,
@@ -316,7 +293,7 @@ class DummyBertAttention(nn.Module):
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Cache] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
@@ -325,7 +302,7 @@ class DummyBertAttention(nn.Module):
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
past_key_value=past_key_value,
past_key_values=past_key_values,
cache_position=cache_position,
**kwargs,
)
@@ -388,14 +365,14 @@ class DummyBertLayer(GradientCheckpointingLayer):
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Cache] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
self_attention_output, _ = self.attention(
hidden_states,
attention_mask,
past_key_value=past_key_value,
past_key_values=past_key_values,
cache_position=cache_position,
**kwargs,
)
@@ -413,7 +390,7 @@ class DummyBertLayer(GradientCheckpointingLayer):
None, # attention_mask
encoder_hidden_states,
encoder_attention_mask,
past_key_value=past_key_value,
past_key_values=past_key_values,
**kwargs,
)
attention_output = cross_attention_output
@@ -452,7 +429,7 @@ class DummyBertEncoder(nn.Module):
attention_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_values,
past_key_values=past_key_values,
cache_position=cache_position,
**kwargs,
)
@@ -503,7 +480,6 @@ class DummyBertLMPredictionHead(nn.Module):
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)

self.bias = nn.Parameter(torch.zeros(config.vocab_size))

def forward(self, hidden_states):
@@ -527,21 +503,12 @@ class DummyBertPreTrainedModel(PreTrainedModel):
"cross_attentions": DummyBertCrossAttention,
}

@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.zero_()
module.weight.fill_(1.0)
elif isinstance(module, DummyBertLMPredictionHead):
module.bias.zero_()
super()._init_weights(module)
if isinstance(module, DummyBertLMPredictionHead):
init.zeros_(module.bias)


@auto_docstring(
@@ -582,14 +549,6 @@ class DummyBertModel(DummyBertPreTrainedModel):
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value

def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)

@check_model_inputs
@auto_docstring
def forward(
@@ -615,19 +574,22 @@ class DummyBertModel(DummyBertPreTrainedModel):
use_cache = False

if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)

if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

if input_ids is not None:
device = input_ids.device
input_shape = input_ids.shape
seq_length = input_ids.shape[1]
else:
device = inputs_embeds.device
input_shape = inputs_embeds.shape[:-1]
seq_length = inputs_embeds.shape[1]

seq_length = input_shape[1]
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(past_key_values_length, past_key_values_length + seq_length, device=device)
@@ -641,7 +603,6 @@ class DummyBertModel(DummyBertPreTrainedModel):
)

attention_mask, encoder_attention_mask = self._create_attention_masks(
input_shape=input_shape,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
embedding_output=embedding_output,
@@ -672,7 +633,6 @@ class DummyBertModel(DummyBertPreTrainedModel):

def _create_attention_masks(
self,
input_shape,
attention_mask,
encoder_attention_mask,
embedding_output,
@@ -680,95 +640,27 @@ class DummyBertModel(DummyBertPreTrainedModel):
cache_position,
past_key_values,
):
if attention_mask is not None and attention_mask.dim() == 2:
if self.config.is_decoder:
attention_mask = create_causal_mask(
config=self.config,
input_embeds=embedding_output,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
)
else:
attention_mask = self._update_full_mask(
attention_mask,
embedding_output,
)
elif attention_mask is not None and attention_mask.dim() == 3:
if "flash" in self.config._attn_implementation or self.config._attn_implementation == "flex_attention":
raise ValueError(
"Passing attention mask with a 3D/4D shape does not work with type "
f"{self.config._attn_implementation} - please use either `sdpa` or `eager` instead."
)
attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
if self.config.is_decoder:
attention_mask = create_causal_mask(
config=self.config,
input_embeds=embedding_output,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
)
else:
attention_mask = create_bidirectional_mask(
config=self.config,
input_embeds=embedding_output,
attention_mask=attention_mask,
)

if encoder_attention_mask is not None:
if encoder_attention_mask.dim() == 2:
encoder_attention_mask = self._update_cross_attn_mask(
encoder_hidden_states,
encoder_attention_mask,
embedding_output.shape[:2],
embedding_output,
)
else:
if "flash" in self.config._attn_implementation or self.config._attn_implementation == "flex_attention":
raise ValueError(
"Passing attention mask with a 3D/4D shape does not work with type "
f"{self.config._attn_implementation} - please use either `sdpa` or `eager` instead."
)
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
encoder_attention_mask = create_bidirectional_mask(
config=self.config,
input_embeds=embedding_output,
attention_mask=encoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
)

return attention_mask, encoder_attention_mask

def _update_full_mask(
self,
attention_mask: Union[torch.Tensor, None],
inputs_embeds: torch.Tensor,
):
if attention_mask is not None:
if "flash" in self.config._attn_implementation:
attention_mask = attention_mask if 0 in attention_mask else None
elif self.config._attn_implementation == "sdpa":
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
elif self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)

return attention_mask

def _update_cross_attn_mask(
self,
encoder_hidden_states: Union[torch.Tensor, None],
encoder_attention_mask: Union[torch.Tensor, None],
input_shape: torch.Size,
inputs_embeds: torch.Tensor,
):
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
if "flash" in self.config._attn_implementation:
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
elif self.config._attn_implementation == "sdpa":
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask,
inputs_embeds.dtype,
tgt_len=input_shape[-1],
)
elif self.config._attn_implementation == "flex_attention":
if isinstance(encoder_attention_mask, torch.Tensor):
encoder_attention_mask = make_flex_block_causal_mask(
encoder_attention_mask,
query_length=input_shape[-1],
is_causal=False,
)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)

return encoder_attention_mask

+ 12
- 42
examples/modular-transformers/modeling_from_uppercase_model.py View File

@@ -4,6 +4,7 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_from_uppercase_model.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨

from collections.abc import Callable
from typing import Optional, Union

@@ -13,6 +14,8 @@ from torch import nn
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import TransformersKwargs
from .configuration_from_uppercase_model import FromUppercaseModelTextConfig, FromUppercaseModelVisionConfig


@@ -24,8 +27,7 @@ def eager_attention_forward(
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
output_attentions: bool = True,
**kwargs,
**kwargs: Unpack[TransformersKwargs],
):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
if attention_mask is not None:
@@ -35,8 +37,6 @@ def eager_attention_forward(

attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
if not output_attentions:
attn_weights = None
return attn_output, attn_weights


@@ -67,8 +67,7 @@ class FromUppercaseModelAttention(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""

@@ -81,15 +80,6 @@ class FromUppercaseModelAttention(nn.Module):
queries = queries.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
# FROM_UPPERCASE_MODEL text model uses both `causal_attention_mask` and `attention_mask`
# in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
if self.config._attn_implementation == "flash_attention_2":
self.is_causal = causal_attention_mask is not None
else:
if attention_mask is not None and causal_attention_mask is not None:
attention_mask = attention_mask + causal_attention_mask
elif causal_attention_mask is not None:
attention_mask = causal_attention_mask

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
@@ -101,17 +91,14 @@ class FromUppercaseModelAttention(nn.Module):
keys,
values,
attention_mask,
is_causal=self.is_causal,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
output_attentions=output_attentions,
**kwargs,
)

attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
attn_output = self.out_proj(attn_output)

if not output_attentions:
attn_weights = None
return attn_output, attn_weights


@@ -143,27 +130,15 @@ class FromUppercaseModelEncoderLayer(GradientCheckpointingLayer):
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
causal_attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(config.encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
**kwargs: Unpack[TransformersKwargs],
) -> torch.FloatTensor:
residual = hidden_states

hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
**kwargs,
)
hidden_states = residual + hidden_states

@@ -172,9 +147,4 @@ class FromUppercaseModelEncoderLayer(GradientCheckpointingLayer):
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

outputs = (hidden_states,)

if output_attentions:
outputs += (attn_weights,)

return outputs
return hidden_states

+ 5
- 2
examples/modular-transformers/modeling_global_indexing.py View File

@@ -13,6 +13,7 @@ from torch import nn
from transformers.modeling_utils import AttentionInterface

from ...cache_utils import Cache
from ...integrations import use_kernel_func_from_hub, use_kernelized_func
from ...processing_utils import Unpack
from ...utils import TransformersKwargs
from .configuration_global_indexing import GlobalIndexingConfig
@@ -25,6 +26,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

@@ -100,6 +102,7 @@ ALL_ATTENTION_FUNCTIONS = AttentionInterface()
ALL_ATTENTION_FUNCTIONS["flex_attention"] = custom_flex


@use_kernelized_func(apply_rotary_pos_emb)
class GlobalIndexingAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

@@ -129,8 +132,8 @@ class GlobalIndexingAttention(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],


+ 29
- 177
examples/modular-transformers/modeling_multimodal2.py View File

@@ -17,7 +17,9 @@ from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import auto_docstring, can_return_tuple, torch_int
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, torch_int
from ...utils.generic import check_model_inputs
from .configuration_multimodal2 import Multimodal2Config, Multimodal2TextConfig, Multimodal2VisionConfig


@@ -29,8 +31,7 @@ def eager_attention_forward(
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
output_attentions: bool = True,
**kwargs,
**kwargs: Unpack[TransformersKwargs],
):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
if attention_mask is not None:
@@ -40,8 +41,6 @@ def eager_attention_forward(

attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
if not output_attentions:
attn_weights = None
return attn_output, attn_weights


@@ -72,8 +71,7 @@ class Multimodal2VisionAttention(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""

@@ -86,15 +84,6 @@ class Multimodal2VisionAttention(nn.Module):
queries = queries.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
# MULTIMODAL2_VISION text model uses both `causal_attention_mask` and `attention_mask`
# in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
if self.config._attn_implementation == "flash_attention_2":
self.is_causal = causal_attention_mask is not None
else:
if attention_mask is not None and causal_attention_mask is not None:
attention_mask = attention_mask + causal_attention_mask
elif causal_attention_mask is not None:
attention_mask = causal_attention_mask

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
@@ -106,17 +95,14 @@ class Multimodal2VisionAttention(nn.Module):
keys,
values,
attention_mask,
is_causal=self.is_causal,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
output_attentions=output_attentions,
**kwargs,
)

attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
attn_output = self.out_proj(attn_output)

if not output_attentions:
attn_weights = None
return attn_output, attn_weights


@@ -135,86 +121,11 @@ class Multimodal2VisionMLP(nn.Module):
return hidden_states


class Multimodal2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(self, config: Union[Multimodal2VisionConfig, Multimodal2TextConfig]):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.is_causal = False

self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""

batch_size, seq_length, embed_dim = hidden_states.shape

queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)

queries = queries.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
# MULTIMODAL2 text model uses both `causal_attention_mask` and `attention_mask`
# in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
if self.config._attn_implementation == "flash_attention_2":
self.is_causal = causal_attention_mask is not None
else:
if attention_mask is not None and causal_attention_mask is not None:
attention_mask = attention_mask + causal_attention_mask
elif causal_attention_mask is not None:
attention_mask = causal_attention_mask

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

attn_output, attn_weights = attention_interface(
self,
queries,
keys,
values,
attention_mask,
is_causal=self.is_causal,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
output_attentions=output_attentions,
)

attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = self.out_proj(attn_output)

if not output_attentions:
attn_weights = None
return attn_output, attn_weights


class Multimodal2VisionEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = Multimodal2Attention(config)
self.self_attn = Multimodal2VisionAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Multimodal2VisionMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
@@ -223,27 +134,15 @@ class Multimodal2VisionEncoderLayer(GradientCheckpointingLayer):
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
causal_attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(config.encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
**kwargs: Unpack[TransformersKwargs],
) -> torch.FloatTensor:
residual = hidden_states

hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
**kwargs,
)
hidden_states = residual + hidden_states

@@ -252,12 +151,7 @@ class Multimodal2VisionEncoderLayer(GradientCheckpointingLayer):
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

outputs = (hidden_states,)

if output_attentions:
outputs += (attn_weights,)

return outputs
return hidden_states


class Multimodal2VisionEncoder(nn.Module):
@@ -279,9 +173,7 @@ class Multimodal2VisionEncoder(nn.Module):
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutput:
r"""
Args:
@@ -296,53 +188,17 @@ class Multimodal2VisionEncoder(nn.Module):
- 0 for tokens that are **masked**.

[What are attention masks?](../glossary#attention-mask)
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Causal mask for the text model. Mask values selected in `[0, 1]`:

- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.

[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)

encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None

hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
layer_outputs = encoder_layer(
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
**kwargs,
)

hidden_states = layer_outputs[0]

if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)

if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)

return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
)


@@ -444,15 +300,9 @@ class Multimodal2VisionTransformer(nn.Module):
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPooling:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)

if pixel_values is None:
raise ValueError("You have to specify pixel_values")

@@ -461,8 +311,7 @@ class Multimodal2VisionTransformer(nn.Module):

encoder_outputs: BaseModelOutput = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
**kwargs,
)

last_hidden_state = encoder_outputs.last_hidden_state
@@ -472,8 +321,6 @@ class Multimodal2VisionTransformer(nn.Module):
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)


@@ -481,12 +328,18 @@ class Multimodal2VisionTransformer(nn.Module):
class Multimodal2VisionPreTrainedModel(PreTrainedModel):
config: Multimodal2Config
base_model_prefix = "multimodal2_vision"
input_modalities = ("image", "text")
supports_gradient_checkpointing = True
_supports_sdpa = True
_supports_flash_attn = True
_supports_flex_attn = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Multimodal2VisionEncoderLayer,
"attentions": Multimodal2VisionAttention,
}

@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, Multimodal2VisionMLP):
@@ -500,6 +353,7 @@ MULTIMODAL2_VISION_START_DOCSTRING = "doc"
class Multimodal2VisionModel(Multimodal2VisionPreTrainedModel):
config: Multimodal2VisionConfig
main_input_name = "pixel_values"
input_modalities = ("image",)
_no_split_modules = ["Multimodal2VisionEncoderLayer"]

def __init__(self, config: Multimodal2VisionConfig):
@@ -511,14 +365,13 @@ class Multimodal2VisionModel(Multimodal2VisionPreTrainedModel):
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding

@can_return_tuple
@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPooling:
r"""
Example:
@@ -543,7 +396,6 @@ class Multimodal2VisionModel(Multimodal2VisionPreTrainedModel):

return self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
**kwargs,
)

+ 6
- 2
examples/modular-transformers/modeling_my_new_model2.py View File

@@ -10,8 +10,10 @@ from typing import Optional
import torch
from torch import nn

from ... import initialization as init
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...integrations import use_kernel_func_from_hub, use_kernelized_func
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
@@ -62,6 +64,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

@@ -127,6 +130,7 @@ def eager_attention_forward(
return attn_output, attn_weights


@use_kernelized_func(apply_rotary_pos_emb)
class MyNewModel2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

@@ -260,12 +264,12 @@ class MyNewModel2PreTrainedModel(PreTrainedModel):
"attentions": MyNewModel2Attention,
}

@torch.no_grad()
def _init_weights(self, module):
super()._init_weights(module)

# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
if "RMSNorm" in module.__class__.__name__:
module.weight.zero_()
init.zeros_(module.weight)


class MyNewModel2ForSequenceClassification(GenericForSequenceClassification, MyNewModel2PreTrainedModel):


+ 2
- 37
examples/modular-transformers/modeling_new_task_model.py View File

@@ -87,27 +87,17 @@ class NewTaskModelMultiModalProjector(nn.Module):
@auto_docstring
class NewTaskModelPreTrainedModel(PreTrainedModel):
config: NewTaskModelConfig
base_model_prefix = ""
base_model_prefix = "model"
input_modalities = ("image", "text")
supports_gradient_checkpointing = True
_no_split_modules = ["NewTaskModelMultiModalProjector"]
_skip_keys_device_placement = "past_key_values"

_can_compile_fullgraph = False
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_attention_backend = True

def _init_weights(self, module):
# important: this ported version of NewTaskModelisn't meant for training from scratch - only
# inference and fine-tuning
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)

if isinstance(module, nn.Linear):
module.weight.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.zero_()


def token_type_ids_mask_function(
token_type_ids: Optional[torch.Tensor],
@@ -249,12 +239,6 @@ class NewTaskModelModel(NewTaskModelPreTrainedModel):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)

def set_decoder(self, decoder):
self.language_model = decoder

def get_decoder(self):
return self.language_model

def get_image_features(self, pixel_values: torch.FloatTensor):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
@@ -457,28 +441,9 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)

def set_decoder(self, decoder):
self.model.set_decoder(decoder)

def get_decoder(self):
return self.model.get_decoder()

def get_image_features(self, pixel_values):
return self.model.get_image_features(pixel_values)

# Make modules available through conditional class for BC
@property
def language_model(self):
return self.model.language_model

@property
def vision_tower(self):
return self.model.vision_tower

@property
def multi_modal_projector(self):
return self.model.multi_modal_projector

@can_return_tuple
@auto_docstring
def forward(


+ 56
- 164
examples/modular-transformers/modeling_roberta.py View File

@@ -10,24 +10,20 @@ from typing import Optional, Union
import torch
import torch.nn as nn

from ... import initialization as init
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...masking_utils import create_causal_mask
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
from ...masking_utils import create_bidirectional_mask, create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import TransformersKwargs, auto_docstring, is_torch_flex_attn_available
from ...pytorch_utils import apply_chunking_to_forward
from ...utils import TransformersKwargs, auto_docstring
from ...utils.generic import check_model_inputs
from .configuration_roberta import RobertaConfig


if is_torch_flex_attn_available():
from ...integrations.flex_attention import make_flex_block_causal_mask


class RobertaEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""

@@ -109,7 +105,7 @@ def eager_attention_forward(
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling

if attention_mask is not None and attention_mask.ndim == 4:
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask

@@ -151,7 +147,7 @@ class RobertaSelfAttention(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Cache] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
@@ -163,14 +159,14 @@ class RobertaSelfAttention(nn.Module):
key_layer = self.key(hidden_states).view(*hidden_shape).transpose(1, 2)
value_layer = self.value(hidden_states).view(*hidden_shape).transpose(1, 2)

if past_key_value is not None:
if past_key_values is not None:
# decoder-only roberta can have a simple dynamic cache for example
current_past_key_value = past_key_value
if isinstance(past_key_value, EncoderDecoderCache):
current_past_key_value = past_key_value.self_attention_cache
current_past_key_values = past_key_values
if isinstance(past_key_values, EncoderDecoderCache):
current_past_key_values = past_key_values.self_attention_cache

# save all key/value_layer to cache to be re-used for fast auto-regressive generation
key_layer, value_layer = current_past_key_value.update(
key_layer, value_layer = current_past_key_values.update(
key_layer,
value_layer,
self.layer_idx,
@@ -224,7 +220,7 @@ class RobertaCrossAttention(nn.Module):
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[EncoderDecoderCache] = None,
past_key_values: Optional[EncoderDecoderCache] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
# determine input shapes
@@ -237,22 +233,22 @@ class RobertaCrossAttention(nn.Module):
# get query proj
query_layer = self.query(hidden_states).view(*q_input_shape).transpose(1, 2)

is_updated = past_key_value.is_updated.get(self.layer_idx) if past_key_value is not None else False
if past_key_value is not None and is_updated:
is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
if past_key_values is not None and is_updated:
# reuse k,v, cross_attentions
key_layer = past_key_value.cross_attention_cache.layers[self.layer_idx].keys
value_layer = past_key_value.cross_attention_cache.layers[self.layer_idx].values
key_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
value_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].values
else:
key_layer = self.key(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2)
value_layer = self.value(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2)

if past_key_value is not None:
if past_key_values is not None:
# save all states to the cache
key_layer, value_layer = past_key_value.cross_attention_cache.update(
key_layer, value_layer = past_key_values.cross_attention_cache.update(
key_layer, value_layer, self.layer_idx
)
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
past_key_value.is_updated[self.layer_idx] = True
past_key_values.is_updated[self.layer_idx] = True

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
@@ -293,25 +289,6 @@ class RobertaAttention(nn.Module):
attention_class = RobertaCrossAttention if is_cross_attention else RobertaSelfAttention
self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx)
self.output = RobertaSelfOutput(config)
self.pruned_heads = set()

def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)

# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)

def forward(
self,
@@ -319,7 +296,7 @@ class RobertaAttention(nn.Module):
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Cache] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
@@ -328,7 +305,7 @@ class RobertaAttention(nn.Module):
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
past_key_value=past_key_value,
past_key_values=past_key_values,
cache_position=cache_position,
**kwargs,
)
@@ -391,14 +368,14 @@ class RobertaLayer(GradientCheckpointingLayer):
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Cache] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
self_attention_output, _ = self.attention(
hidden_states,
attention_mask,
past_key_value=past_key_value,
past_key_values=past_key_values,
cache_position=cache_position,
**kwargs,
)
@@ -416,7 +393,7 @@ class RobertaLayer(GradientCheckpointingLayer):
None, # attention_mask
encoder_hidden_states,
encoder_attention_mask,
past_key_value=past_key_value,
past_key_values=past_key_values,
**kwargs,
)
attention_output = cross_attention_output
@@ -455,7 +432,7 @@ class RobertaEncoder(nn.Module):
attention_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_values,
past_key_values=past_key_values,
cache_position=cache_position,
**kwargs,
)
@@ -506,7 +483,6 @@ class RobertaLMPredictionHead(nn.Module):
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)

self.bias = nn.Parameter(torch.zeros(config.vocab_size))

def forward(self, hidden_states):
@@ -530,21 +506,12 @@ class RobertaPreTrainedModel(PreTrainedModel):
"cross_attentions": RobertaCrossAttention,
}

@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.zero_()
module.weight.fill_(1.0)
elif isinstance(module, RobertaLMPredictionHead):
module.bias.zero_()
super()._init_weights(module)
if isinstance(module, RobertaLMPredictionHead):
init.zeros_(module.bias)


@auto_docstring(
@@ -585,14 +552,6 @@ class RobertaModel(RobertaPreTrainedModel):
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value

def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)

@check_model_inputs
@auto_docstring
def forward(
@@ -615,19 +574,22 @@ class RobertaModel(RobertaPreTrainedModel):
use_cache = False

if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)

if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

if input_ids is not None:
device = input_ids.device
input_shape = input_ids.shape
seq_length = input_ids.shape[1]
else:
device = inputs_embeds.device
input_shape = inputs_embeds.shape[:-1]
seq_length = inputs_embeds.shape[1]

seq_length = input_shape[1]
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(past_key_values_length, past_key_values_length + seq_length, device=device)
@@ -641,7 +603,6 @@ class RobertaModel(RobertaPreTrainedModel):
)

attention_mask, encoder_attention_mask = self._create_attention_masks(
input_shape=input_shape,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
embedding_output=embedding_output,
@@ -672,7 +633,6 @@ class RobertaModel(RobertaPreTrainedModel):

def _create_attention_masks(
self,
input_shape,
attention_mask,
encoder_attention_mask,
embedding_output,
@@ -680,95 +640,27 @@ class RobertaModel(RobertaPreTrainedModel):
cache_position,
past_key_values,
):
if attention_mask is not None and attention_mask.dim() == 2:
if self.config.is_decoder:
attention_mask = create_causal_mask(
config=self.config,
input_embeds=embedding_output,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
)
else:
attention_mask = self._update_full_mask(
attention_mask,
embedding_output,
)
elif attention_mask is not None and attention_mask.dim() == 3:
if "flash" in self.config._attn_implementation or self.config._attn_implementation == "flex_attention":
raise ValueError(
"Passing attention mask with a 3D/4D shape does not work with type "
f"{self.config._attn_implementation} - please use either `sdpa` or `eager` instead."
)
attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
if self.config.is_decoder:
attention_mask = create_causal_mask(
config=self.config,
input_embeds=embedding_output,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
)
else:
attention_mask = create_bidirectional_mask(
config=self.config,
input_embeds=embedding_output,
attention_mask=attention_mask,
)

if encoder_attention_mask is not None:
if encoder_attention_mask.dim() == 2:
encoder_attention_mask = self._update_cross_attn_mask(
encoder_hidden_states,
encoder_attention_mask,
embedding_output.shape[:2],
embedding_output,
)
else:
if "flash" in self.config._attn_implementation or self.config._attn_implementation == "flex_attention":
raise ValueError(
"Passing attention mask with a 3D/4D shape does not work with type "
f"{self.config._attn_implementation} - please use either `sdpa` or `eager` instead."
)
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
encoder_attention_mask = create_bidirectional_mask(
config=self.config,
input_embeds=embedding_output,
attention_mask=encoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
)

return attention_mask, encoder_attention_mask

def _update_full_mask(
self,
attention_mask: Union[torch.Tensor, None],
inputs_embeds: torch.Tensor,
):
if attention_mask is not None:
if "flash" in self.config._attn_implementation:
attention_mask = attention_mask if 0 in attention_mask else None
elif self.config._attn_implementation == "sdpa":
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
elif self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)

return attention_mask

def _update_cross_attn_mask(
self,
encoder_hidden_states: Union[torch.Tensor, None],
encoder_attention_mask: Union[torch.Tensor, None],
input_shape: torch.Size,
inputs_embeds: torch.Tensor,
):
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
if "flash" in self.config._attn_implementation:
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
elif self.config._attn_implementation == "sdpa":
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask,
inputs_embeds.dtype,
tgt_len=input_shape[-1],
)
elif self.config._attn_implementation == "flex_attention":
if isinstance(encoder_attention_mask, torch.Tensor):
encoder_attention_mask = make_flex_block_causal_mask(
encoder_attention_mask,
query_length=input_shape[-1],
is_causal=False,
)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)

return encoder_attention_mask

+ 44
- 13
examples/modular-transformers/modeling_super.py View File

@@ -14,13 +14,13 @@ from transformers.modeling_outputs import CausalLMOutputWithPast

from ...activations import ACT2FN
from ...cache_utils import Cache
from ...integrations import use_kernel_forward_from_hub
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring
from ...utils.generic import check_model_inputs
from ...utils.generic import check_model_inputs, maybe_autocast
from .configuration_super import SuperConfig


@@ -50,20 +50,49 @@ class SuperRotaryEmbedding(nn.Module):

def __init__(self, config: SuperConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_parameters") and isinstance(config.rope_parameters, dict):
self.rope_type = config.rope_parameters.get("rope_type", config.rope_parameters.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings

self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.rope_type = self.config.rope_parameters["rope_type"]
rope_init_fn: Callable = self.compute_default_rope_parameters
if self.rope_type != "default":
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)

self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
self.original_inv_freq = inv_freq

@staticmethod
def compute_default_rope_parameters(
config: Optional[SuperConfig] = None,
device: Optional["torch.device"] = None,
seq_len: Optional[int] = None,
) -> tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies according to the original RoPE implementation
Args:
config ([`~transformers.PreTrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
base = config.rope_parameters["rope_theta"]
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads

attention_factor = 1.0 # Unused in this type of RoPE

# Compute the inverse frequencies
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
)
return inv_freq, attention_factor

@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
@@ -72,7 +101,7 @@ class SuperRotaryEmbedding(nn.Module):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
@@ -104,6 +133,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

@@ -169,6 +199,7 @@ def eager_attention_forward(
return attn_output, attn_weights


@use_kernelized_func(apply_rotary_pos_emb)
class SuperAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

@@ -198,8 +229,8 @@ class SuperAttention(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],


+ 5
- 2
examples/modular-transformers/modeling_switch_function.py View File

@@ -12,6 +12,7 @@ import torch
from torch import nn

from ...cache_utils import Cache
from ...integrations import use_kernel_func_from_hub, use_kernelized_func
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import TransformersKwargs
@@ -26,6 +27,7 @@ def rotate_half(x):
return rot_x


@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

@@ -91,6 +93,7 @@ def eager_attention_forward(
return attn_output, attn_weights


@use_kernelized_func(apply_rotary_pos_emb)
class SwitchFunctionAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

@@ -120,8 +123,8 @@ class SwitchFunctionAttention(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],


+ 31
- 26
examples/modular-transformers/modeling_test_detr.py View File

@@ -4,6 +4,7 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_test_detr.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨

import math
import warnings
from dataclasses import dataclass
@@ -13,6 +14,7 @@ import torch
import torch.nn.functional as F
from torch import Tensor, nn

from ... import initialization as init
from ...activations import ACT2FN
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
@@ -203,10 +205,10 @@ def replace_batch_norm(model):
new_module = TestDetrFrozenBatchNorm2d(module.num_features)

if module.weight.device != torch.device("meta"):
new_module.weight.data.copy_(module.weight)
new_module.bias.data.copy_(module.bias)
new_module.running_mean.data.copy_(module.running_mean)
new_module.running_var.data.copy_(module.running_var)
new_module.weight.copy_(module.weight)
new_module.bias.copy_(module.bias)
new_module.running_mean.copy_(module.running_mean)
new_module.running_var.copy_(module.running_var)

model._modules[name] = new_module

@@ -810,6 +812,7 @@ class TestDetrPreTrainedModel(PreTrainedModel):
config: TestDetrConfig
base_model_prefix = "model"
main_input_name = "pixel_values"
input_modalities = ("image",)
supports_gradient_checkpointing = True
_no_split_modules = [
r"TestDetrConvEncoder",
@@ -817,14 +820,15 @@ class TestDetrPreTrainedModel(PreTrainedModel):
r"TestDetrDecoderLayer",
]

@torch.no_grad()
def _init_weights(self, module):
std = self.config.init_std

if isinstance(module, TestDetrLearnedPositionEmbedding):
nn.init.uniform_(module.row_embeddings.weight)
nn.init.uniform_(module.column_embeddings.weight)
init.uniform_(module.row_embeddings.weight)
init.uniform_(module.column_embeddings.weight)
elif isinstance(module, TestDetrMultiscaleDeformableAttention):
nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
init.constant_(module.sampling_offsets.weight, 0.0)
default_dtype = torch.get_default_dtype()
thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * (
2.0 * math.pi / module.n_heads
@@ -837,27 +841,28 @@ class TestDetrPreTrainedModel(PreTrainedModel):
)
for i in range(module.n_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
nn.init.constant_(module.attention_weights.weight.data, 0.0)
nn.init.constant_(module.attention_weights.bias.data, 0.0)
nn.init.xavier_uniform_(module.value_proj.weight.data)
nn.init.constant_(module.value_proj.bias.data, 0.0)
nn.init.xavier_uniform_(module.output_proj.weight.data)
nn.init.constant_(module.output_proj.bias.data, 0.0)
init.copy_(module.sampling_offsets.bias, grid_init.view(-1))
init.constant_(module.attention_weights.weight, 0.0)
init.constant_(module.attention_weights.bias, 0.0)
init.xavier_uniform_(module.value_proj.weight)
init.constant_(module.value_proj.bias, 0.0)
init.xavier_uniform_(module.output_proj.weight)
init.constant_(module.output_proj.bias, 0.0)
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
module.weight.normal_(mean=0.0, std=std)
init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
module.bias.zero_()
init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
module.weight.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
init.normal_(module.weight, mean=0.0, std=std)
# Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
init.zeros_(module.weight[module.padding_idx])
if hasattr(module, "reference_points") and not self.config.two_stage:
nn.init.xavier_uniform_(module.reference_points.weight.data, gain=1.0)
nn.init.constant_(module.reference_points.bias.data, 0.0)
init.xavier_uniform_(module.reference_points.weight, gain=1.0)
init.constant_(module.reference_points.bias, 0.0)
if hasattr(module, "level_embed"):
nn.init.normal_(module.level_embed)
init.normal_(module.level_embed)


class TestDetrEncoder(TestDetrPreTrainedModel):
@@ -924,6 +929,7 @@ class TestDetrEncoder(TestDetrPreTrainedModel):
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
r"""
Args:
@@ -1046,6 +1052,7 @@ class TestDetrDecoder(TestDetrPreTrainedModel):
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
r"""
Args:
@@ -1267,9 +1274,6 @@ class TestDetrModel(TestDetrPreTrainedModel):

self.post_init()

def get_encoder(self):
return self.encoder

def freeze_backbone(self):
for name, param in self.backbone.conv_encoder.model.named_parameters():
param.requires_grad_(False)
@@ -1379,6 +1383,7 @@ class TestDetrModel(TestDetrPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[tuple[torch.FloatTensor], TestDetrModelOutput]:
r"""
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):


+ 250
- 0
examples/modular-transformers/modeling_test_suffix.py View File

@@ -0,0 +1,250 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from examples/modular-transformers/modular_test_suffix.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_test_suffix.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from collections.abc import Callable
from typing import Optional

import torch
import torch.nn as nn

from ...activations import ACT2FN
from ...cache_utils import Cache
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import TransformersKwargs
from .configuration_test_suffix import TestSuffixLlamaConfig


class TestSuffixDecoderLayer(nn.module):
pass


@use_kernel_forward_from_hub("RMSNorm")
class TestSuffixLlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
TestSuffixLlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)

def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


class TestSuffixLlamaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.hidden_act]

def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, attn_weights


@use_kernelized_func(apply_rotary_pos_emb)
class TestSuffixLlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(self, config: TestSuffixLlamaConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True

self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)

def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)

query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)

attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights


class TestSuffixLlamaDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: TestSuffixLlamaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size

self.self_attn = TestSuffixLlamaAttention(config=config, layer_idx=layer_idx)

self.mlp = TestSuffixLlamaMLP(config)
self.input_layernorm = TestSuffixLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = TestSuffixLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states

+ 8
- 1
examples/modular-transformers/modular_multimodal2.py View File

@@ -35,6 +35,7 @@ class Multimodal2VisionEncoderLayer(CLIPEncoderLayer):
def __init__(self, config):
super().__init__()
self.mlp = Multimodal2VisionMLP(config)
self.self_attn = Multimodal2VisionAttention(config)


class Multimodal2VisionEncoder(CLIPEncoder):
@@ -43,7 +44,8 @@ class Multimodal2VisionEncoder(CLIPEncoder):
self.layers = nn.ModuleList([Multimodal2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])


# Finally here the `Vision` part was correct in CLIP, but we still need to tell it that the encoder arg should use it as well
# Finally here the `Vision` part was correct in CLIP, but we still need to tell it that the encoder and attn arg should
# use it as well
class Multimodal2VisionTransformer(CLIPVisionTransformer):
def __init__(self, config):
super().__init__(config)
@@ -51,6 +53,11 @@ class Multimodal2VisionTransformer(CLIPVisionTransformer):


class Multimodal2VisionPreTrainedModel(CLIPPreTrainedModel):
_can_record_outputs = {
"hidden_states": Multimodal2VisionEncoderLayer,
"attentions": Multimodal2VisionAttention,
}

def _init_weights(self, module):
if isinstance(module, Multimodal2VisionMLP):
pass


+ 1
- 2
examples/modular-transformers/modular_new_model.py View File

@@ -23,11 +23,10 @@ class NewModelConfig(GemmaConfig):
eos_token_id=1,
bos_token_id=2,
tie_word_embeddings=True,
rope_theta=10000.0,
rope_parameters=None,
attention_bias=False,
attention_dropout=0.0,
use_bidirectional_attention=False,
layer_types=None,
**kwargs,
):
super().__init__(self, **kwargs)


+ 12
- 0
examples/modular-transformers/modular_test_suffix.py View File

@@ -0,0 +1,12 @@
import torch.nn as nn

from transformers.models.llama.modeling_llama import LlamaDecoderLayer


class TestSuffixDecoderLayer(nn.module):
pass


# Here, we want to add "Llama" as a suffix to the base `TestModel` name for all required dependencies
class TestSuffixLlamaDecoderLayer(LlamaDecoderLayer):
pass

+ 12
- 2
examples/pytorch/continuous_batching.py View File

@@ -182,13 +182,16 @@ if __name__ == "__main__":

# Benchmark parameters
parser.add_argument("--samples", type=int, default=500, help="Number of samples to generate")
parser.add_argument(
"--input-length", type=int, default=None, help="Length of input sequences. Leave to None to mimic real eval."
)
parser.add_argument("--max-new-tokens", type=int, default=512, help="Maximum number of new tokens to generate")
parser.add_argument("--force-max-length", action="store_true", help="Force generation to stop at max length")

parser.add_argument("--add-prefix", action="store_true", help="Add a prefix to the samples")
parser.add_argument("--compare", action="store_true", help="Compare CB generation with classic generate")
parser.add_argument("--profile", type=str, default=None)
parser.add_argument("--metrics", action="store_true")
parser.add_argument("--force-max-length", action="store_true", help="Force generation to stop at max length")

# Display parameters
parser.add_argument("--displayed", type=int, default=0, help="Number of samples to display")
@@ -251,6 +254,12 @@ if __name__ == "__main__":
else:
possible_prefixes = [None]

tokenizer_kwargs = {"add_generation_prompt": True}
if args.input_length is not None:
tokenizer_kwargs["max_length"] = args.input_length
tokenizer_kwargs["truncation"] = True
tokenizer_kwargs["padding"] = True

batched_inputs = []
for item, prefix in zip(dataset, cycle(possible_prefixes)):
messages = []
@@ -261,7 +270,7 @@ if __name__ == "__main__":
else:
question = prefix + "\n\n" + question
messages.append({"role": "user", "content": question})
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
inputs = tokenizer.apply_chat_template(messages, **tokenizer_kwargs)
inputs = inputs if isinstance(inputs, list) else inputs["input_ids"]
batched_inputs.append(inputs)

@@ -283,6 +292,7 @@ if __name__ == "__main__":
generation_cfg.compile_config = CompileConfig(
fullgraph=True,
mode="max-autotune-no-cudagraphs",
dynamic=True, # FIXME: if we warmup all graphs, this is not needed anymore
)

# If we need to compare, we need to generate the reference outputs


+ 2
- 30
setup.py View File

@@ -36,35 +36,7 @@ To create the package for pypi.
5. On the release branch, add a tag in git to mark the release: "git tag v<VERSION> -m 'Adds tag v<VERSION> for pypi' "
Push the tag to git: git push --tags origin v<RELEASE>-release

6. Build both the sources and the wheel. Do not change anything in setup.py between
creating the wheel and the source distribution (obviously).

Run `make build-release`. This will build the release and do some sanity checks for you. If this ends with an error
message, you need to fix things before going further.

You should now have a /dist directory with both .whl and .tar.gz source versions.

7. Check that everything looks correct by uploading the package to the pypi test server:

twine upload dist/* -r testpypi
(pypi suggest using twine as other methods upload files via plaintext.)
You may have to specify the repository url, use the following command then:
twine upload dist/* -r testpypi --repository-url=https://test.pypi.org/legacy/

Check that you can install it in a virtualenv by running:
pip install -i https://test.pypi.org/simple/ transformers

Check you can run the following commands:
python -c "from transformers import pipeline; classifier = pipeline('text-classification'); print(classifier('What a nice release'))"
python -c "from transformers import *"
python utils/check_build.py --check_lib

If making a patch release, double check the bug you are patching is indeed resolved.

8. Upload the final version to actual pypi:
twine upload dist/* -r pypi

9. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory.
6. Have a core maintainer review and approve the deployment to pypi.
"""

import re
@@ -168,7 +140,7 @@ _deps = [
"tensorboard",
"timeout-decorator",
"tiktoken",
"timm<=1.0.19,!=1.0.18",
"timm>=1.0.20",
"tokenizers>=0.22.0,<=0.23.0",
"torch>=2.2",
"torchaudio",


+ 2
- 0
src/transformers/__init__.py View File

@@ -266,6 +266,7 @@ _import_structure = {
],
"video_utils": [],
"utils.kernel_config": ["KernelConfig"],
"utils.import_utils": ["requires_backends"],
}

# tokenizers-backed objects
@@ -750,6 +751,7 @@ if TYPE_CHECKING:
from .utils import is_torch_npu_available as is_torch_npu_available
from .utils import is_torch_xla_available as is_torch_xla_available
from .utils import is_torch_xpu_available as is_torch_xpu_available
from .utils.import_utils import requires_backends
from .utils.kernel_config import KernelConfig as KernelConfig

# Quantization config


+ 6
- 2
src/transformers/conversion_mapping.py View File

@@ -142,11 +142,11 @@ def _build_checkpoint_conversion_mapping():
if hasattr(torch.nn.utils.parametrizations, "weight_norm"):
mapping["legacy"] += [
WeightRenaming(
source_patterns=r"weight_g$",
source_patterns="weight_g",
target_patterns="parametrizations.weight.original0",
),
WeightRenaming(
source_patterns=r"weight_v$",
source_patterns="weight_v",
target_patterns="parametrizations.weight.original1",
),
]
@@ -166,6 +166,9 @@ def _build_checkpoint_conversion_mapping():
mapping["deepseek_v3"] = mapping["qwen2_moe"].copy()
mapping["dots1"] = mapping["qwen2_moe"].copy()
mapping["ernie4_5_moe"] = mapping["qwen2_moe"].copy()
mapping["ernie4_5_moe"] += [
WeightRenaming("mlp.moe_statics.e_score_correction_bias", "mlp.gate.moe_statics.e_score_correction_bias")
]
mapping["glm4_moe"] = mapping["qwen2_moe"].copy()
mapping["glm4v_moe"] = mapping["qwen2_moe"].copy()
mapping["longcat_flash"] = mapping["qwen2_moe"].copy()
@@ -225,6 +228,7 @@ VLMS = [
"sam3",
"sam3_tracker",
"sam3_tracker_video",
"paddleocrvl",
]




+ 62
- 57
src/transformers/core_model_loading.py View File

@@ -409,7 +409,7 @@ class WeightRenaming(WeightTransform):
config=None,
hf_quantizer=None,
missing_keys: Optional[MutableSet[str]] = None,
misc: Optional[MutableMapping[str, str]] = None,
conversion_errors: Optional[MutableMapping[str, str]] = None,
):
# Collect the tensors here - we use a new dictionary to avoid keeping them in memory in the internal
# attribute during the whole process
@@ -421,7 +421,9 @@ class WeightRenaming(WeightTransform):
collected_tensors = {target_key: collected_tensors[self.source_patterns[0]]}

if hf_quantizer is not None and self.quantization_operation is not None:
with log_to_misc(layer_name, misc, (len(collected_tensors), layer_name), self.quantization_operation):
with log_conversion_errors(
layer_name, conversion_errors, (len(collected_tensors), layer_name), self.quantization_operation
):
collected_tensors = self.quantization_operation.convert(
collected_tensors,
source_patterns=self.source_patterns,
@@ -432,7 +434,7 @@ class WeightRenaming(WeightTransform):
missing_keys=missing_keys,
)

return collected_tensors, misc
return collected_tensors, conversion_errors


@dataclass(slots=True)
@@ -455,14 +457,14 @@ class WeightConverter(WeightTransform):
config=None,
hf_quantizer=None,
missing_keys: Optional[MutableSet[str]] = None,
misc: Optional[MutableMapping[str, str]] = None,
conversion_errors: Optional[MutableMapping[str, str]] = None,
):
# Collect the tensors here - we use a new dictionary to avoid keeping them in memory in the internal
# attribute during the whole process
collected_tensors = self.materialize_tensors()

for op in self.operations:
with log_to_misc(layer_name, misc, (len(collected_tensors), layer_name), op):
with log_conversion_errors(layer_name, conversion_errors, (len(collected_tensors), layer_name), op):
collected_tensors = op.convert(
collected_tensors,
source_patterns=self.source_patterns,
@@ -489,7 +491,9 @@ class WeightConverter(WeightTransform):
pass

if hf_quantizer is not None and self.quantization_operation is not None:
with log_to_misc(layer_name, misc, (len(collected_tensors), layer_name), self.quantization_operation):
with log_conversion_errors(
layer_name, conversion_errors, (len(collected_tensors), layer_name), self.quantization_operation
):
collected_tensors = self.quantization_operation.convert(
collected_tensors,
source_patterns=self.source_patterns,
@@ -499,7 +503,7 @@ class WeightConverter(WeightTransform):
model=model,
missing_keys=missing_keys,
)
return collected_tensors, misc
return collected_tensors, conversion_errors


# For I/O bound operations (i.e. here reading files), it is better to have fewer threads, e.g. 4 is a good default.
@@ -560,13 +564,14 @@ def dot_natural_key(s: str):


@contextmanager
def log_to_misc(
def log_conversion_errors(
first_target_key: str,
misc: MutableMapping[str, str],
conversion_errors: MutableMapping[str, str],
extras: Any = None,
op: Union[list[ConversionOps], ConversionOps, None] = None,
):
# A simple helper to handle errors with contextual messages.
"""Catch all exceptions during `convert` calls, and log the errors for later. Re-raise a `SkipParameters` exception
that will be catched later to skip the parameters that raised the original Exception."""
try:
yield
except Exception as e:
@@ -585,17 +590,19 @@ def log_to_misc(
if isinstance(extras, tuple) and len(extras) == 2:
length, target_keys = extras
descriptor = f"{op_name} " if op_name else ""
misc[first_target_key] = (
conversion_errors[first_target_key] = (
f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {length}"
)
elif isinstance(extras, str):
suffix = f" via {op_name}" if op_name else ""
misc[first_target_key] = f"{e}\nError{suffix} when processing parameter {extras}"
conversion_errors[first_target_key] = f"{e}\nError{suffix} when processing parameter {extras}"
elif extras is None and op_name:
misc[first_target_key] = f"{op_name}: {e}"
conversion_errors[first_target_key] = f"{op_name}: {e}"
else:
misc[first_target_key] = f"{extras} |Error: {e}"
raise SkipLayer()
conversion_errors[first_target_key] = f"{extras} |Error: {e}"

# Raise a specific Exception that we can catch easily
raise SkipParameters()


def set_param_for_module(
@@ -604,44 +611,42 @@ def set_param_for_module(
param_value: torch.Tensor,
mismatch_keys: MutableSet[tuple[str, torch.Size, torch.Size]],
missing_keys: MutableSet[str],
misc: MutableMapping[str, Any],
unexpected_keys: MutableSet[str],
distributed_operation: Optional[TensorParallelLayer],
hf_quantizer: HfQuantizer,
):
with log_to_misc(target_name, misc, target_name):
module_path, _, param_name = target_name.rpartition(".")
module_obj = model.get_submodule(module_path) if module_path else model
module_path, _, param_name = target_name.rpartition(".")
module_obj = model.get_submodule(module_path) if module_path else model

ref = getattr(module_obj, param_name)
if ref is None:
unexpected_keys.add(target_name)
ref = getattr(module_obj, param_name)
if ref is None:
unexpected_keys.add(target_name)
else:
use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor
if not isinstance(param_value, torch.nn.Parameter):
if distributed_operation is not None:
param_value = DTensor.from_local(
param_value,
distributed_operation.device_mesh,
getattr(distributed_operation, "shard", Replicate()),
run_check=False,
shape=ref.size(),
stride=ref.stride(),
)
if not use_dtensor:
# we convert to local
param_value = param_value.to_local()
if param_name not in module_obj._buffers:
param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())

# Remove from missing keys (it's either mismatched, or all good)
missing_keys.discard(target_name)
if ref is not None and ref.shape != param_value.shape and hf_quantizer is None:
mismatch_keys.add((target_name, param_value.shape, ref.shape))
else:
use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor
if not isinstance(param_value, torch.nn.Parameter):
if distributed_operation is not None:
param_value = DTensor.from_local(
param_value,
distributed_operation.device_mesh,
getattr(distributed_operation, "shard", Replicate()),
run_check=False,
shape=ref.size(),
stride=ref.stride(),
)
if not use_dtensor:
# we convert to local
param_value = param_value.to_local()
if param_name not in module_obj._buffers:
param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())

# Remove from missing keys (it's either mismatched, or all good)
missing_keys.discard(target_name)
if ref is not None and ref.shape != param_value.shape and hf_quantizer is None:
mismatch_keys.add((target_name, param_value.shape, ref.shape))
else:
# super important otherwise _init_weight will re-init the param
param_value._is_hf_initialized = True
setattr(module_obj, param_name, param_value)
# super important otherwise _init_weight will re-init the param
param_value._is_hf_initialized = True
setattr(module_obj, param_name, param_value)


def offload_and_maybe_resave_param(
@@ -663,8 +668,9 @@ def offload_and_maybe_resave_param(
return disk_offload_index


class SkipLayer(Exception):
"""Control-flow sentinel: abort processing of the current layer only."""
class SkipParameters(Exception):
"""Control-flow sentinel: abort processing of the current parameters only (that were supposed to be created
by a WeightConverter)."""

pass

@@ -818,7 +824,7 @@ def convert_and_load_state_dict_in_model(
meta_model_state_dict = model.state_dict()
missing_keys = set(meta_model_state_dict.keys())

misc = {}
conversion_errors = {}
mismatch_keys = set()
unexpected_keys = set()

@@ -879,7 +885,7 @@ def convert_and_load_state_dict_in_model(
elif dtype_plan != {} and dtype_policy_alt.search(renamed_key):
matched_dtype_pattern = dtype_policy_alt.search(renamed_key)
if matched_dtype_pattern is not None:
_dtype = dtype_plan[matched_dtype_pattern.group()]
_dtype = dtype_plan[dtype_policy_by_group_name[matched_dtype_pattern.lastgroup]]
elif empty_param is not None and empty_param.dtype != _dtype:
_dtype = empty_param.dtype # usually correct when initializing

@@ -925,13 +931,13 @@ def convert_and_load_state_dict_in_model(
pbar.set_postfix({"Materializing param": first_param_name})
pbar.refresh()
try:
realized_value, misc = mapping.convert(
realized_value, conversion_errors = mapping.convert(
first_param_name,
model=model,
config=model.config,
hf_quantizer=hf_quantizer,
missing_keys=missing_keys,
misc=misc,
conversion_errors=conversion_errors,
)
for target_name, param in realized_value.items():
param = param[0] if isinstance(param, list) else param
@@ -949,7 +955,6 @@ def convert_and_load_state_dict_in_model(
param,
mismatch_keys,
missing_keys,
misc,
unexpected_keys,
mapping.distributed_operation,
hf_quantizer,
@@ -958,7 +963,7 @@ def convert_and_load_state_dict_in_model(
# Cleanup all the tensors that were gathered before next iteration
del realized_value

except SkipLayer:
except SkipParameters:
continue

# Close the pool, independently of whether the code was interrupted or finished successfully
@@ -969,7 +974,7 @@ def convert_and_load_state_dict_in_model(

# Keep the current weight conversion mapping for later saving (in case it was coming directly from the user)
model._weight_conversions = weight_mapping
return missing_keys, unexpected_keys, mismatch_keys, disk_offload_index, misc
return missing_keys, unexpected_keys, mismatch_keys, disk_offload_index, conversion_errors


def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch.Tensor]):
@@ -1016,7 +1021,7 @@ def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch
new_state_dict = {}
for first_param_name, reversed_converter in conversion_mapping.items():
# Apply the reverse converter
realized_value, misc = reversed_converter.convert(first_param_name, model=model, config=model.config)
realized_value, _ = reversed_converter.convert(first_param_name, model=model, config=model.config)
for target_name, param in realized_value.items():
param = param[0] if isinstance(param, list) else param
new_state_dict[target_name] = param


+ 1
- 1
src/transformers/dependency_versions_table.py View File

@@ -75,7 +75,7 @@ deps = {
"tensorboard": "tensorboard",
"timeout-decorator": "timeout-decorator",
"tiktoken": "tiktoken",
"timm": "timm<=1.0.19,!=1.0.18",
"timm": "timm>=1.0.20",
"tokenizers": "tokenizers>=0.22.0,<=0.23.0",
"torch": "torch>=2.2",
"torchaudio": "torchaudio",


+ 38
- 9
src/transformers/feature_extraction_utils.py View File

@@ -67,11 +67,18 @@ class BatchFeature(UserDict):
tensor_type (`Union[None, str, TensorType]`, *optional*):
You can give a tensor_type here to convert the lists of integers in PyTorch/Numpy Tensors at
initialization.
skip_tensor_conversion (`list[str]` or `set[str]`, *optional*):
List or set of keys that should NOT be converted to tensors, even when `tensor_type` is specified.
"""

def __init__(self, data: Optional[dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None):
def __init__(
self,
data: Optional[dict[str, Any]] = None,
tensor_type: Union[None, str, TensorType] = None,
skip_tensor_conversion: Optional[Union[list[str], set[str]]] = None,
):
super().__init__(data)
self.convert_to_tensors(tensor_type=tensor_type)
self.convert_to_tensors(tensor_type=tensor_type, skip_tensor_conversion=skip_tensor_conversion)

def __getitem__(self, item: str) -> Any:
"""
@@ -110,6 +117,14 @@ class BatchFeature(UserDict):
import torch

def as_tensor(value):
if torch.is_tensor(value):
return value

# stack list of tensors if tensor_type is PyTorch (# torch.tensor() does not support list of tensors)
if isinstance(value, (list, tuple)) and len(value) > 0 and torch.is_tensor(value[0]):
return torch.stack(value)

# convert list of numpy arrays to numpy array (stack) if tensor_type is Numpy
if isinstance(value, (list, tuple)) and len(value) > 0:
if isinstance(value[0], np.ndarray):
value = np.array(value)
@@ -138,7 +153,11 @@ class BatchFeature(UserDict):
is_tensor = is_numpy_array
return is_tensor, as_tensor

def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
def convert_to_tensors(
self,
tensor_type: Optional[Union[str, TensorType]] = None,
skip_tensor_conversion: Optional[Union[list[str], set[str]]] = None,
):
"""
Convert the inner content to tensors.

@@ -146,6 +165,8 @@ class BatchFeature(UserDict):
tensor_type (`str` or [`~utils.TensorType`], *optional*):
The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
`None`, no modification is done.
skip_tensor_conversion (`list[str]` or `set[str]`, *optional*):
List or set of keys that should NOT be converted to tensors, even when `tensor_type` is specified.
"""
if tensor_type is None:
return self
@@ -154,18 +175,26 @@ class BatchFeature(UserDict):

# Do the tensor conversion in batch
for key, value in self.items():
# Skip keys explicitly marked for no conversion
if skip_tensor_conversion and key in skip_tensor_conversion:
continue

try:
if not is_tensor(value):
tensor = as_tensor(value)

self[key] = tensor
except: # noqa E722
except Exception as e:
if key == "overflowing_values":
raise ValueError("Unable to create tensor returning overflowing values of different lengths. ")
raise ValueError(
f"Unable to create tensor for '{key}' with overflowing values of different lengths. "
f"Original error: {str(e)}"
) from e
raise ValueError(
"Unable to create tensor, you should probably activate padding "
"with 'padding=True' to have batched tensors with the same length."
)
f"Unable to convert output '{key}' (type: {type(value).__name__}) to tensor: {str(e)}\n"
f"You can try:\n"
f" 1. Use padding=True to ensure all outputs have the same shape\n"
f" 2. Set return_tensors=None to return Python objects instead of tensors"
) from e

return self



+ 20
- 23
src/transformers/generation/continuous_batching/continuous_api.py View File

@@ -66,7 +66,7 @@ def pad_by_intervals(size: int, max_value: int, nb_intervals: int) -> int:
interval_size = max_value // nb_intervals
if interval_size == 0:
return max_value
padded = ceil(size / interval_size) * interval_size
padded = ceil(size / interval_size) * interval_size if size > 0 else interval_size
return min(padded, max_value)


@@ -259,7 +259,7 @@ class ContinuousBatchProcessor:
self.cumulative_seqlens_q = torch.empty((self.max_batch_tokens + 1,), **self.tensor_metadata)
self.max_seqlen_q = 0
self.logits_indices = torch.empty((self.max_batch_tokens,), **self.tensor_metadata)
self.output_ids = torch.empty((1, self.max_batch_tokens), **self.tensor_metadata)
self.output_ids = torch.empty((self.max_batch_tokens,), **self.tensor_metadata)

# For some kwargs, we have a dict of tensors with as many items as there are attention types
layer_types = getattr(self.config, "layer_types", None)
@@ -311,7 +311,7 @@ class ContinuousBatchProcessor:
self.cumulative_seqlens_q[: b_size + 1].zero_()
self.max_seqlen_q = 0
self.logits_indices[:q_len].fill_(-1)
self.output_ids[:, :q_len].fill_(-1)
self.output_ids[:q_len].fill_(-1)

# Reset the attributes that are either tensors or dict of tensors
for layer_type in self.cumulative_seqlens_k:
@@ -447,7 +447,7 @@ class ContinuousBatchProcessor:
self.metrics.record_batch_metrics(self.requests_in_batch)

# Reset the static tensors used for storage
self.reset_static_tensors() # TODO: this might be unnecessary
self.reset_static_tensors() # FIXME: why does this make the generation faster?

# Prepare accumulators
self.actual_query_length = 0
@@ -557,13 +557,10 @@ class ContinuousBatchProcessor:
self.actual_index_sizes[i] = (len(group_read_indices), len(group_write_indices))

@traced
def _sync(self) -> list[int]:
if self.output_ids is not None:
try:
return self.output_ids.tolist()[0]
except Exception:
return [0, 1]
return [0, 0]
def _get_new_tokens(self, num_new_tokens: int) -> list[int]:
indices = self.logits_indices[:num_new_tokens]
new_tokens = self.output_ids[indices]
return new_tokens.tolist()

@traced
def _maybe_send_output(self, state: RequestState) -> None:
@@ -574,13 +571,13 @@ class ContinuousBatchProcessor:
@traced
def update_batch(self) -> None:
"""Update request states based on generated tokens."""
out_tokens = self._sync()
new_tokens = self._get_new_tokens(len(self.requests_in_batch))
for i, state in enumerate(self.requests_in_batch):
# If the request has no remaining prompt ids, it means prefill has already ended or just finished
if len(state.remaining_prefill_tokens) == 0:
self.metrics.record_ttft_metric(state.created_time, state.request_id)
state.status = RequestStatus.DECODING
token = out_tokens[self.logits_indices[i]]
token = new_tokens[i]
state.tokens_to_process = [token]
# Update the request and stop if it is complete
is_finished = state.update_and_check_completion(token)
@@ -713,6 +710,7 @@ class ContinuousBatchProcessor:
# Handle shape compatibility: logit processors expect 2D tensors [batch_size, vocab_size]
# but continuous batching always produces 3D tensors [batch_size, seq_len, vocab_size]
batch_size, seq_len, vocab_size = logits.shape
# NOTE: to be an exact match with generate, we should also convert logits2d to float32 here, but it's not needed in practice
logits_2d = logits.view(batch_size * seq_len, vocab_size)
input_ids_2d = batch_data["input_ids"].view(batch_size * seq_len)
# Process with 2D tensors
@@ -726,12 +724,11 @@ class ContinuousBatchProcessor:
probs = nn.functional.softmax(probs, dim=-1)
# probs[0] has shape [seq_len, vocab_size], multinomial returns [seq_len, 1]
next_tokens = torch.multinomial(probs[0], num_samples=1).squeeze(-1) # Now [seq_len]
# Add batch dimension back to match argmax output
next_tokens = next_tokens.unsqueeze(0) # Now [1, seq_len]
else:
next_tokens = torch.argmax(probs, dim=-1) # Already [1, seq_len]
tokens = next_tokens.size(1) # Get seq_len dimension
self.output_ids[:, :tokens].copy_(next_tokens)
next_tokens = torch.argmax(probs, dim=-1) # shape is [1, seq_len]
next_tokens = next_tokens.squeeze(0) # shape is [seq_len]
tokens = next_tokens.size(0) # Get seq_len dimension
self.output_ids[:tokens].copy_(next_tokens)


# Manager Class (User Interface)
@@ -869,7 +866,7 @@ class ContinuousBatchingManager:
logger.warning("\nBatch processor was not initialized.")
else:
if self.batch_processor.cache.use_prefix_sharing:
logger.warning(
logger.info(
f"\nPrefix sharing was on. Total prefix length: {self.batch_processor.cache._total_prefix_length}"
)

@@ -949,6 +946,10 @@ class ContinuousBatchingManager:
streaming: bool = False,
record_timestamps: bool = False,
) -> None:
# If there is prefix sharing, we sort the inputs to maximize cache hits
if self._allow_prefix_sharing:
inputs = sorted(inputs, reverse=True)
# Add requests in order
for input_ids in inputs:
self.add_request(
input_ids, max_new_tokens=max_new_tokens, streaming=streaming, record_timestamps=record_timestamps
@@ -1079,10 +1080,6 @@ class ContinuousBatchingManager:
)

self._generation_step()

if torch.cuda.is_available():
torch.cuda.synchronize() # FIXME: why is this needed?
# Processor updates the batch after generation step is truly over
batch_processor.update_batch()

@traced


+ 15
- 4
src/transformers/image_processing_utils_fast.py View File

@@ -932,11 +932,22 @@ class BaseImageProcessorFast(BaseImageProcessor):
if do_pad:
processed_images = self.pad(processed_images, pad_size=pad_size, disable_grouping=disable_grouping)

processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)

def to_dict(self):
encoder_dict = super().to_dict()
encoder_dict.pop("_valid_processor_keys", None)
encoder_dict.pop("_valid_kwargs_names", None)
return encoder_dict

# Filter out None values that are class defaults, but preserve explicitly set None values
filtered_dict = {}
for key, value in encoder_dict.items():
if value is None:
class_default = getattr(type(self), key, "NOT_FOUND")
# Keep None if user explicitly set it (class default is non-None)
if class_default != "NOT_FOUND" and class_default is not None:
filtered_dict[key] = value
else:
filtered_dict[key] = value

filtered_dict.pop("_valid_processor_keys", None)
filtered_dict.pop("_valid_kwargs_names", None)
return filtered_dict

+ 26
- 0
src/transformers/integrations/accelerate.py View File

@@ -554,6 +554,32 @@ def offload_weight(weight: torch.Tensor, weight_name: str, offload_folder: str |
return offload_index


def load_offloaded_parameter(model: "PreTrainedModel", param_name: str) -> torch.Tensor:
"""Load `param_name` from disk, if it was offloaded due to the device_map, and thus lives as a meta parameter
inside `model`.
This is needed when resaving a model, when some parameters were offloaded (we need to load them from disk, to
then resave them to disk in the correct shard...)."""
# Start from the most inner module, and try to find the hook that was used for offloading the param
module_parts = param_name.split(".")
modules_to_check = [".".join(module_parts[:-idx]) for idx in range(1, len(module_parts))] + [""]
for parent_name in modules_to_check:
parent = model.get_submodule(parent_name)
if hasattr(parent, "_hf_hook"):
weights_map = parent._hf_hook.weights_map
truncated_param_name = param_name.replace(f"{parent_name}." if parent_name != "" else parent_name, "")
break
# If we did not break the loop, something is wrong
else:
raise ValueError(
f"{param_name} is on the meta device because it was offloaded, but we could not find "
"the corresponding hook for it"
)

# This call loads it from disk
tensor = weights_map[truncated_param_name]
return tensor


def _init_infer_auto_device_map(
model: nn.Module,
max_memory: dict[int | str, int | str] | None = None,


+ 50
- 106
src/transformers/integrations/awq.py View File

@@ -15,12 +15,13 @@

from typing import Optional, Union

from ..utils import is_gptqmodel_available, is_llm_awq_available, is_torch_available, logging
from ..utils.quantization_config import (
AwqBackend,
)
from ..quantizers.quantizers_utils import should_convert_module
from ..utils import is_accelerate_available, is_torch_available, logging


if is_accelerate_available():
from accelerate import init_empty_weights

if is_torch_available():
import torch
import torch.nn as nn
@@ -61,120 +62,63 @@ def replace_with_awq_linear(
model,
modules_to_not_convert=None,
quantization_config=None,
current_key_name=None,
has_been_replaced=False,
device_map: Optional[Union[str, dict]] = None,
) -> bool:
"""
Public method that recursively replaces the Linear layers of the given model with AWQ quantized layers.
`accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
conversion has been successful or not.

During the module replacement, we also infer the backend to use through the `quantization_config` object.
Public method that replaces the linear layers of the given model with awq quantized layers.

Args:
model (`torch.nn.Module`):
The model to convert, can be any `torch.nn.Module` instance.
quantization_config (`AwqConfig`):
The quantization config object that contains the quantization parameters.
modules_to_not_convert (`list`, *optional*):
A list of modules to not convert. If a module name is in the list (e.g. `lm_head`), it will not be
modules_to_not_convert (`list[str]`, *optional*, defaults to `None`):
A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be
converted.
current_key_name (`list`, *optional*):
A list that contains the current key name. This is used for recursion and should not be passed by the user.
has_been_replaced (`bool`, *optional*):
A boolean that indicates if the conversion has been successful or not. This is used for recursion and
should not be passed by the user.
device_map (`Union[str, dict]`, *optional*, defaults to `None`):
The device map that maps the parameters to the device
"""
if modules_to_not_convert is None:
modules_to_not_convert = []

backend = quantization_config.backend

if not is_gptqmodel_available() and not is_llm_awq_available():
raise ValueError(
"AWQ (either `llmawq`) is not available. Please install it with `pip install gptqmodel` or check out the installation guide in https://github.com/mit-han-lab/llm-awq"
)

if backend != AwqBackend.LLMAWQ:
from gptqmodel.quantization import METHOD
from gptqmodel.utils.importer import hf_select_quant_linear_v2

target_cls = hf_select_quant_linear_v2(
bits=quantization_config.bits,
group_size=quantization_config.group_size,
desc_act=False,
sym=False,
format=quantization_config.format,
backend=quantization_config.backend,
device_map=device_map,
quant_method=METHOD.AWQ,
zero_point=quantization_config.zero_point,
pack=False,
)
else:
from awq.quantize.qmodule import WQLinear

target_cls = WQLinear

for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)

if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
in_features = module.in_features
out_features = module.out_features

if backend != AwqBackend.LLMAWQ:
model._modules[name] = target_cls(
bits=quantization_config.bits,
sym=quantization_config.sym,
desc_act=quantization_config.desc_act,
group_size=quantization_config.group_size,
in_features=in_features,
out_features=out_features,
bias=module.bias is not None,
dev=module.weight.device,
register_buffers=True,
)
else:
model._modules[name] = target_cls(
w_bit=quantization_config.bits,
group_size=quantization_config.group_size,
in_features=in_features,
out_features=out_features,
bias=module.bias is not None,
dev=module.weight.device,
)
from gptqmodel.quantization import METHOD
from gptqmodel.utils.importer import hf_select_quant_linear_v2

target_cls = hf_select_quant_linear_v2(
bits=quantization_config.bits,
group_size=quantization_config.group_size,
desc_act=False,
sym=False,
format=quantization_config.format,
backend=quantization_config.backend,
device_map=device_map,
quant_method=METHOD.AWQ,
zero_point=quantization_config.zero_point,
pack=False,
)

for module_name, module in model.named_modules():
if not should_convert_module(module_name, modules_to_not_convert):
continue
with init_empty_weights():
if isinstance(module, nn.Linear):
new_module = target_cls(
bits=quantization_config.bits,
sym=quantization_config.sym,
desc_act=quantization_config.desc_act,
group_size=quantization_config.group_size,
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
dev=module.weight.device,
register_buffers=True,
)
new_module.requires_grad_(False)
model.set_submodule(module_name, new_module)
has_been_replaced = True

# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
if len(list(module.children())) > 0:
_, has_been_replaced = replace_with_awq_linear(
module,
modules_to_not_convert=modules_to_not_convert,
current_key_name=current_key_name,
quantization_config=quantization_config,
has_been_replaced=has_been_replaced,
device_map=device_map,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced


def post_init_awq_ipex_modules(model):
"""
Runs post init for IPEX layers which performs:
- Weights packing, reordering and repacking
"""

from gptqmodel.quantization.awq.modules.linear.gemm_ipex import ipex_post_init

model = ipex_post_init(model)
if not has_been_replaced:
logger.warning(
"You are loading your model using eetq but no linear modules were found in your model."
" Please double check your model architecture, or submit an issue on github if you think this is"
" a bug."
)

return model

+ 88
- 38
src/transformers/integrations/fbgemm_fp8.py View File

@@ -12,12 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import lru_cache
from typing import Optional

from ..activations import ACT2FN
from ..core_model_loading import ConversionOps
from ..quantizers.quantizers_utils import get_module_from_name, should_convert_module
from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging
from ..utils import (
is_accelerate_available,
is_fbgemm_gpu_available,
is_torch_available,
is_torch_xpu_available,
logging,
)


if is_torch_available():
@@ -27,7 +34,9 @@ if is_torch_available():
if is_accelerate_available():
from accelerate import init_empty_weights

if is_fbgemm_gpu_available():
_is_torch_xpu_available = is_torch_xpu_available()

if is_fbgemm_gpu_available() and not _is_torch_xpu_available:
import fbgemm_gpu.experimental.gen_ai # noqa: F401

logger = logging.get_logger(__name__)
@@ -61,7 +70,7 @@ class FbgemmFp8Quantize(ConversionOps):
flattened_param = transposed_param.reshape(-1, original_shape[-1])

# Quantize using per row instead of per column
new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param)
new_value_flat, weight_scale_flat = quantize_fp8_per_row(flattened_param)

# Reshape back to original dimensions
new_value = new_value_flat.reshape(original_shape)
@@ -77,14 +86,14 @@ class FbgemmFp8Quantize(ConversionOps):
flattened_param = transposed_param.reshape(-1, original_shape[-1])

# Quantize using per column
new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param)
new_value_flat, weight_scale_flat = quantize_fp8_per_row(flattened_param)

# Reshape back to original dimensions
new_value = new_value_flat.reshape(original_shape)
new_value = new_value.transpose(1, 2)
weight_scale = weight_scale_flat.reshape(original_shape[0], original_shape[1], 1)
else:
new_value, weight_scale = torch.ops.fbgemm.quantize_fp8_per_row(value)
new_value, weight_scale = quantize_fp8_per_row(value)
weight_scale = torch.nn.Parameter(weight_scale.view(weight_scale.shape[0], 1))

return {target_key: torch.nn.Parameter(new_value), f"{target_key}_scale": weight_scale}
@@ -110,18 +119,26 @@ class FbgemmFp8Linear(torch.nn.Linear):
output_shape = (*x.shape[:-1], -1)
# x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
# https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
x.view(-1, x.shape[-1]).contiguous(), scale_ub=self.input_scale_ub
)
x_quantized, x_scale = quantize_fp8_per_row(x.view(-1, x.shape[-1]).contiguous(), scale_ub=self.input_scale_ub)
# moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
# x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)

# The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight
weight_scale_float32 = self.weight_scale.to(torch.float32)
output = torch.ops.fbgemm.f8f8bf16_rowwise(
x_quantized, self.weight, x_scale, weight_scale_float32, use_fast_accum=True
)
output = output + self.bias if self.bias is not None else output
if _is_torch_xpu_available:
output = torch._scaled_mm(
x_quantized,
self.weight.t(),
scale_a=x_scale.unsqueeze(-1),
scale_b=weight_scale_float32.t(),
out_dtype=x.dtype,
bias=self.bias,
)
else:
output = torch.ops.fbgemm.f8f8bf16_rowwise(
x_quantized, self.weight, x_scale, weight_scale_float32, use_fast_accum=True
)
output = output + self.bias if self.bias is not None else output
# Hacky for now, we have the output to the device of x
output = output.to(x.device)
output = output.reshape(output_shape)
@@ -173,48 +190,79 @@ class FbgemmFp8Llama4TextExperts(nn.Module):
expert_hidden = hidden_states[i]
expert_hidden_reshaped = expert_hidden.reshape(-1, self.hidden_size)
# Quantize for this expert
expert_quantized, expert_scale = torch.ops.fbgemm.quantize_fp8_per_row(
expert_quantized, expert_scale = quantize_fp8_per_row(
expert_hidden_reshaped, num_tokens, self.input_scale_ub
)
sharded_expert_dim = self.gate_up_proj.shape[-1] // 2
gate_up_proj_scale_float32 = self.gate_up_proj_scale.to(torch.float32)
if _is_torch_xpu_available:
gate = torch._scaled_mm(
expert_quantized,
self.gate_up_proj[i].transpose(0, 1)[:sharded_expert_dim].contiguous().t(),
scale_a=expert_scale.unsqueeze(-1),
scale_b=gate_up_proj_scale_float32[i][0][:sharded_expert_dim].view(-1, 1).contiguous().t(),
out_dtype=hidden_states.dtype,
)
up = torch._scaled_mm(
expert_quantized,
self.gate_up_proj[i].transpose(0, 1)[sharded_expert_dim:].contiguous().t(),
scale_a=expert_scale.unsqueeze(-1),
scale_b=gate_up_proj_scale_float32[i][0][sharded_expert_dim:].view(-1, 1).contiguous().t(),
out_dtype=hidden_states.dtype,
)
else:
gate = torch.ops.fbgemm.f8f8bf16_rowwise(
expert_quantized,
self.gate_up_proj[i].transpose(0, 1)[:sharded_expert_dim].contiguous(),
expert_scale,
gate_up_proj_scale_float32[i][0][:sharded_expert_dim].view(-1, 1).contiguous(),
use_fast_accum=True,
)

gate = torch.ops.fbgemm.f8f8bf16_rowwise(
expert_quantized,
self.gate_up_proj[i].transpose(0, 1)[:sharded_expert_dim].contiguous(),
expert_scale,
gate_up_proj_scale_float32[i][0][:sharded_expert_dim].view(-1, 1).contiguous(),
use_fast_accum=True,
)

up = torch.ops.fbgemm.f8f8bf16_rowwise(
expert_quantized,
self.gate_up_proj[i].transpose(0, 1)[sharded_expert_dim:].contiguous(),
expert_scale,
gate_up_proj_scale_float32[i][0][sharded_expert_dim:].view(-1, 1).contiguous(),
use_fast_accum=True,
)
up = torch.ops.fbgemm.f8f8bf16_rowwise(
expert_quantized,
self.gate_up_proj[i].transpose(0, 1)[sharded_expert_dim:].contiguous(),
expert_scale,
gate_up_proj_scale_float32[i][0][sharded_expert_dim:].view(-1, 1).contiguous(),
use_fast_accum=True,
)

activated = up * self.act_fn(gate)

activated_quantized, activated_scale = torch.ops.fbgemm.quantize_fp8_per_row(
activated, num_tokens, self.input_scale_ub
)
activated_quantized, activated_scale = quantize_fp8_per_row(activated, num_tokens, self.input_scale_ub)

down_proj_scale_float32 = self.down_proj_scale.to(torch.float32)
expert_output = torch.ops.fbgemm.f8f8bf16_rowwise(
activated_quantized,
self.down_proj[i].transpose(0, 1).contiguous(),
activated_scale,
down_proj_scale_float32[i].view(-1, 1).contiguous(),
use_fast_accum=True,
)
if _is_torch_xpu_available:
expert_output = torch._scaled_mm(
activated_quantized,
self.down_proj[i].transpose(0, 1).contiguous(),
scale_a=activated_scale.unsqueeze(-1),
scale_b=down_proj_scale_float32[i].view(-1, 1).contiguous().t(),
out_dtype=hidden_states.dtype,
)
else:
expert_output = torch.ops.fbgemm.f8f8bf16_rowwise(
activated_quantized,
self.down_proj[i].transpose(0, 1).contiguous(),
activated_scale,
down_proj_scale_float32[i].view(-1, 1).contiguous(),
use_fast_accum=True,
)

next_states[i] = expert_output
next_states = next_states.to(hidden_states.device)
return next_states.view(-1, self.hidden_size)


@lru_cache(maxsize=1)
def get_quantize_fp8_per_row():
if _is_torch_xpu_available:
from kernels import get_kernel

return get_kernel("kernels-community/fp8-fbgemm").quantize_fp8_per_row
return torch.ops.fbgemm.quantize_fp8_per_row


def replace_with_fbgemm_fp8_linear(
model, modules_to_not_convert: list[str] | None = None, quantization_config=None, pre_quantized=False, tp_plan=None
):
@@ -232,6 +280,8 @@ def replace_with_fbgemm_fp8_linear(
pre_quantized (`book`, defaults to `False`):
Whether the model is pre-quantized or not
"""
global quantize_fp8_per_row
quantize_fp8_per_row = get_quantize_fp8_per_row()

has_been_replaced = False
module_kwargs = {} if pre_quantized else {"dtype": None}


+ 14
- 2
src/transformers/integrations/hub_kernels.py View File

@@ -111,6 +111,12 @@ try:
layer_name="RMSNorm",
)
},
"mps": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/mlx_rmsnorm",
layer_name="RMSNorm",
)
},
"npu": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/liger_kernels",
@@ -253,6 +259,8 @@ except ImportError:

_HUB_KERNEL_MAPPING: dict[str, dict[str, str]] = {
"causal-conv1d": {"repo_id": "kernels-community/causal-conv1d"},
"mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "revision": "v0.0.4"},
"falcon_mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "revision": "v0.0.4"},
}

_KERNEL_MODULE_MAPPING: dict[str, ModuleType | None] = {}
@@ -336,11 +344,15 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _

try:
repo_id = _HUB_KERNEL_MAPPING[kernel_name]["repo_id"]
revision = _HUB_KERNEL_MAPPING[kernel_name].get("revision", None)
version = _HUB_KERNEL_MAPPING[kernel_name].get("version", None)
kernel = get_kernel(repo_id, version=version)
kernel = get_kernel(repo_id, revision=revision, version=version)
mapping[kernel_name] = kernel
except FileNotFoundError:
mapping[kernel_name] = None
except AssertionError:
# Happens when torch is built without an accelerator backend; fall back to slow path.
mapping[kernel_name] = None

else:
# Try to import is_{kernel_name}_available from ..utils
@@ -358,7 +370,7 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _
if callable(is_kernel_available) and is_kernel_available():
# Try to import the module "{kernel_name}" from parent package level
try:
module = importlib.import_module(f"{kernel_name}")
module = importlib.import_module(f"{new_kernel_name}")
mapping[kernel_name] = module
return module
except Exception:


+ 35
- 0
src/transformers/integrations/integration_utils.py View File

@@ -940,6 +940,8 @@ class TrackioCallback(TrainerCallback):
```
"""

SPACE_URL = "https://huggingface.co/spaces/{space_id}"

def __init__(self):
has_trackio = is_trackio_available()
if not has_trackio:
@@ -1058,6 +1060,39 @@ class TrackioCallback(TrainerCallback):
metrics = rewrite_logs(metrics)
self._trackio.log(metrics)

def on_push_begin(self, args, state, control, model, **kwargs):
if not state.is_world_process_zero or self._trackio is None:
return
if (current_project := self._trackio.context_vars.current_project.get()) is None:
return
trackio_version = packaging.version.parse(self._trackio.__version__)
if trackio_version < packaging.version.parse("0.13.0"):
warnings.warn(
"The version of `trackio` that is installed is <=0.13.0, so "
"the local Trackio project will not be pushed to Hugging Face. Run "
"`pip install --upgrade trackio` to fix this."
)
return

space_id = self._trackio.context_vars.current_space_id.get()
if space_id is None:
space_id = self._trackio.sync(current_project, force=True)
space_url = self.SPACE_URL.format(space_id=space_id)

badge_markdown = (
f'<a href="{space_url}" target="_blank"><img src="https://raw.githubusercontent.com/gradio-app/trackio/refs/heads/main/trackio/assets/badge.png" alt="Visualize in Trackio"'
' title="Visualize in Trackio" style="height: 40px;"/></a>'
)
if badge_markdown not in modelcard.AUTOGENERATED_TRAINER_COMMENT:
modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"

if getattr(model, "model_tags", None) is not None:
if "trackio" not in model.model_tags:
model.model_tags.append("trackio")
model.model_tags.append(f"trackio::{space_url}")
else:
model.model_tags = ["trackio", f"trackio:{space_url}"]


class CometCallback(TrainerCallback):
"""


+ 12
- 0
src/transformers/integrations/mistral.py View File

@@ -77,6 +77,7 @@ def convert_tekken_tokenizer(tokenizer_file: str):
"""Convert a "tekken" tokenizer to a fast Tokenizer."""
# Tekken format -- need to use the Converter

from mistral_common.tokens.tokenizers.base import SpecialTokens
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer

# Load directly using their lib
@@ -106,4 +107,15 @@ def convert_tekken_tokenizer(tokenizer_file: str):
# Post-process
tokenizer.add_special_tokens({"additional_special_tokens": all_special})

MAP_SPECAL = {
"bos_token": SpecialTokens.bos.value,
"eos_token": SpecialTokens.eos.value,
"pad_token": SpecialTokens.pad.value,
"unk_token": SpecialTokens.unk.value,
}

for special_key, special_token in MAP_SPECAL.items():
if special_token in all_special:
tokenizer.add_special_tokens({special_key: special_token})

return tokenizer

+ 0
- 15
src/transformers/kernels/falcon_mamba/__init__.py View File

@@ -1,15 +0,0 @@
# coding=utf-8
# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .selective_scan_with_ln_interface import mamba_inner_fn

+ 0
- 529
src/transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py View File

@@ -1,529 +0,0 @@
# coding=utf-8
# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Original code from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py

import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch.cuda.amp import custom_bwd, custom_fwd


try:
import causal_conv1d_cuda
except ImportError:
causal_conv1d_cuda = None

import mamba_ssm
import selective_scan_cuda


# For BC for old mamba-ssm versions: https://github.com/huggingface/transformers/pull/33195#discussion_r1736401127
if hasattr(mamba_ssm.ops.triton, "layernorm"):
from mamba_ssm.ops.triton.layernorm import _layer_norm_fwd
else:
from mamba_ssm.ops.triton.layer_norm import _layer_norm_fwd


class SelectiveScanFn(torch.autograd.Function):
@staticmethod
def forward(
ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False
):
if u.stride(-1) != 1:
u = u.contiguous()
if delta.stride(-1) != 1:
delta = delta.contiguous()
if D is not None:
D = D.contiguous()
if B.stride(-1) != 1:
B = B.contiguous()
if C.stride(-1) != 1:
C = C.contiguous()
if z is not None and z.stride(-1) != 1:
z = z.contiguous()
if B.dim() == 3:
B = rearrange(B, "b dstate l -> b 1 dstate l")
ctx.squeeze_B = True
if C.dim() == 3:
C = rearrange(C, "b dstate l -> b 1 dstate l")
ctx.squeeze_C = True
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
ctx.delta_softplus = delta_softplus
ctx.has_z = z is not None
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
if not ctx.has_z:
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
return out if not return_last_state else (out, last_state)
else:
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
out_z = rest[0]
return out_z if not return_last_state else (out_z, last_state)

@staticmethod
def backward(ctx, dout, *args):
if not ctx.has_z:
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
z = None
out = None
else:
u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
if dout.stride(-1) != 1:
dout = dout.contiguous()
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
# backward of selective_scan_cuda with the backward of chunk).
# Here we just pass in None and dz will be allocated in the C++ code.
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
u,
delta,
A,
B,
C,
D,
z,
delta_bias,
dout,
x,
out,
None,
ctx.delta_softplus,
False, # option to recompute out_z, not used here
)
dz = rest[0] if ctx.has_z else None
dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
return (
du,
ddelta,
dA,
dB,
dC,
dD if D is not None else None,
dz,
ddelta_bias if delta_bias is not None else None,
None,
None,
)


def rms_norm_forward(
x,
weight,
bias,
eps=1e-6,
is_rms_norm=True,
):
# x (b l) d
if x.stride(-1) != 1:
x = x.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y = _layer_norm_fwd(x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm)[0]
# y (b l) d
return y


def selective_scan_fn(
u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False
):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
not considered in the backward pass.
"""
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)


def selective_scan_ref(
u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False
):
"""
u: r(B D L)
delta: r(B D L)
A: c(D N) or r(D N)
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
D: r(D)
z: r(B D L)
delta_bias: r(D), fp32

out: r(B D L)
last_state (optional): r(B D dstate) or c(B D dstate)
"""
dtype_in = u.dtype
u = u.float()
delta = delta.float()
if delta_bias is not None:
delta = delta + delta_bias[..., None].float()
if delta_softplus:
delta = F.softplus(delta)
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
is_variable_B = B.dim() >= 3
is_variable_C = C.dim() >= 3
if A.is_complex():
if is_variable_B:
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
if is_variable_C:
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
else:
B = B.float()
C = C.float()
x = A.new_zeros((batch, dim, dstate))
ys = []
deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
if not is_variable_B:
deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
else:
if B.dim() == 3:
deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
else:
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
if is_variable_C and C.dim() == 4:
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
last_state = None
for i in range(u.shape[2]):
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
if not is_variable_C:
y = torch.einsum("bdn,dn->bd", x, C)
else:
if C.dim() == 3:
y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
else:
y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
if i == u.shape[2] - 1:
last_state = x
if y.is_complex():
y = y.real * 2
ys.append(y)
y = torch.stack(ys, dim=2) # (batch dim L)
out = y if D is None else y + u * rearrange(D, "d -> d 1")
if z is not None:
out = out * F.silu(z)
out = out.to(dtype=dtype_in)
return out if not return_last_state else (out, last_state)


class MambaInnerFn(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(
ctx,
xz,
conv1d_weight,
conv1d_bias,
x_proj_weight,
delta_proj_weight,
out_proj_weight,
out_proj_bias,
A,
B=None,
C=None,
D=None,
delta_bias=None,
B_proj_bias=None,
C_proj_bias=None,
delta_softplus=True,
checkpoint_lvl=1,
b_rms_weight=None,
c_rms_weight=None,
dt_rms_weight=None,
b_c_dt_rms_eps=1e-6,
):
"""
xz: (batch, dim, seqlen)
"""
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
assert checkpoint_lvl in [0, 1]
L = xz.shape[-1]
delta_rank = delta_proj_weight.shape[1]
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
if torch.is_autocast_enabled():
# NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4
target_dtype = (
torch.get_autocast_dtype("cuda")
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
x_proj_weight = x_proj_weight.to(dtype=target_dtype)
delta_proj_weight = delta_proj_weight.to(dtype=target_dtype)
out_proj_weight = out_proj_weight.to(dtype=target_dtype)
out_proj_bias = out_proj_bias.to(dtype=target_dtype) if out_proj_bias is not None else None
if xz.stride(-1) != 1:
xz = xz.contiguous()
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
x, z = xz.chunk(2, dim=1)
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True)
# We're being very careful here about the layout, to avoid extra transposes.
# We want delta to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl = F.linear(rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight) # (bl d)
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L)
ctx.is_variable_B = B is None
ctx.is_variable_C = C is None
ctx.B_proj_bias_is_None = B_proj_bias is None
ctx.C_proj_bias_is_None = C_proj_bias is None
if B is None: # variable B
B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl dstate)
if B_proj_bias is not None:
B = B + B_proj_bias.to(dtype=B.dtype)
if not A.is_complex():
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
else:
B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
else:
if B.stride(-1) != 1:
B = B.contiguous()
if C is None: # variable C
C = x_dbl[:, -d_state:] # (bl dstate)
if C_proj_bias is not None:
C = C + C_proj_bias.to(dtype=C.dtype)
if not A.is_complex():
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
else:
C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
else:
if C.stride(-1) != 1:
C = C.contiguous()
if D is not None:
D = D.contiguous()

if b_rms_weight is not None:
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps)
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
if c_rms_weight is not None:
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps)
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
if dt_rms_weight is not None:
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
delta = rms_norm_forward(delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps)
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()

out, scan_intermediates, out_z = selective_scan_cuda.fwd(
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
)
ctx.delta_softplus = delta_softplus
ctx.out_proj_bias_is_None = out_proj_bias is None
ctx.checkpoint_lvl = checkpoint_lvl
ctx.b_rms_weight = b_rms_weight
ctx.c_rms_weight = c_rms_weight
ctx.dt_rms_weight = dt_rms_weight
ctx.b_c_dt_rms_eps = b_c_dt_rms_eps
if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
conv1d_out, delta = None, None
ctx.save_for_backward(
xz,
conv1d_weight,
conv1d_bias,
x_dbl,
x_proj_weight,
delta_proj_weight,
out_proj_weight,
conv1d_out,
delta,
A,
B,
C,
D,
delta_bias,
scan_intermediates,
b_rms_weight,
c_rms_weight,
dt_rms_weight,
out,
)
return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)

@staticmethod
@custom_bwd
def backward(ctx, dout):
# dout: (batch, seqlen, dim)
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
(
xz,
conv1d_weight,
conv1d_bias,
x_dbl,
x_proj_weight,
delta_proj_weight,
out_proj_weight,
conv1d_out,
delta,
A,
B,
C,
D,
delta_bias,
scan_intermediates,
b_rms_weight,
c_rms_weight,
dt_rms_weight,
out,
) = ctx.saved_tensors
L = xz.shape[-1]
delta_rank = delta_proj_weight.shape[1]
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
x, z = xz.chunk(2, dim=1)
if dout.stride(-1) != 1:
dout = dout.contiguous()
if ctx.checkpoint_lvl == 1:
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True)
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L)
if dt_rms_weight is not None:
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
delta = rms_norm_forward(delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps)
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
if b_rms_weight is not None:
# Recompute & RMSNorm B
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
B = rms_norm_forward(B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps)
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
if c_rms_weight is not None:
# Recompute & RMSNorm C
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
C = rms_norm_forward(C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps)
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()

# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
# backward of selective_scan_cuda with the backward of chunk).
dxz = torch.empty_like(xz) # (batch, dim, seqlen)
dx, dz = dxz.chunk(2, dim=1)
dout = rearrange(dout, "b l e -> e (b l)")
dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
conv1d_out,
delta,
A,
B,
C,
D,
z,
delta_bias,
dout_y,
scan_intermediates,
out,
dz,
ctx.delta_softplus,
True, # option to recompute out_z
)
dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
dD = dD if D is not None else None
dx_dbl = torch.empty_like(x_dbl)
dB_proj_bias = None
if ctx.is_variable_B:
if not A.is_complex():
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
else:
dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d)
dB = None
dC_proj_bias = None
if ctx.is_variable_C:
if not A.is_complex():
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
else:
dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
dx_dbl[:, -d_state:] = dC # (bl d)
dC = None
ddelta = rearrange(ddelta, "b d l -> d (b l)")
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
# backward of conv1d with the backward of chunk).
dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
)
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
return (
dxz,
dconv1d_weight,
dconv1d_bias,
dx_proj_weight,
ddelta_proj_weight,
dout_proj_weight,
dout_proj_bias,
dA,
dB,
dC,
dD,
ddelta_bias if delta_bias is not None else None,
# 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps
dB_proj_bias,
dC_proj_bias,
None,
None,
None,
None,
None,
None,
)


def mamba_inner_fn(
xz,
conv1d_weight,
conv1d_bias,
x_proj_weight,
delta_proj_weight,
out_proj_weight,
out_proj_bias,
A,
B=None,
C=None,
D=None,
delta_bias=None,
B_proj_bias=None,
C_proj_bias=None,
delta_softplus=True,
checkpoint_lvl=1,
b_rms_weight=None,
c_rms_weight=None,
dt_rms_weight=None,
b_c_dt_rms_eps=1e-6,
):
return MambaInnerFn.apply(
xz,
conv1d_weight,
conv1d_bias,
x_proj_weight,
delta_proj_weight,
out_proj_weight,
out_proj_bias,
A,
B,
C,
D,
delta_bias,
B_proj_bias,
C_proj_bias,
delta_softplus,
checkpoint_lvl,
b_rms_weight,
c_rms_weight,
dt_rms_weight,
b_c_dt_rms_eps,
)

+ 239
- 270
src/transformers/modeling_utils.py View File

@@ -16,7 +16,6 @@
import collections
import copy
import functools
import gc
import importlib.metadata
import inspect
import json
@@ -64,6 +63,7 @@ from .integrations.accelerate import (
check_and_set_device_map,
expand_device_map,
init_empty_weights,
load_offloaded_parameter,
)
from .integrations.deepspeed import _load_state_dict_into_zero3_model
from .integrations.eager_paged import eager_paged_attention_forward
@@ -130,7 +130,6 @@ from .utils.quantization_config import QuantizationMethod
if is_accelerate_available():
from accelerate.hooks import add_hook_to_module
from accelerate.utils import extract_model_from_parallel
from accelerate.utils.modeling import get_state_dict_from_offload


_torch_distributed_available = torch.distributed.is_available()
@@ -156,6 +155,12 @@ _init_weights = True
_is_quantized = False
_is_ds_init_called = False

# Mapping from flash attention implementations to their kernel fallback repositories
FLASH_ATTN_KERNEL_FALLBACK = {
"flash_attention_2": "kernels-community/flash-attn2",
"flash_attention_3": "kernels-community/vllm-flash-attn3",
}


def is_local_dist_rank_0():
return (
@@ -233,23 +238,28 @@ def set_zero3_state():
_is_ds_init_called = False


def restore_default_dtype(func):
@contextmanager
def local_torch_dtype(dtype: torch.dtype, model_class_name: str | None = None):
"""
Decorator to restore the default torch dtype
at the end of the function. Serves
as a backup in case calling the function raises
an error after the function has changed the default dtype but before it could restore it.
Locally change the torch default dtype to `dtype`, and restore the old one upon exiting the context.
If `model_class_name` is provided, it's used to provide a more helpful error message if `dtype` is not valid.
"""
# Just a more helping error before we set `torch.set_default_dtype` later on which would crash in this case
if not dtype.is_floating_point:
if model_class_name is not None:
error_message = (
f"{model_class_name} cannot be instantiated under `dtype={dtype}` as it's not a floating-point dtype"
)
else:
error_message = f"Cannot set `{dtype}` as torch's default as it's not a floating-point dtype"
raise ValueError(error_message)

@wraps(func)
def _wrapper(*args, **kwargs):
old_dtype = torch.get_default_dtype()
try:
return func(*args, **kwargs)
finally:
torch.set_default_dtype(old_dtype)

return _wrapper
original_dtype = torch.get_default_dtype()
try:
torch.set_default_dtype(dtype)
yield
finally:
torch.set_default_dtype(original_dtype)


def get_torch_context_manager_or_global_device():
@@ -277,7 +287,9 @@ def get_state_dict_dtype(state_dict):
return t.dtype

# if no floating dtype was found return whatever the first dtype is
return next(state_dict.values()).dtype
if len(state_dict) == 0:
return torch.float32
return next(iter(state_dict.values())).dtype


str_to_torch_dtype = {
@@ -403,6 +415,86 @@ def _find_identical(tensors: list[set[str]], state_dict: dict[str, torch.Tensor]
return shared_tensors, identical


def remove_tied_weights_from_state_dict(
state_dict: dict[str, torch.Tensor], model: "PreTrainedModel"
) -> dict[str, torch.Tensor]:
"""
Remove all tied weights from the given `state_dict`, making sure to keep only the main weight that `model`
will expect when reloading (even if we know tie weights symmetrically, it's better to keep the intended one).
This is because `safetensors` does not allow tensor aliasing - so we're going to remove aliases before saving.
"""
# To avoid any potential mistakes and mismatches between config and actual tied weights, here we check the pointers
# of the Tensors themselves -> we are guaranteed to find all the actual tied weights
ptrs = collections.defaultdict(list)
for name, tensor in state_dict.items():
if not isinstance(tensor, torch.Tensor):
# Sometimes in the state_dict we have non-tensor objects.
# e.g. in bitsandbytes we have some `str` objects in the state_dict
# In the non-tensor case, fall back to the pointer of the object itself
ptrs[id(tensor)].append(name)

elif tensor.device.type == "meta":
# In offloaded cases, there may be meta tensors in the state_dict.
# For these cases, key by the pointer of the original tensor object
# (state_dict tensors are detached and therefore no longer shared)
tensor = model.get_parameter(name)
ptrs[id(tensor)].append(name)

else:
ptrs[id_tensor_storage(tensor)].append(name)

shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}

# Recursively descend to find tied weight keys
all_potential_tied_weights_keys = set(_get_tied_weight_keys(model))
error_names = []
to_delete_names = set()
# Removing the keys which are declared as known duplicates on load. This allows to make sure the name which is
# kept is consistent
if all_potential_tied_weights_keys is not None:
for names in shared_ptrs.values():
found = 0
for name in sorted(names):
matches_pattern = any(re.search(pat, name) for pat in all_potential_tied_weights_keys)
if matches_pattern and name in state_dict:
found += 1
if found < len(names):
to_delete_names.add(name)
# We are entering a place where the weights and the transformers configuration do NOT match.
shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
# Those are actually tensor sharing but disjoint from each other, we can safely clone them
# Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
for name in disjoint_names:
state_dict[name] = state_dict[name].clone()

# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
# If the link between tensors was done at runtime then `from_pretrained` will not get
# the key back leading to random tensor. A proper warning will be shown
# during reload (if applicable), but since the file is not necessarily compatible with
# the config, better show a proper warning.
shared_names, identical_names = _find_identical(shared_names, state_dict)
# delete tensors that have identical storage
for inames in identical_names:
known = inames.intersection(to_delete_names)
for name in known:
del state_dict[name]
unknown = inames.difference(to_delete_names)
if len(unknown) > 1:
error_names.append(unknown)

if shared_names:
error_names.extend(shared_names)

if len(error_names) > 0:
raise RuntimeError(
f"The weights trying to be saved contained shared tensors {error_names} which are not properly defined. "
f"We found all the potential target tied weights keys to be: {all_potential_tied_weights_keys}.\n"
"This can also just mean that the module's tied weight keys are wrong vs the actual tied weights in the model.",
)

return state_dict


def _load_parameter_into_model(model: "PreTrainedModel", param_name: str, tensor: torch.Tensor):
"""Cast a single parameter `param_name` into the `model`, with value `tensor`."""
module, param_type = get_module_from_name(model, param_name)
@@ -694,23 +786,21 @@ def _get_resolved_checkpoint_files(


def _get_dtype(
cls,
dtype: Optional[Union[str, torch.dtype, dict]],
checkpoint_files: Optional[list[str]],
config: PreTrainedConfig,
sharded_metadata: Optional[dict],
state_dict: Optional[dict],
weights_only: bool,
) -> tuple[PreTrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
) -> tuple[PreTrainedConfig, torch.dtype]:
"""Find the correct `dtype` to use based on provided arguments. Also update the `config` based on the
inferred dtype. We do the following:
1. If dtype is not None, we use that dtype
2. If dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
we also may have config.dtype available, but we won't rely on it till v5
1. If dtype is "auto", we try to read the config, else auto-detect dtype from the loaded state_dict, by checking
its first weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
2. Else, use the dtype provided as a dict or str
"""
dtype_orig = None
is_sharded = sharded_metadata is not None
asked_dtype = dtype

if dtype is not None:
if isinstance(dtype, str):
@@ -734,43 +824,46 @@ def _get_dtype(
)
elif hasattr(torch, dtype):
dtype = getattr(torch, dtype)
config.dtype = dtype
for sub_config_key in config.sub_configs:
if (sub_config := getattr(config, sub_config_key)) is not None:
sub_config.dtype = dtype
elif isinstance(dtype, torch.dtype):
config.dtype = dtype
for sub_config_key in config.sub_configs:
if (sub_config := getattr(config, sub_config_key)) is not None:
sub_config.dtype = dtype
elif isinstance(dtype, dict):
for key, curr_dtype in dtype.items():
if hasattr(config, key):
value = getattr(config, key)
curr_dtype = curr_dtype if not isinstance(curr_dtype, str) else getattr(torch, curr_dtype)
value.dtype = curr_dtype
# main torch dtype for modules that aren't part of any sub-config
dtype = dtype.get("")
dtype = dtype if not isinstance(dtype, str) else getattr(torch, dtype)
config.dtype = dtype
if dtype is None:
dtype = torch.float32
else:
else:
raise ValueError(
"`dtype` provided as a `str` can only be `'auto'`, or a string representation of a valid `torch.dtype`"
)

# cast it to a proper `torch.dtype` object
dtype = getattr(torch, dtype) if isinstance(dtype, str) else dtype
elif not isinstance(dtype, (dict, torch.dtype)):
raise ValueError(
f"`dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `dtype` "
f"for each sub-config in composite configs, but received {dtype}"
)
else:
# set torch.get_default_dtype() (usually fp32) as the default dtype if `None` is provided
dtype = torch.get_default_dtype()

dtype_orig = cls._set_default_dtype(dtype)
# Get the main dtype
if isinstance(dtype, dict):
main_dtype = dtype.get("", torch.get_default_dtype())
main_dtype = getattr(torch, main_dtype) if isinstance(main_dtype, str) else main_dtype
else:
# set fp32 as the default dtype for BC
default_dtype = torch.get_default_dtype()
config.dtype = default_dtype
for key in config.sub_configs:
if (sub_config := getattr(config, key)) is not None:
sub_config.dtype = default_dtype
main_dtype = dtype

# Set it on the config and subconfigs
config.dtype = main_dtype
for sub_config_key in config.sub_configs:
if (sub_config := getattr(config, sub_config_key)) is not None:
# The dtype was "auto" -> try to read the subconfig dtype value if any
if asked_dtype == "auto":
sub_dtype = getattr(sub_config, "dtype", main_dtype)
sub_dtype = getattr(torch, sub_dtype) if isinstance(sub_dtype, str) else sub_dtype
# The dtype was provided as a dict, try to see if we match the subconfig name
elif isinstance(dtype, dict):
sub_dtype = dtype.get(sub_config_key, main_dtype)
sub_dtype = getattr(torch, sub_dtype) if isinstance(sub_dtype, str) else sub_dtype
else:
sub_dtype = main_dtype
sub_config.dtype = sub_dtype

return config, dtype, dtype_orig
return config, main_dtype


class PipelineParallel(Enum):
@@ -1391,7 +1484,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
self.model_tags.append(tag)

@classmethod
@restore_default_dtype
def _from_config(cls, config, **kwargs):
"""
All context managers that the model should be initialized under go here.
@@ -1400,9 +1492,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
dtype (`torch.dtype`, *optional*):
Override the default `dtype` and load the model under this dtype.
"""
# when we init a model from within another model (e.g. VLMs) and dispatch on FA2
# a warning is raised that dtype should be fp16. Since we never pass dtype from within
# modeling code, we can try to infer it here same way as done in `from_pretrained`
# For BC on the old `torch_dtype`
dtype = kwargs.pop("dtype", config.dtype)
if (torch_dtype := kwargs.pop("torch_dtype", None)) is not None:
@@ -1412,61 +1501,27 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
if isinstance(dtype, str):
dtype = getattr(torch, dtype)

# override default dtype if needed
dtype_orig = None
if dtype is not None:
dtype_orig = cls._set_default_dtype(dtype)

# If passing `attn_implementation` as kwargs, respect it (it will be applied recursively on subconfigs)
if "attn_implementation" in kwargs:
config._attn_implementation = kwargs.pop("attn_implementation")

init_contexts = []
if dtype is not None:
init_contexts.append(local_torch_dtype(dtype, cls.__name__))
if is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called:
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
# this immediately partitions the model across all gpus, to avoid the overhead in time
# and memory copying it on CPU or each GPU first
import deepspeed

init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()]
with ContextManagers(init_contexts):
model = cls(config, **kwargs)
init_contexts.extend([deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()])

else:
# Instantiate the model
with ContextManagers(init_contexts):
model = cls(config, **kwargs)

# restore default dtype if it was modified
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

return model

@classmethod
def _set_default_dtype(cls, dtype: torch.dtype) -> torch.dtype:
"""
Change the default dtype and return the previous one. This is needed when wanting to instantiate the model
under specific dtype.

Args:
dtype (`torch.dtype`):
a floating dtype to set to.

Returns:
`torch.dtype`: the original `dtype` that can be used to restore `torch.set_default_dtype(dtype)` if it was
modified. If it wasn't, returns `None`.

Note `set_default_dtype` currently only works with floating-point types and asserts if for example,
`torch.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception.
"""
if not dtype.is_floating_point:
raise ValueError(
f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype"
)

logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
dtype_orig = torch.get_default_dtype()
torch.set_default_dtype(dtype)
return dtype_orig

@property
def base_model(self) -> nn.Module:
"""
@@ -1543,7 +1598,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
return True

if is_torch_xpu_available():
logger.info("Detect using FlashAttention2 (via kernel `kernels-community/flash-attn2`) on XPU.")
logger.info(
f"Detect using FlashAttention2 (via kernel `{FLASH_ATTN_KERNEL_FALLBACK['flash_attention_2']}`) on XPU."
)
return True

if importlib.util.find_spec("flash_attn") is None:
@@ -1775,14 +1832,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
and is_kernels_available()
and not is_torch_npu_available()
):
if attn_implementation.endswith("2"):
applicable_attn_implementation = "kernels-community/flash-attn2"
if is_torch_xpu_available():
# On XPU, kernels library is the native implementation
# Disabling this flag to avoid giving wrong fallbacks on errors and warnings
requested_original_flash_attn = False
else:
applicable_attn_implementation = "kernels-community/vllm-flash-attn3"
applicable_attn_implementation = FLASH_ATTN_KERNEL_FALLBACK[attn_implementation.removeprefix("paged|")]

if is_torch_xpu_available() and attn_implementation.removeprefix("paged|") == "flash_attention_2":
# On XPU, kernels library is the native implementation
# Disabling this flag to avoid giving wrong fallbacks on errors and warnings
requested_original_flash_attn = False

if is_paged:
applicable_attn_implementation = f"paged|{applicable_attn_implementation}"
@@ -2332,30 +2387,26 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH

tied_keys = list(tied_keys.items())
for i, (target_param_name, source_param_name) in enumerate(tied_keys):
# Usually we tie a single target to a single source, but when both are missing we may later tie
# both the source and target to a third "backup" parameter that is present in the checkpoint, so we use
# a list here
target_param_names = [target_param_name]

# This is `from_pretrained` -> let's check symmetrically in case the source key is not present
if missing_keys is not None:
remove_from_missing = True
source_is_there = source_param_name not in missing_keys
target_is_there = target_param_name not in missing_keys
# Both are already present -> it means the config is wrong and do not reflect the actual
# checkpoint -> let's raise a warning and do nothing
# checkpoint -> let's raise a warning and NOT tie them
if source_is_there and target_is_there:
logger.warning(
f"The tied weights mapping and config for this model specifies to tie {source_param_name} to "
f"{target_param_name}, but both are present in the checkpoints, so we will NOT tie them. "
"You should update the config with `tie_word_embeddings=False` to silence this warning"
)
# Remove from internal attribute to correctly reflect actual tied weights
self.all_tied_weights_keys.pop(target_param_name)
# Skip to next iteration
continue
# We're missing the source but we have the target -> we swap them, tying the parameter that exists
elif not source_is_there and target_is_there:
target_param_name, source_param_name = source_param_name, target_param_name
target_param_names = [target_param_name]
# Both are missing -> check other keys in case more than 2 keys are tied to the same weight
elif not source_is_there and not target_is_there:
for target_backup, source_backup in tied_keys[i + 1 :]:
@@ -2364,10 +2415,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
if source_backup == source_param_name:
target_backup_is_there = target_backup not in missing_keys
# If the target is present, we found the correct weight to tie into (we know the source is missing)
# Note here that we do not tie the missing source right now as well, as it will be done anyway when
# the pair (target_backup, source_backup) becomes the main pair (target_param_name, source_param_name)
if target_backup_is_there:
source_param_name = target_backup
# Append the source as well, since both are missing we'll tie both
target_param_names.append(source_param_name)
break
# If we did not break from the loop, it was impossible to find a source key -> let's raise
else:
@@ -2383,19 +2434,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH

# Perform the actual tying
source_param = self.get_parameter_or_buffer(source_param_name)
for target_param_name in target_param_names:
if "." in target_param_name:
parent_name, name = target_param_name.rsplit(".", 1)
parent = self.get_submodule(parent_name)
else:
name = target_param_name
parent = self
# Tie the weights
setattr(parent, name, source_param)
self._adjust_bias(parent, source_param)
# Remove from missing if necesary
if missing_keys is not None and remove_from_missing:
missing_keys.discard(target_param_name)
if "." in target_param_name:
parent_name, name = target_param_name.rsplit(".", 1)
parent = self.get_submodule(parent_name)
else:
name = target_param_name
parent = self
# Tie the weights
setattr(parent, name, source_param)
self._adjust_bias(parent, source_param)
# Remove from missing if necesary
if missing_keys is not None and remove_from_missing:
missing_keys.discard(target_param_name)

def _adjust_bias(self, output_embeddings, input_embeddings):
if getattr(output_embeddings, "bias", None) is not None and hasattr(output_embeddings, "weight"):
@@ -3157,29 +3207,23 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
current_peft_config = self.peft_config[active_adapter]
current_peft_config.save_pretrained(save_directory)

# for offloaded modules
module_map = {}

# Save the model
# Get the model state_dict
if state_dict is None:
# if any model parameters are offloaded, make module map
if (
hasattr(self, "hf_device_map")
and len(set(self.hf_device_map.values())) > 1
and ("cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values())
):
warnings.warn(
"Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory exceeds the `shard_size` (5GB default)"
)
for name, module in model_to_save.named_modules():
if name == "":
continue
module_state_dict = module.state_dict()

for key in module_state_dict:
module_map[name + f".{key}"] = module
state_dict = model_to_save.state_dict()

# if any model parameters are offloaded, we need to know it for later
is_offloaded = False
if (
hasattr(self, "hf_device_map")
and len(set(self.hf_device_map.values())) > 1
and ("cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values())
):
is_offloaded = True
warnings.warn(
"Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory "
"exceeds the `shard_size` (50GB default)"
)

# Translate state_dict from smp to hf if saving with smp >= 1.10
if IS_SAGEMAKER_MP_POST_1_10:
for smp_to_hf, _ in smp.state.module_manager.translate_functions:
@@ -3196,76 +3240,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
if self._tp_size is not None:
state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)

# Safetensors does not allow tensor aliasing - we're going to remove aliases before saving
ptrs = collections.defaultdict(list)
for name, tensor in state_dict.items():
if not isinstance(tensor, torch.Tensor):
# Sometimes in the state_dict we have non-tensor objects.
# e.g. in bitsandbytes we have some `str` objects in the state_dict
# In the non-tensor case, fall back to the pointer of the object itself
ptrs[id(tensor)].append(name)

elif tensor.device.type == "meta":
# In offloaded cases, there may be meta tensors in the state_dict.
# For these cases, key by the pointer of the original tensor object
# (state_dict tensors are detached and therefore no longer shared)
tensor = self.get_parameter(name)
ptrs[id(tensor)].append(name)

else:
ptrs[id_tensor_storage(tensor)].append(name)

shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}

# Recursively descend to find tied weight keys
_tied_weights_keys = set(_get_tied_weight_keys(self))
error_names = []
to_delete_names = set()
for names in shared_ptrs.values():
# Removing the keys which are declared as known duplicates on
# load. This allows to make sure the name which is kept is consistent.
if _tied_weights_keys is not None:
found = 0
for name in sorted(names):
matches_pattern = any(re.search(pat, name) for pat in _tied_weights_keys)
if matches_pattern and name in state_dict:
found += 1
if found < len(names):
to_delete_names.add(name)
# We are entering a place where the weights and the transformers configuration do NOT match.
shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
# Those are actually tensor sharing but disjoint from each other, we can safely clone them
# Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
for name in disjoint_names:
state_dict[name] = state_dict[name].clone()

# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
# If the link between tensors was done at runtime then `from_pretrained` will not get
# the key back leading to random tensor. A proper warning will be shown
# during reload (if applicable), but since the file is not necessarily compatible with
# the config, better show a proper warning.
shared_names, identical_names = _find_identical(shared_names, state_dict)
# delete tensors that have identical storage
for inames in identical_names:
known = inames.intersection(to_delete_names)
for name in known:
del state_dict[name]
unknown = inames.difference(to_delete_names)
if len(unknown) > 1:
error_names.append(unknown)

if shared_names:
error_names.extend(shared_names)

if len(error_names) > 0:
raise RuntimeError(
f"The weights trying to be saved contained shared tensors {error_names} which are not properly defined. We found `_tied_weights_keys` to be: {_tied_weights_keys}.\n"
"This can also just mean that the module's tied weight keys are wrong vs the actual tied weights in the model.",
)
# Remove tied weights as safetensors do not handle them
state_dict = remove_tied_weights_from_state_dict(state_dict, model_to_save)

# Revert all renaming and/or weight operations
if save_original_format:
state_dict = revert_weight_conversion(self, state_dict)
state_dict = revert_weight_conversion(model_to_save, state_dict)

# Shard the model if it is too big.
if not _hf_peft_config_loaded:
@@ -3305,47 +3285,39 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
and reg.fullmatch(filename_no_suffix) is not None
):
os.remove(full_filename)

# Save the model
filename_to_tensors = state_dict_split.filename_to_tensors.items()
if module_map:
filename_to_tensors = logging.tqdm(filename_to_tensors, desc="Saving checkpoint shards")
for shard_file, tensors in filename_to_tensors:
shard = {}
for tensor in tensors:
if _is_dtensor_available and isinstance(state_dict[tensor], DTensor):
full_tensor = state_dict[tensor].full_tensor()
for shard_file, tensor_names in logging.tqdm(
state_dict_split.filename_to_tensors.items(), desc="Writing model shards"
):
filename = os.path.join(save_directory, shard_file)
shard_state_dict = {}
for tensor_name in tensor_names:
# Get the tensor, and remove it from state_dict to avoid keeping the ref
tensor = state_dict.pop(tensor_name)

# In case of TP, get the full parameter back
if _is_dtensor_available and isinstance(tensor, DTensor):
tensor = tensor.full_tensor()
# to get the correctly ordered tensor we need to repack if packed
if _get_parameter_tp_plan(tensor, self._tp_plan) == "local_packed_rowwise":
full_tensor = repack_weights(full_tensor, -1, self._tp_size, 2)
shard[tensor] = full_tensor.contiguous() # only do contiguous after it's permuted correctly
else:
shard[tensor] = state_dict[tensor].contiguous()
# delete reference, see https://github.com/huggingface/transformers/pull/34890
del state_dict[tensor]

# remake shard with onloaded parameters if necessary
if module_map:
# init state_dict for this shard
shard_state_dict = dict.fromkeys(shard, "")
for module_name in shard:
# note that get_state_dict_from_offload can update with meta tensors
# if both a parent module and its descendant are offloaded
tensor = shard_state_dict[module_name]
if tensor == "" or (isinstance(tensor, torch.Tensor) and tensor.device.type == "meta"):
# update state dict with onloaded parameters
module = module_map[module_name]
shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict)

# assign shard to be the completed state dict
shard = shard_state_dict
del shard_state_dict
gc.collect()

# TODO: we should def parallelize this we are otherwise just waiting
# too much before scheduling the next write when its in a different file
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)

del state_dict
if _get_parameter_tp_plan(tensor_name, self._tp_plan) == "local_packed_rowwise":
tensor = repack_weights(tensor, -1, self._tp_size, 2)

# If the param was offloaded, we need to load it back from disk to resave it. It's a strange pattern,
# but it would otherwise not be contained in the saved shard if we were to simply move the file
# or something
if is_offloaded and tensor.device.type == "meta":
tensor = load_offloaded_parameter(model_to_save, tensor_name)

# only do contiguous after it's permuted correctly in case of TP
shard_state_dict[tensor_name] = tensor.contiguous()

# TODO: it would be very nice to do the writing concurrently, but safetensors never releases the GIL,
# so it's not possible for now....
# Write the shard to disk
safe_save_file(shard_state_dict, filename, metadata=metadata)
# Cleanup the data before next loop (important with offloading, so we don't blowup cpu RAM)
del shard_state_dict

if index is None:
path_to_weights = os.path.join(save_directory, weights_name)
@@ -3522,11 +3494,13 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
return super().float(*args)

@classmethod
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
def get_init_context(cls, dtype: torch.dtype, is_quantized: bool, _is_ds_init_called: bool):
# Need to instantiate with correct dtype
init_contexts = [local_torch_dtype(dtype, cls.__name__)]
if is_deepspeed_zero3_enabled():
import deepspeed

init_contexts = [no_init_weights()]
init_contexts.append(no_init_weights())
# We cannot initialize the model on meta device with deepspeed when not quantized
if not is_quantized and not _is_ds_init_called:
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
@@ -3534,7 +3508,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
elif is_quantized:
init_contexts.extend([init_empty_weights(), set_quantized_state()])
else:
init_contexts = [no_init_weights(), init_empty_weights()]
init_contexts.extend([no_init_weights(), init_empty_weights()])

return init_contexts

@@ -3568,7 +3542,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
self.use_kernels = False

@classmethod
@restore_default_dtype
def from_pretrained(
cls: type[SpecificPreTrainedModelType],
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
@@ -3822,6 +3795,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
# For BC on torch_dtype argument
if torch_dtype is not None:
dtype = dtype if dtype is not None else torch_dtype
if dtype is None:
dtype = "auto"

if is_offline_mode() and not local_files_only:
local_files_only = True
@@ -3946,12 +3921,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
]

# Find the correct dtype based on current state
config, dtype, dtype_orig = _get_dtype(
cls, dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only
)
config, dtype = _get_dtype(dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only)

config.name_or_path = pretrained_model_name_or_path
model_init_context = cls.get_init_context(is_quantized, _is_ds_init_called)
model_init_context = cls.get_init_context(dtype, is_quantized, _is_ds_init_called)
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
with ContextManagers(model_init_context):
# Let's make sure we don't run the init function of buffer modules
@@ -3980,10 +3953,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
if device_map is not None:
device_map = _get_device_map(model, device_map, max_memory, hf_quantizer)

# restore default dtype
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

# Finalize model weight initialization
model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs = cls._load_pretrained_model(
model,
@@ -4102,7 +4071,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
state_dict = merged_state_dict
error_msgs += _load_state_dict_into_zero3_model(model, state_dict)
# This is not true but for now we assume only best-case scenario with deepspeed, i.e. perfectly matching checkpoints
missing_keys, unexpected_keys, mismatched_keys, misc = set(), set(), set(), set()
missing_keys, unexpected_keys, mismatched_keys, conversion_errors = set(), set(), set(), set()
else:
all_pointer = set()
# Checkpoints are safetensors
@@ -4124,7 +4093,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
else:
raise ValueError("Neither a state dict nor checkpoint files were found.")

missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, misc = (
missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, conversion_errors = (
convert_and_load_state_dict_in_model(
model,
merged_state_dict,
@@ -4198,7 +4167,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
missing_keys=missing_keys,
mismatched_keys=mismatched_keys,
mismatched_shapes=mismatched_keys,
misc=misc,
conversion_errors=conversion_errors,
ignore_mismatched_sizes=ignore_mismatched_sizes,
)



+ 1
- 0
src/transformers/models/__init__.py View File

@@ -265,6 +265,7 @@ if TYPE_CHECKING:
from .ovis2 import *
from .owlv2 import *
from .owlvit import *
from .paddleocr_vl import *
from .paligemma import *
from .parakeet import *
from .patchtsmixer import *


+ 1
- 0
src/transformers/models/align/modeling_align.py View File

@@ -976,6 +976,7 @@ class AlignVisionModel(AlignPreTrainedModel):
main_input_name = "pixel_values"
input_modalities = ("image",)
supports_gradient_checkpointing = False
_no_split_modules = ["AlignVisionBlock"]

def __init__(self, config: AlignVisionConfig):
super().__init__(config)


+ 3
- 1
src/transformers/models/apertus/modeling_apertus.py View File

@@ -25,7 +25,7 @@ from typing import Optional, Union
import torch
from torch import nn

from ...activations import ACT2FN
from ...activations import ACT2CLS, ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
@@ -49,6 +49,8 @@ class ApertusMLP(nn.Module):
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
if config.hidden_act == "xielu":
self.act_fn = ACT2CLS["xielu"](dtype=config.dtype)

def forward(self, x):
return self.down_proj(self.act_fn(self.up_proj(x)))


+ 4
- 1
src/transformers/models/apertus/modular_apertus.py View File

@@ -19,6 +19,7 @@ from typing import Optional
import torch
from torch import nn

from ...activations import ACT2CLS
from ...cache_utils import Cache
from ...configuration_utils import PreTrainedConfig
from ...modeling_rope_utils import RopeParameters
@@ -192,9 +193,11 @@ class ApertusConfig(PreTrainedConfig):

class ApertusMLP(NemotronMLP):
def __init__(self, config):
super().__init__()
super().__init__(config)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
if config.hidden_act == "xielu":
self.act_fn = ACT2CLS["xielu"](dtype=config.dtype)


class ApertusRMSNorm(LlamaRMSNorm):


+ 1
- 1
src/transformers/models/auto/auto_factory.py View File

@@ -543,7 +543,7 @@ def add_generation_mixin_to_remote_model(model_class):

class _LazyAutoMapping(OrderedDict[type[PreTrainedConfig], _LazyAutoMappingValue]):
"""
" A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.

Args:
- config_mapping: The map model type to config class


+ 2
- 0
src/transformers/models/auto/configuration_auto.py View File

@@ -300,6 +300,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
("ovis2", "Ovis2Config"),
("owlv2", "Owlv2Config"),
("owlvit", "OwlViTConfig"),
("paddleocr_vl", "PaddleOCRVLConfig"),
("paligemma", "PaliGemmaConfig"),
("parakeet_ctc", "ParakeetCTCConfig"),
("parakeet_encoder", "ParakeetEncoderConfig"),
@@ -754,6 +755,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
("ovis2", "Ovis2"),
("owlv2", "OWLv2"),
("owlvit", "OWL-ViT"),
("paddleocr_vl", "PaddleOCRVL"),
("paligemma", "PaliGemma"),
("parakeet", "Parakeet"),
("parakeet_ctc", "Parakeet"),


+ 1
- 0
src/transformers/models/auto/image_processing_auto.py View File

@@ -153,6 +153,7 @@ else:
("ovis2", ("Ovis2ImageProcessor", "Ovis2ImageProcessorFast")),
("owlv2", ("Owlv2ImageProcessor", "Owlv2ImageProcessorFast")),
("owlvit", ("OwlViTImageProcessor", "OwlViTImageProcessorFast")),
("paddleocr_vl", ("PaddleOCRVLImageProcessor", "PaddleOCRVLImageProcessorFast")),
("paligemma", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
("perceiver", ("PerceiverImageProcessor", "PerceiverImageProcessorFast")),
("perception_lm", (None, "PerceptionLMImageProcessorFast")),


+ 1
- 0
src/transformers/models/auto/modeling_auto.py View File

@@ -1026,6 +1026,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
("mistral3", "Mistral3ForConditionalGeneration"),
("mllama", "MllamaForConditionalGeneration"),
("ovis2", "Ovis2ForConditionalGeneration"),
("paddleocr_vl", "PaddleOCRVLForConditionalGeneration"),
("paligemma", "PaliGemmaForConditionalGeneration"),
("perception_lm", "PerceptionLMForConditionalGeneration"),
("pix2struct", "Pix2StructForConditionalGeneration"),


+ 1
- 0
src/transformers/models/auto/processing_auto.py View File

@@ -114,6 +114,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("ovis2", "Ovis2Processor"),
("owlv2", "Owlv2Processor"),
("owlvit", "OwlViTProcessor"),
("paddleocr_vl", "PaddleOCRVLProcessor"),
("paligemma", "PaliGemmaProcessor"),
("perception_lm", "PerceptionLMProcessor"),
("phi4_multimodal", "Phi4MultimodalProcessor"),


+ 1
- 0
src/transformers/models/auto/tokenization_auto.py View File

@@ -273,6 +273,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, Optional[str]](
("ovis2", "Qwen2TokenizerFast" if is_tokenizers_available() else None),
("owlv2", "CLIPTokenizerFast" if is_tokenizers_available() else None),
("owlvit", "CLIPTokenizerFast" if is_tokenizers_available() else None),
("paddleocr_vl", "TokenizersBackend" if is_tokenizers_available() else None),
("paligemma", "LlamaTokenizer" if is_tokenizers_available() else None),
("pegasus", "PegasusTokenizer" if is_tokenizers_available() else None),
("pegasus_x", "PegasusTokenizer" if is_tokenizers_available() else None),


+ 15
- 16
src/transformers/models/bamba/modeling_bamba.py View File

@@ -36,6 +36,7 @@ from ... import initialization as init
from ...cache_utils import Cache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
from ...integrations.hub_kernels import lazy_load_kernel
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -44,22 +45,9 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils.generic import maybe_autocast
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
from .configuration_bamba import BambaConfig


if is_mamba_2_ssm_available():
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
else:
selective_state_update = None

if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_update, causal_conv1d_fn = None, None


logger = logging.get_logger(__name__)


@@ -501,9 +489,6 @@ def apply_mask_to_padding_states(hidden_states, attention_mask):
return hidden_states


is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))


# Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer
class BambaMixer(nn.Module):
"""
@@ -575,6 +560,20 @@ class BambaMixer(nn.Module):

self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)

global causal_conv1d_update, causal_conv1d_fn
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)

global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update = getattr(mamba_ssm, "selective_state_update", None)
mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None)
mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None)

global is_fast_path_available
is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))

if not is_fast_path_available:
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"


+ 15
- 15
src/transformers/models/bamba/modular_bamba.py View File

@@ -43,6 +43,7 @@ from transformers.models.mamba2.modeling_mamba2 import (
)

from ... import initialization as init
from ...integrations.hub_kernels import lazy_load_kernel
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel
@@ -52,24 +53,9 @@ from ...utils import (
can_return_tuple,
logging,
)
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
from .configuration_bamba import BambaConfig


if is_mamba_2_ssm_available():
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
else:
selective_state_update = None

if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_update, causal_conv1d_fn = None, None

is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))


logger = logging.get_logger(__name__)


@@ -276,6 +262,20 @@ class BambaMixer(nn.Module):

self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)

global causal_conv1d_update, causal_conv1d_fn
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)

global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update = getattr(mamba_ssm, "selective_state_update", None)
mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None)
mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None)

global is_fast_path_available
is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))

if not is_fast_path_available:
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"


+ 1
- 0
src/transformers/models/bart/modeling_bart.py View File

@@ -1463,6 +1463,7 @@ class BartDecoderWrapper(BartPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.decoder = BartDecoder(config)
self.post_init()

def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)


+ 0
- 1
src/transformers/models/beit/image_processing_beit_fast.py View File

@@ -163,7 +163,6 @@ class BeitImageProcessorFast(BaseImageProcessorFast):
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)



+ 1
- 1
src/transformers/models/beit/modeling_beit.py View File

@@ -216,7 +216,7 @@ class BeitPatchEmbeddings(nn.Module):
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)

embeddings = self.projection(pixel_values)
embeddings = self.projection(pixel_values.to(self.projection.weight.dtype))
patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
embeddings = embeddings.flatten(2).transpose(1, 2)



+ 2
- 2
src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py View File

@@ -1154,7 +1154,6 @@ class BigBirdPegasusEncoderAttention(nn.Module):
return outputs


# Copied from transformers.models.bert.modeling_bert.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
@@ -1178,7 +1177,7 @@ def eager_attention_forward(
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

attn_output = torch.matmul(attn_weights, value)
attn_output = torch.matmul(attn_weights.to(value.dtype), value)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, attn_weights
@@ -2583,6 +2582,7 @@ class BigBirdPegasusDecoderWrapper(BigBirdPegasusPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.decoder = BigBirdPegasusDecoder(config)
self.post_init()

def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)


+ 1
- 0
src/transformers/models/blenderbot/modeling_blenderbot.py View File

@@ -1156,6 +1156,7 @@ class BlenderbotDecoderWrapper(BlenderbotPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.decoder = BlenderbotDecoder(config)
self.post_init()

def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)


+ 1
- 0
src/transformers/models/blenderbot_small/modeling_blenderbot_small.py View File

@@ -1116,6 +1116,7 @@ class BlenderbotSmallDecoderWrapper(BlenderbotSmallPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.decoder = BlenderbotSmallDecoder(config)
self.post_init()

def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)


+ 2
- 0
src/transformers/models/blip/modeling_blip_text.py View File

@@ -740,6 +740,8 @@ class BlipTextLMHeadModel(BlipTextPreTrainedModel, GenerationMixin):
self.cls = BlipTextOnlyMLMHead(config)
self.label_smoothing = config.label_smoothing

self.post_init()

def get_input_embeddings(self):
return self.bert.get_input_embeddings()



+ 1
- 1
src/transformers/models/blip_2/modeling_blip_2.py View File

@@ -603,7 +603,7 @@ class Blip2QFormerMultiHeadAttention(nn.Module):

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs_dropped = self.dropout(attention_probs)
attention_probs_dropped = self.dropout(attention_probs).to(value_layer.dtype)

context_layer = torch.matmul(attention_probs_dropped, value_layer)



+ 152
- 1
src/transformers/models/blt/modeling_blt.py View File

@@ -444,6 +444,155 @@ class BltPreTrainedModel(PreTrainedModel):
"attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"),
}

@torch.no_grad()
def _init_weights(self, module):
"""
Initialize BLT weights following the original ByteLatentTransformer:

- Most weights are drawn from a truncated normal.
- Scale is ~ 1 / sqrt(model_dim) (or 1 / sqrt(hidden_dim) for FFN outputs).
- Norm layers are set to weight = 1, bias = 0.
"""
class_name = module.__class__.__name__

# Norms: RMSNorm / LayerNorm
if isinstance(module, (BltRMSNorm, nn.LayerNorm)) or "RMSNorm" in class_name or "LayerNorm" in class_name:
if getattr(module, "weight", None) is not None:
nn.init.ones_(module.weight)
if getattr(module, "bias", None) is not None:
nn.init.zeros_(module.bias)
return

# Embeddings (encoder / patcher / hash embeddings)
if isinstance(module, nn.Embedding):
hidden_size = getattr(self.config, "hidden_size", None)
if hidden_size is None and hasattr(self.config, "encoder_config"):
hidden_size = getattr(self.config.encoder_config, "hidden_size", None)
if hidden_size is None:
hidden_size = module.embedding_dim

std = hidden_size**-0.5
nn.init.trunc_normal_(
module.weight,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
)
if module.padding_idx is not None:
nn.init.zeros_(module.weight[module.padding_idx])
return

# Self-attention / cross-attention projections
if isinstance(module, (BltSelfAttention, BltCrossAttention)) or class_name in (
"MllamaTextSelfAttention",
"MllamaTextCrossAttention",
):
dim = getattr(self.config, "hidden_size", None)
if dim is None and hasattr(module, "hidden_size"):
dim = module.hidden_size
if dim is None:
for name in ("q_proj", "k_proj", "v_proj", "o_proj", "dense"):
proj = getattr(module, name, None)
if proj is not None and hasattr(proj, "weight"):
dim = proj.weight.shape[-1]
break
if dim is None:
return

std = dim**-0.5

# Input projections (q, k, v)
for proj_name in ("q_proj", "k_proj", "v_proj"):
proj = getattr(module, proj_name, None)
if proj is not None and hasattr(proj, "weight"):
nn.init.trunc_normal_(
proj.weight,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
)
if getattr(proj, "bias", None) is not None:
nn.init.zeros_(proj.bias)

# Output projection: o_proj or dense
o_proj = getattr(module, "o_proj", getattr(module, "dense", None))
if o_proj is not None and hasattr(o_proj, "weight"):
nn.init.trunc_normal_(
o_proj.weight,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
)
if getattr(o_proj, "bias", None) is not None:
nn.init.zeros_(o_proj.bias)
return

# MLP / FFN blocks
if isinstance(module, BltMLP) or class_name == "MllamaTextMLP":
hidden_size = getattr(self.config, "hidden_size", None)
if hidden_size is None and hasattr(self.config, "decoder_config"):
hidden_size = getattr(self.config.decoder_config, "hidden_size", None)
if hidden_size is None and hasattr(self.config, "encoder_config"):
hidden_size = getattr(self.config.encoder_config, "hidden_size", None)

# Input-side std
in_std = None
if hidden_size is not None:
in_std = hidden_size**-0.5

gate_proj = getattr(module, "gate_proj", getattr(module, "fc1", None))
up_proj = getattr(module, "up_proj", None)
down_proj = getattr(module, "down_proj", getattr(module, "fc2", None))

# gate / input projections
for proj in (gate_proj, up_proj):
if proj is not None and hasattr(proj, "weight"):
std = in_std or (proj.weight.shape[1] ** -0.5)
nn.init.trunc_normal_(
proj.weight,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
)
if getattr(proj, "bias", None) is not None:
nn.init.zeros_(proj.bias)

# output/ down projections
if down_proj is not None and hasattr(down_proj, "weight"):
hidden_dim = down_proj.weight.shape[1]
out_std = hidden_dim**-0.5
nn.init.trunc_normal_(
down_proj.weight,
mean=0.0,
std=out_std,
a=-3 * out_std,
b=3 * out_std,
)
if getattr(down_proj, "bias", None) is not None:
nn.init.zeros_(down_proj.bias)
return

# Generic Linear layers (projections, lm_head, etc.)
if isinstance(module, nn.Linear):
fan_in = module.in_features
std = fan_in**-0.5
nn.init.trunc_normal_(
module.weight,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
)
if module.bias is not None:
nn.init.zeros_(module.bias)
return

return


class BltLocalEncoder(BltPreTrainedModel):
config: BltLocalEncoderConfig
@@ -753,6 +902,8 @@ class BltPatcher(BltPreTrainedModel):
bias=False,
)

self.post_init()

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@@ -952,7 +1103,7 @@ def compute_hash_embeddings(
hash_ids = byte_group_hash_function(local_encoder_tokens, group_size, prime, encoder_hash_byte_group_vocab)
# Apply offset to get the correct slice of the fused embedding
offset_hash_ids = hash_ids + embedding_idx * encoder_hash_byte_group_vocab
embeddings += encoder_hash_tok_embedding(offset_hash_ids)
embeddings += encoder_hash_tok_embedding(offset_hash_ids).to(embeddings.device)
embedding_idx += 1

return embeddings


+ 158
- 2
src/transformers/models/blt/modular_blt.py View File

@@ -133,7 +133,7 @@ def compute_hash_embeddings(
hash_ids = byte_group_hash_function(local_encoder_tokens, group_size, prime, encoder_hash_byte_group_vocab)
# Apply offset to get the correct slice of the fused embedding
offset_hash_ids = hash_ids + embedding_idx * encoder_hash_byte_group_vocab
embeddings += encoder_hash_tok_embedding(offset_hash_ids)
embeddings += encoder_hash_tok_embedding(offset_hash_ids).to(embeddings.device)
embedding_idx += 1

return embeddings
@@ -360,8 +360,162 @@ class BltPreTrainedModel(MllamaPreTrainedModel):
"attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"),
}

# Weight initialization is adapted from:
# - https://github.com/facebookresearch/blt/blob/main/bytelatent/model/blt.py
# - https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/transformers_modeling_backend/model/model.py
#
# Both implementations use truncated normal initialization with std ~ 1 / sqrt(d_model)
# (or 1 / sqrt(hidden_dim) for FFN outputs), and unit initialization for normalization layers.
# We follow the same scheme here, but expressed in the Transformers APIs.

@torch.no_grad()
def _init_weights(self, module):
raise AttributeError("No need to inherit it!")
"""
Initialize BLT weights following the original ByteLatentTransformer:

- Most weights are drawn from a truncated normal.
- Scale is ~ 1 / sqrt(model_dim) (or 1 / sqrt(hidden_dim) for FFN outputs).
- Norm layers are set to weight = 1, bias = 0.
"""
class_name = module.__class__.__name__

# Norms: RMSNorm / LayerNorm
if isinstance(module, (BltRMSNorm, nn.LayerNorm)) or "RMSNorm" in class_name or "LayerNorm" in class_name:
if getattr(module, "weight", None) is not None:
nn.init.ones_(module.weight)
if getattr(module, "bias", None) is not None:
nn.init.zeros_(module.bias)
return

# Embeddings (encoder / patcher / hash embeddings)
if isinstance(module, nn.Embedding):
hidden_size = getattr(self.config, "hidden_size", None)
if hidden_size is None and hasattr(self.config, "encoder_config"):
hidden_size = getattr(self.config.encoder_config, "hidden_size", None)
if hidden_size is None:
hidden_size = module.embedding_dim

std = hidden_size**-0.5
nn.init.trunc_normal_(
module.weight,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
)
if module.padding_idx is not None:
nn.init.zeros_(module.weight[module.padding_idx])
return

# Self-attention / cross-attention projections
if isinstance(module, (BltSelfAttention, BltCrossAttention)) or class_name in (
"MllamaTextSelfAttention",
"MllamaTextCrossAttention",
):
dim = getattr(self.config, "hidden_size", None)
if dim is None and hasattr(module, "hidden_size"):
dim = module.hidden_size
if dim is None:
for name in ("q_proj", "k_proj", "v_proj", "o_proj", "dense"):
proj = getattr(module, name, None)
if proj is not None and hasattr(proj, "weight"):
dim = proj.weight.shape[-1]
break
if dim is None:
return

std = dim**-0.5

# Input projections (q, k, v)
for proj_name in ("q_proj", "k_proj", "v_proj"):
proj = getattr(module, proj_name, None)
if proj is not None and hasattr(proj, "weight"):
nn.init.trunc_normal_(
proj.weight,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
)
if getattr(proj, "bias", None) is not None:
nn.init.zeros_(proj.bias)

# Output projection: o_proj or dense
o_proj = getattr(module, "o_proj", getattr(module, "dense", None))
if o_proj is not None and hasattr(o_proj, "weight"):
nn.init.trunc_normal_(
o_proj.weight,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
)
if getattr(o_proj, "bias", None) is not None:
nn.init.zeros_(o_proj.bias)
return

# MLP / FFN blocks
if isinstance(module, BltMLP) or class_name == "MllamaTextMLP":
hidden_size = getattr(self.config, "hidden_size", None)
if hidden_size is None and hasattr(self.config, "decoder_config"):
hidden_size = getattr(self.config.decoder_config, "hidden_size", None)
if hidden_size is None and hasattr(self.config, "encoder_config"):
hidden_size = getattr(self.config.encoder_config, "hidden_size", None)

# Input-side std
in_std = None
if hidden_size is not None:
in_std = hidden_size**-0.5

gate_proj = getattr(module, "gate_proj", getattr(module, "fc1", None))
up_proj = getattr(module, "up_proj", None)
down_proj = getattr(module, "down_proj", getattr(module, "fc2", None))

# gate / input projections
for proj in (gate_proj, up_proj):
if proj is not None and hasattr(proj, "weight"):
std = in_std or (proj.weight.shape[1] ** -0.5)
nn.init.trunc_normal_(
proj.weight,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
)
if getattr(proj, "bias", None) is not None:
nn.init.zeros_(proj.bias)

# output/ down projections
if down_proj is not None and hasattr(down_proj, "weight"):
hidden_dim = down_proj.weight.shape[1]
out_std = hidden_dim**-0.5
nn.init.trunc_normal_(
down_proj.weight,
mean=0.0,
std=out_std,
a=-3 * out_std,
b=3 * out_std,
)
if getattr(down_proj, "bias", None) is not None:
nn.init.zeros_(down_proj.bias)
return

# Generic Linear layers (projections, lm_head, etc.)
if isinstance(module, nn.Linear):
fan_in = module.in_features
std = fan_in**-0.5
nn.init.trunc_normal_(
module.weight,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
)
if module.bias is not None:
nn.init.zeros_(module.bias)
return

return

def _update_causal_mask(self, module):
raise AttributeError("No need to inherit it!")
@@ -634,6 +788,8 @@ class BltPatcher(BltPreTrainedModel):
bias=False,
)

self.post_init()

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,


+ 0
- 2
src/transformers/models/bridgetower/image_processing_bridgetower_fast.py View File

@@ -251,10 +251,8 @@ class BridgeTowerImageProcessorFast(BaseImageProcessorFast):
processed_images, processed_masks = self.pad(
processed_images, return_mask=True, disable_grouping=disable_grouping
)
processed_masks = torch.stack(processed_masks, dim=0) if return_tensors else processed_masks
data["pixel_mask"] = processed_masks

processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
data["pixel_values"] = processed_images

return BatchFeature(data=data, tensor_type=return_tensors)


+ 1
- 0
src/transformers/models/bridgetower/modeling_bridgetower.py View File

@@ -955,6 +955,7 @@ class BridgeTowerVisionModel(BridgeTowerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.visual = BridgeTowerVisionTransformer(config)
self.post_init()

@property
def dtype(self):


+ 1
- 0
src/transformers/models/chameleon/modeling_chameleon.py View File

@@ -809,6 +809,7 @@ class ChameleonVQVAE(ChameleonPreTrainedModel):
self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
self.eval() # Chameleon's VQ model is frozen
self.post_init()

def encode(self, pixel_values: torch.LongTensor):
hidden_states = self.encoder(pixel_values)


+ 2
- 0
src/transformers/models/clipseg/modeling_clipseg.py View File

@@ -1121,6 +1121,8 @@ class CLIPSegDecoder(CLIPSegPreTrainedModel):
decoder_config.hidden_act = "relu"
self.layers = nn.ModuleList([CLIPSegDecoderLayer(decoder_config) for _ in range(len(config.extract_layers))])

self.post_init()

def forward(
self,
hidden_states: tuple[torch.Tensor],


+ 0
- 1
src/transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py View File

@@ -263,7 +263,6 @@ class Cohere2VisionImageProcessorFast(BaseImageProcessorFast):
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

return BatchFeature(
data={"pixel_values": processed_images, "num_patches": num_patches}, tensor_type=return_tensors


+ 0
- 1
src/transformers/models/convnext/image_processing_convnext_fast.py View File

@@ -162,7 +162,6 @@ class ConvNextImageProcessorFast(BaseImageProcessorFast):
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)



+ 4
- 4
src/transformers/models/dac/modeling_dac.py View File

@@ -16,7 +16,7 @@

import math
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Union

import numpy as np
import torch
@@ -583,7 +583,7 @@ class DacModel(DacPreTrainedModel):
input_values: torch.Tensor,
n_quantizers: Optional[int] = None,
return_dict: Optional[bool] = None,
):
) -> Union[tuple, DacEncoderOutput]:
r"""
input_values (`torch.Tensor of shape `(batch_size, 1, time_steps)`):
Input audio data to encode,
@@ -610,7 +610,7 @@ class DacModel(DacPreTrainedModel):
quantized_representation: Optional[torch.Tensor] = None,
audio_codes: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
):
) -> Union[tuple, DacDecoderOutput]:
r"""
quantized_representation (torch.Tensor of shape `(batch_size, dimension, time_steps)`, *optional*):
Quantized continuous representation of input.
@@ -643,7 +643,7 @@ class DacModel(DacPreTrainedModel):
input_values: torch.Tensor,
n_quantizers: Optional[int] = None,
return_dict: Optional[bool] = None,
):
) -> Union[tuple, DacOutput]:
r"""
input_values (`torch.Tensor` of shape `(batch_size, 1, time_steps)`):
Audio data to encode.


+ 1
- 1
src/transformers/models/data2vec/modeling_data2vec_vision.py View File

@@ -216,7 +216,7 @@ class Data2VecVisionPatchEmbeddings(nn.Module):
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)

embeddings = self.projection(pixel_values)
embeddings = self.projection(pixel_values.to(self.projection.weight.dtype))
patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
embeddings = embeddings.flatten(2).transpose(1, 2)



+ 0
- 4
src/transformers/models/decision_transformer/modeling_decision_transformer.py View File

@@ -367,12 +367,8 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
config: DecisionTransformerConfig
base_model_prefix = "transformer"
supports_gradient_checkpointing = True

_can_compile_fullgraph = False

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)

@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights."""


+ 1
- 1
src/transformers/models/deepseek_v3/modeling_deepseek_v3.py View File

@@ -157,7 +157,7 @@ class DeepseekV3NaiveMoe(nn.Module):
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
self.intermediate_dim = config.moe_intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]


+ 1
- 0
src/transformers/models/deepseek_v3/modular_deepseek_v3.py View File

@@ -107,6 +107,7 @@ class DeepseekV3NaiveMoe(MixtralExperts):
def __init__(self, config):
super().__init__(config)
self.num_experts = config.num_local_experts
self.intermediate_dim = config.moe_intermediate_size


class DeepseekV3MoE(nn.Module):


+ 0
- 1
src/transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py View File

@@ -171,7 +171,6 @@ class DeepseekVLImageProcessorFast(BaseImageProcessorFast):
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)



+ 2
- 2
src/transformers/models/deepseek_vl/modeling_deepseek_vl.py View File

@@ -196,7 +196,7 @@ class DeepseekVLModel(DeepseekVLPreTrainedModel):
use_cache: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
):
) -> DeepseekVLBaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
@@ -268,7 +268,7 @@ class DeepseekVLForConditionalGeneration(DeepseekVLPreTrainedModel, GenerationMi
use_cache: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
):
) -> DeepseekVLCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,


+ 0
- 4
src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py View File

@@ -207,9 +207,6 @@ class DeepseekVLHybridImageProcessorFast(BaseImageProcessorFast):
)
high_res_processed_images_grouped[shape] = stacked_high_res_images
high_res_processed_images = reorder_images(high_res_processed_images_grouped, grouped_high_res_images_index)
high_res_processed_images = (
torch.stack(high_res_processed_images, dim=0) if return_tensors else high_res_processed_images
)

resized_images_grouped = {}
for shape, stacked_high_res_padded_images in high_res_padded_images.items():
@@ -233,7 +230,6 @@ class DeepseekVLHybridImageProcessorFast(BaseImageProcessorFast):
)
processed_images_grouped[shape] = stacked_images
processed_images = reorder_images(processed_images_grouped, grouped_resized_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

return BatchFeature(
data={"pixel_values": processed_images, "high_res_pixel_values": high_res_processed_images},


+ 2
- 2
src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py View File

@@ -314,7 +314,7 @@ class DeepseekVLHybridModel(DeepseekVLHybridPreTrainedModel):
use_cache: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
):
) -> DeepseekVLHybridBaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
@@ -424,7 +424,7 @@ class DeepseekVLHybridForConditionalGeneration(DeepseekVLHybridPreTrainedModel,
use_cache: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
):
) -> DeepseekVLHybridCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,


+ 2
- 6
src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py View File

@@ -297,7 +297,7 @@ class DeepseekVLHybridModel(DeepseekVLModel):
use_cache: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
):
) -> DeepseekVLHybridBaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
@@ -361,7 +361,7 @@ class DeepseekVLHybridForConditionalGeneration(DeepseekVLForConditionalGeneratio
use_cache: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
):
) -> DeepseekVLHybridCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@@ -888,9 +888,6 @@ class DeepseekVLHybridImageProcessorFast(DeepseekVLImageProcessorFast):
)
high_res_processed_images_grouped[shape] = stacked_high_res_images
high_res_processed_images = reorder_images(high_res_processed_images_grouped, grouped_high_res_images_index)
high_res_processed_images = (
torch.stack(high_res_processed_images, dim=0) if return_tensors else high_res_processed_images
)

resized_images_grouped = {}
for shape, stacked_high_res_padded_images in high_res_padded_images.items():
@@ -914,7 +911,6 @@ class DeepseekVLHybridImageProcessorFast(DeepseekVLImageProcessorFast):
)
processed_images_grouped[shape] = stacked_images
processed_images = reorder_images(processed_images_grouped, grouped_resized_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

return BatchFeature(
data={"pixel_values": processed_images, "high_res_pixel_values": high_res_processed_images},


+ 0
- 1
src/transformers/models/depth_pro/image_processing_depth_pro_fast.py View File

@@ -94,7 +94,6 @@ class DepthProImageProcessorFast(BaseImageProcessorFast):
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)



+ 4
- 0
src/transformers/models/dia/modeling_dia.py View File

@@ -452,6 +452,8 @@ class DiaEncoder(DiaPreTrainedModel):
self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
self.rotary_emb = DiaRotaryEmbedding(config=config)

self.post_init()

@auto_docstring
@can_return_tuple
def forward(
@@ -578,6 +580,8 @@ class DiaDecoder(DiaPreTrainedModel):
self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
self.rotary_emb = DiaRotaryEmbedding(config=config)

self.post_init()

@auto_docstring
@can_return_tuple
def forward(


+ 4
- 0
src/transformers/models/dia/modular_dia.py View File

@@ -241,6 +241,8 @@ class DiaEncoder(DiaPreTrainedModel):
self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
self.rotary_emb = DiaRotaryEmbedding(config=config)

self.post_init()

@auto_docstring
@can_return_tuple
def forward(
@@ -367,6 +369,8 @@ class DiaDecoder(DiaPreTrainedModel):
self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
self.rotary_emb = DiaRotaryEmbedding(config=config)

self.post_init()

@auto_docstring
@can_return_tuple
def forward(


+ 0
- 1
src/transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py View File

@@ -88,7 +88,6 @@ class DINOv3ViTImageProcessorFast(BaseImageProcessorFast):
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)



+ 0
- 1
src/transformers/models/donut/image_processing_donut_fast.py View File

@@ -231,7 +231,6 @@ class DonutImageProcessorFast(BaseImageProcessorFast):
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)



+ 1
- 1
src/transformers/models/dots1/modeling_dots1.py View File

@@ -315,7 +315,7 @@ class Dots1NaiveMoe(nn.Module):
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
self.intermediate_dim = config.moe_intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]


+ 1
- 2
src/transformers/models/dpt/image_processing_dpt_fast.py View File

@@ -225,8 +225,7 @@ class DPTImageProcessorFast(BaseImageProcessorFast):
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(data={"pixel_values": processed_images})
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)

def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None):
"""


+ 1
- 2
src/transformers/models/dpt/modular_dpt.py View File

@@ -228,8 +228,7 @@ class DPTImageProcessorFast(BeitImageProcessorFast):
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(data={"pixel_values": processed_images})
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)

def post_process_depth_estimation(
self,


+ 1
- 1
src/transformers/models/edgetam/modeling_edgetam.py View File

@@ -393,7 +393,7 @@ class EdgeTamVisionNeck(nn.Module):
n = len(self.convs) - 1
for i in range(n, -1, -1):
lateral_features = hidden_states[i].permute(0, 3, 1, 2)
lateral_features = self.convs[n - i](lateral_features)
lateral_features = self.convs[n - i](lateral_features.to(self.convs[i].weight.dtype))
if i not in self.fpn_top_down_levels or i == n:
prev_features = lateral_features
else:


+ 1
- 2
src/transformers/models/efficientloftr/image_processing_efficientloftr_fast.py View File

@@ -153,9 +153,8 @@ class EfficientLoFTRImageProcessorFast(BaseImageProcessorFast):
stacked_pairs = [torch.stack(pair, dim=0) for pair in image_pairs]

# Return in same format as slow processor
image_pairs = torch.stack(stacked_pairs, dim=0) if return_tensors else stacked_pairs

return BatchFeature(data={"pixel_values": image_pairs})
return BatchFeature(data={"pixel_values": stacked_pairs}, tensor_type=return_tensors)

def post_process_keypoint_matching(
self,


+ 0
- 1
src/transformers/models/efficientnet/image_processing_efficientnet_fast.py View File

@@ -178,7 +178,6 @@ class EfficientNetImageProcessorFast(BaseImageProcessorFast):
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)



Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save
Baidu
map