107 Commits

Author SHA1 Message Date
  i-robot 3995c1ed1b
!7857 add and fix llm_trainer function comments. 5 days ago
  i-robot 13516477ef
!7856 【master commit 同步】 5 days ago
  Lin-Bert ff21e7c436 add and fix llm_trainer function comments. 6 days ago
  zyw_hw 0ebb0721bf update ms pkg 1 week ago
  zyw_hw f303fd1976 update third party info 1 week ago
  pengjingyou a4cd09b47d 【master】【infer】新增Glm4Moe整网减层st用例 1 week ago
  Yule100 16b0f49559 bugfix hostname获取不到的bug 1 week ago
  hangangqiang 96384d4d22 add gmm-linear quant ut 1 week ago
  JavaZero 98ca4431ca update: increase default qk_clip_threshold from 4 to 100 1 week ago
  JavaZero bbb0d7f8ed test: add unit tests for Muon optimizer initialization and computation 1 week ago
  lanxiang bfeca34e53 SlidingWindowAttention和SharedCrossAttention新增测试用例 1 week ago
  Yule100 61dd5bde22 bugfix 文档修复 1 week ago
  lzy0920232 56d913cb15 code_bugfix_cp_moe 1 week ago
  宋佳琪 93bfe6282f stable_rank_fix_moe 1 week ago
  SaiYao dc9e76daf6 【master】【bugfix】【权重】离线转换脚本参数名统一,与convert_weight.py保持一致 1 week ago
  Xinrui Chen d5e9bbd206 [Docs] Delete expired statement in llama3.1 README.md 1 week ago
  senzhen 3c3c2fdf82 添加broadcast逻辑 2 weeks ago
  魏琢艺 0a42a3bd26 trainingstatemonitor doc fix 2 weeks ago
  zyw_hw 403e6ace5a fix docs statement 1 week ago
  Yule100 0a9bd0d981 bugfix qwen3_moe文档错误修改 1 week ago
  SaiYao 5d5ee8696c 【master】【bugfix】【权重】 safetensors 2.0 保存路径配置项复用 directory 1 week ago
  zhangyihuiben fe80f966e6 【master】【mcore】【bugfix】Fix the incorrect path in the YAML file 1 week ago
  renyujin a24a49bd95 add test ut for model_runner 2 weeks ago
  senzhen 59226e23be 修复Glm拼写 1 week ago
  zyw_hw 91b58785d6 fix callback out time 1 week ago
  senzhen 8c025b2aa6 修复Glm大小写 1 week ago
  qinsichun 057b90d3d9 add_quant 1 week ago
  JavaZero dd2318262f refactor: update MoE expert validation logic in Muon and GPTModel 1 week ago
  JavaZero 83b4f6da41 Refactor max attention logit handling in GPT model and related components 1 week ago
  SaiYao 6620315306 【master】【bugfix】【日志】权重相关日志,在raise Error之前添加logger.error,确保在error.log中有对应日志 1 week ago
  husichao 77ff7487c9 add sharded_state_dict for muon op group 2 weeks ago
  SaiYao 9c2939423f 【master】【bugfix】【日志】权重相关日志,在raise Error之前添加logger.error,确保在error.log中有对应日志 1 week ago
  Yule100 bb7df1d098 bugfix 补充ut 1 week ago
  zyw_hw 18f33836cb add callback testcase 2 weeks ago
  zyw_hw aba5d1639a fix huge cc 2 weeks ago
  lanxiang 4e78181f0e 新增modeling_utils用例 1 week ago
  Hsshuai 9517612d38 fix testcase of get_last_checkpoint 1 week ago
  lanxiang 5ff48468f8 修复test_pma用例超时,去掉test_all_reduce中无用用例 1 week ago
  zzzkeke 2b5f93449d 修改blended_megatron_dataset_builder测试用例构建失败用例 1 week ago
  Yule100 28e7c204f2 推理用例补覆盖率 2 weeks ago
  senzhen 81b8258fcd 修改资料TeleChat大小写 1 week ago
  pengjingyou 1a15215957 【master】推理覆盖率提升 2 weeks ago
  zxq c3e7c2ebf9 【UT】补充测试用例 2 weeks ago
  李宜杰 5e4b10043c add pipeline and metric ut_test 2 weeks ago
  senzhen bff73eaf60 文档拼写整改 2 weeks ago
  yiyison 759ddbb811 文档修正 1 week ago
  JavaZero e9201f7885 fix: adjust chunking logic in _slice_tensor_to_shards for tensor distribution 2 weeks ago
  niujunhao 2911df1a56 fix bs>1 in hf dataloader tnd. 2 weeks ago
  yiyison 5adb7dbfb1 增加transform_checkpoint_utils.py测试用例 2 weeks ago
  yiyison f554016963 文档修正 1 week ago
  JingweiHuang fb505ca730 Fix the error of the_build_context 2 weeks ago
  yiyison b5a2fc2e97 增加transform_checkpoint_utils.py测试用例 2 weeks ago
  qinsichun 45f416cc1c test_conv 2 weeks ago
  zxq 14f1792ccf 【master】【UT】补充weight_utils、logger、version_control文件的测试用例 2 weeks ago
  宋佳琪 32b54c2f11 stable_rank_fix 3 weeks ago
  yiyison d8dbeea321 fully_parallel测试用例 2 weeks ago
  zhangyihuiben 8d801da03c 【master】【mcore】【bugfix】Fix the incorrect path in the YAML file and Fix inconsistencies in formatting, capitalization, and typos in the documentation. 2 weeks ago
  zyw_hw 152c8676fc fix softmax ques 2 weeks ago
  JavaZero f0e7b7c521 ensure config is copied in build_optim 2 weeks ago
  JavaZero 3b0afc5164 fix redistribution op in flash_attn 2 weeks ago
  yiyison ab7a873053 增加load_checkpoint_utils.py以及run_check.py的测试用例 2 weeks ago
  yiyison fc82d2cdc1 增加adamw.py测试用例 2 weeks ago
  zxq 8c65d0e32a 【master】【bug-fix】解决q_lora_rank为None时,跑推理任务权重加载不上的问题 2 weeks ago
  yiyison f4528be603 test_checkpoint测试用例bugfix 2 weeks ago
  yiyison d6cd975b33 日志打印优化 3 weeks ago
  zxq c89a094702 【master】【bug-fix】修改文档中的拼写错误 2 weeks ago
  zzzkeke 0a4b0d8b7a 新增 blended_megatron_dataset_builder 测试用例 2 weeks ago
  zzzkeke 08a9956e5d Add gpt dataset test UT 2 weeks ago
  yiyison 9af71006c8 增加checkpoint.py测试用例 2 weeks ago
  yiyison 06df550072 增加model_mixin.py测试用例 2 weeks ago
  SaiYao 987e5a142c 【master】【bugfix】【文档】DeepSeek-V3离线脚本文档修改 2 weeks ago
  kongziyi 1d2f701bd0 【master】【用例】为LayerSetting增加swap用例 2 weeks ago
  SaiYao 06442e00d5 【master】【bugfix】【文档】文档通顺度修复 2 weeks ago
  zyw_hw 935a117754 add convert weight test cases 2 weeks ago
  zyw_hw c60ae815ba add tokenizer cases 2 weeks ago
  zyw_hw 7059ef839e fix profiler step question 3 weeks ago
  zyw_hw f043cf486a fix tokenizer case bug 2 weeks ago
  lanxiang 68473c7ddd 新增callback测试用例 2 weeks ago
  魏琢艺 0f93579720 trainingstatemonitor doc fix 2 weeks ago
  魏琢艺 a3701eb235 add TokenDispatcher testcase 2 weeks ago
  SaiYao a4a1ebf7cf 【UT】添加sharded_tensor的UT用例 2 weeks ago
  kongziyi dd6b4ff597 【master】【bugfix】增加muon优化器开启时dp>=op和swap=False的校验 2 weeks ago
  Hsshuai e5dd76d81e add testcase for trainer 2 weeks ago
  yiyison e5e0b6d5c5 非法device id拦截 3 weeks ago
  Yule100 f3880f9cf9 [Feature] 升级transformers版本 3 weeks ago
  SaiYao c945804bc7 【MCore】对DeepSeek-V3离线权重转换脚本做多进程加速 2 weeks ago
  魏琢艺 1023b27134 remove onehot recompute 3 weeks ago
  niujunhao c519a18c54 muon apply grouped lr. 2 weeks ago
  niujunhao fc5efc59a4 fix input config in megatron dataset. 3 weeks ago
  lanxiang f5b3b35ca2 pma新增测试用例 2 weeks ago
  Hsshuai e46ed991dd Enhance set_safe_mode_for_file_or_dir function with retry logic for file permission changes and remove unnecessary cache refresh in set_strategy_save_path. 2 weeks ago
  lzy0920232 937a4fa113 code_bugfix_rope 3 weeks ago
  Hsshuai bed70f8480 add test for tokenization 3 weeks ago
  zhangyihuiben 9675a2abc9 【master】【mcore】【bugfix】fix nope_layer_interval not rejected by invalid value 3 weeks ago
  zyw_hw 3122257fcc update ms pkg url 2 weeks ago
  JingweiHuang 2b8e8a983e Add unit tests to context 3 weeks ago
  zzzkeke 77bbfcdfa1 The position_embedding_type must be one of: 'rope', 'yarn', 'none', 'relative', 'learned_absolute'. 3 weeks ago
  魏琢艺 02ad951507 avoid repeatedly create group 3 weeks ago
  SaiYao 3003738754 【MCore】对Qwen3系列反转脚本做多进程加速 3 weeks ago
  JavaZero 7da119ae66 Enhance FlashAttention: optimize max logits tracking and reduce max operationa and fix tnd layout 3 weeks ago
  niujunhao b6dceec92d flash new made dir. 3 weeks ago
  niujunhao a22563c4b9 fix toolalpaca case. 3 weeks ago
  kongziyi 6d7b8a6904 【master】【bugfix】增加offset校验,避免分配负数层 1 month ago
  senzhen 0b3b96761e 完善非共享路径校验报错 3 weeks ago
  niujunhao 5cfdc1d8ae fix iter num_parallel_workers in hf streaming load. 3 weeks ago
  yiyison 2cdfc06bf2 保存流程bugfix 3 weeks ago
  SaiYao f912e2b292 【Telechat2】对无用配置项做忽略处理 3 weeks ago
100 changed files with 10128 additions and 4251 deletions
Split View
  1. +1
    -1
      .jenkins/test/config/dependent_packages.yaml
  2. +4414
    -2939
      Third_Party_Open_Source_Software_Notice
  3. +16
    -16
      configs/glm4/README.md
  4. +10
    -10
      configs/glm4_moe/README.md
  5. +13
    -13
      configs/qwen3_moe/README.md
  6. +10
    -2
      convert_weight.py
  7. +3
    -1
      docs/api/api_python/core/mindformers.core.CheckpointMonitor.rst
  8. +9
    -9
      docs/api/api_python/core/mindformers.core.TrainingStateMonitor.rst
  9. +2
    -2
      docs/api/api_python/models/mindformers.models.LlamaForCausalLM.rst
  10. +1
    -1
      docs/api/api_python/tools/mindformers.tools.MindFormerConfig.rst
  11. +9
    -9
      docs/model_cards/glm4.md
  12. +4
    -4
      docs/security_statement.md
  13. +3
    -4
      docs/transformer仓Python编程规范.md
  14. +0
    -1
      mindformers/__init__.py
  15. +189
    -0
      mindformers/checkpoint/broadcast.py
  16. +106
    -71
      mindformers/checkpoint/checkpoint.py
  17. +36
    -20
      mindformers/checkpoint/fully_parallel.py
  18. +15
    -8
      mindformers/checkpoint/sharded_tensor.py
  19. +66
    -79
      mindformers/core/callback/callback.py
  20. +18
    -9
      mindformers/core/context/build_context.py
  21. +12
    -6
      mindformers/core/context/validators.py
  22. +40
    -37
      mindformers/core/optim/__init__.py
  23. +2
    -1
      mindformers/core/optim/build_optim.py
  24. +2
    -2
      mindformers/core/optim/fused_pma_adamw.py
  25. +65
    -50
      mindformers/core/optim/muon.py
  26. +6
    -0
      mindformers/dataset/causal_language_model_dataset.py
  27. +1
    -1
      mindformers/dataset/dataloader/blended_megatron_dataloader.py
  28. +4
    -0
      mindformers/dataset/dataloader/hf_dataloader.py
  29. +11
    -11
      mindformers/generation/utils.py
  30. +2
    -2
      mindformers/model_runner.py
  31. +4
    -2
      mindformers/models/glm4_moe/modeling_glm4_moe_infer.py
  32. +15
    -13
      mindformers/models/llama/llama.py
  33. +11
    -0
      mindformers/models/qwen3/modeling_qwen3_infer.py
  34. +1
    -0
      mindformers/models/telechat2/configuration_telechat2.py
  35. +0
    -1
      mindformers/modules/__init__.py
  36. +21
    -37
      mindformers/modules/layers.py
  37. +3
    -3
      mindformers/modules/quantizers/base.py
  38. +1
    -11
      mindformers/modules/quantizers/ptq_quantizer.py
  39. +0
    -11
      mindformers/modules/quantizers/rtn_quantizer.py
  40. +0
    -1
      mindformers/modules/transformer/__init__.py
  41. +20
    -279
      mindformers/modules/transformer/transformer.py
  42. +4
    -2
      mindformers/parallel_core/inference/model_utils.py
  43. +3
    -19
      mindformers/parallel_core/inference/quantization/golden_stick/a8dynw4.py
  44. +28
    -25
      mindformers/parallel_core/inference/tensor_parallel/grouped_layers.py
  45. +10
    -10
      mindformers/parallel_core/inference/utils.py
  46. +14
    -14
      mindformers/parallel_core/training_graph/base_models/common/embeddings/rope_utils.py
  47. +139
    -37
      mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py
  48. +23
    -26
      mindformers/parallel_core/training_graph/loss_func.py
  49. +129
    -0
      mindformers/parallel_core/training_graph/tensor_parallel/layers.py
  50. +20
    -11
      mindformers/parallel_core/training_graph/transformer/flash_attention.py
  51. +25
    -0
      mindformers/parallel_core/training_graph/transformer/moe/ffn.py
  52. +3
    -24
      mindformers/parallel_core/training_graph/transformer/moe/moe_layer.py
  53. +15
    -0
      mindformers/parallel_core/training_graph/transformer/moe/router.py
  54. +15
    -0
      mindformers/parallel_core/training_graph/transformer/moe/shared_experts.py
  55. +13
    -73
      mindformers/parallel_core/training_graph/transformer/moe/token_dispatcher.py
  56. +93
    -0
      mindformers/parallel_core/training_graph/transformer/moe/utils.py
  57. +10
    -13
      mindformers/parallel_core/training_graph/transformer/multi_latent_attention.py
  58. +6
    -4
      mindformers/parallel_core/training_graph/transformer/multi_token_prediction.py
  59. +89
    -14
      mindformers/parallel_core/training_graph/transformer/norm.py
  60. +26
    -5
      mindformers/parallel_core/training_graph/transformer/utils.py
  61. +17
    -0
      mindformers/parallel_core/transformer_config.py
  62. +7
    -2
      mindformers/parallel_core/utils/model_mixin.py
  63. +65
    -38
      mindformers/tools/ckpt_transform/transform_checkpoint.py
  64. +39
    -18
      mindformers/tools/resume_ckpt.py
  65. +43
    -5
      mindformers/tools/utils.py
  66. +16
    -15
      mindformers/trainer/base_trainer.py
  67. +424
    -67
      mindformers/trainer/llm_trainer_for_graph_experimental/llm_trainer.py
  68. +0
    -12
      mindformers/trainer/optimizer_grouped_parameters.py
  69. +25
    -5
      mindformers/trainer/trainer.py
  70. +37
    -15
      mindformers/trainer/utils.py
  71. +61
    -27
      mindformers/utils/load_checkpoint_utils.py
  72. +3
    -1
      mindformers/utils/resume_ckpt_utils.py
  73. +1
    -1
      mindformers/wrapper/wrapper.py
  74. +2
    -2
      requirements.txt
  75. +16
    -1
      research/deepseek3/README.md
  76. +1
    -2
      research/llama3_1/README.md
  77. +2
    -2
      research/llama3_1/llama.py
  78. +2
    -2
      research/qwen2_5/README.md
  79. +18
    -23
      research/telechat2/README.md
  80. +79
    -9
      tests/st/test_multi_cards_cases/test_model/test_deepseek3/run_deepseek3.py
  81. +67
    -0
      tests/st/test_multi_cards_cases/test_model/test_deepseek3/test_deepseek3_alltoall_deredundency_train.py
  82. +67
    -0
      tests/st/test_multi_cards_cases/test_model/test_deepseek3/test_deepseek3_alltoall_zero_redundancy_train.py
  83. +6
    -4
      tests/st/test_multi_cards_cases/test_model/test_deepseek3/test_deepseek3_train.py
  84. +15
    -0
      tests/st/test_multi_cards_cases/test_model/test_glm4_moe/__init__.py
  85. +15
    -0
      tests/st/test_multi_cards_cases/test_model/test_glm4_moe/test_glm4_moe_infer/__init__.py
  86. +40
    -0
      tests/st/test_multi_cards_cases/test_model/test_glm4_moe/test_glm4_moe_infer/glm4_moe_infer.yaml
  87. +90
    -0
      tests/st/test_multi_cards_cases/test_model/test_glm4_moe/test_glm4_moe_infer/run_glm4_moe.py
  88. +58
    -0
      tests/st/test_multi_cards_cases/test_model/test_glm4_moe/test_glm4_moe_infer/test_glm4_moe_infer.py
  89. +67
    -0
      tests/st/test_multi_cards_cases/test_optimizer/test_pma/config.json
  90. +134
    -0
      tests/st/test_multi_cards_cases/test_optimizer/test_pma/deepseekv3_train.yaml
  91. +241
    -0
      tests/st/test_multi_cards_cases/test_optimizer/test_pma/run_deepseek3.py
  92. +72
    -0
      tests/st/test_multi_cards_cases/test_optimizer/test_pma/test_pma.py
  93. +100
    -2
      tests/st/test_optim/optimizer_util.py
  94. +816
    -4
      tests/st/test_optim/test_adamw.py
  95. +15
    -0
      tests/st/test_optim/test_muon/__init__.py
  96. +63
    -0
      tests/st/test_optim/test_muon/data_utils.py
  97. +236
    -0
      tests/st/test_optim/test_muon/run_muon.py
  98. +202
    -0
      tests/st/test_optim/test_muon/test_muon.py
  99. +29
    -0
      tests/st/test_run_check.py
  100. +1126
    -50
      tests/st/test_safetensors/test_checkpoint_utils.py

+ 1
- 1
.jenkins/test/config/dependent_packages.yaml View File

@@ -1,4 +1,4 @@
mindspore:
'https://repo.mindspore.cn/mindspore/mindspore/version/202511/20251104/r2.7.2_20251104020009_c23dbc0e0cc1f8e8a82cc6c5acea3d8c22846e71_newest/'
'https://repo.mindspore.cn/mindspore/mindspore/version/202512/20251208/r2.7.2_20251208020007_3c91386a5d0d898f8ebd32011c649f78a9b2918b_newest/'
ms_custom_ops:
'https://repo.mindspore.cn/mindspore/ms_custom_ops/version/202510/20251029/master_20251029031507_da84a03f2297b5499e9a47a8e294078229230087_newest/'

+ 4414
- 2939
Third_Party_Open_Source_Software_Notice
File diff suppressed because it is too large
View File


+ 16
- 16
configs/glm4/README.md View File

@@ -1,4 +1,4 @@
# Glm4
# GLM-4

## 模型描述

@@ -8,8 +8,8 @@ GLM-4 系列模型是专为智能代理设计的基础模型, 其性能可与Ope

| 模型名称 | 规格 | 支持任务 | 模型架构 | 支持设备 | 模型级别 |
|:----------:|:---------:|:------:|:-----:|:-------------------------------------------------:|:-------------------:|
|GLM4-32B | 32B | 推理 | Mcore | Atlas 800T A2/Atlas 800I A2/Atlas 900 A3 SuperPoD | [Validated](#模型级别介绍) |
|GLM4-9B | 9B | 推理 | Mcore | Atlas 800T A2/Atlas 800I A2/Atlas 900 A3 SuperPoD | [Validated](#模型级别介绍) |
|GLM-4-32B | 32B | 推理 | Mcore | Atlas 800T A2/Atlas 800I A2/Atlas 900 A3 SuperPoD | [Validated](#模型级别介绍) |
|GLM-4-9B | 9B | 推理 | Mcore | Atlas 800T A2/Atlas 800I A2/Atlas 900 A3 SuperPoD | [Validated](#模型级别介绍) |

说明:

@@ -18,15 +18,15 @@ GLM-4 系列模型是专为智能代理设计的基础模型, 其性能可与Ope

## 版本配套

GLM4 当前支持的版本配套如下。
GLM-4 当前支持的版本配套如下。

| | Mindspore Transformers | MindSpore | CANN | HDK |
| | MindSpore Transformers | MindSpore | CANN | HDK |
|:---------:|:----------------------:|:---------:|:----:|:---:|
| 当前支持的版本 | 在研版本 | 在研版本 | 在研版本 | 在研版本 |

## 使用样例

MindSpore Transformers 支持使用 GLM4 进行推理。各任务的整体使用流程如下:
MindSpore Transformers 支持使用 GLM-4 进行推理。各任务的整体使用流程如下:

| 任务 | 前期准备 | 使用流程 |
|:---:|:------------------------|:---------------------------|
@@ -69,7 +69,7 @@ parallel_config:
- pretrained_model_dir:Hugging Face模型目录路径,放置模型配置、Tokenizer等文件。`/path/hf_dir`中的内容如下:

```text
📂GLM4
📂GLM-4
├── 📄config.json
├── 📄generation_config.json
├── 📄merges.txt
@@ -130,7 +130,7 @@ python run_mindformer.py \

多卡推理:

Glm4的32B规模模型,只能进行多卡推理,多卡推理的配置需参考下面修改配置:
GLM-4的32B规模模型,只能进行多卡推理,多卡推理的配置需参考下面修改配置:

1. 模型并行model_parallel的配置和使用的卡数需保持一致,下文用例为8卡推理,需将model_parallel设置成8;
2. 当前版本的多卡推理不支持数据并行,需将data_parallel设置为1。
@@ -188,15 +188,15 @@ bash scripts/msrun_launcher.sh "run_mindformer.py \

### 模型文件说明

Glm4的模型文件包括以下内容:
GLM-4的模型文件包括以下内容:

```text
📦glm4
├── 📄__init__.py # glm4模块初始化文件
├── 📄configuration_glm4.py # glm4模型配置类定义
├── 📄modeling_glm4.py # glm4模型主体实现
├── 📄modeling_glm4_infer.py # glm4推理模型实现
└── 📄utils.py # glm4工具函数和基础类
├── 📄__init__.py # GLM-4模块初始化文件
├── 📄configuration_glm4.py # GLM-4模型配置类定义
├── 📄modeling_glm4.py # GLM-4模型主体实现
├── 📄modeling_glm4_infer.py # GLM-4推理模型实现
└── 📄utils.py # GLM-4工具函数和基础类
```

### 并行配置建议
@@ -218,7 +218,7 @@ Glm4的模型文件包括以下内容:
<th>模型级别</th>
</tr>
<tr>
<td>GLM4-32B</td>
<td>GLM-4-32B</td>
<td>32B</td>
<td>1 × Atlas 800T A2 (2P)</td>
<td>2</td>
@@ -235,7 +235,7 @@ Glm4的模型文件包括以下内容:
<td> Validated </td>
</tr>
<tr>
<td>GLM4-9B</td>
<td>GLM-4-9B</td>
<td>9B</td>
<td>1 × Atlas 800T A2 (1P)</td>
<td>1</td>


+ 10
- 10
configs/glm4_moe/README.md View File

@@ -2,7 +2,7 @@

## 模型描述

GLM-4.5 系列模型是专为智能代理设计的基础模型,基于GLM4采用了MoE结构的变体,也标记为GLM4-MoE。GLM-4.5 总参数 3550 亿,激活参数 320 亿,而 GLM-4.5-Air 采用更紧凑的设计,总参数 1060 亿,激活参数 120 亿。GLM-4.5模型统一了推理、编码和智能体能力,满足智能体应用的复杂需求。
GLM-4.5 系列模型是专为智能代理设计的基础模型,基于GLM-4采用了MoE结构的变体,也标记为GLM-4-MoE。GLM-4.5 总参数 3550 亿,激活参数 320 亿,而 GLM-4.5-Air 采用更紧凑的设计,总参数 1060 亿,激活参数 120 亿。GLM-4.5模型统一了推理、编码和智能体能力,满足智能体应用的复杂需求。
具体模型能力查看以下技术报告:[GLM-4.5: Reasoning, Coding, and Agentic Abililties](https://z.ai/blog/glm-4.5)

## 支持规格
@@ -21,7 +21,7 @@ GLM-4.5 系列模型是专为智能代理设计的基础模型,基于GLM4采

GLM-4.5 当前支持的版本配套如下。

| | Mindspore Transformers | MindSpore | CANN | HDK |
| | MindSpore Transformers | MindSpore | CANN | HDK |
|:---------:|:----------------------:|:---------:|:----:|:---:|
| 当前支持的版本 | 在研版本 | 在研版本 | 在研版本 | 在研版本 |

@@ -70,7 +70,7 @@ parallel_config:
- pretrained_model_dir:Hugging Face模型目录路径,放置模型配置、Tokenizer等文件。`/path/hf_dir`中的内容如下:

```text
📂Glm4.5
📂GLM-4.5
├── 📄config.json
├── 📄generation_config.json
├── 📄merges.txt
@@ -110,7 +110,7 @@ msrun_launcher.sh包括run_mindformer.py命令和推理卡数两个参数。

多卡推理:

Glm4.5分别有355B和106B两种规模,只能进行多卡推理,多卡推理的配置需参考下面修改配置:
GLM-4.5分别有355B和106B两种规模,只能进行多卡推理,多卡推理的配置需参考下面修改配置:

1. 模型并行model_parallel的配置和使用的卡数需保持一致,下文用例为8卡推理,需将model_parallel设置成8;
2. 当前版本的多卡推理不支持数据并行,需将data_parallel设置为1。
@@ -190,15 +190,15 @@ bash scripts/msrun_launcher.sh "run_mindformer.py \

### 模型文件说明

glm4_moe的模型文件包括以下内容:
GLM-4-MoE的模型文件包括以下内容:

```text
📦glm4_moe
├── 📄__init__.py # glm4_moe模块初始化文件
├── 📄configuration_glm4_moe.py # glm4_moe模型配置类定义
├── 📄modeling_glm4_moe.py # glm4_moe模型主体实现
├── 📄modeling_glm4_moe_infer.py # glm4_moe推理模型实现
└── 📄utils.py # glm4_moe工具函数和基础类
├── 📄__init__.py # GLM-4-MoE模块初始化文件
├── 📄configuration_glm4_moe.py # GLM-4-MoE模型配置类定义
├── 📄modeling_glm4_moe.py # GLM-4-MoE模型主体实现
├── 📄modeling_glm4_moe_infer.py # GLM-4-MoE推理模型实现
└── 📄utils.py # GLM-4-MoE工具函数和基础类
```

### 并行配置建议


+ 13
- 13
configs/qwen3_moe/README.md View File

@@ -136,7 +136,7 @@ train_dataset: &train_dataset

#### 3. 启动预训练任务

通过指定模型路径和配置文件[configs/qwen3/pretrain_qwen3_30b_a3b_4k.yaml](https://gitee.com/mindspore/mindformers/blob/master/configs/qwen3_moe/pretrain_qwen3_30b_a3b_4k.yaml)以msrun的方式启动[run_mindformer.py](https://gitee.com/mindspore/mindformers/blob/master/run_mindformer.py)脚本,进行16卡分布式训练。可参考如下方式拉起两台Atlas 800T A2(64G)训练。
通过指定模型路径和配置文件[configs/qwen3_moe/pretrain_qwen3_30b_a3b_4k.yaml](https://gitee.com/mindspore/mindformers/blob/master/configs/qwen3_moe/pretrain_qwen3_30b_a3b_4k.yaml)以`msrun`的方式启动[run_mindformer.py](https://gitee.com/mindspore/mindformers/blob/master/run_mindformer.py)脚本,进行16卡分布式训练。可参考如下方式拉起两台Atlas 800T A2(64G)训练。

在每台服务器上执行如下命令。设置`master_ip`为主节点IP地址,即`Rank 0`服务器的IP;`node_rank`为每个节点的序号;`port`为当前进程的端口号(可在50000~65536中选择)。

@@ -145,7 +145,7 @@ master_ip=192.168.1.1
node_rank=0
port=50001
bash scripts/msrun_launcher.sh "run_mindformer.py \
--config configs/qwen3_moe/pretrain_qwen3_moe_30b_a3b_4k.yaml \
--config configs/qwen3_moe/pretrain_qwen3_30b_a3b_4k.yaml \
--auto_trans_ckpt False \
--use_parallel True \
--run_mode train" \
@@ -211,16 +211,16 @@ parallel_config:

run_mindformer.py的参数说明如下:

| 参数 | 参数说明 |
|:-------------------------------|:----------------------------------------------------------|
| config | yaml配置文件的路径 |
| run_mode | 运行的模式,推理设置为predict |
| use_parallel | 是否使用多卡推理 |
| predict_data | 推理的输入数据,多batch推理时需要传入输入数据的txt文件路径,包含多行输入 |
| predict_batch_size | 多batch推理的batch_size大小 |
| pretrained_model_dir | Hugging Face模型目录路径,放置模型配置、Tokenizer等文件 |
| parallel_config.data_parallel | 数据并行,当前推理模式下设置为1 |
| parallel_config.model_parallel | 模型并行,默认值为 1。需根据实际模型规模及硬件资源情况,调整该参数为相应的device_nu(即实际使用的卡数) |
| 参数 | 参数说明 |
|:-------------------------------|:-----------------------------------------------------------|
| config | yaml配置文件的路径 |
| run_mode | 运行的模式,推理设置为predict |
| use_parallel | 是否使用多卡推理 |
| predict_data | 推理的输入数据,多batch推理时需要传入输入数据的txt文件路径,包含多行输入 |
| predict_batch_size | 多batch推理的batch_size大小 |
| pretrained_model_dir | Hugging Face模型目录路径,放置模型配置、Tokenizer等文件 |
| parallel_config.data_parallel | 数据并行,当前推理模式下设置为1 |
| parallel_config.model_parallel | 模型并行,默认值为 1。需根据实际模型规模及硬件资源情况,调整该参数为相应的device_npu(即实际使用的卡数) |

msrun_launcher.sh包括run_mindformer.py命令和推理卡数两个参数。

@@ -310,7 +310,7 @@ node_rank=0
port=50001

bash scripts/msrun_launcher.sh "run_mindformer.py \
--config configs/qwen3_moe/predict_qwen3_moe.yaml" \
--config configs/qwen3_moe/predict_qwen3_moe.yaml \
--run_mode predict \
--use_parallel True \
--pretrained_model_dir '/path/hf_dir' \


+ 10
- 2
convert_weight.py View File

@@ -52,7 +52,7 @@ reversed_convert_map = {
}


if __name__ == '__main__':
def main(args_list=None):
parser = argparse.ArgumentParser()
parser.add_argument('--model', default=None, type=str, required=True, help='model name')
parser.add_argument('--reversed', action='store_true', help="convert ms to hf")
@@ -65,7 +65,11 @@ if __name__ == '__main__':
help="Only for telechat. Telechat version.")
parser.add_argument('--is_lora', default=False, type=str2bool, required=False)

args, extra_args = parser.parse_known_args()
if args_list is not None:
args, extra_args = parser.parse_known_args(args_list)
else:
args, extra_args = parser.parse_known_args()

extra_args = [i
for item in extra_args
for i in item.split("=")]
@@ -103,3 +107,7 @@ if __name__ == '__main__':

merged_args = argparse.Namespace(**{**vars(args), **extra_kwargs})
convert_func(merged_args)


if __name__ == '__main__':
main()

+ 3
- 1
docs/api/api_python/core/mindformers.core.CheckpointMonitor.rst View File

@@ -7,7 +7,7 @@ mindformers.core.CheckpointMonitor

参数:
- **prefix** (str, 可选) - checkpoint文件的前缀名。默认值: ``CKP`` 。
- **directory** (str, 可选) - checkpoint文件将要保存的文件夹路径。默认值: ``None`` 。
- **directory** (str, 可选) - checkpoint文件将要保存的文件夹路径。默认值: ``None`` ,此时checkpoint文件将保存在当前工作目录下的 `./output/checkpoint` 文件夹中
- **config** (CheckpointConfig, 可选) - checkpoint的配置。默认值: ``None`` 。
- **save_checkpoint_steps** (int, 可选) - 每隔多少个step保存一次checkpoint。默认值: ``1`` 。
- **save_checkpoint_seconds** (int, 可选) - 每隔多少秒保存一次checkpoint。不能同时与 `save_checkpoint_steps` 一起使用。默认值: ``0`` 。
@@ -29,6 +29,8 @@ mindformers.core.CheckpointMonitor
- **use_checkpoint_health_monitor** (bool, 可选) - 是否开启通过embedding norm来进行权重健康监测的功能。默认值: ``False`` 。
- **embedding_local_norm_threshold** (float, 可选) - embedding norm的阈值。默认值: ``1.0`` 。
- **health_ckpts_record_dir** (str, 可选) - 记录权重健康状态文件的保存地址。默认值: ``./output`` 。
- **use_legacy_format** (bool, 可选) - 是否使用旧版权重保存格式,默认值:``True``。
- **save_optimizer** (bool, 可选) - 是否保存优化器权重,仅用于新版权重保存流程。旧版流程中不启用该配置,保持为 ``None`` 。默认值:``True``。


异常:


+ 9
- 9
docs/api/api_python/core/mindformers.core.TrainingStateMonitor.rst View File

@@ -1,7 +1,7 @@
mindformers.core.TrainingStateMonitor
=====================================

.. py:class:: mindformers.core.TrainingStateMonitor(origin_epochs, config=None, step_interval=1, dataset_size=None, initial_epoch=0, initial_step=0, global_batch_size=0, check_for_nan_in_loss_and_grad=False, use_skip_data_by_global_norm=False, embedding_size=4096, use_local_norm=False)
.. py:class:: mindformers.core.TrainingStateMonitor(origin_epochs, config=None, step_interval=1, dataset_size=None, initial_epoch=0, initial_step=0, micro_batch_num=0, global_batch_size=0, tensor_model_parallel_size=0, check_for_nan_in_loss_and_grad=False, use_skip_data_by_global_norm=False, embedding_size=4096, use_local_norm=False)

监控训练过程中指标变化的回调函数。

@@ -11,20 +11,20 @@ mindformers.core.TrainingStateMonitor

- ``"target"`` - 指定要监控的参数的命名或正则表达式。必须是字符串列表,例如["layers.[01]", "attention"]。默认值: ``[".*"]`` ,即选择所有参数。
- ``"invert"`` - 反选target指定的参数,即target指定的参数不会被监控。默认值: ``False`` 。
- ``"local_norm_format"`` - 决定local norm的展示方式。必须是字符串'tensorboard'、'log'之一(分别代表写入tensorboard、日志),或包含它们的列表,或 ``None`` 。只有指定的参数会被监控,选择 'log' 时可能引入大量打印信息。设置为 ``None`` 以忽略该指标。默认值:``None`` 。
- ``"device_local_norm_format"`` - 决定device local norm的展示方式。必须是字符串'tensorboard'、'log'之一(分别代表写入tensorboard、日志),或包含它们的列表,或 ``None`` 。设置为 ``None`` 以忽略该指标。默认值:``None`` 。
- ``"local_loss_format"`` - 决定local loss的展示方式。必须是字符串'tensorboard'、'log'之一(分别代表写入tensorboard、日志),或包含它们的列表,或 ``None`` 。设置为 ``None`` 以忽略该指标。默认值:``None`` 。
- ``"device_local_loss_format"`` - 决定device local loss的展示方式。必须是字符串'tensorboard'、'log'之一(分别代表写入tensorboard、日志),或包含它们的列表,或 ``None`` 。设置为 ``None`` 以忽略该指标。默认值:``None`` 。
- ``"optimizer_state_format"`` - 决定优化器状态的展示方式。必须是字符串'tensorboard'、'log'之一(分别代表写入tensorboard、日志),或包含它们的列表,或 ``None`` 。只有指定参数的优化器状态会被监控,选择 'log' 时可能引入大量打印信息。设置为 ``None`` 以忽略该指标。默认值:'tensorboard' 。
- ``"weight_state_format"`` - 决定权重L2-norm的展示方式。必须是字符串'tensorboard'、'log'之一(分别代表写入tensorboard、日志),或包含它们的列表,或 ``None`` 。设置为 ``None`` 以忽略该指标。默认值:'tensorboard' 。
- ``"throughput_baseline"`` - 模型吞吐量的基线,用于计算线性度。必须为正数。会同时写入日志文件和tensorboard。设置为 ``None`` 以忽略该指标。默认值: ``None`` 。
- ``"print_struct"`` - 是否打印模型结构。若是,则会在第一个step打印所有可训练参数的名字,并退出训练。默认值: ``False`` 。
- ``"local_norm_format"`` - 决定local norm的展示方式。必须是字符串'tensorboard'、'log'之一(分别代表写入tensorboard、日志),或包含它们的列表,或 ``None`` 。只有指定的参数会被监控,选择 'log' 时可能引入大量打印信息。设置为 ``None`` 以忽略该指标。默认值: ``None`` 。
- ``"device_local_norm_format"`` - 决定device local norm的展示方式。必须是字符串'tensorboard'、'log'之一(分别代表写入tensorboard、日志),或包含它们的列表,或 ``None`` 。设置为 ``None`` 以忽略该指标。默认值: ``None`` 。
- ``"local_loss_format"`` - 决定local loss的展示方式。必须是字符串'tensorboard'、'log'之一(分别代表写入tensorboard、日志),或包含它们的列表,或 ``None`` 。设置为 ``None`` 以忽略该指标。默认值: ``None`` 。
- ``"device_local_loss_format"`` - 决定device local loss的展示方式。必须是字符串'tensorboard'、'log'之一(分别代表写入tensorboard、日志),或包含它们的列表,或 ``None`` 。设置为 ``None`` 以忽略该指标。默认值: ``None`` 。
- ``"optimizer_state_format"`` - 决定优化器状态的展示方式。必须是字符串'tensorboard'、'log'之一(分别代表写入tensorboard、日志),或包含它们的列表,或 ``None`` 。只有指定参数的优化器状态会被监控,选择 'log' 时可能引入大量打印信息。设置为 ``None`` 以忽略该指标。默认值: ``None`` 。
- ``"weight_state_format"`` - 决定权重L2-norm的展示方式。必须是字符串'tensorboard'、'log'之一(分别代表写入tensorboard、日志),或包含它们的列表,或 ``None`` 。设置为 ``None`` 以忽略该指标。默认值: ``None`` 。

- **step_interval** (int, 可选) - 每多少次step对指标进行展示。默认值: ``1`` 。
- **dataset_size** (int, 可选) - 数据下沉模式必选。训练的数据集数量。默认值: ``None`` 。
- **initial_epoch** (int, 可选) - 训练开始的epoch数。默认值: ``0`` 。
- **initial_step** (int, 可选) - 训练开始的step数。默认值: ``0`` 。
- **micro_batch_num** (int, 可选) - 流水线并行时设置的MicroBatch大小。默认值: ``0`` 。
- **global_batch_size** (int, 可选) - 总BatchSize大小。默认值: ``0`` 。
- **tensor_model_parallel_size** (int, 可选) - 张量并行的切分数量。默认值: ``0`` 。
- **check_for_nan_in_loss_and_grad** (bool, 可选) - 是否检查损失和梯度存在Nan。默认值: ``False`` 。
- **use_skip_data_by_global_norm** (bool, 可选) - 是否开启通过global norm来进行数据跳过的功能。默认值: ``False`` 。
- **embedding_size** (int, 可选) - 通过hidden_size * vocab_size来计算embedding norm的size。默认值: ``4096`` 。

+ 2
- 2
docs/api/api_python/models/mindformers.models.LlamaForCausalLM.rst View File

@@ -3,10 +3,10 @@ mindformers.models.LlamaForCausalLM

.. py:class:: mindformers.models.LlamaForCausalLM(config=None)

在线计算并提供执行LLama训练时的损失值和逻辑值。
在线计算并提供执行Llama训练时的损失值和逻辑值。

参数:
- **config** (LlamaConfig, 可选) - LLama模型的配置。默认值: ``None`` 。
- **config** (LlamaConfig, 可选) - Llama模型的配置。默认值: ``None`` 。

输入:
- **input_ids** (Tensor) - 数据类型为Int64/Int32的词汇表中输入序列标记的索引,张量的形状为::math:`(batch, seq\_length)`。


+ 1
- 1
docs/api/api_python/tools/mindformers.tools.MindFormerConfig.rst View File

@@ -3,7 +3,7 @@ mindformers.tools.MindFormerConfig

.. py:class:: mindformers.tools.MindFormerConfig(*args, **kwargs)

一个配置类,继承于Python的dict类。可以解析来自yaml文件或dict实例的配置参数。
一个配置类,继承于Python的dict类。可以解析来自yaml文件或dict实例的配置参数。

参数:
- **args** (Any) - 可扩展参数列表,可以是yaml配置文件路径或配置字典。


+ 9
- 9
docs/model_cards/glm4.md View File

@@ -74,7 +74,7 @@ MindSpore Transformers 提供 `alpaca` 数据集示例处理脚本制作[全参

| 数据集名称 | 适用模型 | 适用阶段 | 下载链接 |
|:-------------|:-------:|:--------:|:------------------------------------------------------------------------------------------:|
| alpaca | glm4-9b | Finetune | [Link](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json) |
| alpaca | GLM-4-9B | Finetune | [Link](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json) |

数据预处理中所用的 `tokenizer.model` 可以参考[模型权重下载](#模型权重下载)进行下载。

@@ -138,7 +138,7 @@ MindSpore TransFormers 提供已经转换完成的预训练权重、词表文件
mlp_concat: False
```

2. 执行 mindforers 根目录下的 `convert_weight.py` [转换脚本](https://gitee.com/mindspore/mindformers/blob/master/convert_weight.py),将 HuggingFace 的权重转换为完整的 MindSpore ckpt 权重。
2. 执行 mindformers 根目录下的 `convert_weight.py` [转换脚本](https://gitee.com/mindspore/mindformers/blob/master/convert_weight.py),将 HuggingFace 的权重转换为完整的 MindSpore ckpt 权重。

```shell
python convert_weight.py --model glm4 --input_path HF_CKPT_PATH --output_path MS_NOT_CONCAT_CKPT_PATH --dtype DTYPE --config YAML_PATH
@@ -201,13 +201,13 @@ bash scripts/examples/glm4/run_glm4_predict.sh PARALLEL CONFIG_PATH CKPT_PATH TO

参数说明如下表:

| 参数名 | 含义 | 取值说明 |
|-------------|--------------------------|---------------------------------------------------------------------------------------------|
| PARALLEL | 指定选择推理模式为单卡推理 or 多卡推理。 | (str, 必选) - 单卡推理配置为 `single` ,多卡推理配置为 `parallel` 。 |
| CONFIG_PATH | 模型配置文件路径。 | (str, 必选) - 如 `/path/to/glm4/predict_glm4_9b_chat.yaml` 。 |
| CKPT_PATH | 推理时用到的模型权重文件路径。 | (str, 必选) - 单卡为完整权重,双卡为分布式权重。<br>如单卡推理 `/path/to/glm4.ckpt`,多卡推理 `/path/to/glm4_ckpt_dir` 。 |
| TOKENIZER | gml4 模型的 tokenizer 文件路径。 | (str, 必选) - 如 `/path/to/tokenizer.model` 。 |
| DEVICE_NUM | 指定多卡推理的卡数。 | (int, 可选) - 多卡推理时必须指定推理卡数。<br>如双卡推理时,则配置为 `2` 。 |
| 参数名 | 含义 | 取值说明 |
|-------------|---------------------------|---------------------------------------------------------------------------------------------|
| PARALLEL | 指定选择推理模式为单卡推理或多卡推理。 | (str, 必选) - 单卡推理配置为 `single` ,多卡推理配置为 `parallel` 。 |
| CONFIG_PATH | 模型配置文件路径。 | (str, 必选) - 如 `/path/to/glm4/predict_glm4_9b_chat.yaml` 。 |
| CKPT_PATH | 推理时用到的模型权重文件路径。 | (str, 必选) - 单卡为完整权重,双卡为分布式权重。<br>如单卡推理 `/path/to/glm4.ckpt`,多卡推理 `/path/to/glm4_ckpt_dir` 。 |
| TOKENIZER | GLM-4 模型的 tokenizer 文件路径。 | (str, 必选) - 如 `/path/to/tokenizer.model` 。 |
| DEVICE_NUM | 指定多卡推理的卡数。 | (int, 可选) - 多卡推理时必须指定推理卡数。<br>如双卡推理时,则配置为 `2` 。 |

运行如下命令进行推理:



+ 4
- 4
docs/security_statement.md View File

@@ -6,11 +6,11 @@

### 运行用户建议

出于安全性及权限最小化角度考虑,不建议使用root等管理员类型账户使用
出于安全性及权限最小化角度考虑,不建议使用root等管理员类型账户。

### 安全隐私声明

请用户在使用个人数据时遵从当地适用的法律法规。
请用户在使用个人数据时遵从当地适用的法律法规。

### 文件权限控制

@@ -20,7 +20,7 @@
表1 文件(夹)各场景权限管控推荐最大值

| 类型 | linux权限参考最大值 |
|-------------------|----------------|
|----------------|----------------|
| 用户主目录 | 750(rwxr-x---) |
| 程序文件(含脚本文件、库文件等) | 550(r-xr-x---) |
| 程序文件目录 | 550(r-xr-x---) |
@@ -35,7 +35,7 @@
| 维护升级文件目录 | 770(rwxrwx---) |
| 业务数据文件 | 640(rw-r-----) |
| 业务数据文件目录 | 750(rwxr-x---) |
| 密钥组件、私钥、证书、密文文件目录 | 700(rwx—----) |
| 密钥组件、私钥、证书、密文文件目录 | 700(rwx------) |
| 密钥组件、私钥、证书、加密密文 | 600(rw-------) |
| 加解密接口、加解密脚本 | 500(r-x------) |



+ 3
- 4
docs/transformer仓Python编程规范.md View File

@@ -2,8 +2,7 @@

本规范以[PEP8](https://www.python.org/dev/peps/pep-0008/)为基础,参考华为Python通用编码规范、安全编程规范,并结合业界共识整理而成,参与MindSpore社区开发需要首先遵循本规范内容(与PEP8冲突部分),其余遵循PEP8规范。

如果对规则异议,建议提交issue并说明理由,经MindSpore社区运营团队评审接纳后可修改生效。
a
如果对规则有异议,建议提交 issue 并说明理由,经MindSpore社区运营团队评审接纳后可修改生效。

## 适用范围

@@ -105,7 +104,7 @@ import xxx

### 2.2 数据校验

<font size=3>**规则 2.2.1 对所有外部数据进行合法性检查,包括但不限于:函数入参、外部输入命行、文件格式,文件大小、环境变量、用户数据等。**</font>
<font size=3>**规则 2.2.1 对所有外部数据进行合法性检查,包括但不限于:函数入参、外部输入命行、文件格式,文件大小、环境变量、用户数据等。**</font>

<font size=3>**建议 2.2.2 必须对文件路径进行规范化后再使用。**</font>

@@ -212,7 +211,7 @@ import xxx

【例外情况】:

1. 在资源释放失败不会影响程序后续行为的情况下,释放资源时发生的异常可以被抑制。释放资源的例子包括关闭文件、网络套接字、线程等等。这些资源通常是在except或者fianlly块中被释放,并且在后续的程序运行中都不会再被使用。因此,除非资源被耗尽,否则不会有其他途径使得这些异常会影响程序后续的行为。在充分处理了资源耗尽问题的情况下,只需对异常进行净化和记录日志(以备日后改进)就足够了;在这种情况下没必要做其他额外的错误处理。
1. 在资源释放失败不会影响程序后续行为的情况下,释放资源时发生的异常可以被抑制。释放资源的例子包括关闭文件、网络套接字、线程等等。这些资源通常是在except或者finally块中被释放,并且在后续的程序运行中都不会再被使用。因此,除非资源被耗尽,否则不会有其他途径使得这些异常会影响程序后续的行为。在充分处理了资源耗尽问题的情况下,只需对异常进行净化和记录日志(以备日后改进)就足够了;在这种情况下没必要做其他额外的错误处理。
2. 如果在特定的抽象层次上不可能从异常情况中恢复过来,则在那个层级的代码就不用处理这个异常,而是应该抛出一个合适的异常,让更高层次的代码去捕获处理,并尝试恢复。对于这种情况,最通常的实现方法是省略掉catch语句块,允许异常被广播出去。

<font size=3>**规则 2.3.2 使用try…except…结构对代码作保护时,需要在异常后使用finally…结构保证操作对象的释放。**</font>


+ 0
- 1
mindformers/__init__.py View File

@@ -157,7 +157,6 @@ from mindformers.modules import (
AlibiTensorV2,
Dropout,
EmbeddingOpParallelConfig,
FeedForward,
FixedSparseAttention,
LayerNorm,
Linear,


+ 189
- 0
mindformers/checkpoint/broadcast.py View File

@@ -0,0 +1,189 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""broadcast params across redundant rank groups."""
import numpy as np

from mindspore import context, Tensor
from mindspore.nn import Cell
from mindspore.ops import operations as P
from mindspore.communication import create_group, destroy_group
from mindspore.communication._comm_helper import _get_group_map, _remove_group_info
from mindspore.runtime import synchronize
from mindspore.mint.distributed import all_gather_object

from mindformers.tools.utils import get_real_rank, get_real_group_size
from mindformers.tools.logger import logger


class SingleCommunicator(Cell):
"""
Used to broadcast single parameter.
"""

def __init__(self, group_name):
super().__init__()
self.allreduce = P.AllReduce(group=group_name)
self.add_flags(skip_auto_parallel_compile=True)

def construct(self, loaded_param):
result = self.allreduce(loaded_param)
return result


def _change_parallel_context(origin_dataset_strategy):
"""Change the original parallel state."""
context.set_auto_parallel_context(parallel_mode="hybrid_parallel")
if origin_dataset_strategy != "data_parallel":
context.set_auto_parallel_context(dataset_strategy="data_parallel")


def _get_sorted_group_map():
"""Get the world group map."""
group_map = _get_group_map()
if group_map:
group_map = {key: group_map[key] for key in sorted(group_map.keys())}
return group_map


def _get_param_index_in_group(total_param_loaded, group, param):
"""Get param_index in group."""
param_rank_index = []
for rank_id in group:
if rank_id < len(total_param_loaded):
if param in total_param_loaded[rank_id]:
param_rank_index.append(rank_id)
else:
raise ValueError("rank_id should be smaller than total rank num")
return param_rank_index


def _remove_param_not_load(param_name, param_not_load):
"""Remove param_name from param_not_load."""
if param_not_load is not None and param_name in param_not_load:
param_not_load.remove(param_name)


def _create_allreduce_input(params, group, net_param_dict, total_param_loaded, param_not_load, cur_rank):
"""Creates allreduce input."""
allreduce_input = []
for param in params:
if param not in net_param_dict:
continue
if param.startswith("accu_grads") or param.endswith("expert_load"):
continue
param_rank_index = _get_param_index_in_group(total_param_loaded, group, param)
if not param_rank_index:
continue
if len(param_rank_index) == 1:
real_param = net_param_dict[param]
_remove_param_not_load(real_param.name, param_not_load)
if cur_rank != param_rank_index[0]:
real_param.set_data(Tensor(np.zeros(real_param.shape), dtype=real_param.dtype), real_param.sliced)
allreduce_input.append(real_param)
elif len(param_rank_index) > 1:
raise ValueError(f"For param {param} in group {group} should be in one rank, but in {param_rank_index}.")
return allreduce_input


def _get_group_name(group_map, group):
"""get group name"""
group_name = "remove_redundancy" + str(group)
is_manual_communication_group = True
if group_map:
for name, rank_list in group_map.items():
if list(group) == rank_list:
group_name = name
is_manual_communication_group = False
break
return group_name, is_manual_communication_group


def _communicate_allreduce(allreduce_input, group_map, group):
"""Communicate allreduce input."""
if not allreduce_input:
return
group_name, is_manual_communication_group = _get_group_name(group_map, group)
if is_manual_communication_group:
create_group(group_name, list(group))
communicator = SingleCommunicator(group_name)
for real_param in allreduce_input:
real_param.set_data(communicator(Tensor(real_param)), real_param.sliced)
if is_manual_communication_group:
destroy_group(group_name)
_remove_group_info(group_name)


def _restore_parallel_context(origin_parallel_mode, origin_dataset_strategy):
"""Restore the original parallel state."""
context.set_auto_parallel_context(parallel_mode=origin_parallel_mode)
if origin_dataset_strategy != "data_parallel":
if origin_dataset_strategy is not None and isinstance(origin_dataset_strategy, list):
origin_dataset_strategy = tuple(tuple(ds_item) for ds_item in origin_dataset_strategy)
context.set_auto_parallel_context(dataset_strategy=origin_dataset_strategy)


def single_parameter_broadcast(net, param_redundancy, param_not_load, param_loaded):
"""
Broadcasts unique parameter values across redundant rank groups to eliminate duplicate parameter storage.

This function coordinates parameter sharing among ranks that hold redundant copies of the same parameters
(identified by `param_redundancy`). It temporarily adjusts parallel execution contexts, synchronizes the
loading status of parameters across all ranks, and performs allreduce communication within redundant groups
to ensure all ranks in a group access the same parameter values—thus removing redundant parameter storage.

Core workflow:
1. Capture the current rank ID and original parallel context configurations (to restore later).
2. Retrieve the network's parameter dictionary and adjust parallel context settings for communication.
3. Track the current rank's parameter loading status and synchronize this status across all ranks via allgather.
4. For each group of ranks with redundant parameters:
a. Filter parameters to only those present in the network's parameter dictionary.
b. Create input data for allreduce communication, including valid parameter values from loaded ranks.
c. Execute allreduce to broadcast consistent parameter values to all ranks in the group.
5. Restore the original parallel context settings and synchronize all ranks to complete the process.

Args:
net: MindSpore Network instance containing the parameters to be broadcasted and synchronized.
param_redundancy (dict): Mapping of redundant rank groups (tuples of rank IDs) to lists of parameter keys.
Each entry represents parameters that are duplicated across the ranks in the corresponding group.
param_not_load (set): Set of parameter keys that have not been loaded by any rank (excluded from broadcast).
param_loaded (set): Set tracking rank IDs that have successfully loaded their assigned parameters.
Updated to include the current rank before synchronizing loading status.
"""
cur_rank = get_real_rank()
origin_parallel_mode = context.get_auto_parallel_context("parallel_mode")
origin_dataset_strategy = context.get_auto_parallel_context("dataset_strategy")

net_param_dict = net.parameters_dict()
_change_parallel_context(origin_dataset_strategy)
param_loaded.add(cur_rank)
total_num = get_real_group_size()
total_param_loaded = [None] * total_num
synchronize()
all_gather_object(total_param_loaded, param_loaded)
logger.debug("Total params loaded:")
for param_loaded_ in total_param_loaded:
logger.debug(param_loaded_)

group_map = _get_sorted_group_map()
for group, params in param_redundancy.items():
logger.debug(f"Rank group: {group}")
logger.debug(f"AllReduce params: {params}")
params = [param for param in params if param in set(net_param_dict.keys())]
allreduce_input = _create_allreduce_input(
params, group, net_param_dict, total_param_loaded, param_not_load, cur_rank)
_communicate_allreduce(allreduce_input, group_map, group)
_restore_parallel_context(origin_parallel_mode, origin_dataset_strategy)
synchronize()
logger.info("End loading the parameter broadcast for removing redundant parameters.")

+ 106
- 71
mindformers/checkpoint/checkpoint.py View File

@@ -66,6 +66,7 @@ from mindformers.checkpoint.sharded_tensor import (
get_strategy_info_from_sharded_tensor,
ShardedTensor, get_sharded_tensor_list_from_cell, get_cur_sharded_tensor
)
from mindformers.checkpoint.broadcast import single_parameter_broadcast


@dataclass
@@ -832,32 +833,17 @@ def get_metadata_of_checkpoint(checkpoint_dir: str) -> tuple[dict, dict]:
def params_key_mapping(
sharded_tensor_metas: Dict[str, List[ShardedTensor]],
network: Cell
) -> tuple[dict, dict, Cell]:
) -> tuple[dict, dict]:
"""
Mapping Hugging Face checkpoint keys to MindSpore Transformers.

Args:
sharded_tensor_metas: Metadata about sharded tensors.
network: The target core network (Cell) which has method `convert_name` to convert Hugging Face weight.
network: The network (Cell) which has method `convert_name` to convert Hugging Face weight.

Returns:
A dictionary after mapping about sharded tensor metas.
"""

# pylint: disable=W0212
def get_core_network(network):
"""Get the core network that has `convert_name` method."""
if hasattr(network, 'convert_name'):
return network
if hasattr(network, '_backbone'):
return get_core_network(network._backbone)
if hasattr(network, 'network'):
return get_core_network(network.network)
raise NotImplementedError("Network has no function `convert_name`.")

# Get the core network and check the convert method is illegal
core_network = get_core_network(network)

# The key of `mapped_sharded_tensor_metas` is in the network,
# such as { qkv: [ShardedTensor, ShardedTensor, ShardedTensor], ... }
mapped_sharded_tensor_metas = {}
@@ -868,7 +854,7 @@ def params_key_mapping(
key_mapping = {}

for param_name in sharded_tensor_metas:
param_name_converted = core_network.convert_name(param_name)
param_name_converted = network.convert_name(param_name)
sharded_tensor_list = sharded_tensor_metas.get(param_name)

for sharded_tensor in sharded_tensor_list:
@@ -876,10 +862,20 @@ def params_key_mapping(
sharded_tensor.org_key = param_name

key_mapping[param_name] = param_name_converted
param_name_converted_concat = core_network.convert_concat_name(param_name_converted)
param_name_converted_concat = network.convert_concat_name(param_name_converted)
mapped_sharded_tensor_metas.setdefault(param_name_converted_concat, []).extend(sharded_tensor_list)

return mapped_sharded_tensor_metas, key_mapping, core_network
return mapped_sharded_tensor_metas, key_mapping


# pylint: disable=W0212
def get_core_network(network):
"""Get the core network that has `convert_name` method."""
if hasattr(network, '_backbone'):
return get_core_network(network._backbone)
if hasattr(network, 'network'):
return get_core_network(network.network)
return network


def load_checkpoint(
@@ -916,17 +912,11 @@ def load_checkpoint(

# Retrieve metadata from checkpoint files
src_sharded_tensor_metas, param_file_mappings = get_metadata_of_checkpoint(checkpoint_dir)
# Mapping the weight keys, which is used to determine whether to load the Hugging Face weights.

try:
src_sharded_tensor_metas, key_mapping, core_network = params_key_mapping(src_sharded_tensor_metas, network)
# Validate the returned values
if not isinstance(src_sharded_tensor_metas, dict) or not isinstance(key_mapping, dict) or core_network is None:
raise ValueError("Mapping the params sharded metas failed.")
except NotImplementedError as e:
raise NotImplementedError(
f"Network '{type(network).__name__}' does not have the method to convert Hugging Face weights. "
"Please ensure the network or its backbone implements this method.") from e
# Get the core network and check the convert method is illegal
network = get_core_network(network)
# Mapping the weight keys, which is used to determine whether to load the Hugging Face weights.
src_sharded_tensor_metas, key_mapping = params_key_mapping(src_sharded_tensor_metas, network)

if not src_sharded_tensor_metas or not param_file_mappings:
raise RuntimeError(
@@ -940,8 +930,9 @@ def load_checkpoint(
return "accu_grads" not in param_name
return param_name in list(network.parameters_dict().keys())

param_redundancy = None
if balanced_load:
dst_sharded_tensor_metas = apply_balance_shard_strategy(network, filter_func)[-1]
dst_sharded_tensor_metas, param_redundancy = apply_balance_shard_strategy(network, filter_func)[2:]
else:
if get_real_group_size() > 1:
cur_rank_sharded_tensors = get_cur_sharded_tensor(network, filter_func)
@@ -961,7 +952,7 @@ def load_checkpoint(
state_dict: Dict[str, Parameter] = {}

# Concat parameters
concat_params(checkpoint_dir, core_network, key_mapping, need_concat_params, state_dict)
concat_params(checkpoint_dir, network, key_mapping, need_concat_params, state_dict)

# Load parameters that don't require resharding
for file_name, param_info in no_shard_params.items():
@@ -999,12 +990,18 @@ def load_checkpoint(
)

# Load state dictionary into network and optimizer
load_parameters(network, state_dict, optimizer, balanced_load=balanced_load)
load_parameters(
network,
state_dict,
optimizer,
balanced_load=balanced_load,
param_redundancy=param_redundancy
)


def concat_params(checkpoint_dir: str, core_network, key_mapping: dict, need_concat_params, state_dict: dict):
def concat_params(checkpoint_dir: str, network: Cell, key_mapping: dict, need_concat_params, state_dict: dict):
"""Concat the need_concat_params dict in checkpoint."""
if need_concat_params and not hasattr(core_network, 'convert_hf_weight'):
if need_concat_params and not hasattr(network, 'convert_hf_weight'):
raise NotImplementedError("The `convert_hf_weight` method of network is not implemented.")

for param_name, concat_info in need_concat_params.items():
@@ -1025,7 +1022,7 @@ def concat_params(checkpoint_dir: str, core_network, key_mapping: dict, need_con
for k, v in org_weight_dict.items()
}
# Concat the weight.
concated_weight = core_network.convert_hf_weight(concat_dict)
concated_weight = network.convert_hf_weight(concat_dict)

if reshard_info:
# Get the offset of the Tensor to reshard.
@@ -1055,59 +1052,97 @@ def load_parameters(
state_dict: Dict[str, Parameter],
optimizer: Optional[Cell] = None,
state_dict_opt: Optional[Dict[str, Parameter]] = None,
balanced_load: Optional[bool] = False
balanced_load: Optional[bool] = False,
param_redundancy: Optional[Dict[Tuple, str]] = None
):
"""
Loads parameters into network and optimizer.

Separates network-specific and optimizer-specific parameters from the input state dictionaries,
loads them into their respective components, and returns lists of parameters that couldn't be loaded.
Filters out cache-related parameters from the unloaded network parameters list.
Loads parameters into a MindSpore network and optional optimizer, with support for redundant parameter handling.

This function separates network-specific and optimizer-specific parameters from input state dictionaries,
loads them into their respective components, and provides detailed logging of unloaded parameters. When
`balanced_load` is enabled, it leverages shard balancing and parameter broadcasting to eliminate redundant
parameter storage across ranks, improving memory efficiency in distributed training scenarios.

Core workflow:
1. Initialize optimizer state dictionary if not provided.
2. (If balanced load enabled) Generate parameter redundancy map via shard balancing if not explicitly provided.
3. Separate parameters from the main state dict into network-specific and optimizer-specific (state_dict_opt).
4. Load network parameters, track unloaded parameters, and filter out cache-related entries from unloaded logs.
5. (If balanced load enabled) Broadcast redundant parameters across ranks to ensure consistency.
6. Load optimizer parameters (if optimizer and state_dict_opt are provided) and apply balanced load if enabled.
7. Log detailed information about loaded/unloaded parameters for both network and optimizer.

Args:
network: The target network Cell to load parameters into
state_dict: Dictionary containing network parameters to load
optimizer: Optional optimizer Cell to load optimizer parameters into
state_dict_opt: Optional dictionary containing optimizer parameters to load
network (Cell): Target MindSpore Network Cell to load parameters into. Must be a valid Cell instance.
state_dict (Dict[str, Parameter]): Dictionary containing network parameters to load. Keys must match
parameter names in the network (or optimizer, for parameters to be redirected).
optimizer (Optional[Cell]): Optional MindSpore Optimizer Cell to load optimizer-specific parameters into.
If provided, must be a valid Cell instance.
state_dict_opt (Optional[Dict[str, Parameter]]): Optional dictionary containing optimizer parameters to load.
Initialized as an empty dict if not provided.
balanced_load (Optional[bool]): Whether to enable balanced loading with redundant parameter elimination.
When True, uses `apply_balance_shard_strategy` to identify redundant parameters and
`single_parameter_broadcast` to synchronize values across ranks. Defaults to False.
param_redundancy (Optional[Dict[Tuple[int, ...], List[str]]]): Precomputed mapping of redundant rank groups
(tuples of rank IDs) to lists of parameter keys. Only used if `balanced_load` is True; if not provided,
generated dynamically via `apply_balance_shard_strategy`. Defaults to None.

Raises:
ValueError: If network is not a Cell, state_dict is invalid, state_dict_opt is not a dict,
or optimizer is provided but is not a Cell
ValueError: If `network` is not a valid MindSpore Cell, `state_dict` is invalid (e.g., not a dict),
`state_dict_opt` is provided but not a dict, or `optimizer` is provided but not a valid Cell.
RuntimeError: If parameter loading fails due to mismatched keys or invalid parameter types (propagated from
`load_param_into_net`).
"""
def split_state_dict(network, state_dict, optimizer, state_dict_opt):
"""split state dict"""
network_param_names = set(network.parameters_dict().keys())
optimizer_param_names = set(optimizer.parameters_dict().keys()) if optimizer else set()
for param_name in list(state_dict.keys()):
if param_name not in network_param_names and param_name in optimizer_param_names and \
param_name not in state_dict_opt:
state_dict_opt[param_name] = state_dict.pop(param_name)
return network_param_names, optimizer_param_names, state_dict, state_dict_opt

def print_not_load_info(param_list: List, param_info: str):
if not param_list:
logger.info(f"All {param_info} are loaded.")
return

logger.info(f"{param_info} not loaded:")
for p in param_list:
logger.info(f" - {p}")

state_dict_opt: Dict[str, Parameter] = {} if not state_dict_opt else state_dict_opt

# Separate network and optimizer parameters
network_param_names = set(network.parameters_dict().keys())
optimizer_param_names = set(optimizer.parameters_dict().keys()) if optimizer else set()
for param_name in list(state_dict.keys()):
if param_name not in network_param_names and param_name in optimizer_param_names \
and param_name not in state_dict_opt:
state_dict_opt[param_name] = state_dict.pop(param_name)
if balanced_load and param_redundancy is None:
param_redundancy = apply_balance_shard_strategy(network)[-1]

network_param_names, _, state_dict, state_dict_opt = \
split_state_dict(network, state_dict, optimizer, state_dict_opt)

# Load parameters into network
param_not_load, ckpt_not_load = [], []
logger.debug(f"Network state_dict keys: {list(state_dict.keys())}")
param_not_load, ckpt_not_load = load_param_into_net(network, state_dict, remove_redundancy=balanced_load)
logger.info(f"Network parameters not loaded: {list(param_not_load)}")
logger.info(f"Checkpoint weights not loaded: {list(ckpt_not_load)}")

# Filter out cache parameters from unloaded list
param_not_load, ckpt_not_load = load_param_into_net(network, state_dict, strict_load=True)
if balanced_load:
param_loaded = {param_name for param_name in state_dict if param_name not in ckpt_not_load}
single_parameter_broadcast(network, param_redundancy, param_not_load, param_loaded)
# Filter out cache and optimizer parameters from unloaded list
param_not_load = [p for p in param_not_load if "key_cache" not in p and "value_cache" not in p]
print_not_load_info(param_not_load, "Network parameters")
print_not_load_info(ckpt_not_load, "Checkpoint weights")

# Load parameters into optimizer if available
param_not_load_opt, ckpt_not_load_opt = [], []
if optimizer and state_dict_opt:
for param_name in list(state_dict.keys()):
if param_name in optimizer_param_names and param_name not in state_dict_opt:
state_dict_opt[param_name] = state_dict.pop(param_name)
logger.debug(f"Optimizer state_dict keys: {list(state_dict_opt.keys())}")
param_not_load_opt, ckpt_not_load_opt = load_param_into_net(
optimizer,
state_dict_opt,
remove_redundancy=balanced_load
)
logger.info(f"Optimizer parameters not loaded: {list(param_not_load_opt)}")
logger.info(f"Optimizer weights not loaded: {list(ckpt_not_load_opt)}")
param_not_load_opt, ckpt_not_load_opt = load_param_into_net(optimizer, state_dict_opt, strict_load=True)
if balanced_load:
param_loaded_opt = {param_name for param_name in state_dict_opt if param_name not in ckpt_not_load_opt}
single_parameter_broadcast(optimizer, param_redundancy, param_not_load_opt, param_loaded_opt)

param_not_load_opt = [p for p in param_not_load_opt if p not in network_param_names]
print_not_load_info(param_not_load_opt, "Optimizer parameters")
print_not_load_info(ckpt_not_load_opt, "Optimizer weights")


def get_checkpoint_path(checkpoint: str) -> str:


+ 36
- 20
mindformers/checkpoint/fully_parallel.py View File

@@ -15,6 +15,7 @@
"""save / load parallelization strategy."""
import os
from collections import defaultdict
from typing import Callable

from mindspore import save_checkpoint
from mindspore.communication import get_rank
@@ -32,7 +33,7 @@ from mindformers.checkpoint.utils import (
_get_shard_size,
sharded_tensor_shard_id
)
from mindformers.tools.utils import get_real_local_rank
from mindformers.tools.utils import get_real_rank


class BalancedSaveStrategy():
@@ -164,7 +165,7 @@ class BalancedSaveStrategy():
if self.do_cache_distribution and self.cached_distribution is not None:
shared_distribution = self.cached_distribution
else:
shard_id_to_ranks, shard_id_to_tensor, _ = apply_balance_shard_strategy(self.network, self.filter_func)
shard_id_to_ranks, shard_id_to_tensor = apply_balance_shard_strategy(self.network, self.filter_func)[:2]
shared_distribution = (shard_id_to_ranks, shard_id_to_tensor)

if self.do_cache_distribution:
@@ -349,28 +350,35 @@ def distribute_shards(shard_coverage, shard_sizes, total_ranks):
return shard_assignment


def apply_balance_shard_strategy(network: Cell, filter_func):
def apply_balance_shard_strategy(network: Cell, filter_func: Callable[[str], bool] = None):
"""
Process and balance sharded tensor metadata across all ranks.
Distributes and balances sharded tensor metadata across all ranks in a parallel group.

This function retrieves strategy metadata from the network (and optimizer if provided),
processes sharding information, and distributes shards across ranks to generate balanced
sharded tensor metadata. If no strategy metadata exists, it falls back to directly extracting
sharded tensors from the network and optimizer.
This function aggregates sharded tensor metadata from the input network (filtered by the provided
`filter_func`), calculates shard sizes, maps shards to their associated ranks, and distributes
shards to ranks in a load-balanced manner. It also identifies redundant shard copies across ranks
and returns key metadata for shard management.

Core workflow:
1. Collect all sharded tensor metadata from the network, filtered by `filter_func`.
2. Generate unique shard IDs for each tensor shard and track which ranks own each shard.
3. Calculate the size (in bytes) of each unique shard based on its local shape and data type.
4. Distribute shards to ranks using a load-balanced strategy via the `distribute_shards` function.
5. Compile metadata for the current rank, including assigned shards and redundant shard copies.

Args:
network (Cell): The MindSpore network cell containing parameters and sharding strategies.
optimizer (Optional[Optimizer]): Optional optimizer instance (if provided, filters out
accumulator gradient parameters from sharding metadata).
network (Cell): MindSpore Network Cell containing parameters and their sharding information.
filter_func: A filtering function to select specific tensors from the network (e.g., exclude
non-trainable parameters). Applied during sharded tensor collection via `get_all_sharded_tensor`.

Returns:
list: Balanced sharded tensor metadata for the current rank, either derived from
strategy metadata distribution or directly extracted from the network/optimizer.
Notes:
- Relies on MindSpore's `get_strategy_metadata` for strategy-based sharding info.
- Filters out "accu_grads" parameters when an optimizer is provided to avoid redundant sharding.
- Falls back to direct tensor extraction if no strategy metadata is available.
tuple: A 4-element tuple containing:
1. shard_to_saving_rank (dict): Mapping of unique shard IDs to the rank responsible for storing the shard.
2. shard_id_to_tensor (dict): Mapping of unique shard IDs to their corresponding sharded tensor metadata.
3. dst_sharded_tensor_metas (dict): Mapping of original tensor keys to sharded tensor metadata
assigned to the current local rank.
4. param_redundancy (dict): Mapping of rank groups (tuples of ranks) to lists of tensor keys.
Represents shards that exist redundantly across multiple ranks in the group.
"""
total_shard_metadata = get_all_sharded_tensor(network, filter_func)
shard_id_to_ranks = defaultdict(list)
@@ -399,8 +407,16 @@ def apply_balance_shard_strategy(network: Cell, filter_func):
)

dst_sharded_tensor_metas = {} # {shard_name: ShardTensor}
local_rank = get_real_local_rank()
local_rank = get_real_rank()
for shard_id, rank_id in shard_to_saving_rank.items():
if rank_id == local_rank:
dst_sharded_tensor_metas[_reverse_sharded_tensor_shard_id(shard_id)[0]] = shard_id_to_tensor[shard_id]
return shard_id_to_ranks, shard_id_to_tensor, dst_sharded_tensor_metas

param_redundancy = {}
for shard_id, rank_group in shard_id_to_ranks.items():
if len(rank_group) == 1:
continue
if local_rank in rank_group:
param_redundancy.setdefault(tuple(rank_group), []).append(_reverse_sharded_tensor_shard_id(shard_id)[0])

return shard_to_saving_rank, shard_id_to_tensor, dst_sharded_tensor_metas, param_redundancy

+ 15
- 8
mindformers/checkpoint/sharded_tensor.py View File

@@ -18,12 +18,11 @@ from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple, Union, Callable

import mindspore as ms
from mindspore.communication import get_group_size
from mindspore.nn import Cell
from mindspore.parallel.shard import _DistributedTensorInfo
from mindspore.parallel.strategy import get_current_strategy_metadata, get_strategy_metadata

from mindformers.tools.utils import get_real_rank
from mindformers.tools.utils import get_real_rank, get_real_group_size
from mindformers.tools.logger import logger


@@ -335,8 +334,11 @@ def get_replica_id_from_layout(param_infos: List[Dict]) -> List[List[int]]:
return replica_ids


def get_sharded_tensor_list_from_strategy_metadata(param_infos: List[Dict], cur_npu_rank: int,
filter_func: Callable[[str], bool]) -> Optional[List[ShardedTensor]]:
def get_sharded_tensor_list_from_strategy_metadata(
param_infos: List[Dict],
cur_npu_rank: int,
filter_func: Callable[[str], bool] = None
) -> Optional[List[ShardedTensor]]:
"""
Transform distributed strategy of a network to a list of ShardedTensor.

@@ -495,7 +497,10 @@ def convert_sharded_tensor_list_to_dict(
return sharded_tensor_dict


def get_all_sharded_tensor(network, filter_func) -> list:
def get_all_sharded_tensor(
network: Cell,
filter_func: Callable[[str], bool] = None
) -> List[List]:
"""Get all rank sharded tensors."""
logger.info(".........Get All Ranks' Strategy Metadata.........")
global_strategy_info = get_strategy_metadata(network)
@@ -503,9 +508,8 @@ def get_all_sharded_tensor(network, filter_func) -> list:
raise RuntimeError('`get_strategy_metadata` returns `None`, which indicates there is no strategy info. '
'Please check whether this is a distributed job.')

npu_nums = get_group_size()
npu_nums = get_real_group_size()
sharded_tensor_metas = []

for cur_npu_rank in range(0, npu_nums):
org_cur_rank_strategy_layout = global_strategy_info[cur_npu_rank]
cur_rank_strategy_layout = [
@@ -524,7 +528,10 @@ def get_all_sharded_tensor(network, filter_func) -> list:
return sharded_tensor_metas


def get_cur_sharded_tensor(network, filter_func):
def get_cur_sharded_tensor(
network: Cell,
filter_func: Callable[[str], bool] = None
) -> List:
"""Get current rank sharded tensors."""
logger.info(".........Get Current Strategy Metadata.........")
strategy_info = get_current_strategy_metadata(network)


+ 66
- 79
mindformers/core/callback/callback.py View File

@@ -58,6 +58,7 @@ from mindspore.communication.comm_func import all_gather_into_tensor, barrier
from mindspore.profiler import ProfilerLevel, schedule
from mindspore.utils import stress_detect

from mindformers.wrapper.wrapper import get_real_models
from mindformers.checkpoint.sharded_tensor import get_all_sharded_tensor
from mindformers.core.context.build_context import is_legacy_model
from mindformers.tools import get_output_root_path
@@ -216,11 +217,9 @@ def _get_max_eigenvalue(input_tensor, num_iter):
v_tensor = ms.ops.matmul(input_seq, u_tensor) # (b,n,n) * (n,1) = (b,n,1)
eigenvalue = ms.ops.matmul(v_tensor.transpose(-2, -1), u_tensor).squeeze() # (b,1,n) * (b,n,1) = b
v_norm = v_tensor.norm(dim=1, keepdim=True) # (b,1,1)
if (v_norm != 0).all():
u_tensor = v_tensor / v_norm
else:
return 0.0
return eigenvalue.asnumpy()
v_norm_safe = ms.ops.select(v_norm == 0, ms.ops.ones_like(v_norm), v_norm)
u_tensor = v_tensor / v_norm_safe
return eigenvalue


def _get_stable_rank(weight, num_iter):
@@ -230,11 +229,13 @@ def _get_stable_rank(weight, num_iter):
except Exception as e:
logger.warning(f"{weight.name} calculate max eigenvalue failed: {e}")
return 0.0, 0.0
eig = np.asarray(eig)
if (eig != 0.0).all():
f_norm = ms.ops.norm(weight, ord='fro', dim=(-2, -1))
return ms.ops.square(f_norm).asnumpy() / eig, eig
return 0.0, 0.0
f_norm_square = ms.ops.square(ms.ops.norm(weight, ord='fro', dim=(-2, -1)))
stable_rank = ms.ops.select(
eig != 0,
f_norm_square / eig,
ms.ops.zeros_like(eig)
)
return stable_rank.asnumpy(), eig.asnumpy()


def _get_optimizer_state(optim_params, filter_fn: Callable = None):
@@ -246,6 +247,11 @@ def _get_optimizer_state(optim_params, filter_fn: Callable = None):
return norms


def _is_positive_natural_number(x):
"""Check if it is a positive natural number"""
return isinstance(x, int) and x > 0


@MindFormerRegister.register(MindFormerModuleType.CALLBACK)
class MFLossMonitor(Callback):
"""
@@ -679,59 +685,35 @@ class TrainingStateMonitor(Callback):

- local_norm_format: Determine where to display the local norm.
Should be a `str` in ['tensorboard', 'log'] (mean that write data to tensorboard or log file),
or a `list` containing them, or ``None``. Only params specified will be monitored.
or a `list` containing them, or ``None``. Only params specified will be monitored.
may cause a large amount of print info if 'log' is selected.
Set to ``None`` to ignore this metric. Default: ``None``.

- device_local_norm_format: Determine where to display the device local norm.
Should be a `str` in ['tensorboard', 'log'] (mean that write data to tensorboard or log file),
or a `list` containing them, or ``None``. Set to ``None`` to ignore this metric. Default: ``None``.
or a `list` containing them, or ``None``. Set to ``None`` to ignore this metric. Default: ``None``.

- local_loss_format: Determine where to display the local loss.
Should be a `str` in ['tensorboard', 'log'] (mean that write data to tensorboard or log file),
or a `list` containing them, or ``None``. Set to ``None`` to ignore this metric.
or a `list` containing them, or ``None``. Set to ``None`` to ignore this metric.
Default: ``None``.

- device_local_loss_format: Determine where to display the device local loss.
Should be a `str` in ['tensorboard', 'log'] (mean that write data to tensorboard or log file),
or a `list` containing them, or ``None``. Set to ``None`` to ignore this metric.
or a `list` containing them, or ``None``. Set to ``None`` to ignore this metric.
Default: ``None``.

- optimizer_state_format: Determine where to display the optimizer state.
Should be a `str` in ['tensorboard', 'log'] (mean that write data to tensorboard or log file),
or a `list` containing them, or ``None``. Only the optimizer state of params specified
or a `list` containing them, or ``None``. Only the optimizer state of params specified
will be monitored, may cause a large amount of print info if 'log' is selected.
Set to ``None`` to ignore this metric. Default: ``None``.

- weight_state_format: Determine where to display the weight L2-norm.
Should be a `str` in ['tensorboard', 'log'] (mean that write data to tensorboard or log file),
or a `list` containing them, or ``None``. Set to ``None`` to ignore this metric.
or a `list` containing them, or ``None``. Set to ``None`` to ignore this metric.
Default: ``None``.

- stable_rank_config
- format: Determine where to display the weight stable_rank and max eigenvalue.
Should be a `str` in ['tensorboard', 'log'] (mean that write data to tensorboard or log file),
or a `list` containing them, or ``None``. Set to ``None`` to ignore this metric.
Default: ``None``.

- step_interval (int, optional): Specify the frequency of monitoring stable rank.
Must be natural number. Default: ``1``.

- target (list[str], optional): Specify the name or regular expression of params to calculate
stable rank. e.g. ["layers.[01]", "attention"]. Default: ``['*']``.

- do_aggregation (bool, optional): Whether to aggregate weight parameter when when it has been
sliced to calculate stable_rank and eigenvalue.. Default: ``False``.

- moe_show_mode (Literal['full', 'statistics', 'all'], optional): Only works when calculating
weight_stable_rank and weight_eigenvalue for MOE model (3D param). Set to `full` to list all
experts data. Set to `statistics` to show min, max, mean of all experts. Set to `all` to show
both original data and statistic data. Default: ``'all'``.

- power_iteration_num (int, optional): The power iteration method is used to approximate the max
eigenvalue.The more iterations performed, the closer the computed result is to the true value,
but the computational cost increases accordingly. Must be natural number. Default: ``5``.

- throughput_baseline: The model throughput baseline to calculate linearity. Must be a positive number.
Will be displayed both to tensorboard and log. Set to ``None`` to ignore this metric. Default: ``None``.

@@ -752,7 +734,6 @@ class TrainingStateMonitor(Callback):
embedding_size (int, optional): The size of embedding norm which is get
by hidden_size * vocab_size. Default: ``4096``.
use_local_norm (bool, optional): Whether to turn on the local norm. Default: ``False``.
Default: ``False``.
"""

@args_type_check(embedding_size=int, use_skip_data_by_global_norm=bool)
@@ -771,10 +752,9 @@ class TrainingStateMonitor(Callback):
embedding_size: int = 4096,
use_local_norm: bool = False):
super().__init__()
if not (isinstance(step_interval, int) and step_interval > 0):
logger.warning("The value of 'monitor_config.step_interval' should be positive integer, "
f"but get {step_interval}. Use default value: 1.")
step_interval = 1
if not _is_positive_natural_number(step_interval):
raise TypeError("The value of 'monitor_config.step_interval' should be positive integer, "
f"but get {step_interval}.")
self.step_interval = step_interval
self.last_print_time = 0
self.step_time = time.time()
@@ -808,7 +788,7 @@ class TrainingStateMonitor(Callback):
self.device_local_norm_pattern = re.compile('(device_local_norm)_[a-z]+[0-9]+_([0-9]+)')
# when pipeline_stages > 2, param aggregation is not supported for now
pp_parallel = context.get_auto_parallel_context("pipeline_stages") > 1
if pp_parallel and self.sr_format:
if pp_parallel and self.sr_format and self.do_aggregation:
raise TypeError("When pipeline_stages > 1, weight aggregation is not supported")

def on_train_epoch_begin(self, run_context):
@@ -1076,15 +1056,37 @@ class TrainingStateMonitor(Callback):
if hasattr(config.get('stable_rank_config'), "get"):
self.sr_format = config.get('stable_rank_config').get('format', None)
self.sr_step_interval = config.get('stable_rank_config').get('step_interval', 100)
if not _is_positive_natural_number(self.sr_step_interval):
raise TypeError("'monitor_config.stable_rank_config.step_interval' should be positive integer,"
f"but get {self.sr_step_interval}.")
self.sr_last_print_time = 0
self.sr_target = config.get('stable_rank_config').get('target') or ['.*']
self.sr_target_cache = {}
self.do_aggregation = config.get('stable_rank_config').get('do_aggregation', False)
self.moe_show_mode = config.get('stable_rank_config').get('moe_show_mode') or ["all"]
self.power_iteration_num = config.get('stable_rank_config').get('power_iteration_num', 5)
if not _is_positive_natural_number(self.power_iteration_num):
raise TypeError("'monitor_config.stable_rank_config.power_iteration_num' should be positive integer,"
f"but get {self.power_iteration_num}.")
else:
self.sr_format = None

def _init_global_norm_monitor_config(self, config):
"""Initialize global norm monitor config"""
self.check_for_global_norm = config.get('check_for_global_norm')
self.global_norm_record_path = os.path.join(get_output_root_path(), "abnormal_global_norm.json")
self.global_norm_spike_threshold = config.get('global_norm_spike_threshold')
self.global_norm_spike_count_threshold = config.get('global_norm_spike_count_threshold', 10)
if not _is_positive_natural_number(self.global_norm_spike_count_threshold):
raise TypeError("'monitor_config.global_norm_spike_count_threshold' should be positive integer, "
f"but get {self.global_norm_spike_count_threshold}.")
self.abnormal_global_norms: dict[str, list[float]] = {}
if self.global_norm_record_path and os.path.exists(self.global_norm_record_path):
# the data format might be like {"300": [3.3], "600": [4.1, 4.2],}
# because json cannot use number as key, we convert it to string
with open(self.global_norm_record_path, 'r', encoding="utf-8") as file:
self.abnormal_global_norms = json.load(file)

def _init_config(self, config):
"""Initialize members from config"""
if config is None:
@@ -1106,15 +1108,9 @@ class TrainingStateMonitor(Callback):
self.optimizer_state_format = config.get('optimizer_state_format', None)
self.weight_state_format = config.get('weight_state_format', None)
self._init_stable_rank_config(config)
self._init_global_norm_monitor_config(config)
self.throughput_baseline = config.get('throughput_baseline', None)
self.print_struct = config.get('print_struct')

self.check_for_global_norm = config.get('check_for_global_norm')
self.global_norm_record_path = os.path.join(get_output_root_path(), "abnormal_global_norm.json")
self.global_norm_spike_threshold = config.get('global_norm_spike_threshold')
self.global_norm_spike_count_threshold = config.get('global_norm_spike_count_threshold')
self.abnormal_global_norms: dict[str, list[float]] = {}

if self.print_struct is None:
self.print_struct = False
if not (isinstance(self.target, list) and self.target and all(isinstance(i, str) for i in self.target)):
@@ -1130,11 +1126,6 @@ class TrainingStateMonitor(Callback):
'optimizer_state_format', 'weight_state_format', 'max_attention_logit_format']
for attr in attrs:
self._check_attr_formats(attr)
if self.global_norm_record_path and os.path.exists(self.global_norm_record_path):
# the data format might be like {"300": [3.3], "600": [4.1, 4.2],}
# because json cannot use number as key, we convert it to string
with open(self.global_norm_record_path, 'r', encoding="utf-8") as file:
self.abnormal_global_norms = json.load(file)

def _print_stable_rank(self, name, param, cur_step_num):
"""output stable rank and max eigenvalues"""
@@ -1145,8 +1136,12 @@ class TrainingStateMonitor(Callback):
stable_rank, eigenvalue = _get_stable_rank(param, self.power_iteration_num)
self._output(f'weight_stable_rank/{name}', stable_rank, cur_step_num, self.sr_format)
self._output(f'weight_eigenvalue/{name}', eigenvalue, cur_step_num, self.sr_format)
if (stable_rank, eigenvalue) == (0.0, 0.0):
logger.info(f"{name}'s stable rank might be 0 or some exception happened, check warning above.")
else:
stable_rank, eigenvalue = _get_stable_rank(param, self.power_iteration_num)
if (stable_rank, eigenvalue) == (0.0, 0.0):
return
if self.moe_show_mode in ('all', 'full'):
for index, sr in enumerate(stable_rank):
self._output(f'weight_stable_rank/{name}/expert_{index}', sr, cur_step_num,
@@ -1276,25 +1271,19 @@ class TrainingStateMonitor(Callback):

def _dump_max_attention_logit(self, cb_params):
"""write the max attention logit to log/tensorboard"""
if cb_params.optimizer is not None:
cb_optimizer = cb_params.optimizer
else:
cb_optimizer = cb_params.network.optimizer
params = cb_optimizer._parameters # pylint: disable=W0212
network = cb_params.train_network
network = get_real_models(network)
params = network.get_max_attention_logit()

if not params:
return
step = cb_params.cur_step_num
vals = []
for param in params:
name = getattr(param, "name", "")
if "max_logits_val" not in name:
continue

t = param.value()
v = t.asnumpy().squeeze()
for param_name, param in params.items():
v = param.asnumpy().squeeze()
v = v / max(1, self.micro_batch_num)

tag = f"max_attention_logit/{name}"
tag = f"max_attention_logit/{param_name}"
if 'log' in self.max_attention_logit_format:
self._output(tag, v.tolist(), step, ['log'])
if 'tensorboard' in self.max_attention_logit_format:
@@ -1517,8 +1506,6 @@ class CheckpointMonitor(ModelCheckpoint):
save_optimizer (bool, optional): Whether to save optimizer weights,
only used in megatron-format weight save scene. Legacy scene will be set to ``None``.
Default: ``True``.
save_checkpoint_path (str, optional): Users can specify the path to store weights.
If None, the checkpoints will be saved at './output_dir/checkpoint'. Default: ``None``.

Raises:
ValueError: If `prefix` is not str or contains the '/' character.
@@ -1555,8 +1542,7 @@ class CheckpointMonitor(ModelCheckpoint):
use_checkpoint_health_monitor=False,
health_ckpts_record_dir="./output",
use_legacy_format=True,
save_optimizer=True,
save_checkpoint_path=None):
save_optimizer=True):

self.config = config
self.save_network_params = save_network_params
@@ -1570,7 +1556,7 @@ class CheckpointMonitor(ModelCheckpoint):
# Ensure that 'save_optimizer' only use in the sense of 'use_legacy_format == False'
self.save_optimizer = save_optimizer if not use_legacy_format else False
self.origin_prefix = prefix
self.save_checkpoint_path = save_checkpoint_path
self.directory = directory
self.need_remove_redundancy = remove_redundancy

prefix = prefix + f"_rank_{self.rank_id}"
@@ -2031,7 +2017,7 @@ class CheckpointMonitor(ModelCheckpoint):
network=cb_params.network,
filter_func=(lambda x: x in list(
cb_params.network.network.parameters_dict().keys())) if not self.save_optimizer else None
)
) if get_real_group_size() > 1 else None

save_checkpoint(
iteration=iteration,
@@ -2041,7 +2027,7 @@ class CheckpointMonitor(ModelCheckpoint):
common_info=self.common_info,
keep_max_num=self._config.keep_checkpoint_max,
user_prefix=self.origin_prefix,
save_checkpoint_path=self.save_checkpoint_path,
save_checkpoint_path=self.directory,
sharded_tensor_metas=sharded_tensor_metas,
remove_redundancy=self.need_remove_redundancy
)
@@ -2306,6 +2292,7 @@ class ProfileMonitor(Callback):
step_num = cb_params.cur_step_num
if self.profiler and not self.is_profiler_start:
self.profiler.start()
self.profiler.step() # avoid the first step to align with train steps
self.is_profiler_start = True

if self.mstx_enabled:
@@ -2353,7 +2340,7 @@ class ProfileMonitor(Callback):
'world_size': parallel.get('device_num', None)
}))
except AttributeError as e:
logger.warning("Profiler failed to record distributed args, %s", e)
logger.warning("Profiler failed to record distributed args, %s", e)

def _is_profile_required(self, rank_id):
"""


+ 18
- 9
mindformers/core/context/build_context.py View File

@@ -18,11 +18,11 @@ import os
from dataclasses import dataclass
from typing import Union

import mindspore as ms
import mindspore.dataset as ds
import mindspore
import mindspore.context as ms_context
import psutil

import mindformers
from mindformers.core.config_args import (
ContextConfig,
MFContextConfig,
@@ -36,9 +36,7 @@ from mindformers.tools.utils import (
MODE,
get_output_subpath,
check_in_dynamic_cluster,
get_real_local_rank
)
from mindformers.utils import get_cann_workqueue_cores
from mindformers.version_control import (
check_tft_valid,
set_ms_deterministic
@@ -155,7 +153,7 @@ class MSContextOperator:
kernel_launch_group[key] = val
thread_num = int(kernel_launch_group.get('thread_num', 2))
kernel_group_num = int(kernel_launch_group.get('kernel_group_num', 8))
ms.runtime.set_kernel_launch_group(thread_num=thread_num, kernel_group_num=kernel_group_num)
mindspore.runtime.set_kernel_launch_group(thread_num=thread_num, kernel_group_num=kernel_group_num)

def _set_device_id(self, ctx, ms_ctx):
if self.config.use_parallel and check_in_dynamic_cluster():
@@ -386,7 +384,18 @@ def set_ms_affinity(affinity_config, affinity_cpu_list):
affinity_cpu_list = None

if affinity_config:
device_id = get_real_local_rank()
# Check if any device_X in affinity_config has X >= device_num
max_device_id = mindformers.tools.utils.get_real_group_size() - 1
for key in affinity_config:
try:
x = int(key.split('_')[1])
except Exception as exc:
raise ValueError(f"Invalid device config key {key} in affinity_config. "
f"The pattern should be `device_X`, where X refers to device id.") from exc
if x > max_device_id:
raise ValueError(f"Invalid device id {x} in affinity_config. "
f"Maximum allowed device id is {max_device_id}.")
device_id = mindformers.tools.utils.get_real_local_rank()
device_config = affinity_config.get(f'device_{device_id}', None)
if device_config:
affinity_cpu_list = device_config.get('affinity_cpu_list', None)
@@ -397,7 +406,7 @@ def set_ms_affinity(affinity_config, affinity_cpu_list):
else:
module_to_cpu_dict = None

ms.runtime.set_cpu_affinity(
mindspore.runtime.set_cpu_affinity(
True,
affinity_cpu_list,
module_to_cpu_dict
@@ -408,14 +417,14 @@ def set_cpu_affinity(rank_id, rank_size):
"""cpu affinity"""
use_cpu_affinity = os.environ.get('CPU_AFFINITY')
if use_cpu_affinity and use_cpu_affinity.lower() in ('1', 'true'):
ds.config.set_numa_enable(True)
mindspore.dataset.config.set_numa_enable(True)
count = psutil.cpu_count()
current_process = psutil.Process()
used_cpus_num = count // rank_size
used_cpus = list(
range(rank_id * used_cpus_num, (rank_id + 1) * used_cpus_num)
)
cann_used_cpus = get_cann_workqueue_cores(rank_id)
cann_used_cpus = mindformers.utils.get_cann_workqueue_cores(rank_id)
logger.info(f"cann workqueue cpus: {cann_used_cpus}")
used_cpus = list(set(used_cpus) - set(cann_used_cpus))
if not used_cpus:


+ 12
- 6
mindformers/core/context/validators.py View File

@@ -69,17 +69,23 @@ def validate_sink_size(config):
def validate_precision_sync(config):
"""Validate train_percision_sync and infer_percision_sync."""
"""
Validate train_precision_sync and infer_precision_sync configuration values.
Args:
config (MindFormerConfig): Configuration object containing precision sync settings
Raises:
ValueError: If train_precision_sync or infer_precision_sync are not boolean values
"""
train_precision_sync = config.get_value('train_precision_sync')
infer_percision_sync = config.get_value('train_precision_sync')
infer_precision_sync = config.get_value('infer_precision_sync')
if train_precision_sync is not None and not isinstance(
train_precision_sync, bool):
raise ValueError(
f'train_percision_sync should be bool, got {train_precision_sync}')
if infer_percision_sync is not None and not isinstance(
infer_percision_sync, bool):
f'train_precision_sync should be bool, got {train_precision_sync}')
if infer_precision_sync is not None and not isinstance(
infer_precision_sync, bool):
raise ValueError(
f'train_percision_sync should be bool, got {infer_percision_sync}')
f'infer_precision_sync should be bool, got {infer_precision_sync}')
def validate_invalid_predict_mode(config):


+ 40
- 37
mindformers/core/optim/__init__.py View File

@@ -28,47 +28,50 @@ __all__ = ['AdamW', 'PmaAdamW', 'Muon']

@MindFormerRegister.register(MindFormerModuleType.OPTIMIZER)
class AdamW:
r"""
"""
This is the implementation of AdamW.

.. math::
\begin{array}{l}
&\newline
&\hline \\
&\textbf{Parameters}: \: 1^{\text {st }}\text {moment vector} \: m , \: 2^{\text {nd}} \:
\text{moment vector} \: v , \\
&\: gradients \: g, \: \text{learning rate} \: \gamma,
\text {exponential decay rates for the moment estimates} \: \beta_{1} \: \beta_{2} , \\
&\:\text {parameter vector} \: w_{0}, \:\text{timestep} \: t, \: \text{weight decay} \: \lambda \\
&\textbf{Init}: m_{0} \leftarrow 0, \: v_{0} \leftarrow 0, \: t \leftarrow 0, \:
\text{init parameter vector} \: w_{0} \\[-1.ex]
&\newline
&\hline \\
&\textbf{repeat} \\
&\hspace{5mm} t \leftarrow t+1 \\
&\hspace{5mm}\boldsymbol{g}_{t} \leftarrow \nabla f_{t}\left(\boldsymbol{w}_{t-1}\right) \\
&\hspace{5mm}\boldsymbol{w}_{t} \leftarrow \boldsymbol{w}_{t-1}-\gamma\lambda\boldsymbol{w}_{t-1} \\
&\hspace{5mm}\boldsymbol{m}_{t} \leftarrow \beta_{1} \boldsymbol{m}_{t-1}+\left(1-\beta_{1}\right)
\boldsymbol{g}_{t} \\
&\hspace{5mm}\boldsymbol{v}_{t} \leftarrow \beta_{2} \boldsymbol{v}_{t-1}+\left(1-\beta_{2}\right)
\boldsymbol{g}_{t}^{2} \\
&\hspace{5mm}\widehat{\boldsymbol{m}_{t}} \leftarrow \boldsymbol{m}_{t}/\big(1-\beta_{1}^{t} \big) \\
&\hspace{5mm}\widehat{\boldsymbol{v}_{t}} \leftarrow \boldsymbol{v}_{t}/\big(1-\beta_{2}^{t} \big) \\
&\hspace{5mm}\boldsymbol{w}_{t} \leftarrow \boldsymbol{w}_{t-1}-\gamma\widehat{\boldsymbol{m}_{t}}
/\left(\sqrt{\widehat{\boldsymbol{v}_{t}}}+\epsilon\right) \\
&\textbf{until}\text { stopping criterion is met } \\[-1.ex]
&\newline
&\hline \\[-1.ex]
&\textbf{return} \: \boldsymbol{w}_{t} \\[-1.ex]
&\newline
&\hline \\[-1.ex]
\end{array}
\\begin{array}{l}
&\\newline
&\\hline \\\\
&\\textbf{Parameters}: \\: 1^{\\text {st }}\\text {moment vector} \\: m , \\: 2^{\\text {nd}} \\:
\\text{moment vector} \\: v , \\\\
&\\: gradients \\: g, \\: \\text{learning rate} \\: \\gamma,
\\text {exponential decay rates for the moment estimates} \\: \\beta_{1} \\: \\beta_{2} , \\\\
&\\:\\text {parameter vector} \\: w_{0}, \\:\\text{timestep} \\: t, \\: \\text{weight decay} \\: \\lambda \\\\
&\\textbf{Init}: m_{0} \\leftarrow 0, \\: v_{0} \\leftarrow 0, \\: t \\leftarrow 0, \\:
\\text{init parameter vector} \\: w_{0} \\\\[-1.ex]
&\\newline
&\\hline \\\\
&\\textbf{repeat} \\\\
&\\hspace{5mm} t \\leftarrow t+1 \\\\
&\\hspace{5mm}\\boldsymbol{g}_{t} \\leftarrow \\nabla f_{t}\\left(\\boldsymbol{w}_{t-1}\\right) \\\\
&\\hspace{5mm}\\boldsymbol{w}_{t} \\leftarrow \\boldsymbol{w}_{t-1}-\\gamma\\lambda
\\boldsymbol{w}_{t-1} \\\\
&\\hspace{5mm}\\boldsymbol{m}_{t} \\leftarrow \\beta_{1} \\boldsymbol{m}_{t-1}+\\left(1-\\beta_{1}\\right)
\\boldsymbol{g}_{t} \\\\
&\\hspace{5mm}\\boldsymbol{v}_{t} \\leftarrow \\beta_{2} \\boldsymbol{v}_{t-1}+\\left(1-\\beta_{2}\\right)
\\boldsymbol{g}_{t}^{2} \\\\
&\\hspace{5mm}\\widehat{\\boldsymbol{m}_{t}} \\leftarrow \\boldsymbol{m}_{t}/
\\big(1-\\beta_{1}^{t} \\big) \\\\
&\\hspace{5mm}\\widehat{\\boldsymbol{v}_{t}} \\leftarrow \\boldsymbol{v}_{t}/
\\big(1-\\beta_{2}^{t} \\big) \\\\
&\\hspace{5mm}\\boldsymbol{w}_{t} \\leftarrow \\boldsymbol{w}_{t-1}-\\gamma\\widehat{\\boldsymbol{m}_{t}}
/\\left(\\sqrt{\\widehat{\\boldsymbol{v}_{t}}}+\\epsilon\\right) \\\\
&\\textbf{until}\\text { stopping criterion is met } \\\\[-1.ex]
&\\newline
&\\hline \\\\[-1.ex]
&\\textbf{return} \\: \\boldsymbol{w}_{t} \\\\[-1.ex]
&\\newline
&\\hline \\\\[-1.ex]
\\end{array}

:math:`m` represents the first moment vector moment1, :math:`v` represents the second moment vector moment2,
:math:`\widehat{m}` represents the bias-corrected first moment vector, :math:`\widehat{v}` represents
the bias-corrected second moment vector, :math:`g` represents gradients, :math:`\gamma` represents
learning_rate, :math:`\beta_1`, `\beta_2` represent beta1 and beta2, :math:`t` represents the current step,
:math:`w` represents params, and :math:`\lambda` represents weight_decay.
:math:`\\widehat{m}` represents the bias-corrected first moment vector, :math:`\\widehat{v}` represents
the bias-corrected second moment vector, :math:`g` represents gradients, :math:`\\gamma` represents
learning_rate, :math:`\\beta_1`, `\\beta_2` represent beta1 and beta2, :math:`t` represents the current step,
:math:`w` represents params, and :math:`\\lambda` represents weight_decay.

Args:
params (Union[list[Parameter], list[dict]]): Must be list of `Parameter` or list of `dict`. When the
@@ -218,7 +221,7 @@ class AdamW:

@MindFormerRegister.register(MindFormerModuleType.OPTIMIZER)
class PmaAdamW:
r"""
"""
This is the implementation of PmAdamW.

Args:


+ 2
- 1
mindformers/core/optim/build_optim.py View File

@@ -33,6 +33,7 @@ def get_tft_wrapped_cls(class_name, config):
optim_cls = optim_cls.get_actual_adamw_cls(use_fused)

if check_tft_valid():
# pylint: disable=C0415
from mindspore.train.callback import TrainFaultTolerance
optim_cls = TrainFaultTolerance.get_optimizer_wrapper(optim_cls)
else:
@@ -89,7 +90,7 @@ def build_optim(

if default_args is not None:
config.update(default_args)
config = config.copy()
optim_cls, config = get_tft_wrapped_cls(config.pop('type'), config)
else:
optim_cls, config = get_tft_wrapped_cls(class_name, kwargs)


+ 2
- 2
mindformers/core/optim/fused_pma_adamw.py View File

@@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""FusedPmaAdamW implementation"""
import mindspore.ops as ops
from mindspore import ops

from mindspore._checkparam import GT, INC_NEITHER
from mindspore import _checkparam as validator
@@ -76,7 +76,7 @@ def _check_param_value(fused_num, interleave_step, fused_algo, ema_alpha, prim_n


class FusedPmaAdamW(FusedAdamW):
r"""
"""
This is the implementation of PmaAdamW that uses fused operators.

Args:


+ 65
- 50
mindformers/core/optim/muon.py View File

@@ -98,8 +98,8 @@ def _slice_tensor_to_shards(x, tp, tp_dim, op, rank_id, op_group, tp_group):
x = Chunk()(x, tp, tp_dim)[chunk_id]

if op > 1:
if op_group == tp_group:
chunk_id = rank_id % tp
if tp_dim == -1:
chunk_id = rank_id % op
else:
chunk_id = rank_id // tp % op
x = Chunk()(x, op)[chunk_id]
@@ -273,10 +273,13 @@ class Muon(Optimizer):
adamw_betas=(0.95, 0.95),
adamw_eps=1e-8,
micro_batch_num=1,
qk_clip_threshold=4,
qk_clip_threshold=100,
model=None,
**kwargs,
):
super().__init__(learning_rate, params, weight_decay)
if kwargs.get('swap', False):
raise ValueError("Muon does not support swap.")

self._verify_model(model)

@@ -400,32 +403,22 @@ class Muon(Optimizer):
"""Initialize parallel configuration."""
self.tp = model.get_gpt_transformer_config().tensor_model_parallel_size
self.tps = tuple(self.tp for _ in self._parameters)
self.dp = model.get_gpt_transformer_config().data_parallel_size
logger.info(f"Muon tp group size is: {self.tp}")

if not get_auto_parallel_context('enable_parallel_optimizer'):
self.op = 1
else:
self.op = get_auto_parallel_context('optimizer_weight_shard_size')
if self.op == -1:
if self.op < 1:
raise ValueError(
"Must set parallel.parallel_optimizer_config.optimizer_weight_shard_size when using Muon")
"Must set parallel.parallel_optimizer_config.optimizer_weight_shard_size > 1 "
"when enable_parallel_optimizer is True.")
if self.dp < self.op:
raise ValueError('Must set parallel_config.data_parallel >= '
'parallel.parallel_optimizer_config.optimizer_weight_shard_size when using Muon.')
logger.info(f"Muon op group size is: {self.op}")

# Validate MoE expert counts divisibility constraint:
# num_moe_experts must be divisible by (optimizer_weight_shard_size * expert_model_parallel_size)
if model.is_moe_model():
config = model.get_gpt_transformer_config()
num_moe_experts = config.num_moe_experts
expert_model_parallel_size = config.expert_model_parallel_size
if self.op * expert_model_parallel_size <= 0:
raise ValueError("Invalid optimizer_shard * expert_model_parallel_size (<=0).")
if num_moe_experts % (self.op * expert_model_parallel_size) != 0:
raise ValueError(
f"Invalid configuration: 'num_moe_experts' ({num_moe_experts}) must be divisible by "
f"'optimizer_weight_shard_size * expert_model_parallel_size' ({self.op} * "
f"{expert_model_parallel_size} = {self.op * expert_model_parallel_size})."
)

def _initialize_communication_groups(self):
"""Initialize communication groups for parallel training."""
self.tp_group = self._get_tp_group_name(self.rank_id, self.tp)
@@ -434,9 +427,7 @@ class Muon(Optimizer):

def _initialize_op_groups(self, model):
"""Initialize optimizer parallel groups for parameters."""
self.ops, self.op_groups = model.get_op_groups_info(
self._parameters, self.op, self.op_group, self.op_in_tp_group
)
self.ops, self.op_groups = model.get_op_groups_info(self._parameters, self.op)

def _create_communication_group(self, rank_list):
"""
@@ -473,7 +464,7 @@ class Muon(Optimizer):
logger.info(
f"op_in_tp group will reuse tp group" \
f", since tensor_parallel_size({tp}) == optimizer_parallel_size({op})."
)
)
op_in_tp_group_name = tp_group
else:
logger.info(f"Muon op_in_tp group list is: {rank_list}")
@@ -495,32 +486,11 @@ class Muon(Optimizer):
tp_group_name = self._create_communication_group(rank_list)
return tp_group_name

@jit(backend="ms_backend")
def construct(self, gradients):
"""Construct method for optimizer.

Args:
gradients: Gradients for optimization.

Returns:
Updated gradients after optimization.
def _hyper_map_func(self, lr, weight_decay, gradients):
"""
gradients = self.flatten_gradients(gradients)
weight_decay = self.get_weight_decay()
lr = self.get_lr()
self.assignadd(self.global_step, self.global_step_increase_tensor)
optim_result = self.hyper_map(
F.partial(
_muon_opt,
self.muon_momentum,
self.matched_adamw_rms,
self.beta1,
self.beta2,
self.global_step,
self.eps,
lr,
),
weight_decay,
Apply Muon optimizer update using hyper_map across parameter structures.
"""
hyper_map_args = [
self.rank_ids,
self._parameters,
self.moments1,
@@ -537,7 +507,52 @@ class Muon(Optimizer):
self.tp_groups,
self.param_name_tuple,
self.muon_split_fns,
self.muon_merge_fns,
self.muon_merge_fns
]

if self.is_group:
# If parameters are divided into groups (group-wise hyperparams)
if self.is_group_lr:
# Case 1: Both learning rate and weight decay are grouped
partial_func = F.partial(
_muon_opt, self.muon_momentum, self.matched_adamw_rms,
self.beta1, self.beta2, self.global_step, self.eps
)
hyper_map_args = [lr, weight_decay] + hyper_map_args
else:
# Case 2: Only weight decay is grouped, lr is global
partial_func = F.partial(
_muon_opt, self.muon_momentum, self.matched_adamw_rms,
self.beta1, self.beta2, self.global_step, self.eps, lr
)
hyper_map_args = [weight_decay] + hyper_map_args
else:
# No parameter groups: lr and weight decay are global hyperparameters
partial_func = F.partial(
_muon_opt, self.muon_momentum, self.matched_adamw_rms,
self.beta1, self.beta2, self.global_step, self.eps, lr, weight_decay
)

return self.hyper_map(partial_func, *hyper_map_args)

@jit(backend="ms_backend")
def construct(self, gradients):
"""Construct method for optimizer.

Args:
gradients: Gradients for optimization.

Returns:
Updated gradients after optimization.
"""
gradients = self.flatten_gradients(gradients)
weight_decay = self.get_weight_decay()
lr = self.get_lr()
self.assignadd(self.global_step, self.global_step_increase_tensor)
optim_result = self._hyper_map_func(
lr,
weight_decay,
gradients,
)

updates = self.model.apply_qk_clip_scaling(


+ 6
- 0
mindformers/dataset/causal_language_model_dataset.py View File

@@ -38,9 +38,15 @@ CAST_TO_INT_COLUMNS = ["input_ids", "labels"]


def _use_compressed_eod_mask(data_loader):
"""
Determine whether the given data loader should use a compressed EOD (End-Of-Document) mask.
"""
if (hasattr(data_loader, 'config') and data_loader.config and
data_loader.config.create_compressed_eod_mask): # megatron dataset
return True
if (hasattr(data_loader, 'create_compressed_eod_mask') and
data_loader.create_compressed_eod_mask):
return True
if (hasattr(data_loader, 'adaptor_config') and data_loader.adaptor_config and
data_loader.adaptor_config.compress_mask): # common dataloader
return True


+ 1
- 1
mindformers/dataset/dataloader/blended_megatron_dataloader.py View File

@@ -199,7 +199,7 @@ class MegatronDatasetBuilder:
eod_mask_loss=self.config.get("eod_mask_loss", False),
create_attention_mask=self.config.get("create_attention_mask", True),
create_compressed_eod_mask=self.config.get("create_compressed_eod_mask", False),
eod_pad_length=self.config.get("eod_pad_length", 128),
eod_pad_length=eod_pad_length,
s3_cache_path=self.config.get("s3_cache_path", None),
drop_last_partial_validation_sequence=self.config.get("drop_last_partial_validation_sequence", True),
add_extra_token_to_sequence=self.config.get("add_extra_token_to_sequence", True),


+ 4
- 0
mindformers/dataset/dataloader/hf_dataloader.py View File

@@ -354,6 +354,10 @@ class HFDataLoader:
"""Wrap source dataset with Mindspore Dataset."""
if getattr(config, 'streaming', False):
hf_dataset = HFIterableDataset(config, dataset, num_shards, shard_id)
if num_parallel_workers > 1:
num_parallel_workers = 1
logger.warning(
"Streaming mode only supports 'num_parallel_workers=1'. Automatically resetting the value.")
else:
hf_dataset = HFDataset(config, dataset)
dataset = GeneratorDataset(


+ 11
- 11
mindformers/generation/utils.py View File

@@ -14,8 +14,8 @@
# ============================================================================
"""utils for text generation."""
from collections import UserDict
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from threading import Thread
from typing import Optional
import numpy as np

@@ -59,22 +59,22 @@ def softmax(x, axis=None):


def softmax_single(i, res, x):
""" Worker used by thread pool to compute softmax safely. """
res[i] = softmax(x)


def softmax_with_threads(x, is_finished=None):
"""calculate softmax with threads"""
res = np.ones_like(x)
all_threads = []
for i in range(0, res.shape[0]):
if is_finished and is_finished[i]:
continue
thread = Thread(target=softmax_single,
args=(i, res, x[i]))
all_threads.append(thread)
thread.start()
for thread in all_threads:
thread.join()
with ThreadPoolExecutor() as executor:
futures = []
for i in range(0, res.shape[0]):
if is_finished and is_finished[i]:
continue
future = executor.submit(softmax_single, i, res, x[i])
futures.append(future)
for future in futures:
future.result()
return res




+ 2
- 2
mindformers/model_runner.py View File

@@ -18,6 +18,7 @@ For text generation
"""
import os
import json
import importlib
from typing import Optional, List, Union, Dict
import numpy as np

@@ -186,7 +187,6 @@ class ModelRunner:
model_runner_cls = MindIEModelRunner
if model_type not in models.__all__:
try:
import importlib
model_runner_cls = importlib.import_module(model_type, ["MindIEModelRunner"]).MindIEModelRunner
except ImportError:
logger.info(f"import MindIEModelRunner from module {model_type} failed, "
@@ -581,7 +581,7 @@ class InputBuilder:
input_ids
"""
if not hasattr(self.tokenizer, "apply_chat_template"):
raise RuntimeError("The tokenizer dose not implement apply_chat_template function.")
raise RuntimeError("The tokenizer does not implement apply_chat_template function.")
if not self.tokenizer.chat_template:
raise RuntimeError("The model does not appear to be a chat model because it is not configured with a "
"`chat_template`.")


+ 4
- 2
mindformers/models/glm4_moe/modeling_glm4_moe_infer.py View File

@@ -17,7 +17,6 @@ __all__ = ['InferenceGlm4MoeForCausalLM']

from mindformers.models.utils import jit
from mindformers.parallel_core.transformer_config import TransformerConfig
from mindformers.parallel_core.transformer_config_utils import convert_to_transformer_config
from mindformers.models.glm4_moe.utils import Glm4MoePreTrainedModel
from mindformers.parallel_core.inference.utils import update_comm_config
from mindformers.parallel_core.inference.base_models.gpt.gpt_model import GPTModel
@@ -41,8 +40,11 @@ class InferenceGlm4MoeForCausalLM(Glm4MoePreTrainedModel, InferModelMixin):

def __init__(self, config: Glm4MoeConfig):
super().__init__(config, auto_prefix=False)
if config.model_type == "glm4_moe":
setattr(config, "num_nextn_predict_layers", 0)

self.config = config
config: TransformerConfig = convert_to_transformer_config(self.config)
config: TransformerConfig = self.convert_to_transformer_config(self.config)

# update communication-related configuration in TransformerConfig
config = update_comm_config(config)


+ 15
- 13
mindformers/models/llama/llama.py View File

@@ -111,22 +111,22 @@ class LlamaModel(LlamaPreTrainedModel):
else:
logger.info("MoE config is None, use normal FFN")
if not self.use_flash_attention and self.use_ring_attention:
raise ValueError(f"When the ring_attention = True, the flash_attention must be True.")
raise ValueError("When the ring_attention = True, the flash_attention must be True.")
if not self.use_flash_attention and self.use_eod_attn_mask_compression:
raise ValueError(f"When the use_eod_attn_mask_compression = True, the flash_attention must be True.")
raise ValueError("When the use_eod_attn_mask_compression = True, the flash_attention must be True.")
self.seq_split_num = config.parallel_config.seq_split_num
self.seq_pipe = self.seq_split_num > 1
if self.seq_pipe:
dp = config.parallel_config.data_parallel
if self.use_ring_attention:
raise ValueError(f"When the seq_pipe = True, the use_ring_attention cannot be True.")
raise ValueError("When the seq_pipe = True, the use_ring_attention cannot be True.")
if config.use_attn_mask_compression and not check_seqpp_fa_opt_support():
raise ValueError(f"Currently, when the seq_pipe = True, "
f"use_attn_mask_compress must be False with mindspore < 2.6.0. "
f"If you want to enable it, please upgrade mindspore to 2.6.0 or later.")
raise ValueError("Currently, when the seq_pipe = True, "
"use_attn_mask_compress must be False with mindspore < 2.6.0. "
"If you want to enable it, please upgrade mindspore to 2.6.0 or later.")
if config.use_eod_attn_mask_compression:
raise ValueError(f"Currently, when the seq_pipe = True, "
f"use_eod_attn_mask_compression cannot be True.")
raise ValueError("Currently, when the seq_pipe = True, "
"use_eod_attn_mask_compression cannot be True.")
self.n_kv_head = self.n_head if config.n_kv_heads is None else config.n_kv_heads
kv_shape = (config.batch_size * dp, self.n_kv_head, config.seq_length, self.head_dim)
self.zeros = initializer('zeros', kv_shape, dtype=self.dtype)
@@ -430,10 +430,10 @@ class LlamaModel(LlamaPreTrainedModel):
@MindFormerRegister.register(MindFormerModuleType.MODELS)
class LlamaForCausalLM(LlamaPreTrainedModel):
r"""
Provide llama training loss or logits through network.
Provide Llama training loss or logits through network.

Args:
config (LlamaConfig, optional): The config of llama model. Default: `None` .
config (LlamaConfig, optional): The config of Llama model. Default: `None` .

Inputs:
- **input_ids** (Tensor) - the indices of input sequence tokens in the vocabulary with data type Int64/Int32,
@@ -485,7 +485,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):

@lazy_inline
def __init__(self, config: LlamaConfig = None):
super(LlamaForCausalLM, self).__init__(config, auto_prefix=True)
super().__init__(config, auto_prefix=True)
_check_config(config.parallel_config)
self.config = config
self.ignore_token_id = config.ignore_token_id
@@ -507,7 +507,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
self.prefill_gather_flatten = P.Gather()
self.sub_batch_valid_len = P.Sub()
self.predict_run_mode = get_predict_run_mode()
logger.info("Predict run mode: {}".format(self.predict_run_mode))
logger.info(f"Predict run mode: {self.predict_run_mode}")
if self.predict_run_mode and self.config.is_dynamic:
logger.info("use_flash_attention is set to True when run_mode is predict and is_dynamic is True.")
self.config.use_flash_attention = True
@@ -807,7 +807,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):

@classmethod
def obtain_name_map(cls, load_checkpoint_files):
name_map = dict()
name_map = {}
for checkpoint_file in load_checkpoint_files:
with safe_open(checkpoint_file, framework="np") as f:
for k in f.keys():
@@ -829,6 +829,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
return params


# pylint: disable=C0415
def _concat_qkv_weight(wq_keys, wk_keys, wv_keys, model_config, qkv_dict, condition, target_dict):
"""concat qkv weight from dicts"""
from mindformers.utils.convert_utils import qkv_concat_hf2mg
@@ -876,6 +877,7 @@ def _concat_qkv_weight(wq_keys, wk_keys, wv_keys, model_config, qkv_dict, condit
target_dict.update({w_qkv_key: w_qkv_value_mg})


# pylint: disable=C0415
def _concat_ffn_weight(w1_keys, w3_keys, model_config, qkv_dict, condition, target_dict):
"""concat ffn weight from dicts"""
from mindformers.utils.convert_utils import ffn_concat_hf2mg


+ 11
- 0
mindformers/models/qwen3/modeling_qwen3_infer.py View File

@@ -118,3 +118,14 @@ class InferenceQwen3ForCausalLM(Qwen3PreTrainedModel, InferModelMixin):
value_cache=value_cache
)
return logits

def convert_name(self, weight_name):
r"""
Override convert_name method in inference model, in order to read PTQ weights correctly.
PTQ weights are generated after training, so it should only exist in inference model.
"""
weight_name = super().convert_name(weight_name)
# Do extra conversion for quantization parameters.
if self.config.quantization is not None:
weight_name = weight_name.replace('.weight_scale', '.w_scale')
return weight_name

+ 1
- 0
mindformers/models/telechat2/configuration_telechat2.py View File

@@ -84,6 +84,7 @@ class Telechat2Config(PretrainedConfig):
('embed_layernorm', NotSupportedInfo.useless),
('base_seqlen', NotSupportedInfo.useless),
('training_seqlen', NotSupportedInfo.useless),
('masked_softmax_fusion', NotSupportedInfo.useless),
])
def __init__(
self,


+ 0
- 1
mindformers/modules/__init__.py View File

@@ -15,7 +15,6 @@
"""MindFormers Transformers API."""
from .transformer import (
EmbeddingOpParallelConfig,
FeedForward,
LowerTriangularMaskWithDynamic,
MoEConfig,
OpParallelConfig,


+ 21
- 37
mindformers/modules/layers.py View File

@@ -49,6 +49,8 @@ from mindformers.tools.logger import logger
from mindformers.tools.utils import is_pynative
from mindformers.modules.activation import get_activation
from mindformers.modules.transformer.op_parallel_config import default_dpmp_config, OpParallelConfig, MoEParallelConfig
from mindformers.parallel_core.training_graph.base_models.common.embeddings.yarn_rotary_pos_embedding import \
_yarn_find_correction_range

__all__ = [
"FixedSparseAttention",
@@ -177,7 +179,6 @@ class _LayerInputCheck:
Check the input shape's is equal to the expected shape, the value on 0-th is viewed as batch, and the
batch size will not be checked.
"""
target_shape = target_shape
length, hidden = target_shape
if isinstance(input_shape, tuple):
input_shape = list(input_shape)
@@ -244,11 +245,9 @@ class Dropout(nn.Cell):
"""

def __init__(self, keep_prob=0.5, dtype=mstype.float32):
super(Dropout, self).__init__()
super().__init__()
if keep_prob <= 0 or keep_prob > 1:
raise ValueError(
"dropout probability should be a number in range (0, 1], but got {}".format(
keep_prob))
raise ValueError(f"dropout probability should be a number in range (0, 1], but got {keep_prob}")
Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name)
self.keep_prob = keep_prob
@@ -269,7 +268,7 @@ class Dropout(nn.Cell):
return out

def extend_repr(self):
return 'keep_prob={}'.format(self.keep_prob)
return f'keep_prob={self.keep_prob}'

def shard(self, strategy):
self.dropout.shard(strategy)
@@ -291,10 +290,10 @@ class LayerNorm(Cell):
"""

def __init__(self, normalized_shape, eps=1e-5, param_init_type=mstype.float32, is_self_defined=False):
super(LayerNorm, self).__init__()
super().__init__()
if param_init_type not in [mstype.float32, mstype.float16, mstype.bfloat16]:
raise TypeError("The type of parameter 'param_init_type' should in [float32, float16], "
"but got the type : {}.".format(type(param_init_type)))
raise TypeError(f"The type of parameter 'param_init_type' should in [float32, float16], "
f"but got the type : {type(param_init_type)}.")
# Since the mindspore 1.10 version, the layernorm has been changed to P.LayerNorm
self.is_self_defined = is_self_defined
if not self.is_self_defined:
@@ -441,7 +440,7 @@ class Linear(Cell):
use_gmm=False,
param_init_type=mstype.float32,
compute_dtype=mstype.float16):
super(Linear, self).__init__()
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
if not (isinstance(activation, str) or activation is None or issubclass(activation, nn.Cell)):
@@ -465,6 +464,7 @@ class Linear(Cell):
self.weight = Parameter(initializer(weight_init, [self.expert_num] + weight_shape, param_init_type),
name="weight")
if self.use_gmm:
# pylint: disable=import-outside-toplevel
from mindspore.ops.auto_generate import GroupedMatmul
# split_item only supports 0 and 3 now, 0 means the size of tensorlist not equal to 1,
# 3 means the size of tensorlist is 1.
@@ -676,7 +676,7 @@ class FixedSparseAttention(nn.Cell):
seq_length=1024,
num_different_global_patterns=4,
parallel_config=default_dpmp_config):
super(FixedSparseAttention, self).__init__()
super().__init__()
dp, mp = parallel_config.data_parallel, parallel_config.model_parallel
if num_heads % mp != 0:
raise ValueError(f"The number of heads {num_heads} must be a "
@@ -700,17 +700,17 @@ class FixedSparseAttention(nn.Cell):
self.parallel_config = parallel_config
size_per_head_list = [64, 128]
if self.seq_length != 1024:
raise ValueError("For 'FixedSparseAttention', the class variable 'seq_length' must be 1024, "
"but got the value : {}.".format(seq_length))
raise ValueError(f"For 'FixedSparseAttention', the class variable 'seq_length' must be 1024, "
f"but got the value : {seq_length}.")
if self.block_size != 64:
raise ValueError("For 'FixedSparseAttention', the class variable 'block_size' must be 64, "
"but got the value : {}.".format(block_size))
raise ValueError(f"For 'FixedSparseAttention', the class variable 'block_size' must be 64, "
f"but got the value : {block_size}.")
if num_different_global_patterns != 4:
raise ValueError("For 'FixedSparseAttention', the class variable 'num_different_global_patterns' "
"must be 4, but got the value : {}".format(num_different_global_patterns))
raise ValueError(f"For 'FixedSparseAttention', the class variable 'num_different_global_patterns' "
f"must be 4, but got the value : {num_different_global_patterns}")
if self.size_per_head not in size_per_head_list:
raise ValueError("For 'FixedSparseAttention', the class variable 'size_per_head' only supports {}, "
"but got the value : {}.".format(size_per_head_list, self.size_per_head))
raise ValueError(f"For 'FixedSparseAttention', the class variable 'size_per_head' "
f"only supports {size_per_head_list}, but got the value : {self.size_per_head}.")
local_ones = np.ones((self.block_size, self.block_size),
dtype=np.float16)
global_mask_original = np.ones((self.seq_length, self.global_size), dtype=np.float16)
@@ -851,7 +851,7 @@ class AlibiTensor(nn.Cell):
"""

def __init__(self, seq_length, num_heads, parallel_config=default_dpmp_config):
super(AlibiTensor, self).__init__()
super().__init__()
dp = parallel_config.data_parallel

self.seq_length = seq_length
@@ -915,7 +915,7 @@ class AlibiTensorV2(nn.Cell):
"""

def __init__(self, num_heads):
super(AlibiTensorV2, self).__init__()
super().__init__()
self.num_heads = num_heads

self.expand_2d = P.ExpandDims()
@@ -1124,22 +1124,6 @@ def _check_linear_scaling_factor(scaling_factor):
raise ValueError(f"`scaling_factor`'s factor field must be a float >= 1, got {factor}")


def _yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
"""Inverse dim formula to find dim based on number of rotations"""
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))


def _yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
"""Find dim range bounds based on rotations"""
low = math.floor(
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
)
high = math.ceil(
_yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
)
return max(low, 0), min(high, dim - 1) # Clamp values just in case


def _yarn_get_mscale(scale=1, mscale=1):
if scale <= 1:
return 1.0


+ 3
- 3
mindformers/modules/quantizers/base.py View File

@@ -221,14 +221,14 @@ class Quantizer(ABC):

@abstractmethod
def _process_model_after_weight_loading(self, model, **kwargs):
pass
return model

@property
@abstractmethod
def is_serializable(self):
pass
return False

@property
@abstractmethod
def is_trainable(self):
pass
return False

+ 1
- 11
mindformers/modules/quantizers/ptq_quantizer.py View File

@@ -52,19 +52,9 @@ class PtqQuantizer(Quantizer):
def _process_model_before_weight_loading(
self, model: "PreTrainedModel", **kwargs
):
# pylint: disable=import-outside-toplevel
from mindspore_gs.ptq import PTQ
ptq = PTQ(config=self.quant_config, layer_policies=self.layer_policies)
model = ptq.apply(model)
model = ptq.convert(model)
return model

def _process_model_after_weight_loading(self, model, **kwargs):
return model

@property
def is_serializable(self):
return False

@property
def is_trainable(self):
return False

+ 0
- 11
mindformers/modules/quantizers/rtn_quantizer.py View File

@@ -65,14 +65,3 @@ class RtnQuantizer(Quantizer):
model = ptq.apply(model)
model = ptq.convert(model)
return model

def _process_model_after_weight_loading(self, model, **kwargs):
return model

@property
def is_serializable(self):
return False

@property
def is_trainable(self):
return False

+ 0
- 1
mindformers/modules/transformer/__init__.py View File

@@ -21,7 +21,6 @@ This is an experimental interface that is subject to change or deletion.

from .transformer import (
EmbeddingOpParallelConfig,
FeedForward,
LowerTriangularMaskWithDynamic,
TransformerOpParallelConfig,
TransformerRecomputeConfig,


+ 20
- 279
mindformers/modules/transformer/transformer.py View File

@@ -26,7 +26,6 @@ import mindspore as ms
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import Zero
from mindspore import nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
@@ -40,18 +39,14 @@ except ImportError:
from mindspore import log as logger
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
from mindspore.context import ParallelMode
from mindformers.modules.layers import Linear, _args_type_validator_check, _valid_type_checks, _valid_value_checks, \
_check_input_dtype
from mindformers.modules.transformer.op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig, \
_Config, _check_config, MoEParallelConfig
from mindformers.version_control import get_dropout
from mindformers.modules.transformer.op_parallel_config import _PipeLineConfig, OpParallelConfig, \
_Config, MoEParallelConfig

from mindformers.tools.logger import _LogActionOnce
from mindformers.tools.utils import is_pynative

__all__ = [
"LowerTriangularMaskWithDynamic",
"FeedForward",
"TransformerOpParallelConfig",
"EmbeddingOpParallelConfig",
"TransformerRecomputeConfig",
@@ -211,7 +206,7 @@ class TransformerSwapConfig(_Config):
if isinstance(layer_swap, dict):
layer_swap = [layer_swap]
if self._validate_layers_consistency(layer_swap):
return [dict(backward_prefetch=layer_swap[0][self.backward_prefetch], layers=True)]
return [{"backward_prefetch": layer_swap[0][self.backward_prefetch], "layers": True}]
return layer_swap

def _initialize_op_swap(self, op_swap):
@@ -225,7 +220,7 @@ class TransformerSwapConfig(_Config):
op_swap_dict = self.op_swap_to_dict(op_swap)
for k, v in op_swap_dict.items():
if self._validate_layers_consistency(v, mode=f'op_swap: {k}'):
op_swap_dict[k] = [dict(backward_prefetch=v[0][self.backward_prefetch], layers=True)]
op_swap_dict[k] = [{"backward_prefetch": v[0][self.backward_prefetch], "layers": True}]
return op_swap_dict

def _validate_layers_consistency(self, layer_swap, mode='layer_swap'):
@@ -283,17 +278,17 @@ class TransformerSwapConfig(_Config):
"""Adds an operation swap configuration to the dictionary."""
if key in dic:
dic[key].append(
dict(
layers=item.get(self.layers),
backward_prefetch=item.get(self.backward_prefetch)
)
{
'layers': item.get(self.layers),
'backward_prefetch': item.get(self.backward_prefetch)
}
)
else:
dic[key] = [
dict(
layers=item.get(self.layers),
backward_prefetch=item.get(self.backward_prefetch)
)
{
'layers': item.get(self.layers),
'backward_prefetch': item.get(self.backward_prefetch)
}
]
return dic

@@ -507,9 +502,9 @@ class ContextParallelAlgo(Enum):
Args:
Enum (str): chosses context parallel type
"""
colossalai_cp = "colossalai_cp"
ulysses_cp = "ulysses_cp"
hybrid_cp = "hybrid_cp"
COLOSSALAI_CP = "colossalai_cp"
ULYSSES_CP = "ulysses_cp"
HYBRID_CP = "hybrid_cp"


default_transformer_swap_config = TransformerSwapConfig()
@@ -601,7 +596,7 @@ class TransformerOpParallelConfig(_Config):
ValueError: in hybrid_cp algorithm, context_parallel should be divisible by ulysses_degree_in_cp
"""
if self.context_parallel == 1:
if self.context_parallel_algo != ContextParallelAlgo.colossalai_cp:
if self.context_parallel_algo != ContextParallelAlgo.COLOSSALAI_CP:
logger.warning(f"context_parallel_algo {self.context_parallel_algo.value} will not take effect "
"when context_parallel == 1.")
if self.ulysses_degree_in_cp > 1:
@@ -610,10 +605,10 @@ class TransformerOpParallelConfig(_Config):
return

# here context parallel > 1
if self.context_parallel_algo != ContextParallelAlgo.hybrid_cp and self.ulysses_degree_in_cp > 1:
if self.context_parallel_algo != ContextParallelAlgo.HYBRID_CP and self.ulysses_degree_in_cp > 1:
logger.warning(f"ulysses_degree_in_cp {self.ulysses_degree_in_cp} will not take effect when "
f"context_parallel_algo {self.context_parallel_algo.value} is not `hybrid_cp`.")
if (self.context_parallel_algo == ContextParallelAlgo.hybrid_cp and
if (self.context_parallel_algo == ContextParallelAlgo.HYBRID_CP and
self.context_parallel % self.ulysses_degree_in_cp != 0):
raise ValueError(f"When using hybrid_cp algorithm, context_parallel {self.context_parallel} "
f"should be divisible by ulysses_degree_in_cp {self.ulysses_degree_in_cp}. "
@@ -627,9 +622,9 @@ class TransformerOpParallelConfig(_Config):
"""
if self.context_parallel == 1:
return 1
if self.context_parallel_algo == ContextParallelAlgo.colossalai_cp:
if self.context_parallel_algo == ContextParallelAlgo.COLOSSALAI_CP:
return 1
if self.context_parallel_algo == ContextParallelAlgo.ulysses_cp:
if self.context_parallel_algo == ContextParallelAlgo.ULYSSES_CP:
return self.context_parallel
# hybird
return self.ulysses_degree_in_cp
@@ -786,260 +781,6 @@ class TransformerOpParallelConfig(_Config):
default_transformer_config = TransformerOpParallelConfig()


class FeedForward(Cell):
r"""
The multilayer perceptron with two linear layers with dropout applied at final output. The first linear
will project the input dimension from hidden_size to ffn_hidden_size. The second linear will project the
dimension from ffn_hidden_size to hidden_size. The first linear is sharded on the relative dimension,
and the second linear is sharded on the output dimension. The overview process can be:

.. math::
Dropout((xW_1+b_1)W_2 + b_2)

where the :math:`W_1, W_2, b_1` and :math:`b_2` are trainable parameters.

Args:
hidden_size (int): The dimension of the inputs.
ffn_hidden_size (int): The intermediate hidden size.
dropout_rate (float): The dropout rate for the second linear's output.
hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu',
'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument.
If user wants to run the net in the parallel mode, the custom activation must also provide
the `activation_shard` function. Please see examples. Default: gelu.
expert_num (int): The number of experts used in Linear. For the case expert_num > 1, BatchMatMul is used
and the first dimension in BatchMatMul indicate expert_num. Default: 1.
expert_group_size (int): The number of tokens in each data parallel group. Default: None. This parameter is
effective only when in AUTO_PARALLEL mode, and NOT SHARDING_PROPAGATION.
param_init_type (dtype.Number): The parameter initialization type. Should be mstype.float32 or
mstype.float16. Default: mstype.float32.
parallel_config (OpParallelConfig, MoEParallelConfig): The config of parallel setting, see
`OpParallelConfig` or `MoEParallelConfig`. When MoE is applied, MoEParallelConfig is effective,
otherwise OpParallelConfig is effective. Default `default_dpmp_config`,
an instance of `OpParallelConfig` with default args.

Inputs:
- **x** (Tensor) - should be `[batch, seq_length, hidden_size] or [batch * seq_length, hidden_size]`.
Float tensor.

Outputs:
Tensor, the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size] or
[batch * seq_length, hidden_size]`.

Raises:
TypeError: `hidden_act` is not a string or nn.Cell.
TypeError: `parallel_config` is not a subclass of OpParallelConfig.
ValueError: `ffn_hidden_size` is not a multiple of the model parallel way.
ValueError: `hidden_size` is not a multiple of the model parallel way.

Supported Platforms:
``Ascend`` ``GPU``

Examples:
>>> import numpy as np
>>> from mindformers.modules.transformer import FeedForward
>>> from mindspore import dtype as mstype
>>> from mindspore import Tensor, nn
>>> import mindspore.ops as ops
>>> model = FeedForward(hidden_size=15, ffn_hidden_size=30, dropout_rate=0.1)
>>> tensor = Tensor(np.ones((2, 20, 15)), mstype.float32)
>>> output = model(tensor)
>>> print(output.shape)
(2, 20, 15)
>>> # Example 2 using custom hidden activation
>>> class MyActivationNoShard(nn.Cell):
... def __init__(self):
... super(MyActivationNoShard, self).__init__()
... self.add = ops.Add()
... def construct(self, x):
... return self.add(x, 0.1)
>>> model = FeedForward(hidden_size=15, ffn_hidden_size=30, dropout_rate=0.1,
... hidden_act=MyActivationNoShard)
>>> tensor = Tensor(np.ones((2, 20, 15)), mstype.float32)
>>> output = model(tensor)
>>> print(output.shape)
(2, 20, 15)
>>> # Example 3 using custom hidden activation with activation_shard
>>> # If user wantss to run on the SEMI/AUTO parallel mode, the custom activation must provide
>>> # a class function named activation_shard. It accepts the argument parallel_config (OpParallelConfig,
>>> # MoEParallelConfig) and set the shard for the primitives used in the construct.
>>> class MyActivationWithShard(nn.Cell):
... def __init__(self):
... super(MyActivationWithShard, self).__init__()
... self.add = ops.Add()
... def construct(self, x):
... return self.add(x, 0.1)
... def activation_shard(self, parallel_config):
... self.add.shard(((parallel_config.data_parallel, parallel_config.model_parallel), ()))
>>>
>>> model = FeedForward(hidden_size=15, ffn_hidden_size=30, dropout_rate=0.1,
... hidden_act=MyActivationWithShard)
>>> tensor = Tensor(np.ones((2, 20, 15)), mstype.float32)
>>> output = model(tensor)
>>> print(output.shape)
(2, 20, 15)
"""

@_LogActionOnce(m_logger=logger, key='FeedForward',
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
@_args_type_validator_check(hidden_size=Validator.check_positive_int,
ffn_hidden_size=Validator.check_positive_int,
dropout_rate=Validator.check_non_negative_float,
param_init_type=_valid_value_checks([mstype.float32, mstype.bfloat16, mstype.float16],
"FeedForward"),
compute_dtype=_valid_value_checks([mstype.float32, mstype.bfloat16, mstype.float16],
"FeedForward"),
parallel_config=_valid_type_checks([OpParallelConfig, MoEParallelConfig],
"FeedForward"))
def __init__(self, hidden_size,
ffn_hidden_size,
dropout_rate,
hidden_act='gelu',
expert_num=1,
expert_group_size=None,
param_init_type=mstype.float32,
parallel_config=default_dpmp_config,
compute_dtype=mstype.float16):
super(FeedForward, self).__init__()
self.dtype = compute_dtype
if hidden_act is None or not (isinstance(hidden_act, str) or issubclass(hidden_act, nn.Cell)):
raise TypeError(f"For FeedForward cell, the hidden_act should str type or nn.Cell type, "
f"but got {hidden_act}.")
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
_check_config(parallel_config)
mp = parallel_config.model_parallel
if expert_num > 1:
ep = parallel_config.expert_parallel
else:
ep = 1
# ffn use less dp than other ops when use_moe, due to there are ops use dp and ep.
dp = parallel_config.data_parallel // ep
if ffn_hidden_size % mp != 0:
raise ValueError("For 'FeedForward', the class variable 'ffn_hidden_size' must be a multiple of the"
"num of model parallel, but got the ffn_hidden_size is {} and the num of model "
"parallel is {}.".format(ffn_hidden_size, mp))
if hidden_size % mp != 0:
raise ValueError("For 'FeedForward', the class variable 'hidden_size' must be a multiple of the num of "
"model parallel, but got the hidden_size is {} and the num of model parallel is {}."
.format(hidden_size, mp))
if dropout_rate < 0 or dropout_rate >= 1:
raise ValueError("For 'FeedForward', the class variable 'dropout_rate' must be in the range [0, 1.0), "
"but got the value : {}.".format(dropout_rate))
input_size = hidden_size
output_size = ffn_hidden_size

# Project to ffn_hidden_size
self.mapping = Linear(in_channels=input_size,
out_channels=output_size,
activation=hidden_act,
transpose_b=False,
expert_num=expert_num,
expert_group_size=expert_group_size,
outer_batch=dp,
param_init_type=param_init_type,
compute_dtype=compute_dtype)

# Project back to hidden_size
self.projection = Linear(in_channels=output_size,
out_channels=input_size,
transpose_b=False,
expert_num=expert_num,
expert_group_size=expert_group_size,
outer_batch=dp,
param_init_type=param_init_type,
compute_dtype=compute_dtype)
if expert_num > 1:
self.projection.shard(strategy_matmul=((dp, ep, 1, mp), (ep, mp, 1)))
else:
self.projection.shard(strategy_matmul=((dp, mp), (mp, 1)))
self.projection.bias.parallel_optimizer = False
self.dropout = get_dropout(dropout_rate)
self.dropout_3d = get_dropout(dropout_rate)
self.dropout_4d = get_dropout(dropout_rate)
self.cast = P.Cast()
else:
_check_config(parallel_config)
mp = parallel_config.model_parallel
if expert_num > 1:
ep = parallel_config.expert_parallel
else:
ep = 1
# ffn use less dp than other ops when use_moe, due to there are ops use dp and ep.
dp = parallel_config.data_parallel // ep
if ffn_hidden_size % mp != 0:
raise ValueError("For 'FeedForward', the class variable 'ffn_hidden_size' must be a multiple of the"
"num of model parallel, but got the ffn_hidden_size is {} and the num of model "
"parallel is {}.".format(ffn_hidden_size, mp))
if hidden_size % mp != 0:
raise ValueError("For 'FeedForward', the class variable 'hidden_size' must be a multiple of the num of "
"model parallel, but got the hidden_size is {} and the num of model parallel is {}."
.format(hidden_size, mp))
if dropout_rate < 0 or dropout_rate >= 1:
raise ValueError("For 'FeedForward', the class variable 'dropout_rate' must be in the range [0, 1.0), "
"but got the value : {}.".format(dropout_rate))
input_size = hidden_size
output_size = ffn_hidden_size

# Project to ffn_hidden_size
self.mapping = Linear(in_channels=input_size,
out_channels=output_size,
activation=hidden_act,
transpose_b=False,
expert_num=expert_num,
expert_group_size=expert_group_size,
outer_batch=dp,
param_init_type=param_init_type,
compute_dtype=compute_dtype)

if expert_num > 1:
self.mapping.shard(strategy_matmul=((dp, ep, 1, 1), (ep, 1, mp)),
strategy_bias=((dp, ep, 1, mp), (1, ep, 1, mp)),
strategy_activation=((dp, ep, 1, mp),))
else:
self.mapping.shard(strategy_matmul=((dp, 1), (1, mp)),
strategy_bias=((dp, mp), (mp,)),
strategy_activation=((dp, mp),))
# Project back to hidden_size
self.projection = Linear(in_channels=output_size,
out_channels=input_size,
transpose_b=False,
expert_num=expert_num,
expert_group_size=expert_group_size,
outer_batch=dp,
param_init_type=param_init_type,
compute_dtype=compute_dtype)
if expert_num > 1:
self.projection.shard(strategy_matmul=((dp, ep, 1, mp), (ep, mp, 1)),
strategy_bias=((dp, ep, 1, 1), (1, ep, 1, 1)))
else:
self.projection.shard(strategy_matmul=((dp, mp), (mp, 1)),
strategy_bias=((dp, 1), (1,)))
self.projection.bias.parallel_optimizer = False
self.dropout = get_dropout(dropout_rate)
self.dropout_3d = get_dropout(dropout_rate)
self.dropout_4d = get_dropout(dropout_rate)
self.dropout.dropout.shard(((dp, 1),))
self.dropout_3d.dropout.shard(((dp, 1, 1),))
self.dropout_4d.dropout.shard(((dp, ep, 1, 1),))
self.cast = P.Cast()

def construct(self, x):
"""Forward process of the FeedForward"""
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name)
x = self.cast(x, self.dtype)
# returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size]
hidden = self.mapping(x)
output = self.projection(hidden)
# returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size]
if len(F.shape(output)) == 3:
output = self.dropout_3d(output)
elif len(F.shape(output)) == 2:
output = self.dropout(output)
else:
output = self.dropout_4d(output)
return output


class LowerTriangularMaskWithDynamic(Cell):
r"""
Get the Strictly Lower triangular matrix from the input_ids.


+ 4
- 2
mindformers/parallel_core/inference/model_utils.py View File

@@ -195,7 +195,7 @@ class InferModelMixin(ModelMixin):
def _safetensors_weights_iterator(self, weights_files: List[str]) -> Generator[Tuple[str, Any], None, None]:
"""Iterate over the weights in the model safetensor files."""
rank_id = get_tensor_model_parallel_rank()
is_main_rank = (rank_id == 0)
is_main_rank = rank_id == 0
for st_file in tqdm(
weights_files,
desc=f"[Rank {rank_id}] Loading safetensors checkpoint shards",
@@ -236,7 +236,9 @@ class InferModelMixin(ModelMixin):
"""
mapping_rules = {
'.linear_q_down_proj.': ('.linear_qkv_down_proj.', '.linear_q_down_proj.', 'q_down'),
'.linear_kv_down_proj.': ('.linear_qkv_down_proj.', '.linear_kv_down_proj.', 'kv_down'),
'.linear_kv_down_proj.': ('.linear_qkv_down_proj.', '.linear_kv_down_proj.', 'kv_down')
if getattr(self.config, 'q_lora_rank', None) is not None else
('.linear_kv_down_proj.', '.linear_kv_down_proj.', 'kv_down'),
'.linear_q_up_proj.': ('.linear_q_up_proj.', '.linear_q_up_proj.', 'q_up'),
'.linear_kv_up_proj.': ('.linear_kv_up_proj.', '.linear_kv_up_proj.', 'kv_up'),
'.linear_q.': ('.linear_qkv.', '.linear_q.', 'q'),


+ 3
- 19
mindformers/parallel_core/inference/quantization/golden_stick/a8dynw4.py View File

@@ -20,7 +20,7 @@ import numpy as np
import mindspore
from mindspore import nn, Parameter, ops, mint
from mindspore.common.initializer import initializer
from mindspore.ops.auto_generate import WeightQuantBatchMatmul, DynamicQuantExt, GroupedMatmulV4
from mindspore.ops.auto_generate import DynamicQuantExt, GroupedMatmulV4

from mindformers.parallel_core.inference.weights_utils import set_weight_attrs
from mindformers.parallel_core.inference.transformer.moe.experts import GroupedMLP
@@ -77,20 +77,7 @@ class A8W4DynamicLinearMethod(LinearMethodBase):
layer.insert_param_to_cell("gmm_bias", gmm_bias)

else:
self.matmul = WeightQuantBatchMatmul(False, True, group_size)
weight_shape = (self.output_size_per_partition, self.input_size_per_partition)
weight = Parameter(initializer('ones', weight_shape, mindspore.int8), requires_grad=False)

w_scale_shape = (output_size_per_partition,)
w_scale_dtype = mindspore.bfloat16 if params_dtype == mindspore.bfloat16 else mindspore.float32
w_scale = Parameter(
initializer('ones', w_scale_shape, w_scale_dtype), name="w_scale", requires_grad=False)

set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
set_weight_attrs(w_scale, {"output_dim": 0})

set_weight_attrs(weight, extra_weight_attrs)
set_weight_attrs(w_scale, extra_weight_attrs)
raise ValueError("A8W4DynamicQuant is now only support for MOE")

if layer is not None:
layer.insert_param_to_cell("weight", weight)
@@ -143,10 +130,7 @@ class A8W4DynamicLinearMethod(LinearMethodBase):
group_type=0,
group_list_type=1)[0]
else:
w_scale = ops.cast(w_scale, mindspore.float16)
qx = ops.cast(qx, mindspore.float16)
out = self.matmul(qx, weight, w_scale, None, None, None, None)
out = ops.mul(out, qx_scale.unsqueeze(1))
raise ValueError("A8W4DynamicQuant is now only support for MOE")
if bias is not None:
out = self.bias_add(out, bias)
out = out.reshape(output_shape)


+ 28
- 25
mindformers/parallel_core/inference/tensor_parallel/grouped_layers.py View File

@@ -84,6 +84,7 @@ class UnquantizedGroupedLinearMethod(GroupedLinearMethodBase):
self.cast = P.Cast()
self.matmul = ops.auto_generate.GroupedMatmulV4()

# pylint: disable=W0237
def create_weights(self, layer: nn.Cell, num_local_experts: int,
input_size_per_partition: int, output_partition_sizes: list[int],
params_dtype, **extra_weight_attrs):
@@ -216,17 +217,17 @@ class ColumnParallelGroupedLinear(GroupedLinearBase):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""
):
super(ColumnParallelGroupedLinear, self).__init__(num_local_experts,
input_size,
output_size,
skip_bias_add,
config.params_dtype,
quant_config=quant_config,
prefix=prefix)
super().__init__(num_local_experts,
input_size,
output_size,
skip_bias_add,
config.params_dtype,
quant_config=quant_config,
prefix=prefix)
if stride > 1:
raise NotImplementedError(
"For ColumnParallelGroupedLinear, `stride > 1` is not supported for now, "
"but got `stride={}`".format(stride))
f"For ColumnParallelGroupedLinear, `stride > 1` is not supported for now, "
f"but got `stride={stride}`")
if skip_bias_add:
raise NotImplementedError(
"For ColumnParallelGroupedLinear, `skip_bias_add=True` is not supported for now."
@@ -275,6 +276,7 @@ class ColumnParallelGroupedLinear(GroupedLinearBase):
else:
self.bias = None

# pylint: disable=W0237
def construct(self, input_parallel, weight=None, group_list=None):
"""Forward of ColumnParallelGroupedLinear."""
if weight is None:
@@ -386,15 +388,15 @@ class ColumnParallelGroupedLinear(GroupedLinearBase):


class RowParallelGroupedLinear(GroupedLinearBase):
r"""
"""
The group linear layer with weight sliced on first dimension by tensor parallel size.
This layer implements the operation as:

.. math::
\text{outputs} = \text{inputs} * \text{weight} + \text{bias},
\\text{outputs} = \\text{inputs} * \\text{weight} + \\text{bias},

where :math:`inputs` is the input tensors, :math:`\text{weight}` is a weight matrix created by the layer,
and :math:`\text{bias}` is a bias vector created by the layer (only if has_bias is True).
where :math:`inputs` is the input tensors, :math:`\\text{weight}` is a weight matrix created by the layer,
and :math:`\\text{bias}` is a bias vector created by the layer (only if has_bias is True).

Args:
num_local_experts (int): The number of local expert.
@@ -416,11 +418,11 @@ class RowParallelGroupedLinear(GroupedLinearBase):
prefix (str): The prefix string for this linear layer. Default: empty string("").

Inputs:
- **x** (Tensor) - Tensor of shape :math:`(*, in\_channels)`. The `input_size` in `Args` should be equal
to :math:`in\_channels` in `Inputs`.
- **x** (Tensor) - Tensor of shape :math:`(*, in\\_channels)`. The `input_size` in `Args` should be equal
to :math:`in\\_channels` in `Inputs`.

Outputs:
Tensor of shape :math:`(*, out\_channels)`.
Tensor of shape :math:`(*, out\\_channels)`.

Supported Platforms:
``Ascend``
@@ -445,17 +447,17 @@ class RowParallelGroupedLinear(GroupedLinearBase):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""
):
super(RowParallelGroupedLinear, self).__init__(num_local_experts,
input_size,
output_size,
skip_bias_add,
config.params_dtype,
quant_config=quant_config,
prefix=prefix)
super().__init__(num_local_experts,
input_size,
output_size,
skip_bias_add,
config.params_dtype,
quant_config=quant_config,
prefix=prefix)
if stride > 1:
raise NotImplementedError(
"For RowParallelGroupedLinear, `stride > 1` is not supported for now, "
"but got `stride={}`".format(stride))
f"For RowParallelGroupedLinear, `stride > 1` is not supported for now, "
f"but got `stride={stride}`")
if not is_expert:
raise NotImplementedError(
"For RowParallelGroupedLinear, `is_expert=False` is not supported for now.")
@@ -502,6 +504,7 @@ class RowParallelGroupedLinear(GroupedLinearBase):
else:
self.bias = None

# pylint: disable=W0237
def construct(self, input_, weight=None, group_list=None):
"""Forward of RowParallelGroupedLinear."""
if weight is None:


+ 10
- 10
mindformers/parallel_core/inference/utils.py View File

@@ -20,12 +20,15 @@ __all__ = [
"update_comm_config",
]

import os
import stat
from contextlib import contextmanager
import numpy as np

import mindspore as ms
from mindspore import Tensor, ops, Parameter, mint
from mindspore.communication import get_group_size
from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy

from mindformers.version_control import is_310p
from mindformers.parallel_core.transformer_config import TransformerConfig
@@ -65,7 +68,7 @@ ATTNMASK_FUNC_MAP = {


def get_attn_mask_func(mask_func_type):
r"""
"""
Get attention mask function.

Args:
@@ -75,9 +78,9 @@ def get_attn_mask_func(mask_func_type):
Function, the attention mask function.
"""
if mask_func_type not in ATTNMASK_FUNC_MAP:
raise KeyError("Invalid attention mask function. Supported attention "
"mask function are ['attn_mask_fill', 'attn_mask_add'] "
", but got {}.".format(mask_func_type))
raise KeyError(f"Invalid attention mask function. Supported attention "
f"mask function are ['attn_mask_fill', 'attn_mask_add'] "
f", but got {mask_func_type}.")
return ATTNMASK_FUNC_MAP[mask_func_type]


@@ -158,7 +161,7 @@ def create_empty_parameter(shape, *, dtype=None, device=None, **kwargs):
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
if numerator % denominator != 0:
raise ValueError("{} is not divisible by {}".format(numerator, denominator))
raise ValueError(f"{numerator} is not divisible by {denominator}")


def divide(numerator, denominator):
@@ -178,10 +181,6 @@ def save_strategy_file(state_dict, strategy_file_name):
Supported Platforms:
``Ascend``
"""
import os
import stat
from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy

stra = ckpt_strategy()

stage_rank_size = state_dict["stage_rank_size"]
@@ -361,12 +360,13 @@ def get_num_layers_and_offset(config):
return int(layer_list[pp_rank]), int(sum(layer_list[:pp_rank]))
return num_layers, 0


def use_ms_custom_ops():
"""
Determine whether has custom ops
"""
try:
# pylint: disable=W0611
# pylint: disable=W0611, C0415
import ms_custom_ops
except ModuleNotFoundError:
# environment need install ms_custom_ops package


+ 14
- 14
mindformers/parallel_core/training_graph/base_models/common/embeddings/rope_utils.py View File

@@ -189,24 +189,24 @@ class ApplyRotaryPosEmb(nn.Cell):
self.split.shard((layout("cp", "dp", "tp", "None"),))
self.neg.shard((layout("cp", "dp", "tp", "None"),))
self.add.shard((layout("cp", "dp", "tp", "None"), layout("cp", "dp", "tp", "None")))
self.mul.shard(in_strategy=(layout("None", "dp", "tp", "None"), layout("None", "None", "None", "None")))
self.slice.shard(in_strategy=(layout("None", "dp", "tp", "None"),),
out_strategy=(layout("None", "dp", "tp", "None"),))
self.strideslice.shard(in_strategy=(layout("None", "dp", "tp", "None"),),
out_strategy=(layout("None", "dp", "tp", "None"),))
self.cat.shard(in_strategy=((layout("None", "dp", "tp", "None"), layout("None", "dp", "tp", "None")),),
out_strategy=(layout("None", "dp", "tp", "None"),))
self.mul.shard(in_strategy=(layout("cp", "dp", "tp", "None"), layout("cp", "None", "None", "None")))
self.slice.shard(in_strategy=(layout("cp", "dp", "tp", "None"),),
out_strategy=(layout("cp", "dp", "tp", "None"),))
self.strideslice.shard(in_strategy=(layout("cp", "dp", "tp", "None"),),
out_strategy=(layout("cp", "dp", "tp", "None"),))
self.cat.shard(in_strategy=((layout("cp", "dp", "tp", "None"), layout("cp", "dp", "tp", "None")),),
out_strategy=(layout("cp", "dp", "tp", "None"),))
else:
self.split.shard((layout("cp", "dp", "None", "None"),))
self.neg.shard((layout("cp", "dp", "None", "None"),))
self.add.shard((layout("cp", "dp", "None", "None"), layout("cp", "dp", "None", "None")))
self.mul.shard(in_strategy=(layout("None", "dp", "None", "None"), layout("None", "None", "None", "None")))
self.slice.shard(in_strategy=(layout("None", "dp", "None", "None"),),
out_strategy=(layout("None", "dp", "None", "None"),))
self.strideslice.shard(in_strategy=(layout("None", "dp", "None", "None"),),
out_strategy=(layout("None", "dp", "None", "None"),))
self.cat.shard(in_strategy=((layout("None", "dp", "None", "None"), layout("None", "dp", "None", "None")),),
out_strategy=(layout("None", "dp", "None", "None"),))
self.mul.shard(in_strategy=(layout("cp", "dp", "None", "None"), layout("cp", "None", "None", "None")))
self.slice.shard(in_strategy=(layout("cp", "dp", "None", "None"),),
out_strategy=(layout("cp", "dp", "None", "None"),))
self.strideslice.shard(in_strategy=(layout("cp", "dp", "None", "None"),),
out_strategy=(layout("cp", "dp", "None", "None"),))
self.cat.shard(in_strategy=((layout("cp", "dp", "None", "None"), layout("cp", "dp", "None", "None")),),
out_strategy=(layout("cp", "dp", "None", "None"),))

if self.apply_rope_fusion:
self.rope.shard(in_strategy=(layout("cp", "dp", "tp", "None"),


+ 139
- 37
mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py View File

@@ -15,10 +15,12 @@
"""mindformers GPT model"""
__all__ = ['GPTModel']

import hashlib
from typing import Literal, Optional, Union
import numpy as np

import mindspore as ms
from mindspore.communication import create_group, get_group_size, get_rank
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.ops import auto_generate as aclnn_ops
@@ -29,7 +31,8 @@ from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagati
from mindspore import ops

from mindformers.parallel_core.training_graph.loss_func import CrossEntropyLoss
from mindformers.parallel_core.training_graph.transformer.multi_token_prediction import MultiTokenPredictionBlock
from mindformers.parallel_core.training_graph.transformer.multi_token_prediction import MultiTokenPredictionBlock, \
func_infer_dtype, func_infer_shape, func_infer_shape_labels_and_masks
from mindformers.parallel_core.training_graph.device_matrix import layout
from mindformers.parallel_core.utils.spec_utils import ModuleSpec
from mindformers.parallel_core.training_graph.transformer.mask_generate import CausalMaskGenerate
@@ -54,26 +57,60 @@ from mindformers.tools.logger import logger
from mindformers.models.utils import get_current_rank_stage, get_model_parameters
from mindformers.version_control import get_lazy_inline as lazy_inline
from mindformers.core.optim.muon_utils import make_muon_fns
from mindformers.checkpoint.sharded_tensor import ShardedTensor


def compute_repeat_num_and_model_parallel_size(sharded_info: ShardedTensor, world_size: int, pp: int, op: int):
"""Compute real op size."""
axis_fragmentations = sharded_info.axis_fragmentations
flag = False
weight_sharded_size = 1
for axis in axis_fragmentations:
if axis == 1:
continue
if flag:
raise ValueError("Only one axis can be fragmented in Muon optimizer.")
flag = True
weight_sharded_size *= axis
repeat_num = world_size // pp // weight_sharded_size
real_op_size = min(op, repeat_num)
if sharded_info.local_shape[0] % real_op_size != 0:
real_op_size = 1
return real_op_size, weight_sharded_size


def create_communication_group(rank_list):
"""
Create a communication group with a hashed name.

Args:
rank_list: List of ranks in the communication group

Returns:
str: The created group name
"""
rank_list_str = "-".join([str(i) for i in rank_list])
hashed = hashlib.md5(rank_list_str.encode()).hexdigest()[:48]
group_name = str(hashed)
create_group(group_name, rank_list)
return group_name

def func_infer_dtype(*args):
"""infer_dtype for Morph."""
return args[0]

OP_GROUP_NAME = {}

def func_infer_shape(*args):
"""infer_shape for Morph."""
input_shape = args[0]
shape_value = np.prod(input_shape[:-1])
output_shape = [int(shape_value), args[0][-1]]
return output_shape

def func_infer_shape_labels_and_masks(*args):
"""infer_shape for Morph."""
input_shape = args[0]
shape_value = np.prod(input_shape)
output_shape = [int(shape_value)]
return output_shape
def get_op_group_name(rank_id: int, real_op_size: int, model_parallel_size: int):
"""Get op group name."""
if (rank_id, real_op_size, model_parallel_size) in OP_GROUP_NAME:
return OP_GROUP_NAME[(rank_id, real_op_size, model_parallel_size)]
dp_range = model_parallel_size
op_range = model_parallel_size * real_op_size
rank_start = rank_id % dp_range + rank_id // op_range * op_range
rank_end = rank_start + op_range
rank_list = list(range(rank_start, rank_end, dp_range))
op_group_name = create_communication_group(rank_list)
OP_GROUP_NAME[(rank_id, real_op_size, model_parallel_size)] = (op_group_name, rank_list)
return op_group_name, rank_list


class PreprocessLabelsAndMasks(nn.Cell):
@@ -660,18 +697,45 @@ class GPTModel(nn.Cell):
if hasattr(self.decoder.layers[i].mlp, "router"):
self.assign_aux(self.decoder.layers[i].mlp.router.fi_accu, self.zeros_tensor)

def reset_max_attention_logit(self,):
"""Reset max attention logit for all layers."""
def _iter_core_attentions(self):
"""Iterate over all core_attention modules with their param names.

Yields:
Tuple[str, module]: A tuple of (param_name, core_attention_module).
"""
num_layers = self.config.num_layers
mtp_num_layers = self.config.mtp_num_layers

for i in range(num_layers):
if hasattr(self.decoder.layers[i].self_attention.core_attention, "max_logits_val"):
param = self.decoder.layers[i].self_attention.core_attention.max_logits_val
F.assign(param, F.zeros_like(param))
core_attn = self.decoder.layers[i].self_attention.core_attention
yield f"decoder.layers.{i}.self_attention.core_attention", core_attn

for i in range(mtp_num_layers):
if hasattr(self.mtp.layers[i].transformer_layer.self_attention.core_attention, "max_logits_val"):
param = self.mtp.layers[i].transformer_layer.self_attention.core_attention.max_logits_val
core_attn = self.mtp.layers[i].transformer_layer.self_attention.core_attention
yield f"mtp.layers.{i}.transformer_layer.self_attention.core_attention", core_attn

def get_max_attention_logit(self):
"""Get max attention logit values for all layers.

Returns:
dict: A dictionary mapping parameter names to their max logit values.
Only includes layers with valid (sum > 0) max_logits_val.
"""
max_logits = {}
for param_name, core_attn in self._iter_core_attentions():
if not hasattr(core_attn, "max_logits_val"):
continue
param = core_attn.max_logits_val.value()
if param.sum() <= 0:
continue
max_logits[f"{param_name}.max_logits_val"] = param
return max_logits

def reset_max_attention_logit(self):
"""Reset max attention logit to zeros for all layers."""
for _, core_attn in self._iter_core_attentions():
if hasattr(core_attn, "max_logits_val"):
param = core_attn.max_logits_val
F.assign(param, F.zeros_like(param))

def shard(self, config: TransformerConfig):
@@ -723,6 +787,14 @@ class GPTModel(nn.Cell):
def sharding_propagation(self, config: TransformerConfig):
pass

def sharded_state_dict(self):
"""Get all sharded state dict."""
sharded_state_dict = {}
for _, sub_cell in self.cells_and_names():
if sub_cell != self and hasattr(sub_cell, "sharded_state_dict"):
sharded_state_dict.update(sub_cell.sharded_state_dict())
return sharded_state_dict

def get_model_parameters(self):
"""Get current rank trainable parameters in gpt model ."""
params = set()
@@ -849,7 +921,7 @@ class GPTModel(nn.Cell):
tp_dims.append(0)
return tuple(tp_dims)

def get_op_groups_info(self, params, op, op_group, op_in_tp_group):
def get_op_groups_info(self, params, op):
"""Return optimizer parallel group information for each parameter.
Args:
@@ -868,16 +940,15 @@ class GPTModel(nn.Cell):
"self_attention.linear_q_down_proj.weight",
"self_attention.linear_kv_up_proj.weight",
"self_attention.linear_kv_down_proj.weight",
"eh_proj"
]

use_tp_group_list = [
"mlp.router.weight",
"mlp.shared_experts.linear_fc",
"self_attention.linear_q_down_proj.weight",
"eh_proj",
"max_logits_val"
]

sharded_state_dict = self.sharded_state_dict()
world_size = get_group_size()
ep = self.config.expert_model_parallel_size
pp = self.config.pipeline_model_parallel_size

def name_filter(param_name, full_name_list):
for full_name in full_name_list:
if full_name in param_name:
@@ -897,13 +968,44 @@ class GPTModel(nn.Cell):
logger.warning(
f"Parameter {param.name}: parallel_optimizer was set to False due to the use of Muon optimizer."
)
else:
op_list.append(op)
continue

if name_filter(param.name, use_tp_group_list):
op_groups.append(op_in_tp_group)
else:
op_groups.append(op_group)
# compute real op size
sharded_info = sharded_state_dict.get(param.name)
real_op_size, weight_sharded_size = compute_repeat_num_and_model_parallel_size(sharded_info, world_size, pp,
op)
if real_op_size == 1:
op_list.append(1)
op_groups.append("")
logger.info(f"Parameter {param.name} : No op group.")
continue

op_list.append(real_op_size)
op_group_name, rank_list = get_op_group_name(get_rank(), real_op_size, weight_sharded_size)
logger.info(f"Parameter {param.name} : Muon real_op_size={real_op_size} group list is: {rank_list}")
op_groups.append(op_group_name)

# check if op is valid for expert
for param, real_op_size in zip(params, op_list):
if "mlp.experts.weight1" not in param.name:
continue
# Validate MoE expert counts divisibility constraint:
# num_moe_experts must be divisible by (optimizer_weight_shard_size * expert_model_parallel_size)
num_moe_experts = self.config.num_moe_experts
if bool(num_moe_experts and num_moe_experts > 0):
if num_moe_experts % (real_op_size * ep) != 0:
error_msg = (f"Invalid configuration: 'num_moe_experts' ({num_moe_experts}) must be divisible by "
f"'real_op_size * expert_model_parallel_size' ({real_op_size} * "
f"{ep} = {real_op_size * ep}).\n"
f"Hint:\n"
f" Although you set `optimizer_weight_shard_size={op}`, the maximum optimizer shard size "
f"for `{param.name}` is `{real_op_size}`. Try reducing 'optimizer_weight_shard_size'.")
logger.error(error_msg)
raise ValueError(
error_msg
)
# All expert weights share the same real_op_size, so we only need to check once
break

return tuple(op_list), tuple(op_groups)



+ 23
- 26
mindformers/parallel_core/training_graph/loss_func.py View File

@@ -40,26 +40,23 @@ _device_local_loss = {}

def get_device_local_loss(tag="lm"):
"""Get `_device_local_loss` Parameter after init"""
global _device_local_loss
if tag is None:
return _device_local_loss
if _device_local_loss.get(tag, None) is None:
_device_local_loss[tag] = Parameter(
Tensor([0.0], mstype.float32), name=f"_device_local_loss", requires_grad=False
Tensor([0.0], mstype.float32), name="_device_local_loss", requires_grad=False
)
return _device_local_loss[tag]


def reset_device_local_loss():
"""Reset `_device_local_loss` parameter to zero"""
global _device_local_loss
for _, loss in _device_local_loss.items():
F.assign(loss, Tensor([0.0], mstype.float32))


def check_device_local_loss():
"""check if Nan or Inf in `_device_local_loss` parameter then terminate training"""
global _device_local_loss
if not _device_local_loss:
return
for tag, device_local_loss in _device_local_loss.items():
@@ -88,7 +85,7 @@ class _LogSoftmax(nn.Cell):
The corresponding log softmax results.
"""
def __init__(self, config: TransformerConfig = default_transformer_config):
super(_LogSoftmax, self).__init__()
super().__init__()
dp = config.data_parallel_size
mp = config.tensor_model_parallel_size
cp = config.context_parallel_size
@@ -143,7 +140,7 @@ class _NLLLoss(nn.Cell):
The corresponding loss results.
"""
def __init__(self, config: TransformerConfig = default_transformer_config):
super(_NLLLoss, self).__init__()
super().__init__()
dp = config.data_parallel_size
mp = config.tensor_model_parallel_size
cp = config.context_parallel_size
@@ -176,7 +173,7 @@ class _NLLLoss(nn.Cell):


class CrossEntropyLoss(nn.Cell):
r"""
"""
Calculate the cross entropy loss.

CrossEntropyLoss supports two different types of targets:
@@ -185,9 +182,9 @@ class CrossEntropyLoss(nn.Cell):
When reduction is set to 'none', the cross-entropy loss is computed as follows:

.. math::
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
l_n = - w_{y_n} \log \frac{\exp(x_{n,y_n})}{\sum_{c=1}^C \exp(x_{n,c})}
\cdot \mathbb{1}\{y_n \not= \text{ignore_index}\}
\\ell(x, y) = L = \\{l_1,\\dots,l_N\\}^\top, \\quad
l_n = - w_{y_n} \\log \\frac{\\exp(x_{n,y_n})}{\\sum_{c=1}^C \\exp(x_{n,c})}
\\cdot \\mathbb{1}\\{y_n \\not= \\text{ignore_index}\\}

where :math:`x` denotes the predicted values, :math:`t` denotes the target values, :math:`w` denotes the weights,
and :math:`N` is the batch size. The index :math:`c` ranges from [0, C-1], representing the class indices,
@@ -196,19 +193,19 @@ class CrossEntropyLoss(nn.Cell):
If reduction is not set to 'none' (the default is 'mean'), the loss is computed as:

.. math::
\ell(x, y) = \begin{cases}
\sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n} \cdot \mathbb{1}\{y_n \not= \text{ignore_index}\}} l_n, &
\text{if reduction} = \text{'mean',}\\
\sum_{n=1}^N l_n, &
\text{if reduction} = \text{'sum'.}
\end{cases}
\\ell(x, y) = \\begin{cases}
\\sum_{n=1}^N \\frac{1}{\\sum_{n=1}^N w_{y_n} \\cdot \\mathbb{1}\\{y_n \\not=
\\text{ignore_index}\\}} l_n, &\\text{if reduction} = \\text{'mean',}\\\\
\\sum_{n=1}^N l_n, &
\\text{if reduction} = \\text{'sum'.}
\\end{cases}

- Class probabilities (float), used when the target is a probability distribution over multiple class labels.
When reduction is set to 'none', the cross-entropy loss is computed as follows:

.. math::
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
l_n = - \sum_{c=1}^C w_c \log \frac{\exp(x_{n,c})}{\sum_{i=1}^C \exp(x_{n,i})} y_{n,c}
\\ell(x, y) = L = \\{l_1,\\dots,l_N\\}^\\top, \\quad
l_n = - \\sum_{c=1}^C w_c \\log \\frac{\\exp(x_{n,c})}{\\sum_{i=1}^C \\exp(x_{n,i})} y_{n,c}

where :math:`x` denotes the predicted values, :math:`t` denotes the target values, :math:`w` denotes the weights,
and :math:`N` is the batch size. The index :math:`c` ranges from [0, C-1], representing the class indices,
@@ -217,12 +214,12 @@ class CrossEntropyLoss(nn.Cell):
If reduction is not set to 'none' (the default is 'mean'), the loss is computed as:

.. math::
\ell(x, y) = \begin{cases}
\frac{\sum_{n=1}^N l_n}{N}, &
\text{if reduction} = \text{'mean',}\\
\sum_{n=1}^N l_n, &
\text{if reduction} = \text{'sum'.}
\end{cases}
\\ell(x, y) = \\begin{cases}
\\frac{\\sum_{n=1}^N l_n}{N}, &
\\text{if reduction} = \\text{'mean',}\\\\
\\sum_{n=1}^N l_n, &
\\text{if reduction} = \\text{'sum'.}
\\end{cases}

Args:
config (TransformerConfig): The parallel configuration. Default: default_transformer_config,
@@ -258,7 +255,7 @@ class CrossEntropyLoss(nn.Cell):
@_LogActionOnce(m_logger=logger, key='CrossEntropyLoss',
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
def __init__(self, config: TransformerConfig = default_transformer_config, loss_tag='lm', **kwargs):
super(CrossEntropyLoss, self).__init__()
super().__init__()
dp = config.data_parallel_size
mp = config.tensor_model_parallel_size
cp = config.context_parallel_size
@@ -347,7 +344,7 @@ class VocabParallelCrossEntropy(nn.Cell):
"""calculate cross entropy loss"""

def __init__(self, config: TransformerConfig = default_transformer_config, **kwargs):
super(VocabParallelCrossEntropy, self).__init__()
super().__init__()
self.cross_entropy = CrossEntropyLoss(config, **kwargs)

def construct(self, vocab_parallel_logits, target, input_mask=None, label_smoothing=None):


+ 129
- 0
mindformers/parallel_core/training_graph/tensor_parallel/layers.py View File

@@ -36,6 +36,7 @@ from mindspore.ops.operations import Morph
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
from mindspore import mint

from mindformers.checkpoint.sharded_tensor import ShardedTensor
from mindformers.parallel_core.transformer_config import TransformerConfig
from mindformers.parallel_core.utils.init_method import init_method_zero
from mindformers.parallel_core.inference.utils import divide
@@ -207,6 +208,28 @@ class VocabParallelEmbedding(nn.Cell):
def sharding_propagation(self, config: TransformerConfig):
pass

def sharded_state_dict(self):
"""Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now."""
sharded_state_dict = {}
weight_shape = (self.num_embeddings, self.embedding_dim)

if self.enable_embedding_tp:
axis_fragmentations = (self.tp, 1)
local_shape = (self.num_embeddings // self.tp, self.embedding_dim)
else:
axis_fragmentations = (1, 1)
local_shape = (self.num_embeddings, self.embedding_dim)

sharded_state_dict[self.weight.name] = ShardedTensor(
key=self.weight.name,
org_key=self.weight.name,
dtype=self.weight.dtype,
local_shape=local_shape,
global_shape=weight_shape,
global_offset=(0, 0),
axis_fragmentations=axis_fragmentations)
return sharded_state_dict


class ColumnParallelLinear(nn.Cell):
"""Linear layer with column parallelism.
@@ -427,6 +450,33 @@ class ColumnParallelLinear(nn.Cell):
matmul_in_strategy = ((dp * cp, 1), weight_strategy)
self.matmul.shard(in_strategy=matmul_in_strategy)

def sharded_state_dict(self):
"""Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now."""
tp = self.config.tensor_model_parallel_size

sharded_state_dict = {}
weight_shape = (self.output_size, self.input_size)
local_shape = (self.output_size // tp, self.input_size)
if not self.skip_weight_param_allocation:
sharded_state_dict[self.weight.name] = ShardedTensor(
key=self.weight.name,
org_key=self.weight.name,
dtype=self.weight.dtype,
local_shape=local_shape,
global_shape=weight_shape,
global_offset=(0, 0),
axis_fragmentations=(tp, 1))
if self.has_bias:
sharded_state_dict[self.bias.name] = ShardedTensor(
key=self.bias.name,
org_key=self.bias.name,
dtype=self.bias.dtype,
local_shape=(self.output_size,),
global_shape=(self.output_size,),
global_offset=(0,),
axis_fragmentations=(1,))
return sharded_state_dict


class RowParallelLinear(nn.Cell):
"""Linear layer with row parallelism.
@@ -663,6 +713,33 @@ class RowParallelLinear(nn.Cell):
matmul_in_strategy = ((dp * cp, tp), weight_strategy)
self.matmul.shard(in_strategy=matmul_in_strategy)

def sharded_state_dict(self):
"""Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now."""
tp = self.config.tensor_model_parallel_size
sharded_state_dict = {}
weight_shape = (self.output_size, self.input_size)
local_shape = (self.output_size, self.input_size // tp)

sharded_state_dict[self.weight.name] = ShardedTensor(
key=self.weight.name,
org_key=self.weight.name,
dtype=self.weight.dtype,
local_shape=local_shape,
global_shape=weight_shape,
global_offset=(0, 0),
axis_fragmentations=(1, tp))

if self.has_bias:
sharded_state_dict[self.bias.name] = ShardedTensor(
key=self.bias.name,
org_key=self.bias.name,
dtype=self.bias.dtype,
local_shape=(self.output_size,),
global_shape=(self.output_size,),
global_offset=(0,),
axis_fragmentations=(1,))
return sharded_state_dict


class LinearNoTP(ColumnParallelLinear):
"""Linear layer without tensor parallelism.
@@ -712,6 +789,32 @@ class LinearNoTP(ColumnParallelLinear):
)
)

def sharded_state_dict(self):
"""Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now."""
sharded_state_dict = {}
weight_shape = (self.output_size, self.input_size)
local_shape = (self.output_size, self.input_size)

sharded_state_dict[self.weight.name] = ShardedTensor(
key=self.weight.name,
org_key=self.weight.name,
dtype=self.weight.dtype,
local_shape=local_shape,
global_shape=weight_shape,
global_offset=(0, 0),
axis_fragmentations=(1, 1))

if self.has_bias:
sharded_state_dict[self.bias.name] = ShardedTensor(
key=self.bias.name,
org_key=self.bias.name,
dtype=self.bias.dtype,
local_shape=(self.output_size,),
global_shape=(self.output_size,),
global_offset=(0,),
axis_fragmentations=(1,))
return sharded_state_dict


class SequenceParallelLinear(ColumnParallelLinear):
"""Linear layer without tensor parallelism.
@@ -761,3 +864,29 @@ class SequenceParallelLinear(ColumnParallelLinear):
layout(("cp", "tp"), "dp", "None"), # output [S, B, H]
)
)

def sharded_state_dict(self):
"""Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now."""
sharded_state_dict = {}
weight_shape = (self.output_size, self.input_size)
local_shape = (self.output_size, self.input_size)

sharded_state_dict[self.weight.name] = ShardedTensor(
key=self.weight.name,
org_key=self.weight.name,
dtype=self.weight.dtype,
local_shape=local_shape,
global_offset=(0, 0),
global_shape=weight_shape,
axis_fragmentations=(1, 1))

if self.has_bias:
sharded_state_dict[self.bias.name] = ShardedTensor(
key=self.bias.name,
org_key=self.bias.name,
dtype=self.bias.dtype,
local_shape=(self.output_size,),
global_offset=(0,),
global_shape=(self.output_size,),
axis_fragmentations=(1,))
return sharded_state_dict

+ 20
- 11
mindformers/parallel_core/training_graph/transformer/flash_attention.py View File

@@ -177,9 +177,11 @@ class FlashAttention(Cell):

if self.monitor_max_attention_logit:
self.max_logits_val = Parameter(
Tensor(np.zeros((1, self.head_num)), dtype=mstype.float32),
Tensor(np.zeros((self.head_num)), dtype=mstype.float32),
parallel_optimizer=False, requires_grad=False
)
self.reduce_max = aclnn_ops.ReduceMax()
self.reduce_max.add_prim_attr("self_define_shard", True)
self.assign_add = ops.AssignAdd()
self.assign_add.add_prim_attr("self_define_shard", True)

@@ -288,7 +290,7 @@ class FlashAttention(Cell):
attention_mask = self.lower_triangle_mask

if self.input_layout == "TND":
_, _, _, output = self.flash_attention(query,
softmax_val, _, _, output = self.flash_attention(query,
key,
value,
alibi_mask,
@@ -298,6 +300,9 @@ class FlashAttention(Cell):
prefix,
actual_seq_qlen,
actual_seq_kvlen)
if self.monitor_max_attention_logit:
max_logits = self.reduce_max(softmax_val, (0, 2))
output = F.depend(output, self.assign_add(self.max_logits_val, max_logits))
return output

q_seq_len, bsz = query.shape[:2]
@@ -331,8 +336,7 @@ class FlashAttention(Cell):
attention_mask,
prefix)
if self.monitor_max_attention_logit:
max_logits = ops.ReduceMax()(softmax_val, (2, 3))
max_logits = ops.ReduceMax(keep_dims=True)(max_logits, (0))
max_logits = self.reduce_max(softmax_val, (0, 2, 3))
output = F.depend(output, self.assign_add(self.max_logits_val, max_logits))

if self.input_layout == "BNSD":
@@ -377,12 +381,17 @@ class FlashAttention(Cell):

if self.monitor_max_attention_logit:
self.assign_add.shard(
in_strategy=(
layout("None", "tp"),
layout("None", "tp"),
),
out_strategy=(
layout("None", "tp"),
)
in_strategy=(layout("tp"), layout("tp")),
out_strategy=(layout("tp"),)
)
if self.input_layout == "BNSD":
self.reduce_max.shard(
in_strategy=(layout("None", "tp", "None", "None"),),
out_strategy=(layout("tp"),)
)
elif self.input_layout == "TND":
self.reduce_max.shard(
in_strategy=(layout("None", "tp", "None"),),
out_strategy=(layout("tp"),)
)
return self

+ 25
- 0
mindformers/parallel_core/training_graph/transformer/moe/ffn.py View File

@@ -23,6 +23,7 @@ from mindspore.ops.auto_generate import Shape, Cast, GroupedMatmul, Reshape, Swi
from mindspore.ops.operations import Morph
from mindspore.parallel._utils import _get_parallel_mode

from mindformers.checkpoint.sharded_tensor import ShardedTensor
from mindformers.parallel_core.training_graph.device_matrix import layout_moe as layout
from mindformers.parallel_core.training_graph.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher, MoEAlltoAllDeredundencyTokenDispatcher, MoEAlltoAllZeroRedundancyTokenDispatcher
from mindformers.parallel_core.transformer_config import TransformerConfig
@@ -215,3 +216,27 @@ class FFNGroupedGEMM(nn.Cell):
layout(dp, sp, mp0), # output [B, S, h]
)
)

def sharded_state_dict(self):
"""Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now."""
ep = self.config.expert_model_parallel_size
sharded_state_dict = {}
sharded_state_dict[self.weight1.name] = ShardedTensor(
key=self.weight1.name,
org_key=self.weight1.name,
dtype=self.weight1.dtype,
local_shape=(self.num_local_experts // ep * self.hidden_size, self.moe_ffn_hidden_size * 2),
global_shape=(self.num_local_experts * self.hidden_size, self.moe_ffn_hidden_size * 2),
global_offset=(0, 0),
axis_fragmentations=(ep, 1),
)
sharded_state_dict[self.weight2.name] = ShardedTensor(
key=self.weight2.name,
org_key=self.weight2.name,
dtype=self.weight2.dtype,
local_shape=(self.num_local_experts // ep * self.moe_ffn_hidden_size, self.hidden_size),
global_shape=(self.num_local_experts * self.moe_ffn_hidden_size, self.hidden_size),
global_offset=(0, 0),
axis_fragmentations=(ep, 1),
)
return sharded_state_dict

+ 3
- 24
mindformers/parallel_core/training_graph/transformer/moe/moe_layer.py View File

@@ -13,21 +13,21 @@
# limitations under the License.
# ============================================================================
"""Transformer MoE Layer."""
import hashlib
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Union
import numpy as np

import mindspore as ms
from mindspore import nn, ops, Tensor, Parameter
from mindspore.context import ParallelMode
from mindspore.ops.operations import Morph
from mindspore.communication.management import get_rank, get_group_size, create_group
from mindspore.ops.auto_generate import AddExt, Reshape, Shape, Transpose
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation

from mindformers.parallel_core.training_graph.device_matrix import layout, layout_moe
from mindformers.parallel_core.training_graph.transformer.moe.router import TopKRouter
from mindformers.parallel_core.training_graph.transformer.moe.utils import get_dp_mod_ep_group_name
from mindformers.parallel_core.utils.spec_utils import ModuleSpec, build_module
from mindformers.parallel_core.transformer_config import TransformerConfig

@@ -112,7 +112,7 @@ class MoELayer(BaseMoELayer):
self.ep = config.expert_model_parallel_size
self.expert_num = config.num_moe_experts
self.num_local_experts = self.expert_num // self.ep
self.dp_modulo_ep_group = self._dp_modulo_ep_group()
self.dp_modulo_ep_group = get_dp_mod_ep_group_name(self.dp, self.ep)

# ops
self.add = AddExt()
@@ -161,27 +161,6 @@ class MoELayer(BaseMoELayer):
elif _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL,):
self.shard(config)

def _dp_modulo_ep_group(self):
"""Create MoE data parallel group across DP."""
rank_id = get_rank()
world_size = get_group_size()
dp_group_id = rank_id // self.dp

start_rank = dp_group_id * self.dp
end_rank = min(start_rank + self.dp, world_size)

rank_list = []
ep_group_id_in_dp = (rank_id % self.dp) % self.ep
for r in range(start_rank, end_rank):
if r % self.ep == ep_group_id_in_dp:
rank_list.append(r)

rank_list_str = "-".join([str(i) for i in rank_list])
hashed = hashlib.sha256(rank_list_str.encode()).hexdigest()[:48]
dp_group_name = str(hashed)
create_group(dp_group_name, rank_list)
return dp_group_name

def permute_reshape_infer_shape(self, *args):
origin_shape = args[0]
return (self.dp, origin_shape[0] * origin_shape[1] // self.dp, origin_shape[-1])


+ 15
- 0
mindformers/parallel_core/training_graph/transformer/moe/router.py View File

@@ -26,6 +26,7 @@ from mindspore.ops.auto_generate import AddExt, AssignAdd, Cast, Div, Mul, Resha
from mindspore.ops.operations import Shape, ReduceSum, ReduceMean
from mindspore.parallel._utils import _get_parallel_mode

from mindformers.checkpoint.sharded_tensor import ShardedTensor
from mindformers.parallel_core.training_graph.device_matrix import layout_moe as layout
from mindformers.parallel_core.transformer_config import TransformerConfig
from mindformers.tools.utils import get_real_group_size, get_real_rank
@@ -117,6 +118,20 @@ class Router(ABC, nn.Cell):
router_logits = self.linear(inputs.astype(self.moe_router_dtype), weight)
return router_logits

def sharded_state_dict(self):
"""Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now."""
sharded_state_dict = {}
sharded_state_dict[self.weight.name] = ShardedTensor(
key=self.weight.name,
org_key=self.weight.name,
dtype=self.weight.dtype,
local_shape=(self.expert_dim, self.hidden_size),
global_shape=(self.expert_dim, self.hidden_size),
global_offset=(0, 0),
axis_fragmentations=(1, 1),
)
return sharded_state_dict


class TopKRouter(Router):
"""Route each token to the top-k experts."""


+ 15
- 0
mindformers/parallel_core/training_graph/transformer/moe/shared_experts.py View File

@@ -23,6 +23,7 @@ from mindspore.ops.auto_generate import Cast, Mul, Sigmoid
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
from mindspore.context import ParallelMode

from mindformers.checkpoint.sharded_tensor import ShardedTensor
from mindformers.parallel_core.training_graph.transformer.mlp import MLP, MLPSubmodules, MLPInterleaved
from mindformers.parallel_core.transformer_config import TransformerConfig
from mindformers.parallel_core.training_graph.device_matrix import layout
@@ -108,6 +109,20 @@ class SharedExpertMLP(MLP):
def expert_sharding_propagation(self, config: TransformerConfig):
super().sharding_propagation(config)

def sharded_state_dict(self):
"""Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now."""
sharded_state_dict = {}
sharded_state_dict[self.shared_experts_gate.weight.name] = ShardedTensor(
key=self.shared_experts_gate.weight.name,
org_key=self.shared_experts_gate.weight.name,
dtype=self.shared_experts_gate.weight.dtype,
local_shape=(1, self.hidden_size),
global_shape=(1, self.hidden_size),
global_offset=(0, 0),
axis_fragmentations=(1, 1),
)
return sharded_state_dict


class SharedExpertMLPInterleaved(MLPInterleaved):
"""


+ 13
- 73
mindformers/parallel_core/training_graph/transformer/moe/token_dispatcher.py View File

@@ -14,22 +14,24 @@
# ============================================================================
"""MoETokenDispatcher for MoE."""

import hashlib
from abc import abstractmethod
from typing import Any
import numpy as np

import mindspore as ms
from mindspore import nn, ParallelMode, ops, mint, Parameter
from mindspore import nn, ops, mint, Parameter
from mindspore.common.tensor import Tensor
from mindspore.communication import create_group, get_rank
from mindspore.parallel._utils import _get_parallel_mode
from mindspore.communication import get_rank
from mindspore.ops.auto_generate import CumsumExt, FmodScalar, SortExt, IndexSelect, OneHotExt, Cast, Reshape, Zeros, Transpose, ReduceSum, MaskedSelect

from mindformers.parallel_core.transformer_config import TransformerConfig
from mindformers.parallel_core.training_graph.transformer.moe.utils import (
get_ep_group_name,
get_iep_group_name,
get_oep_group_name
)
from mindformers.version_control import get_all2allvc

_OEP_GROUP_NAME = {}
_IEP_GROUP_NAME = {}


class AlltoAll(nn.Cell):
"""AlltoAll operation wrapper."""
@@ -80,29 +82,12 @@ class MoETokenDispatcher:
self.cp = config.context_parallel_size

self.num_out_tokens = None
self.ep_group = self._ep_group()
self.ep_group = get_ep_group_name(get_rank(), self.ep)

self.d2h = ops.MoveTo().add_prim_attr("recompute", False)
if self.config.print_expert_load:
self.assign_add = ops.AssignAdd()

def _ep_group(self):
"""Get expert model parallel group."""
if _get_parallel_mode() == ParallelMode.STAND_ALONE:
return None
rank_id = get_rank()
ep = self.config.expert_model_parallel_size

rank_start = rank_id // ep * ep
rand_end = rank_id // ep * ep + ep
rank_list = list(range(rank_start, rand_end))

rank_list_str = "-".join([str(i) for i in range(rank_start, rand_end)])
hashed = hashlib.sha256(rank_list_str.encode()).hexdigest()[:48]
ep_group_name = str(hashed)
create_group(ep_group_name, rank_list)
return ep_group_name

@property
def tp_group(self):
"""Get expert tensor parallel group."""
@@ -306,61 +291,16 @@ class MoEAlltoAllDeredundencyTokenDispatcher(MoETokenDispatcher):
self.b = self.a + node_expert_num

# communication group
self.oep_group = self._get_oep_group_name()
self.iep_group = self._get_iep_group_name()
self.oep_group = get_oep_group_name(self.rank_id, self.ep, self.iep)
self.iep_group = get_iep_group_name(self.rank_id, self.iep)

self.mul = ops.Mul().recompute(True)
self.nonzero = ops.NonZero().recompute(False)
self.squeeze_0 = ops.Squeeze(0).recompute(False)
self.oep_allgather = ops.AllGather(group=self.oep_group).recompute(False)
self.onehot = ops.OneHot().recompute(False)
self.onehot = ops.OneHot()
self.iep_alltoallv = ops.AlltoAllV(group=self.iep_group, block_size=1).recompute(False)

def _get_oep_group_name(self):
"""
Generates a unique group name for a set of ranks involved in outer expert partitioning (oep)
and creates a communication group with this name.
This method calculates a range of ranks based on the current rank id
and the expert partition size, hashes this range to create a unique
identifier, and then establishes a new communication group using this identifier.
"""
rank_start = self.rank_id // self.ep * self.ep
rank_start = rank_start + self.rank_id % self.iep
rand_end = rank_start + self.ep
rank_list = list(range(rank_start, rand_end, self.iep))

rank_list_str = "-".join([str(i) for i in rank_list])
if rank_list_str in _OEP_GROUP_NAME:
return _OEP_GROUP_NAME[rank_list_str]

hashed = hashlib.sha256(rank_list_str.encode()).hexdigest()[:48]
oep_group_name = str(hashed)
create_group(oep_group_name, rank_list)
_OEP_GROUP_NAME[rank_list_str] = oep_group_name
return oep_group_name

def _get_iep_group_name(self):
"""
Generates a unique group name for a set of ranks involved in inner expert partitioning (iep)
and creates a communication group with this name.
This method calculates a range of ranks based on the current rank id
and the expert partition size, hashes this range to create a unique
identifier, and then establishes a new communication group using this identifier.
"""
rank_start = self.rank_id // self.iep * self.iep
rand_end = rank_start + self.iep
rank_list = list(range(rank_start, rand_end))

rank_list_str = "-".join([str(i) for i in rank_list])
if rank_list_str in _IEP_GROUP_NAME:
return _IEP_GROUP_NAME[rank_list_str]

hashed = hashlib.sha256(rank_list_str.encode()).hexdigest()[:48]
iep_group_name = str(hashed)
create_group(iep_group_name, rank_list)
_IEP_GROUP_NAME[rank_list_str] = iep_group_name
return iep_group_name

def get_exdispatch_idx(self, x, expert_ids, router_coeff):
"""
Obtain nddispatch information within nodes.


+ 93
- 0
mindformers/parallel_core/training_graph/transformer/moe/utils.py View File

@@ -0,0 +1,93 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" utils """

import hashlib
from mindspore import ParallelMode
from mindspore.parallel._utils import _get_parallel_mode
from mindspore.communication.management import get_rank, get_group_size, create_group

GROUP_NAME = {}


def get_group(rank_list):
"""check whether a group has been created."""
rank_list_str = "-".join([str(i) for i in rank_list])
if rank_list_str in GROUP_NAME:
return GROUP_NAME[rank_list_str]

hashed = hashlib.sha256(rank_list_str.encode()).hexdigest()[:48]
group_name = str(hashed)
create_group(group_name, rank_list)
GROUP_NAME[rank_list_str] = group_name
return group_name


def get_dp_mod_ep_group_name(data_parallel_size, expert_model_parallel_size):
"""Create MoE data parallel group across DP."""
rank_id = get_rank()
world_size = get_group_size()
dp_group_id = rank_id // data_parallel_size

start_rank = dp_group_id * data_parallel_size
end_rank = min(start_rank + data_parallel_size, world_size)

rank_list = []
ep_group_id_in_dp = (rank_id % data_parallel_size) % expert_model_parallel_size
for r in range(start_rank, end_rank):
if r % expert_model_parallel_size == ep_group_id_in_dp:
rank_list.append(r)
return get_group(rank_list)


def get_ep_group_name(rank_id, expert_model_parallel_size):
"""Get expert model parallel group."""
if _get_parallel_mode() == ParallelMode.STAND_ALONE:
return None

rank_start = rank_id // expert_model_parallel_size * expert_model_parallel_size
rand_end = rank_id // expert_model_parallel_size * expert_model_parallel_size + expert_model_parallel_size
rank_list = list(range(rank_start, rand_end))
return get_group(rank_list)


def get_oep_group_name(rank_id, expert_model_parallel_size, npu_nums_per_device):
"""
Generates a unique group name for a set of ranks involved in outer expert partitioning (oep)
and creates a communication group with this name.
This method calculates a range of ranks based on the current rank id
and the expert partition size, hashes this range to create a unique
identifier, and then establishes a new communication group using this identifier.
"""
rank_start = rank_id // expert_model_parallel_size * expert_model_parallel_size
rank_start = rank_start + rank_id % npu_nums_per_device
rand_end = rank_start + expert_model_parallel_size
rank_list = list(range(rank_start, rand_end, npu_nums_per_device))
return get_group(rank_list)


def get_iep_group_name(rank_id, npu_nums_per_device):
"""
Generates a unique group name for a set of ranks involved in inner expert partitioning (iep)
and creates a communication group with this name.
This method calculates a range of ranks based on the current rank id
and the expert partition size, hashes this range to create a unique
identifier, and then establishes a new communication group using this identifier.
"""
rank_start = rank_id // npu_nums_per_device * npu_nums_per_device
rand_end = rank_start + npu_nums_per_device
rank_list = list(range(rank_start, rand_end))
return get_group(rank_list)

+ 10
- 13
mindformers/parallel_core/training_graph/transformer/multi_latent_attention.py View File

@@ -113,10 +113,9 @@ class MultiLatentAttention(nn.Cell):
self.cp_co = self.cp // self.cp_ds

if self.num_attention_heads % (self.tp * self.cp_ds) != 0:
raise ValueError("For 'ParallelAttention', the class variable 'num_heads' must be a multiple of "
"'tensor_parallel * ulysses_cp_num', but got num_heads is {}, tensor_parallel is {}, "
"ulysses_cp_num is {}."
.format(self.num_attention_heads, self.tp, self.cp_ds))
raise ValueError(f"For 'ParallelAttention', the class variable 'num_heads' must be a multiple of "
f"'tensor_parallel * ulysses_cp_num', but got num_heads is {self.num_attention_heads}, "
f"tensor_parallel is {self.tp}, ulysses_cp_num is {self.cp_ds}.")

zero_pad_length = self.q_head_dim - self.v_head_dim
if zero_pad_length < 0:
@@ -183,7 +182,7 @@ class MultiLatentAttention(nn.Cell):
cp = self.cp

self.bs_transpose.shard(((dp, cp, tp),))
self.tnd_transpose.shard(((cp, dp, tp, 1),))
self.tnd_transpose.shard((layout("cp", "dp", "tp", "None"),))

def construct(self, x: Tensor, attention_mask=None, rotary_pos_emb=None, rotary_pos_cos=None,
rotary_pos_sin=None, prefix_keys_values=None, pad_zeros=None, actual_seq_len=None):
@@ -422,12 +421,11 @@ class MLASelfAttentionConcatenated(MultiLatentAttention):

def shard_self_attn(self):
"""sharding for MLASelfAttentionConcatenated with semi_auto_parallel"""
dp = self.config.data_parallel_size
tp = self.config.tensor_model_parallel_size
cp = self.config.context_parallel_size
self.pe_concat.add_prim_attr("self_define_shard", True)

self.tile_kv.shard((layout("cp", "dp", "None", "None"),))
self.pe_concat.shard(((cp, dp, tp, 1), (cp, dp, tp, 1)))
self.pe_concat.shard(in_strategy=((layout("cp", "dp", "tp", "None"), layout("cp", "dp", "tp", "None")),),
out_strategy=(layout("cp", "dp", "tp", "None"),))
self.split.shard((layout("cp", "dp", "tp", "None"),))
self.split_3d.shard((layout(("cp", "tp"), "dp", "None"),))

@@ -629,12 +627,11 @@ class MLASelfAttention(MultiLatentAttention):

def shard_self_attn(self):
"""sharding for MLASelfAttention with semi_auto_parallel"""
dp = self.config.data_parallel_size
tp = self.config.tensor_model_parallel_size
cp = self.config.context_parallel_size
self.pe_concat.add_prim_attr("self_define_shard", True)

self.tile_kv.shard((layout("cp", "dp", "None", "None"),))
self.pe_concat.shard(((cp, dp, tp, 1), (cp, dp, tp, 1)))
self.pe_concat.shard(in_strategy=((layout("cp", "dp", "tp", "None"), layout("cp", "dp", "tp", "None")),),
out_strategy=(layout("cp", "dp", "tp", "None"),))
self.split.shard((layout("cp", "dp", "tp", "None"),))
self.split_3d.shard((layout(("cp", "tp"), "dp", "None"),))
self.expand_dims.shard((layout("cp", "dp", "None"),))


+ 6
- 4
mindformers/parallel_core/training_graph/transformer/multi_token_prediction.py View File

@@ -193,13 +193,12 @@ class MultiTokenPredictionLayer(nn.Cell):
"""Set parallel strategy."""
dp = self.config.data_parallel_size
tp = self.config.tensor_model_parallel_size
cp = self.config.context_parallel_size
self.concat.add_prim_attr("self_define_shard", True)
self.concat.shard(in_strategy=((layout("cp", "dp", "None"), layout("cp", "dp", "None")),),
out_strategy=(layout("cp", "dp", "None"),))
if self.use_seq_parallel and cp == 1:
self.enorm.shard(config, in_strategy=(layout("tp", "dp", "None"), layout("None",)))
self.hnorm.shard(config, in_strategy=(layout("tp", "dp", "None"), layout("None",)))
if self.use_seq_parallel:
self.enorm.shard(config, in_strategy=(layout(("cp", "tp"), "dp", "None"), layout("None",)))
self.hnorm.shard(config, in_strategy=(layout(("cp", "tp"), "dp", "None"), layout("None",)))
self.concat.add_prim_attr("self_define_shard", True)
self.concat.shard(in_strategy=((layout(("cp", "tp"), "dp", "None"), layout(("cp", "tp"), "dp", "None")),),
out_strategy=(layout(("cp", "tp"), "dp", "None"),))
@@ -541,6 +540,9 @@ class MtpSharedVocabParallelEmbedding(VocabParallelEmbedding):
output = self.embedding_morph(input_ids, weight)
return output

def sharded_state_dict(self):
return {}


class MtpSharedLanguageModelEmbedding(LanguageModelEmbedding):
"""Embedding layer used in Multi-Token Prediction module, same to standard LanguageModelEmbedding."""


+ 89
- 14
mindformers/parallel_core/training_graph/transformer/norm.py View File

@@ -23,6 +23,7 @@ from mindspore.ops.auto_generate import MeanExt, Sqrt, Rsqrt, SubExt, AddExt, Mu
from mindspore.common.initializer import initializer
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation

from mindformers.checkpoint.sharded_tensor import ShardedTensor
from mindformers.parallel_core.transformer_config import TransformerConfig
from mindformers.parallel_core.training_graph.device_matrix import layout

@@ -36,7 +37,7 @@ def get_strategy(config: TransformerConfig):


class LayerNorm(nn.Cell):
r"""
"""
Layer norm operation.

Args:
@@ -52,9 +53,10 @@ class LayerNorm(nn.Cell):
"""

def __init__(self, config, dim, eps=1e-5):
super(LayerNorm, self).__init__()
super().__init__()
self.params_dtype = config.params_dtype
self.compute_type = config.layernorm_compute_dtype
self.dim = dim

self.gamma = Parameter(initializer('ones', dim, self.params_dtype), name="gamma",
parallel_optimizer=False)
@@ -115,9 +117,32 @@ class LayerNorm(nn.Cell):
def sharding_propagation(self, config: TransformerConfig):
pass

def sharded_state_dict(self):
"""Return sharding metadata for LayerNorm parameters."""
sharded_state_dict = {}
sharded_state_dict[self.gamma.name] = ShardedTensor(
key=self.gamma.name,
org_key=self.gamma.name,
dtype=self.gamma.dtype,
local_shape=(self.dim,),
global_shape=(self.dim,),
global_offset=(0,),
axis_fragmentations=(1,),
)
sharded_state_dict[self.beta.name] = ShardedTensor(
key=self.beta.name,
org_key=self.beta.name,
dtype=self.beta.dtype,
local_shape=(self.dim,),
global_shape=(self.dim,),
global_offset=(0,),
axis_fragmentations=(1,),
)
return sharded_state_dict


class FusedLayerNorm(nn.Cell):
r"""
"""
Layer norm operation.

Args:
@@ -133,10 +158,10 @@ class FusedLayerNorm(nn.Cell):
"""

def __init__(self, config, dim, eps=1e-5):
super(FusedLayerNorm, self).__init__()
super().__init__()
self.params_dtype = config.params_dtype
self.compute_type = config.layernorm_compute_dtype
self.dim = dim
self.layer_norm = P.LayerNorm(begin_norm_axis=-1,
begin_params_axis=-1,
epsilon=eps)
@@ -170,17 +195,39 @@ class FusedLayerNorm(nn.Cell):
strategy = (cp, dp, 1)

if strategy[-1] != 1:
raise TypeError(
'The last dim in FusedLayerNorm can not equal to 1! Strategy {} not supported!'.format(strategy))
raise TypeError(f'The last dim in FusedLayerNorm can not equal to 1! Strategy {strategy} not supported!')

self.layer_norm.shard((strategy, (strategy[-1],), (strategy[-1],)))

def sharding_propagation(self, config: TransformerConfig):
pass

def sharded_state_dict(self):
"""Return sharding metadata for FusedLayerNorm parameters."""
sharded_state_dict = {}
sharded_state_dict[self.gamma.name] = ShardedTensor(
key=self.gamma.name,
org_key=self.gamma.name,
dtype=self.gamma.dtype,
local_shape=(self.dim,),
global_shape=(self.dim,),
global_offset=(0,),
axis_fragmentations=(1,),
)
sharded_state_dict[self.beta.name] = ShardedTensor(
key=self.beta.name,
org_key=self.beta.name,
dtype=self.beta.dtype,
local_shape=(self.dim,),
global_shape=(self.dim,),
global_offset=(0,),
axis_fragmentations=(1,),
)
return sharded_state_dict


class RMSNorm(nn.Cell):
r"""
"""
A self-defined RMSNorm operation using reduce mean.

Args:
@@ -196,10 +243,10 @@ class RMSNorm(nn.Cell):
"""

def __init__(self, config, dim, eps=1e-6):
super(RMSNorm, self).__init__()
super().__init__()
self.params_dtype = config.params_dtype
self.compute_type = config.layernorm_compute_dtype
self.dim = dim
self.eps = eps
self.weight = Parameter(initializer('ones', (dim), self.params_dtype))

@@ -249,9 +296,23 @@ class RMSNorm(nn.Cell):
def sharding_propagation(self, config: TransformerConfig):
pass

def sharded_state_dict(self):
"""Return sharding metadata for RMSNorm parameters."""
sharded_state_dict = {}
sharded_state_dict[self.weight.name] = ShardedTensor(
key=self.weight.name,
org_key=self.weight.name,
dtype=self.weight.dtype,
local_shape=(self.dim,),
global_shape=(self.dim,),
global_offset=(0,),
axis_fragmentations=(1,),
)
return sharded_state_dict


class FusedRMSNorm(nn.Cell):
r"""
"""
FusedRMSNorm operation

Args:
@@ -267,10 +328,10 @@ class FusedRMSNorm(nn.Cell):
"""

def __init__(self, config, dim, eps=1e-6):
super(FusedRMSNorm, self).__init__()
super().__init__()
self.params_dtype = config.params_dtype
self.compute_type = config.layernorm_compute_dtype
self.dim = dim
self.eps = eps
self.weight = Parameter(initializer('ones', (dim), self.params_dtype))

@@ -294,11 +355,25 @@ class FusedRMSNorm(nn.Cell):
if in_strategy:
self.norm.shard(in_strategy)
else:
self.norm.shard((layout("cp", "dp", "None"), layout("None",)))
self.norm.shard((layout("cp", "dp", "None"), layout("None", )))

def sharding_propagation(self, config: TransformerConfig):
pass

def sharded_state_dict(self):
"""Return sharding metadata for FusedRMSNorm parameters."""
sharded_state_dict = {}
sharded_state_dict[self.weight.name] = ShardedTensor(
key=self.weight.name,
org_key=self.weight.name,
dtype=self.weight.dtype,
local_shape=(self.dim,),
global_shape=(self.dim,),
global_offset=(0,),
axis_fragmentations=(1,),
)
return sharded_state_dict


class Norm:
"""


+ 26
- 5
mindformers/parallel_core/training_graph/transformer/utils.py View File

@@ -77,7 +77,7 @@ ATTNMASK_FUNC_MAP = {


def get_attn_mask_func(mask_func_type):
r"""
"""
Get attention mask function.

Args:
@@ -159,7 +159,6 @@ class LayerSetting:
self.pp_interleave_num = pp_interleave_num if use_pp_interleave else 1
self.offset = np.array(offset, np.int32)
self._check_inputs()
self.offset = np.broadcast_to(self.offset, (self.pp_interleave_num, self.pp))

self.is_zbv = ms.get_auto_parallel_context("pipeline_scheduler") == "zero_bubble_v"
avg_layer = self.num_layers // (self.pp * self.pp_interleave_num)
@@ -571,17 +570,39 @@ class LayerSetting:

def _check_inputs(self):
"""Check the inputs of offset."""
total_stages = self.pp * self.pp_interleave_num
if total_stages <= 1:
self.offset = np.zeros((1, 1), dtype=np.int32)
return
base_layers = self.num_layers // total_stages
remainder_layers = self.num_layers % total_stages

if self.offset.ndim >= 1 and self.offset.shape[-1] != self.pp:
raise ValueError(f"offset.shape[-1] should equal to `pp` ({self.pp}), "
f"but got ({self.offset.shape[-1]}). `offset`: {self.offset}")
if self.offset.ndim >= 2 and self.offset.shape[-2] != self.pp_interleave_num:
raise ValueError(f"offset.shape[-2] should equal to `pp_interleave_num` ({self.pp_interleave_num}), "
f"but got ({self.offset.shape[-2]}). `offset`: {self.offset}")
if self.offset.sum() != self.num_layers % (self.pp * self.pp_interleave_num):
r = self.num_layers % (self.pp * self.pp_interleave_num)
if self.offset.sum() != remainder_layers:
raise ValueError(f"The sum of `offset` ({self.offset.sum()}) should equal to remainder of `num_layers` "
f"({self.num_layers}) % (pp ({self.pp}) * pp_interleave_num ({self.pp_interleave_num})) "
f"= {r}")
f"= {remainder_layers}")

# Broadcast offset
self.offset = np.broadcast_to(self.offset, (self.pp_interleave_num, self.pp))
actual_layers = self.offset.flatten() + base_layers

# Ensure head and tail stages have non-negative layer counts (0 is allowed)
if actual_layers[0] < 0:
raise ValueError(f"Head stage has negative layers. Offset must be ≥ {-base_layers}.")
if actual_layers[-1] < 0:
raise ValueError(f"Tail stage has negative layers. Offset must be ≥ {-base_layers}.")

# Ensure all middle stages have at least 1 layer
if total_stages > 2:
middle_layers = actual_layers[1:-1]
if np.any(middle_layers < 1):
raise ValueError(f"Some middle stage has fewer than 1 layer. Offset must be ≥ {1 - base_layers}.")

@staticmethod
def set_pattern_recompute(layer, p_list, add_prim_attr=False, set_on=True, info=''):


+ 17
- 0
mindformers/parallel_core/transformer_config.py View File

@@ -650,6 +650,12 @@ class TransformerConfig(ModelParallelConfig, MFModelConfig):
"When using moe_dry_run, moe_token_dispatcher_type must be 'alltoall' or 'alltoall_deredundency'."
)

if self.position_embedding_type not in ["rope", "yarn", "none", "relative", "learned_absolute", "partial_rope"]:
raise ValueError(
f"The current value of position_embedding_type is {self.position_embedding_type},"
" but position_embedding_type must be one of: 'rope', 'yarn', 'none', 'relative', 'learned_absolute'."
)

if isinstance(self.rope_scaling, dict):
self.position_embedding_type = (self.rope_scaling.pop("type", None) or
self.rope_scaling.pop("rope_type", None))
@@ -660,6 +666,17 @@ class TransformerConfig(ModelParallelConfig, MFModelConfig):
setattr(self, k, v)
del self.rope_scaling

if self.position_embedding_type == "none":
self.nope_layer_interval = None

if self.nope_layer_interval is None:
pass
elif not isinstance(self.nope_layer_interval, int):
raise TypeError("nope_layer_interval must be a int, "
f"but got {type(self.nope_layer_interval)}.")
elif self.nope_layer_interval <= 0:
raise ValueError("nope_layer_interval must be larger than 0.")

if self.bias_swiglu_fusion and self.hidden_act != 'swiglu':
raise ValueError(
"When using bias_swiglu_fusion, hidden_act must be swiglu."


+ 7
- 2
mindformers/parallel_core/utils/model_mixin.py View File

@@ -586,6 +586,11 @@ class TrainModelMixin:
model = self.check_and_get_model()
return model.get_model_parameters()

def get_max_attention_logit(self):
"""Get max attention logit values from the model."""
model = self.check_and_get_model()
return model.get_max_attention_logit()

def make_model_muon_fns(self):
"""Make model muon functions."""
model = self.check_and_get_model()
@@ -601,10 +606,10 @@ class TrainModelMixin:
model = self.check_and_get_model()
return model.get_tp_dims(parameters)

def get_op_groups_info(self, parameters, op_size, tp_group, op_group):
def get_op_groups_info(self, parameters, op_size):
"""Get operation groups information for parameters."""
model = self.check_and_get_model()
return model.get_op_groups_info(parameters, op_size, tp_group, op_group)
return model.get_op_groups_info(parameters, op_size)

def get_parallel_config_for_muon(self):
"""Get parallel configuration for Muon optimizer."""


+ 65
- 38
mindformers/tools/ckpt_transform/transform_checkpoint.py View File

@@ -58,6 +58,7 @@ __all__ = ['TransformCkpt']

class TransformCkpt:
"""Transform src_checkpoint from src_strategy to dst_strategy."""

def __init__(self,
auto_trans_ckpt: bool = False,
rank_id: Optional[int] = None,
@@ -212,23 +213,29 @@ class TransformCkpt:
if self.world_size > 1:
dst_strategy_list = glob(os.path.join(self.dst_strategy_dir, f"*_rank_{self.rank_id}.ckpt"))
if not dst_strategy_list:
raise RuntimeError(f"The `dst_strategy`={self.dst_strategy_dir} \
does not contain strategy file of rank_{self.rank_id}.")
err_msg = (f"The `dst_strategy`={self.dst_strategy_dir} "
f"does not contain strategy file of rank_{self.rank_id}.")
logger.error(err_msg)
raise RuntimeError(err_msg)
if len(dst_strategy_list) > 1:
raise RuntimeError(f"There can only be one strategy file corresponding to rank_{self.rank_id}, \
but multiple strategy files corresponding to rank_{self.rank_id} were found \
in {self.dst_strategy_dir}.")
err_msg = (f"There can only be one strategy file corresponding to rank_{self.rank_id}, "
f"but multiple strategy files corresponding to rank_{self.rank_id} "
f"were found in {self.dst_strategy_dir}.")
logger.error(err_msg)
raise RuntimeError(err_msg)
dst_strategy = dst_strategy_list[0]
else:
dst_strategy = None

if check_in_modelarts():
if not mox.file.exists(self.transformed_checkpoint_dir_obs):
raise ValueError(f"transformed_checkpoint_dir_obs: "
f"{self.transformed_checkpoint_dir_obs} is not found!")
err_msg = f"transformed_checkpoint_dir_obs: {self.transformed_checkpoint_dir_obs} is not found!"
logger.error(err_msg)
raise ValueError(err_msg)
if self.world_size > 1 and not mox.file.exists(self.dst_strategy_dir_obs):
raise ValueError(f"dst_strategy_dir_obs: {self.dst_strategy_dir_obs} is not found!")

err_msg = f"dst_strategy_dir_obs: {self.dst_strategy_dir_obs} is not found!"
logger.error(err_msg)
raise ValueError(err_msg)

# Get final dst_strategy in auto_trans_ckpt mode.
dst_strategy = self.get_dst_strategy(dst_strategy)
@@ -247,13 +254,7 @@ class TransformCkpt:
barrier_world(f"Remake {dst_ckpt_dir} by main rank.")

logger.info("The transformed checkpoint will be saved under %s.", dst_ckpt_dir)
self.transform_ckpt(
src_checkpoint=src_ckpt_dir,
dst_checkpoint_dir=dst_ckpt_dir,
src_strategy=src_strategy,
dst_strategy=dst_strategy,
prefix=prefix
)
self.transform_ckpt(src_ckpt_dir, dst_ckpt_dir, src_strategy, dst_strategy, prefix)

self.clear_cache()
return dst_checkpoint_dir
@@ -267,7 +268,9 @@ class TransformCkpt:
"""Transform ckpt using mindspore.transform_checkpoint"""
self.check_src_checkpoint_and_strategy(src_checkpoint, src_strategy)
if src_strategy is None and dst_strategy is None:
raise ValueError("`src_strategy` and `dst_strategy` cannot both be None!")
err_msg = "`src_strategy` and `dst_strategy` cannot both be None!"
logger.error(err_msg)
raise ValueError(err_msg)
if check_in_modelarts():
dst_checkpoint_dir_obs = os.path.join(self.transformed_checkpoint_dir_obs,
os.path.basename(dst_checkpoint_dir))
@@ -339,7 +342,7 @@ class TransformCkpt:
dst_strategy):
"""transform checkpoints using mindspore.transform_checkpoint_by_rank"""
for current_transform_rank_id in \
range(self.rank_id, self.rank_id + self.world_size // self.transform_process_num):
range(self.rank_id, self.rank_id + self.world_size // self.transform_process_num):
logger.info(".........Transforming Ckpt For Rank: %d.........", current_transform_rank_id)
src_rank_list = ms.rank_list_for_transform(current_transform_rank_id,
src_strategy,
@@ -349,7 +352,9 @@ class TransformCkpt:
checkpoint_rank_dir = os.path.join(src_checkpoint, f"rank_{src_rank_id}")
checkpoint_file_list = glob(os.path.join(checkpoint_rank_dir, "*.ckpt"))
if not checkpoint_file_list:
raise ValueError(f"The checkpoint of rank_{src_rank_id} is not found!")
err_msg = f"The checkpoint of rank_{src_rank_id} is not found!"
logger.error(err_msg)
raise ValueError(err_msg)
checkpoint_file_list = sorted(checkpoint_file_list, key=os.path.getmtime)
checkpoint_file_map[src_rank_id] = checkpoint_file_list[-1]
save_checkpoint_dir = os.path.join(dst_checkpoint, f"rank_{current_transform_rank_id}")
@@ -371,11 +376,15 @@ class TransformCkpt:
def build_soft_link_of_checkpoint(checkpoint, soft_link_dir):
"""Build softlink of src checkpoint"""
if os.path.isdir(checkpoint) and not check_rank_folders(checkpoint, 0) and \
not check_ckpt_file_exist(checkpoint):
raise ValueError(f"No rank_0 folder or ckpt files are found under {checkpoint}.")
not check_ckpt_file_exist(checkpoint):
err_msg = f"No rank_0 folder or ckpt files are found under {checkpoint}."
logger.error(err_msg)
raise ValueError(err_msg)
if os.path.isfile(checkpoint) and not checkpoint.endswith('.ckpt'):
raise ValueError(f"The value of load_checkpoint must be a folder or a file with suffix '.ckpt', "
f"but got {checkpoint}")
err_msg = (f"The value of load_checkpoint must be a folder or a file with suffix '.ckpt', "
f"but got {checkpoint}")
logger.error(err_msg)
raise ValueError(err_msg)

if os.path.isdir(checkpoint):
if check_rank_folders(checkpoint, 0):
@@ -418,7 +427,9 @@ class TransformCkpt:
return None

if not os.path.exists(strategy_path):
raise ValueError(f'strategy_path: {strategy_path} not found!')
err_msg = f'strategy_path: {strategy_path} not found!'
logger.error(err_msg)
raise ValueError(err_msg)

if os.path.isfile(strategy_path):
return strategy_path
@@ -454,8 +465,9 @@ class TransformCkpt:

if not (dst_strategy.endswith(f"_rank_{self.rank_id}.ckpt") and
os.path.exists(dst_strategy)):
raise ValueError(f"dst_strategy: {dst_strategy} is not found!")

err_msg = f"dst_strategy: {dst_strategy} is not found!"
logger.error(err_msg)
raise ValueError(err_msg)

logger.info(".........Collecting strategy.........")
if check_in_modelarts():
@@ -521,16 +533,19 @@ class TransformCkpt:
"""
# Before obtaining transform_rank_id_list, check 1 ≤ transform_process_num ≤ world_size.
if transform_process_num < 1:
raise ValueError("transform_process_num should not smaller than 1,"
f"but got {transform_process_num}.")
err_msg = f"transform_process_num should not smaller than 1, but got {transform_process_num}."
logger.error(err_msg)
raise ValueError(err_msg)
if transform_process_num > self.world_size:
logger.warning(f"transform_process_num: {transform_process_num} should not "
f"bigger than world_size: {self.world_size}. "
f"transform_process_num is set to {self.world_size}.")
transform_process_num = self.world_size
if self.world_size % transform_process_num != 0:
raise ValueError(f"transform_process_num: {transform_process_num} "
f"should be divided by world_size: {self.world_size}.")
err_msg = (f"transform_process_num: {transform_process_num} "
f"should be divided by world_size: {self.world_size}.")
logger.error(err_msg)
raise ValueError(err_msg)

if check_in_modelarts() and 1 < transform_process_num < self.node_num:
logger.warning("transform_process_num: %d should not smaller than \
@@ -551,15 +566,15 @@ class TransformCkpt:

return transform_rank_id_list


@staticmethod
def check_src_checkpoint_and_strategy(src_checkpoint, src_strategy):
"""check src checkpoint and strategy"""
check_path(src_checkpoint, "src_checkpoint")
if not os.path.isdir(src_checkpoint) or not glob(os.path.join(src_checkpoint, "rank_*")):
raise ValueError("The load_checkpoint must be a dir and "
"ckpt should be stored in the format of load_checkpoint/rank_x/xxx.ckpt,"
f"but get {src_checkpoint}.")
err_msg = ("The load_checkpoint must be a dir and ckpt should be stored "
f"in the format of load_checkpoint/rank_x/xxx.ckpt, but get {src_checkpoint}.")
logger.error(err_msg)
raise ValueError(err_msg)
# Check rank_dirs is continuous.
# For example, rank_0, rank_1, rank_4 is not continuous because it is missing rank_3
src_checkpoint_rank_dir_list = glob(os.path.join(src_checkpoint, "rank_*"))
@@ -568,7 +583,9 @@ class TransformCkpt:
src_checkpoint_rank_num = len(src_checkpoint_rank_id_list)
for i in range(src_checkpoint_rank_num):
if src_checkpoint_rank_id_list[i] != i:
raise FileNotFoundError(f"The rank_{i} folder was not found under src_checkpoint folder.")
err_msg = f"The rank_{i} folder was not found under src_checkpoint folder."
logger.error(err_msg)
raise FileNotFoundError(err_msg)

# A full checkpoint do not require a strategy.
if len(src_checkpoint_rank_id_list) == 1 and src_strategy:
@@ -576,7 +593,9 @@ class TransformCkpt:
src_strategy = None
# Distributed checkpoints must be accompanied by strategy.
if len(src_checkpoint_rank_id_list) > 1 and src_strategy is None:
raise ValueError("`src_strategy` should not be None when `src_checkpoint` is sliced.")
err_msg = "`src_strategy` should not be None when `src_checkpoint` is sliced."
logger.error(err_msg)
raise ValueError(err_msg)

def send_strategy_to_obs(self, strategy):
"""Local rank send strategy file to obs."""
@@ -623,6 +642,7 @@ class TransformCkpt:
last_strategy_num = dst_strategy_num
if dst_strategy_num < self.world_size:
if time.time() - start_time > 7200:
logger.error("Timeout while collecting all strategy!")
raise TimeoutError("Timeout while collecting all strategy!")
time.sleep(5)
else:
@@ -642,7 +662,9 @@ class TransformCkpt:
transform_failed_txts = glob(os.path.join(ckpt_dir, 'transform_failed_rank_*.txt'))
transform_succeed_txts = glob(os.path.join(ckpt_dir, 'transform_succeed_rank_*.txt'))
if transform_failed_txts:
raise ValueError(f"Transform failed, find {transform_failed_txts}.")
err_msg = f"Transform failed, find {transform_failed_txts}."
logger.error(err_msg)
raise ValueError(err_msg)
current_count = len(transform_succeed_txts)
progress = (current_count / self.transform_process_num) * 100
if current_count != last_count:
@@ -653,7 +675,8 @@ class TransformCkpt:
else:
break

if __name__ == '__main__':

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--src_checkpoint',
default="",
@@ -707,3 +730,7 @@ if __name__ == '__main__':
)

print("......Transform finished!......")


if __name__ == '__main__':
main()

+ 39
- 18
mindformers/tools/resume_ckpt.py View File

@@ -84,7 +84,9 @@ def get_resume_checkpoint_by_meta(checkpoint_dir, ckpt_format='ckpt',
if check_in_modelarts():
resume_record_dir = os.path.join(get_remote_save_url(), "resume_record")
if not Validator.is_obs_url(resume_record_dir):
raise ValueError(f"{resume_record_dir} is not a valid obs path.")
err_meg = f"{resume_record_dir} is not a valid obs path."
logger.error(err_meg)
raise ValueError(err_meg)
else:
resume_record_dir = os.path.join(get_output_root_path(), "resume_record")
remake_folder(resume_record_dir, permissions=0o750)
@@ -151,7 +153,9 @@ def checkpoint_health_monitor(health_ckpts_record_dir, resume_ckpt_list):
health_ckpts.append(item)

if not health_ckpts:
raise ValueError("The training has no healthy checkpoints yet, please start training again.")
err_msg = "The training has no healthy checkpoints yet, please start training again."
logger.error(err_msg)
raise ValueError(err_msg)

if not not_health_ckpts:
not_health_ckpts_set = set(not_health_ckpts)
@@ -167,12 +171,16 @@ def get_resume_ckpt(latest_checkpointed_iteration_txt, rank_id):

if not check_in_modelarts():
if not os.path.exists(latest_checkpointed_iteration_txt):
raise ValueError(f"Can not find {latest_checkpointed_iteration_txt}")
err_msg = f"Can not find {latest_checkpointed_iteration_txt}"
logger.error(err_msg)
raise ValueError(err_msg)
with open(latest_checkpointed_iteration_txt, 'r', encoding='utf-8') as f:
resume_info = [line.strip() for line in f.readlines()]
else:
if not mox.file.exists(latest_checkpointed_iteration_txt):
raise ValueError(f"OBS: Can not find {latest_checkpointed_iteration_txt}")
err_msg = f"OBS: Can not find {latest_checkpointed_iteration_txt}"
logger.error(err_msg)
raise ValueError(err_msg)
with mox.file.File(latest_checkpointed_iteration_txt, 'r') as f:
resume_info = [line.strip() for line in f.readlines()]

@@ -182,7 +190,9 @@ def get_resume_ckpt(latest_checkpointed_iteration_txt, rank_id):
return True

if resume_info[0].startswith("Error"):
raise ValueError(f"Get resume-able checkpoint failed, due to {resume_info[0]}")
err_msg = f"Get resume-able checkpoint failed, due to {resume_info[0]}"
logger.error(err_msg)
raise ValueError(err_msg)

resume_ckpt = replace_rank_id_in_ckpt_name(resume_info[-1], rank_id)
logger.info("Get resume checkpoint: %s", resume_ckpt)
@@ -242,7 +252,9 @@ def get_resume_ckpt_list(checkpoint_dir, last_ckpt_file, rank_id, rank_dir_num,
ckpt_prefix_tmp = ckpt_prefix.replace(f"rank_{original_rank}", f"rank_{rank_id_tmp}")
checkpoint_rank_dir = os.path.join(checkpoint_dir, f"rank_{rank_id_tmp}")
if not os.path.exists(checkpoint_rank_dir):
raise FileNotFoundError(f"{checkpoint_rank_dir} is not found!")
err_msg = f"{checkpoint_rank_dir} is not found!"
logger.error(err_msg)
raise FileNotFoundError(err_msg)
for ckpt_file in os.listdir(checkpoint_rank_dir):
health_ckpt_match = (ckpt_file.startswith(ckpt_prefix_tmp[:ckpt_prefix_tmp.rfind("_")])
and use_checkpoint_health_monitor)
@@ -262,10 +274,14 @@ def get_resume_ckpt_list(checkpoint_dir, last_ckpt_file, rank_id, rank_dir_num,
ckpt_file = replace_rank_id_in_ckpt_name(ckpts[0], rank_id)
resume_ckpt = os.path.join(checkpoint_dir, f"rank_{rank_id}", ckpt_file)
if not os.path.exists(resume_ckpt):
raise FileNotFoundError(f"{resume_ckpt} is not found!")
err_msg = f"{resume_ckpt} is not found!"
logger.error(err_msg)
raise FileNotFoundError(err_msg)
resume_ckpt_list.append(resume_ckpt)
if not resume_ckpt_list:
raise RuntimeError("No checkpoint could be resumed.")
err_msg = "No checkpoint could be resumed."
logger.error(err_msg)
raise RuntimeError(err_msg)

if use_checkpoint_health_monitor:
resume_ckpt_list.sort(key=lambda x: get_times_epoch_and_step_from_ckpt_name(x, ckpt_format))
@@ -325,8 +341,9 @@ def check_last_timestamp_checkpoints(checkpoint_dir, rank_dir_num, ckpt_format='
checkpoint_rank_dir = os.path.join(checkpoint_dir, f"rank_{rank_id_tmp}")
last_checkpoint = get_last_checkpoint(checkpoint_rank_dir, ckpt_format)
if not last_checkpoint:
raise ValueError(f"Checkpoint not found under {checkpoint_rank_dir} "
f"with config.load_ckpt_format:{ckpt_format}.")
err_msg = f"Checkpoint not found under {checkpoint_rank_dir} with config.load_ckpt_format:{ckpt_format}."
logger.error(err_msg)
raise ValueError(err_msg)
if check_ckpt_file_name(last_checkpoint, ckpt_format):
compared_checkpoint_name = replace_rank_id_in_ckpt_name(last_checkpoint, 0)
compared_original_checkpoint_name = os.path.basename(last_checkpoint)
@@ -353,12 +370,16 @@ def check_last_timestamp_checkpoints(checkpoint_dir, rank_dir_num, ckpt_format='
compared_checkpoint_name = current_checkpoint_name
compared_original_checkpoint_name = original_checkpoint_name
elif compared_checkpoint_name != current_checkpoint_name:
raise ValueError(f"Check name of the checkpoint file with the last timestamp Failed.\n"
f"1. Find 2 different checkpoints name: {compared_original_checkpoint_name} and "
f"{original_checkpoint_name}.\n2. Checkpoint file name should follow rule: "
f"{{prefix}}-{{epoch}}_{{step}}.{ckpt_format}, and not corrupted across all rank "
f"folders.\n 3. Rename `resume_training` checkpoint such as "
f"llama_7b_rank_0-3_2.{ckpt_format} may solve the problem.")
err_msg = (f"Check name of the checkpoint file with the last timestamp Failed.\n"
f"1. Find 2 different checkpoints name: {compared_original_checkpoint_name} and "
f"{original_checkpoint_name}.\n2. Checkpoint file name should follow rule: "
f"{{prefix}}-{{epoch}}_{{step}}.{ckpt_format}, and not corrupted across all rank "
f"folders.\n 3. Rename `resume_training` checkpoint such as "
f"llama_7b_rank_0-3_2.{ckpt_format} may solve the problem.")
logger.error(err_msg)
raise ValueError(err_msg)
if find_diff_ckpt:
raise ValueError(f"Some checkpoints follow the {{prefix}}-{{epoch}}_{{step}}.{ckpt_format} "
f"naming convention, while others do not.")
err_msg = (f"Some checkpoints follow the {{prefix}}-{{epoch}}_{{step}}.{ckpt_format} "
f"naming convention, while others do not.")
logger.error(err_msg)
raise ValueError(err_msg)

+ 43
- 5
mindformers/tools/utils.py View File

@@ -18,6 +18,7 @@ import os
import re
import stat
import tempfile
import time
from multiprocessing import Process
from typing import Dict, List, Tuple, Union
from importlib import import_module
@@ -475,15 +476,52 @@ def is_last_pipeline_stage():
return (rank // device_num_per_stage + 1) == stage_num


def set_safe_mode_for_file_or_dir(path):
def set_safe_mode_for_file_or_dir(path, max_retries=5, retry_interval=2):
"""Set safe file permissions for file or directory with retry mechanism.

Args:
path (str or list): Path(s) to set permissions for.
max_retries (int): Maximum number of retry attempts. Default: 5.
retry_interval (int): Interval between retries in seconds. Default: 2.

Raises:
Exception: If all retry attempts fail.
"""
if isinstance(path, str):
path = [path]

for item in path:
success = True
err_msg = ""
item = Path(item)
if item.is_dir():
item.chmod(DIRECTORY_PERMISSION)
if item.is_file():
item.chmod(FILE_PERMISSION)
for attempt in range(max_retries):
try:
if attempt > 0:
time.sleep(retry_interval)
os.stat(item)

if item.is_dir():
item.chmod(DIRECTORY_PERMISSION)
if item.is_file():
item.chmod(FILE_PERMISSION)
break
except FileNotFoundError as e:
if attempt < max_retries - 1:
continue
success = False
err_msg = e
except PermissionError as e:
if attempt < max_retries - 1:
continue
success = False
err_msg = e
except Exception as e:
if attempt < max_retries - 1:
continue
success = False
err_msg = e
if not success:
raise Exception(err_msg)


def get_epoch_and_step_from_ckpt_name(ckpt_file, ckpt_fmt='ckpt'):


+ 16
- 15
mindformers/trainer/base_trainer.py View File

@@ -15,7 +15,7 @@
"""Base Trainer."""
import os
import re
import subprocess
import socket
from pprint import pprint
from functools import partial
from typing import Optional, Union, List
@@ -108,12 +108,8 @@ class BaseTrainer:

def __init__(self, task: str = None, model_name: str = None):

host_name_output = subprocess.run(['hostname'], shell=False, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, encoding='utf-8', check=True)
host_ip_output = subprocess.run(['hostname', '-I'], shell=False, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, encoding='utf-8', check=True)
host_name = host_name_output.stdout.strip()
host_ip = host_ip_output.stdout.strip().split(' ')[0]
host_name = socket.gethostname()
host_ip = socket.gethostbyname(host_name)
logger.info(f"host_name: {host_name}, host_ip: {host_ip}")

if model_name is None:
@@ -1133,7 +1129,9 @@ class BaseTrainer:
logger.info(".............Start load resume context from common.json..................")
common_file = os.path.join(config.load_checkpoint, 'common.json')
if not os.path.exists(common_file):
raise FileNotFoundError(f"No common.json found in directory '{config.load_checkpoint}'.")
error_msg = f"No common.json found in directory '{config.load_checkpoint}'."
logger.error(error_msg)
raise FileNotFoundError(error_msg)
common_info = CommonInfo.load_common(common_file)
step_scale = common_info.global_batch_size / config.runner_config.global_batch_size
config.runner_config.initial_step = int(common_info.step_num * step_scale)
@@ -1167,9 +1165,11 @@ class BaseTrainer:
logger.info("..............Start resume checkpoint path from strategy..............")
resume_ckpt_path = self.resume_ckpt_path_with_strategy(config)
if resume_ckpt_path is None:
raise ValueError(f"Try to resume from checkpoints with strategy in directory "
f"'{config.load_checkpoint}' failed, please specify load_checkpoint to "
f"specific checkpoint file to resume training.")
err_msg = (f"Try to resume from checkpoints with strategy in directory "
f"'{config.load_checkpoint}' failed, please specify load_checkpoint to "
f"specific checkpoint file to resume training.")
logger.error(err_msg)
raise ValueError(err_msg)
config.load_checkpoint = resume_ckpt_path
load_resume_context_from_checkpoint(config, dataset)
resume_dict = {
@@ -1253,8 +1253,9 @@ class BaseTrainer:
if hasattr(network, "get_model_parameters"):
model_params.update(network.get_model_parameters())
else:
raise NotImplementedError(f"The {type(network)} has not implemented the interface: "
f"get_model_parameters.")
err_msg = f"The {type(network)} has not implemented the interface: `get_model_parameters`."
logger.error(err_msg)
raise NotImplementedError(err_msg)

is_moe_model = False
is_mtp_model = False
@@ -1475,7 +1476,7 @@ class BaseTrainer:
)
load_checkpoint(
checkpoint=config.load_checkpoint,
network=model.train_network,
network=network,
optimizer=optimizer,
global_step=global_step,
balanced_load=config.balanced_load
@@ -1483,7 +1484,7 @@ class BaseTrainer:
else:
load_checkpoint(
checkpoint=config.load_checkpoint,
network=model.train_network,
network=network,
balanced_load=config.balanced_load
)
elif (config.load_checkpoint or config.only_save_strategy) and not check_is_reboot_node():


+ 424
- 67
mindformers/trainer/llm_trainer_for_graph_experimental/llm_trainer.py View File

@@ -79,11 +79,26 @@ __all__ = ['LLMTrainer']

@MindFormerRegister.register(MindFormerModuleType.TRAINER, legacy=False)
class LLMTrainer:
"""
LLM Model Trainer Class.
This class provides training and inference capabilities for Large Language Models,
handling dataset creation, model building, optimizer setup, and training loop execution.
"""
"""Initialize LLM Trainer instance.

This method initializes all instance variables required for training and inference.
All attributes are set to their default values (None for objects, False for booleans,
empty list for callbacks) and will be configured during the setup and training process.

Instance Variables:
llm_model: The neural network model instance (initialized as None)
train_dataset: The training dataset instance (initialized as None)
callbacks: List of training callbacks for monitoring and control
global_batch_size: Global batch size across all devices in distributed training
dataset_batch_size: Batch size for dataset processing
predict_batch_size: Batch size for prediction/inference
append_restore_info: Additional information for checkpoint restoration
common_restore_info: Common checkpoint restoration information
network_delay_inited: Flag indicating if network parameters use delayed initialization
optimizer_delay_inited: Flag indicating if optimizer parameters use delayed initialization
lr_scheduler: Learning rate scheduler instance
grouped_lr_scheduler: Grouped learning rate scheduler for different parameter groups
"""
def __init__(self) -> None:
self.llm_model = None
self.train_dataset = None
@@ -102,18 +117,43 @@ class LLMTrainer:
"""Initialize and setup configuration for training or inference.

This method sets up the configuration based on whether it's for training or inference mode.
For training, it configures parallel context, batch sizes, and other training-specific settings.
For inference, it validates parallel mode and sets data parallel size.
It performs comprehensive initialization including parallel context setup, batch size calculation,
seed configuration, and various training/inference-specific settings.

For training mode (is_train=True), it performs:
- Sets random seed for reproducibility
- Configures optimizer parallel context
- Sets up pipeline parallel context
- Configures dump local norm parallel context
- Resets and validates gradient accumulation steps
- Computes and sets data parallel size
- Sets dataset strategy parallel context
- Calculates training batch sizes based on global batch size
- Configures model settings for Muon optimizer if applicable
- Validates auto parallel mode for training

For inference mode (is_train=False), it performs:
- Sets inference seed
- Validates parallel mode for prediction
- Sets data parallel size

Args:
config (MindFormerConfig): Configuration object containing all training/inference settings.
Must include: training_args, distribute_parallel_config, optimizer, model, etc.
is_train (bool): Flag indicating whether setup is for training (True) or inference (False).
Defaults to True.

Raises:
ValueError: If config is None.
ValueError: If config is None or contains invalid values.
TypeError: If config is not an instance of MindFormerConfig.
RuntimeError: If parallel mode is not supported for the specified mode.

Side Effects:
- Sets self.config to the provided config object
- Modifies config.model structure (via _set_model_config_adapter_old_format)
- Updates various config attributes based on computed values
- Sets output directory path
- Logs host information
"""
if config is None:
raise ValueError("Configuration must be provided, but received None.")
@@ -160,12 +200,29 @@ class LLMTrainer:

This method configures the pipeline parallel settings for the model training process.
It sets the pipeline stages and pipeline configuration parameters such as interleave
and scheduler type when in auto parallel mode.
and scheduler type when in auto parallel mode. Pipeline parallelism splits the model
across multiple devices, with each device processing different layers sequentially.

The pipeline configuration includes:
- pipeline_model_parallel_size: Number of pipeline stages
- pipeline_interleave: Whether to enable pipeline interleave
- pipeline_scheduler: Scheduler type, default is "1f1b"
- pipeline_model_parallel_size: Number of pipeline stages (how many devices the model
is split across). Each stage processes a subset of model layers.
- pipeline_interleave: Whether to enable pipeline interleave optimization, which
improves pipeline efficiency by overlapping computation and communication.
- pipeline_scheduler: Scheduler type for pipeline execution. Default is "1f1b" (1 forward,
1 backward), which alternates forward and backward passes. Other options may include
"gpipe" "zero_bubble_v" or custom schedulers.

Note:
- This method only takes effect when auto parallel mode is valid and
distribute_parallel_config is provided
- Pipeline parallelism requires careful configuration of micro_batch_num to ensure
it's >= pipeline_model_parallel_size
- The configuration is applied to MindSpore's auto parallel context

Side Effects:
- Calls ms.set_auto_parallel_context() to configure pipeline stages
- Sets pipeline_config with interleave and scheduler settings
- Logs pipeline parallel configuration information
"""
# New process uses distribute_parallel_config to set PP-related parallel configuration
distribute_parallel_config = self.config.distribute_parallel_config
@@ -188,13 +245,32 @@ class LLMTrainer:
"""Set optimizer parallel context based on distributed parallel configuration.

This method configures the optimizer parallel settings for model training.
It enables parallel optimizer and sets optimizer level and weight shard size
when the distribute parallel configuration is provided and parallel optimizer is enabled.
Optimizer parallelism distributes optimizer states (e.g., momentum, variance) across
devices to reduce memory usage per device, enabling training of larger models.

The optimizer parallel configuration includes:
- enable_parallel_optimizer: Whether to enable parallel optimizer
- optimizer_level: Optimizer level, default is "level1"
- optimizer_weight_shard_size: Weight shard size, default is -1
- enable_parallel_optimizer: Whether to enable parallel optimizer. When enabled,
optimizer states are sharded across devices, reducing memory footprint.
- optimizer_level: Optimizer parallelization level. "level1" shards optimizer states
across data parallel dimension, "level2" may include additional optimizations.
Default is "level1".
- optimizer_weight_shard_size: Weight shard size for optimizer parallelism. -1 means
automatic calculation based on available devices. Positive values specify explicit
shard size.
- parallel_optimizer_threshold: Minimum parameter size threshold (in MB) for applying
optimizer parallelism. Parameters smaller than this threshold won't be sharded.
Default is 64 MB.

Note:
- This method only takes effect when distribute_parallel_config is provided and
enable_parallel_optimizer is True
- Optimizer parallelism is particularly useful for large models with limited
per-device memory
- The configuration is applied to MindSpore's auto parallel context

Side Effects:
- Calls ms.set_auto_parallel_context() to configure optimizer parallelism
- Logs optimizer parallel configuration information
"""
# New process uses distribute_parallel_config to set optimizer-related parallel configuration
distribute_parallel_config = self.config.distribute_parallel_config
@@ -292,17 +368,44 @@ class LLMTrainer:
"""Check runner config and set training step parameters.

This method calculates and configures the training steps based on the dataset size,
epochs, and sink mode settings. It adjusts the number of epochs when sink mode
is enabled and sink_size is specified. It also sets initial epoch and step values
for training resumption.
epochs, and sink mode settings. It handles sink mode optimization, which improves
training efficiency by processing data in batches within the computational graph.
It also sets initial epoch and step values for training resumption from checkpoints.

Args:
dataset (GeneratorDataset): Training dataset used to determine data size for
calculations. The dataset must have a get_dataset_size() method that returns
the total number of samples.
The method performs the following operations:
1. Gets the training dataset size
2. Sets original epochs value
3. Initializes gradient accumulation steps if not set
4. Sets initial epoch and step to 0 if not specified
5. Adjusts epochs calculation based on sink mode and sink size
6. Updates configuration with dataset size and training parameters
1. Gets the training dataset size via _get_train_dataset_size()
2. Stores original epochs value in origin_epochs for reference
3. Initializes gradient_accumulation_steps to 1 if not set
4. Sets initial_epoch and initial_step to 0 if not specified (for fresh training)
5. Adjusts epochs calculation based on sink mode:
- If sink_mode is True and sink_size is specified (> 0):
* Validates sink_size is positive
* Warns if dataset size < sink_size
* Calculates adjusted epochs: epochs = (data_size / sink_size) * original_epochs
* Sets sink_size to data_size if sink_size is -1
- If sink_mode is False: Sets sink_size to -1 (disabled)
6. Updates config.data_size with the dataset size
7. Logs training configuration information

Raises:
ValueError: If sink_size is set but is <= 0 (and not -1) when sink_mode is True.
RuntimeError: If train_dataset is not set (via _get_train_dataset_size).

Note:
- Sink mode improves training efficiency by reducing Python overhead
- When resuming training, initial_epoch and initial_step should be set before
calling this method
- The adjusted epochs calculation ensures consistent training duration regardless
of sink_size configuration

Side Effects:
- Modifies self.config.training_args with computed values
- Sets self.config.data_size
- Logs configuration information
"""
data_size = self._get_train_dataset_size()
new_epochs = self.config.training_args.epochs
@@ -413,6 +516,16 @@ class LLMTrainer:

This method computes the appropriate data parallel size based on the available
devices and other parallel configuration settings, then updates the configuration.

Data parallelism splits the training data across multiple devices, enabling
parallel processing of different data batches.

The method delegates the actual calculation to _compute_data_parallel_size()
and stores the result in the distribute_parallel_config.

Side Effects:
- Updates self.config.distribute_parallel_config.data_parallel_size
with the computed value
"""
self.config.distribute_parallel_config.data_parallel_size = self._compute_data_parallel_size()

@@ -420,14 +533,46 @@ class LLMTrainer:
"""Compute the data parallel size based on distributed configuration.

This method calculates the appropriate data parallel size based on the available
devices and other parallel configuration settings. It ensures that the parallel
configuration is valid and compatible with the device setup.
devices and other parallel configuration settings. Data parallelism splits the
training data across multiple devices, with each device processing a different
subset of the data.

Calculation logic:
1. If auto parallel mode is not valid, returns 1 (no data parallelism)
2. If data_parallel_size is explicitly set in config, returns that value
3. Otherwise, calculates: dp = device_num / (tp * pp * cp)
where:
- device_num: Total number of available devices
- tp: tensor_model_parallel_size (model parallelism across tensor dimensions)
- pp: pipeline_model_parallel_size (model parallelism across layers)
- cp: context_parallel_size (sequence parallelism)
4. For prediction mode with batch_size=1, forces dp=1
5. Validates expert_parallel_size constraints if MoE is used

Args:
self: Trainer instance with configured distribute_parallel_config

Returns:
int: The computed data parallel size.
int: The computed data parallel size. Represents how many data parallel groups
will be created. Each group processes a different shard of the training data.

Raises:
ValueError: If the parallel configuration is invalid or incompatible.
ValueError:
- If device_num is not divisible by (tp * pp * cp)
- If expert_parallel_size > data_parallel_size * tensor_parallel_size * context_parallel_size
- If (dp * tp * cp) is not divisible by expert_parallel_size when MoE is used

Note:
- Data parallel size must satisfy: device_num = dp * tp * pp * cp
- Expert parallelism (for MoE models) has additional constraints:
ep <= dp * tp * cp and (dp * tp * cp) % ep == 0
- In prediction mode with batch_size=1, data parallelism is disabled (dp=1)
- The method ensures all parallel dimensions are compatible

Example:
With 8 devices, tp=2, pp=2, cp=1:
dp = 8 / (2 * 2 * 1) = 2
This means 2 data parallel groups, each with 4 devices (2 tp * 2 pp)
"""
if not self._check_auto_parallel_mode_valid():
return 1
@@ -467,18 +612,51 @@ class LLMTrainer:
"""Compute training batch size according to Global Batch Size (GBS).

This method calculates the appropriate dataset batch size and global batch size based on
user-specified GBS and micro_batch_size. It handles different modes including:
1. Semi-auto/automatic parallel mode with various configurations
2. Data parallel/standalone mode

The method validates basic constraints and computes gradient accumulation steps or
micro batch numbers based on the parallel configuration.
user-specified GBS and micro_batch_size. The Global Batch Size (GBS) represents the
effective batch size across all devices and micro-batches, which is crucial for maintaining
consistent training dynamics in distributed settings.

The method handles different training modes:
1. Semi-auto/automatic parallel mode:
- Validates that GBS is divisible by (dp * micro_batch_size * micro_batch_interleave_num)
- Computes num_micro_batches = GBS / (dp * micro_batch_size * micro_batch_interleave_num)
- Sets gradient_accumulation_steps or micro_batch_num based on training configuration
- Validates pipeline parallel constraints (micro_batch_num >= pipeline_stages)
- Calculates train_data_batch_size = GBS / data_parallel_size

2. Data parallel/standalone mode:
- Calculates train_data_batch_size = GBS / device_num
- Resets distribute_parallel_config to defaults

Formula:
GBS = data_parallel_size * micro_batch_size * micro_batch_interleave_num * num_micro_batches
train_data_batch_size = GBS / data_parallel_size (in auto parallel mode)
train_data_batch_size = GBS / device_num (in standalone/data parallel mode)

Returns:
tuple: A tuple containing (train_data_batch_size, global_batch_size)
tuple[int, int]: A tuple containing:
- train_data_batch_size (int): Batch size for each device's dataset processing
- global_batch_size (int): Global batch size across all devices (same as input GBS)

Raises:
ValueError: If batch sizes are invalid or incompatible with parallel configuration.
ValueError:
- If global_batch_size or micro_batch_size <= 0
- If GBS is not divisible by (dp * micro_batch_size * micro_batch_interleave_num)
- If gradient_accumulation_steps > 1 and pipeline_parallel_size > 1 simultaneously
- If micro_batch_num < pipeline_model_parallel_size in pipeline parallel mode

Note:
- Gradient accumulation and pipeline parallel cannot be used simultaneously
- In pipeline parallel mode, micro_batch_num must be >= pipeline_model_parallel_size
- The method modifies config values (gradient_accumulation_steps, micro_batch_num)
based on computed values
- For standalone mode, distribute_parallel_config is reset to defaults

Side Effects:
- Modifies self.config.training_args.gradient_accumulation_steps (if applicable)
- Modifies self.config.distribute_parallel_config.micro_batch_num (if applicable)
- May reset distribute_parallel_config in standalone mode
- Logs batch size calculation information
"""
# Validate basic constraints
global_batch_size = self.config.training_args.global_batch_size
@@ -554,12 +732,25 @@ class LLMTrainer:
This method resets the `distribute_parallel_config` to its default configuration
and explicitly sets the `data_parallel_size` to 1. This is typically used when
falling back to a standalone or data parallel mode where complex parallel strategies
are not needed or supported.
(tensor parallelism, pipeline parallelism, etc.) are not needed or supported.

The method performs the following operations:
1. Updates the distribute_parallel_config with default values from TrainingParallelConfig
2. Sets the data_parallel_size to 1
This resets all parallel dimensions (tensor_parallel_size, pipeline_parallel_size,
context_parallel_size, expert_parallel_size) to their defaults (typically 1)
2. Explicitly sets the data_parallel_size to 1, indicating no data parallelism
3. Logs the configuration change for debugging purposes

Note:
- This method is typically called when not in auto parallel mode
- After resetting, the configuration represents a single-device or simple
data parallel setup
- All model parallelism features are disabled after this reset

Side Effects:
- Modifies self.config.distribute_parallel_config with default values
- Sets data_parallel_size to 1
- Logs configuration reset information
"""
self.config.distribute_parallel_config.update(TrainingParallelConfig().default_value())
self.config.distribute_parallel_config.data_parallel_size = 1
@@ -624,12 +815,51 @@ class LLMTrainer:
self.config.model.model_config.disable_lazy_inline = True

def _set_construct_args_key(self, column_names: List[str] = None) -> None:
"""Set the construct arguments key for dataset processing.

This method configures the column names that will be used for dataset construction.
The construct_args_key specifies which columns from the dataset should be passed
to the model's forward function. This is essential for proper data flow during
training and inference.

Args:
column_names (List[str], optional): List of column names to be used for dataset
construction. These names should match the keys in the dataset output.
If None and construct_args_key is not already set in config, no action is taken.

Note:
- This method only sets the construct_args_key if it's not already configured
in the config and column_names is provided
- The column names are typically extracted from the dataset's input columns
during dataset creation
- This configuration affects how data is passed to the model during training
"""
if self.config.train_dataset.construct_args_key is None and column_names is not None:
self.config.train_dataset.construct_args_key = column_names
logger.info("The config of train_dataset.construct_args_key has been set to %s.",
self.config.train_dataset.construct_args_key)

def _set_model_config_for_muon_optimizer(self) -> None:
"""Configure model settings for Muon optimizer.

This method automatically enables the maximum attention logit monitoring feature
when Muon optimizer is detected. The Muon optimizer requires monitoring of maximum
attention logits during training to track attention patterns and optimize training
stability. When enabled, the model will record the maximum logit values from attention
mechanisms, which can be used for debugging, monitoring, and optimization purposes.

The configuration is set automatically during training setup and only takes effect
when the optimizer type is explicitly set to "Muon". This ensures that the necessary
monitoring infrastructure is in place before training begins.

Note:
This method should be called during the training configuration phase, typically
as part of the initial setup process before model construction.

Side Effects:
- Modifies `self.config.model.model_config.monitor_max_attention_logit` to True
- Logs an informational message when Muon optimizer is detected
"""
# Enable max attention logits for Muon optimizer
if self.config.optimizer.type == "Muon":
self.config.model.model_config.monitor_max_attention_logit = True
@@ -853,24 +1083,50 @@ class LLMTrainer:
def _wrap_network_with_tool_cells(self, network: nn.Cell) -> nn.Cell:
"""Wrap the network with tool cells for training process.

This method wraps the network with various tool cells based on the training configuration:
1. Micro-batch interleaving for double copy parallel feature
2. Gradient accumulation cell for gradient accumulation training
3. Pipeline cell for pipeline parallel training
4. Virtual dataset cell for auto parallel training
This method wraps the network with various tool cells based on the training configuration.
These wrappers enable advanced training features like gradient accumulation, pipeline
parallelism, and data parallelism optimizations. The wrappers are applied in a specific
order to ensure proper functionality.

Wrapping order (from innermost to outermost):
1. MicroBatchInterleaved: Enables double copy parallel feature for improved memory
efficiency when micro_batch_interleave_num > 1
2. GradAccumulationCellWithMultiOutputs: Enables gradient accumulation when
gradient_accumulation_steps > 1 and not using pipeline parallel
3. PipelineCellWithMultiOutputs: Enables pipeline parallel training when
pipeline_stages > 1
4. _VirtualDatasetCell: Enables virtual dataset for auto parallel mode in graph mode,
ensuring proper data distribution across devices

Args:
network (nn.Cell): The base network to be wrapped.
network (nn.Cell): The base network to be wrapped. This should be the core model
without any training wrappers.

Returns:
nn.Cell: The wrapped network with appropriate tool cells.
nn.Cell: The wrapped network with appropriate tool cells applied in the correct
order. The returned network is ready for training with the configured
parallel and optimization features.

The method performs the following operations:
- Applies MicroBatchInterleaved wrapper when micro_batch_interleave_num > 1
- Applies GradAccumulationCellWithMultiOutputs when gradient_accumulation_steps > 1
and not using pipeline parallel
- Applies PipelineCellWithMultiOutputs when pipeline stages > 1
- Applies _VirtualDatasetCell for auto parallel mode in graph mode
and not using pipeline parallel (pipeline parallel has its own gradient handling)
- Applies PipelineCellWithMultiOutputs when pipeline stages > 1, which handles
micro-batch scheduling across pipeline stages
- Applies _VirtualDatasetCell for auto parallel mode in graph mode (mode == 0),
which ensures proper data sharding and broadcasting
- Configures dataset broadcast optimization level if applicable

Note:
- Wrappers are applied sequentially, with each wrapper wrapping the previous result
- The order matters: MicroBatchInterleaved is innermost, _VirtualDatasetCell is outermost
- Pipeline parallel and gradient accumulation are mutually exclusive at this level
- Virtual dataset cell is only applied in graph mode (not in PyNative mode)

Side Effects:
- Modifies the network structure by adding wrapper layers
- Logs wrapper application information for each wrapper type
- May configure dataset broadcast optimization attributes
"""
micro_batch_interleave_num = self.config.distribute_parallel_config.micro_batch_interleave_num
gradient_accumulation_steps = self.config.training_args.gradient_accumulation_steps
@@ -909,18 +1165,44 @@ class LLMTrainer:
def _init_parameters_data(self, network: nn.Cell, optimizer: Optional[nn.Optimizer] = None) -> None:
"""Initialize network and optimizer parameters data.

This method initializes the parameters data for both network and optimizer when needed:
1. Initializes network parameters if network_delay_inited flag is set
2. Initializes optimizer parameters if optimizer_delay_inited flag is set and optimizer is provided
This method initializes the parameters data for both network and optimizer when delayed
initialization was used. Delayed initialization is a memory optimization technique where
parameter memory allocation is deferred until after checkpoint loading, allowing for
more efficient memory usage during model construction.

The method handles two scenarios:
1. Network parameter initialization: When network_delay_inited is True, initializes
all network parameters. This is typically needed when parameters were created with
no_init_parameters() context manager.
2. Optimizer parameter initialization: When optimizer_delay_inited is True and an
optimizer is provided, initializes optimizer state parameters (e.g., momentum buffers,
variance estimates for Adam).

Args:
network (nn.Cell): The neural network model whose parameters need to be initialized.
Must have an init_parameters_data() method if network_delay_inited is True.
optimizer (Optional[nn.Optimizer]): The optimizer whose parameters need to be initialized.
Defaults to None.
Must have an init_parameters_data() method if optimizer_delay_inited is True.
Defaults to None. If None and optimizer_delay_inited is True, optimizer
initialization is skipped.

The method performs the following operations:
- Calls network.init_parameters_data() when network_delay_inited is True
- Calls optimizer.init_parameters_data() when optimizer_delay_inited is True and optimizer is not None
- Calls network.init_parameters_data() when network_delay_inited is True, which
allocates memory for all network parameters
- Calls optimizer.init_parameters_data() when optimizer_delay_inited is True and
optimizer is not None, which allocates memory for optimizer state variables

Note:
- Delayed initialization is typically used when loading checkpoints, as it allows
loading weights before allocating parameter memory
- Both network and optimizer can have delayed initialization independently
- If flags are False, no initialization is performed (parameters were already initialized)
- This method should be called after checkpoint loading but before training starts

Side Effects:
- Allocates memory for network parameters if network_delay_inited is True
- Allocates memory for optimizer state if optimizer_delay_inited is True
- Logs initialization status for debugging
"""
if self.network_delay_inited:
logger.info("Initializing network parameters data with delay initialization...")
@@ -1033,22 +1315,66 @@ class LLMTrainer:
"""Create training dataset for LLM training.

This method creates and configures a training dataset based on the configuration settings.
It handles different data loader types and applies appropriate dataset processing.
It handles the complete dataset creation pipeline including data loading, preprocessing,
sharding, batching, and special optimizations. The method supports various data loader
types and applies appropriate dataset processing for distributed training.

Dataset Creation Pipeline:
1. Creates LLMDataset instance from train_dataset configuration
2. Configures MindSpore dataset settings (seed, prefetch, NUMA)
3. Generates shard information for data parallel distribution
4. Determines input columns based on attention mask and EOD mask requirements
5. Creates data loader with appropriate column names and sharding
6. Builds final dataset with batching, micro-batch handling, and other options
7. Applies special processing for BlendedMegatronDatasetDataLoader if needed

Returns:
GeneratorDataset: Configured training dataset ready for model training.
GeneratorDataset: Configured training dataset ready for model training. The dataset
is properly sharded for data parallelism, batched according to configuration,
and includes all necessary preprocessing (padding, masking, etc.).

Raises:
ValueError: If dataset broadcast optimization level is incompatible with sink mode,
or if broadcast data configuration is invalid for certain data loaders.
ValueError:
- If dataset broadcast optimization level is incompatible with sink mode
- If broadcast data configuration is invalid for certain data loaders
- If required configuration parameters are missing
RuntimeError: If dataset creation fails due to invalid data paths or formats

The method performs the following operations:
- Validates dataset broadcast optimization level compatibility
- Creates LLMDataset instance with provided configuration
- Configures dataset sharding information
- Creates data loader with specified column names
- Builds final dataset with batching and processing options
- Applies special processing for BlendedMegatronDatasetDataLoader if needed
- Creates LLMDataset instance with train_dataset configuration
- Sets MindSpore dataset configuration (seed, prefetch_size, numa_enable)
- Generates shard information (shard_id, num_shards) for data parallel distribution
- Determines if compressed EOD mask and attention mask should be created
- Gets default input columns based on mask requirements
- Sets construct_args_key for dataset construction
- Enables EOD attention mask compression if needed
- Creates data loader with column names and sharding configuration
- Calculates micro_batch_num based on pipeline stages or gradient accumulation
- Creates final dataset with:
* Data batch size from training_args
* Drop remainder setting
* Input/output column configuration
* Micro batch number for pipeline parallel or gradient accumulation
* EOD reset configuration
* Dynamic batching settings
* Padding token configuration
* Parallel workers configuration
* Token profiling configuration (if enabled)
- Applies special handling for BlendedMegatronDatasetDataLoader:
* Adjusts dataset size to align with global batch size
* Removes redundant data

Note:
- The dataset is sharded based on data_parallel_size for distributed training
- Micro batch number is determined by pipeline_stages or gradient_accumulation_steps
- EOD mask compression can be enabled for memory efficiency
- BlendedMegatronDatasetDataLoader requires special size adjustment
- The dataset supports token profiling for performance analysis

Side Effects:
- Sets self.config.model.model_config.use_eod_attn_mask_compression if needed
- Sets train_dataset.construct_args_key via _set_construct_args_key()
- Logs dataset creation and configuration information
"""
llm_dataset = LLMDataset(dataset_config=self.config.train_dataset)
dataset_seed = self.config.training_args.dataset_seed or self.config.training_args.training_seed or 1234
@@ -1830,6 +2156,26 @@ class LLMTrainer:
return load_checkpoint_path_or_dir

def _get_muon_optimizer_kwargs(self) -> dict:
"""Get keyword arguments for Muon optimizer initialization.

This method prepares the required keyword arguments for creating a Muon optimizer
instance. The Muon optimizer requires specific parameters including the model instance
and micro batch number, which is determined based on the training configuration.

Returns:
dict: Dictionary containing keyword arguments for Muon optimizer:
- "model": The LLM model instance
- "micro_batch_num": Number of micro batches per step, determined by:
- pipeline_model_parallel_size > 1: Uses micro_batch_num from config
- Otherwise: Uses gradient_accumulation_steps from training args
- Defaults to 1 if neither is set

Note:
- This method is specifically for Muon optimizer configuration
- The micro_batch_num calculation ensures compatibility with both pipeline
parallel and gradient accumulation training modes
- The model instance must be set before calling this method
"""
micro_batch_num = self.config.distribute_parallel_config.get("micro_batch_num", 1) \
if self._get_pipeline_stages() > 1 \
else self.config.training_args.get("gradient_accumulation_steps", 1)
@@ -1876,10 +2222,21 @@ class LLMTrainer:
"""Get the GPT model's Transformer configuration information.

This method retrieves the GPT Transformer configuration information from the LLM model instance.
It is primarily used to obtain specific configuration parameters of the model, such as embedding dimension size.
It is primarily used to obtain specific configuration parameters of the model, such as
embedding dimension size, number of layers, attention heads, MoE configuration, etc.

Returns:
int: The GPT Transformer configuration information.
TransformerConfig: The GPT Transformer configuration object containing all model
architecture parameters including:
- Model dimensions (hidden_size, num_layers, num_heads, etc.)
- MoE configuration (if applicable)
- Attention and MLP configurations
- Parallel and optimization settings
- Other transformer-specific parameters

Raises:
AttributeError: If llm_model is not set or doesn't have the get_gpt_transformer_config method.
"""
return self.llm_model.get_gpt_transformer_config()



+ 0
- 12
mindformers/trainer/optimizer_grouped_parameters.py View File

@@ -182,18 +182,6 @@ def get_optimizer_grouped_parameters(model: Optional[PreTrainedModel] = None,
# Append parameter to its group
parameter_group_vars[group_name]["params"].append(param)
parameter_group_names[group_name]["params"].append(param.name)
for param in model.get_parameters():
if "max_logits_val" in param.name:
param.requires_grad = False
if parameter_group_vars.get("max_logits") is None:
parameter_group_names["max_logits"] = {
"params": [],
}
parameter_group_vars["max_logits"] = {
"params": [],
}
parameter_group_vars["max_logits"]["params"].append(param)
parameter_group_names["max_logits"]["params"].append(param.name)
param_groups = json.dumps(parameter_group_names, indent=2)
logger.info("Param groups = %s", param_groups)
return list(parameter_group_vars.values())

+ 25
- 5
mindformers/trainer/trainer.py View File

@@ -1356,6 +1356,7 @@ class Trainer:
self._clear_redundant_model_checkpoint_name()
self._warn_if_resume_training_is_string()
self._error_if_checkpoint_prefix_contains_rank_info()
self._check_balanced_load_if_not_use_parallel()

def _adjust_resume_training_if_ckpt_path_invalid(self):
"""Disable resume_training if checkpoint path is empty or invalid."""
@@ -1401,9 +1402,21 @@ class Trainer:
"""Ensure auto_trans_ckpt has shared directory requirements met."""
if self.config.auto_trans_ckpt:
if not is_publicly_accessible_path(get_output_root_path()):
raise ValueError(f"When device num > {get_device_num_per_node()} and auto_trans_ckpt is set to True, "
f"the output_dir should be a shared directory that can be accessed by all nodes. "
f"But {os.path.abspath(self.config.output_dir)} is not a shared directory.")
raise ValueError(
f"When device num > {get_device_num_per_node()} and auto_trans_ckpt is set to True, "
f"the output_dir must be a shared directory accessible by all nodes. "
f"However, {os.path.abspath(self.config.output_dir)} is not recognized as a shared path. "
f"\n\nDetails:"
f"\n1. This error occurs because distributed training with multiple nodes requires "
f"the checkpoint output directory to be accessible across all nodes "
f"(e.g., NFS-mounted directory, distributed file system path, or shared storage)."
f"\n2. If you confirm that {os.path.abspath(self.config.output_dir)} IS a shared directory "
f"accessible by all nodes, explicitly mark it as a shared path by setting the following "
f"environment variable to bypass this check:\n "
f"export SHARED_PATHS={os.path.abspath(self.config.output_dir)}"
f"\n3. If the path is NOT a shared directory, update config.output_dir to a "
f"valid shared directory path that all nodes can read from and write to."
)
clear_auto_trans_output(
self.config.load_checkpoint, self.config.src_strategy_path_or_dir, self.config.load_ckpt_format)

@@ -1450,6 +1463,12 @@ class Trainer:
if "rank" in callback.get("prefix", "mindformers"):
raise ValueError("The prefix for saving checkpoint is not allowed to contain 'rank'.")

def _check_balanced_load_if_not_use_parallel(self):
if self.config.balanced_load and not self.config.use_parallel:
logger.warning(f"`balanced_load={self.config.balanced_load}` is not valid "
f"when `use_parallel={self.config.use_parallel}`.")
self.config.balanced_load = False

def _check_args_task_and_model(self):
"""Check args, task and model."""
# get support model names of task
@@ -1568,8 +1587,9 @@ def _reset_config_for_save(config: dict = None):
"""
if config is None:
config = {}
config = copy.deepcopy(config)
reset_parallel_config(config)
else:
config = copy.deepcopy(config)
reset_parallel_config(config)

config_dict = OrderedDict()



+ 37
- 15
mindformers/trainer/utils.py View File

@@ -58,9 +58,9 @@ class BaseEnum(str, Enum):
@classmethod
def _missing_(cls, value):
"""Enum with more explicit error message for missing values."""
raise ValueError(
f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
)
err_msg = f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
logger.error(err_msg)
raise ValueError(err_msg)


class IntervalStrategy(BaseEnum):
@@ -132,7 +132,9 @@ def preload_ckpt(config):
set_mindio_server_info(mindio_pool_capacity)

if not os.path.realpath(ckpt_path) or not os.path.exists(ckpt_path):
raise FileNotFoundError(f"The load_checkpoint must be correct, but get {ckpt_path}")
err_log = f"The load_checkpoint must be correct, but get {ckpt_path}"
logger.error(err_log)
raise FileNotFoundError(err_log)

preload_ok = False
if os.path.isfile(ckpt_path):
@@ -153,7 +155,9 @@ def preload_ckpt(config):
logger.info(f"MindIO preloading `{checkpoint_path}`...")
preload_ok = mindio_preload(checkpoint_path)
else:
raise ValueError(f"{ckpt_path} is not a valid path to load checkpoint when auto_trans_ckpt is False.")
err_msg = f"{ckpt_path} is not a valid path to load checkpoint when auto_trans_ckpt is False."
logger.error(err_msg)
raise ValueError(err_msg)

if preload_ok:
logger.info("MindIO preload checkpoint successfully!")
@@ -197,8 +201,10 @@ def check_runner_config(config, dataset):
if config.runner_config.sink_mode:
if config.runner_config.sink_size != -1:
if config.runner_config.sink_size <= 0:
raise ValueError("per epoch size must be more than 0 or equal to -1, "
f"but get {config.runner_config.sink_size}")
err_log = (f"per epoch size must be more than 0 or equal to -1, "
f"but get {config.runner_config.sink_size}")
logger.error(err_log)
raise ValueError(err_log)
if data_size < config.runner_config.sink_size:
logger.warning("The data size %s (get from dataset.get_dataset_size()) is smaller "
"than the sink_size %s (get from config.runner_config.sink_size), "
@@ -320,7 +326,9 @@ def get_distribute_checkpoint_path(checkpoint_dir, rank_id=None, ckpt_format='ck
logger.info("Your load_checkpoint is file, it will be load in network.")
distribute_checkpoint_path = checkpoint_dir
else:
raise FileNotFoundError(f"{checkpoint_dir} is not found.")
err_msg = f"{checkpoint_dir} is not found."
logger.error(err_msg)
raise FileNotFoundError(err_msg)
return distribute_checkpoint_path


@@ -348,7 +356,9 @@ def load_resume_context_from_checkpoint(config, dataset):
"""resume training, load training info from checkpoint to config"""
if not os.path.realpath(config.load_checkpoint) or \
not os.path.exists(config.load_checkpoint):
raise FileNotFoundError(f"The load_checkpoint must be correct, but get {config.load_checkpoint}")
err_log = f"The load_checkpoint must be correct, but get {config.load_checkpoint}"
logger.error(err_log)
raise FileNotFoundError(err_log)

if os.path.isdir(config.load_checkpoint):
# When graceful exit is enabled or auto checkpoint transformation is disabled,
@@ -375,6 +385,10 @@ def load_resume_context_from_checkpoint(config, dataset):
else:
checkpoint_tmp = os.path.join(config.load_checkpoint, f"rank_{rank_id}",
replace_rank_id_in_ckpt_name(config.resume_training, rank_id))
if not os.path.isfile(checkpoint_tmp):
err_msg = f"{checkpoint_tmp} is not found!"
logger.error(err_msg)
raise FileNotFoundError(err_msg)
resume_dict = load_checkpoint(
checkpoint_tmp,
choice_func=lambda x: x in ["loss_scale", "epoch_num", "step_num", "global_batch_size"],
@@ -512,15 +526,20 @@ def transform_and_load_checkpoint(config, model, network, dataset, optimizer=Non


def check_checkpoint_config_valid(config):
"""Check valid load checkpoint path in config."""
# check valid load checkpoint path
if not config.only_save_strategy and (not os.path.realpath(config.load_checkpoint) or
not os.path.exists(config.load_checkpoint)):
raise FileNotFoundError(f"The load_checkpoint must be correct, but get {config.load_checkpoint}")
err_msg = f"The load_checkpoint must be correct, but get {config.load_checkpoint}"
logger.error(err_msg)
raise FileNotFoundError(err_msg)

# check valid format
if config.load_ckpt_format is not None and config.load_ckpt_format not in CkptFormat.support_type():
raise ValueError(
f"config.load_ckpt_format only support for 'ckpt' or 'safetensors', but got {config.load_ckpt_format}.")
err_msg = ("config.load_ckpt_format only support for 'ckpt' or 'safetensors', "
f"but got {config.load_ckpt_format}.")
logger.error(err_msg)
raise ValueError(err_msg)


def check_path_include_total_ckpt(path):
@@ -570,7 +589,9 @@ def load_slora_ckpt(checkpoint_dict, config, network):
logger.info("............Start load slora checkpoint ............")
adapter_path = os.path.join(pet_config.adapter_path, "lora_adapter.json")
if not os.path.exists(adapter_path):
raise FileNotFoundError(f"The adapter_path must be correct, but get {adapter_path}")
err_msg = f"The adapter_path must be correct, but get {adapter_path}"
logger.error(err_msg)
raise FileNotFoundError(err_msg)
with open(adapter_path, 'r', encoding='utf-8') as file:
path_dict = json.load(file)
adapter_list = []
@@ -686,8 +707,9 @@ def get_load_checkpoint_result(config):
else:
checkpoint_dict = load_distributed_checkpoint(config.load_checkpoint)
else:
raise ValueError(f"{config.load_checkpoint} is not a valid path to load checkpoint "
f"when auto_trans_ckpt is False.")
err_msg = f"{config.load_checkpoint} is not a valid path to load checkpoint when auto_trans_ckpt is False."
logger.error(err_msg)
raise ValueError(err_msg)
return checkpoint_dict if checkpoint_dict else checkpoint_future




+ 61
- 27
mindformers/utils/load_checkpoint_utils.py View File

@@ -65,7 +65,7 @@ def _get_origin_network(network):
"""recursive find if cells which have function <convert_name>"""
if 'convert_name' in dir(network):
return network, True
#DFS for network
# DFS for network
for cell in list(network.cells()):
network, find_cell = _get_origin_network(cell)
if find_cell:
@@ -91,9 +91,13 @@ def get_load_path_after_hf_convert(config, network):
def _check_checkpoint_path(path):
"""check checkpoint path."""
if not isinstance(path, str) or isinstance(path, os.PathLike):
raise ValueError(f"config.load_checkpoint must be a `str`, but got `{path}` as type `{type(path)}`.")
err_msg = f"config.load_checkpoint must be a `str`, but got `{path}` as type `{type(path)}`."
logger.error(err_msg)
raise ValueError(err_msg)
if not os.path.exists(path):
raise FileNotFoundError(f"config.load_checkpoint `{path}` does not exist.")
err_msg = f"config.load_checkpoint `{path}` does not exist."
logger.error(err_msg)
raise FileNotFoundError(err_msg)

if path[-1] == '/': # remove last '/' in path
return path[:-1]
@@ -113,7 +117,9 @@ def _get_checkpoint_mode(config):

# check path is dir
if not os.path.isdir(checkpoint_path):
raise ValueError("Provided path is neither a file nor a directory.")
err_msg = "Provided path is neither a file nor a directory."
logger.error(err_msg)
raise ValueError(err_msg)

dir_files = os.listdir(checkpoint_path)
if any(folder_name.startswith('rank_') for folder_name in dir_files):
@@ -122,7 +128,9 @@ def _get_checkpoint_mode(config):
if any(file_name.endswith(config.load_ckpt_format) for file_name in dir_files):
return CheckpointFileMode.MULTI_CHECKPOINT_FILE.value

raise ValueError("not support mode: no valid checkpoint files found")
err_msg = "not support mode: no valid checkpoint files found"
logger.error(err_msg)
raise ValueError(err_msg)


def _get_src_strategy(config):
@@ -139,8 +147,10 @@ def _get_src_strategy(config):
src_strategy_path = os.path.join(upper_dir, 'strategy')
logger.info(f"src_strategy_path_or_dir is empty, load source strategy from {src_strategy_path}.")
else:
raise ValueError("when use checkpoint after train/finetune, src_strategy_path_or_dir should be set "
"as a folder contained strategy ckpt files.")
err_msg = ("when use checkpoint after train/finetune, src_strategy_path_or_dir should be set "
"as a folder contained strategy ckpt files.")
logger.error(err_msg)
raise ValueError(err_msg)
logger.info(f"load source strategy from {src_strategy_path}.")
return src_strategy_path

@@ -246,7 +256,9 @@ def _get_src_file(checkpoint_dir, checkpoint_name=None, ckpt_format='ckpt'):
else:
ckpt_path = get_last_checkpoint(checkpoint_rank_dir, ckpt_format)
if not os.path.exists(ckpt_path):
raise FileNotFoundError(f"{ckpt_path} is not found.")
err_msg = f"{ckpt_path} is not found."
logger.error(err_msg)
raise FileNotFoundError(err_msg)
return ckpt_path


@@ -262,7 +274,9 @@ def load_checkpoint_with_safetensors(config, model, network, input_data, do_eval

pet_config = config.model.model_config.get("pet_config")
if pet_config and pet_config.pet_type == "slora" and network.lora_list:
raise ValueError(f"slora only support .ckpt file, {config.load_ckpt_format} file will be compatible soon.")
err_msg = f"slora only support .ckpt file, {config.load_ckpt_format} file will be compatible soon."
logger.error(err_msg)
raise ValueError(err_msg)
ckpt_file_mode = _get_checkpoint_mode(config)
validate_config_with_file_mode(ckpt_file_mode, config.use_parallel, config.auto_trans_ckpt)
# reduce compile time in prediction
@@ -314,7 +328,7 @@ def load_checkpoint_with_safetensors(config, model, network, input_data, do_eval
if config.resume_training or (config.get('remove_redundancy', False) and not do_predict):
# pylint: disable=W0212
network = model._train_network
#build model
# build model
if config.use_parallel:
compile_model(
model=model,
@@ -325,7 +339,7 @@ def load_checkpoint_with_safetensors(config, model, network, input_data, do_eval
sink_size=config.runner_config.sink_size,
do_eval=do_eval, do_predict=do_predict
)
#wait generate all rank strategy files
# wait generate all rank strategy files
barrier()

# only execute qkv concat check on the main rank in predict mode
@@ -337,7 +351,7 @@ def load_checkpoint_with_safetensors(config, model, network, input_data, do_eval
barrier()

process_for_stand_alone_mode(config, network, strategy_path)
#merge dst strategy
# merge dst strategy
strategy_path = get_merged_dst_strategy_path(config, strategy_path)
load_safetensors_checkpoint(config, load_checkpoint_files, network, strategy_path, load_checkpoint, optimizer)

@@ -366,18 +380,26 @@ def validate_config_with_file_mode(ckpt_file_mode, use_parallel, auto_trans_ckpt
"""validate use_parallel and auto_trans_ckpt config with different file mode"""
if ckpt_file_mode == CheckpointFileMode.SINGLE_CHECKPOINT_FILE.value:
if use_parallel:
raise ValueError("When load checkpoint is a single file and use_parallel is True, please change "
"load_checkpoint in yaml from file name to the directory where only this file is located.")
err_msg = ("When load checkpoint is a single file and use_parallel is True, please change "
"load_checkpoint in yaml from file name to the directory where only this file is located.")
logger.error(err_msg)
raise ValueError(err_msg)
elif ckpt_file_mode == CheckpointFileMode.MULTI_CHECKPOINT_FILE.value:
if use_parallel and not auto_trans_ckpt:
raise ValueError("When load checkpoint is complete and use_parallel is True, please set auto_trans_ckpt: "
"True to enable automatic slicing function.")
err_msg = ("When load checkpoint is complete and use_parallel is True, please set auto_trans_ckpt: "
"True, to enable automatic slicing function.")
logger.error(err_msg)
raise ValueError(err_msg)
elif ckpt_file_mode == CheckpointFileMode.MULTI_CHECKPOINT_FILE_WITH_RANK_ID.value:
if not use_parallel:
raise ValueError("when input checkpoint file is rank dir, Please set use_parallel: True to enable "
"distributed ckpt load.")
err_msg = ("when input checkpoint file is rank dir, Please set use_parallel: "
"True, to enable distributed ckpt load.")
logger.error(err_msg)
raise ValueError(err_msg)
else:
raise ValueError("not support mode: no valid checkpoint files found")
err_msg = "not support mode: no valid checkpoint files found"
logger.error(err_msg)
raise ValueError(err_msg)


def unify_safetensors(src_checkpoint, src_strategy_path, unified_path, use_parallel=False,
@@ -422,6 +444,7 @@ def load_safetensors_checkpoint(config, load_checkpoint_files, network, strategy
logger.info("......obtain name map for HF safetensors.....")
name_map = origin_network.obtain_name_map(load_checkpoint_files)
except Exception as e:
logger.error(f"Please complete abstract function obtain_name_map. Details: {e}")
raise TypeError(f"Please complete abstract function obtain_name_map. Details: {e}") from e
if is_main_rank():
_convert_index_json(load_ckpt_path, load_ckpt_path, origin_network.convert_map_dict, False)
@@ -439,7 +462,9 @@ def load_safetensors_checkpoint(config, load_checkpoint_files, network, strategy
hyper_param_file = os.path.join(load_ckpt_path, 'hyper_param.safetensors')
if optimizer and config.resume_training:
if not os.path.exists(hyper_param_file):
raise FileNotFoundError(rf"No hyper_param.safetensors in given dir: {load_ckpt_path}")
err_msg = rf"No hyper_param.safetensors in given dir: {load_ckpt_path}"
logger.error(err_msg)
raise FileNotFoundError(err_msg)
logger.info("......Start load hyper param into optimizer......")
hyper_param_dict = ms.load_checkpoint(ckpt_file_name=hyper_param_file, format='safetensors')
update_global_step(config, hyper_param_dict)
@@ -457,7 +482,7 @@ def load_safetensors_checkpoint(config, load_checkpoint_files, network, strategy
format=config.load_ckpt_format
))
if not config.model.model_config.get("qkv_concat", False) \
and is_hf_safetensors_dir(load_ckpt_path, origin_network):
and is_hf_safetensors_dir(load_ckpt_path, origin_network):
logger.info("......HuggingFace weights convert name......")
params_dict = origin_network.convert_weight_dict(params_dict, model_config=config.model.model_config)
if optimizer and config.resume_training:
@@ -505,7 +530,9 @@ def process_hf_checkpoint(model, output_dir=None, load_checkpoint=None):

# If the child process exits abnormally, the main process should also throw an exception.
if p.exitcode != 0:
raise RuntimeError("convert HuggingFace weight failed.")
err_msg = "convert HuggingFace weight failed."
logger.error(err_msg)
raise RuntimeError(err_msg)

# If used parallel mode, other cards are waiting for the main card to complete the weight conversion.
barrier_world("for the main rank to convert HuggingFace weight...")
@@ -518,11 +545,14 @@ def process_hf_checkpoint(model, output_dir=None, load_checkpoint=None):
def get_last_checkpoint(checkpoint_dir, ckpt_format='ckpt'):
"""get last checkpoint for resuming or finetune."""
if not os.path.isdir(checkpoint_dir):
raise NotADirectoryError(
err_msg = (
f"{checkpoint_dir} is not a real directory,"
f"When distributed loads are sliced weights,"
f"load_checkpoint should be a checkpoint directory containing the directory of rank_{{0-*}},"
f"The directory structure is as follows: **checkpoint_root_dir/rank_{{0-*}}/**.{ckpt_format}")
logger.error(err_msg)
raise NotADirectoryError(err_msg)

output_checkpoint_path = [
checkpoint
for checkpoint in os.listdir(checkpoint_dir)
@@ -562,11 +592,15 @@ def validate_qkv_concat(model_cls_or_instance, qkv_concat_config, load_checkpoin
break

if is_qkv_concat and not qkv_concat_config:
raise ValueError("The qkv concat check failed! The qkv in the model weights has been concatenated,"
" but qkv_concat is set to false.")
err_msg = ("The qkv concat check failed! The qkv in the model weights has been concatenated, "
"but qkv_concat is set to false.")
logger.error(err_msg)
raise ValueError(err_msg)
if not is_qkv_concat and qkv_concat_config:
raise ValueError("The qkv concat check failed! The qkv in the model weights has been not concatenated,"
" but qkv_concat is set to true.")
err_msg = ("The qkv concat check failed! The qkv in the model weights has been not concatenated, "
"but qkv_concat is set to true.")
logger.error(err_msg)
raise ValueError(err_msg)
if is_qkv_concat and qkv_concat_config:
logger.info("The qkv concat check succeed! The qkv in the model weights has been concatenated and "
"qkv_concat is set to true.")


+ 3
- 1
mindformers/utils/resume_ckpt_utils.py View File

@@ -57,7 +57,9 @@ def load_resume_checkpoint(load_checkpoint_path, remove_redundancy, load_ckpt_fo
"""resume training, load training info from checkpoint to config"""
if not os.path.realpath(load_checkpoint_path) or \
not os.path.exists(load_checkpoint_path):
raise FileNotFoundError(f"The load_checkpoint_path must be correct, but get {load_checkpoint_path}")
err_msg = f"The load_checkpoint_path must be correct, but get {load_checkpoint_path}"
logger.error(err_msg)
raise FileNotFoundError(err_msg)

if os.path.isdir(load_checkpoint_path):
hyper_param_file = os.path.join(load_checkpoint_path, 'hyper_param.safetensors')


+ 1
- 1
mindformers/wrapper/wrapper.py View File

@@ -207,7 +207,7 @@ class MFTrainOneStepCell(nn.TrainOneStepWithLossScaleCell):
**kwargs (Any): Additional parameters.

Inputs:
- **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
- **\\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \\ldots)`.

Outputs:
Tuple of 5 or 7 Tensor, the loss, overflow flag, current loss scale value, learning rate,


+ 2
- 2
requirements.txt View File

@@ -10,7 +10,7 @@ nltk==3.9.1
mindpet==1.0.4
opencv-python-headless
pyarrow==19.0.0
tokenizers==0.21.0
tokenizers
astunparse>=1.6.3
numpy<2.0.0
tiktoken
@@ -20,5 +20,5 @@ safetensors
tensorboardX
deprecated>=1.2.0
pybind11 # for megatron dataset
transformers==4.51.3
transformers==4.57.1
datasets>=4.0.0

+ 16
- 1
research/deepseek3/README.md View File

@@ -2,7 +2,7 @@

## 模型描述

DeepSeek-V3是由DeepSeek(深度求索)推出的一个强大的专家混合(MoE)语言模型,它拥有671B总参数,其中激活参数量为37B。为了实现高效推理和低成本训练,DeepSeek-V3采用了多头潜注意力(MLA)和DeepSeekMoE架构,这在DeepSeek-V2中得到了充分验证。此外,DeepSeek-V3 还率先采用了无辅助损失的负载均衡策略,并设定了多token预测训练目标,以提高性能。DeepSeek-V3在14.8万亿个多种类的高质量token上进行预训练,接着通过监督微调和强化学习充分优化其能力。综合评估显示,在发布时DeepSeek-V3的性能优于其他开源模型,并可与领先的闭源模型相媲美。尽管性能卓越,DeepSeek-V3 的全部训练成本非常低,且其训练过程也非常稳定。
DeepSeek-V3是由DeepSeek(深度求索)推出的一个强大的专家混合(MoE)语言模型,它拥有671B总参数,其中激活参数量为37B。为了实现高效推理和低成本训练,DeepSeek-V3采用了多头潜注意力(MLA)和DeepSeekMoE架构,这在DeepSeek-V2中得到了充分验证。此外,DeepSeek-V3 还率先采用了无辅助损失的负载均衡策略,并设定了多token预测训练目标,以提高性能。DeepSeek-V3在14.8万亿个多种类的高质量token上进行预训练,接着通过监督微调和强化学习充分优化其能力。综合评估显示,在发布时DeepSeek-V3的性能优于其他开源模型,并可与领先的闭源模型相媲美。尽管性能卓越,DeepSeek-V3 的全部训练成本非常低,且其训练过程也非常稳定。

```text
@misc{deepseekai2024deepseekv3technicalreport,
@@ -87,6 +87,11 @@ python research/deepseek3/fp8_cast_bf16.py \
--output-bf16-hf-path path/to/hf_model_bf16_dir/
```

参数说明:

- input-fp8-hf-path:数据类型为fp8的原始权重文件夹路径。
- output-bf16-hf-path:转换成数据类型为bf16后的权重文件夹路径。

>`path/to/hf_model_bf16_dir/` 可修改为自定义路径,确保该路径有足够的磁盘空间(约 1.4TB)。

## 推理
@@ -135,6 +140,11 @@ python research/deepseek3/convert_weight.py \
- infer:是否进行推理权重的转换,默认值:`False`。
- mindspore_ckpt_path:转换后的MindSpore权重文件夹保存路径
- worker_num:多进程转换的进程数,默认值:`4`。
- use_grouped_gemm:是否使用grouped_gemm,默认值:`False`。
- n_head:模型结构中Attention的头数,默认值:`128`。
- v_head_dim:单个注意力头中,Value向量的维度大小,默认值为:`128`。
- save_format:权重保存的格式,默认值:`safetensors`。
- param_json:权重的参数映射表的JSON文件名,默认值:`model.safetensors.index.json`。

如果使用训练后保存的权重进行推理,需要使用`deepseek3_train2infer.py`脚本将其转换为推理格式。执行以下命令进行转换:

@@ -220,6 +230,11 @@ bash scripts/msrun_launcher.sh "research/deepseek3/run_predict_deepseek.py \
32 8 $master_ip 8888 3 output/msrun_log False 300
```

参数说明:

- config: 推理的YAML配置文件路径。
- input: 推理的问题输入。

预期的推理结果如下:

```txt


+ 1
- 2
research/llama3_1/README.md View File

@@ -97,7 +97,6 @@ MindFormers提供**alpaca**作为[微调](#微调)数据集。
2. 执行`research/llama3_1/llama3_1_preprocess.py`,生成Mindrecord数据,将带有prompt模板的数据转换为mindrecord格式。

```shell
# 此工具依赖fschat工具包解析prompt模板, 请提前安装fschat >= 0.2.13 python = 3.9
python llama3_1_preprocess.py \
--dataset_type qa \
--input_glob /{path}/alpaca-data-conversation.json \
@@ -146,7 +145,7 @@ dtype: 转换权重的精度

### 全参微调

MindFormers提供`Llama3_1-8b`单机多卡以及`Llama3_1-70b`多机多卡的的微调示例,过程中使用`alpaca`
MindSpore Transformers提供 `Llama3.1-8B` 单机多卡以及 `Llama3.1-70B` 多机多卡的微调示例,过程中使用 `alpaca`
数据集对模型进行微调,数据集可以参考[数据集下载](#数据集下载)获得。

#### 单机训练


+ 2
- 2
research/llama3_1/llama.py View File

@@ -50,10 +50,10 @@ from mindformers.generation.utils import convert_pin
@MindFormerRegister.register(MindFormerModuleType.MODELS)
class ParallelLlamaForCausalLM(LlamaPreTrainedModel):
r"""
Provide llama training loss or logits through network.
Provide Llama training loss or logits through network.

Args:
config (LlamaConfig): The config of llama model.
config (LlamaConfig): The config of Llama model.

Returns:
output: Tensor, the output of llama decoderlayer


+ 2
- 2
research/qwen2_5/README.md View File

@@ -205,7 +205,7 @@ mindspore_ckpt_path: qkv_concat转换后权重文件保存路径,单卡权重
1. 当前支持模型已提供推理相关配置文件,请根据实际使用模型更改配置文件。
2. 运行下面的代码需要在`mindformers/`目录下,或者先将`mindformers/`目录所在路径加入到`PYTHONPATH`环境变量中。

``qwen2_5-7b` 8卡微调为例,执行如下命令进行微调。
`qwen2_5-7b` 8卡微调为例,执行如下命令进行微调。

1. 主要参数配置参考:

@@ -236,7 +236,7 @@ mindspore_ckpt_path: qkv_concat转换后权重文件保存路径,单卡权重
tokenizer:
model_max_length: 32768
vocab_file: "./path/vocab.json" # 参考qwen2_5-7b官网下载的词表
merges_file: "./path/merges.txt" # # 参考qwen2_5-7b官网下载的merge文件
merges_file: "./path/merges.txt" # 参考qwen2_5-7b官网下载的merge文件
#callbacks config
callbacks:
- type: CheckpointMonitor


+ 18
- 23
research/telechat2/README.md View File

@@ -30,33 +30,29 @@

以下模型性能均由Atlas 800T A2硬件环境下测试得出。

TeleChat2-7b:
TeleChat2-7B:

| config | task | Datasets | SeqLength | phase | performance |
|:---------------------------------------------------:| :-------------------: |:----------:|:---------:|:---------------:|:------------:|
| [TeleChat2_7b](./telechat2-7b/finetune_telechat_7b.yaml) | text_generation | example_dataset | 8192 | [finetune](#微调) | 2950 tokens/s/p |
| [TeleChat2_7b](./telechat2-7b/predict_telechat_7b.yaml) | text_generation | example_dataset | 8192 | [predict](#推理) | 54.1 tokens/s |
| [TeleChat2_7B](./telechat2-7b/finetune_telechat_7b.yaml) | text_generation | example_dataset | 8192 | [finetune](#微调) | 2950 tokens/s/p |

TeleChat2-35b:
TeleChat2-35B:

| config | task | Datasets | SeqLength | phase | performance |
|-----------------------------------------------------| --------------------- |------------|-----------|-----------------|--------------|
| [TeleChat2_35b](./telechat2-35b/finetune_telechat_35b.yaml) | text_generation | example_dataset | 8192 | [finetune](#微调) | 516 tokens/s/p |
| [TeleChat2_35b](./telechat2-35b/predict_telechat_35b.yaml) | text_generation | example_dataset | 8192 | [predict](#推理) | 27.7 tokens/s |
| [TeleChat2_35B](./telechat2-35b/finetune_telechat_35b.yaml) | text_generation | example_dataset | 8192 | [finetune](#微调) | 516 tokens/s/p |

TeleChat2-115b:
TeleChat2-115B:

| config | task | Datasets | SeqLength | phase | performance |
|-----------------------------------------------------| --------------------- |------------|-----------|-----------------|--------------|
| [TeleChat2_115b](./telechat2-115b/finetune_telechat_115b.yaml) | text_generation | example_dataset | 8192 | [finetune](#微调) | 158 tokens/s/p |
| [TeleChat2_115b](./telechat2-115b/predict_telechat_115b.yaml) | text_generation | example_dataset | 8192 | [predict](#推理) | 26.5 tokens/s |
| [TeleChat2_115B](./telechat2-115b/finetune_telechat_115b.yaml) | text_generation | example_dataset | 8192 | [finetune](#微调) | 158 tokens/s/p |

TeleChat2-39b-a12b:
TeleChat2-39B-A12B:

| config | task | Datasets | SeqLength | phase | performance |
| ------------------------------------------------------------ | --------------- | --------------- | --------- | ---------------- | ------------- |
| [TeleChat2_39b_a12b](./telechat2-39b-a12b/finetune_telechat_39b_a12b.yaml) | text_generation | example_dataset | 8192 | [finetune](#微调) | 865 tokens/s/p |
| [TeleChat2_39b_a12b](./telechat2-39b-a12b/predict_telechat_39b_a12b_parallel.yaml) | text_generation | example_dataset | 8192 | [predict](#推理) | 36.4 tokens/s |
| [TeleChat2_39B_A12B](./telechat2-39b-a12b/finetune_telechat_39b_a12b.yaml) | text_generation | example_dataset | 8192 | [finetune](#微调) | 865 tokens/s/p |

## 模型文件

@@ -138,6 +134,12 @@ input_dataset_file: 预训练的数据集
vocab_file_path: 词模型文件路径(如使用上述链接下载,指定到对应路径下即可)
max_length: 数据集长度
output_path: 生成数据集的路径
seed: 随机数种子,默认值:2024
start_token: 输入的首token,默认值:<_start>
user_token: 用户输入的提示词token,默认值:<_usr>
bot_token: 机器人输入的提示词token,默认值:<_bot>
end_token: 终止符对应的token,默认值:<_end>
pad_token: padding时补齐的token,默认值:<_pad>
```

> 注:`bos`, `eos`, `pad`等特殊`ids`要和`yaml`配置文件中`model_config`部分保持一致,默认`bos_token_id=1`, `eos_token_id=2`, `pad_token_id=3`。
@@ -149,10 +151,10 @@ MindFormers提供已经转换完成的预训练权重、词表文件用于预训

1.torch模型权重及词模型下载链接:

- [TeleChat2-7b](https://modelscope.cn/models/TeleAI/TeleChat2-7B-32K)
- [TeleChat2-7B](https://modelscope.cn/models/TeleAI/TeleChat2-7B-32K)
- [TeleChat2-39B-A12B](https://modelscope.cn/models/TeleAI/TeleChat2-39B-A12B)
- [TeleChat2-35b](https://modelscope.cn/models/TeleAI/TeleChat2-35B)
- [TeleChat2-115b](https://modelscope.cn/models/TeleAI/TeleChat2-115B)
- [TeleChat2-35B](https://modelscope.cn/models/TeleAI/TeleChat2-35B)
- [TeleChat2-115B](https://modelscope.cn/models/TeleAI/TeleChat2-115B)

下载完成后,运行如下转换脚本,将全量微调的权重转换为完整的ckpt权重。

@@ -168,13 +170,6 @@ torch_path: torch版本权重保存目录路径
mindspore_path: 权重保存文件名,可以指定自定义保存路径
```

2.获取MindFormers提供的已转换权重,可直接从下面的链接获取。

- [TeleChat2-7b](https://telechat-docker.obs.cn-north-4.myhuaweicloud.com/model_weight/Telechat_7B/Telechat_7B.zip)
- [TeleChat2-35b](https://telechat-docker.obs.cn-north-4.myhuaweicloud.com/model_weight/Telechat_35B/Telechat_35B.zip)
- [TeleChat2-115b](https://telechat-docker.obs.cn-north-4.myhuaweicloud.com/model_weight/Telechat_115B/Telechat_115B.zip)
- [Telechat2-39b-a12b](https://telechat-docker.obs.cn-north-4.myhuaweicloud.com/model_weight/Telechat_39B_A12.tar):仅适用于8卡推理,使用方式请参考[Telechat2-39B-A12B推理](#Telechat2-39B-A12B推理)章节。

### 分布式权重切分与合并

分布式训练/微调后所得到的权重文件为根据策略切分后的权重,需要手动将切分权重合一,以用于评估和推理。
@@ -226,7 +221,7 @@ MindFormers提供`TeleChat2-115B`的微调示例,过程中使用中电信人
- step 2. 根据服务器节点数等信息,修改相应的配置。

```yaml
# 以telechat-115b模型8机64卡训练为例,默认配置机4096卡,如果节点数有变,需要修改相应的配置。
# 以telechat-115B模型8机64卡训练为例,默认配置机4096卡,如果节点数有变,需要修改相应的配置。
# 配置文件路径:finetune_telechat_115b.yaml
parallel_config:
data_parallel: 1


+ 79
- 9
tests/st/test_multi_cards_cases/test_model/test_deepseek3/run_deepseek3.py View File

@@ -34,7 +34,7 @@ ms.set_context(mode=ms.GRAPH_MODE)
def ds3_train(config, dataset, construct_args_key, checker_config):
"""set model train."""
callback = TrainingChecker(**checker_config)
task_trainer = Trainer(task="text_generation",
task_trainer = Trainer(task='text_generation',
args=config,
train_dataset=dataset,
callbacks=callback)
@@ -59,7 +59,7 @@ def parallel_train_dp2_mp2_cp2_ep2():
config.print_separate_loss = False
config.train_precision_sync = True
config.pretrained_model_dir = CUR_DIR
config.parallel.full_batch = True
config.parallel.full_batch = False
config.parallel.dataset_strategy = 'full_batch'
config.parallel_config.context_parallel = 2
config.parallel_config.pipeline_stage = 1
@@ -121,6 +121,73 @@ def parallel_train_dp2_pp2_ep2_tnd():
ds3_train(config, dataset, construct_args_key, checker_config)


def parallel_train_alltoall_deredundency():
"""test mcore deepseekv3 train in dp=pp=ep=2 with moe_token_dispatcher = alltoall_deredundency."""
ms.set_seed(0)
config = MindFormerConfig(f'{CUR_DIR}/deepseekv3_train.yaml')
config.print_separate_loss = False
config.train_precision_sync = True
config.pretrained_model_dir = CUR_DIR
config.runner_config.sink_mode = True
config.parallel.full_batch = False
config.model.model_config.moe_token_dispatcher_type = 'alltoall_deredundency'
config.model.model_config.npu_nums_per_device = 2
dp = config.parallel_config.data_parallel
config.parallel_config.model_parallel = 1
config.parallel_config.use_seq_parallel = False
config.parallel.dataset_strategy = [[dp, 1], [dp, 1]]
build_context(config)

construct_args_key = ['input_ids', 'labels']
model_config = config.model.model_config
dataset = get_dataset(model_config.seq_length, model_config.vocab_size, 2, 20)

loss_std = [13.485920, 13.485721, 13.485860, 13.485846, 13.486056,
13.485989, 13.486071, 13.485976, 13.485886, 13.485951,
13.485897, 13.486100, 13.485947, 13.486015, 13.486045,
13.485865, 13.486064, 13.486042, 13.485949, 13.485986,]
checker_config = {
'loss_list_std': loss_std,
'experiment_mode': False,
'micro_batch_num': 2,
'micro_batch_interleave_num': 1
}
ds3_train(config, dataset, construct_args_key, checker_config)


def parallel_train_alltoall_zero_redundancy():
"""test mcore deepseekv3 train in dp=pp=ep=2 with moe_token_dispatcher = alltoall_zero_redundancy."""
ms.set_seed(0)
config = MindFormerConfig(f'{CUR_DIR}/deepseekv3_train.yaml')
config.print_separate_loss = False
config.train_precision_sync = True
config.pretrained_model_dir = CUR_DIR
config.runner_config.sink_mode = True
config.parallel.full_batch = False
config.model.model_config.moe_token_dispatcher_type = 'alltoall_zero_redundancy'
dp = config.parallel_config.data_parallel
config.parallel_config.model_parallel = 1
config.parallel_config.use_seq_parallel = False
config.parallel.dataset_strategy = [[dp, 1], [dp, 1]]
build_context(config)

construct_args_key = ['input_ids', 'labels']
model_config = config.model.model_config
dataset = get_dataset(model_config.seq_length, model_config.vocab_size, 2, 20)

loss_std = [13.485920, 13.485721, 13.485860, 13.485846, 13.486056,
13.485989, 13.486071, 13.485976, 13.485886, 13.485951,
13.485897, 13.486100, 13.485947, 13.486015, 13.486045,
13.485865, 13.486064, 13.486042, 13.485949, 13.485986,]
checker_config = {
'loss_list_std': loss_std,
'experiment_mode': False,
'micro_batch_num': 2,
'micro_batch_interleave_num': 1
}
ds3_train(config, dataset, construct_args_key, checker_config)


def parallel_train_dp2_mp2_ep2_calculate_per_token_loss_and_print_seperate_loss():
"""test mcore deepseekv3 train in dp=mp=ep=2."""
ms.set_seed(0)
@@ -130,7 +197,7 @@ def parallel_train_dp2_mp2_ep2_calculate_per_token_loss_and_print_seperate_loss(

config.train_precision_sync = True
config.pretrained_model_dir = CUR_DIR
config.parallel.full_batch = True
config.parallel.full_batch = False
config.parallel.dataset_strategy = 'full_batch'
config.parallel_config.pipeline_stage = 1
build_context(config)
@@ -160,7 +227,7 @@ def moe_token_permute():
config.print_separate_loss = False
config.train_precision_sync = True
config.pretrained_model_dir = CUR_DIR
config.parallel.full_batch = True
config.parallel.full_batch = False
config.parallel.dataset_strategy = 'full_batch'
config.model.model_config.moe_permute_fusion = True
build_context(config)
@@ -189,7 +256,7 @@ def parallel_train_pp2_mp2_ep2_zbv():
config.train_precision_sync = True
config.pretrained_model_dir = CUR_DIR
config.runner_config.sink_mode = True
config.parallel.full_batch = True
config.parallel.full_batch = False
config.parallel.dataset_strategy = 'full_batch'
config.parallel_config.data_parallel = 1
config.parallel_config.micro_batch_num = 4
@@ -222,7 +289,7 @@ def moe_eplb():
config.print_separate_loss = False
config.train_precision_sync = True
config.pretrained_model_dir = CUR_DIR
config.parallel.full_batch = True
config.parallel.full_batch = False
config.parallel.dataset_strategy = 'full_batch'
config.model.model_config.print_expert_load = True
build_context(config)
@@ -245,12 +312,15 @@ def moe_eplb():


TEST_MAP = {
'parallel_train_dp2_mp2_cp2_ep2': parallel_train_dp2_mp2_cp2_ep2,
'parallel_train_dp2_pp2_ep2_tnd': parallel_train_dp2_pp2_ep2_tnd,
"parallel_train_dp2_mp2_ep2_calculate_per_token_loss_and_print_seperate_loss":
'parallel_train_dp2_mp2_ep2_calculate_per_token_loss_and_print_seperate_loss':
parallel_train_dp2_mp2_ep2_calculate_per_token_loss_and_print_seperate_loss,
'parallel_train_pp2_mp2_ep2_zbv': parallel_train_pp2_mp2_ep2_zbv,
"moe_token_permute": moe_token_permute,
"moe_eplb": moe_eplb,
'parallel_train_alltoall_deredundency': parallel_train_alltoall_deredundency,
'parallel_train_alltoall_zero_redundancy': parallel_train_alltoall_zero_redundancy,
'moe_token_permute': moe_token_permute,
'moe_eplb': moe_eplb,
}

if __name__ == '__main__':


+ 67
- 0
tests/st/test_multi_cards_cases/test_model/test_deepseek3/test_deepseek3_alltoall_deredundency_train.py View File

@@ -0,0 +1,67 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test DeepseekV3 training"""
import os
from multiprocessing.pool import Pool
from pathlib import Path
import random
import pytest

from tests.st.test_multi_cards_cases.utils import TaskType
from mindformers.tools.logger import logger


_LEVEL_0_TASK_TIME = 0
_LEVEL_1_TASK_TIME = 135
_TASK_TYPE = TaskType.FOUR_CARDS_TASK

def run_command(command_info):
cmd, log_path = command_info
logger.info(f"Running command: {cmd}")
ret = os.system(cmd)
return ret, log_path


def check_results(commands, results):
error_idx = [_ for _ in range(len(results)) if results[_][0] != 0]
for idx in error_idx:
print(f"testcase {commands[idx]} failed. please check log {results[idx][1]}.")
os.system(f"grep -E 'ERROR|error|Error' {results[idx][1]} -C 5")
assert error_idx == []


class TestDeepseekV3AlltoAllDeredundency:
"""Test class for DeepseekV3"""

def setup_method(self):
"""Setup method to prepare test environment"""
os.environ['MS_DEV_JIT_SYNTAX_LEVEL'] = '0'
self.sh_path = Path(__file__).parent.resolve()
self.run_script_path = self.sh_path / "run_deepseek3.py"
assert self.run_script_path.exists(), f"Run script not found: {self.run_script_path}"

@pytest.mark.level1
def test_four_card_configurations(self):
"""Test four cards for DeepseekV3."""
port_id = int(os.environ.get("ASCEND_PORT_ID", random.randint(50000, 65535)))
cmd_list = [
(f"msrun --worker_num=4 --local_worker_num=4 --master_port={port_id} "
"--log_dir=./msrun_log_deepseekv3_alltoall_deredundency "
f"--join=True {self.run_script_path} --mode=parallel_train_alltoall_deredundency",
"./msrun_log_deepseekv3_alltoall_deredundency/worker_3.log"),
]
with Pool(len(cmd_list)) as pool:
results = list(pool.imap(run_command, cmd_list))
check_results(cmd_list, results)

+ 67
- 0
tests/st/test_multi_cards_cases/test_model/test_deepseek3/test_deepseek3_alltoall_zero_redundancy_train.py View File

@@ -0,0 +1,67 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test DeepseekV3 training"""
import os
from multiprocessing.pool import Pool
from pathlib import Path
import random
import pytest

from tests.st.test_multi_cards_cases.utils import TaskType
from mindformers.tools.logger import logger


_LEVEL_0_TASK_TIME = 0
_LEVEL_1_TASK_TIME = 135
_TASK_TYPE = TaskType.FOUR_CARDS_TASK

def run_command(command_info):
cmd, log_path = command_info
logger.info(f"Running command: {cmd}")
ret = os.system(cmd)
return ret, log_path


def check_results(commands, results):
error_idx = [_ for _ in range(len(results)) if results[_][0] != 0]
for idx in error_idx:
print(f"testcase {commands[idx]} failed. please check log {results[idx][1]}.")
os.system(f"grep -E 'ERROR|error|Error' {results[idx][1]} -C 5")
assert error_idx == []


class TestDeepseekV3AlltoAllZeroRedundancy:
"""Test class for DeepseekV3"""

def setup_method(self):
"""Setup method to prepare test environment"""
os.environ['MS_DEV_JIT_SYNTAX_LEVEL'] = '0'
self.sh_path = Path(__file__).parent.resolve()
self.run_script_path = self.sh_path / "run_deepseek3.py"
assert self.run_script_path.exists(), f"Run script not found: {self.run_script_path}"

@pytest.mark.level1
def test_four_card_configurations(self):
"""Test four cards for DeepseekV3."""
port_id = int(os.environ.get("ASCEND_PORT_ID", random.randint(50000, 65535)))
cmd_list = [
(f"msrun --worker_num=4 --local_worker_num=4 --master_port={port_id} "
"--log_dir=./msrun_log_deepseekv3_alltoall_zero_redundancy "
f"--join=True {self.run_script_path} --mode=parallel_train_alltoall_zero_redundancy",
"./msrun_log_deepseekv3_alltoall_zero_redundancy/worker_3.log"),
]
with Pool(len(cmd_list)) as pool:
results = list(pool.imap(run_command, cmd_list))
check_results(cmd_list, results)

+ 6
- 4
tests/st/test_multi_cards_cases/test_model/test_deepseek3/test_deepseek3_train.py View File

@@ -17,13 +17,14 @@ import os
from multiprocessing.pool import Pool
from pathlib import Path
import random
import pytest

from mindformers.tools.logger import logger
from tests.st.test_multi_cards_cases.utils import TaskType
from mindformers.tools.logger import logger


_LEVEL_0_TASK_TIME = 170
_LEVEL_1_TASK_TIME = 0
_LEVEL_0_TASK_TIME = 0
_LEVEL_1_TASK_TIME = 170
_TASK_TYPE = TaskType.EIGHT_CARDS_TASK

def run_command(command_info):
@@ -51,13 +52,14 @@ class TestDeepseekV3:
self.run_script_path = self.sh_path / "run_deepseek3.py"
assert self.run_script_path.exists(), f"Run script not found: {self.run_script_path}"

@pytest.mark.level1
def test_eight_card_configurations(self):
"""Test eight cards for DeepseekV3."""
port_id = int(os.environ.get("ASCEND_PORT_ID", random.randint(50000, 65535)))
cmd_list = [
(f"msrun --worker_num=8 --local_worker_num=8 --master_port={port_id} --log_dir=./msrun_log_deepseekv3 "
f"--join=True {self.run_script_path} --mode=parallel_train_dp2_mp2_cp2_ep2",
f"./msrun_log_deepseekv3/worker_7.log"),
"./msrun_log_deepseekv3/worker_7.log"),
]
with Pool(len(cmd_list)) as pool:
results = list(pool.imap(run_command, cmd_list))


+ 15
- 0
tests/st/test_multi_cards_cases/test_model/test_glm4_moe/__init__.py View File

@@ -0,0 +1,15 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test mcore glm4 moe."""

+ 15
- 0
tests/st/test_multi_cards_cases/test_model/test_glm4_moe/test_glm4_moe_infer/__init__.py View File

@@ -0,0 +1,15 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test mcore infer glm4 moe."""

+ 40
- 0
tests/st/test_multi_cards_cases/test_model/test_glm4_moe/test_glm4_moe_infer/glm4_moe_infer.yaml View File

@@ -0,0 +1,40 @@
seed: 0
output_dir: './output' # path to save checkpoint/strategy
load_checkpoint: ''
use_parallel: False
run_mode: 'predict'
use_legacy: False
load_ckpt_format: 'safetensors'

trainer:
type: CausalLanguageModelingTrainer
model_name: 'glm4_moe'

# default parallel of device num = 8 for Atlas 800T A2
parallel_config:
data_parallel: 1
model_parallel: 2
# HuggingFace file directory
pretrained_model_dir: '/path/hf_dir'
model:
model_config:
compute_dtype: "bfloat16"
layernorm_compute_dtype: "float32"
softmax_compute_dtype: "float32"
rotary_dtype: "bfloat16"
params_dtype: "bfloat16"

# mindspore context init config
context:
mode: 0 #0--Graph Mode; 1--Pynative Mode
max_device_memory: "29GB"
device_id: 0
device_target: "Ascend"
affinity_cpu_list: None
deterministic: "ON"
infer_precision_sync: True

# parallel context config
parallel:
parallel_mode: "MANUAL_PARALLEL"
full_batch: False

+ 90
- 0
tests/st/test_multi_cards_cases/test_model/test_glm4_moe/test_glm4_moe_infer/run_glm4_moe.py View File

@@ -0,0 +1,90 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""mcore glm4 moe model ST of inference"""
import argparse
import os
from transformers import AutoTokenizer

from mindspore.nn.utils import no_init_parameters

from tests.st.test_multi_cards_cases.test_model.utils import compare_distance

from mindformers import AutoModel, build_context, MindFormerConfig
from mindformers.core.parallel_config import build_parallel_config
from mindformers.tools.logger import logger


def test_glm4_moe_predict_mcore(device_num: int = 1):
"""
Feature: Mcore Glm4Moe predict task
Description: Two-card tp parallel
Expectation: Success or assert precision failed
"""
max_decode_length = 32
config_path = os.path.join(os.path.dirname(__file__), "glm4_moe_infer.yaml")
config = MindFormerConfig(config_path)
config.use_parallel = device_num > 1
config.parallel_config.model_parallel = device_num
config.pretrained_model_dir = "/home/workspace/mindspore_dataset/weight/GLM-4.5-Air-tiny"
# Reduced layer network
config.model.model_config.num_hidden_layers = 2
build_context(config)
build_parallel_config(config)
# Auto tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model_dir)
# init network
with no_init_parameters():
network = AutoModel.from_config(config)
network.load_weights(config.pretrained_model_dir)
# Build prompt and answer
batch_datas = {1: {"prompt": "Please introduce some scenic spots in Beijing.",
"answer": "Please introduce some scenic spots in Beijing."
"ahan flying UmbursalhsiqadereFINE写实ENVIRONMENTally NASpired "
"Biosphericoux posit Lifts-offENS小的范围内"},
4: {"prompt": "Please introduce some scenic spots in Beijing.",
"answer": "Please introduce some scenic spots in Beijing."
"ahan flying UmbursalhsiqadereFINE写实ENVIRONMENTally NASpired "
"Biosphericoux posit Lifts-offENS小的范围内"},
}

for batch_size, batch_data in batch_datas.items():
input_ids = tokenizer.encode(batch_data["prompt"])
input_ids_list = []
answer = batch_data["answer"]
for _ in range(0, batch_size):
input_ids_list.append(input_ids)

outputs = network.generate(input_ids_list,
max_length=max_decode_length,
do_sample=False,
return_dict_in_generate=False)

for output in outputs:
output_text = tokenizer.decode(output)
logger.info("test_glm4_5_air_predict, output_text: %s", str(output_text))
compare_distance(output_text, answer)


if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Run Glm4Moe ST")
parser.add_argument("--device_num", type=int, default=2)

args = parser.parse_args()
os.environ['MS_ENABLE_LCCL'] = "off"
os.environ['HCCL_DETERMINICTIC'] = "true"
os.environ['LCCL_DETERMINICTIC'] = "1"
os.environ['ASCEND_LAUNCH_BLOCKING'] = "1"
os.environ['CUSTOM_MATMUL_SHUFFLE'] = "off"
test_glm4_moe_predict_mcore(args.device_num)

+ 58
- 0
tests/st/test_multi_cards_cases/test_model/test_glm4_moe/test_glm4_moe_infer/test_glm4_moe_infer.py View File

@@ -0,0 +1,58 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test Mcore Glm4Moe inference"""
import os
import random
from pathlib import Path

import pytest

from tests.st.test_multi_cards_cases.utils import TaskType
from mindformers.tools.logger import logger


_LEVEL_0_TASK_TIME = 80
_LEVEL_1_TASK_TIME = 0
_TASK_TYPE = TaskType.TWO_CARDS_TASK


class TestMcoreGlm4MoeParallelInference:
"""Test class for Glm4Moe in inference"""

def setup_method(self):
"""Setup method to prepare test environment"""
self.sh_path = Path(__file__).parent.resolve()
self.run_script_path = self.sh_path / "run_glm4_moe.py"
assert self.run_script_path.exists(), f"Run script not found: {self.run_script_path}"

@pytest.mark.level0
def test_two_cards_cases(self):
"""Test two cards for Glm4Moe."""
port_id = int(os.environ.get("ASCEND_PORT_ID", random.randint(50000, 65535)))
cmd_list = [
"msrun",
"--worker_num=2",
"--local_worker_num=2", # Should match NPU cards available
f"--master_port={port_id}", # Ensure port is unique per test run if parallelized at pytest level
"--log_dir=./msrun_log_glm4moe",
"--join=True"]
cmd_list += [
str(self.run_script_path),
"--device_num=2"
]
cmd = " ".join(cmd_list)
logger.info(f"Running command: {cmd}")
return_code = os.system(cmd)
assert return_code == 0, "Glm4Moe inference st failed."

+ 67
- 0
tests/st/test_multi_cards_cases/test_optimizer/test_pma/config.json View File

@@ -0,0 +1,67 @@
{
"architectures": [
"DeepseekV3ForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"auto_map": {
"AutoConfig": "configuration_deepseek.DeepseekV3Config",
"AutoModel": "modeling_deepseek.DeepseekV3Model",
"AutoModelForCausalLM": "modeling_deepseek.DeepseekV3ForCausalLM"
},
"bos_token_id": 0,
"eos_token_id": 1,
"ep_size": 1,
"first_k_dense_replace": 3,
"hidden_act": "silu",
"hidden_size": 7168,
"initializer_range": 0.02,
"intermediate_size": 18432,
"kv_lora_rank": 512,
"max_position_embeddings": 163840,
"model_type": "deepseek_v3",
"moe_intermediate_size": 2048,
"moe_layer_freq": 1,
"n_group": 8,
"n_routed_experts": 256,
"n_shared_experts": 1,
"norm_topk_prob": true,
"num_attention_heads": 128,
"num_experts_per_tok": 8,
"num_hidden_layers": 61,
"num_key_value_heads": 128,
"num_nextn_predict_layers": 1,
"q_lora_rank": 1536,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"quantization_config": {
"activation_scheme": "dynamic",
"fmt": "e4m3",
"quant_method": "fp8",
"weight_block_size": [
128,
128
]
},
"rms_norm_eps": 1e-06,
"rope_scaling": {
"beta_fast": 32,
"beta_slow": 1,
"factor": 40,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "yarn"
},
"rope_theta": 10000,
"routed_scaling_factor": 2.5,
"scoring_func": "sigmoid",
"tie_word_embeddings": false,
"topk_group": 4,
"topk_method": "noaux_tc",
"torch_dtype": "bfloat16",
"transformers_version": "4.33.1",
"use_cache": true,
"v_head_dim": 128,
"vocab_size": 129280
}

+ 134
- 0
tests/st/test_multi_cards_cases/test_optimizer/test_pma/deepseekv3_train.yaml View File

@@ -0,0 +1,134 @@
seed: 0
output_dir: './output' # path to save checkpoint/strategy
load_checkpoint: ''
src_strategy_path_or_dir: ''
auto_trans_ckpt: False # If true, auto transform load_checkpoint to load in distributed model
only_save_strategy: False
resume_training: False
use_parallel: True
run_mode: 'train'
train_precision_sync: True
load_ckpt_format: 'safetensors'
use_legacy: False
# trainer config
trainer:
type: CausalLanguageModelingTrainer
model_name: 'deepseekV3'

# runner config
runner_config:
epochs: 1
batch_size: 1
sink_mode: True
sink_size: 1

# optimizer
optimizer:
type: AdamW
betas: [0.9, 0.95]
eps: 1.e-8

# lr schedule
lr_schedule:
type: ConstantWarmUpLR
learning_rate: 2.2e-4
lr_end: 2.2e-4
warmup_steps: 0
total_steps: -1

# mindspore context init config
context:
mode: 0 #0--Graph Mode; 1--Pynative Mode
device_target: "Ascend"
max_call_depth: 10000
max_device_memory: "28GB"
save_graphs: False
save_graphs_path: "./graph"
jit_config:
jit_level: "O0"

parallel_config:
data_parallel: 2
model_parallel: 2
pipeline_stage: 2
expert_parallel: 2
micro_batch_num: 2
vocab_emb_dp: True
use_seq_parallel: True
gradient_aggregation_group: 4
# when model parallel is greater than 1, we can set micro_batch_interleave_num=2, that may accelerate the train process.
micro_batch_interleave_num: 1

# parallel context config
parallel:
parallel_mode: 1
gradients_mean: False
enable_alltoall: True
search_mode: "sharding_propagation"
enable_parallel_optimizer: True
strategy_ckpt_save_file: "./ckpt_strategy.ckpt"
parallel_optimizer_config:
gradient_accumulation_shard: False
parallel_optimizer_threshold: 64
optimizer_weight_shard_size: 4

# recompute config
recompute_config:
recompute: False
select_recompute: False
parallel_optimizer_comm_recompute: False
mp_comm_recompute: False
recompute_slice_activation: False

pretrained_model_dir: ''
# model config
model:
model_config:
vocab_size: 32000
seq_length: 4096
hidden_size: 256
num_hidden_layers: 4
first_k_dense_replace: 1
num_attention_heads: 8
num_key_value_heads: 8
max_position_embeddings: 4096
intermediate_size: 512
moe_intermediate_size: 512
v_head_dim: 192
compute_dtype: "bfloat16"
layernorm_compute_dtype: "float32"
softmax_compute_dtype: "float32"
rotary_dtype: "float32"
params_dtype: "float32"
router_dense_type: "float32"
hidden_dropout: 0.1
offset: 0
use_flash_attention: True
mtp_loss_scaling_factor: 0.3
input_sliced_sig: True
n_routed_experts: 16
add_bias_linear: False
gated_linear_unit: True
qk_layernorm: True
moe_aux_loss_coeff: 0.0001
moe_router_load_balancing_type: "seq_aux_loss"
moe_token_drop_policy: probs
moe_router_enable_expert_bias: True
moe_router_bias_update_rate: 0.001
moe_grouped_gemm: True
use_interleaved_weight_layout_mlp: False
topk_group: 0
n_group: 0

# callbacks
callbacks:
- type: MFLossMonitor
per_print_times: 1
# balance topk bias with callback
- type: TopkBiasBalanceCallback

# wrapper cell config
runner_wrapper:
type: MFTrainOneStepCell
scale_sense: 1.0
use_clip_grad: True

+ 241
- 0
tests/st/test_multi_cards_cases/test_optimizer/test_pma/run_deepseek3.py View File

@@ -0,0 +1,241 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test module for testing the paralleled mcore deepseek3 interface used for mindformers.
"""
import os
import argparse
from types import MethodType
from safetensors import safe_open
import numpy as np

import mindspore as ms
from mindspore.ops.operations import Cast

from tests.st.test_multi_cards_cases.test_model.test_deepseek3.data_gen_utils import get_dataset, generate_weight

from mindformers import build_context, MindFormerConfig
from mindformers.trainer import Trainer
from mindformers.core.callback.callback import CheckpointMonitor, TrainCallBack


cpu_cast = Cast().set_device("CPU")

CUR_DIR = os.path.dirname(__file__)

ms.set_context(mode=ms.GRAPH_MODE)

OPTIMIZER_KEYS = {"adam_m", "adam_v", "epoch_num", "global_step",
"loss_scale", "step_num", "scale_sense"}


def ds3_train(config, dataset, construct_args_key, callback):
"""set model train."""
task_trainer = Trainer(task="text_generation",
args=config,
train_dataset=dataset,
callbacks=callback)

task_trainer.config.train_dataset.input_columns = construct_args_key
task_trainer.config.train_dataset.construct_args_key = construct_args_key
def create_network(self, default_args):
network = type(self).create_network(self, default_args)
param_dict = generate_weight(network)
ms.load_param_into_net(network, param_dict)
return network
task_trainer.trainer.create_network = MethodType(create_network, task_trainer.trainer)
task_trainer.train()


def create_config(mode='origin'):
"""Create config for different test modes."""
ms.set_seed(0)
config = MindFormerConfig(f'{CUR_DIR}/deepseekv3_train.yaml')
config.print_separate_loss = False
config.train_precision_sync = True
config.pretrained_model_dir = CUR_DIR
config.parallel.full_batch = True
config.parallel.dataset_strategy = 'full_batch'
config.parallel_config.data_parallel = 1

if mode != 'load_origin':
config.load_checkpoint = f'{CUR_DIR}/test/checkpoint/'

if mode == 'ema':
config.optimizer.type = 'PmaAdamW'
config.optimizer.interleave_step = 1
config.optimizer.fused_num = 2
elif mode == 'sma':
config.optimizer.type = 'PmaAdamW'
config.optimizer.interleave_step = 1
config.optimizer.fused_num = 2
config.optimizer.fused_algo = 'sma'

build_context(config)
return config


def run_load_origin():
"""Run origin test for loading checkpoint."""
ms.set_seed(0)
config = create_config('load_origin')

construct_args_key = ['input_ids', 'labels']
model_config = config.model.model_config
dataset = get_dataset(model_config.seq_length, model_config.vocab_size, 4, 20)

callback = [CheckpointMonitor(
save_checkpoint_steps=1,
checkpoint_format='safetensors',
directory=f"{CUR_DIR}/test"),
TrainCallBack(stop_step=1)]

ds3_train(config, dataset, construct_args_key, callback)


def run_origin():
"""Run origin test to compare."""
ms.set_seed(0)
config = create_config('origin')

construct_args_key = ['input_ids', 'labels']
model_config = config.model.model_config
dataset = get_dataset(model_config.seq_length, model_config.vocab_size, 4, 20)

callback = [
CheckpointMonitor(save_checkpoint_steps=1, checkpoint_format='safetensors', directory=f"{CUR_DIR}/origin"),
TrainCallBack(stop_step=2)]
ds3_train(config, dataset, construct_args_key, callback)


def run_ema():
"""Run ema test to compare."""
ms.set_seed(0)
config = create_config('ema')

construct_args_key = ['input_ids', 'labels']
model_config = config.model.model_config
dataset = get_dataset(model_config.seq_length, model_config.vocab_size, 4, 20)

callback = [CheckpointMonitor(save_checkpoint_steps=1, checkpoint_format='safetensors', directory=f"{CUR_DIR}/ema"),
TrainCallBack(stop_step=2)]
ds3_train(config, dataset, construct_args_key, callback)


def run_sma():
"""Run sma test to compare."""
ms.set_seed(0)
config = create_config('sma')

construct_args_key = ['input_ids', 'labels']
model_config = config.model.model_config
dataset = get_dataset(model_config.seq_length, model_config.vocab_size, 4, 20)

callback = [CheckpointMonitor(save_checkpoint_steps=1, checkpoint_format='safetensors', directory=f"{CUR_DIR}/sma"),
TrainCallBack(stop_step=2)]
ds3_train(config, dataset, construct_args_key, callback)


def test_pma():
"""
Feature: test pma.
Description: Run pma function.
Expectation: Success or assert precision failed
"""
run_load_origin()
run_origin()
run_ema()
run_sma()
origin_dict1 = load_safetensors("CKP_rank_0-1_1.safetensors", f"{CUR_DIR}/origin/checkpoint/rank_0")
origin_dict2 = load_safetensors("CKP_rank_0-2_1.safetensors", f"{CUR_DIR}/origin/checkpoint/rank_0")

check_dict1 = load_safetensors("CKP_rank_0-1_1.safetensors", f"{CUR_DIR}/ema/checkpoint/rank_0")
check_dict2 = load_safetensors("CKP_rank_0-2_1.safetensors", f"{CUR_DIR}/ema/checkpoint/rank_0")

check_dict3 = load_safetensors("CKP_rank_0-1_1.safetensors", f"{CUR_DIR}/sma/checkpoint/rank_0")
check_dict4 = load_safetensors("CKP_rank_0-2_1.safetensors", f"{CUR_DIR}/sma/checkpoint/rank_0")

compare_checkpoint_step_one(check_dict1, origin_dict1, 0.2, "pma_weight_ema.")
compare_checkpoint_step_one(check_dict3, origin_dict1, 1, "pma_weight_sma.")

compare_checkpoint_step_two(check_dict2, origin_dict1, origin_dict2, 0.2 * 0.8, 0.2)
compare_checkpoint_step_two(check_dict4, origin_dict1, origin_dict2, 0.5, 0.5)

TEST_MAP = {
"test_pma": test_pma,
}


def compare_checkpoint_step_one(check_dict, origin_dict, alpha, pma_prefix):
"""Compare checkpoint for first step."""
unexpected_keys = []
for k, _ in check_dict.items():
if origin_dict.get(k) is not None:
origin_value = cpu_cast(ms.Tensor(origin_dict.get(k)), ms.float32)
check_value = cpu_cast(ms.Tensor(check_dict.get(k)), ms.float32)
assert np.allclose(origin_value, check_value)
else:
if 'pma' in k:
pma_value = cpu_cast(ms.Tensor(check_dict.get(k)), ms.float32)
assert origin_dict.get(k.replace(pma_prefix, "")) is not None
origin_value = cpu_cast(ms.Tensor(origin_dict.get(k.replace(pma_prefix, ""))),
ms.float32)
assert np.allclose(pma_value, alpha * origin_value)
else:
unexpected_keys.append(k)
assert not unexpected_keys


def compare_checkpoint_step_two(check_dict, origin_dict1, origin_dict2, alpha1, alpha2):
"""Compare checkpoint for second step."""
unexpected_keys = []
for k, _ in check_dict.items():
if "router.expert_bias" in k:
continue
if origin_dict2.get(k) is not None and origin_dict1.get(k) is not None:
origin_value1 = cpu_cast(ms.Tensor(origin_dict1.get(k)), ms.float32)
origin_value2 = cpu_cast(ms.Tensor(origin_dict2.get(k)), ms.float32)
check_value = cpu_cast(ms.Tensor(check_dict.get(k)), ms.float32)
if any(key in k for key in OPTIMIZER_KEYS):
assert np.allclose(origin_value2, check_value)
continue
assert np.allclose(origin_value1 * alpha1 + alpha2 * origin_value2, check_value)
else:
if 'pma' in k:
pma_value = cpu_cast(ms.Tensor(check_dict.get(k)), ms.float32)
assert np.allclose(pma_value, pma_value * 0)
else:
unexpected_keys.append(k)
assert not unexpected_keys


def load_safetensors(ckpt, path):
"""Load checkpoint which format is safetensors."""
ckpt = os.path.join(path, ckpt)
ckpt_dict = {}

with safe_open(ckpt, framework='np', device='cpu') as f:
for k in f.keys():
ckpt_dict[k] = f.get_tensor(k)

return ckpt_dict


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--mode', type=str, help='test mode of deepseek model.')

args = parser.parse_args()
TEST_MAP[args.mode]()

+ 72
- 0
tests/st/test_multi_cards_cases/test_optimizer/test_pma/test_pma.py View File

@@ -0,0 +1,72 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test DeepseekV3 training with ZeroBubbleV"""
import os
from multiprocessing.pool import Pool
from pathlib import Path
import random
import pytest

from tests.st.test_multi_cards_cases.utils import TaskType

from mindformers.tools.logger import logger


_LEVEL_0_TASK_TIME = 0
_LEVEL_1_TASK_TIME = 436
_TASK_TYPE = TaskType.FOUR_CARDS_TASK


def run_command(command_info):
cmd, log_path = command_info
logger.info(f"Running command: {cmd}")
ret = os.system(cmd)
return ret, log_path


def check_results(commands, results):
error_idx = [_ for _ in range(len(results)) if results[_][0] != 0]
for idx in error_idx:
print(f"testcase {commands[idx]} failed. please check log {results[idx][1]}.")
os.system(f"grep -E 'ERROR|error|Error' {results[idx][1]} -C 5")
assert error_idx == []


class TestDeepseekV3WithPma:
"""Test class for DeepseekV3 with Pma"""

def setup_method(self):
"""Setup method to prepare test environment"""
os.environ['MS_DEV_JIT_SYNTAX_LEVEL'] = '0'
self.sh_path = Path(__file__).parent.resolve()
self.run_script_path = self.sh_path / "run_deepseek3.py"
assert self.run_script_path.exists(), f"Run script not found: {self.run_script_path}"

@pytest.mark.level1
def test_pma(self):
"""
Feature: Test DeepseekV3 with pma.
Description: Test four cards for DeepseekV3.
Expectation: Success
"""
port_id = int(os.environ.get("ASCEND_PORT_ID", random.randint(50000, 65535)))
cmd_list = [
(f"msrun --worker_num=4 --local_worker_num=4 --master_port={port_id} --log_dir=./msrun_log_pma "
f"--join=True {self.run_script_path} --mode=test_pma",
"./msrun_log_pma/worker_0.log"),
]
with Pool(len(cmd_list)) as pool:
results = list(pool.imap(run_command, cmd_list))
check_results(cmd_list, results)

+ 100
- 2
tests/st/test_optim/optimizer_util.py View File

@@ -22,6 +22,7 @@ from mindspore import nn, Tensor
from mindspore.ops import operations as P

from mindformers.core.optim import build_optim
from mindformers.core.optim.muon import Muon

np.random.seed(1024)

@@ -58,7 +59,7 @@ class NetWithLoss(nn.Cell):
"""

def __init__(self, network, loss_fn):
super(NetWithLoss, self).__init__()
super().__init__()
self.network = network
self.loss = loss_fn

@@ -74,7 +75,7 @@ class FakeNet(nn.Cell):
"""

def __init__(self):
super(FakeNet, self).__init__()
super().__init__()
self.fc1 = nn.Dense(in_channels=8, out_channels=4, weight_init=Tensor(fc1_weight), bias_init=Tensor(fc1_bias))
self.fc2 = nn.Dense(in_channels=4, out_channels=1, weight_init=Tensor(fc2_weight), bias_init=Tensor(fc2_bias))
self.relu = nn.ReLU()
@@ -155,3 +156,100 @@ default_fc1_weight_adamw_v = (
default_fc2_weight_adamw_v = (
np.array([[35.217834, 42.283375, 26.52298, 21.510029]], dtype=np.float32)
)


class MockTransformerConfig:
"""Mock transformer config for testing Muon optimizer."""
def __init__(self):
self.multi_latent_attention = True
self.tensor_model_parallel_size = 1
self.data_parallel_size = 1


class MockModel:
"""
Mock model class that provides required interfaces for Muon optimizer.
This simulates the model interface that Muon optimizer expects.
"""
def __init__(self):
self.config = MockTransformerConfig()

def get_gpt_transformer_config(self):
"""Return transformer config."""
return self.config

def make_model_muon_fns(self):
"""Return muon split and merge functions."""
def muon_split_fn(param_name, tensor): # pylint: disable=unused-argument
"""Split function - returns tensor as list."""
return [tensor]

def muon_merge_fn(param_name, tensor_list): # pylint: disable=unused-argument
"""Merge function - returns first tensor."""
return tensor_list[0]

return muon_split_fn, muon_merge_fn

def get_param_layer_indices(self, params):
"""Return layer indices for parameters."""
return {p.name: 0 for p in params}

def get_muon_filter(self):
"""Return filter function to determine which params use Muon."""
def muon_filter(param):
# Apply Muon to weight parameters with 2D shape (not bias)
return len(param.shape) == 2 and 'bias' not in param.name
return muon_filter

def get_tp_dims(self, params):
"""Return tensor parallel dimensions."""
return tuple(-1 for _ in params)

def get_op_groups_info(self, params, op): # pylint: disable=unused-argument
"""Return optimizer parallel group info."""
ops = tuple(1 for _ in params)
op_groups = tuple("" for _ in params)
return ops, op_groups


def build_muon_network(net, mock_model, learning_rate=0.02):
"""
Build network with Muon optimizer for testing.

Args:
net: The network to train
mock_model: Mock model providing Muon interface
learning_rate: Learning rate for optimizer

Returns:
tuple: (losses, optimizer)
"""

loss_fn = nn.L1Loss(reduction='mean')
networkwithloss = NetWithLoss(net, loss_fn)
networkwithloss.set_train()

params = networkwithloss.trainable_params()

# Create Muon optimizer
optimizer = Muon(
params=params,
learning_rate=learning_rate,
weight_decay=0.1,
matched_adamw_rms=0.2,
momentum=0.95,
nesterov=True,
adamw_betas=(0.95, 0.95),
adamw_eps=1e-8,
model=mock_model,
)

trainonestepcell = mindspore.nn.TrainOneStepCell(networkwithloss, optimizer)

losses = []
data, label = make_fake_data()
for i in range(20):
loss = trainonestepcell(data[i], label[i])
losses.append(loss.asnumpy())

return np.array(losses), optimizer

+ 816
- 4
tests/st/test_optim/test_adamw.py View File

@@ -21,18 +21,27 @@ import pytest
import numpy as np

import mindspore as ms

from tests.st.test_optim.optimizer_util import build_network, FakeNet, default_fc1_weight_adamw_m, \
default_fc2_weight_adamw_m, default_fc1_weight_adamw_v, default_fc2_weight_adamw_v
from mindspore import Parameter, Tensor, dtype as mstype
from mindspore.nn import Cell
from tests.st.test_optim.optimizer_util import (
build_network,
FakeNet,
default_fc1_weight_adamw_m,
default_fc2_weight_adamw_m,
default_fc1_weight_adamw_v,
default_fc2_weight_adamw_v
)
from mindformers.core.optim.adamw import AdamW, _check_param_value

ms.set_context(mode=0)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
class TestAdamW:
"""A test class for testing optimizer computation."""

def test_computation(self):
"""
Feature: Trainer.train()
@@ -45,3 +54,806 @@ class TestAdamW:
assert np.allclose(cells.exp_avg[2].asnumpy(), default_fc2_weight_adamw_m, atol=1.e-4)
assert np.allclose(cells.exp_avg_sq[0].asnumpy(), default_fc1_weight_adamw_v, atol=1.e-4)
assert np.allclose(cells.exp_avg_sq[2].asnumpy(), default_fc2_weight_adamw_v, atol=1.e-4)


class SimpleNet(Cell):
"""Simple network for testing"""

def __init__(self):
super().__init__()
self.weight = Parameter(Tensor(np.ones([2, 3]), mstype.float32), name="weight")
self.bias = Parameter(Tensor(np.zeros([3]), mstype.float32), name="bias")

def construct(self, x):
return x * self.weight + self.bias


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_init():
"""
Feature: AdamW optimizer initialization
Description: Test AdamW initialization with default parameters
Expectation: Successfully initialize AdamW optimizer with default parameters and verify beta1, beta2, and eps values
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params())
assert optimizer is not None
assert np.allclose(optimizer.beta1.asnumpy(), np.array([0.9]))
assert np.allclose(optimizer.beta2.asnumpy(), np.array([0.999]))
assert np.allclose(optimizer.eps.asnumpy(), np.array([1e-8]))


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_init_with_custom_params():
"""
Feature: AdamW optimizer initialization with custom parameters
Description: Test AdamW initialization with custom learning rate, betas, eps, and weight decay
Expectation: Successfully initialize AdamW with custom parameters and verify the values are set correctly
"""
net = SimpleNet()
learning_rate = 0.005
betas = (0.8, 0.99)
eps = 1e-7
weight_decay = 0.01

optimizer = AdamW(
net.trainable_params(),
learning_rate=learning_rate,
betas=betas,
eps=eps,
weight_decay=weight_decay
)

assert np.allclose(optimizer.beta1.asnumpy(), np.array([0.8]))
assert np.allclose(optimizer.beta2.asnumpy(), np.array([0.99]))
assert np.allclose(optimizer.eps.asnumpy(), np.array([1e-7]))


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_init_with_swap():
"""
Feature: AdamW optimizer initialization with swap parameter
Description: Test AdamW initialization with swap=True to offload optimizer states to CPU
Expectation: Successfully initialize AdamW with swap=True and verify optimizer states are on CPU
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params(), swap=True)
assert optimizer.swap is True
# Check if exp_avg parameters are on CPU
for param in optimizer.exp_avg:
assert param.device == 'CPU'
for param in optimizer.exp_avg_sq:
assert param.device == 'CPU'


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_init_with_group_params():
"""
Feature: AdamW optimizer initialization with group parameters
Description: Test AdamW initialization with grouped parameters having different learning rates and weight decays
Expectation: Successfully initialize AdamW with grouped parameters
"""
net = SimpleNet()
params = [
{'params': [net.weight], 'lr': 0.001, 'weight_decay': 0.01},
{'params': [net.bias], 'lr': 0.0001, 'weight_decay': 0.0}
]
optimizer = AdamW(params)
assert optimizer is not None


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_param_value():
"""
Feature: _check_param_value function
Description: Test _check_param_value function with valid parameters
Expectation: Successfully validate parameters without raising exceptions
"""
betas = (0.9, 0.999)
eps = 1e-8
weight_decay = 0.01
_check_param_value(betas, eps, weight_decay, "AdamW")


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_param_value_invalid_betas_type():
"""
Feature: _check_param_value function parameter validation
Description: Test _check_param_value function with invalid betas type (string instead of tuple/list)
Expectation: Raise TypeError when betas is not a tuple or list
"""
betas = "invalid"
eps = 1e-8
weight_decay = 0.01
with pytest.raises(TypeError):
_check_param_value(betas, eps, weight_decay, "AdamW")


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_param_value_invalid_betas_length():
"""
Feature: _check_param_value function parameter validation
Description: Test _check_param_value function with invalid betas length (3 elements instead of 2)
Expectation: Raise ValueError when betas length is not 2
"""
betas = (0.9, 0.999, 0.9999)
eps = 1e-8
weight_decay = 0.01
with pytest.raises(ValueError):
_check_param_value(betas, eps, weight_decay, "AdamW")


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_param_value_invalid_beta1():
"""
Feature: _check_param_value function parameter validation
Description: Test _check_param_value function with invalid beta1 (equal to 1.0, which is out of range)
Expectation: Raise ValueError when beta1 is not in (0.0, 1.0)
"""
betas = (1.0, 0.999)
eps = 1e-8
weight_decay = 0.01
with pytest.raises(ValueError):
_check_param_value(betas, eps, weight_decay, "AdamW")


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_param_value_invalid_beta2():
"""
Feature: _check_param_value function parameter validation
Description: Test _check_param_value function with invalid beta2 (equal to 0.0, which is out of range)
Expectation: Raise ValueError when beta2 is not in (0.0, 1.0)
"""
betas = (0.9, 0.0)
eps = 1e-8
weight_decay = 0.01
with pytest.raises(ValueError):
_check_param_value(betas, eps, weight_decay, "AdamW")


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_param_value_invalid_eps():
"""
Feature: _check_param_value function parameter validation
Description: Test _check_param_value function with invalid eps (equal to 0.0, which is not greater than 0)
Expectation: Raise ValueError when eps is not greater than 0
"""
betas = (0.9, 0.999)
eps = 0.0
weight_decay = 0.01
with pytest.raises(ValueError):
_check_param_value(betas, eps, weight_decay, "AdamW")


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_param_value_invalid_weight_decay():
"""
Feature: _check_param_value function parameter validation
Description: Test _check_param_value function with invalid weight_decay type (string instead of float/int/Cell)
Expectation: Raise TypeError when weight_decay is not a float, int, or Cell
"""
betas = (0.9, 0.999)
eps = 1e-8
weight_decay = "invalid"
with pytest.raises(TypeError):
_check_param_value(betas, eps, weight_decay, "AdamW")


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_construct():
"""
Feature: AdamW optimizer construct method
Description: Test AdamW construct method with dummy gradients
Expectation: Successfully execute AdamW construct method and return non-None result
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params())

# Create dummy gradients
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float32)
)

# Test construct
result = optimizer(gradients)
assert result is not None


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_construct_with_group_lr():
"""
Feature: AdamW optimizer construct method with group learning rates
Description: Test AdamW construct method with grouped parameters having different learning rates
Expectation: Successfully execute AdamW construct method with group learning rates and return non-None result
"""
net = SimpleNet()
params = [
{'params': [net.weight], 'lr': 0.001},
{'params': [net.bias], 'lr': 0.0001}
]
optimizer = AdamW(params)

# Create dummy gradients
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float32)
)

# Test construct
result = optimizer(gradients)
assert result is not None


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_clone_state():
"""
Feature: AdamW optimizer clone_state method
Description: Test AdamW clone_state method to create copies of optimizer states
Expectation: Successfully clone optimizer states with correct shape and dtype
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params())

# Clone state
cloned_state = optimizer.clone_state("test", "zeros")
assert len(cloned_state) == len(optimizer.parameters)
for i, param in enumerate(cloned_state):
assert param.shape == optimizer.parameters[i].shape
assert param.dtype == mstype.float32


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_clone_state_with_swap():
"""
Feature: AdamW optimizer clone_state method with swap=True
Description: Test AdamW clone_state method with swap=True to clone states to CPU
Expectation: Successfully clone optimizer states to CPU when swap=True
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params(), swap=True)

# Clone state
cloned_state = optimizer.clone_state("test", "zeros")
assert len(cloned_state) == len(optimizer.parameters)
for param in cloned_state:
assert param.device == 'CPU'


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_with_dynamic_weight_decay():
"""
Feature: AdamW optimizer with dynamic weight decay
Description: Test AdamW initialization with dynamic weight decay
Expectation: Successfully initialize AdamW with dynamic weight decay
"""
net = SimpleNet()
weight_decay = 1.0
optimizer = AdamW(net.trainable_params(), weight_decay=weight_decay)
assert optimizer is not None


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_step():
"""
Feature: AdamW optimizer step execution
Description: Test AdamW step execution with dummy gradients
Expectation: Successfully execute one optimization step and update parameters
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params())

# Create dummy input and gradients
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float32)
)

# Run one step
optimizer(gradients)

# Check if parameters are updated
for param in net.trainable_params():
assert not np.all(param.asnumpy() == np.ones_like(param.asnumpy()))


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_with_zero_gradients():
"""
Feature: AdamW optimizer with zero gradients
Description: Test AdamW execution with zero gradients and non-zero weight decay
Expectation: Parameters are updated due to weight decay even with zero gradients
"""
net = SimpleNet()
# Use non-zero weight decay to ensure parameters are updated
optimizer = AdamW(net.trainable_params(), weight_decay=0.1)

# Create zero gradients
gradients = (
Tensor(np.zeros([2, 3]), mstype.float32),
Tensor(np.zeros([3]), mstype.float32)
)

# Run one step
optimizer(gradients)

# Check if parameters are still updated due to weight decay
for param in net.trainable_params():
assert not np.allclose(param.asnumpy(), np.ones_like(param.asnumpy()))


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_invalid_learning_rate():
"""
Feature: AdamW optimizer with invalid learning rate
Description: Test AdamW initialization with invalid learning rate type
Expectation: Raise ValueError when learning_rate is invalid
"""
net = SimpleNet()
with pytest.raises(ValueError):
AdamW(net.trainable_params(), learning_rate="invalid")


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_large_weight_decay():
"""
Feature: AdamW optimizer with large weight decay
Description: Test AdamW execution with large weight decay (0.1)
Expectation: Parameters are significantly updated due to large weight decay
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params(), weight_decay=0.1)

# Create dummy gradients
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float32)
)

# Run one step
optimizer(gradients)

# Check if parameters are updated significantly due to large weight decay
for param in net.trainable_params():
assert np.all(param.asnumpy() < np.ones_like(param.asnumpy()))


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_small_eps():
"""
Feature: AdamW optimizer with small epsilon
Description: Test AdamW execution with very small epsilon (1e-12)
Expectation: Successfully execute AdamW with small epsilon and return non-None result
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params(), eps=1e-12)

# Create dummy gradients
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float32)
)

# Run one step
result = optimizer(gradients)
assert result is not None


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_with_tuple_params():
"""
Feature: AdamW optimizer with tuple parameters
Description: Test AdamW initialization with tuple of parameters
Expectation: Successfully initialize AdamW with tuple of parameters
"""
net = SimpleNet()
params_tuple = tuple(net.trainable_params())
optimizer = AdamW(params_tuple)
assert optimizer is not None


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_global_step_increase():
"""
Feature: AdamW optimizer global step
Description: Test AdamW global step attribute and step execution
Expectation: Verify global_step attribute exists and is a Tensor, and successfully execute one step
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params())

# Create dummy gradients
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float32)
)

# Just verify that the global_step attribute exists and is a Tensor
assert hasattr(optimizer, 'global_step')
assert isinstance(optimizer.global_step, Tensor)

# Run one step to ensure the optimizer works correctly
result = optimizer(gradients)
assert result is not None


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_with_mixed_precision_params():
"""
Feature: AdamW optimizer with mixed precision parameters
Description: Test AdamW execution with parameters of different precisions
Expectation: Successfully initialize AdamW with mixed precision parameters and execute one step
"""
# Create parameters with different precisions
params = [
Parameter(Tensor(np.ones([2, 3]), mstype.float32), name="fp32_param"),
Parameter(Tensor(np.ones([3]), mstype.float16), name="fp16_param")
]
optimizer = AdamW(params)
assert optimizer is not None

# Create dummy gradients with matching precisions
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float16)
)

# Run one step
result = optimizer(gradients)
assert result is not None


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_optim_filter():
"""
Feature: AdamW optimizer with optim_filter
Description: Test AdamW execution with optim_filter set to include all parameters
Expectation: Successfully execute AdamW with optim_filter and return non-None result
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params())

# Create dummy gradients
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float32)
)

# Set optim_filter to include all parameters
optimizer.optim_filter = (True, True)

# Run one step
result = optimizer(gradients)
assert result is not None


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_construct_without_group():
"""
Feature: AdamW optimizer construct method without group
Description: Test AdamW construct method with is_group set to False
Expectation: Successfully execute AdamW construct method with is_group=False
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params())

# Set is_group to False
optimizer.is_group = False

# Create dummy gradients
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float32)
)

# Test construct
result = optimizer(gradients)
assert result is not None


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_with_large_lr():
"""
Feature: AdamW optimizer with large learning rate
Description: Test AdamW execution with large learning rate (1.0)
Expectation: Successfully execute AdamW with large learning rate and return non-None result
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params(), learning_rate=1.0)

# Create dummy gradients
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float32)
)

# Run one step
result = optimizer(gradients)
assert result is not None


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_with_tensor_lr():
"""
Feature: AdamW optimizer with Tensor learning rate
Description: Test AdamW initialization with Tensor learning rate
Expectation: Successfully initialize AdamW with Tensor learning rate
"""
net = SimpleNet()
lr_tensor = Tensor(np.array([0.001]), mstype.float32)
optimizer = AdamW(net.trainable_params(), learning_rate=lr_tensor)
assert optimizer is not None


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_with_iterable_lr():
"""
Feature: AdamW optimizer with Iterable learning rate
Description: Test AdamW initialization with Iterable learning rate
Expectation: Successfully initialize AdamW with Iterable learning rate
"""
net = SimpleNet()
lr_iter = [0.001, 0.0009, 0.0008]
optimizer = AdamW(net.trainable_params(), learning_rate=lr_iter)
assert optimizer is not None


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_optim_filter_false():
"""
Feature: AdamW optimizer with optim_filter=False
Description: Test AdamW execution with optim_filter set to False for some parameters
Expectation: Successfully execute AdamW with optim_filter=False and verify gradients are returned for
filtered parameters
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params())

# Create dummy gradients
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float32)
)

# Set optim_filter to False for all parameters
optimizer.optim_filter = (False, False)

# Run one step
result = optimizer(gradients)
assert result is not None
assert len(result) == len(gradients)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_get_weight_decay_and_lr():
"""
Feature: AdamW optimizer get_weight_decay and get_lr methods
Description: Test AdamW get_weight_decay and get_lr methods
Expectation: Successfully call get_weight_decay and get_lr methods
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params(), weight_decay=0.01, learning_rate=0.001)

# Test get_weight_decay method
weight_decay = optimizer.get_weight_decay()
assert weight_decay is not None

# Test get_lr method
lr = optimizer.get_lr()
assert lr is not None


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_clone_state_with_cloned_obj():
"""
Feature: AdamW optimizer clone_state method with existing cloned_obj
Description: Test AdamW clone_state method when old_param.param_info already has cloned_obj
Expectation: Successfully clone state and append to existing cloned_obj list
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params())

# First clone to create cloned_obj
first_clone = optimizer.clone_state("first", "zeros")

# Second clone should append to existing cloned_obj
second_clone = optimizer.clone_state("second", "zeros")

assert len(first_clone) == len(optimizer.parameters)
assert len(second_clone) == len(optimizer.parameters)

# Verify cloned_obj exists and has both clones
for old_param in optimizer.parameters:
assert hasattr(old_param.param_info, "cloned_obj")
assert len(old_param.param_info.cloned_obj) >= 2


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_construct_all_branches():
"""
Feature: AdamW optimizer construct method branches
Description: Test all branches of AdamW construct method
Expectation: Successfully execute all branches of construct method
"""
# Test is_group=False branch
net = SimpleNet()
optimizer = AdamW(net.trainable_params())
optimizer.is_group = False
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float32)
)
result = optimizer(gradients)
assert result is not None

# Test is_group=True, is_group_lr=True branch
params = [
{'params': [net.weight], 'lr': 0.001},
{'params': [net.bias], 'lr': 0.0001}
]
optimizer = AdamW(params)
optimizer.is_group = True
optimizer.is_group_lr = True
result = optimizer(gradients)
assert result is not None

# Test is_group=True, is_group_lr=False branch
optimizer.is_group_lr = False
result = optimizer(gradients)
assert result is not None


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_param_value_more_cases():
"""
Feature: _check_param_value function with more cases
Description: Test _check_param_value function with more parameter combinations
Expectation: Successfully validate parameters or raise expected exceptions
"""
# Test with float weight_decay
betas = (0.9, 0.999)
eps = 1e-8
weight_decay = 0.0
_check_param_value(betas, eps, weight_decay, "AdamW")

# Test with edge case betas
betas = (0.0001, 0.9999)
_check_param_value(betas, eps, weight_decay, "AdamW")

# Test with large eps
eps = 1e-3
_check_param_value(betas, eps, weight_decay, "AdamW")


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_clone_state_with_ones_init():
"""
Feature: AdamW clone_state method with ones init
Description: Test AdamW clone_state method with ones initialization
Expectation: Successfully clone state with ones init
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params())

# Test with 'ones' init
ones_clone = optimizer.clone_state("adam_m_ones", "ones")
assert len(ones_clone) == len(optimizer.parameters)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_with_large_learning_rate():
"""
Feature: AdamW optimizer with large learning rate
Description: Test AdamW execution with very large learning rate
Expectation: Successfully execute AdamW with large learning rate and update parameters significantly
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params(), learning_rate=1.0)

# Create dummy gradients
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float32)
)

# Run one step
optimizer(gradients)

# Check if parameters are significantly updated
for param in net.trainable_params():
assert not np.allclose(param.asnumpy(), np.ones_like(param.asnumpy()), atol=0.1)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_with_very_small_learning_rate():
"""
Feature: AdamW optimizer with very small learning rate
Description: Test AdamW execution with very small learning rate
Expectation: Successfully execute AdamW with very small learning rate
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params(), learning_rate=1e-10)

# Create dummy gradients
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float32)
)

# Run one step and check if it completes successfully
result = optimizer(gradients)
assert result is not None

+ 15
- 0
tests/st/test_optim/test_muon/__init__.py View File

@@ -0,0 +1,15 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test muon optimizer."""

+ 63
- 0
tests/st/test_optim/test_muon/data_utils.py View File

@@ -0,0 +1,63 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Baseline data for Muon optimizer tests.
"""
import numpy as np

# Default tolerance for loss comparison
DEFAULT_RTOL = 1e-4
DEFAULT_ATOL = 1e-4

# Baseline losses for single card test cases
# learning_rate=0.02, weight_decay=0.1, momentum=0.95, nesterov=True
BASELINE_LOSSES_NESTEROV_TRUE = np.array([
0.3881023, 7.8122883, 15.039654, 22.062939, 28.884716,
35.514862, 41.940598, 48.178577, 54.222153, 60.07846,
65.739815, 71.20518, 76.508705, 81.63688, 86.58084,
91.356064, 95.94581, 100.37069, 104.620384, 108.72005
], dtype=np.float32)

# learning_rate=0.02, weight_decay=0.1, momentum=0.95, nesterov=False
BASELINE_LOSSES_NESTEROV_FALSE = np.array([
0.3881023, 7.8122883, 15.032751, 22.052126, 28.875042,
35.503002, 41.92948, 48.16231, 54.218227, 60.07244,
65.745224, 71.22119, 76.5374, 81.64788, 86.525246,
91.292816, 95.89634, 100.308716, 104.57111, 108.64668
], dtype=np.float32)

# learning_rate=0.01, weight_decay=0.05, momentum=0.9, nesterov=True
BASELINE_LOSSES_DIFF_LR = np.array([
0.3881023, 7.8966713, 15.322964, 22.66404, 29.917278,
37.085056, 44.168663, 51.175865, 58.094597, 64.92998,
71.680595, 78.34835, 84.92714, 91.44285, 97.866035,
104.204056, 110.46475, 116.63603, 122.729706, 128.74644
], dtype=np.float32)


def compare_losses(actual_losses, expected_losses, rtol=DEFAULT_RTOL, atol=DEFAULT_ATOL):
"""
Compare actual losses with expected baseline losses.

Args:
actual_losses (np.ndarray): Actual losses from the test run
expected_losses (np.ndarray): Expected baseline losses
rtol (float): Relative tolerance for comparison
atol (float): Absolute tolerance for comparison

Returns:
bool: True if losses match within tolerance, False otherwise
"""
return np.allclose(actual_losses, expected_losses, rtol=rtol, atol=atol)

+ 236
- 0
tests/st/test_optim/test_muon/run_muon.py View File

@@ -0,0 +1,236 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Run Muon optimizer accuracy test with configurable parameters via args"""
import argparse
import numpy as np
import mindspore as ms
from mindspore import nn, Tensor

from mindformers.core.context.build_context import build_context
from mindformers.core.optim.muon import Muon

np.random.seed(1024)

# Test weight initialization - same as optimizer_util.py
FC1_WEIGHT = np.array([[0.72346634, 0.95608497, 0.4084163, 0.18627149,
0.6942514, 0.39767185, 0.24918061, 0.4548748],
[0.7203382, 0.19086994, 0.76286614, 0.87920564,
0.3169892, 0.9462494, 0.62827677, 0.27504718],
[0.3544535, 0.2524781, 0.5370583, 0.8313121,
0.6670143, 0.0488653, 0.62225235, 0.7546456],
[0.17985944, 0.05106374, 0.31064633, 0.4863033,
0.848814, 0.5523157, 0.20295663, 0.7213356]]).astype("float32")

FC1_BIAS = np.array([0.79708564, 0.13728078, 0.66322654, 0.88128525]).astype("float32")

FC2_WEIGHT = np.array([[0.8473515, 0.50923985, 0.42287776, 0.29769543]]).astype("float32")

FC2_BIAS = np.array([0.09996348]).astype("float32")


class MockTransformerConfig:
"""Mock transformer config for testing Muon optimizer."""
def __init__(self):
self.multi_latent_attention = True
self.tensor_model_parallel_size = 1
self.data_parallel_size = 1


class MockModel:
"""
Mock model class that provides required interfaces for Muon optimizer.
This simulates the model interface that Muon optimizer expects.
"""
def __init__(self):
self.config = MockTransformerConfig()

def get_gpt_transformer_config(self):
"""Return transformer config."""
return self.config

def make_model_muon_fns(self):
"""Return muon split and merge functions."""
def muon_split_fn(param_name, tensor): # pylint: disable=unused-argument
"""Split function - returns tensor as list."""
return [tensor]

def muon_merge_fn(param_name, tensor_list): # pylint: disable=unused-argument
"""Merge function - returns first tensor."""
return tensor_list[0]

return muon_split_fn, muon_merge_fn

# pylint: disable=unused-argument
def apply_qk_clip_scaling(self, params, param_names, param_layer, logit_threshold,
muon_split_fn, muon_merge_fn):
"""Apply query-key clipping scaling."""
return [(0, params[0])]

def get_param_layer_indices(self, params):
"""Return layer indices for parameters."""
return {p.name: 0 for p in params}

def get_muon_filter(self):
"""Return filter function to determine which params use Muon."""
def muon_filter(param):
# Apply Muon to weight parameters with 2D shape (not bias)
return len(param.shape) == 2 and 'bias' not in param.name
return muon_filter

def get_tp_dims(self, params):
"""Return tensor parallel dimensions."""
return tuple(-1 for _ in params)

def get_op_groups_info(self, params, op): # pylint: disable=unused-argument
"""Return optimizer parallel group info."""
ops = tuple(1 for _ in params)
op_groups = tuple("" for _ in params)
return ops, op_groups


class FakeNet(nn.Cell):
"""Build fake net for testing."""

def __init__(self):
super().__init__()
self.fc1 = nn.Dense(in_channels=8, out_channels=4,
weight_init=Tensor(FC1_WEIGHT),
bias_init=Tensor(FC1_BIAS))
self.fc2 = nn.Dense(in_channels=4, out_channels=1,
weight_init=Tensor(FC2_WEIGHT),
bias_init=Tensor(FC2_BIAS))
self.relu = nn.ReLU()

def construct(self, x):
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x


class NetWithLoss(nn.Cell):
"""Build net with loss."""

def __init__(self, network, loss_fn):
super().__init__()
self.network = network
self.loss = loss_fn

def construct(self, x, label):
out = self.network(x)
loss = self.loss(out, label)
return loss


def make_fake_data():
"""Make fake data for testing."""
data, label = [], []
for i in range(20):
data.append(ms.Tensor(np.array(np.ones((2, 8)) * i, dtype=np.float32)))
label.append(ms.Tensor(np.array(np.ones((2, 1)) * (i + 1), dtype=np.float32)))
return data, label


class MuonRunner:
"""Class to manage Muon optimizer test and training."""

def __init__(self, args_from_parser):
self.args = args_from_parser
self.learning_rate = self.args.learning_rate
self.weight_decay = self.args.weight_decay
self.momentum = self.args.momentum
self.nesterov = self.args.nesterov
self.num_steps = self.args.num_steps

def build_network(self):
"""Build network with Muon optimizer."""
net = FakeNet()
mock_model = MockModel()

loss_fn = nn.L1Loss(reduction='mean')
networkwithloss = NetWithLoss(net, loss_fn)
networkwithloss.set_train()

params = networkwithloss.trainable_params()

# Create Muon optimizer
optimizer = Muon(
params=params,
learning_rate=self.learning_rate,
weight_decay=self.weight_decay,
matched_adamw_rms=0.2,
momentum=self.momentum,
nesterov=self.nesterov,
adamw_betas=(0.95, 0.95),
adamw_eps=1e-8,
model=mock_model,
)

return networkwithloss, optimizer, mock_model

def run(self):
"""Run the training with Muon optimizer."""
networkwithloss, optimizer, mock_model = self.build_network()
trainonestepcell = nn.TrainOneStepCell(networkwithloss, optimizer)

losses = []
data, label = make_fake_data()
for i in range(self.num_steps):
loss = trainonestepcell(data[i], label[i])
losses.append(loss.asnumpy())

# Save results
output_dict = {
"losses": np.array(losses),
"num_muon_m": len(optimizer.muon_m),
"num_moments1": len(optimizer.moments1),
"num_moments2": len(optimizer.moments2),
}

# Save muon momentum values for weight parameters
muon_filter = mock_model.get_muon_filter()
# pylint: disable=protected-access
for idx, param in enumerate(optimizer._parameters):
if muon_filter(param):
muon_m_value = optimizer.muon_m[idx].asnumpy()
output_dict[f"muon_m_{idx}"] = muon_m_value

np.savez(self.args.output_path, **output_dict)
print(f"Results saved to {self.args.output_path}")


def main():
parser = argparse.ArgumentParser(description="Run Muon optimizer test")
parser.add_argument("--learning_rate", type=float, default=0.02)
parser.add_argument("--weight_decay", type=float, default=0.1)
parser.add_argument("--momentum", type=float, default=0.95)
parser.add_argument("--nesterov", type=lambda x: x.lower() == "true", default=True)
parser.add_argument("--num_steps", type=int, default=20)
parser.add_argument("--output_path", type=str, default="output_muon.npz")

args = parser.parse_args()

# Set context
build_context({"use_legacy": False, "use_parallel": True})
ms.set_deterministic(True)
ms.set_context(mode=ms.GRAPH_MODE)
ms.set_seed(42)

# Run training
runner = MuonRunner(args)
runner.run()


if __name__ == "__main__":
main()

+ 202
- 0
tests/st/test_optim/test_muon/test_muon.py View File

@@ -0,0 +1,202 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test module for testing the Muon optimizer interface used for MindFormers.
How to run this:
pytest tests/st/test_optim/test_muon/test_muon.py
"""
from pathlib import Path
import subprocess
import pytest
import numpy as np

from tests.st.test_optim.test_muon.data_utils import (
BASELINE_LOSSES_NESTEROV_TRUE,
BASELINE_LOSSES_NESTEROV_FALSE,
BASELINE_LOSSES_DIFF_LR,
compare_losses,
DEFAULT_RTOL,
DEFAULT_ATOL,
)

from mindformers.tools.logger import logger

# Test parameters definition
SINGLE_CARD_TEST_CASES = [
# Default config with nesterov=True
{
"learning_rate": 0.02,
"weight_decay": 0.1,
"momentum": 0.95,
"nesterov": True,
"num_steps": 20,
"baseline_losses": BASELINE_LOSSES_NESTEROV_TRUE,
},
# Config without Nesterov momentum
{
"learning_rate": 0.02,
"weight_decay": 0.1,
"momentum": 0.95,
"nesterov": False,
"num_steps": 20,
"baseline_losses": BASELINE_LOSSES_NESTEROV_FALSE,
},
# Config with different learning rate
{
"learning_rate": 0.01,
"weight_decay": 0.05,
"momentum": 0.9,
"nesterov": True,
"num_steps": 20,
"baseline_losses": BASELINE_LOSSES_DIFF_LR,
},
]


def build_msrun_command_list(
worker_num,
local_worker_num,
log_dir,
run_script_path,
learning_rate,
weight_decay,
momentum,
nesterov,
num_steps,
output_path,
port=29500
):
"""Build the msrun command with the specified parameters."""
cmd_list = [
"msrun",
f"--worker_num={worker_num}",
f"--local_worker_num={local_worker_num}",
f"--master_port={port}",
f"--log_dir={log_dir}",
"--join=True",
str(run_script_path),
f"--learning_rate={learning_rate}",
f"--weight_decay={weight_decay}",
f"--momentum={momentum}",
f"--nesterov={str(nesterov).lower()}",
f"--num_steps={num_steps}",
f"--output_path={output_path}",
]
logger.info(f"Equivalent shell command for Muon test: {' '.join(cmd_list)}")
return cmd_list


class TestMuon:
"""Test class for Muon optimizer with different configurations."""
OUTPUT_FILENAME = "output_muon.npz"
LOG_DIR_NAME = "msrun_log"

def setup_method(self):
"""Setup method to prepare test environment."""
self.sh_path = Path(__file__).parent.resolve()
self.run_script_path = self.sh_path / "run_muon.py"

def check_results(self, output_dict, baseline_losses=None):
"""
Check the output results from the Muon optimizer run.

Args:
output_dict: Dictionary containing the output results
num_params: Expected number of parameters
baseline_losses: Expected baseline losses for comparison
"""
# Check losses
losses = output_dict.get("losses")
assert losses is not None, "Losses not found in output"
assert len(losses) > 0, "Losses array is empty"
assert not np.any(np.isnan(losses)), "Losses contain NaN values"
assert not np.any(np.isinf(losses)), "Losses contain Inf values"

# Compare with baseline if provided
if baseline_losses is not None:
assert compare_losses(losses, baseline_losses, rtol=DEFAULT_RTOL, atol=DEFAULT_ATOL), (
f"Losses do not match baseline.\n"
f"Actual: {losses}\n"
f"Expected: {baseline_losses}\n"
f"Max diff: {np.max(np.abs(losses - baseline_losses))}"
)

def run_test(
self,
worker_num,
local_worker_num,
optimizer_args,
tmp_path,
port=29500,
baseline_losses=None
):
"""Helper function to run test and check results."""
output_file_path = tmp_path / self.OUTPUT_FILENAME
log_dir_path = tmp_path / self.LOG_DIR_NAME
log_dir_path.mkdir(parents=True, exist_ok=True)

cmd_list = build_msrun_command_list(
worker_num=worker_num,
local_worker_num=local_worker_num,
log_dir=log_dir_path,
run_script_path=self.run_script_path,
learning_rate=optimizer_args["learning_rate"],
weight_decay=optimizer_args["weight_decay"],
momentum=optimizer_args["momentum"],
nesterov=optimizer_args["nesterov"],
num_steps=optimizer_args["num_steps"],
output_path=output_file_path,
port=port
)

result = subprocess.run(
cmd_list, shell=False, capture_output=True, text=True, check=False
)

assert result.returncode == 0, (
f"Test script failed with non-zero exit code: "
f"{result.returncode}.\nStdout:\n{result.stdout}\nStderr:\n{result.stderr}"
)
assert output_file_path.exists(), (
f"Output file {output_file_path} was not created."
)

output_dict = np.load(output_file_path)
self.check_results(output_dict, baseline_losses=baseline_losses)

return output_dict


@pytest.mark.level0
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
class TestMuonSingleCard(TestMuon):
"""Test class for Muon optimizer with single card configurations."""

@pytest.mark.parametrize("optimizer_args", SINGLE_CARD_TEST_CASES)
def test_muon_single_card(self, optimizer_args, tmp_path):
"""
Feature: Muon optimizer training
Description: Test computation of Muon optimizer with various configurations.
Expectation: Training completes successfully with valid losses matching baseline
"""
baseline_losses = optimizer_args.get("baseline_losses")
self.run_test(
worker_num=1,
local_worker_num=1,
optimizer_args=optimizer_args,
tmp_path=tmp_path,
baseline_losses=baseline_losses
)

+ 29
- 0
tests/st/test_run_check.py View File

@@ -0,0 +1,29 @@
# Copyright 2025 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test for run_check function"""
import pytest
from mindformers import run_check


@pytest.mark.level0
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_run_check():
"""
Feature: Test run_check function
Description: Call run_check to check if MindSpore, MindFormers, CANN and driver versions are compatible
Expectation: No exceptions raised, all checks pass
"""
run_check()

+ 1126
- 50
tests/st/test_safetensors/test_checkpoint_utils.py View File

@@ -13,17 +13,101 @@
# limitations under the License.
# ============================================================================
"""test for load_checkpoint_utils."""
# pylint: disable=W0621
import os
import json
import tempfile
from unittest.mock import patch, MagicMock

import pytest
import numpy as np
from mindspore import Parameter

from mindformers.tools.register import MindFormerConfig
from mindformers.checkpoint.utils import compile_model
from mindformers.utils.load_checkpoint_utils import CkptFormat, _get_checkpoint_mode, CheckpointFileMode, \
_check_checkpoint_path

from mindformers.checkpoint.utils import compile_model, check_checkpoints_dir_max_num, get_checkpoint_iter_dir, \
get_checkpoint_tracker_filename, get_common_filename, get_metadata_filename, \
get_latest_iteration_from_tracker, get_checkpoint_name, get_sharded_tensor_shard_id, \
sharded_tensor_shard_id, _reverse_sharded_tensor_shard_id, _get_shard_size, verify_ckpt_valid, FileType
from mindformers.models.modeling_utils import PreTrainedModel
from mindformers.utils.load_checkpoint_utils import (
CkptFormat, _get_checkpoint_mode, CheckpointFileMode, _check_checkpoint_path,
extract_suffix, get_last_checkpoint, validate_config_with_file_mode,
update_global_step, unify_safetensors, _revise_remove_redundancy_with_file,
_get_origin_network, get_load_path_after_hf_convert, _get_src_strategy,
_get_src_file_suffix, _get_src_file, load_safetensors_checkpoint,
process_hf_checkpoint, validate_qkv_concat, get_merged_src_strategy_path,
get_merged_dst_strategy_path, process_for_stand_alone_mode,
load_checkpoint_with_safetensors
)


@pytest.fixture
def mock_config():
"""Create a mock config with default values"""

class MockConfig:
"""Mock configuration class for testing"""

def __init__(self):
self.load_checkpoint = "/path/to/checkpoint"
self.load_ckpt_format = "safetensors"
self.use_parallel = False
self.auto_trans_ckpt = False
self.resume_training = None
self.remove_redundancy = False
self.output_dir = "/output"
self.src_strategy_path_or_dir = None
self.load_ckpt_async = False
self.context = type('', (), {})()
self.context.mode = "GRAPH_MODE"
self.runner_config = type('', (), {})()
self.runner_config.sink_mode = True
self.runner_config.epochs = 1
self.runner_config.sink_size = 1
self.runner_config.step_scale = 2.0
self.model = type('', (), {})()
self.model.model_config = {}
self.parallel = type('', (), {})()
self.parallel.parallel_mode = "DATA_PARALLEL"

def get(self, key, default=None):
return getattr(self, key, default)

return MockConfig()


@pytest.fixture
def mock_network():
"""Create a mock network"""
mock_net = MagicMock()
mock_net.cells.return_value = []
return mock_net


@pytest.fixture
def mock_model():
"""Create a mock model"""
mock_mod = MagicMock()
mock_mod.config = MagicMock()
mock_mod.config.model_type = "test_model"
return mock_mod


@pytest.fixture
def mock_file():
"""Create a mock file"""
mock_f = MagicMock()
mock_f.metadata.return_value = None
return mock_f


class TestCommonCheckpointMethod:
"""A test class for testing common methods"""

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_support_type(self):
"""test CkptFormat support type"""
# run the test
@@ -32,25 +116,713 @@ class TestCommonCheckpointMethod:
# verify the results
assert result == ['ckpt', 'safetensors']

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_checkpoint_path_with_non_string_pathlike(self):
"""test check checkpoint path with non string pathlike"""
path = 123
with pytest.raises(ValueError,
match=r"config.load_checkpoint must be a str, but got 123 as type <class 'int'>."):
match=r"config.load_checkpoint must be a `str`, but got `123` as type `<class 'int'>`."):
_check_checkpoint_path(path)

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_checkpoint_path_with_nonexistent_path(self):
"""test check checkpoint path with nonexistent path"""
path = 'NoneExistPath'
with pytest.raises(FileNotFoundError, match=r"config.load_checkpoint NoneExistPath does not exist."):
with pytest.raises(FileNotFoundError, match=r"config.load_checkpoint `NoneExistPath` does not exist."):
_check_checkpoint_path(path)

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_checkpoint_path_with_valid_path(self):
"""test check checkpoint path with valid path"""
# create a temporary directory for testing
with tempfile.TemporaryDirectory() as tmpdir:
# test with directory path
result = _check_checkpoint_path(tmpdir)
assert result == tmpdir

# test with directory path ending with slash
result = _check_checkpoint_path(tmpdir + '/')
assert result == tmpdir

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize(
"file_path, expected",
[
# test pattern 1: {prefix}_rank_{rank_id}-{epoch}_{step}.safetensors
("model_rank_0-10_200.safetensors", "-10_200"),
# test pattern 2: {prefix}_rank_{rank_id}_{task_id}-{epoch}_{step}.safetensors
("model_rank_0_1-10_200.safetensors", "_1-10_200"),
# test with invalid pattern
("invalid_filename.safetensors", "invalid_filename")
]
)
def test_extract_suffix(self, file_path, expected):
"""test extract_suffix function"""
result = extract_suffix(file_path)
assert result == expected

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_last_checkpoint(self):
"""test get_last_checkpoint function"""
# setup mocks using context managers
with patch('os.path.isdir') as mock_isdir, \
patch('os.path.exists') as mock_exists, \
patch('os.listdir') as mock_listdir, \
patch('os.path.getmtime') as mock_getmtime:
# setup mock return values
mock_isdir.return_value = True
mock_exists.return_value = True
mock_listdir.return_value = ["model_0.ckpt", "model_1.ckpt", "model_2.ckpt"]
mock_getmtime.side_effect = lambda x: {
"/test/model_0.ckpt": 100,
"/test/model_1.ckpt": 200,
"/test/model_2.ckpt": 300
}[x]

# test with valid directory
result = get_last_checkpoint("/test", "ckpt")
assert result == "/test/model_2.ckpt"

# test with no checkpoint files
mock_listdir.return_value = ["other_file.txt"]
result = get_last_checkpoint("/test", "ckpt")
assert result is None

# test with invalid directory
mock_isdir.return_value = False
with pytest.raises(NotADirectoryError):
get_last_checkpoint("/invalid/dir", "ckpt")

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize(
"file_mode, use_parallel, auto_trans_ckpt, expected_exception",
[
# test single checkpoint file mode with parallel
(CheckpointFileMode.SINGLE_CHECKPOINT_FILE.value, True, False, ValueError),
# test multi checkpoint file mode with parallel but no auto_trans_ckpt
(CheckpointFileMode.MULTI_CHECKPOINT_FILE.value, True, False, ValueError),
# test multi checkpoint file with rank id mode without parallel
(CheckpointFileMode.MULTI_CHECKPOINT_FILE_WITH_RANK_ID.value, False, False, ValueError),
# test invalid mode
("invalid_mode", False, False, ValueError),
# test valid cases - no exception expected
(CheckpointFileMode.SINGLE_CHECKPOINT_FILE.value, False, False, None),
(CheckpointFileMode.MULTI_CHECKPOINT_FILE.value, True, True, None),
(CheckpointFileMode.MULTI_CHECKPOINT_FILE_WITH_RANK_ID.value, True, False, None)
]
)
def test_validate_config_with_file_mode(self, file_mode, use_parallel, auto_trans_ckpt, expected_exception):
"""test validate_config_with_file_mode function"""
if expected_exception:
with pytest.raises(expected_exception):
validate_config_with_file_mode(file_mode, use_parallel, auto_trans_ckpt)
else:
validate_config_with_file_mode(file_mode, use_parallel, auto_trans_ckpt)

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize(
"step_scale, initial_global_step, expected_global_step, expected_in_dict",
[
(2.0, 100, 200, True),
(None, 100, 100, True),
(2.0, None, None, False)
]
)
def test_update_global_step(self, step_scale, initial_global_step, expected_global_step, expected_in_dict):
"""test update_global_step function"""
# setup config
config = type('', (), {})()
config.runner_config = type('', (), {})()
config.runner_config.step_scale = step_scale

# setup hyper_param_dict
hyper_param_dict = {}
if initial_global_step is not None:
hyper_param_dict["global_step"] = Parameter(np.array(initial_global_step, dtype=np.int32))

# test update_global_step
update_global_step(config, hyper_param_dict)

# verify the results
if expected_in_dict:
assert "global_step" in hyper_param_dict
assert hyper_param_dict["global_step"].asnumpy() == expected_global_step
else:
assert "global_step" not in hyper_param_dict

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_unify_safetensors(self):
"""test unify_safetensors function"""
# setup mocks using context managers
with patch('mindformers.utils.load_checkpoint_utils.is_main_rank') as mock_is_main_rank, \
patch('mindformers.utils.load_checkpoint_utils.barrier') as mock_barrier, \
patch('mindspore.unified_safetensors') as mock_unified_safetensors:
# test when is_main_rank is True
mock_is_main_rank.return_value = True
unify_safetensors("/src/checkpoint", "/src/strategy", "/dst/unified", True, "-10_200", False)
mock_unified_safetensors.assert_called_once()
mock_barrier.assert_called_once()

# test when is_main_rank is False
mock_is_main_rank.return_value = False
mock_barrier.reset_mock()
unify_safetensors("/src/checkpoint", "/src/strategy", "/dst/unified", True, "-10_200", False)
mock_unified_safetensors.assert_called_once() # should not be called again
mock_barrier.assert_called_once()

# test without parallel
mock_is_main_rank.return_value = True
mock_barrier.reset_mock()
unify_safetensors("/src/checkpoint", "/src/strategy", "/dst/unified", False, "-10_200", False)
mock_barrier.assert_not_called()

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize(
"config_remove_redundancy, metadata, expected_result",
[
# test with metadata remove_redundancy=True and config remove_redundancy=False
(False, {"remove_redundancy": "True"}, True),
# test with metadata remove_redundancy=False and config remove_redundancy=True
(True, {"remove_redundancy": "False"}, False),
# test with matching metadata and config
(True, {"remove_redundancy": "True"}, True),
# test with no metadata
(True, None, True),
# test with metadata but no remove_redundancy key
(True, {"other_key": "value"}, True)
]
)
def test__revise_remove_redundancy_with_file(self, config_remove_redundancy, metadata, expected_result, mock_file):
"""test _revise_remove_redundancy_with_file function"""
mock_file.metadata.return_value = metadata
result = _revise_remove_redundancy_with_file(config_remove_redundancy, mock_file)
assert result == expected_result

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize(
"network_has_convert_name, child_has_convert_name, expected_found",
[
# test with network that has convert_name
(True, False, True),
# test with nested network where child has convert_name
(False, True, True),
# test with network that doesn't have convert_name and no children with it
(False, False, False)
]
)
def test__get_origin_network(self, network_has_convert_name, child_has_convert_name, expected_found):
"""test _get_origin_network function"""
# setup mocks using context managers
with patch('mindformers.utils.load_checkpoint_utils.logger'):
if network_has_convert_name:
# create a mock network with convert_name attribute
mock_network = MagicMock()
mock_network.convert_name = MagicMock()
# Return empty list for cells() to avoid recursion
mock_network.cells.return_value = []
else:
if child_has_convert_name:
# create a mock network without convert_name but with a child that has it
mock_child = MagicMock()
mock_child.convert_name = MagicMock()
# Return empty list for cells() to avoid further recursion
mock_child.cells.return_value = []

# Create a network that returns the child directly when cells() is called
mock_network = MagicMock()
mock_network.cells.return_value = [mock_child]
else:
# create a mock network without convert_name and no children with it
mock_network = MagicMock()
mock_network.cells.return_value = []

# Remove convert_name attribute to simulate network without it
if hasattr(mock_network, 'convert_name'):
delattr(mock_network, 'convert_name')

# run the test
_, found = _get_origin_network(mock_network)
assert found == expected_found

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_load_path_after_hf_convert(self, mock_config, mock_network):
"""test get_load_path_after_hf_convert function"""
# setup mocks using context managers
with patch('mindformers.utils.load_checkpoint_utils.is_hf_safetensors_dir') as mock_is_hf_safetensors_dir, \
patch('mindformers.utils.load_checkpoint_utils.'
'check_safetensors_addition_param_support') as mock_check_support:
# test when not hf safetensors
mock_is_hf_safetensors_dir.return_value = False
result = get_load_path_after_hf_convert(mock_config, mock_network)
assert result == "/path/to/checkpoint"

# test when hf safetensors but not qkv_concat and not supported
mock_is_hf_safetensors_dir.return_value = True
mock_check_support.return_value = False
mock_config.model.model_config = {"qkv_concat": False}

with patch('mindformers.utils.load_checkpoint_utils.process_hf_checkpoint',
return_value="/path/to/converted"):
with patch('mindformers.utils.load_checkpoint_utils.barrier'):
result = get_load_path_after_hf_convert(mock_config, mock_network)
assert result == "/path/to/converted"

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test__get_src_strategy(self, mock_config):
"""test _get_src_strategy function"""
# setup mocks using context managers
with patch('os.path.isfile') as mock_isfile, \
patch('os.path.isdir') as mock_isdir, \
patch('os.path.join') as mock_join, \
patch('os.path.exists') as mock_exists, \
patch('os.path.dirname') as mock_dirname, \
patch('mindformers.utils.load_checkpoint_utils.logger'):
# Test case 1: input_src_strategy is provided
mock_config.load_checkpoint = "/test/checkpoint.ckpt"
mock_config.src_strategy_path_or_dir = "/input/strategy"
mock_isdir.return_value = True
result = _get_src_strategy(mock_config)
assert result == "/input/strategy"

# Test case 2: no strategy dir exists
mock_config.src_strategy_path_or_dir = None
mock_isfile.return_value = True
mock_exists.return_value = False

with pytest.raises(
ValueError,
match="when use checkpoint after train/finetune, src_strategy_path_or_dir should be set"
):
_get_src_strategy(mock_config)

# Test case 3: config.load_checkpoint is a directory and strategy dir exists
mock_isfile.return_value = False
mock_exists.return_value = True

# Setup mock_dirname to return a valid parent directory
mock_dirname.return_value = "/test"

# Setup mock_join to return a valid path
mock_join.return_value = "/test/strategy"

mock_config.load_checkpoint = "/test/checkpoint_dir"
mock_config.src_strategy_path_or_dir = None

result = _get_src_strategy(mock_config)
assert result == "/test/strategy"

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test__get_src_file_suffix(self, mock_config):
"""test _get_src_file_suffix function"""
# setup mocks using context managers
with patch('mindformers.utils.load_checkpoint_utils.is_main_rank') as mock_is_main_rank, \
patch('mindformers.utils.load_checkpoint_utils.get_last_checkpoint') as mock_get_last_checkpoint, \
patch('os.path.isfile') as mock_isfile, \
patch('os.path.isdir') as mock_isdir:
# test when is_main_rank is True and resume_training is string
mock_is_main_rank.return_value = True
mock_config.resume_training = "checkpoint-10_200.safetensors"
mock_config.load_checkpoint = "/path/to/checkpoint"
mock_config.load_ckpt_format = "safetensors"

with patch('mindformers.utils.load_checkpoint_utils.extract_suffix', return_value="-10_200"):
result = _get_src_file_suffix(mock_config)
assert result == ("/path/to/checkpoint", "-10_200")

# test when is_main_rank is True and load_checkpoint is file
mock_isfile.return_value = True
mock_isdir.return_value = False
mock_config.resume_training = None
mock_config.load_checkpoint = "/path/to/rank_0/checkpoint-10_200.safetensors"

with patch('mindformers.utils.load_checkpoint_utils.extract_suffix', return_value="-10_200"):
result = _get_src_file_suffix(mock_config)
assert result == ("/path/to", "-10_200")

# test when is_main_rank is True and load_checkpoint is dir
mock_isfile.return_value = False
mock_isdir.return_value = True
mock_config.load_checkpoint = "/path/to/checkpoint"
mock_get_last_checkpoint.return_value = "/path/to/checkpoint/rank_0/checkpoint-10_200.safetensors"

with patch('mindformers.utils.load_checkpoint_utils.extract_suffix', return_value="-10_200"):
result = _get_src_file_suffix(mock_config)
assert result == ("/path/to/checkpoint", "-10_200")

# test when is_main_rank is False
mock_is_main_rank.return_value = False
result = _get_src_file_suffix(mock_config)
assert result == (None, None)

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('mindformers.utils.load_checkpoint_utils.logger')
def test__get_src_file(self, mock_logger):
"""test _get_src_file function"""
# setup mocks using context managers
with patch('os.path.exists') as mock_exists, \
patch('os.path.join') as mock_join, \
patch('mindformers.utils.load_checkpoint_utils.get_real_rank') as mock_get_real_rank, \
patch('mindformers.utils.load_checkpoint_utils.get_last_checkpoint') as mock_get_last_checkpoint:
# test with checkpoint_name provided
mock_get_real_rank.return_value = 0
mock_join.return_value = "/test/rank_0/checkpoint.ckpt"
mock_exists.return_value = True

result = _get_src_file("/test", "checkpoint.ckpt", "ckpt")
assert result == "/test/rank_0/checkpoint.ckpt"

# test without checkpoint_name
mock_get_last_checkpoint.return_value = "/test/rank_0/last_checkpoint.ckpt"
result = _get_src_file("/test", None, "ckpt")
assert result == "/test/rank_0/last_checkpoint.ckpt"

# test with non-existent file
mock_exists.return_value = False
with pytest.raises(FileNotFoundError):
_get_src_file("/test", "non_existent.ckpt", "ckpt")

# Verify that logger.error has been called.
mock_logger.error.assert_called()

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_load_safetensors_checkpoint(self, mock_config, mock_network, mock_file):
"""test load_safetensors_checkpoint function"""
# Setup mocks using context managers
with patch('mindformers.utils.load_checkpoint_utils._get_origin_network') as mock_get_origin_network, \
patch('mindformers.utils.load_checkpoint_utils.ms') as mock_ms, \
patch('mindformers.utils.load_checkpoint_utils.logger'), \
patch('mindformers.utils.load_checkpoint_utils.safe_open') as mock_safe_open, \
patch('mindformers.utils.load_checkpoint_utils.is_hf_safetensors_dir') as mock_is_hf_safetensors_dir:
# Setup mock return values
mock_get_origin_network.return_value = (MagicMock(), False)
mock_ms.load_checkpoint.return_value = {"param1": MagicMock()}
mock_is_hf_safetensors_dir.return_value = False

# Mock the safe_open context manager
mock_safe_open.return_value.__enter__.return_value = mock_file

strategy_path = "/path/to/strategy"
load_ckpt_path = "/path/to/checkpoint"
optimizer = None

load_safetensors_checkpoint(mock_config, ["/path/to/checkpoint.safetensors"], mock_network, strategy_path,
load_ckpt_path,
optimizer)
mock_ms.load_param_into_net.assert_called_once()

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_process_hf_checkpoint(self, mock_model, tmp_path):
"""test process_hf_checkpoint function"""
# setup mocks using context managers
with patch('mindformers.utils.load_checkpoint_utils.is_main_rank') as mock_is_main_rank, \
patch('mindformers.utils.load_checkpoint_utils.barrier_world') as mock_barrier_world, \
patch('mindformers.utils.load_checkpoint_utils.Process') as mock_process:
# test when is_main_rank is True
mock_is_main_rank.return_value = True
mock_process_instance = MagicMock()
mock_process_instance.exitcode = 0
mock_process.return_value = mock_process_instance

# Use tmp_path for output and input paths
output_dir = tmp_path / "output" / "dir"
input_checkpoint = tmp_path / "input" / "checkpoint"
# Create input directory
input_checkpoint.parent.mkdir(parents=True, exist_ok=True)

result = process_hf_checkpoint(mock_model, str(output_dir), str(input_checkpoint))
expected_path = str(output_dir / "test_model_ms_converted_weight")
assert result == expected_path
mock_process_instance.start.assert_called_once()
mock_process_instance.join.assert_called_once()
mock_barrier_world.assert_called_once()

# Reset mocks for next test case
mock_process.reset_mock()
mock_process_instance = MagicMock()
mock_process_instance.exitcode = 1
mock_process.return_value = mock_process_instance

# test when process exits with error
with pytest.raises(RuntimeError, match="convert HuggingFace weight failed."):
process_hf_checkpoint(mock_model, str(output_dir), str(input_checkpoint))

# Reset mocks for next test case
mock_process.reset_mock()
mock_process_instance = MagicMock()
mock_process_instance.exitcode = 0
mock_process.return_value = mock_process_instance

# test when is_main_rank is False
mock_is_main_rank.return_value = False
process_hf_checkpoint(mock_model, str(output_dir), str(input_checkpoint))
mock_process_instance.start.assert_not_called()

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize(
"model, qkv_concat_config, check_safetensors_key_return, "
"has_concat_keys, expected_exception, should_log_warning",
[
# test with non-PreTrainedModel
("not_a_model", False, False, False, None, True),
# test with PreTrainedModel but no concat keys
(MagicMock(spec=PreTrainedModel), False, False, False, None, False),
# Test case where check_safetensors_key returns True and qkv_concat_config is True
(MagicMock(spec=PreTrainedModel), True, True, True, None, False),
# Test case where check_safetensors_key returns False and qkv_concat_config is True
(MagicMock(spec=PreTrainedModel), True, False, True, ValueError, False),
# Test case where check_safetensors_key returns True and qkv_concat_config is False
(MagicMock(spec=PreTrainedModel), False, True, True, ValueError, False)
]
)
def test_validate_qkv_concat(self, model, qkv_concat_config,
check_safetensors_key_return, has_concat_keys, expected_exception, should_log_warning):
"""test validate_qkv_concat function"""
# Setup mocks using context managers
with patch('mindformers.utils.load_checkpoint_utils.logger') as mock_logger, \
patch('mindformers.utils.load_checkpoint_utils.check_safetensors_key') as mock_check_safetensors_key:

# Setup mock behavior
mock_check_safetensors_key.return_value = check_safetensors_key_return

# If it's a PreTrainedModel, set up obtain_qkv_ffn_concat_keys
if hasattr(model, 'obtain_qkv_ffn_concat_keys'):
model.obtain_qkv_ffn_concat_keys.return_value = ["qkv_concat_key"] if has_concat_keys else None

# Run the test and check results
if expected_exception:
with pytest.raises(expected_exception, match="The qkv concat check failed!"):
validate_qkv_concat(model, qkv_concat_config, "/path/to/checkpoint")
else:
validate_qkv_concat(model, qkv_concat_config, "/path/to/checkpoint")

# Check if warning was logged when expected
if should_log_warning:
mock_logger.warning.assert_called_once()
else:
mock_logger.warning.assert_not_called()

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_merged_src_strategy_path(self, mock_config):
"""test get_merged_src_strategy_path function"""
# setup mocks using context managers
with patch('mindformers.utils.load_checkpoint_utils.is_main_rank') as mock_is_main_rank, \
patch('mindformers.utils.load_checkpoint_utils.barrier') as mock_barrier, \
patch('mindformers.utils.load_checkpoint_utils._get_src_strategy') as mock_get_src_strategy, \
patch('mindformers.utils.load_checkpoint_utils.ms.merge_pipeline_strategys') as mock_merge_strategys, \
patch('os.makedirs'):
# test when is_main_rank is True
mock_is_main_rank.return_value = True
mock_get_src_strategy.return_value = "/input/strategy"

result = get_merged_src_strategy_path(mock_config)
assert result == "/output/merged_strategy/src_strategy.ckpt"
mock_merge_strategys.assert_called_once()
mock_barrier.assert_called_once()

# test when is_main_rank is False
mock_is_main_rank.return_value = False
mock_barrier.reset_mock()
result = get_merged_src_strategy_path(mock_config)
mock_merge_strategys.assert_called_once() # should not be called again
mock_barrier.assert_called_once()

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_merged_dst_strategy_path(self, mock_config):
"""test get_merged_dst_strategy_path function"""
# setup mocks using context managers
with patch('mindformers.utils.load_checkpoint_utils.is_main_rank') as mock_is_main_rank, \
patch('mindformers.utils.load_checkpoint_utils.barrier') as mock_barrier, \
patch('mindformers.utils.load_checkpoint_utils.ms.merge_pipeline_strategys') as mock_merge_strategys, \
patch('os.makedirs'):
# test with use_parallel=True, auto_trans_ckpt=True, not stand_alone
mock_is_main_rank.return_value = True

mock_config.use_parallel = True
mock_config.auto_trans_ckpt = True
mock_config.parallel.parallel_mode = "DATA_PARALLEL"

strategy_path = "/path/to/strategy.ckpt"

result = get_merged_dst_strategy_path(mock_config, strategy_path)
assert result == "/output/merged_strategy/dst_strategy.ckpt"
mock_merge_strategys.assert_called_once()
mock_barrier.assert_called_once()

# test with stand_alone mode
mock_config.parallel.parallel_mode = "STAND_ALONE"
result = get_merged_dst_strategy_path(mock_config, strategy_path)
assert result == "/path/to/strategy.ckpt"

# test with use_parallel=False
mock_config.use_parallel = False
result = get_merged_dst_strategy_path(mock_config, strategy_path)
assert result == "/path/to/strategy.ckpt"

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_process_for_stand_alone_mode(self, mock_config, mock_network):
"""test process_for_stand_alone_mode function"""
strategy_path = "/path/to/strategy.ckpt"

# setup mocks using context managers
with patch('mindformers.utils.load_checkpoint_utils._pynative_executor'), \
patch('mindformers.utils.load_checkpoint_utils.is_main_rank') as mock_is_main_rank, \
patch('mindformers.utils.load_checkpoint_utils.barrier') as mock_barrier, \
patch('mindformers.utils.load_checkpoint_utils.generate_state_dict') as mock_generate_state_dict, \
patch('mindformers.utils.load_checkpoint_utils.save_strategy_file') as mock_save_strategy_file, \
patch('os.makedirs') as mock_makedirs, \
patch('shutil.rmtree') as mock_rmtree, \
patch('os.path.exists') as mock_exists:
# test with stand_alone mode
mock_is_main_rank.return_value = True
mock_exists.return_value = True
mock_config.parallel.parallel_mode = "STAND_ALONE"
mock_config.use_parallel = True

process_for_stand_alone_mode(mock_config, mock_network, strategy_path)
mock_rmtree.assert_called_once()
mock_makedirs.assert_called_once()
mock_generate_state_dict.assert_called_once()
mock_save_strategy_file.assert_called_once()
mock_barrier.assert_called()

# Reset mocks for next test case
mock_barrier.reset_mock()
mock_rmtree.reset_mock()
mock_makedirs.reset_mock()
mock_generate_state_dict.reset_mock()
mock_save_strategy_file.reset_mock()

# test when strategy dir doesn't exist
mock_exists.return_value = False
process_for_stand_alone_mode(mock_config, mock_network, strategy_path)
mock_rmtree.assert_not_called()

# Reset mocks for next test case
mock_barrier.reset_mock()
mock_rmtree.reset_mock()
mock_makedirs.reset_mock()
mock_generate_state_dict.reset_mock()
mock_save_strategy_file.reset_mock()

# test when not stand_alone mode
mock_config.parallel.parallel_mode = "DATA_PARALLEL"
process_for_stand_alone_mode(mock_config, mock_network, strategy_path)
mock_rmtree.assert_not_called()
mock_barrier.assert_not_called()

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_load_checkpoint_with_safetensors(self, mock_config, mock_model, mock_network):
"""test load_checkpoint_with_safetensors function"""
# setup mocks using context managers
with patch('mindformers.utils.load_checkpoint_utils._check_checkpoint_path') as mock_check_checkpoint_path, \
patch('mindformers.utils.load_checkpoint_utils._get_checkpoint_mode') as mock_get_checkpoint_mode, \
patch('mindformers.utils.load_checkpoint_utils.'
'validate_config_with_file_mode') as mock_validate_config_with_file_mode, \
patch('mindformers.utils.load_checkpoint_utils.compile_model') as mock_compile_model, \
patch('mindformers.utils.load_checkpoint_utils.validate_qkv_concat'), \
patch('mindformers.utils.load_checkpoint_utils.process_for_stand_alone_mode'), \
patch('mindformers.utils.load_checkpoint_utils.'
'get_merged_dst_strategy_path') as mock_get_merged_dst_strategy_path, \
patch('mindformers.utils.load_checkpoint_utils.'
'load_safetensors_checkpoint') as mock_load_safetensors_checkpoint, \
patch('mindformers.utils.load_checkpoint_utils.logger'), \
patch('mindformers.utils.load_checkpoint_utils.barrier'):
# setup mocks return values
mock_check_checkpoint_path.return_value = "/valid/checkpoint"
mock_get_checkpoint_mode.return_value = CheckpointFileMode.SINGLE_CHECKPOINT_FILE.value
mock_get_merged_dst_strategy_path.return_value = "/path/to/merged/strategy"

# setup input_data and optimizer
input_data = MagicMock()
optimizer = None

# test with do_eval=True
load_checkpoint_with_safetensors(mock_config, mock_model, mock_network, input_data, do_eval=True,
do_predict=False,
optimizer=optimizer)
mock_check_checkpoint_path.assert_called_once()
mock_get_checkpoint_mode.assert_called_once()
mock_validate_config_with_file_mode.assert_called_once()
mock_load_safetensors_checkpoint.assert_called_once()

# test with do_predict=True
mock_load_safetensors_checkpoint.reset_mock()
load_checkpoint_with_safetensors(mock_config, mock_model, mock_network, input_data, do_eval=False,
do_predict=True,
optimizer=optimizer)
mock_load_safetensors_checkpoint.assert_called_once()

# test with use_parallel=True
mock_config.use_parallel = True
mock_load_safetensors_checkpoint.reset_mock()
mock_compile_model.reset_mock()
load_checkpoint_with_safetensors(mock_config, mock_model, mock_network, input_data, do_eval=False,
do_predict=False,
optimizer=optimizer)
mock_compile_model.assert_called_once()
mock_load_safetensors_checkpoint.assert_called_once()

# test with resume_training=True
mock_config.resume_training = True
# Access protected member for testing purposes
# pylint: disable=W0212
mock_model._train_network = MagicMock()
mock_load_safetensors_checkpoint.reset_mock()
load_checkpoint_with_safetensors(mock_config, mock_model, mock_network, input_data, do_eval=False,
do_predict=False,
optimizer=optimizer)
mock_load_safetensors_checkpoint.assert_called_once()


class TestBuildModel:
"""A test class for testing build_model"""
runner_config = {'sink_mode': True, 'epochs': 1, 'sink_size': 1}
config = {'runner_config': runner_config}
config = {
'runner_config': runner_config,
'context': {'mode': 0} # 0 is typically ms.GRAPH_MODE, 1 is ms.PYNATIVE_MODE
}
model = MagicMock()
dataset = MagicMock()

@@ -92,18 +864,16 @@ class TestBuildModel:
"""test build model infer predict layout when do predict is true"""
mock_get_auto_parallel_context.return_value = 'auto_parallel'
config = MindFormerConfig(**self.config)
model = MagicMock()
dataset = MagicMock()
compile_model(
model=model,
dataset=dataset,
model=self.model,
dataset=self.dataset,
mode=config.context.mode,
sink_mode=config.runner_config.sink_mode,
epoch=config.runner_config.epochs,
sink_size=config.runner_config.sink_size,
do_eval=False, do_predict=True
)
model.infer_predict_layout.assert_called_once_with(*dataset)
self.model.infer_predict_layout.assert_called_once_with(*self.dataset)

@patch('mindspore.context.get_auto_parallel_context')
def test_build_model_model_build(self, mock_get_auto_parallel_context):
@@ -124,6 +894,10 @@ class TestBuildModel:

class TestGetCheckpointMode:
"""A test class for testing get_checkpoint_mode"""

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch('os.path.isfile')
@patch('os.path.isdir')
def test_single_checkpoint_file(self, mock_isdir, mock_isfile):
@@ -134,49 +908,351 @@ class TestGetCheckpointMode:
config.load_checkpoint = '/test/checkpoint_file.safetensors'
assert _get_checkpoint_mode(config) == CheckpointFileMode.SINGLE_CHECKPOINT_FILE.value

@patch('os.path.isfile')
@patch('os.path.isdir')
def test_multi_checkpoint_file_with_rank_id(self, mock_isdir, mock_isfile):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_multi_checkpoint_file_with_rank_id(self):
"""test multi checkpoint file with rank id"""
mock_isfile.return_value = False
mock_isdir.return_value = True
with patch('os.listdir', return_value=['rank_0']):
config = type('', (), {})()
config.load_checkpoint = '/test/checkpoint_dir/'
assert _get_checkpoint_mode(config) == CheckpointFileMode.MULTI_CHECKPOINT_FILE_WITH_RANK_ID.value
with patch('os.path.isfile') as mock_isfile, \
patch('os.path.isdir') as mock_isdir:
mock_isfile.return_value = False
mock_isdir.return_value = True
with patch('os.listdir', return_value=['rank_0']):
config = type('', (), {})()
config.load_checkpoint = '/test/checkpoint_dir/'
assert _get_checkpoint_mode(config) == CheckpointFileMode.MULTI_CHECKPOINT_FILE_WITH_RANK_ID.value

@patch('os.path.isfile')
@patch('os.path.isdir')
def test_multi_checkpoint_file(self, mock_isdir, mock_isfile):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_multi_checkpoint_file(self):
""" test multi checkpoint file"""
mock_isfile.return_value = False
mock_isdir.return_value = True
with patch('os.listdir', return_value=['checkpoint.safetensors']):
config = type('', (), {})()
config.load_checkpoint = '/test/checkpoint_dir/'
config.load_ckpt_format = '.safetensors'
assert _get_checkpoint_mode(config) == CheckpointFileMode.MULTI_CHECKPOINT_FILE.value
with patch('os.path.isfile') as mock_isfile, \
patch('os.path.isdir') as mock_isdir:
mock_isfile.return_value = False
mock_isdir.return_value = True
with patch('os.listdir', return_value=['checkpoint.safetensors']):
config = type('', (), {})()
config.load_checkpoint = '/test/checkpoint_dir/'
config.load_ckpt_format = '.safetensors'
assert _get_checkpoint_mode(config) == CheckpointFileMode.MULTI_CHECKPOINT_FILE.value

@patch('os.path.isfile')
@patch('os.path.isdir')
def test_invalid_path(self, mock_isdir, mock_isfile):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_invalid_path(self):
"""test invalid path"""
mock_isfile.return_value = False
mock_isdir.return_value = False
config = type('', (), {})()
config.load_checkpoint = 'invalid_path'
with pytest.raises(ValueError, match="Provided path is neither a file nor a directory."):
_get_checkpoint_mode(config)

@patch('os.path.isfile')
@patch('os.path.isdir')
def test_no_valid_checkpoint_files(self, mock_isdir, mock_isfile):
"""test no valid checkpoint files"""
mock_isfile.return_value = False
mock_isdir.return_value = True
with patch('os.listdir', return_value=['not_a_checkpoint_file']):
with patch('os.path.isfile') as mock_isfile, \
patch('os.path.isdir') as mock_isdir:
mock_isfile.return_value = False
mock_isdir.return_value = False
config = type('', (), {})()
config.load_checkpoint = '/test/checkpoint_dir/'
config.load_ckpt_format = '.safetensors'
with pytest.raises(ValueError, match="not support mode: no valid checkpoint files found"):
config.load_checkpoint = 'invalid_path'
with pytest.raises(ValueError, match="Provided path is neither a file nor a directory."):
_get_checkpoint_mode(config)

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_no_valid_checkpoint_files(self):
"""test no valid checkpoint files"""
with patch('os.path.isfile') as mock_isfile, \
patch('os.path.isdir') as mock_isdir:
mock_isfile.return_value = False
mock_isdir.return_value = True
with patch('os.listdir', return_value=['not_a_checkpoint_file']):
config = type('', (), {})()
config.load_checkpoint = '/test/checkpoint_dir/'
config.load_ckpt_format = '.safetensors'
with pytest.raises(ValueError, match="not support mode: no valid checkpoint files found"):
_get_checkpoint_mode(config)


class TestCheckpointUtils:
"""A test class for testing checkpoint utils functions"""

# Test get_checkpoint_iter_dir function
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_checkpoint_iter_dir(self):
"""test get_checkpoint_iter_dir function"""
checkpoints_path = '/test/checkpoints'
iteration = 123
result = get_checkpoint_iter_dir(checkpoints_path, iteration)
# Use os.path.normpath to handle different path separators
assert os.path.normpath(result) == os.path.normpath('/test/checkpoints/iteration_00000123')

# Test with different iteration format
iteration = 1000
result = get_checkpoint_iter_dir(checkpoints_path, iteration)
assert os.path.normpath(result) == os.path.normpath('/test/checkpoints/iteration_00001000')

# Test get_checkpoint_tracker_filename function
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_checkpoint_tracker_filename(self):
"""test get_checkpoint_tracker_filename function"""
checkpoints_path = '/test/checkpoints'
result = get_checkpoint_tracker_filename(checkpoints_path)
assert os.path.normpath(result) == os.path.normpath('/test/checkpoints/latest_checkpointed_iteration.txt')

# Test get_common_filename function
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_common_filename(self):
"""test get_common_filename function"""
checkpoints_path = '/test/checkpoints'
iteration = 123
result = get_common_filename(checkpoints_path, iteration)
assert os.path.normpath(result) == os.path.normpath('/test/checkpoints/iteration_00000123/common.json')

# Test get_metadata_filename function
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_metadata_filename(self):
"""test get_metadata_filename function"""
checkpoints_path = '/test/checkpoints'
iteration = 123
result = get_metadata_filename(checkpoints_path, iteration)
assert os.path.normpath(result) == os.path.normpath('/test/checkpoints/iteration_00000123/metadata.json')

# Test get_checkpoint_name function
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_checkpoint_name(self):
"""test get_checkpoint_name function"""
# Test with cur_iter_checkpoint_dir and user_prefix
cur_iter_checkpoint_dir = '/test/checkpoints/iteration_00000123'
user_prefix = 'model'
file_idx = 0
total_file_num = 1
file_type = FileType.MODEL
result = get_checkpoint_name(cur_iter_checkpoint_dir, user_prefix, file_idx, total_file_num, file_type)
expected = '/test/checkpoints/iteration_00000123/model-model-0000000-0000001'
assert os.path.normpath(result) == os.path.normpath(expected)

# Test with optimizer type
file_type = FileType.OPTIMIZER
result = get_checkpoint_name(cur_iter_checkpoint_dir, user_prefix, file_idx, total_file_num, file_type)
expected = '/test/checkpoints/iteration_00000123/model-opt-0000000-0000001'
assert os.path.normpath(result) == os.path.normpath(expected)

# Test without user_prefix
result = get_checkpoint_name(cur_iter_checkpoint_dir, None, file_idx, total_file_num, FileType.MODEL)
expected = '/test/checkpoints/iteration_00000123/model-0000000-0000001'
assert os.path.normpath(result) == os.path.normpath(expected)

# Test without cur_iter_checkpoint_dir
result = get_checkpoint_name(None, None, file_idx, total_file_num, FileType.MODEL)
expected = 'model-0000000-0000001'
assert result == expected

# Test sharded tensor related functions
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_sharded_tensor_shard_id_functions(self):
"""test sharded tensor shard id functions"""
param_name = 'model.layer.weight'
global_offset = (100, 200)

# Test get_sharded_tensor_shard_id
shard_id1 = get_sharded_tensor_shard_id(param_name, global_offset)
expected = "('model.layer.weight', (100, 200))"
assert shard_id1 == expected

# Test sharded_tensor_shard_id (duplicate function)
shard_id2 = sharded_tensor_shard_id(param_name, global_offset)
assert shard_id2 == expected
assert shard_id1 == shard_id2

# Test _reverse_sharded_tensor_shard_id function
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_reverse_sharded_tensor_shard_id(self):
"""test _reverse_sharded_tensor_shard_id function"""
# Test normal case
shard_id = "('model.layer.weight', (100, 200))"
param_name, global_offset = _reverse_sharded_tensor_shard_id(shard_id)
assert param_name == 'model.layer.weight'
assert global_offset == (100, 200)

# Test with empty offset
shard_id = "('model.layer.weight', ())"
param_name, global_offset = _reverse_sharded_tensor_shard_id(shard_id)
assert param_name == 'model.layer.weight'
assert not global_offset

# Test with single element offset
shard_id = "('model.layer.weight', (50,))"
param_name, global_offset = _reverse_sharded_tensor_shard_id(shard_id)
assert param_name == 'model.layer.weight'
assert global_offset == (50,)

# Test invalid shard id
with pytest.raises(ValueError):
_reverse_sharded_tensor_shard_id("invalid_shard_id")

# Test _get_shard_size function
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_shard_size(self):
"""test _get_shard_size function"""
# Test with float32 type (32 bits per float32)
local_shape = (100, 200)
dtype = 'Float32'
expected = 100 * 200 * 32 # 100*200 elements * 32 bits each
assert _get_shard_size(local_shape, dtype) == expected

# Test with int8 type (8 bits per int8)
dtype = 'Int8'
expected = 100 * 200 * 8 # 100*200 elements * 8 bits each
assert _get_shard_size(local_shape, dtype) == expected

# Test with unknown dtype (should default to 16 bits)
dtype = 'UnknownType'
expected = 100 * 200 * 16 # 100*200 elements * 16 bits default
assert _get_shard_size(local_shape, dtype) == expected

# Test with empty shape
local_shape = ()
dtype = 'Float32'
expected = 1 * 32 # scalar, 32 bits
assert _get_shard_size(local_shape, dtype) == expected

# Test verify_ckpt_valid function with tmp_path
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_verify_ckpt_valid(self, tmp_path):
"""test verify_ckpt_valid function"""
# Test with valid directory containing safetensors file
ckpt_dir = tmp_path / "valid_ckpt"
ckpt_dir.mkdir()
safetensor_file = ckpt_dir / "model.safetensors"
safetensor_file.touch()
assert verify_ckpt_valid(str(ckpt_dir)) is None

# Test with valid directory containing metadata and safetensors file
ckpt_dir = tmp_path / "valid_ckpt_with_metadata"
ckpt_dir.mkdir()
metadata_path = ckpt_dir / "metadata.json"
metadata_content = {
"storage_data": {
"param1": [{
"file_name": "model.safetensors"
}]
}
}

with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata_content, f)
safetensor_file = ckpt_dir / "model.safetensors"
safetensor_file.touch()
assert verify_ckpt_valid(str(ckpt_dir)) is None

# Test with invalid directory (not exists)
with pytest.raises(NotADirectoryError):
verify_ckpt_valid("/non/existent/directory")

# Test with directory containing no files
ckpt_dir = tmp_path / "empty_ckpt"
ckpt_dir.mkdir()
with pytest.raises(FileNotFoundError):
verify_ckpt_valid(str(ckpt_dir))

# Test with metadata referencing missing safetensors file
ckpt_dir = tmp_path / "invalid_metadata_ckpt"
ckpt_dir.mkdir()
metadata_path = ckpt_dir / "metadata.json"
metadata_content = {
"storage_data": {
"param1": [{
"file_name": "missing.safetensors"
}]
}
}
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata_content, f)
with pytest.raises(FileNotFoundError):
verify_ckpt_valid(str(ckpt_dir))

# Test with invalid metadata json
ckpt_dir = tmp_path / "invalid_json_ckpt"
ckpt_dir.mkdir()
metadata_path = ckpt_dir / "metadata.json"
with open(metadata_path, 'w', encoding='utf-8') as f:
f.write("invalid json content")
with pytest.raises(RuntimeError):
verify_ckpt_valid(str(ckpt_dir))

# Test check_checkpoints_dir_max_num function with tmp_path
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_checkpoints_dir_max_num(self, tmp_path):
"""test check_checkpoints_dir_max_num function"""
# Create test directory structure
checkpoints_root_path = tmp_path / "checkpoints"
checkpoints_root_path.mkdir()

# Create more directories than max_keep_num
for i in range(5):
dir_path = checkpoints_root_path / f"iteration_{i:08d}"
dir_path.mkdir()

# Test with max_keep_num = 3, should keep newest 3 directories
check_checkpoints_dir_max_num(3, str(checkpoints_root_path))

# Verify only 3 directories remain
remaining_dirs = list(checkpoints_root_path.iterdir())
remaining_dirs.sort()
assert len(remaining_dirs) == 3
assert [d.name for d in remaining_dirs] == ["iteration_00000002", "iteration_00000003", "iteration_00000004"]

# Test with max_keep_num larger than existing directories
check_checkpoints_dir_max_num(10, str(checkpoints_root_path))
remaining_dirs = list(checkpoints_root_path.iterdir())
assert len(remaining_dirs) == 3

# Test get_latest_iteration_from_tracker function with tmp_path
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_latest_iteration_from_tracker(self, tmp_path):
"""test get_latest_iteration_from_tracker function"""
checkpoints_path = tmp_path / "checkpoints"
checkpoints_path.mkdir()

# Create tracker file
tracker_path = checkpoints_path / "latest_checkpointed_iteration.txt"
tracker_path.write_text("123", encoding="utf-8")

# Create corresponding directory
iter_dir = checkpoints_path / "iteration_00000123"
iter_dir.mkdir()

# Test normal case
assert get_latest_iteration_from_tracker(str(checkpoints_path)) == 123

# Test with missing tracker file
tracker_path.unlink()
with pytest.raises(FileNotFoundError):
get_latest_iteration_from_tracker(str(checkpoints_path))

# Test with invalid iteration number in tracker file
tracker_path.write_text("invalid_iter", encoding="utf-8")
with pytest.raises(ValueError):
get_latest_iteration_from_tracker(str(checkpoints_path))

# Test with missing iteration directory
tracker_path.write_text("456", encoding="utf-8")
with pytest.raises(FileNotFoundError):
get_latest_iteration_from_tracker(str(checkpoints_path))

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save
Baidu
map