12 Commits

Author SHA1 Message Date
  i-robot b768b5c1e5
!7846 【master】权重2.0部分函数重构 5 days ago
  i-robot a767dadbfb
!7841 【master】【门禁】更新ms包 5 days ago
  zyw_hw eb8ddefc51 update ms pkg 1 week ago
  i-robot 34a6ef54c4
!7852 【master】更新三方依赖版本信息 5 days ago
  i-robot 679e210ddd
!7847 【master】【infer】新增Glm4Moe整网减层st用例 5 days ago
  i-robot c6fccd9d7b
!7818 【master】【bugfix】修复部分环境不存在hostname时获取不到时报错的问题 5 days ago
  i-robot 1fccbe3130
!7838 add gmm-linear quant ut 5 days ago
  zyw_hw 769525e198 update third party info 6 days ago
  senzhen ea7fbaa5e4 权重2.0函数整改 1 week ago
  hangangqiang efa6537770 add gmm-linear quant ut 6 days ago
  pengjingyou a5df635ab7 【master】【infer】新增Glm4Moe整网减层st用例 1 week ago
  Yule100 afce36cace bugfix hostname获取不到的bug 1 week ago
26 changed files with 6123 additions and 4483 deletions
Split View
  1. +1
    -1
      .jenkins/test/config/dependent_packages.yaml
  2. +4414
    -2939
      Third_Party_Open_Source_Software_Notice
  3. +0
    -5
      mindformers/checkpoint/broadcast.py
  4. +14
    -14
      mindformers/checkpoint/checkpoint.py
  5. +103
    -81
      mindformers/checkpoint/fully_parallel.py
  6. +21
    -18
      mindformers/checkpoint/metadata.py
  7. +185
    -192
      mindformers/checkpoint/sharded_tensor.py
  8. +4
    -2
      mindformers/models/glm4_moe/modeling_glm4_moe_infer.py
  9. +3
    -7
      mindformers/trainer/base_trainer.py
  10. +1
    -1
      tests/st/test_multi_cards_cases/test_model/test_glm4_moe/__init__.py
  11. +15
    -0
      tests/st/test_multi_cards_cases/test_model/test_glm4_moe/test_glm4_moe_infer/__init__.py
  12. +40
    -0
      tests/st/test_multi_cards_cases/test_model/test_glm4_moe/test_glm4_moe_infer/glm4_moe_infer.yaml
  13. +90
    -0
      tests/st/test_multi_cards_cases/test_model/test_glm4_moe/test_glm4_moe_infer/run_glm4_moe.py
  14. +58
    -0
      tests/st/test_multi_cards_cases/test_model/test_glm4_moe/test_glm4_moe_infer/test_glm4_moe_infer.py
  15. +4
    -1
      tests/st/test_ut/base_schema.json
  16. +35
    -59
      tests/st/test_ut/test_checkpoint/test_fully_parallel.py
  17. +0
    -149
      tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization/gpt_model_for_test.py
  18. +381
    -85
      tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization/numpy_quantizer.py
  19. +0
    -168
      tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization/run_parallel_linear.py
  20. +172
    -125
      tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization/simple_gpt_model.py
  21. +412
    -0
      tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization/simple_mcore.py
  22. +108
    -0
      tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization/test_configs.yaml
  23. +62
    -28
      tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization/test_parallel_linear.py
  24. +0
    -224
      tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/numpy_quantizer.py
  25. +0
    -242
      tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/run_parallel_linear.py
  26. +0
    -142
      tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/test_parallel_linear.py

+ 1
- 1
.jenkins/test/config/dependent_packages.yaml View File

@@ -1,4 +1,4 @@
mindspore:
'https://repo.mindspore.cn/mindspore/mindspore/version/202511/20251124/r2.7.2_20251124170006_89f847825c34199c5561608a6a2b6290dcf822f9_newest/'
'https://repo.mindspore.cn/mindspore/mindspore/version/202512/20251208/r2.7.2_20251208020007_3c91386a5d0d898f8ebd32011c649f78a9b2918b_newest/'
ms_custom_ops:
'https://repo.mindspore.cn/mindspore/ms_custom_ops/version/202510/20251029/master_20251029031507_da84a03f2297b5499e9a47a8e294078229230087_newest/'

+ 4414
- 2939
Third_Party_Open_Source_Software_Notice
File diff suppressed because it is too large
View File


+ 0
- 5
mindformers/checkpoint/broadcast.py View File

@@ -81,8 +81,6 @@ def _create_allreduce_input(params, group, net_param_dict, total_param_loaded, p
for param in params:
if param not in net_param_dict:
continue
if param.startswith("accu_grads") or param.endswith("expert_load"):
continue
param_rank_index = _get_param_index_in_group(total_param_loaded, group, param)
if not param_rank_index:
continue
@@ -172,9 +170,6 @@ def single_parameter_broadcast(net, param_redundancy, param_not_load, param_load
total_param_loaded = [None] * total_num
synchronize()
all_gather_object(total_param_loaded, param_loaded)
logger.debug("Total params loaded:")
for param_loaded_ in total_param_loaded:
logger.debug(param_loaded_)

group_map = _get_sorted_group_map()
for group, params in param_redundancy.items():


+ 14
- 14
mindformers/checkpoint/checkpoint.py View File

@@ -62,9 +62,12 @@ from mindformers.checkpoint.metadata import (
get_total_params_file_mapping_info,
)
from mindformers.checkpoint.sharded_tensor import (
convert_sharded_tensor_list_to_dict,
get_strategy_info_from_sharded_tensor,
ShardedTensor, get_sharded_tensor_list_from_cell, get_cur_sharded_tensor
ShardedTensor,
get_sharded_tensor_from_cell,
get_cur_sharded_tensor,
get_cur_sharded_tensor_after_balanced,
get_param_redundancy_after_balanced
)
from mindformers.checkpoint.broadcast import single_parameter_broadcast

@@ -264,7 +267,7 @@ class AsyncSaveManager:
def save_checkpoint(iteration: int, network: Cell, optimizer: Optimizer = None,
async_save_manager: AsyncSaveManager = None, common_info: CommonInfo = None,
keep_max_num: int = 5, user_prefix: str = None, save_checkpoint_path: str = None,
sharded_tensor_metas: list = None, remove_redundancy: bool = False):
sharded_tensor_metas: Dict = None, remove_redundancy: bool = False):
"""
Saves the current state of the training process,
including the model, optimizer, and learning rate scheduler, to a checkpoint file.
@@ -280,7 +283,7 @@ def save_checkpoint(iteration: int, network: Cell, optimizer: Optimizer = None,
save_checkpoint_path (str): The user can specify the path to save the weights.
If None, the default path is 'output_dir/checkpoint'.
And 'output_dir' is configured in yaml and defaults to './output' in the execution script path.
sharded_tensor_metas (List): The ShardedTensor metas of this network.
sharded_tensor_metas (Dict): The ShardedTensor metas of this network.
remove_redundancy (bool): Whether to remove redundancy of saving checkpoint.
"""
logger.info('....... Start to save checkpoint as new format .......')
@@ -932,16 +935,12 @@ def load_checkpoint(

param_redundancy = None
if balanced_load:
dst_sharded_tensor_metas, param_redundancy = apply_balance_shard_strategy(network, filter_func)[2:]
rank_id_to_sharded_tensors = apply_balance_shard_strategy(network, filter_func)
dst_sharded_tensor_metas = get_cur_sharded_tensor_after_balanced(rank_id_to_sharded_tensors)
param_redundancy = get_param_redundancy_after_balanced(rank_id_to_sharded_tensors)
else:
if get_real_group_size() > 1:
cur_rank_sharded_tensors = get_cur_sharded_tensor(network, filter_func)
else:
# Fallback: Get sharded tensors directly from network and optimizer
cur_rank_sharded_tensors = get_sharded_tensor_list_from_cell(network, optimizer)

# Convert list of sharded tensors to dictionary for lookup
dst_sharded_tensor_metas = convert_sharded_tensor_list_to_dict(cur_rank_sharded_tensors)
dst_sharded_tensor_metas = get_cur_sharded_tensor(network, filter_func) \
if get_real_group_size() > 1 else get_sharded_tensor_from_cell(network, optimizer)

# Categorize parameters based on sharding strategies
_, need_concat_params, no_shard_params, online_shard_params = categorize_params(
@@ -1116,7 +1115,8 @@ def load_parameters(

# Separate network and optimizer parameters
if balanced_load and param_redundancy is None:
param_redundancy = apply_balance_shard_strategy(network)[-1]
rank_id_to_sharded_tensors = apply_balance_shard_strategy(network)
param_redundancy = get_param_redundancy_after_balanced(rank_id_to_sharded_tensors)

network_param_names, _, state_dict, state_dict_opt = \
split_state_dict(network, state_dict, optimizer, state_dict_opt)


+ 103
- 81
mindformers/checkpoint/fully_parallel.py View File

@@ -18,11 +18,11 @@ from collections import defaultdict
from typing import Callable

from mindspore import save_checkpoint
from mindspore.communication import get_rank
from mindspore.nn import Cell

from mindformers.checkpoint.sharded_tensor import get_all_sharded_tensor
from mindformers.tools.logger import logger
from mindformers.tools.utils import get_real_rank
from mindformers.checkpoint.metadata import save_metadata, load_metadata
from mindformers.checkpoint.utils import (
_reverse_sharded_tensor_shard_id,
@@ -33,7 +33,6 @@ from mindformers.checkpoint.utils import (
_get_shard_size,
sharded_tensor_shard_id
)
from mindformers.tools.utils import get_real_rank


class BalancedSaveStrategy():
@@ -70,7 +69,7 @@ class BalancedSaveStrategy():
self.total_files_num = None
self.cur_rank_file_id = None
self.cached_distribution = None
self.rank_id = get_rank()
self.rank_id = get_real_rank()
self.checkpoint_path = checkpoint_path
self.network = network
self.ckpt_format = "safetensors"
@@ -88,8 +87,8 @@ class BalancedSaveStrategy():
The total number of checkpoint files.
"""
if self.total_files_num is None:
shared_distribution, id_to_tensor = self.apply_saving_parallelization()
rank_params_mappings = self._get_rank_params_mappings(shared_distribution, id_to_tensor)
shared_distribution = self.apply_saving_parallelization()
rank_params_mappings = self._get_rank_params_mappings(shared_distribution)
self.total_files_num = self._get_total_files_num(rank_params_mappings)

return self.total_files_num
@@ -105,8 +104,8 @@ class BalancedSaveStrategy():
The identifier for the current rank's checkpoint file.
"""
if self.cur_rank_file_id is None:
shared_distribution, id_to_tensor = self.apply_saving_parallelization()
rank_params_mappings = self._get_rank_params_mappings(shared_distribution, id_to_tensor)
shared_distribution = self.apply_saving_parallelization()
rank_params_mappings = self._get_rank_params_mappings(shared_distribution)
self.cur_rank_file_id = self._get_cur_rank_file_id(rank_params_mappings)

return self.cur_rank_file_id
@@ -120,8 +119,8 @@ class BalancedSaveStrategy():
and saves the selected parameters in the specified format.
It also saves metadata about the checkpoint files if the current rank is 0.
"""
shared_distribution, id_to_tensor = self.apply_saving_parallelization()
rank_params_mappings = self._get_rank_params_mappings(shared_distribution, id_to_tensor)
shared_distribution = self.apply_saving_parallelization()
rank_params_mappings = self._get_rank_params_mappings(shared_distribution)

if self.total_files_num is None:
self.total_files_num = self._get_total_files_num(rank_params_mappings)
@@ -160,13 +159,18 @@ class BalancedSaveStrategy():
Otherwise, it will retrieve the current distribution and cache it if caching is enabled.

Returns:
A tuple containing the shared distribution dictionary and the shard-to-name mapping dictionary.
shared_distribution (Dict[int, Dict[str, Tuple]]): A nested dictionary where:
- Outer keys: Target rank IDs (int) in the parallel group.
- Outer values: Dictionaries mapping unique shard IDs (str) to tuples containing:
1. Corresponding `ShardedTensor` object with complete shard metadata (shape, dtype, global offset, etc.).
2. Rank group (tuple of ints): Ranks that have redundant copies of the shard (supports fault tolerance or
parallel access).
"""
if self.do_cache_distribution and self.cached_distribution is not None:
shared_distribution = self.cached_distribution
else:
shard_id_to_ranks, shard_id_to_tensor = apply_balance_shard_strategy(self.network, self.filter_func)[:2]
shared_distribution = (shard_id_to_ranks, shard_id_to_tensor)
rank_id_to_sharded_tensors = apply_balance_shard_strategy(self.network, self.filter_func)
shared_distribution = rank_id_to_sharded_tensors

if self.do_cache_distribution:
self.cached_distribution = shared_distribution
@@ -221,8 +225,12 @@ class BalancedSaveStrategy():
and saves this mapping along with the shard metadata to the specified checkpoint path.

Args:
shared_distribution (dict): A dictionary where keys are parameter IDs and values are rank IDs indicating
which rank is responsible for a particular parameter.
shared_distribution (Dict[int, Dict[str, Tuple]]): A nested dictionary where:
- Outer keys: Target rank IDs (int) in the parallel group.
- Outer values: Dictionaries mapping unique shard IDs (str) to tuples containing:
1. Corresponding `ShardedTensor` object with complete shard metadata (shape, dtype, global offset, etc.).
2. Rank group (tuple of ints): Ranks that have redundant copies of the shard (supports fault tolerance or
parallel access).
iteration (int): The current iteration number.
"""
if self.rank_id == 0:
@@ -240,13 +248,13 @@ class BalancedSaveStrategy():
(save_file_name + ".safetensors", rank_id, _reverse_sharded_tensor_shard_id(param_id)))
cur_rank_id += 1

shard_to_metadata = get_all_sharded_tensor(self.network, self.filter_func)
sharded_tensor_metas = get_all_sharded_tensor(self.network, self.filter_func)
origin_metadata_file = get_metadata_filename(self.checkpoint_path, iteration)

if os.path.exists(origin_metadata_file):
origin_shard_metadata, origin_param_file_mapping = load_metadata(
get_metadata_filename(self.checkpoint_path, iteration))
shard_to_metadata.extend(list(origin_shard_metadata.values()))
sharded_tensor_metas.update({"origin": origin_shard_metadata})
for param_id, storage in origin_param_file_mapping.items():
for storage_item in storage:
param_file_mapping.append((
@@ -256,31 +264,37 @@ class BalancedSaveStrategy():
))

metadata_file_path = get_metadata_filename(self.checkpoint_path, iteration)
save_metadata(shard_to_metadata, param_file_mapping, metadata_file_path)
save_metadata(sharded_tensor_metas, param_file_mapping, metadata_file_path)
if self.rank_id == 0:
logger.info(
f"The 'metadata.json' of non-redundancy weight saved successfully at '{metadata_file_path}'."
)

def _get_rank_params_mappings(self, shared_distribution, id_to_tensor):
def _get_rank_params_mappings(self, shared_distribution):
"""
Create a mapping from rank IDs to lists of parameter names based on the shared distribution and
shard-to-name mapping.
Create a mapping from rank IDs to lists of parameter names based on the shared distribution.

Args:
shared_distribution (dict): A dictionary where keys are parameter IDs and values are rank IDs indicating
which rank is responsible for a particular parameter.
id_to_tensor (dict): A dictionary that maps parameter IDs to their corresponding ShardTensor.
shared_distribution (Dict[int, Dict[str, Tuple]]): A nested dictionary where:
- Outer keys: Target rank IDs (int) in the parallel group.
- Outer values: Dictionaries mapping unique shard IDs (str) to tuples containing:
1. Corresponding `ShardedTensor` object with complete shard metadata (shape, dtype, global offset, etc.).
2. Rank group (tuple of ints): Ranks that have redundant copies of the shard (supports fault tolerance or
parallel access).

Returns:
A dictionary where keys are rank IDs and values are lists of parameter names assigned to that rank.
Dict[int, Optional[List[str]]]: A sorted dictionary where:
- Outer keys: Rank IDs (int) sorted in ascending numerical order.
- Outer values: List of param name (str) assigned to the rank.
"""
rank_params_mappings = {}
for param_id, rank_id in shared_distribution.items():
if rank_id not in rank_params_mappings:
rank_params_mappings[rank_id] = [id_to_tensor[param_id].key]
else:
rank_params_mappings[rank_id].append(id_to_tensor[param_id].key)
for rank_id, sharded_tensors in shared_distribution.items():
rank_params_mappings[rank_id] = []
for _, shard_id_info in sharded_tensors.items():
sharded_tensor, _ = shard_id_info
param_name = sharded_tensor.key
rank_params_mappings[rank_id].append(param_name)

sorted_rank_params_mappings = {
k: rank_params_mappings.get(k, None)
for k in sorted(rank_params_mappings)
@@ -292,18 +306,20 @@ class BalancedSaveStrategy():
Create a mapping from rank IDs to lists of parameter IDs based on the shared distribution.

Args:
shared_distribution (dict): A dictionary where keys are parameter IDs and values are rank IDs indicating
which rank is responsible for a particular parameter.
shared_distribution (Dict[int, Dict[str, Tuple]]): A nested dictionary where:
- Outer keys: Target rank IDs (int) in the parallel group.
- Outer values: Dictionaries mapping unique shard IDs (str) to tuples containing:
1. Corresponding `ShardedTensor` object with complete shard metadata (shape, dtype, global offset, etc.).
2. Rank group (tuple of ints): Ranks that have redundant copies of the shard (supports fault tolerance or
parallel access).

Returns:
A dictionary where keys are rank IDs and values are lists of parameter IDs assigned to that rank.
Dict[int, Optional[List[str]]]: A sorted dictionary where:
- Outer keys: Rank IDs (int) sorted in ascending numerical order.
- Outer values: List of parameter IDs (str) assigned to the rank.
"""
rank_params_mappings = {}
for param_id, rank_id in shared_distribution.items():
if rank_id not in rank_params_mappings:
rank_params_mappings[rank_id] = [param_id]
else:
rank_params_mappings[rank_id].append(param_id)
rank_params_mappings = {rank_id: list(sharded_tensors.keys()) \
for rank_id, sharded_tensors in shared_distribution.items()}

sorted_rank_params_mappings = {
k: rank_params_mappings.get(k, None)
@@ -326,7 +342,11 @@ def distribute_shards(shard_coverage, shard_sizes, total_ranks):
total_ranks (int): The total number of ranks.

Returns:
A dictionary mapping shard IDs to the rank that will save the shard.
Dict[str, Tuple[int, Tuple[int, ...]]]: A dictionary where each key is a unique shard ID,
and the corresponding value is a 2-element tuple:
1. Selected target rank (int): The rank assigned to store the shard (chosen to minimize current load).
2. Rank group (Tuple[int, ...]): Ranks that originally contain the shard (from `shard_coverage`),
representing redundant copies for fault tolerance or parallel access.
"""
coverage_map = {
k: tuple(sorted(v))
@@ -344,7 +364,7 @@ def distribute_shards(shard_coverage, shard_sizes, total_ranks):

for shard_id, available_ranks in sorted_shards:
selected_rank = min(available_ranks, key=lambda rank: rank_loads[rank])
shard_assignment[shard_id] = selected_rank
shard_assignment[shard_id] = (selected_rank, available_ranks)
rank_loads[selected_rank] += shard_sizes[shard_id]

return shard_assignment
@@ -352,33 +372,39 @@ def distribute_shards(shard_coverage, shard_sizes, total_ranks):

def apply_balance_shard_strategy(network: Cell, filter_func: Callable[[str], bool] = None):
"""
Distributes and balances sharded tensor metadata across all ranks in a parallel group.

This function aggregates sharded tensor metadata from the input network (filtered by the provided
`filter_func`), calculates shard sizes, maps shards to their associated ranks, and distributes
shards to ranks in a load-balanced manner. It also identifies redundant shard copies across ranks
and returns key metadata for shard management.

Core workflow:
1. Collect all sharded tensor metadata from the network, filtered by `filter_func`.
2. Generate unique shard IDs for each tensor shard and track which ranks own each shard.
3. Calculate the size (in bytes) of each unique shard based on its local shape and data type.
4. Distribute shards to ranks using a load-balanced strategy via the `distribute_shards` function.
5. Compile metadata for the current rank, including assigned shards and redundant shard copies.
Distributes and balances sharded tensor storage across ranks in a parallel group,
generating rank-specific shard assignments.

Collects sharded tensor metadata from the input MindSpore Network Cell (filtered by an optional function),
computes unique shard identifiers and their sizes, and distributes shards to ranks using a load-balanced strategy.
The result maps each target rank to its assigned shards along with the group of ranks that share redundant copies
of those shards.

Core Workflow:
1. Extract all sharded tensor metadata from the network using `get_all_sharded_tensor`, applying the `filter_func`
to select target tensors (e.g., exclude non-trainable parameters).
2. Generate unique shard IDs for each tensor shard (via `sharded_tensor_shard_id`) by combining the tensor key
and global offset, then track which ranks originally own each shard.
3. Calculate the byte size of each unique shard using its local shape and data type (via `_get_shard_size`),
avoiding redundant size computations for identical shards.
4. Distribute shards to ranks for storage using the `distribute_shards` function, which implements a load-balanced
algorithm to evenly distribute the storage load across the parallel group.
5. Compile a rank-to-shard mapping: for each rank, store its assigned shards and the corresponding rank group
(ranks with redundant copies of the same shard).

Args:
network (Cell): MindSpore Network Cell containing parameters and their sharding information.
filter_func: A filtering function to select specific tensors from the network (e.g., exclude
non-trainable parameters). Applied during sharded tensor collection via `get_all_sharded_tensor`.
network (Cell): A MindSpore Network Cell containing parameters and their associated sharding metadata.
filter_func (Optional[Callable[[str], bool]]): An optional filtering function that takes a tensor key (str)
and returns a boolean. Only tensors for which the function returns `True` are included in the shard
distribution. Defaults to `None` (all sharded tensors in the network are included).

Returns:
tuple: A 4-element tuple containing:
1. shard_to_saving_rank (dict): Mapping of unique shard IDs to the rank responsible for storing the shard.
2. shard_id_to_tensor (dict): Mapping of unique shard IDs to their corresponding sharded tensor metadata.
3. dst_sharded_tensor_metas (dict): Mapping of original tensor keys to sharded tensor metadata
assigned to the current local rank.
4. param_redundancy (dict): Mapping of rank groups (tuples of ranks) to lists of tensor keys.
Represents shards that exist redundantly across multiple ranks in the group.
Dict[int, Dict[str, Tuple]]: A nested dictionary where:
- Outer keys: Target rank IDs (int) in the parallel group.
- Outer values: Dictionaries mapping unique shard IDs (str) to tuples containing:
1. Corresponding `ShardedTensor` object with complete shard metadata (local shape, dtype, global offset, etc.).
2. Rank group (tuple of ints): Ranks that have redundant copies of the shard (supports fault tolerance or
parallel access).
"""
total_shard_metadata = get_all_sharded_tensor(network, filter_func)
shard_id_to_ranks = defaultdict(list)
@@ -386,14 +412,14 @@ def apply_balance_shard_strategy(network: Cell, filter_func: Callable[[str], boo
shards_in_this_parallelization_group = set()
shard_id_to_tensor = {}

for rank, sharded_tensor_metas in enumerate(total_shard_metadata):
for tensor_meta in sharded_tensor_metas:
shard_id = sharded_tensor_shard_id(tensor_meta.key, tensor_meta.global_offset)
for rank, sharded_tensor_metas in total_shard_metadata.items():
for sharded_tensor in sharded_tensor_metas.values():
shard_id = sharded_tensor_shard_id(sharded_tensor.key, sharded_tensor.global_offset)
shard_id_to_ranks[shard_id].append(rank)

if shard_id not in shard_to_size:
shard_to_size[shard_id] = _get_shard_size(tensor_meta.local_shape, tensor_meta.dtype)
shard_id_to_tensor[shard_id] = tensor_meta
shard_to_size[shard_id] = _get_shard_size(sharded_tensor.local_shape, sharded_tensor.dtype)
shard_id_to_tensor[shard_id] = sharded_tensor
shards_in_this_parallelization_group.add(shard_id)

shard_id_to_ranks = {
@@ -406,17 +432,13 @@ def apply_balance_shard_strategy(network: Cell, filter_func: Callable[[str], boo
shard_id_to_ranks, shard_to_size, len(total_shard_metadata)
)

dst_sharded_tensor_metas = {} # {shard_name: ShardTensor}
local_rank = get_real_rank()
for shard_id, rank_id in shard_to_saving_rank.items():
if rank_id == local_rank:
dst_sharded_tensor_metas[_reverse_sharded_tensor_shard_id(shard_id)[0]] = shard_id_to_tensor[shard_id]

param_redundancy = {}
for shard_id, rank_group in shard_id_to_ranks.items():
if len(rank_group) == 1:
continue
if local_rank in rank_group:
param_redundancy.setdefault(tuple(rank_group), []).append(_reverse_sharded_tensor_shard_id(shard_id)[0])
rank_id_to_sharded_tensors = {}
for shard_id, rank_info in shard_to_saving_rank.items():
selected_rank_id, rank_group = rank_info
sharded_tensor = shard_id_to_tensor[shard_id]
if selected_rank_id in rank_id_to_sharded_tensors:
rank_id_to_sharded_tensors[selected_rank_id][shard_id] = (sharded_tensor, rank_group)
else:
rank_id_to_sharded_tensors[selected_rank_id] = {shard_id: (sharded_tensor, rank_group)}

return shard_to_saving_rank, shard_id_to_tensor, dst_sharded_tensor_metas, param_redundancy
return rank_id_to_sharded_tensors

+ 21
- 18
mindformers/checkpoint/metadata.py View File

@@ -109,28 +109,31 @@ def save_metadata(sharded_tensor_metas, param_file_mappings, meta_data_path):
due to file system errors, permission issues, or disk space problems.
"""
state_dict_metadata = {}
sharded_tensor_metas = [
item
for sublist in sharded_tensor_metas
for item in sublist
]
for _, tensor_meta in enumerate(sharded_tensor_metas):
param_name = tensor_meta.key
sharded_tensor_list = []
for _, cur_rank_sharded_tensor_metas in sharded_tensor_metas.items():
for _, sharded_tensor in cur_rank_sharded_tensor_metas.items():
if isinstance(sharded_tensor, list):
sharded_tensor_list.extend(sharded_tensor)
else:
sharded_tensor_list.append(sharded_tensor)

for sharded_tensor in sharded_tensor_list:
param_name = sharded_tensor.key
new_chunk = {
"global_offset": tensor_meta.global_offset,
"local_shape": tensor_meta.local_shape
"global_offset": sharded_tensor.global_offset,
"local_shape": sharded_tensor.local_shape
}
if param_name not in state_dict_metadata:
state_dict_metadata[param_name] = {
"properties": {
"dtype": str(tensor_meta.dtype),
"replica_id": tensor_meta.replica_id,
"allow_shape_mismatch": tensor_meta.allow_shape_mismatch,
"allow_to_save": tensor_meta.allow_to_save
"dtype": str(sharded_tensor.dtype),
"replica_id": sharded_tensor.replica_id,
"allow_shape_mismatch": sharded_tensor.allow_shape_mismatch,
"allow_to_save": sharded_tensor.allow_to_save
},
"global_shape": tensor_meta.global_shape,
"axis_fragmentations": tensor_meta.axis_fragmentations,
"layout": _serialize_sharded_tensor_layout(tensor_meta.layout),
"global_shape": sharded_tensor.global_shape,
"axis_fragmentations": sharded_tensor.axis_fragmentations,
"layout": _serialize_sharded_tensor_layout(sharded_tensor.layout),
"chunk": [new_chunk]
}
elif param_name in state_dict_metadata:
@@ -344,9 +347,9 @@ def get_total_params_file_mapping_info(sharded_tensor_metas, user_prefix, model_

npu_nums = get_group_size()
param_file_mappings = []
for cur_npu_rank, cur_rank_sharded_tensor_list in enumerate(sharded_tensor_metas):
for cur_npu_rank, cur_rank_sharded_tensors in sharded_tensor_metas.items():
# Get mappings of parameter file of current rank.
for sharded_tensor in cur_rank_sharded_tensor_list:
for sharded_tensor in cur_rank_sharded_tensors.values():
if model_keys and sharded_tensor.key not in list(model_keys):
ckpt_name = get_checkpoint_name(None, user_prefix, cur_npu_rank, npu_nums, FileType.OPTIMIZER)
else:


+ 185
- 192
mindformers/checkpoint/sharded_tensor.py View File

@@ -220,182 +220,113 @@ def _rank_id_with_slice_id(alias_rank_stride):
return rank_slice_table, cur_global_offset


def get_param_name_from_layout(param_infos: List[Dict]) -> List[str]:
"""Extract parameter names."""
names = []

for param_dict in param_infos:
for param_name, _ in param_dict.items():
names.append(param_name)

return names


def get_value_type_from_layout(param_infos: List[Dict]) -> List[type]:
"""Extract parameter types."""
types = []

for param_dict in param_infos:
for _, (_, param_type, _) in param_dict.items():
types.append(param_type)

return types


def get_local_shape_from_layout(param_infos: List[Dict]) -> List[Tuple[int, ...]]:
"""Compute local (sharded) shape on current device."""
shapes = []

for param_dict in param_infos:
for _, (cur_layout, _, cur_shape) in param_dict.items():
distributed_info = _DistributedTensorInfo(cur_layout)
cur_stra = distributed_info.sharding_strategy
shapes.append(tuple(int(s // c) for s, c in zip(cur_shape, cur_stra)))

return shapes


def get_global_shape_from_layout(param_infos: List[Dict]) -> List[Tuple[int, ...]]:
"""Extract global shapes."""
shapes = []

for param_dict in param_infos:
for _, (_, _, cur_shape) in param_dict.items():
shapes.append(cur_shape)

return shapes


def get_axis_fragmentations_from_layout(param_infos: List[Dict]) -> List[Tuple[int, ...]]:
"""Extract axis fragmentations (effective sharding strategies per tensor axis)."""
fragmentations = []

for param_dict in param_infos:
for _, (cur_layout, _, _) in param_dict.items():
distributed_info = _DistributedTensorInfo(cur_layout)
cur_stra = distributed_info.sharding_strategy
fragmentations.append(tuple(int(x) for x in cur_stra))

return fragmentations


def get_global_offset_from_layout(param_infos: List[Dict]) -> List[Tuple[int, ...]]:
"""Calculate global offsets (starting index in full tensor) for current device."""
offsets = []

for param_dict in param_infos:
for _, (cur_layout, _, _) in param_dict.items():
cur_layout_dict = cur_layout.to_dict()
cur_dev_matrix = cur_layout_dict.get("device_matrix")
cur_alias_name = cur_layout_dict.get("alias_name")
cur_tensor_map = cur_layout_dict.get("tensor_map")
cur_rank_list = cur_layout_dict.get("rank_list")

dev_arrange = _alias_name_with_rank_id(cur_dev_matrix, cur_alias_name, cur_rank_list)
flat_tensor_map = _flatten_tensor_map(cur_tensor_map)
alias_rank_stride = _tensor_map_with_rank_id(
cur_dev_matrix, flat_tensor_map, cur_alias_name, dev_arrange
)

_, cur_global_offset = _rank_id_with_slice_id(alias_rank_stride)
offsets.append(cur_global_offset)

return offsets


def get_replica_id_from_layout(param_infos: List[Dict]) -> List[List[int]]:
"""Determine replica ID for each device (0 for primary, 1 for duplicate, etc.)."""
replica_ids = []
def _get_global_offset_and_replica_id_from_layout(layout):
"""
Extracts rank-specific global offset and replica IDs from the distributed tensor layout.

for param_dict in param_infos:
for _, (cur_layout, _, _) in param_dict.items():
cur_layout_dict = cur_layout.to_dict()
cur_dev_matrix = cur_layout_dict.get("device_matrix")
cur_alias_name = cur_layout_dict.get("alias_name")
cur_tensor_map = cur_layout_dict.get("tensor_map")
cur_rank_list = cur_layout_dict.get("rank_list")
Parses the layout's device matrix, alias name, tensor map, and rank list to compute the global offset
(starting index of the local slice in the global tensor) and replica IDs (identifies the replica of the
local slice across sharded dimensions).

dev_arrange = _alias_name_with_rank_id(cur_dev_matrix, cur_alias_name, cur_rank_list)
flat_tensor_map = _flatten_tensor_map(cur_tensor_map)
alias_rank_stride = _tensor_map_with_rank_id(
cur_dev_matrix, flat_tensor_map, cur_alias_name, dev_arrange
)

rank_slice_table, _ = _rank_id_with_slice_id(alias_rank_stride)
Args:
layout: Distributed tensor layout object (must implement `to_dict()` method) containing sharding
configuration details.

slice_cnt: Dict[int, int] = defaultdict(int)
cur_replica_id: List[int] = []
for _, slice_id in enumerate(rank_slice_table):
replica_id = slice_cnt[slice_id]
slice_cnt[slice_id] += 1
cur_replica_id.append(replica_id)
replica_ids.append(cur_replica_id)
Returns:
Tuple[List[int], List[int]]: A tuple with two elements:
1. global_offset: List of integers representing the base global offset for the tensor's sharding
2. cur_replica_id: List of integers where each element is the replica index for the corresponding
sharded dimension
"""
layout_dict = layout.to_dict()
dev_matrix = layout_dict.get("device_matrix")
alias_name = layout_dict.get("alias_name")
tensor_map = layout_dict.get("tensor_map")
rank_list = layout_dict.get("rank_list")

dev_arrange = _alias_name_with_rank_id(dev_matrix, alias_name, rank_list)
flat_tensor_map = _flatten_tensor_map(tensor_map)
alias_rank_stride = _tensor_map_with_rank_id(
dev_matrix, flat_tensor_map, alias_name, dev_arrange
)

return replica_ids
rank_slice_table, global_offset = _rank_id_with_slice_id(alias_rank_stride)
slice_cnt: Dict[int, int] = defaultdict(int)
cur_replica_id: List[int] = []
for _, slice_id in enumerate(rank_slice_table):
replica_id = slice_cnt[slice_id]
slice_cnt[slice_id] += 1
cur_replica_id.append(replica_id)
return global_offset, cur_replica_id


def get_sharded_tensor_list_from_strategy_metadata(
param_infos: List[Dict],
def get_sharded_tensor_from_strategy_metadata(
param_infos: Dict[str, List],
cur_npu_rank: int,
filter_func: Callable[[str], bool] = None
) -> Optional[List[ShardedTensor]]:
) -> Optional[Dict[str, ShardedTensor]]:
"""
Transform distributed strategy of a network to a list of ShardedTensor.
Creates ShardedTensor instances for the current NPU rank based on distributed strategy metadata.

Processes parameter metadata (layout, dtype, global shape) to construct sharded tensors tailored to the current
NPU rank. Applies an optional filter to select specific parameters, computes rank-specific sharding details
(local shape, global offset, replica ID), and builds ShardedTensor objects using the provided metadata.

Args:
param_infos (List[Dict]): The distributed strategy of a rank of network.
cur_npu_rank (int): Current Rank ID of NPUs.
filter_func (Callable[[str], bool]): A filter function
that decide whether to save metadata info of optimizer weight.
param_infos: A dictionary mapping parameter names (str) to their distributed metadata. Each value is a list
containing three elements in order:
1. layout: Distributed tensor layout object (supports `to_dict()` method) containing sharding
configuration (device matrix, alias name, tensor map, rank list)
2. dtype: Data type of the parameter (e.g., torch.float32, numpy.float64)
3. global_shape: Tuple of integers representing the full global shape of the unsharded tensor
cur_npu_rank: Integer indicating the current NPU rank index (used to compute rank-specific global offset)
filter_func: Optional callable that takes a parameter name (str) and returns a boolean. If provided, only
parameters for which the function returns True are included in the output. Defaults to None (all parameters
included).

Returns:
A list of ShardedTensor or None: A list containing sharded tensor metadata, or None if no param_infos.
Optional[Dict[str, ShardedTensor]]: A dictionary where keys are parameter names (filtered if `filter_func` is
provided) and values are corresponding ShardedTensor instances for the current NPU rank. Returns None if the
input `param_infos` is empty.
"""
if not param_infos:
return None

cur_rank_sharded_tensor_list = []

cur_param_name_list = get_param_name_from_layout(param_infos)
cur_value_type_list = get_value_type_from_layout(param_infos)
cur_local_shape_list = get_local_shape_from_layout(param_infos)
cur_global_shape_list = get_global_shape_from_layout(param_infos)
cur_axis_fragmentations_list = get_axis_fragmentations_from_layout(param_infos)
cur_global_offset_list = get_global_offset_from_layout(param_infos)
cur_replica_id_list = get_replica_id_from_layout(param_infos)

for idx, param_name in enumerate(cur_param_name_list):
# If not save optimizer weight, the metadata will also not save the optimizer part.
cur_rank_sharded_tensor_dict = {}
for param_name, param_info in param_infos.items():
if filter_func and not filter_func(param_name):
continue

org_global_offset = cur_global_offset_list[idx]
npu_nums_per_pp = len(org_global_offset)

# The situation where different strategies need to be adapted later
global_offset = (org_global_offset[cur_npu_rank % npu_nums_per_pp],)
layout, dtype, global_shape = param_info
distributed_info = _DistributedTensorInfo(layout)
strategy = distributed_info.sharding_strategy
axis_fragmentations = tuple(int(x) for x in strategy)
local_shape = tuple(int(s // c) for s, c in zip(global_shape, axis_fragmentations))
global_offset, replica_id = _get_global_offset_and_replica_id_from_layout(layout)
npu_nums_per_pp = len(global_offset)
global_offset = (global_offset[cur_npu_rank % npu_nums_per_pp],)

cur_sharded_tensor = build_sharded_tensor(
param_name=param_name,
param_dtype=cur_value_type_list[idx],
local_shape=cur_local_shape_list[idx],
global_shape=cur_global_shape_list[idx],
param_dtype=dtype,
local_shape=local_shape,
global_shape=global_shape,
global_offset=global_offset,
axis_fragmentations=cur_axis_fragmentations_list[idx],
replica_id=cur_replica_id_list[idx],
axis_fragmentations=axis_fragmentations,
replica_id=replica_id,
allow_shape_mismatch=False,
allow_to_save=True,
layout=param_infos[idx][param_name][0]
layout=layout
)
cur_rank_sharded_tensor_list.append(cur_sharded_tensor)
cur_rank_sharded_tensor_dict[param_name] = cur_sharded_tensor

return cur_rank_sharded_tensor_list
return cur_rank_sharded_tensor_dict


def get_sharded_tensor_list_from_cell(
def get_sharded_tensor_from_cell(
network: Cell,
optimizer: Optional[Cell] = None,
) -> List[ShardedTensor]:
) -> Dict[str, ShardedTensor]:
"""
Extracts sharded tensor metadata from a network cell and optional optimizer cell.

@@ -409,14 +340,14 @@ def get_sharded_tensor_list_from_cell(
optimizer: Optional optimizer Cell containing additional parameters

Returns:
List of ShardedTensor objects with metadata from network and optimizer parameters
Dict of ShardedTensor objects with metadata from network and optimizer parameters
"""
logger.info(".........Get Current Strategy Metadata from Cell.........")
cur_rank_sharded_tensor_list: List[ShardedTensor] = []
sharded_tensor_dict: Dict[str, ShardedTensor] = {}

def _get_sharded_tensors_from_cell(
cell: Cell, ignore_params_list: Optional[List[str]] = None
) -> List[ShardedTensor]:
) -> Dict[str, ShardedTensor]:
"""
Helper function to extract sharded tensors from a single Cell.

@@ -428,9 +359,9 @@ def get_sharded_tensor_list_from_cell(
ignore_params_list: Optional list of parameter names to skip

Returns:
List of ShardedTensor objects for the cell's parameters
Dict of ShardedTensor objects for the cell's parameters
"""
sharded_tensor_list = []
cur_cell_sharded_tensor_dict = {}
for param in cell.get_parameters():
param_name = param.name

@@ -453,55 +384,52 @@ def get_sharded_tensor_list_from_cell(
global_offset=global_offset,
axis_fragmentations=axis_fragmentations
)
sharded_tensor_list.append(sharded_tensor)
cur_cell_sharded_tensor_dict[param_name] = sharded_tensor

return sharded_tensor_list
return cur_cell_sharded_tensor_dict

# Get sharded tensors from the main network
cur_rank_sharded_tensor_list.extend(_get_sharded_tensors_from_cell(network))
sharded_tensor_dict.update(_get_sharded_tensors_from_cell(network))

# Add sharded tensors from optimizer if provided, ignoring network parameters
if optimizer:
# Create list of parameter names already collected from network
ignore_params_list = [sharded_tensor.key for sharded_tensor in cur_rank_sharded_tensor_list]
ignore_params_list = list(sharded_tensor_dict.keys())
# Get optimizer parameters, skipping those already in network
cur_rank_sharded_tensor_list.extend(
sharded_tensor_dict.update(
_get_sharded_tensors_from_cell(optimizer, ignore_params_list)
)

return cur_rank_sharded_tensor_list
return sharded_tensor_dict


def convert_sharded_tensor_list_to_dict(
sharded_tensor_list: List[ShardedTensor]
) -> Dict[str, ShardedTensor]:
def get_all_sharded_tensor(
network: Cell,
filter_func: Callable[[str], bool] = None
) -> Dict[int, Dict[str, ShardedTensor]]:
"""
Converts a list of ShardedTensor objects to a dictionary.
Collects sharded tensor metadata for all ranks in the parallel group from the MindSpore network.

Creates a dictionary where each key is the 'key' attribute of a ShardedTensor
from the input list, and the corresponding value is the ShardedTensor object
itself.
Retrieves global distributed strategy metadata from the input network, then generates rank-specific
ShardedTensor instances for every rank in the parallel group. Applies an optional filter to select
target parameters (e.g., exclude non-trainable weights) during ShardedTensor creation.

Args:
sharded_tensor_list: List of ShardedTensor objects to convert
network (Cell): A MindSpore Network Cell containing distributed parameters and their sharding strategy.
filter_func (Optional[Callable[[str], bool]]): An optional filtering function that takes a parameter name (str)
and returns a boolean. Only parameters for which the function returns `True` are included in the
ShardedTensor collection. Defaults to `None` (all eligible parameters are included).

Returns:
Dictionary mapping ShardedTensor keys to their corresponding objects
Dict[int, Dict[str, ShardedTensor]]: A nested dictionary where:
- Outer keys: Rank IDs (int) in the parallel group (range: `[0, total_ranks - 1]`).
- Outer values: Dictionaries mapping parameter names (str) to their corresponding `ShardedTensor` instances,
containing rank-specific metadata (local shape, global offset, dtype, etc.).

Raises:
RuntimeError: If `get_strategy_metadata` returns `None`, indicating no distributed strategy metadata is
associated with the network.
"""
sharded_tensor_dict: Dict[str, ShardedTensor] = {}

for sharded_tensor in sharded_tensor_list:
param_name = sharded_tensor.key
sharded_tensor_dict[param_name] = sharded_tensor

return sharded_tensor_dict


def get_all_sharded_tensor(
network: Cell,
filter_func: Callable[[str], bool] = None
) -> List[List]:
"""Get all rank sharded tensors."""
logger.info(".........Get All Ranks' Strategy Metadata.........")
global_strategy_info = get_strategy_metadata(network)
if not global_strategy_info:
@@ -509,36 +437,101 @@ def get_all_sharded_tensor(
'Please check whether this is a distributed job.')

npu_nums = get_real_group_size()
sharded_tensor_metas = []
sharded_tensor_metas: Dict[int, Dict[str, ShardedTensor]] = {}
for cur_npu_rank in range(0, npu_nums):
org_cur_rank_strategy_layout = global_strategy_info[cur_npu_rank]
cur_rank_strategy_layout = [
dict([item])
for item in org_cur_rank_strategy_layout.items()
]
cur_rank_strategy_layout = global_strategy_info[cur_npu_rank]

# Get Sharded tensors from strategy metadata of current rank.
cur_rank_sharded_tensors = get_sharded_tensor_list_from_strategy_metadata(
cur_rank_sharded_tensors = get_sharded_tensor_from_strategy_metadata(
param_infos=cur_rank_strategy_layout,
cur_npu_rank=cur_npu_rank,
filter_func=filter_func
)

sharded_tensor_metas.append(cur_rank_sharded_tensors)
sharded_tensor_metas[cur_npu_rank] = cur_rank_sharded_tensors
return sharded_tensor_metas


def get_cur_sharded_tensor(
network: Cell,
filter_func: Callable[[str], bool] = None
) -> List:
"""Get current rank sharded tensors."""
) -> Dict[str, ShardedTensor]:
"""
Retrieves rank-specific ShardedTensor instances for the current NPU rank from the MindSpore network.

Args:
network (Cell): A MindSpore Network Cell containing distributed parameters and their sharding strategy.
filter_func (Optional[Callable[[str], bool]]): An optional filtering function that takes a parameter name (str)
and returns a boolean. Only parameters for which the function returns `True` are included in the
output. Defaults to `None` (all eligible parameters assigned to the current rank are included).

Returns:
Dict[str, ShardedTensor]: A dictionary where keys are parameter names (str) and values are `ShardedTensor`
instances.
"""
logger.info(".........Get Current Strategy Metadata.........")
strategy_info = get_current_strategy_metadata(network)
# Convert strategy layout to required format
strategy_info = [dict([item]) for item in strategy_info[0].items()]
strategy_info = get_current_strategy_metadata(network)[0]
# Get sharded tensors from strategy metadata
cur_rank_sharded_tensors = get_sharded_tensor_list_from_strategy_metadata(
cur_rank_sharded_tensors = get_sharded_tensor_from_strategy_metadata(
param_infos=strategy_info, cur_npu_rank=get_real_rank(), filter_func=filter_func
)
return cur_rank_sharded_tensors


def get_cur_sharded_tensor_after_balanced(
rank_id_to_sharded_tensors: Dict[int, Dict[str, Tuple]]
) -> Dict[str, ShardedTensor]:
"""
Retrieves the load-balanced ShardedTensor instances assigned to the current rank.

Args:
rank_id_to_sharded_tensors: A nested dictionary representing the load-balanced shard distribution across ranks:
- Outer keys: Rank IDs (int) in the parallel group.
- Outer values: Dictionaries mapping unique shard IDs (str) to tuples containing:
1. Target `ShardedTensor` instance (with rank-specific metadata like local shape, global offset, etc.).
2. Rank group (Tuple[int, ...]): Redundant ranks with copies of the shard (ignored in this function).

Returns:
Dict[str, ShardedTensor]: A dictionary where keys are parameter names (str) and values are `ShardedTensor`
instances.
"""
cur_rank_sharded_tensors = {}
local_rank = get_real_rank()
sharded_tensors = rank_id_to_sharded_tensors[local_rank]
for _, shard_id_info in sharded_tensors.items():
sharded_tensor, _ = shard_id_info
param_name = sharded_tensor.key
cur_rank_sharded_tensors[param_name] = sharded_tensor
return cur_rank_sharded_tensors


def get_param_redundancy_after_balanced(
rank_id_to_sharded_tensors: Dict[int, Dict[str, Tuple]]
) -> Dict[Tuple, List]:
"""
Identifies redundant parameter shards for the current rank from a load-balanced shard distribution.

Args:
rank_id_to_sharded_tensors: A nested dictionary representing the load-balanced shard distribution across ranks:
- Outer keys: Rank IDs (int) in the parallel group.
- Outer values: Dictionaries mapping unique shard IDs (str) to tuples containing:
1. `ShardedTensor` instance (with `key` attribute specifying the original parameter name).
2. Rank group (Tuple[int, ...]): Sorted tuple of ranks that store redundant copies of the shard.

Returns:
Dict[Tuple[int, ...], List[str]]: A dictionary where:
- Keys: Rank groups (Tuple[int, ...]) – sets of ranks that share redundant copies of parameters.
Only groups containing the current rank are included.
- Values: Lists of parameter names that are redundantly stored across the corresponding rank group.
"""
param_redundancy = {}
local_rank = get_real_rank()
for _, sharded_tensors in rank_id_to_sharded_tensors.items():
for _, shard_id_info in sharded_tensors.items():
sharded_tensor, rank_group = shard_id_info
param_name = sharded_tensor.key
if len(rank_group) == 1:
continue
if local_rank in rank_group:
param_redundancy.setdefault(tuple(rank_group), []).append(param_name)
return param_redundancy

+ 4
- 2
mindformers/models/glm4_moe/modeling_glm4_moe_infer.py View File

@@ -17,7 +17,6 @@ __all__ = ['InferenceGlm4MoeForCausalLM']

from mindformers.models.utils import jit
from mindformers.parallel_core.transformer_config import TransformerConfig
from mindformers.parallel_core.transformer_config_utils import convert_to_transformer_config
from mindformers.models.glm4_moe.utils import Glm4MoePreTrainedModel
from mindformers.parallel_core.inference.utils import update_comm_config
from mindformers.parallel_core.inference.base_models.gpt.gpt_model import GPTModel
@@ -41,8 +40,11 @@ class InferenceGlm4MoeForCausalLM(Glm4MoePreTrainedModel, InferModelMixin):

def __init__(self, config: Glm4MoeConfig):
super().__init__(config, auto_prefix=False)
if config.model_type == "glm4_moe":
setattr(config, "num_nextn_predict_layers", 0)

self.config = config
config: TransformerConfig = convert_to_transformer_config(self.config)
config: TransformerConfig = self.convert_to_transformer_config(self.config)

# update communication-related configuration in TransformerConfig
config = update_comm_config(config)


+ 3
- 7
mindformers/trainer/base_trainer.py View File

@@ -15,7 +15,7 @@
"""Base Trainer."""
import os
import re
import subprocess
import socket
from pprint import pprint
from functools import partial
from typing import Optional, Union, List
@@ -108,12 +108,8 @@ class BaseTrainer:

def __init__(self, task: str = None, model_name: str = None):

host_name_output = subprocess.run(['hostname'], shell=False, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, encoding='utf-8', check=True)
host_ip_output = subprocess.run(['hostname', '-I'], shell=False, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, encoding='utf-8', check=True)
host_name = host_name_output.stdout.strip()
host_ip = host_ip_output.stdout.strip().split(' ')[0]
host_name = socket.gethostname()
host_ip = socket.gethostbyname(host_name)
logger.info(f"host_name: {host_name}, host_ip: {host_ip}")

if model_name is None:


tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/__init__.py → tests/st/test_multi_cards_cases/test_model/test_glm4_moe/__init__.py View File

@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test column parallel linear"""
"""test mcore glm4 moe."""

+ 15
- 0
tests/st/test_multi_cards_cases/test_model/test_glm4_moe/test_glm4_moe_infer/__init__.py View File

@@ -0,0 +1,15 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test mcore infer glm4 moe."""

+ 40
- 0
tests/st/test_multi_cards_cases/test_model/test_glm4_moe/test_glm4_moe_infer/glm4_moe_infer.yaml View File

@@ -0,0 +1,40 @@
seed: 0
output_dir: './output' # path to save checkpoint/strategy
load_checkpoint: ''
use_parallel: False
run_mode: 'predict'
use_legacy: False
load_ckpt_format: 'safetensors'

trainer:
type: CausalLanguageModelingTrainer
model_name: 'glm4_moe'

# default parallel of device num = 8 for Atlas 800T A2
parallel_config:
data_parallel: 1
model_parallel: 2
# HuggingFace file directory
pretrained_model_dir: '/path/hf_dir'
model:
model_config:
compute_dtype: "bfloat16"
layernorm_compute_dtype: "float32"
softmax_compute_dtype: "float32"
rotary_dtype: "bfloat16"
params_dtype: "bfloat16"

# mindspore context init config
context:
mode: 0 #0--Graph Mode; 1--Pynative Mode
max_device_memory: "29GB"
device_id: 0
device_target: "Ascend"
affinity_cpu_list: None
deterministic: "ON"
infer_precision_sync: True

# parallel context config
parallel:
parallel_mode: "MANUAL_PARALLEL"
full_batch: False

+ 90
- 0
tests/st/test_multi_cards_cases/test_model/test_glm4_moe/test_glm4_moe_infer/run_glm4_moe.py View File

@@ -0,0 +1,90 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""mcore glm4 moe model ST of inference"""
import argparse
import os
from transformers import AutoTokenizer

from mindspore.nn.utils import no_init_parameters

from tests.st.test_multi_cards_cases.test_model.utils import compare_distance

from mindformers import AutoModel, build_context, MindFormerConfig
from mindformers.core.parallel_config import build_parallel_config
from mindformers.tools.logger import logger


def test_glm4_moe_predict_mcore(device_num: int = 1):
"""
Feature: Mcore Glm4Moe predict task
Description: Two-card tp parallel
Expectation: Success or assert precision failed
"""
max_decode_length = 32
config_path = os.path.join(os.path.dirname(__file__), "glm4_moe_infer.yaml")
config = MindFormerConfig(config_path)
config.use_parallel = device_num > 1
config.parallel_config.model_parallel = device_num
config.pretrained_model_dir = "/home/workspace/mindspore_dataset/weight/GLM-4.5-Air-tiny"
# Reduced layer network
config.model.model_config.num_hidden_layers = 2
build_context(config)
build_parallel_config(config)
# Auto tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model_dir)
# init network
with no_init_parameters():
network = AutoModel.from_config(config)
network.load_weights(config.pretrained_model_dir)
# Build prompt and answer
batch_datas = {1: {"prompt": "Please introduce some scenic spots in Beijing.",
"answer": "Please introduce some scenic spots in Beijing."
"ahan flying UmbursalhsiqadereFINE写实ENVIRONMENTally NASpired "
"Biosphericoux posit Lifts-offENS小的范围内"},
4: {"prompt": "Please introduce some scenic spots in Beijing.",
"answer": "Please introduce some scenic spots in Beijing."
"ahan flying UmbursalhsiqadereFINE写实ENVIRONMENTally NASpired "
"Biosphericoux posit Lifts-offENS小的范围内"},
}

for batch_size, batch_data in batch_datas.items():
input_ids = tokenizer.encode(batch_data["prompt"])
input_ids_list = []
answer = batch_data["answer"]
for _ in range(0, batch_size):
input_ids_list.append(input_ids)

outputs = network.generate(input_ids_list,
max_length=max_decode_length,
do_sample=False,
return_dict_in_generate=False)

for output in outputs:
output_text = tokenizer.decode(output)
logger.info("test_glm4_5_air_predict, output_text: %s", str(output_text))
compare_distance(output_text, answer)


if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Run Glm4Moe ST")
parser.add_argument("--device_num", type=int, default=2)

args = parser.parse_args()
os.environ['MS_ENABLE_LCCL'] = "off"
os.environ['HCCL_DETERMINICTIC'] = "true"
os.environ['LCCL_DETERMINICTIC'] = "1"
os.environ['ASCEND_LAUNCH_BLOCKING'] = "1"
os.environ['CUSTOM_MATMUL_SHUFFLE'] = "off"
test_glm4_moe_predict_mcore(args.device_num)

+ 58
- 0
tests/st/test_multi_cards_cases/test_model/test_glm4_moe/test_glm4_moe_infer/test_glm4_moe_infer.py View File

@@ -0,0 +1,58 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test Mcore Glm4Moe inference"""
import os
import random
from pathlib import Path

import pytest

from tests.st.test_multi_cards_cases.utils import TaskType
from mindformers.tools.logger import logger


_LEVEL_0_TASK_TIME = 80
_LEVEL_1_TASK_TIME = 0
_TASK_TYPE = TaskType.TWO_CARDS_TASK


class TestMcoreGlm4MoeParallelInference:
"""Test class for Glm4Moe in inference"""

def setup_method(self):
"""Setup method to prepare test environment"""
self.sh_path = Path(__file__).parent.resolve()
self.run_script_path = self.sh_path / "run_glm4_moe.py"
assert self.run_script_path.exists(), f"Run script not found: {self.run_script_path}"

@pytest.mark.level0
def test_two_cards_cases(self):
"""Test two cards for Glm4Moe."""
port_id = int(os.environ.get("ASCEND_PORT_ID", random.randint(50000, 65535)))
cmd_list = [
"msrun",
"--worker_num=2",
"--local_worker_num=2", # Should match NPU cards available
f"--master_port={port_id}", # Ensure port is unique per test run if parallelized at pytest level
"--log_dir=./msrun_log_glm4moe",
"--join=True"]
cmd_list += [
str(self.run_script_path),
"--device_num=2"
]
cmd = " ".join(cmd_list)
logger.info(f"Running command: {cmd}")
return_code = os.system(cmd)
assert return_code == 0, "Glm4Moe inference st failed."

+ 4
- 1
tests/st/test_ut/base_schema.json View File

@@ -1,6 +1,9 @@
{
"mindformers.checkpoint.load_checkpoint": {
"signature": "(checkpoint: str, network: mindspore.nn.cell.Cell, optimizer: mindspore.nn.optim.optimizer.Optimizer = None, global_step: int = None, balanced_load: bool = False)"
},
"mindformers.checkpoint.save_checkpoint": {
"signature": "(iteration: int, network: mindspore.nn.cell.Cell, optimizer: mindspore.nn.optim.optimizer.Optimizer = None, async_save_manager: mindformers.checkpoint.checkpoint.AsyncSaveManager = None, common_info: mindformers.checkpoint.checkpoint.CommonInfo = None, keep_max_num: int = 5, user_prefix: str = None, save_checkpoint_path: str = None, sharded_tensor_metas: list = None, remove_redundancy: bool = False)"
"signature": "(iteration: int, network: mindspore.nn.cell.Cell, optimizer: mindspore.nn.optim.optimizer.Optimizer = None, async_save_manager: mindformers.checkpoint.checkpoint.AsyncSaveManager = None, common_info: mindformers.checkpoint.checkpoint.CommonInfo = None, keep_max_num: int = 5, user_prefix: str = None, save_checkpoint_path: str = None, sharded_tensor_metas: Dict = None, remove_redundancy: bool = False)"
},
"mindformers.core.AdamW": {
"signature": "(params, learning_rate=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, use_fused=False, amsgrad=False, maximize=False, swap=False)"


+ 35
- 59
tests/st/test_ut/test_checkpoint/test_fully_parallel.py View File

@@ -54,18 +54,10 @@ def mock_get_all_sharded_tensor():
mock_shard_tensor3 = MockShardTensor("param3", (0,), (10,), "float32")

with patch("mindformers.checkpoint.fully_parallel.get_all_sharded_tensor") as mock:
mock.return_value = [
[mock_shard_tensor1, mock_shard_tensor2],
[mock_shard_tensor3]
]
yield mock


@pytest.fixture
def mock_get_rank():
"""Mock get_rank function"""
with patch("mindformers.checkpoint.fully_parallel.get_rank") as mock:
mock.return_value = 0
mock.return_value = {
0: {"param1": mock_shard_tensor1, "param2": mock_shard_tensor2},
1: {"param3": mock_shard_tensor3}
}
yield mock


@@ -172,12 +164,13 @@ def test_distribute_shards_basic():

# Check that all shards are assigned
assert len(result) == 3
# Check that each shard is assigned to a valid rank
for rank in result.values():
assert 0 <= rank < total_ranks
# Check that shards are assigned to ranks that cover them
for shard_id, rank in result.items():
assert rank in shard_coverage[shard_id]
for shard_id, rank_info in result.items():
selected_rank, rank_group = rank_info
# Check that each shard is assigned to a valid rank
assert 0 <= selected_rank < total_ranks
# Check that shards are assigned to ranks that cover them
assert selected_rank in shard_coverage[shard_id]
assert rank_group == tuple(shard_coverage[shard_id])


@pytest.mark.level0
@@ -219,7 +212,7 @@ def test_distribute_shards_single_rank():

result = distribute_shards(shard_coverage, shard_sizes, total_ranks)

assert result == {"shard1": 0, "shard2": 0}
assert result == {"shard1": (0, (0,)), "shard2": (0, (0,))}


@pytest.mark.level0
@@ -236,19 +229,13 @@ def test_apply_balance_shard_strategy(
"""
result = apply_balance_shard_strategy(mock_network, None)

assert len(result) == 4
shard_to_saving_rank, shard_id_to_tensor, dst_sharded_tensor_metas, param_redundancy = result

assert isinstance(shard_to_saving_rank, dict)
assert isinstance(shard_id_to_tensor, dict)
assert isinstance(dst_sharded_tensor_metas, dict)
assert isinstance(param_redundancy, dict)
assert isinstance(result, dict)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balanced_save_strategy_init(mock_network, mock_get_rank):
def test_balanced_save_strategy_init(mock_network, mock_get_real_rank):
"""
Feature: BalancedSaveStrategy initialization
Description: Test BalancedSaveStrategy class initialization with various parameters
@@ -274,7 +261,7 @@ def test_balanced_save_strategy_init(mock_network, mock_get_rank):
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balanced_save_strategy_apply_saving_parallelization(
mock_network, mock_get_rank, mock_get_all_sharded_tensor,
mock_network, mock_get_real_rank, mock_get_all_sharded_tensor,
mock_sharded_tensor_shard_id, mock_get_shard_size
):
"""
@@ -289,17 +276,14 @@ def test_balanced_save_strategy_apply_saving_parallelization(

result = strategy.apply_saving_parallelization()

assert len(result) == 2
shard_id_to_ranks, shard_id_to_tensor = result
assert isinstance(shard_id_to_ranks, dict)
assert isinstance(shard_id_to_tensor, dict)
assert isinstance(result, dict)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balanced_save_strategy_apply_saving_parallelization_with_cache(
mock_network, mock_get_rank, mock_get_all_sharded_tensor,
mock_network, mock_get_real_rank, mock_get_all_sharded_tensor,
mock_sharded_tensor_shard_id, mock_get_shard_size
):
"""
@@ -318,7 +302,7 @@ def test_balanced_save_strategy_apply_saving_parallelization_with_cache(

# Second call - should use cached distribution
with patch("mindformers.checkpoint.fully_parallel.apply_balance_shard_strategy") as mock_apply:
mock_apply.return_value = ({}, {})
mock_apply.return_value = {}
result2 = strategy.apply_saving_parallelization()
# Check that apply_balance_shard_strategy was not called again
mock_apply.assert_not_called()
@@ -330,7 +314,7 @@ def test_balanced_save_strategy_apply_saving_parallelization_with_cache(
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balanced_save_strategy_get_total_files(
mock_network, mock_get_rank, mock_get_all_sharded_tensor,
mock_network, mock_get_real_rank, mock_get_all_sharded_tensor,
mock_sharded_tensor_shard_id, mock_get_shard_size
):
"""
@@ -353,7 +337,7 @@ def test_balanced_save_strategy_get_total_files(
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balanced_save_strategy_get_cur_rank_file_id(
mock_network, mock_get_rank, mock_get_all_sharded_tensor,
mock_network, mock_get_real_rank, mock_get_all_sharded_tensor,
mock_sharded_tensor_shard_id, mock_get_shard_size
):
"""
@@ -376,7 +360,7 @@ def test_balanced_save_strategy_get_cur_rank_file_id(
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balanced_save_strategy_save(
tmp_path, mock_network, mock_get_rank, mock_get_all_sharded_tensor,
tmp_path, mock_network, mock_get_real_rank, mock_get_all_sharded_tensor,
mock_sharded_tensor_shard_id, mock_get_shard_size, mock_save_checkpoint,
mock_get_metadata_filename, mock_get_checkpoint_name, mock_get_checkpoint_iter_dir,
mock_save_metadata, mock_load_metadata, mock_reverse_sharded_tensor_shard_id
@@ -409,7 +393,7 @@ def test_balanced_save_strategy_save(
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balanced_save_strategy_save_with_existing_metadata(
tmp_path, mock_network, mock_get_rank, mock_get_all_sharded_tensor,
tmp_path, mock_network, mock_get_real_rank, mock_get_all_sharded_tensor,
mock_sharded_tensor_shard_id, mock_get_shard_size, mock_save_checkpoint,
mock_get_metadata_filename, mock_get_checkpoint_name, mock_get_checkpoint_iter_dir,
mock_save_metadata, mock_reverse_sharded_tensor_shard_id
@@ -448,7 +432,7 @@ def test_balanced_save_strategy_save_with_existing_metadata(
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balanced_save_strategy__get_rank_params_mappings(
mock_network, mock_get_rank
mock_network, mock_get_real_rank
):
"""
Feature: BalancedSaveStrategy._get_rank_params_mappings method
@@ -460,13 +444,6 @@ def test_balanced_save_strategy__get_rank_params_mappings(
checkpoint_path="./checkpoint"
)

# Create mock data
shared_distribution = {
"shard1": 0,
"shard2": 1,
"shard3": 0
}

mock_tensor1 = MagicMock()
mock_tensor1.key = "param1"
mock_tensor2 = MagicMock()
@@ -474,13 +451,13 @@ def test_balanced_save_strategy__get_rank_params_mappings(
mock_tensor3 = MagicMock()
mock_tensor3.key = "param3"

id_to_tensor = {
"shard1": mock_tensor1,
"shard2": mock_tensor2,
"shard3": mock_tensor3
# Create mock data
shared_distribution = {
0: {"shard1": (mock_tensor1, (0,)), "shard3": (mock_tensor3, (0,))},
1: {"shard2": (mock_tensor2, (1,))}
}

result = strategy._get_rank_params_mappings(shared_distribution, id_to_tensor)
result = strategy._get_rank_params_mappings(shared_distribution)

assert isinstance(result, dict)
assert 0 in result
@@ -494,7 +471,7 @@ def test_balanced_save_strategy__get_rank_params_mappings(
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balanced_save_strategy__get_rank_param_ids_mappings(
mock_network, mock_get_rank
mock_network, mock_get_real_rank
):
"""
Feature: BalancedSaveStrategy._get_rank_param_ids_mappings method
@@ -508,9 +485,8 @@ def test_balanced_save_strategy__get_rank_param_ids_mappings(

# Create mock data
shared_distribution = {
"shard1": 0,
"shard2": 1,
"shard3": 0
0: {"shard1": (None, (0,)), "shard3": (None, (0,))},
1: {"shard2": (None, (1,))}
}

result = strategy._get_rank_param_ids_mappings(shared_distribution)
@@ -527,7 +503,7 @@ def test_balanced_save_strategy__get_rank_param_ids_mappings(
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balanced_save_strategy__get_total_files_num(
mock_network, mock_get_rank
mock_network, mock_get_real_rank
):
"""
Feature: BalancedSaveStrategy._get_total_files_num method
@@ -563,7 +539,7 @@ def test_balanced_save_strategy__get_total_files_num(
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balanced_save_strategy__get_cur_rank_file_id(
mock_network, mock_get_rank
mock_network, mock_get_real_rank
):
"""
Feature: BalancedSaveStrategy._get_cur_rank_file_id method
@@ -612,7 +588,7 @@ def test_balanced_save_strategy__get_cur_rank_file_id(
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balanced_save_strategy_get_total_files_and_cur_rank_file_id(
mock_network, mock_get_rank, mock_get_all_sharded_tensor,
mock_network, mock_get_real_rank, mock_get_all_sharded_tensor,
mock_sharded_tensor_shard_id, mock_get_shard_size
):
"""
@@ -634,7 +610,7 @@ def test_balanced_save_strategy_get_total_files_and_cur_rank_file_id(

# Check that second calls use cached values
with patch("mindformers.checkpoint.fully_parallel.apply_balance_shard_strategy") as mock_apply:
mock_apply.return_value = ({}, {})
mock_apply.return_value = {}
total_files2 = strategy.get_total_files()
cur_rank_file_id2 = strategy.get_cur_rank_file_id()



+ 0
- 149
tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization/gpt_model_for_test.py View File

@@ -1,149 +0,0 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""a model designed for test."""


from functools import partial
import numpy as np
import mindspore as ms
from mindformers.parallel_core.inference.tensor_parallel.layers import (ColumnParallelLinear,
RowParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear)
from mindformers.parallel_core.transformer_config import TransformerConfig
from mindformers.models.configuration_utils import PretrainedConfig
from mindformers.parallel_core.inference.quantization.utils import get_quant_config


class LinearSpec:
"""Specification for linear layers in the model."""

def __init__(self, linear_type, input_size, output_size, has_bias, compute_dtype, quant_type):
if isinstance(compute_dtype, str):
compute_dtype = self.convert_pt_dtype_to_ms(compute_dtype)
if compute_dtype not in [ms.dtype.float32, ms.dtype.float16, ms.dtype.bfloat16]:
raise ValueError(f"Unsupported compute_dtype: {compute_dtype}")
self.linear_type = linear_type
self.input_size = input_size
self.output_size = output_size
self.has_bias = has_bias
self.skip_bias_add = False
self.compute_dtype = compute_dtype
self.transpose_b=True
self.quant_type = quant_type

def name(self):
return f"{self.linear_type}-has_bias_{self.has_bias}-" \
f"compute_dtype_{self.compute_dtype}-quant_type_{self.quant_type}"

@staticmethod
def convert_pt_dtype_to_ms(pt_dtype: str):
"""Convert PyTorch dtype to MindSpore dtype."""
dtype_mapping = {
'fp32': ms.dtype.float32,
'fp16': ms.dtype.float16,
'bf16': ms.dtype.bfloat16,
}
mstype = dtype_mapping.get(pt_dtype, None)
if mstype is None:
raise ValueError(f"Unsupported pytorch dtype: {pt_dtype}")
return mstype


class ModelSpec:
def __init__(self, compute_dtype, param_init_dtype, tensor_parallel, linear_specs):
self.linear_specs = linear_specs
self.compute_dtype = compute_dtype
self.param_init_dtype = param_init_dtype
self.tensor_parallel = tensor_parallel


class TestPretrainedConfig(PretrainedConfig):
def __init__(self, quantization, pretrained_model_dir):
super().__init__(
quantization=quantization,
pretrained_model_dir=pretrained_model_dir,
)


class GPTModelForTest(ms.nn.Cell):
"""A model designed for testing parallel linear operations."""

def __init__(self, model_spec, comm_pgs, quantization: str, quant_model_dir=None):
super().__init__()
self.model_spec = model_spec
if quant_model_dir is None:
quant_config = None
else:
quant_config = get_quant_config(TestPretrainedConfig(quantization, quant_model_dir), [])
transformer_config = TransformerConfig(
tensor_model_parallel_size=model_spec.tensor_parallel,
compute_dtype=model_spec.compute_dtype,
params_dtype=model_spec.param_init_dtype,
num_layers=1,
num_attention_heads=model_spec.tensor_parallel,
)
self.linears = GPTModelForTest._build_linears(comm_pgs, model_spec, transformer_config, quant_config)
self.num_linears = len(self.linears)

@staticmethod
def _build_linears(comm_pgs, model_spec, transformer_config, quant_config):
"""Build a list of linear layers based on the model specifications."""
linear_map = {
"ColumnParallelLinear": partial(ColumnParallelLinear, gather_output=True),
"MergedColumnParallelLinear": MergedColumnParallelLinear,
"QKVParallelLinear": QKVParallelLinear,
"RowParallelLinear": RowParallelLinear,
"ReplicatedLinear": ReplicatedLinear,
}
linears = []
for index, linear_spec in enumerate(model_spec.linear_specs):
linear = linear_map[linear_spec.linear_type](
input_size=linear_spec.input_size,
output_size=linear_spec.output_size,
config=transformer_config,
skip_bias_add=linear_spec.skip_bias_add,
compute_dtype=linear_spec.compute_dtype,
transpose_b=linear_spec.transpose_b,
bias=linear_spec.has_bias,
tp_group=comm_pgs.tp,
quant_config=quant_config,
prefix=f"linears.{index}"
)
linears.append(linear)
return ms.nn.SequentialCell(linears)

def forward(self, x):
"""Forward pass through the model, processing input through all linear layers."""
output = self.construct(x).astype(ms.dtype.float32).asnumpy()
bs = output.shape[0]
if bs != self.num_linears:
raise ValueError(f"outputs size must be equal to the number of linears: {bs} != {self.num_linears}")
outputs = np.split(output, bs, axis=0)
output_dict = {}
for index, linear_spec in enumerate(self.model_spec.linear_specs):
name = f"index_{index}-{linear_spec.name()}"
output_dict[name] = outputs[index].squeeze(axis=0)
return output_dict

def construct(self, x):
y = ms.ops.zeros_like(x)
y = y.expand_dims(axis=0)
for index in range(self.num_linears):
linear = self.linears[index]
z = linear(x).expand_dims(axis=0)
y = ms.ops.concat((y, z))
return y[1:,::]

+ 381
- 85
tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization/numpy_quantizer.py View File

@@ -17,108 +17,55 @@

import json
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, Tuple

import numpy as np
from safetensors.numpy import save_file
from gpt_model_for_test import ModelSpec
from simple_gpt_model import ModelSpec


class NumpyQuantizer:
"""A class for quantizing model weights using NumPy."""
@dataclass
class QuantResult:
"""Result of weight quantization."""
weights: Dict[str, np.ndarray]
descriptions: Dict[str, str]

def __init__(self, model_spec: ModelSpec, quant_policy: list[str]):
self.model_spec = model_spec
self.quant_policy = quant_policy
self.description_file_path = None

def quant(self, quant_input: np.ndarray, weights, save_dir):
"""Quantize the input and weights, save to safetensors and JSON description."""
quant_weights, quant_desc = self._quant(quant_input, weights)
print(f"quant_weights: {quant_weights.keys()}", flush=True)
print(f"quant_desc: {quant_desc}", flush=True)
save_file(quant_weights, os.path.join(save_dir, 'quant-model-00001-00001.safetensors'))
with open(os.path.join(save_dir, "quantization_description.json"), "w", encoding='utf-8') as f:
json.dump(quant_desc, f, indent=2, ensure_ascii=False)
print(f"quantization weights saved to {save_dir}", flush=True)
class WeightQuantStrategy(ABC):
"""Abstract base class for weight quantization strategies."""

def _quant(self, quant_input: np.ndarray, weights):
"""Internal method to perform quantization on weights based on policy."""
quant_weights = {}
quant_desc = {}
for index, (qpolicy, linear_spec) in enumerate(zip(self.quant_policy, self.model_spec.linear_specs)):
weight = weights[f"linears.{index}.weight"]
if qpolicy == 'a8w8':
_, input_scale, input_offset = NumpyQuantizer._act_int8_quant(quant_input)
quant_weight, w_scale = NumpyQuantizer._weight_int8_quant(weight, transpose_b=linear_spec.transpose_b)
x_zp = input_offset.astype(np.int32) # per-tensor zero-point
quant_bias = -np.sum(x_zp * quant_weight.astype(np.int32), axis=-1).astype(np.int32)
deq_scale = (input_scale.astype(np.float32) * w_scale.astype(np.float32))
beta = np.zeros(linear_spec.input_size, dtype=np.int32)
quant_weights.update({
f"linears.{index}.weight": quant_weight,
f"linears.{index}.deq_scale": deq_scale,
f"linears.{index}.input_scale": np.tile(input_scale, linear_spec.input_size),
f"linears.{index}.input_offset": np.tile(input_offset, linear_spec.input_size),
f"linears.{index}.quant_bias": quant_bias,
f"linears.{index}.beta": beta,
})
quant_desc.update({
f"linears.{index}.weight": "W8A8",
f"linears.{index}.deq_scale": "W8A8",
f"linears.{index}.input_scale": "W8A8",
f"linears.{index}.input_offset": "W8A8",
f"linears.{index}.quant_bias": "W8A8",
f"linears.{index}.beta": "W8A8",
})
if linear_spec.has_bias:
quant_weights[f"linears.{index}.bias"] = weights[f"linears.{index}.bias"]
quant_desc[f"linears.{index}.bias"] = "W8A8"
continue
if qpolicy == 'a8dynw8':
quant_weight, w_scale = NumpyQuantizer._weight_int8_quant(weight, transpose_b=linear_spec.transpose_b)
quant_weights.update({
f"linears.{index}.weight": quant_weight,
f"linears.{index}.w_scale": w_scale
})
quant_desc.update({
f"linears.{index}.weight": "W8A8_DYNAMIC",
f"linears.{index}.w_scale": "W8A8_DYNAMIC",
})
if linear_spec.has_bias:
quant_weights[f"linears.{index}.bias"] = weights[f"linears.{index}.bias"]
quant_desc[f"linears.{index}.bias"] = "W8A8_DYNAMIC"
continue
if qpolicy is None:
quant_weights.update({
f"linears.{index}.weight": weight,
})
quant_desc.update({
f"linears.{index}.weight": "FLOAT",
})
if linear_spec.has_bias:
quant_weights[f"linears.{index}.bias"] = weights[f"linears.{index}.bias"]
quant_desc[f"linears.{index}.bias"] = "FLOAT"
continue
raise ValueError(f"Unsupported quant policy: {qpolicy}")
return quant_weights, quant_desc
@abstractmethod
def quantize_weight(self, weight: np.ndarray, transpose_b: bool,
input_size: int = None) -> Dict[str, np.ndarray]:
"""Quantize a single weight tensor."""

@abstractmethod
def get_description(self) -> str:
"""Get the quantization description string."""

@staticmethod
def _get_quant_min_max(num_bits=8, signed=True, narrow_range=False):
"""Calculate quantization params for minimum/maximum quantization integer"""
def get_quant_min_max(num_bits: int = 8, signed: bool = True,
narrow_range: bool = False) -> Tuple[int, int]:
"""Calculate quantization params for minimum/maximum quantization integer."""
if signed:
quant_min = 0 - 2 ** (num_bits - 1)
quant_min = -(2 ** (num_bits - 1))
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0
quant_max = 2 ** num_bits - 1

if narrow_range:
quant_min = quant_min + 1

return quant_min, quant_max

@staticmethod
def _act_int8_quant(tensor):
def act_int8_quant(tensor: np.ndarray) -> Tuple[np.ndarray, float, float]:
"""Quantize activation tensor to int8."""
bits=8
quant_min, quant_max = NumpyQuantizer._get_quant_min_max(bits)
bits = 8
quant_min, quant_max = WeightQuantStrategy.get_quant_min_max(bits)

min_val = np.min(tensor)
max_val = np.max(tensor)
@@ -139,15 +86,18 @@ class NumpyQuantizer:
return quantized, scale, zero_point

@staticmethod
def _weight_int8_quant(tensor, transpose_b=True):
def weight_int8_quant(tensor: np.ndarray,
transpose_b: bool = True) -> Tuple[np.ndarray, np.ndarray]:
"""Quantize weight tensor to int8."""
bits=8
quant_min, quant_max = NumpyQuantizer._get_quant_min_max(bits)
bits = 8
quant_min, quant_max = WeightQuantStrategy.get_quant_min_max(bits)
oc_axis = 0 if transpose_b else 1
ic_axis = 1 if transpose_b else 0
oc = tensor.shape[oc_axis]

min_val = np.min(tensor, axis=ic_axis, keepdims=True)
max_val = np.max(tensor, axis=ic_axis, keepdims=True)

if (max_val == min_val).all():
scale = np.ones((oc,), dtype=np.float32)
else:
@@ -160,4 +110,350 @@ class NumpyQuantizer:
quantized = np.round(tensor / scale)
quantized = np.clip(quantized, quant_min, quant_max).astype(np.int8)
scale = np.squeeze(scale)

return quantized, scale

@staticmethod
def weight_int4_per_group_pack(tensor: np.ndarray, group_size: int,
transpose_b: bool = True) -> (
Tuple[np.ndarray, np.ndarray]
):
"""
Quantize weight tensor to int4 per group and pack two int4 values into one uint8.

Args:
tensor: weight tensor to quantize, shape (oc, ic) if transpose_b=True
group_size: size of each quantization group along input dimension
transpose_b: whether the tensor is in (oc, ic) format

Returns:
packed: packed int4 weights as uint8, shape (oc//2, ic)
scale_uint64: quantization scales as uint64, shape (oc, ic//group_size)
"""
bits = 4
quant_min, quant_max = WeightQuantStrategy.get_quant_min_max(bits, signed=True)

if transpose_b:
oc, ic = tensor.shape[0], tensor.shape[1]
else:
ic, oc = tensor.shape[0], tensor.shape[1]

# Validate dimensions
if ic % group_size != 0:
raise ValueError(
f"Input dimension {ic} must be divisible by group_size {group_size}"
)
if oc % 2 != 0:
raise ValueError(
f"Output dimension {oc} must be even for int4 packing"
)

num_groups = ic // group_size

# Reshape tensor for per-group quantization: (oc, num_groups, group_size)
if transpose_b:
tensor_grouped = tensor.reshape(oc, num_groups, group_size)
else:
tensor_grouped = tensor.T.reshape(oc, num_groups, group_size)

# Calculate scale per group (symmetric quantization)
max_vals = np.max(np.abs(tensor_grouped), axis=2, keepdims=True)
max_vals = np.where(max_vals == 0, 1.0, max_vals)
scales = (max_vals / quant_max).astype(np.float32)

# Quantize and reshape
quantized = np.round(tensor_grouped / scales)
quantized = np.clip(quantized, quant_min, quant_max).astype(np.int8)
quantized = quantized.reshape(oc, ic)
scales = scales.squeeze(axis=2)

# Pack two consecutive oc values into one uint8
quantized_even = quantized[0::2, :]
quantized_odd = quantized[1::2, :]

even_unsigned = (quantized_even & 0x0F).astype(np.uint8)
odd_unsigned = (quantized_odd & 0x0F).astype(np.uint8)

# Pack: even in low 4 bits, odd in high 4 bits
packed_unsigned = (odd_unsigned << 4) | even_unsigned

return (packed_unsigned,
scales.astype(np.float32).view(np.uint32).astype(np.uint64))


class A8W8Strategy(WeightQuantStrategy):
"""INT8 weight and activation quantization strategy."""

def __init__(self, quant_input: np.ndarray):
self.quant_input = quant_input

def quantize_weight(self, weight: np.ndarray, transpose_b: bool,
input_size: int = None) -> Dict[str, np.ndarray]:
"""Quantize weight using INT8 static quantization."""
_, input_scale, input_offset = self.act_int8_quant(self.quant_input)
quant_weight, w_scale = self.weight_int8_quant(weight, transpose_b)

x_zp = input_offset.astype(np.int32)
quant_bias = -np.sum(x_zp * quant_weight.astype(np.int32), axis=-1).astype(np.int32)
deq_scale = input_scale.astype(np.float32) * w_scale.astype(np.float32)
output_size = weight.shape[0]
beta = np.zeros(output_size, dtype=np.int32)

# Input scale and offset should match input_size
input_scale_arr = np.full((input_size,), input_scale, dtype=np.float32)
input_offset_arr = np.full((input_size,), input_offset, dtype=np.float32)

return {
'weight': quant_weight,
'deq_scale': deq_scale,
'input_scale': input_scale_arr,
'input_offset': input_offset_arr.astype(np.int8),
'quant_bias': quant_bias,
'beta': beta,
}

def get_description(self) -> str:
return "W8A8"


class A8DynW8Strategy(WeightQuantStrategy):
"""INT8 dynamic weight quantization strategy."""

def quantize_weight(self, weight: np.ndarray, transpose_b: bool,
input_size: int = None) -> Dict[str, np.ndarray]:
"""Quantize weight using INT8 dynamic quantization."""
quant_weight, w_scale = self.weight_int8_quant(weight, transpose_b)
return {
'weight': quant_weight,
'w_scale': w_scale,
}

def get_description(self) -> str:
return "W8A8_DYNAMIC"


class A8W4Strategy(WeightQuantStrategy):
"""INT4 weight quantization strategy."""

def __init__(self, group_size: int = 256):
self.group_size = group_size

def quantize_weight(self, weight: np.ndarray, transpose_b: bool,
input_size: int = None) -> Dict[str, np.ndarray]:
"""Quantize weight using INT4 per-group quantization."""
qweight_packed, w_scale = self.weight_int4_per_group_pack(
weight, self.group_size, transpose_b
)
return {
'weight': qweight_packed,
'w_scale': w_scale,
}

def get_description(self) -> str:
return "W4A8_DYNAMIC"


class FloatStrategy(WeightQuantStrategy):
"""No quantization (float) strategy."""

def quantize_weight(self, weight: np.ndarray, transpose_b: bool,
input_size: int = None) -> Dict[str, np.ndarray]:
"""Return weight as-is without quantization."""
return {'weight': weight}

def get_description(self) -> str:
return "FLOAT"


class LayerWeightHandler:
"""Handler for processing weights of different layer types."""

def __init__(self, index: int, linear_spec, weights: dict, strategy: WeightQuantStrategy):
self.index = index
self.linear_spec = linear_spec
self.weights = weights
self.strategy = strategy

def process(self) -> QuantResult:
"""Process weights based on layer type."""
layer_type = self.linear_spec.linear_type

if layer_type == "QKVParallelLinear":
return self._process_qkv()
if layer_type == "MergedColumnParallelLinear":
return self._process_merged()
if layer_type in ("ColumnParallelGroupedLinear", "RowParallelGroupedLinear"):
return self._process_grouped()
return self._process_standard()

def _process_qkv(self) -> QuantResult:
"""Process QKV parallel linear weights."""
quant_weights = {}
quant_desc = {}

for qkv_name in ['q', 'k', 'v']:
weight_key = f"linears.{self.index}.{qkv_name}.weight"
weight = self.weights[weight_key]

quant_result = self.strategy.quantize_weight(
weight, self.linear_spec.transpose_b, self.linear_spec.input_size
)

# Add quantized weights with proper keys
for suffix, value in quant_result.items():
key = f"linears.{self.index}.{qkv_name}.{suffix}"
quant_weights[key] = value
quant_desc[key] = self.strategy.get_description()

# Add bias if present
if self.linear_spec.has_bias:
bias_key = f"linears.{self.index}.{qkv_name}.bias"
quant_weights[bias_key] = self.weights[bias_key]
quant_desc[bias_key] = self.strategy.get_description()

return QuantResult(quant_weights, quant_desc)

def _process_merged(self) -> QuantResult:
"""Process merged column parallel linear weights."""
quant_weights = {}
quant_desc = {}

for merge_name in ['gating', 'hidden']:
weight_key = f"linears.{self.index}.{merge_name}.weight"
weight = self.weights[weight_key]

quant_result = self.strategy.quantize_weight(
weight, self.linear_spec.transpose_b, self.linear_spec.input_size
)

# Add quantized weights with proper keys
for suffix, value in quant_result.items():
key = f"linears.{self.index}.{merge_name}.{suffix}"
quant_weights[key] = value
quant_desc[key] = self.strategy.get_description()

# Add bias if present
if self.linear_spec.has_bias:
bias_key = f"linears.{self.index}.{merge_name}.bias"
quant_weights[bias_key] = self.weights[bias_key]
quant_desc[bias_key] = self.strategy.get_description()

return QuantResult(quant_weights, quant_desc)

def _process_grouped(self) -> QuantResult:
"""Process grouped linear (MoE) weights."""
quant_weights = {}
quant_desc = {}

for gate_name in ['gate', 'up']:
weight_key = f"linears.{self.index}.{gate_name}.weight"
weight = self.weights[weight_key]

quant_result = self.strategy.quantize_weight(weight, transpose_b=True)

# Add quantized weights with proper keys
for suffix, value in quant_result.items():
key = f"linears.{self.index}.{gate_name}.{suffix}"
quant_weights[key] = value

# Description uses base key for grouped layers
quant_desc[f"linears.{self.index}.weight"] = self.strategy.get_description()
quant_desc[f"linears.{self.index}.w_scale"] = self.strategy.get_description()

return QuantResult(quant_weights, quant_desc)

def _process_standard(self) -> QuantResult:
"""Process standard linear layer weights."""
quant_weights = {}
quant_desc = {}

weight_key = f"linears.{self.index}.weight"
weight = self.weights[weight_key]

quant_result = self.strategy.quantize_weight(
weight, self.linear_spec.transpose_b, self.linear_spec.input_size
)

# Add quantized weights with proper keys
for suffix, value in quant_result.items():
key = f"linears.{self.index}.{suffix}"
quant_weights[key] = value
quant_desc[key] = self.strategy.get_description()

# Add bias if present
if self.linear_spec.has_bias:
bias_key = f"linears.{self.index}.bias"
quant_weights[bias_key] = self.weights[bias_key]
quant_desc[bias_key] = self.strategy.get_description()

return QuantResult(quant_weights, quant_desc)


class NumpyQuantizer:
"""A class for quantizing model weights using NumPy."""

def __init__(self, model_spec: ModelSpec, quant_policy: list):
self.model_spec = model_spec
self.quant_policy = quant_policy
self.global_group_size = None

def quant(self, quant_input: np.ndarray, weights: dict, save_dir: str):
"""Quantize the input and weights, save to safetensors and JSON description."""
quant_weights, quant_desc = self._quant(quant_input, weights)
print(f"quant_weights: {quant_weights.keys()}", flush=True)
print(f"quant_desc: {quant_desc}", flush=True)

save_file(quant_weights, os.path.join(save_dir, 'quant-model-00001-00001.safetensors'))
with open(os.path.join(save_dir, "quantization_description.json"), "w",
encoding='utf-8') as f:
json.dump(quant_desc, f, indent=2, ensure_ascii=False)

print(f"quantization weights saved to {save_dir}", flush=True)

def _quant(self, quant_input: np.ndarray, weights: dict) -> Tuple[dict, dict]:
"""Internal method to perform quantization on weights based on policy."""
all_quant_weights = {}
all_quant_desc = {}

for index, (qpolicy, linear_spec) in enumerate(
zip(self.quant_policy, self.model_spec.linear_specs)
):
# Create appropriate quantization strategy
strategy = self._create_strategy(qpolicy, quant_input, linear_spec)

# Process weights using the strategy
handler = LayerWeightHandler(index, linear_spec, weights, strategy)
result = handler.process()

# Merge results
all_quant_weights.update(result.weights)
all_quant_desc.update(result.descriptions)

# Add global group size if set
if self.global_group_size is not None:
all_quant_desc["group_size"] = int(self.global_group_size)

return all_quant_weights, all_quant_desc

def _create_strategy(self, qpolicy: str, quant_input: np.ndarray,
linear_spec) -> WeightQuantStrategy:
"""Create appropriate quantization strategy based on policy."""
if qpolicy == 'a8w8':
return A8W8Strategy(quant_input)
if qpolicy == 'a8dynw8':
return A8DynW8Strategy()
if qpolicy == 'a8w4':
# Validate that a8w4 is only used with grouped layers
layer_type = linear_spec.linear_type
if layer_type not in ("ColumnParallelGroupedLinear",
"RowParallelGroupedLinear"):
raise ValueError(
"a8w4 quantization only supports grouped linear layers"
)
group_size = 256
self.global_group_size = group_size
return A8W4Strategy(group_size)
if qpolicy is None or qpolicy == 'float':
return FloatStrategy()

raise ValueError(f"Unsupported quant policy: {qpolicy}")

+ 0
- 168
tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization/run_parallel_linear.py View File

@@ -1,168 +0,0 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Run ColumnParallelLinear accuracy test with configurable parameters via args"""


import argparse
import glob
import os
import tempfile

import numpy as np
from safetensors import safe_open

import mindspore as ms
from mindspore.communication import init
from numpy_quantizer import NumpyQuantizer
from gpt_model_for_test import GPTModelForTest, LinearSpec, ModelSpec
from mindformers.parallel_core.inference.parallel_state import initialize_model_parallel
from mindformers.parallel_core.process_group_config import ModelCommProcessGroups


class ParallelModelRunner:
"""Runner for parallel model testing with quantization support."""

def __init__(self, config):
"""Initialize the parallel model runner with given arguments."""
self.config = config
# set up parallel context
rank_id_str = os.environ.get("RANK_ID")
self.rank_id = int(rank_id_str) if rank_id_str is not None else None
self.worker_num = int(os.environ.get("MS_WORKER_NUM", "1"))
self.model_comm_pgs = ModelCommProcessGroups.get_default_model_comm_pgs()
if self.rank_id is not None:
init()
initialize_model_parallel(tensor_model_parallel_size=self.config.tensor_parallel)
self.model_comm_pgs = ModelCommProcessGroups.use_parallel_state_groups(required_groups=['tp'])

linear_specs = []
quant_policys = []
self.quantization = config.quantization
for linear_type in config.linear_types:
for has_bias in [True, False]:
for quant_policy in config.quant_policies:
quant_policy = quant_policy if config.quantization == 'golden-stick' else 'float'
linear_specs.append(LinearSpec(linear_type, config.input_size, config.output_size,
has_bias, config.compute_dtype, quant_policy))
quant_policys.append(quant_policy)

self.model_spec = ModelSpec(
compute_dtype=config.compute_dtype,
param_init_dtype=config.param_init_dtype,
tensor_parallel=config.tensor_parallel,
linear_specs=linear_specs,
)
self.quant_model_dir = None
if self.quantization == 'golden-stick':
self.quantizer = NumpyQuantizer(self.model_spec, quant_policys)
self.quant_model_dir = tempfile.mkdtemp(prefix="quant_model_for_test_")

@staticmethod
def _gen_float_weights(model_spec):
"""Generate random float weights for model specifications."""
np.random.seed(42)
weights = {}
for index, linear_spec in enumerate(model_spec.linear_specs):
weight_shape = (linear_spec.output_size, linear_spec.input_size)
output_size = linear_spec.output_size
weight = 0.01 * np.random.randn(*weight_shape).astype(np.float32)
weights[f"linears.{index}.weight"] = weight
if linear_spec.has_bias:
bias = 0.01 * np.random.randn(output_size).astype(np.float32)
weights[f"linears.{index}.bias"] = bias
return weights

@staticmethod
def _gen_input(model_spec):
"""Generate random input data for model specifications."""
np.random.seed(42)
return 0.01 * np.random.randn(2 * 2, model_spec.linear_specs[0].input_size).astype(np.float32)

def _create_network(self):
"""Create the network model for testing."""
return GPTModelForTest(self.model_spec, self.model_comm_pgs, self.quantization, self.quant_model_dir)

def _load_quant_weights(self):
"""Load quantized weights from the model directory."""
if not os.path.isdir(self.quant_model_dir):
raise ValueError(f"Invalid quant_model_dir: {self.quant_model_dir}")
safetensor_files = glob.glob(os.path.join(self.quant_model_dir, "*.safetensors"))
if len(safetensor_files) == 1:
safetensor_file = safetensor_files[0]
elif len(safetensor_files) > 1:
raise FileNotFoundError(f"Found multiple safetensor files in {self.quant_model_dir}")
else:
raise FileNotFoundError(f"Found no safetensor file in {self.quant_model_dir}")
if not os.path.exists(safetensor_file):
raise FileNotFoundError(f"File {safetensor_file} not found.")
with safe_open(safetensor_file, framework="np", device="cpu") as f:
weights = {}
for key in f.keys():
weights[key] = f.get_slice(key)
return weights

@staticmethod
def load_weights_into_network(network, weights):
"""Load weights into the network parameters."""
params = network.parameters_dict()
loaded = []
for k, v in weights.items():
param = params.get(k)
if param is None:
continue
loaded.append(k)
param.weight_loader(param, v)
print(f"weights not use: {set(weights.keys()) - set(loaded)}", flush=True)
print(f"params not load: {set(params.keys()) - set(loaded)}", flush=True)

def run(self):
"""Run the parallel model test."""
input_data = ParallelModelRunner._gen_input(self.model_spec)
weights = ParallelModelRunner._gen_float_weights(self.model_spec)
if self.quantization == 'golden-stick':
self.quantizer.quant(input_data, weights, self.quant_model_dir)
weights = self._load_quant_weights()
network = self._create_network()
ParallelModelRunner.load_weights_into_network(network, weights)
net_input = ms.Tensor(input_data, dtype=LinearSpec.convert_pt_dtype_to_ms(self.model_spec.compute_dtype))
output_dict = network.forward(net_input)

if self.rank_id is None or int(self.rank_id) == 0:
np.savez(self.config.output_path, **output_dict)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run ColumnParallelLinear test")
parser.add_argument("--linear_types", type=str, action='append', default=None,
help="List of linear types, e.g., --linear_types ColumnParallelLinear "\
"--linear_types RowParallelLinear")
parser.add_argument("--tensor_parallel", type=int, default=1)
parser.add_argument("--compute_dtype", type=str, default='bf16')
parser.add_argument("--param_init_dtype", type=str, default='bf16')
parser.add_argument("--output_path", type=str, default="output.npz")
parser.add_argument("--quantization", type=str, default=None)
parser.add_argument("--quant_policies", type=str, action='append', default=None,
help="List of quantization policies, e.g., --quant_policies a8w8 --quant_policies a8dynw8")
args = parser.parse_args()
args.input_size = 32
args.output_size = 32

ms.set_context(device_target="Ascend",
mode=ms.GRAPH_MODE,
jit_config={"jit_level": "O0", "infer_boost": "on"},
deterministic="ON")

quant_runner = ParallelModelRunner(args)
quant_runner.run()

tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/gpt_model_for_test.py → tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization/simple_gpt_model.py View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""a model designed for test."""
"""Simple GPT model for testing parallel linear layers."""


from functools import partial
@@ -29,15 +29,30 @@ from mindformers.parallel_core.transformer_config import TransformerConfig
from mindformers.models.configuration_utils import PretrainedConfig
from mindformers.parallel_core.inference.quantization.utils import get_quant_config
from mindformers.parallel_core.inference.weights_utils import set_weight_attrs
from mindformers.parallel_core.inference.tensor_parallel.grouped_layers import UnquantizedGroupedLinearMethod
from mindformers.parallel_core.inference.quantization.base_config import QuantizeMethodBase
from mindformers.parallel_core.inference.tensor_parallel.grouped_layers import (
UnquantizedGroupedLinearMethod
)


def convert_dtype_str_to_ms(dtype_str: str):
"""Convert dtype string to MindSpore dtype."""
dtype_mapping = {
'fp32': ms.dtype.float32,
'fp16': ms.dtype.float16,
'bf16': ms.dtype.bfloat16,
}
mstype = dtype_mapping.get(dtype_str, None)
if mstype is None:
raise ValueError(f"Unsupported dtype string: {dtype_str}")
return mstype


class LinearSpec:
"""Specification for linear layers in the model."""
"""Specification for standard linear layers."""

def __init__(self, linear_type, input_size, output_size, has_bias, compute_dtype, quant_type):
if isinstance(compute_dtype, str):
compute_dtype = self.convert_pt_dtype_to_ms(compute_dtype)
compute_dtype = convert_dtype_str_to_ms(compute_dtype)
if compute_dtype not in [ms.dtype.float32, ms.dtype.float16, ms.dtype.bfloat16]:
raise ValueError(f"Unsupported compute_dtype: {compute_dtype}")
self.linear_type = linear_type
@@ -46,101 +61,84 @@ class LinearSpec:
self.has_bias = has_bias
self.skip_bias_add = False
self.compute_dtype = compute_dtype
self.transpose_b=True
self.transpose_b = True
self.quant_type = quant_type

def name(self):
"""Generate a unique name for this layer configuration."""
return f"{self.linear_type}-has_bias_{self.has_bias}-" \
f"compute_dtype_{self.compute_dtype}-quant_type_{self.quant_type}"

@staticmethod
def convert_pt_dtype_to_ms(pt_dtype: str):
"""Convert PyTorch dtype to MindSpore dtype."""
dtype_mapping = {
'fp32': ms.dtype.float32,
'fp16': ms.dtype.float16,
'bf16': ms.dtype.bfloat16,
}
mstype = dtype_mapping.get(pt_dtype, None)
if mstype is None:
raise ValueError(f"Unsupported pytorch dtype: {pt_dtype}")
return mstype
def infer_shape(self):
"""Infer output shape. Returns output size."""
if self.linear_type == "MergedColumnParallelLinear":
# MergedColumnParallelLinear outputs 2 * output_size (gating + hidden)
return self.output_size * 2
return self.output_size


class QKVLinearSpec:
"""Specification for linear layers in the model."""
"""Specification for QKV parallel linear layers."""

def __init__(self, linear_type, hidden_size, head_size, total_num_heads, total_num_kv_heads,
has_bias, compute_dtype, quant_type):
if isinstance(compute_dtype, str):
compute_dtype = self.convert_pt_dtype_to_ms(compute_dtype)
compute_dtype = convert_dtype_str_to_ms(compute_dtype)
if compute_dtype not in [ms.dtype.float32, ms.dtype.float16, ms.dtype.bfloat16]:
raise ValueError(f"Unsupported compute_dtype: {compute_dtype}")

self.linear_type = linear_type
self.input_size = hidden_size
self.head_size = head_size
self.total_num_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
self.output_size = (
(self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_size
)
self.output_size = (total_num_heads + 2 * total_num_kv_heads) * head_size
self.output_sizes = [
self.total_num_heads * self.head_size, # q_proj
self.total_num_kv_heads * self.head_size, # k_proj
self.total_num_kv_heads * self.head_size, # v_proj
total_num_heads * head_size, # q_proj
total_num_kv_heads * head_size, # k_proj
total_num_kv_heads * head_size, # v_proj
]
self.has_bias = has_bias
self.skip_bias_add = False
self.compute_dtype = compute_dtype
self.transpose_b=True
self.transpose_b = True
self.quant_type = quant_type

def name(self):
"""Generate a unique name for this layer configuration."""
return f"{self.linear_type}-has_bias_{self.has_bias}-" \
f"compute_dtype_{self.compute_dtype}-quant_type_{self.quant_type}"

@staticmethod
def convert_pt_dtype_to_ms(pt_dtype: str):
"""Convert PyTorch dtype to MindSpore dtype."""
dtype_mapping = {
'fp32': ms.dtype.float32,
'fp16': ms.dtype.float16,
'bf16': ms.dtype.bfloat16,
}
mstype = dtype_mapping.get(pt_dtype, None)
if mstype is None:
raise ValueError(f"Unsupported pytorch dtype: {pt_dtype}")
return mstype
def infer_shape(self):
"""Infer output shape. Returns output size (q + k + v concatenated)."""
return self.output_size


class GroupLinearSpec:
"""Specification for linear layers in the model."""
"""Specification for grouped linear layers (MoE)."""

def __init__(self,linear_type, num_local_experts,input_size, output_size, quant_type):
def __init__(self, linear_type, num_local_experts, input_size, output_size, quant_type):
self.linear_type = linear_type
self.num_local_experts = num_local_experts
self.input_size = input_size
self.output_size = output_size
self.has_bias = None
self.skip_bias_add = False
self.transpose_b = True
self.quant_type = quant_type

def name(self):
return f"{self.linear_type}-has_bias_{self.has_bias}-" \
f"quant_type_{self.quant_type}"
"""Generate a unique name for this layer configuration."""
return f"{self.linear_type}-has_bias_{self.has_bias}-quant_type_{self.quant_type}"

def infer_shape(self):
"""Infer output shape. Returns output size."""
return self.output_size

@staticmethod
def convert_pt_dtype_to_ms(pt_dtype: str):
"""Convert PyTorch dtype to MindSpore dtype."""
dtype_mapping = {
'fp32': ms.dtype.float32,
'fp16': ms.dtype.float16,
'bf16': ms.dtype.bfloat16,
}
mstype = dtype_mapping.get(pt_dtype, None)
if mstype is None:
raise ValueError(f"Unsupported pytorch dtype: {pt_dtype}")
return mstype

class ModelSpec:
"""Specification for the entire model."""

def __init__(self, compute_dtype, param_init_dtype, tensor_parallel, linear_specs):
self.linear_specs = linear_specs
self.compute_dtype = compute_dtype
@@ -148,7 +146,9 @@ class ModelSpec:
self.tensor_parallel = tensor_parallel


class TestPretrainedConfig(PretrainedConfig):
class SimplePretrainedConfig(PretrainedConfig):
"""Simple pretrained config for testing."""

def __init__(self, quantization, pretrained_model_dir):
super().__init__(
quantization=quantization,
@@ -156,16 +156,20 @@ class TestPretrainedConfig(PretrainedConfig):
)


class GPTModelForTest(ms.nn.Cell):
"""A model designed for testing parallel linear operations."""
class SimpleGPTModel(ms.nn.Cell):
"""A simple GPT model for testing parallel linear operations."""

def __init__(self, model_spec, comm_pgs, quantization: str, quant_model_dir=None):
super().__init__()
self.model_spec = model_spec

# Setup quantization config
if quant_model_dir is None:
quant_config = None
else:
quant_config = get_quant_config(TestPretrainedConfig(quantization, quant_model_dir), [])
quant_config = get_quant_config(SimplePretrainedConfig(quantization, quant_model_dir), [])

# Setup transformer config
transformer_config = TransformerConfig(
tensor_model_parallel_size=model_spec.tensor_parallel,
compute_dtype=model_spec.compute_dtype,
@@ -173,9 +177,16 @@ class GPTModelForTest(ms.nn.Cell):
num_layers=1,
num_attention_heads=model_spec.tensor_parallel,
)
self.linears = GPTModelForTest._build_linears(comm_pgs, model_spec, transformer_config, quant_config)

self.linears = self._build_linears(comm_pgs, model_spec, transformer_config, quant_config)
self.num_linears = len(self.linears)

def process_weights_after_loading(self):
"""Process weights after loading - convert format if needed."""
for cell in self.linears:
if hasattr(cell, 'quant_method') and cell.quant_method is not None:
cell.quant_method.process_weights_after_loading(cell)

@staticmethod
def _build_linears(comm_pgs, model_spec, transformer_config, quant_config):
"""Build a list of linear layers based on the model specifications."""
@@ -187,94 +198,130 @@ class GPTModelForTest(ms.nn.Cell):
"RowParallelLinear": RowParallelLinear,
"ReplicatedLinear": ReplicatedLinear,
}

linears = []
for index, linear_spec in enumerate(model_spec.linear_specs):
if linear_spec.linear_type=="QKVParallelLinear":
linear = linear_map[linear_spec.linear_type](
hidden_size=linear_spec.input_size,
head_size=linear_spec.head_size,
total_num_heads=linear_spec.total_num_heads,
total_num_kv_heads=linear_spec.total_num_kv_heads,
config=transformer_config,
compute_dtype=linear_spec.compute_dtype,
transpose_b=linear_spec.transpose_b,
bias=linear_spec.has_bias,
tp_group=comm_pgs.tp,
quant_config=quant_config,
prefix=f"linears.{index}"
)
elif linear_spec.linear_type=="ColumnParallelGroupedLinear":
if quant_config is None:
quant_method: Optional[QuantizeMethodBase] = UnquantizedGroupedLinearMethod()
weight = quant_method.create_weights(
linear = SimpleGPTModel._build_single_linear(
linear_spec, index, linear_map, comm_pgs,
transformer_config, quant_config
)
linears.append(linear)

return ms.nn.SequentialCell(linears)

@staticmethod
def _build_single_linear(linear_spec, index, linear_map, comm_pgs, transformer_config, quant_config):
"""Build a single linear layer based on its specification."""
linear_type = linear_spec.linear_type
prefix = f"linears.{index}"

if linear_type == "QKVParallelLinear":
return linear_map[linear_type](
hidden_size=linear_spec.input_size,
head_size=linear_spec.head_size,
total_num_heads=linear_spec.total_num_heads,
total_num_kv_heads=linear_spec.total_num_kv_heads,
config=transformer_config,
compute_dtype=linear_spec.compute_dtype,
transpose_b=linear_spec.transpose_b,
bias=linear_spec.has_bias,
tp_group=comm_pgs.tp,
quant_config=quant_config,
prefix=prefix
)

if linear_type == "MergedColumnParallelLinear":
return linear_map[linear_type](
hidden_size=linear_spec.input_size,
ffn_hidden_size=linear_spec.output_size,
config=transformer_config,
bias=linear_spec.has_bias,
gather_output=True,
transpose_b=linear_spec.transpose_b,
compute_dtype=linear_spec.compute_dtype,
tp_group=comm_pgs.tp,
quant_config=quant_config,
prefix=prefix
)

if linear_type == "ColumnParallelGroupedLinear":
# Create weights for grouped linear
if quant_config is None:
quant_method = UnquantizedGroupedLinearMethod()
weight = quant_method.create_weights(
layer=None,
num_local_experts=linear_spec.num_local_experts,
input_size_per_partition=linear_spec.input_size,
output_partition_sizes=[linear_spec.output_size],
params_dtype=ms.bfloat16
)
else:
quant_method = quant_config.get_quant_method(quant_config, f"linears.{index}")
weight = quant_method.create_weights(
)
else:
quant_method = quant_config.get_quant_method(quant_config, prefix)
weight = quant_method.create_weights(
layer=None,
num_local_experts=linear_spec.num_local_experts,
input_size_per_partition=linear_spec.input_size,
output_partition_sizes=[linear_spec.output_size],
params_dtype="bf16"
)
linear = linear_map[linear_spec.linear_type](
num_local_experts=linear_spec.num_local_experts,
input_size=linear_spec.input_size,
output_size=linear_spec.output_size,
config=transformer_config,
weight=weight,
bias=linear_spec.has_bias,
tp_group=comm_pgs.tp,
quant_config=quant_config,
prefix=f"linears.{index}"
)
set_weight_attrs(weight, {"weight_loader": linear.weight_loader})
else:
linear = linear_map[linear_spec.linear_type](
input_size=linear_spec.input_size,
output_size=linear_spec.output_size,
config=transformer_config,
skip_bias_add=linear_spec.skip_bias_add,
compute_dtype=linear_spec.compute_dtype,
transpose_b=linear_spec.transpose_b,
bias=linear_spec.has_bias,
tp_group=comm_pgs.tp,
quant_config=quant_config,
prefix=f"linears.{index}"
)
linears.append(linear)
return ms.nn.SequentialCell(linears)

linear = linear_map[linear_type](
num_local_experts=linear_spec.num_local_experts,
input_size=linear_spec.input_size,
output_size=linear_spec.output_size,
config=transformer_config,
weight=weight,
bias=linear_spec.has_bias,
tp_group=comm_pgs.tp,
quant_config=quant_config,
prefix=prefix
)
set_weight_attrs(weight, {"weight_loader": linear.weight_loader})
return linear

# Standard linear layers (ColumnParallelLinear, RowParallelLinear, ReplicatedLinear)
return linear_map[linear_type](
input_size=linear_spec.input_size,
output_size=linear_spec.output_size,
config=transformer_config,
skip_bias_add=linear_spec.skip_bias_add,
compute_dtype=linear_spec.compute_dtype,
transpose_b=linear_spec.transpose_b,
bias=linear_spec.has_bias,
tp_group=comm_pgs.tp,
quant_config=quant_config,
prefix=prefix
)

def forward(self, x):
"""Forward pass through the model, processing input through all linear layers."""
output = self.construct(x).astype(ms.dtype.float32).asnumpy()
bs = output.shape[0]
if bs != self.num_linears:
raise ValueError(f"outputs size must be equal to the number of linears: {bs} != {self.num_linears}")
outputs = np.split(output, bs, axis=0)
outputs = self.construct(x)

# Process each layer's output into a dictionary
output_dict = {}
for index, linear_spec in enumerate(self.model_spec.linear_specs):
name = f"index_{index}-{linear_spec.name()}"
output_dict[name] = outputs[index].squeeze(axis=0)
output_dict[name] = outputs[index].astype(ms.dtype.float32).asnumpy()

return output_dict

def construct(self, x):
"""Forward pass through one layer."""
y = ms.ops.zeros_like(x)
y = y.expand_dims(axis=0)
"""Forward pass through all layers, returns a list of outputs."""
outputs = []
for index in range(self.num_linears):
linear = self.linears[index]

# Special handling for grouped linear (MoE)
if isinstance(linear, ColumnParallelGroupedLinear):
group_list = np.random.multinomial(x.shape[0],
np.ones(linear.num_local_experts)/linear.num_local_experts)
group_list = np.random.multinomial(
x.shape[0],
np.ones(linear.num_local_experts) / linear.num_local_experts
)
group_list = ms.Tensor(group_list)
z = linear(x,group_list=group_list).expand_dims(axis=0)
z = linear(x, group_list=group_list)
else:
z = linear(x).expand_dims(axis=0)
y = ms.ops.concat((y, z))
return y[1:,::]
z = linear(x)

outputs.append(z)

return outputs

+ 412
- 0
tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization/simple_mcore.py View File

@@ -0,0 +1,412 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Simple MCore parallel linear inference runner with YAML configuration"""


import argparse
import glob
import os
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple

import yaml
import numpy as np
from safetensors import safe_open
from safetensors.numpy import save_file

import mindspore as ms
from mindspore.communication import init
from numpy_quantizer import NumpyQuantizer
from simple_gpt_model import (SimpleGPTModel, LinearSpec, ModelSpec,
QKVLinearSpec, GroupLinearSpec, convert_dtype_str_to_ms)
from mindformers.parallel_core.inference.parallel_state import initialize_model_parallel
from mindformers.parallel_core.process_group_config import ModelCommProcessGroups


@dataclass
class WeightLoadingInfo:
"""Information about how to load a weight parameter."""
cleaned_key: str
shard_id: Optional[str] = None
expert_id: Optional[int] = None
layer_type: str = "standard" # "standard", "qkv", "merged", "grouped"


class WeightKeyParser:
"""Parser for weight keys to extract loading information."""

# Define weight patterns with their parsing rules
WEIGHT_PATTERNS = [
(".gate", "w1", 0, "grouped"),
(".up", "w3", 0, "grouped"),
(".gating", "gating", None, "merged"),
(".hidden", "hidden", None, "merged"),
(".q.", "q", None, "qkv"),
(".k.", "k", None, "qkv"),
(".v.", "v", None, "qkv"),
]

@staticmethod
def parse(weight_key: str) -> WeightLoadingInfo:
"""Parse a weight key and return loading information."""
for pattern, shard_id, expert_id, layer_type in WeightKeyParser.WEIGHT_PATTERNS:
if pattern in weight_key:
cleaned_key = weight_key.replace(pattern.rstrip('.'), '')
return WeightLoadingInfo(cleaned_key, shard_id, expert_id, layer_type)

# Default: standard layer without special shard_id
return WeightLoadingInfo(weight_key, None, None, "standard")


class WeightGenerator:
"""Generator for creating random weights for different layer types."""

@staticmethod
def generate(model_spec: ModelSpec) -> dict:
"""Generate random float weights for all layers in the model spec."""
np.random.seed(42)
weights = {}

for index, linear_spec in enumerate(model_spec.linear_specs):
layer_weights = WeightGenerator._generate_for_layer(index, linear_spec)
weights.update(layer_weights)

return weights

@staticmethod
def _generate_for_layer(index: int, linear_spec) -> dict:
"""Generate weights for a single layer based on its type."""
layer_type = linear_spec.linear_type

if layer_type == "QKVParallelLinear":
return WeightGenerator._generate_qkv(index, linear_spec)
if layer_type == "ColumnParallelGroupedLinear":
return WeightGenerator._generate_grouped(index, linear_spec)
if layer_type == "MergedColumnParallelLinear":
return WeightGenerator._generate_merged(index, linear_spec)
return WeightGenerator._generate_standard(index, linear_spec)

@staticmethod
def _generate_qkv(index: int, spec) -> dict:
"""Generate weights for QKV parallel linear."""
weights = {}
qkv_names = ["q", "k", "v"]

for name, output_size in zip(qkv_names, spec.output_sizes):
weight_shape = (output_size, spec.input_size)
weights[f"linears.{index}.{name}.weight"] = (
0.01 * np.random.randn(*weight_shape).astype(np.float32)
)

if spec.has_bias:
weights[f"linears.{index}.{name}.bias"] = (
0.01 * np.random.randn(output_size).astype(np.float32)
)

return weights

@staticmethod
def _generate_grouped(index: int, spec) -> dict:
"""Generate weights for grouped linear (MoE)."""
weights = {}
half_size = spec.output_size // 2

for name in ["gate", "up"]:
weight_shape = (half_size, spec.input_size)
weights[f"linears.{index}.{name}.weight"] = 0.01 * np.random.randn(*weight_shape).astype(np.float32)

return weights

@staticmethod
def _generate_merged(index: int, spec) -> dict:
"""Generate weights for merged column parallel linear."""
weights = {}
weight_shape = (spec.output_size, spec.input_size)

for name in ["gating", "hidden"]:
weights[f"linears.{index}.{name}.weight"] = 0.01 * np.random.randn(*weight_shape).astype(np.float32)

if spec.has_bias:
weights[f"linears.{index}.{name}.bias"] = 0.01 * np.random.randn(spec.output_size).astype(np.float32)

return weights

@staticmethod
def _generate_standard(index: int, spec) -> dict:
"""Generate weights for standard linear layers."""
weights = {}
weight_shape = (spec.output_size, spec.input_size)

weights[f"linears.{index}.weight"] = 0.01 * np.random.randn(*weight_shape).astype(np.float32)

if spec.has_bias:
weights[f"linears.{index}.bias"] = 0.01 * np.random.randn(spec.output_size).astype(np.float32)

return weights


class WeightLoader:
"""Loader for network weights with support for different layer types."""

@staticmethod
def load_into_network(network, weights: dict):
"""Load weights into network parameters."""
params = network.parameters_dict()
print(params)
loaded = []

for original_key, weight_value in weights.items():
load_info = WeightKeyParser.parse(original_key)
param = params.get(load_info.cleaned_key)

if param is None:
continue

WeightLoader._load_single_weight(param, weight_value, load_info)
loaded.append(original_key)

# Report loading status
WeightLoader._report_loading_status(weights, loaded, params)

@staticmethod
def _load_single_weight(param, weight_value, load_info: WeightLoadingInfo):
"""Load a single weight into a parameter based on its type."""
if load_info.layer_type == "grouped":
# ColumnParallelGroupedLinear: needs both shard_id and expert_id
param.weight_loader(param, weight_value, load_info.shard_id, load_info.expert_id)
elif load_info.shard_id is not None:
# QKV or Merged layers: needs shard_id only
param.weight_loader(param, weight_value, load_info.shard_id)
else:
# Standard layers: no special arguments
param.weight_loader(param, weight_value)

@staticmethod
def _report_loading_status(weights: dict, loaded: list, params: dict):
"""Report which weights were not used and which params were not loaded."""
weights_not_used = set(weights.keys()) - set(loaded)
params_not_loaded = set(params.keys()) - set(loaded)

if weights_not_used:
print(f"weights not used: {weights_not_used}", flush=True)
if params_not_loaded:
print(f"params not loaded: {params_not_loaded}", flush=True)


class LinearSpecFactory:
"""Factory for creating LinearSpec instances based on layer type."""

@staticmethod
def create(layer_type: str, has_bias: bool, quant_policy: str, config) -> object:
"""Create appropriate LinearSpec based on layer type."""
if layer_type == "QKVParallelLinear":
return QKVLinearSpec(
layer_type, config.input_size, config.head_size,
config.total_num_heads, config.total_num_kv_heads,
has_bias, config.compute_dtype, quant_policy
)
if layer_type == "ColumnParallelGroupedLinear":
return GroupLinearSpec(
layer_type, config.num_local_experts, config.input_size,
config.output_size, quant_policy
)
# Standard and merged layers use LinearSpec
return LinearSpec(
layer_type, config.input_size, config.output_size,
has_bias, config.compute_dtype, quant_policy
)


class SimpleMCoreRunner:
"""Simple runner for MCore parallel linear layers with quantization support."""

def __init__(self, config):
"""Initialize the simple MCore runner with given arguments."""
self.config = config
self._setup_parallel_context()
self._load_config_and_build_model()

def _setup_parallel_context(self):
"""Setup parallel computing context."""
rank_id_str = os.environ.get("RANK_ID")
self.rank_id = int(rank_id_str) if rank_id_str is not None else None
self.worker_num = int(os.environ.get("MS_WORKER_NUM", "1"))
self.model_comm_pgs = ModelCommProcessGroups.get_default_model_comm_pgs()

if self.rank_id is not None:
init()
initialize_model_parallel(tensor_model_parallel_size=self.config.tensor_parallel)
self.model_comm_pgs = ModelCommProcessGroups.use_parallel_state_groups(required_groups=['tp'])

def _load_config_and_build_model(self):
"""Load YAML configuration and build model specification."""
# Load YAML configuration
config_path = Path(self.config.config_file)
with open(config_path, 'r', encoding='utf-8') as f:
yaml_config = yaml.safe_load(f)

# Extract model configuration
model_config = yaml_config['model_config']
self._set_model_config(model_config)

# Build linear specs from test cases
test_cases = yaml_config['test_cases']
linear_specs, quant_policies = self._build_linear_specs(test_cases)

# Create model spec
self.model_spec = ModelSpec(
compute_dtype=self.compute_dtype,
param_init_dtype=self.param_init_dtype,
tensor_parallel=self.config.tensor_parallel,
linear_specs=linear_specs,
)

# Setup quantization if needed
self.quantization = self.config.quantization
self.quant_model_dir = None
if self.quantization == 'golden-stick':
self.quantizer = NumpyQuantizer(self.model_spec, quant_policies)
self.quant_model_dir = tempfile.mkdtemp(prefix="quant_model_for_test_")

def _set_model_config(self, model_config: dict):
"""Set model configuration from YAML."""
self.input_size = model_config['input_size']
self.output_size = model_config['output_size']
self.head_size = model_config['head_size']
self.total_num_heads = model_config['total_num_heads']
self.total_num_kv_heads = model_config['total_num_kv_heads']
self.compute_dtype = model_config['compute_dtype']
self.param_init_dtype = model_config['param_init_dtype']
self.num_local_experts = model_config['num_local_experts']

def _build_linear_specs(self, test_cases: list) -> Tuple[list, list]:
"""Build linear specs and quant policies from test cases."""
linear_specs = []
quant_policies = []

for case in test_cases:
quant_policy = case['quant_policy']
quant_policy = quant_policy if self.config.quantization == 'golden-stick' else 'float'

linear_spec = LinearSpecFactory.create(
case['linear_type'],
case['has_bias'],
quant_policy,
self
)

linear_specs.append(linear_spec)
quant_policies.append(quant_policy)

return linear_specs, quant_policies

def _create_network(self):
"""Create the network model for testing."""
return SimpleGPTModel(self.model_spec, self.model_comm_pgs, self.quantization, self.quant_model_dir)

def _load_quant_weights(self) -> dict:
"""Load quantized weights from the model directory."""
if not os.path.isdir(self.quant_model_dir):
raise ValueError(f"Invalid quant_model_dir: {self.quant_model_dir}")

safetensor_files = glob.glob(os.path.join(self.quant_model_dir, "*.safetensors"))
if len(safetensor_files) != 1:
raise FileNotFoundError(
f"Expected 1 safetensor file in {self.quant_model_dir}, found {len(safetensor_files)}"
)

safetensor_file = safetensor_files[0]
with safe_open(safetensor_file, framework="np", device="cpu") as f:
return {key: f.get_slice(key) for key in f.keys()}

def _prepare_weights(self, weights: dict) -> dict:
"""Prepare weights for loading (convert to safetensors if needed)."""
first_value = next(iter(weights.values()))

# MoE must use safetensors format
if isinstance(first_value, np.ndarray):
with tempfile.TemporaryDirectory() as temp_dir:
path = os.path.join(temp_dir, "model.safetensors")
save_file(weights, path)
weights.clear()
with safe_open(path, framework="np", device="cpu") as f:
weights = {key: f.get_slice(key) for key in f.keys()}

return weights

def run(self):
"""Run the simple MCore test."""
# Generate input and weights
input_data = self._generate_input()
weights = WeightGenerator.generate(self.model_spec)

# Apply quantization if needed
if self.quantization == 'golden-stick':
self.quantizer.quant(input_data, weights, self.quant_model_dir)
weights = self._load_quant_weights()

# Create network and load weights
network = self._create_network()
weights = self._prepare_weights(weights)
WeightLoader.load_into_network(network, weights)

# Process weights after loading (e.g., format conversion for custom ops)
if hasattr(network, 'process_weights_after_loading'):
network.process_weights_after_loading()

# Run inference
net_input = self._create_tensor(input_data)
output_dict = network.forward(net_input)

# Save output
if self.rank_id is None or int(self.rank_id) == 0:
np.savez(self.config.output_path, **output_dict)

def _generate_input(self) -> np.ndarray:
"""Generate random input data."""
np.random.seed(42)
batch_size = 4
return 0.01 * np.random.randn(batch_size, self.model_spec.linear_specs[0].input_size).astype(np.float32)

def _create_tensor(self, data: np.ndarray) -> ms.Tensor:
"""Create a MindSpore tensor with appropriate dtype."""
dtype = self.model_spec.compute_dtype
if isinstance(dtype, str):
dtype = convert_dtype_str_to_ms(dtype)
return ms.Tensor(data, dtype=dtype)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run simple MCore parallel linear test with YAML configuration")
parser.add_argument("--config_file", type=str, required=True,
help="Path to YAML configuration file")
parser.add_argument("--tensor_parallel", type=int, default=1,
help="Tensor parallel size")
parser.add_argument("--output_path", type=str, default="output.npz",
help="Output file path")
parser.add_argument("--quantization", type=str, default=None,
help="Quantization method (e.g., 'golden-stick')")

args = parser.parse_args()

ms.set_context(device_target="Ascend",
mode=ms.GRAPH_MODE,
jit_config={"jit_level": "O0", "infer_boost": "on"},
deterministic="ON")

runner = SimpleMCoreRunner(args)
runner.run()

+ 108
- 0
tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization/test_configs.yaml View File

@@ -0,0 +1,108 @@
# Test configurations for parallel linear layers quantization
# Each test case is a dictionary with linear_type, has_bias, quant_policy, and optional precision thresholds

# Model configuration parameters
model_config:
input_size: 2048
output_size: 2048
head_size: 16 # 修改为16以满足 trans_data 的32对齐要求 (16*2*3=96, 96%32=0)
total_num_heads: 2
total_num_kv_heads: 2
compute_dtype: bf16
param_init_dtype: bf16
num_local_experts: 1

# Default precision thresholds (used when not specified in test case)
default_precision:
cos_sim_thd: 0.999
l1_norm_thd: 0.01
kl_dvg_thd: 0.01

# All test cases
test_cases:
# ColumnParallelLinear tests
- linear_type: ColumnParallelLinear
has_bias: true
quant_policy: a8w8
- linear_type: ColumnParallelLinear
has_bias: true
quant_policy: a8dynw8
- linear_type: ColumnParallelLinear
has_bias: false
quant_policy: a8w8
- linear_type: ColumnParallelLinear
has_bias: false
quant_policy: a8dynw8
# RowParallelLinear tests
- linear_type: RowParallelLinear
has_bias: true
quant_policy: a8w8
- linear_type: RowParallelLinear
has_bias: true
quant_policy: a8dynw8
- linear_type: RowParallelLinear
has_bias: false
quant_policy: a8w8
- linear_type: RowParallelLinear
has_bias: false
quant_policy: a8dynw8
# QKVParallelLinear tests
- linear_type: QKVParallelLinear
has_bias: true
quant_policy: a8w8
- linear_type: QKVParallelLinear
has_bias: true
quant_policy: a8dynw8
- linear_type: QKVParallelLinear
has_bias: false
quant_policy: a8w8
# Custom precision for this edge case
precision:
cos_sim_thd: 0.996
l1_norm_thd: 0.01
kl_dvg_thd: 0.01
- linear_type: QKVParallelLinear
has_bias: false
quant_policy: a8dynw8
# MergedColumnParallelLinear tests
- linear_type: MergedColumnParallelLinear
has_bias: true
quant_policy: a8w8
- linear_type: MergedColumnParallelLinear
has_bias: true
quant_policy: a8dynw8
- linear_type: MergedColumnParallelLinear
has_bias: false
quant_policy: a8w8
- linear_type: MergedColumnParallelLinear
has_bias: false
quant_policy: a8dynw8

# ColumnParallelGroupedLinear (MoE) tests
- linear_type: ColumnParallelGroupedLinear
has_bias: null
quant_policy: a8dynw8

- linear_type: ColumnParallelGroupedLinear
has_bias: null
quant_policy: a8w4
# Relaxed thresholds for int4 quantization
precision:
cos_sim_thd: 0.96
l1_norm_thd: 0.05
kl_dvg_thd: 0.01

+ 62
- 28
tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization/test_parallel_linear.py View File

@@ -15,17 +15,18 @@
"""Test ColumnParallelLinear with various configurations"""


from typing import Optional
from pathlib import Path
import subprocess

import pytest
import numpy as np
import yaml
from tests.utils.precision_utils import PrecisionChecker
from mindformers.tools.logger import logger


def build_msrun_command_list(linear_types, log_dir, run_script_path, output_path_param, tensor_parallel,
port, quantization, quant_policies:Optional[list]=None):
def build_msrun_command_list(log_dir, run_script_path, output_path_param, tensor_parallel,
port, quantization, config_yaml_path):
""" Build the msrun command with the specified parameters. """
if tensor_parallel == 1:
cmd_list = ["python"]
@@ -34,24 +35,19 @@ def build_msrun_command_list(linear_types, log_dir, run_script_path, output_path
"msrun",
f"--worker_num={tensor_parallel}",
f"--local_worker_num={tensor_parallel}",
f"--master_port={port}", # Ensure port is unique per test run if parallelized at pytest level
f"--master_port={port}",
f"--log_dir={log_dir}",
"--join=True",
]

cmd_list += [
str(run_script_path),
f"--config_file={config_yaml_path}",
f"--output_path={output_path_param}",
f"--tensor_parallel={tensor_parallel}",
]
for linear_type in linear_types:
cmd_list.append(f"--linear_types={linear_type}")
for quant_policy in quant_policies:
cmd_list.append(f"--quant_policies={quant_policy}")
if quantization is not None:
cmd_list.append(f"--quantization={quantization}")
if quant_policies is None:
raise RuntimeError("quant_policies must be provided when quantization is enabled.")

logger.info(f"Equivalent shell command for debugging (approximate): {' '.join(cmd_list)}")
return cmd_list
@@ -62,22 +58,25 @@ class TestParallelLinear:
def setup_method(self):
"""Setup method to prepare test environment"""
self.sh_path = Path(__file__).parent.resolve()
self.run_script_path = self.sh_path / "run_parallel_linear.py"
self.run_script_path = self.sh_path / "simple_mcore.py"
self.log_file_path = self.sh_path / 'test_output' / 'logs'
self.log_file_path.mkdir(parents=True, exist_ok=True)

def infer(self, linear_types, log_dir_path, output_file_path, tensor_parallel, port, quantization,
quant_policies=None):
# Load test configurations from yaml
self.config_file = self.sh_path / "test_configs.yaml"
with open(self.config_file, 'r', encoding='utf-8') as f:
self.configs = yaml.safe_load(f)

def infer(self, log_dir_path, output_file_path, tensor_parallel, port, quantization, config_yaml_path):
"""Run inference with the specified parameters and check for output file."""
cmd_list = build_msrun_command_list(
linear_types=linear_types,
log_dir=log_dir_path,
run_script_path=self.run_script_path,
output_path_param=output_file_path,
tensor_parallel=tensor_parallel,
port=port,
quantization=quantization,
quant_policies=quant_policies,
config_yaml_path=config_yaml_path,
)

result = subprocess.run(
@@ -91,37 +90,75 @@ class TestParallelLinear:
f"Output file {output_file_path} was not created."
)

def run_test(self, linear_types, quant_policies, tmp_path, tensor_parallel=1, port=8118):
"""Helper function to run test and check results"""
def run_test_from_yaml(self, test_cases_key, tmp_path, tensor_parallel=1, port=8118):
"""Run test based on yaml configurations."""
test_cases = self.configs[test_cases_key]
default_precision = self.configs['default_precision']

# Build precision map: key -> (linear_type, has_bias, quant_policy)
precision_map = {}
for case in test_cases:
linear_type = case['linear_type']
has_bias = case['has_bias']
quant_policy = case['quant_policy']
precision = case.get('precision', default_precision)
key = (linear_type, has_bias, quant_policy)
precision_map[key] = precision

# Run quantized inference
output_file_path = tmp_path / 'quant-output.npz'
self.infer(
linear_types=linear_types,
log_dir_path=self.log_file_path,
output_file_path=output_file_path,
tensor_parallel=tensor_parallel,
port=port,
quantization='golden-stick',
quant_policies=quant_policies,
config_yaml_path=self.config_file,
)
quant_output = np.load(output_file_path)

# Run float inference
output_file_path = tmp_path / 'float-output.npz'
self.infer(
linear_types=linear_types,
log_dir_path=self.log_file_path,
output_file_path=output_file_path,
tensor_parallel=tensor_parallel,
port=port+1,
quantization=None,
quant_policies=quant_policies,
config_yaml_path=self.config_file,
)
float_output = np.load(output_file_path)
checker = PrecisionChecker()

# Check precision for each output
succeed = True
for key in quant_output:
fkey = key[:key.rfind('-')] + '-quant_type_float'
if fkey not in float_output:
raise ValueError(f"Diff key in quant_output but not in float_output: {key}")

# Parse key to get linear_type, has_bias, and quant_policy
# key format: index_{index}-{linear_type}-has_bias_{has_bias}-compute_dtype_{dtype}-quant_type_{policy}
parts = key.split('-')
linear_type = parts[1]
has_bias_str = parts[2].split('_')[-1]
if has_bias_str == 'None':
has_bias = None
else:
has_bias = has_bias_str == 'True'
# Extract quant_policy from "quant_type_POLICY" format
quant_policy = parts[-1].split('_', 2)[-1] # Split by '_' and get the last part after 'quant_type_'

# Get precision config for this specific case
config_key = (linear_type, has_bias, quant_policy)
precision = precision_map.get(config_key, default_precision)

# Create checker with appropriate thresholds
checker = PrecisionChecker(
cos_sim_thd=precision['cos_sim_thd'],
l1_norm_thd=precision['l1_norm_thd'],
kl_dvg_thd=precision['kl_dvg_thd']
)

try:
checker.check_precision(float_output[fkey], quant_output[key])
print(f"Check precision for {key} succeed", flush=True)
@@ -133,9 +170,6 @@ class TestParallelLinear:
@pytest.mark.level0
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_single_card_configurations(self, tmp_path):
"""Test single card with various configurations."""
linear_types = ["ColumnParallelLinear", "RowParallelLinear"]
quant_policies = ["a8w8", "a8dynw8"]
self.run_test(linear_types=linear_types, quant_policies=quant_policies,
tmp_path=tmp_path, port=8888)
def test_parallel_linear_quantization(self, tmp_path):
"""Test parallel linear layers with various configurations from yaml."""
self.run_test_from_yaml('test_cases', tmp_path, tensor_parallel=1, port=8888)

+ 0
- 224
tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/numpy_quantizer.py View File

@@ -1,224 +0,0 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""NumpyQuantizer for test."""


import json
import os
import numpy as np
from safetensors.numpy import save_file
from gpt_model_for_test import ModelSpec


class NumpyQuantizer:
"""A class for quantizing model weights using NumPy."""

def __init__(self, model_spec: ModelSpec, quant_policy: list[str]):
self.model_spec = model_spec
self.quant_policy = quant_policy
self.description_file_path = None
self.global_group_size = None

def quant(self, quant_input: np.ndarray, weights, save_dir):
"""Quantize the input and weights, save to safetensors and JSON description."""
quant_weights, quant_desc = self._quant(quant_input, weights)
print(f"quant_weights: {quant_weights.keys()}", flush=True)
print(f"quant_desc: {quant_desc}", flush=True)
save_file(quant_weights, os.path.join(save_dir, 'quant-model-00001-00001.safetensors'))
with open(os.path.join(save_dir, "quantization_description.json"), "w", encoding='utf-8') as f:
json.dump(quant_desc, f, indent=2, ensure_ascii=False)
print(f"quantization weights saved to {save_dir}", flush=True)

def _quant(self, quant_input: np.ndarray, weights):
"""Internal method to perform quantization on weights based on policy."""
quant_weights = {}
quant_desc = {}
for index, (qpolicy, linear_spec) in enumerate(zip(self.quant_policy, self.model_spec.linear_specs)):
if qpolicy == 'a8w8':
weight = weights[f"linears.{index}.weight"]
_, input_scale, input_offset = NumpyQuantizer._act_int8_quant(quant_input)
quant_weight, w_scale = NumpyQuantizer._weight_int8_quant(weight, transpose_b=linear_spec.transpose_b)
x_zp = input_offset.astype(np.int32) # per-tensor zero-point
quant_bias = -np.sum(x_zp * quant_weight.astype(np.int32), axis=-1).astype(np.int32)
deq_scale = (input_scale.astype(np.float32) * w_scale.astype(np.float32))
beta = np.zeros(linear_spec.output_size, dtype=np.int32)
quant_weights.update({
f"linears.{index}.weight": quant_weight,
f"linears.{index}.deq_scale": deq_scale,
f"linears.{index}.input_scale": np.tile(input_scale, linear_spec.output_size),
f"linears.{index}.input_offset": np.tile(input_offset, linear_spec.output_size),
f"linears.{index}.quant_bias": quant_bias,
f"linears.{index}.beta": beta,
})
quant_desc.update({
f"linears.{index}.weight": "W8A8",
f"linears.{index}.deq_scale": "W8A8",
f"linears.{index}.input_scale": "W8A8",
f"linears.{index}.input_offset": "W8A8",
f"linears.{index}.quant_bias": "W8A8",
f"linears.{index}.beta": "W8A8",
})
if linear_spec.has_bias:
quant_weights[f"linears.{index}.bias"] = weights[f"linears.{index}.bias"]
quant_desc[f"linears.{index}.bias"] = "W8A8"
continue
if qpolicy == 'a8dynw8':
is_grouped = linear_spec.linear_type in ("ColumnParallelGroupedLinear", "RowParallelGroupedLinear")
if not is_grouped:
weight = weights[f"linears.{index}.weight"]
quant_weight, w_scale = NumpyQuantizer._weight_int8_quant(weight,
transpose_b=linear_spec.transpose_b)
quant_weights.update({
f"linears.{index}.weight": quant_weight,
f"linears.{index}.w_scale": w_scale
})
quant_desc.update({
f"linears.{index}.weight": "W8A8_DYNAMIC",
f"linears.{index}.w_scale": "W8A8_DYNAMIC",
})
else:
quant_weight_gate, w_scale_gate = NumpyQuantizer._weight_int8_quant(
weights[f"linears.{index}.gate.weight"], transpose_b=True)
quant_weight_up, w_scale_up = NumpyQuantizer._weight_int8_quant(
weights[f"linears.{index}.up.weight"], transpose_b=True)
quant_weights.update({
f"linears.{index}.gate.weight": quant_weight_gate,
f"linears.{index}.gate.w_scale": w_scale_gate,
f"linears.{index}.up.weight": quant_weight_up,
f"linears.{index}.up.w_scale": w_scale_up,
})
quant_desc.update({
f"linears.{index}.weight": "W8A8_DYNAMIC",
f"linears.{index}.w_scale": "W8A8_DYNAMIC",
})
if linear_spec.has_bias:
quant_weights[f"linears.{index}.bias"] = weights[f"linears.{index}.bias"]
quant_desc[f"linears.{index}.bias"] = "W8A8_DYNAMIC"
continue
if qpolicy == 'a8w4':
group_size = 256
self.global_group_size = group_size
is_grouped = linear_spec.linear_type in ("ColumnParallelGroupedLinear", "RowParallelGroupedLinear")
if not is_grouped:
raise ValueError("a8w4 quantization only support grouped linear")
qweight_packed_gate, w_scale_uint64_gate = NumpyQuantizer._weight_int4_per_group_pack(
weights[f"linears.{index}.gate.weight"], group_size, transpose_b=True)
qweight_packed_up, w_scale_uint64_up = NumpyQuantizer._weight_int4_per_group_pack(
weights[f"linears.{index}.up.weight"], group_size, transpose_b=True)
quant_weights.update({
f"linears.{index}.gate.weight": qweight_packed_gate,
f"linears.{index}.gate.w_scale": w_scale_uint64_gate,
f"linears.{index}.up.weight": qweight_packed_up,
f"linears.{index}.up.w_scale": w_scale_uint64_up,
})
quant_desc.update({
f"linears.{index}.weight": "W4A8_DYNAMIC",
f"linears.{index}.w_scale": "W4A8_DYNAMIC",
})
if linear_spec.has_bias:
quant_weights[f"linears.{index}.bias"] = weights[f"linears.{index}.bias"]
quant_desc[f"linears.{index}.bias"] = "W4A8_DYNAMIC"
continue
if qpolicy is None:
weight = weights[f"linears.{index}.weight"]
quant_weights.update({
f"linears.{index}.weight": weight,
})
quant_desc.update({
f"linears.{index}.weight": "FLOAT",
})
if linear_spec.has_bias:
quant_weights[f"linears.{index}.bias"] = weights[f"linears.{index}.bias"]
quant_desc[f"linears.{index}.bias"] = "FLOAT"
continue
raise ValueError(f"Unsupported quant policy: {qpolicy}")
if self.global_group_size is not None:
quant_desc["group_size"] = int(self.global_group_size)
return quant_weights, quant_desc

@staticmethod
def _get_quant_min_max(num_bits=8, signed=True, narrow_range=False):
"""Calculate quantization params for minimum/maximum quantization integer"""
if signed:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
return quant_min, quant_max

@staticmethod
def _act_int8_quant(tensor):
"""Quantize activation tensor to int8."""
bits=8
quant_min, quant_max = NumpyQuantizer._get_quant_min_max(bits)

min_val = np.min(tensor)
max_val = np.max(tensor)

if (max_val == min_val).all():
scale = np.array([1.0], dtype=np.float32)
zero_point = np.array([0.0], dtype=np.float32)
else:
min_val = min_val.astype(np.float64)
max_val = max_val.astype(np.float64)
scale = (max_val - min_val) / (quant_max - quant_min)
zero_point = quant_min - min_val / scale.astype(np.float32)
scale = scale.astype(np.float32)

quantized = np.round(tensor / scale + zero_point)
quantized = np.clip(quantized, quant_min, quant_max).astype(np.int8)

return quantized, scale, zero_point

@staticmethod
def _weight_int8_quant(tensor, transpose_b=True):
"""Quantize weight tensor to int8."""
bits=8
quant_min, quant_max = NumpyQuantizer._get_quant_min_max(bits)
oc_axis = 0 if transpose_b else 1
ic_axis = 1 if transpose_b else 0
oc = tensor.shape[oc_axis]
min_val = np.min(tensor, axis=ic_axis, keepdims=True)
max_val = np.max(tensor, axis=ic_axis, keepdims=True)
if (max_val == min_val).all():
scale = np.ones((oc,), dtype=np.float32)
else:
min_val = min_val.astype(np.float64)
max_val = max_val.astype(np.float64)
max_val = np.maximum(np.abs(min_val), np.abs(max_val))
min_val = -max_val
scale = ((max_val - min_val) / (quant_max - quant_min)).astype(np.float32)

quantized = np.round(tensor / scale)
quantized = np.clip(quantized, quant_min, quant_max).astype(np.int8)
scale = np.squeeze(scale)
return quantized, scale

@staticmethod
def _weight_int4_per_group_pack(tensor, group_size, transpose_b=True):
"""weight_int4_per_group_pack."""
if transpose_b:
oc, ic = tensor.shape[0], tensor.shape[1]
else:
ic, oc = tensor.shape[0], tensor.shape[1]
q = np.empty((oc//2,ic), dtype=np.int8)
scale = np.empty((oc,ic//group_size), dtype=np.float32)
scale_uint64 = scale.astype(np.float32).view(np.uint32).astype(np.uint64)
scale_uint64 = scale_uint64.reshape(scale.shape)
packed = q
return packed, scale_uint64

+ 0
- 242
tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/run_parallel_linear.py View File

@@ -1,242 +0,0 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Run ColumnParallelLinear accuracy test with configurable parameters via args"""


import argparse
import glob
import os
import tempfile
import numpy as np
from safetensors import safe_open
from safetensors.numpy import save_file

import mindspore as ms
from mindspore.communication import init
from numpy_quantizer import NumpyQuantizer
from gpt_model_for_test import GPTModelForTest, LinearSpec, ModelSpec, QKVLinearSpec, GroupLinearSpec
from mindformers.parallel_core.inference.parallel_state import initialize_model_parallel
from mindformers.parallel_core.process_group_config import ModelCommProcessGroups


class ParallelModelRunner:
"""Runner for parallel model testing with quantization support."""

def __init__(self, config):
"""Initialize the parallel model runner with given arguments."""
self.config = config
# set up parallel context
rank_id_str = os.environ.get("RANK_ID")
self.rank_id = int(rank_id_str) if rank_id_str is not None else None
self.worker_num = int(os.environ.get("MS_WORKER_NUM", "1"))
self.model_comm_pgs = ModelCommProcessGroups.get_default_model_comm_pgs()
if self.rank_id is not None:
init()
initialize_model_parallel(tensor_model_parallel_size=self.config.tensor_parallel)
self.model_comm_pgs = ModelCommProcessGroups.use_parallel_state_groups(required_groups=['tp'])

linear_specs = []
quant_policys = []
self.quantization = config.quantization
for linear_type in config.linear_types:
for has_bias in [True, False]:
for quant_policy in config.quant_policies:
quant_policy = quant_policy if config.quantization == 'golden-stick' else 'float'
if linear_type=="QKVParallelLinear":
linear_specs.append(QKVLinearSpec(linear_type, config.input_size, config.head_size,
config.total_num_heads,config.total_num_kv_heads,
has_bias, config.compute_dtype, quant_policy))
elif linear_type=="ColumnParallelGroupedLinear":
linear_specs.append(GroupLinearSpec(linear_type, config.num_local_experts,config.input_size,
config.output_size,
quant_policy))
else:
linear_specs.append(LinearSpec(linear_type, config.input_size, config.output_size,
has_bias, config.compute_dtype, quant_policy))
quant_policys.append(quant_policy)

self.model_spec = ModelSpec(
compute_dtype=config.compute_dtype,
param_init_dtype=config.param_init_dtype,
tensor_parallel=config.tensor_parallel,
linear_specs=linear_specs,
)
self.quant_model_dir = None
if self.quantization == 'golden-stick':
self.quantizer = NumpyQuantizer(self.model_spec, quant_policys)
self.quant_model_dir = tempfile.mkdtemp(prefix="quant_model_for_test_")

@staticmethod
def _gen_float_weights(model_spec):
"""Generate random float weights for model specifications."""
np.random.seed(42)
weights = {}
for index, linear_spec in enumerate(model_spec.linear_specs):
if linear_spec.linear_type=="QKVParallelLinear":
#qkv
weight_shapes = [(linear_spec.output_sizes[0], linear_spec.input_size),
(linear_spec.output_sizes[1], linear_spec.input_size),
(linear_spec.output_sizes[2], linear_spec.input_size)]
output_size = linear_spec.output_size
qkv_map = {0:"q",1:"k",2:"v"}
for shared_id,weight_shape in enumerate(weight_shapes):
weight = 0.01 * np.random.randn(*weight_shape).astype(np.float32)
weights[f"linears.{index}.{qkv_map[shared_id]}.weight"] = weight
if linear_spec.has_bias:
for shared_id,weight_shape in enumerate(weight_shapes):
bias = 0.01 * np.random.randn(weight_shape[0]).astype(np.float32)
weights[f"linears.{index}.{qkv_map[shared_id]}.bias"]= bias
elif linear_spec.linear_type=="ColumnParallelGroupedLinear":
# gate,up
weight_shapes = [(linear_spec.output_size//2,linear_spec.input_size),
(linear_spec.output_size//2,linear_spec.input_size)]
output_size = linear_spec.output_size
gate_up_map = {0:"gate",1:"up"}
for shared_id,weight_shape in enumerate(weight_shapes):
weight = 0.01 * np.random.randn(*weight_shape).astype(np.float32)
weights[f"linears.{index}.{gate_up_map[shared_id]}.weight"]=weight
else:
weight_shape = (linear_spec.output_size, linear_spec.input_size)
output_size = linear_spec.output_size
weight = 0.01 * np.random.randn(*weight_shape).astype(np.float32)
weights[f"linears.{index}.weight"] = weight
if linear_spec.has_bias:
bias = 0.01 * np.random.randn(output_size).astype(np.float32)
weights[f"linears.{index}.bias"] = bias
return weights

@staticmethod
def _gen_input(model_spec):
"""Generate random input data for model specifications."""
np.random.seed(42)
return 0.01 * np.random.randn(2 * 2, model_spec.linear_specs[0].input_size).astype(np.float32)

def _create_network(self):
"""Create the network model for testing."""
return GPTModelForTest(self.model_spec, self.model_comm_pgs, self.quantization, self.quant_model_dir)

def _load_quant_weights(self):
"""Load quantized weights from the model directory."""
if not os.path.isdir(self.quant_model_dir):
raise ValueError(f"Invalid quant_model_dir: {self.quant_model_dir}")
safetensor_files = glob.glob(os.path.join(self.quant_model_dir, "*.safetensors"))
if len(safetensor_files) == 1:
safetensor_file = safetensor_files[0]
elif len(safetensor_files) > 1:
raise FileNotFoundError(f"Found multiple safetensor files in {self.quant_model_dir}")
else:
raise FileNotFoundError(f"Found no safetensor file in {self.quant_model_dir}")
if not os.path.exists(safetensor_file):
raise FileNotFoundError(f"File {safetensor_file} not found.")
with safe_open(safetensor_file, framework="np", device="cpu") as f:
weights = {}
for key in f.keys():
weights[key] = f.get_slice(key)
return weights

@staticmethod
def load_weights_into_network(network, weights):
"""Load weights into the network parameters."""
params = network.parameters_dict()
print(params)
loaded = []
for k, v in weights.items():
shard_id = None
expert_id = None
original_key = k
if ".gate" in k or ".q." in k:
k = k.replace(".gate","")
k = k.replace(".q","")
expert_id = 0
shard_id = "w1" # For ColumnParallelGroupedLinear, use "w1" for gate weights
if ".up" in k or ".k." in k:
k = k.replace(".up","")
k = k.replace(".k","")
shard_id = "w3" # For ColumnParallelGroupedLinear, use "w3" for up weights
if expert_id is None:
expert_id = 0
if ".v." in k:
k = k.replace(".v","")
shard_id = 2
expert_id = None
param = params.get(k)
if param is None:
continue
loaded.append(original_key) # Track original key, not transformed key
if shard_id is not None:
if expert_id is not None:
param.weight_loader(param, v,shard_id,expert_id)
else:
param.weight_loader(param, v,shard_id)
else:
param.weight_loader(param, v)


print(f"weights not use: {set(weights.keys()) - set(loaded)}", flush=True)
print(f"params not load: {set(params.keys()) - set(loaded)}", flush=True)

def run(self):
"""Run the parallel model test."""
input_data = ParallelModelRunner._gen_input(self.model_spec)
weights = ParallelModelRunner._gen_float_weights(self.model_spec)
if self.quantization == 'golden-stick':
self.quantizer.quant(input_data, weights, self.quant_model_dir)
weights = self._load_quant_weights()
network = self._create_network()
first_value = next(iter(weights.values()))
# Moe must input safetensors
if isinstance(first_value, np.ndarray):
with tempfile.TemporaryDirectory() as temp_dir:
path = os.path.join(temp_dir, "model.safetensors")
save_file(weights, path)
weights.clear()
with safe_open(path, framework="np", device="cpu") as f:
for key in f.keys():
weights[key] = f.get_slice(key)
ParallelModelRunner.load_weights_into_network(network, weights)
net_input = ms.Tensor(input_data, dtype=LinearSpec.convert_pt_dtype_to_ms(self.model_spec.compute_dtype))
output_dict = network.forward(net_input)

if self.rank_id is None or int(self.rank_id) == 0:
np.savez(self.config.output_path, **output_dict)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run ColumnParallelLinear test")
parser.add_argument("--linear_types", type=str, action='append', default=None,
help="List of linear types, e.g., --linear_types ColumnParallelLinear "\
"--linear_types RowParallelLinear")
parser.add_argument("--tensor_parallel", type=int, default=1)
parser.add_argument("--head_size", type=int, default=10)
parser.add_argument("--total_num_heads", type=int, default=2)
parser.add_argument("--total_num_kv_heads", type=int, default=2)
parser.add_argument("--compute_dtype", type=str, default='bf16')
parser.add_argument("--param_init_dtype", type=str, default='bf16')
parser.add_argument("--num_local_experts", type=int, default=1)
parser.add_argument("--output_path", type=str, default="output.npz")
parser.add_argument("--quantization", type=str, default=None)
parser.add_argument("--quant_policies", type=str, action='append', default=None,
help="List of quantization policies, e.g., --quant_policies a8w8 --quant_policies a8dynw8")
args = parser.parse_args()
args.input_size = 2048
args.output_size = 2048

ms.set_context(device_target="Ascend",
mode=ms.GRAPH_MODE,
jit_config={"jit_level": "O0", "infer_boost": "on"},
deterministic="ON")

quant_runner = ParallelModelRunner(args)
quant_runner.run()

+ 0
- 142
tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/test_parallel_linear.py View File

@@ -1,142 +0,0 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test ColumnParallelLinear with various configurations"""


from typing import Optional
from pathlib import Path
import subprocess
import pytest
import numpy as np
from tests.utils.precision_utils import PrecisionChecker
from mindformers.tools.logger import logger


def build_msrun_command_list(linear_types, log_dir, run_script_path, output_path_param, tensor_parallel,
port, quantization, quant_policies:Optional[list]=None):
""" Build the msrun command with the specified parameters. """
if tensor_parallel == 1:
cmd_list = ["python"]
else:
cmd_list = [
"msrun",
f"--worker_num={tensor_parallel}",
f"--local_worker_num={tensor_parallel}",
f"--master_port={port}", # Ensure port is unique per test run if parallelized at pytest level
f"--log_dir={log_dir}",
"--join=True",
]

cmd_list += [
str(run_script_path),
f"--output_path={output_path_param}",
f"--tensor_parallel={tensor_parallel}",
]
for linear_type in linear_types:
cmd_list.append(f"--linear_types={linear_type}")
for quant_policy in quant_policies:
cmd_list.append(f"--quant_policies={quant_policy}")
if quantization is not None:
cmd_list.append(f"--quantization={quantization}")
if quant_policies is None:
raise RuntimeError("quant_policies must be provided when quantization is enabled.")

logger.info(f"Equivalent shell command for debugging (approximate): {' '.join(cmd_list)}")
return cmd_list


class TestParallelLinear:
"""Test class for ParallelLinear with different configurations"""
def setup_method(self):
"""Setup method to prepare test environment"""
self.sh_path = Path(__file__).parent.resolve()
self.run_script_path = self.sh_path / "run_parallel_linear.py"
self.log_file_path = self.sh_path / 'test_output' / 'logs'
self.log_file_path.mkdir(parents=True, exist_ok=True)

def infer(self, linear_types, log_dir_path, output_file_path, tensor_parallel, port, quantization,
quant_policies=None):
"""Run inference with the specified parameters and check for output file."""
cmd_list = build_msrun_command_list(
linear_types=linear_types,
log_dir=log_dir_path,
run_script_path=self.run_script_path,
output_path_param=output_file_path,
tensor_parallel=tensor_parallel,
port=port,
quantization=quantization,
quant_policies=quant_policies,
)

result = subprocess.run(
cmd_list, shell=False, capture_output=True, text=True, check=False)

assert result.returncode == 0, (
f"Test script failed with non-zero exit code: "
f"{result.returncode}.\nStdout:\n{result.stdout}\nStderr:\n{result.stderr}"
)
assert output_file_path.exists(), (
f"Output file {output_file_path} was not created."
)

def run_test(self, linear_types, quant_policies, tmp_path, tensor_parallel=1, port=8118):
"""Helper function to run test and check results"""
output_file_path = tmp_path / 'quant-output.npz'
self.infer(
linear_types=linear_types,
log_dir_path=self.log_file_path,
output_file_path=output_file_path,
tensor_parallel=tensor_parallel,
port=port,
quantization='golden-stick',
quant_policies=quant_policies,
)
quant_output = np.load(output_file_path)

output_file_path = tmp_path / 'float-output.npz'
self.infer(
linear_types=linear_types,
log_dir_path=self.log_file_path,
output_file_path=output_file_path,
tensor_parallel=tensor_parallel,
port=port+1,
quantization=None,
quant_policies=quant_policies,
)
float_output = np.load(output_file_path)
checker = PrecisionChecker()
succeed = True
for key in quant_output:
fkey = key[:key.rfind('-')] + '-quant_type_float'
if fkey not in float_output:
raise ValueError(f"Diff key in quant_output but not in float_output: {key}")
try:
checker.check_precision(float_output[fkey], quant_output[key])
print(f"Check precision for {key} succeed", flush=True)
except AssertionError as e:
print(f"Check precision for {key} failed: {e}", flush=True)
succeed = False
succeed = True
assert succeed, "Some precision check failed"

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_single_card_moe_configurations(self, tmp_path):
"""Test single card with various configurations."""
linear_types = ["ColumnParallelGroupedLinear"]
quant_policies = ["a8w4","a8dynw8"]
self.run_test(linear_types=linear_types, quant_policies=quant_policies,
tmp_path=tmp_path, port=8888)

Loading…
Cancel
Save
Baidu
map