42 Commits

Author SHA1 Message Date
  i-robot f271c58cd2
!7786 新增modeling_utils测试用例 1 week ago
  i-robot cb3083aab9
!7810 【bugfix】【master】修复测试用例bug 1 week ago
  i-robot 656b143f71
!7807 修复test_pma用例超时,去掉test_all_reduce中无用用例 1 week ago
  i-robot 4e95ab6f51
!7808 修改blended_megatron_dataset_builder测试用例构建失败用例 1 week ago
  i-robot 1738940a3d
!7800 【MindFormers】【覆盖率】mindformers用例覆盖率较低,补充用例并行解码和流式推理 1 week ago
  lanxiang 0d55e8959d 新增modeling_utils用例 1 week ago
  lanxiang 695e7a7e2e 修复test_pma用例超时,去掉test_all_reduce中无用用例 1 week ago
  i-robot 3d81fa1cb1
!7812 【master】【bugfix】修改资料TeleChat大小写 1 week ago
  senzhen a9530fbbe2 修改资料TeleChat大小写 1 week ago
  i-robot fa9210dabd
!7799 【master】推理覆盖率提升 1 week ago
  i-robot 3988632900
!7763 【UT】补充测试用例 1 week ago
  Yule100 a5baae49c0 推理用例补覆盖率 1 week ago
  i-robot 4b1e9b9814
!7768 add pipeline and metric ut_test 1 week ago
  Hsshuai c7937097f1 fix testcase of get_last_checkpoint 1 week ago
  i-robot 773328361c
!7779 【master】【bugfix】文档拼写整改 1 week ago
  i-robot 646ebbe674
!7809 【bugfix】【master】文档修正 1 week ago
  yiyison 514f58b821 文档修正 1 week ago
  i-robot 82f8c9545e
!7804 【master】【bugfix】修复Muon优化器中tp!=op场景下的精度问题 1 week ago
  i-robot 394342b40e
!7793 [bugfix] fix bs>1 in hf dataloader tnd. 1 week ago
  i-robot 7119deb050
!7801 【master】增加transform_checkpoint_utils.py测试用例 1 week ago
  zzzkeke 550c063fd8 修改blended_megatron_dataset_builder测试用例构建失败用例 1 week ago
  i-robot 3619a5e8d4
!7806 【bugfix】【master】文档修正 1 week ago
  yiyison fb8be3937e 文档修正 1 week ago
  i-robot 48d81f9ccc
!7748 Fix the error of the_build_context 1 week ago
  i-robot ec99b68fd4
!7795 【master】增加checkpoint/utils 的测试用例 1 week ago
  i-robot 20ea6cb185
!7750 【master】【bugfix】补充覆盖率 1 week ago
  yiyison 3b78fd99dd 增加transform_checkpoint_utils.py测试用例 2 weeks ago
  i-robot 520113827b
!7798 【master】【UT】补充weight_utils、logger、version_control文件的测试用例 1 week ago
  i-robot 3d7747a89e
!7709 stable_rank_fix 1 week ago
  i-robot d9d612874e
!7794 【master】增加fully_parallel测试用例 1 week ago
  JavaZero 6f25245ef7 fix: adjust chunking logic in _slice_tensor_to_shards for tensor distribution 1 week ago
  yiyison 158fbb74db 增加transform_checkpoint_utils.py测试用例 1 week ago
  pengjingyou 3d61efd427 【master】推理覆盖率提升 1 week ago
  qinsichun 93eb56dd28 test_conv 1 week ago
  yiyison 4c68ae07ef fully_parallel测试用例 2 weeks ago
  zxq 26492c13fa 【UT】补充测试用例 2 weeks ago
  zxq 285388d5f3 【master】【UT】补充weight_utils、logger、version_control文件的测试用例 1 week ago
  宋佳琪 da79b68ee6 stable_rank_fix 3 weeks ago
  senzhen a044bdbc35 文档拼写整改 2 weeks ago
  JingweiHuang 9ca643d51f Fix the error of the_build_context 2 weeks ago
  李宜杰 b01eff2a5d add pipeline and metric ut_test 2 weeks ago
  niujunhao 32f6442f82 fix bs>1 in hf dataloader tnd. 2 weeks ago
58 changed files with 8162 additions and 163 deletions
Split View
  1. +14
    -14
      configs/glm4/README.md
  2. +9
    -9
      configs/glm4_moe/README.md
  3. +2
    -2
      docs/model_cards/glm4.md
  4. +0
    -1
      docs/transformer仓Python编程规范.md
  5. +32
    -17
      mindformers/core/callback/callback.py
  6. +8
    -11
      mindformers/core/context/build_context.py
  7. +2
    -2
      mindformers/core/optim/muon.py
  8. +6
    -0
      mindformers/dataset/causal_language_model_dataset.py
  9. +1
    -1
      mindformers/parallel_core/transformer_config.py
  10. +6
    -1
      mindformers/tools/ckpt_transform/transform_checkpoint.py
  11. +1
    -1
      research/deepseek3/README.md
  12. +2
    -2
      research/qwen2_5/README.md
  13. +16
    -23
      research/telechat2/README.md
  14. +1
    -1
      tests/st/test_multi_cards_cases/test_optimizer/test_pma/test_pma.py
  15. +14
    -0
      tests/st/test_run_check.py
  16. +311
    -16
      tests/st/test_safetensors/test_checkpoint_utils.py
  17. +644
    -0
      tests/st/test_ut/test_checkpoint/test_fully_parallel.py
  18. +0
    -24
      tests/st/test_ut/test_core/test_callback/test_all_reduce.py
  19. +43
    -35
      tests/st/test_ut/test_core/test_context/test_build_context.py
  20. +3
    -1
      tests/st/test_ut/test_dataset/test_dataloader/test_blended_megatron_dataset_builder.py
  21. +41
    -0
      tests/st/test_ut/test_generation/qwen3_0_6b_infer.yaml
  22. +286
    -0
      tests/st/test_ut/test_generation/test_parallel_decoding.py
  23. +56
    -0
      tests/st/test_ut/test_generation/test_streamer.py
  24. +282
    -1
      tests/st/test_ut/test_metrics.py
  25. +168
    -0
      tests/st/test_ut/test_mindformer_book.py
  26. +264
    -0
      tests/st/test_ut/test_models/test_auto/test_configuration_auto.py
  27. +86
    -0
      tests/st/test_ut/test_models/test_auto/test_utils.py
  28. +0
    -0
      tests/st/test_ut/test_models/test_glm4/__init__.py
  29. +47
    -0
      tests/st/test_ut/test_models/test_glm4/test_configuration_glm4.py
  30. +46
    -0
      tests/st/test_ut/test_models/test_glm4/test_modeling_glm4.py
  31. +0
    -0
      tests/st/test_ut/test_models/test_glm4_moe/__init__.py
  32. +48
    -0
      tests/st/test_ut/test_models/test_glm4_moe/test_configuration_glm4_moe.py
  33. +46
    -0
      tests/st/test_ut/test_models/test_glm4_moe/test_modeling_glm4_moe.py
  34. +1216
    -0
      tests/st/test_ut/test_models/test_modeling_utils.py
  35. +302
    -0
      tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization/test_base_config.py
  36. +0
    -0
      tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/test_mapping/__init__.py
  37. +174
    -0
      tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/test_mapping/test_infer_mapping.py
  38. +0
    -0
      tests/st/test_ut/test_parallel_core/test_inference/test_transformer/test_fused_softmax/__init__.py
  39. +92
    -0
      tests/st/test_ut/test_parallel_core/test_inference/test_transformer/test_fused_softmax/test_infer_fused_softmax.py
  40. +0
    -0
      tests/st/test_ut/test_parallel_core/test_inference/test_transformer/test_lower_triangular_mask/__init__.py
  41. +70
    -0
      tests/st/test_ut/test_parallel_core/test_inference/test_transformer/test_lower_triangular_mask/test_infer_lower_triangular_mask.py
  42. +0
    -0
      tests/st/test_ut/test_parallel_core/test_inference/test_transformer/test_moe/test_moe_utils/__init__.py
  43. +114
    -0
      tests/st/test_ut/test_parallel_core/test_inference/test_transformer/test_moe/test_moe_utils/test_infer_moe_utils.py
  44. +0
    -0
      tests/st/test_ut/test_parallel_core/test_inference/test_utils/__init__.py
  45. +331
    -0
      tests/st/test_ut/test_parallel_core/test_inference/test_utils/test_utils.py
  46. +234
    -0
      tests/st/test_ut/test_parallel_core/test_inference/test_weights_utils.py
  47. +283
    -0
      tests/st/test_ut/test_pipeline/test_base_pipeline.py
  48. +405
    -0
      tests/st/test_ut/test_pipeline/test_pipeline.py
  49. +176
    -0
      tests/st/test_ut/test_pipeline/test_pipeline_registry.py
  50. +234
    -0
      tests/st/test_ut/test_pipeline/test_registry_constant.py
  51. +173
    -0
      tests/st/test_ut/test_tools/test_generic.py
  52. +147
    -0
      tests/st/test_ut/test_tools/test_logger.py
  53. +160
    -0
      tests/st/test_ut/test_tools/test_register/test_config.py
  54. +1330
    -0
      tests/st/test_ut/test_tools/test_transform_checkpoint.py
  55. +3
    -0
      tests/st/test_ut/test_trainer/test_trainer_methods.py
  56. +104
    -0
      tests/st/test_ut/test_utils/test_convert_utils.py
  57. +128
    -0
      tests/st/test_ut/test_version_control.py
  58. +1
    -1
      toolkit/safetensors/README.md

+ 14
- 14
configs/glm4/README.md View File

@@ -1,4 +1,4 @@
# Glm4
# GLM-4

## 模型描述

@@ -8,8 +8,8 @@ GLM-4 系列模型是专为智能代理设计的基础模型, 其性能可与Ope

| 模型名称 | 规格 | 支持任务 | 模型架构 | 支持设备 | 模型级别 |
|:----------:|:---------:|:------:|:-----:|:-------------------------------------------------:|:-------------------:|
|GLM4-32B | 32B | 推理 | Mcore | Atlas 800T A2/Atlas 800I A2/Atlas 900 A3 SuperPoD | [Validated](#模型级别介绍) |
|GLM4-9B | 9B | 推理 | Mcore | Atlas 800T A2/Atlas 800I A2/Atlas 900 A3 SuperPoD | [Validated](#模型级别介绍) |
|GLM-4-32B | 32B | 推理 | Mcore | Atlas 800T A2/Atlas 800I A2/Atlas 900 A3 SuperPoD | [Validated](#模型级别介绍) |
|GLM-4-9B | 9B | 推理 | Mcore | Atlas 800T A2/Atlas 800I A2/Atlas 900 A3 SuperPoD | [Validated](#模型级别介绍) |

说明:

@@ -18,15 +18,15 @@ GLM-4 系列模型是专为智能代理设计的基础模型, 其性能可与Ope

## 版本配套

GLM4 当前支持的版本配套如下。
GLM-4 当前支持的版本配套如下。

| | Mindspore Transformers | MindSpore | CANN | HDK |
| | MindSpore Transformers | MindSpore | CANN | HDK |
|:---------:|:----------------------:|:---------:|:----:|:---:|
| 当前支持的版本 | 在研版本 | 在研版本 | 在研版本 | 在研版本 |

## 使用样例

MindSpore Transformers 支持使用 GLM4 进行推理。各任务的整体使用流程如下:
MindSpore Transformers 支持使用 GLM-4 进行推理。各任务的整体使用流程如下:

| 任务 | 前期准备 | 使用流程 |
|:---:|:------------------------|:---------------------------|
@@ -69,7 +69,7 @@ parallel_config:
- pretrained_model_dir:Hugging Face模型目录路径,放置模型配置、Tokenizer等文件。`/path/hf_dir`中的内容如下:

```text
📂GLM4
📂GLM-4
├── 📄config.json
├── 📄generation_config.json
├── 📄merges.txt
@@ -192,11 +192,11 @@ Glm4的模型文件包括以下内容:

```text
📦glm4
├── 📄__init__.py # glm4模块初始化文件
├── 📄configuration_glm4.py # glm4模型配置类定义
├── 📄modeling_glm4.py # glm4模型主体实现
├── 📄modeling_glm4_infer.py # glm4推理模型实现
└── 📄utils.py # glm4工具函数和基础类
├── 📄__init__.py # GLM-4模块初始化文件
├── 📄configuration_glm4.py # GLM-4模型配置类定义
├── 📄modeling_glm4.py # GLM-4模型主体实现
├── 📄modeling_glm4_infer.py # GLM-4推理模型实现
└── 📄utils.py # GLM-4工具函数和基础类
```

### 并行配置建议
@@ -218,7 +218,7 @@ Glm4的模型文件包括以下内容:
<th>模型级别</th>
</tr>
<tr>
<td>GLM4-32B</td>
<td>GLM-4-32B</td>
<td>32B</td>
<td>1 × Atlas 800T A2 (2P)</td>
<td>2</td>
@@ -235,7 +235,7 @@ Glm4的模型文件包括以下内容:
<td> Validated </td>
</tr>
<tr>
<td>GLM4-9B</td>
<td>GLM-4-9B</td>
<td>9B</td>
<td>1 × Atlas 800T A2 (1P)</td>
<td>1</td>


+ 9
- 9
configs/glm4_moe/README.md View File

@@ -2,7 +2,7 @@

## 模型描述

GLM-4.5 系列模型是专为智能代理设计的基础模型,基于GLM4采用了MoE结构的变体,也标记为GLM4-MoE。GLM-4.5 总参数 3550 亿,激活参数 320 亿,而 GLM-4.5-Air 采用更紧凑的设计,总参数 1060 亿,激活参数 120 亿。GLM-4.5模型统一了推理、编码和智能体能力,满足智能体应用的复杂需求。
GLM-4.5 系列模型是专为智能代理设计的基础模型,基于GLM-4采用了MoE结构的变体,也标记为GLM-4-MoE。GLM-4.5 总参数 3550 亿,激活参数 320 亿,而 GLM-4.5-Air 采用更紧凑的设计,总参数 1060 亿,激活参数 120 亿。GLM-4.5模型统一了推理、编码和智能体能力,满足智能体应用的复杂需求。
具体模型能力查看以下技术报告:[GLM-4.5: Reasoning, Coding, and Agentic Abililties](https://z.ai/blog/glm-4.5)

## 支持规格
@@ -21,7 +21,7 @@ GLM-4.5 系列模型是专为智能代理设计的基础模型,基于GLM4采

GLM-4.5 当前支持的版本配套如下。

| | Mindspore Transformers | MindSpore | CANN | HDK |
| | MindSpore Transformers | MindSpore | CANN | HDK |
|:---------:|:----------------------:|:---------:|:----:|:---:|
| 当前支持的版本 | 在研版本 | 在研版本 | 在研版本 | 在研版本 |

@@ -70,7 +70,7 @@ parallel_config:
- pretrained_model_dir:Hugging Face模型目录路径,放置模型配置、Tokenizer等文件。`/path/hf_dir`中的内容如下:

```text
📂Glm4.5
📂GLM-4.5
├── 📄config.json
├── 📄generation_config.json
├── 📄merges.txt
@@ -190,15 +190,15 @@ bash scripts/msrun_launcher.sh "run_mindformer.py \

### 模型文件说明

glm4_moe的模型文件包括以下内容:
GLM-4-MoE的模型文件包括以下内容:

```text
📦glm4_moe
├── 📄__init__.py # glm4_moe模块初始化文件
├── 📄configuration_glm4_moe.py # glm4_moe模型配置类定义
├── 📄modeling_glm4_moe.py # glm4_moe模型主体实现
├── 📄modeling_glm4_moe_infer.py # glm4_moe推理模型实现
└── 📄utils.py # glm4_moe工具函数和基础类
├── 📄__init__.py # GLM-4-MoE模块初始化文件
├── 📄configuration_glm4_moe.py # GLM-4-MoE模型配置类定义
├── 📄modeling_glm4_moe.py # GLM-4-MoE模型主体实现
├── 📄modeling_glm4_moe_infer.py # GLM-4-MoE推理模型实现
└── 📄utils.py # GLM-4-MoE工具函数和基础类
```

### 并行配置建议


+ 2
- 2
docs/model_cards/glm4.md View File

@@ -74,7 +74,7 @@ MindSpore Transformers 提供 `alpaca` 数据集示例处理脚本制作[全参

| 数据集名称 | 适用模型 | 适用阶段 | 下载链接 |
|:-------------|:-------:|:--------:|:------------------------------------------------------------------------------------------:|
| alpaca | glm4-9b | Finetune | [Link](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json) |
| alpaca | GLM-4-9B | Finetune | [Link](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json) |

数据预处理中所用的 `tokenizer.model` 可以参考[模型权重下载](#模型权重下载)进行下载。

@@ -203,7 +203,7 @@ bash scripts/examples/glm4/run_glm4_predict.sh PARALLEL CONFIG_PATH CKPT_PATH TO

| 参数名 | 含义 | 取值说明 |
|-------------|---------------------------|---------------------------------------------------------------------------------------------|
| PARALLEL | 指定选择推理模式为单卡推理 or 多卡推理。 | (str, 必选) - 单卡推理配置为 `single` ,多卡推理配置为 `parallel` 。 |
| PARALLEL | 指定选择推理模式为单卡推理或多卡推理。 | (str, 必选) - 单卡推理配置为 `single` ,多卡推理配置为 `parallel` 。 |
| CONFIG_PATH | 模型配置文件路径。 | (str, 必选) - 如 `/path/to/glm4/predict_glm4_9b_chat.yaml` 。 |
| CKPT_PATH | 推理时用到的模型权重文件路径。 | (str, 必选) - 单卡为完整权重,双卡为分布式权重。<br>如单卡推理 `/path/to/glm4.ckpt`,多卡推理 `/path/to/glm4_ckpt_dir` 。 |
| TOKENIZER | GLM-4 模型的 tokenizer 文件路径。 | (str, 必选) - 如 `/path/to/tokenizer.model` 。 |


+ 0
- 1
docs/transformer仓Python编程规范.md View File

@@ -3,7 +3,6 @@
本规范以[PEP8](https://www.python.org/dev/peps/pep-0008/)为基础,参考华为Python通用编码规范、安全编程规范,并结合业界共识整理而成,参与MindSpore社区开发需要首先遵循本规范内容(与PEP8冲突部分),其余遵循PEP8规范。

如果对规则有异议,建议提交 issue 并说明理由,经MindSpore社区运营团队评审接纳后可修改生效。
a

## 适用范围



+ 32
- 17
mindformers/core/callback/callback.py View File

@@ -246,6 +246,11 @@ def _get_optimizer_state(optim_params, filter_fn: Callable = None):
return norms


def _is_positive_natural_number(x):
"""Check if it is a positive natural number"""
return isinstance(x, int) and x > 0


@MindFormerRegister.register(MindFormerModuleType.CALLBACK)
class MFLossMonitor(Callback):
"""
@@ -771,10 +776,9 @@ class TrainingStateMonitor(Callback):
embedding_size: int = 4096,
use_local_norm: bool = False):
super().__init__()
if not (isinstance(step_interval, int) and step_interval > 0):
logger.warning("The value of 'monitor_config.step_interval' should be positive integer, "
f"but get {step_interval}. Use default value: 1.")
step_interval = 1
if not _is_positive_natural_number(step_interval):
raise TypeError("The value of 'monitor_config.step_interval' should be positive integer, "
f"but get {step_interval}.")
self.step_interval = step_interval
self.last_print_time = 0
self.step_time = time.time()
@@ -808,7 +812,7 @@ class TrainingStateMonitor(Callback):
self.device_local_norm_pattern = re.compile('(device_local_norm)_[a-z]+[0-9]+_([0-9]+)')
# when pipeline_stages > 2, param aggregation is not supported for now
pp_parallel = context.get_auto_parallel_context("pipeline_stages") > 1
if pp_parallel and self.sr_format:
if pp_parallel and self.sr_format and self.do_aggregation:
raise TypeError("When pipeline_stages > 1, weight aggregation is not supported")

def on_train_epoch_begin(self, run_context):
@@ -1076,15 +1080,37 @@ class TrainingStateMonitor(Callback):
if hasattr(config.get('stable_rank_config'), "get"):
self.sr_format = config.get('stable_rank_config').get('format', None)
self.sr_step_interval = config.get('stable_rank_config').get('step_interval', 100)
if not _is_positive_natural_number(self.sr_step_interval):
raise TypeError("'monitor_config.stable_rank_config.step_interval' should be positive integer,"
f"but get {self.sr_step_interval}.")
self.sr_last_print_time = 0
self.sr_target = config.get('stable_rank_config').get('target') or ['.*']
self.sr_target_cache = {}
self.do_aggregation = config.get('stable_rank_config').get('do_aggregation', False)
self.moe_show_mode = config.get('stable_rank_config').get('moe_show_mode') or ["all"]
self.power_iteration_num = config.get('stable_rank_config').get('power_iteration_num', 5)
if not _is_positive_natural_number(self.power_iteration_num):
raise TypeError("'monitor_config.stable_rank_config.power_iteration_num' should be positive integer,"
f"but get {self.power_iteration_num}.")
else:
self.sr_format = None

def _init_global_norm_monitor_config(self, config):
"""Initialize global norm monitor config"""
self.check_for_global_norm = config.get('check_for_global_norm')
self.global_norm_record_path = os.path.join(get_output_root_path(), "abnormal_global_norm.json")
self.global_norm_spike_threshold = config.get('global_norm_spike_threshold')
self.global_norm_spike_count_threshold = config.get('global_norm_spike_count_threshold', 10)
if not _is_positive_natural_number(self.global_norm_spike_count_threshold):
raise TypeError("'monitor_config.global_norm_spike_count_threshold' should be positive integer, "
f"but get {self.global_norm_spike_count_threshold}.")
self.abnormal_global_norms: dict[str, list[float]] = {}
if self.global_norm_record_path and os.path.exists(self.global_norm_record_path):
# the data format might be like {"300": [3.3], "600": [4.1, 4.2],}
# because json cannot use number as key, we convert it to string
with open(self.global_norm_record_path, 'r', encoding="utf-8") as file:
self.abnormal_global_norms = json.load(file)

def _init_config(self, config):
"""Initialize members from config"""
if config is None:
@@ -1106,15 +1132,9 @@ class TrainingStateMonitor(Callback):
self.optimizer_state_format = config.get('optimizer_state_format', None)
self.weight_state_format = config.get('weight_state_format', None)
self._init_stable_rank_config(config)
self._init_global_norm_monitor_config(config)
self.throughput_baseline = config.get('throughput_baseline', None)
self.print_struct = config.get('print_struct')

self.check_for_global_norm = config.get('check_for_global_norm')
self.global_norm_record_path = os.path.join(get_output_root_path(), "abnormal_global_norm.json")
self.global_norm_spike_threshold = config.get('global_norm_spike_threshold')
self.global_norm_spike_count_threshold = config.get('global_norm_spike_count_threshold')
self.abnormal_global_norms: dict[str, list[float]] = {}

if self.print_struct is None:
self.print_struct = False
if not (isinstance(self.target, list) and self.target and all(isinstance(i, str) for i in self.target)):
@@ -1130,11 +1150,6 @@ class TrainingStateMonitor(Callback):
'optimizer_state_format', 'weight_state_format', 'max_attention_logit_format']
for attr in attrs:
self._check_attr_formats(attr)
if self.global_norm_record_path and os.path.exists(self.global_norm_record_path):
# the data format might be like {"300": [3.3], "600": [4.1, 4.2],}
# because json cannot use number as key, we convert it to string
with open(self.global_norm_record_path, 'r', encoding="utf-8") as file:
self.abnormal_global_norms = json.load(file)

def _print_stable_rank(self, name, param, cur_step_num):
"""output stable rank and max eigenvalues"""


+ 8
- 11
mindformers/core/context/build_context.py View File

@@ -18,11 +18,11 @@ import os
from dataclasses import dataclass
from typing import Union

import mindspore as ms
import mindspore.dataset as ds
import mindspore
import mindspore.context as ms_context
import psutil

import mindformers
from mindformers.core.config_args import (
ContextConfig,
MFContextConfig,
@@ -36,10 +36,7 @@ from mindformers.tools.utils import (
MODE,
get_output_subpath,
check_in_dynamic_cluster,
get_real_local_rank,
get_real_group_size
)
from mindformers.utils import get_cann_workqueue_cores
from mindformers.version_control import (
check_tft_valid,
set_ms_deterministic
@@ -156,7 +153,7 @@ class MSContextOperator:
kernel_launch_group[key] = val
thread_num = int(kernel_launch_group.get('thread_num', 2))
kernel_group_num = int(kernel_launch_group.get('kernel_group_num', 8))
ms.runtime.set_kernel_launch_group(thread_num=thread_num, kernel_group_num=kernel_group_num)
mindspore.runtime.set_kernel_launch_group(thread_num=thread_num, kernel_group_num=kernel_group_num)

def _set_device_id(self, ctx, ms_ctx):
if self.config.use_parallel and check_in_dynamic_cluster():
@@ -388,7 +385,7 @@ def set_ms_affinity(affinity_config, affinity_cpu_list):

if affinity_config:
# Check if any device_X in affinity_config has X >= device_num
max_device_id = get_real_group_size() - 1
max_device_id = mindformers.tools.utils.get_real_group_size() - 1
for key in affinity_config:
try:
x = int(key.split('_')[1])
@@ -398,7 +395,7 @@ def set_ms_affinity(affinity_config, affinity_cpu_list):
if x > max_device_id:
raise ValueError(f"Invalid device id {x} in affinity_config. "
f"Maximum allowed device id is {max_device_id}.")
device_id = get_real_local_rank()
device_id = mindformers.tools.utils.get_real_local_rank()
device_config = affinity_config.get(f'device_{device_id}', None)
if device_config:
affinity_cpu_list = device_config.get('affinity_cpu_list', None)
@@ -409,7 +406,7 @@ def set_ms_affinity(affinity_config, affinity_cpu_list):
else:
module_to_cpu_dict = None

ms.runtime.set_cpu_affinity(
mindspore.runtime.set_cpu_affinity(
True,
affinity_cpu_list,
module_to_cpu_dict
@@ -420,14 +417,14 @@ def set_cpu_affinity(rank_id, rank_size):
"""cpu affinity"""
use_cpu_affinity = os.environ.get('CPU_AFFINITY')
if use_cpu_affinity and use_cpu_affinity.lower() in ('1', 'true'):
ds.config.set_numa_enable(True)
mindspore.dataset.config.set_numa_enable(True)
count = psutil.cpu_count()
current_process = psutil.Process()
used_cpus_num = count // rank_size
used_cpus = list(
range(rank_id * used_cpus_num, (rank_id + 1) * used_cpus_num)
)
cann_used_cpus = get_cann_workqueue_cores(rank_id)
cann_used_cpus = mindformers.utils.get_cann_workqueue_cores(rank_id)
logger.info(f"cann workqueue cpus: {cann_used_cpus}")
used_cpus = list(set(used_cpus) - set(cann_used_cpus))
if not used_cpus:


+ 2
- 2
mindformers/core/optim/muon.py View File

@@ -98,8 +98,8 @@ def _slice_tensor_to_shards(x, tp, tp_dim, op, rank_id, op_group, tp_group):
x = Chunk()(x, tp, tp_dim)[chunk_id]

if op > 1:
if op_group == tp_group:
chunk_id = rank_id % tp
if tp_dim == -1:
chunk_id = rank_id % op
else:
chunk_id = rank_id // tp % op
x = Chunk()(x, op)[chunk_id]


+ 6
- 0
mindformers/dataset/causal_language_model_dataset.py View File

@@ -38,9 +38,15 @@ CAST_TO_INT_COLUMNS = ["input_ids", "labels"]


def _use_compressed_eod_mask(data_loader):
"""
Determine whether the given data loader should use a compressed EOD (End-Of-Document) mask.
"""
if (hasattr(data_loader, 'config') and data_loader.config and
data_loader.config.create_compressed_eod_mask): # megatron dataset
return True
if (hasattr(data_loader, 'create_compressed_eod_mask') and
data_loader.create_compressed_eod_mask):
return True
if (hasattr(data_loader, 'adaptor_config') and data_loader.adaptor_config and
data_loader.adaptor_config.compress_mask): # common dataloader
return True


+ 1
- 1
mindformers/parallel_core/transformer_config.py View File

@@ -650,7 +650,7 @@ class TransformerConfig(ModelParallelConfig, MFModelConfig):
"When using moe_dry_run, moe_token_dispatcher_type must be 'alltoall' or 'alltoall_deredundency'."
)

if self.position_embedding_type not in ["rope", "yarn", "none", "relative", "learned_absolute"]:
if self.position_embedding_type not in ["rope", "yarn", "none", "relative", "learned_absolute", "partial_rope"]:
raise ValueError(
f"The current value of position_embedding_type is {self.position_embedding_type},"
" but position_embedding_type must be one of: 'rope', 'yarn', 'none', 'relative', 'learned_absolute'."


+ 6
- 1
mindformers/tools/ckpt_transform/transform_checkpoint.py View File

@@ -653,7 +653,8 @@ class TransformCkpt:
else:
break

if __name__ == '__main__':

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--src_checkpoint',
default="",
@@ -707,3 +708,7 @@ if __name__ == '__main__':
)

print("......Transform finished!......")


if __name__ == '__main__':
main()

+ 1
- 1
research/deepseek3/README.md View File

@@ -2,7 +2,7 @@

## 模型描述

DeepSeek-V3是由DeepSeek(深度求索)推出的一个强大的专家混合(MoE)语言模型,它拥有671B总参数,其中激活参数量为37B。为了实现高效推理和低成本训练,DeepSeek-V3采用了多头潜注意力(MLA)和DeepSeekMoE架构,这在DeepSeek-V2中得到了充分验证。此外,DeepSeek-V3 还率先采用了无辅助损失的负载均衡策略,并设定了多token预测训练目标,以提高性能。DeepSeek-V3在14.8万亿个多种类的高质量token上进行预训练,接着通过监督微调和强化学习充分优化其能力。综合评估显示,在发布时DeepSeek-V3的性能优于其他开源模型,并可与领先的闭源模型相媲美。尽管性能卓越,DeepSeek-V3 的全部训练成本非常低,且其训练过程也非常稳定。
DeepSeek-V3是由DeepSeek(深度求索)推出的一个强大的专家混合(MoE)语言模型,它拥有671B总参数,其中激活参数量为37B。为了实现高效推理和低成本训练,DeepSeek-V3采用了多头潜注意力(MLA)和DeepSeekMoE架构,这在DeepSeek-V2中得到了充分验证。此外,DeepSeek-V3 还率先采用了无辅助损失的负载均衡策略,并设定了多token预测训练目标,以提高性能。DeepSeek-V3在14.8万亿个多种类的高质量token上进行预训练,接着通过监督微调和强化学习充分优化其能力。综合评估显示,在发布时DeepSeek-V3的性能优于其他开源模型,并可与领先的闭源模型相媲美。尽管性能卓越,DeepSeek-V3 的全部训练成本非常低,且其训练过程也非常稳定。

```text
@misc{deepseekai2024deepseekv3technicalreport,


+ 2
- 2
research/qwen2_5/README.md View File

@@ -205,7 +205,7 @@ mindspore_ckpt_path: qkv_concat转换后权重文件保存路径,单卡权重
1. 当前支持模型已提供推理相关配置文件,请根据实际使用模型更改配置文件。
2. 运行下面的代码需要在`mindformers/`目录下,或者先将`mindformers/`目录所在路径加入到`PYTHONPATH`环境变量中。

``qwen2_5-7b` 8卡微调为例,执行如下命令进行微调。
`qwen2_5-7b` 8卡微调为例,执行如下命令进行微调。

1. 主要参数配置参考:

@@ -236,7 +236,7 @@ mindspore_ckpt_path: qkv_concat转换后权重文件保存路径,单卡权重
tokenizer:
model_max_length: 32768
vocab_file: "./path/vocab.json" # 参考qwen2_5-7b官网下载的词表
merges_file: "./path/merges.txt" # # 参考qwen2_5-7b官网下载的merge文件
merges_file: "./path/merges.txt" # 参考qwen2_5-7b官网下载的merge文件
#callbacks config
callbacks:
- type: CheckpointMonitor


+ 16
- 23
research/telechat2/README.md View File

@@ -30,33 +30,33 @@

以下模型性能均由Atlas 800T A2硬件环境下测试得出。

TeleChat2-7b:
TeleChat2-7B:

| config | task | Datasets | SeqLength | phase | performance |
|:---------------------------------------------------:| :-------------------: |:----------:|:---------:|:---------------:|:------------:|
| [TeleChat2_7b](./telechat2-7b/finetune_telechat_7b.yaml) | text_generation | example_dataset | 8192 | [finetune](#微调) | 2950 tokens/s/p |
| [TeleChat2_7b](./telechat2-7b/predict_telechat_7b.yaml) | text_generation | example_dataset | 8192 | [predict](#推理) | 54.1 tokens/s |
| [TeleChat2_7B](./telechat2-7b/finetune_telechat_7b.yaml) | text_generation | example_dataset | 8192 | [finetune](#微调) | 2950 tokens/s/p |
| [TeleChat2_7B](./telechat2-7b/predict_telechat_7b.yaml) | text_generation | example_dataset | 8192 | [predict](#推理) | 54.1 tokens/s |

TeleChat2-35b:
TeleChat2-35B:

| config | task | Datasets | SeqLength | phase | performance |
|-----------------------------------------------------| --------------------- |------------|-----------|-----------------|--------------|
| [TeleChat2_35b](./telechat2-35b/finetune_telechat_35b.yaml) | text_generation | example_dataset | 8192 | [finetune](#微调) | 516 tokens/s/p |
| [TeleChat2_35b](./telechat2-35b/predict_telechat_35b.yaml) | text_generation | example_dataset | 8192 | [predict](#推理) | 27.7 tokens/s |
| [TeleChat2_35B](./telechat2-35b/finetune_telechat_35b.yaml) | text_generation | example_dataset | 8192 | [finetune](#微调) | 516 tokens/s/p |
| [TeleChat2_35B](./telechat2-35b/predict_telechat_35b.yaml) | text_generation | example_dataset | 8192 | [predict](#推理) | 27.7 tokens/s |

TeleChat2-115b:
TeleChat2-115B:

| config | task | Datasets | SeqLength | phase | performance |
|-----------------------------------------------------| --------------------- |------------|-----------|-----------------|--------------|
| [TeleChat2_115b](./telechat2-115b/finetune_telechat_115b.yaml) | text_generation | example_dataset | 8192 | [finetune](#微调) | 158 tokens/s/p |
| [TeleChat2_115b](./telechat2-115b/predict_telechat_115b.yaml) | text_generation | example_dataset | 8192 | [predict](#推理) | 26.5 tokens/s |
| [TeleChat2_115B](./telechat2-115b/finetune_telechat_115b.yaml) | text_generation | example_dataset | 8192 | [finetune](#微调) | 158 tokens/s/p |
| [TeleChat2_115B](./telechat2-115b/predict_telechat_115b.yaml) | text_generation | example_dataset | 8192 | [predict](#推理) | 26.5 tokens/s |

TeleChat2-39b-a12b:
TeleChat2-39B-A12B:

| config | task | Datasets | SeqLength | phase | performance |
| ------------------------------------------------------------ | --------------- | --------------- | --------- | ---------------- | ------------- |
| [TeleChat2_39b_a12b](./telechat2-39b-a12b/finetune_telechat_39b_a12b.yaml) | text_generation | example_dataset | 8192 | [finetune](#微调) | 865 tokens/s/p |
| [TeleChat2_39b_a12b](./telechat2-39b-a12b/predict_telechat_39b_a12b_parallel.yaml) | text_generation | example_dataset | 8192 | [predict](#推理) | 36.4 tokens/s |
| [TeleChat2_39B_A12B](./telechat2-39b-a12b/finetune_telechat_39b_a12b.yaml) | text_generation | example_dataset | 8192 | [finetune](#微调) | 865 tokens/s/p |
| [TeleChat2_39B_A12B](./telechat2-39b-a12b/predict_telechat_39b_a12b_parallel.yaml) | text_generation | example_dataset | 8192 | [predict](#推理) | 36.4 tokens/s |

## 模型文件

@@ -149,10 +149,10 @@ MindFormers提供已经转换完成的预训练权重、词表文件用于预训

1.torch模型权重及词模型下载链接:

- [TeleChat2-7b](https://modelscope.cn/models/TeleAI/TeleChat2-7B-32K)
- [TeleChat2-7B](https://modelscope.cn/models/TeleAI/TeleChat2-7B-32K)
- [TeleChat2-39B-A12B](https://modelscope.cn/models/TeleAI/TeleChat2-39B-A12B)
- [TeleChat2-35b](https://modelscope.cn/models/TeleAI/TeleChat2-35B)
- [TeleChat2-115b](https://modelscope.cn/models/TeleAI/TeleChat2-115B)
- [TeleChat2-35B](https://modelscope.cn/models/TeleAI/TeleChat2-35B)
- [TeleChat2-115B](https://modelscope.cn/models/TeleAI/TeleChat2-115B)

下载完成后,运行如下转换脚本,将全量微调的权重转换为完整的ckpt权重。

@@ -168,13 +168,6 @@ torch_path: torch版本权重保存目录路径
mindspore_path: 权重保存文件名,可以指定自定义保存路径
```

2.获取MindFormers提供的已转换权重,可直接从下面的链接获取。

- [TeleChat2-7b](https://telechat-docker.obs.cn-north-4.myhuaweicloud.com/model_weight/Telechat_7B/Telechat_7B.zip)
- [TeleChat2-35b](https://telechat-docker.obs.cn-north-4.myhuaweicloud.com/model_weight/Telechat_35B/Telechat_35B.zip)
- [TeleChat2-115b](https://telechat-docker.obs.cn-north-4.myhuaweicloud.com/model_weight/Telechat_115B/Telechat_115B.zip)
- [Telechat2-39b-a12b](https://telechat-docker.obs.cn-north-4.myhuaweicloud.com/model_weight/Telechat_39B_A12.tar):仅适用于8卡推理,使用方式请参考[Telechat2-39B-A12B推理](#Telechat2-39B-A12B推理)章节。

### 分布式权重切分与合并

分布式训练/微调后所得到的权重文件为根据策略切分后的权重,需要手动将切分权重合一,以用于评估和推理。
@@ -226,7 +219,7 @@ MindFormers提供`TeleChat2-115B`的微调示例,过程中使用中电信人
- step 2. 根据服务器节点数等信息,修改相应的配置。

```yaml
# 以telechat-115b模型8机64卡训练为例,默认配置机4096卡,如果节点数有变,需要修改相应的配置。
# 以telechat-115B模型8机64卡训练为例,默认配置机4096卡,如果节点数有变,需要修改相应的配置。
# 配置文件路径:finetune_telechat_115b.yaml
parallel_config:
data_parallel: 1


+ 1
- 1
tests/st/test_multi_cards_cases/test_optimizer/test_pma/test_pma.py View File

@@ -25,7 +25,7 @@ from mindformers.tools.logger import logger


_LEVEL_0_TASK_TIME = 0
_LEVEL_1_TASK_TIME = 124
_LEVEL_1_TASK_TIME = 436
_TASK_TYPE = TaskType.FOUR_CARDS_TASK




+ 14
- 0
tests/st/test_run_check.py View File

@@ -1,3 +1,17 @@
# Copyright 2025 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test for run_check function"""
import pytest
from mindformers import run_check


+ 311
- 16
tests/st/test_safetensors/test_checkpoint_utils.py View File

@@ -14,6 +14,8 @@
# ============================================================================
"""test for load_checkpoint_utils."""
# pylint: disable=W0621
import os
import json
import tempfile
from unittest.mock import patch, MagicMock

@@ -22,7 +24,11 @@ import numpy as np
from mindspore import Parameter

from mindformers.tools.register import MindFormerConfig
from mindformers.checkpoint.utils import compile_model

from mindformers.checkpoint.utils import compile_model, check_checkpoints_dir_max_num, get_checkpoint_iter_dir, \
get_checkpoint_tracker_filename, get_common_filename, get_metadata_filename, \
get_latest_iteration_from_tracker, get_checkpoint_name, get_sharded_tensor_shard_id, \
sharded_tensor_shard_id, _reverse_sharded_tensor_shard_id, _get_shard_size, verify_ckpt_valid, FileType
from mindformers.models.modeling_utils import PreTrainedModel
from mindformers.utils.load_checkpoint_utils import (
CkptFormat, _get_checkpoint_mode, CheckpointFileMode, _check_checkpoint_path,
@@ -96,6 +102,8 @@ def mock_file():
return mock_f




class TestCommonCheckpointMethod:
"""A test class for testing common methods"""

@@ -805,13 +813,12 @@ class TestCommonCheckpointMethod:
optimizer=optimizer)
mock_load_safetensors_checkpoint.assert_called_once()


class TestBuildModel:
"""A test class for testing build_model"""
runner_config = {'sink_mode': True, 'epochs': 1, 'sink_size': 1}
config = {
'runner_config': runner_config,
'context': {'mode': 0} # Add context.mode to fix AttributeError
'context': {'mode': 0} # 0 is typically ms.GRAPH_MODE, 1 is ms.PYNATIVE_MODE
}
model = MagicMock()
dataset = MagicMock()
@@ -854,18 +861,16 @@ class TestBuildModel:
"""test build model infer predict layout when do predict is true"""
mock_get_auto_parallel_context.return_value = 'auto_parallel'
config = MindFormerConfig(**self.config)
model = MagicMock()
dataset = MagicMock()
compile_model(
model=model,
dataset=dataset,
model=self.model,
dataset=self.dataset,
mode=config.context.mode,
sink_mode=config.runner_config.sink_mode,
epoch=config.runner_config.epochs,
sink_size=config.runner_config.sink_size,
do_eval=False, do_predict=True
)
model.infer_predict_layout.assert_called_once_with(*dataset)
self.model.infer_predict_layout.assert_called_once_with(*self.dataset)

@patch('mindspore.context.get_auto_parallel_context')
def test_build_model_model_build(self, mock_get_auto_parallel_context):
@@ -890,15 +895,15 @@ class TestGetCheckpointMode:
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_single_checkpoint_file(self):
@patch('os.path.isfile')
@patch('os.path.isdir')
def test_single_checkpoint_file(self, mock_isdir, mock_isfile):
"""test single checkpoint file"""
with patch('os.path.isfile') as mock_isfile, \
patch('os.path.isdir') as mock_isdir:
mock_isfile.return_value = True
mock_isdir.return_value = False
config = type('', (), {})()
config.load_checkpoint = '/test/checkpoint_file.safetensors'
assert _get_checkpoint_mode(config) == CheckpointFileMode.SINGLE_CHECKPOINT_FILE.value
mock_isfile.return_value = True
mock_isdir.return_value = False
config = type('', (), {})()
config.load_checkpoint = '/test/checkpoint_file.safetensors'
assert _get_checkpoint_mode(config) == CheckpointFileMode.SINGLE_CHECKPOINT_FILE.value

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@@ -958,3 +963,293 @@ class TestGetCheckpointMode:
config.load_ckpt_format = '.safetensors'
with pytest.raises(ValueError, match="not support mode: no valid checkpoint files found"):
_get_checkpoint_mode(config)


class TestCheckpointUtils:
"""A test class for testing checkpoint utils functions"""

# Test get_checkpoint_iter_dir function
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_checkpoint_iter_dir(self):
"""test get_checkpoint_iter_dir function"""
checkpoints_path = '/test/checkpoints'
iteration = 123
result = get_checkpoint_iter_dir(checkpoints_path, iteration)
# Use os.path.normpath to handle different path separators
assert os.path.normpath(result) == os.path.normpath('/test/checkpoints/iteration_00000123')

# Test with different iteration format
iteration = 1000
result = get_checkpoint_iter_dir(checkpoints_path, iteration)
assert os.path.normpath(result) == os.path.normpath('/test/checkpoints/iteration_00001000')

# Test get_checkpoint_tracker_filename function
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_checkpoint_tracker_filename(self):
"""test get_checkpoint_tracker_filename function"""
checkpoints_path = '/test/checkpoints'
result = get_checkpoint_tracker_filename(checkpoints_path)
assert os.path.normpath(result) == os.path.normpath('/test/checkpoints/latest_checkpointed_iteration.txt')

# Test get_common_filename function
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_common_filename(self):
"""test get_common_filename function"""
checkpoints_path = '/test/checkpoints'
iteration = 123
result = get_common_filename(checkpoints_path, iteration)
assert os.path.normpath(result) == os.path.normpath('/test/checkpoints/iteration_00000123/common.json')

# Test get_metadata_filename function
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_metadata_filename(self):
"""test get_metadata_filename function"""
checkpoints_path = '/test/checkpoints'
iteration = 123
result = get_metadata_filename(checkpoints_path, iteration)
assert os.path.normpath(result) == os.path.normpath('/test/checkpoints/iteration_00000123/metadata.json')

# Test get_checkpoint_name function
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_checkpoint_name(self):
"""test get_checkpoint_name function"""
# Test with cur_iter_checkpoint_dir and user_prefix
cur_iter_checkpoint_dir = '/test/checkpoints/iteration_00000123'
user_prefix = 'model'
file_idx = 0
total_file_num = 1
file_type = FileType.MODEL
result = get_checkpoint_name(cur_iter_checkpoint_dir, user_prefix, file_idx, total_file_num, file_type)
expected = '/test/checkpoints/iteration_00000123/model-model-0000000-0000001'
assert os.path.normpath(result) == os.path.normpath(expected)

# Test with optimizer type
file_type = FileType.OPTIMIZER
result = get_checkpoint_name(cur_iter_checkpoint_dir, user_prefix, file_idx, total_file_num, file_type)
expected = '/test/checkpoints/iteration_00000123/model-opt-0000000-0000001'
assert os.path.normpath(result) == os.path.normpath(expected)

# Test without user_prefix
result = get_checkpoint_name(cur_iter_checkpoint_dir, None, file_idx, total_file_num, FileType.MODEL)
expected = '/test/checkpoints/iteration_00000123/model-0000000-0000001'
assert os.path.normpath(result) == os.path.normpath(expected)

# Test without cur_iter_checkpoint_dir
result = get_checkpoint_name(None, None, file_idx, total_file_num, FileType.MODEL)
expected = 'model-0000000-0000001'
assert result == expected

# Test sharded tensor related functions
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_sharded_tensor_shard_id_functions(self):
"""test sharded tensor shard id functions"""
param_name = 'model.layer.weight'
global_offset = (100, 200)

# Test get_sharded_tensor_shard_id
shard_id1 = get_sharded_tensor_shard_id(param_name, global_offset)
expected = "('model.layer.weight', (100, 200))"
assert shard_id1 == expected

# Test sharded_tensor_shard_id (duplicate function)
shard_id2 = sharded_tensor_shard_id(param_name, global_offset)
assert shard_id2 == expected
assert shard_id1 == shard_id2

# Test _reverse_sharded_tensor_shard_id function
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_reverse_sharded_tensor_shard_id(self):
"""test _reverse_sharded_tensor_shard_id function"""
# Test normal case
shard_id = "('model.layer.weight', (100, 200))"
param_name, global_offset = _reverse_sharded_tensor_shard_id(shard_id)
assert param_name == 'model.layer.weight'
assert global_offset == (100, 200)

# Test with empty offset
shard_id = "('model.layer.weight', ())"
param_name, global_offset = _reverse_sharded_tensor_shard_id(shard_id)
assert param_name == 'model.layer.weight'
assert not global_offset

# Test with single element offset
shard_id = "('model.layer.weight', (50,))"
param_name, global_offset = _reverse_sharded_tensor_shard_id(shard_id)
assert param_name == 'model.layer.weight'
assert global_offset == (50,)

# Test invalid shard id
with pytest.raises(ValueError):
_reverse_sharded_tensor_shard_id("invalid_shard_id")

# Test _get_shard_size function
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_shard_size(self):
"""test _get_shard_size function"""
# Test with float32 type (32 bits per float32)
local_shape = (100, 200)
dtype = 'Float32'
expected = 100 * 200 * 32 # 100*200 elements * 32 bits each
assert _get_shard_size(local_shape, dtype) == expected

# Test with int8 type (8 bits per int8)
dtype = 'Int8'
expected = 100 * 200 * 8 # 100*200 elements * 8 bits each
assert _get_shard_size(local_shape, dtype) == expected

# Test with unknown dtype (should default to 16 bits)
dtype = 'UnknownType'
expected = 100 * 200 * 16 # 100*200 elements * 16 bits default
assert _get_shard_size(local_shape, dtype) == expected

# Test with empty shape
local_shape = ()
dtype = 'Float32'
expected = 1 * 32 # scalar, 32 bits
assert _get_shard_size(local_shape, dtype) == expected

# Test verify_ckpt_valid function with tmp_path
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_verify_ckpt_valid(self, tmp_path):
"""test verify_ckpt_valid function"""
# Test with valid directory containing safetensors file
ckpt_dir = tmp_path / "valid_ckpt"
ckpt_dir.mkdir()
safetensor_file = ckpt_dir / "model.safetensors"
safetensor_file.touch()
assert verify_ckpt_valid(str(ckpt_dir)) is None

# Test with valid directory containing metadata and safetensors file
ckpt_dir = tmp_path / "valid_ckpt_with_metadata"
ckpt_dir.mkdir()
metadata_path = ckpt_dir / "metadata.json"
metadata_content = {
"storage_data": {
"param1": [{
"file_name": "model.safetensors"
}]
}
}

with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata_content, f)
safetensor_file = ckpt_dir / "model.safetensors"
safetensor_file.touch()
assert verify_ckpt_valid(str(ckpt_dir)) is None

# Test with invalid directory (not exists)
with pytest.raises(NotADirectoryError):
verify_ckpt_valid("/non/existent/directory")

# Test with directory containing no files
ckpt_dir = tmp_path / "empty_ckpt"
ckpt_dir.mkdir()
with pytest.raises(FileNotFoundError):
verify_ckpt_valid(str(ckpt_dir))

# Test with metadata referencing missing safetensors file
ckpt_dir = tmp_path / "invalid_metadata_ckpt"
ckpt_dir.mkdir()
metadata_path = ckpt_dir / "metadata.json"
metadata_content = {
"storage_data": {
"param1": [{
"file_name": "missing.safetensors"
}]
}
}
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata_content, f)
with pytest.raises(FileNotFoundError):
verify_ckpt_valid(str(ckpt_dir))

# Test with invalid metadata json
ckpt_dir = tmp_path / "invalid_json_ckpt"
ckpt_dir.mkdir()
metadata_path = ckpt_dir / "metadata.json"
with open(metadata_path, 'w', encoding='utf-8') as f:
f.write("invalid json content")
with pytest.raises(RuntimeError):
verify_ckpt_valid(str(ckpt_dir))

# Test check_checkpoints_dir_max_num function with tmp_path
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_checkpoints_dir_max_num(self, tmp_path):
"""test check_checkpoints_dir_max_num function"""
# Create test directory structure
checkpoints_root_path = tmp_path / "checkpoints"
checkpoints_root_path.mkdir()

# Create more directories than max_keep_num
for i in range(5):
dir_path = checkpoints_root_path / f"iteration_{i:08d}"
dir_path.mkdir()

# Test with max_keep_num = 3, should keep newest 3 directories
check_checkpoints_dir_max_num(3, str(checkpoints_root_path))

# Verify only 3 directories remain
remaining_dirs = list(checkpoints_root_path.iterdir())
remaining_dirs.sort()
assert len(remaining_dirs) == 3
assert [d.name for d in remaining_dirs] == ["iteration_00000002", "iteration_00000003", "iteration_00000004"]

# Test with max_keep_num larger than existing directories
check_checkpoints_dir_max_num(10, str(checkpoints_root_path))
remaining_dirs = list(checkpoints_root_path.iterdir())
assert len(remaining_dirs) == 3

# Test get_latest_iteration_from_tracker function with tmp_path
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_latest_iteration_from_tracker(self, tmp_path):
"""test get_latest_iteration_from_tracker function"""
checkpoints_path = tmp_path / "checkpoints"
checkpoints_path.mkdir()

# Create tracker file
tracker_path = checkpoints_path / "latest_checkpointed_iteration.txt"
tracker_path.write_text("123", encoding="utf-8")

# Create corresponding directory
iter_dir = checkpoints_path / "iteration_00000123"
iter_dir.mkdir()

# Test normal case
assert get_latest_iteration_from_tracker(str(checkpoints_path)) == 123

# Test with missing tracker file
tracker_path.unlink()
with pytest.raises(FileNotFoundError):
get_latest_iteration_from_tracker(str(checkpoints_path))

# Test with invalid iteration number in tracker file
tracker_path.write_text("invalid_iter", encoding="utf-8")
with pytest.raises(ValueError):
get_latest_iteration_from_tracker(str(checkpoints_path))

# Test with missing iteration directory
tracker_path.write_text("456", encoding="utf-8")
with pytest.raises(FileNotFoundError):
get_latest_iteration_from_tracker(str(checkpoints_path))

+ 644
- 0
tests/st/test_ut/test_checkpoint/test_fully_parallel.py View File

@@ -0,0 +1,644 @@
# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test for fully_parallel.py"""
# pylint: disable=W0621, W0212, W0613
import os
from unittest.mock import patch, MagicMock

import pytest
from mindspore import nn

from mindformers.checkpoint.utils import FileType
from mindformers.checkpoint.fully_parallel import (
BalancedSaveStrategy,
distribute_shards,
apply_balance_shard_strategy
)


class MockShardTensor:
"""Mock ShardTensor class for testing"""

def __init__(self, key, global_offset, local_shape, dtype, size=100):
self.key = key
self.global_offset = global_offset
self.local_shape = local_shape
self.dtype = dtype
self.size = size


@pytest.fixture
def mock_network():
"""Create a mock network for testing"""
network = MagicMock(spec=nn.Cell)
return network


@pytest.fixture
def mock_get_all_sharded_tensor():
"""Mock get_all_sharded_tensor function"""
mock_shard_tensor1 = MockShardTensor("param1", (0,), (10,), "float32")
mock_shard_tensor2 = MockShardTensor("param2", (10,), (10,), "float32")
mock_shard_tensor3 = MockShardTensor("param3", (0,), (10,), "float32")

with patch("mindformers.checkpoint.fully_parallel.get_all_sharded_tensor") as mock:
mock.return_value = [
[mock_shard_tensor1, mock_shard_tensor2],
[mock_shard_tensor3]
]
yield mock


@pytest.fixture
def mock_get_rank():
"""Mock get_rank function"""
with patch("mindformers.checkpoint.fully_parallel.get_rank") as mock:
mock.return_value = 0
yield mock


@pytest.fixture
def mock_get_real_local_rank():
"""Mock get_real_local_rank function"""
with patch("mindformers.checkpoint.fully_parallel.get_real_local_rank") as mock:
mock.return_value = 0
yield mock


@pytest.fixture
def mock_save_checkpoint():
"""Mock save_checkpoint function"""
with patch("mindformers.checkpoint.fully_parallel.save_checkpoint") as mock:
yield mock


@pytest.fixture
def mock_get_metadata_filename():
"""Mock get_metadata_filename function"""
with patch("mindformers.checkpoint.fully_parallel.get_metadata_filename") as mock:
mock.return_value = "metadata.json"
yield mock


@pytest.fixture
def mock_get_checkpoint_name():
"""Mock get_checkpoint_name function"""
with patch("mindformers.checkpoint.fully_parallel.get_checkpoint_name") as mock:
mock.return_value = "checkpoint_0-2"
yield mock


@pytest.fixture
def mock_get_checkpoint_iter_dir():
"""Mock get_checkpoint_iter_dir function"""
with patch("mindformers.checkpoint.fully_parallel.get_checkpoint_iter_dir") as mock:
mock.return_value = "./checkpoint_iter_0"
yield mock


@pytest.fixture
def mock_save_metadata():
"""Mock save_metadata function"""
with patch("mindformers.checkpoint.fully_parallel.save_metadata") as mock:
yield mock


@pytest.fixture
def mock_load_metadata():
"""Mock load_metadata function"""
with patch("mindformers.checkpoint.fully_parallel.load_metadata") as mock:
mock.return_value = ({}, {})
yield mock


@pytest.fixture
def mock_reverse_sharded_tensor_shard_id():
"""Mock _reverse_sharded_tensor_shard_id function"""
with patch("mindformers.checkpoint.fully_parallel._reverse_sharded_tensor_shard_id") as mock:
mock.return_value = "param1"
yield mock


@pytest.fixture
def mock_sharded_tensor_shard_id():
"""Mock sharded_tensor_shard_id function"""
with patch("mindformers.checkpoint.fully_parallel.sharded_tensor_shard_id") as mock:
mock.side_effect = lambda key, offset: f"{key}_{offset}"
yield mock


@pytest.fixture
def mock_get_shard_size():
"""Mock _get_shard_size function"""
with patch("mindformers.checkpoint.fully_parallel._get_shard_size") as mock:
mock.return_value = 100
yield mock


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_distribute_shards_basic():
"""
Feature: distribute_shards function basic functionality
Description: Test distribute_shards function with basic input data, including different shard coverage and sizes
Expectation: All shards are assigned to valid ranks, and each shard is assigned to a rank that covers it
"""
shard_coverage = {
"shard1": [0, 1],
"shard2": [0],
"shard3": [1]
}
shard_sizes = {
"shard1": 100,
"shard2": 200,
"shard3": 150
}
total_ranks = 2

result = distribute_shards(shard_coverage, shard_sizes, total_ranks)

# Check that all shards are assigned
assert len(result) == 3
# Check that each shard is assigned to a valid rank
for rank in result.values():
assert 0 <= rank < total_ranks
# Check that shards are assigned to ranks that cover them
for shard_id, rank in result.items():
assert rank in shard_coverage[shard_id]


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_distribute_shards_empty():
"""
Feature: distribute_shards function with empty input
Description: Test distribute_shards function when shard_coverage and shard_sizes are empty
Expectation: Return an empty dictionary
"""
shard_coverage = {}
shard_sizes = {}
total_ranks = 2

result = distribute_shards(shard_coverage, shard_sizes, total_ranks)

assert result == {}


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_distribute_shards_single_rank():
"""
Feature: distribute_shards function with single rank
Description: Test distribute_shards function when there is only one rank available
Expectation: All shards are assigned to the single rank
"""
shard_coverage = {
"shard1": [0],
"shard2": [0]
}
shard_sizes = {
"shard1": 100,
"shard2": 200
}
total_ranks = 1

result = distribute_shards(shard_coverage, shard_sizes, total_ranks)

assert result == {"shard1": 0, "shard2": 0}


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_apply_balance_shard_strategy(
mock_network, mock_get_all_sharded_tensor, mock_get_real_local_rank,
mock_sharded_tensor_shard_id, mock_get_shard_size
):
"""
Feature: apply_balance_shard_strategy function
Description: Test apply_balance_shard_strategy function with mock network and related fixtures
Expectation: Return three dictionaries: shard_to_saving_rank, shard_id_to_tensor, and dst_sharded_tensor_metas
"""
result = apply_balance_shard_strategy(mock_network, None)

assert len(result) == 3
shard_to_saving_rank, shard_id_to_tensor, dst_sharded_tensor_metas = result

assert isinstance(shard_to_saving_rank, dict)
assert isinstance(shard_id_to_tensor, dict)
assert isinstance(dst_sharded_tensor_metas, dict)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balanced_save_strategy_init(mock_network, mock_get_rank):
"""
Feature: BalancedSaveStrategy initialization
Description: Test BalancedSaveStrategy class initialization with various parameters
Expectation: All attributes are correctly set according to the input parameters
"""

strategy = BalancedSaveStrategy(
network=mock_network,
user_prefix="test",
do_cache_distribution=True,
checkpoint_path="./checkpoint"
)

assert strategy.network == mock_network
assert strategy.user_prefix == "test"
assert strategy.do_cache_distribution is True
assert strategy.cached_distribution is None
assert strategy.checkpoint_path == "./checkpoint"
assert strategy.file_type == FileType.MODEL


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balanced_save_strategy_apply_saving_parallelization(
mock_network, mock_get_rank, mock_get_all_sharded_tensor,
mock_sharded_tensor_shard_id, mock_get_shard_size
):
"""
Feature: BalancedSaveStrategy.apply_saving_parallelization method
Description: Test apply_saving_parallelization method without cache
Expectation: Return a tuple of two dictionaries: shared_distribution and id_to_tensor
"""
strategy = BalancedSaveStrategy(
network=mock_network,
checkpoint_path="./checkpoint"
)

result = strategy.apply_saving_parallelization()

assert len(result) == 2
shard_id_to_ranks, shard_id_to_tensor = result
assert isinstance(shard_id_to_ranks, dict)
assert isinstance(shard_id_to_tensor, dict)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balanced_save_strategy_apply_saving_parallelization_with_cache(
mock_network, mock_get_rank, mock_get_all_sharded_tensor,
mock_sharded_tensor_shard_id, mock_get_shard_size
):
"""
Feature: BalancedSaveStrategy.apply_saving_parallelization method with cache
Description: Test apply_saving_parallelization method with cache enabled
Expectation: First call computes distribution, second call uses cached distribution without recomputing
"""
strategy = BalancedSaveStrategy(
network=mock_network,
do_cache_distribution=True,
checkpoint_path="./checkpoint"
)

# First call - should compute distribution
result1 = strategy.apply_saving_parallelization()

# Second call - should use cached distribution
with patch("mindformers.checkpoint.fully_parallel.apply_balance_shard_strategy") as mock_apply:
mock_apply.return_value = ({}, {})
result2 = strategy.apply_saving_parallelization()
# Check that apply_balance_shard_strategy was not called again
mock_apply.assert_not_called()

assert result1 == result2


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balanced_save_strategy_get_total_files(
mock_network, mock_get_rank, mock_get_all_sharded_tensor,
mock_sharded_tensor_shard_id, mock_get_shard_size
):
"""
Feature: BalancedSaveStrategy.get_total_files method
Description: Test get_total_files method to get the total number of checkpoint files
Expectation: Return a non-negative integer representing the total number of files
"""
strategy = BalancedSaveStrategy(
network=mock_network,
checkpoint_path="./checkpoint"
)

total_files = strategy.get_total_files()

assert isinstance(total_files, int)
assert total_files >= 0


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balanced_save_strategy_get_cur_rank_file_id(
mock_network, mock_get_rank, mock_get_all_sharded_tensor,
mock_sharded_tensor_shard_id, mock_get_shard_size
):
"""
Feature: BalancedSaveStrategy.get_cur_rank_file_id method
Description: Test get_cur_rank_file_id method to get the current rank's file ID
Expectation: Return a non-negative integer representing the current rank's file ID
"""
strategy = BalancedSaveStrategy(
network=mock_network,
checkpoint_path="./checkpoint"
)

cur_rank_file_id = strategy.get_cur_rank_file_id()

assert isinstance(cur_rank_file_id, int)
assert cur_rank_file_id >= 0


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balanced_save_strategy_save(
tmp_path, mock_network, mock_get_rank, mock_get_all_sharded_tensor,
mock_sharded_tensor_shard_id, mock_get_shard_size, mock_save_checkpoint,
mock_get_metadata_filename, mock_get_checkpoint_name, mock_get_checkpoint_iter_dir,
mock_save_metadata, mock_load_metadata, mock_reverse_sharded_tensor_shard_id
):
"""
Feature: BalancedSaveStrategy.save method
Description: Test save method to save model checkpoint without existing metadata
Expectation: save_checkpoint is called, get_checkpoint_iter_dir is called, get_checkpoint_name is called
"""
checkpoint_path = str(tmp_path / "checkpoint")
os.makedirs(checkpoint_path, exist_ok=True)

strategy = BalancedSaveStrategy(
network=mock_network,
checkpoint_path=checkpoint_path
)

with patch("mindformers.checkpoint.fully_parallel.os.path.exists", return_value=False):
strategy.save(0)

# Check that save_checkpoint was called
mock_save_checkpoint.assert_called_once()
# Check that get_checkpoint_iter_dir was called
mock_get_checkpoint_iter_dir.assert_called_once_with(checkpoint_path, 0)
# Check that get_checkpoint_name was called
mock_get_checkpoint_name.assert_called()


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balanced_save_strategy_save_with_existing_metadata(
tmp_path, mock_network, mock_get_rank, mock_get_all_sharded_tensor,
mock_sharded_tensor_shard_id, mock_get_shard_size, mock_save_checkpoint,
mock_get_metadata_filename, mock_get_checkpoint_name, mock_get_checkpoint_iter_dir,
mock_save_metadata, mock_reverse_sharded_tensor_shard_id
):
"""
Feature: BalancedSaveStrategy.save method with existing metadata
Description: Test save method to save model checkpoint with existing metadata file
Expectation: save_checkpoint is called, load_metadata is called
"""
checkpoint_path = str(tmp_path / "checkpoint")
os.makedirs(checkpoint_path, exist_ok=True)

# Create a mock metadata file
metadata_file = os.path.join(checkpoint_path, "metadata.json")
with open(metadata_file, "w", encoding="utf-8") as f:
f.write("{}")

strategy = BalancedSaveStrategy(
network=mock_network,
checkpoint_path=checkpoint_path
)

with patch("mindformers.checkpoint.fully_parallel.os.path.exists", return_value=True):
with patch("mindformers.checkpoint.fully_parallel.load_metadata") as mock_load:
mock_load.return_value = ({"shard1": MagicMock()},
{"param1": [{"file_name": "test.safetensors", "storage_rank": 0}]})
strategy.save(0)

# Check that save_checkpoint was called
mock_save_checkpoint.assert_called_once()
# Check that load_metadata was called
mock_load.assert_called_once()


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balanced_save_strategy__get_rank_params_mappings(
mock_network, mock_get_rank
):
"""
Feature: BalancedSaveStrategy._get_rank_params_mappings method
Description: Test _get_rank_params_mappings method to create mapping from rank IDs to parameter names
Expectation: Return a dictionary mapping rank IDs to lists of parameter names
"""
strategy = BalancedSaveStrategy(
network=mock_network,
checkpoint_path="./checkpoint"
)

# Create mock data
shared_distribution = {
"shard1": 0,
"shard2": 1,
"shard3": 0
}

mock_tensor1 = MagicMock()
mock_tensor1.key = "param1"
mock_tensor2 = MagicMock()
mock_tensor2.key = "param2"
mock_tensor3 = MagicMock()
mock_tensor3.key = "param3"

id_to_tensor = {
"shard1": mock_tensor1,
"shard2": mock_tensor2,
"shard3": mock_tensor3
}

result = strategy._get_rank_params_mappings(shared_distribution, id_to_tensor)

assert isinstance(result, dict)
assert 0 in result
assert 1 in result
assert "param1" in result[0]
assert "param3" in result[0]
assert "param2" in result[1]


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balanced_save_strategy__get_rank_param_ids_mappings(
mock_network, mock_get_rank
):
"""
Feature: BalancedSaveStrategy._get_rank_param_ids_mappings method
Description: Test _get_rank_param_ids_mappings method to create mapping from rank IDs to parameter IDs
Expectation: Return a dictionary mapping rank IDs to lists of parameter IDs
"""
strategy = BalancedSaveStrategy(
network=mock_network,
checkpoint_path="./checkpoint"
)

# Create mock data
shared_distribution = {
"shard1": 0,
"shard2": 1,
"shard3": 0
}

result = strategy._get_rank_param_ids_mappings(shared_distribution)

assert isinstance(result, dict)
assert 0 in result
assert 1 in result
assert "shard1" in result[0]
assert "shard3" in result[0]
assert "shard2" in result[1]


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balanced_save_strategy__get_total_files_num(
mock_network, mock_get_rank
):
"""
Feature: BalancedSaveStrategy._get_total_files_num method
Description: Test _get_total_files_num method to calculate total number of files based on rank params mappings
Expectation: Return the correct number of files based on the input mappings
"""
strategy = BalancedSaveStrategy(
network=mock_network,
checkpoint_path="./checkpoint"
)

# Test with non-empty params
rank_params_mappings = {
0: ["param1", "param2"],
1: ["param3"],
2: []
}

result = strategy._get_total_files_num(rank_params_mappings)
assert result == 2

# Test with all empty params
rank_params_mappings = {
0: [],
1: []
}

result = strategy._get_total_files_num(rank_params_mappings)
assert result == 0


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balanced_save_strategy__get_cur_rank_file_id(
mock_network, mock_get_rank
):
"""
Feature: BalancedSaveStrategy._get_cur_rank_file_id method
Description: Test _get_cur_rank_file_id method to get the current rank's file ID based on rank params mappings
Expectation: Return the correct file ID for the current rank based on the input mappings
"""
strategy = BalancedSaveStrategy(
network=mock_network,
checkpoint_path="./checkpoint"
)

# Test when current rank has params
rank_params_mappings = {
0: [],
1: ["param1"],
2: ["param2"]
}

with patch.object(strategy, 'rank_id', 1):
result = strategy._get_cur_rank_file_id(rank_params_mappings)
assert result == 0

# Test when current rank has no params
rank_params_mappings = {
0: ["param1"],
1: [],
2: ["param2"]
}

with patch.object(strategy, 'rank_id', 1):
result = strategy._get_cur_rank_file_id(rank_params_mappings)
assert result == 1

# Test when current rank is not in mappings
rank_params_mappings = {
0: ["param1"],
2: ["param2"]
}

with patch.object(strategy, 'rank_id', 1):
result = strategy._get_cur_rank_file_id(rank_params_mappings)
assert result is None


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balanced_save_strategy_get_total_files_and_cur_rank_file_id(
mock_network, mock_get_rank, mock_get_all_sharded_tensor,
mock_sharded_tensor_shard_id, mock_get_shard_size
):
"""
Feature: BalancedSaveStrategy.get_total_files and get_cur_rank_file_id methods with caching
Description: Test that calling get_total_files and get_cur_rank_file_id caches the results
Expectation: Second calls to these methods should use cached values without recomputing
"""
strategy = BalancedSaveStrategy(
network=mock_network,
checkpoint_path="./checkpoint"
)

total_files = strategy.get_total_files()
cur_rank_file_id = strategy.get_cur_rank_file_id()

# Check that values are cached
assert strategy.total_files_num == total_files
assert strategy.cur_rank_file_id == cur_rank_file_id

# Check that second calls use cached values
with patch("mindformers.checkpoint.fully_parallel.apply_balance_shard_strategy") as mock_apply:
mock_apply.return_value = ({}, {})
total_files2 = strategy.get_total_files()
cur_rank_file_id2 = strategy.get_cur_rank_file_id()

# Mock should not be called since values are cached
mock_apply.assert_not_called()

assert total_files2 == total_files
assert cur_rank_file_id2 == cur_rank_file_id

+ 0
- 24
tests/st/test_ut/test_core/test_callback/test_all_reduce.py View File

@@ -63,30 +63,6 @@ class TestHelperFunctions(unittest.TestCase):
self.assertEqual(loss, 0.8)
self.assertFalse(overflow)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_weight_norm(self):
"""Test _get_weight_norm function"""
# Create mock network
mock_network = Mock()
param1 = Mock()
param1.to.return_value = param1
param1.norm.return_value = Tensor(np.array([2.0]))
param2 = Mock()
param2.to.return_value = param2
param2.norm.return_value = Tensor(np.array([3.0]))

mock_network.trainable_params.return_value = [param1, param2]

with patch('mindspore.ops.functional.stack') as mock_stack:
mock_stack.return_value = Tensor(np.array([3.605551]))

# pylint: disable=W0212
norm = callback_module._get_weight_norm(mock_network)

self.assertAlmostEqual(norm, 3.605551, places=5)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard


+ 43
- 35
tests/st/test_ut/test_core/test_context/test_build_context.py View File

@@ -326,14 +326,16 @@ def test_build_parallel_context():

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.context.build_context.get_real_local_rank')
@patch('mindformers.core.context.build_context.ms.runtime.set_cpu_affinity')
def test_set_ms_affinity_with_affinity_config(mock_set_affinity, mock_rank):
@patch('mindformers.tools.utils.get_real_group_size')
@patch('mindformers.tools.utils.get_real_local_rank')
@patch('mindspore.runtime.set_cpu_affinity')
def test_set_ms_affinity_with_affinity_config(mock_set_affinity, mock_rank, mock_group_size):
"""
Feature: Test set_ms_affinity with affinity_config.
Description: Verify affinity_config overrides affinity_cpu_list and passes module config.
Expectation: MindSpore set_cpu_affinity called with config values.
"""
mock_group_size.return_value = 8
mock_rank.return_value = 1
affinity_config = {
'device_1': {
@@ -351,14 +353,16 @@ def test_set_ms_affinity_with_affinity_config(mock_set_affinity, mock_rank):

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.context.build_context.get_real_local_rank')
@patch('mindformers.core.context.build_context.ms.runtime.set_cpu_affinity')
def test_set_ms_affinity_without_device_entry(mock_set_affinity, mock_rank):
@patch('mindformers.tools.utils.get_real_group_size')
@patch('mindformers.tools.utils.get_real_local_rank')
@patch('mindspore.runtime.set_cpu_affinity')
def test_set_ms_affinity_without_device_entry(mock_set_affinity, mock_rank, mock_group_size):
"""
Feature: Test set_ms_affinity when device entry missing.
Description: Verify defaults are used when affinity_config lacks device info.
Expectation: MindSpore set_cpu_affinity called with None values.
"""
mock_group_size.return_value = 8
mock_rank.return_value = 0
affinity_config = {
'device_1': {
@@ -376,43 +380,47 @@ def test_set_ms_affinity_without_device_entry(mock_set_affinity, mock_rank):

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.context.build_context.get_cann_workqueue_cores', return_value=[0, 1])
@patch('mindformers.core.context.build_context.psutil.Process')
@patch('mindformers.core.context.build_context.psutil.cpu_count', return_value=8)
@patch('mindformers.core.context.build_context.ds.config.set_numa_enable')
def test_set_cpu_affinity_bind_available_cpus(mock_set_numa, mock_cpu_count,
mock_process_cls, mock_get_cores,
monkeypatch):
"""
Feature: Test set_cpu_affinity binding behavior.
Description: Verify CPU affinity excludes CANN workqueue cores when available.
Expectation: Process cpu_affinity receives filtered CPU list.
@patch('mindformers.tools.utils.get_real_group_size')
@patch('mindformers.tools.utils.get_real_local_rank')
def test_set_ms_affinity_with_invalid_device_id(mock_rank, mock_group_size):
"""
monkeypatch.setenv('CPU_AFFINITY', 'True')
process_mock = mock_process_cls.return_value

set_cpu_affinity(rank_id=0, rank_size=2)

mock_set_numa.assert_called_once_with(True)
mock_cpu_count.assert_called_once()
mock_get_cores.assert_called_once_with(0)
process_mock.cpu_affinity.assert_called_once_with([2, 3])
Feature: Test set_ms_affinity when device entry missing.
Description: Verify defaults are used when affinity_config lacks device info.
Expectation: MindSpore set_cpu_affinity called with None values.
"""
mock_group_size.return_value = 1
mock_rank.return_value = 0
affinity_config = {
'device_1': {
'affinity_cpu_list': [4, 5],
'module_to_cpu_dict': {'module_a': [6, 7]}
}
}
with pytest.raises(ValueError) as exc_info:
set_ms_affinity(affinity_config, None)
assert 'Invalid device id' in str(exc_info.value)


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.context.build_context.get_cann_workqueue_cores', return_value=[0, 1, 2, 3])
@patch('mindformers.core.context.build_context.psutil.Process')
@patch('mindformers.core.context.build_context.psutil.cpu_count', return_value=8)
@patch('mindformers.core.context.build_context.ds.config.set_numa_enable')
def test_set_cpu_affinity_fallback_when_all_cores_taken(mock_set_numa, mock_cpu_count,
mock_process_cls, mock_get_cores,
monkeypatch):
@pytest.mark.parametrize('cann_workqueue_cores, cpu_affinity', [
([0, 1], [2, 3]),
([0, 1, 2, 3], [0, 1, 2, 3])
])
@patch('mindformers.utils.get_cann_workqueue_cores')
@patch('psutil.Process')
@patch('psutil.cpu_count', return_value=8)
@patch('mindspore.dataset.config.set_numa_enable')
def test_set_cpu_affinity(
mock_set_numa, mock_cpu_count, mock_process_cls, mock_get_cores,
monkeypatch, cann_workqueue_cores, cpu_affinity):
"""
Feature: Test set_cpu_affinity fallback behavior.
Description: Verify original CPU list is used when CANN occupies all candidate cores.
Description: Verify that the original CPU list is used when CANN occupies all candidate cores,
and CPU affinity excludes CANN workqueue cores when available.
Expectation: Process cpu_affinity receives unfiltered CPU list.
"""
mock_get_cores.return_value = cann_workqueue_cores
monkeypatch.setenv('CPU_AFFINITY', 'True')
process_mock = mock_process_cls.return_value

@@ -421,4 +429,4 @@ def test_set_cpu_affinity_fallback_when_all_cores_taken(mock_set_numa, mock_cpu_
mock_set_numa.assert_called_once_with(True)
mock_cpu_count.assert_called_once()
mock_get_cores.assert_called_once_with(0)
process_mock.cpu_affinity.assert_called_once_with([0, 1, 2, 3])
process_mock.cpu_affinity.assert_called_once_with(cpu_affinity)

+ 3
- 1
tests/st/test_ut/test_dataset/test_dataloader/test_blended_megatron_dataset_builder.py View File

@@ -23,7 +23,7 @@ from mindformers.dataset.blended_datasets.blended_megatron_dataset_builder impor
_get_size_per_split_per_dataset
)
from mindformers.dataset.blended_datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
from mindformers.dataset.blended_datasets.utils import Split
from mindformers.dataset.blended_datasets.utils import Split, compile_helpers


class DummyTokenizer:
@@ -522,6 +522,7 @@ class TestBlendedMegatronDatasetBuilder:
Description: Test build method works with blend configuration having weights and size
Expectation: Method builds datasets correctly with weights processing
"""
compile_helpers()
config = create_test_config()
config.mock = False
config.blend = (["prefix1", "prefix2"], [0.3, 0.7])
@@ -645,6 +646,7 @@ class TestBlendedMegatronDatasetBuilder:
Description: Test parallel building of megatron datasets
Expectation: Method builds datasets in parallel correctly
"""
compile_helpers()
config = create_test_config()
config.mock = False
config.blend = (["prefix1", "prefix2"], [0.5, 0.5])


+ 41
- 0
tests/st/test_ut/test_generation/qwen3_0_6b_infer.yaml View File

@@ -0,0 +1,41 @@
seed: 0
output_dir: './output' # path to save checkpoint/strategy
load_checkpoint: ''
use_parallel: False
run_mode: 'predict'
use_legacy: False
load_ckpt_format: 'safetensors'

trainer:
type: CausalLanguageModelingTrainer
model_name: 'qwen3'

# default parallel of device num = 8 for Atlas 800T A2
parallel_config:
data_parallel: 1
model_parallel: 1
pretrained_model_dir: '/home/workspace/mindspore_dataset/weight/Qwen3-0.6B'
generation:
max_length: 128
model:
model_config:
compute_dtype: "bfloat16"
layernorm_compute_dtype: "float32"
softmax_compute_dtype: "float32"
rotary_dtype: "bfloat16"
params_dtype: "bfloat16"

# mindspore context init config
context:
mode: 0 #0--Graph Mode; 1--Pynative Mode
enable_graph_kernel: False
ascend_config:
precision_mode: "must_keep_origin_dtype"
max_device_memory: "29GB"
save_graphs: False
save_graphs_path: "./graph"

# parallel context config
parallel:
parallel_mode: "MANUAL_PARALLEL"
enable_alltoall: False

+ 286
- 0
tests/st/test_ut/test_generation/test_parallel_decoding.py View File

@@ -0,0 +1,286 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test parallel decoding"""
import numpy as np
import pytest

import mindspore as ms
from mindspore import Tensor

from mindformers.generation.parallel_decoding import (
_logits_process,
_pre_process,
_la_logits_process,
_la_pre_process,
_memory_decoding_pre_process,
_prefix_cache_pre_process,
parallel_decoding_control,
parallel_decoding_logits_process,
_construct_mask,
_parallel_decoding_pad,
_parallel_decoding_pad_2d_tensor
)


class MockConfig:
def __init__(self, parallel_decoding=None):
if parallel_decoding:
self.parallel_decoding_params = {"parallel_decoding": parallel_decoding}
else:
self.parallel_decoding_params = None


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_register_decorators():
"""
Feature: parallel decoding.
Description: test a function in parallel decoding.
Expectation: success.
"""
assert 'la' in _logits_process
assert 'la' in _pre_process
assert 'memory_decoding' in _pre_process
assert 'prefix_cache' in _pre_process


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_construct_mask():
"""
Feature: parallel decoding.
Description: test a function in parallel decoding.
Expectation: success.
"""
q_seq_lens = [2, 3]
mask = _construct_mask(q_seq_lens)
expected = np.array([
[-0, 1, 1, 1, 1],
[-0, -0, 1, 1, 1],
[1, 1, -0, 1, 1],
[1, 1, -0, -0, 1],
[1, 1, -0, -0, -0]
], dtype=np.float16)
np.testing.assert_array_equal(mask, expected)


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_parallel_decoding_pad():
"""
Feature: parallel decoding.
Description: test a function in parallel decoding.
Expectation: success.
"""
arr = np.array([1, 2, 3])
padded = _parallel_decoding_pad(arr, axis=0, pad_len=5, value=-1)
expected = np.array([1, 2, 3, -1, -1])
np.testing.assert_array_equal(padded, expected)

# pad_len < current len → no change
same = _parallel_decoding_pad(arr, axis=0, pad_len=2, value=-1)
np.testing.assert_array_equal(same, arr)


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_parallel_decoding_pad_2d_tensor():
"""
Feature: parallel decoding.
Description: test a function in parallel decoding.
Expectation: success.
"""
inputs = np.array([1, 2, 3, 4, 5, 6])
lens = [2, 3]
padded = _parallel_decoding_pad_2d_tensor(inputs, pad_seq_len=4, lens=lens, value=-1)
expected = np.array([
[1, 2, -1, -1],
[3, 4, 5, -1]
])
np.testing.assert_array_equal(padded, expected)


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_la_logits_process_simple():
"""
Feature: parallel decoding.
Description: test a function in parallel decoding.
Expectation: success.
"""
logits = Tensor(np.random.rand(4, 100), ms.float32)
result = _la_logits_process(logits, None, None, False)
assert result.shape == (4, 100)


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_la_logits_process_with_q_seq_lens():
"""
Feature: parallel decoding.
Description: test a function in parallel decoding.
Expectation: success.
"""
logits = Tensor(np.random.rand(6, 100), ms.float32) # batch=2, max_seq=3
q_seq_lens = [2, 3]
block_tables = [[1, 2], [3, 4]]
result = _la_logits_process(logits, q_seq_lens, block_tables, prefill=True)
assert result.shape == (2, 100) # last token of each seq


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_la_pre_process_normal():
"""
Feature: parallel decoding.
Description: test a function in parallel decoding.
Expectation: success.
"""
config = MockConfig("la")
input_ids = Tensor([[1, 2, 3]], ms.int32)
model_inputs = {}
block_tables = np.array([[10, 11]])
slot_mapping = np.array([0, 1, 2])
q_seq_lens = [3]

out_model_inputs, out_block, out_slot = _la_pre_process(
config, input_ids, model_inputs,
block_tables=block_tables,
slot_mapping=slot_mapping,
q_seq_lens=q_seq_lens
)

assert isinstance(out_model_inputs['input_ids'], Tensor)
assert out_model_inputs['input_ids'].shape == (1, 3)
assert out_model_inputs['q_seq_lens'].shape == (1,)
assert np.array_equal(out_block, block_tables.astype(np.int32))
assert np.array_equal(out_slot, slot_mapping.astype(np.int32))


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_la_pre_process_with_max_padding():
"""
Feature: parallel decoding.
Description: test a function in parallel decoding.
Expectation: success.
"""
config = MockConfig("la")
input_ids = Tensor([[1, 2, 0, 0, 3, 4, 0, 0]], ms.int32) # shape (1,8), max_len=4, two seqs
model_inputs = {}
block_tables = np.array([[1, 2], [3, 4]])
slot_mapping = np.array([0, 1, 0, 0, 2, 3, 0, 0])
q_seq_lens = [2, 2] # each seq has 2 real tokens

out_model_inputs, _, _ = _la_pre_process(
config, input_ids, model_inputs,
block_tables=block_tables,
slot_mapping=slot_mapping,
q_seq_lens=q_seq_lens
)

# Should extract [1,2,3,4] → shape (1,4)
assert out_model_inputs['input_ids'].shape == (1, 4)
expected_ids = np.array([[1, 2, 0, 0]])
np.testing.assert_array_equal(out_model_inputs['input_ids'].asnumpy(), expected_ids)


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_la_pre_process_no_q_seq_lens():
"""
Feature: parallel decoding.
Description: test a function in parallel decoding.
Expectation: success.
"""
config = MockConfig("la")
input_ids = Tensor([[1, 2, 3]], ms.int32)
model_inputs = {}
block_tables = np.array([[10, 11]])
slot_mapping = np.array([0, 1, 2])

out_model_inputs, _, _ = _la_pre_process(
config, input_ids, model_inputs,
block_tables=block_tables,
slot_mapping=slot_mapping,
q_seq_lens=None,
valid_length_each_example=[3]
)

assert out_model_inputs['q_seq_lens'].shape == (1,)


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_memory_and_prefix_preprocess():
"""
Feature: parallel decoding.
Description: test a function in parallel decoding.
Expectation: success.
"""
config = MockConfig("memory_decoding")
input_ids = Tensor([], ms.int32)
model_inputs = {}
block_tables = np.array([0])
slot_mapping = np.array([0])

out1 = _memory_decoding_pre_process(config, input_ids, model_inputs,
block_tables=block_tables, slot_mapping=slot_mapping)
out2 = _prefix_cache_pre_process(config, input_ids, model_inputs,
block_tables=block_tables, slot_mapping=slot_mapping)

assert np.array_equal(out1[1], block_tables.astype(np.int32))
assert np.array_equal(out2[2], slot_mapping.astype(np.int32))


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_parallel_decoding_control():
"""
Feature: parallel decoding.
Description: test a function in parallel decoding.
Expectation: success.
"""
assert parallel_decoding_control(MockConfig("la")) is True
assert parallel_decoding_control(MockConfig("memory_decoding")) is True
assert parallel_decoding_control(MockConfig("prefix_cache")) is True
assert parallel_decoding_control(MockConfig("invalid")) is False
assert parallel_decoding_control(MockConfig(None)) is False


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_parallel_decoding_logits_process():
"""
Feature: parallel decoding.
Description: test a function in parallel decoding.
Expectation: success.
"""
config = MockConfig("la")
logits = Tensor(np.random.rand(2, 100), ms.float32)
result = parallel_decoding_logits_process(config, logits, None, None, False)
assert result.shape == (2, 100)

+ 56
- 0
tests/st/test_ut/test_generation/test_streamer.py View File

@@ -0,0 +1,56 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test stremer inference"""
import os
import pytest

from transformers import AutoTokenizer

from mindspore.nn.utils import no_init_parameters

from mindformers import AutoModel, build_context, MindFormerConfig
from mindformers import pipeline, TextStreamer, TextIteratorStreamer


@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_streamer():
"""
Feature: Streamer inference.
Description: Test streamer inference.
Expectation: Success.
"""
config_path = os.path.join(os.path.dirname(__file__), "qwen3_0_6b_infer.yaml")
config = MindFormerConfig(config_path)
config.use_parallel = False
config.parallel_config.model_parallel = 1
build_context(config)

inputs = ["I love Beijing, because", "请介绍北京", "生成以换行符结尾的句子"]

tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model_dir, trust_remote_code=True)

with no_init_parameters():
network = AutoModel.from_config(config)
network.load_weights(config.pretrained_model_dir)

streamer = TextStreamer(tokenizer)
text_generation_pipeline = pipeline(task="text_generation", model=network, tokenizer=tokenizer, streamer=streamer)
_ = text_generation_pipeline(inputs, max_length=64, do_sample=False, top_k=3, top_p=1)

streamer = TextIteratorStreamer(tokenizer)
text_generation_pipeline = pipeline(task="text_generation", model=network, tokenizer=tokenizer, streamer=streamer)
_ = text_generation_pipeline(inputs, max_length=64, do_sample=False, top_k=3, top_p=1)

+ 282
- 1
tests/st/test_ut/test_metrics.py View File

@@ -13,16 +13,43 @@
# limitations under the License.
# ============================================================================
"""test metric schedule."""
import pytest
import importlib
import logging

import numpy as np
import pytest

import mindspore as ms
from mindspore.common import Tensor
from mindspore.common import dtype as mstype

from mindformers.core.metric import PromptAccMetric, EmF1Metric
from mindformers.core.metric import utils as metric_utils

PIPELINE_STAGE = 1
DEFAULT_NUM_DATA = 1
DEFAULT_TOTAL_LOSS = 0.5
CONSTANT_CELL_OUTPUT = 0.5


class ConstantTensorCell:
"""Utility callable returning a constant tensor, reused across tests."""

def __init__(self, value):
self.value = value

def __call__(self, *args, **kwargs):
del args
del kwargs
return ms.Tensor(np.array([self.value], dtype=np.float32))

ms.set_context(mode=1, device_target='CPU')
# Ensure pipeline_stages is configured for tests to avoid division-by-zero
# inside PerplexityMetric initialization.
try:
ms.set_auto_parallel_context(pipeline_stages=PIPELINE_STAGE)
except RuntimeError as exc: # pragma: no cover - best effort for CI environments
logging.warning("Failed to set pipeline_stages for tests: %s", exc)


@pytest.mark.level0
@@ -76,3 +103,257 @@ def test_emf1_metric():
error = 1e-8
f1_score, em_score = 75.0, 50.0
assert abs(result.get("F1", 0) - f1_score) < error and abs(result.get("Em", 0) - em_score) < error


# ----- Additional tests to improve coverage for core.metric -----
metric_mod = importlib.import_module("mindformers.core.metric.metric")

EntityScore = metric_mod.EntityScore
PerplexityMetric = metric_mod.PerplexityMetric
ADGENMetric = metric_mod.ADGENMetric
PromptAccMetric = metric_mod.PromptAccMetric
EmF1Metric = metric_mod.EmF1Metric


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_entityscore_get_entities_and_eval():
"""
Feature: EntityScore
Description: Validate entity extraction, accumulation, and evaluation outputs.
Expectation: Evaluation returns overall metrics and per-class dict without errors.
"""
metric = EntityScore()
metric.clear()

seq = ["S-name", "B-address", "I-address", "O", "B-org", "I-org", "I-org"]
chunks = metric.get_entities_bios(seq)
assert isinstance(chunks, list)
assert len(chunks) >= 1

recall, precision, f1 = metric.compute(0, 0, 0)
assert recall == 0 and precision == 0 and f1 == 0.0

num_labels = len(metric.label2id)
batch_logits = Tensor(np.zeros((1, 3, num_labels)).astype(np.float32))
batch_labels = Tensor(np.zeros((1, 3)).astype(np.int32))
metric.update(batch_logits, batch_labels)
overall, per_class = metric.eval()
assert "precision" in overall and "recall" in overall and "f1" in overall
assert isinstance(per_class, dict)


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_perplexitymetric_non_pipeline_and_pipeline(monkeypatch):
"""
Feature: PerplexityMetric
Description: Cover behavior in both non-pipeline and pipeline-parallel modes.
Expectation: Metric outputs contain loss/PPL when applicable and handle stages safely.
"""
monkeypatch.setattr(
metric_mod,
"PerplexityCell",
lambda pipeline_parallel: ConstantTensorCell(CONSTANT_CELL_OUTPUT)
)
monkeypatch.setattr(ms, "get_auto_parallel_context", lambda key: 1 if key == "pipeline_stages" else "GRAPH_MODE")

metric = PerplexityMetric()
metric.clear()

logits = Tensor(np.random.rand(1, 1, 2).astype(np.float32))
labels = Tensor(np.array([[0]]).astype(np.int32))
mask = Tensor(np.array([[1]]).astype(np.int32))

metric.update(logits, labels, mask)
metric.update(logits, labels, mask)
# guard: if update didn't increment num_data for any reason, set values to avoid ZeroDivisionError
if getattr(metric, "num_data", 0) == 0:
metric.num_data = DEFAULT_NUM_DATA
metric.total_loss = DEFAULT_TOTAL_LOSS
result = metric.eval()
assert "loss" in result and "PPL" in result

monkeypatch.setattr(ms, "get_auto_parallel_context", lambda key: 2 if key == "pipeline_stages" else "GRAPH_MODE")
monkeypatch.setattr(metric_mod, "get_group_size", lambda: 2)
monkeypatch.setattr(metric_mod, "get_rank", lambda: 0)
metric2 = PerplexityMetric()
metric2.clear()
metric2.pipeline_parallel = True
metric2.is_last_stage = False
res = metric2.eval()
assert res is None, "Pipeline intermediate stage should return None"


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adgenmetric_empty_and_normal(monkeypatch):
"""
Feature: ADGENMetric
Description: Ensure rouge and bleu statistics accumulate for empty and normal inputs.
Expectation: Evaluation dict includes rouge-1 and bleu-4 keys.
"""
class FakeRouge:
def get_scores(self, hyp_inputs, ref_inputs):
del hyp_inputs
del ref_inputs
return [{"rouge-1": {"f": 0.5}, "rouge-2": {"f": 0.4}, "rouge-l": {"f": 0.45}}]

monkeypatch.setattr(metric_mod, "Rouge", lambda *args, **kwargs: FakeRouge())
monkeypatch.setattr(metric_mod, "sentence_bleu", lambda refs, hyp, smoothing_function=None: 0.25)

metric = ADGENMetric()
metric.clear()
metric.update([""], np.array([""]))
metric.update(["hello world"], np.array(["hello world"]))
out = metric.eval()
assert "rouge-1" in out and "bleu-4" in out


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_promptaccmetric_calculate_and_update(monkeypatch):
"""
Feature: PromptAccMetric
Description: Validate calculate/update flow with mocked loss implementation.
Expectation: Evaluation dictionary contains "Acc" field.
"""
class FakeLoss:
def __call__(self, logits, tokens, mask):
return ms.Tensor(np.array([0.1], dtype=np.float32))

monkeypatch.setattr(metric_mod, "CrossEntropyLoss", FakeLoss)

metric = PromptAccMetric()
metric.clear()

logits = Tensor(np.random.rand(1, 1, 3, 2).astype(np.float32))
input_ids = Tensor(np.random.randint(0, 2, size=(1, 3)).astype(np.int32))
input_mask = Tensor(np.ones((1, 3)).astype(np.int32))
labels = Tensor(np.array([0]).astype(np.int32))

metric.update(logits, input_ids, input_mask, labels)
out = metric.eval()
assert "Acc" in out


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_emf1_helpers_and_evaluate_pairs_edge_cases():
"""
Feature: EmF1Metric helpers
Description: Exercise helper methods and edge cases for segmentation, EM/F1 computation.
Expectation: Helper calls return expected types/values and evaluation handles empty inputs.
"""
m = EmF1Metric()
m.clear()

segs = m.mixed_segmentation("Hello world, nice!")
assert isinstance(segs, list)

# note: ASCII comma and ASCII exclamation are not listed in the implementation's
# punctuation list, so remove_punctuation preserves them; expect lowercase with comma and '!'
assert m.remove_punctuation("Hello,World!") == "hello,world!"

seq_prefix, lcs_len = m.find_lcs(list("abcdef"), list("abxyef"))
assert isinstance(seq_prefix, list)
assert isinstance(lcs_len, int)

assert m.calc_em_score(["Hello"], "Hello") == 1
assert m.calc_f1_score(["abc"], "abc") == 1.0

result, cnt = m.evaluate_pairs([], [])
assert result == {} and cnt == 0


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_perplexitycell_construct_basic():
"""
Feature: PerplexityCell construct
Description: Whitebox validation of reshape behavior and injected loss execution path.
Expectation: Construct returns tensor value and reshaped dimensions match expectations.
"""
# prepare a small batch
batch_size, seq_length, vocab = 2, 4, 5
logits = Tensor(np.random.rand(batch_size, seq_length, vocab).astype(np.float32))
labels = Tensor(np.random.randint(0, vocab, size=(batch_size, seq_length)).astype(np.int32))
mask = Tensor(np.ones((batch_size, seq_length)).astype(np.int32))

called = {}

class FakeLoss(ms.nn.Cell):
def construct(self, l_logits, l_labels, l_mask):
# record shapes seen by loss
called['logits_shape'] = tuple(l_logits.shape)
called['labels_shape'] = tuple(l_labels.shape)
called['mask_shape'] = tuple(l_mask.shape)
return ms.Tensor(np.array([0.42], dtype=np.float32))

# create cell and override the runtime ops to pure-Python callables
cell = metric_utils.PerplexityCell(is_pipeline_parallel=False)
cell.loss = FakeLoss()

# replace reshape with a callable that returns a mindspore Tensor with numpy reshape
def reshape_op(x, shape):
# x may be a Tensor or numpy array; convert to numpy
arr = x.asnumpy() if hasattr(x, 'asnumpy') else np.array(x)
new = arr.reshape(shape)
return ms.Tensor(new)

cell.reshape = reshape_op

out = cell.construct(logits, labels, mask)
# loss returns a 1-element tensor with value 0.42
assert isinstance(out, ms.Tensor)
assert abs(float(out.asnumpy().ravel()[0]) - 0.42) < 1e-6

# verify reshape produced expected flattened sizes
expected_n = batch_size * (seq_length - 1)
assert called['logits_shape'][0] == expected_n
assert called['labels_shape'][0] == expected_n
assert called['mask_shape'][0] == expected_n


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_perplexitycell_pipeline_flag_and_attrs():
"""
Feature: PerplexityCell pipeline flag
Description: Confirm attributes remain intact and construct executes when pipeline mode is True.
Expectation: Construct call succeeds and returns tensor output.
"""
batch_size, seq_length, vocab = 1, 3, 4
logits = Tensor(np.random.rand(batch_size, seq_length, vocab).astype(np.float32))
labels = Tensor(np.random.randint(0, vocab, size=(batch_size, seq_length)).astype(np.int32))
mask = Tensor(np.ones((batch_size, seq_length)).astype(np.int32))

cell = metric_utils.PerplexityCell(is_pipeline_parallel=True)
assert cell.is_pipeline_parallel is True

# simple loss that returns sum of labels as a float tensor
class SumLabelsLoss(ms.nn.Cell):
def construct(self, logits_value, l_labels, mask_value):
del logits_value
del mask_value
arr = l_labels.asnumpy() if hasattr(l_labels, 'asnumpy') else np.array(l_labels)
return ms.Tensor(np.array([float(arr.sum())], dtype=np.float32))

cell.loss = SumLabelsLoss()

# override reshape to identity reshape preserving shape semantics
def reshape_id(x, shape):
arr = x.asnumpy() if hasattr(x, 'asnumpy') else np.array(x)
new = arr.reshape(shape)
return ms.Tensor(new)

cell.reshape = reshape_id
out = cell.construct(logits, labels, mask)
assert isinstance(out, ms.Tensor)

+ 168
- 0
tests/st/test_ut/test_mindformer_book.py View File

@@ -0,0 +1,168 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test mindformer_book.py
"""
from unittest.mock import patch
import pytest

from mindformers.mindformer_book import MindFormerBook


#pylint: disable=W0212
class TestMindFormerBook:
""" A test class for testing mindformer_book."""
def setup_method(self):
"""Execute before each test method: save original data and set up test data"""
self.original_trainer_list = getattr(MindFormerBook, '_TRAINER_SUPPORT_TASKS_LIST', {})
self.original_pipeline_list = getattr(MindFormerBook, '_PIPELINE_SUPPORT_TASK_LIST', {})

MindFormerBook._TRAINER_SUPPORT_TASKS_LIST = {
"general": {"some_key": "some_value"},
"text_generation": {
"common": {"config": "value"},
"model1": "path1",
"model2": "path2"
},
"text_classification": {
"common": {"config": "value"},
"model3": "path3"
}
}

MindFormerBook._PIPELINE_SUPPORT_TASK_LIST = {
"text_generation": {
"common": {"config": "value"},
"model1": "path1",
"model2": "path2"
},
"text_classification": {
"common": {"config": "value"},
"model3": "path3"
},
"image_classification": {
"common": {"config": "value"},
"model4": "path4"
}
}

def teardown_method(self):
"""Execute after each test method: restore original data"""
MindFormerBook._TRAINER_SUPPORT_TASKS_LIST = self.original_trainer_list
MindFormerBook._PIPELINE_SUPPORT_TASK_LIST = self.original_pipeline_list

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_show_trainer_support_model_list_without_task(self):
"""Test case when no task is specified"""
with patch('mindformers.mindformer_book.print_dict') as mock_print_dict, \
patch('mindformers.mindformer_book.logger') as mock_logger:
MindFormerBook.show_trainer_support_model_list()
mock_logger.info.assert_called_with("Trainer support model list of MindFormer is: ")
mock_print_dict.assert_called_once()
call_args = mock_print_dict.call_args[0][0]
assert "text_generation" in call_args
assert "text_classification" in call_args
assert "general" not in call_args
assert call_args["text_generation"] == ["model1", "model2"]

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_show_trainer_support_model_list_with_valid_task(self):
"""Test case when a valid task is specified"""
with patch('mindformers.mindformer_book.print_path_or_list') as mock_print_list, \
patch('mindformers.mindformer_book.logger') as mock_logger:
MindFormerBook.show_trainer_support_model_list(task="text_generation")
mock_logger.info.assert_called_with("Trainer support model list for %s task is: ", "text_generation")
mock_print_list.assert_called_once()
call_args = mock_print_list.call_args[0][0]
assert call_args == ["model1", "model2"]

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_show_trainer_support_model_list_with_another_valid_task(self):
"""Test case when another valid task is specified"""
with patch('mindformers.mindformer_book.print_path_or_list') as mock_print_list, \
patch('mindformers.mindformer_book.logger') as mock_logger:
MindFormerBook.show_trainer_support_model_list(task="text_classification")
mock_logger.info.assert_called_with("Trainer support model list for %s task is: ", "text_classification")
mock_print_list.assert_called_once()
call_args = mock_print_list.call_args[0][0]
assert call_args == ["model3"]

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_show_trainer_support_model_list_with_invalid_task(self):
"""Test case when an invalid task is specified"""
with patch('mindformers.mindformer_book.logger') as mock_logger:
with pytest.raises(KeyError, match="unsupported task"):
MindFormerBook.show_trainer_support_model_list(task="invalid_task")
mock_logger.info.assert_not_called()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_show_pipeline_support_model_list_without_task(self):
"""Test pipeline case when no task is specified"""
with patch('mindformers.mindformer_book.print_dict') as mock_print_dict, \
patch('mindformers.mindformer_book.logger') as mock_logger:
MindFormerBook.show_pipeline_support_model_list()
mock_logger.info.assert_called_with("Pipeline support model list of MindFormer is: ")
mock_print_dict.assert_called_once()
call_args = mock_print_dict.call_args[0][0]
assert "text_generation" in call_args
assert "text_classification" in call_args
assert "image_classification" in call_args
assert call_args["text_generation"] == ["model1", "model2"]

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_show_pipeline_support_model_list_with_valid_task(self):
"""Test pipeline case when a valid task is specified"""
with patch('mindformers.mindformer_book.print_path_or_list') as mock_print_list, \
patch('mindformers.mindformer_book.logger') as mock_logger:
MindFormerBook.show_pipeline_support_model_list(task="text_generation")
mock_logger.info.assert_called_with("Pipeline support model list for %s task is: ", "text_generation")
mock_print_list.assert_called_once()
call_args = mock_print_list.call_args[0][0]
assert call_args == ["model1", "model2"]

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_show_pipeline_support_model_list_with_another_valid_task(self):
"""Test pipeline case when another valid task is specified"""
with patch('mindformers.mindformer_book.print_path_or_list') as mock_print_list, \
patch('mindformers.mindformer_book.logger') as mock_logger:
MindFormerBook.show_pipeline_support_model_list(task="image_classification")
mock_logger.info.assert_called_with("Pipeline support model list for %s task is: ", "image_classification")
mock_print_list.assert_called_once()
call_args = mock_print_list.call_args[0][0]
assert call_args == ["model4"]

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_show_pipeline_support_model_list_with_invalid_task(self):
"""Test pipeline case when an invalid task is specified"""
with patch('mindformers.mindformer_book.logger') as mock_logger:
with pytest.raises(KeyError, match="unsupported task"):
MindFormerBook.show_pipeline_support_model_list(task="invalid_task")
mock_logger.info.assert_not_called()

+ 264
- 0
tests/st/test_ut/test_models/test_auto/test_configuration_auto.py View File

@@ -0,0 +1,264 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2025 Huawei Technologies
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Unit tests for mindformers.models.auto.configuration_auto."""
from types import SimpleNamespace

import pytest

import mindformers.models.auto.configuration_auto as auto_cfg
from mindformers.models.auto.configuration_auto import (
AutoConfig,
CONFIG_MAPPING,
config_class_to_model_type,
_LazyConfigMapping,
_list_model_options,
replace_list_option_in_docstrings,
)
from mindformers.models.configuration_utils import PretrainedConfig


class DummyMindFormerConfig(dict):
"""Stub MindFormerConfig for unit tests."""

def __init__(self, use_legacy=True, has_pretrained=False, has_generation=False):
super().__init__()
self._use_legacy = use_legacy
self.model = SimpleNamespace(
model_config={"type": "DemoConfig"},
arch=SimpleNamespace(type="demo_arch"),
)
if has_pretrained:
self["pretrained_model_dir"] = "pretrained_dir"
self.pretrained_model_dir = "pretrained_dir"
else:
self.pretrained_model_dir = None
if has_generation:
self["generation_config"] = {"gen": True}
self.generation_config = {"gen": True}
else:
self.generation_config = None

def get_value(self, key, default=None):
"""Get value from config."""
if key == "use_legacy":
return self._use_legacy
return default


@pytest.fixture(autouse=True)
def restore_extra_content():
"""Ensure CONFIG_MAPPING extra registrations are restored between tests."""
# Accessing protected member for test cleanup is intentional
original = CONFIG_MAPPING._extra_content.copy() # pylint: disable=W0212,protected-access
yield
CONFIG_MAPPING._extra_content = original # pylint: disable=W0212,protected-access


class TestConfigurationAuto:
"""Test class for mindformers.models.auto.configuration_auto."""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_config_class_to_model_type_core_and_extra(self, monkeypatch):
"""config_class_to_model_type should inspect default and extra registries."""
assert config_class_to_model_type("LlamaConfig") == "llama"
dummy_class = type("NewConfig", (), {})
# Accessing protected member for test setup is intentional
monkeypatch.setitem(CONFIG_MAPPING._extra_content, "custom", dummy_class) # pylint: disable=W0212,protected-access
assert config_class_to_model_type("NewConfig") == "custom"

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_lazy_config_mapping_register_and_getitem(self, monkeypatch):
"""_LazyConfigMapping should import modules lazily and honor register."""
module = SimpleNamespace(MockConfig="sentinel")
monkeypatch.setattr(auto_cfg.importlib, "import_module", lambda name, package=None: module)
mapping = _LazyConfigMapping({"mock": "MockConfig"})
assert mapping["mock"] == "sentinel"
mapping.register("extra", "ExtraConfig", exist_ok=True)
assert mapping["extra"] == "ExtraConfig"
with pytest.raises(ValueError):
mapping.register("mock", "OtherConfig")
with pytest.raises(KeyError):
_ = mapping["missing"]

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_list_model_options_and_docstring_replacement(self):
"""_list_model_options and decorator should update docstrings or raise errors."""
doc = _list_model_options(" ", {"llama": ["LlamaConfig"]}, use_model_types=False)
assert "LlamaConfig" in doc

@replace_list_option_in_docstrings({"llama": ["LlamaConfig"]})
def sample():
"""List options"""

assert "llama" in sample.__doc__

def broken():
"""no placeholder"""

decorator = replace_list_option_in_docstrings({"llama": ["LlamaConfig"]})
with pytest.raises(ValueError):
decorator(broken)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_autoconfig_invalid_yaml_name_branches(self, monkeypatch):
"""AutoConfig.invalid_yaml_name should validate against support list."""
monkeypatch.setattr(AutoConfig, "_support_list",
{"llama": ["llama_7b"], "glm": {"9b": ["glm_9b"]}})
assert AutoConfig.invalid_yaml_name("unknown_model")
assert not AutoConfig.invalid_yaml_name("llama_7b")
assert not AutoConfig.invalid_yaml_name("glm_9b")
with pytest.raises(ValueError):
AutoConfig.invalid_yaml_name("glm_bad")

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_autoconfig_for_model_and_error(self):
"""AutoConfig.for_model should instantiate registered configs or raise."""
class DummyConfig(PretrainedConfig):
"""Dummy config for unit tests."""
model_type = "dummy_key"

def __init__(self, value=None):
super().__init__()
self.value = value

CONFIG_MAPPING.register("dummy_key", DummyConfig, exist_ok=True)
result = AutoConfig.for_model("dummy_key", value=3)
assert isinstance(result, DummyConfig) and result.value == 3
with pytest.raises(ValueError):
AutoConfig.for_model("missing")

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_from_pretrained_switches_modes(self, monkeypatch):
"""AutoConfig.from_pretrained should delegate based on experimental flag."""
monkeypatch.setattr(auto_cfg, "is_experimental_mode", lambda _: False)
monkeypatch.setattr(AutoConfig, "get_config_origin_mode",
classmethod(lambda cls, name, **_: ("origin", name)))
res = AutoConfig.from_pretrained("path/model.yaml", pretrained_model_name_or_path="override")
assert res == ("origin", "override")
monkeypatch.setattr(auto_cfg, "is_experimental_mode", lambda _: True)
monkeypatch.setattr(AutoConfig, "get_config_experimental_mode",
classmethod(lambda cls, name, **_: ("exp", name)))
assert AutoConfig.from_pretrained("path/model.yaml") == ("exp", "path/model.yaml")

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_get_config_origin_mode_type_and_extension_errors(self, tmp_path):
"""get_config_origin_mode should validate input types and extensions."""
with pytest.raises(TypeError):
AutoConfig.get_config_origin_mode(123)
bad_file = tmp_path / "not_yaml.txt"
bad_file.write_text("content", encoding="utf-8")
with pytest.raises(ValueError):
AutoConfig.get_config_origin_mode(str(bad_file))

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_get_config_origin_mode_invalid_yaml_name(self, monkeypatch):
"""Non-existing yaml names should raise ValueError."""
monkeypatch.setattr(AutoConfig, "invalid_yaml_name", classmethod(lambda cls, _: True))
with pytest.raises(ValueError):
AutoConfig.get_config_origin_mode("unknown_name")

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_get_config_origin_mode_legacy_flow(self, monkeypatch, tmp_path):
"""Legacy pathway should build configs and update auxiliary fields."""
dummy = DummyMindFormerConfig(use_legacy=True, has_pretrained=True, has_generation=True)
monkeypatch.setattr(auto_cfg, "MindFormerConfig", lambda *_: dummy)
built = {}
monkeypatch.setattr(auto_cfg, "build_model_config",
lambda cfg: built.setdefault("config", cfg) or "legacy")
monkeypatch.setattr(auto_cfg.MindFormerBook, "set_model_config_to_name",
lambda *args, **kwargs: built.setdefault("mark", args))
yaml_file = tmp_path / "model.yaml"
yaml_file.write_text("model: {}", encoding="utf-8")
AutoConfig.get_config_origin_mode(str(yaml_file), hidden_size=128)
assert dummy.model.model_config["hidden_size"] == 128
assert dummy.model.pretrained_model_dir == "pretrained_dir"
assert dummy.model.generation_config == {"gen": True}
assert built["config"] == dummy.model.model_config

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_get_config_origin_mode_nonlegacy_flow(self, monkeypatch, tmp_path):
"""Non-legacy pathway should use get_model_config without calling builder."""
dummy = DummyMindFormerConfig(use_legacy=False)
monkeypatch.setattr(auto_cfg, "MindFormerConfig", lambda *_: dummy)
marker = {}
monkeypatch.setattr(auto_cfg, "build_model_config",
lambda *_: marker.setdefault("should_not_call", True))
monkeypatch.setattr(auto_cfg, "get_model_config",
lambda model: marker.setdefault("model", model) or "new_config")
yaml_file = tmp_path / "model.yaml"
yaml_file.write_text("model: {}", encoding="utf-8")
AutoConfig.get_config_origin_mode(str(yaml_file), dropout=0.1)
assert dummy.model.model_config["dropout"] == 0.1
assert marker["model"] == dummy.model
assert "should_not_call" not in marker

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_get_config_experimental_mode_remote_code(self, monkeypatch):
"""Remote code configs should be loaded via dynamic modules when trusted."""
monkeypatch.setattr(auto_cfg.PretrainedConfig, "get_config_dict",
classmethod(lambda cls, name,
**kwargs: ({"auto_map": {"AutoConfig": "mod.Class"}}, {})))
monkeypatch.setattr(auto_cfg, "resolve_trust_remote_code", lambda trust, *args, **kwargs: True)

class RemoteConfig:
"""Remote config for unit tests."""
@staticmethod
def register_for_auto_class():
"""Register for auto class."""
RemoteConfig.registered = True

@staticmethod
def from_pretrained(name, **kwargs):
"""From pretrained."""
return {"name": name, "kwargs": kwargs}

monkeypatch.setattr(auto_cfg, "get_class_from_dynamic_module", lambda *args, **kwargs: RemoteConfig)
monkeypatch.setattr(auto_cfg.os.path, "isdir", lambda _: True)
result = AutoConfig.get_config_experimental_mode("remote_repo", trust_remote_code=True)
assert result["name"] == "remote_repo"
assert RemoteConfig.registered is True

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_get_config_experimental_mode_local_config(self, monkeypatch):
"""Local configs with model_type should resolve via CONFIG_MAPPING."""
class LocalConfig(PretrainedConfig):
"""LocalConfig for tests."""
model_type = "custom_dummy"

@classmethod
def from_dict(cls, config_dict, **kwargs):
return {"config": config_dict, "extra": kwargs}

CONFIG_MAPPING.register("custom_dummy", LocalConfig, exist_ok=True)
monkeypatch.setattr(auto_cfg.PretrainedConfig, "get_config_dict",
classmethod(lambda cls, name,
**kwargs: ({"model_type": "custom_dummy", "value": 1}, {"unused": True})))
result = AutoConfig.get_config_experimental_mode("local_repo")
assert result["config"]["value"] == 1
assert result["extra"]["unused"] is True

+ 86
- 0
tests/st/test_ut/test_models/test_auto/test_utils.py View File

@@ -0,0 +1,86 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test utils.py
"""
from unittest.mock import patch
import pytest

from mindformers.mindformer_book import MindFormerBook
from mindformers.models.auto.utils import get_default_yaml_file, set_default_yaml_file


class TestYamlFileFunctions:
""" A test class for testing utils."""
@pytest.fixture
def mock_trainer_support_list(self):
"""Mock the trainer support task list"""
return {
"text_generation": {
"model1": "/path/to/model1.yaml",
"model2": "/path/to/model2.yaml"
}
}

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_default_yaml_file_found(self, mock_trainer_support_list):
"""Test getting default yaml file when model exists"""
with patch.object(
MindFormerBook, 'get_trainer_support_task_list', return_value=mock_trainer_support_list):
result = get_default_yaml_file("model1")
assert result == "/path/to/model1.yaml"

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_default_yaml_file_not_found(self, mock_trainer_support_list):
"""Test getting default yaml file when model doesn't exist"""
with patch.object(
MindFormerBook, 'get_trainer_support_task_list', return_value=mock_trainer_support_list):
result = get_default_yaml_file("nonexistent_model")
assert result == ""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_set_default_yaml_file_success(self, mock_trainer_support_list):
"""Test setting default yaml file when source exists and target doesn't"""
with patch.object(
MindFormerBook, 'get_trainer_support_task_list', return_value=mock_trainer_support_list), \
patch('os.path.exists') as mock_exists, \
patch('os.path.realpath', return_value='/real/path/to/model1.yaml'), \
patch('shutil.copy') as mock_copy, \
patch('mindformers.models.auto.utils.logger') as mock_logger:
mock_exists.side_effect = lambda path: path != '/target/path.yaml'
set_default_yaml_file("model1", "/target/path.yaml")
mock_copy.assert_called_once_with("/path/to/model1.yaml", "/target/path.yaml")
mock_logger.info.assert_called_once_with("default yaml config in %s is used.", "/target/path.yaml")

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_set_default_yaml_file_source_not_found(self, mock_trainer_support_list):
"""Test setting default yaml file when source file doesn't exist"""
with (patch.object(
MindFormerBook, 'get_trainer_support_task_list', return_value=mock_trainer_support_list), \
patch('os.path.exists') as mock_exists, patch('os.path.realpath', return_value=''), \
patch('shutil.copy') as mock_copy):
mock_exists.side_effect = lambda path: False
with pytest.raises(
FileNotFoundError, match="default yaml file path must be correct, but get /path/to/model1.yaml"):
set_default_yaml_file("model1", "/target/path.yaml")
mock_copy.assert_not_called()

+ 0
- 0
tests/st/test_ut/test_models/test_glm4/__init__.py View File


+ 47
- 0
tests/st/test_ut/test_models/test_glm4/test_configuration_glm4.py View File

@@ -0,0 +1,47 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Unit tests for Glm4Config."""
import pytest

from mindformers.models.glm4.configuration_glm4 import Glm4Config


class TestGlm4Config:
"""Validates default behaviors for the Glm4 configuration."""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_default_configuration_fields(self):
"""Test that the Glm4Config initializes with expected default values."""
config = Glm4Config()

assert config.vocab_size == 151552
assert config.hidden_size == 4096
assert config.num_hidden_layers == 40
assert config.num_attention_heads == 32
assert config.num_key_value_heads == 2
assert config.position_embedding_type == "partial_rope"
assert config.model_type == "glm4"
assert "layers.*.self_attn.q_proj" in Glm4Config.base_model_tp_plan

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_override_arguments_apply(self):
"""Test that arguments passed to Glm4Config constructor correctly override the defaults."""
config = Glm4Config(vocab_size=10, num_attention_heads=8, eos_token_id=(1,))

assert config.vocab_size == 10
assert config.num_attention_heads == 8
assert config.eos_token_id == (1,)

+ 46
- 0
tests/st/test_ut/test_models/test_glm4/test_modeling_glm4.py View File

@@ -0,0 +1,46 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""UTs for Glm4 modeling API."""
import os
import pytest

from mindformers.models.glm4.configuration_glm4 import Glm4Config
from mindformers.models.glm4.modeling_glm4 import Glm4ForCausalLM
from mindformers.models.glm4.modeling_glm4_infer import InferenceGlm4ForCausalLM


class TestGlm4ForCausalLM:
"""Ensure Glm4ForCausalLM routes to the proper implementation."""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_init_model_in_predict_mode(self):
"""When RUN_MODE is unset/predict, the inference model should be instantiated."""
os.environ['RUN_MODE'] = "predict"
config = Glm4Config()

model = Glm4ForCausalLM(config)

assert isinstance(model, InferenceGlm4ForCausalLM)
assert model.config is config

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_init_model_in_train_mode(self):
"""RUN_MODE=train should raise an explicit NotImplementedError."""
os.environ['RUN_MODE'] = "train"

with pytest.raises(NotImplementedError):
Glm4ForCausalLM(Glm4Config())

+ 0
- 0
tests/st/test_ut/test_models/test_glm4_moe/__init__.py View File


+ 48
- 0
tests/st/test_ut/test_models/test_glm4_moe/test_configuration_glm4_moe.py View File

@@ -0,0 +1,48 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Unit tests for Glm4MoeConfig."""
import pytest

from mindformers.models.glm4_moe.configuration_glm4_moe import Glm4MoeConfig


class TestGlm4MoeConfig:
"""Tests covering the Glm4Moe configuration helper."""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_default_configuration_values(self):
"""Ensure defaults from the spec are propagated to attributes."""
config = Glm4MoeConfig()

assert config.vocab_size == 151552
assert config.hidden_size == 4096
assert config.num_hidden_layers == 46
assert config.num_attention_heads == 96
assert config.moe_intermediate_size == 1408
assert config.num_experts_per_tok == 8
assert config.norm_topk_prob is True
assert config.model_type == "glm4_moe"
assert "layers.*.self_attn.q_proj" in Glm4MoeConfig.base_model_tp_plan

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_rope_scaling_type_key_is_renamed(self):
"""When rope scaling contains 'type', it should be copied to 'rope_type'."""
rope_scaling = {"type": "yarn", "factor": 2.0}
config = Glm4MoeConfig(rope_scaling=rope_scaling)

assert config.rope_scaling["rope_type"] == "yarn"
assert config.rope_scaling["factor"] == 2.0

+ 46
- 0
tests/st/test_ut/test_models/test_glm4_moe/test_modeling_glm4_moe.py View File

@@ -0,0 +1,46 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""UTs for Glm4Moe modeling API."""
import os
import pytest

from mindformers.models.glm4_moe.configuration_glm4_moe import Glm4MoeConfig
from mindformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeForCausalLM
from mindformers.models.glm4_moe.modeling_glm4_moe_infer import InferenceGlm4MoeForCausalLM


class TestGlm4MoeForCausalLM:
"""Ensure Glm4MoeForCausalLM routes to the proper implementation."""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_init_model_in_predict_mode(self):
"""When RUN_MODE is unset/predict, the inference model should be instantiated."""
os.environ['RUN_MODE'] = "predict"
config = Glm4MoeConfig()

model = Glm4MoeForCausalLM(config)

assert isinstance(model, InferenceGlm4MoeForCausalLM)
assert model.config is config

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_init_model_in_train_mode(self):
"""RUN_MODE=train should raise an explicit NotImplementedError."""
os.environ['RUN_MODE'] = "train"

with pytest.raises(NotImplementedError):
Glm4MoeForCausalLM(Glm4MoeConfig())

+ 1216
- 0
tests/st/test_ut/test_models/test_modeling_utils.py View File

@@ -0,0 +1,1216 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ModelingUtils test cases"""

from unittest.mock import patch, MagicMock, DEFAULT
import os
import json
import shutil
import tempfile
import pytest
import mindspore as ms
from mindformers.models.configuration_utils import PretrainedConfig
from mindformers.models.modeling_utils import dtype_byte_size, save_checkpoint, shard_checkpoint, \
load_sharded_checkpoint, _add_variant, PreTrainedModel


# pylint: disable=W0212
class TestModelingUtils:
"""Test cases for modeling utilities functions"""

def setup_method(self):
"""Set up test environment"""
# Create mock parameters with different sizes
self.mock_param1 = MagicMock()
self.mock_param1.numel.return_value = 1000000 # 1M elements
self.mock_param1.dtype = ms.float32 # 4 bytes per element = ~4MB

self.mock_param2 = MagicMock()
self.mock_param2.numel.return_value = 2000000 # 2M elements
self.mock_param2.dtype = ms.float32 # 4 bytes per element = ~8MB

self.mock_param3 = MagicMock()
self.mock_param3.numel.return_value = 500000 # 0.5M elements
self.mock_param3.dtype = ms.float32 # 4 bytes per element = ~2MB

# Create state dict
self.state_dict = {
"param1": self.mock_param1,
"param2": self.mock_param2,
"param3": self.mock_param3
}

# Create a mock config object
self.mock_config = MagicMock()
self.mock_config.parallel_config = MagicMock()
self.mock_config.parallel_config.pipeline_stage = 0
self.mock_config.pp_interleave_num = 1

# Create a mock model instance
self.model = PreTrainedModel.__new__(PreTrainedModel)
self.model.config = self.mock_config

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.convert_file_size_to_int')
@patch('mindformers.models.modeling_utils.dtype_byte_size')
def test_shard_checkpoint_single_shard(self, mock_dtype_byte_size, mock_convert_file_size_to_int):
"""Test shard_checkpoint with single shard (all weights fit in one shard)"""
# Mock byte size to be small enough for all weights to fit in one shard
mock_dtype_byte_size.return_value = 1e-6 # 1 byte per element
mock_convert_file_size_to_int.return_value = 10000000 # 10MB limit

shards, index = shard_checkpoint(self.state_dict, "10MB")

# Should have only one shard
assert len(shards) == 1
assert "mindspore_model.ckpt" in shards

# Index should be None for single shard
assert index is None

# All parameters should be in the shard
shard_content = shards["mindspore_model.ckpt"]
assert "param1" in shard_content
assert "param2" in shard_content
assert "param3" in shard_content

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.convert_file_size_to_int')
@patch('mindformers.models.modeling_utils.dtype_byte_size')
def test_shard_checkpoint_multiple_shards(self, mock_dtype_byte_size, mock_convert_file_size_to_int):
"""Test shard_checkpoint with multiple shards"""
# Mock byte size to create multiple shards
mock_dtype_byte_size.return_value = 1 # 1 byte per element
mock_convert_file_size_to_int.return_value = 1500000 # 1.5MB limit

shards, index = shard_checkpoint(self.state_dict, "1.5MB")

# Should have multiple shards
assert len(shards) > 1

# Index should not be None
assert index is not None
assert "metadata" in index
assert "weight_map" in index

# Check that all parameters are in the weight map
weight_map = index["weight_map"]
assert "param1" in weight_map
assert "param2" in weight_map
assert "param3" in weight_map

# Check that parameters are distributed across shards
shard_files = list(shards.keys())
assert len(shard_files) > 1

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.convert_file_size_to_int')
@patch('mindformers.models.modeling_utils.dtype_byte_size')
def test_shard_checkpoint_large_weight(self, mock_dtype_byte_size, mock_convert_file_size_to_int):
"""Test shard_checkpoint with a weight larger than max_shard_size"""
# Mock a very large parameter
mock_large_param = MagicMock()
mock_large_param.numel.return_value = 10000000 # 10M elements
mock_large_param.dtype = ms.float32 # 4 bytes per element = ~40MB

large_state_dict = {
"small_param": self.mock_param1,
"large_param": mock_large_param
}

# Set limit to be smaller than the large parameter
mock_dtype_byte_size.return_value = 1 # 1 byte per element
mock_convert_file_size_to_int.return_value = 5000000 # 5MB limit

shards, _ = shard_checkpoint(large_state_dict, "5MB")

# Should have multiple shards since large param exceeds limit
assert len(shards) > 1

# Large parameter should be in its own shard
large_param_shard = None
for _, shard_content in shards.items():
if "large_param" in shard_content:
large_param_shard = shard_content
break

assert large_param_shard is not None
assert "large_param" in large_param_shard

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_add_variant_no_variant(self):
"""Test _add_variant function with no variant"""
weights_name = "mindspore_model.ckpt"
result = _add_variant(weights_name, None)
assert result == weights_name

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_add_variant_with_variant(self):
"""Test _add_variant function with variant"""
weights_name = "mindspore_model.ckpt"
variant = "fp16"
expected = "mindspore_model.fp16.ckpt"
result = _add_variant(weights_name, variant)
assert result == expected

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_add_variant_complex_name(self):
"""Test _add_variant function with complex file name"""
weights_name = "model.custom.extension.ckpt"
variant = "quantized"
expected = "model.custom.extension.quantized.ckpt"
result = _add_variant(weights_name, variant)
assert result == expected

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.ms.get_auto_parallel_context')
def test_not_semi_auto_parallel_mode(self, mock_get_auto_parallel_context):
"""Test when parallel mode is not semi_auto_parallel, should not raise any exception"""
# Set parallel mode to standalone
mock_get_auto_parallel_context.return_value = "stand_alone"
self.mock_config.parallel_config.pipeline_stage = 2 # pp > 1

# Should not raise any exception
try:
self.model.check_pipeline_stage()
except Exception as e:
pytest.fail(f"check_pipeline_stage() raised an exception unexpectedly: {e}")

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.ms.get_auto_parallel_context')
def test_pipeline_stage_less_than_2(self, mock_get_auto_parallel_context):
"""Test when pipeline_stage <= 1, should not raise any exception"""
# Set parallel mode to semi_auto_parallel
mock_get_auto_parallel_context.return_value = "semi_auto_parallel"
self.mock_config.parallel_config.pipeline_stage = 1 # pp <= 1

# Should not raise any exception
try:
self.model.check_pipeline_stage()
except Exception as e:
pytest.fail(f"check_pipeline_stage() raised an exception unexpectedly: {e}")

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.ms.get_auto_parallel_context')
def test_missing_num_layers_attribute(self, mock_get_auto_parallel_context):
"""Test when pipeline_stage > 1 but num_layers is not found, should raise ValueError"""
# Set parallel mode to semi_auto_parallel and pipeline_stage > 1
mock_get_auto_parallel_context.return_value = "semi_auto_parallel"
self.mock_config.parallel_config.pipeline_stage = 2

# Remove num_layers and num_hidden_layers attributes
if hasattr(self.mock_config, 'num_layers'):
delattr(self.mock_config, 'num_layers')
if hasattr(self.mock_config, 'num_hidden_layers'):
delattr(self.mock_config, 'num_hidden_layers')

# Should raise ValueError
with pytest.raises(ValueError) as context:
self.model.check_pipeline_stage()

assert "is not found when pipeline_stage > 1" in str(context.value)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.ms.get_auto_parallel_context')
def test_num_layers_less_than_pipeline_stage(self, mock_get_auto_parallel_context):
"""Test when num_layers < pipeline_stage, should raise ValueError"""
# Set parallel mode to semi_auto_parallel and pipeline_stage > 1
mock_get_auto_parallel_context.return_value = "semi_auto_parallel"
self.mock_config.parallel_config.pipeline_stage = 4 # pp = 4
self.mock_config.num_layers = 3 # num_layers < pp

# Should raise ValueError
with pytest.raises(ValueError) as context:
self.model.check_pipeline_stage()

assert "num_layers (3) < pp(4)" in str(context.value)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.ms.get_auto_parallel_context')
def test_valid_num_layers_and_pipeline_stage(self, mock_get_auto_parallel_context):
"""Test when num_layers >= pipeline_stage, should not raise any exception"""
# Set parallel mode to semi_auto_parallel and pipeline_stage > 1
mock_get_auto_parallel_context.return_value = "semi_auto_parallel"
self.mock_config.parallel_config.pipeline_stage = 3 # pp = 3
self.mock_config.num_layers = 6 # num_layers > pp

# Should not raise any exception
try:
self.model.check_pipeline_stage()
except Exception as e:
pytest.fail(f"check_pipeline_stage() raised an exception unexpectedly: {e}")

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.ms.get_auto_parallel_context')
def test_pipeline_interleave_valid_case(self, mock_get_auto_parallel_context):
"""Test when pipeline interleave is enabled and valid configuration, should not raise any exception"""
# Set parallel mode to semi_auto_parallel and pipeline_stage > 1
mock_get_auto_parallel_context.side_effect = [
"semi_auto_parallel", # First call for parallel_mode
True # Second call for pipeline_interleave
]
self.mock_config.parallel_config.pipeline_stage = 2 # pp = 2
self.mock_config.num_layers = 6 # num_layers = 6
self.mock_config.pp_interleave_num = 2 # pp_interleave_num = 2
# pp * pp_interleave_num = 4, which is < num_layers (6) - should be valid

# Should not raise any exception
try:
self.model.check_pipeline_stage()
except Exception as e:
pytest.fail(f"check_pipeline_stage() raised an exception unexpectedly: {e}")

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.ms.get_auto_parallel_context')
def test_pipeline_interleave_invalid_case(self, mock_get_auto_parallel_context):
"""Test when pipeline interleave is enabled but pp * pp_interleave_num > num_layers, should raise ValueError"""
# Set parallel mode to semi_auto_parallel and pipeline_stage > 1
mock_get_auto_parallel_context.side_effect = [
"semi_auto_parallel", # First call for parallel_mode
True # Second call for pipeline_interleave
]
self.mock_config.parallel_config.pipeline_stage = 3 # pp = 3
self.mock_config.num_layers = 5 # num_layers = 5
self.mock_config.pp_interleave_num = 2 # pp_interleave_num = 2
# pp * pp_interleave_num = 6, which is > num_layers (5) - should be invalid

# Should raise ValueError
with pytest.raises(ValueError) as context:
self.model.check_pipeline_stage()

assert "num_layers : 5 and pp * pp_interleave_num = 6" in str(context.value)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.ms.get_auto_parallel_context')
def test_num_hidden_layers_fallback(self, mock_get_auto_parallel_context):
"""Test when num_layers is not present but num_hidden_layers is used as fallback"""
# Set parallel mode to semi_auto_parallel and pipeline_stage > 1
mock_get_auto_parallel_context.return_value = "semi_auto_parallel"
self.mock_config.parallel_config.pipeline_stage = 3 # pp = 3

# Remove num_layers but keep num_hidden_layers
if hasattr(self.mock_config, 'num_layers'):
delattr(self.mock_config, 'num_layers')
self.mock_config.num_hidden_layers = 2 # num_hidden_layers < pp - should be invalid

# Should raise ValueError
with pytest.raises(ValueError) as context:
self.model.check_pipeline_stage()

assert "num_layers (2) < pp(3)" in str(context.value)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch("mindformers.tools.PushToHubMixin._get_files_timestamps")
@patch("mindformers.tools.PushToHubMixin._create_repo")
@patch("mindformers.tools.PushToHubMixin._upload_modified_files")
def test_save_pretrained_in_json(self, mock_get_files_timestamps,
mock_create_repo,
mock_upload_modified_files):
"""Test save pretrained model"""
mock_get_files_timestamps.return_value = {"test": "test"}
mock_create_repo.return_value = "test"
mock_upload_modified_files.return_value = "test"

with tempfile.TemporaryDirectory() as temp_dir:
with patch("mindformers.models.modeling_utils.shard_checkpoint") as mock_shard_checkpoint:
mock_shard_checkpoint.return_value = {"test": {"test": "test"}}, None
self.model.save_pretrained(save_directory=temp_dir, save_json=True,
token="test", state_dict={"test": "test"}, push_to_hub=True)


# pylint: disable=W0212
class TestLoadShardedCheckpoint:
"""Test cases for load_sharded_checkpoint function"""

def setup_method(self):
"""Set up test environment"""
self.temp_dir = tempfile.mkdtemp()

# Create mock model
self.mock_model = MagicMock()

# Create sample index file content
self.index_content = {
"metadata": {"total_size": 1000000},
"weight_map": {
"param1": "mindspore_model-00001-of-00002.ckpt",
"param2": "mindspore_model-00001-of-00002.ckpt",
"param3": "mindspore_model-00002-of-00002.ckpt"
}
}

# Create index file
self.index_file = os.path.join(self.temp_dir, "mindspore_model.ckpt.index.json")
with open(self.index_file, 'w', encoding="utf-8") as f:
json.dump(self.index_content, f)

# Create mock checkpoint files
self.shard1_file = os.path.join(self.temp_dir, "mindspore_model-00001-of-00002.ckpt")
self.shard2_file = os.path.join(self.temp_dir, "mindspore_model-00002-of-00002.ckpt")

# Create empty files for shards
with open(self.shard1_file, 'w', encoding="utf-8") as f:
f.write("")
with open(self.shard2_file, 'w', encoding="utf-8") as f:
f.write("")

def teardown_method(self):
"""Clean up test environment"""
shutil.rmtree(self.temp_dir)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.load_checkpoint')
@patch('mindformers.models.modeling_utils.load_param_into_net')
def test_load_sharded_checkpoint(self, mock_load_param_into_net, mock_load_checkpoint):
"""Test load_sharded_checkpoint function"""

# Mock load_checkpoint to return different state dicts for different files
def mock_load_checkpoint_side_effect(file_path):
if "00001" in file_path:
return {"param1": "value1", "param2": "value2"}
if "00002" in file_path:
return {"param3": "value3"}
return {}

mock_load_checkpoint.side_effect = mock_load_checkpoint_side_effect

# Mock load_param_into_net to return empty lists (no missing/unexpected keys)
mock_load_param_into_net.return_value = ([], [])

# Call the function
result = load_sharded_checkpoint(self.mock_model, self.temp_dir)

# Verify load_checkpoint was called for each shard
mock_load_checkpoint.assert_any_call(self.shard1_file)
mock_load_checkpoint.assert_any_call(self.shard2_file)

# Verify load_param_into_net was called with combined state dict
mock_load_param_into_net.assert_called_once()

# Result should be the return value from load_param_into_net
assert result == ([], [])

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.load_checkpoint')
@patch('mindformers.models.modeling_utils.load_param_into_net')
def test_load_sharded_checkpoint_strict_false(self, mock_load_param_into_net, mock_load_checkpoint):
"""Test load_sharded_checkpoint function with strict=False"""
# Mock load_checkpoint
mock_load_checkpoint.return_value = {"param1": "value1"}

# Mock load_param_into_net
mock_load_param_into_net.return_value = (["missing_key"], ["unexpected_key"])

# Call the function with strict=False
result = load_sharded_checkpoint(self.mock_model, self.temp_dir, strict=False)

# Verify load_param_into_net was called with strict_load=False
mock_load_param_into_net.assert_called_once()
call_args = mock_load_param_into_net.call_args
assert call_args[1]['strict_load'] is False

# Check result
assert result == (["missing_key"], ["unexpected_key"])

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.os.path.join')
def test_load_sharded_checkpoint_invalid_folder(self, mock_join):
"""Test load_sharded_checkpoint with invalid folder"""
# Mock join to return a path that doesn't exist
mock_join.return_value = "/non/existent/path/index.json"

# Should raise FileNotFoundError when trying to open non-existent index file
with pytest.raises(FileNotFoundError):
load_sharded_checkpoint(self.mock_model, "/non/existent/path")

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_dtype_byte_size_bool(self):
"""Test dtype_byte_size function with boolean type"""
# Test boolean type which returns 1/8 byte
result = dtype_byte_size(ms.bool_)
assert result == 1/8

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_dtype_byte_size_standard_types(self):
"""Test dtype_byte_size function with standard numeric types"""
# Test float32 type (32 bits = 4 bytes)
result = dtype_byte_size(ms.float32)
assert result == 4

# Test int32 type (32 bits = 4 bytes)
result = dtype_byte_size(ms.int32)
assert result == 4

# Test float16 type (16 bits = 2 bytes)
result = dtype_byte_size(ms.float16)
assert result == 2

# Test int8 type (8 bits = 1 byte)
result = dtype_byte_size(ms.int8)
assert result == 1

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_dtype_byte_size_invalid_type(self):
"""Test dtype_byte_size function with invalid dtype"""
# Create an invalid dtype string
invalid_dtype = "invalid_dtype_123"

# Patch ms.bool_ to test the invalid case
with patch('mindformers.models.modeling_utils.re.search') as mock_search:
mock_search.return_value = None
with pytest.raises(ValueError) as context:
dtype_byte_size(invalid_dtype)
assert "is not a valid dtype" in str(context.value)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_dtype_byte_size_edge_cases(self):
"""Test dtype_byte_size function edge cases"""
# Test with 64-bit type
with patch('mindformers.models.modeling_utils.re.search') as mock_search:
# Mock the regex search to return 64 bits
mock_match = MagicMock()
mock_match.groups.return_value = ['64']
mock_search.return_value = mock_match

# Test with a 64-bit type
result = dtype_byte_size(ms.float64)
assert result == 8 # 64 bits / 8 = 8 bytes

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.ms.save_checkpoint')
def test_save_checkpoint(self, mock_save_checkpoint):
"""Test save_checkpoint function"""
# Create a temporary directory
with tempfile.TemporaryDirectory() as temp_dir:
# Create a mock save object
mock_save_obj = MagicMock()

# Define checkpoint file name
ckpt_file_name = os.path.join(temp_dir, "mindspore_model.ckpt")

# Call the function
save_checkpoint(mock_save_obj, temp_dir)

# Verify save_checkpoint was called with correct arguments
mock_save_checkpoint.assert_called_once_with(mock_save_obj, ckpt_file_name)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.ms.save_checkpoint')
def test_save_checkpoint_with_custom_weights_name(self, mock_save_checkpoint):
"""Test save_checkpoint function with custom weights name"""
# Patch WEIGHTS_NAME to test custom name
with patch('mindformers.models.modeling_utils.WEIGHTS_NAME', 'custom_model.ckpt'):
with tempfile.TemporaryDirectory() as temp_dir:
# Create a mock save object
mock_save_obj = MagicMock()

# Define expected checkpoint file name
ckpt_file_name = os.path.join(temp_dir, "custom_model.ckpt")

# Call the function
save_checkpoint(mock_save_obj, temp_dir)

# Verify save_checkpoint was called with correct arguments
mock_save_checkpoint.assert_called_once_with(mock_save_obj, ckpt_file_name)


# pylint: disable=W0212
class TestPreTrainedModelMethods:
"""Test cases for PreTrainedModel methods"""

def setup_method(self):
"""Set up test environment"""
# Create a mock config
self.mock_config = MagicMock()
self.mock_config.name_or_path = "test_model"

# Create a mock model instance
self.model = PreTrainedModel.__new__(PreTrainedModel)
self.model.config = self.mock_config
self.model.base_model_prefix = ""
self.model._keys_to_ignore_on_save = None
self.model._auto_class = None

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_can_generate_default_behavior(self):
"""Test can_generate method with default behavior"""
# By default, model should be able to generate if it inherits from GenerationMixin
result = PreTrainedModel.can_generate()
# Since we're calling directly on PreTrainedModel, it should return True
# because the methods are not from GeneratorMixin
assert result is True

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_save_pretrained_origin_mode_default_values(self):
"""Test save_pretrained_origin_mode with default values"""
with tempfile.TemporaryDirectory() as temp_dir, \
patch.multiple('mindformers.models.modeling_utils',
DEFAULT_CHECKPOINT_SAVE_FOLDER=temp_dir,
ms=DEFAULT,
yaml=DEFAULT) as mocks:
# Mock required methods
self.model._inverse_parse_config = MagicMock(return_value=(self.mock_config, []))
self.model._wrap_config = MagicMock(return_value={"model": {}})
self.model.remove_type = MagicMock()
mocks['yaml'].dump.return_value = "yaml_dump_result"

# Call the method
self.model.save_pretrained_origin_mode()

# Verify yaml.dump was called
mocks['yaml'].dump.assert_called_once()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_save_pretrained_origin_mode_custom_directory(self):
"""Test save_pretrained_origin_mode with custom directory"""
with tempfile.TemporaryDirectory() as temp_dir:
custom_dir = os.path.join(temp_dir, "custom_save_dir")

with patch.multiple('mindformers.models.modeling_utils',
ms=DEFAULT,
yaml=DEFAULT) as mocks:
# Mock required methods
self.model._inverse_parse_config = MagicMock(return_value=(self.mock_config, []))
self.model._wrap_config = MagicMock(return_value={"model": {}})
self.model.remove_type = MagicMock()
mocks['yaml'].dump.return_value = "yaml_dump_result"

# Call the method
self.model.save_pretrained_origin_mode(save_directory=custom_dir)

# Verify directory was created
assert os.path.exists(custom_dir)

# Verify yaml.dump was called
mocks['yaml'].dump.assert_called_once()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_save_pretrained_origin_mode_invalid_types(self):
"""Test save_pretrained_origin_mode with invalid parameter types"""
with pytest.raises(TypeError) as context:
self.model.save_pretrained_origin_mode(save_directory=123, save_name="test")
assert "save_directory and save_name should be a str" in str(context.value)

with pytest.raises(TypeError) as context:
self.model.save_pretrained_origin_mode(save_directory="/tmp", save_name=123)
assert "save_directory and save_name should be a str" in str(context.value)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_save_pretrained_origin_mode_no_config(self):
"""Test save_pretrained_origin_mode when model has no config"""
self.model.config = None
with tempfile.TemporaryDirectory() as temp_dir:
with pytest.raises(AttributeError) as context:
self.model.save_pretrained_origin_mode(save_directory=temp_dir)
assert "has no attribute" in str(context.value)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_save_pretrained_experimental_mode_with_file_path(self):
"""Test save_pretrained_experimental_mode with file path instead of directory"""
with tempfile.TemporaryDirectory() as temp_dir:
file_path = os.path.join(temp_dir, "not_a_dir.txt")
# Create a file
with open(file_path, 'w', encoding='utf-8') as f:
f.write("test")

with patch('mindformers.models.modeling_utils.logger') as mock_logger:
self.model.save_pretrained_experimental_mode(save_directory=file_path)
# Should log an error
mock_logger.error.assert_called_once()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_remove_type_method(self):
"""Test remove_type method"""
# Test with PretrainedConfig
mock_config = PretrainedConfig()
mock_config.__dict__ = {"type": "test_type", "other_attr": "value"}

self.model.remove_type(mock_config)

# Type should be removed
assert "type" not in mock_config.__dict__

# Other attributes should remain
assert "other_attr" in mock_config.__dict__

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_inverse_parse_config_method(self):
"""Test _inverse_parse_config method"""

# Create a config with various types of attributes
config = PretrainedConfig()
config.test_attr = "test_value"
config.test_int = 42
config.test_float = 3.14
config.test_bool = True

# Call the method
result_config, _ = self.model._inverse_parse_config(config)

# Check that type was added
assert "type" in result_config.__dict__
assert result_config.__dict__["type"] == "PretrainedConfig"

# Check that basic types are preserved
assert "test_attr" in result_config.__dict__
assert "test_int" in result_config.__dict__
assert "test_float" in result_config.__dict__
assert "test_bool" in result_config.__dict__

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_wrap_config_method(self):
"""Test _wrap_config method"""

# Create a config
config = PretrainedConfig()
config.test_attr = "test_value"

# Mock to_dict method
config.to_dict = MagicMock(return_value={"test_attr": "test_value"})

# Call the method
result = self.model._wrap_config(config)

# Check the structure
assert "model" in result
assert "model_config" in result["model"]
assert "arch" in result["model"]
assert "type" in result["model"]["arch"]
assert result["model"]["arch"]["type"] == self.model.__class__.__name__


# pylint: disable=W0212
class TestPreTrainedModelLoading:
"""Test cases for PreTrainedModel loading methods"""

def setup_method(self):
"""Set up test environment"""
self.temp_dir = tempfile.mkdtemp()
self.model_class = PreTrainedModel

# Mock config
self.mock_config = MagicMock()
self.mock_config.name_or_path = "test_model"

# Mock model
self.mock_model = MagicMock()
self.mock_model.config = self.mock_config

def teardown_method(self):
"""Clean up test environment"""
shutil.rmtree(self.temp_dir)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.os.path.exists')
def test_get_config_args_nonexistent_model(self, mock_exists):
"""Test _get_config_args with nonexistent model"""
mock_exists.return_value = False
self.model_class._support_list = ['supported_model']

with pytest.raises(ValueError) as context:
self.model_class._get_config_args('unsupported_model')

assert "does not exist" in str(context.value)
assert "not supported" in str(context.value)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.os.path.exists')
@patch('mindformers.models.modeling_utils.os.path.isdir')
def test_get_config_args_file_instead_of_directory(self, mock_isdir, mock_exists):
"""Test _get_config_args with file path instead of directory"""
mock_exists.return_value = True
mock_isdir.return_value = False

with pytest.raises(ValueError) as context:
self.model_class._get_config_args('/path/to/file')

assert "is not a directory" in str(context.value)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.os.path.exists')
@patch('mindformers.models.modeling_utils.os.path.isdir')
@patch('mindformers.models.modeling_utils.os.listdir')
def test_get_config_args_missing_files(self, mock_listdir, mock_isdir, mock_exists):
"""Test _get_config_args with missing yaml or ckpt files"""
mock_exists.return_value = True
mock_isdir.return_value = True
mock_listdir.return_value = ['some_other_file.txt'] # No yaml or ckpt files

with pytest.raises(FileNotFoundError) as context:
self.model_class._get_config_args('/path/to/model_dir')

assert "no yaml file for model config" in str(context.value)

@patch('mindformers.models.modeling_utils.os.path.exists')
@patch('mindformers.models.modeling_utils.os.path.isdir')
@patch('mindformers.models.modeling_utils.os.listdir')
@patch('mindformers.models.modeling_utils.MindFormerConfig')
def test_get_config_args_local_directory(self, mock_config, mock_listdir, mock_isdir, mock_exists):
"""Test _get_config_args with local directory containing model files"""
mock_exists.return_value = True
mock_isdir.return_value = True
mock_listdir.return_value = ['config.yaml', 'model.ckpt']
mock_config.return_value = MagicMock()

# Mock model type indices
self.model_class._model_type = 0

self.model_class._get_config_args('/path/to/model_dir')

# Verify MindFormerConfig was called with correct yaml file
mock_config.assert_called_once()

@patch('mindformers.models.modeling_utils.os.path.exists')
@patch('mindformers.models.modeling_utils.os.path.isdir')
@patch('mindformers.models.modeling_utils.MindFormerConfig')
def test_get_config_args_not_is_dir(self, mock_config, mock_isdir, mock_exists):
"""Test _get_config_args with local directory containing model files"""
mock_exists.return_value = False
mock_isdir.return_value = False
mock_config.return_value = MagicMock()

# Mock model type indices
self.model_class._model_type = 0

self.model_class._support_list = ['common']
with pytest.raises(FileNotFoundError) as context:
self.model_class._get_config_args('common')
assert "default yaml file path must be correct" in str(context.value)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_is_experimental_mode_with_file(self):
"""Test is_experimental_mode with file path instead of directory"""
with patch.multiple('mindformers.models.modeling_utils.os.path',
exists=MagicMock(return_value=True),
isdir=MagicMock(return_value=False)):
result = self.model_class.is_experimental_mode('/path/to/file')
assert result is False

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.os.path.exists')
@patch('mindformers.models.modeling_utils.os.path.isdir')
@patch('mindformers.models.modeling_utils.os.listdir')
def test_is_experimental_mode_with_config_json(self, mock_listdir, mock_isdir, mock_exists):
"""Test is_experimental_mode with config.json file present"""
mock_exists.return_value = True
mock_isdir.return_value = True
mock_listdir.return_value = ['config.json'] # config.json but no yaml files

result = self.model_class.is_experimental_mode('/path/to/model_dir')
assert result is True

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.os.path.exists')
@patch('mindformers.models.modeling_utils.os.path.isdir')
@patch('mindformers.models.modeling_utils.os.listdir')
def test_is_experimental_mode_with_yaml_files(self, mock_listdir, mock_isdir, mock_exists):
"""Test is_experimental_mode with yaml files present"""
mock_exists.return_value = True
mock_isdir.return_value = True
mock_listdir.return_value = ['config.yaml', 'model.ckpt']

result = self.model_class.is_experimental_mode('/path/to/model_dir')
assert result is False

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_is_experimental_mode_huggingface_style(self):
"""Test is_experimental_mode with HuggingFace-style model path"""
result = self.model_class.is_experimental_mode('bert-base-uncased')
assert result is False

result = self.model_class.is_experimental_mode('mindspore/bert-base')
assert result is False

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_from_pretrained_invalid_type(self):
"""Test from_pretrained with invalid pretrained_model_name_or_dir type"""
with pytest.raises(TypeError) as context:
self.model_class.from_pretrained(123)

assert "pretrained_model_name_or_dir should be a str" in str(context.value)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.PreTrainedModel.is_experimental_mode')
@patch('mindformers.models.modeling_utils.PreTrainedModel.from_pretrained_experimental_mode')
def test_from_pretrained_experimental_mode(self, mock_experimental, mock_is_experimental):
"""Test from_pretrained routes to experimental mode"""
mock_is_experimental.return_value = True
mock_experimental.return_value = self.mock_model

result = self.model_class.from_pretrained('test_model')

mock_experimental.assert_called_once()
assert result == self.mock_model

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.PreTrainedModel.is_experimental_mode')
@patch('mindformers.models.modeling_utils.PreTrainedModel.from_pretrained_origin_mode')
def test_from_pretrained_origin_mode(self, mock_origin, mock_is_experimental):
"""Test from_pretrained routes to origin mode"""
mock_is_experimental.return_value = False
mock_origin.return_value = self.mock_model

result = self.model_class.from_pretrained('test_model')

mock_origin.assert_called_once()
assert result == self.mock_model

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.PreTrainedModel._get_config_args')
@patch('mindformers.models.modeling_utils.build_network')
def test_from_pretrained_origin_mode_success(self, mock_build_network, mock_get_config_args):
"""Test from_pretrained_origin_mode success case"""
mock_config_args = MagicMock()
mock_config_args.model = MagicMock()
mock_config_args.model.model_config = MagicMock()
mock_config_args.get.return_value = None
mock_get_config_args.return_value = mock_config_args
mock_build_network.return_value = self.mock_model

result = self.model_class.from_pretrained_origin_mode('test_model')

mock_get_config_args.assert_called_once()
mock_build_network.assert_called_once()
assert result == self.mock_model

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_from_pretrained_origin_mode_invalid_type(self):
"""Test from_pretrained_origin_mode with invalid type"""
with pytest.raises(TypeError) as context:
self.model_class.from_pretrained_origin_mode(123)

assert "pretrained_model_name_or_dir should be a str" in str(context.value)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.load_param_into_net')
def test_load_pretrained_model_success(self, mock_load_param):
"""Test _load_pretrained_model success case"""
mock_model = MagicMock()
mock_model.get_parameters.return_value = []
mock_model.base_model_prefix = ""
mock_model.config = MagicMock()
mock_model.config.architectures = None

mock_load_param.return_value = ([], []) # missing_keys, unexpected_keys

result_model, missing_keys, unexpected_keys, mismatched_keys = self.model_class._load_pretrained_model(
mock_model,
{}, # state_dict
None, # resolved_archive_file
"test_model"
)

assert result_model == mock_model
assert not missing_keys
assert not unexpected_keys
assert not mismatched_keys

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.load_param_into_net')
def test_load_pretrained_model_state_dict_none(self, mock_load_param):
"""Test load_pretrained_model_state_dict_none_case"""
mock_model = MagicMock()
mock_model.get_parameters.return_value = []
mock_model.base_model_prefix = ""
mock_model.config = MagicMock()
mock_model.config.architectures = None

mock_load_param.return_value = ([], []) # missing_keys, unexpected_keys

with pytest.raises(ValueError) as context:
self.model_class._load_pretrained_model(
mock_model,
None, # state_dict
None, # resolved_archive_file
"test_model"
)
assert "should be str, list or tuple" in str(context.value)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.load_param_into_net')
def test_load_pretrained_model_resolved_archive_file_list(self, mock_load_param):
"""Test load_pretrained_model_resolved_archive_file_list_case"""
mock_model = MagicMock()
mock_model.get_parameters.return_value = []
mock_model.base_model_prefix = ""
mock_model.config = MagicMock()
mock_model.config.architectures = None

mock_load_param.return_value = ([], []) # missing_keys, unexpected_keys

with pytest.raises(ValueError) as context:
self.model_class._load_pretrained_model(
mock_model,
None, # state_dict
["test"], # resolved_archive_file
"test_model"
)
assert "resolved_archive_file_:test not found!" in str(context.value)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.load_param_into_net')
def test_load_pretrained_model_resolved_archive_file_str(self, mock_load_param):
"""Test load_pretrained_model_resolved_archive_file_str_case"""
mock_model = MagicMock()
mock_model.get_parameters.return_value = []
mock_model.base_model_prefix = ""
mock_model.config = MagicMock()
mock_model.config.architectures = None

mock_load_param.return_value = ([], []) # missing_keys, unexpected_keys

with pytest.raises(ValueError) as context:
self.model_class._load_pretrained_model(
mock_model,
None, # state_dict
"test", # resolved_archive_file
"test_model"
)
assert "resolved_archive_file:test not found!" in str(context.value)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.os.path.exists')
def test_get_src_checkpoint_with_string_path(self, mock_exists):
"""Test _get_src_checkpoint with string path"""
mock_exists.return_value = True

with patch('mindformers.models.modeling_utils.make_soft_link') as mock_link:
result = self.model_class._get_src_checkpoint(
state_dict=None,
resolved_archive_file='/path/to/checkpoint.ckpt',
src_checkpoint='/path/to/src_checkpoint.ckpt'
)

mock_link.assert_called_once()
assert result == '/path/to/src_checkpoint.ckpt'

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.os.path.exists')
def test_get_src_checkpoint_with_list_path(self, mock_exists):
"""Test _get_src_checkpoint with list path"""
mock_exists.return_value = True

with pytest.raises(ValueError) as context:
self.model_class._get_src_checkpoint(
state_dict=None,
resolved_archive_file=['/path/to/checkpoint.ckpt'],
src_checkpoint='/path/to/src_checkpoint.ckpt'
)

assert "Failed to read the checkpoint file /path/to/checkpoint.ckpt" in str(context.value)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_register_for_auto_class_invalid_class(self):
"""Test register_for_auto_class with invalid auto class"""
with patch('mindformers.models.modeling_utils.auto_module') as mock_auto:
mock_auto.InvalidClass = None
self.model_class.register_for_auto_class('InvalidClass')


# pylint: disable=W0212
class TestFromPretrainedExperimentalMode:
"""Test cases for from_pretrained_experimental_mode method"""

def setup_method(self):
"""Set up test environment"""

# Create a mock model class
class MockModel(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config

self.mock_model_class = MockModel

# Create a mock config
self.mock_config = MagicMock(spec=PretrainedConfig)
self.mock_config.name_or_path = "test_model"
self.mock_config.__class__ = PretrainedConfig

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.cached_file')
def test_commit_hash_extraction_from_config(self, mock_cached_file):
"""Test commit hash extraction when config doesn't have _commit_hash"""
mock_cached_file.return_value = "/path/to/config.json"

with patch.multiple('mindformers.models.modeling_utils',
GenerationConfig=DEFAULT,
extract_commit_hash=DEFAULT) as mocks:
mocks['extract_commit_hash'].return_value = "abc123"

with patch.multiple(self.mock_model_class,
config_class=DEFAULT,
_load_pretrained_model=DEFAULT) as model_mocks:
model_mocks['config_class'].from_pretrained.return_value = (self.mock_config, {})
model_mocks['_load_pretrained_model'].return_value = (MagicMock(), [], [], [])

# Call the method with config that doesn't have _commit_hash
self.mock_model_class.from_pretrained_experimental_mode("test_model")

# Verify extract_commit_hash was called
mocks['extract_commit_hash'].assert_called_once()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.os.path.isdir')
@patch('mindformers.models.modeling_utils.os.path.isfile')
def test_local_directory_with_weights_file(self, mock_isfile, mock_isdir):
"""Test handling of local directory with weights file"""
mock_isdir.return_value = True
mock_isfile.return_value = True

with patch.multiple('mindformers.models.modeling_utils',
GenerationConfig=DEFAULT,
logger=DEFAULT) as mocks:
with patch.multiple(self.mock_model_class,
config_class=DEFAULT,
_load_pretrained_model=DEFAULT) as model_mocks:
model_mocks['config_class'].from_pretrained.return_value = (self.mock_config, {})
model_mocks['_load_pretrained_model'].return_value = (MagicMock(), [], [], [])

# Call the method with local directory path
self.mock_model_class.from_pretrained_experimental_mode("/local/model/path")

# Verify logger was called with loading info
mocks['logger'].info.assert_called()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.models.modeling_utils.os.path.isdir')
@patch('mindformers.models.modeling_utils.os.path.isfile')
@patch('mindformers.models.modeling_utils.cached_file')
def test_missing_weights_file_error(self, mock_cached_file, mock_isfile, mock_isdir):
"""Test error handling when weights file is missing"""
mock_isdir.return_value = True
mock_isfile.return_value = False # No weights file found
mock_cached_file.return_value = None # No cached file either

with patch.object(self.mock_model_class, 'config_class') as mock_config_class:
mock_config_class.from_pretrained.return_value = (self.mock_config, {})

with patch('mindformers.models.modeling_utils.has_file') as mock_has_file:
mock_has_file.return_value = False

# Should raise EnvironmentError when no weights file is found
with pytest.raises(EnvironmentError):
self.mock_model_class.from_pretrained_experimental_mode("/local/model/path")

+ 302
- 0
tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization/test_base_config.py View File

@@ -0,0 +1,302 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Unit tests for base_config.py"""

from typing import List, Any, Optional
import pytest
import mindspore
from mindspore import nn, Tensor

from mindformers.parallel_core.inference.quantization.base_config import (
QuantizeMethodBase,
QuantizationConfig,
)


class ConcreteQuantizeMethod(QuantizeMethodBase):
"""Concrete implementation of QuantizeMethodBase for testing."""

def __init__(self):
super().__init__()
self.weights_created = False

def create_weights(self, layer: nn.Cell, *_weight_args, **_extra_weight_attrs):
"""Create weights for a layer."""
self.weights_created = True
_ = layer

def apply(self, layer: nn.Cell, *_args, **_kwargs) -> Tensor:
"""Apply the weights in layer to the input tensor."""
_ = layer
if not self.weights_created:
raise RuntimeError("Weights must be created before applying")
return Tensor([1.0])


class ConcreteQuantizationConfig(QuantizationConfig):
"""Concrete implementation of QuantizationConfig for testing."""

def __init__(self, name: str = "test_quant"):
super().__init__()
self._name = name

def get_name(self) -> str:
"""Name of the quantization method."""
return self._name

def get_supported_act_dtypes(self) -> List[str]:
"""List of supported activation dtypes."""
return ["float16", "float32"]

@classmethod
def get_min_capability(cls) -> int:
"""Minimum capability to support the quantization method."""
return 70

@staticmethod
def get_config_filenames() -> list[str]:
"""List of filenames to search for in the model directory."""
return ["quantization_config.json"]

@classmethod
def from_config(cls, config: dict[str, Any]) -> "ConcreteQuantizationConfig":
"""Create a config class from the model's quantization config."""
name = config.get("quantization_type", "test_quant")
return cls(name=name)

def get_quant_method(
self, layer: mindspore.nn.Cell, prefix: str
) -> Optional[QuantizeMethodBase]:
"""Get the quantize method to use for the quantized layer."""
_ = prefix
if isinstance(layer, nn.Dense):
return ConcreteQuantizeMethod()
return None


class TestQuantizeMethodBase:
"""Test class for QuantizeMethodBase."""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_concrete_quantize_method_create_weights(self):
"""Test that concrete implementation can create weights."""
method = ConcreteQuantizeMethod()
layer = nn.Dense(10, 20)
method.create_weights(layer)
assert method.weights_created is True

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_concrete_quantize_method_apply(self):
"""Test that concrete implementation can apply weights."""
method = ConcreteQuantizeMethod()
layer = nn.Dense(10, 20)
method.create_weights(layer)
result = method.apply(layer)
assert isinstance(result, Tensor)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_quantize_method_base_embedding_raises_runtime_error(self):
"""Test that embedding method raises RuntimeError by default."""
method = ConcreteQuantizeMethod()
layer = nn.Dense(10, 20)
with pytest.raises(RuntimeError):
method.embedding(layer)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_quantize_method_base_process_weights_after_loading(self):
"""Test that process_weights_after_loading returns None by default."""
method = ConcreteQuantizeMethod()
layer = nn.Dense(10, 20)
method.process_weights_after_loading(layer)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_apply_without_create_weights_raises_error(self):
"""Test that apply without create_weights raises error."""
method = ConcreteQuantizeMethod()
layer = nn.Dense(10, 20)
# The concrete implementation should check if weights are created
with pytest.raises(RuntimeError):
method.apply(layer)


class TestQuantizationConfig:
"""Test class for QuantizationConfig."""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_concrete_config_initialization(self):
"""Test that concrete config initializes packed_modules_mapping."""
config = ConcreteQuantizationConfig()
assert isinstance(config.packed_modules_mapping, dict)
assert len(config.packed_modules_mapping) == 0

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_concrete_config_get_name(self):
"""Test get_name method."""
config = ConcreteQuantizationConfig(name="test_quantization")
assert config.get_name() == "test_quantization"

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_concrete_config_get_supported_act_dtypes(self):
"""Test get_supported_act_dtypes method."""
config = ConcreteQuantizationConfig()
dtypes = config.get_supported_act_dtypes()
assert isinstance(dtypes, list)
assert "float16" in dtypes
assert "float32" in dtypes

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_concrete_config_get_min_capability(self):
"""Test get_min_capability class method."""
capability = ConcreteQuantizationConfig.get_min_capability()
assert isinstance(capability, int)
assert capability == 70

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_concrete_config_get_config_filenames(self):
"""Test get_config_filenames static method."""
filenames = ConcreteQuantizationConfig.get_config_filenames()
assert isinstance(filenames, list)
assert "quantization_config.json" in filenames

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_concrete_config_from_config(self):
"""Test from_config class method."""
config_dict = {"quantization_type": "custom_quant"}
config = ConcreteQuantizationConfig.from_config(config_dict)
assert isinstance(config, ConcreteQuantizationConfig)
assert config.get_name() == "custom_quant"

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_concrete_config_get_quant_method(self):
"""Test get_quant_method method."""
config = ConcreteQuantizationConfig()
layer = nn.Dense(10, 20)
quant_method = config.get_quant_method(layer, prefix="dense_layer")
assert isinstance(quant_method, ConcreteQuantizeMethod)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_concrete_config_get_quant_method_returns_none(self):
"""Test get_quant_method returns None for unsupported layer."""
config = ConcreteQuantizationConfig()
layer = nn.ReLU() # Not a Dense layer
quant_method = config.get_quant_method(layer, prefix="relu_layer")
assert quant_method is None

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_get_from_keys_finds_first_key(self):
"""Test get_from_keys finds value using first matching key."""
config = {"quantization_type": "test", "quant_type": "alternative"}
result = QuantizationConfig.get_from_keys(config, ["quantization_type", "quant_type"])
assert result == "test"

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_get_from_keys_finds_second_key(self):
"""Test get_from_keys finds value using second key when first not present."""
config = {"quant_type": "alternative"}
result = QuantizationConfig.get_from_keys(config, ["quantization_type", "quant_type"])
assert result == "alternative"

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_get_from_keys_raises_value_error(self):
"""Test get_from_keys raises ValueError when no key found."""
config = {"other_key": "value"}
with pytest.raises(ValueError, match="Cannot find any of"):
QuantizationConfig.get_from_keys(config, ["quantization_type", "quant_type"])

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_get_from_keys_with_empty_keys(self):
"""Test get_from_keys with empty keys list raises ValueError."""
config = {"key": "value"}
with pytest.raises(ValueError, match="Cannot find any of"):
QuantizationConfig.get_from_keys(config, [])

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_get_from_keys_or_returns_value(self):
"""Test get_from_keys_or returns value when key exists."""
config = {"quantization_type": "test"}
result = QuantizationConfig.get_from_keys_or(
config, ["quantization_type"], "default_value"
)
assert result == "test"

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_get_from_keys_or_returns_default(self):
"""Test get_from_keys_or returns default when key does not exist."""
config = {"other_key": "value"}
default_value = "default_quant"
result = QuantizationConfig.get_from_keys_or(
config, ["quantization_type"], default_value
)
assert result == default_value

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_get_from_keys_or_with_none_default(self):
"""Test get_from_keys_or works with None as default."""
config = {"other_key": "value"}
result = QuantizationConfig.get_from_keys_or(config, ["quantization_type"], None)
assert result is None

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_get_from_keys_or_with_empty_config(self):
"""Test get_from_keys_or with empty config returns default."""
config = {}
default_value = "default"
result = QuantizationConfig.get_from_keys_or(config, ["any_key"], default_value)
assert result == default_value

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_packed_modules_mapping_mutable(self):
"""Test that packed_modules_mapping can be modified."""
config = ConcreteQuantizationConfig()
config.packed_modules_mapping["module1"] = ["weight1", "weight2"]
assert config.packed_modules_mapping["module1"] == ["weight1", "weight2"]
assert len(config.packed_modules_mapping) == 1

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_get_from_keys_with_different_value_types(self):
"""Test get_from_keys works with different value types."""
config = {
"int_value": 42,
"float_value": 3.14,
"list_value": [1, 2, 3],
"dict_value": {"nested": "value"},
}
assert QuantizationConfig.get_from_keys(config, ["int_value"]) == 42
assert QuantizationConfig.get_from_keys(config, ["float_value"]) == 3.14
assert QuantizationConfig.get_from_keys(config, ["list_value"]) == [1, 2, 3]
assert QuantizationConfig.get_from_keys(config, ["dict_value"]) == {"nested": "value"}

+ 0
- 0
tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/test_mapping/__init__.py View File


+ 174
- 0
tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/test_mapping/test_infer_mapping.py View File

@@ -0,0 +1,174 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""UTs for tensor-parallel mapping helpers."""
from functools import partial

import numpy as np
import pytest

import mindspore as ms
from mindspore import Tensor
import mindspore.common.dtype as mstype

from mindformers.parallel_core.inference.parallel_state import ProcessGroup
from mindformers.parallel_core.inference.tensor_parallel import mappings


ms.context.set_context(deterministic="ON")
jit_level = "O0"
infer_boost = "on"
ms.set_context(device_target="Ascend",
mode=ms.GRAPH_MODE,
jit_config={"jit_level": jit_level, "infer_boost": infer_boost})


class FakeGather:
"""Mock AllGather operator recording inputs."""

def __init__(self):
self.calls = []

def __call__(self, tensor):
self.calls.append(tensor)
return tensor


class FakeReduceScatter:
"""Mock ReduceScatter returning half-size tensor."""

def __init__(self):
self.calls = []

def __call__(self, tensor):
self.calls.append(tensor)
# Return the first split chunk
return tensor[:tensor.shape[0] // 2]


class FakeAllReduce:
"""Mock AllReduce returning tensor doubled."""

def __init__(self):
self.calls = []

def __call__(self, tensor):
self.calls.append(tensor)
return tensor * 2


class FakeSplit:
"""Mock Split op returning chunks."""

def __init__(self, axis, output_num):
self.axis = axis
self.output_num = output_num

def __call__(self, tensor):
return tuple(np.split(tensor.asnumpy(), self.output_num, axis=self.axis))

class TestTensorParallelMappings:
"""Groups mapping tests into a single suite."""

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_gather_returns_input_when_group_size_one(self):
"""
Test that gather_from_model_parallel_region returns the original tensor unchanged
when the process group size is 1.
"""
group = ProcessGroup(group=None, rank=0, size=1)
# pylint: disable=W0212
group._is_group_created = True
tensor = Tensor(np.ones((2, 2), dtype=np.float32), dtype=mstype.float32)

output = mappings.gather_from_model_parallel_region(tensor, group, dim=-1)

assert np.array_equal(output.asnumpy(), tensor.asnumpy())


@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_gather_transposes_when_dim_nonzero(self, monkeypatch):
"""
Test that gather_from_model_parallel_region correctly handles gathering along a non-last dimension.
"""
fake_gather = FakeGather()
monkeypatch.setattr(mappings.ops, "AllGather", lambda group: fake_gather)
group = ProcessGroup(group="test", rank=0, size=2)
# pylint: disable=W0212
group._is_group_created = True
tensor = Tensor(np.arange(6).reshape(3, 2).astype(np.float32), dtype=mstype.float32)

output = mappings.gather_from_model_parallel_region(tensor, group, dim=1)

assert output.shape == tensor.shape


@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_reduce_allreduce_invoked(self, monkeypatch):
"""
Test that reduce_from_model_parallel_region performs an AllReduce operation.
"""
fake_reduce = FakeAllReduce()
monkeypatch.setattr(mappings.ops, "AllReduce", lambda group: fake_reduce)
group = ProcessGroup(group="test", rank=0, size=2)
# pylint: disable=W0212
group._is_group_created = True
tensor = Tensor(np.ones((2, 2), dtype=np.float32), dtype=mstype.float32)

output = mappings.reduce_from_model_parallel_region(tensor, group)

assert np.array_equal(output.asnumpy(), (tensor * 2).asnumpy())


@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_reduce_scatter_returns_split(self, monkeypatch):
"""
Test that reduce_scatter_to_model_parallel_region performs a ReduceScatter operation.
"""
fake_reduce_scatter = FakeReduceScatter()
monkeypatch.setattr(mappings.ops, "ReduceScatter", lambda group: fake_reduce_scatter)
group = ProcessGroup(group="test", rank=0, size=2)
# pylint: disable=W0212
group._is_group_created = True
tensor = Tensor(np.ones((4, 2), dtype=np.float32), dtype=mstype.float32)

output = mappings.reduce_scatter_to_model_parallel_region(tensor, group)

assert output.shape[0] == tensor.shape[0] // 2


@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_scatter_returns_rank_chunk(self, monkeypatch):
"""
Test that scatter_to_model_parallel_region splits the input tensor along the specified dimension.
"""
monkeypatch.setattr(mappings.ops, "Split", partial(FakeSplit))
group = ProcessGroup(group="test", rank=1, size=2)
# pylint: disable=W0212
group._is_group_created = True
tensor = Tensor(np.arange(8).reshape(2, 4).astype(np.float32), dtype=mstype.float32)

output = mappings.scatter_to_model_parallel_region(tensor, group, dim=1)

assert output.shape == (2, 2)

+ 0
- 0
tests/st/test_ut/test_parallel_core/test_inference/test_transformer/test_fused_softmax/__init__.py View File


+ 92
- 0
tests/st/test_ut/test_parallel_core/test_inference/test_transformer/test_fused_softmax/test_infer_fused_softmax.py View File

@@ -0,0 +1,92 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""UTs for FusedScaleMaskSoftmax."""
import numpy as np
import pytest

import mindspore as ms
from mindspore import Tensor
import mindspore.common.dtype as mstype

from mindformers.parallel_core.inference.transformer.fused_softmax import FusedScaleMaskSoftmax


ms.context.set_context(deterministic="ON")
jit_level = "O0"
infer_boost = "on"
ms.set_context(device_target="Ascend",
mode=ms.GRAPH_MODE,
jit_config={"jit_level": jit_level, "infer_boost": infer_boost})


def simple_mask(tensor, mask):
"""Mask function for tests that multiplies by mask."""
return tensor + mask


class TestFusedScaleMaskSoftmax:
"""Tests covering the fused softmax helper."""

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_forward_pass_with_scale_and_mask(self):
"""
Test the forward pass of FusedScaleMaskSoftmax with both scaling and a mask applied.

Verifies that the module correctly applies the scale factor to the input tensor,
applies the provided attention mask, and computes the softmax, returning an output
with the expected shape. This tests the core functionality under typical conditions.
"""
fused_softmax = FusedScaleMaskSoftmax(mask_func=simple_mask, scale=0.5, softmax_compute_type=mstype.float32)
x = Tensor(np.array([[2.0, 0.0]], dtype=np.float32), dtype=mstype.float32)
mask = Tensor(np.array([[0.0, -1.0]], dtype=np.float32), dtype=mstype.float32)

output = fused_softmax(x, mask)

assert output.shape == (1, 2)

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_precision_casts_to_fp32_when_needed(self):
"""
Test that FusedScaleMaskSoftmax automatically casts inputs to float32 when required.

Verifies that when the softmax computation type is set to float32 but the input
tensor is in float16, the module performs the necessary precision casting to fp32
for the softmax operation, ensuring numerical stability, and returns an output
with the correct shape.
"""
fused_softmax = FusedScaleMaskSoftmax(mask_func=simple_mask, scale=None, softmax_compute_type=mstype.float32)
x = Tensor(np.array([[1.0, 1.0]], dtype=np.float16), dtype=mstype.float16)

output = fused_softmax(x, mask=None)

assert output.shape == (1, 2)

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_invalid_scale_precision_combination_raises(self):
"""
Test that FusedScaleMaskSoftmax raises a ValueError for invalid precision combinations.

Verifies that the module enforces the rule that if a scale factor is applied,
the softmax computation must be performed in float32 to maintain precision.
Attempting to use a scale with float16 computation should raise a ValueError.
"""
with pytest.raises(ValueError):
FusedScaleMaskSoftmax(mask_func=simple_mask, scale=0.1, softmax_compute_type=mstype.float16)

+ 0
- 0
tests/st/test_ut/test_parallel_core/test_inference/test_transformer/test_lower_triangular_mask/__init__.py View File


+ 70
- 0
tests/st/test_ut/test_parallel_core/test_inference/test_transformer/test_lower_triangular_mask/test_infer_lower_triangular_mask.py View File

@@ -0,0 +1,70 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Unit tests for LowerTriangularMaskWithDynamic."""
import numpy as np
import pytest

import mindspore as ms
from mindspore import Tensor
import mindspore.common.dtype as mstype

from mindformers.parallel_core.inference.transformer.lower_triangular_mask import (
LowerTriangularMaskWithDynamic,
)


ms.context.set_context(deterministic="ON")
jit_level = "O0"
infer_boost = "on"
ms.set_context(device_target="Ascend",
mode=ms.GRAPH_MODE,
jit_config={"jit_level": jit_level, "infer_boost": infer_boost})


class TestLowerTriangularMask:
"""Validates lower-triangular mask generation."""

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_lower_triangular_mask_in_prefill(self):
"""Prefill path should directly return the static fa mask."""
lower_triangular_mask = LowerTriangularMaskWithDynamic(seq_length=4, compute_type=mstype.float16)
lower_triangular_mask.is_prefill = True

mask = lower_triangular_mask(positions=Tensor(np.zeros((1, 4)), dtype=mstype.int32))
assert mask.shape == lower_triangular_mask.fa_lower_triangle_mask.shape

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_lower_triangular_mask_in_decode(self):
"""Decode path should gather using provided positions."""
lower_triangular_mask = LowerTriangularMaskWithDynamic(seq_length=4, compute_type=mstype.float16)
lower_triangular_mask.is_prefill = False
positions = Tensor(np.array([0, 2], dtype=np.int32))

mask = lower_triangular_mask(positions=positions)
expected_shape = (positions.shape[0], lower_triangular_mask.pa_lower_triangle_mask.shape[1])
assert mask.shape == expected_shape

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_lower_triangular_mask_bfloat16_in_prefill_mask(self):
"""When using bf16 compute type, mask coefficient becomes +1."""
lower_triangular_mask = LowerTriangularMaskWithDynamic(seq_length=4, compute_type=mstype.bfloat16)
mask = lower_triangular_mask.prefill()
assert mask.shape == lower_triangular_mask.fa_lower_triangle_mask.shape

+ 0
- 0
tests/st/test_ut/test_parallel_core/test_inference/test_transformer/test_moe/test_moe_utils/__init__.py View File


+ 114
- 0
tests/st/test_ut/test_parallel_core/test_inference/test_transformer/test_moe/test_moe_utils/test_infer_moe_utils.py View File

@@ -0,0 +1,114 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""UTs for `moe_utils.py`."""
import numpy as np
import pytest

import mindspore as ms
from mindspore import Tensor
import mindspore.common.dtype as mstype

from mindformers.parallel_core.inference.transformer.moe.moe_utils import (
group_limited_topk,
topk_routing_with_score_function,
)


ms.context.set_context(deterministic="ON")
jit_level = "O0"
infer_boost = "on"
ms.set_context(device_target="Ascend",
mode=ms.GRAPH_MODE,
jit_config={"jit_level": jit_level, "infer_boost": infer_boost})


class TestTopkRoutingWithScoreFunction:
"""Unit tests for the top-k routing helper."""

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_softmax_routing_returns_normalized_weights(self):
"""The weights should sum to one per token when normalization is enabled."""
logits = Tensor(
np.array([[1.0, 2.0, 0.5, -0.5], [-1.0, 0.0, 2.5, 1.0]], dtype=np.float32),
dtype=mstype.float32,
)
expert_weight, routing_map = topk_routing_with_score_function(
logits=logits,
topk=2,
num_experts=4,
score_function="softmax",
norm_topk_prob=True,
)

assert expert_weight.shape == (2, 2)
assert routing_map.shape == (2, 2)

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_sigmoid_routing_with_bias_without_normalization(self):
"""Bias should affect the chosen experts while weights stay unnormalized when disabled."""
logits = Tensor(
np.array([[0.0, -2.0, 2.0, 1.0]], dtype=np.float32),
dtype=mstype.float32,
)
expert_bias = Tensor(np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float32), dtype=mstype.float32)

expert_weight, routing_map = topk_routing_with_score_function(
logits=logits,
topk=2,
num_experts=4,
score_function="sigmoid",
expert_bias=expert_bias,
norm_topk_prob=False,
)

assert expert_weight.shape == (1, 2)
assert routing_map.shape == (1, 2)

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_group_limited_topk_only_selects_from_best_group(self):
"""group_limited_topk should not route experts outside the best group subset."""
scores = Tensor(np.array([[0.9, 0.8, 0.1, 0.2]], dtype=np.float32), dtype=mstype.float32)

probs, top_indices = group_limited_topk(
scores=scores,
topk=2,
num_experts=4,
num_groups=2,
group_topk=1,
)

assert probs.shape == (1, 2)
assert top_indices.shape == (1, 2)

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_invalid_score_function_raises(self):
"""An unsupported score function name should raise ValueError."""
logits = Tensor(np.zeros((1, 2), dtype=np.float32), dtype=mstype.float32)

with pytest.raises(ValueError):
topk_routing_with_score_function(
logits=logits,
topk=1,
num_experts=2,
score_function="unsupported",
)

+ 0
- 0
tests/st/test_ut/test_parallel_core/test_inference/test_utils/__init__.py View File


+ 331
- 0
tests/st/test_ut/test_parallel_core/test_inference/test_utils/test_utils.py View File

@@ -0,0 +1,331 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Unit tests for transformer utils helpers."""
import sys
from types import SimpleNamespace

import numpy as np
import pytest

from mindspore import Tensor, Parameter
import mindspore.common.dtype as mstype

from mindformers.parallel_core.transformer_config import TransformerConfig
import mindformers.parallel_core.inference.utils as transformer_utils


class DummySubCell:
"""Subcell exposing a sharded state dict."""

def __init__(self):
self.param = Parameter(
Tensor(np.ones((2, 2), dtype=np.float32), dtype=mstype.float32), name="sub.param"
)

def sharded_state_dict(self):
return {
"sub.param": {
"shape": self.param.shape,
"shard": (1, 2),
}
}

def name_cells(self):
return {"self": self}


class DummyNetwork:
"""Minimal network exposing parameters and cells."""

def __init__(self):
self.sub = DummySubCell()
self.head = Parameter(Tensor(np.ones((2,), dtype=np.float32), dtype=mstype.float32), name="head.bias")

def name_cells(self):
return {"self": self, "sub": self.sub}

def parameters_dict(self):
return {"sub.param": self.sub.param, "head.bias": self.head}


class TestAttnMaskHelpers:
"""Tests for attention mask helpers."""

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_attn_mask_fill_applies_value(self):
"""
Test 'attn_mask_fill' function correctly applies the fill value to masked positions.
"""
func = transformer_utils.get_attn_mask_func("attn_mask_fill")
scores = Tensor(np.ones((1, 2), dtype=np.float32), dtype=mstype.float32)
mask = Tensor(np.array([[False, True]]), dtype=mstype.bool_)
output = func(scores, mask, fill_value=-9.0)

output_np = output.asnumpy()
assert output_np[0, 0] == pytest.approx(1.0, rel=1e-6)
assert output_np[0, 1] == pytest.approx(-9.0, rel=1e-6)

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_attn_mask_add_casts_mask(self):
"""
Test 'attn_mask_add' function adding a float mask to attention scores.
"""
func = transformer_utils.get_attn_mask_func("attn_mask_add")
scores = Tensor(np.zeros((1, 2), dtype=np.float32), dtype=mstype.float32)
mask = Tensor(np.array([[0.0, -5.0]], dtype=np.float32), dtype=mstype.float32)
output = func(scores, mask)

output_np = output.asnumpy()
assert output_np.shape == (1, 2)

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_get_attn_mask_func_with_invalid_name(self):
"""
Test get_attn_mask_func raising a KeyError for an unsupported mask function type.
"""
with pytest.raises(KeyError):
transformer_utils.get_attn_mask_func("unknown")


class TestStateDictGeneration:
"""Tests for sharded state dict utilities."""

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_generate_state_dict_includes_sharded_and_full_params(self, monkeypatch):
"""
Test that generate_state_dict correctly includes both sharded and non-sharded parameters.
"""
monkeypatch.setattr(transformer_utils, "get_group_size", lambda: 2)
state_dict = transformer_utils.generate_state_dict(DummyNetwork())

assert state_dict["total_rank"] == 2
assert "sub.param" in state_dict["model"]
assert "head.bias" in state_dict["model"]
assert state_dict["model"]["head.bias"]["shard"] == (1,)


class TestCommAndTopologyHelpers:
"""Tests targeting communication helper utilities."""

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_update_comm_config_single_tp_multi_dp(self, monkeypatch):
"""
Test update_comm_config for a configuration with single tensor parallel group and multiple data parallel groups.
"""
monkeypatch.setattr(transformer_utils, "get_tensor_model_parallel_world_size", lambda: 1)
monkeypatch.setattr(transformer_utils, "get_data_parallel_world_size", lambda: 2)
monkeypatch.setattr(transformer_utils, "get_moe_tensor_parallel_world_size", lambda: 1)

config = TransformerConfig(num_layers=1, num_attention_heads=1)
updated = transformer_utils.update_comm_config(config)

assert updated.use_alltoall is True
assert updated.attn_allreduce is False
assert updated.ffn_allreduce is False

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_update_comm_config_moe_tp_enabled(self, monkeypatch):
"""
Test update_comm_config when MOE tensor parallelism is enabled.
"""
monkeypatch.setattr(transformer_utils, "get_tensor_model_parallel_world_size", lambda: 1)
monkeypatch.setattr(transformer_utils, "get_data_parallel_world_size", lambda: 2)
monkeypatch.setattr(transformer_utils, "get_moe_tensor_parallel_world_size", lambda: 2)

config = TransformerConfig(num_layers=1, num_attention_heads=1)
updated = transformer_utils.update_comm_config(config)

assert updated.attn_allgather is True
assert updated.ffn_reduce_scatter is True
assert updated.ffn_allreduce is False

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_get_num_layers_and_offset_with_pp_offsets(self, monkeypatch):
"""
Test get_num_layers_and_offset with a valid pipeline parallel offset configuration.
"""
monkeypatch.setattr(transformer_utils, "get_pipeline_model_parallel_world_size", lambda: 2)
monkeypatch.setattr(transformer_utils, "get_pipeline_model_parallel_rank", lambda: 1)

config = TransformerConfig(num_layers=5, offset=[1, 0], num_attention_heads=1)

layers, offset = transformer_utils.get_num_layers_and_offset(config)

assert layers == 2
assert offset == 3

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_get_num_layers_and_offset_raises_for_small_model(self, monkeypatch):
"""
Test that get_num_layers_and_offset raises RuntimeError when the model has too few layers.
"""
monkeypatch.setattr(transformer_utils, "get_pipeline_model_parallel_world_size", lambda: 8)

config = TransformerConfig(num_layers=4, num_attention_heads=1)

with pytest.raises(RuntimeError):
transformer_utils.get_num_layers_and_offset(config)

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_get_num_layers_and_offset_invalid_offset_shape(self, monkeypatch):
"""
Test that get_num_layers_and_offset raises ValueError for an offset list with incorrect length.
"""
monkeypatch.setattr(transformer_utils, "get_pipeline_model_parallel_world_size", lambda: 2)

config = TransformerConfig(num_layers=6, offset=[1, 0, 0], num_attention_heads=1)

with pytest.raises(ValueError):
transformer_utils.get_num_layers_and_offset(config)


class TestMathHelpers:
"""Tests for small math helpers."""

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_divide_checks_divisibility(self):
"""
Test that the divide function checks for exact divisibility.
"""
assert transformer_utils.divide(6, 3) == 2
with pytest.raises(ValueError):
transformer_utils.divide(5, 3)


class TestCustomOpsToggle:
"""Tests for custom ops toggling."""

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_use_ms_custom_ops_false_when_module_missing(self, monkeypatch):
"""
Test that use_ms_custom_ops returns False when the 'ms_custom_ops' module is not imported.

Ensures the fallback mechanism works correctly if the custom operators package is unavailable.
"""
monkeypatch.setitem(sys.modules, "ms_custom_ops", None)
assert transformer_utils.use_ms_custom_ops() is False

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_use_ms_custom_ops_true_when_module_present(self, monkeypatch):
"""
Test that use_ms_custom_ops returns True when the 'ms_custom_ops' module is present and not on 310p.

Verifies the primary condition for enabling custom operators based on module availability.
"""
monkeypatch.setitem(sys.modules, "ms_custom_ops", SimpleNamespace())
monkeypatch.setattr(transformer_utils, "is_310p", lambda: False)
assert transformer_utils.use_ms_custom_ops() is True

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_use_ms_custom_ops_false_when_310p(self, monkeypatch):
"""
Test that use_ms_custom_ops returns False even if the module is present when running on 310p.

Confirms the hardware-specific override that disables custom operators on the Ascend 310P platform.
"""
monkeypatch.setitem(sys.modules, "ms_custom_ops", SimpleNamespace())
monkeypatch.setattr(transformer_utils, "is_310p", lambda: True)
assert transformer_utils.use_ms_custom_ops() is False


class TestParameterUtility:
"""Covers helpers related to parameter creation."""

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_create_empty_parameter_returns_expected_shape(self):
"""
Test that create_empty_parameter creates a Parameter with the specified shape and data type.
"""
param = transformer_utils.create_empty_parameter((2, 3), dtype=mstype.float32, name="dummy")
assert param.shape == (2, 3)
assert param.dtype == mstype.float32


class TestWorldSizeFallbacks:
"""Ensure fallback logic returns non-zero defaults."""

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_world_size_helpers_default_to_one(self, monkeypatch):
"""
Test that world size helper functions default to 1 when underlying query functions return 0.

Ensures robustness by providing safe defaults for parallelism degrees, preventing division by zero.
"""
monkeypatch.setattr(transformer_utils, "get_tensor_model_parallel_world_size", lambda: 0)
monkeypatch.setattr(transformer_utils, "get_moe_tensor_parallel_world_size", lambda: 0)
monkeypatch.setattr(transformer_utils, "get_moe_expert_parallel_world_size", lambda: 0)
monkeypatch.setattr(transformer_utils, "get_data_parallel_world_size", lambda: 0)

assert transformer_utils.get_tp_world_size() == 1
assert transformer_utils.get_moe_tp_world_size() == 1
assert transformer_utils.get_moe_ep_world_size() == 1
assert transformer_utils.get_dp_world_size() == 1


class TestPaddingIndexGeneration:
"""Tests for generate_padding_index helper."""

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_generate_padding_index_single_dp(self, monkeypatch):
"""
Test generate_padding_index for a simple case with single data parallel group.

Verifies that the function generates padding and unpadding indices with the correct shape
based on the input sequence lengths.
"""
monkeypatch.setattr(transformer_utils, "get_tensor_model_parallel_world_size", lambda: 1)
monkeypatch.setattr(transformer_utils, "get_data_parallel_world_size", lambda: 1)
monkeypatch.setattr(transformer_utils, "get_data_parallel_group",
lambda: SimpleNamespace(rank=0, group=None))

q_seq_lens = Tensor(np.array([[2]], dtype=np.int32))
attn_pad, attn_unpad, ffn_pad, ffn_unpad = transformer_utils.generate_padding_index(q_seq_lens)

assert attn_pad.shape == (2,)
assert attn_unpad.shape == (2,)
assert ffn_pad.shape == (2,)
assert ffn_unpad.shape == (2,)

+ 234
- 0
tests/st/test_ut/test_parallel_core/test_inference/test_weights_utils.py View File

@@ -0,0 +1,234 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test weights_utils.py
"""
from unittest.mock import patch, MagicMock
import numpy as np
import pytest

from mindformers.parallel_core.inference.weights_utils import (deal_training_qkv_weight, deal_training_ffn_weight,
deal_training_moe_weight,
make_expert_params_mapping_with_expert_dim,
split_fusion_loaded_weight)


class TestDealTrainingQkvWeight:
"""Test class for testing deal_training_qkv_weight."""
@pytest.fixture
def mock_config(self):
"""Create a mock config object"""
config = MagicMock()
config.kv_channels = None
config.hidden_size = 768
config.num_attention_heads = 12
config.num_query_groups = 12
return config

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_2d_weight_tp_size(self, mock_config):
"""Test with 2D weight and tensor parallel size 2"""
head_dim = mock_config.hidden_size // mock_config.num_attention_heads
q_channel = mock_config.num_attention_heads * head_dim
kv_channel = mock_config.num_query_groups * head_dim
total_channels = q_channel + 2 * kv_channel

weight = np.random.rand(total_channels, 1024).astype(np.float32)

with patch('mindformers.parallel_core.inference.parallel_state.get_tensor_model_parallel_world_size',
return_value=2), \
patch('mindformers.parallel_core.inference.parallel_state.get_tensor_model_parallel_rank',
return_value=0), \
patch('mindformers.parallel_core.inference.weights_utils.split_loaded_weight') as mock_split:
def mock_split_side_effect(w, axis, start, size):
if axis == 0:
return w[start:start + size, :]
return w

mock_split.side_effect = mock_split_side_effect
result = deal_training_qkv_weight(weight, mock_config)
assert result.ndim == 2
assert result.shape[0] == weight.shape[0]

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_tp_size_greater_than_num_query_groups(self, mock_config):
"""Test when tensor parallel size is greater than number of query groups"""
mock_config.num_query_groups = 2
head_dim = mock_config.hidden_size // mock_config.num_attention_heads
q_channel = mock_config.num_attention_heads * head_dim
kv_channel = mock_config.num_query_groups * head_dim
total_channels = q_channel + 2 * kv_channel
weight = np.random.rand(total_channels, 1024).astype(np.float32)
with patch('mindformers.parallel_core.inference.parallel_state.get_tensor_model_parallel_world_size',
return_value=4), \
patch('mindformers.parallel_core.inference.parallel_state.get_tensor_model_parallel_rank',
return_value=1), \
patch('mindformers.parallel_core.inference.weights_utils.split_loaded_weight') as mock_split:
def mock_split_side_effect(w, axis, start, size):
if axis == 0:
return w[start:start + size, :]
return w
mock_split.side_effect = mock_split_side_effect
result = deal_training_qkv_weight(weight, mock_config)
assert result is not None
assert result.ndim == 2


class TestDealTrainingFfnWeight:
"""Test class for testing deal_training_ffn_weight."""
@pytest.fixture
def mock_config(self):
"""Create a mock config object"""
config = MagicMock()
config.ffn_hidden_size = 4096
return config

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_1d_weight_tp_size_2(self, mock_config):
"""Test with 1D weight and tensor parallel size 2"""
w = mock_config.ffn_hidden_size * 2 # For W1 and W3
weight = np.random.rand(w).astype(np.float32)
with patch('mindformers.parallel_core.inference.parallel_state.get_tensor_model_parallel_world_size',
return_value=2), \
patch('mindformers.parallel_core.inference.parallel_state.get_tensor_model_parallel_rank',
return_value=0), \
patch('mindformers.parallel_core.inference.weights_utils.split_loaded_weight') as mock_split:
def mock_split_side_effect(w, axis, start, size):
if axis == 0:
return w[start:start + size]
return w

mock_split.side_effect = mock_split_side_effect
result = deal_training_ffn_weight(weight, mock_config)
assert result.ndim == 1
assert result.shape[0] == weight.shape[0]

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_2d_weight_tp_size_2(self, mock_config):
"""Test with 2D weight and tensor parallel size 2"""
w = mock_config.ffn_hidden_size * 2
h = 1024
weight = np.random.rand(w, h).astype(np.float32)

with patch('mindformers.parallel_core.inference.parallel_state.get_tensor_model_parallel_world_size',
return_value=2), \
patch('mindformers.parallel_core.inference.parallel_state.get_tensor_model_parallel_rank',
return_value=0), \
patch('mindformers.parallel_core.inference.weights_utils.split_loaded_weight') as mock_split:
def mock_split_side_effect(w, axis, start, size):
if axis == 0:
return w[start:start + size, :]
return w
mock_split.side_effect = mock_split_side_effect
result = deal_training_ffn_weight(weight, mock_config)
assert result.ndim == 2
assert result.shape[0] == weight.shape[0]
assert result.shape[1] == weight.shape[1]


class TestDealTrainingMoeWeight:
"""Test class for testing deal_training_moe_weight."""
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_moe_weight_tp_size_4(self):
"""Test with tensor parallel size 4"""
w = 2048
h = 1024
weight = np.random.rand(h, w).astype(np.float32)
with patch('mindformers.parallel_core.inference.parallel_state.get_tensor_model_parallel_world_size',
return_value=4), \
patch('mindformers.parallel_core.inference.parallel_state.get_tensor_model_parallel_rank',
return_value=1), \
patch('mindformers.parallel_core.inference.weights_utils.split_loaded_weight') as mock_split:
def mock_split_side_effect(w, axis, start, size):
if axis == 1:
return w[:, start:start + size]
return w
mock_split.side_effect = mock_split_side_effect
result = deal_training_moe_weight(weight)
assert result.shape[0] == weight.shape[0]
assert result.shape[1] == weight.shape[1]


class TestMakeExpertParamsMappingWithExpertDim:
"""Test class for testing make_expert_params_mapping_with_expert_dim."""
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_make_expert_params_mapping_basic(self):
"""Test basic expert parameter mapping generation"""
ckpt_gate_proj_name = "gate_proj"
ckpt_down_proj_name = "down_proj"
ckpt_up_proj_name = "up_proj"
result = make_expert_params_mapping_with_expert_dim(
ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name
)
assert len(result) == 3
for param_tuple in result:
assert len(param_tuple) == 3
assert isinstance(param_tuple[0], str)
assert isinstance(param_tuple[1], str)
assert isinstance(param_tuple[2], str)

expected_shard_ids = ['w1', 'w2', 'w3']
actual_shard_ids = [item[2] for item in result]
assert actual_shard_ids == expected_shard_ids

for param_tuple in result:
weight_name = param_tuple[1]
weight_prefix = param_tuple[0]
if 'gate_proj' in weight_name or 'up_proj' in weight_name:
assert weight_prefix == "experts.weight1"
else:
assert weight_prefix == "experts.weight2"


class TestSplitFusionLoadedWeight:
"""Test class for testing split_fusion_loaded_weight."""
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_split_fusion_loaded_weight_2d(self):
"""Test with 2D array input"""
loaded_weight = np.array([
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12],
[13, 14, 15],
[16, 17, 18]
], dtype=np.float32)

start_idxs = [0, 2, 4]
shard_sizes = [2, 2, 2]
result = split_fusion_loaded_weight(loaded_weight, start_idxs, shard_sizes)
expected = np.array([
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12],
[13, 14, 15],
[16, 17, 18]
], dtype=np.float32)
np.testing.assert_array_equal(result, expected)

+ 283
- 0
tests/st/test_ut/test_pipeline/test_base_pipeline.py View File

@@ -0,0 +1,283 @@
"""Unit tests for the base pipeline helpers."""

import pytest

from mindformers.pipeline.base_pipeline import Pipeline
from mindformers.pipeline import base_pipeline as base_pipeline_module


class _SaveableComponent:
"""Simple helper capturing save_pretrained calls."""

def __init__(self):
self.calls = []

def save_pretrained(self, *args, **kwargs):
self.calls.append((args, kwargs))


class DummyPipeline(Pipeline):
"""Lightweight Pipeline implementation for white-box tests."""

def __init__(self):
# Skip the heavy super().__init__ and wire up the minimal attributes needed
# by the helper methods under test.
self.model = _SaveableComponent()
self.tokenizer = None
self.feature_extractor = None
self.image_processor = None
self.network = None
self._preprocess_params = {"base_pre": 0}
self._forward_params = {"base_fw": 0}
self._postprocess_params = {"base_post": 0}
self.call_count = 0
self._batch_size = None
self.records = []

def _sanitize_parameters(self, **pipeline_parameters):
# Mirror the real contract: return tuple of (preprocess, forward, postprocess) overrides
return (
pipeline_parameters.get("preprocess_params", {}),
pipeline_parameters.get("forward_params", {}),
pipeline_parameters.get("postprocess_params", {}),
)

def preprocess(self, inputs, **preprocess_params):
payload = {"kind": "preprocess", "inputs": inputs, "params": preprocess_params}
self.records.append(payload)
return payload

def _forward(self, model_inputs, **forward_params):
payload = {"kind": "forward", "inputs": model_inputs, "params": forward_params}
self.records.append(payload)
return payload

def postprocess(self, model_outputs, **postprocess_params):
payload = {"kind": "postprocess", "inputs": model_outputs, "params": postprocess_params}
self.records.append(payload)
# run_multi expects run_single to return an iterable so wrap the payload in a list
return [payload]


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_call_with_scalar_merges_params_and_updates_call_count():
"""
Feature: __call__ merging logic
Description: Verify scalar inputs merge default/sanitized params and bump call_count.
Expectation: Preprocess receives merged dict, forward/postprocess mirror sanitized data.
"""
pipe = DummyPipeline()
result = pipe(
"single",
preprocess_params={"extra": 1},
forward_params={"fw_extra": 2},
postprocess_params={"post_extra": 3},
)

assert pipe.call_count == 1
assert pipe.records[0]["params"] == {"base_pre": 0, "extra": 1}
assert pipe.records[1]["params"] == {"base_fw": 0, "fw_extra": 2}
assert pipe.records[2]["params"] == {"base_post": 0, "post_extra": 3}
assert result[0]["kind"] == "postprocess"


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_call_with_list_invokes_run_multi_with_default_batch_size():
"""
Feature: __call__ list dispatch
Description: Ensure list inputs route through run_multi using configured batch_size.
Expectation: run_multi observes provided inputs and default batch size when none passed.
"""

class SpyPipeline(DummyPipeline):
"""Capture arguments passed to run_multi for verification."""

def __init__(self):
super().__init__()
self.multi_args = None

def run_multi(self, inputs, batch_size, preprocess_params, forward_params_unused, postprocess_params_unused):
del forward_params_unused
del postprocess_params_unused
self.multi_args = {
"inputs": inputs,
"batch_size": batch_size,
"preprocess": preprocess_params,
}
return ["multi"]

pipe = SpyPipeline()
pipe.batch_size = 3
outcome = pipe([1, 2, 3], preprocess_params={"l": 1})

assert outcome == ["multi"]
assert pipe.multi_args == {
"inputs": [1, 2, 3],
"batch_size": 3,
"preprocess": {"base_pre": 0, "l": 1},
}
assert pipe.call_count == 1


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_run_single_executes_pipeline_stages():
"""
Feature: Pipeline stage orchestration
Description: Ensure `run_single` invokes preprocess/forward/postprocess sequentially.
Expectation: Stage records follow the expected order and final output originates from postprocess.
"""
pipe = DummyPipeline()
result = pipe.run_single(
inputs="sample",
preprocess_params={"prep": 1},
forward_params={"fw": 2},
postprocess_params={"post": 3},
)

assert result[0]["kind"] == "postprocess"
kinds = [entry["kind"] for entry in pipe.records]
assert kinds == ["preprocess", "forward", "postprocess"]


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_run_multi_batches_and_validates_length():
"""
Feature: Batch execution logic
Description: Validate `run_multi` chunks inputs correctly and enforces batch divisibility.
Expectation: Outputs accumulate per batch and ValueError raised when length mismatch occurs.
"""
pipe = DummyPipeline()
outputs = pipe.run_multi(
inputs=[1, 2, 3, 4],
batch_size=2,
preprocess_params={},
forward_params={},
postprocess_params={},
)
assert len(outputs) == 2

with pytest.raises(ValueError):
pipe.run_multi([1, 2, 3], batch_size=2, preprocess_params={}, forward_params={}, postprocess_params={})


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_call_with_dataset_handles_batching(monkeypatch):
"""
Feature: Dataset path handling
Description: Simulate GeneratorDataset inputs to ensure batching and iteration logic execute.
Expectation: Dataset batches with provided size, outputs cover all batches, tqdm is bypassed for speed.
"""

class FakeBatchDataset:
def __init__(self, data, batch_size):
self.data = data
self.batch_size = batch_size

def create_dict_iterator(self, do_copy=False):
del do_copy
for idx in range(0, len(self.data), self.batch_size):
yield {"chunk": tuple(self.data[idx:idx + self.batch_size])}

class FakeDataset:
def __init__(self, data):
self.data = data
self.batched_with = None

def batch(self, batch_size):
self.batched_with = batch_size
return FakeBatchDataset(self.data, batch_size)

monkeypatch.setattr(base_pipeline_module, "GeneratorDataset", FakeDataset)
monkeypatch.setattr(base_pipeline_module, "BatchDataset", FakeBatchDataset)
monkeypatch.setattr(base_pipeline_module, "RepeatDataset", FakeBatchDataset)
monkeypatch.setattr(base_pipeline_module, "tqdm", lambda data, **_: data)

dataset = FakeDataset([1, 2, 3, 4])
pipe = DummyPipeline()
outputs = pipe(dataset, batch_size=2)

assert len(outputs) == 2
assert dataset.batched_with == 2


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_batch_size_property_validation():
"""
Feature: Batch size setter guardrails
Description: `batch_size` must accept positive integers only.
Expectation: Valid integers are stored; invalid types or negatives raise ValueError.
"""
pipe = DummyPipeline()
pipe.batch_size = 4
assert pipe.batch_size == 4

with pytest.raises(ValueError):
pipe.batch_size = -1
with pytest.raises(ValueError):
pipe.batch_size = 1.5


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_save_pretrained_invokes_all_components(tmp_path):
"""
Feature: save_pretrained delegation
Description: Ensure model/tokenizer/image helpers receive save requests and file paths are created.
Expectation: Each component's `save_pretrained` is called once; early exit occurs when path is a file.
"""
pipe = DummyPipeline()
pipe.tokenizer = _SaveableComponent()
pipe.feature_extractor = _SaveableComponent()
pipe.image_processor = _SaveableComponent()

target_dir = tmp_path / "pipeline"
pipe.save_pretrained(str(target_dir), save_name="custom")

assert pipe.model.calls[0][0][0] == str(target_dir)
assert pipe.tokenizer.calls[0][0][0] == str(target_dir)
assert target_dir.exists()

# When a file path is provided, the helper should short-circuit without invoking save_pretrained.
file_path = tmp_path / "model.bin"
file_path.write_text("stub")
pipe.model.calls.clear()
pipe.save_pretrained(str(file_path))
assert not pipe.model.calls


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_transform_and_predict_delegate_to_call():
"""
Feature: Compatibility helpers
Description: Ensure `transform` and `predict` are thin proxies over `__call__`.
Expectation: Calls are forwarded verbatim and return value propagated.
"""

class ProxyPipeline(DummyPipeline):
def __init__(self):
super().__init__()
self.calls = []

def __call__(self, *args, **kwargs):
self.calls.append((args, kwargs))
return "delegated"

pipe = ProxyPipeline()
assert pipe.transform("x") == "delegated"
assert pipe.predict(data="y") == "delegated"
assert pipe.calls == [(("x",), {}), ((), {"data": "y"})]

+ 405
- 0
tests/st/test_ut/test_pipeline/test_pipeline.py View File

@@ -0,0 +1,405 @@
"""Unit tests covering the public pipeline entry points."""

import importlib
import sys
import types

import pytest

# Import the pipeline module object (the file `mindformers/pipeline/pipeline.py`).
# The previous line `from mindformers.pipeline import pipeline` imports the
# `pipeline` symbol (a function) exported by the package, not the module file.
pipeline_module = importlib.import_module("mindformers.pipeline.pipeline")


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_is_experimental_mode_various_cases(tmp_path):
"""
Feature: experimental mode detection
Description: Verify `is_experimental_mode` correctly distinguishes between
string repo identifiers, local directories, non-string models,
and raises when experimental-only kwargs are present.
Expectation: Returns True for repo-like strings and directories, False for
non-string model instances, and raises ValueError when illegal
experimental kwargs are passed with a plain model string.
"""
# non-string model instance -> not experimental
assert pipeline_module.is_experimental_mode(model=123) is False

# a string with a slash (repo name) and not starting with 'mindspore' -> experimental
assert pipeline_module.is_experimental_mode("owner/repo") is True

# a local directory path -> experimental
d = tmp_path / "model_dir"
d.mkdir()
assert pipeline_module.is_experimental_mode(str(d)) is True

# model string without slash + experimental-only kw should raise
with pytest.raises(ValueError):
pipeline_module.is_experimental_mode("model_name", config=1)


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_clean_custom_task_transform_ms_attribute():
"""
Feature: custom task cleaning
Description: Verify `clean_custom_task` transforms a string entry in the
`ms` field into the corresponding attribute/object from the
`mindformers` package.
Expectation: Returns a cleaned dict where `ms` is a tuple of resolved
attributes from the injected `mindformers` module.
"""
# prepare a fake mindformers module with an attribute we can resolve
dummy_mod = types.SimpleNamespace()

class DummyClass:
pass

setattr(dummy_mod, "MyDummy", DummyClass)

# inject into sys.modules so that import inside function will pick it up
sys.modules["mindformers"] = dummy_mod

try:
task_info = {"impl": "irrelevant", "ms": "MyDummy"}
cleaned, _ = pipeline_module.clean_custom_task(task_info)
assert isinstance(cleaned, dict)
assert isinstance(cleaned["ms"], tuple)
# The resolved item should be the DummyClass we provided
assert cleaned["ms"][0] is DummyClass
finally:
# remove our injected module to avoid side effects for other tests
del sys.modules["mindformers"]


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_clean_custom_task_missing_impl_raises():
"""
Feature: custom task validation
Description: `clean_custom_task` must fail when `impl` key is missing.
Expectation: Raises RuntimeError.
"""
with pytest.raises(RuntimeError):
pipeline_module.clean_custom_task({})


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_load_model_success_and_failure_paths():
"""
Feature: dynamic model loading
Description: Exercise `load_model` code paths where the first candidate
model class succeeds, and where all candidates fail.
Expectation: Returns instantiated model for successful class; raises
ValueError if all classes fail to load.
"""

# Successful loader class
class SuccessModel:
def __init__(self):
self._ok = True

@classmethod
def from_pretrained(cls, *_args, **_kwargs):
return cls()

cfg = types.SimpleNamespace(architectures=[])
model = pipeline_module.load_model("some-id", cfg, model_classes=(SuccessModel,), task="t")
assert getattr(model, "_ok", False) is True

# All failing loader classes -> ValueError expected
class Fail1:
@classmethod
def from_pretrained(cls, *_args, **_kwargs):
raise OSError("fail1")

class Fail2:
@classmethod
def from_pretrained(cls, *_args, **_kwargs):
raise ValueError("fail2")

with pytest.raises(ValueError):
pipeline_module.load_model("some-id", cfg, model_classes=(Fail1, Fail2), task="t")


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_ms_pipeline_invalid_task_raises_keyerror():
"""
Feature: pipeline task validation
Description: `get_ms_pipeline` should raise if task is not registered in
`SUPPORT_PIPELINES`.
Expectation: Raises KeyError for invalid task names.
"""
with pytest.raises(KeyError):
pipeline_module.get_ms_pipeline("nonexistent_task", None, None, None, None)


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_ms_experimental_pipeline_requires_task_and_model():
"""
Feature: experimental pipeline preconditions
Description: `get_ms_experimental_pipeline` must raise a RuntimeError when
task or model is not provided.
Expectation: Raises RuntimeError when task and model are both None.
"""
with pytest.raises(RuntimeError):
pipeline_module.get_ms_experimental_pipeline(task=None, model=None)


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_pipeline_dispatches_between_standard_and_experimental(monkeypatch):
"""
Feature: pipeline entry point
Description: Ensure `pipeline` routes to experimental or standard builders based on model input.
Expectation: Correct helper is invoked and invalid backend raises ValueError.
"""

called = {"standard": 0, "experimental": 0}

def fake_is_experimental_mode(model, **_):
return isinstance(model, str) and model.startswith("exp/")

def fake_get_ms_pipeline(*args, **kwargs):
called["standard"] += 1
return ("standard", args, kwargs)

def fake_get_ms_experimental_pipeline(*args, **kwargs):
called["experimental"] += 1
return ("experimental", args, kwargs)

monkeypatch.setattr(pipeline_module, "is_experimental_mode", fake_is_experimental_mode)
monkeypatch.setattr(pipeline_module, "get_ms_pipeline", fake_get_ms_pipeline)
monkeypatch.setattr(pipeline_module, "get_ms_experimental_pipeline", fake_get_ms_experimental_pipeline)

result_exp = pipeline_module.pipeline(task="text-generation", model="exp/repo")
assert result_exp[0] == "experimental"
result_std = pipeline_module.pipeline(task="text-generation", model=object())
assert result_std[0] == "standard"
assert called == {"standard": 1, "experimental": 1}

with pytest.raises(ValueError):
pipeline_module.pipeline(task="text-generation", model="m", backend="unknown")


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_ms_pipeline_constructs_components(monkeypatch):
"""
Feature: MindSpore pipeline builder
Description: Ensure `get_ms_pipeline` builds network, processors, tokenizer, and pipeline call.
Expectation: Helper functions receive expected arguments and output from build_pipeline is returned.
"""

monkeypatch.setattr(pipeline_module, "SUPPORT_PIPELINES", {"demo-task": {"foo": "cfg.yaml"}})

def fake_config(path_value):
del path_value
return types.SimpleNamespace(
model="MODEL_CFG",
processor=types.SimpleNamespace(
tokenizer="TOK_CFG", image_processor="IMG_CFG", audio_processor="AUD_CFG"
),
)

monkeypatch.setattr(pipeline_module, "MindFormerConfig", fake_config)

calls = {}

def fake_build_network(model_cfg, default_args):
calls["build_network"] = (model_cfg, default_args)
return "MODEL_OBJ"

def fake_build_processor(argument):
calls.setdefault("build_processor", []).append(argument)
return f"PROC({argument})"

def fake_build_tokenizer(argument, tokenizer_name):
calls["build_tokenizer"] = (argument, tokenizer_name)
return "TOKENIZER_OBJ"

def fake_build_pipeline(**kwargs):
calls["build_pipeline"] = kwargs
return "PIPELINE_OBJ"

monkeypatch.setattr(pipeline_module, "build_network", fake_build_network)
monkeypatch.setattr(pipeline_module, "build_processor", fake_build_processor)
monkeypatch.setattr(pipeline_module, "build_tokenizer", fake_build_tokenizer)
monkeypatch.setattr(pipeline_module, "build_pipeline", fake_build_pipeline)

output = pipeline_module.get_ms_pipeline(
"demo-task",
"foo",
tokenizer=None,
image_processor=None,
audio_processor=None,
batch_size=4,
use_past=True,
)

assert output == "PIPELINE_OBJ"
assert calls["build_network"] == ("MODEL_CFG", {"batch_size": 4, "use_past": True})
assert calls["build_tokenizer"] == ("TOK_CFG", "foo")
# image/audio processors are both constructed
assert calls["build_processor"] == ["IMG_CFG", "AUD_CFG"]
assert calls["build_pipeline"]["model"] == "MODEL_OBJ"
assert calls["build_pipeline"]["tokenizer"] == "TOKENIZER_OBJ"
assert calls["build_pipeline"]["image_processor"].startswith("PROC")


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_ms_pipeline_invalid_inputs(monkeypatch):
"""
Feature: MindSpore pipeline validation
Description: Validate errors for unsupported tasks and model names.
Expectation: Raises KeyError in both cases.
"""

monkeypatch.setattr(pipeline_module, "SUPPORT_PIPELINES", {"valid": {"foo": "cfg"}})
with pytest.raises(KeyError):
pipeline_module.get_ms_pipeline("missing-task", None, None, None, None)
with pytest.raises(KeyError):
pipeline_module.get_ms_pipeline("valid", "bar", None, None, None)


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_ms_experimental_pipeline_builds_components(monkeypatch):
"""
Feature: Experimental pipeline builder
Description: Exercise a happy path covering tokenizer/image processor loading and context setup.
Expectation: Returns instantiated pipeline class with propagated kwargs.
"""

class DummyPipeline:
def __init__(self, **kwargs):
self.kwargs = kwargs

dummy_config = types.SimpleNamespace(custom_pipelines={}, _commit_hash="cfg-hash")

class DummyModelConfig:
_commit_hash = "model-hash"
tokenizer_class = None

class DummyModel:
def __init__(self):
self.config = DummyModelConfig()
self._eval_called = False

def eval(self):
self._eval_called = True
return self

def fake_check_task(task):
return task, {"impl": DummyPipeline, "ms": (object,)}, None

created = {}

def fake_load_model(model, **kwargs):
created["load_model"] = (model, kwargs)
return DummyModel()

def fake_auto_tokenizer(identifier, **_):
created["tokenizer"] = identifier
return "TOKENIZER"

def fake_auto_image_processor(identifier, **_):
created["image_processor"] = identifier
return "IMAGE_PROCESSOR"

monkeypatch.setattr(pipeline_module, "check_task", fake_check_task)
monkeypatch.setattr(pipeline_module, "load_model", fake_load_model)
monkeypatch.setattr(pipeline_module.AutoTokenizer, "from_pretrained", fake_auto_tokenizer)
monkeypatch.setattr(pipeline_module.AutoImageProcessor, "from_pretrained", fake_auto_image_processor)
monkeypatch.setattr(pipeline_module, "cached_file", lambda *_, **__: "cfg")
monkeypatch.setattr(pipeline_module, "extract_commit_hash", lambda *_: "commit")
monkeypatch.setattr(pipeline_module, "set_context", lambda **kwargs: created.setdefault("context", kwargs))

monkeypatch.setattr(pipeline_module, "TOKENIZER_MAPPING", {DummyModelConfig: None})
monkeypatch.setattr(pipeline_module, "IMAGE_PROCESSOR_MAPPING", {DummyModelConfig: None})
monkeypatch.setattr(pipeline_module, "NO_IMAGE_PROCESSOR_TASKS", set())

pipeline_obj = pipeline_module.get_ms_experimental_pipeline(
task="text-generation",
model="repo/model",
config=dummy_config,
tokenizer=None,
image_processor=None,
device_id=1,
device_target="Ascend",
model_kwargs={"foo": "bar"},
pipeline_class=DummyPipeline,
device_map="cpu",
torch_dtype="fp16",
)

assert isinstance(pipeline_obj, DummyPipeline)
assert pipeline_obj.kwargs["tokenizer"] == "TOKENIZER"
assert pipeline_obj.kwargs["image_processor"] == "IMAGE_PROCESSOR"
assert created["context"] == {"mode": 0, "device_id": 1, "device_target": "Ascend"}


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_load_model_prefers_successful_class():
"""
Feature: Model loading helper
Description: Ensure `load_model` iterates over candidate classes until one succeeds.
Expectation: Returns first successful model and raises when all fail.
"""

class FailingModel:
called = 0

@classmethod
def from_pretrained(cls, *_, **__):
cls.called += 1
raise OSError("fail")

class SuccessfulModel:
called = 0

@classmethod
def from_pretrained(cls, *_, **__):
cls.called += 1
return cls()

config = types.SimpleNamespace(architectures=None)
model = pipeline_module.load_model(
"id",
config=config,
model_classes=(FailingModel, SuccessfulModel),
task="text-generation",
)
assert isinstance(model, SuccessfulModel)
assert FailingModel.called == 1 and SuccessfulModel.called == 1

class AlwaysFail:
@classmethod
def from_pretrained(cls, *_, **__):
raise ValueError("bad")

with pytest.raises(ValueError):
pipeline_module.load_model(
"id",
config=config,
model_classes=(AlwaysFail,),
task="text-generation",
)

+ 176
- 0
tests/st/test_ut/test_pipeline/test_pipeline_registry.py View File

@@ -0,0 +1,176 @@
"""Unit tests for the pipeline registry helpers."""

import pytest

from mindformers.pipeline import pipeline_registry as pipeline_registry_module
from mindformers.pipeline.pipeline_registry import PipelineRegistry


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_supported_tasks_and_to_dict():
"""Get supported tasks and verify dict output.

Feature:
Test behavior of `PipelineRegistry.get_supported_tasks` and `to_dict`.

Description:
Initialize a registry with one task and one alias, verify the returned
task list contains both entries (sorted), and verify `to_dict` returns
the original mapping object.

Expectation:
`tasks` contains two items and `to_dict` returns the original dict.
"""
supported = {"task1": {"impl": object}}
aliases = {"alias1": "task1"}
reg = PipelineRegistry(supported, aliases)
tasks = reg.get_supported_tasks()
# two entries sorted
assert tasks in (["alias1", "task1"], ["task1", "alias1"])
# to_dict returns the underlying mapping
assert reg.to_dict() is supported


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_task_direct_and_alias():
"""Validate direct task name and alias resolution.

Feature:
Ensure `check_task` returns the normalized task name and the
corresponding implementation when given a direct task name or an alias.

Description:
Provide a mapping that contains `t1` and `translation` and an alias
that maps to `t1`. Call `check_task` with both types and assert the
returned values.

Expectation:
Both direct name and alias resolve to the same target implementation
and options are None.
"""
supported = {"t1": {"impl": object}, "translation": {"impl": object}}
aliases = {"alias": "t1"}
reg = PipelineRegistry(supported, aliases)

name, targeted, opts = reg.check_task("t1")
assert name == "t1"
assert targeted is supported["t1"]
assert opts is None

# alias should map to target
name2, targeted2, opts2 = reg.check_task("alias")
assert name2 == "t1"
assert targeted2 is supported["t1"]
assert opts2 is None


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_task_translation_valid_and_invalid():
"""Test parsing and error handling for translation tasks.

Feature:
Verify that parameterized translation tasks of the form
`translation_XX_to_YY` are parsed correctly and invalid formats raise
a KeyError.

Description:
Use a supported mapping that only contains `translation`. Test a valid
translation task and an invalid format that should raise.

Expectation:
Valid string returns language tuple; invalid format raises KeyError.
"""
supported = {"translation": {"impl": object}}
reg = PipelineRegistry(supported, {})

# valid: translation_en_to_de
name, targeted, options = reg.check_task("translation_en_to_de")
assert name == "translation"
assert targeted is supported["translation"]
assert options == ("en", "de")

# invalid format should raise KeyError
with pytest.raises(KeyError):
reg.check_task("translation_en_de")


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_task_unknown_raises():
"""Ensure unknown task names raise an informative KeyError.

Feature:
The registry should raise a KeyError for unknown tasks and include
'Unknown task' in the error message.

Description:
Initialize an empty registry and call `check_task` with a non-existent
task name. Verify that a KeyError is raised and the error message
contains the expected string.

Expectation:
KeyError is raised and message contains 'Unknown task'.
"""
reg = PipelineRegistry({}, {})
with pytest.raises(KeyError) as ei:
reg.check_task("noexist")
assert "Unknown task" in str(ei.value)


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_register_pipeline_overwrite_and_defaults_and_task_type(monkeypatch):
"""Test pipeline registration, default wrapping, and task type field.

Feature:
Ensure `register_pipeline` correctly registers implementations,
wraps a default that only contains 'ms' into {'model': ...}, and logs
a warning when overwriting an existing registration.

Description:
Register a custom pipeline class into an empty supported mapping and
verify stored values and class-level `registered_impl`. Then register
again and assert a warning was emitted.

Expectation:
The supported mapping contains the correct entries, the pipeline class
has `registered_impl`, and a warning is logged on overwrite.
"""
supported = {}
reg = PipelineRegistry(supported, {})

class MyPipeline:
pass

# register with default that has only 'ms' key -> should be wrapped under 'model'
reg.register_pipeline("t", MyPipeline, ms_model=(int,), default={"ms": "value"}, task_type="typeA")
assert "t" in supported
impl = supported["t"]
assert impl["impl"] is MyPipeline
assert impl["ms"] == (int,)
# default should be wrapped into {'model': {...}}
assert impl["default"] == {"model": {"ms": "value"}}
assert impl["type"] == "typeA"
# The pipeline_class should have registered_impl attribute set
assert hasattr(MyPipeline, "registered_impl")
assert "t" in MyPipeline.registered_impl

# registering again should produce a warning about overwriting
warnings = []

def fake_warning(msg, *_args, **_kwargs):
warnings.append(msg)

monkeypatch.setattr(pipeline_registry_module.logger, "warning", fake_warning)

reg.register_pipeline("t", MyPipeline, ms_model=(int,))
assert warnings, "Expected warning to be emitted"
assert any("is already registered" in msg for msg in warnings)

+ 234
- 0
tests/st/test_ut/test_pipeline/test_registry_constant.py View File

@@ -0,0 +1,234 @@
"""Tests for registry_constant helpers and data wiring."""

import importlib
from importlib import util as importlib_util
from pathlib import Path

import pytest


# Resolve the path to the real registry_constant.py in a robust way.
# Prefer importlib's spec if the package is importable; otherwise fall back
# to searching parent directories for the repository layout.
try:
_spec = importlib_util.find_spec("mindformers.pipeline.registry_constant")
if _spec and getattr(_spec, "origin", None):
REGISTRY_PATH = Path(_spec.origin)
else:
raise RuntimeError("spec not found")
except Exception:
# fallback: look for the file under parents; this works when running
# tests from workspace without installing the package
base = Path(__file__).resolve()

# 1) Find repository root by common markers (.git, setup.py, pyproject.toml)
repo_root = None
for p in base.parents:
if (p / ".git").exists() or (p / "setup.py").exists() or (p / "pyproject.toml").exists():
repo_root = p
break

if repo_root is not None:
candidate = repo_root / "mindformers" / "pipeline" / "registry_constant.py"
if candidate.exists():
REGISTRY_PATH = candidate
else:
# if repo root found but file missing, fall back to a best-effort search below
repo_root = None

# 2) If repo root not determined, search parents for a candidate but avoid
# picking up nested 'tests/mindformers' layouts which can occur when tests
# are run from a working directory that mirrors the package tree.
if repo_root is None:
found = None
for p in base.parents:
candidate = p / "mindformers" / "pipeline" / "registry_constant.py"
if candidate.exists():
# prefer candidates not under a 'tests' directory
if "tests" not in map(str, candidate.parts):
found = candidate
break
# otherwise keep the first found as a last resort
if found is None:
found = candidate
if found is None:
# original fallback (best-effort)
REGISTRY_PATH = Path(__file__).resolve().parents[3] / "mindformers" / "pipeline" / "registry_constant.py"
else:
REGISTRY_PATH = found


def _load_registry_with_supported_tasks(supported_tasks):
"""Dynamically execute a modified copy of registry_constant with a
custom SUPPORTED_TASKS dict. Returns the execution namespace dict.

This approach recreates the module-level initialization logic (the
for-loop that populates NO_* sets) without importing the real module
(which would already have executed with its built-in SUPPORTED_TASKS).
"""
src = REGISTRY_PATH.read_text(encoding="utf-8")

# find the place where the NO_* sets start; reuse the remainder of the
# file (including the for-loop and PIPELINE_REGISTRY init) so we only
# replace SUPPORTED_TASKS.
marker = "NO_FEATURE_EXTRACTOR_TASKS = set()"
idx = src.find(marker)
assert idx != -1, "registry_constant.py structure changed; cannot locate marker"
line_offset = src[:idx].count("\n")
remainder = ("\n" * line_offset) + src[idx:]

# build a new source where SUPPORTED_TASKS is our custom dict
new_src = "SUPPORTED_TASKS = " + repr(supported_tasks) + "\n" + remainder

# prepare a namespace with minimal dependencies mocked
class DummyPipelineRegistry:
def __init__(self, supported_tasks=None, task_aliases=None):
self.supported_tasks = supported_tasks
self.task_aliases = task_aliases

namespace = {
"PipelineRegistry": DummyPipelineRegistry,
"TextGenerationPipeline": object,
"AutoModelForCausalLM": object,
"TASK_ALIASES": {},
}

exec(compile(new_src, str(REGISTRY_PATH), "exec"), namespace) # pylint: disable=W0122
return namespace


def _recompute_sets_in_module(custom_tasks):
"""Execute the real registry_constant for-loop against custom tasks."""

module = importlib.import_module("mindformers.pipeline.registry_constant")
module.SUPPORTED_TASKS = custom_tasks
module.NO_FEATURE_EXTRACTOR_TASKS = set()
module.NO_IMAGE_PROCESSOR_TASKS = set()
module.NO_TOKENIZER_TASKS = set()

src = REGISTRY_PATH.read_text(encoding="utf-8")
marker_start = "for task, values in SUPPORTED_TASKS.items():"
marker_end = "PIPELINE_REGISTRY ="
start_idx = src.find(marker_start)
end_idx = src.find(marker_end)
assert start_idx != -1 and end_idx != -1, "registry_constant structure changed"
line_offset = src[:start_idx].count("\n")
loop_src = ("\n" * line_offset) + src[start_idx:end_idx]

exec(compile(loop_src, str(REGISTRY_PATH), "exec"), module.__dict__) # pylint: disable=W0122
return module


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_default_supported_tasks_processing():
"""
Feature: registry_constant defaults
Description: Ensure NO_* sets include expected defaults when importing the real module.
Expectation: text-generation is categorized correctly across NO_* sets.
"""
# Importing the real module executes the loop once; assert default behavior
rc = importlib.import_module("mindformers.pipeline.registry_constant")
assert "text-generation" in rc.NO_FEATURE_EXTRACTOR_TASKS
assert "text-generation" in rc.NO_IMAGE_PROCESSOR_TASKS
assert "text-generation" not in rc.NO_TOKENIZER_TASKS


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_processing_various_types_and_error():
"""
Feature: registry_constant type handling
Description: Validate NO_* set assignment for image/audio types and error for invalid type.
Expectation: Image/audio tasks populate correct NO_* sets; unknown type raises ValueError.
"""
# image type -> tokenizers not required
ns = _load_registry_with_supported_tasks({"img-task": {"type": "image"}})
assert "img-task" in ns["NO_TOKENIZER_TASKS"]

# audio type -> tokenizer + image processor not required
ns = _load_registry_with_supported_tasks({"aud-task": {"type": "audio"}})
assert "aud-task" in ns["NO_TOKENIZER_TASKS"]
assert "aud-task" in ns["NO_IMAGE_PROCESSOR_TASKS"]

# invalid type should raise ValueError during module execution
with pytest.raises(ValueError):
_load_registry_with_supported_tasks({"bad": {"type": "unknown"}})


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_pipeline_registry_aliases_and_supported_tasks_consistency():
"""
Feature: PIPELINE_REGISTRY wiring
Description: Verify exported PIPELINE_REGISTRY shares TASK_ALIASES/SUPPORTED_TASKS references.
Expectation: Registry uses the same dict objects and exposes alias entries defined in TASK_ALIASES.
"""

rc = importlib.import_module("mindformers.pipeline.registry_constant")
assert rc.PIPELINE_REGISTRY.task_aliases is rc.TASK_ALIASES
assert rc.PIPELINE_REGISTRY.supported_tasks is rc.SUPPORTED_TASKS
assert rc.TASK_ALIASES["text_generation"] == "text-generation"


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_processing_video_and_multimodal_types():
"""
Feature: registry_constant video/multimodal handling
Description: Ensure video tasks skip tokenizer and multimodal entries bypass NO_* categorization.
Expectation: Video tasks populate NO_TOKENIZER_TASKS; multimodal entries leave sets unchanged and do not raise.
"""

ns = _load_registry_with_supported_tasks({
"vid-task": {"type": "video"},
"multi-task": {"type": "multimodal"},
})
assert "vid-task" in ns["NO_TOKENIZER_TASKS"]
assert "vid-task" not in ns["NO_IMAGE_PROCESSOR_TASKS"]
assert "multi-task" not in ns["NO_TOKENIZER_TASKS"]
assert "multi-task" not in ns["NO_FEATURE_EXTRACTOR_TASKS"]


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_recompute_sets_in_module_covers_video_audio_and_multimodal():
"""
Feature: registry_constant loop execution
Description: Run the actual module-level loop with custom tasks to cover image/video/audio branches.
Expectation: Module NO_* sets reflect injected task types and state resets cleanly after reload.
"""

custom = {
"img-task": {"type": "image"},
"aud-task": {"type": "audio"},
"multi-task": {"type": "multimodal"},
}
module = _recompute_sets_in_module(custom)
assert "img-task" in module.NO_TOKENIZER_TASKS
assert "aud-task" in module.NO_TOKENIZER_TASKS
assert "aud-task" in module.NO_IMAGE_PROCESSOR_TASKS
assert "multi-task" not in module.NO_TOKENIZER_TASKS

importlib.reload(importlib.import_module("mindformers.pipeline.registry_constant"))


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_recompute_sets_invalid_type_triggers_value_error():
"""
Feature: registry_constant invalid type guard
Description: Ensure executing the real loop with unsupported type raises ValueError.
Expectation: ValueError message references offending task.
"""

with pytest.raises(ValueError):
_recompute_sets_in_module({"bad": {"type": "unknown"}})

importlib.reload(importlib.import_module("mindformers.pipeline.registry_constant"))

+ 173
- 0
tests/st/test_ut/test_tools/test_generic.py View File

@@ -0,0 +1,173 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test generic.py
"""
import tempfile
from unittest.mock import MagicMock
import os
import re
import pytest

from mindformers.tools.generic import (working_or_temp_dir, add_model_info_to_auto_map, experimental_mode_func_checker,
is_experimental_mode)


class TestWorkingOrTempDir:
""" A test class for testing working_or_temp_dir."""
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_working_dir_without_temp(self):
"""Test using working directory when use_temp_dir is False"""
with tempfile.TemporaryDirectory() as temp_working_dir:
with working_or_temp_dir(temp_working_dir, use_temp_dir=False) as result_dir:
assert result_dir == temp_working_dir
assert os.path.exists(result_dir)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_working_dir_with_temp(self):
"""Test using temporary directory when use_temp_dir is True"""
with tempfile.TemporaryDirectory() as temp_working_dir:
with working_or_temp_dir(temp_working_dir, use_temp_dir=True) as result_dir:
assert result_dir != temp_working_dir


class TestAddModelInfoToAutoMap:
""" A test class for testing add_model_info_to_auto_map."""
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_add_model_info_with_string_values(self):
"""Test with string values in auto_map"""
auto_map = {
"key1": "value1",
"key2": "value2",
"key3": None
}
repo_id = "my_repo"
result = add_model_info_to_auto_map(auto_map, repo_id)
expected = {
"key1": "my_repo--value1",
"key2": "my_repo--value2",
"key3": None
}
assert result == expected

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_add_model_info_with_list_values(self):
"""Test with list values in auto_map"""
auto_map = {
"key1": ["value1", "value2"],
"key2": [None, "value3"],
"key3": "single_value"
}
repo_id = "my_repo"
result = add_model_info_to_auto_map(auto_map, repo_id)
expected = {
"key1": ["my_repo--value1", "my_repo--value2"],
"key2": [None, "my_repo--value3"],
"key3": "my_repo--single_value"
}
assert result == expected

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_add_model_info_with_existing_dashes(self):
"""Test with values that already contain dashes"""
auto_map = {
"key1": "existing--value",
"key2": ["normal_value", "existing--value"],
"key3": None
}
repo_id = "my_repo"
result = add_model_info_to_auto_map(auto_map, repo_id)
expected = {
"key1": "existing--value",
"key2": ["my_repo--normal_value", "existing--value"],
"key3": None
}
assert result == expected

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_add_model_info_empty_auto_map(self):
"""Test with empty auto_map"""
auto_map = {}
repo_id = "my_repo"
result = add_model_info_to_auto_map(auto_map, repo_id)
assert not result


class TestExperimentalModeFuncChecker:
""" A test class for testing experimental_mode_func_checker."""
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_decorator_success_case(self):
"""Test decorator when function executes successfully"""
mock_cls = MagicMock()
mock_cls.__name__ = "TestClass"

@experimental_mode_func_checker()
def test_function(x, y):
return x + y

result = test_function(2, 3)
assert result == 5

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_decorator_with_custom_error_message(self):
"""Test decorator with custom error message"""
mock_cls = MagicMock()
mock_cls.__name__ = "TestClass"
custom_msg = "Custom error message"

@experimental_mode_func_checker(custom_err_msg=custom_msg)
def test_function(cls, x, y):
raise ValueError("Test error")

with pytest.raises(RuntimeError) as exc_info:
test_function(mock_cls, 2, 3)

error_str = str(exc_info.value)
assert "Error occurred when executing function test_function" in error_str
assert custom_msg in error_str
assert "You are using TestClass in experimental mode" in error_str
assert isinstance(exc_info.value.__cause__, ValueError)


class TestIsExperimentalMode:
""" A test class for testing is_experimental_mode."""
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_path(self):
"""Test with non-string path parameter"""
with pytest.raises(ValueError, match=re.escape(
"param 'path' in AutoConfig.from_pretrained() must be str, but got <class 'int'>")):
is_experimental_mode(123)
result = is_experimental_mode("some/path/that/does/not/exist")
assert result is True
result = is_experimental_mode("mindspore/some/path")
assert result is False

+ 147
- 0
tests/st/test_ut/test_tools/test_logger.py View File

@@ -0,0 +1,147 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test logger.py
"""
import inspect
from unittest.mock import patch, MagicMock
import pytest

from mindformers.tools.logger import (_get_stack_info, judge_redirect, StreamRedirector, AiLogFastStreamRedirect2File,
judge_stdout, validate_nodes_devices_input)


class TestGetStackInfo:
"""Test class for testing _get_stack_info."""
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_starts_with_stack_prefix(self):
"""Test that returned string starts with expected prefix"""
current_frame = inspect.currentframe()
result = _get_stack_info(current_frame)
assert result.startswith('Stack (most recent call last):\n')

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.tools.utils.generate_rank_list')
@patch('mindformers.tools.utils.convert_nodes_devices_input')
@patch('mindformers.tools.utils.get_num_nodes_devices')
def test_rank_not_in_redirect_list_returns_false(self, mock_get_num, mock_convert, mock_generate):
"""Test when rank_id is not in redirect list returns False"""
mock_get_num.return_value = (2, 2)
mock_convert.return_value = [0]
mock_generate.return_value = [0, 1]
result = judge_redirect(rank_id=2, rank_size=4, redirect_nodes=[0])
assert result is True


class TestStreamRedirector:
"""Test class for testing StreamRedirector."""
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_context_manager_enter_calls_start(self):
"""Test that __enter__ method calls start()"""
source_stream = MagicMock()
target_stream = MagicMock()
redirector = StreamRedirector(source_stream, target_stream)
with patch.object(redirector, 'start') as mock_start:
with redirector:
mock_start.assert_called_once()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_context_manager_exit_calls_stop(self):
"""Test that __exit__ method calls stop()"""
source_stream = MagicMock()
target_stream = MagicMock()
redirector = StreamRedirector(source_stream, target_stream)
with patch.object(redirector, 'stop') as mock_stop:
redirector.__exit__(None, None, None)
mock_stop.assert_called_once()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_decorator_wraps_function_properly(self):
"""Test that __call__ returns a decorator that wraps the function"""
source_stream = MagicMock()
target_stream = MagicMock()
redirector = StreamRedirector(source_stream, target_stream)
test_func = MagicMock()
with patch.object(redirector, 'start') as mock_start, \
patch.object(redirector, 'stop') as mock_stop:
wrapper = redirector(test_func)
wrapper('arg1', kwarg1='value1')
mock_start.assert_called_once()
test_func.assert_called_once_with('arg1', kwarg1='value1')
mock_stop.assert_called_once()


class TestAiLogFastStreamRedirect2File:
"""Test class for testing AiLogFastStreamRedirect2File."""
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.tools.utils.get_rank_info')
@patch('mindformers.tools.StreamRedirector.__init__')
def test_start_when_redirect_false(self, mock_stream_redirector_init, mock_get_rank_info):
"""Test when both nodes and devices parameters are provided"""
mock_get_rank_info.return_value = (0, 4)
mock_stream_redirector_init.return_value = None
redirector = AiLogFastStreamRedirect2File()
mock_stream_redirector_init.assert_called_once()
assert redirector.is_redirect is True


class TestJudgeStdout:
"""Test class for testing judge_stdout."""
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.tools.utils.generate_rank_list')
@patch('mindformers.tools.utils.convert_nodes_devices_input')
@patch('mindformers.tools.utils.get_num_nodes_devices')
def test_both_nodes_and_devices_provided(self, mock_get_num, mock_convert, mock_generate):
"""Test when both nodes and devices parameters are provided"""
mock_get_num.return_value = (2, 2)
mock_convert.side_effect = [[0], [0, 1]]
mock_generate.return_value = [0, 1]

result = judge_stdout(rank_id=1, rank_size=4, is_output=True, nodes=[0], devices=[0, 1])
assert result is True


class TestValidateNodesDevicesInput:
"""Test class for testing validate_nodes_devices_input."""
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_invalid_type_raises_type_error(self):
"""Test that invalid type raises TypeError"""
with pytest.raises(TypeError,
match="The value of test_var can be None or a value of type tuple, list, or dict."):
validate_nodes_devices_input('test_var', "invalid_string")

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_list_with_non_int_raises_type_error(self):
"""Test that list containing non-integer raises TypeError"""
with pytest.raises(TypeError, match="The elements of a variable of type list or tuple must be of type int."):
validate_nodes_devices_input('test_var', [1, '2', 3])

+ 160
- 0
tests/st/test_ut/test_tools/test_register/test_config.py View File

@@ -0,0 +1,160 @@
#!/usr/bin/env python
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Unit tests for mindformers.tools.register.config."""
import argparse
import sys
from collections import OrderedDict

import pytest

from mindformers.tools.register import config as config_module
from mindformers.tools.register.config import (
ActionDict,
DictConfig,
MindFormerConfig,
BASE_CONFIG,
ordered_yaml_dump,
parse_args,
)

yaml = pytest.importorskip("yaml")

# pylint: disable=protected-access


class TestConfig:
"""Test class for mindformers.tools.register.config."""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_dict_config_attribute_and_to_dict(self):
"""DictConfig should expose attribute access semantics."""
cfg = DictConfig(a=1, nested=DictConfig(b=2))
assert cfg.a == 1
cfg.c = 3
assert cfg.c == 3
del cfg.a
assert cfg.a is None
plain = cfg.to_dict()
assert plain == {"nested": {"b": 2}, "c": 3}

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_dict_config_deepcopy_isolated(self):
"""Deep copy should create independent nested objects."""
cfg = DictConfig(nested=DictConfig(value=[1, 2]))
copied = cfg
copied.nested.value.append(3)
assert cfg.nested.value == [1, 2]
assert copied.nested.value == [1, 2, 3]

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_mindformer_config_loads_yaml_with_base(self, monkeypatch, tmp_path):
"""MindFormerConfig should merge base yaml files and convert dict to config."""
monkeypatch.setattr(config_module.ConfigTemplate, "apply_template", lambda _: None)
base_content = {"alpha": 1, "nested": {"from_base": True}}
base_file = tmp_path / "base.yaml"
base_file.write_text(yaml.safe_dump(base_content), encoding="utf-8")

child_content = {
BASE_CONFIG: "base.yaml",
"beta": 2,
"nested": {"from_child": True},
}
child_file = tmp_path / "child.yaml"
child_file.write_text(yaml.safe_dump(child_content), encoding="utf-8")

cfg = MindFormerConfig(str(child_file))
assert cfg.alpha == 1
assert cfg.beta == 2
assert isinstance(cfg.nested, MindFormerConfig)
assert cfg.nested.from_child

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_mindformer_config_merge_and_set(self, monkeypatch):
"""merge_from_dict and set_value should correctly update nested fields."""
monkeypatch.setattr(config_module.ConfigTemplate, "apply_template", lambda _: None)
cfg = MindFormerConfig(model={"model_config": {"type": "Demo"}})
cfg.merge_from_dict({"model.arch": "DemoArch", "new.branch.leaf": 10})
assert cfg.model.arch == "DemoArch"
assert cfg.new.branch.leaf == 10

cfg.set_value("context.mode", "GRAPH")
cfg.set_value(["context", "device_id"], 3)
assert cfg.get_value("context.mode") == "GRAPH"
assert cfg.get_value(["context", "device_id"]) == 3
assert cfg.get_value("context.fake", default="fallback") == "fallback"

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_file2dict_without_filename_raises(self):
"""_file2dict should raise when filename is None."""
with pytest.raises(NameError):
getattr(MindFormerConfig, "_file2dict")(None)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_action_dict_parse_and_call(self):
"""ActionDict should parse ints, floats, tuples and bool strings."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--opts",
action=ActionDict,
nargs="*",
default={},
)
args = parser.parse_args(
[
"--opts",
"ints=1,2",
"floats=3.5",
"tuple=(7,8)",
"mixed=[1,(2,3),[4,5]]",
"flag=True",
]
)
assert args.opts["ints"] == [1, 2]
assert args.opts["floats"] == 3.5
assert args.opts["tuple"] == (7, 8)
assert args.opts["mixed"] == [1, (2, 3), [4, 5]]
assert args.opts["flag"] is False # current implementation compares function object

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_action_dict_find_next_comma_invalid_pairs(self):
"""find_next_comma should raise when brackets are unbalanced."""
with pytest.raises(ValueError):
ActionDict.find_next_comma("[1,2")

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_ordered_yaml_dump_preserves_order(self):
"""ordered_yaml_dump should keep OrderedDict order in emitted yaml."""
ordered = OrderedDict()
ordered["first"] = 1
ordered["second"] = 2
dumped = ordered_yaml_dump(ordered)
assert dumped.index("first") < dumped.index("second")

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_parse_args_reads_cli(self, monkeypatch):
"""parse_args should honor the --config cli argument."""
monkeypatch.setattr(sys, "argv", ["prog", "--config", "path/to/model.yaml"])
parsed = parse_args()
assert parsed.config == "path/to/model.yaml"

+ 1330
- 0
tests/st/test_ut/test_tools/test_transform_checkpoint.py View File

@@ -0,0 +1,1330 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test for transform_checkpoint.py"""
# pylint: disable=W0212
import os
from unittest.mock import patch, MagicMock

import pytest
from mindformers.tools.ckpt_transform import transform_checkpoint
from mindformers.tools.ckpt_transform.transform_checkpoint import TransformCkpt, main


class TestTransformCkpt:
"""Test TransformCkpt class"""

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_init(self):
"""Test __init__ method"""
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=False,
rank_id=0,
world_size=1,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)
assert transform_ckpt.world_size == 1
assert transform_ckpt.rank_id == 0
assert transform_ckpt.is_main_rank is True
assert transform_ckpt.npu_num_per_node == 1
assert transform_ckpt.transform_process_num == 1
assert transform_ckpt.transform_by_rank is False

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_transform_rank_id_list(self):
"""Test _get_transform_rank_id_list method"""
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=8), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=8):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=False,
rank_id=0,
world_size=8,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=8
)
rank_list = transform_ckpt._get_transform_rank_id_list(2)
assert rank_list == [0, 4]

rank_list = transform_ckpt._get_transform_rank_id_list(8)
assert rank_list == [0, 1, 2, 3, 4, 5, 6, 7]

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_strategy_file(self, tmp_path):
"""Test get_strategy method with file"""
# Create test ckpt file
test_ckpt_path = os.path.join(tmp_path, "test.ckpt")
with open(test_ckpt_path, "w", encoding="utf-8") as f:
f.write("test ckpt content")

with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=False,
rank_id=0,
world_size=1,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)
strategy_path = transform_ckpt.get_strategy(test_ckpt_path)
assert strategy_path == test_ckpt_path

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_strategy_none(self):
"""Test get_strategy method with None"""
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=False,
rank_id=0,
world_size=1,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)
strategy_path = transform_ckpt.get_strategy(None)
assert strategy_path is None

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_src_checkpoint_and_strategy_invalid(self, tmp_path):
"""Test check_src_checkpoint_and_strategy method with invalid input"""
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=False,
rank_id=0,
world_size=1,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)
with pytest.raises(ValueError):
transform_ckpt.check_src_checkpoint_and_strategy(tmp_path, None)

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_build_soft_link_of_checkpoint(self, tmp_path):
"""Test build_soft_link_of_checkpoint method with various input types"""
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=False,
rank_id=0,
world_size=1,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)

# Test 1: Invalid directory (no rank_0 folder or ckpt files)
invalid_dir = os.path.join(tmp_path, "invalid_dir")
os.makedirs(invalid_dir)
soft_link_dir = os.path.join(tmp_path, "soft_link1")
os.makedirs(soft_link_dir)

with pytest.raises(ValueError):
transform_ckpt.build_soft_link_of_checkpoint(invalid_dir, soft_link_dir)

# Test 2: File input (ckpt file)
test_ckpt_path = os.path.join(tmp_path, "test.ckpt")
with open(test_ckpt_path, "w", encoding="utf-8") as f:
f.write("test ckpt content")

soft_link_dir = os.path.join(tmp_path, "soft_link2")
os.makedirs(soft_link_dir)

with patch("mindformers.tools.ckpt_transform.transform_checkpoint."
"make_soft_link") as mock_make_soft_link:
transform_ckpt.build_soft_link_of_checkpoint(test_ckpt_path, soft_link_dir)
mock_make_soft_link.assert_called_once()

# Test 3: Directory with rank_0 folder
valid_dir = os.path.join(tmp_path, "valid_dir")
rank_0_dir = os.path.join(valid_dir, "rank_0")
os.makedirs(rank_0_dir)
valid_ckpt = os.path.join(rank_0_dir, "test.ckpt")
with open(valid_ckpt, "w", encoding="utf-8") as f:
f.write("valid ckpt content")

soft_link_dir = os.path.join(tmp_path, "soft_link3")
os.makedirs(soft_link_dir)

with patch("mindformers.tools.ckpt_transform.transform_checkpoint."
"make_soft_link") as mock_make_soft_link:
transform_ckpt.build_soft_link_of_checkpoint(valid_dir, soft_link_dir)
mock_make_soft_link.assert_called_once()

# Test 4: Directory with ckpt files directly
ckpt_dir = os.path.join(tmp_path, "ckpt_dir")
os.makedirs(ckpt_dir)
ckpt1 = os.path.join(ckpt_dir, "ckpt1.ckpt")
ckpt2 = os.path.join(ckpt_dir, "ckpt2.ckpt")
with open(ckpt1, "w", encoding="utf-8") as f:
f.write("ckpt1 content")
with open(ckpt2, "w", encoding="utf-8") as f:
f.write("ckpt2 content")

soft_link_dir = os.path.join(tmp_path, "soft_link4")
os.makedirs(soft_link_dir)

with patch("mindformers.tools.ckpt_transform.transform_checkpoint."
"make_soft_link") as mock_make_soft_link:
transform_ckpt.build_soft_link_of_checkpoint(ckpt_dir, soft_link_dir)
# Should be called twice, once for each ckpt file
assert mock_make_soft_link.call_count == 2

# Test 5: Directory with both rank folders and ckpt files
mixed_dir = os.path.join(tmp_path, "mixed_dir")
mixed_rank_0_dir = os.path.join(mixed_dir, "rank_0")
os.makedirs(mixed_rank_0_dir)
mixed_ckpt = os.path.join(mixed_rank_0_dir, "mixed.ckpt")
with open(mixed_ckpt, "w", encoding="utf-8") as f:
f.write("mixed ckpt content")

# Add a direct ckpt file in mixed_dir
direct_ckpt = os.path.join(mixed_dir, "direct.ckpt")
with open(direct_ckpt, "w", encoding="utf-8") as f:
f.write("direct ckpt content")

soft_link_dir = os.path.join(tmp_path, "soft_link5")
os.makedirs(soft_link_dir)

with patch("mindformers.tools.ckpt_transform.transform_checkpoint."
"make_soft_link") as mock_make_soft_link:
transform_ckpt.build_soft_link_of_checkpoint(mixed_dir, soft_link_dir)
# Should be called once for the rank folder, ignoring the direct ckpt file
mock_make_soft_link.assert_called_once()

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_clear_cache(self, tmp_path):
"""Test clear_cache method"""
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=False,
rank_id=0,
world_size=1,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)
# Add a cache file
cache_file = os.path.join(tmp_path, "cache.txt")
with open(cache_file, "w", encoding="utf-8") as f:
f.write("cache content")
transform_ckpt.cache_list.append(cache_file)
# Clear cache
with patch("mindformers.tools.ckpt_transform.transform_checkpoint."
"delete_file") as mock_delete_file:
transform_ckpt.clear_cache()
mock_delete_file.assert_called_once_with(cache_file)

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_transform_checkpoints(self, tmp_path):
"""Test transform_checkpoints method"""
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.ms") as mock_ms, \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=False,
rank_id=0,
world_size=1,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)
dst_ckpt_dir = os.path.join(tmp_path, "dst_ckpt")
transform_ckpt.transform_checkpoints(
src_checkpoint=tmp_path,
dst_checkpoint=dst_ckpt_dir,
prefix="checkpoint_",
src_strategy=None,
dst_strategy=None
)
mock_ms.transform_checkpoints.assert_called_once()

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_transform_checkpoint_by_rank(self, tmp_path):
"""Test transform_checkpoint_by_rank method"""
# Create test ckpt file
test_ckpt_path = os.path.join(tmp_path, "test.ckpt")
with open(test_ckpt_path, "w", encoding="utf-8") as f:
f.write("test ckpt content")

with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=8), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.ms") as mock_ms, \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.glob", return_value=[test_ckpt_path]), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=8):
# Mock rank_list_for_transform to return a list
mock_ms.rank_list_for_transform.return_value = [0]
transform_ckpt = TransformCkpt(
auto_trans_ckpt=False,
rank_id=0,
world_size=8,
transform_process_num=1,
transform_by_rank=True,
npu_num_per_node=8
)
dst_ckpt_dir = os.path.join(tmp_path, "dst_ckpt")
transform_ckpt.transform_checkpoint_by_rank(
src_checkpoint=tmp_path,
dst_checkpoint=dst_ckpt_dir,
prefix="checkpoint_",
src_strategy=None,
dst_strategy=None
)
# Check that transform_checkpoint_by_rank was called 8 times (once for each rank)
assert mock_ms.transform_checkpoint_by_rank.call_count == 8

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_call(self, tmp_path):
"""Test __call__ method"""
# Create test ckpt file
test_ckpt_path = os.path.join(tmp_path, "test.ckpt")
with open(test_ckpt_path, "w", encoding="utf-8") as f:
f.write("test ckpt content")

with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.ms"), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.barrier_world"), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.remake_folder"), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.make_soft_link"), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.check_in_modelarts", return_value=False), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=False,
rank_id=0,
world_size=1,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)
# Mock get_strategy to return None
with patch.object(transform_ckpt, "get_strategy", return_value=None), \
patch.object(transform_ckpt, "transform_ckpt"), \
patch.object(transform_ckpt, "clear_cache"), \
patch("os.listdir", return_value=[]):
result = transform_ckpt(
src_checkpoint=test_ckpt_path,
dst_checkpoint_dir=None,
src_strategy=None,
dst_strategy=None,
prefix="checkpoint_"
)
assert result is not None

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_call_auto_trans_ckpt_true(self, tmp_path):
"""Test __call__ method with auto_trans_ckpt=True"""
# Create test ckpt file
test_ckpt_path = os.path.join(tmp_path, "test.ckpt")
with open(test_ckpt_path, "w", encoding="utf-8") as f:
f.write("test ckpt content")

# Test with auto_trans_ckpt=True and world_size>1
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=2), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.ms"), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.barrier_world"), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.remake_folder"), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.make_soft_link"), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.check_in_modelarts", return_value=False), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.ms.get_auto_parallel_context",
return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_output_root_path",
return_value=str(tmp_path)), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
# Create dst_strategy_dir
dst_strategy_dir = os.path.join(tmp_path, "strategy")
os.makedirs(dst_strategy_dir, exist_ok=True)

transform_ckpt = TransformCkpt(
auto_trans_ckpt=True,
rank_id=0,
world_size=2,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)

# Mock get_strategy to return a strategy file
strategy_file = os.path.join(dst_strategy_dir,
"test_strategy_rank_0.ckpt")
with open(strategy_file, "w", encoding="utf-8") as f:
f.write("test strategy")

with patch.object(transform_ckpt, "get_strategy", return_value=strategy_file), \
patch.object(transform_ckpt, "get_dst_strategy", return_value=strategy_file), \
patch.object(transform_ckpt, "transform_ckpt"), \
patch.object(transform_ckpt, "clear_cache"), \
patch("os.listdir", return_value=[]):
result = transform_ckpt(
src_checkpoint=test_ckpt_path,
dst_checkpoint_dir=None,
src_strategy=None,
dst_strategy=None,
prefix="checkpoint_"
)
assert result is not None

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_call_modelarts(self, tmp_path):
"""Test __call__ method with ModelArts environment"""
# Create test ckpt file
test_ckpt_path = os.path.join(tmp_path, "test.ckpt")
with open(test_ckpt_path, "w", encoding="utf-8") as f:
f.write("test ckpt content")

# Import the module and add mox attribute directly
mock_mox = MagicMock()
mock_mox.file = MagicMock()
mock_mox.file.exists.return_value = True
transform_checkpoint.mox = mock_mox

with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.ms"), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.barrier_world"), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.remake_folder"), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.make_soft_link"), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.check_in_modelarts", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.ms.get_auto_parallel_context",
return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_output_root_path",
return_value=str(tmp_path)), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_remote_save_url",
return_value="s3://bucket/path"), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=True,
rank_id=0,
world_size=1,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)

# Mock get_strategy to return None
with patch.object(transform_ckpt, "get_strategy", return_value=None), \
patch.object(transform_ckpt, "get_dst_strategy", return_value=None), \
patch.object(transform_ckpt, "transform_ckpt"), \
patch.object(transform_ckpt, "clear_cache"), \
patch("os.listdir", return_value=[]):
result = transform_ckpt(
src_checkpoint=test_ckpt_path,
dst_checkpoint_dir=None,
src_strategy=None,
dst_strategy=None,
prefix="checkpoint_"
)
assert result is not None

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_src_checkpoint_and_strategy_valid(self, tmp_path):
"""Test check_src_checkpoint_and_strategy method with valid input"""
# Create test ckpt file
test_ckpt_path = os.path.join(tmp_path, "test.ckpt")
with open(test_ckpt_path, "w", encoding="utf-8") as f:
f.write("test ckpt content")

with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
# Create a valid directory structure
valid_dir = os.path.join(tmp_path, "valid_ckpt")
rank_0_dir = os.path.join(valid_dir, "rank_0")
os.makedirs(rank_0_dir, exist_ok=True)
valid_ckpt = os.path.join(rank_0_dir, "test.ckpt")
with open(valid_ckpt, "w", encoding="utf-8") as f:
f.write("valid ckpt content")

transform_ckpt = TransformCkpt(
auto_trans_ckpt=False,
rank_id=0,
world_size=1,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)
# This should not raise an exception
transform_ckpt.check_src_checkpoint_and_strategy(valid_dir, test_ckpt_path)

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_transform_ckpt(self, tmp_path):
"""Test transform_ckpt method with various scenarios"""
# Test 1: Both src_strategy and dst_strategy are None
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=False,
rank_id=0,
world_size=1,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)

# Create a valid directory structure
valid_dir = os.path.join(tmp_path, "valid_ckpt")
rank_0_dir = os.path.join(valid_dir, "rank_0")
os.makedirs(rank_0_dir, exist_ok=True)
valid_ckpt = os.path.join(rank_0_dir, "test.ckpt")
with open(valid_ckpt, "w", encoding="utf-8") as f:
f.write("valid ckpt content")

# This should raise ValueError since both strategies are None
with pytest.raises(ValueError):
transform_ckpt.transform_ckpt(
src_checkpoint=valid_dir,
dst_checkpoint_dir=tmp_path,
src_strategy=None,
dst_strategy=None,
prefix="checkpoint_"
)

# Test 2: transform_ckpt with exception handling
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.check_in_modelarts", return_value=False), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.create_file") as mock_create_file, \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=False,
rank_id=0,
world_size=1,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)

# Create a valid directory structure
valid_dir = os.path.join(tmp_path, "valid_ckpt")
rank_0_dir = os.path.join(valid_dir, "rank_0")
os.makedirs(rank_0_dir, exist_ok=True)
valid_ckpt = os.path.join(rank_0_dir, "test.ckpt")
with open(valid_ckpt, "w", encoding="utf-8") as f:
f.write("valid ckpt content")

# Mock transform_checkpoints to raise an exception
with patch.object(transform_ckpt, "check_src_checkpoint_and_strategy"), \
patch.object(transform_ckpt, "transform_checkpoints",
side_effect=Exception("Transform failed")), \
patch.object(transform_ckpt, "wait_transform"):
transform_ckpt.transform_ckpt(
src_checkpoint=valid_dir,
dst_checkpoint_dir=tmp_path,
src_strategy="src_strategy.ckpt",
dst_strategy="dst_strategy.ckpt",
prefix="checkpoint_"
)
# Check that transform_failed file was created
mock_create_file.assert_called()

# Test 3: transform_ckpt with ModelArts case
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.check_in_modelarts", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_remote_save_url",
return_value="s3://bucket/"), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_output_root_path",
return_value="/tmp/"), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.create_file") as mock_create_file, \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=True,
rank_id=0,
world_size=1,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)

# Create a valid directory structure
valid_dir = os.path.join(tmp_path, "valid_ckpt")
rank_0_dir = os.path.join(valid_dir, "rank_0")
os.makedirs(rank_0_dir, exist_ok=True)
valid_ckpt = os.path.join(rank_0_dir, "test.ckpt")
with open(valid_ckpt, "w", encoding="utf-8") as f:
f.write("valid ckpt content")

# Mock transform_checkpoints to succeed
with patch.object(transform_ckpt, "check_src_checkpoint_and_strategy"), \
patch.object(transform_ckpt, "transform_checkpoints"), \
patch.object(transform_ckpt, "wait_transform"), \
patch.object(transform_ckpt, "send_transformed_checkpoint_to_obs"):
transform_ckpt.transform_ckpt(
src_checkpoint=valid_dir,
dst_checkpoint_dir=tmp_path,
src_strategy="src_strategy.ckpt",
dst_strategy="dst_strategy.ckpt",
prefix="checkpoint_"
)
# Check that transform_succeed file was created
mock_create_file.assert_called()

# Test 4, \ transform_ckpt when rank_id is not in transform_rank_id_list
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=2), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=False), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=False,
rank_id=1,
world_size=2,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)
# Set transform_rank_id_list to [0] so rank 1 is not in the list
transform_ckpt.transform_rank_id_list = [0]

# Create a valid directory structure
valid_dir = os.path.join(tmp_path, "valid_ckpt")
rank_0_dir = os.path.join(valid_dir, "rank_0")
os.makedirs(rank_0_dir, exist_ok=True)
valid_ckpt = os.path.join(rank_0_dir, "test.ckpt")
with open(valid_ckpt, "w", encoding="utf-8") as f:
f.write("valid ckpt content")

# Mock wait_transform to avoid infinite loop
with patch.object(transform_ckpt, "check_src_checkpoint_and_strategy"), \
patch.object(transform_ckpt, "wait_transform"):
transform_ckpt.transform_ckpt(
src_checkpoint=valid_dir,
dst_checkpoint_dir=tmp_path,
src_strategy="src_strategy.ckpt",
dst_strategy="dst_strategy.ckpt",
prefix="checkpoint_"
)
# Should complete without calling transform_checkpoints

# Test 5: transform_ckpt with transform_by_rank=True
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.check_in_modelarts", return_value=False), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=False,
rank_id=0,
world_size=1,
transform_process_num=1,
transform_by_rank=True,
npu_num_per_node=1
)

# Create a valid directory structure
valid_dir = os.path.join(tmp_path, "valid_ckpt")
rank_0_dir = os.path.join(valid_dir, "rank_0")
os.makedirs(rank_0_dir, exist_ok=True)
valid_ckpt = os.path.join(rank_0_dir, "test.ckpt")
with open(valid_ckpt, "w", encoding="utf-8") as f:
f.write("valid ckpt content")

# Mock transform_checkpoint_by_rank to succeed
with patch.object(transform_ckpt, "check_src_checkpoint_and_strategy"), \
patch.object(transform_ckpt, "transform_checkpoint_by_rank"), \
patch.object(transform_ckpt, "wait_transform"):
transform_ckpt.transform_ckpt(
src_checkpoint=valid_dir,
dst_checkpoint_dir=tmp_path,
src_strategy="src_strategy.ckpt",
dst_strategy="dst_strategy.ckpt",
prefix="checkpoint_"
)
# Should complete successfully

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_init_invalid_npu_num(self):
"""Test __init__ method with invalid npu_num_per_node"""
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=2), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
pytest.raises(ValueError), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=2):
TransformCkpt(
auto_trans_ckpt=False,
rank_id=0,
world_size=2,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=3 # Not a power of 2
)

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_init_auto_trans_ckpt_true(self, tmp_path):
"""Test __init__ method with auto_trans_ckpt=True"""
# Test with world_size=1 and auto_trans_ckpt=True
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.ms.get_auto_parallel_context",
return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_output_root_path",
return_value=str(tmp_path)), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=True,
rank_id=0,
world_size=1,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)
assert transform_ckpt.auto_trans_ckpt is True
assert transform_ckpt.transformed_checkpoint_dir == os.path.join(tmp_path, "transformed_checkpoint")
# No dst_strategy_dir when world_size=1
assert not hasattr(transform_ckpt, 'dst_strategy_dir')

# Test world_size>1 and auto_trans_ckpt=True
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=2), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.ms.get_auto_parallel_context",
return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_output_root_path",
return_value=str(tmp_path)), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=True,
rank_id=0,
world_size=2,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)
assert transform_ckpt.auto_trans_ckpt is True
assert transform_ckpt.transformed_checkpoint_dir == os.path.join(tmp_path, "transformed_checkpoint")
assert transform_ckpt.dst_strategy_dir == os.path.join(tmp_path, "strategy")

# Test pipeline parallelism and auto_trans_ckpt=True
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=2), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.ms.get_auto_parallel_context",
return_value=2), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_output_root_path",
return_value=str(tmp_path)), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=True,
rank_id=0,
world_size=2,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)
assert transform_ckpt.use_pipeline is True

# Test ModelArts environment and auto_trans_ckpt=True
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=2), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.ms.get_auto_parallel_context",
return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_output_root_path",
return_value=str(tmp_path)), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_remote_save_url",
return_value="s3://bucket/path"), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.check_in_modelarts", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=True,
rank_id=0,
world_size=2,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)
assert hasattr(transform_ckpt, 'transformed_checkpoint_dir_obs')
assert hasattr(transform_ckpt, 'dst_strategy_dir_obs')

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_main(self):
"""Test main function"""
with patch("sys.argv", [
"transform_checkpoint.py",
"--src_checkpoint", "/path/to/src/ckpt",
"--dst_checkpoint_dir", "/path/to/dst/ckpt",
"--src_strategy", "/path/to/src/strategy.ckpt",
"--dst_strategy", "/path/to/dst/strategy.ckpt",
"--prefix", "checkpoint_",
"--rank_id", "0",
"--world_size", "1",
"--transform_process_num", "1"
# 不传入transform_by_rank参数,使用默认值False
]), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.TransformCkpt") as mock_transform_ckpt:
# Mock the TransformCkpt class and its __call__ method
mock_instance = mock_transform_ckpt.return_value
mock_instance.return_value = "/path/to/dst/ckpt"

# Import and call main function
main()

# Verify TransformCkpt was initialized correctly
mock_transform_ckpt.assert_called_once()
_, kwargs = mock_transform_ckpt.call_args
assert kwargs["rank_id"] == 0
assert kwargs["world_size"] == 1
assert kwargs["transform_process_num"] == 1
assert not kwargs["transform_by_rank"]

# Verify TransformCkpt instance was called correctly
mock_instance.assert_called_once()
_, call_kwargs = mock_instance.call_args
assert call_kwargs["src_checkpoint"] == "/path/to/src/ckpt"
assert call_kwargs["dst_checkpoint_dir"] == "/path/to/dst/ckpt"
assert call_kwargs["src_strategy"] == "/path/to/src/strategy.ckpt"
assert call_kwargs["dst_strategy"] == "/path/to/dst/strategy.ckpt"
assert call_kwargs["prefix"] == "checkpoint_"

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_transform_rank_id_list_invalid(self):
"""Test _get_transform_rank_id_list method with invalid input"""
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=8), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=8):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=False,
rank_id=0,
world_size=8,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=8
)
# Test with transform_process_num < 1
with pytest.raises(ValueError):
transform_ckpt._get_transform_rank_id_list(0)
# Test with transform_process_num not divisible by world_size
with pytest.raises(ValueError):
transform_ckpt._get_transform_rank_id_list(3)

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_strategy(self, tmp_path):
"""Test get_strategy method with various inputs"""
with (patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1)):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=False,
rank_id=0,
world_size=1,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)

# Test 1: None input
result = transform_ckpt.get_strategy(None)
assert result is None

# Test 2: "None" string input
result = transform_ckpt.get_strategy("None")
assert result is None

# Test 3: Invalid path
invalid_path = os.path.join(tmp_path, "invalid_path")
with pytest.raises(ValueError):
transform_ckpt.get_strategy(invalid_path)

# Test 4: File input
test_file = os.path.join(tmp_path, "test_strategy.ckpt")
with open(test_file, "w", encoding="utf-8") as f:
f.write("test strategy content")

result = transform_ckpt.get_strategy(test_file)
assert result == test_file

# Test 5: Directory input with main rank
strategy_dir = os.path.join(tmp_path, "strategy_dir")
os.makedirs(strategy_dir)

# Create a strategy file in the directory
strategy_file = os.path.join(strategy_dir, "strategy_0.ckpt")
with open(strategy_file, "w", encoding="utf-8") as f:
f.write("strategy content")

# Mock ms.merge_pipeline_strategys
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.ms."
"merge_pipeline_strategys") as mock_merge, \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.create_file") as mock_create_file:
result = transform_ckpt.get_strategy(strategy_dir)
expected_merge_path = os.path.join(strategy_dir, "merged_ckpt_strategy.ckpt")
assert result == expected_merge_path
mock_merge.assert_called_once_with(strategy_dir, expected_merge_path)
mock_create_file.assert_called_once()

# Test 6: Directory input with main rank and existing merged strategy
# Create merged strategy file
merged_strategy = os.path.join(strategy_dir, "merged_ckpt_strategy.ckpt")
with open(merged_strategy, "w", encoding="utf-8") as f:
f.write("merged strategy content")

# Mock ms.merge_pipeline_strategys
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.ms."
"merge_pipeline_strategys") as mock_merge, \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.create_file") as mock_create_file, \
patch("os.remove") as mock_remove:
result = transform_ckpt.get_strategy(strategy_dir)
expected_merge_path = os.path.join(strategy_dir, "merged_ckpt_strategy.ckpt")
assert result == expected_merge_path
mock_remove.assert_called_once_with(expected_merge_path)
mock_merge.assert_called_once_with(strategy_dir, expected_merge_path)
mock_create_file.assert_called_once()

# Test 7: Directory input with non-main rank
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=False):
transform_ckpt_non_main = TransformCkpt(
auto_trans_ckpt=False,
rank_id=1,
world_size=2,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)

# Create merged_succeed.txt to avoid infinite loop
merged_succeed_txt = os.path.join(strategy_dir, "merge_succeed.txt")
with open(merged_succeed_txt, "w", encoding="utf-8") as f:
f.write("merge succeed")

result = transform_ckpt_non_main.get_strategy(strategy_dir)
expected_merge_path = os.path.join(strategy_dir, "merged_ckpt_strategy.ckpt")
assert result == expected_merge_path

# Test 8: Directory input with rank_id parameter
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.ms."
"merge_pipeline_strategys") as mock_merge, \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.create_file") as mock_create_file:
result = transform_ckpt.get_strategy(strategy_dir, rank_id=1)
expected_merge_path = os.path.join(strategy_dir, "merged_ckpt_strategy_by_rank_1.ckpt")
assert result == expected_merge_path
mock_merge.assert_called_once_with(strategy_dir, expected_merge_path)
mock_create_file.assert_called_once()

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_build_soft_link_of_checkpoint_invalid_file(self, tmp_path):
"""Test build_soft_link_of_checkpoint method with invalid file"""
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=False,
rank_id=0,
world_size=1,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)
# Create an invalid file (not a ckpt file)
invalid_file = os.path.join(tmp_path, "invalid.txt")
with open(invalid_file, "w", encoding="utf-8") as f:
f.write("invalid content")
soft_link_dir = os.path.join(tmp_path, "soft_link")
os.makedirs(soft_link_dir)
with pytest.raises(ValueError):
transform_ckpt.build_soft_link_of_checkpoint(invalid_file, soft_link_dir)

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_send_strategy_to_obs(self, tmp_path):
# pylint: disable=W0613
"""Test send_strategy_to_obs method"""
# Create mock functions for mox.file operations
def mock_copy(*args, **kwargs):
return None

def mock_exists(*args, **kwargs):
return False

# Mock the moxing module and mox alias
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.check_in_modelarts", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_remote_save_url",
return_value="s3://bucket"), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.mox", create=True) as mock_mox, \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
# Configure the mock mox object
mock_mox.file.copy = mock_copy
mock_mox.file.exists = mock_exists

transform_ckpt = TransformCkpt(
auto_trans_ckpt=True,
rank_id=0,
world_size=1,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)
# Add required attributes for ModelArts
transform_ckpt.dst_strategy_dir_obs = "s3://bucket/strategy"

# Create a strategy file
strategy_file = os.path.join(tmp_path, "test_strategy.ckpt")
with open(strategy_file, "w", encoding="utf-8") as f:
f.write("test strategy content")

transform_ckpt.send_strategy_to_obs(strategy_file)

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_send_transformed_checkpoint_to_obs(self, tmp_path):
"""Test send_transformed_checkpoint_to_obs method"""
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.check_in_modelarts", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_remote_save_url",
return_value="s3://bucket"), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.mox_adapter") as mock_mox_adapter, \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=True,
rank_id=0,
world_size=1,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)
# Add required attributes for ModelArts
transform_ckpt.transformed_checkpoint_dir_obs = "s3://bucket/transformed"

# Create a dst checkpoint directory
dst_ckpt_dir = os.path.join(tmp_path, "dst_ckpt")
os.makedirs(dst_ckpt_dir)

transform_ckpt.send_transformed_checkpoint_to_obs(dst_ckpt_dir)
mock_mox_adapter.assert_called_once()

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_wait_transform(self, tmp_path):
"""Test wait_transform method"""
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.check_in_modelarts", return_value=False), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=False,
rank_id=0,
world_size=1,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)

# Create a ckpt_dir
ckpt_dir = os.path.join(tmp_path, "ckpt_dir")
os.makedirs(ckpt_dir)

# Create transform_succeed file
succeed_file = os.path.join(ckpt_dir, "transform_succeed_rank_0.txt")
with open(succeed_file, "w", encoding="utf-8") as f:
f.write("transform succeed")

# This should return immediately since the succeed file exists
transform_ckpt.wait_transform(ckpt_dir)

# Test with transform_failed file
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.check_in_modelarts", return_value=False):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=False,
rank_id=0,
world_size=1,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)

# Create a ckpt_dir
ckpt_dir = os.path.join(tmp_path, "ckpt_dir_failed")
os.makedirs(ckpt_dir)

# Create transform_failed file
failed_file = os.path.join(ckpt_dir, "transform_failed_rank_0.txt")
with open(failed_file, "w", encoding="utf-8") as f:
f.write("transform failed")

# This should raise ValueError since a failed file exists
with pytest.raises(ValueError):
transform_ckpt.wait_transform(ckpt_dir)

# Test with ModelArts case
# Import the module and add mox attribute directly

mock_mox = MagicMock()
mock_mox.file = MagicMock()

# Define a side_effect to return different results based on the pattern
def mock_glob_side_effect(pattern):
if 'transform_failed' in pattern:
return [] # No failed files
if 'transform_succeed' in pattern:
return ["s3://bucket/path/transformed_checkpoint/ckpt_dir_modelarts/transform_succeed_rank_0.txt"]
return []

mock_mox.file.glob.side_effect = mock_glob_side_effect
transform_checkpoint.mox = mock_mox

with patch("mindformers.tools.ckpt_transform.transform_checkpoint.check_in_modelarts", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_remote_save_url",
return_value="s3://bucket/path"), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_output_root_path",
return_value=str(tmp_path)):
# Create TransformCkpt instance
transform_ckpt = TransformCkpt(
auto_trans_ckpt=True,
rank_id=0,
world_size=1,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)

# Create a ckpt_dir
ckpt_dir = os.path.join(tmp_path, "ckpt_dir_modelarts")
os.makedirs(ckpt_dir)

# This should return immediately since mock returns succeed file
transform_ckpt.wait_transform(ckpt_dir)

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_wait_collect_all_strategy(self, tmp_path):
"""Test wait_collect_all_strategy method"""
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.check_in_modelarts", return_value=False), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=True,
rank_id=0,
world_size=1,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)
# Add required attributes
transform_ckpt.dst_strategy_dir = tmp_path

# Create a strategy file
strategy_file = os.path.join(tmp_path, "ckpt_strategy_rank_0.ckpt")
with open(strategy_file, "w", encoding="utf-8") as f:
f.write("test strategy content")

# This should return immediately since the strategy file exists
transform_ckpt.wait_collect_all_strategy()

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_clear_cache_not_main_rank(self, tmp_path):
"""Test clear_cache method when not main rank"""
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=False), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=False,
rank_id=1,
world_size=2,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)
# Add a cache file
cache_file = os.path.join(tmp_path, "cache.txt")
with open(cache_file, "w", encoding="utf-8") as f:
f.write("cache content")
transform_ckpt.cache_list.append(cache_file)
# Clear cache - should not delete anything since not main rank
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.delete_file") as mock_delete_file:
transform_ckpt.clear_cache()
mock_delete_file.assert_not_called()

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_dst_strategy(self, tmp_path):
"""Test get_dst_strategy method"""
# Test with world_size=1
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=1), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=False,
rank_id=0,
world_size=1,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)
result = transform_ckpt.get_dst_strategy("test_strategy.ckpt")
assert result is None

# Test with world_size > 1 and invalid dst_strategy
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=2), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
transform_ckpt = TransformCkpt(
auto_trans_ckpt=False,
rank_id=0,
world_size=2,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)
# Test with invalid dst_strategy (wrong rank suffix)
with pytest.raises(ValueError):
transform_ckpt.get_dst_strategy("test_strategy_rank_1.ckpt")

# Test with world_size > 1 and valid dst_strategy
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=2), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
# Create a valid strategy file
valid_strategy = os.path.join(tmp_path, "test_strategy_rank_0.ckpt")
with open(valid_strategy, "w", encoding="utf-8") as f:
f.write("valid strategy")

transform_ckpt = TransformCkpt(
auto_trans_ckpt=False,
rank_id=0,
world_size=2,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)
transform_ckpt.use_pipeline = False
result = transform_ckpt.get_dst_strategy(valid_strategy)
assert result == valid_strategy

# Test with pipeline parallelism and main rank
with patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_group_size", return_value=2), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_real_rank", return_value=0), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.is_main_rank", return_value=True), \
patch("mindformers.tools.ckpt_transform.transform_checkpoint.get_device_num_per_node", return_value=1):
# Create a valid strategy file
valid_strategy = os.path.join(tmp_path, "test_strategy_rank_0.ckpt")
with open(valid_strategy, "w", encoding="utf-8") as f:
f.write("valid strategy")

# Create dst_strategy_dir with merged strategy
dst_strategy_dir = os.path.join(tmp_path, "strategy")
os.makedirs(dst_strategy_dir)
merged_strategy = os.path.join(dst_strategy_dir, "merged_ckpt_strategy.ckpt")
with open(merged_strategy, "w", encoding="utf-8") as f:
f.write("merged strategy")

transform_ckpt = TransformCkpt(
auto_trans_ckpt=False,
rank_id=0,
world_size=2,
transform_process_num=1,
transform_by_rank=False,
npu_num_per_node=1
)
transform_ckpt.use_pipeline = True
transform_ckpt.dst_strategy_dir = dst_strategy_dir

with patch.object(transform_ckpt, "get_strategy", return_value=merged_strategy), \
patch.object(transform_ckpt, "wait_collect_all_strategy"):
result = transform_ckpt.get_dst_strategy(valid_strategy)
assert result == merged_strategy

+ 3
- 0
tests/st/test_ut/test_trainer/test_trainer_methods.py View File

@@ -518,6 +518,9 @@ class TestTrainerCheckpointMethods(unittest.TestCase):
f.write('mock')
os.stat(last_checkpoint_path)

os.utime(last_checkpoint_path, (os.path.getatime(last_checkpoint_path) + 1,
os.path.getmtime(last_checkpoint_path) + 1))

trainer._check_checkpoint_config(True)
assert trainer.config.model.model_config.checkpoint_name_or_path == last_checkpoint_path



+ 104
- 0
tests/st/test_ut/test_utils/test_convert_utils.py View File

@@ -0,0 +1,104 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test convert_utils
"""
import numpy as np
import pytest

from mindformers.utils.convert_utils import is_lora_param, qkv_concat_hf2mg, ffn_concat_hf2mg


class TestIsLoraParam:
""" A test class for testing is_lora_param."""
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_lora_in_key_lowercase(self):
"""Test with 'lora' in lowercase in the key"""
assert is_lora_param("model.lora.weight") is True

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_no_lora_in_key(self):
"""Test with no 'lora' in the key"""
assert is_lora_param("model.linear.weight") is False


class TestQkvConcatHf2Mg:
""" A test class for testing qkv_concat_hf2mg."""
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_qkv_concat_2d_array(self):
"""Test with 2D array input"""
hidden_size = 768
num_heads = 12
n_kv_heads = 12
n_rep = num_heads // n_kv_heads
q_channel = hidden_size
kv_channel = hidden_size // n_rep
total_channels = q_channel + 2 * kv_channel
qkv_weights = np.random.rand(total_channels, 1024).astype(np.float32)
result = qkv_concat_hf2mg(qkv_weights, num_heads, n_kv_heads, hidden_size)
assert result.shape == qkv_weights.shape

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_qkv_concat_1d_array(self):
"""Test with 1D array input (bias case)"""
hidden_size = 768
num_heads = 12
n_kv_heads = 12
n_rep = num_heads // n_kv_heads
q_channel = hidden_size
kv_channel = hidden_size // n_rep
total_channels = q_channel + 2 * kv_channel
qkv_weights = np.random.rand(total_channels).astype(np.float32)
result = qkv_concat_hf2mg(qkv_weights, num_heads, n_kv_heads, hidden_size)
assert result.shape == qkv_weights.shape

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_qkv_concat_3d_array_raises_error(self):
"""Test with 3D array input should raise ValueError"""
qkv_weights = np.random.rand(10, 20, 30).astype(np.float32)
num_heads = 12
n_kv_heads = 12
hidden_size = 768
with pytest.raises(ValueError, match="qkv_weights shape is not supported."):
qkv_concat_hf2mg(qkv_weights, num_heads, n_kv_heads, hidden_size)


class TestFfnConcatHf2Mg:
""" A test class for testing ffn_concat_hf2mg."""
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_ffn_concat_basic_case(self):
"""Test basic FFN concat conversion"""
ffn_weights = np.array([
[1, 2, 3], # gate weights part 1
[4, 5, 6], # gate weights part 2
[7, 8, 9], # hidden weights part 1
[10, 11, 12] # hidden weights part 2
], dtype=np.float32)
ffn_hidden_size = 2
result = ffn_concat_hf2mg(ffn_weights, ffn_hidden_size)
assert result.shape == ffn_weights.shape
assert isinstance(result, np.ndarray)

+ 128
- 0
tests/st/test_ut/test_version_control.py View File

@@ -0,0 +1,128 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test version_control.py
"""
from unittest.mock import patch
import pytest

import mindspore as ms
import mindspore_gs
from mindformers.version_control import (check_is_reboot_node, check_valid_mindspore_gs, check_valid_gmm_op,
is_version_python, get_norm)


class TestCheckIsVersion:
"""Test class for testing version_control."""
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('os.getenv')
def test_version_too_low_returns_false(self, mock_getenv):
"""Test when MindSpore version is lower than 2.6.0 returns False with warning."""
ms.__version__ = "2.6.0"
result = check_is_reboot_node()
mock_getenv.return_value = "ARF:1"
assert result is False

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_version_too_low_returns_false2(self):
"""Test when MindSpore version is lower than 2.6.0 returns False with warning."""
ms.__version__ = "2.5.0"
result = check_is_reboot_node()
assert result is False

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_version(self):
"""Test when mindspore_gs version."""
mindspore_gs.__version__ = "0.6.0"
result = check_valid_mindspore_gs()
assert result is True
mindspore_gs.__version__ = "0.5.0"
result = check_valid_mindspore_gs()
assert result is False

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindspore.__version__', '2.6.0')
def test_check_valid_gmm_op_with_version_equal_to_required(self):
"""Test when MindSpore version equals required version, should return True"""
result = check_valid_gmm_op(gmm_version="GroupedMatmulV4")
assert result is True

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindspore.__version__', '2.6.0-rc1')
def test_check_valid_gmm_op_with_rc_version(self):
"""Test when MindSpore version has rc suffix, should handle correctly"""
result = check_valid_gmm_op(gmm_version="GroupedMatmulV4")
assert result is True or result is False

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_is_version_python_cur_higher_than_tar(self):
"""Test when current version is higher than target version"""
result = is_version_python("3.9.1", "3.9.0")
assert result is True
result = is_version_python("3.7.10", "3.9.0")
assert result is False

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_is_version_python_missing_dot_in_cur(self):
"""Test when current version string doesn't contain dot, should raise ValueError"""
with pytest.raises(ValueError) as exc_info:
is_version_python("37910", "3.9.0")
assert "The version string will contain the `.`" in str(exc_info.value)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_is_version_python_different_version_lengths(self):
"""Test version strings with different number of segments"""
result = is_version_python("3.9.0.1", "3.9.0")
assert result is True
result = is_version_python("3.9", "3.9.0")
assert result is True

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_norm_version_ge_1_11_0(self):
"""Test when mindspore version >= '1.11.0', should return tensor_norm1"""
with patch('mindspore.__version__', '1.11.0'):
with patch('mindformers.tools.utils.is_version_ge') as mock_is_version_ge:
mock_is_version_ge.return_value = True
norm_func = get_norm()
assert norm_func.__name__ == 'tensor_norm1' or norm_func.__code__.co_varnames[:5] == (
'input_tensor', 'tensor_ord', 'dim', 'keepdim', 'dtype')
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_norm_version_lt_1_11_0(self):
"""Test when mindspore version < '1.11.0', should return tensor_norm2"""
with patch('mindspore.__version__', '1.10.0'):
with patch('mindformers.tools.utils.is_version_ge') as mock_is_version_ge:
mock_is_version_ge.return_value = False
norm_func = get_norm()
assert norm_func.__name__ == 'tensor_norm2' or norm_func.__defaults__[0] == 2

+ 1
- 1
toolkit/safetensors/README.md View File

@@ -10,7 +10,7 @@
python unified_safetensors.py \
--mindspore_ckpt_dir /path/checkpoint \
--src_strategy_dirs /path/src_strategy_dirs \
--output_dir /path/src_strategy_dirs \
--output_dir /path/output_dir \
--file_suffix "1_1" \
--format "ckpt" \
--has_redundancy False \


Loading…
Cancel
Save
Baidu
map