26 Commits

Author SHA1 Message Date
  Ita Zaporozhets 5ff1f0ea8d
Merge branch 'main' into update_v5_guide_toks 15 hours ago
  Steven Liu 31de95ef71
[docs] optimizations quickstart (#42538) 19 hours ago
  Yoni Gozlan 23394cc491
Simplify using custom resolution for sam3 and sam3_video inference (#42787) 22 hours ago
  Danny Waser 06378d40e6
fix: Initialize ApertusMLP's xielu activation using `torch_dtype` (#42864) 1 day ago
  Yoni Gozlan fc50bdc685
Remove null values from fast image processors dict (#42780) 1 day ago
  Cyril Vallez c7aec088a6
Enforce call to `post_init` and fix all of them (#42873) 1 day ago
  Rémi Ouazan f3d5f2558b
[CB] Easy optimizations for continuous batching (#42839) 1 day ago
  AYou0207 298d08dc36
typo (#42863) 1 day ago
  Cyril Vallez a187b857a7
Remove tied weights from internal attribute if they are not tied (#42871) 1 day ago
  Mohamed Mekkouri 64c12fdf5f
[docs] Improve contribution guidelines for Quantization (#42870) 1 day ago
  Sai-Suraj-27 f0d9cd1ff6
Fixes 2 failing tests from AMD CI (#42777) 1 day ago
  jiqing-feng 66623a1fd6
Fix speccht5_tts pipeline (#42830) 1 day ago
  YangKai0616 e17b1b85e3
[Fix] Fix FA2 kernels ut (#42803) 1 day ago
  Cyril Vallez 40dc11cd3e
Fix Gemma (#42847) 3 days ago
  Cyril Vallez c247063001
Reapply modular examples (#42846) 4 days ago
  Yoni Gozlan a61aba59fb
Improve BatchFeature: stack list and lists of torch tensors (#42750) 4 days ago
  Cyril Vallez 5b710c7542
Do not rely on config for inferring model dtype (#42838) 4 days ago
  Anton Vlasjuk 33c948e494
[`T5Gemma2`] Fix bidirectional mask for encoder (#42820) 4 days ago
  Cyril Vallez e6b9d06147
[saving] Simplify general logic (#42766) 4 days ago
  Wu, Ke 65dc261512
Add inputs_to_logits_ratio to LasrCTCConfig (#42720) 4 days ago
  Cyril Vallez 64a7cc82a6
Simplify dtype instantiation (#42825) 4 days ago
  Raushan Turganbay 37426b27bf
Fix a typo in MoE models (#42835) 4 days ago
  Rémi Ouazan aa495f62de
Fixes for the failures of AMD CI (#42718) 4 days ago
  jiqing-feng c24b51dd78
Fix xpu output check for Ministral3 tests (#42761) 4 days ago
  ZX-ModelCloud b19844eef6
Compatible with GPTQModel FORAMT.LLM_AWQ (#42833) 4 days ago
  Marc Sun 780cc65907
Fix deepspeed sp loss due to missing labels (#42812) 4 days ago
100 changed files with 1019 additions and 1105 deletions
Split View
  1. +1
    -1
      Makefile
  2. +2
    -0
      docs/source/en/_toctree.yml
  3. +15
    -0
      docs/source/en/model_doc/sam3.md
  4. +15
    -0
      docs/source/en/model_doc/sam3_video.md
  5. +178
    -0
      docs/source/en/optimization_overview.md
  6. +95
    -16
      docs/source/en/quantization/contribute.md
  7. +3
    -10
      examples/modular-transformers/configuration_duplicated_method.py
  8. +3
    -10
      examples/modular-transformers/configuration_my_new_model.py
  9. +3
    -10
      examples/modular-transformers/configuration_my_new_model2.py
  10. +15
    -19
      examples/modular-transformers/configuration_new_model.py
  11. +3
    -0
      examples/modular-transformers/modeling_add_function.py
  12. +56
    -164
      examples/modular-transformers/modeling_dummy_bert.py
  13. +12
    -42
      examples/modular-transformers/modeling_from_uppercase_model.py
  14. +5
    -2
      examples/modular-transformers/modeling_global_indexing.py
  15. +29
    -177
      examples/modular-transformers/modeling_multimodal2.py
  16. +6
    -2
      examples/modular-transformers/modeling_my_new_model2.py
  17. +2
    -37
      examples/modular-transformers/modeling_new_task_model.py
  18. +56
    -164
      examples/modular-transformers/modeling_roberta.py
  19. +44
    -13
      examples/modular-transformers/modeling_super.py
  20. +5
    -2
      examples/modular-transformers/modeling_switch_function.py
  21. +31
    -26
      examples/modular-transformers/modeling_test_detr.py
  22. +8
    -1
      examples/modular-transformers/modular_multimodal2.py
  23. +1
    -2
      examples/modular-transformers/modular_new_model.py
  24. +12
    -2
      examples/pytorch/continuous_batching.py
  25. +1
    -1
      setup.py
  26. +1
    -1
      src/transformers/dependency_versions_table.py
  27. +38
    -9
      src/transformers/feature_extraction_utils.py
  28. +17
    -21
      src/transformers/generation/continuous_batching/continuous_api.py
  29. +15
    -4
      src/transformers/image_processing_utils_fast.py
  30. +26
    -0
      src/transformers/integrations/accelerate.py
  31. +219
    -260
      src/transformers/modeling_utils.py
  32. +1
    -0
      src/transformers/models/align/modeling_align.py
  33. +3
    -1
      src/transformers/models/apertus/modeling_apertus.py
  34. +4
    -1
      src/transformers/models/apertus/modular_apertus.py
  35. +1
    -1
      src/transformers/models/auto/auto_factory.py
  36. +1
    -0
      src/transformers/models/bart/modeling_bart.py
  37. +0
    -1
      src/transformers/models/beit/image_processing_beit_fast.py
  38. +1
    -0
      src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
  39. +1
    -0
      src/transformers/models/blenderbot/modeling_blenderbot.py
  40. +1
    -0
      src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
  41. +2
    -0
      src/transformers/models/blip/modeling_blip_text.py
  42. +3
    -1
      src/transformers/models/blt/modeling_blt.py
  43. +3
    -1
      src/transformers/models/blt/modular_blt.py
  44. +0
    -2
      src/transformers/models/bridgetower/image_processing_bridgetower_fast.py
  45. +1
    -0
      src/transformers/models/bridgetower/modeling_bridgetower.py
  46. +1
    -0
      src/transformers/models/chameleon/modeling_chameleon.py
  47. +2
    -0
      src/transformers/models/clipseg/modeling_clipseg.py
  48. +0
    -1
      src/transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py
  49. +0
    -1
      src/transformers/models/convnext/image_processing_convnext_fast.py
  50. +0
    -4
      src/transformers/models/decision_transformer/modeling_decision_transformer.py
  51. +1
    -1
      src/transformers/models/deepseek_v3/modeling_deepseek_v3.py
  52. +1
    -0
      src/transformers/models/deepseek_v3/modular_deepseek_v3.py
  53. +0
    -1
      src/transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py
  54. +0
    -4
      src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py
  55. +0
    -4
      src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py
  56. +0
    -1
      src/transformers/models/depth_pro/image_processing_depth_pro_fast.py
  57. +4
    -0
      src/transformers/models/dia/modeling_dia.py
  58. +4
    -0
      src/transformers/models/dia/modular_dia.py
  59. +0
    -1
      src/transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py
  60. +0
    -1
      src/transformers/models/donut/image_processing_donut_fast.py
  61. +1
    -1
      src/transformers/models/dots1/modeling_dots1.py
  62. +1
    -2
      src/transformers/models/dpt/image_processing_dpt_fast.py
  63. +1
    -2
      src/transformers/models/dpt/modular_dpt.py
  64. +1
    -2
      src/transformers/models/efficientloftr/image_processing_efficientloftr_fast.py
  65. +0
    -1
      src/transformers/models/efficientnet/image_processing_efficientnet_fast.py
  66. +1
    -1
      src/transformers/models/efficientnet/modeling_efficientnet.py
  67. +11
    -10
      src/transformers/models/eomt/image_processing_eomt_fast.py
  68. +3
    -0
      src/transformers/models/ernie/modeling_ernie.py
  69. +3
    -0
      src/transformers/models/ernie/modular_ernie.py
  70. +1
    -0
      src/transformers/models/evolla/modeling_evolla.py
  71. +1
    -0
      src/transformers/models/evolla/modular_evolla.py
  72. +0
    -3
      src/transformers/models/flaubert/modeling_flaubert.py
  73. +0
    -2
      src/transformers/models/flava/image_processing_flava_fast.py
  74. +1
    -1
      src/transformers/models/fuyu/image_processing_fuyu.py
  75. +9
    -11
      src/transformers/models/gemma/modeling_gemma.py
  76. +9
    -11
      src/transformers/models/gemma/modular_gemma.py
  77. +0
    -1
      src/transformers/models/gemma3/image_processing_gemma3_fast.py
  78. +1
    -0
      src/transformers/models/gemma3n/modeling_gemma3n.py
  79. +1
    -0
      src/transformers/models/gemma3n/modular_gemma3n.py
  80. +1
    -1
      src/transformers/models/glm4_moe/modeling_glm4_moe.py
  81. +1
    -1
      src/transformers/models/glm4v_moe/modeling_glm4v_moe.py
  82. +0
    -1
      src/transformers/models/glpn/image_processing_glpn_fast.py
  83. +0
    -1
      src/transformers/models/got_ocr2/image_processing_got_ocr2_fast.py
  84. +1
    -0
      src/transformers/models/got_ocr2/modeling_got_ocr2.py
  85. +0
    -4
      src/transformers/models/gpt2/modeling_gpt2.py
  86. +0
    -8
      src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
  87. +4
    -0
      src/transformers/models/idefics2/modeling_idefics2.py
  88. +2
    -0
      src/transformers/models/idefics3/modeling_idefics3.py
  89. +1
    -5
      src/transformers/models/imagegpt/image_processing_imagegpt_fast.py
  90. +0
    -1
      src/transformers/models/instructblipvideo/video_processing_instructblipvideo.py
  91. +0
    -1
      src/transformers/models/internvl/video_processing_internvl.py
  92. +0
    -1
      src/transformers/models/janus/image_processing_janus_fast.py
  93. +0
    -2
      src/transformers/models/janus/modeling_janus.py
  94. +2
    -2
      src/transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py
  95. +4
    -0
      src/transformers/models/lasr/configuration_lasr.py
  96. +4
    -0
      src/transformers/models/lasr/modular_lasr.py
  97. +0
    -1
      src/transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py
  98. +0
    -1
      src/transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py
  99. +1
    -2
      src/transformers/models/lightglue/image_processing_lightglue_fast.py
  100. +1
    -2
      src/transformers/models/llama4/image_processing_llama4_fast.py

+ 1
- 1
Makefile View File

@@ -45,7 +45,7 @@ repo-consistency:
python utils/check_modular_conversion.py
python utils/check_dummies.py
python utils/check_repo.py
python utils/check_init_weights_data.py
python utils/check_modeling_structure.py
python utils/check_inits.py
python utils/check_pipeline_typing.py
python utils/check_config_docstrings.py


+ 2
- 0
docs/source/en/_toctree.yml View File

@@ -68,6 +68,8 @@
title: Perplexity of fixed-length models
title: Generate API
- sections:
- local: optimization_overview
title: Overview
- local: attention_interface
title: Attention backends
- local: continuous_batching


+ 15
- 0
docs/source/en/model_doc/sam3.md View File

@@ -354,6 +354,21 @@ When running the same text prompt on multiple images, pre-compute text embedding
... print(f"Image {i+1}: {len(results['masks'])} '{text_prompt}' objects found")
```

### Custom Resolution Inference

<div class="warning">
⚠️ **Performance Note**: Custom resolutions may degrade accuracy. The model is meant to be used at 1008px resolution.
</div>

For faster inference or lower memory usage:

```python
>>> config = Sam3Config.from_pretrained("facebook/sam3")
>>> config.image_size = 560
>>> model = Sam3Model.from_pretrained("facebook/sam3", config=config).to(device)
>>> processor = Sam3Processor.from_pretrained("facebook/sam3", size={"height": 560, "width": 560})
```

### Prompt Label Conventions

SAM3 uses the following label conventions:


+ 15
- 0
docs/source/en/model_doc/sam3_video.md View File

@@ -188,6 +188,21 @@ For real-time applications, SAM3 Video supports processing video frames as they
>>> print(f"Masks are at original video resolution: {frame_0_outputs['masks'].shape}")
```

#### Custom Resolution Inference

<div class="warning">
⚠️ **Performance Note**: Custom resolutions may degrade accuracy. The model is meant to be used at 1008px resolution.
</div>

For faster inference or lower memory usage:

```python
>>> config = Sam3VideoConfig.from_pretrained("facebook/sam3")
>>> config.image_size = 560
>>> model = Sam3VideoModel.from_pretrained("facebook/sam3", config=config).to(device, dtype=torch.bfloat16)
>>> processor = Sam3VideoProcessor.from_pretrained("facebook/sam3", size={"height": 560, "width": 560})
```

## Sam3VideoConfig

[[autodoc]] Sam3VideoConfig


+ 178
- 0
docs/source/en/optimization_overview.md View File

@@ -0,0 +1,178 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# Overview

Transformers provides multiple inference optimization techniques to make models fast, affordable, and accessible. Options include alternative attention mechanisms for reduced memory traffic, code compilation for faster execution, and optimized kernels for throughput. Stack these techniques for maximum performance.

> [!NOTE]
> Memory and speed are closely related but not the same. Shrinking your memory footprint makes a model "faster" because there is less data to move around. Pure speed optimizations don't always reduce memory and sometimes increase usage. Choose the appropriate optimization based on your use case and hardware.

Use the table below to pick an optimization technique.

| Technique | Speed | Memory |
|---|:---:|:---:|
| [Compilation](#compilation) | ✅ | |
| [Attention backends](#attention-backends) | ✅ | ✅ |
| [Kernels](#kernels) | ✅ | ✅ |
| [Quantization](#quantization) | ✅ | ✅ |
| [Caching](#caching) | ✅ | ✅ |
| [Parallelism](#parallelism) | ✅ | |
| [Continuous batching](#continuous-batching) | ✅ | |

This guide gives you a quick start on Transformers optimizations.

## Compilation

[torch.compile](./perf_torch_compile) reduces Python overhead, fuses operations, and creates kernels tuned for your shapes and hardware. The first run warms it up and subsequent runs use the faster compiled path.

Pass a [fixed size cache](./kv_cache#fixed-size-cache) to [`~GenerationMixin.generate`] to trigger `torch.compile` automatically.

```py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", dtype=torch.float16, device_map="auto")
input = tokenizer("The French Bread Law states", return_tensors="pt").to(model.device)

output = model.generate(**input, do_sample=False, max_new_tokens=20, cache_implementation="static")
tokenizer.batch_decode(output, skip_special_tokens=True)[0]
```

> [!WARNING]
> Avoid calling `torch.compile(model)` outside of [`~GenerationMixin.generate`] to prevent the model from recompiling every step.

## Attention backends

Alternative [attention backends](./attention_interface) lower memory traffic. For example, FlashAttention tiles attention computations and avoids large intermediate tensors to reduce memory footprint.

Set `attn_implementation` in [`~PreTrainedModel.from_pretrained`] to load an optimized attention backend.

```py
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", attn_implementation="flash_attention_2")
```

## Kernels

Kernels fuse operations to boost throughput and reduce memory usage. The [Kernels](https://huggingface.co/docs/kernels/en/index) library loads optimized compute kernels from the [Hub](https://huggingface.co/kernels-community) in a flexible and version-safe way.

The example below loads an optimized FlashAttention-2 kernel without installing the package.

```py
import torch
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-0.6B", attn_implementation="kernels-community/flash-attn2"
)
```

## Quantization

[Quantization](./quantization/overview) shrinks the size of every parameter which lowers memory footprint and increases speed because you can do more operations.

Pass a quantization config to the `quantization_config` argument in [`~PreTrainedModel.from_pretrained`]. Each quantization backend has a different config with different arguments. The example below quantizes a model to 4-bits and configures the computation dtype with the [bitsandbytes](./quantization/bitsandbytes) backend.

```py
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)

model = AutoModelForCausalLM.from_pretrained(
"allenai/Olmo-3-7B-Think", quantization_config=bnb_config
)
```

## Caching

[Caching](./kv_cache) speeds up generation by reusing past keys and values instead of recomputing them for every token. To offset and reduce the memory cost of storing past keys and values, Transformers
supports offloading the cache to the CPU. Only the current layer remains on the GPU.

Use the `cache_implementation` argument in [`~GenerationMixin.generate`] to set a cache strategy.

```py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-0.6B", attn_implementation="kernels-community/flash-attn2"
)
inputs = tokenizer("The Le Décret Pain states that a baguette must,", return_tensors="pt")
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=50, cache_implementation="offloaded")
```

## Parallelism

[Parallelism](./perf_infer_gpu_multi) distributes a model across devices so models too big for one device run fast. This approach uses more memory due to sharding overhead and communication to sync results.

[Tensor parallelism](./perf_infer_gpu_multi) splits a model layer across devices. Set `tp_plan="auto"` in [`~PreTrainedModel.from_pretrained`] to enable it.

```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", tp_plan="auto")
print(model._tp_plan)
```

## Continuous batching

[Continuous batching](./continuous_batching) maximizes throughput by keeping the GPU busy with dynamic scheduling and chunked prefill. [Serving](./serving.md) applications use it to process multiple incoming requests concurrently.

Use [`~ContinuousMixin.generate_batch`] to enable continuous batching.

```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig

model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-0.6B",
attn_implementation="paged|sdpa",
device_map="cuda",
torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")

prompts = [
"The Le Décret Pain states that a baguette must",
"Explain gravity in one sentence.",
"Name the capital of France.",
]
inputs = [tokenizer.encode(p) for p in prompts]

generation_config = GenerationConfig(
max_new_tokens=32,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=False,
max_batch_tokens=512,
)

outputs = model.generate_batch(
inputs=inputs,
generation_config=generation_config,
)

for request_id, output in outputs.items():
text = tokenizer.decode(output.generated_tokens, skip_special_tokens=True)
print(f"[{request_id}] {text}")
```

+ 95
- 16
docs/source/en/quantization/contribute.md View File

@@ -16,7 +16,7 @@ rendered properly in your Markdown viewer.

# Contribute

Transformers supports many quantization methods such as QLoRA, GPTQ, LLM.int8, and AWQ. However, there are still many more quantization approaches that haven't been integrated yet. To make adding and using these quantization methods with Transformers easier, use the [`~quantizers.HfQuantizer`] class. [`~quantizers.HfQuantizer`] is designed to be an internal helper class for adding a quantization method instead of something applied to every PyTorch module.
Transformers supports many quantization methods such as QLoRA, GPTQ, LLM.int8, and AWQ. However, there are still many more quantization approaches that haven't been integrated yet. To make adding and using these quantization methods with Transformers easier, use the [`~quantizers.HfQuantizer`] class. [`~quantizers.HfQuantizer`] is designed to be an internal helper class for adding a quantization method instead of something applied to every PyTorch module.

This guide will show you how to integrate a new quantization method with [`~quantizers.HfQuantizer`].

@@ -28,16 +28,16 @@ Before integrating a new quantization method into Transformers, ensure the metho
- The method can run on commonly-used hardware (CPU, GPU, etc.).
- The method is wrapped in a [nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) ([`~bitsandbytes.nn.Linear8bitLt`], [`~bitsandbytes.nn.Linear4bit`]), and the quantized linear layer should have the following definition.

```py
class Linear4bit(nn.Module):
def __init__(self, ...):
...
def forward(self, x):
return my_4bit_kernel(x, self.weight, self.bias)
```
```py
class Linear4bit(nn.Module):
def __init__(self, ...):
...

This way, Transformers models are easily quantized by replacing instances of [nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) with a target class.
def forward(self, x):
return my_4bit_kernel(x, self.weight, self.bias)
```

This way, Transformers models are easily quantized by replacing instances of [nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) with a target class.

- The quantization method should be serializable. You can save the quantized weights locally or push them to the Hub.
- Make sure the package containing the quantization kernels/primitive is stable (no frequent breaking changes).
@@ -48,23 +48,23 @@ Some quantization methods may require "pre-quantizing" the model through data ca

0. The best starting point would be to have a look at another quantization method such as Finegrained Fp8. You will have to update or create three files in total: the [config file](https://github.com/huggingface/transformers/blob/main/src/transformers/utils/quantization_config.py), the [integration file](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/finegrained_fp8.py) and the [quantizer file](https://github.com/huggingface/transformers/blob/main/src/transformers/quantizers/quantizer_finegrained_fp8.py).

1. Create a new quantization config class inside [src/transformers/utils/quantization_config.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/utils/quantization_config.py). Add the new quantization config to the [_import_structure](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/__init__.py#L1088) inside Transformers' [src/transformers/__init__.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/__init__.py) file.
1. Create a new quantization config class inside [src/transformers/utils/quantization_config.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/utils/quantization_config.py). Add the new quantization config to the [\_import_structure](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/__init__.py#L1088) inside Transformers' [src/transformers/\_\_init\_\_.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/__init__.py) file.

2. Create a new file inside [src/transformers/quantizers/](https://github.com/huggingface/transformers/tree/abbffc4525566a48a9733639797c812301218b83/src/transformers/quantizers) named `quantizer_your_method.py`, and make it inherit from [`~quantizers.HfQuantizer]. Make sure to add the new quantizer and quantization config in the quantization auto-mapping in [src/transformers/quantizers/auto.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/quantizers/auto.py).

3. Define the following class attributes and property methods for your quantization method:

- `requires_calibration`: Whether the quantization method requires a data calibration process. If set to `True`, you can only support inference (with quantized weights) and not inference and quantization.
- `is_serializable`: A property method to determine whether the method is serializable or not.
- `is_trainable`: A property method to determine whether you can fine-tune models on top of the quantization method (with or without PEFT approaches).
- `requires_calibration`: Whether the quantization method requires a data calibration process. If set to `True`, you can only support inference (with quantized weights) and not inference and quantization.
- `is_serializable`: A property method to determine whether the method is serializable or not.
- `is_trainable`: A property method to determine whether you can fine-tune models on top of the quantization method (with or without PEFT approaches).

4. Write the `validate_environment` and `update_dtype` methods. These methods are called before creating the quantized model to ensure users use the right configuration. Refer to other quantizers for an example of it is implemented.

5. Write the `_process_model_before_weight_loading` method. In Transformers, the quantized models are initialized first on the `"meta"` device before loading the weights. This means the `_process_model_before_weight_loading` method takes care of manipulating the model skeleton to replace some modules ([nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html)) with the target modules (quantization modules).

You can define module replacement logic or any other utility method by creating a new file in [transformers/src/integrations/](https://github.com/huggingface/transformers/tree/abbffc4525566a48a9733639797c812301218b83/src/transformers/integrations) and exposing the relevant methods in that folder's `__init__.py` file.
You can define module replacement logic or any other utility method by creating a new file in [transformers/src/integrations/](https://github.com/huggingface/transformers/tree/abbffc4525566a48a9733639797c812301218b83/src/transformers/integrations) and exposing the relevant methods in that folder's `__init__.py` file.

6. Add the `get_quantize_ops` method to the quantizer class if the quantization supports quantizing on the fly. In transformers, we materialize each tensor and apply a sequence of different operations on it. In our case, the quantization operation happens at the end. You need to create a `XXXQuantize`, a subclass of `ConversionOps`, and add a `convert` method. In the `convert` method, you need to quantize the weights and return a dictionary of quantized params.
6. Add the `get_quantize_ops` method to the quantizer class if the quantization supports quantizing on the fly. In transformers, we materialize each tensor and apply a sequence of different operations on it. In our case, the quantization operation happens at the end. You need to create a `XXXQuantize`, a subclass of `ConversionOps`, and add a `convert` method. In the `convert` method, you need to quantize the weights and return a dictionary of quantized params.

7. Add the `get_weight_conversions` method to the quantizer class if the quantization supports loading pre-quantized weights. In transformers, we can collect multiple tensors and apply operations on them. This is particularly useful when we have tensors in the checkpoint that require to be regrouped to re-create the quantized tensors.

@@ -73,3 +73,82 @@ You can define module replacement logic or any other utility method by creating
9. Document everything! Make sure your quantization method is documented by adding a new file under `docs/source/en/quantization`.

10. You should add tests by adding the package in our nightly Dockerfile inside `docker/transformers-quantization-latest-gpu` and then adding a new test file in `tests/quantization/xxx`. Feel free to check out existing quantization methods to see how it is implemented.

## Files overview

| File | Purpose |
| -------------------------------------------- | ------------------------------------------------------------------------------------------------ |
| `utils/quantization_config.py` | Define `YourMethodConfig` inheriting from `QuantizationConfigMixin` |
| `quantizers/quantizer_your_method.py` | Implement `YourMethodHfQuantizer` inheriting from `HfQuantizer` |
| `integrations/your_method.py` | Implement `ConversionOps` subclasses and helper functions |
| `quantizers/auto.py` | Register quantizer and config in `AUTO_QUANTIZER_MAPPING` and `AUTO_QUANTIZATION_CONFIG_MAPPING` |
| `docs/source/en/quantization/your_method.md` | Document usage for users |
| `tests/quantization/your_method/` | Add integration tests |

## Understanding `get_quantize_ops` vs `get_weight_conversions`

These two methods handle different scenarios for loading weights. Understanding when to use each is essential.

### `get_quantize_ops` — Quantize on the fly

Use this when loading a **non-quantized checkpoint** (e.g., float16/bfloat16 weights) and quantizing during load.

```
Checkpoint: model.safetensors (float16 weights for example)
get_quantize_ops → YourQuantize.convert()
Result: Quantized weights in memory
```

The `convert` method receives one tensor at a time, quantizes it, and can return a dictionary of quantized params, for example:

```py
class YourQuantize(ConversionOps):
def convert(self, input_dict, model, full_layer_name, missing_keys, **kwargs):
# input_dict = {"layer.weight": <float16 tensor>}
value = list(input_dict.values())[0]
module, tensor_name = get_module_from_name(model, full_layer_name)

# Quantize and assign
quantized, scale, zero_point = your_quantize_fn(value)
return {full_layer_name: quantized, full_layer_name + ".scale": scale, full_layer_name + ".zero_point": zero_point}
```

### `get_weight_conversions` — Load pre-quantized checkpoints

Use this when loading a **pre-quantized checkpoint** where the quantized weights are saved as several separate components (such as data, scale, and zero point), and these need to be combined into one tensor during loading. Not all quantization methods require this reconstruction step: for example, some methods like FP8 simply load weights and scales as-is, without combining them. Others, such as torchao, do require reassembling the quantized tensor from its multiple saved components.

```
Checkpoint: model.safetensors (quantized components)
- layer._weight_qdata
- layer._weight_scale
- layer._weight_zero_point
get_weight_conversions → WeightConverter + YourDeserialize.convert()
Result: Reconstructed quantized tensor → layer.weight
```

The `WeightConverter` collects related tensors based on `source_patterns`, then passes them to your `convert` method:

```py
def get_weight_conversions(self):
if self.pre_quantized:
return [
WeightConverter(
source_patterns=["_weight_qdata", "_weight_scale", "_weight_zero_point"],
target_patterns="weight",
operations=[YourDeserialize(self)],
),
]
return []


class YourDeserialize(ConversionOps):
def convert(self, input_dict, model, full_layer_name, **kwargs):
# input_dict contains all collected tensors
# Reconstruct the quantized tensor from components
reconstructed_tensor = reconstruct_from_components(input_dict)
return {full_layer_name: reconstructed_tensor}
```

+ 3
- 10
examples/modular-transformers/configuration_duplicated_method.py View File

@@ -8,7 +8,7 @@
from typing import Optional

from ...configuration_utils import PreTrainedConfig
from ...modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params
from ...modeling_rope_utils import RopeParameters


class DuplicatedMethodConfig(PreTrainedConfig):
@@ -129,7 +129,7 @@ class DuplicatedMethodConfig(PreTrainedConfig):
eos_token_id: Optional[int] = 2,
pretraining_tp: Optional[int] = 1,
tie_word_embeddings: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
mlp_bias: Optional[bool] = False,
@@ -157,14 +157,7 @@ class DuplicatedMethodConfig(PreTrainedConfig):
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
# Try to set `rope_scaling` if available, otherwise use `rope_parameters`
rope_scaling = kwargs.pop("rope_scaling", None)
self.rope_parameters = rope_scaling or rope_parameters

# Validate the correctness of rotary position embeddings parameters
rope_theta = kwargs.get("rope_theta", 10000.0)
standardize_rope_params(self, rope_theta=rope_theta)
rope_config_validation(self)
self.rope_parameters = rope_parameters

super().__init__(
pad_token_id=pad_token_id,


+ 3
- 10
examples/modular-transformers/configuration_my_new_model.py View File

@@ -8,7 +8,7 @@
from typing import Optional

from ...configuration_utils import PreTrainedConfig
from ...modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params
from ...modeling_rope_utils import RopeParameters


class MyNewModelConfig(PreTrainedConfig):
@@ -165,7 +165,7 @@ class MyNewModelConfig(PreTrainedConfig):
eos_token_id: Optional[int] = 2,
pretraining_tp: Optional[int] = 1,
tie_word_embeddings: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
mlp_bias=True,
@@ -194,14 +194,7 @@ class MyNewModelConfig(PreTrainedConfig):
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
# Try to set `rope_scaling` if available, otherwise use `rope_parameters`
rope_scaling = kwargs.pop("rope_scaling", None)
self.rope_parameters = rope_scaling or rope_parameters

# Validate the correctness of rotary position embeddings parameters
rope_theta = kwargs.get("rope_theta", 10000.0)
standardize_rope_params(self, rope_theta=rope_theta)
rope_config_validation(self)
self.rope_parameters = rope_parameters

super().__init__(
pad_token_id=pad_token_id,


+ 3
- 10
examples/modular-transformers/configuration_my_new_model2.py View File

@@ -7,7 +7,7 @@
from typing import Optional

from ...configuration_utils import PreTrainedConfig
from ...modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params
from ...modeling_rope_utils import RopeParameters


class MyNewModel2Config(PreTrainedConfig):
@@ -68,7 +68,7 @@ class MyNewModel2Config(PreTrainedConfig):
eos_token_id: Optional[int] = 2,
pretraining_tp: Optional[int] = 1,
tie_word_embeddings: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
mlp_bias: Optional[bool] = False,
@@ -96,14 +96,7 @@ class MyNewModel2Config(PreTrainedConfig):
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
# Try to set `rope_scaling` if available, otherwise use `rope_parameters`
rope_scaling = kwargs.pop("rope_scaling", None)
self.rope_parameters = rope_scaling or rope_parameters

# Validate the correctness of rotary position embeddings parameters
rope_theta = kwargs.get("rope_theta", 10000.0)
standardize_rope_params(self, rope_theta=rope_theta)
rope_config_validation(self)
self.rope_parameters = rope_parameters

super().__init__(
pad_token_id=pad_token_id,


+ 15
- 19
examples/modular-transformers/configuration_new_model.py View File

@@ -6,7 +6,8 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Example where we only want to overwrite the defaults of an init

from ...configuration_utils import PreTrainedConfig, layer_type_validation

from ...configuration_utils import PreTrainedConfig


class NewModelConfig(PreTrainedConfig):
@@ -59,14 +60,14 @@ class NewModelConfig(PreTrainedConfig):
Beginning of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_parameters (`RopeParameters`, *optional*):
Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
with longer `max_position_embeddings`.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
layer_types (`list`, *optional*):
Attention pattern for each layer.
use_bidirectional_attention (`bool`, *optional*):
If True, the model will attend to all text tokens instead of using a causal mask.

@@ -116,20 +117,12 @@ class NewModelConfig(PreTrainedConfig):
eos_token_id=1,
bos_token_id=2,
tie_word_embeddings=True,
rope_theta=10000.0,
rope_parameters=None,
attention_bias=False,
attention_dropout=0.0,
use_bidirectional_attention=False,
layer_types=None,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
@@ -142,15 +135,18 @@ class NewModelConfig(PreTrainedConfig):
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.use_bidirectional_attention = use_bidirectional_attention
self.rope_parameters = rope_parameters

self.layer_types = layer_types
if self.layer_types is None:
self.layer_types = ["full_attention" for _ in range(self.num_hidden_layers)]
layer_type_validation(self.layer_types, self.num_hidden_layers)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

@property
def num_heads(self):


+ 3
- 0
examples/modular-transformers/modeling_add_function.py View File

@@ -10,6 +10,8 @@ from typing import Optional
import torch
from torch import nn

from ...integrations import use_kernel_func_from_hub


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
@@ -18,6 +20,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.



+ 56
- 164
examples/modular-transformers/modeling_dummy_bert.py View File

@@ -10,24 +10,20 @@ from typing import Optional, Union
import torch
from torch import nn

from ... import initialization as init
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...masking_utils import create_causal_mask
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
from ...masking_utils import create_bidirectional_mask, create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import TransformersKwargs, auto_docstring, is_torch_flex_attn_available
from ...pytorch_utils import apply_chunking_to_forward
from ...utils import TransformersKwargs, auto_docstring
from ...utils.generic import check_model_inputs
from .configuration_dummy_bert import DummyBertConfig


if is_torch_flex_attn_available():
from ...integrations.flex_attention import make_flex_block_causal_mask


class DummyBertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""

@@ -106,7 +102,7 @@ def eager_attention_forward(
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling

if attention_mask is not None and attention_mask.ndim == 4:
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask

@@ -148,7 +144,7 @@ class DummyBertSelfAttention(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Cache] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
@@ -160,14 +156,14 @@ class DummyBertSelfAttention(nn.Module):
key_layer = self.key(hidden_states).view(*hidden_shape).transpose(1, 2)
value_layer = self.value(hidden_states).view(*hidden_shape).transpose(1, 2)

if past_key_value is not None:
if past_key_values is not None:
# decoder-only dummy_bert can have a simple dynamic cache for example
current_past_key_value = past_key_value
if isinstance(past_key_value, EncoderDecoderCache):
current_past_key_value = past_key_value.self_attention_cache
current_past_key_values = past_key_values
if isinstance(past_key_values, EncoderDecoderCache):
current_past_key_values = past_key_values.self_attention_cache

# save all key/value_layer to cache to be re-used for fast auto-regressive generation
key_layer, value_layer = current_past_key_value.update(
key_layer, value_layer = current_past_key_values.update(
key_layer,
value_layer,
self.layer_idx,
@@ -221,7 +217,7 @@ class DummyBertCrossAttention(nn.Module):
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[EncoderDecoderCache] = None,
past_key_values: Optional[EncoderDecoderCache] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
# determine input shapes
@@ -234,22 +230,22 @@ class DummyBertCrossAttention(nn.Module):
# get query proj
query_layer = self.query(hidden_states).view(*q_input_shape).transpose(1, 2)

is_updated = past_key_value.is_updated.get(self.layer_idx) if past_key_value is not None else False
if past_key_value is not None and is_updated:
is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
if past_key_values is not None and is_updated:
# reuse k,v, cross_attentions
key_layer = past_key_value.cross_attention_cache.layers[self.layer_idx].keys
value_layer = past_key_value.cross_attention_cache.layers[self.layer_idx].values
key_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
value_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].values
else:
key_layer = self.key(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2)
value_layer = self.value(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2)

if past_key_value is not None:
if past_key_values is not None:
# save all states to the cache
key_layer, value_layer = past_key_value.cross_attention_cache.update(
key_layer, value_layer = past_key_values.cross_attention_cache.update(
key_layer, value_layer, self.layer_idx
)
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
past_key_value.is_updated[self.layer_idx] = True
past_key_values.is_updated[self.layer_idx] = True

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
@@ -290,25 +286,6 @@ class DummyBertAttention(nn.Module):
attention_class = DummyBertCrossAttention if is_cross_attention else DummyBertSelfAttention
self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx)
self.output = DummyBertSelfOutput(config)
self.pruned_heads = set()

def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)

# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)

def forward(
self,
@@ -316,7 +293,7 @@ class DummyBertAttention(nn.Module):
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Cache] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
@@ -325,7 +302,7 @@ class DummyBertAttention(nn.Module):
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
past_key_value=past_key_value,
past_key_values=past_key_values,
cache_position=cache_position,
**kwargs,
)
@@ -388,14 +365,14 @@ class DummyBertLayer(GradientCheckpointingLayer):
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Cache] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
self_attention_output, _ = self.attention(
hidden_states,
attention_mask,
past_key_value=past_key_value,
past_key_values=past_key_values,
cache_position=cache_position,
**kwargs,
)
@@ -413,7 +390,7 @@ class DummyBertLayer(GradientCheckpointingLayer):
None, # attention_mask
encoder_hidden_states,
encoder_attention_mask,
past_key_value=past_key_value,
past_key_values=past_key_values,
**kwargs,
)
attention_output = cross_attention_output
@@ -452,7 +429,7 @@ class DummyBertEncoder(nn.Module):
attention_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_values,
past_key_values=past_key_values,
cache_position=cache_position,
**kwargs,
)
@@ -503,7 +480,6 @@ class DummyBertLMPredictionHead(nn.Module):
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)

self.bias = nn.Parameter(torch.zeros(config.vocab_size))

def forward(self, hidden_states):
@@ -527,21 +503,12 @@ class DummyBertPreTrainedModel(PreTrainedModel):
"cross_attentions": DummyBertCrossAttention,
}

@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.zero_()
module.weight.fill_(1.0)
elif isinstance(module, DummyBertLMPredictionHead):
module.bias.zero_()
super()._init_weights(module)
if isinstance(module, DummyBertLMPredictionHead):
init.zeros_(module.bias)


@auto_docstring(
@@ -582,14 +549,6 @@ class DummyBertModel(DummyBertPreTrainedModel):
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value

def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)

@check_model_inputs
@auto_docstring
def forward(
@@ -615,19 +574,22 @@ class DummyBertModel(DummyBertPreTrainedModel):
use_cache = False

if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)

if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

if input_ids is not None:
device = input_ids.device
input_shape = input_ids.shape
seq_length = input_ids.shape[1]
else:
device = inputs_embeds.device
input_shape = inputs_embeds.shape[:-1]
seq_length = inputs_embeds.shape[1]

seq_length = input_shape[1]
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(past_key_values_length, past_key_values_length + seq_length, device=device)
@@ -641,7 +603,6 @@ class DummyBertModel(DummyBertPreTrainedModel):
)

attention_mask, encoder_attention_mask = self._create_attention_masks(
input_shape=input_shape,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
embedding_output=embedding_output,
@@ -672,7 +633,6 @@ class DummyBertModel(DummyBertPreTrainedModel):

def _create_attention_masks(
self,
input_shape,
attention_mask,
encoder_attention_mask,
embedding_output,
@@ -680,95 +640,27 @@ class DummyBertModel(DummyBertPreTrainedModel):
cache_position,
past_key_values,
):
if attention_mask is not None and attention_mask.dim() == 2:
if self.config.is_decoder:
attention_mask = create_causal_mask(
config=self.config,
input_embeds=embedding_output,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
)
else:
attention_mask = self._update_full_mask(
attention_mask,
embedding_output,
)
elif attention_mask is not None and attention_mask.dim() == 3:
if "flash" in self.config._attn_implementation or self.config._attn_implementation == "flex_attention":
raise ValueError(
"Passing attention mask with a 3D/4D shape does not work with type "
f"{self.config._attn_implementation} - please use either `sdpa` or `eager` instead."
)
attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
if self.config.is_decoder:
attention_mask = create_causal_mask(
config=self.config,
input_embeds=embedding_output,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
)
else:
attention_mask = create_bidirectional_mask(
config=self.config,
input_embeds=embedding_output,
attention_mask=attention_mask,
)

if encoder_attention_mask is not None:
if encoder_attention_mask.dim() == 2:
encoder_attention_mask = self._update_cross_attn_mask(
encoder_hidden_states,
encoder_attention_mask,
embedding_output.shape[:2],
embedding_output,
)
else:
if "flash" in self.config._attn_implementation or self.config._attn_implementation == "flex_attention":
raise ValueError(
"Passing attention mask with a 3D/4D shape does not work with type "
f"{self.config._attn_implementation} - please use either `sdpa` or `eager` instead."
)
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
encoder_attention_mask = create_bidirectional_mask(
config=self.config,
input_embeds=embedding_output,
attention_mask=encoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
)

return attention_mask, encoder_attention_mask

def _update_full_mask(
self,
attention_mask: Union[torch.Tensor, None],
inputs_embeds: torch.Tensor,
):
if attention_mask is not None:
if "flash" in self.config._attn_implementation:
attention_mask = attention_mask if 0 in attention_mask else None
elif self.config._attn_implementation == "sdpa":
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
elif self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)

return attention_mask

def _update_cross_attn_mask(
self,
encoder_hidden_states: Union[torch.Tensor, None],
encoder_attention_mask: Union[torch.Tensor, None],
input_shape: torch.Size,
inputs_embeds: torch.Tensor,
):
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
if "flash" in self.config._attn_implementation:
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
elif self.config._attn_implementation == "sdpa":
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask,
inputs_embeds.dtype,
tgt_len=input_shape[-1],
)
elif self.config._attn_implementation == "flex_attention":
if isinstance(encoder_attention_mask, torch.Tensor):
encoder_attention_mask = make_flex_block_causal_mask(
encoder_attention_mask,
query_length=input_shape[-1],
is_causal=False,
)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)

return encoder_attention_mask

+ 12
- 42
examples/modular-transformers/modeling_from_uppercase_model.py View File

@@ -4,6 +4,7 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_from_uppercase_model.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨

from collections.abc import Callable
from typing import Optional, Union

@@ -13,6 +14,8 @@ from torch import nn
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import TransformersKwargs
from .configuration_from_uppercase_model import FromUppercaseModelTextConfig, FromUppercaseModelVisionConfig


@@ -24,8 +27,7 @@ def eager_attention_forward(
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
output_attentions: bool = True,
**kwargs,
**kwargs: Unpack[TransformersKwargs],
):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
if attention_mask is not None:
@@ -35,8 +37,6 @@ def eager_attention_forward(

attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
if not output_attentions:
attn_weights = None
return attn_output, attn_weights


@@ -67,8 +67,7 @@ class FromUppercaseModelAttention(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""

@@ -81,15 +80,6 @@ class FromUppercaseModelAttention(nn.Module):
queries = queries.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
# FROM_UPPERCASE_MODEL text model uses both `causal_attention_mask` and `attention_mask`
# in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
if self.config._attn_implementation == "flash_attention_2":
self.is_causal = causal_attention_mask is not None
else:
if attention_mask is not None and causal_attention_mask is not None:
attention_mask = attention_mask + causal_attention_mask
elif causal_attention_mask is not None:
attention_mask = causal_attention_mask

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
@@ -101,17 +91,14 @@ class FromUppercaseModelAttention(nn.Module):
keys,
values,
attention_mask,
is_causal=self.is_causal,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
output_attentions=output_attentions,
**kwargs,
)

attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
attn_output = self.out_proj(attn_output)

if not output_attentions:
attn_weights = None
return attn_output, attn_weights


@@ -143,27 +130,15 @@ class FromUppercaseModelEncoderLayer(GradientCheckpointingLayer):
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
causal_attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(config.encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
**kwargs: Unpack[TransformersKwargs],
) -> torch.FloatTensor:
residual = hidden_states

hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
**kwargs,
)
hidden_states = residual + hidden_states

@@ -172,9 +147,4 @@ class FromUppercaseModelEncoderLayer(GradientCheckpointingLayer):
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

outputs = (hidden_states,)

if output_attentions:
outputs += (attn_weights,)

return outputs
return hidden_states

+ 5
- 2
examples/modular-transformers/modeling_global_indexing.py View File

@@ -13,6 +13,7 @@ from torch import nn
from transformers.modeling_utils import AttentionInterface

from ...cache_utils import Cache
from ...integrations import use_kernel_func_from_hub, use_kernelized_func
from ...processing_utils import Unpack
from ...utils import TransformersKwargs
from .configuration_global_indexing import GlobalIndexingConfig
@@ -25,6 +26,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

@@ -100,6 +102,7 @@ ALL_ATTENTION_FUNCTIONS = AttentionInterface()
ALL_ATTENTION_FUNCTIONS["flex_attention"] = custom_flex


@use_kernelized_func(apply_rotary_pos_emb)
class GlobalIndexingAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

@@ -129,8 +132,8 @@ class GlobalIndexingAttention(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],


+ 29
- 177
examples/modular-transformers/modeling_multimodal2.py View File

@@ -17,7 +17,9 @@ from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import auto_docstring, can_return_tuple, torch_int
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, torch_int
from ...utils.generic import check_model_inputs
from .configuration_multimodal2 import Multimodal2Config, Multimodal2TextConfig, Multimodal2VisionConfig


@@ -29,8 +31,7 @@ def eager_attention_forward(
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
output_attentions: bool = True,
**kwargs,
**kwargs: Unpack[TransformersKwargs],
):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
if attention_mask is not None:
@@ -40,8 +41,6 @@ def eager_attention_forward(

attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
if not output_attentions:
attn_weights = None
return attn_output, attn_weights


@@ -72,8 +71,7 @@ class Multimodal2VisionAttention(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""

@@ -86,15 +84,6 @@ class Multimodal2VisionAttention(nn.Module):
queries = queries.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
# MULTIMODAL2_VISION text model uses both `causal_attention_mask` and `attention_mask`
# in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
if self.config._attn_implementation == "flash_attention_2":
self.is_causal = causal_attention_mask is not None
else:
if attention_mask is not None and causal_attention_mask is not None:
attention_mask = attention_mask + causal_attention_mask
elif causal_attention_mask is not None:
attention_mask = causal_attention_mask

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
@@ -106,17 +95,14 @@ class Multimodal2VisionAttention(nn.Module):
keys,
values,
attention_mask,
is_causal=self.is_causal,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
output_attentions=output_attentions,
**kwargs,
)

attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
attn_output = self.out_proj(attn_output)

if not output_attentions:
attn_weights = None
return attn_output, attn_weights


@@ -135,86 +121,11 @@ class Multimodal2VisionMLP(nn.Module):
return hidden_states


class Multimodal2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(self, config: Union[Multimodal2VisionConfig, Multimodal2TextConfig]):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.is_causal = False

self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""

batch_size, seq_length, embed_dim = hidden_states.shape

queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)

queries = queries.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
# MULTIMODAL2 text model uses both `causal_attention_mask` and `attention_mask`
# in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
if self.config._attn_implementation == "flash_attention_2":
self.is_causal = causal_attention_mask is not None
else:
if attention_mask is not None and causal_attention_mask is not None:
attention_mask = attention_mask + causal_attention_mask
elif causal_attention_mask is not None:
attention_mask = causal_attention_mask

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

attn_output, attn_weights = attention_interface(
self,
queries,
keys,
values,
attention_mask,
is_causal=self.is_causal,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
output_attentions=output_attentions,
)

attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = self.out_proj(attn_output)

if not output_attentions:
attn_weights = None
return attn_output, attn_weights


class Multimodal2VisionEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = Multimodal2Attention(config)
self.self_attn = Multimodal2VisionAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Multimodal2VisionMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
@@ -223,27 +134,15 @@ class Multimodal2VisionEncoderLayer(GradientCheckpointingLayer):
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
causal_attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(config.encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
**kwargs: Unpack[TransformersKwargs],
) -> torch.FloatTensor:
residual = hidden_states

hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
**kwargs,
)
hidden_states = residual + hidden_states

@@ -252,12 +151,7 @@ class Multimodal2VisionEncoderLayer(GradientCheckpointingLayer):
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

outputs = (hidden_states,)

if output_attentions:
outputs += (attn_weights,)

return outputs
return hidden_states


class Multimodal2VisionEncoder(nn.Module):
@@ -279,9 +173,7 @@ class Multimodal2VisionEncoder(nn.Module):
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutput:
r"""
Args:
@@ -296,53 +188,17 @@ class Multimodal2VisionEncoder(nn.Module):
- 0 for tokens that are **masked**.

[What are attention masks?](../glossary#attention-mask)
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Causal mask for the text model. Mask values selected in `[0, 1]`:

- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.

[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)

encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None

hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
layer_outputs = encoder_layer(
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
**kwargs,
)

hidden_states = layer_outputs[0]

if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)

if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)

return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
)


@@ -444,15 +300,9 @@ class Multimodal2VisionTransformer(nn.Module):
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPooling:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)

if pixel_values is None:
raise ValueError("You have to specify pixel_values")

@@ -461,8 +311,7 @@ class Multimodal2VisionTransformer(nn.Module):

encoder_outputs: BaseModelOutput = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
**kwargs,
)

last_hidden_state = encoder_outputs.last_hidden_state
@@ -472,8 +321,6 @@ class Multimodal2VisionTransformer(nn.Module):
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)


@@ -481,12 +328,18 @@ class Multimodal2VisionTransformer(nn.Module):
class Multimodal2VisionPreTrainedModel(PreTrainedModel):
config: Multimodal2Config
base_model_prefix = "multimodal2_vision"
input_modalities = ("image", "text")
supports_gradient_checkpointing = True
_supports_sdpa = True
_supports_flash_attn = True
_supports_flex_attn = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Multimodal2VisionEncoderLayer,
"attentions": Multimodal2VisionAttention,
}

@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, Multimodal2VisionMLP):
@@ -500,6 +353,7 @@ MULTIMODAL2_VISION_START_DOCSTRING = "doc"
class Multimodal2VisionModel(Multimodal2VisionPreTrainedModel):
config: Multimodal2VisionConfig
main_input_name = "pixel_values"
input_modalities = ("image",)
_no_split_modules = ["Multimodal2VisionEncoderLayer"]

def __init__(self, config: Multimodal2VisionConfig):
@@ -511,14 +365,13 @@ class Multimodal2VisionModel(Multimodal2VisionPreTrainedModel):
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding

@can_return_tuple
@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPooling:
r"""
Example:
@@ -543,7 +396,6 @@ class Multimodal2VisionModel(Multimodal2VisionPreTrainedModel):

return self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
**kwargs,
)

+ 6
- 2
examples/modular-transformers/modeling_my_new_model2.py View File

@@ -10,8 +10,10 @@ from typing import Optional
import torch
from torch import nn

from ... import initialization as init
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...integrations import use_kernel_func_from_hub, use_kernelized_func
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
@@ -62,6 +64,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

@@ -127,6 +130,7 @@ def eager_attention_forward(
return attn_output, attn_weights


@use_kernelized_func(apply_rotary_pos_emb)
class MyNewModel2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

@@ -260,12 +264,12 @@ class MyNewModel2PreTrainedModel(PreTrainedModel):
"attentions": MyNewModel2Attention,
}

@torch.no_grad()
def _init_weights(self, module):
super()._init_weights(module)

# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
if "RMSNorm" in module.__class__.__name__:
module.weight.zero_()
init.zeros_(module.weight)


class MyNewModel2ForSequenceClassification(GenericForSequenceClassification, MyNewModel2PreTrainedModel):


+ 2
- 37
examples/modular-transformers/modeling_new_task_model.py View File

@@ -87,27 +87,17 @@ class NewTaskModelMultiModalProjector(nn.Module):
@auto_docstring
class NewTaskModelPreTrainedModel(PreTrainedModel):
config: NewTaskModelConfig
base_model_prefix = ""
base_model_prefix = "model"
input_modalities = ("image", "text")
supports_gradient_checkpointing = True
_no_split_modules = ["NewTaskModelMultiModalProjector"]
_skip_keys_device_placement = "past_key_values"

_can_compile_fullgraph = False
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_attention_backend = True

def _init_weights(self, module):
# important: this ported version of NewTaskModelisn't meant for training from scratch - only
# inference and fine-tuning
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)

if isinstance(module, nn.Linear):
module.weight.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.zero_()


def token_type_ids_mask_function(
token_type_ids: Optional[torch.Tensor],
@@ -249,12 +239,6 @@ class NewTaskModelModel(NewTaskModelPreTrainedModel):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)

def set_decoder(self, decoder):
self.language_model = decoder

def get_decoder(self):
return self.language_model

def get_image_features(self, pixel_values: torch.FloatTensor):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
@@ -457,28 +441,9 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)

def set_decoder(self, decoder):
self.model.set_decoder(decoder)

def get_decoder(self):
return self.model.get_decoder()

def get_image_features(self, pixel_values):
return self.model.get_image_features(pixel_values)

# Make modules available through conditional class for BC
@property
def language_model(self):
return self.model.language_model

@property
def vision_tower(self):
return self.model.vision_tower

@property
def multi_modal_projector(self):
return self.model.multi_modal_projector

@can_return_tuple
@auto_docstring
def forward(


+ 56
- 164
examples/modular-transformers/modeling_roberta.py View File

@@ -10,24 +10,20 @@ from typing import Optional, Union
import torch
import torch.nn as nn

from ... import initialization as init
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...masking_utils import create_causal_mask
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
from ...masking_utils import create_bidirectional_mask, create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import TransformersKwargs, auto_docstring, is_torch_flex_attn_available
from ...pytorch_utils import apply_chunking_to_forward
from ...utils import TransformersKwargs, auto_docstring
from ...utils.generic import check_model_inputs
from .configuration_roberta import RobertaConfig


if is_torch_flex_attn_available():
from ...integrations.flex_attention import make_flex_block_causal_mask


class RobertaEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""

@@ -109,7 +105,7 @@ def eager_attention_forward(
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling

if attention_mask is not None and attention_mask.ndim == 4:
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask

@@ -151,7 +147,7 @@ class RobertaSelfAttention(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Cache] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
@@ -163,14 +159,14 @@ class RobertaSelfAttention(nn.Module):
key_layer = self.key(hidden_states).view(*hidden_shape).transpose(1, 2)
value_layer = self.value(hidden_states).view(*hidden_shape).transpose(1, 2)

if past_key_value is not None:
if past_key_values is not None:
# decoder-only roberta can have a simple dynamic cache for example
current_past_key_value = past_key_value
if isinstance(past_key_value, EncoderDecoderCache):
current_past_key_value = past_key_value.self_attention_cache
current_past_key_values = past_key_values
if isinstance(past_key_values, EncoderDecoderCache):
current_past_key_values = past_key_values.self_attention_cache

# save all key/value_layer to cache to be re-used for fast auto-regressive generation
key_layer, value_layer = current_past_key_value.update(
key_layer, value_layer = current_past_key_values.update(
key_layer,
value_layer,
self.layer_idx,
@@ -224,7 +220,7 @@ class RobertaCrossAttention(nn.Module):
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[EncoderDecoderCache] = None,
past_key_values: Optional[EncoderDecoderCache] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
# determine input shapes
@@ -237,22 +233,22 @@ class RobertaCrossAttention(nn.Module):
# get query proj
query_layer = self.query(hidden_states).view(*q_input_shape).transpose(1, 2)

is_updated = past_key_value.is_updated.get(self.layer_idx) if past_key_value is not None else False
if past_key_value is not None and is_updated:
is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
if past_key_values is not None and is_updated:
# reuse k,v, cross_attentions
key_layer = past_key_value.cross_attention_cache.layers[self.layer_idx].keys
value_layer = past_key_value.cross_attention_cache.layers[self.layer_idx].values
key_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
value_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].values
else:
key_layer = self.key(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2)
value_layer = self.value(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2)

if past_key_value is not None:
if past_key_values is not None:
# save all states to the cache
key_layer, value_layer = past_key_value.cross_attention_cache.update(
key_layer, value_layer = past_key_values.cross_attention_cache.update(
key_layer, value_layer, self.layer_idx
)
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
past_key_value.is_updated[self.layer_idx] = True
past_key_values.is_updated[self.layer_idx] = True

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
@@ -293,25 +289,6 @@ class RobertaAttention(nn.Module):
attention_class = RobertaCrossAttention if is_cross_attention else RobertaSelfAttention
self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx)
self.output = RobertaSelfOutput(config)
self.pruned_heads = set()

def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)

# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)

def forward(
self,
@@ -319,7 +296,7 @@ class RobertaAttention(nn.Module):
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Cache] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
@@ -328,7 +305,7 @@ class RobertaAttention(nn.Module):
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
past_key_value=past_key_value,
past_key_values=past_key_values,
cache_position=cache_position,
**kwargs,
)
@@ -391,14 +368,14 @@ class RobertaLayer(GradientCheckpointingLayer):
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Cache] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
self_attention_output, _ = self.attention(
hidden_states,
attention_mask,
past_key_value=past_key_value,
past_key_values=past_key_values,
cache_position=cache_position,
**kwargs,
)
@@ -416,7 +393,7 @@ class RobertaLayer(GradientCheckpointingLayer):
None, # attention_mask
encoder_hidden_states,
encoder_attention_mask,
past_key_value=past_key_value,
past_key_values=past_key_values,
**kwargs,
)
attention_output = cross_attention_output
@@ -455,7 +432,7 @@ class RobertaEncoder(nn.Module):
attention_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_values,
past_key_values=past_key_values,
cache_position=cache_position,
**kwargs,
)
@@ -506,7 +483,6 @@ class RobertaLMPredictionHead(nn.Module):
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)

self.bias = nn.Parameter(torch.zeros(config.vocab_size))

def forward(self, hidden_states):
@@ -530,21 +506,12 @@ class RobertaPreTrainedModel(PreTrainedModel):
"cross_attentions": RobertaCrossAttention,
}

@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.zero_()
module.weight.fill_(1.0)
elif isinstance(module, RobertaLMPredictionHead):
module.bias.zero_()
super()._init_weights(module)
if isinstance(module, RobertaLMPredictionHead):
init.zeros_(module.bias)


@auto_docstring(
@@ -585,14 +552,6 @@ class RobertaModel(RobertaPreTrainedModel):
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value

def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)

@check_model_inputs
@auto_docstring
def forward(
@@ -615,19 +574,22 @@ class RobertaModel(RobertaPreTrainedModel):
use_cache = False

if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)

if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

if input_ids is not None:
device = input_ids.device
input_shape = input_ids.shape
seq_length = input_ids.shape[1]
else:
device = inputs_embeds.device
input_shape = inputs_embeds.shape[:-1]
seq_length = inputs_embeds.shape[1]

seq_length = input_shape[1]
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(past_key_values_length, past_key_values_length + seq_length, device=device)
@@ -641,7 +603,6 @@ class RobertaModel(RobertaPreTrainedModel):
)

attention_mask, encoder_attention_mask = self._create_attention_masks(
input_shape=input_shape,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
embedding_output=embedding_output,
@@ -672,7 +633,6 @@ class RobertaModel(RobertaPreTrainedModel):

def _create_attention_masks(
self,
input_shape,
attention_mask,
encoder_attention_mask,
embedding_output,
@@ -680,95 +640,27 @@ class RobertaModel(RobertaPreTrainedModel):
cache_position,
past_key_values,
):
if attention_mask is not None and attention_mask.dim() == 2:
if self.config.is_decoder:
attention_mask = create_causal_mask(
config=self.config,
input_embeds=embedding_output,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
)
else:
attention_mask = self._update_full_mask(
attention_mask,
embedding_output,
)
elif attention_mask is not None and attention_mask.dim() == 3:
if "flash" in self.config._attn_implementation or self.config._attn_implementation == "flex_attention":
raise ValueError(
"Passing attention mask with a 3D/4D shape does not work with type "
f"{self.config._attn_implementation} - please use either `sdpa` or `eager` instead."
)
attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
if self.config.is_decoder:
attention_mask = create_causal_mask(
config=self.config,
input_embeds=embedding_output,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
)
else:
attention_mask = create_bidirectional_mask(
config=self.config,
input_embeds=embedding_output,
attention_mask=attention_mask,
)

if encoder_attention_mask is not None:
if encoder_attention_mask.dim() == 2:
encoder_attention_mask = self._update_cross_attn_mask(
encoder_hidden_states,
encoder_attention_mask,
embedding_output.shape[:2],
embedding_output,
)
else:
if "flash" in self.config._attn_implementation or self.config._attn_implementation == "flex_attention":
raise ValueError(
"Passing attention mask with a 3D/4D shape does not work with type "
f"{self.config._attn_implementation} - please use either `sdpa` or `eager` instead."
)
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
encoder_attention_mask = create_bidirectional_mask(
config=self.config,
input_embeds=embedding_output,
attention_mask=encoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
)

return attention_mask, encoder_attention_mask

def _update_full_mask(
self,
attention_mask: Union[torch.Tensor, None],
inputs_embeds: torch.Tensor,
):
if attention_mask is not None:
if "flash" in self.config._attn_implementation:
attention_mask = attention_mask if 0 in attention_mask else None
elif self.config._attn_implementation == "sdpa":
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
elif self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)

return attention_mask

def _update_cross_attn_mask(
self,
encoder_hidden_states: Union[torch.Tensor, None],
encoder_attention_mask: Union[torch.Tensor, None],
input_shape: torch.Size,
inputs_embeds: torch.Tensor,
):
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
if "flash" in self.config._attn_implementation:
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
elif self.config._attn_implementation == "sdpa":
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask,
inputs_embeds.dtype,
tgt_len=input_shape[-1],
)
elif self.config._attn_implementation == "flex_attention":
if isinstance(encoder_attention_mask, torch.Tensor):
encoder_attention_mask = make_flex_block_causal_mask(
encoder_attention_mask,
query_length=input_shape[-1],
is_causal=False,
)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)

return encoder_attention_mask

+ 44
- 13
examples/modular-transformers/modeling_super.py View File

@@ -14,13 +14,13 @@ from transformers.modeling_outputs import CausalLMOutputWithPast

from ...activations import ACT2FN
from ...cache_utils import Cache
from ...integrations import use_kernel_forward_from_hub
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring
from ...utils.generic import check_model_inputs
from ...utils.generic import check_model_inputs, maybe_autocast
from .configuration_super import SuperConfig


@@ -50,20 +50,49 @@ class SuperRotaryEmbedding(nn.Module):

def __init__(self, config: SuperConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_parameters") and isinstance(config.rope_parameters, dict):
self.rope_type = config.rope_parameters.get("rope_type", config.rope_parameters.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings

self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.rope_type = self.config.rope_parameters["rope_type"]
rope_init_fn: Callable = self.compute_default_rope_parameters
if self.rope_type != "default":
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)

self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
self.original_inv_freq = inv_freq

@staticmethod
def compute_default_rope_parameters(
config: Optional[SuperConfig] = None,
device: Optional["torch.device"] = None,
seq_len: Optional[int] = None,
) -> tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies according to the original RoPE implementation
Args:
config ([`~transformers.PreTrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
base = config.rope_parameters["rope_theta"]
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads

attention_factor = 1.0 # Unused in this type of RoPE

# Compute the inverse frequencies
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
)
return inv_freq, attention_factor

@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
@@ -72,7 +101,7 @@ class SuperRotaryEmbedding(nn.Module):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
@@ -104,6 +133,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

@@ -169,6 +199,7 @@ def eager_attention_forward(
return attn_output, attn_weights


@use_kernelized_func(apply_rotary_pos_emb)
class SuperAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

@@ -198,8 +229,8 @@ class SuperAttention(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],


+ 5
- 2
examples/modular-transformers/modeling_switch_function.py View File

@@ -12,6 +12,7 @@ import torch
from torch import nn

from ...cache_utils import Cache
from ...integrations import use_kernel_func_from_hub, use_kernelized_func
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import TransformersKwargs
@@ -26,6 +27,7 @@ def rotate_half(x):
return rot_x


@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

@@ -91,6 +93,7 @@ def eager_attention_forward(
return attn_output, attn_weights


@use_kernelized_func(apply_rotary_pos_emb)
class SwitchFunctionAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

@@ -120,8 +123,8 @@ class SwitchFunctionAttention(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],


+ 31
- 26
examples/modular-transformers/modeling_test_detr.py View File

@@ -4,6 +4,7 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_test_detr.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨

import math
import warnings
from dataclasses import dataclass
@@ -13,6 +14,7 @@ import torch
import torch.nn.functional as F
from torch import Tensor, nn

from ... import initialization as init
from ...activations import ACT2FN
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
@@ -203,10 +205,10 @@ def replace_batch_norm(model):
new_module = TestDetrFrozenBatchNorm2d(module.num_features)

if module.weight.device != torch.device("meta"):
new_module.weight.data.copy_(module.weight)
new_module.bias.data.copy_(module.bias)
new_module.running_mean.data.copy_(module.running_mean)
new_module.running_var.data.copy_(module.running_var)
new_module.weight.copy_(module.weight)
new_module.bias.copy_(module.bias)
new_module.running_mean.copy_(module.running_mean)
new_module.running_var.copy_(module.running_var)

model._modules[name] = new_module

@@ -810,6 +812,7 @@ class TestDetrPreTrainedModel(PreTrainedModel):
config: TestDetrConfig
base_model_prefix = "model"
main_input_name = "pixel_values"
input_modalities = ("image",)
supports_gradient_checkpointing = True
_no_split_modules = [
r"TestDetrConvEncoder",
@@ -817,14 +820,15 @@ class TestDetrPreTrainedModel(PreTrainedModel):
r"TestDetrDecoderLayer",
]

@torch.no_grad()
def _init_weights(self, module):
std = self.config.init_std

if isinstance(module, TestDetrLearnedPositionEmbedding):
nn.init.uniform_(module.row_embeddings.weight)
nn.init.uniform_(module.column_embeddings.weight)
init.uniform_(module.row_embeddings.weight)
init.uniform_(module.column_embeddings.weight)
elif isinstance(module, TestDetrMultiscaleDeformableAttention):
nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
init.constant_(module.sampling_offsets.weight, 0.0)
default_dtype = torch.get_default_dtype()
thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * (
2.0 * math.pi / module.n_heads
@@ -837,27 +841,28 @@ class TestDetrPreTrainedModel(PreTrainedModel):
)
for i in range(module.n_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
nn.init.constant_(module.attention_weights.weight.data, 0.0)
nn.init.constant_(module.attention_weights.bias.data, 0.0)
nn.init.xavier_uniform_(module.value_proj.weight.data)
nn.init.constant_(module.value_proj.bias.data, 0.0)
nn.init.xavier_uniform_(module.output_proj.weight.data)
nn.init.constant_(module.output_proj.bias.data, 0.0)
init.copy_(module.sampling_offsets.bias, grid_init.view(-1))
init.constant_(module.attention_weights.weight, 0.0)
init.constant_(module.attention_weights.bias, 0.0)
init.xavier_uniform_(module.value_proj.weight)
init.constant_(module.value_proj.bias, 0.0)
init.xavier_uniform_(module.output_proj.weight)
init.constant_(module.output_proj.bias, 0.0)
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
module.weight.normal_(mean=0.0, std=std)
init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
module.bias.zero_()
init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
module.weight.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
init.normal_(module.weight, mean=0.0, std=std)
# Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
init.zeros_(module.weight[module.padding_idx])
if hasattr(module, "reference_points") and not self.config.two_stage:
nn.init.xavier_uniform_(module.reference_points.weight.data, gain=1.0)
nn.init.constant_(module.reference_points.bias.data, 0.0)
init.xavier_uniform_(module.reference_points.weight, gain=1.0)
init.constant_(module.reference_points.bias, 0.0)
if hasattr(module, "level_embed"):
nn.init.normal_(module.level_embed)
init.normal_(module.level_embed)


class TestDetrEncoder(TestDetrPreTrainedModel):
@@ -924,6 +929,7 @@ class TestDetrEncoder(TestDetrPreTrainedModel):
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
r"""
Args:
@@ -1046,6 +1052,7 @@ class TestDetrDecoder(TestDetrPreTrainedModel):
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
r"""
Args:
@@ -1267,9 +1274,6 @@ class TestDetrModel(TestDetrPreTrainedModel):

self.post_init()

def get_encoder(self):
return self.encoder

def freeze_backbone(self):
for name, param in self.backbone.conv_encoder.model.named_parameters():
param.requires_grad_(False)
@@ -1379,6 +1383,7 @@ class TestDetrModel(TestDetrPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[tuple[torch.FloatTensor], TestDetrModelOutput]:
r"""
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):


+ 8
- 1
examples/modular-transformers/modular_multimodal2.py View File

@@ -35,6 +35,7 @@ class Multimodal2VisionEncoderLayer(CLIPEncoderLayer):
def __init__(self, config):
super().__init__()
self.mlp = Multimodal2VisionMLP(config)
self.self_attn = Multimodal2VisionAttention(config)


class Multimodal2VisionEncoder(CLIPEncoder):
@@ -43,7 +44,8 @@ class Multimodal2VisionEncoder(CLIPEncoder):
self.layers = nn.ModuleList([Multimodal2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])


# Finally here the `Vision` part was correct in CLIP, but we still need to tell it that the encoder arg should use it as well
# Finally here the `Vision` part was correct in CLIP, but we still need to tell it that the encoder and attn arg should
# use it as well
class Multimodal2VisionTransformer(CLIPVisionTransformer):
def __init__(self, config):
super().__init__(config)
@@ -51,6 +53,11 @@ class Multimodal2VisionTransformer(CLIPVisionTransformer):


class Multimodal2VisionPreTrainedModel(CLIPPreTrainedModel):
_can_record_outputs = {
"hidden_states": Multimodal2VisionEncoderLayer,
"attentions": Multimodal2VisionAttention,
}

def _init_weights(self, module):
if isinstance(module, Multimodal2VisionMLP):
pass


+ 1
- 2
examples/modular-transformers/modular_new_model.py View File

@@ -23,11 +23,10 @@ class NewModelConfig(GemmaConfig):
eos_token_id=1,
bos_token_id=2,
tie_word_embeddings=True,
rope_theta=10000.0,
rope_parameters=None,
attention_bias=False,
attention_dropout=0.0,
use_bidirectional_attention=False,
layer_types=None,
**kwargs,
):
super().__init__(self, **kwargs)


+ 12
- 2
examples/pytorch/continuous_batching.py View File

@@ -182,13 +182,16 @@ if __name__ == "__main__":

# Benchmark parameters
parser.add_argument("--samples", type=int, default=500, help="Number of samples to generate")
parser.add_argument(
"--input-length", type=int, default=None, help="Length of input sequences. Leave to None to mimic real eval."
)
parser.add_argument("--max-new-tokens", type=int, default=512, help="Maximum number of new tokens to generate")
parser.add_argument("--force-max-length", action="store_true", help="Force generation to stop at max length")

parser.add_argument("--add-prefix", action="store_true", help="Add a prefix to the samples")
parser.add_argument("--compare", action="store_true", help="Compare CB generation with classic generate")
parser.add_argument("--profile", type=str, default=None)
parser.add_argument("--metrics", action="store_true")
parser.add_argument("--force-max-length", action="store_true", help="Force generation to stop at max length")

# Display parameters
parser.add_argument("--displayed", type=int, default=0, help="Number of samples to display")
@@ -251,6 +254,12 @@ if __name__ == "__main__":
else:
possible_prefixes = [None]

tokenizer_kwargs = {"add_generation_prompt": True}
if args.input_length is not None:
tokenizer_kwargs["max_length"] = args.input_length
tokenizer_kwargs["truncation"] = True
tokenizer_kwargs["padding"] = True

batched_inputs = []
for item, prefix in zip(dataset, cycle(possible_prefixes)):
messages = []
@@ -261,7 +270,7 @@ if __name__ == "__main__":
else:
question = prefix + "\n\n" + question
messages.append({"role": "user", "content": question})
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
inputs = tokenizer.apply_chat_template(messages, **tokenizer_kwargs)
inputs = inputs if isinstance(inputs, list) else inputs["input_ids"]
batched_inputs.append(inputs)

@@ -283,6 +292,7 @@ if __name__ == "__main__":
generation_cfg.compile_config = CompileConfig(
fullgraph=True,
mode="max-autotune-no-cudagraphs",
dynamic=True, # FIXME: if we warmup all graphs, this is not needed anymore
)

# If we need to compare, we need to generate the reference outputs


+ 1
- 1
setup.py View File

@@ -140,7 +140,7 @@ _deps = [
"tensorboard",
"timeout-decorator",
"tiktoken",
"timm<=1.0.19,!=1.0.18",
"timm>=1.0.20",
"tokenizers>=0.22.0,<=0.23.0",
"torch>=2.2",
"torchaudio",


+ 1
- 1
src/transformers/dependency_versions_table.py View File

@@ -75,7 +75,7 @@ deps = {
"tensorboard": "tensorboard",
"timeout-decorator": "timeout-decorator",
"tiktoken": "tiktoken",
"timm": "timm<=1.0.19,!=1.0.18",
"timm": "timm>=1.0.20",
"tokenizers": "tokenizers>=0.22.0,<=0.23.0",
"torch": "torch>=2.2",
"torchaudio": "torchaudio",


+ 38
- 9
src/transformers/feature_extraction_utils.py View File

@@ -67,11 +67,18 @@ class BatchFeature(UserDict):
tensor_type (`Union[None, str, TensorType]`, *optional*):
You can give a tensor_type here to convert the lists of integers in PyTorch/Numpy Tensors at
initialization.
skip_tensor_conversion (`list[str]` or `set[str]`, *optional*):
List or set of keys that should NOT be converted to tensors, even when `tensor_type` is specified.
"""

def __init__(self, data: Optional[dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None):
def __init__(
self,
data: Optional[dict[str, Any]] = None,
tensor_type: Union[None, str, TensorType] = None,
skip_tensor_conversion: Optional[Union[list[str], set[str]]] = None,
):
super().__init__(data)
self.convert_to_tensors(tensor_type=tensor_type)
self.convert_to_tensors(tensor_type=tensor_type, skip_tensor_conversion=skip_tensor_conversion)

def __getitem__(self, item: str) -> Any:
"""
@@ -110,6 +117,14 @@ class BatchFeature(UserDict):
import torch

def as_tensor(value):
if torch.is_tensor(value):
return value

# stack list of tensors if tensor_type is PyTorch (# torch.tensor() does not support list of tensors)
if isinstance(value, (list, tuple)) and len(value) > 0 and torch.is_tensor(value[0]):
return torch.stack(value)

# convert list of numpy arrays to numpy array (stack) if tensor_type is Numpy
if isinstance(value, (list, tuple)) and len(value) > 0:
if isinstance(value[0], np.ndarray):
value = np.array(value)
@@ -138,7 +153,11 @@ class BatchFeature(UserDict):
is_tensor = is_numpy_array
return is_tensor, as_tensor

def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
def convert_to_tensors(
self,
tensor_type: Optional[Union[str, TensorType]] = None,
skip_tensor_conversion: Optional[Union[list[str], set[str]]] = None,
):
"""
Convert the inner content to tensors.

@@ -146,6 +165,8 @@ class BatchFeature(UserDict):
tensor_type (`str` or [`~utils.TensorType`], *optional*):
The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
`None`, no modification is done.
skip_tensor_conversion (`list[str]` or `set[str]`, *optional*):
List or set of keys that should NOT be converted to tensors, even when `tensor_type` is specified.
"""
if tensor_type is None:
return self
@@ -154,18 +175,26 @@ class BatchFeature(UserDict):

# Do the tensor conversion in batch
for key, value in self.items():
# Skip keys explicitly marked for no conversion
if skip_tensor_conversion and key in skip_tensor_conversion:
continue

try:
if not is_tensor(value):
tensor = as_tensor(value)

self[key] = tensor
except: # noqa E722
except Exception as e:
if key == "overflowing_values":
raise ValueError("Unable to create tensor returning overflowing values of different lengths. ")
raise ValueError(
f"Unable to create tensor for '{key}' with overflowing values of different lengths. "
f"Original error: {str(e)}"
) from e
raise ValueError(
"Unable to create tensor, you should probably activate padding "
"with 'padding=True' to have batched tensors with the same length."
)
f"Unable to convert output '{key}' (type: {type(value).__name__}) to tensor: {str(e)}\n"
f"You can try:\n"
f" 1. Use padding=True to ensure all outputs have the same shape\n"
f" 2. Set return_tensors=None to return Python objects instead of tensors"
) from e

return self



+ 17
- 21
src/transformers/generation/continuous_batching/continuous_api.py View File

@@ -259,7 +259,7 @@ class ContinuousBatchProcessor:
self.cumulative_seqlens_q = torch.empty((self.max_batch_tokens + 1,), **self.tensor_metadata)
self.max_seqlen_q = 0
self.logits_indices = torch.empty((self.max_batch_tokens,), **self.tensor_metadata)
self.output_ids = torch.empty((1, self.max_batch_tokens), **self.tensor_metadata)
self.output_ids = torch.empty((self.max_batch_tokens,), **self.tensor_metadata)

# For some kwargs, we have a dict of tensors with as many items as there are attention types
layer_types = getattr(self.config, "layer_types", None)
@@ -311,7 +311,7 @@ class ContinuousBatchProcessor:
self.cumulative_seqlens_q[: b_size + 1].zero_()
self.max_seqlen_q = 0
self.logits_indices[:q_len].fill_(-1)
self.output_ids[:, :q_len].fill_(-1)
self.output_ids[:q_len].fill_(-1)

# Reset the attributes that are either tensors or dict of tensors
for layer_type in self.cumulative_seqlens_k:
@@ -447,7 +447,7 @@ class ContinuousBatchProcessor:
self.metrics.record_batch_metrics(self.requests_in_batch)

# Reset the static tensors used for storage
self.reset_static_tensors() # TODO: this might be unnecessary
self.reset_static_tensors() # FIXME: why does this make the generation faster?

# Prepare accumulators
self.actual_query_length = 0
@@ -557,13 +557,10 @@ class ContinuousBatchProcessor:
self.actual_index_sizes[i] = (len(group_read_indices), len(group_write_indices))

@traced
def _sync(self) -> list[int]:
if self.output_ids is not None:
try:
return self.output_ids.tolist()[0]
except Exception:
return [0, 1]
return [0, 0]
def _get_new_tokens(self, num_new_tokens: int) -> list[int]:
indices = self.logits_indices[:num_new_tokens]
new_tokens = self.output_ids[indices]
return new_tokens.tolist()

@traced
def _maybe_send_output(self, state: RequestState) -> None:
@@ -574,13 +571,13 @@ class ContinuousBatchProcessor:
@traced
def update_batch(self) -> None:
"""Update request states based on generated tokens."""
out_tokens = self._sync()
new_tokens = self._get_new_tokens(len(self.requests_in_batch))
for i, state in enumerate(self.requests_in_batch):
# If the request has no remaining prompt ids, it means prefill has already ended or just finished
if len(state.remaining_prefill_tokens) == 0:
self.metrics.record_ttft_metric(state.created_time, state.request_id)
state.status = RequestStatus.DECODING
token = out_tokens[self.logits_indices[i]]
token = new_tokens[i]
state.tokens_to_process = [token]
# Update the request and stop if it is complete
is_finished = state.update_and_check_completion(token)
@@ -727,12 +724,11 @@ class ContinuousBatchProcessor:
probs = nn.functional.softmax(probs, dim=-1)
# probs[0] has shape [seq_len, vocab_size], multinomial returns [seq_len, 1]
next_tokens = torch.multinomial(probs[0], num_samples=1).squeeze(-1) # Now [seq_len]
# Add batch dimension back to match argmax output
next_tokens = next_tokens.unsqueeze(0) # Now [1, seq_len]
else:
next_tokens = torch.argmax(probs, dim=-1) # Already [1, seq_len]
tokens = next_tokens.size(1) # Get seq_len dimension
self.output_ids[:, :tokens].copy_(next_tokens)
next_tokens = torch.argmax(probs, dim=-1) # shape is [1, seq_len]
next_tokens = next_tokens.squeeze(0) # shape is [seq_len]
tokens = next_tokens.size(0) # Get seq_len dimension
self.output_ids[:tokens].copy_(next_tokens)


# Manager Class (User Interface)
@@ -950,6 +946,10 @@ class ContinuousBatchingManager:
streaming: bool = False,
record_timestamps: bool = False,
) -> None:
# If there is prefix sharing, we sort the inputs to maximize cache hits
if self._allow_prefix_sharing:
inputs = sorted(inputs, reverse=True)
# Add requests in order
for input_ids in inputs:
self.add_request(
input_ids, max_new_tokens=max_new_tokens, streaming=streaming, record_timestamps=record_timestamps
@@ -1080,10 +1080,6 @@ class ContinuousBatchingManager:
)

self._generation_step()

if torch.cuda.is_available():
torch.cuda.synchronize() # FIXME: why is this needed?
# Processor updates the batch after generation step is truly over
batch_processor.update_batch()

@traced


+ 15
- 4
src/transformers/image_processing_utils_fast.py View File

@@ -932,11 +932,22 @@ class BaseImageProcessorFast(BaseImageProcessor):
if do_pad:
processed_images = self.pad(processed_images, pad_size=pad_size, disable_grouping=disable_grouping)

processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)

def to_dict(self):
encoder_dict = super().to_dict()
encoder_dict.pop("_valid_processor_keys", None)
encoder_dict.pop("_valid_kwargs_names", None)
return encoder_dict

# Filter out None values that are class defaults, but preserve explicitly set None values
filtered_dict = {}
for key, value in encoder_dict.items():
if value is None:
class_default = getattr(type(self), key, "NOT_FOUND")
# Keep None if user explicitly set it (class default is non-None)
if class_default != "NOT_FOUND" and class_default is not None:
filtered_dict[key] = value
else:
filtered_dict[key] = value

filtered_dict.pop("_valid_processor_keys", None)
filtered_dict.pop("_valid_kwargs_names", None)
return filtered_dict

+ 26
- 0
src/transformers/integrations/accelerate.py View File

@@ -554,6 +554,32 @@ def offload_weight(weight: torch.Tensor, weight_name: str, offload_folder: str |
return offload_index


def load_offloaded_parameter(model: "PreTrainedModel", param_name: str) -> torch.Tensor:
"""Load `param_name` from disk, if it was offloaded due to the device_map, and thus lives as a meta parameter
inside `model`.
This is needed when resaving a model, when some parameters were offloaded (we need to load them from disk, to
then resave them to disk in the correct shard...)."""
# Start from the most inner module, and try to find the hook that was used for offloading the param
module_parts = param_name.split(".")
modules_to_check = [".".join(module_parts[:-idx]) for idx in range(1, len(module_parts))] + [""]
for parent_name in modules_to_check:
parent = model.get_submodule(parent_name)
if hasattr(parent, "_hf_hook"):
weights_map = parent._hf_hook.weights_map
truncated_param_name = param_name.replace(f"{parent_name}." if parent_name != "" else parent_name, "")
break
# If we did not break the loop, something is wrong
else:
raise ValueError(
f"{param_name} is on the meta device because it was offloaded, but we could not find "
"the corresponding hook for it"
)

# This call loads it from disk
tensor = weights_map[truncated_param_name]
return tensor


def _init_infer_auto_device_map(
model: nn.Module,
max_memory: dict[int | str, int | str] | None = None,


+ 219
- 260
src/transformers/modeling_utils.py View File

@@ -16,7 +16,6 @@
import collections
import copy
import functools
import gc
import importlib.metadata
import inspect
import json
@@ -64,6 +63,7 @@ from .integrations.accelerate import (
check_and_set_device_map,
expand_device_map,
init_empty_weights,
load_offloaded_parameter,
)
from .integrations.deepspeed import _load_state_dict_into_zero3_model
from .integrations.eager_paged import eager_paged_attention_forward
@@ -130,7 +130,6 @@ from .utils.quantization_config import QuantizationMethod
if is_accelerate_available():
from accelerate.hooks import add_hook_to_module
from accelerate.utils import extract_model_from_parallel
from accelerate.utils.modeling import get_state_dict_from_offload


_torch_distributed_available = torch.distributed.is_available()
@@ -156,6 +155,12 @@ _init_weights = True
_is_quantized = False
_is_ds_init_called = False

# Mapping from flash attention implementations to their kernel fallback repositories
FLASH_ATTN_KERNEL_FALLBACK = {
"flash_attention_2": "kernels-community/flash-attn2",
"flash_attention_3": "kernels-community/vllm-flash-attn3",
}


def is_local_dist_rank_0():
return (
@@ -233,23 +238,28 @@ def set_zero3_state():
_is_ds_init_called = False


def restore_default_dtype(func):
@contextmanager
def local_torch_dtype(dtype: torch.dtype, model_class_name: str | None = None):
"""
Decorator to restore the default torch dtype
at the end of the function. Serves
as a backup in case calling the function raises
an error after the function has changed the default dtype but before it could restore it.
Locally change the torch default dtype to `dtype`, and restore the old one upon exiting the context.
If `model_class_name` is provided, it's used to provide a more helpful error message if `dtype` is not valid.
"""
# Just a more helping error before we set `torch.set_default_dtype` later on which would crash in this case
if not dtype.is_floating_point:
if model_class_name is not None:
error_message = (
f"{model_class_name} cannot be instantiated under `dtype={dtype}` as it's not a floating-point dtype"
)
else:
error_message = f"Cannot set `{dtype}` as torch's default as it's not a floating-point dtype"
raise ValueError(error_message)

@wraps(func)
def _wrapper(*args, **kwargs):
old_dtype = torch.get_default_dtype()
try:
return func(*args, **kwargs)
finally:
torch.set_default_dtype(old_dtype)

return _wrapper
original_dtype = torch.get_default_dtype()
try:
torch.set_default_dtype(dtype)
yield
finally:
torch.set_default_dtype(original_dtype)


def get_torch_context_manager_or_global_device():
@@ -405,6 +415,86 @@ def _find_identical(tensors: list[set[str]], state_dict: dict[str, torch.Tensor]
return shared_tensors, identical


def remove_tied_weights_from_state_dict(
state_dict: dict[str, torch.Tensor], model: "PreTrainedModel"
) -> dict[str, torch.Tensor]:
"""
Remove all tied weights from the given `state_dict`, making sure to keep only the main weight that `model`
will expect when reloading (even if we know tie weights symmetrically, it's better to keep the intended one).
This is because `safetensors` does not allow tensor aliasing - so we're going to remove aliases before saving.
"""
# To avoid any potential mistakes and mismatches between config and actual tied weights, here we check the pointers
# of the Tensors themselves -> we are guaranteed to find all the actual tied weights
ptrs = collections.defaultdict(list)
for name, tensor in state_dict.items():
if not isinstance(tensor, torch.Tensor):
# Sometimes in the state_dict we have non-tensor objects.
# e.g. in bitsandbytes we have some `str` objects in the state_dict
# In the non-tensor case, fall back to the pointer of the object itself
ptrs[id(tensor)].append(name)

elif tensor.device.type == "meta":
# In offloaded cases, there may be meta tensors in the state_dict.
# For these cases, key by the pointer of the original tensor object
# (state_dict tensors are detached and therefore no longer shared)
tensor = model.get_parameter(name)
ptrs[id(tensor)].append(name)

else:
ptrs[id_tensor_storage(tensor)].append(name)

shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}

# Recursively descend to find tied weight keys
all_potential_tied_weights_keys = set(_get_tied_weight_keys(model))
error_names = []
to_delete_names = set()
# Removing the keys which are declared as known duplicates on load. This allows to make sure the name which is
# kept is consistent
if all_potential_tied_weights_keys is not None:
for names in shared_ptrs.values():
found = 0
for name in sorted(names):
matches_pattern = any(re.search(pat, name) for pat in all_potential_tied_weights_keys)
if matches_pattern and name in state_dict:
found += 1
if found < len(names):
to_delete_names.add(name)
# We are entering a place where the weights and the transformers configuration do NOT match.
shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
# Those are actually tensor sharing but disjoint from each other, we can safely clone them
# Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
for name in disjoint_names:
state_dict[name] = state_dict[name].clone()

# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
# If the link between tensors was done at runtime then `from_pretrained` will not get
# the key back leading to random tensor. A proper warning will be shown
# during reload (if applicable), but since the file is not necessarily compatible with
# the config, better show a proper warning.
shared_names, identical_names = _find_identical(shared_names, state_dict)
# delete tensors that have identical storage
for inames in identical_names:
known = inames.intersection(to_delete_names)
for name in known:
del state_dict[name]
unknown = inames.difference(to_delete_names)
if len(unknown) > 1:
error_names.append(unknown)

if shared_names:
error_names.extend(shared_names)

if len(error_names) > 0:
raise RuntimeError(
f"The weights trying to be saved contained shared tensors {error_names} which are not properly defined. "
f"We found all the potential target tied weights keys to be: {all_potential_tied_weights_keys}.\n"
"This can also just mean that the module's tied weight keys are wrong vs the actual tied weights in the model.",
)

return state_dict


def _load_parameter_into_model(model: "PreTrainedModel", param_name: str, tensor: torch.Tensor):
"""Cast a single parameter `param_name` into the `model`, with value `tensor`."""
module, param_type = get_module_from_name(model, param_name)
@@ -696,23 +786,21 @@ def _get_resolved_checkpoint_files(


def _get_dtype(
cls,
dtype: Optional[Union[str, torch.dtype, dict]],
checkpoint_files: Optional[list[str]],
config: PreTrainedConfig,
sharded_metadata: Optional[dict],
state_dict: Optional[dict],
weights_only: bool,
) -> tuple[PreTrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
) -> tuple[PreTrainedConfig, torch.dtype]:
"""Find the correct `dtype` to use based on provided arguments. Also update the `config` based on the
inferred dtype. We do the following:
1. If dtype is not None, we use that dtype
2. If dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
we also may have config.dtype available, but we won't rely on it till v5
1. If dtype is "auto", we try to read the config, else auto-detect dtype from the loaded state_dict, by checking
its first weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
2. Else, use the dtype provided as a dict or str
"""
dtype_orig = None
is_sharded = sharded_metadata is not None
asked_dtype = dtype

if dtype is not None:
if isinstance(dtype, str):
@@ -736,43 +824,46 @@ def _get_dtype(
)
elif hasattr(torch, dtype):
dtype = getattr(torch, dtype)
config.dtype = dtype
for sub_config_key in config.sub_configs:
if (sub_config := getattr(config, sub_config_key)) is not None:
sub_config.dtype = dtype
elif isinstance(dtype, torch.dtype):
config.dtype = dtype
for sub_config_key in config.sub_configs:
if (sub_config := getattr(config, sub_config_key)) is not None:
sub_config.dtype = dtype
elif isinstance(dtype, dict):
for key, curr_dtype in dtype.items():
if hasattr(config, key):
value = getattr(config, key)
curr_dtype = curr_dtype if not isinstance(curr_dtype, str) else getattr(torch, curr_dtype)
value.dtype = curr_dtype
# main torch dtype for modules that aren't part of any sub-config
dtype = dtype.get("")
dtype = dtype if not isinstance(dtype, str) else getattr(torch, dtype)
config.dtype = dtype
if dtype is None:
dtype = torch.float32
else:
else:
raise ValueError(
"`dtype` provided as a `str` can only be `'auto'`, or a string representation of a valid `torch.dtype`"
)

# cast it to a proper `torch.dtype` object
dtype = getattr(torch, dtype) if isinstance(dtype, str) else dtype
elif not isinstance(dtype, (dict, torch.dtype)):
raise ValueError(
f"`dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `dtype` "
f"for each sub-config in composite configs, but received {dtype}"
)
else:
# set torch.get_default_dtype() (usually fp32) as the default dtype if `None` is provided
dtype = torch.get_default_dtype()

dtype_orig = cls._set_default_dtype(dtype)
# Get the main dtype
if isinstance(dtype, dict):
main_dtype = dtype.get("", torch.get_default_dtype())
main_dtype = getattr(torch, main_dtype) if isinstance(main_dtype, str) else main_dtype
else:
# set fp32 as the default dtype for BC
default_dtype = torch.get_default_dtype()
config.dtype = default_dtype
for key in config.sub_configs:
if (sub_config := getattr(config, key)) is not None:
sub_config.dtype = default_dtype
dtype = dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
return config, dtype, dtype_orig
main_dtype = dtype

# Set it on the config and subconfigs
config.dtype = main_dtype
for sub_config_key in config.sub_configs:
if (sub_config := getattr(config, sub_config_key)) is not None:
# The dtype was "auto" -> try to read the subconfig dtype value if any
if asked_dtype == "auto":
sub_dtype = getattr(sub_config, "dtype", main_dtype)
sub_dtype = getattr(torch, sub_dtype) if isinstance(sub_dtype, str) else sub_dtype
# The dtype was provided as a dict, try to see if we match the subconfig name
elif isinstance(dtype, dict):
sub_dtype = dtype.get(sub_config_key, main_dtype)
sub_dtype = getattr(torch, sub_dtype) if isinstance(sub_dtype, str) else sub_dtype
else:
sub_dtype = main_dtype
sub_config.dtype = sub_dtype

return config, main_dtype


class PipelineParallel(Enum):
@@ -798,11 +889,7 @@ class ModuleUtilsMixin:
"""
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
"""
dtype = self._dtype or next(param.dtype for param in self.parameters() if param.is_floating_point())
if isinstance(dtype, str):
if hasattr(torch, dtype):
dtype = getattr(torch, dtype)
return dtype
return next(param.dtype for param in self.parameters() if param.is_floating_point())

def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
"""
@@ -1081,7 +1168,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
_keep_in_fp32_modules_strict = None

dtype_plan: Optional[dict[str, torch.dtype]] = None
_dtype: Optional[Union[str, torch.dtype]] = torch.get_default_dtype()

# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
@@ -1226,8 +1312,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.config = config
default_dtype = torch.get_default_dtype()
self._dtype = default_dtype

# Check the attention implementation is supported, or set it if not yet set (on the internal attr, to avoid
# setting it recursively)
@@ -1400,7 +1484,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
self.model_tags.append(tag)

@classmethod
@restore_default_dtype
def _from_config(cls, config, **kwargs):
"""
All context managers that the model should be initialized under go here.
@@ -1409,9 +1492,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
dtype (`torch.dtype`, *optional*):
Override the default `dtype` and load the model under this dtype.
"""
# when we init a model from within another model (e.g. VLMs) and dispatch on FA2
# a warning is raised that dtype should be fp16. Since we never pass dtype from within
# modeling code, we can try to infer it here same way as done in `from_pretrained`
# For BC on the old `torch_dtype`
dtype = kwargs.pop("dtype", config.dtype)
if (torch_dtype := kwargs.pop("torch_dtype", None)) is not None:
@@ -1421,67 +1501,27 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
if isinstance(dtype, str):
dtype = getattr(torch, dtype)

# override default dtype if needed
dtype_orig = None
if dtype is not None:
dtype_orig = cls._set_default_dtype(dtype)

# If passing `attn_implementation` as kwargs, respect it (it will be applied recursively on subconfigs)
if "attn_implementation" in kwargs:
config._attn_implementation = kwargs.pop("attn_implementation")

init_contexts = []
if dtype is not None:
init_contexts.append(local_torch_dtype(dtype, cls.__name__))
if is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called:
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
# this immediately partitions the model across all gpus, to avoid the overhead in time
# and memory copying it on CPU or each GPU first
import deepspeed

init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()]
with ContextManagers(init_contexts):
model = cls(config, **kwargs)
init_contexts.extend([deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()])

else:
# Instantiate the model
with ContextManagers(init_contexts):
model = cls(config, **kwargs)

# restore default dtype if it was modified
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

return model

@classmethod
def _set_default_dtype(cls, dtype: torch.dtype) -> torch.dtype:
"""
Change the default dtype and return the previous one. This is needed when wanting to instantiate the model
under specific dtype.

Args:
dtype (`torch.dtype`):
a floating dtype to set to.

Returns:
`torch.dtype`: the original `dtype` that can be used to restore `torch.set_default_dtype(dtype)` if it was
modified. If it wasn't, returns `None`.

Note `set_default_dtype` currently only works with floating-point types and asserts if for example,
`torch.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception.
"""
if isinstance(dtype, str):
if hasattr(torch, dtype):
dtype = getattr(torch, dtype)
else:
raise ValueError(f"Received an invalid string dtype: {dtype}")
if not dtype.is_floating_point:
raise ValueError(
f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype"
)

logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
dtype_orig = torch.get_default_dtype()
torch.set_default_dtype(dtype)
cls._dtype = dtype
return dtype_orig

@property
def base_model(self) -> nn.Module:
"""
@@ -1558,7 +1598,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
return True

if is_torch_xpu_available():
logger.info("Detect using FlashAttention2 (via kernel `kernels-community/flash-attn2`) on XPU.")
logger.info(
f"Detect using FlashAttention2 (via kernel `{FLASH_ATTN_KERNEL_FALLBACK['flash_attention_2']}`) on XPU."
)
return True

if importlib.util.find_spec("flash_attn") is None:
@@ -1790,14 +1832,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
and is_kernels_available()
and not is_torch_npu_available()
):
if attn_implementation.endswith("2"):
applicable_attn_implementation = "kernels-community/flash-attn2"
if is_torch_xpu_available():
# On XPU, kernels library is the native implementation
# Disabling this flag to avoid giving wrong fallbacks on errors and warnings
requested_original_flash_attn = False
else:
applicable_attn_implementation = "kernels-community/vllm-flash-attn3"
applicable_attn_implementation = FLASH_ATTN_KERNEL_FALLBACK[attn_implementation.removeprefix("paged|")]

if is_torch_xpu_available() and attn_implementation.removeprefix("paged|") == "flash_attention_2":
# On XPU, kernels library is the native implementation
# Disabling this flag to avoid giving wrong fallbacks on errors and warnings
requested_original_flash_attn = False

if is_paged:
applicable_attn_implementation = f"paged|{applicable_attn_implementation}"
@@ -2358,13 +2398,15 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
source_is_there = source_param_name not in missing_keys
target_is_there = target_param_name not in missing_keys
# Both are already present -> it means the config is wrong and do not reflect the actual
# checkpoint -> let's raise a warning and do nothing
# checkpoint -> let's raise a warning and NOT tie them
if source_is_there and target_is_there:
logger.warning(
f"The tied weights mapping and config for this model specifies to tie {source_param_name} to "
f"{target_param_name}, but both are present in the checkpoints, so we will NOT tie them. "
"You should update the config with `tie_word_embeddings=False` to silence this warning"
)
# Remove from internal attribute to correctly reflect actual tied weights
self.all_tied_weights_keys.pop(target_param_name)
# Skip to next iteration
continue
# We're missing the source but we have the target -> we swap them, tying the parameter that exists
@@ -3172,29 +3214,23 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
current_peft_config = self.peft_config[active_adapter]
current_peft_config.save_pretrained(save_directory)

# for offloaded modules
module_map = {}

# Save the model
# Get the model state_dict
if state_dict is None:
# if any model parameters are offloaded, make module map
if (
hasattr(self, "hf_device_map")
and len(set(self.hf_device_map.values())) > 1
and ("cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values())
):
warnings.warn(
"Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory exceeds the `shard_size` (5GB default)"
)
for name, module in model_to_save.named_modules():
if name == "":
continue
module_state_dict = module.state_dict()

for key in module_state_dict:
module_map[name + f".{key}"] = module
state_dict = model_to_save.state_dict()

# if any model parameters are offloaded, we need to know it for later
is_offloaded = False
if (
hasattr(self, "hf_device_map")
and len(set(self.hf_device_map.values())) > 1
and ("cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values())
):
is_offloaded = True
warnings.warn(
"Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory "
"exceeds the `shard_size` (50GB default)"
)

# Translate state_dict from smp to hf if saving with smp >= 1.10
if IS_SAGEMAKER_MP_POST_1_10:
for smp_to_hf, _ in smp.state.module_manager.translate_functions:
@@ -3211,76 +3247,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
if self._tp_size is not None:
state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)

# Safetensors does not allow tensor aliasing - we're going to remove aliases before saving
ptrs = collections.defaultdict(list)
for name, tensor in state_dict.items():
if not isinstance(tensor, torch.Tensor):
# Sometimes in the state_dict we have non-tensor objects.
# e.g. in bitsandbytes we have some `str` objects in the state_dict
# In the non-tensor case, fall back to the pointer of the object itself
ptrs[id(tensor)].append(name)

elif tensor.device.type == "meta":
# In offloaded cases, there may be meta tensors in the state_dict.
# For these cases, key by the pointer of the original tensor object
# (state_dict tensors are detached and therefore no longer shared)
tensor = self.get_parameter(name)
ptrs[id(tensor)].append(name)

else:
ptrs[id_tensor_storage(tensor)].append(name)

shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}

# Recursively descend to find tied weight keys
_tied_weights_keys = set(_get_tied_weight_keys(self))
error_names = []
to_delete_names = set()
for names in shared_ptrs.values():
# Removing the keys which are declared as known duplicates on
# load. This allows to make sure the name which is kept is consistent.
if _tied_weights_keys is not None:
found = 0
for name in sorted(names):
matches_pattern = any(re.search(pat, name) for pat in _tied_weights_keys)
if matches_pattern and name in state_dict:
found += 1
if found < len(names):
to_delete_names.add(name)
# We are entering a place where the weights and the transformers configuration do NOT match.
shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
# Those are actually tensor sharing but disjoint from each other, we can safely clone them
# Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
for name in disjoint_names:
state_dict[name] = state_dict[name].clone()

# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
# If the link between tensors was done at runtime then `from_pretrained` will not get
# the key back leading to random tensor. A proper warning will be shown
# during reload (if applicable), but since the file is not necessarily compatible with
# the config, better show a proper warning.
shared_names, identical_names = _find_identical(shared_names, state_dict)
# delete tensors that have identical storage
for inames in identical_names:
known = inames.intersection(to_delete_names)
for name in known:
del state_dict[name]
unknown = inames.difference(to_delete_names)
if len(unknown) > 1:
error_names.append(unknown)

if shared_names:
error_names.extend(shared_names)

if len(error_names) > 0:
raise RuntimeError(
f"The weights trying to be saved contained shared tensors {error_names} which are not properly defined. We found `_tied_weights_keys` to be: {_tied_weights_keys}.\n"
"This can also just mean that the module's tied weight keys are wrong vs the actual tied weights in the model.",
)
# Remove tied weights as safetensors do not handle them
state_dict = remove_tied_weights_from_state_dict(state_dict, model_to_save)

# Revert all renaming and/or weight operations
if save_original_format:
state_dict = revert_weight_conversion(self, state_dict)
state_dict = revert_weight_conversion(model_to_save, state_dict)

# Shard the model if it is too big.
if not _hf_peft_config_loaded:
@@ -3320,47 +3292,39 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
and reg.fullmatch(filename_no_suffix) is not None
):
os.remove(full_filename)

# Save the model
filename_to_tensors = state_dict_split.filename_to_tensors.items()
if module_map:
filename_to_tensors = logging.tqdm(filename_to_tensors, desc="Saving checkpoint shards")
for shard_file, tensors in filename_to_tensors:
shard = {}
for tensor in tensors:
if _is_dtensor_available and isinstance(state_dict[tensor], DTensor):
full_tensor = state_dict[tensor].full_tensor()
for shard_file, tensor_names in logging.tqdm(
state_dict_split.filename_to_tensors.items(), desc="Writing model shards"
):
filename = os.path.join(save_directory, shard_file)
shard_state_dict = {}
for tensor_name in tensor_names:
# Get the tensor, and remove it from state_dict to avoid keeping the ref
tensor = state_dict.pop(tensor_name)

# In case of TP, get the full parameter back
if _is_dtensor_available and isinstance(tensor, DTensor):
tensor = tensor.full_tensor()
# to get the correctly ordered tensor we need to repack if packed
if _get_parameter_tp_plan(tensor, self._tp_plan) == "local_packed_rowwise":
full_tensor = repack_weights(full_tensor, -1, self._tp_size, 2)
shard[tensor] = full_tensor.contiguous() # only do contiguous after it's permuted correctly
else:
shard[tensor] = state_dict[tensor].contiguous()
# delete reference, see https://github.com/huggingface/transformers/pull/34890
del state_dict[tensor]

# remake shard with onloaded parameters if necessary
if module_map:
# init state_dict for this shard
shard_state_dict = dict.fromkeys(shard, "")
for module_name in shard:
# note that get_state_dict_from_offload can update with meta tensors
# if both a parent module and its descendant are offloaded
tensor = shard_state_dict[module_name]
if tensor == "" or (isinstance(tensor, torch.Tensor) and tensor.device.type == "meta"):
# update state dict with onloaded parameters
module = module_map[module_name]
shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict)

# assign shard to be the completed state dict
shard = shard_state_dict
del shard_state_dict
gc.collect()

# TODO: we should def parallelize this we are otherwise just waiting
# too much before scheduling the next write when its in a different file
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)

del state_dict
if _get_parameter_tp_plan(tensor_name, self._tp_plan) == "local_packed_rowwise":
tensor = repack_weights(tensor, -1, self._tp_size, 2)

# If the param was offloaded, we need to load it back from disk to resave it. It's a strange pattern,
# but it would otherwise not be contained in the saved shard if we were to simply move the file
# or something
if is_offloaded and tensor.device.type == "meta":
tensor = load_offloaded_parameter(model_to_save, tensor_name)

# only do contiguous after it's permuted correctly in case of TP
shard_state_dict[tensor_name] = tensor.contiguous()

# TODO: it would be very nice to do the writing concurrently, but safetensors never releases the GIL,
# so it's not possible for now....
# Write the shard to disk
safe_save_file(shard_state_dict, filename, metadata=metadata)
# Cleanup the data before next loop (important with offloading, so we don't blowup cpu RAM)
del shard_state_dict

if index is None:
path_to_weights = os.path.join(save_directory, weights_name)
@@ -3537,11 +3501,13 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
return super().float(*args)

@classmethod
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
def get_init_context(cls, dtype: torch.dtype, is_quantized: bool, _is_ds_init_called: bool):
# Need to instantiate with correct dtype
init_contexts = [local_torch_dtype(dtype, cls.__name__)]
if is_deepspeed_zero3_enabled():
import deepspeed

init_contexts = [no_init_weights()]
init_contexts.append(no_init_weights())
# We cannot initialize the model on meta device with deepspeed when not quantized
if not is_quantized and not _is_ds_init_called:
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
@@ -3549,7 +3515,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
elif is_quantized:
init_contexts.extend([init_empty_weights(), set_quantized_state()])
else:
init_contexts = [no_init_weights(), init_empty_weights()]
init_contexts.extend([no_init_weights(), init_empty_weights()])

return init_contexts

@@ -3583,7 +3549,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
self.use_kernels = False

@classmethod
@restore_default_dtype
def from_pretrained(
cls: type[SpecificPreTrainedModelType],
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
@@ -3963,12 +3928,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
]

# Find the correct dtype based on current state
config, dtype, dtype_orig = _get_dtype(
cls, dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only
)
config, dtype = _get_dtype(dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only)

config.name_or_path = pretrained_model_name_or_path
model_init_context = cls.get_init_context(is_quantized, _is_ds_init_called)
model_init_context = cls.get_init_context(dtype, is_quantized, _is_ds_init_called)
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
with ContextManagers(model_init_context):
# Let's make sure we don't run the init function of buffer modules
@@ -3997,10 +3960,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
if device_map is not None:
device_map = _get_device_map(model, device_map, max_memory, hf_quantizer)

# restore default dtype
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

# Finalize model weight initialization
model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs = cls._load_pretrained_model(
model,


+ 1
- 0
src/transformers/models/align/modeling_align.py View File

@@ -976,6 +976,7 @@ class AlignVisionModel(AlignPreTrainedModel):
main_input_name = "pixel_values"
input_modalities = ("image",)
supports_gradient_checkpointing = False
_no_split_modules = ["AlignVisionBlock"]

def __init__(self, config: AlignVisionConfig):
super().__init__(config)


+ 3
- 1
src/transformers/models/apertus/modeling_apertus.py View File

@@ -25,7 +25,7 @@ from typing import Optional, Union
import torch
from torch import nn

from ...activations import ACT2FN
from ...activations import ACT2CLS, ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
@@ -49,6 +49,8 @@ class ApertusMLP(nn.Module):
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
if config.hidden_act == "xielu":
self.act_fn = ACT2CLS["xielu"](dtype=config.dtype)

def forward(self, x):
return self.down_proj(self.act_fn(self.up_proj(x)))


+ 4
- 1
src/transformers/models/apertus/modular_apertus.py View File

@@ -19,6 +19,7 @@ from typing import Optional
import torch
from torch import nn

from ...activations import ACT2CLS
from ...cache_utils import Cache
from ...configuration_utils import PreTrainedConfig
from ...modeling_rope_utils import RopeParameters
@@ -192,9 +193,11 @@ class ApertusConfig(PreTrainedConfig):

class ApertusMLP(NemotronMLP):
def __init__(self, config):
super().__init__()
super().__init__(config)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
if config.hidden_act == "xielu":
self.act_fn = ACT2CLS["xielu"](dtype=config.dtype)


class ApertusRMSNorm(LlamaRMSNorm):


+ 1
- 1
src/transformers/models/auto/auto_factory.py View File

@@ -543,7 +543,7 @@ def add_generation_mixin_to_remote_model(model_class):

class _LazyAutoMapping(OrderedDict[type[PreTrainedConfig], _LazyAutoMappingValue]):
"""
" A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.

Args:
- config_mapping: The map model type to config class


+ 1
- 0
src/transformers/models/bart/modeling_bart.py View File

@@ -1463,6 +1463,7 @@ class BartDecoderWrapper(BartPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.decoder = BartDecoder(config)
self.post_init()

def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)


+ 0
- 1
src/transformers/models/beit/image_processing_beit_fast.py View File

@@ -163,7 +163,6 @@ class BeitImageProcessorFast(BaseImageProcessorFast):
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)



+ 1
- 0
src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py View File

@@ -2582,6 +2582,7 @@ class BigBirdPegasusDecoderWrapper(BigBirdPegasusPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.decoder = BigBirdPegasusDecoder(config)
self.post_init()

def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)


+ 1
- 0
src/transformers/models/blenderbot/modeling_blenderbot.py View File

@@ -1156,6 +1156,7 @@ class BlenderbotDecoderWrapper(BlenderbotPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.decoder = BlenderbotDecoder(config)
self.post_init()

def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)


+ 1
- 0
src/transformers/models/blenderbot_small/modeling_blenderbot_small.py View File

@@ -1116,6 +1116,7 @@ class BlenderbotSmallDecoderWrapper(BlenderbotSmallPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.decoder = BlenderbotSmallDecoder(config)
self.post_init()

def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)


+ 2
- 0
src/transformers/models/blip/modeling_blip_text.py View File

@@ -740,6 +740,8 @@ class BlipTextLMHeadModel(BlipTextPreTrainedModel, GenerationMixin):
self.cls = BlipTextOnlyMLMHead(config)
self.label_smoothing = config.label_smoothing

self.post_init()

def get_input_embeddings(self):
return self.bert.get_input_embeddings()



+ 3
- 1
src/transformers/models/blt/modeling_blt.py View File

@@ -753,6 +753,8 @@ class BltPatcher(BltPreTrainedModel):
bias=False,
)

self.post_init()

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@@ -952,7 +954,7 @@ def compute_hash_embeddings(
hash_ids = byte_group_hash_function(local_encoder_tokens, group_size, prime, encoder_hash_byte_group_vocab)
# Apply offset to get the correct slice of the fused embedding
offset_hash_ids = hash_ids + embedding_idx * encoder_hash_byte_group_vocab
embeddings += encoder_hash_tok_embedding(offset_hash_ids)
embeddings += encoder_hash_tok_embedding(offset_hash_ids).to(embeddings.device)
embedding_idx += 1

return embeddings


+ 3
- 1
src/transformers/models/blt/modular_blt.py View File

@@ -133,7 +133,7 @@ def compute_hash_embeddings(
hash_ids = byte_group_hash_function(local_encoder_tokens, group_size, prime, encoder_hash_byte_group_vocab)
# Apply offset to get the correct slice of the fused embedding
offset_hash_ids = hash_ids + embedding_idx * encoder_hash_byte_group_vocab
embeddings += encoder_hash_tok_embedding(offset_hash_ids)
embeddings += encoder_hash_tok_embedding(offset_hash_ids).to(embeddings.device)
embedding_idx += 1

return embeddings
@@ -634,6 +634,8 @@ class BltPatcher(BltPreTrainedModel):
bias=False,
)

self.post_init()

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,


+ 0
- 2
src/transformers/models/bridgetower/image_processing_bridgetower_fast.py View File

@@ -251,10 +251,8 @@ class BridgeTowerImageProcessorFast(BaseImageProcessorFast):
processed_images, processed_masks = self.pad(
processed_images, return_mask=True, disable_grouping=disable_grouping
)
processed_masks = torch.stack(processed_masks, dim=0) if return_tensors else processed_masks
data["pixel_mask"] = processed_masks

processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
data["pixel_values"] = processed_images

return BatchFeature(data=data, tensor_type=return_tensors)


+ 1
- 0
src/transformers/models/bridgetower/modeling_bridgetower.py View File

@@ -955,6 +955,7 @@ class BridgeTowerVisionModel(BridgeTowerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.visual = BridgeTowerVisionTransformer(config)
self.post_init()

@property
def dtype(self):


+ 1
- 0
src/transformers/models/chameleon/modeling_chameleon.py View File

@@ -809,6 +809,7 @@ class ChameleonVQVAE(ChameleonPreTrainedModel):
self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
self.eval() # Chameleon's VQ model is frozen
self.post_init()

def encode(self, pixel_values: torch.LongTensor):
hidden_states = self.encoder(pixel_values)


+ 2
- 0
src/transformers/models/clipseg/modeling_clipseg.py View File

@@ -1121,6 +1121,8 @@ class CLIPSegDecoder(CLIPSegPreTrainedModel):
decoder_config.hidden_act = "relu"
self.layers = nn.ModuleList([CLIPSegDecoderLayer(decoder_config) for _ in range(len(config.extract_layers))])

self.post_init()

def forward(
self,
hidden_states: tuple[torch.Tensor],


+ 0
- 1
src/transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py View File

@@ -263,7 +263,6 @@ class Cohere2VisionImageProcessorFast(BaseImageProcessorFast):
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

return BatchFeature(
data={"pixel_values": processed_images, "num_patches": num_patches}, tensor_type=return_tensors


+ 0
- 1
src/transformers/models/convnext/image_processing_convnext_fast.py View File

@@ -162,7 +162,6 @@ class ConvNextImageProcessorFast(BaseImageProcessorFast):
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)



+ 0
- 4
src/transformers/models/decision_transformer/modeling_decision_transformer.py View File

@@ -367,12 +367,8 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
config: DecisionTransformerConfig
base_model_prefix = "transformer"
supports_gradient_checkpointing = True

_can_compile_fullgraph = False

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)

@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights."""


+ 1
- 1
src/transformers/models/deepseek_v3/modeling_deepseek_v3.py View File

@@ -157,7 +157,7 @@ class DeepseekV3NaiveMoe(nn.Module):
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
self.intermediate_dim = config.moe_intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]


+ 1
- 0
src/transformers/models/deepseek_v3/modular_deepseek_v3.py View File

@@ -107,6 +107,7 @@ class DeepseekV3NaiveMoe(MixtralExperts):
def __init__(self, config):
super().__init__(config)
self.num_experts = config.num_local_experts
self.intermediate_dim = config.moe_intermediate_size


class DeepseekV3MoE(nn.Module):


+ 0
- 1
src/transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py View File

@@ -171,7 +171,6 @@ class DeepseekVLImageProcessorFast(BaseImageProcessorFast):
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)



+ 0
- 4
src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py View File

@@ -207,9 +207,6 @@ class DeepseekVLHybridImageProcessorFast(BaseImageProcessorFast):
)
high_res_processed_images_grouped[shape] = stacked_high_res_images
high_res_processed_images = reorder_images(high_res_processed_images_grouped, grouped_high_res_images_index)
high_res_processed_images = (
torch.stack(high_res_processed_images, dim=0) if return_tensors else high_res_processed_images
)

resized_images_grouped = {}
for shape, stacked_high_res_padded_images in high_res_padded_images.items():
@@ -233,7 +230,6 @@ class DeepseekVLHybridImageProcessorFast(BaseImageProcessorFast):
)
processed_images_grouped[shape] = stacked_images
processed_images = reorder_images(processed_images_grouped, grouped_resized_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

return BatchFeature(
data={"pixel_values": processed_images, "high_res_pixel_values": high_res_processed_images},


+ 0
- 4
src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py View File

@@ -888,9 +888,6 @@ class DeepseekVLHybridImageProcessorFast(DeepseekVLImageProcessorFast):
)
high_res_processed_images_grouped[shape] = stacked_high_res_images
high_res_processed_images = reorder_images(high_res_processed_images_grouped, grouped_high_res_images_index)
high_res_processed_images = (
torch.stack(high_res_processed_images, dim=0) if return_tensors else high_res_processed_images
)

resized_images_grouped = {}
for shape, stacked_high_res_padded_images in high_res_padded_images.items():
@@ -914,7 +911,6 @@ class DeepseekVLHybridImageProcessorFast(DeepseekVLImageProcessorFast):
)
processed_images_grouped[shape] = stacked_images
processed_images = reorder_images(processed_images_grouped, grouped_resized_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

return BatchFeature(
data={"pixel_values": processed_images, "high_res_pixel_values": high_res_processed_images},


+ 0
- 1
src/transformers/models/depth_pro/image_processing_depth_pro_fast.py View File

@@ -94,7 +94,6 @@ class DepthProImageProcessorFast(BaseImageProcessorFast):
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)



+ 4
- 0
src/transformers/models/dia/modeling_dia.py View File

@@ -452,6 +452,8 @@ class DiaEncoder(DiaPreTrainedModel):
self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
self.rotary_emb = DiaRotaryEmbedding(config=config)

self.post_init()

@auto_docstring
@can_return_tuple
def forward(
@@ -578,6 +580,8 @@ class DiaDecoder(DiaPreTrainedModel):
self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
self.rotary_emb = DiaRotaryEmbedding(config=config)

self.post_init()

@auto_docstring
@can_return_tuple
def forward(


+ 4
- 0
src/transformers/models/dia/modular_dia.py View File

@@ -241,6 +241,8 @@ class DiaEncoder(DiaPreTrainedModel):
self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
self.rotary_emb = DiaRotaryEmbedding(config=config)

self.post_init()

@auto_docstring
@can_return_tuple
def forward(
@@ -367,6 +369,8 @@ class DiaDecoder(DiaPreTrainedModel):
self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
self.rotary_emb = DiaRotaryEmbedding(config=config)

self.post_init()

@auto_docstring
@can_return_tuple
def forward(


+ 0
- 1
src/transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py View File

@@ -88,7 +88,6 @@ class DINOv3ViTImageProcessorFast(BaseImageProcessorFast):
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)



+ 0
- 1
src/transformers/models/donut/image_processing_donut_fast.py View File

@@ -231,7 +231,6 @@ class DonutImageProcessorFast(BaseImageProcessorFast):
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)



+ 1
- 1
src/transformers/models/dots1/modeling_dots1.py View File

@@ -315,7 +315,7 @@ class Dots1NaiveMoe(nn.Module):
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
self.intermediate_dim = config.moe_intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]


+ 1
- 2
src/transformers/models/dpt/image_processing_dpt_fast.py View File

@@ -225,8 +225,7 @@ class DPTImageProcessorFast(BaseImageProcessorFast):
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(data={"pixel_values": processed_images})
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)

def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None):
"""


+ 1
- 2
src/transformers/models/dpt/modular_dpt.py View File

@@ -228,8 +228,7 @@ class DPTImageProcessorFast(BeitImageProcessorFast):
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(data={"pixel_values": processed_images})
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)

def post_process_depth_estimation(
self,


+ 1
- 2
src/transformers/models/efficientloftr/image_processing_efficientloftr_fast.py View File

@@ -153,9 +153,8 @@ class EfficientLoFTRImageProcessorFast(BaseImageProcessorFast):
stacked_pairs = [torch.stack(pair, dim=0) for pair in image_pairs]

# Return in same format as slow processor
image_pairs = torch.stack(stacked_pairs, dim=0) if return_tensors else stacked_pairs

return BatchFeature(data={"pixel_values": image_pairs})
return BatchFeature(data={"pixel_values": stacked_pairs}, tensor_type=return_tensors)

def post_process_keypoint_matching(
self,


+ 0
- 1
src/transformers/models/efficientnet/image_processing_efficientnet_fast.py View File

@@ -178,7 +178,6 @@ class EfficientNetImageProcessorFast(BaseImageProcessorFast):
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)



+ 1
- 1
src/transformers/models/efficientnet/modeling_efficientnet.py View File

@@ -435,7 +435,7 @@ class EfficientNetPreTrainedModel(PreTrainedModel):
base_model_prefix = "efficientnet"
main_input_name = "pixel_values"
input_modalities = ("image",)
_no_split_modules = []
_no_split_modules = ["EfficientNetBlock"]

@torch.no_grad()
def _init_weights(self, module: nn.Module):


+ 11
- 10
src/transformers/models/eomt/image_processing_eomt_fast.py View File

@@ -162,8 +162,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
)
ignore_index = kwargs.pop("ignore_index", None)
images_kwargs = kwargs.copy()
processed_images, patch_offsets = self._preprocess(images, **images_kwargs)
outputs = BatchFeature({"pixel_values": processed_images})
outputs = self._preprocess(images, **images_kwargs)

if segmentation_maps is not None:
processed_segmentation_maps = self._prepare_image_like_inputs(
@@ -183,9 +182,9 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
}
)

processed_segmentation_maps, _ = self._preprocess(
processed_segmentation_maps = self._preprocess(
images=processed_segmentation_maps, **segmentation_maps_kwargs
)
).pixel_values
processed_segmentation_maps = processed_segmentation_maps.squeeze(1).to(torch.int64)
# Convert to list of binary masks and labels
mask_labels, class_labels = [], []
@@ -208,8 +207,8 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
outputs["mask_labels"] = mask_labels
outputs["class_labels"] = class_labels

if patch_offsets:
outputs["patch_offsets"] = [torch.tensor(offsets) for offsets in patch_offsets]
if outputs.patch_offsets:
outputs["patch_offsets"] = [torch.tensor(offsets) for offsets in outputs.patch_offsets]

return outputs

@@ -274,11 +273,13 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_images_grouped[shape] = stacked_images
images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = reorder_images(processed_images_grouped, grouped_images_index)

processed_images = torch.stack(images, dim=0) if return_tensors else images

return processed_images, patch_offsets
return BatchFeature(
data={"pixel_values": processed_images, "patch_offsets": patch_offsets},
tensor_type=return_tensors,
skip_tensor_conversion=["patch_offsets"],
)

def merge_image_patches(
self,


+ 3
- 0
src/transformers/models/ernie/modeling_ernie.py View File

@@ -113,6 +113,9 @@ class ErnieEmbeddings(nn.Module):
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)

# .to is better than using _no_split_modules on ErnieEmbeddings as it's the first module and >1/2 the model size
inputs_embeds = inputs_embeds.to(token_type_embeddings.device)
embeddings = inputs_embeds + token_type_embeddings

position_embeddings = self.position_embeddings(position_ids)


+ 3
- 0
src/transformers/models/ernie/modular_ernie.py View File

@@ -107,6 +107,9 @@ class ErnieEmbeddings(BertEmbeddings):
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)

# .to is better than using _no_split_modules on ErnieEmbeddings as it's the first module and >1/2 the model size
inputs_embeds = inputs_embeds.to(token_type_embeddings.device)
embeddings = inputs_embeds + token_type_embeddings

position_embeddings = self.position_embeddings(position_ids)


+ 1
- 0
src/transformers/models/evolla/modeling_evolla.py View File

@@ -524,6 +524,7 @@ class EvollaSaProtProteinEncoder(EvollaSaProtPreTrainedModel):
super().__init__(config)
self.embeddings = EvollaSaProtEmbeddings(config)
self.encoder = EvollaSaProtEncoder(config)
self.post_init()

def get_input_embeddings(self):
return self.embeddings.word_embeddings


+ 1
- 0
src/transformers/models/evolla/modular_evolla.py View File

@@ -209,6 +209,7 @@ class EvollaSaProtProteinEncoder(EvollaSaProtPreTrainedModel):
super().__init__(config)
self.embeddings = EvollaSaProtEmbeddings(config)
self.encoder = EvollaSaProtEncoder(config)
self.post_init()

def get_input_embeddings(self):
return self.embeddings.word_embeddings


+ 0
- 3
src/transformers/models/flaubert/modeling_flaubert.py View File

@@ -660,9 +660,6 @@ class FlaubertPreTrainedModel(PreTrainedModel):
config: FlaubertConfig
base_model_prefix = "transformer"

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)

@property
def dummy_inputs(self):
inputs_list = torch.tensor([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])


+ 0
- 2
src/transformers/models/flava/image_processing_flava_fast.py View File

@@ -306,7 +306,6 @@ class FlavaImageProcessorFast(BaseImageProcessorFast):
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

return processed_images

@@ -397,7 +396,6 @@ class FlavaImageProcessorFast(BaseImageProcessorFast):
mask_group_max_aspect_ratio=mask_group_max_aspect_ratio,
)
masks = [mask_generator() for _ in range(len(images))]
masks = torch.stack(masks, dim=0) if return_tensors else masks
data["bool_masked_pos"] = masks

return BatchFeature(data=data, tensor_type=return_tensors)


+ 1
- 1
src/transformers/models/fuyu/image_processing_fuyu.py View File

@@ -94,7 +94,7 @@ class FuyuBatchFeature(BatchFeature):
The outputs dictionary from the processors contains a mix of tensors and lists of tensors.
"""

def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None, **kwargs):
"""
Convert the inner content to tensors.



+ 9
- 11
src/transformers/models/gemma/modeling_gemma.py View File

@@ -410,16 +410,14 @@ class GemmaModel(GemmaPreTrainedModel):
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

# It may already have been prepared by e.g. `generate`
if not isinstance(causal_mask_mapping := attention_mask, dict):
causal_mask_mapping = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)

# embed positions
hidden_states = inputs_embeds
@@ -434,7 +432,7 @@ class GemmaModel(GemmaPreTrainedModel):
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,


+ 9
- 11
src/transformers/models/gemma/modular_gemma.py View File

@@ -267,16 +267,14 @@ class GemmaModel(LlamaModel):
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

# It may already have been prepared by e.g. `generate`
if not isinstance(causal_mask_mapping := attention_mask, dict):
causal_mask_mapping = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)

# embed positions
hidden_states = inputs_embeds
@@ -291,7 +289,7 @@ class GemmaModel(LlamaModel):
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,


+ 0
- 1
src/transformers/models/gemma3/image_processing_gemma3_fast.py View File

@@ -231,7 +231,6 @@ class Gemma3ImageProcessorFast(BaseImageProcessorFast):
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(
data={"pixel_values": processed_images, "num_crops": num_crops}, tensor_type=return_tensors
)


+ 1
- 0
src/transformers/models/gemma3n/modeling_gemma3n.py View File

@@ -919,6 +919,7 @@ class Gemma3nAudioEncoder(PreTrainedModel):
self.conformer = nn.ModuleList(
[Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)]
)
self.post_init()

def forward(
self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor, **kwargs


+ 1
- 0
src/transformers/models/gemma3n/modular_gemma3n.py View File

@@ -1472,6 +1472,7 @@ class Gemma3nAudioEncoder(PreTrainedModel):
self.conformer = nn.ModuleList(
[Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)]
)
self.post_init()

def forward(
self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor, **kwargs


+ 1
- 1
src/transformers/models/glm4_moe/modeling_glm4_moe.py View File

@@ -339,7 +339,7 @@ class Glm4MoeNaiveMoe(nn.Module):
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
self.intermediate_dim = config.moe_intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]


+ 1
- 1
src/transformers/models/glm4v_moe/modeling_glm4v_moe.py View File

@@ -402,7 +402,7 @@ class Glm4vMoeTextNaiveMoe(nn.Module):
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
self.intermediate_dim = config.moe_intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]


+ 0
- 1
src/transformers/models/glpn/image_processing_glpn_fast.py View File

@@ -107,7 +107,6 @@ class GLPNImageProcessorFast(BaseImageProcessorFast):
processed_groups[shape] = stacked_images

processed_images = reorder_images(processed_groups, grouped_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)

def post_process_depth_estimation(self, outputs, target_sizes=None):


+ 0
- 1
src/transformers/models/got_ocr2/image_processing_got_ocr2_fast.py View File

@@ -189,7 +189,6 @@ class GotOcr2ImageProcessorFast(BaseImageProcessorFast):
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

return BatchFeature(
data={"pixel_values": processed_images, "num_patches": num_patches}, tensor_type=return_tensors


+ 1
- 0
src/transformers/models/got_ocr2/modeling_got_ocr2.py View File

@@ -433,6 +433,7 @@ class GotOcr2VisionEncoder(GotOcr2PreTrainedModel):
self.neck = GotOcr2VisionNeck(config)

self.gradient_checkpointing = False
self.post_init()

def get_input_embeddings(self):
return self.patch_embed


+ 0
- 4
src/transformers/models/gpt2/modeling_gpt2.py View File

@@ -476,12 +476,8 @@ class GPT2PreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
_supports_attention_backend = True

_can_compile_fullgraph = True

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)

@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights."""


+ 0
- 8
src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py View File

@@ -26,7 +26,6 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
@@ -43,10 +42,6 @@ from ...utils import (
from .configuration_gpt_bigcode import GPTBigCodeConfig


if is_flash_attn_available():
pass


logger = logging.get_logger(__name__)


@@ -360,9 +355,6 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)

@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights."""


+ 4
- 0
src/transformers/models/idefics2/modeling_idefics2.py View File

@@ -452,6 +452,8 @@ class Idefics2VisionTransformer(Idefics2PreTrainedModel):
self.encoder = Idefics2Encoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

self.post_init()

def get_input_embeddings(self):
return self.embeddings

@@ -711,6 +713,8 @@ class Idefics2PerceiverResampler(Idefics2PreTrainedModel):
self.layers = nn.ModuleList([Idefics2PerceiverLayer(config, idx) for idx in range(self.depth)])
self.norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)

self.post_init()

@auto_docstring
def forward(
self,


+ 2
- 0
src/transformers/models/idefics3/modeling_idefics3.py View File

@@ -458,6 +458,8 @@ class Idefics3VisionTransformer(Idefics3PreTrainedModel):
self.patch_size = config.patch_size
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

self.post_init()

# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionTransformer.get_input_embeddings
def get_input_embeddings(self):
return self.embeddings


+ 1
- 5
src/transformers/models/imagegpt/image_processing_imagegpt_fast.py View File

@@ -164,12 +164,8 @@ class ImageGPTImageProcessorFast(BaseImageProcessorFast):

input_ids = reorder_images(input_ids_grouped, grouped_images_index)

return BatchFeature(
data={"input_ids": torch.stack(input_ids, dim=0) if return_tensors else input_ids},
tensor_type=return_tensors,
)
return BatchFeature(data={"input_ids": input_ids}, tensor_type=return_tensors)

pixel_values = torch.stack(pixel_values, dim=0) if return_tensors else pixel_values
return BatchFeature(data={"pixel_values": pixel_values}, tensor_type=return_tensors)

def to_dict(self):


+ 0
- 1
src/transformers/models/instructblipvideo/video_processing_instructblipvideo.py View File

@@ -84,7 +84,6 @@ class InstructBlipVideoVideoProcessor(BaseVideoProcessor):
processed_videos_grouped[shape] = stacked_videos

processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index)
processed_videos = torch.stack(processed_videos, dim=0) if return_tensors else processed_videos

return BatchFeature(data={"pixel_values": processed_videos}, tensor_type=return_tensors)



+ 0
- 1
src/transformers/models/internvl/video_processing_internvl.py View File

@@ -140,7 +140,6 @@ class InternVLVideoProcessor(BaseVideoProcessor):
processed_videos_grouped[shape] = stacked_videos

processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index)
processed_videos = torch.stack(processed_videos, dim=0) if return_tensors else processed_videos

return BatchFeature(data={"pixel_values_videos": processed_videos}, tensor_type=return_tensors)



+ 0
- 1
src/transformers/models/janus/image_processing_janus_fast.py View File

@@ -180,7 +180,6 @@ class JanusImageProcessorFast(BaseImageProcessorFast):
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)



+ 0
- 2
src/transformers/models/janus/modeling_janus.py View File

@@ -973,8 +973,6 @@ class JanusVQVAE(JanusPreTrainedModel):
self.eval() # Janus's VQ model is frozen
self.decoder = JanusVQVAEDecoder(config)
self.gradient_checkpointing = False

# Initialize the VQVAE model.
self.post_init()

def encode(self, pixel_values: torch.LongTensor):


+ 2
- 2
src/transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py View File

@@ -264,8 +264,8 @@ class Kosmos2_5ImageProcessorFast(BaseImageProcessorFast):

encoded_outputs = BatchFeature(
data={
"flattened_patches": torch.stack(flattened_patches, dim=0) if return_tensors else flattened_patches,
"attention_mask": torch.stack(attention_masks, dim=0) if return_tensors else attention_masks,
"flattened_patches": flattened_patches,
"attention_mask": attention_masks,
"width": width,
"height": height,
"rows": rows,


+ 4
- 0
src/transformers/models/lasr/configuration_lasr.py View File

@@ -240,5 +240,9 @@ class LasrCTCConfig(PreTrainedConfig):

return cls(encoder_config=encoder_config.to_dict(), **kwargs)

@property
def inputs_to_logits_ratio(self):
return self.encoder_config.subsampling_conv_stride**2


__all__ = ["LasrEncoderConfig", "LasrCTCConfig"]

+ 4
- 0
src/transformers/models/lasr/modular_lasr.py View File

@@ -291,6 +291,10 @@ class LasrCTCConfig(ParakeetCTCConfig):
**kwargs,
)

@property
def inputs_to_logits_ratio(self):
return self.encoder_config.subsampling_conv_stride**2


class LasrEncoderSubsampling(nn.Module):
def __init__(self, config: LasrEncoderConfig):


+ 0
- 1
src/transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py View File

@@ -101,7 +101,6 @@ class LayoutLMv2ImageProcessorFast(BaseImageProcessorFast):
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

data = BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)



+ 0
- 1
src/transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py View File

@@ -115,7 +115,6 @@ class LayoutLMv3ImageProcessorFast(BaseImageProcessorFast):
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

data = BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)



+ 1
- 2
src/transformers/models/lightglue/image_processing_lightglue_fast.py View File

@@ -174,9 +174,8 @@ class LightGlueImageProcessorFast(BaseImageProcessorFast):
stacked_pairs = [torch.stack(pair, dim=0) for pair in image_pairs]

# Return in same format as slow processor
image_pairs = torch.stack(stacked_pairs, dim=0) if return_tensors else stacked_pairs

return BatchFeature(data={"pixel_values": image_pairs})
return BatchFeature(data={"pixel_values": stacked_pairs}, tensor_type=return_tensors)

def post_process_keypoint_matching(
self,


+ 1
- 2
src/transformers/models/llama4/image_processing_llama4_fast.py View File

@@ -419,10 +419,9 @@ class Llama4ImageProcessorFast(BaseImageProcessorFast):
)
grouped_processed_images[shape] = torch.cat([processed_images, global_tiles.unsqueeze(1)], dim=1)
processed_images = reorder_images(grouped_processed_images, grouped_images_index)
aspect_ratios_list = reorder_images(grouped_aspect_ratios, grouped_images_index)
aspect_ratios = reorder_images(grouped_aspect_ratios, grouped_images_index)

processed_images = torch.cat(processed_images, dim=0) if return_tensors else processed_images
aspect_ratios = torch.stack(aspect_ratios_list, dim=0) if return_tensors else aspect_ratios_list
return BatchFeature(
data={"pixel_values": processed_images, "aspect_ratios": aspect_ratios}, tensor_type=return_tensors
)


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save
Baidu
map