8 Commits

Author SHA1 Message Date
  Cyril Vallez c7aec088a6
Enforce call to `post_init` and fix all of them (#42873) 5 hours ago
  Rémi Ouazan f3d5f2558b
[CB] Easy optimizations for continuous batching (#42839) 6 hours ago
  AYou0207 298d08dc36
typo (#42863) 7 hours ago
  Cyril Vallez a187b857a7
Remove tied weights from internal attribute if they are not tied (#42871) 8 hours ago
  Mohamed Mekkouri 64c12fdf5f
[docs] Improve contribution guidelines for Quantization (#42870) 10 hours ago
  Sai-Suraj-27 f0d9cd1ff6
Fixes 2 failing tests from AMD CI (#42777) 10 hours ago
  jiqing-feng 66623a1fd6
Fix speccht5_tts pipeline (#42830) 11 hours ago
  YangKai0616 e17b1b85e3
[Fix] Fix FA2 kernels ut (#42803) 11 hours ago
78 changed files with 485 additions and 197 deletions
Split View
  1. +1
    -1
      Makefile
  2. +95
    -16
      docs/source/en/quantization/contribute.md
  3. +12
    -2
      examples/pytorch/continuous_batching.py
  4. +17
    -21
      src/transformers/generation/continuous_batching/continuous_api.py
  5. +18
    -10
      src/transformers/modeling_utils.py
  6. +1
    -1
      src/transformers/models/auto/auto_factory.py
  7. +1
    -0
      src/transformers/models/bart/modeling_bart.py
  8. +1
    -0
      src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
  9. +1
    -0
      src/transformers/models/blenderbot/modeling_blenderbot.py
  10. +1
    -0
      src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
  11. +2
    -0
      src/transformers/models/blip/modeling_blip_text.py
  12. +2
    -0
      src/transformers/models/blt/modeling_blt.py
  13. +2
    -0
      src/transformers/models/blt/modular_blt.py
  14. +1
    -0
      src/transformers/models/bridgetower/modeling_bridgetower.py
  15. +1
    -0
      src/transformers/models/chameleon/modeling_chameleon.py
  16. +2
    -0
      src/transformers/models/clipseg/modeling_clipseg.py
  17. +0
    -4
      src/transformers/models/decision_transformer/modeling_decision_transformer.py
  18. +4
    -0
      src/transformers/models/dia/modeling_dia.py
  19. +4
    -0
      src/transformers/models/dia/modular_dia.py
  20. +1
    -0
      src/transformers/models/evolla/modeling_evolla.py
  21. +1
    -0
      src/transformers/models/evolla/modular_evolla.py
  22. +0
    -3
      src/transformers/models/flaubert/modeling_flaubert.py
  23. +1
    -0
      src/transformers/models/gemma3n/modeling_gemma3n.py
  24. +1
    -0
      src/transformers/models/gemma3n/modular_gemma3n.py
  25. +1
    -0
      src/transformers/models/got_ocr2/modeling_got_ocr2.py
  26. +0
    -4
      src/transformers/models/gpt2/modeling_gpt2.py
  27. +0
    -8
      src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
  28. +4
    -0
      src/transformers/models/idefics2/modeling_idefics2.py
  29. +2
    -0
      src/transformers/models/idefics3/modeling_idefics3.py
  30. +0
    -2
      src/transformers/models/janus/modeling_janus.py
  31. +1
    -0
      src/transformers/models/marian/modeling_marian.py
  32. +1
    -0
      src/transformers/models/mbart/modeling_mbart.py
  33. +2
    -0
      src/transformers/models/moshi/modeling_moshi.py
  34. +1
    -0
      src/transformers/models/mvp/modeling_mvp.py
  35. +2
    -0
      src/transformers/models/ovis2/modeling_ovis2.py
  36. +2
    -0
      src/transformers/models/ovis2/modular_ovis2.py
  37. +2
    -0
      src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py
  38. +2
    -0
      src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py
  39. +1
    -0
      src/transformers/models/pegasus/modeling_pegasus.py
  40. +1
    -0
      src/transformers/models/pegasus_x/modeling_pegasus_x.py
  41. +1
    -0
      src/transformers/models/plbart/modeling_plbart.py
  42. +8
    -0
      src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py
  43. +6
    -0
      src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py
  44. +2
    -0
      src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
  45. +2
    -0
      src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py
  46. +2
    -0
      src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
  47. +4
    -0
      src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py
  48. +2
    -0
      src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py
  49. +2
    -0
      src/transformers/models/qwen3_vl/modeling_qwen3_vl.py
  50. +2
    -0
      src/transformers/models/qwen3_vl/modular_qwen3_vl.py
  51. +2
    -0
      src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py
  52. +6
    -0
      src/transformers/models/rag/modeling_rag.py
  53. +1
    -0
      src/transformers/models/sam/modeling_sam.py
  54. +2
    -0
      src/transformers/models/sam2/modeling_sam2.py
  55. +2
    -0
      src/transformers/models/sam2/modular_sam2.py
  56. +6
    -0
      src/transformers/models/sam3/modeling_sam3.py
  57. +2
    -2
      src/transformers/models/sam3_video/modeling_sam3_video.py
  58. +1
    -0
      src/transformers/models/sam_hq/modeling_sam_hq.py
  59. +2
    -2
      src/transformers/models/segformer/modeling_segformer.py
  60. +1
    -0
      src/transformers/models/shieldgemma2/modeling_shieldgemma2.py
  61. +2
    -0
      src/transformers/models/siglip/modeling_siglip.py
  62. +2
    -0
      src/transformers/models/siglip2/modeling_siglip2.py
  63. +2
    -0
      src/transformers/models/smolvlm/modeling_smolvlm.py
  64. +2
    -0
      src/transformers/models/timm_backbone/modeling_timm_backbone.py
  65. +2
    -4
      src/transformers/models/timm_wrapper/modeling_timm_wrapper.py
  66. +1
    -0
      src/transformers/models/trocr/modeling_trocr.py
  67. +1
    -0
      src/transformers/models/whisper/modeling_whisper.py
  68. +0
    -3
      src/transformers/models/xlm/modeling_xlm.py
  69. +2
    -2
      src/transformers/pipelines/text_to_audio.py
  70. +2
    -2
      src/transformers/testing_utils.py
  71. +11
    -1
      tests/generation/test_continuous_batching.py
  72. +1
    -1
      tests/models/helium/test_modeling_helium.py
  73. +1
    -1
      tests/models/openai/test_modeling_openai.py
  74. +33
    -0
      tests/pipelines/test_pipelines_text_to_audio.py
  75. +17
    -3
      tests/test_modeling_common.py
  76. +8
    -3
      tests/utils/test_modeling_utils.py
  77. +0
    -101
      utils/check_init_weights_data.py
  78. +150
    -0
      utils/check_modeling_structure.py

+ 1
- 1
Makefile View File

@@ -45,7 +45,7 @@ repo-consistency:
python utils/check_modular_conversion.py
python utils/check_dummies.py
python utils/check_repo.py
python utils/check_init_weights_data.py
python utils/check_modeling_structure.py
python utils/check_inits.py
python utils/check_pipeline_typing.py
python utils/check_config_docstrings.py


+ 95
- 16
docs/source/en/quantization/contribute.md View File

@@ -16,7 +16,7 @@ rendered properly in your Markdown viewer.

# Contribute

Transformers supports many quantization methods such as QLoRA, GPTQ, LLM.int8, and AWQ. However, there are still many more quantization approaches that haven't been integrated yet. To make adding and using these quantization methods with Transformers easier, use the [`~quantizers.HfQuantizer`] class. [`~quantizers.HfQuantizer`] is designed to be an internal helper class for adding a quantization method instead of something applied to every PyTorch module.
Transformers supports many quantization methods such as QLoRA, GPTQ, LLM.int8, and AWQ. However, there are still many more quantization approaches that haven't been integrated yet. To make adding and using these quantization methods with Transformers easier, use the [`~quantizers.HfQuantizer`] class. [`~quantizers.HfQuantizer`] is designed to be an internal helper class for adding a quantization method instead of something applied to every PyTorch module.

This guide will show you how to integrate a new quantization method with [`~quantizers.HfQuantizer`].

@@ -28,16 +28,16 @@ Before integrating a new quantization method into Transformers, ensure the metho
- The method can run on commonly-used hardware (CPU, GPU, etc.).
- The method is wrapped in a [nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) ([`~bitsandbytes.nn.Linear8bitLt`], [`~bitsandbytes.nn.Linear4bit`]), and the quantized linear layer should have the following definition.

```py
class Linear4bit(nn.Module):
def __init__(self, ...):
...
def forward(self, x):
return my_4bit_kernel(x, self.weight, self.bias)
```
```py
class Linear4bit(nn.Module):
def __init__(self, ...):
...

This way, Transformers models are easily quantized by replacing instances of [nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) with a target class.
def forward(self, x):
return my_4bit_kernel(x, self.weight, self.bias)
```

This way, Transformers models are easily quantized by replacing instances of [nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) with a target class.

- The quantization method should be serializable. You can save the quantized weights locally or push them to the Hub.
- Make sure the package containing the quantization kernels/primitive is stable (no frequent breaking changes).
@@ -48,23 +48,23 @@ Some quantization methods may require "pre-quantizing" the model through data ca

0. The best starting point would be to have a look at another quantization method such as Finegrained Fp8. You will have to update or create three files in total: the [config file](https://github.com/huggingface/transformers/blob/main/src/transformers/utils/quantization_config.py), the [integration file](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/finegrained_fp8.py) and the [quantizer file](https://github.com/huggingface/transformers/blob/main/src/transformers/quantizers/quantizer_finegrained_fp8.py).

1. Create a new quantization config class inside [src/transformers/utils/quantization_config.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/utils/quantization_config.py). Add the new quantization config to the [_import_structure](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/__init__.py#L1088) inside Transformers' [src/transformers/__init__.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/__init__.py) file.
1. Create a new quantization config class inside [src/transformers/utils/quantization_config.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/utils/quantization_config.py). Add the new quantization config to the [\_import_structure](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/__init__.py#L1088) inside Transformers' [src/transformers/\_\_init\_\_.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/__init__.py) file.

2. Create a new file inside [src/transformers/quantizers/](https://github.com/huggingface/transformers/tree/abbffc4525566a48a9733639797c812301218b83/src/transformers/quantizers) named `quantizer_your_method.py`, and make it inherit from [`~quantizers.HfQuantizer]. Make sure to add the new quantizer and quantization config in the quantization auto-mapping in [src/transformers/quantizers/auto.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/quantizers/auto.py).

3. Define the following class attributes and property methods for your quantization method:

- `requires_calibration`: Whether the quantization method requires a data calibration process. If set to `True`, you can only support inference (with quantized weights) and not inference and quantization.
- `is_serializable`: A property method to determine whether the method is serializable or not.
- `is_trainable`: A property method to determine whether you can fine-tune models on top of the quantization method (with or without PEFT approaches).
- `requires_calibration`: Whether the quantization method requires a data calibration process. If set to `True`, you can only support inference (with quantized weights) and not inference and quantization.
- `is_serializable`: A property method to determine whether the method is serializable or not.
- `is_trainable`: A property method to determine whether you can fine-tune models on top of the quantization method (with or without PEFT approaches).

4. Write the `validate_environment` and `update_dtype` methods. These methods are called before creating the quantized model to ensure users use the right configuration. Refer to other quantizers for an example of it is implemented.

5. Write the `_process_model_before_weight_loading` method. In Transformers, the quantized models are initialized first on the `"meta"` device before loading the weights. This means the `_process_model_before_weight_loading` method takes care of manipulating the model skeleton to replace some modules ([nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html)) with the target modules (quantization modules).

You can define module replacement logic or any other utility method by creating a new file in [transformers/src/integrations/](https://github.com/huggingface/transformers/tree/abbffc4525566a48a9733639797c812301218b83/src/transformers/integrations) and exposing the relevant methods in that folder's `__init__.py` file.
You can define module replacement logic or any other utility method by creating a new file in [transformers/src/integrations/](https://github.com/huggingface/transformers/tree/abbffc4525566a48a9733639797c812301218b83/src/transformers/integrations) and exposing the relevant methods in that folder's `__init__.py` file.

6. Add the `get_quantize_ops` method to the quantizer class if the quantization supports quantizing on the fly. In transformers, we materialize each tensor and apply a sequence of different operations on it. In our case, the quantization operation happens at the end. You need to create a `XXXQuantize`, a subclass of `ConversionOps`, and add a `convert` method. In the `convert` method, you need to quantize the weights and return a dictionary of quantized params.
6. Add the `get_quantize_ops` method to the quantizer class if the quantization supports quantizing on the fly. In transformers, we materialize each tensor and apply a sequence of different operations on it. In our case, the quantization operation happens at the end. You need to create a `XXXQuantize`, a subclass of `ConversionOps`, and add a `convert` method. In the `convert` method, you need to quantize the weights and return a dictionary of quantized params.

7. Add the `get_weight_conversions` method to the quantizer class if the quantization supports loading pre-quantized weights. In transformers, we can collect multiple tensors and apply operations on them. This is particularly useful when we have tensors in the checkpoint that require to be regrouped to re-create the quantized tensors.

@@ -73,3 +73,82 @@ You can define module replacement logic or any other utility method by creating
9. Document everything! Make sure your quantization method is documented by adding a new file under `docs/source/en/quantization`.

10. You should add tests by adding the package in our nightly Dockerfile inside `docker/transformers-quantization-latest-gpu` and then adding a new test file in `tests/quantization/xxx`. Feel free to check out existing quantization methods to see how it is implemented.

## Files overview

| File | Purpose |
| -------------------------------------------- | ------------------------------------------------------------------------------------------------ |
| `utils/quantization_config.py` | Define `YourMethodConfig` inheriting from `QuantizationConfigMixin` |
| `quantizers/quantizer_your_method.py` | Implement `YourMethodHfQuantizer` inheriting from `HfQuantizer` |
| `integrations/your_method.py` | Implement `ConversionOps` subclasses and helper functions |
| `quantizers/auto.py` | Register quantizer and config in `AUTO_QUANTIZER_MAPPING` and `AUTO_QUANTIZATION_CONFIG_MAPPING` |
| `docs/source/en/quantization/your_method.md` | Document usage for users |
| `tests/quantization/your_method/` | Add integration tests |

## Understanding `get_quantize_ops` vs `get_weight_conversions`

These two methods handle different scenarios for loading weights. Understanding when to use each is essential.

### `get_quantize_ops` — Quantize on the fly

Use this when loading a **non-quantized checkpoint** (e.g., float16/bfloat16 weights) and quantizing during load.

```
Checkpoint: model.safetensors (float16 weights for example)
get_quantize_ops → YourQuantize.convert()
Result: Quantized weights in memory
```

The `convert` method receives one tensor at a time, quantizes it, and can return a dictionary of quantized params, for example:

```py
class YourQuantize(ConversionOps):
def convert(self, input_dict, model, full_layer_name, missing_keys, **kwargs):
# input_dict = {"layer.weight": <float16 tensor>}
value = list(input_dict.values())[0]
module, tensor_name = get_module_from_name(model, full_layer_name)

# Quantize and assign
quantized, scale, zero_point = your_quantize_fn(value)
return {full_layer_name: quantized, full_layer_name + ".scale": scale, full_layer_name + ".zero_point": zero_point}
```

### `get_weight_conversions` — Load pre-quantized checkpoints

Use this when loading a **pre-quantized checkpoint** where the quantized weights are saved as several separate components (such as data, scale, and zero point), and these need to be combined into one tensor during loading. Not all quantization methods require this reconstruction step: for example, some methods like FP8 simply load weights and scales as-is, without combining them. Others, such as torchao, do require reassembling the quantized tensor from its multiple saved components.

```
Checkpoint: model.safetensors (quantized components)
- layer._weight_qdata
- layer._weight_scale
- layer._weight_zero_point
get_weight_conversions → WeightConverter + YourDeserialize.convert()
Result: Reconstructed quantized tensor → layer.weight
```

The `WeightConverter` collects related tensors based on `source_patterns`, then passes them to your `convert` method:

```py
def get_weight_conversions(self):
if self.pre_quantized:
return [
WeightConverter(
source_patterns=["_weight_qdata", "_weight_scale", "_weight_zero_point"],
target_patterns="weight",
operations=[YourDeserialize(self)],
),
]
return []


class YourDeserialize(ConversionOps):
def convert(self, input_dict, model, full_layer_name, **kwargs):
# input_dict contains all collected tensors
# Reconstruct the quantized tensor from components
reconstructed_tensor = reconstruct_from_components(input_dict)
return {full_layer_name: reconstructed_tensor}
```

+ 12
- 2
examples/pytorch/continuous_batching.py View File

@@ -182,13 +182,16 @@ if __name__ == "__main__":

# Benchmark parameters
parser.add_argument("--samples", type=int, default=500, help="Number of samples to generate")
parser.add_argument(
"--input-length", type=int, default=None, help="Length of input sequences. Leave to None to mimic real eval."
)
parser.add_argument("--max-new-tokens", type=int, default=512, help="Maximum number of new tokens to generate")
parser.add_argument("--force-max-length", action="store_true", help="Force generation to stop at max length")

parser.add_argument("--add-prefix", action="store_true", help="Add a prefix to the samples")
parser.add_argument("--compare", action="store_true", help="Compare CB generation with classic generate")
parser.add_argument("--profile", type=str, default=None)
parser.add_argument("--metrics", action="store_true")
parser.add_argument("--force-max-length", action="store_true", help="Force generation to stop at max length")

# Display parameters
parser.add_argument("--displayed", type=int, default=0, help="Number of samples to display")
@@ -251,6 +254,12 @@ if __name__ == "__main__":
else:
possible_prefixes = [None]

tokenizer_kwargs = {"add_generation_prompt": True}
if args.input_length is not None:
tokenizer_kwargs["max_length"] = args.input_length
tokenizer_kwargs["truncation"] = True
tokenizer_kwargs["padding"] = True

batched_inputs = []
for item, prefix in zip(dataset, cycle(possible_prefixes)):
messages = []
@@ -261,7 +270,7 @@ if __name__ == "__main__":
else:
question = prefix + "\n\n" + question
messages.append({"role": "user", "content": question})
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
inputs = tokenizer.apply_chat_template(messages, **tokenizer_kwargs)
inputs = inputs if isinstance(inputs, list) else inputs["input_ids"]
batched_inputs.append(inputs)

@@ -283,6 +292,7 @@ if __name__ == "__main__":
generation_cfg.compile_config = CompileConfig(
fullgraph=True,
mode="max-autotune-no-cudagraphs",
dynamic=True, # FIXME: if we warmup all graphs, this is not needed anymore
)

# If we need to compare, we need to generate the reference outputs


+ 17
- 21
src/transformers/generation/continuous_batching/continuous_api.py View File

@@ -259,7 +259,7 @@ class ContinuousBatchProcessor:
self.cumulative_seqlens_q = torch.empty((self.max_batch_tokens + 1,), **self.tensor_metadata)
self.max_seqlen_q = 0
self.logits_indices = torch.empty((self.max_batch_tokens,), **self.tensor_metadata)
self.output_ids = torch.empty((1, self.max_batch_tokens), **self.tensor_metadata)
self.output_ids = torch.empty((self.max_batch_tokens,), **self.tensor_metadata)

# For some kwargs, we have a dict of tensors with as many items as there are attention types
layer_types = getattr(self.config, "layer_types", None)
@@ -311,7 +311,7 @@ class ContinuousBatchProcessor:
self.cumulative_seqlens_q[: b_size + 1].zero_()
self.max_seqlen_q = 0
self.logits_indices[:q_len].fill_(-1)
self.output_ids[:, :q_len].fill_(-1)
self.output_ids[:q_len].fill_(-1)

# Reset the attributes that are either tensors or dict of tensors
for layer_type in self.cumulative_seqlens_k:
@@ -447,7 +447,7 @@ class ContinuousBatchProcessor:
self.metrics.record_batch_metrics(self.requests_in_batch)

# Reset the static tensors used for storage
self.reset_static_tensors() # TODO: this might be unnecessary
self.reset_static_tensors() # FIXME: why does this make the generation faster?

# Prepare accumulators
self.actual_query_length = 0
@@ -557,13 +557,10 @@ class ContinuousBatchProcessor:
self.actual_index_sizes[i] = (len(group_read_indices), len(group_write_indices))

@traced
def _sync(self) -> list[int]:
if self.output_ids is not None:
try:
return self.output_ids.tolist()[0]
except Exception:
return [0, 1]
return [0, 0]
def _get_new_tokens(self, num_new_tokens: int) -> list[int]:
indices = self.logits_indices[:num_new_tokens]
new_tokens = self.output_ids[indices]
return new_tokens.tolist()

@traced
def _maybe_send_output(self, state: RequestState) -> None:
@@ -574,13 +571,13 @@ class ContinuousBatchProcessor:
@traced
def update_batch(self) -> None:
"""Update request states based on generated tokens."""
out_tokens = self._sync()
new_tokens = self._get_new_tokens(len(self.requests_in_batch))
for i, state in enumerate(self.requests_in_batch):
# If the request has no remaining prompt ids, it means prefill has already ended or just finished
if len(state.remaining_prefill_tokens) == 0:
self.metrics.record_ttft_metric(state.created_time, state.request_id)
state.status = RequestStatus.DECODING
token = out_tokens[self.logits_indices[i]]
token = new_tokens[i]
state.tokens_to_process = [token]
# Update the request and stop if it is complete
is_finished = state.update_and_check_completion(token)
@@ -727,12 +724,11 @@ class ContinuousBatchProcessor:
probs = nn.functional.softmax(probs, dim=-1)
# probs[0] has shape [seq_len, vocab_size], multinomial returns [seq_len, 1]
next_tokens = torch.multinomial(probs[0], num_samples=1).squeeze(-1) # Now [seq_len]
# Add batch dimension back to match argmax output
next_tokens = next_tokens.unsqueeze(0) # Now [1, seq_len]
else:
next_tokens = torch.argmax(probs, dim=-1) # Already [1, seq_len]
tokens = next_tokens.size(1) # Get seq_len dimension
self.output_ids[:, :tokens].copy_(next_tokens)
next_tokens = torch.argmax(probs, dim=-1) # shape is [1, seq_len]
next_tokens = next_tokens.squeeze(0) # shape is [seq_len]
tokens = next_tokens.size(0) # Get seq_len dimension
self.output_ids[:tokens].copy_(next_tokens)


# Manager Class (User Interface)
@@ -950,6 +946,10 @@ class ContinuousBatchingManager:
streaming: bool = False,
record_timestamps: bool = False,
) -> None:
# If there is prefix sharing, we sort the inputs to maximize cache hits
if self._allow_prefix_sharing:
inputs = sorted(inputs, reverse=True)
# Add requests in order
for input_ids in inputs:
self.add_request(
input_ids, max_new_tokens=max_new_tokens, streaming=streaming, record_timestamps=record_timestamps
@@ -1080,10 +1080,6 @@ class ContinuousBatchingManager:
)

self._generation_step()

if torch.cuda.is_available():
torch.cuda.synchronize() # FIXME: why is this needed?
# Processor updates the batch after generation step is truly over
batch_processor.update_batch()

@traced


+ 18
- 10
src/transformers/modeling_utils.py View File

@@ -155,6 +155,12 @@ _init_weights = True
_is_quantized = False
_is_ds_init_called = False

# Mapping from flash attention implementations to their kernel fallback repositories
FLASH_ATTN_KERNEL_FALLBACK = {
"flash_attention_2": "kernels-community/flash-attn2",
"flash_attention_3": "kernels-community/vllm-flash-attn3",
}


def is_local_dist_rank_0():
return (
@@ -1592,7 +1598,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
return True

if is_torch_xpu_available():
logger.info("Detect using FlashAttention2 (via kernel `kernels-community/flash-attn2`) on XPU.")
logger.info(
f"Detect using FlashAttention2 (via kernel `{FLASH_ATTN_KERNEL_FALLBACK['flash_attention_2']}`) on XPU."
)
return True

if importlib.util.find_spec("flash_attn") is None:
@@ -1824,14 +1832,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
and is_kernels_available()
and not is_torch_npu_available()
):
if attn_implementation.endswith("2"):
applicable_attn_implementation = "kernels-community/flash-attn2"
if is_torch_xpu_available():
# On XPU, kernels library is the native implementation
# Disabling this flag to avoid giving wrong fallbacks on errors and warnings
requested_original_flash_attn = False
else:
applicable_attn_implementation = "kernels-community/vllm-flash-attn3"
applicable_attn_implementation = FLASH_ATTN_KERNEL_FALLBACK[attn_implementation.removeprefix("paged|")]

if is_torch_xpu_available() and attn_implementation.removeprefix("paged|") == "flash_attention_2":
# On XPU, kernels library is the native implementation
# Disabling this flag to avoid giving wrong fallbacks on errors and warnings
requested_original_flash_attn = False

if is_paged:
applicable_attn_implementation = f"paged|{applicable_attn_implementation}"
@@ -2392,13 +2398,15 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
source_is_there = source_param_name not in missing_keys
target_is_there = target_param_name not in missing_keys
# Both are already present -> it means the config is wrong and do not reflect the actual
# checkpoint -> let's raise a warning and do nothing
# checkpoint -> let's raise a warning and NOT tie them
if source_is_there and target_is_there:
logger.warning(
f"The tied weights mapping and config for this model specifies to tie {source_param_name} to "
f"{target_param_name}, but both are present in the checkpoints, so we will NOT tie them. "
"You should update the config with `tie_word_embeddings=False` to silence this warning"
)
# Remove from internal attribute to correctly reflect actual tied weights
self.all_tied_weights_keys.pop(target_param_name)
# Skip to next iteration
continue
# We're missing the source but we have the target -> we swap them, tying the parameter that exists


+ 1
- 1
src/transformers/models/auto/auto_factory.py View File

@@ -543,7 +543,7 @@ def add_generation_mixin_to_remote_model(model_class):

class _LazyAutoMapping(OrderedDict[type[PreTrainedConfig], _LazyAutoMappingValue]):
"""
" A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.

Args:
- config_mapping: The map model type to config class


+ 1
- 0
src/transformers/models/bart/modeling_bart.py View File

@@ -1463,6 +1463,7 @@ class BartDecoderWrapper(BartPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.decoder = BartDecoder(config)
self.post_init()

def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)


+ 1
- 0
src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py View File

@@ -2582,6 +2582,7 @@ class BigBirdPegasusDecoderWrapper(BigBirdPegasusPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.decoder = BigBirdPegasusDecoder(config)
self.post_init()

def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)


+ 1
- 0
src/transformers/models/blenderbot/modeling_blenderbot.py View File

@@ -1156,6 +1156,7 @@ class BlenderbotDecoderWrapper(BlenderbotPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.decoder = BlenderbotDecoder(config)
self.post_init()

def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)


+ 1
- 0
src/transformers/models/blenderbot_small/modeling_blenderbot_small.py View File

@@ -1116,6 +1116,7 @@ class BlenderbotSmallDecoderWrapper(BlenderbotSmallPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.decoder = BlenderbotSmallDecoder(config)
self.post_init()

def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)


+ 2
- 0
src/transformers/models/blip/modeling_blip_text.py View File

@@ -740,6 +740,8 @@ class BlipTextLMHeadModel(BlipTextPreTrainedModel, GenerationMixin):
self.cls = BlipTextOnlyMLMHead(config)
self.label_smoothing = config.label_smoothing

self.post_init()

def get_input_embeddings(self):
return self.bert.get_input_embeddings()



+ 2
- 0
src/transformers/models/blt/modeling_blt.py View File

@@ -753,6 +753,8 @@ class BltPatcher(BltPreTrainedModel):
bias=False,
)

self.post_init()

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,


+ 2
- 0
src/transformers/models/blt/modular_blt.py View File

@@ -634,6 +634,8 @@ class BltPatcher(BltPreTrainedModel):
bias=False,
)

self.post_init()

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,


+ 1
- 0
src/transformers/models/bridgetower/modeling_bridgetower.py View File

@@ -955,6 +955,7 @@ class BridgeTowerVisionModel(BridgeTowerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.visual = BridgeTowerVisionTransformer(config)
self.post_init()

@property
def dtype(self):


+ 1
- 0
src/transformers/models/chameleon/modeling_chameleon.py View File

@@ -809,6 +809,7 @@ class ChameleonVQVAE(ChameleonPreTrainedModel):
self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
self.eval() # Chameleon's VQ model is frozen
self.post_init()

def encode(self, pixel_values: torch.LongTensor):
hidden_states = self.encoder(pixel_values)


+ 2
- 0
src/transformers/models/clipseg/modeling_clipseg.py View File

@@ -1121,6 +1121,8 @@ class CLIPSegDecoder(CLIPSegPreTrainedModel):
decoder_config.hidden_act = "relu"
self.layers = nn.ModuleList([CLIPSegDecoderLayer(decoder_config) for _ in range(len(config.extract_layers))])

self.post_init()

def forward(
self,
hidden_states: tuple[torch.Tensor],


+ 0
- 4
src/transformers/models/decision_transformer/modeling_decision_transformer.py View File

@@ -367,12 +367,8 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
config: DecisionTransformerConfig
base_model_prefix = "transformer"
supports_gradient_checkpointing = True

_can_compile_fullgraph = False

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)

@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights."""


+ 4
- 0
src/transformers/models/dia/modeling_dia.py View File

@@ -452,6 +452,8 @@ class DiaEncoder(DiaPreTrainedModel):
self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
self.rotary_emb = DiaRotaryEmbedding(config=config)

self.post_init()

@auto_docstring
@can_return_tuple
def forward(
@@ -578,6 +580,8 @@ class DiaDecoder(DiaPreTrainedModel):
self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
self.rotary_emb = DiaRotaryEmbedding(config=config)

self.post_init()

@auto_docstring
@can_return_tuple
def forward(


+ 4
- 0
src/transformers/models/dia/modular_dia.py View File

@@ -241,6 +241,8 @@ class DiaEncoder(DiaPreTrainedModel):
self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
self.rotary_emb = DiaRotaryEmbedding(config=config)

self.post_init()

@auto_docstring
@can_return_tuple
def forward(
@@ -367,6 +369,8 @@ class DiaDecoder(DiaPreTrainedModel):
self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
self.rotary_emb = DiaRotaryEmbedding(config=config)

self.post_init()

@auto_docstring
@can_return_tuple
def forward(


+ 1
- 0
src/transformers/models/evolla/modeling_evolla.py View File

@@ -524,6 +524,7 @@ class EvollaSaProtProteinEncoder(EvollaSaProtPreTrainedModel):
super().__init__(config)
self.embeddings = EvollaSaProtEmbeddings(config)
self.encoder = EvollaSaProtEncoder(config)
self.post_init()

def get_input_embeddings(self):
return self.embeddings.word_embeddings


+ 1
- 0
src/transformers/models/evolla/modular_evolla.py View File

@@ -209,6 +209,7 @@ class EvollaSaProtProteinEncoder(EvollaSaProtPreTrainedModel):
super().__init__(config)
self.embeddings = EvollaSaProtEmbeddings(config)
self.encoder = EvollaSaProtEncoder(config)
self.post_init()

def get_input_embeddings(self):
return self.embeddings.word_embeddings


+ 0
- 3
src/transformers/models/flaubert/modeling_flaubert.py View File

@@ -660,9 +660,6 @@ class FlaubertPreTrainedModel(PreTrainedModel):
config: FlaubertConfig
base_model_prefix = "transformer"

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)

@property
def dummy_inputs(self):
inputs_list = torch.tensor([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])


+ 1
- 0
src/transformers/models/gemma3n/modeling_gemma3n.py View File

@@ -919,6 +919,7 @@ class Gemma3nAudioEncoder(PreTrainedModel):
self.conformer = nn.ModuleList(
[Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)]
)
self.post_init()

def forward(
self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor, **kwargs


+ 1
- 0
src/transformers/models/gemma3n/modular_gemma3n.py View File

@@ -1472,6 +1472,7 @@ class Gemma3nAudioEncoder(PreTrainedModel):
self.conformer = nn.ModuleList(
[Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)]
)
self.post_init()

def forward(
self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor, **kwargs


+ 1
- 0
src/transformers/models/got_ocr2/modeling_got_ocr2.py View File

@@ -433,6 +433,7 @@ class GotOcr2VisionEncoder(GotOcr2PreTrainedModel):
self.neck = GotOcr2VisionNeck(config)

self.gradient_checkpointing = False
self.post_init()

def get_input_embeddings(self):
return self.patch_embed


+ 0
- 4
src/transformers/models/gpt2/modeling_gpt2.py View File

@@ -476,12 +476,8 @@ class GPT2PreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
_supports_attention_backend = True

_can_compile_fullgraph = True

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)

@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights."""


+ 0
- 8
src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py View File

@@ -26,7 +26,6 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
@@ -43,10 +42,6 @@ from ...utils import (
from .configuration_gpt_bigcode import GPTBigCodeConfig


if is_flash_attn_available():
pass


logger = logging.get_logger(__name__)


@@ -360,9 +355,6 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)

@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights."""


+ 4
- 0
src/transformers/models/idefics2/modeling_idefics2.py View File

@@ -452,6 +452,8 @@ class Idefics2VisionTransformer(Idefics2PreTrainedModel):
self.encoder = Idefics2Encoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

self.post_init()

def get_input_embeddings(self):
return self.embeddings

@@ -711,6 +713,8 @@ class Idefics2PerceiverResampler(Idefics2PreTrainedModel):
self.layers = nn.ModuleList([Idefics2PerceiverLayer(config, idx) for idx in range(self.depth)])
self.norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)

self.post_init()

@auto_docstring
def forward(
self,


+ 2
- 0
src/transformers/models/idefics3/modeling_idefics3.py View File

@@ -458,6 +458,8 @@ class Idefics3VisionTransformer(Idefics3PreTrainedModel):
self.patch_size = config.patch_size
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

self.post_init()

# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionTransformer.get_input_embeddings
def get_input_embeddings(self):
return self.embeddings


+ 0
- 2
src/transformers/models/janus/modeling_janus.py View File

@@ -973,8 +973,6 @@ class JanusVQVAE(JanusPreTrainedModel):
self.eval() # Janus's VQ model is frozen
self.decoder = JanusVQVAEDecoder(config)
self.gradient_checkpointing = False

# Initialize the VQVAE model.
self.post_init()

def encode(self, pixel_values: torch.LongTensor):


+ 1
- 0
src/transformers/models/marian/modeling_marian.py View File

@@ -1248,6 +1248,7 @@ class MarianDecoderWrapper(MarianPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.decoder = MarianDecoder(config)
self.post_init()

def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)


+ 1
- 0
src/transformers/models/mbart/modeling_mbart.py View File

@@ -1442,6 +1442,7 @@ class MBartDecoderWrapper(MBartPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.decoder = MBartDecoder(config)
self.post_init()

def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)


+ 2
- 0
src/transformers/models/moshi/modeling_moshi.py View File

@@ -869,6 +869,8 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
self.gradient_checkpointing = False
self.config = config

self.post_init()

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,


+ 1
- 0
src/transformers/models/mvp/modeling_mvp.py View File

@@ -1509,6 +1509,7 @@ class MvpDecoderWrapper(MvpPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.decoder = MvpDecoder(config)
self.post_init()

def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)


+ 2
- 0
src/transformers/models/ovis2/modeling_ovis2.py View File

@@ -457,6 +457,8 @@ class Ovis2VisionModel(Ovis2PreTrainedModel):
)
self.head_norm = nn.LayerNorm(self.vocab_size - self.num_visual_indicator_tokens)

self.post_init()

def forward(self, pixel_values: torch.FloatTensor, **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
outputs = self.transformer(pixel_values, **kwargs)
last_hidden_state = outputs[0]


+ 2
- 0
src/transformers/models/ovis2/modular_ovis2.py View File

@@ -176,6 +176,8 @@ class Ovis2VisionModel(Ovis2PreTrainedModel):
)
self.head_norm = nn.LayerNorm(self.vocab_size - self.num_visual_indicator_tokens)

self.post_init()

def forward(self, pixel_values: torch.FloatTensor, **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
outputs = self.transformer(pixel_values, **kwargs)
last_hidden_state = outputs[0]


+ 2
- 0
src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py View File

@@ -919,6 +919,8 @@ class PaddleOCRVisionTransformer(PaddleOCRVLPreTrainedModel):
self.encoder = PaddleOCRVisionEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

self.post_init()

def forward(
self,
pixel_values: torch.FloatTensor,


+ 2
- 0
src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py View File

@@ -1037,6 +1037,8 @@ class PaddleOCRVisionTransformer(PaddleOCRVLPreTrainedModel):
self.encoder = PaddleOCRVisionEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

self.post_init()

def forward(
self,
pixel_values: torch.FloatTensor,


+ 1
- 0
src/transformers/models/pegasus/modeling_pegasus.py View File

@@ -1220,6 +1220,7 @@ class PegasusDecoderWrapper(PegasusPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.decoder = PegasusDecoder(config)
self.post_init()

def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)


+ 1
- 0
src/transformers/models/pegasus_x/modeling_pegasus_x.py View File

@@ -1476,6 +1476,7 @@ class PegasusXDecoderWrapper(PegasusXPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.decoder = PegasusXDecoder(config)
self.post_init()

def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)


+ 1
- 0
src/transformers/models/plbart/modeling_plbart.py View File

@@ -1273,6 +1273,7 @@ class PLBartDecoderWrapper(PLBartPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.decoder = PLBartDecoder(config)
self.post_init()

def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)


+ 8
- 0
src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py View File

@@ -1105,6 +1105,8 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel):
)
self.gradient_checkpointing = False

self.post_init()

def rot_pos_emb(self, grid_thw):
pos_ids = []
for t, h, w in grid_thw:
@@ -3441,6 +3443,8 @@ class Qwen2_5OmniToken2WavBigVGANModel(Qwen2_5OmniPreTrainedModel):
config.upsample_initial_channel // (2**self.num_upsample_layers), 1, 7, 1, padding=3, bias=False
)

self.post_init()

def normalize_spectrogram(self, spectrogram, max_value, min_db):
return torch.clamp((2 * max_value) * ((spectrogram - min_db) / (-min_db)) - max_value, -max_value, max_value)

@@ -3568,6 +3572,8 @@ class Qwen2_5OmniToken2WavDiTModel(Qwen2_5OmniPreTrainedModel):
self.norm_out = Qwen2_5_OmniAdaLayerNormZero_Final(config.hidden_size) # final modulation
self.proj_out = nn.Linear(config.hidden_size, config.mel_dim)

self.post_init()

def _create_block_diff(self, hidden_states):
batch, seq_len = hidden_states.shape[0], hidden_states.shape[1]
block_indices = torch.arange(seq_len, device=hidden_states.device) // self.block_size # [seq_length]
@@ -3720,6 +3726,8 @@ class Qwen2_5OmniToken2WavModel(Qwen2_5OmniPreTrainedModel):
config.bigvgan_config, attn_implementation=attn_impl
)

self.post_init()

def forward(
self,
code,


+ 6
- 0
src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py View File

@@ -3600,6 +3600,8 @@ class Qwen2_5OmniToken2WavBigVGANModel(Qwen2_5OmniPreTrainedModel):
config.upsample_initial_channel // (2**self.num_upsample_layers), 1, 7, 1, padding=3, bias=False
)

self.post_init()

def normalize_spectrogram(self, spectrogram, max_value, min_db):
return torch.clamp((2 * max_value) * ((spectrogram - min_db) / (-min_db)) - max_value, -max_value, max_value)

@@ -3727,6 +3729,8 @@ class Qwen2_5OmniToken2WavDiTModel(Qwen2_5OmniPreTrainedModel):
self.norm_out = Qwen2_5_OmniAdaLayerNormZero_Final(config.hidden_size) # final modulation
self.proj_out = nn.Linear(config.hidden_size, config.mel_dim)

self.post_init()

def _create_block_diff(self, hidden_states):
batch, seq_len = hidden_states.shape[0], hidden_states.shape[1]
block_indices = torch.arange(seq_len, device=hidden_states.device) // self.block_size # [seq_length]
@@ -3879,6 +3883,8 @@ class Qwen2_5OmniToken2WavModel(Qwen2_5OmniPreTrainedModel):
config.bigvgan_config, attn_implementation=attn_impl
)

self.post_init()

def forward(
self,
code,


+ 2
- 0
src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py View File

@@ -336,6 +336,8 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
)
self.gradient_checkpointing = False

self.post_init()

def rot_pos_emb(self, grid_thw):
pos_ids = []
for t, h, w in grid_thw:


+ 2
- 0
src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py View File

@@ -207,6 +207,8 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
)
self.gradient_checkpointing = False

self.post_init()

def rot_pos_emb(self, grid_thw):
pos_ids = []
for t, h, w in grid_thw:


+ 2
- 0
src/transformers/models/qwen2_vl/modeling_qwen2_vl.py View File

@@ -693,6 +693,8 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
)
self.gradient_checkpointing = False

self.post_init()

def get_dtype(self) -> torch.dtype:
return self.blocks[0].mlp.fc2.weight.dtype



+ 4
- 0
src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py View File

@@ -1073,6 +1073,8 @@ class Qwen3OmniMoeVisionEncoder(Qwen3OmniMoePreTrainedModel):

self.gradient_checkpointing = False

self.post_init()

def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
merge_size = self.spatial_merge_size

@@ -3716,6 +3718,8 @@ class Qwen3OmniMoeCode2WavDecoderBlock(Qwen3OmniMoePreTrainedModel):

self.block = nn.ModuleList(block)

self.post_init()

def forward(self, hidden, **kwargs):
for block in self.block:
hidden = block(hidden)


+ 2
- 0
src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py View File

@@ -2337,6 +2337,8 @@ class Qwen3OmniMoeCode2WavDecoderBlock(Qwen3OmniMoePreTrainedModel):

self.block = nn.ModuleList(block)

self.post_init()

def forward(self, hidden, **kwargs):
for block in self.block:
hidden = block(hidden)


+ 2
- 0
src/transformers/models/qwen3_vl/modeling_qwen3_vl.py View File

@@ -632,6 +632,8 @@ class Qwen3VLVisionModel(Qwen3VLPreTrainedModel):

self.gradient_checkpointing = False

self.post_init()

def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
merge_size = self.spatial_merge_size



+ 2
- 0
src/transformers/models/qwen3_vl/modular_qwen3_vl.py View File

@@ -528,6 +528,8 @@ class Qwen3VLVisionModel(Qwen3VLPreTrainedModel):

self.gradient_checkpointing = False

self.post_init()

def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
merge_size = self.spatial_merge_size



+ 2
- 0
src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py View File

@@ -646,6 +646,8 @@ class Qwen3VLMoeVisionModel(Qwen3VLMoePreTrainedModel):

self.gradient_checkpointing = False

self.post_init()

def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
merge_size = self.spatial_merge_size



+ 6
- 0
src/transformers/models/rag/modeling_rag.py View File

@@ -422,6 +422,8 @@ class RagModel(RagPreTrainedModel):
self.ctx_encoder = None
self.context_encoder_training = False

self.post_init()

@auto_docstring
def forward(
self,
@@ -690,6 +692,8 @@ class RagSequenceForGeneration(RagPreTrainedModel):
# instantiate model
self.rag = RagModel(config=config, question_encoder=question_encoder, generator=generator, retriever=retriever)

self.post_init()

def set_retriever(self, retriever: RagRetriever):
self.rag.retriever = retriever

@@ -1126,6 +1130,8 @@ class RagTokenForGeneration(RagPreTrainedModel, GenerationMixin):
# instantiate model
self.rag = RagModel(config=config, question_encoder=question_encoder, generator=generator, retriever=retriever)

self.post_init()

def set_retriever(self, retriever: RagRetriever):
self.rag.retriever = retriever



+ 1
- 0
src/transformers/models/sam/modeling_sam.py View File

@@ -1048,6 +1048,7 @@ class SamVisionEncoder(SamPreTrainedModel):
self.neck = SamVisionNeck(config)

self.gradient_checkpointing = False
self.post_init()

def get_input_embeddings(self):
return self.patch_embed


+ 2
- 0
src/transformers/models/sam2/modeling_sam2.py View File

@@ -600,6 +600,8 @@ class Sam2HieraDetModel(Sam2PreTrainedModel):
self.blocks.append(block)
total_block_idx += 1

self.post_init()

def get_input_embeddings(self):
return self.patch_embed



+ 2
- 0
src/transformers/models/sam2/modular_sam2.py View File

@@ -716,6 +716,8 @@ class Sam2HieraDetModel(Sam2PreTrainedModel):
self.blocks.append(block)
total_block_idx += 1

self.post_init()

def get_input_embeddings(self):
return self.patch_embed



+ 6
- 0
src/transformers/models/sam3/modeling_sam3.py View File

@@ -1338,6 +1338,8 @@ class Sam3DetrEncoder(Sam3PreTrainedModel):

self.layers = nn.ModuleList([Sam3DetrEncoderLayer(config) for _ in range(config.num_layers)])

self.post_init()

def _prepare_multilevel_features(
self,
vision_features: list[torch.Tensor],
@@ -1617,6 +1619,8 @@ class Sam3DetrDecoder(Sam3PreTrainedModel):

self.position_encoding = Sam3SinePositionEmbedding(num_pos_feats=config.hidden_size // 2, normalize=False)

self.post_init()

@compile_compatible_method_lru_cache(maxsize=1)
def _get_coords(
self, height: torch.Tensor, width: torch.Tensor, dtype: torch.dtype, device: torch.device
@@ -1987,6 +1991,8 @@ class Sam3MaskDecoder(Sam3PreTrainedModel):
self.prompt_cross_attn_norm = nn.LayerNorm(hidden_size)
self.prompt_cross_attn_dropout = nn.Dropout(config.dropout)

self.post_init()

@check_model_inputs
def forward(
self,


+ 2
- 2
src/transformers/models/sam3_video/modeling_sam3_video.py View File

@@ -505,8 +505,6 @@ class Sam3VideoPreTrainedModel(PreTrainedModel):

@auto_docstring
class Sam3VideoModel(Sam3VideoPreTrainedModel):
all_tied_weights_keys = {}

def __init__(self, config: Sam3VideoConfig):
super().__init__(config)
self.config = config
@@ -542,6 +540,8 @@ class Sam3VideoModel(Sam3VideoPreTrainedModel):

self.tracker_neck = Sam3VisionNeck(config.detector_config.vision_config)

self.post_init()

def get_vision_features_for_tracker(self, vision_embeds: torch.Tensor):
hidden_states = vision_embeds.last_hidden_state
batch_size = hidden_states.shape[0]


+ 1
- 0
src/transformers/models/sam_hq/modeling_sam_hq.py View File

@@ -525,6 +525,7 @@ class SamHQVisionEncoder(SamHQPreTrainedModel):
self.neck = SamHQVisionNeck(config)

self.gradient_checkpointing = False
self.post_init()

def get_input_embeddings(self):
return self.patch_embed


+ 2
- 2
src/transformers/models/segformer/modeling_segformer.py View File

@@ -549,9 +549,9 @@ class SegformerMLP(nn.Module):
return hidden_states


class SegformerDecodeHead(SegformerPreTrainedModel):
class SegformerDecodeHead(nn.Module):
def __init__(self, config):
super().__init__(config)
super().__init__()
# linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size
mlps = []
for i in range(config.num_encoder_blocks):


+ 1
- 0
src/transformers/models/shieldgemma2/modeling_shieldgemma2.py View File

@@ -57,6 +57,7 @@ class ShieldGemma2ForImageClassification(PreTrainedModel):
self.yes_token_index = getattr(config, "yes_token_index", 10_784)
self.no_token_index = getattr(config, "no_token_index", 3771)
self.model = AutoModelForImageTextToText.from_config(config=config)
self.post_init()

def get_input_embeddings(self):
return self.model.language_model.get_input_embeddings()


+ 2
- 0
src/transformers/models/siglip/modeling_siglip.py View File

@@ -631,6 +631,8 @@ class SiglipVisionTransformer(SiglipPreTrainedModel):
if self.use_head:
self.head = SiglipMultiheadAttentionPoolingHead(config)

self.post_init()

@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def forward(


+ 2
- 0
src/transformers/models/siglip2/modeling_siglip2.py View File

@@ -501,6 +501,8 @@ class Siglip2VisionTransformer(Siglip2PreTrainedModel):
if self.use_head:
self.head = Siglip2MultiheadAttentionPoolingHead(config)

self.post_init()

@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def forward(


+ 2
- 0
src/transformers/models/smolvlm/modeling_smolvlm.py View File

@@ -330,6 +330,8 @@ class SmolVLMVisionTransformer(SmolVLMPreTrainedModel):
self.patch_size = config.patch_size
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

self.post_init()

def get_input_embeddings(self):
return self.embeddings



+ 2
- 0
src/transformers/models/timm_backbone/modeling_timm_backbone.py View File

@@ -84,6 +84,8 @@ class TimmBackbone(PreTrainedModel, BackboneMixin):
self._all_layers = {layer["module"]: str(i) for i, layer in enumerate(self._backbone.feature_info.info)}
super()._init_backbone(config)

self.post_init()

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
requires_backends(cls, ["vision", "timm"])


+ 2
- 4
src/transformers/models/timm_wrapper/modeling_timm_wrapper.py View File

@@ -90,10 +90,6 @@ class TimmWrapperPreTrainedModel(PreTrainedModel):
# used in Trainer to avoid passing `loss_kwargs` to model forward
accepts_loss_kwargs = False

def __init__(self, *args, **kwargs):
requires_backends(self, ["vision", "timm"])
super().__init__(*args, **kwargs)

def post_init(self):
self.supports_gradient_checkpointing = self._timm_model_supports_gradient_checkpointing()
super().post_init()
@@ -143,6 +139,7 @@ class TimmWrapperModel(TimmWrapperPreTrainedModel):
"""

def __init__(self, config: TimmWrapperConfig):
requires_backends(self, ["vision", "timm"])
super().__init__(config)
# using num_classes=0 to avoid creating classification head
extra_init_kwargs = config.model_args or {}
@@ -265,6 +262,7 @@ class TimmWrapperForImageClassification(TimmWrapperPreTrainedModel):
"""

def __init__(self, config: TimmWrapperConfig):
requires_backends(self, ["vision", "timm"])
super().__init__(config)

if config.num_labels == 0:


+ 1
- 0
src/transformers/models/trocr/modeling_trocr.py View File

@@ -636,6 +636,7 @@ class TrOCRDecoderWrapper(TrOCRPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.decoder = TrOCRDecoder(config)
self.post_init()

def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)


+ 1
- 0
src/transformers/models/whisper/modeling_whisper.py View File

@@ -1247,6 +1247,7 @@ class WhisperDecoderWrapper(WhisperPreTrainedModel):
super().__init__(config)
config.is_encoder_decoder = False
self.decoder = WhisperDecoder(config)
self.post_init()

def get_input_embeddings(self):
return self.decoder.embed_tokens


+ 0
- 3
src/transformers/models/xlm/modeling_xlm.py View File

@@ -603,9 +603,6 @@ class XLMPreTrainedModel(PreTrainedModel):
config: XLMConfig
base_model_prefix = "transformer"

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)

@property
def dummy_inputs(self):
inputs_list = torch.tensor([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])


+ 2
- 2
src/transformers/pipelines/text_to_audio.py View File

@@ -117,8 +117,8 @@ class TextToAudioPipeline(Pipeline):
else vocoder
)

if self.model.config.model_type in ["musicgen"]:
# MusicGen expect to use the tokenizer
if self.model.config.model_type in ["musicgen", "speecht5"]:
# MusicGen and SpeechT5 expect to use their tokenizer instead
self.processor = None

self.sampling_rate = sampling_rate


+ 2
- 2
src/transformers/testing_utils.py View File

@@ -221,7 +221,7 @@ if is_torch_available():
import torch
from safetensors.torch import load_file

from .modeling_utils import PreTrainedModel
from .modeling_utils import FLASH_ATTN_KERNEL_FALLBACK, PreTrainedModel

IS_ROCM_SYSTEM = torch.version.hip is not None
IS_CUDA_SYSTEM = torch.version.cuda is not None
@@ -620,7 +620,7 @@ def require_flash_attn(test_case):
try:
from kernels import get_kernel

get_kernel("kernels-community/flash-attn2")
get_kernel(FLASH_ATTN_KERNEL_FALLBACK["flash_attention_2"])
except Exception as _:
kernels_available = False



+ 11
- 1
tests/generation/test_continuous_batching.py View File

@@ -251,7 +251,17 @@ class ContinuousBatchingGenerationTest(unittest.TestCase):
generate_outputs = model.generate(**inputs.to(torch_device), generation_config=model.generation_config)

for i, user_message in enumerate(user_messages):
continuous_batching_output = continuous_batching_outputs[f"req_{i}"].generated_tokens
# Find the corresponding request in the continuous batching outputs
input_tokens = inputs.input_ids[i][inputs.attention_mask[i] == 1].tolist()
key_to_pop = None
for key, state in continuous_batching_outputs.items():
if state.prompt_ids == input_tokens:
key_to_pop = key
break
if key_to_pop is None:
self.fail(f"Request {i} not found in continuous batching outputs")
continuous_batching_output = continuous_batching_outputs.pop(key_to_pop).generated_tokens

generate_output = generate_outputs[i][num_input_tokens:].tolist()
while generate_output[-1] == model.generation_config.pad_token_id:
generate_output.pop()


+ 1
- 1
tests/models/helium/test_modeling_helium.py View File

@@ -56,7 +56,7 @@ class HeliumIntegrationTest(unittest.TestCase):
model_id = "kyutai/helium-1-preview"
expected_texts = Expectations(
{
("rocm", (9, 5)): ["Hello, today is a great day to start a new project. I have been working on a new project for a while now, and I"],
("rocm", (9, 5)): ["Hello, today is a great day to start a new project. I have been working on a new project for a while now, and I have"],
(None, None): ["Hello, today is a great day to start a new project. I have been working on a new project for a while now and I have"],
("cuda", 8): ['Hello, today is a great day to start a new project. I have been working on a new project for a while now, and I'],
}


+ 1
- 1
tests/models/openai/test_modeling_openai.py View File

@@ -300,5 +300,5 @@ class OPENAIGPTModelLanguageGenerationTest(unittest.TestCase):
481,
] # the president is a very good man. " \n " i\'m sure he is, " said the

output_ids = model.generate(input_ids, do_sample=False)
output_ids = model.generate(input_ids, do_sample=False, max_length=20)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)

+ 33
- 0
tests/pipelines/test_pipelines_text_to_audio.py View File

@@ -15,6 +15,7 @@
import unittest

import numpy as np
import torch

from transformers import (
MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING,
@@ -40,6 +41,38 @@ class TextToAudioPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING
# for now only test text_to_waveform and not text_to_spectrogram

@require_torch
def test_small_speecht5_pt(self):
audio_generator = pipeline(task="text-to-audio", model="microsoft/speecht5_tts")
num_channels = 1 # model generates mono audio
forward_params = {
"do_sample": True,
"semantic_max_new_tokens": 5,
"speaker_embeddings": torch.rand(1, 512) * 0.2 - 0.1,
}

outputs = audio_generator("This is a test", forward_params=forward_params)
self.assertEqual({"audio": ANY(np.ndarray), "sampling_rate": 16000}, outputs)
self.assertEqual(len(outputs["audio"].shape), num_channels)

# test two examples side-by-side
outputs = audio_generator(["This is a test", "This is a second test"], forward_params=forward_params)
audio = [output["audio"] for output in outputs]
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)

# test batching, this time with parameterization in the forward pass
audio_generator = pipeline(task="text-to-audio", model="microsoft/speecht5_tts")
forward_params = {
"do_sample": False,
"max_new_tokens": 5,
"speaker_embeddings": torch.rand(1, 512) * 0.2 - 0.1,
}
outputs = audio_generator(
["This is a test", "This is a second test"], forward_params=forward_params, batch_size=2
)
audio = [output["audio"] for output in outputs]
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)

@require_torch
def test_small_musicgen_pt(self):
music_generator = pipeline(


+ 17
- 3
tests/test_modeling_common.py View File

@@ -52,7 +52,7 @@ from transformers.integrations.deepspeed import (
unset_hf_deepspeed_config,
)
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_utils import _get_tied_weight_keys
from transformers.modeling_utils import FLASH_ATTN_KERNEL_FALLBACK, _get_tied_weight_keys
from transformers.models.auto import get_values
from transformers.models.auto.modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
@@ -3243,6 +3243,20 @@ class ModelTesterMixin:
self.skipTest(f"bfloat16 not supported on {torch_device} (on the specific device currently used)")

dtype = torch.bfloat16

def _expected_attn_implementations(attention_implementation: str) -> set[str]:
# Allow kernels fallbacks for flash attention tests.
requested = attention_implementation
base = requested.removeprefix("paged|")
prefix = "paged|" if requested.startswith("paged|") else ""

expected = {requested}
if base in FLASH_ATTN_KERNEL_FALLBACK:
expected.add(f"{prefix}{FLASH_ATTN_KERNEL_FALLBACK[base]}")
return expected

expected_attn_implementations = _expected_attn_implementations(attn_implementation)

for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
@@ -3275,7 +3289,7 @@ class ModelTesterMixin:
for key in model_fa.config:
if isinstance(getattr(model_fa.config, key), PreTrainedConfig):
sub_config = getattr(model_fa.config, key)
self.assertTrue(sub_config._attn_implementation == attn_implementation)
self.assertIn(sub_config._attn_implementation, expected_attn_implementations)

has_fa = False
for name, submodule in model_fa.named_modules():
@@ -3283,7 +3297,7 @@ class ModelTesterMixin:
if (
"Attention" in class_name
and getattr(submodule, "config", None)
and submodule.config._attn_implementation == attn_implementation
and submodule.config._attn_implementation in expected_attn_implementations
):
has_fa = True
break


+ 8
- 3
tests/utils/test_modeling_utils.py View File

@@ -129,7 +129,11 @@ if is_torch_available():
_prepare_4d_attention_mask,
_prepare_4d_causal_attention_mask,
)
from transformers.modeling_utils import _find_disjoint, _find_identical
from transformers.modeling_utils import (
FLASH_ATTN_KERNEL_FALLBACK,
_find_disjoint,
_find_identical,
)
from transformers.pytorch_utils import isin_mps_friendly

# Fake pretrained models for tests
@@ -3028,7 +3032,7 @@ class TestAttentionImplementation(unittest.TestCase):
)

self.assertTrue(
"You do not have `flash_attn` installed, using `kernels-community/flash-attn2` from the `kernels` library instead!"
f"You do not have `flash_attn` installed, using `{FLASH_ATTN_KERNEL_FALLBACK['flash_attention_2']}` from the `kernels` library instead!"
in cl.out
)

@@ -3040,7 +3044,8 @@ class TestAttentionImplementation(unittest.TestCase):

with self.assertRaises(ImportError) as cm:
_ = AutoModel.from_pretrained(
"hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="kernels-community/flash-attn2"
"hf-tiny-model-private/tiny-random-MCTCTModel",
attn_implementation=FLASH_ATTN_KERNEL_FALLBACK["flash_attention_2"],
)

self.assertTrue("`kernels` is either not installed or uses an incompatible version." in str(cm.exception))


+ 0
- 101
utils/check_init_weights_data.py View File

@@ -1,101 +0,0 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility that ensures `_init_weights(self, module)` implementations do not use `.data`.

Direct `.data` access breaks the lazy-initialization safeguards handled by `HFParameter`, so the library forbids it.
"""

import ast
import sys
from pathlib import Path


MODELING_ROOT = Path("src/transformers/models")
MODELING_PATTERNS = ("modeling_*.py", "modular_*.py")


def iter_modeling_files():
for pattern in MODELING_PATTERNS:
yield from MODELING_ROOT.rglob(pattern)


def full_name(node):
"""
Return full dotted name from an Attribute or Name node.
"""
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
return full_name(node.value) + "." + node.attr
else:
raise ValueError("Not a Name or Attribute node")


def function_has_forbidden_usage(fn: ast.FunctionDef) -> int | None:
"""
Returns the first offending line number if we detect an in-place operation on a module's weight, otherwise `None`.
"""

args = fn.args.args
if len(args) < 2 or getattr(args[0], "arg", None) != "self" or getattr(args[1], "arg", None) != "module":
return None

for node in ast.walk(fn):
if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute):
is_inplace_ops = node.func.attr.endswith("_")
# We allow in-place ops on tensors that are not part of the module itself (see e.g. modeling_qwen3_next.py L997)
is_on_module_weight = isinstance(node.func.value, (ast.Name, ast.Attribute)) and "module." in full_name(
node.func.value
)
if is_inplace_ops and is_on_module_weight:
return node.lineno

return None


def main() -> int:
violations: list[str] = []

for file_path in iter_modeling_files():
try:
text = file_path.read_text(encoding="utf-8")
tree = ast.parse(text, filename=str(file_path))
except Exception as exc:
violations.append(f"{file_path}: failed to parse ({exc}).")
continue

for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name == "_init_weights":
offending_line = function_has_forbidden_usage(node)
if offending_line is not None:
violations.append(
f"{file_path}:{offending_line}: `_init_weights(self, module)` uses an in-place operation on a "
"module's weight. Please use the `init` functions primitives instead, usually imported as "
"`from ... import initialization as init`."
)
break

if violations:
print("Found forbidden usage inside `_init_weights(self, module)`:\n", file=sys.stderr)
print("\n".join(violations), file=sys.stderr)
return 1

return 0


if __name__ == "__main__":
sys.exit(main())

+ 150
- 0
utils/check_modeling_structure.py View File

@@ -0,0 +1,150 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility that ensures that modeling (and modular) files respect some important conventions we have in Transformers.
"""

import ast
import sys
from pathlib import Path

from rich import print


MODELS_ROOT = Path("src/transformers/models")
MODELING_PATTERNS = ("modeling_*.py", "modular_*.py")


def iter_modeling_files():
for pattern in MODELING_PATTERNS:
yield from MODELS_ROOT.rglob(pattern)


def colored_error_message(file_path: str, line_number: int, message: str) -> str:
return f"[bold red]{file_path}[/bold red]:[bold yellow]L{line_number}[/bold yellow]: {message}"


def full_name(node: ast.AST):
"""
Return full dotted name from an Attribute or Name node.
"""
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
return full_name(node.value) + "." + node.attr
else:
raise ValueError("Not a Name or Attribute node")


def check_init_weights(node: ast.AST, violations: list[str], file_path: str) -> list[str]:
"""
Check that `_init_weights` correctly use `init.(...)` patterns to init the weights in-place. This is very important,
as we rely on the internal flag set on the parameters themselves to check if they need to be re-init or not.
"""
if isinstance(node, ast.FunctionDef) and node.name == "_init_weights":
args = node.args.args
if len(args) < 2 or getattr(args[0], "arg", None) != "self" or getattr(args[1], "arg", None) != "module":
return violations

for sub_node in ast.walk(node):
if isinstance(sub_node, ast.Call) and isinstance(sub_node.func, ast.Attribute):
is_inplace_ops = sub_node.func.attr.endswith("_")
# We allow in-place ops on tensors that are not part of the module itself (see e.g. modeling_qwen3_next.py L997)
is_on_module_weight = isinstance(
sub_node.func.value, (ast.Name, ast.Attribute)
) and "module." in full_name(sub_node.func.value)
if is_inplace_ops and is_on_module_weight:
error_msg = (
"`_init_weights(self, module)` uses an in-place operation on a module's weight. Please use the "
"`init` functions primitives instead, usually imported as `from ... import initialization as init`"
)
violations.append(colored_error_message(file_path, sub_node.lineno, error_msg))

return violations


def is_self_method_call(node: ast.AST, method: str) -> bool:
"""Check if `node` is a method call on `self`, such as `self.method(...)`"""
return (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.value.id == "self"
and node.func.attr == method
)


def is_super_method_call(node: ast.AST, method: str) -> bool:
"""Check if `node` is a call to `super().method(...)`"""
return (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Call)
and isinstance(node.func.value.func, ast.Name)
and node.func.value.func.id == "super"
and node.func.attr == method
)


def check_post_init(node: ast.AST, violations: list[str], file_path: str) -> list[str]:
"""
Check that `self.post_init()` is correctly called at the end of `__init__` for all `PreTrainedModel`s. This is
very important as we need to do some processing there.
"""
# Check if it's a PreTrainedModel class definition
if isinstance(node, ast.ClassDef) and any(full_name(parent).endswith("PreTrainedModel") for parent in node.bases):
for sub_node in node.body:
# Check that we are in __init__
if isinstance(sub_node, ast.FunctionDef) and sub_node.name == "__init__":
for statement in ast.walk(sub_node):
# This means it's correctly called verbatim
if is_self_method_call(statement, method="post_init"):
break
# This means `super().__init__` is called in a modular, so it is already called in the parent
elif "modular_" in str(file_path) and is_super_method_call(statement, method="__init__"):
break
# If we did not break, `post_init` was never called
else:
error_msg = f"`__init__` of {node.name} does not call `self.post_init`"
violations.append(colored_error_message(file_path, sub_node.lineno, error_msg))
break

return violations


def main():
violations: list[str] = []

for file_path in iter_modeling_files():
try:
text = file_path.read_text(encoding="utf-8")
tree = ast.parse(text, filename=str(file_path))
except Exception as exc:
violations.append(f"{file_path}: failed to parse ({exc}).")
continue

for node in ast.walk(tree):
violations = check_init_weights(node, violations, file_path)
violations = check_post_init(node, violations, file_path)

if len(violations) > 0:
violations = sorted(violations)
print("\n".join(violations), file=sys.stderr)
raise ValueError("Some errors in modelings. Check the above message")


if __name__ == "__main__":
main()

Loading…
Cancel
Save
Baidu
map