4 Commits

Author SHA1 Message Date
  Marc Sun 188de42f49 fix quanto 11 hours ago
  Marc Sun 54c3719ca9 added 11 hours ago
  Marc Sun 62470e4844 Fix 11 hours ago
  Marc Sun e520f7d9b3 torchao 11 hours ago
5 changed files with 155 additions and 47 deletions
Split View
  1. +0
    -13
      src/transformers/quantizers/quantizer_mxfp4.py
  2. +2
    -1
      src/transformers/quantizers/quantizer_torchao.py
  3. +53
    -33
      tests/quantization/mxfp4/test_mxfp4.py
  4. +48
    -0
      tests/quantization/quanto_integration/test_quanto.py
  5. +52
    -0
      tests/quantization/torchao_integration/test_torchao.py

+ 0
- 13
src/transformers/quantizers/quantizer_mxfp4.py View File

@@ -215,19 +215,6 @@ class Mxfp4HfQuantizer(HfQuantizer):
)
return config

def get_param_name(self, param_name: str) -> str:
if self.quantization_config.dequantize:
if "_blocks" in param_name:
return param_name.replace("_blocks", "")
elif "_scales" in param_name:
return param_name.replace("_scales", "")
elif not self.pre_quantized:
if param_name.endswith("gate_up_proj"):
return param_name.replace("gate_up_proj", "gate_up_proj_blocks")
if param_name.endswith("down_proj"):
return param_name.replace("down_proj", "down_proj_blocks")
return param_name

def get_state_dict_and_metadata(self, model):
from ..integrations import Mxfp4GptOssExperts



+ 2
- 1
src/transformers/quantizers/quantizer_torchao.py View File

@@ -237,7 +237,8 @@ class TorchAoHfQuantizer(HfQuantizer):
"""
super().preprocess_model(model, config, dtype, checkpoint_files, **kwargs)
# Torchao needs access to all metadata later
self.set_metadata(checkpoint_files)
if checkpoint_files is not None:
self.set_metadata(checkpoint_files)

def _process_model_after_weight_loading(self, model, **kwargs):
"""No process required for torchao quantized model"""


+ 53
- 33
tests/quantization/mxfp4/test_mxfp4.py View File

@@ -26,6 +26,7 @@ from transformers.testing_utils import (
require_torch_large_accelerator,
require_triton,
slow,
torch_device,
)
from transformers.utils import (
is_torch_available,
@@ -240,39 +241,6 @@ class Mxfp4QuantizerTest(unittest.TestCase):
result_dtype = quantizer.update_dtype(torch.float32)
self.assertEqual(result_dtype, torch.float32)

def test_get_param_name_dequantize(self):
"""Test parameter name updating when dequantizing"""
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer

config = Mxfp4Config(dequantize=True)
quantizer = Mxfp4HfQuantizer(config)

# Should remove _blocks suffix
param_name = "model.layers.0.mlp.experts.gate_up_proj_blocks"
updated_name = quantizer.get_param_name(param_name)
self.assertEqual(updated_name, "model.layers.0.mlp.experts.gate_up_proj")

# Should remove _scales suffix
param_name = "model.layers.0.mlp.experts.down_proj_scales"
updated_name = quantizer.get_param_name(param_name)
self.assertEqual(updated_name, "model.layers.0.mlp.experts.down_proj")

# Should not change other names
param_name = "model.embed_tokens.weight"
updated_name = quantizer.get_param_name(param_name)
self.assertEqual(updated_name, "model.embed_tokens.weight")

def test_get_param_name_no_dequantize(self):
"""Test parameter name updating when not dequantizing"""
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer

config = Mxfp4Config(dequantize=False)
quantizer = Mxfp4HfQuantizer(config)

param_name = "model.layers.0.mlp.experts.gate_up_proj_blocks"
updated_name = quantizer.get_param_name(param_name)
self.assertEqual(updated_name, param_name)

def test_is_trainable(self):
"""Test trainability"""
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
@@ -501,3 +469,55 @@ class Mxfp4ModelTest(unittest.TestCase):
device_map="auto",
)
self.check_inference_correctness_quantized(loaded_model, tokenizer)

def test_compute_module_sizes(self):
r"""
Test if we compute the right module sizes needed to generate the device map.
Also test if we get the right values for `total_byte_count` in `caching_allocator_warmup`.
Note that `compute_module_sizes` is being used in `get_total_byte_count`
"""
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.integrations import Mxfp4GptOssExperts
from transformers.integrations.accelerate import compute_module_sizes
from transformers.modeling_utils import expand_device_map, get_total_byte_count
from transformers.quantizers import AutoHfQuantizer

# we need to preprocess the model like that because device_map calculation happens before we load the weights inside the model.
# For normal wieghts, it's fine but for quantized weights, the tensors dtype might change during loading.
with torch.device("meta"):
config = AutoConfig.from_pretrained(self.model_name)
model = AutoModelForCausalLM.from_config(config, dtype=torch.bfloat16)
model_size, _ = compute_module_sizes(model, only_modules=False)

expected_keys = [name for name, _ in model.named_parameters()] + [
name for name, _ in model.named_buffers()
]
expanded_device_map = expand_device_map({"": torch_device}, expected_keys)
total_byte_count = list(get_total_byte_count(model, expanded_device_map).values())[0]

# testing prequantized = False should be enough, the shape should be the same whether it is pre-quantized or not
hf_quantizer = AutoHfQuantizer.from_config(Mxfp4Config(), pre_quantized=False)
hf_quantizer.preprocess_model(model=model, config=model.config)
quantized_model_size, _ = compute_module_sizes(model, hf_quantizer, only_modules=False)

expected_keys = [name for name, _ in model.named_parameters()] + [
name for name, _ in model.named_buffers()
]
expanded_device_map = expand_device_map({"": torch_device}, expected_keys)
quantized_total_byte_count = list(get_total_byte_count(model, expanded_device_map, hf_quantizer).values())[
0
]
for name, module in model.named_modules():
if isinstance(module, Mxfp4GptOssExperts):
# from 16 bits to 4 bits
assert int(model_size[f"{name}.gate_up_proj"] // 4) == int(
quantized_model_size[f"{name}.gate_up_proj"]
)
assert int(model_size[f"{name}.down_proj"] // 4) == int(quantized_model_size[f"{name}.down_proj"])

# check that we get the same value, as we use `compute_module_sizes` in `get_total_byte_count`
assert total_byte_count == model_size[""]
assert quantized_total_byte_count == quantized_model_size[""]

# we should at least have 3 times memory reduction in total for this model
assert model_size[""] > quantized_model_size[""] * 3

+ 48
- 0
tests/quantization/quanto_integration/test_quanto.py View File

@@ -238,6 +238,54 @@ class QuantoQuantizationTest(unittest.TestCase):
self.check_same_model(model, self.quantized_model)
self.check_inference_correctness(model, device=torch_device)

def test_compute_module_sizes(self):
r"""
Test if we compute the right module sizes needed to generate the device map.
Also test if we get the right values for `total_byte_count` in `caching_allocator_warmup`.
Note that `compute_module_sizes` is being used in `get_total_byte_count`
"""
from transformers.integrations.accelerate import compute_module_sizes
from transformers.modeling_utils import expand_device_map, get_total_byte_count
from transformers.quantizers import AutoHfQuantizer

# we need to preprocess the model like that because device_map calculation happens before we load the weights inside the model.
# For normal wieghts, it's fine but for quantized weights, the tensors dtype might change during loading.
with torch.device("meta"):
config = AutoConfig.from_pretrained(self.model_name)
model = AutoModelForCausalLM.from_config(config, dtype=torch.bfloat16)
model_size, _ = compute_module_sizes(model, only_modules=False)

expected_keys = [name for name, _ in model.named_parameters()] + [
name for name, _ in model.named_buffers()
]
expanded_device_map = expand_device_map({"": torch_device}, expected_keys)
total_byte_count = list(get_total_byte_count(model, expanded_device_map).values())[0]

# testing prequantized = False should be enough, the shape should be the same whether it is pre-quantized or not
hf_quantizer = AutoHfQuantizer.from_config(QuantoConfig(weights="int4"), pre_quantized=False)
hf_quantizer.preprocess_model(model=model, config=model.config)
quantized_model_size, _ = compute_module_sizes(model, hf_quantizer, only_modules=False)

expected_keys = [name for name, _ in model.named_parameters()] + [
name for name, _ in model.named_buffers()
]
expanded_device_map = expand_device_map({"": torch_device}, expected_keys)
quantized_total_byte_count = list(get_total_byte_count(model, expanded_device_map, hf_quantizer).values())[
0
]

for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and "lm_head" not in name:
# from 16 bits to 4 bits
assert int(model_size[f"{name}.weight"] // 4) == int(quantized_model_size[f"{name}.weight"])

# check that we get the same value, as we use `compute_module_sizes` in `get_total_byte_count`
assert total_byte_count == model_size[""]
assert quantized_total_byte_count == quantized_model_size[""]

# we should at least have 1.5 times memory reduction in total
assert model_size[""] > quantized_model_size[""] * 1.5


class QuantoQuantizationQBitsTensorTest(QuantoQuantizationTest):
EXPECTED_OUTPUTS = "Hello my name is joe and i am a little girl\n\n"


+ 52
- 0
tests/quantization/torchao_integration/test_torchao.py View File

@@ -540,6 +540,58 @@ class TorchAoTest(unittest.TestCase):
)
self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor))

def test_compute_module_sizes(self):
r"""
Test if we compute the right module sizes needed to generate the device map.
Also test if we get the right values for `total_byte_count` in `caching_allocator_warmup`.
Note that `compute_module_sizes` is being used in `get_total_byte_count`
"""
from transformers import AutoConfig
from transformers.integrations.accelerate import compute_module_sizes
from transformers.modeling_utils import expand_device_map, get_total_byte_count
from transformers.quantizers import AutoHfQuantizer

# we need to preprocess the model like that because device_map calculation happens before we load the weights inside the model.
# For normal wieghts, it's fine but for quantized weights, the tensors dtype might change during loading.
with torch.device("meta"):
config = AutoConfig.from_pretrained(self.model_name)
model = AutoModelForCausalLM.from_config(config, dtype=torch.bfloat16)
model_size, _ = compute_module_sizes(model, only_modules=False)

expected_keys = [name for name, _ in model.named_parameters()] + [
name for name, _ in model.named_buffers()
]
expanded_device_map = expand_device_map({"": torch_device}, expected_keys)
total_byte_count = list(get_total_byte_count(model, expanded_device_map).values())[0]

# testing prequantized = False should be enough, the shape should be the same whether it is pre-quantized or not
hf_quantizer = AutoHfQuantizer.from_config(
TorchAoConfig(quant_type=Int4WeightOnlyConfig(**self.quant_scheme_kwargs)), pre_quantized=False
)
hf_quantizer.preprocess_model(model=model, config=model.config)
quantized_model_size, _ = compute_module_sizes(model, hf_quantizer, only_modules=False)

expected_keys = [name for name, _ in model.named_parameters()] + [
name for name, _ in model.named_buffers()
]
expanded_device_map = expand_device_map({"": torch_device}, expected_keys)
quantized_total_byte_count = list(get_total_byte_count(model, expanded_device_map, hf_quantizer).values())[
0
]

for name, module in model.named_modules():
# modules are not replaced when using torchao
if isinstance(module, torch.nn.Linear) and "lm_head" not in name:
# from 16 bits to 4 bits
assert int(model_size[f"{name}.weight"] // 4) == int(quantized_model_size[f"{name}.weight"])

# check that we get the same value, as we use `compute_module_sizes` in `get_total_byte_count`
assert total_byte_count == model_size[""]
assert quantized_total_byte_count == quantized_model_size[""]

# we should at least have 1.5 times memory reduction in total
assert model_size[""] > quantized_model_size[""] * 2


@require_torch_accelerator
class TorchAoAcceleratorTest(TorchAoTest):


Loading…
Cancel
Save
Baidu
map