14 Commits

Author SHA1 Message Date
  i-robot 62526e2d43
!7823 【master】重构max_logits监控流程,从参数分组逻辑中解耦出来 1 week ago
  i-robot 739411b385
!7822 【master】【bugfix】【日志】权重相关日志,在raise Error之前添加logger.error,确保在error.log中有对应日志 1 week ago
  i-robot 211024cc01
!7802 【master】【bugfix】【muon】add sharded_state_dict for muon op group 1 week ago
  SaiYao 802f3516c5 【master】【bugfix】【日志】权重相关日志,在raise Error之前添加logger.error,确保在error.log中有对应日志 1 week ago
  husichao 43fd826ae2 add sharded_state_dict for muon op group 1 week ago
  JavaZero 654512dc5e Refactor max attention logit handling in GPT model and related components 1 week ago
  i-robot d40514ad01
!7820 【master】【bugfix】【日志】权重相关日志,在raise Error之前添加logger.error,确保在error.log中有对应日志 1 week ago
  i-robot d70bab2d21
!7811 【master】【bugfix】【覆盖率】mindformers用例覆盖率较低,需补充用例 1 week ago
  i-robot 325b311a98
!7796 【master】【覆盖率】add callback testcase 1 week ago
  i-robot 84de277056
!7782 【master】【cleancode】整改代码重复率过大问题 1 week ago
  SaiYao 33a9a49d40 【master】【bugfix】【日志】权重相关日志,在raise Error之前添加logger.error,确保在error.log中有对应日志 1 week ago
  zyw_hw 05f92de6f9 add callback testcase 2 weeks ago
  Yule100 538d4d2793 bugfix 补充ut 1 week ago
  zyw_hw 0b3a3b5576 fix huge cc 2 weeks ago
52 changed files with 6955 additions and 663 deletions
Split View
  1. +0
    -1
      mindformers/__init__.py
  2. +8
    -13
      mindformers/core/callback/callback.py
  3. +40
    -37
      mindformers/core/optim/__init__.py
  4. +2
    -2
      mindformers/core/optim/fused_pma_adamw.py
  5. +1
    -3
      mindformers/core/optim/muon.py
  6. +0
    -1
      mindformers/modules/__init__.py
  7. +21
    -37
      mindformers/modules/layers.py
  8. +3
    -3
      mindformers/modules/quantizers/base.py
  9. +1
    -11
      mindformers/modules/quantizers/ptq_quantizer.py
  10. +0
    -11
      mindformers/modules/quantizers/rtn_quantizer.py
  11. +0
    -1
      mindformers/modules/transformer/__init__.py
  12. +20
    -279
      mindformers/modules/transformer/transformer.py
  13. +28
    -25
      mindformers/parallel_core/inference/tensor_parallel/grouped_layers.py
  14. +10
    -10
      mindformers/parallel_core/inference/utils.py
  15. +116
    -37
      mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py
  16. +23
    -26
      mindformers/parallel_core/training_graph/loss_func.py
  17. +129
    -0
      mindformers/parallel_core/training_graph/tensor_parallel/layers.py
  18. +25
    -0
      mindformers/parallel_core/training_graph/transformer/moe/ffn.py
  19. +15
    -0
      mindformers/parallel_core/training_graph/transformer/moe/router.py
  20. +15
    -0
      mindformers/parallel_core/training_graph/transformer/moe/shared_experts.py
  21. +3
    -0
      mindformers/parallel_core/training_graph/transformer/multi_token_prediction.py
  22. +89
    -14
      mindformers/parallel_core/training_graph/transformer/norm.py
  23. +1
    -1
      mindformers/parallel_core/training_graph/transformer/utils.py
  24. +7
    -2
      mindformers/parallel_core/utils/model_mixin.py
  25. +59
    -37
      mindformers/tools/ckpt_transform/transform_checkpoint.py
  26. +39
    -18
      mindformers/tools/resume_ckpt.py
  27. +11
    -6
      mindformers/trainer/base_trainer.py
  28. +0
    -12
      mindformers/trainer/optimizer_grouped_parameters.py
  29. +37
    -15
      mindformers/trainer/utils.py
  30. +56
    -22
      mindformers/utils/load_checkpoint_utils.py
  31. +3
    -1
      mindformers/utils/resume_ckpt_utils.py
  32. +1
    -1
      mindformers/wrapper/wrapper.py
  33. +6
    -3
      tests/st/test_safetensors/test_checkpoint_utils.py
  34. +9
    -8
      tests/st/test_ut/test_api_compatibility.py
  35. +1279
    -1
      tests/st/test_ut/test_core/test_callback/test_checkpoint_monitor.py
  36. +303
    -0
      tests/st/test_ut/test_core/test_callback/test_cold_hot_expert_monitor.py
  37. +96
    -0
      tests/st/test_ut/test_core/test_callback/test_expert_migrate_callback.py
  38. +422
    -0
      tests/st/test_ut/test_core/test_callback/test_helper_functions.py
  39. +829
    -5
      tests/st/test_ut/test_core/test_callback/test_mfloss_monitor.py
  40. +302
    -0
      tests/st/test_ut/test_core/test_callback/test_other_callbacks.py
  41. +137
    -1
      tests/st/test_ut/test_core/test_callback/test_profile_monitor.py
  42. +653
    -0
      tests/st/test_ut/test_core/test_callback/test_stress_test_monitor.py
  43. +1050
    -0
      tests/st/test_ut/test_core/test_callback/test_training_state_monitor.py
  44. +178
    -0
      tests/st/test_ut/test_core/test_optim/test_get_op_group.py
  45. +2
    -2
      tests/st/test_ut/test_model_mixin.py
  46. +57
    -0
      tests/st/test_ut/test_models/test_base_model/test_base_model.py
  47. +79
    -0
      tests/st/test_ut/test_models/test_build_models/test_build_model.py
  48. +147
    -0
      tests/st/test_ut/test_models/test_build_models/test_utils.py
  49. +116
    -0
      tests/st/test_ut/test_tools/test_check_rules.py
  50. +482
    -0
      tests/st/test_ut/test_tools/test_utils/test_utils.py
  51. +3
    -17
      tests/st/test_ut/test_transformer_apis.py
  52. +42
    -0
      tests/st/test_ut/test_utils/test_import_utils.py

+ 0
- 1
mindformers/__init__.py View File

@@ -157,7 +157,6 @@ from mindformers.modules import (
AlibiTensorV2,
Dropout,
EmbeddingOpParallelConfig,
FeedForward,
FixedSparseAttention,
LayerNorm,
Linear,


+ 8
- 13
mindformers/core/callback/callback.py View File

@@ -58,6 +58,7 @@ from mindspore.communication.comm_func import all_gather_into_tensor, barrier
from mindspore.profiler import ProfilerLevel, schedule
from mindspore.utils import stress_detect

from mindformers.wrapper.wrapper import get_real_models
from mindformers.checkpoint.sharded_tensor import get_all_sharded_tensor
from mindformers.core.context.build_context import is_legacy_model
from mindformers.tools import get_output_root_path
@@ -1291,25 +1292,19 @@ class TrainingStateMonitor(Callback):

def _dump_max_attention_logit(self, cb_params):
"""write the max attention logit to log/tensorboard"""
if cb_params.optimizer is not None:
cb_optimizer = cb_params.optimizer
else:
cb_optimizer = cb_params.network.optimizer
params = cb_optimizer._parameters # pylint: disable=W0212
network = cb_params.train_network
network = get_real_models(network)
params = network.get_max_attention_logit()

if not params:
return
step = cb_params.cur_step_num
vals = []
for param in params:
name = getattr(param, "name", "")
if "max_logits_val" not in name:
continue

t = param.value()
v = t.asnumpy().squeeze()
for param_name, param in params.items():
v = param.asnumpy().squeeze()
v = v / max(1, self.micro_batch_num)

tag = f"max_attention_logit/{name}"
tag = f"max_attention_logit/{param_name}"
if 'log' in self.max_attention_logit_format:
self._output(tag, v.tolist(), step, ['log'])
if 'tensorboard' in self.max_attention_logit_format:


+ 40
- 37
mindformers/core/optim/__init__.py View File

@@ -28,47 +28,50 @@ __all__ = ['AdamW', 'PmaAdamW', 'Muon']

@MindFormerRegister.register(MindFormerModuleType.OPTIMIZER)
class AdamW:
r"""
"""
This is the implementation of AdamW.

.. math::
\begin{array}{l}
&\newline
&\hline \\
&\textbf{Parameters}: \: 1^{\text {st }}\text {moment vector} \: m , \: 2^{\text {nd}} \:
\text{moment vector} \: v , \\
&\: gradients \: g, \: \text{learning rate} \: \gamma,
\text {exponential decay rates for the moment estimates} \: \beta_{1} \: \beta_{2} , \\
&\:\text {parameter vector} \: w_{0}, \:\text{timestep} \: t, \: \text{weight decay} \: \lambda \\
&\textbf{Init}: m_{0} \leftarrow 0, \: v_{0} \leftarrow 0, \: t \leftarrow 0, \:
\text{init parameter vector} \: w_{0} \\[-1.ex]
&\newline
&\hline \\
&\textbf{repeat} \\
&\hspace{5mm} t \leftarrow t+1 \\
&\hspace{5mm}\boldsymbol{g}_{t} \leftarrow \nabla f_{t}\left(\boldsymbol{w}_{t-1}\right) \\
&\hspace{5mm}\boldsymbol{w}_{t} \leftarrow \boldsymbol{w}_{t-1}-\gamma\lambda\boldsymbol{w}_{t-1} \\
&\hspace{5mm}\boldsymbol{m}_{t} \leftarrow \beta_{1} \boldsymbol{m}_{t-1}+\left(1-\beta_{1}\right)
\boldsymbol{g}_{t} \\
&\hspace{5mm}\boldsymbol{v}_{t} \leftarrow \beta_{2} \boldsymbol{v}_{t-1}+\left(1-\beta_{2}\right)
\boldsymbol{g}_{t}^{2} \\
&\hspace{5mm}\widehat{\boldsymbol{m}_{t}} \leftarrow \boldsymbol{m}_{t}/\big(1-\beta_{1}^{t} \big) \\
&\hspace{5mm}\widehat{\boldsymbol{v}_{t}} \leftarrow \boldsymbol{v}_{t}/\big(1-\beta_{2}^{t} \big) \\
&\hspace{5mm}\boldsymbol{w}_{t} \leftarrow \boldsymbol{w}_{t-1}-\gamma\widehat{\boldsymbol{m}_{t}}
/\left(\sqrt{\widehat{\boldsymbol{v}_{t}}}+\epsilon\right) \\
&\textbf{until}\text { stopping criterion is met } \\[-1.ex]
&\newline
&\hline \\[-1.ex]
&\textbf{return} \: \boldsymbol{w}_{t} \\[-1.ex]
&\newline
&\hline \\[-1.ex]
\end{array}
\\begin{array}{l}
&\\newline
&\\hline \\\\
&\\textbf{Parameters}: \\: 1^{\\text {st }}\\text {moment vector} \\: m , \\: 2^{\\text {nd}} \\:
\\text{moment vector} \\: v , \\\\
&\\: gradients \\: g, \\: \\text{learning rate} \\: \\gamma,
\\text {exponential decay rates for the moment estimates} \\: \\beta_{1} \\: \\beta_{2} , \\\\
&\\:\\text {parameter vector} \\: w_{0}, \\:\\text{timestep} \\: t, \\: \\text{weight decay} \\: \\lambda \\\\
&\\textbf{Init}: m_{0} \\leftarrow 0, \\: v_{0} \\leftarrow 0, \\: t \\leftarrow 0, \\:
\\text{init parameter vector} \\: w_{0} \\\\[-1.ex]
&\\newline
&\\hline \\\\
&\\textbf{repeat} \\\\
&\\hspace{5mm} t \\leftarrow t+1 \\\\
&\\hspace{5mm}\\boldsymbol{g}_{t} \\leftarrow \\nabla f_{t}\\left(\\boldsymbol{w}_{t-1}\\right) \\\\
&\\hspace{5mm}\\boldsymbol{w}_{t} \\leftarrow \\boldsymbol{w}_{t-1}-\\gamma\\lambda
\\boldsymbol{w}_{t-1} \\\\
&\\hspace{5mm}\\boldsymbol{m}_{t} \\leftarrow \\beta_{1} \\boldsymbol{m}_{t-1}+\\left(1-\\beta_{1}\\right)
\\boldsymbol{g}_{t} \\\\
&\\hspace{5mm}\\boldsymbol{v}_{t} \\leftarrow \\beta_{2} \\boldsymbol{v}_{t-1}+\\left(1-\\beta_{2}\\right)
\\boldsymbol{g}_{t}^{2} \\\\
&\\hspace{5mm}\\widehat{\\boldsymbol{m}_{t}} \\leftarrow \\boldsymbol{m}_{t}/
\\big(1-\\beta_{1}^{t} \\big) \\\\
&\\hspace{5mm}\\widehat{\\boldsymbol{v}_{t}} \\leftarrow \\boldsymbol{v}_{t}/
\\big(1-\\beta_{2}^{t} \\big) \\\\
&\\hspace{5mm}\\boldsymbol{w}_{t} \\leftarrow \\boldsymbol{w}_{t-1}-\\gamma\\widehat{\\boldsymbol{m}_{t}}
/\\left(\\sqrt{\\widehat{\\boldsymbol{v}_{t}}}+\\epsilon\\right) \\\\
&\\textbf{until}\\text { stopping criterion is met } \\\\[-1.ex]
&\\newline
&\\hline \\\\[-1.ex]
&\\textbf{return} \\: \\boldsymbol{w}_{t} \\\\[-1.ex]
&\\newline
&\\hline \\\\[-1.ex]
\\end{array}

:math:`m` represents the first moment vector moment1, :math:`v` represents the second moment vector moment2,
:math:`\widehat{m}` represents the bias-corrected first moment vector, :math:`\widehat{v}` represents
the bias-corrected second moment vector, :math:`g` represents gradients, :math:`\gamma` represents
learning_rate, :math:`\beta_1`, `\beta_2` represent beta1 and beta2, :math:`t` represents the current step,
:math:`w` represents params, and :math:`\lambda` represents weight_decay.
:math:`\\widehat{m}` represents the bias-corrected first moment vector, :math:`\\widehat{v}` represents
the bias-corrected second moment vector, :math:`g` represents gradients, :math:`\\gamma` represents
learning_rate, :math:`\\beta_1`, `\\beta_2` represent beta1 and beta2, :math:`t` represents the current step,
:math:`w` represents params, and :math:`\\lambda` represents weight_decay.

Args:
params (Union[list[Parameter], list[dict]]): Must be list of `Parameter` or list of `dict`. When the
@@ -218,7 +221,7 @@ class AdamW:

@MindFormerRegister.register(MindFormerModuleType.OPTIMIZER)
class PmaAdamW:
r"""
"""
This is the implementation of PmAdamW.

Args:


+ 2
- 2
mindformers/core/optim/fused_pma_adamw.py View File

@@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""FusedPmaAdamW implementation"""
import mindspore.ops as ops
from mindspore import ops

from mindspore._checkparam import GT, INC_NEITHER
from mindspore import _checkparam as validator
@@ -76,7 +76,7 @@ def _check_param_value(fused_num, interleave_step, fused_algo, ema_alpha, prim_n


class FusedPmaAdamW(FusedAdamW):
r"""
"""
This is the implementation of PmaAdamW that uses fused operators.

Args:


+ 1
- 3
mindformers/core/optim/muon.py View File

@@ -442,9 +442,7 @@ class Muon(Optimizer):

def _initialize_op_groups(self, model):
"""Initialize optimizer parallel groups for parameters."""
self.ops, self.op_groups = model.get_op_groups_info(
self._parameters, self.op, self.op_group, self.op_in_tp_group
)
self.ops, self.op_groups = model.get_op_groups_info(self._parameters, self.op)

def _create_communication_group(self, rank_list):
"""


+ 0
- 1
mindformers/modules/__init__.py View File

@@ -15,7 +15,6 @@
"""MindFormers Transformers API."""
from .transformer import (
EmbeddingOpParallelConfig,
FeedForward,
LowerTriangularMaskWithDynamic,
MoEConfig,
OpParallelConfig,


+ 21
- 37
mindformers/modules/layers.py View File

@@ -49,6 +49,8 @@ from mindformers.tools.logger import logger
from mindformers.tools.utils import is_pynative
from mindformers.modules.activation import get_activation
from mindformers.modules.transformer.op_parallel_config import default_dpmp_config, OpParallelConfig, MoEParallelConfig
from mindformers.parallel_core.training_graph.base_models.common.embeddings.yarn_rotary_pos_embedding import \
_yarn_find_correction_range

__all__ = [
"FixedSparseAttention",
@@ -177,7 +179,6 @@ class _LayerInputCheck:
Check the input shape's is equal to the expected shape, the value on 0-th is viewed as batch, and the
batch size will not be checked.
"""
target_shape = target_shape
length, hidden = target_shape
if isinstance(input_shape, tuple):
input_shape = list(input_shape)
@@ -244,11 +245,9 @@ class Dropout(nn.Cell):
"""

def __init__(self, keep_prob=0.5, dtype=mstype.float32):
super(Dropout, self).__init__()
super().__init__()
if keep_prob <= 0 or keep_prob > 1:
raise ValueError(
"dropout probability should be a number in range (0, 1], but got {}".format(
keep_prob))
raise ValueError(f"dropout probability should be a number in range (0, 1], but got {keep_prob}")
Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name)
self.keep_prob = keep_prob
@@ -269,7 +268,7 @@ class Dropout(nn.Cell):
return out

def extend_repr(self):
return 'keep_prob={}'.format(self.keep_prob)
return f'keep_prob={self.keep_prob}'

def shard(self, strategy):
self.dropout.shard(strategy)
@@ -291,10 +290,10 @@ class LayerNorm(Cell):
"""

def __init__(self, normalized_shape, eps=1e-5, param_init_type=mstype.float32, is_self_defined=False):
super(LayerNorm, self).__init__()
super().__init__()
if param_init_type not in [mstype.float32, mstype.float16, mstype.bfloat16]:
raise TypeError("The type of parameter 'param_init_type' should in [float32, float16], "
"but got the type : {}.".format(type(param_init_type)))
raise TypeError(f"The type of parameter 'param_init_type' should in [float32, float16], "
f"but got the type : {type(param_init_type)}.")
# Since the mindspore 1.10 version, the layernorm has been changed to P.LayerNorm
self.is_self_defined = is_self_defined
if not self.is_self_defined:
@@ -441,7 +440,7 @@ class Linear(Cell):
use_gmm=False,
param_init_type=mstype.float32,
compute_dtype=mstype.float16):
super(Linear, self).__init__()
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
if not (isinstance(activation, str) or activation is None or issubclass(activation, nn.Cell)):
@@ -465,6 +464,7 @@ class Linear(Cell):
self.weight = Parameter(initializer(weight_init, [self.expert_num] + weight_shape, param_init_type),
name="weight")
if self.use_gmm:
# pylint: disable=import-outside-toplevel
from mindspore.ops.auto_generate import GroupedMatmul
# split_item only supports 0 and 3 now, 0 means the size of tensorlist not equal to 1,
# 3 means the size of tensorlist is 1.
@@ -676,7 +676,7 @@ class FixedSparseAttention(nn.Cell):
seq_length=1024,
num_different_global_patterns=4,
parallel_config=default_dpmp_config):
super(FixedSparseAttention, self).__init__()
super().__init__()
dp, mp = parallel_config.data_parallel, parallel_config.model_parallel
if num_heads % mp != 0:
raise ValueError(f"The number of heads {num_heads} must be a "
@@ -700,17 +700,17 @@ class FixedSparseAttention(nn.Cell):
self.parallel_config = parallel_config
size_per_head_list = [64, 128]
if self.seq_length != 1024:
raise ValueError("For 'FixedSparseAttention', the class variable 'seq_length' must be 1024, "
"but got the value : {}.".format(seq_length))
raise ValueError(f"For 'FixedSparseAttention', the class variable 'seq_length' must be 1024, "
f"but got the value : {seq_length}.")
if self.block_size != 64:
raise ValueError("For 'FixedSparseAttention', the class variable 'block_size' must be 64, "
"but got the value : {}.".format(block_size))
raise ValueError(f"For 'FixedSparseAttention', the class variable 'block_size' must be 64, "
f"but got the value : {block_size}.")
if num_different_global_patterns != 4:
raise ValueError("For 'FixedSparseAttention', the class variable 'num_different_global_patterns' "
"must be 4, but got the value : {}".format(num_different_global_patterns))
raise ValueError(f"For 'FixedSparseAttention', the class variable 'num_different_global_patterns' "
f"must be 4, but got the value : {num_different_global_patterns}")
if self.size_per_head not in size_per_head_list:
raise ValueError("For 'FixedSparseAttention', the class variable 'size_per_head' only supports {}, "
"but got the value : {}.".format(size_per_head_list, self.size_per_head))
raise ValueError(f"For 'FixedSparseAttention', the class variable 'size_per_head' "
f"only supports {size_per_head_list}, but got the value : {self.size_per_head}.")
local_ones = np.ones((self.block_size, self.block_size),
dtype=np.float16)
global_mask_original = np.ones((self.seq_length, self.global_size), dtype=np.float16)
@@ -851,7 +851,7 @@ class AlibiTensor(nn.Cell):
"""

def __init__(self, seq_length, num_heads, parallel_config=default_dpmp_config):
super(AlibiTensor, self).__init__()
super().__init__()
dp = parallel_config.data_parallel

self.seq_length = seq_length
@@ -915,7 +915,7 @@ class AlibiTensorV2(nn.Cell):
"""

def __init__(self, num_heads):
super(AlibiTensorV2, self).__init__()
super().__init__()
self.num_heads = num_heads

self.expand_2d = P.ExpandDims()
@@ -1124,22 +1124,6 @@ def _check_linear_scaling_factor(scaling_factor):
raise ValueError(f"`scaling_factor`'s factor field must be a float >= 1, got {factor}")


def _yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
"""Inverse dim formula to find dim based on number of rotations"""
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))


def _yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
"""Find dim range bounds based on rotations"""
low = math.floor(
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
)
high = math.ceil(
_yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
)
return max(low, 0), min(high, dim - 1) # Clamp values just in case


def _yarn_get_mscale(scale=1, mscale=1):
if scale <= 1:
return 1.0


+ 3
- 3
mindformers/modules/quantizers/base.py View File

@@ -221,14 +221,14 @@ class Quantizer(ABC):

@abstractmethod
def _process_model_after_weight_loading(self, model, **kwargs):
pass
return model

@property
@abstractmethod
def is_serializable(self):
pass
return False

@property
@abstractmethod
def is_trainable(self):
pass
return False

+ 1
- 11
mindformers/modules/quantizers/ptq_quantizer.py View File

@@ -52,19 +52,9 @@ class PtqQuantizer(Quantizer):
def _process_model_before_weight_loading(
self, model: "PreTrainedModel", **kwargs
):
# pylint: disable=import-outside-toplevel
from mindspore_gs.ptq import PTQ
ptq = PTQ(config=self.quant_config, layer_policies=self.layer_policies)
model = ptq.apply(model)
model = ptq.convert(model)
return model

def _process_model_after_weight_loading(self, model, **kwargs):
return model

@property
def is_serializable(self):
return False

@property
def is_trainable(self):
return False

+ 0
- 11
mindformers/modules/quantizers/rtn_quantizer.py View File

@@ -65,14 +65,3 @@ class RtnQuantizer(Quantizer):
model = ptq.apply(model)
model = ptq.convert(model)
return model

def _process_model_after_weight_loading(self, model, **kwargs):
return model

@property
def is_serializable(self):
return False

@property
def is_trainable(self):
return False

+ 0
- 1
mindformers/modules/transformer/__init__.py View File

@@ -21,7 +21,6 @@ This is an experimental interface that is subject to change or deletion.

from .transformer import (
EmbeddingOpParallelConfig,
FeedForward,
LowerTriangularMaskWithDynamic,
TransformerOpParallelConfig,
TransformerRecomputeConfig,


+ 20
- 279
mindformers/modules/transformer/transformer.py View File

@@ -26,7 +26,6 @@ import mindspore as ms
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import Zero
from mindspore import nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
@@ -40,18 +39,14 @@ except ImportError:
from mindspore import log as logger
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
from mindspore.context import ParallelMode
from mindformers.modules.layers import Linear, _args_type_validator_check, _valid_type_checks, _valid_value_checks, \
_check_input_dtype
from mindformers.modules.transformer.op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig, \
_Config, _check_config, MoEParallelConfig
from mindformers.version_control import get_dropout
from mindformers.modules.transformer.op_parallel_config import _PipeLineConfig, OpParallelConfig, \
_Config, MoEParallelConfig

from mindformers.tools.logger import _LogActionOnce
from mindformers.tools.utils import is_pynative

__all__ = [
"LowerTriangularMaskWithDynamic",
"FeedForward",
"TransformerOpParallelConfig",
"EmbeddingOpParallelConfig",
"TransformerRecomputeConfig",
@@ -211,7 +206,7 @@ class TransformerSwapConfig(_Config):
if isinstance(layer_swap, dict):
layer_swap = [layer_swap]
if self._validate_layers_consistency(layer_swap):
return [dict(backward_prefetch=layer_swap[0][self.backward_prefetch], layers=True)]
return [{"backward_prefetch": layer_swap[0][self.backward_prefetch], "layers": True}]
return layer_swap

def _initialize_op_swap(self, op_swap):
@@ -225,7 +220,7 @@ class TransformerSwapConfig(_Config):
op_swap_dict = self.op_swap_to_dict(op_swap)
for k, v in op_swap_dict.items():
if self._validate_layers_consistency(v, mode=f'op_swap: {k}'):
op_swap_dict[k] = [dict(backward_prefetch=v[0][self.backward_prefetch], layers=True)]
op_swap_dict[k] = [{"backward_prefetch": v[0][self.backward_prefetch], "layers": True}]
return op_swap_dict

def _validate_layers_consistency(self, layer_swap, mode='layer_swap'):
@@ -283,17 +278,17 @@ class TransformerSwapConfig(_Config):
"""Adds an operation swap configuration to the dictionary."""
if key in dic:
dic[key].append(
dict(
layers=item.get(self.layers),
backward_prefetch=item.get(self.backward_prefetch)
)
{
'layers': item.get(self.layers),
'backward_prefetch': item.get(self.backward_prefetch)
}
)
else:
dic[key] = [
dict(
layers=item.get(self.layers),
backward_prefetch=item.get(self.backward_prefetch)
)
{
'layers': item.get(self.layers),
'backward_prefetch': item.get(self.backward_prefetch)
}
]
return dic

@@ -507,9 +502,9 @@ class ContextParallelAlgo(Enum):
Args:
Enum (str): chosses context parallel type
"""
colossalai_cp = "colossalai_cp"
ulysses_cp = "ulysses_cp"
hybrid_cp = "hybrid_cp"
COLOSSALAI_CP = "colossalai_cp"
ULYSSES_CP = "ulysses_cp"
HYBRID_CP = "hybrid_cp"


default_transformer_swap_config = TransformerSwapConfig()
@@ -601,7 +596,7 @@ class TransformerOpParallelConfig(_Config):
ValueError: in hybrid_cp algorithm, context_parallel should be divisible by ulysses_degree_in_cp
"""
if self.context_parallel == 1:
if self.context_parallel_algo != ContextParallelAlgo.colossalai_cp:
if self.context_parallel_algo != ContextParallelAlgo.COLOSSALAI_CP:
logger.warning(f"context_parallel_algo {self.context_parallel_algo.value} will not take effect "
"when context_parallel == 1.")
if self.ulysses_degree_in_cp > 1:
@@ -610,10 +605,10 @@ class TransformerOpParallelConfig(_Config):
return

# here context parallel > 1
if self.context_parallel_algo != ContextParallelAlgo.hybrid_cp and self.ulysses_degree_in_cp > 1:
if self.context_parallel_algo != ContextParallelAlgo.HYBRID_CP and self.ulysses_degree_in_cp > 1:
logger.warning(f"ulysses_degree_in_cp {self.ulysses_degree_in_cp} will not take effect when "
f"context_parallel_algo {self.context_parallel_algo.value} is not `hybrid_cp`.")
if (self.context_parallel_algo == ContextParallelAlgo.hybrid_cp and
if (self.context_parallel_algo == ContextParallelAlgo.HYBRID_CP and
self.context_parallel % self.ulysses_degree_in_cp != 0):
raise ValueError(f"When using hybrid_cp algorithm, context_parallel {self.context_parallel} "
f"should be divisible by ulysses_degree_in_cp {self.ulysses_degree_in_cp}. "
@@ -627,9 +622,9 @@ class TransformerOpParallelConfig(_Config):
"""
if self.context_parallel == 1:
return 1
if self.context_parallel_algo == ContextParallelAlgo.colossalai_cp:
if self.context_parallel_algo == ContextParallelAlgo.COLOSSALAI_CP:
return 1
if self.context_parallel_algo == ContextParallelAlgo.ulysses_cp:
if self.context_parallel_algo == ContextParallelAlgo.ULYSSES_CP:
return self.context_parallel
# hybird
return self.ulysses_degree_in_cp
@@ -786,260 +781,6 @@ class TransformerOpParallelConfig(_Config):
default_transformer_config = TransformerOpParallelConfig()


class FeedForward(Cell):
r"""
The multilayer perceptron with two linear layers with dropout applied at final output. The first linear
will project the input dimension from hidden_size to ffn_hidden_size. The second linear will project the
dimension from ffn_hidden_size to hidden_size. The first linear is sharded on the relative dimension,
and the second linear is sharded on the output dimension. The overview process can be:

.. math::
Dropout((xW_1+b_1)W_2 + b_2)

where the :math:`W_1, W_2, b_1` and :math:`b_2` are trainable parameters.

Args:
hidden_size (int): The dimension of the inputs.
ffn_hidden_size (int): The intermediate hidden size.
dropout_rate (float): The dropout rate for the second linear's output.
hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu',
'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument.
If user wants to run the net in the parallel mode, the custom activation must also provide
the `activation_shard` function. Please see examples. Default: gelu.
expert_num (int): The number of experts used in Linear. For the case expert_num > 1, BatchMatMul is used
and the first dimension in BatchMatMul indicate expert_num. Default: 1.
expert_group_size (int): The number of tokens in each data parallel group. Default: None. This parameter is
effective only when in AUTO_PARALLEL mode, and NOT SHARDING_PROPAGATION.
param_init_type (dtype.Number): The parameter initialization type. Should be mstype.float32 or
mstype.float16. Default: mstype.float32.
parallel_config (OpParallelConfig, MoEParallelConfig): The config of parallel setting, see
`OpParallelConfig` or `MoEParallelConfig`. When MoE is applied, MoEParallelConfig is effective,
otherwise OpParallelConfig is effective. Default `default_dpmp_config`,
an instance of `OpParallelConfig` with default args.

Inputs:
- **x** (Tensor) - should be `[batch, seq_length, hidden_size] or [batch * seq_length, hidden_size]`.
Float tensor.

Outputs:
Tensor, the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size] or
[batch * seq_length, hidden_size]`.

Raises:
TypeError: `hidden_act` is not a string or nn.Cell.
TypeError: `parallel_config` is not a subclass of OpParallelConfig.
ValueError: `ffn_hidden_size` is not a multiple of the model parallel way.
ValueError: `hidden_size` is not a multiple of the model parallel way.

Supported Platforms:
``Ascend`` ``GPU``

Examples:
>>> import numpy as np
>>> from mindformers.modules.transformer import FeedForward
>>> from mindspore import dtype as mstype
>>> from mindspore import Tensor, nn
>>> import mindspore.ops as ops
>>> model = FeedForward(hidden_size=15, ffn_hidden_size=30, dropout_rate=0.1)
>>> tensor = Tensor(np.ones((2, 20, 15)), mstype.float32)
>>> output = model(tensor)
>>> print(output.shape)
(2, 20, 15)
>>> # Example 2 using custom hidden activation
>>> class MyActivationNoShard(nn.Cell):
... def __init__(self):
... super(MyActivationNoShard, self).__init__()
... self.add = ops.Add()
... def construct(self, x):
... return self.add(x, 0.1)
>>> model = FeedForward(hidden_size=15, ffn_hidden_size=30, dropout_rate=0.1,
... hidden_act=MyActivationNoShard)
>>> tensor = Tensor(np.ones((2, 20, 15)), mstype.float32)
>>> output = model(tensor)
>>> print(output.shape)
(2, 20, 15)
>>> # Example 3 using custom hidden activation with activation_shard
>>> # If user wantss to run on the SEMI/AUTO parallel mode, the custom activation must provide
>>> # a class function named activation_shard. It accepts the argument parallel_config (OpParallelConfig,
>>> # MoEParallelConfig) and set the shard for the primitives used in the construct.
>>> class MyActivationWithShard(nn.Cell):
... def __init__(self):
... super(MyActivationWithShard, self).__init__()
... self.add = ops.Add()
... def construct(self, x):
... return self.add(x, 0.1)
... def activation_shard(self, parallel_config):
... self.add.shard(((parallel_config.data_parallel, parallel_config.model_parallel), ()))
>>>
>>> model = FeedForward(hidden_size=15, ffn_hidden_size=30, dropout_rate=0.1,
... hidden_act=MyActivationWithShard)
>>> tensor = Tensor(np.ones((2, 20, 15)), mstype.float32)
>>> output = model(tensor)
>>> print(output.shape)
(2, 20, 15)
"""

@_LogActionOnce(m_logger=logger, key='FeedForward',
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
@_args_type_validator_check(hidden_size=Validator.check_positive_int,
ffn_hidden_size=Validator.check_positive_int,
dropout_rate=Validator.check_non_negative_float,
param_init_type=_valid_value_checks([mstype.float32, mstype.bfloat16, mstype.float16],
"FeedForward"),
compute_dtype=_valid_value_checks([mstype.float32, mstype.bfloat16, mstype.float16],
"FeedForward"),
parallel_config=_valid_type_checks([OpParallelConfig, MoEParallelConfig],
"FeedForward"))
def __init__(self, hidden_size,
ffn_hidden_size,
dropout_rate,
hidden_act='gelu',
expert_num=1,
expert_group_size=None,
param_init_type=mstype.float32,
parallel_config=default_dpmp_config,
compute_dtype=mstype.float16):
super(FeedForward, self).__init__()
self.dtype = compute_dtype
if hidden_act is None or not (isinstance(hidden_act, str) or issubclass(hidden_act, nn.Cell)):
raise TypeError(f"For FeedForward cell, the hidden_act should str type or nn.Cell type, "
f"but got {hidden_act}.")
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
_check_config(parallel_config)
mp = parallel_config.model_parallel
if expert_num > 1:
ep = parallel_config.expert_parallel
else:
ep = 1
# ffn use less dp than other ops when use_moe, due to there are ops use dp and ep.
dp = parallel_config.data_parallel // ep
if ffn_hidden_size % mp != 0:
raise ValueError("For 'FeedForward', the class variable 'ffn_hidden_size' must be a multiple of the"
"num of model parallel, but got the ffn_hidden_size is {} and the num of model "
"parallel is {}.".format(ffn_hidden_size, mp))
if hidden_size % mp != 0:
raise ValueError("For 'FeedForward', the class variable 'hidden_size' must be a multiple of the num of "
"model parallel, but got the hidden_size is {} and the num of model parallel is {}."
.format(hidden_size, mp))
if dropout_rate < 0 or dropout_rate >= 1:
raise ValueError("For 'FeedForward', the class variable 'dropout_rate' must be in the range [0, 1.0), "
"but got the value : {}.".format(dropout_rate))
input_size = hidden_size
output_size = ffn_hidden_size

# Project to ffn_hidden_size
self.mapping = Linear(in_channels=input_size,
out_channels=output_size,
activation=hidden_act,
transpose_b=False,
expert_num=expert_num,
expert_group_size=expert_group_size,
outer_batch=dp,
param_init_type=param_init_type,
compute_dtype=compute_dtype)

# Project back to hidden_size
self.projection = Linear(in_channels=output_size,
out_channels=input_size,
transpose_b=False,
expert_num=expert_num,
expert_group_size=expert_group_size,
outer_batch=dp,
param_init_type=param_init_type,
compute_dtype=compute_dtype)
if expert_num > 1:
self.projection.shard(strategy_matmul=((dp, ep, 1, mp), (ep, mp, 1)))
else:
self.projection.shard(strategy_matmul=((dp, mp), (mp, 1)))
self.projection.bias.parallel_optimizer = False
self.dropout = get_dropout(dropout_rate)
self.dropout_3d = get_dropout(dropout_rate)
self.dropout_4d = get_dropout(dropout_rate)
self.cast = P.Cast()
else:
_check_config(parallel_config)
mp = parallel_config.model_parallel
if expert_num > 1:
ep = parallel_config.expert_parallel
else:
ep = 1
# ffn use less dp than other ops when use_moe, due to there are ops use dp and ep.
dp = parallel_config.data_parallel // ep
if ffn_hidden_size % mp != 0:
raise ValueError("For 'FeedForward', the class variable 'ffn_hidden_size' must be a multiple of the"
"num of model parallel, but got the ffn_hidden_size is {} and the num of model "
"parallel is {}.".format(ffn_hidden_size, mp))
if hidden_size % mp != 0:
raise ValueError("For 'FeedForward', the class variable 'hidden_size' must be a multiple of the num of "
"model parallel, but got the hidden_size is {} and the num of model parallel is {}."
.format(hidden_size, mp))
if dropout_rate < 0 or dropout_rate >= 1:
raise ValueError("For 'FeedForward', the class variable 'dropout_rate' must be in the range [0, 1.0), "
"but got the value : {}.".format(dropout_rate))
input_size = hidden_size
output_size = ffn_hidden_size

# Project to ffn_hidden_size
self.mapping = Linear(in_channels=input_size,
out_channels=output_size,
activation=hidden_act,
transpose_b=False,
expert_num=expert_num,
expert_group_size=expert_group_size,
outer_batch=dp,
param_init_type=param_init_type,
compute_dtype=compute_dtype)

if expert_num > 1:
self.mapping.shard(strategy_matmul=((dp, ep, 1, 1), (ep, 1, mp)),
strategy_bias=((dp, ep, 1, mp), (1, ep, 1, mp)),
strategy_activation=((dp, ep, 1, mp),))
else:
self.mapping.shard(strategy_matmul=((dp, 1), (1, mp)),
strategy_bias=((dp, mp), (mp,)),
strategy_activation=((dp, mp),))
# Project back to hidden_size
self.projection = Linear(in_channels=output_size,
out_channels=input_size,
transpose_b=False,
expert_num=expert_num,
expert_group_size=expert_group_size,
outer_batch=dp,
param_init_type=param_init_type,
compute_dtype=compute_dtype)
if expert_num > 1:
self.projection.shard(strategy_matmul=((dp, ep, 1, mp), (ep, mp, 1)),
strategy_bias=((dp, ep, 1, 1), (1, ep, 1, 1)))
else:
self.projection.shard(strategy_matmul=((dp, mp), (mp, 1)),
strategy_bias=((dp, 1), (1,)))
self.projection.bias.parallel_optimizer = False
self.dropout = get_dropout(dropout_rate)
self.dropout_3d = get_dropout(dropout_rate)
self.dropout_4d = get_dropout(dropout_rate)
self.dropout.dropout.shard(((dp, 1),))
self.dropout_3d.dropout.shard(((dp, 1, 1),))
self.dropout_4d.dropout.shard(((dp, ep, 1, 1),))
self.cast = P.Cast()

def construct(self, x):
"""Forward process of the FeedForward"""
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name)
x = self.cast(x, self.dtype)
# returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size]
hidden = self.mapping(x)
output = self.projection(hidden)
# returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size]
if len(F.shape(output)) == 3:
output = self.dropout_3d(output)
elif len(F.shape(output)) == 2:
output = self.dropout(output)
else:
output = self.dropout_4d(output)
return output


class LowerTriangularMaskWithDynamic(Cell):
r"""
Get the Strictly Lower triangular matrix from the input_ids.


+ 28
- 25
mindformers/parallel_core/inference/tensor_parallel/grouped_layers.py View File

@@ -84,6 +84,7 @@ class UnquantizedGroupedLinearMethod(GroupedLinearMethodBase):
self.cast = P.Cast()
self.matmul = ops.auto_generate.GroupedMatmulV4()

# pylint: disable=W0237
def create_weights(self, layer: nn.Cell, num_local_experts: int,
input_size_per_partition: int, output_partition_sizes: list[int],
params_dtype, **extra_weight_attrs):
@@ -216,17 +217,17 @@ class ColumnParallelGroupedLinear(GroupedLinearBase):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""
):
super(ColumnParallelGroupedLinear, self).__init__(num_local_experts,
input_size,
output_size,
skip_bias_add,
config.params_dtype,
quant_config=quant_config,
prefix=prefix)
super().__init__(num_local_experts,
input_size,
output_size,
skip_bias_add,
config.params_dtype,
quant_config=quant_config,
prefix=prefix)
if stride > 1:
raise NotImplementedError(
"For ColumnParallelGroupedLinear, `stride > 1` is not supported for now, "
"but got `stride={}`".format(stride))
f"For ColumnParallelGroupedLinear, `stride > 1` is not supported for now, "
f"but got `stride={stride}`")
if skip_bias_add:
raise NotImplementedError(
"For ColumnParallelGroupedLinear, `skip_bias_add=True` is not supported for now."
@@ -275,6 +276,7 @@ class ColumnParallelGroupedLinear(GroupedLinearBase):
else:
self.bias = None

# pylint: disable=W0237
def construct(self, input_parallel, weight=None, group_list=None):
"""Forward of ColumnParallelGroupedLinear."""
if weight is None:
@@ -386,15 +388,15 @@ class ColumnParallelGroupedLinear(GroupedLinearBase):


class RowParallelGroupedLinear(GroupedLinearBase):
r"""
"""
The group linear layer with weight sliced on first dimension by tensor parallel size.
This layer implements the operation as:

.. math::
\text{outputs} = \text{inputs} * \text{weight} + \text{bias},
\\text{outputs} = \\text{inputs} * \\text{weight} + \\text{bias},

where :math:`inputs` is the input tensors, :math:`\text{weight}` is a weight matrix created by the layer,
and :math:`\text{bias}` is a bias vector created by the layer (only if has_bias is True).
where :math:`inputs` is the input tensors, :math:`\\text{weight}` is a weight matrix created by the layer,
and :math:`\\text{bias}` is a bias vector created by the layer (only if has_bias is True).

Args:
num_local_experts (int): The number of local expert.
@@ -416,11 +418,11 @@ class RowParallelGroupedLinear(GroupedLinearBase):
prefix (str): The prefix string for this linear layer. Default: empty string("").

Inputs:
- **x** (Tensor) - Tensor of shape :math:`(*, in\_channels)`. The `input_size` in `Args` should be equal
to :math:`in\_channels` in `Inputs`.
- **x** (Tensor) - Tensor of shape :math:`(*, in\\_channels)`. The `input_size` in `Args` should be equal
to :math:`in\\_channels` in `Inputs`.

Outputs:
Tensor of shape :math:`(*, out\_channels)`.
Tensor of shape :math:`(*, out\\_channels)`.

Supported Platforms:
``Ascend``
@@ -445,17 +447,17 @@ class RowParallelGroupedLinear(GroupedLinearBase):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""
):
super(RowParallelGroupedLinear, self).__init__(num_local_experts,
input_size,
output_size,
skip_bias_add,
config.params_dtype,
quant_config=quant_config,
prefix=prefix)
super().__init__(num_local_experts,
input_size,
output_size,
skip_bias_add,
config.params_dtype,
quant_config=quant_config,
prefix=prefix)
if stride > 1:
raise NotImplementedError(
"For RowParallelGroupedLinear, `stride > 1` is not supported for now, "
"but got `stride={}`".format(stride))
f"For RowParallelGroupedLinear, `stride > 1` is not supported for now, "
f"but got `stride={stride}`")
if not is_expert:
raise NotImplementedError(
"For RowParallelGroupedLinear, `is_expert=False` is not supported for now.")
@@ -502,6 +504,7 @@ class RowParallelGroupedLinear(GroupedLinearBase):
else:
self.bias = None

# pylint: disable=W0237
def construct(self, input_, weight=None, group_list=None):
"""Forward of RowParallelGroupedLinear."""
if weight is None:


+ 10
- 10
mindformers/parallel_core/inference/utils.py View File

@@ -20,12 +20,15 @@ __all__ = [
"update_comm_config",
]

import os
import stat
from contextlib import contextmanager
import numpy as np

import mindspore as ms
from mindspore import Tensor, ops, Parameter, mint
from mindspore.communication import get_group_size
from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy

from mindformers.version_control import is_310p
from mindformers.parallel_core.transformer_config import TransformerConfig
@@ -65,7 +68,7 @@ ATTNMASK_FUNC_MAP = {


def get_attn_mask_func(mask_func_type):
r"""
"""
Get attention mask function.

Args:
@@ -75,9 +78,9 @@ def get_attn_mask_func(mask_func_type):
Function, the attention mask function.
"""
if mask_func_type not in ATTNMASK_FUNC_MAP:
raise KeyError("Invalid attention mask function. Supported attention "
"mask function are ['attn_mask_fill', 'attn_mask_add'] "
", but got {}.".format(mask_func_type))
raise KeyError(f"Invalid attention mask function. Supported attention "
f"mask function are ['attn_mask_fill', 'attn_mask_add'] "
f", but got {mask_func_type}.")
return ATTNMASK_FUNC_MAP[mask_func_type]


@@ -158,7 +161,7 @@ def create_empty_parameter(shape, *, dtype=None, device=None, **kwargs):
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
if numerator % denominator != 0:
raise ValueError("{} is not divisible by {}".format(numerator, denominator))
raise ValueError(f"{numerator} is not divisible by {denominator}")


def divide(numerator, denominator):
@@ -178,10 +181,6 @@ def save_strategy_file(state_dict, strategy_file_name):
Supported Platforms:
``Ascend``
"""
import os
import stat
from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy

stra = ckpt_strategy()

stage_rank_size = state_dict["stage_rank_size"]
@@ -361,12 +360,13 @@ def get_num_layers_and_offset(config):
return int(layer_list[pp_rank]), int(sum(layer_list[:pp_rank]))
return num_layers, 0


def use_ms_custom_ops():
"""
Determine whether has custom ops
"""
try:
# pylint: disable=W0611
# pylint: disable=W0611, C0415
import ms_custom_ops
except ModuleNotFoundError:
# environment need install ms_custom_ops package


+ 116
- 37
mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py View File

@@ -15,10 +15,12 @@
"""mindformers GPT model"""
__all__ = ['GPTModel']

import hashlib
from typing import Literal, Optional, Union
import numpy as np

import mindspore as ms
from mindspore.communication import create_group, get_group_size, get_rank
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.ops import auto_generate as aclnn_ops
@@ -29,7 +31,8 @@ from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagati
from mindspore import ops

from mindformers.parallel_core.training_graph.loss_func import CrossEntropyLoss
from mindformers.parallel_core.training_graph.transformer.multi_token_prediction import MultiTokenPredictionBlock
from mindformers.parallel_core.training_graph.transformer.multi_token_prediction import MultiTokenPredictionBlock, \
func_infer_dtype, func_infer_shape, func_infer_shape_labels_and_masks
from mindformers.parallel_core.training_graph.device_matrix import layout
from mindformers.parallel_core.utils.spec_utils import ModuleSpec
from mindformers.parallel_core.training_graph.transformer.mask_generate import CausalMaskGenerate
@@ -54,26 +57,60 @@ from mindformers.tools.logger import logger
from mindformers.models.utils import get_current_rank_stage, get_model_parameters
from mindformers.version_control import get_lazy_inline as lazy_inline
from mindformers.core.optim.muon_utils import make_muon_fns
from mindformers.checkpoint.sharded_tensor import ShardedTensor


def compute_repeat_num_and_model_parallel_size(sharded_info: ShardedTensor, world_size: int, pp: int, op: int):
"""Compute real op size."""
axis_fragmentations = sharded_info.axis_fragmentations
flag = False
weight_sharded_size = 1
for axis in axis_fragmentations:
if axis == 1:
continue
if flag:
raise ValueError("Only one axis can be fragmented in Muon optimizer.")
flag = True
weight_sharded_size *= axis
repeat_num = world_size // pp // weight_sharded_size
real_op_size = min(op, repeat_num)
if sharded_info.local_shape[0] % real_op_size != 0:
real_op_size = 1
return real_op_size, weight_sharded_size


def create_communication_group(rank_list):
"""
Create a communication group with a hashed name.

Args:
rank_list: List of ranks in the communication group

Returns:
str: The created group name
"""
rank_list_str = "-".join([str(i) for i in rank_list])
hashed = hashlib.md5(rank_list_str.encode()).hexdigest()[:48]
group_name = str(hashed)
create_group(group_name, rank_list)
return group_name

def func_infer_dtype(*args):
"""infer_dtype for Morph."""
return args[0]

OP_GROUP_NAME = {}

def func_infer_shape(*args):
"""infer_shape for Morph."""
input_shape = args[0]
shape_value = np.prod(input_shape[:-1])
output_shape = [int(shape_value), args[0][-1]]
return output_shape

def func_infer_shape_labels_and_masks(*args):
"""infer_shape for Morph."""
input_shape = args[0]
shape_value = np.prod(input_shape)
output_shape = [int(shape_value)]
return output_shape
def get_op_group_name(rank_id: int, real_op_size: int, model_parallel_size: int):
"""Get op group name."""
if (rank_id, real_op_size, model_parallel_size) in OP_GROUP_NAME:
return OP_GROUP_NAME[(rank_id, real_op_size, model_parallel_size)]
dp_range = model_parallel_size
op_range = model_parallel_size * real_op_size
rank_start = rank_id % dp_range + rank_id // op_range * op_range
rank_end = rank_start + op_range
rank_list = list(range(rank_start, rank_end, dp_range))
op_group_name = create_communication_group(rank_list)
OP_GROUP_NAME[(rank_id, real_op_size, model_parallel_size)] = (op_group_name, rank_list)
return op_group_name, rank_list


class PreprocessLabelsAndMasks(nn.Cell):
@@ -660,18 +697,45 @@ class GPTModel(nn.Cell):
if hasattr(self.decoder.layers[i].mlp, "router"):
self.assign_aux(self.decoder.layers[i].mlp.router.fi_accu, self.zeros_tensor)

def reset_max_attention_logit(self,):
"""Reset max attention logit for all layers."""
def _iter_core_attentions(self):
"""Iterate over all core_attention modules with their param names.

Yields:
Tuple[str, module]: A tuple of (param_name, core_attention_module).
"""
num_layers = self.config.num_layers
mtp_num_layers = self.config.mtp_num_layers

for i in range(num_layers):
if hasattr(self.decoder.layers[i].self_attention.core_attention, "max_logits_val"):
param = self.decoder.layers[i].self_attention.core_attention.max_logits_val
F.assign(param, F.zeros_like(param))
core_attn = self.decoder.layers[i].self_attention.core_attention
yield f"decoder.layers.{i}.self_attention.core_attention", core_attn

for i in range(mtp_num_layers):
if hasattr(self.mtp.layers[i].transformer_layer.self_attention.core_attention, "max_logits_val"):
param = self.mtp.layers[i].transformer_layer.self_attention.core_attention.max_logits_val
core_attn = self.mtp.layers[i].transformer_layer.self_attention.core_attention
yield f"mtp.layers.{i}.transformer_layer.self_attention.core_attention", core_attn

def get_max_attention_logit(self):
"""Get max attention logit values for all layers.

Returns:
dict: A dictionary mapping parameter names to their max logit values.
Only includes layers with valid (sum > 0) max_logits_val.
"""
max_logits = {}
for param_name, core_attn in self._iter_core_attentions():
if not hasattr(core_attn, "max_logits_val"):
continue
param = core_attn.max_logits_val.value()
if param.sum() <= 0:
continue
max_logits[f"{param_name}.max_logits_val"] = param
return max_logits

def reset_max_attention_logit(self):
"""Reset max attention logit to zeros for all layers."""
for _, core_attn in self._iter_core_attentions():
if hasattr(core_attn, "max_logits_val"):
param = core_attn.max_logits_val
F.assign(param, F.zeros_like(param))

def shard(self, config: TransformerConfig):
@@ -723,6 +787,14 @@ class GPTModel(nn.Cell):
def sharding_propagation(self, config: TransformerConfig):
pass

def sharded_state_dict(self):
"""Get all sharded state dict."""
sharded_state_dict = {}
for _, sub_cell in self.cells_and_names():
if sub_cell != self and hasattr(sub_cell, "sharded_state_dict"):
sharded_state_dict.update(sub_cell.sharded_state_dict())
return sharded_state_dict

def get_model_parameters(self):
"""Get current rank trainable parameters in gpt model ."""
params = set()
@@ -849,7 +921,7 @@ class GPTModel(nn.Cell):
tp_dims.append(0)
return tuple(tp_dims)

def get_op_groups_info(self, params, op, op_group, op_in_tp_group):
def get_op_groups_info(self, params, op):
"""Return optimizer parallel group information for each parameter.
Args:
@@ -868,16 +940,14 @@ class GPTModel(nn.Cell):
"self_attention.linear_q_down_proj.weight",
"self_attention.linear_kv_up_proj.weight",
"self_attention.linear_kv_down_proj.weight",
"eh_proj"
]

use_tp_group_list = [
"mlp.router.weight",
"mlp.shared_experts.linear_fc",
"self_attention.linear_q_down_proj.weight",
"eh_proj",
"max_logits_val"
]

sharded_state_dict = self.sharded_state_dict()
world_size = get_group_size()
pp = self.config.pipeline_model_parallel_size

def name_filter(param_name, full_name_list):
for full_name in full_name_list:
if full_name in param_name:
@@ -897,13 +967,22 @@ class GPTModel(nn.Cell):
logger.warning(
f"Parameter {param.name}: parallel_optimizer was set to False due to the use of Muon optimizer."
)
else:
op_list.append(op)
continue

# compute real op size
sharded_info = sharded_state_dict.get(param.name)
real_op_size, weight_sharded_size = compute_repeat_num_and_model_parallel_size(sharded_info, world_size, pp,
op)
if real_op_size == 1:
op_list.append(1)
op_groups.append("")
logger.info(f"Parameter {param.name} : No op group.")
continue

if name_filter(param.name, use_tp_group_list):
op_groups.append(op_in_tp_group)
else:
op_groups.append(op_group)
op_list.append(real_op_size)
op_group_name, rank_list = get_op_group_name(get_rank(), real_op_size, weight_sharded_size)
logger.info(f"Parameter {param.name} : Muon op group list is: {rank_list}")
op_groups.append(op_group_name)

return tuple(op_list), tuple(op_groups)



+ 23
- 26
mindformers/parallel_core/training_graph/loss_func.py View File

@@ -40,26 +40,23 @@ _device_local_loss = {}

def get_device_local_loss(tag="lm"):
"""Get `_device_local_loss` Parameter after init"""
global _device_local_loss
if tag is None:
return _device_local_loss
if _device_local_loss.get(tag, None) is None:
_device_local_loss[tag] = Parameter(
Tensor([0.0], mstype.float32), name=f"_device_local_loss", requires_grad=False
Tensor([0.0], mstype.float32), name="_device_local_loss", requires_grad=False
)
return _device_local_loss[tag]


def reset_device_local_loss():
"""Reset `_device_local_loss` parameter to zero"""
global _device_local_loss
for _, loss in _device_local_loss.items():
F.assign(loss, Tensor([0.0], mstype.float32))


def check_device_local_loss():
"""check if Nan or Inf in `_device_local_loss` parameter then terminate training"""
global _device_local_loss
if not _device_local_loss:
return
for tag, device_local_loss in _device_local_loss.items():
@@ -88,7 +85,7 @@ class _LogSoftmax(nn.Cell):
The corresponding log softmax results.
"""
def __init__(self, config: TransformerConfig = default_transformer_config):
super(_LogSoftmax, self).__init__()
super().__init__()
dp = config.data_parallel_size
mp = config.tensor_model_parallel_size
cp = config.context_parallel_size
@@ -143,7 +140,7 @@ class _NLLLoss(nn.Cell):
The corresponding loss results.
"""
def __init__(self, config: TransformerConfig = default_transformer_config):
super(_NLLLoss, self).__init__()
super().__init__()
dp = config.data_parallel_size
mp = config.tensor_model_parallel_size
cp = config.context_parallel_size
@@ -176,7 +173,7 @@ class _NLLLoss(nn.Cell):


class CrossEntropyLoss(nn.Cell):
r"""
"""
Calculate the cross entropy loss.

CrossEntropyLoss supports two different types of targets:
@@ -185,9 +182,9 @@ class CrossEntropyLoss(nn.Cell):
When reduction is set to 'none', the cross-entropy loss is computed as follows:

.. math::
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
l_n = - w_{y_n} \log \frac{\exp(x_{n,y_n})}{\sum_{c=1}^C \exp(x_{n,c})}
\cdot \mathbb{1}\{y_n \not= \text{ignore_index}\}
\\ell(x, y) = L = \\{l_1,\\dots,l_N\\}^\top, \\quad
l_n = - w_{y_n} \\log \\frac{\\exp(x_{n,y_n})}{\\sum_{c=1}^C \\exp(x_{n,c})}
\\cdot \\mathbb{1}\\{y_n \\not= \\text{ignore_index}\\}

where :math:`x` denotes the predicted values, :math:`t` denotes the target values, :math:`w` denotes the weights,
and :math:`N` is the batch size. The index :math:`c` ranges from [0, C-1], representing the class indices,
@@ -196,19 +193,19 @@ class CrossEntropyLoss(nn.Cell):
If reduction is not set to 'none' (the default is 'mean'), the loss is computed as:

.. math::
\ell(x, y) = \begin{cases}
\sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n} \cdot \mathbb{1}\{y_n \not= \text{ignore_index}\}} l_n, &
\text{if reduction} = \text{'mean',}\\
\sum_{n=1}^N l_n, &
\text{if reduction} = \text{'sum'.}
\end{cases}
\\ell(x, y) = \\begin{cases}
\\sum_{n=1}^N \\frac{1}{\\sum_{n=1}^N w_{y_n} \\cdot \\mathbb{1}\\{y_n \\not=
\\text{ignore_index}\\}} l_n, &\\text{if reduction} = \\text{'mean',}\\\\
\\sum_{n=1}^N l_n, &
\\text{if reduction} = \\text{'sum'.}
\\end{cases}

- Class probabilities (float), used when the target is a probability distribution over multiple class labels.
When reduction is set to 'none', the cross-entropy loss is computed as follows:

.. math::
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
l_n = - \sum_{c=1}^C w_c \log \frac{\exp(x_{n,c})}{\sum_{i=1}^C \exp(x_{n,i})} y_{n,c}
\\ell(x, y) = L = \\{l_1,\\dots,l_N\\}^\\top, \\quad
l_n = - \\sum_{c=1}^C w_c \\log \\frac{\\exp(x_{n,c})}{\\sum_{i=1}^C \\exp(x_{n,i})} y_{n,c}

where :math:`x` denotes the predicted values, :math:`t` denotes the target values, :math:`w` denotes the weights,
and :math:`N` is the batch size. The index :math:`c` ranges from [0, C-1], representing the class indices,
@@ -217,12 +214,12 @@ class CrossEntropyLoss(nn.Cell):
If reduction is not set to 'none' (the default is 'mean'), the loss is computed as:

.. math::
\ell(x, y) = \begin{cases}
\frac{\sum_{n=1}^N l_n}{N}, &
\text{if reduction} = \text{'mean',}\\
\sum_{n=1}^N l_n, &
\text{if reduction} = \text{'sum'.}
\end{cases}
\\ell(x, y) = \\begin{cases}
\\frac{\\sum_{n=1}^N l_n}{N}, &
\\text{if reduction} = \\text{'mean',}\\\\
\\sum_{n=1}^N l_n, &
\\text{if reduction} = \\text{'sum'.}
\\end{cases}

Args:
config (TransformerConfig): The parallel configuration. Default: default_transformer_config,
@@ -258,7 +255,7 @@ class CrossEntropyLoss(nn.Cell):
@_LogActionOnce(m_logger=logger, key='CrossEntropyLoss',
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
def __init__(self, config: TransformerConfig = default_transformer_config, loss_tag='lm', **kwargs):
super(CrossEntropyLoss, self).__init__()
super().__init__()
dp = config.data_parallel_size
mp = config.tensor_model_parallel_size
cp = config.context_parallel_size
@@ -347,7 +344,7 @@ class VocabParallelCrossEntropy(nn.Cell):
"""calculate cross entropy loss"""

def __init__(self, config: TransformerConfig = default_transformer_config, **kwargs):
super(VocabParallelCrossEntropy, self).__init__()
super().__init__()
self.cross_entropy = CrossEntropyLoss(config, **kwargs)

def construct(self, vocab_parallel_logits, target, input_mask=None, label_smoothing=None):


+ 129
- 0
mindformers/parallel_core/training_graph/tensor_parallel/layers.py View File

@@ -36,6 +36,7 @@ from mindspore.ops.operations import Morph
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
from mindspore import mint

from mindformers.checkpoint.sharded_tensor import ShardedTensor
from mindformers.parallel_core.transformer_config import TransformerConfig
from mindformers.parallel_core.utils.init_method import init_method_zero
from mindformers.parallel_core.inference.utils import divide
@@ -207,6 +208,28 @@ class VocabParallelEmbedding(nn.Cell):
def sharding_propagation(self, config: TransformerConfig):
pass

def sharded_state_dict(self):
"""Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now."""
sharded_state_dict = {}
weight_shape = (self.num_embeddings, self.embedding_dim)

if self.enable_embedding_tp:
axis_fragmentations = (self.tp, 1)
local_shape = (self.num_embeddings // self.tp, self.embedding_dim)
else:
axis_fragmentations = (1, 1)
local_shape = (self.num_embeddings, self.embedding_dim)

sharded_state_dict[self.weight.name] = ShardedTensor(
key=self.weight.name,
org_key=self.weight.name,
dtype=self.weight.dtype,
local_shape=local_shape,
global_shape=weight_shape,
global_offset=(0, 0),
axis_fragmentations=axis_fragmentations)
return sharded_state_dict


class ColumnParallelLinear(nn.Cell):
"""Linear layer with column parallelism.
@@ -427,6 +450,33 @@ class ColumnParallelLinear(nn.Cell):
matmul_in_strategy = ((dp * cp, 1), weight_strategy)
self.matmul.shard(in_strategy=matmul_in_strategy)

def sharded_state_dict(self):
"""Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now."""
tp = self.config.tensor_model_parallel_size

sharded_state_dict = {}
weight_shape = (self.output_size, self.input_size)
local_shape = (self.output_size // tp, self.input_size)
if not self.skip_weight_param_allocation:
sharded_state_dict[self.weight.name] = ShardedTensor(
key=self.weight.name,
org_key=self.weight.name,
dtype=self.weight.dtype,
local_shape=local_shape,
global_shape=weight_shape,
global_offset=(0, 0),
axis_fragmentations=(tp, 1))
if self.has_bias:
sharded_state_dict[self.bias.name] = ShardedTensor(
key=self.bias.name,
org_key=self.bias.name,
dtype=self.bias.dtype,
local_shape=(self.output_size,),
global_shape=(self.output_size,),
global_offset=(0,),
axis_fragmentations=(1,))
return sharded_state_dict


class RowParallelLinear(nn.Cell):
"""Linear layer with row parallelism.
@@ -663,6 +713,33 @@ class RowParallelLinear(nn.Cell):
matmul_in_strategy = ((dp * cp, tp), weight_strategy)
self.matmul.shard(in_strategy=matmul_in_strategy)

def sharded_state_dict(self):
"""Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now."""
tp = self.config.tensor_model_parallel_size
sharded_state_dict = {}
weight_shape = (self.output_size, self.input_size)
local_shape = (self.output_size, self.input_size // tp)

sharded_state_dict[self.weight.name] = ShardedTensor(
key=self.weight.name,
org_key=self.weight.name,
dtype=self.weight.dtype,
local_shape=local_shape,
global_shape=weight_shape,
global_offset=(0, 0),
axis_fragmentations=(1, tp))

if self.has_bias:
sharded_state_dict[self.bias.name] = ShardedTensor(
key=self.bias.name,
org_key=self.bias.name,
dtype=self.bias.dtype,
local_shape=(self.output_size,),
global_shape=(self.output_size,),
global_offset=(0,),
axis_fragmentations=(1,))
return sharded_state_dict


class LinearNoTP(ColumnParallelLinear):
"""Linear layer without tensor parallelism.
@@ -712,6 +789,32 @@ class LinearNoTP(ColumnParallelLinear):
)
)

def sharded_state_dict(self):
"""Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now."""
sharded_state_dict = {}
weight_shape = (self.output_size, self.input_size)
local_shape = (self.output_size, self.input_size)

sharded_state_dict[self.weight.name] = ShardedTensor(
key=self.weight.name,
org_key=self.weight.name,
dtype=self.weight.dtype,
local_shape=local_shape,
global_shape=weight_shape,
global_offset=(0, 0),
axis_fragmentations=(1, 1))

if self.has_bias:
sharded_state_dict[self.bias.name] = ShardedTensor(
key=self.bias.name,
org_key=self.bias.name,
dtype=self.bias.dtype,
local_shape=(self.output_size,),
global_shape=(self.output_size,),
global_offset=(0,),
axis_fragmentations=(1,))
return sharded_state_dict


class SequenceParallelLinear(ColumnParallelLinear):
"""Linear layer without tensor parallelism.
@@ -761,3 +864,29 @@ class SequenceParallelLinear(ColumnParallelLinear):
layout(("cp", "tp"), "dp", "None"), # output [S, B, H]
)
)

def sharded_state_dict(self):
"""Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now."""
sharded_state_dict = {}
weight_shape = (self.output_size, self.input_size)
local_shape = (self.output_size, self.input_size)

sharded_state_dict[self.weight.name] = ShardedTensor(
key=self.weight.name,
org_key=self.weight.name,
dtype=self.weight.dtype,
local_shape=local_shape,
global_offset=(0, 0),
global_shape=weight_shape,
axis_fragmentations=(1, 1))

if self.has_bias:
sharded_state_dict[self.bias.name] = ShardedTensor(
key=self.bias.name,
org_key=self.bias.name,
dtype=self.bias.dtype,
local_shape=(self.output_size,),
global_offset=(0,),
global_shape=(self.output_size,),
axis_fragmentations=(1,))
return sharded_state_dict

+ 25
- 0
mindformers/parallel_core/training_graph/transformer/moe/ffn.py View File

@@ -23,6 +23,7 @@ from mindspore.ops.auto_generate import Shape, Cast, GroupedMatmul, Reshape, Swi
from mindspore.ops.operations import Morph
from mindspore.parallel._utils import _get_parallel_mode

from mindformers.checkpoint.sharded_tensor import ShardedTensor
from mindformers.parallel_core.training_graph.device_matrix import layout_moe as layout
from mindformers.parallel_core.training_graph.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher, MoEAlltoAllDeredundencyTokenDispatcher, MoEAlltoAllZeroRedundancyTokenDispatcher
from mindformers.parallel_core.transformer_config import TransformerConfig
@@ -215,3 +216,27 @@ class FFNGroupedGEMM(nn.Cell):
layout(dp, sp, mp0), # output [B, S, h]
)
)

def sharded_state_dict(self):
"""Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now."""
ep = self.config.expert_model_parallel_size
sharded_state_dict = {}
sharded_state_dict[self.weight1.name] = ShardedTensor(
key=self.weight1.name,
org_key=self.weight1.name,
dtype=self.weight1.dtype,
local_shape=(self.num_local_experts // ep * self.hidden_size, self.moe_ffn_hidden_size * 2),
global_shape=(self.num_local_experts * self.hidden_size, self.moe_ffn_hidden_size * 2),
global_offset=(0, 0),
axis_fragmentations=(ep, 1),
)
sharded_state_dict[self.weight2.name] = ShardedTensor(
key=self.weight2.name,
org_key=self.weight2.name,
dtype=self.weight2.dtype,
local_shape=(self.num_local_experts // ep * self.moe_ffn_hidden_size, self.hidden_size),
global_shape=(self.num_local_experts * self.moe_ffn_hidden_size, self.hidden_size),
global_offset=(0, 0),
axis_fragmentations=(ep, 1),
)
return sharded_state_dict

+ 15
- 0
mindformers/parallel_core/training_graph/transformer/moe/router.py View File

@@ -26,6 +26,7 @@ from mindspore.ops.auto_generate import AddExt, AssignAdd, Cast, Div, Mul, Resha
from mindspore.ops.operations import Shape, ReduceSum, ReduceMean
from mindspore.parallel._utils import _get_parallel_mode

from mindformers.checkpoint.sharded_tensor import ShardedTensor
from mindformers.parallel_core.training_graph.device_matrix import layout_moe as layout
from mindformers.parallel_core.transformer_config import TransformerConfig
from mindformers.tools.utils import get_real_group_size, get_real_rank
@@ -117,6 +118,20 @@ class Router(ABC, nn.Cell):
router_logits = self.linear(inputs.astype(self.moe_router_dtype), weight)
return router_logits

def sharded_state_dict(self):
"""Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now."""
sharded_state_dict = {}
sharded_state_dict[self.weight.name] = ShardedTensor(
key=self.weight.name,
org_key=self.weight.name,
dtype=self.weight.dtype,
local_shape=(self.expert_dim, self.hidden_size),
global_shape=(self.expert_dim, self.hidden_size),
global_offset=(0, 0),
axis_fragmentations=(1, 1),
)
return sharded_state_dict


class TopKRouter(Router):
"""Route each token to the top-k experts."""


+ 15
- 0
mindformers/parallel_core/training_graph/transformer/moe/shared_experts.py View File

@@ -23,6 +23,7 @@ from mindspore.ops.auto_generate import Cast, Mul, Sigmoid
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
from mindspore.context import ParallelMode

from mindformers.checkpoint.sharded_tensor import ShardedTensor
from mindformers.parallel_core.training_graph.transformer.mlp import MLP, MLPSubmodules, MLPInterleaved
from mindformers.parallel_core.transformer_config import TransformerConfig
from mindformers.parallel_core.training_graph.device_matrix import layout
@@ -108,6 +109,20 @@ class SharedExpertMLP(MLP):
def expert_sharding_propagation(self, config: TransformerConfig):
super().sharding_propagation(config)

def sharded_state_dict(self):
"""Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now."""
sharded_state_dict = {}
sharded_state_dict[self.shared_experts_gate.weight.name] = ShardedTensor(
key=self.shared_experts_gate.weight.name,
org_key=self.shared_experts_gate.weight.name,
dtype=self.shared_experts_gate.weight.dtype,
local_shape=(1, self.hidden_size),
global_shape=(1, self.hidden_size),
global_offset=(0, 0),
axis_fragmentations=(1, 1),
)
return sharded_state_dict


class SharedExpertMLPInterleaved(MLPInterleaved):
"""


+ 3
- 0
mindformers/parallel_core/training_graph/transformer/multi_token_prediction.py View File

@@ -541,6 +541,9 @@ class MtpSharedVocabParallelEmbedding(VocabParallelEmbedding):
output = self.embedding_morph(input_ids, weight)
return output

def sharded_state_dict(self):
return {}


class MtpSharedLanguageModelEmbedding(LanguageModelEmbedding):
"""Embedding layer used in Multi-Token Prediction module, same to standard LanguageModelEmbedding."""


+ 89
- 14
mindformers/parallel_core/training_graph/transformer/norm.py View File

@@ -23,6 +23,7 @@ from mindspore.ops.auto_generate import MeanExt, Sqrt, Rsqrt, SubExt, AddExt, Mu
from mindspore.common.initializer import initializer
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation

from mindformers.checkpoint.sharded_tensor import ShardedTensor
from mindformers.parallel_core.transformer_config import TransformerConfig
from mindformers.parallel_core.training_graph.device_matrix import layout

@@ -36,7 +37,7 @@ def get_strategy(config: TransformerConfig):


class LayerNorm(nn.Cell):
r"""
"""
Layer norm operation.

Args:
@@ -52,9 +53,10 @@ class LayerNorm(nn.Cell):
"""

def __init__(self, config, dim, eps=1e-5):
super(LayerNorm, self).__init__()
super().__init__()
self.params_dtype = config.params_dtype
self.compute_type = config.layernorm_compute_dtype
self.dim = dim

self.gamma = Parameter(initializer('ones', dim, self.params_dtype), name="gamma",
parallel_optimizer=False)
@@ -115,9 +117,32 @@ class LayerNorm(nn.Cell):
def sharding_propagation(self, config: TransformerConfig):
pass

def sharded_state_dict(self):
"""Return sharding metadata for LayerNorm parameters."""
sharded_state_dict = {}
sharded_state_dict[self.gamma.name] = ShardedTensor(
key=self.gamma.name,
org_key=self.gamma.name,
dtype=self.gamma.dtype,
local_shape=(self.dim,),
global_shape=(self.dim,),
global_offset=(0,),
axis_fragmentations=(1,),
)
sharded_state_dict[self.beta.name] = ShardedTensor(
key=self.beta.name,
org_key=self.beta.name,
dtype=self.beta.dtype,
local_shape=(self.dim,),
global_shape=(self.dim,),
global_offset=(0,),
axis_fragmentations=(1,),
)
return sharded_state_dict


class FusedLayerNorm(nn.Cell):
r"""
"""
Layer norm operation.

Args:
@@ -133,10 +158,10 @@ class FusedLayerNorm(nn.Cell):
"""

def __init__(self, config, dim, eps=1e-5):
super(FusedLayerNorm, self).__init__()
super().__init__()
self.params_dtype = config.params_dtype
self.compute_type = config.layernorm_compute_dtype
self.dim = dim
self.layer_norm = P.LayerNorm(begin_norm_axis=-1,
begin_params_axis=-1,
epsilon=eps)
@@ -170,17 +195,39 @@ class FusedLayerNorm(nn.Cell):
strategy = (cp, dp, 1)

if strategy[-1] != 1:
raise TypeError(
'The last dim in FusedLayerNorm can not equal to 1! Strategy {} not supported!'.format(strategy))
raise TypeError(f'The last dim in FusedLayerNorm can not equal to 1! Strategy {strategy} not supported!')

self.layer_norm.shard((strategy, (strategy[-1],), (strategy[-1],)))

def sharding_propagation(self, config: TransformerConfig):
pass

def sharded_state_dict(self):
"""Return sharding metadata for FusedLayerNorm parameters."""
sharded_state_dict = {}
sharded_state_dict[self.gamma.name] = ShardedTensor(
key=self.gamma.name,
org_key=self.gamma.name,
dtype=self.gamma.dtype,
local_shape=(self.dim,),
global_shape=(self.dim,),
global_offset=(0,),
axis_fragmentations=(1,),
)
sharded_state_dict[self.beta.name] = ShardedTensor(
key=self.beta.name,
org_key=self.beta.name,
dtype=self.beta.dtype,
local_shape=(self.dim,),
global_shape=(self.dim,),
global_offset=(0,),
axis_fragmentations=(1,),
)
return sharded_state_dict


class RMSNorm(nn.Cell):
r"""
"""
A self-defined RMSNorm operation using reduce mean.

Args:
@@ -196,10 +243,10 @@ class RMSNorm(nn.Cell):
"""

def __init__(self, config, dim, eps=1e-6):
super(RMSNorm, self).__init__()
super().__init__()
self.params_dtype = config.params_dtype
self.compute_type = config.layernorm_compute_dtype
self.dim = dim
self.eps = eps
self.weight = Parameter(initializer('ones', (dim), self.params_dtype))

@@ -249,9 +296,23 @@ class RMSNorm(nn.Cell):
def sharding_propagation(self, config: TransformerConfig):
pass

def sharded_state_dict(self):
"""Return sharding metadata for RMSNorm parameters."""
sharded_state_dict = {}
sharded_state_dict[self.weight.name] = ShardedTensor(
key=self.weight.name,
org_key=self.weight.name,
dtype=self.weight.dtype,
local_shape=(self.dim,),
global_shape=(self.dim,),
global_offset=(0,),
axis_fragmentations=(1,),
)
return sharded_state_dict


class FusedRMSNorm(nn.Cell):
r"""
"""
FusedRMSNorm operation

Args:
@@ -267,10 +328,10 @@ class FusedRMSNorm(nn.Cell):
"""

def __init__(self, config, dim, eps=1e-6):
super(FusedRMSNorm, self).__init__()
super().__init__()
self.params_dtype = config.params_dtype
self.compute_type = config.layernorm_compute_dtype
self.dim = dim
self.eps = eps
self.weight = Parameter(initializer('ones', (dim), self.params_dtype))

@@ -294,11 +355,25 @@ class FusedRMSNorm(nn.Cell):
if in_strategy:
self.norm.shard(in_strategy)
else:
self.norm.shard((layout("cp", "dp", "None"), layout("None",)))
self.norm.shard((layout("cp", "dp", "None"), layout("None", )))

def sharding_propagation(self, config: TransformerConfig):
pass

def sharded_state_dict(self):
"""Return sharding metadata for FusedRMSNorm parameters."""
sharded_state_dict = {}
sharded_state_dict[self.weight.name] = ShardedTensor(
key=self.weight.name,
org_key=self.weight.name,
dtype=self.weight.dtype,
local_shape=(self.dim,),
global_shape=(self.dim,),
global_offset=(0,),
axis_fragmentations=(1,),
)
return sharded_state_dict


class Norm:
"""


+ 1
- 1
mindformers/parallel_core/training_graph/transformer/utils.py View File

@@ -77,7 +77,7 @@ ATTNMASK_FUNC_MAP = {


def get_attn_mask_func(mask_func_type):
r"""
"""
Get attention mask function.

Args:


+ 7
- 2
mindformers/parallel_core/utils/model_mixin.py View File

@@ -562,6 +562,11 @@ class TrainModelMixin:
model = self.check_and_get_model()
return model.get_model_parameters()

def get_max_attention_logit(self):
"""Get max attention logit values from the model."""
model = self.check_and_get_model()
return model.get_max_attention_logit()

def make_model_muon_fns(self):
"""Make model muon functions."""
model = self.check_and_get_model()
@@ -577,10 +582,10 @@ class TrainModelMixin:
model = self.check_and_get_model()
return model.get_tp_dims(parameters)

def get_op_groups_info(self, parameters, op_size, tp_group, op_group):
def get_op_groups_info(self, parameters, op_size):
"""Get operation groups information for parameters."""
model = self.check_and_get_model()
return model.get_op_groups_info(parameters, op_size, tp_group, op_group)
return model.get_op_groups_info(parameters, op_size)

def get_parallel_config_for_muon(self):
"""Get parallel configuration for Muon optimizer."""


+ 59
- 37
mindformers/tools/ckpt_transform/transform_checkpoint.py View File

@@ -58,6 +58,7 @@ __all__ = ['TransformCkpt']

class TransformCkpt:
"""Transform src_checkpoint from src_strategy to dst_strategy."""

def __init__(self,
auto_trans_ckpt: bool = False,
rank_id: Optional[int] = None,
@@ -212,23 +213,29 @@ class TransformCkpt:
if self.world_size > 1:
dst_strategy_list = glob(os.path.join(self.dst_strategy_dir, f"*_rank_{self.rank_id}.ckpt"))
if not dst_strategy_list:
raise RuntimeError(f"The `dst_strategy`={self.dst_strategy_dir} \
does not contain strategy file of rank_{self.rank_id}.")
err_msg = (f"The `dst_strategy`={self.dst_strategy_dir} "
f"does not contain strategy file of rank_{self.rank_id}.")
logger.error(err_msg)
raise RuntimeError(err_msg)
if len(dst_strategy_list) > 1:
raise RuntimeError(f"There can only be one strategy file corresponding to rank_{self.rank_id}, \
but multiple strategy files corresponding to rank_{self.rank_id} were found \
in {self.dst_strategy_dir}.")
err_msg = (f"There can only be one strategy file corresponding to rank_{self.rank_id}, "
f"but multiple strategy files corresponding to rank_{self.rank_id} "
f"were found in {self.dst_strategy_dir}.")
logger.error(err_msg)
raise RuntimeError(err_msg)
dst_strategy = dst_strategy_list[0]
else:
dst_strategy = None

if check_in_modelarts():
if not mox.file.exists(self.transformed_checkpoint_dir_obs):
raise ValueError(f"transformed_checkpoint_dir_obs: "
f"{self.transformed_checkpoint_dir_obs} is not found!")
err_msg = f"transformed_checkpoint_dir_obs: {self.transformed_checkpoint_dir_obs} is not found!"
logger.error(err_msg)
raise ValueError(err_msg)
if self.world_size > 1 and not mox.file.exists(self.dst_strategy_dir_obs):
raise ValueError(f"dst_strategy_dir_obs: {self.dst_strategy_dir_obs} is not found!")

err_msg = f"dst_strategy_dir_obs: {self.dst_strategy_dir_obs} is not found!"
logger.error(err_msg)
raise ValueError(err_msg)

# Get final dst_strategy in auto_trans_ckpt mode.
dst_strategy = self.get_dst_strategy(dst_strategy)
@@ -247,13 +254,7 @@ class TransformCkpt:
barrier_world(f"Remake {dst_ckpt_dir} by main rank.")

logger.info("The transformed checkpoint will be saved under %s.", dst_ckpt_dir)
self.transform_ckpt(
src_checkpoint=src_ckpt_dir,
dst_checkpoint_dir=dst_ckpt_dir,
src_strategy=src_strategy,
dst_strategy=dst_strategy,
prefix=prefix
)
self.transform_ckpt(src_ckpt_dir, dst_ckpt_dir, src_strategy, dst_strategy, prefix)

self.clear_cache()
return dst_checkpoint_dir
@@ -267,7 +268,9 @@ class TransformCkpt:
"""Transform ckpt using mindspore.transform_checkpoint"""
self.check_src_checkpoint_and_strategy(src_checkpoint, src_strategy)
if src_strategy is None and dst_strategy is None:
raise ValueError("`src_strategy` and `dst_strategy` cannot both be None!")
err_msg = "`src_strategy` and `dst_strategy` cannot both be None!"
logger.error(err_msg)
raise ValueError(err_msg)
if check_in_modelarts():
dst_checkpoint_dir_obs = os.path.join(self.transformed_checkpoint_dir_obs,
os.path.basename(dst_checkpoint_dir))
@@ -339,7 +342,7 @@ class TransformCkpt:
dst_strategy):
"""transform checkpoints using mindspore.transform_checkpoint_by_rank"""
for current_transform_rank_id in \
range(self.rank_id, self.rank_id + self.world_size // self.transform_process_num):
range(self.rank_id, self.rank_id + self.world_size // self.transform_process_num):
logger.info(".........Transforming Ckpt For Rank: %d.........", current_transform_rank_id)
src_rank_list = ms.rank_list_for_transform(current_transform_rank_id,
src_strategy,
@@ -349,7 +352,9 @@ class TransformCkpt:
checkpoint_rank_dir = os.path.join(src_checkpoint, f"rank_{src_rank_id}")
checkpoint_file_list = glob(os.path.join(checkpoint_rank_dir, "*.ckpt"))
if not checkpoint_file_list:
raise ValueError(f"The checkpoint of rank_{src_rank_id} is not found!")
err_msg = f"The checkpoint of rank_{src_rank_id} is not found!"
logger.error(err_msg)
raise ValueError(err_msg)
checkpoint_file_list = sorted(checkpoint_file_list, key=os.path.getmtime)
checkpoint_file_map[src_rank_id] = checkpoint_file_list[-1]
save_checkpoint_dir = os.path.join(dst_checkpoint, f"rank_{current_transform_rank_id}")
@@ -371,11 +376,15 @@ class TransformCkpt:
def build_soft_link_of_checkpoint(checkpoint, soft_link_dir):
"""Build softlink of src checkpoint"""
if os.path.isdir(checkpoint) and not check_rank_folders(checkpoint, 0) and \
not check_ckpt_file_exist(checkpoint):
raise ValueError(f"No rank_0 folder or ckpt files are found under {checkpoint}.")
not check_ckpt_file_exist(checkpoint):
err_msg = f"No rank_0 folder or ckpt files are found under {checkpoint}."
logger.error(err_msg)
raise ValueError(err_msg)
if os.path.isfile(checkpoint) and not checkpoint.endswith('.ckpt'):
raise ValueError(f"The value of load_checkpoint must be a folder or a file with suffix '.ckpt', "
f"but got {checkpoint}")
err_msg = (f"The value of load_checkpoint must be a folder or a file with suffix '.ckpt', "
f"but got {checkpoint}")
logger.error(err_msg)
raise ValueError(err_msg)

if os.path.isdir(checkpoint):
if check_rank_folders(checkpoint, 0):
@@ -418,7 +427,9 @@ class TransformCkpt:
return None

if not os.path.exists(strategy_path):
raise ValueError(f'strategy_path: {strategy_path} not found!')
err_msg = f'strategy_path: {strategy_path} not found!'
logger.error(err_msg)
raise ValueError(err_msg)

if os.path.isfile(strategy_path):
return strategy_path
@@ -454,8 +465,9 @@ class TransformCkpt:

if not (dst_strategy.endswith(f"_rank_{self.rank_id}.ckpt") and
os.path.exists(dst_strategy)):
raise ValueError(f"dst_strategy: {dst_strategy} is not found!")

err_msg = f"dst_strategy: {dst_strategy} is not found!"
logger.error(err_msg)
raise ValueError(err_msg)

logger.info(".........Collecting strategy.........")
if check_in_modelarts():
@@ -521,16 +533,19 @@ class TransformCkpt:
"""
# Before obtaining transform_rank_id_list, check 1 ≤ transform_process_num ≤ world_size.
if transform_process_num < 1:
raise ValueError("transform_process_num should not smaller than 1,"
f"but got {transform_process_num}.")
err_msg = f"transform_process_num should not smaller than 1, but got {transform_process_num}."
logger.error(err_msg)
raise ValueError(err_msg)
if transform_process_num > self.world_size:
logger.warning(f"transform_process_num: {transform_process_num} should not "
f"bigger than world_size: {self.world_size}. "
f"transform_process_num is set to {self.world_size}.")
transform_process_num = self.world_size
if self.world_size % transform_process_num != 0:
raise ValueError(f"transform_process_num: {transform_process_num} "
f"should be divided by world_size: {self.world_size}.")
err_msg = (f"transform_process_num: {transform_process_num} "
f"should be divided by world_size: {self.world_size}.")
logger.error(err_msg)
raise ValueError(err_msg)

if check_in_modelarts() and 1 < transform_process_num < self.node_num:
logger.warning("transform_process_num: %d should not smaller than \
@@ -551,15 +566,15 @@ class TransformCkpt:

return transform_rank_id_list


@staticmethod
def check_src_checkpoint_and_strategy(src_checkpoint, src_strategy):
"""check src checkpoint and strategy"""
check_path(src_checkpoint, "src_checkpoint")
if not os.path.isdir(src_checkpoint) or not glob(os.path.join(src_checkpoint, "rank_*")):
raise ValueError("The load_checkpoint must be a dir and "
"ckpt should be stored in the format of load_checkpoint/rank_x/xxx.ckpt,"
f"but get {src_checkpoint}.")
err_msg = ("The load_checkpoint must be a dir and ckpt should be stored "
f"in the format of load_checkpoint/rank_x/xxx.ckpt, but get {src_checkpoint}.")
logger.error(err_msg)
raise ValueError(err_msg)
# Check rank_dirs is continuous.
# For example, rank_0, rank_1, rank_4 is not continuous because it is missing rank_3
src_checkpoint_rank_dir_list = glob(os.path.join(src_checkpoint, "rank_*"))
@@ -568,7 +583,9 @@ class TransformCkpt:
src_checkpoint_rank_num = len(src_checkpoint_rank_id_list)
for i in range(src_checkpoint_rank_num):
if src_checkpoint_rank_id_list[i] != i:
raise FileNotFoundError(f"The rank_{i} folder was not found under src_checkpoint folder.")
err_msg = f"The rank_{i} folder was not found under src_checkpoint folder."
logger.error(err_msg)
raise FileNotFoundError(err_msg)

# A full checkpoint do not require a strategy.
if len(src_checkpoint_rank_id_list) == 1 and src_strategy:
@@ -576,7 +593,9 @@ class TransformCkpt:
src_strategy = None
# Distributed checkpoints must be accompanied by strategy.
if len(src_checkpoint_rank_id_list) > 1 and src_strategy is None:
raise ValueError("`src_strategy` should not be None when `src_checkpoint` is sliced.")
err_msg = "`src_strategy` should not be None when `src_checkpoint` is sliced."
logger.error(err_msg)
raise ValueError(err_msg)

def send_strategy_to_obs(self, strategy):
"""Local rank send strategy file to obs."""
@@ -623,6 +642,7 @@ class TransformCkpt:
last_strategy_num = dst_strategy_num
if dst_strategy_num < self.world_size:
if time.time() - start_time > 7200:
logger.error("Timeout while collecting all strategy!")
raise TimeoutError("Timeout while collecting all strategy!")
time.sleep(5)
else:
@@ -642,7 +662,9 @@ class TransformCkpt:
transform_failed_txts = glob(os.path.join(ckpt_dir, 'transform_failed_rank_*.txt'))
transform_succeed_txts = glob(os.path.join(ckpt_dir, 'transform_succeed_rank_*.txt'))
if transform_failed_txts:
raise ValueError(f"Transform failed, find {transform_failed_txts}.")
err_msg = f"Transform failed, find {transform_failed_txts}."
logger.error(err_msg)
raise ValueError(err_msg)
current_count = len(transform_succeed_txts)
progress = (current_count / self.transform_process_num) * 100
if current_count != last_count:


+ 39
- 18
mindformers/tools/resume_ckpt.py View File

@@ -84,7 +84,9 @@ def get_resume_checkpoint_by_meta(checkpoint_dir, ckpt_format='ckpt',
if check_in_modelarts():
resume_record_dir = os.path.join(get_remote_save_url(), "resume_record")
if not Validator.is_obs_url(resume_record_dir):
raise ValueError(f"{resume_record_dir} is not a valid obs path.")
err_meg = f"{resume_record_dir} is not a valid obs path."
logger.error(err_meg)
raise ValueError(err_meg)
else:
resume_record_dir = os.path.join(get_output_root_path(), "resume_record")
remake_folder(resume_record_dir, permissions=0o750)
@@ -151,7 +153,9 @@ def checkpoint_health_monitor(health_ckpts_record_dir, resume_ckpt_list):
health_ckpts.append(item)

if not health_ckpts:
raise ValueError("The training has no healthy checkpoints yet, please start training again.")
err_msg = "The training has no healthy checkpoints yet, please start training again."
logger.error(err_msg)
raise ValueError(err_msg)

if not not_health_ckpts:
not_health_ckpts_set = set(not_health_ckpts)
@@ -167,12 +171,16 @@ def get_resume_ckpt(latest_checkpointed_iteration_txt, rank_id):

if not check_in_modelarts():
if not os.path.exists(latest_checkpointed_iteration_txt):
raise ValueError(f"Can not find {latest_checkpointed_iteration_txt}")
err_msg = f"Can not find {latest_checkpointed_iteration_txt}"
logger.error(err_msg)
raise ValueError(err_msg)
with open(latest_checkpointed_iteration_txt, 'r', encoding='utf-8') as f:
resume_info = [line.strip() for line in f.readlines()]
else:
if not mox.file.exists(latest_checkpointed_iteration_txt):
raise ValueError(f"OBS: Can not find {latest_checkpointed_iteration_txt}")
err_msg = f"OBS: Can not find {latest_checkpointed_iteration_txt}"
logger.error(err_msg)
raise ValueError(err_msg)
with mox.file.File(latest_checkpointed_iteration_txt, 'r') as f:
resume_info = [line.strip() for line in f.readlines()]

@@ -182,7 +190,9 @@ def get_resume_ckpt(latest_checkpointed_iteration_txt, rank_id):
return True

if resume_info[0].startswith("Error"):
raise ValueError(f"Get resume-able checkpoint failed, due to {resume_info[0]}")
err_msg = f"Get resume-able checkpoint failed, due to {resume_info[0]}"
logger.error(err_msg)
raise ValueError(err_msg)

resume_ckpt = replace_rank_id_in_ckpt_name(resume_info[-1], rank_id)
logger.info("Get resume checkpoint: %s", resume_ckpt)
@@ -242,7 +252,9 @@ def get_resume_ckpt_list(checkpoint_dir, last_ckpt_file, rank_id, rank_dir_num,
ckpt_prefix_tmp = ckpt_prefix.replace(f"rank_{original_rank}", f"rank_{rank_id_tmp}")
checkpoint_rank_dir = os.path.join(checkpoint_dir, f"rank_{rank_id_tmp}")
if not os.path.exists(checkpoint_rank_dir):
raise FileNotFoundError(f"{checkpoint_rank_dir} is not found!")
err_msg = f"{checkpoint_rank_dir} is not found!"
logger.error(err_msg)
raise FileNotFoundError(err_msg)
for ckpt_file in os.listdir(checkpoint_rank_dir):
health_ckpt_match = (ckpt_file.startswith(ckpt_prefix_tmp[:ckpt_prefix_tmp.rfind("_")])
and use_checkpoint_health_monitor)
@@ -262,10 +274,14 @@ def get_resume_ckpt_list(checkpoint_dir, last_ckpt_file, rank_id, rank_dir_num,
ckpt_file = replace_rank_id_in_ckpt_name(ckpts[0], rank_id)
resume_ckpt = os.path.join(checkpoint_dir, f"rank_{rank_id}", ckpt_file)
if not os.path.exists(resume_ckpt):
raise FileNotFoundError(f"{resume_ckpt} is not found!")
err_msg = f"{resume_ckpt} is not found!"
logger.error(err_msg)
raise FileNotFoundError(err_msg)
resume_ckpt_list.append(resume_ckpt)
if not resume_ckpt_list:
raise RuntimeError("No checkpoint could be resumed.")
err_msg = "No checkpoint could be resumed."
logger.error(err_msg)
raise RuntimeError(err_msg)

if use_checkpoint_health_monitor:
resume_ckpt_list.sort(key=lambda x: get_times_epoch_and_step_from_ckpt_name(x, ckpt_format))
@@ -325,8 +341,9 @@ def check_last_timestamp_checkpoints(checkpoint_dir, rank_dir_num, ckpt_format='
checkpoint_rank_dir = os.path.join(checkpoint_dir, f"rank_{rank_id_tmp}")
last_checkpoint = get_last_checkpoint(checkpoint_rank_dir, ckpt_format)
if not last_checkpoint:
raise ValueError(f"Checkpoint not found under {checkpoint_rank_dir} "
f"with config.load_ckpt_format:{ckpt_format}.")
err_msg = f"Checkpoint not found under {checkpoint_rank_dir} with config.load_ckpt_format:{ckpt_format}."
logger.error(err_msg)
raise ValueError(err_msg)
if check_ckpt_file_name(last_checkpoint, ckpt_format):
compared_checkpoint_name = replace_rank_id_in_ckpt_name(last_checkpoint, 0)
compared_original_checkpoint_name = os.path.basename(last_checkpoint)
@@ -353,12 +370,16 @@ def check_last_timestamp_checkpoints(checkpoint_dir, rank_dir_num, ckpt_format='
compared_checkpoint_name = current_checkpoint_name
compared_original_checkpoint_name = original_checkpoint_name
elif compared_checkpoint_name != current_checkpoint_name:
raise ValueError(f"Check name of the checkpoint file with the last timestamp Failed.\n"
f"1. Find 2 different checkpoints name: {compared_original_checkpoint_name} and "
f"{original_checkpoint_name}.\n2. Checkpoint file name should follow rule: "
f"{{prefix}}-{{epoch}}_{{step}}.{ckpt_format}, and not corrupted across all rank "
f"folders.\n 3. Rename `resume_training` checkpoint such as "
f"llama_7b_rank_0-3_2.{ckpt_format} may solve the problem.")
err_msg = (f"Check name of the checkpoint file with the last timestamp Failed.\n"
f"1. Find 2 different checkpoints name: {compared_original_checkpoint_name} and "
f"{original_checkpoint_name}.\n2. Checkpoint file name should follow rule: "
f"{{prefix}}-{{epoch}}_{{step}}.{ckpt_format}, and not corrupted across all rank "
f"folders.\n 3. Rename `resume_training` checkpoint such as "
f"llama_7b_rank_0-3_2.{ckpt_format} may solve the problem.")
logger.error(err_msg)
raise ValueError(err_msg)
if find_diff_ckpt:
raise ValueError(f"Some checkpoints follow the {{prefix}}-{{epoch}}_{{step}}.{ckpt_format} "
f"naming convention, while others do not.")
err_msg = (f"Some checkpoints follow the {{prefix}}-{{epoch}}_{{step}}.{ckpt_format} "
f"naming convention, while others do not.")
logger.error(err_msg)
raise ValueError(err_msg)

+ 11
- 6
mindformers/trainer/base_trainer.py View File

@@ -1133,7 +1133,9 @@ class BaseTrainer:
logger.info(".............Start load resume context from common.json..................")
common_file = os.path.join(config.load_checkpoint, 'common.json')
if not os.path.exists(common_file):
raise FileNotFoundError(f"No common.json found in directory '{config.load_checkpoint}'.")
error_msg = f"No common.json found in directory '{config.load_checkpoint}'."
logger.error(error_msg)
raise FileNotFoundError(error_msg)
common_info = CommonInfo.load_common(common_file)
step_scale = common_info.global_batch_size / config.runner_config.global_batch_size
config.runner_config.initial_step = int(common_info.step_num * step_scale)
@@ -1167,9 +1169,11 @@ class BaseTrainer:
logger.info("..............Start resume checkpoint path from strategy..............")
resume_ckpt_path = self.resume_ckpt_path_with_strategy(config)
if resume_ckpt_path is None:
raise ValueError(f"Try to resume from checkpoints with strategy in directory "
f"'{config.load_checkpoint}' failed, please specify load_checkpoint to "
f"specific checkpoint file to resume training.")
err_msg = (f"Try to resume from checkpoints with strategy in directory "
f"'{config.load_checkpoint}' failed, please specify load_checkpoint to "
f"specific checkpoint file to resume training.")
logger.error(err_msg)
raise ValueError(err_msg)
config.load_checkpoint = resume_ckpt_path
load_resume_context_from_checkpoint(config, dataset)
resume_dict = {
@@ -1253,8 +1257,9 @@ class BaseTrainer:
if hasattr(network, "get_model_parameters"):
model_params.update(network.get_model_parameters())
else:
raise NotImplementedError(f"The {type(network)} has not implemented the interface: "
f"get_model_parameters.")
err_msg = f"The {type(network)} has not implemented the interface: `get_model_parameters`."
logger.error(err_msg)
raise NotImplementedError(err_msg)

is_moe_model = False
is_mtp_model = False


+ 0
- 12
mindformers/trainer/optimizer_grouped_parameters.py View File

@@ -177,18 +177,6 @@ def get_optimizer_grouped_parameters(model: Optional[PreTrainedModel] = None,
# Append parameter to its group
parameter_group_vars[group_name]["params"].append(param)
parameter_group_names[group_name]["params"].append(param.name)
for param in model.get_parameters():
if "max_logits_val" in param.name:
param.requires_grad = False
if parameter_group_vars.get("max_logits") is None:
parameter_group_names["max_logits"] = {
"params": [],
}
parameter_group_vars["max_logits"] = {
"params": [],
}
parameter_group_vars["max_logits"]["params"].append(param)
parameter_group_names["max_logits"]["params"].append(param.name)
param_groups = json.dumps(parameter_group_names, indent=2)
logger.info("Param groups = %s", param_groups)
return list(parameter_group_vars.values())

+ 37
- 15
mindformers/trainer/utils.py View File

@@ -58,9 +58,9 @@ class BaseEnum(str, Enum):
@classmethod
def _missing_(cls, value):
"""Enum with more explicit error message for missing values."""
raise ValueError(
f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
)
err_msg = f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
logger.error(err_msg)
raise ValueError(err_msg)


class IntervalStrategy(BaseEnum):
@@ -132,7 +132,9 @@ def preload_ckpt(config):
set_mindio_server_info(mindio_pool_capacity)

if not os.path.realpath(ckpt_path) or not os.path.exists(ckpt_path):
raise FileNotFoundError(f"The load_checkpoint must be correct, but get {ckpt_path}")
err_log = f"The load_checkpoint must be correct, but get {ckpt_path}"
logger.error(err_log)
raise FileNotFoundError(err_log)

preload_ok = False
if os.path.isfile(ckpt_path):
@@ -153,7 +155,9 @@ def preload_ckpt(config):
logger.info(f"MindIO preloading `{checkpoint_path}`...")
preload_ok = mindio_preload(checkpoint_path)
else:
raise ValueError(f"{ckpt_path} is not a valid path to load checkpoint when auto_trans_ckpt is False.")
err_msg = f"{ckpt_path} is not a valid path to load checkpoint when auto_trans_ckpt is False."
logger.error(err_msg)
raise ValueError(err_msg)

if preload_ok:
logger.info("MindIO preload checkpoint successfully!")
@@ -197,8 +201,10 @@ def check_runner_config(config, dataset):
if config.runner_config.sink_mode:
if config.runner_config.sink_size != -1:
if config.runner_config.sink_size <= 0:
raise ValueError("per epoch size must be more than 0 or equal to -1, "
f"but get {config.runner_config.sink_size}")
err_log = (f"per epoch size must be more than 0 or equal to -1, "
f"but get {config.runner_config.sink_size}")
logger.error(err_log)
raise ValueError(err_log)
if data_size < config.runner_config.sink_size:
logger.warning("The data size %s (get from dataset.get_dataset_size()) is smaller "
"than the sink_size %s (get from config.runner_config.sink_size), "
@@ -320,7 +326,9 @@ def get_distribute_checkpoint_path(checkpoint_dir, rank_id=None, ckpt_format='ck
logger.info("Your load_checkpoint is file, it will be load in network.")
distribute_checkpoint_path = checkpoint_dir
else:
raise FileNotFoundError(f"{checkpoint_dir} is not found.")
err_msg = f"{checkpoint_dir} is not found."
logger.error(err_msg)
raise FileNotFoundError(err_msg)
return distribute_checkpoint_path


@@ -348,7 +356,9 @@ def load_resume_context_from_checkpoint(config, dataset):
"""resume training, load training info from checkpoint to config"""
if not os.path.realpath(config.load_checkpoint) or \
not os.path.exists(config.load_checkpoint):
raise FileNotFoundError(f"The load_checkpoint must be correct, but get {config.load_checkpoint}")
err_log = f"The load_checkpoint must be correct, but get {config.load_checkpoint}"
logger.error(err_log)
raise FileNotFoundError(err_log)

if os.path.isdir(config.load_checkpoint):
# When graceful exit is enabled or auto checkpoint transformation is disabled,
@@ -375,6 +385,10 @@ def load_resume_context_from_checkpoint(config, dataset):
else:
checkpoint_tmp = os.path.join(config.load_checkpoint, f"rank_{rank_id}",
replace_rank_id_in_ckpt_name(config.resume_training, rank_id))
if not os.path.isfile(checkpoint_tmp):
err_msg = f"{checkpoint_tmp} is not found!"
logger.error(err_msg)
raise FileNotFoundError(err_msg)
resume_dict = load_checkpoint(
checkpoint_tmp,
choice_func=lambda x: x in ["loss_scale", "epoch_num", "step_num", "global_batch_size"],
@@ -512,15 +526,20 @@ def transform_and_load_checkpoint(config, model, network, dataset, optimizer=Non


def check_checkpoint_config_valid(config):
"""Check valid load checkpoint path in config."""
# check valid load checkpoint path
if not config.only_save_strategy and (not os.path.realpath(config.load_checkpoint) or
not os.path.exists(config.load_checkpoint)):
raise FileNotFoundError(f"The load_checkpoint must be correct, but get {config.load_checkpoint}")
err_msg = f"The load_checkpoint must be correct, but get {config.load_checkpoint}"
logger.error(err_msg)
raise FileNotFoundError(err_msg)

# check valid format
if config.load_ckpt_format is not None and config.load_ckpt_format not in CkptFormat.support_type():
raise ValueError(
f"config.load_ckpt_format only support for 'ckpt' or 'safetensors', but got {config.load_ckpt_format}.")
err_msg = ("config.load_ckpt_format only support for 'ckpt' or 'safetensors', "
f"but got {config.load_ckpt_format}.")
logger.error(err_msg)
raise ValueError(err_msg)


def check_path_include_total_ckpt(path):
@@ -570,7 +589,9 @@ def load_slora_ckpt(checkpoint_dict, config, network):
logger.info("............Start load slora checkpoint ............")
adapter_path = os.path.join(pet_config.adapter_path, "lora_adapter.json")
if not os.path.exists(adapter_path):
raise FileNotFoundError(f"The adapter_path must be correct, but get {adapter_path}")
err_msg = f"The adapter_path must be correct, but get {adapter_path}"
logger.error(err_msg)
raise FileNotFoundError(err_msg)
with open(adapter_path, 'r', encoding='utf-8') as file:
path_dict = json.load(file)
adapter_list = []
@@ -686,8 +707,9 @@ def get_load_checkpoint_result(config):
else:
checkpoint_dict = load_distributed_checkpoint(config.load_checkpoint)
else:
raise ValueError(f"{config.load_checkpoint} is not a valid path to load checkpoint "
f"when auto_trans_ckpt is False.")
err_msg = f"{config.load_checkpoint} is not a valid path to load checkpoint when auto_trans_ckpt is False."
logger.error(err_msg)
raise ValueError(err_msg)
return checkpoint_dict if checkpoint_dict else checkpoint_future




+ 56
- 22
mindformers/utils/load_checkpoint_utils.py View File

@@ -91,9 +91,13 @@ def get_load_path_after_hf_convert(config, network):
def _check_checkpoint_path(path):
"""check checkpoint path."""
if not isinstance(path, str) or isinstance(path, os.PathLike):
raise ValueError(f"config.load_checkpoint must be a `str`, but got `{path}` as type `{type(path)}`.")
err_msg = f"config.load_checkpoint must be a `str`, but got `{path}` as type `{type(path)}`."
logger.error(err_msg)
raise ValueError(err_msg)
if not os.path.exists(path):
raise FileNotFoundError(f"config.load_checkpoint `{path}` does not exist.")
err_msg = f"config.load_checkpoint `{path}` does not exist."
logger.error(err_msg)
raise FileNotFoundError(err_msg)

if path[-1] == '/': # remove last '/' in path
return path[:-1]
@@ -113,7 +117,9 @@ def _get_checkpoint_mode(config):

# check path is dir
if not os.path.isdir(checkpoint_path):
raise ValueError("Provided path is neither a file nor a directory.")
err_msg = "Provided path is neither a file nor a directory."
logger.error(err_msg)
raise ValueError(err_msg)

dir_files = os.listdir(checkpoint_path)
if any(folder_name.startswith('rank_') for folder_name in dir_files):
@@ -122,7 +128,9 @@ def _get_checkpoint_mode(config):
if any(file_name.endswith(config.load_ckpt_format) for file_name in dir_files):
return CheckpointFileMode.MULTI_CHECKPOINT_FILE.value

raise ValueError("not support mode: no valid checkpoint files found")
err_msg = "not support mode: no valid checkpoint files found"
logger.error(err_msg)
raise ValueError(err_msg)


def _get_src_strategy(config):
@@ -139,8 +147,10 @@ def _get_src_strategy(config):
src_strategy_path = os.path.join(upper_dir, 'strategy')
logger.info(f"src_strategy_path_or_dir is empty, load source strategy from {src_strategy_path}.")
else:
raise ValueError("when use checkpoint after train/finetune, src_strategy_path_or_dir should be set "
"as a folder contained strategy ckpt files.")
err_msg = ("when use checkpoint after train/finetune, src_strategy_path_or_dir should be set "
"as a folder contained strategy ckpt files.")
logger.error(err_msg)
raise ValueError(err_msg)
logger.info(f"load source strategy from {src_strategy_path}.")
return src_strategy_path

@@ -246,7 +256,9 @@ def _get_src_file(checkpoint_dir, checkpoint_name=None, ckpt_format='ckpt'):
else:
ckpt_path = get_last_checkpoint(checkpoint_rank_dir, ckpt_format)
if not os.path.exists(ckpt_path):
raise FileNotFoundError(f"{ckpt_path} is not found.")
err_msg = f"{ckpt_path} is not found."
logger.error(err_msg)
raise FileNotFoundError(err_msg)
return ckpt_path


@@ -262,7 +274,9 @@ def load_checkpoint_with_safetensors(config, model, network, input_data, do_eval

pet_config = config.model.model_config.get("pet_config")
if pet_config and pet_config.pet_type == "slora" and network.lora_list:
raise ValueError(f"slora only support .ckpt file, {config.load_ckpt_format} file will be compatible soon.")
err_msg = f"slora only support .ckpt file, {config.load_ckpt_format} file will be compatible soon."
logger.error(err_msg)
raise ValueError(err_msg)
ckpt_file_mode = _get_checkpoint_mode(config)
validate_config_with_file_mode(ckpt_file_mode, config.use_parallel, config.auto_trans_ckpt)
# reduce compile time in prediction
@@ -366,18 +380,26 @@ def validate_config_with_file_mode(ckpt_file_mode, use_parallel, auto_trans_ckpt
"""validate use_parallel and auto_trans_ckpt config with different file mode"""
if ckpt_file_mode == CheckpointFileMode.SINGLE_CHECKPOINT_FILE.value:
if use_parallel:
raise ValueError("When load checkpoint is a single file and use_parallel is True, please change "
"load_checkpoint in yaml from file name to the directory where only this file is located.")
err_msg = ("When load checkpoint is a single file and use_parallel is True, please change "
"load_checkpoint in yaml from file name to the directory where only this file is located.")
logger.error(err_msg)
raise ValueError(err_msg)
elif ckpt_file_mode == CheckpointFileMode.MULTI_CHECKPOINT_FILE.value:
if use_parallel and not auto_trans_ckpt:
raise ValueError("When load checkpoint is complete and use_parallel is True, please set auto_trans_ckpt: "
"True to enable automatic slicing function.")
err_msg = ("When load checkpoint is complete and use_parallel is True, please set auto_trans_ckpt: "
"True, to enable automatic slicing function.")
logger.error(err_msg)
raise ValueError(err_msg)
elif ckpt_file_mode == CheckpointFileMode.MULTI_CHECKPOINT_FILE_WITH_RANK_ID.value:
if not use_parallel:
raise ValueError("when input checkpoint file is rank dir, Please set use_parallel: True to enable "
"distributed ckpt load.")
err_msg = ("when input checkpoint file is rank dir, Please set use_parallel: "
"True, to enable distributed ckpt load.")
logger.error(err_msg)
raise ValueError(err_msg)
else:
raise ValueError("not support mode: no valid checkpoint files found")
err_msg = "not support mode: no valid checkpoint files found"
logger.error(err_msg)
raise ValueError(err_msg)


def unify_safetensors(src_checkpoint, src_strategy_path, unified_path, use_parallel=False,
@@ -422,6 +444,7 @@ def load_safetensors_checkpoint(config, load_checkpoint_files, network, strategy
logger.info("......obtain name map for HF safetensors.....")
name_map = origin_network.obtain_name_map(load_checkpoint_files)
except Exception as e:
logger.error(f"Please complete abstract function obtain_name_map. Details: {e}")
raise TypeError(f"Please complete abstract function obtain_name_map. Details: {e}") from e
if is_main_rank():
_convert_index_json(load_ckpt_path, load_ckpt_path, origin_network.convert_map_dict, False)
@@ -439,7 +462,9 @@ def load_safetensors_checkpoint(config, load_checkpoint_files, network, strategy
hyper_param_file = os.path.join(load_ckpt_path, 'hyper_param.safetensors')
if optimizer and config.resume_training:
if not os.path.exists(hyper_param_file):
raise FileNotFoundError(rf"No hyper_param.safetensors in given dir: {load_ckpt_path}")
err_msg = rf"No hyper_param.safetensors in given dir: {load_ckpt_path}"
logger.error(err_msg)
raise FileNotFoundError(err_msg)
logger.info("......Start load hyper param into optimizer......")
hyper_param_dict = ms.load_checkpoint(ckpt_file_name=hyper_param_file, format='safetensors')
update_global_step(config, hyper_param_dict)
@@ -505,7 +530,9 @@ def process_hf_checkpoint(model, output_dir=None, load_checkpoint=None):

# If the child process exits abnormally, the main process should also throw an exception.
if p.exitcode != 0:
raise RuntimeError("convert HuggingFace weight failed.")
err_msg = "convert HuggingFace weight failed."
logger.error(err_msg)
raise RuntimeError(err_msg)

# If used parallel mode, other cards are waiting for the main card to complete the weight conversion.
barrier_world("for the main rank to convert HuggingFace weight...")
@@ -518,11 +545,14 @@ def process_hf_checkpoint(model, output_dir=None, load_checkpoint=None):
def get_last_checkpoint(checkpoint_dir, ckpt_format='ckpt'):
"""get last checkpoint for resuming or finetune."""
if not os.path.isdir(checkpoint_dir):
raise NotADirectoryError(
err_msg = (
f"{checkpoint_dir} is not a real directory,"
f"When distributed loads are sliced weights,"
f"load_checkpoint should be a checkpoint directory containing the directory of rank_{{0-*}},"
f"The directory structure is as follows: **checkpoint_root_dir/rank_{{0-*}}/**.{ckpt_format}")
logger.error(err_msg)
raise NotADirectoryError(err_msg)

output_checkpoint_path = [
checkpoint
for checkpoint in os.listdir(checkpoint_dir)
@@ -562,11 +592,15 @@ def validate_qkv_concat(model_cls_or_instance, qkv_concat_config, load_checkpoin
break

if is_qkv_concat and not qkv_concat_config:
raise ValueError("The qkv concat check failed! The qkv in the model weights has been concatenated,"
" but qkv_concat is set to false.")
err_msg = ("The qkv concat check failed! The qkv in the model weights has been concatenated, "
"but qkv_concat is set to false.")
logger.error(err_msg)
raise ValueError(err_msg)
if not is_qkv_concat and qkv_concat_config:
raise ValueError("The qkv concat check failed! The qkv in the model weights has been not concatenated,"
" but qkv_concat is set to true.")
err_msg = ("The qkv concat check failed! The qkv in the model weights has been not concatenated, "
"but qkv_concat is set to true.")
logger.error(err_msg)
raise ValueError(err_msg)
if is_qkv_concat and qkv_concat_config:
logger.info("The qkv concat check succeed! The qkv in the model weights has been concatenated and "
"qkv_concat is set to true.")


+ 3
- 1
mindformers/utils/resume_ckpt_utils.py View File

@@ -57,7 +57,9 @@ def load_resume_checkpoint(load_checkpoint_path, remove_redundancy, load_ckpt_fo
"""resume training, load training info from checkpoint to config"""
if not os.path.realpath(load_checkpoint_path) or \
not os.path.exists(load_checkpoint_path):
raise FileNotFoundError(f"The load_checkpoint_path must be correct, but get {load_checkpoint_path}")
err_msg = f"The load_checkpoint_path must be correct, but get {load_checkpoint_path}"
logger.error(err_msg)
raise FileNotFoundError(err_msg)

if os.path.isdir(load_checkpoint_path):
hyper_param_file = os.path.join(load_checkpoint_path, 'hyper_param.safetensors')


+ 1
- 1
mindformers/wrapper/wrapper.py View File

@@ -207,7 +207,7 @@ class MFTrainOneStepCell(nn.TrainOneStepWithLossScaleCell):
**kwargs (Any): Additional parameters.

Inputs:
- **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
- **\\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \\ldots)`.

Outputs:
Tuple of 5 or 7 Tensor, the loss, overflow flag, current loss scale value, learning rate,


+ 6
- 3
tests/st/test_safetensors/test_checkpoint_utils.py View File

@@ -102,8 +102,6 @@ def mock_file():
return mock_f




class TestCommonCheckpointMethod:
"""A test class for testing common methods"""

@@ -485,7 +483,8 @@ class TestCommonCheckpointMethod:
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test__get_src_file(self):
@patch('mindformers.utils.load_checkpoint_utils.logger')
def test__get_src_file(self, mock_logger):
"""test _get_src_file function"""
# setup mocks using context managers
with patch('os.path.exists') as mock_exists, \
@@ -510,6 +509,9 @@ class TestCommonCheckpointMethod:
with pytest.raises(FileNotFoundError):
_get_src_file("/test", "non_existent.ckpt", "ckpt")

# Verify that logger.error has been called.
mock_logger.error.assert_called()

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@@ -813,6 +815,7 @@ class TestCommonCheckpointMethod:
optimizer=optimizer)
mock_load_safetensors_checkpoint.assert_called_once()


class TestBuildModel:
"""A test class for testing build_model"""
runner_config = {'sink_mode': True, 'epochs': 1, 'sink_size': 1}


+ 9
- 8
tests/st/test_ut/test_api_compatibility.py View File

@@ -101,7 +101,7 @@ def is_not_compatibility(base_str, new_str):
def set_failure_list(api_str, value, signature, failure_list):
"""set failure info list"""
failure_list.append(f"# {api_str}:")
failure_list.append(f" - function signature is different: ")
failure_list.append(" - function signature is different: ")
failure_list.append(f" - the base signature is {value}.")
failure_list.append(f" - now it is {signature}.")

@@ -170,12 +170,12 @@ def api_signature(obj, api_str, content, base_schema, failure_list, is_update=Fa
else:
tmp_len = -1
signature = None
for i in range(len(signature_list)):
if signature_list[i] == "(*args, **kwargs)":
for _, sig in enumerate(signature_list):
if sig == "(*args, **kwargs)":
continue
if len(signature_list[i]) > tmp_len:
tmp_len = len(signature_list[i])
signature = signature_list[i]
if len(sig) > tmp_len:
tmp_len = len(sig)
signature = sig
else:
signature = str(inspect.signature(obj))

@@ -293,7 +293,8 @@ class TestApiStability:
def check_one_element(elem, mod_name, mod, is_public):
obj = getattr(mod, elem)
if hasattr(obj, "__module__"):
if obj.__module__ not in ['sentencepiece_model_pb2']: # cannot use __import__ module list
# cannot use __import__ module list
if obj.__module__ not in ['sentencepiece_model_pb2', 'node_strategy_pb2']:
mod_source = str(__import__(obj.__module__))
if "mindformers" not in mod_source:
return
@@ -337,4 +338,4 @@ class TestApiStability:
with open(self.api_json_path, "w", encoding="utf-8") as w:
w.write(json.dumps(self.content, ensure_ascii=False, indent=4))

assert not self.is_update, f"self.is_update should be set to False"
assert not self.is_update, "self.is_update should be set to False"

+ 1279
- 1
tests/st/test_ut/test_core/test_callback/test_checkpoint_monitor.py View File

@@ -13,13 +13,22 @@
# limitations under the License.
# ============================================================================
"""Test test_checkpoint_monitor.py"""
import json
import os
import unittest
import shutil
import tempfile
from unittest.mock import Mock, patch

import numpy as np
import pytest
from mindspore import ModelCheckpoint
from mindformers.core.callback.callback import CheckpointMonitor

# pylint: disable=protected-access
# pylint: disable=unused-argument # for mock logic


class TestCheckpointMonitor(unittest.TestCase):
"""Test cases for CheckpointMonitor class"""

@@ -102,5 +111,1274 @@ class TestCheckpointMonitor(unittest.TestCase):

self.assertTrue(monitor.save_trainable_params)


class TestCheckpointMonitorExtended:
"""Extended tests for CheckpointMonitor"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.get_output_subpath', return_value='/tmp/checkpoint')
def test_checkpoint_monitor_init(self, *mocks):
"""Test CheckpointMonitor initialization"""

monitor = CheckpointMonitor(
prefix='TEST',
save_checkpoint_steps=100,
global_batch_size=32
)

assert monitor.global_batch_size == 32
assert monitor.save_checkpoint_steps == 100

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.get_output_subpath', return_value='/tmp/checkpoint')
def test_checkpoint_monitor_remove_redundancy(self, *mocks):
"""Test CheckpointMonitor with remove_redundancy parameter"""

monitor = CheckpointMonitor(
prefix='TEST',
remove_redundancy=True
)

assert monitor.need_remove_redundancy

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.get_output_subpath', return_value='/tmp/checkpoint')
def test_checkpoint_monitor_save_network_params(self, *mocks):
"""Test CheckpointMonitor with save_network_params parameter"""

monitor = CheckpointMonitor(
prefix='TEST',
save_network_params=True
)

assert monitor.save_network_params

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.get_output_subpath', return_value='/tmp/checkpoint')
def test_checkpoint_monitor_save_trainable_params(self, *mocks):
"""Test CheckpointMonitor with save_trainable_params parameter"""

monitor = CheckpointMonitor(
prefix='TEST',
save_trainable_params=True
)

assert monitor.save_trainable_params


class TestCheckpointMonitorHelpers:
"""Test CheckpointMonitor helper methods"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_output_subpath')
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
def test_record_last_ckpt_to_json(self, mock_get_real_rank, mock_get_output_subpath):
"""Test record_last_ckpt_to_json method"""

with tempfile.TemporaryDirectory() as tmpdir:
mock_get_output_subpath.return_value = tmpdir

monitor = CheckpointMonitor(prefix='TEST')
monitor._directory = tmpdir
monitor.meta_json = os.path.join(tmpdir, 'meta.json')

monitor.record_last_ckpt_to_json(5, 10, 'test_ckpt.ckpt')

# Verify file was created
assert os.path.exists(monitor.meta_json)

# Verify content
with open(monitor.meta_json, 'r', encoding='utf-8') as f:
data = json.load(f)
assert data['last_epoch'] == 5
assert data['last_step'] == 10
assert data['last_ckpt_file'] == 'test_ckpt.ckpt'


class TestCheckpointMonitorSaveAndHealth:
"""Test CheckpointMonitor save and health check methods"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.get_rank', return_value=0)
@patch('mindformers.core.callback.callback.get_group_size', return_value=8)
@patch('mindformers.core.callback.callback.auto_parallel_context')
@patch('mindformers.core.callback.callback.get_embedding_info', return_value=1.5)
@patch('mindformers.core.callback.callback.AllReduceNet')
@patch('mindformers.core.callback.callback.create_group')
@patch('mindformers.core.callback.callback.ms.get_auto_parallel_context',
return_value='semi_auto_parallel')
@patch('mindformers.core.callback.callback.ms.set_auto_parallel_context')
def test_get_checkpoint_health_info_healthy(
self, mock_set_context, mock_get_parallel_mode,
mock_create_group, mock_allreduce_net, mock_get_embedding,
mock_auto_context, mock_group_size, mock_get_rank, mock_real_rank):
"""Test get_checkpoint_health_info when checkpoint is healthy"""

mock_auto_context.return_value.get_pipeline_stages.return_value = 2

# Mock AllReduce result
mock_allreduce_instance = Mock()
mock_health_tensor = Mock()
mock_health_tensor.asnumpy.return_value = np.array([1.0])
mock_allreduce_instance.return_value = mock_health_tensor
mock_allreduce_net.return_value = mock_allreduce_instance

# Mock create_group to avoid distributed communication initialization
mock_create_group.return_value = None

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt',
embedding_size=4096,
embedding_local_norm_threshold=1.0,
use_checkpoint_health_monitor=True
)

cb_params = Mock()
cb_params.cur_step_num = 10

is_health = monitor.get_checkpoint_health_info(cb_params)

assert is_health == 1

# Verify create_group was called
mock_create_group.assert_called_once()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.time.time', return_value=1000.0)
@patch('mindformers.core.callback.callback.os.path.exists', return_value=True)
@patch('mindformers.core.callback.callback.os.path.getmtime', return_value=1005.0)
def test_print_savetime(self, mock_getmtime, mock_exists, mock_time):
"""Test print_savetime method"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt'
)

# Setup save_info_list
monitor.save_info_list[10] = {
'ckpt': {
'save_start_time': 1000.0,
'ckpt_file_path': '/path/to/ckpt.ckpt',
'save_end_time': None
},
'network': {
'save_start_time': None,
'ckpt_file_path': None,
'save_end_time': None
},
'trainable_params': {
'save_start_time': None,
'ckpt_file_path': None,
'save_end_time': None
}
}

monitor.print_savetime(10, 100)

# Verify the save_end_time was set
assert monitor.save_info_list[10]['ckpt']['save_end_time'] is not None


class TestCheckpointMonitorStepEndAndTrainEnd:
"""Test CheckpointMonitor.step_end and on_train_end"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
def test_step_end_legacy_format(self, mock_real_rank):
"""Test step_end with legacy format"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt',
use_legacy_format=True
)

run_context = Mock()
cb_params = Mock()
cb_params.cur_step_num = 10
run_context.original_args.return_value = cb_params

# Should call parent's step_end
with patch.object(CheckpointMonitor.__bases__[0], 'step_end') as mock_parent_step_end:
monitor.step_end(run_context)
mock_parent_step_end.assert_called_once()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
def test_on_train_end_new_format(self, mock_real_rank):
"""Test on_train_end with new format"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt',
use_legacy_format=False,
async_save=False
)

monitor._save_megatron_ckpt_file_format = Mock()

run_context = Mock()
cb_params = Mock()
cb_params.cur_step_num = 100
run_context.original_args.return_value = cb_params

monitor.on_train_end(run_context)

monitor._save_megatron_ckpt_file_format.assert_called_once_with(cb_params)


class TestCheckpointMonitorSaveCkpt:
"""Test CheckpointMonitor._save_ckpt method"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.check_arf_status', return_value=False)
def test_save_ckpt_skip_same_step(self, mock_arf, mock_real_rank):
"""Test _save_ckpt skips when called twice for same step"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt'
)

cb_params = Mock()
cb_params.cur_step_num = 10

monitor._last_triggered_step = 10
monitor.save_checkpoint = Mock()

# Should return early without saving
monitor._save_ckpt(cb_params, force_to_save=False)

monitor.save_checkpoint.assert_not_called()


class TestCheckpointMonitorSaveCheckpoint:
"""Test CheckpointMonitor.save_checkpoint and related methods"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.logger')
@patch('mindformers.core.callback.callback.get_rank', return_value=0)
@patch('mindformers.core.callback.callback.time.time', return_value=1000.0)
@patch('mindformers.core.callback.callback.os.path.join',
side_effect=lambda *args: '/'.join(args))
@patch('mindformers.core.callback.callback.set_safe_mode_for_file_or_dir')
@patch('mindformers.core.callback.callback.json.dump')
@patch('mindformers.core.callback.callback.json.load', return_value=[])
@patch('builtins.open', create=True)
@patch('mindformers.core.callback.callback.os.path.exists', return_value=False)
def test_save_checkpoint_with_health_monitor(self, mock_exists, mock_open_file, mock_json_load,
mock_json_dump, mock_safe_mode, mock_path_join,
mock_time, mock_get_rank, mock_logger, mock_real_rank):
"""Test save_checkpoint with health monitoring enabled"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt',
use_checkpoint_health_monitor=True,
health_ckpts_record_dir='./health'
)

monitor.save_info_list[10] = {
'ckpt': {'save_start_time': None, 'ckpt_file_path': None, 'save_end_time': None},
'network': {'save_start_time': None, 'ckpt_file_path': None, 'save_end_time': None},
'trainable_params': {
'save_start_time': None, 'ckpt_file_path': None,
'save_end_time': None}
}

cb_params = Mock()
cb_params.cur_step_num = 10
cb_params.cur_epoch_num = 1
cb_params.batch_num = 100
cb_params.optimizer = Mock()
cb_params.optimizer.global_step = 10
cb_params.train_network = Mock()

monitor.get_checkpoint_health_info = Mock(return_value=1)
monitor.remove_redundancy = Mock()
monitor._manager = Mock()
monitor._manager.ckpoint_num = 0

# Should not raise error
monitor.save_checkpoint(cb_params)

# Verify health info was checked
monitor.get_checkpoint_health_info.assert_called_once()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.time.time', return_value=1000.0)
def test_filter_ckpt_not_save(self, mock_time, mock_real_rank):
"""Test _filter_ckpt_not_save method"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt'
)

monitor.filter_list = ['optimizer', 'temp']

# Should filter out parameters starting with filter_list items
assert not monitor._filter_ckpt_not_save('optimizer.weight', monitor.filter_list)
assert not monitor._filter_ckpt_not_save('model.temp', monitor.filter_list)
assert monitor._filter_ckpt_not_save('model.weight', monitor.filter_list)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.time.time', return_value=1000.0)
@patch('mindformers.core.callback.callback.context.get_auto_parallel_context',
return_value='stand_alone')
@patch('mindformers.core.callback.callback.ms.save_checkpoint')
def test_remove_redundancy_standalone(
self, mock_save_ckpt, mock_context, mock_time, mock_real_rank):
"""Test remove_redundancy in standalone mode"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt',
remove_redundancy=True
)

network = Mock()
cur_file = './test.ckpt'
append_dict = {}

# In standalone mode, should use simple save
monitor.remove_redundancy(network, cur_file, append_dict, None)

mock_save_ckpt.assert_called_once()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.time.time', return_value=1000.0)
def test_get_cur_dp(self, mock_time, mock_real_rank):
"""Test _get_cur_dp method"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt'
)

# Test with simple redundancy dict
param_redundancy_dict = {
'layer.weight': [(0, 1, 2, 3), (4, 5, 6, 7)],
'layer.bias': [(0, 1, 2, 3), (4, 5, 6, 7)]
}

cur_dp = monitor._get_cur_dp(0, param_redundancy_dict)

# Should return the tuple containing rank 0
assert 0 in cur_dp


class TestCheckpointMonitorTFTSave:
"""Test CheckpointMonitor TFT save methods"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.ms.save_checkpoint')
def test_tft_save_ckpt(self, mock_save_ckpt, mock_real_rank):
"""Test _tft_save_ckpt method"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt',
remove_redundancy=True
)

param_layout_set = {'layer.weight', 'layer.bias'}
save_param_names = {'layer.weight'}
cur_file = './test_rank_0/ckpt_1.ckpt'
append_dict = {'epoch_num': 1}
network = Mock()

monitor._tft_save_ckpt(param_layout_set, save_param_names, cur_file, append_dict, network)

mock_save_ckpt.assert_called_once()
# Verify choice_func filters correctly
call_args = mock_save_ckpt.call_args
choice_func = call_args[1].get('choice_func') if call_args[1] else None
if choice_func:
# layer.weight is in save_param_names
assert choice_func('layer.weight')
# layer.bias is in param_layout_set but not in save_param_names
assert not choice_func('layer.bias')

@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.context.get_auto_parallel_context',
return_value='semi_auto_parallel')
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.logger')
@patch('mindformers.core.callback.callback.re.sub')
def test_do_remove_redundancy_for_tft(
self, mock_re_sub, mock_logger, mock_context, mock_real_rank):
"""Test _do_remove_redundancy_for_tft method"""

def re_sub_side_effect(p, r, s):
if 'rank_' in p:
return s.replace('rank_0', f'rank_{r.split("_")[1]}')
return s

mock_re_sub.side_effect = re_sub_side_effect

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt',
remove_redundancy=True,
checkpoint_format='ckpt'
)

monitor._tft_save_ckpt = Mock()
monitor.record_last_ckpt_to_json = Mock()

redundancy_info = (
0, # rank_id
{'layer.weight': [(0, 1)]}, # param_redundancy_dict
{0: {'layer.weight'}, 1: {'layer.weight'}}, # single_params
{'layer.weight': Mock()} # param_layout
)
cur_file = './test_rank_0/ckpt_1_10.ckpt'
network = Mock()
append_dict = {'epoch_num': 1}

monitor._do_remove_redundancy_for_tft(redundancy_info, cur_file, network, append_dict)

# Should call _tft_save_ckpt for each rank in cur_dp
assert monitor._tft_save_ckpt.called
# Should set __exception_save__ flag
assert '__exception_save__' in append_dict

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.ms.save_checkpoint')
def test_tft_save_ckpt_with_filter_list(self, mock_save_ckpt, mock_real_rank):
"""Test _tft_save_ckpt filters out items in filter_list"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt',
remove_redundancy=True
)

param_layout_set = set()
save_param_names = {'layer.weight', 'accu_grads.weight'}
cur_file = './test.ckpt'
append_dict = {}
network = Mock()

monitor._tft_save_ckpt(param_layout_set, save_param_names, cur_file, append_dict, network)

# Verify choice_func filters out filter_list items
call_args = mock_save_ckpt.call_args
choice_func = call_args[1].get('choice_func') if call_args[1] else None
if choice_func:
# accu_grads should be filtered out
assert not choice_func('accu_grads.weight')


class TestCheckpointMonitorSkipTrainableParams:
"""Test CheckpointMonitor._check_if_skip_trainable_params method"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.context.get_context', return_value=0) # GRAPH_MODE
@patch('mindformers.core.callback.callback.ms.get_auto_parallel_context',
return_value='semi_auto_parallel')
def test_skip_trainable_params_graph_mode_parallel(
self, mock_parallel_ctx, mock_get_ctx, mock_real_rank):
"""Test _check_if_skip_trainable_params in graph mode with auto parallel"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt'
)

# Mock parameter that is not sliced
mock_param = Mock()
mock_param.sliced = False
mock_param.has_init = False
mock_param.param_info = Mock()
mock_param.param_info.is_pipeline_shared_param = False

result = monitor._check_if_skip_trainable_params(mock_param)
# Should skip because sliced=False in graph mode + auto parallel
assert result

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.context.get_context', return_value=0) # GRAPH_MODE
@patch('mindformers.core.callback.callback.ms.get_auto_parallel_context',
return_value='semi_auto_parallel')
def test_skip_trainable_params_has_init(self, mock_parallel_ctx, mock_get_ctx, mock_real_rank):
"""Test _check_if_skip_trainable_params with has_init=True"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt'
)

# Mock parameter with has_init=True
mock_param = Mock()
mock_param.sliced = True
mock_param.has_init = True
mock_param.param_info = Mock()
mock_param.param_info.is_pipeline_shared_param = False

result = monitor._check_if_skip_trainable_params(mock_param)
# Should skip because has_init=True in graph mode + auto parallel
assert result

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.context.get_context', return_value=0) # GRAPH_MODE
@patch('mindformers.core.callback.callback.ms.get_auto_parallel_context',
return_value='semi_auto_parallel')
def test_skip_trainable_params_pipeline_shared(
self, mock_parallel_ctx, mock_get_ctx, mock_real_rank):
"""Test _check_if_skip_trainable_params with pipeline shared param"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt'
)

# Mock parameter that is pipeline shared
mock_param = Mock()
mock_param.sliced = True
mock_param.has_init = False
mock_param.param_info = Mock()
mock_param.param_info.is_pipeline_shared_param = True

result = monitor._check_if_skip_trainable_params(mock_param)
# Should skip because is_pipeline_shared_param=True
assert result

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.context.get_context',
return_value=1) # PYNATIVE_MODE
@patch('mindformers.core.callback.callback.ms.get_auto_parallel_context',
return_value='semi_auto_parallel')
def test_skip_trainable_params_pynative_mode(
self, mock_parallel_ctx, mock_get_ctx, mock_real_rank):
"""Test _check_if_skip_trainable_params in pynative mode"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt'
)

# Mock parameter
mock_param = Mock()
mock_param.sliced = False
mock_param.has_init = False
mock_param.param_info = Mock()
mock_param.param_info.is_pipeline_shared_param = False

result = monitor._check_if_skip_trainable_params(mock_param)
# Should not skip in pynative mode (not graph mode)
assert not result

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.context.get_context', return_value=0) # GRAPH_MODE
@patch('mindformers.core.callback.callback.ms.get_auto_parallel_context',
return_value='stand_alone')
def test_skip_trainable_params_standalone(
self, mock_parallel_ctx, mock_get_ctx, mock_real_rank):
"""Test _check_if_skip_trainable_params in standalone mode"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt'
)

# Mock parameter
mock_param = Mock()
mock_param.sliced = False
mock_param.has_init = False
mock_param.param_info = Mock()
mock_param.param_info.is_pipeline_shared_param = False

result = monitor._check_if_skip_trainable_params(mock_param)
# Should not skip in standalone mode
assert not result


class TestCheckpointMonitorRemoveRedundancyBranches:
"""Test CheckpointMonitor.remove_redundancy method branches"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_auto_parallel_context',
return_value=1) # 1 stage
@patch('mindformers.core.callback.callback.context.get_auto_parallel_context',
return_value='semi_auto_parallel')
@patch('mindformers.core.callback.callback.get_parameter_redundancy')
@patch('mindformers.core.callback.callback.remove_param_redundancy')
@patch('mindformers.core.callback.callback.ms.save_checkpoint')
@patch('mindformers.core.callback.callback.logger')
def test_remove_redundancy_with_param_layout(self, mock_logger, mock_save_ckpt,
mock_remove_redundancy, mock_get_redundancy,
mock_context, mock_pp, mock_group_size, mock_rank):
"""Test remove_redundancy with param_layout dict"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt',
remove_redundancy=True
)

network = Mock()
train_network = Mock()
train_network.parameter_layout_dict = {'layer.weight': Mock(), 'layer.bias': Mock()}

mock_get_redundancy.return_value = {'layer.weight': [(0, 1, 2, 3)]}
mock_remove_redundancy.return_value = {0: {'layer.weight'}}

cur_file = './test.ckpt'
append_dict = {}

monitor.remove_redundancy(network, cur_file, append_dict, train_network)

mock_save_ckpt.assert_called_once()
mock_logger.info.assert_called()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_auto_parallel_context', return_value=1)
@patch('mindformers.core.callback.callback.context.get_auto_parallel_context',
return_value='semi_auto_parallel')
@patch('mindformers.core.callback.callback.get_parameter_redundancy')
@patch('mindformers.core.callback.callback.remove_param_redundancy')
@patch('mindformers.core.callback.callback.ms.save_checkpoint')
@patch('mindformers.core.callback.callback.logger')
def test_remove_redundancy_without_param_layout(self, mock_logger, mock_save_ckpt,
mock_remove_redundancy, mock_get_redundancy,
mock_context, mock_pp, mock_group_size, mock_rank):
"""Test remove_redundancy without param_layout dict"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt',
remove_redundancy=True
)

network = Mock()
network.parameter_layout_dict = None

mock_get_redundancy.return_value = {'layer.weight': [(0, 1, 2, 3)]}
mock_remove_redundancy.return_value = {0: {'layer.weight'}}

cur_file = './test.ckpt'
append_dict = {}

monitor.remove_redundancy(network, cur_file, append_dict, None)

mock_save_ckpt.assert_called_once()
# Verify choice_func is used correctly
call_args = mock_save_ckpt.call_args
choice_func = call_args[1].get('choice_func') if call_args[1] else None
if choice_func:
assert choice_func('layer.weight')

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_auto_parallel_context', return_value=1)
@patch('mindformers.core.callback.callback.context.get_auto_parallel_context',
return_value='semi_auto_parallel')
@patch('mindformers.core.callback.callback.get_parameter_redundancy')
@patch('mindformers.core.callback.callback.remove_param_redundancy')
@patch('mindformers.core.callback.callback.logger')
def test_remove_redundancy_exception_save(self, mock_logger, mock_remove_redundancy, mock_get_redundancy,
mock_context, mock_pp, mock_group_size, mock_rank):
"""Test remove_redundancy with __exception_save__ in append_dict"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt',
remove_redundancy=True,
checkpoint_format='ckpt'
)

network = Mock()
train_network = Mock()
train_network.parameter_layout_dict = {'layer.weight': Mock()}

mock_get_redundancy.return_value = {'layer.weight': [(0,)]}
mock_remove_redundancy.return_value = {0: {'layer.weight'}}

monitor._do_remove_redundancy_for_tft = Mock()

cur_file = './test_rank_0/ckpt_1_10.ckpt'
append_dict = {'__exception_save__': True, 'epoch_num': 1}

monitor.remove_redundancy(network, cur_file, append_dict, train_network)

# Should call _do_remove_redundancy_for_tft
monitor._do_remove_redundancy_for_tft.assert_called_once()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_auto_parallel_context', return_value=1)
@patch('mindformers.core.callback.callback.context.get_auto_parallel_context',
return_value='semi_auto_parallel')
@patch('mindformers.core.callback.callback.get_parameter_redundancy')
@patch('mindformers.core.callback.callback.remove_param_redundancy')
@patch('mindformers.core.callback.callback.ms.save_checkpoint')
@patch('mindformers.core.callback.callback.logger')
def test_remove_redundancy_all_params_non_redundant_warning(self, mock_logger, mock_save_ckpt,
mock_remove_redundancy, mock_get_redundancy,
mock_context, mock_pp, mock_group_size, mock_rank):
"""Test remove_redundancy logs warning when all params are non-redundant"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt',
remove_redundancy=True
)

network = Mock()
train_network = Mock()
# param_layout.keys() returns same as save_param_names
param_layout_dict = {'layer.weight': Mock()}
train_network.parameter_layout_dict = param_layout_dict

mock_get_redundancy.return_value = {'layer.weight': [(0,)]}
mock_remove_redundancy.return_value = {0: param_layout_dict.keys()}

cur_file = './test.ckpt'
append_dict = {}

monitor.remove_redundancy(network, cur_file, append_dict, train_network)

# Should log warning about non-redundant params
mock_logger.warning.assert_called()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.context.get_auto_parallel_context',
return_value='stand_alone')
@patch('mindformers.core.callback.callback.ms.save_checkpoint')
def test_remove_redundancy_no_config(self, mock_save_ckpt, mock_context, mock_rank):
"""Test remove_redundancy when remove_redundancy config is False"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt',
remove_redundancy=False
)

network = Mock()
cur_file = './test.ckpt'
append_dict = {}

monitor.remove_redundancy(network, cur_file, append_dict, None)

# Should call ms.save_checkpoint without redundancy removal
mock_save_ckpt.assert_called_once()


class TestCheckpointMonitorSaveCheckpointNetwork:
"""Test CheckpointMonitor.save_checkpoint_network method"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.time.time', return_value=1000.0)
@patch('mindformers.core.callback.callback.os.makedirs')
@patch('mindformers.core.callback.callback.context.get_context',
return_value=1) # PYNATIVE_MODE
@patch('mindformers.core.callback.callback.ms.get_auto_parallel_context',
return_value='stand_alone')
def test_save_checkpoint_network_trainable_params(self, mock_parallel_ctx, mock_get_ctx,
mock_makedirs, mock_time, mock_rank):
"""Test save_checkpoint_network with save_trainable_params=True"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt',
save_trainable_params=True
)

monitor.save_info_list[10] = {
'ckpt': {'save_start_time': None, 'ckpt_file_path': None, 'save_end_time': None},
'network': {'save_start_time': None, 'ckpt_file_path': None, 'save_end_time': None},
'trainable_params': {
'save_start_time': None, 'ckpt_file_path': None,
'save_end_time': None}
}

# Create mock parameter
mock_param = Mock()
mock_param.name = 'layer.weight'
mock_param.sliced = True
mock_param.has_init = False
mock_param.param_info = Mock()
mock_param.param_info.is_pipeline_shared_param = False
mock_param.data = Mock()
mock_param.data.asnumpy.return_value = np.array([1.0, 2.0])

# Mock network - need optimizer to be non-None so save_obj becomes mock_network.network
mock_network = Mock()
mock_network.network = Mock()
mock_network.network.trainable_params.return_value = [mock_param]
mock_network.network.init_parameters_data = Mock()
mock_network.network.parameter_layout_dict = {}
mock_network.optimizer = Mock() # Non-None so save_obj = save_obj.network

cb_params = Mock()
cb_params.network = mock_network
cb_params.train_network = Mock()
cb_params.cur_step_num = 10
cb_params.cur_epoch_num = 1
cb_params.batch_num = 100

monitor._trainable_manager = Mock()
monitor._trainable_manager.ckpoint_num = 0
monitor.need_remove_extra_ckpt = False
monitor.remove_redundancy = Mock()

monitor.save_checkpoint_network(cb_params)

# Should call remove_redundancy with param list
monitor.remove_redundancy.assert_called_once()
# Verify save_start_time was set
assert monitor.save_info_list[10]['trainable_params']['save_start_time'] == 1000.0

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.time.time', return_value=1000.0)
@patch('mindformers.core.callback.callback.os.makedirs')
def test_save_checkpoint_network_network_params(self, mock_makedirs, mock_time, mock_rank):
"""Test save_checkpoint_network with save_network_params=True"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt',
save_network_params=True
)

monitor.save_info_list[10] = {
'ckpt': {'save_start_time': None, 'ckpt_file_path': None, 'save_end_time': None},
'network': {'save_start_time': None, 'ckpt_file_path': None, 'save_end_time': None},
'trainable_params': {
'save_start_time': None, 'ckpt_file_path': None,
'save_end_time': None}
}

mock_network = Mock()
mock_network.network = Mock()
mock_network.optimizer = None

cb_params = Mock()
cb_params.network = mock_network
cb_params.train_network = Mock()
cb_params.cur_step_num = 10
cb_params.cur_epoch_num = 1
cb_params.batch_num = 100

monitor._network_manager = Mock()
monitor._network_manager.ckpoint_num = 0
monitor.need_remove_extra_ckpt = True
monitor.remove_redundancy = Mock()

monitor.save_checkpoint_network(cb_params)

# Should call remove_redundancy
monitor.remove_redundancy.assert_called_once()
# Should remove oldest ckpt file since need_remove_extra_ckpt is True
monitor._network_manager.remove_oldest_ckpoint_file.assert_called_once()
# need_remove_extra_ckpt should be reset to False
assert not monitor.need_remove_extra_ckpt

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.time.time', return_value=1000.0)
@patch('mindformers.core.callback.callback.os.makedirs')
@patch('mindformers.core.callback.callback.context.get_context', return_value=0) # GRAPH_MODE
@patch('mindformers.core.callback.callback.ms.get_auto_parallel_context',
return_value='semi_auto_parallel')
@patch('mindformers.core.callback.callback._get_merged_param_data')
def test_save_checkpoint_network_trainable_params_merged(
self, mock_merged_data, mock_parallel_ctx,
mock_get_ctx, mock_makedirs,
mock_time, mock_rank):
"""Test save_checkpoint_network merges param data in auto parallel"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt',
save_trainable_params=True
)

monitor.save_info_list[10] = {
'ckpt': {'save_start_time': None, 'ckpt_file_path': None, 'save_end_time': None},
'network': {'save_start_time': None, 'ckpt_file_path': None, 'save_end_time': None},
'trainable_params': {
'save_start_time': None, 'ckpt_file_path': None,
'save_end_time': None}
}

# Create mock parameter that should be saved
mock_param = Mock()
mock_param.name = 'layer.weight'
mock_param.sliced = True
mock_param.has_init = False
mock_param.param_info = Mock()
mock_param.param_info.is_pipeline_shared_param = False
mock_param.data = Mock()
mock_param.data.asnumpy.return_value = np.array([1.0, 2.0])

# Mock network with parameter_layout_dict
mock_network = Mock()
mock_network.network = Mock()
mock_network.network.trainable_params.return_value = [mock_param]
mock_network.network.init_parameters_data = Mock()
mock_network.network.parameter_layout_dict = {'layer.weight': Mock()}
mock_network.optimizer = Mock() # Non-None so save_obj = save_obj.network

mock_merged_data.return_value = Mock()

cb_params = Mock()
cb_params.network = mock_network
cb_params.train_network = Mock()
cb_params.cur_step_num = 10
cb_params.cur_epoch_num = 1
cb_params.batch_num = 100

monitor._trainable_manager = Mock()
monitor._trainable_manager.ckpoint_num = 0
monitor.need_remove_extra_ckpt = True
monitor.remove_redundancy = Mock()

monitor.save_checkpoint_network(cb_params)

# Should call _get_merged_param_data
mock_merged_data.assert_called_once()
# Should remove oldest file
monitor._trainable_manager.remove_oldest_ckpoint_file.assert_called_once()


class TestCheckpointMonitorMegatronFormat:
"""Test CheckpointMonitor._save_megatron_ckpt_file_format method"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.get_all_sharded_tensor')
@patch('mindformers.core.callback.callback.save_checkpoint')
def test_save_megatron_ckpt_file_format_basic(
self, mock_save_ckpt, mock_get_sharded, mock_rank):
"""Test _save_megatron_ckpt_file_format basic functionality"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt',
use_legacy_format=False,
save_optimizer=True,
global_batch_size=64
)

monitor._last_triggered_step = 0
monitor._append_step_num = 0

mock_get_sharded.return_value = {'layer.weight': Mock()}

cb_params = Mock()
cb_params.cur_step_num = 10
cb_params.cur_epoch_num = 1
cb_params.network = Mock()
cb_params.network.network = Mock()
cb_params.network.optimizer = Mock()
cb_params.network.optimizer.global_step = 10
cb_params.network.network.parameters_dict.return_value = {'layer.weight': Mock()}
cb_params.net_outputs = (Mock(), Mock(), 1024) # loss, overflow, loss_scale

monitor._save_megatron_ckpt_file_format(cb_params)

# Should call save_checkpoint
mock_save_ckpt.assert_called_once()
# Should update _last_triggered_step
assert monitor._last_triggered_step == 10
# Verify common_info was set
assert monitor.common_info.step_num == 10
assert monitor.common_info.epoch_num == 1
assert monitor.common_info.global_batch_size == 64

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.get_all_sharded_tensor')
@patch('mindformers.core.callback.callback.save_checkpoint')
def test_save_megatron_ckpt_file_format_skip_same_step(
self, mock_save_ckpt, mock_get_sharded, mock_rank):
"""Test _save_megatron_ckpt_file_format skips when same step"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt',
use_legacy_format=False
)

monitor._last_triggered_step = 10 # Same as cur_step_num

cb_params = Mock()
cb_params.cur_step_num = 10

monitor._save_megatron_ckpt_file_format(cb_params)

# Should not call save_checkpoint
mock_save_ckpt.assert_not_called()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.get_all_sharded_tensor')
@patch('mindformers.core.callback.callback.save_checkpoint')
def test_save_megatron_ckpt_file_format_no_optimizer(
self, mock_save_ckpt, mock_get_sharded, mock_rank):
"""Test _save_megatron_ckpt_file_format without optimizer"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt',
use_legacy_format=False,
save_optimizer=False,
global_batch_size=32
)

monitor._last_triggered_step = 0
monitor._append_step_num = 5

mock_get_sharded.return_value = {}

cb_params = Mock()
cb_params.cur_step_num = 10
cb_params.cur_epoch_num = 2
cb_params.network = Mock()
cb_params.network.network = Mock()
cb_params.network.optimizer = Mock()
cb_params.network.optimizer.global_step = 15
cb_params.network.network.parameters_dict.return_value = {'layer.weight': Mock()}
cb_params.net_outputs = (Mock(), Mock()) # No loss_scale

monitor._save_megatron_ckpt_file_format(cb_params)

# Should call save_checkpoint with optimizer=None
mock_save_ckpt.assert_called_once()
call_kwargs = mock_save_ckpt.call_args[1]
assert call_kwargs.get('optimizer') is None
# Verify step_num includes append_step_num
assert monitor.common_info.step_num == 15 # 5 + 10
# loss_scale should be None since net_outputs doesn't have 3 elements
assert monitor.common_info.loss_scale is None

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.get_all_sharded_tensor')
@patch('mindformers.core.callback.callback.save_checkpoint')
def test_save_megatron_ckpt_file_format_with_async(
self, mock_save_ckpt, mock_get_sharded, mock_rank):
"""Test _save_megatron_ckpt_file_format with async save"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt',
use_legacy_format=False,
async_save=True,
save_optimizer=True
)

monitor._last_triggered_step = 0
monitor._append_step_num = 0

mock_get_sharded.return_value = {}

cb_params = Mock()
cb_params.cur_step_num = 20
cb_params.cur_epoch_num = 1
cb_params.network = Mock()
cb_params.network.network = Mock()
cb_params.network.optimizer = Mock()
cb_params.network.optimizer.global_step = 20
cb_params.network.network.parameters_dict.return_value = {}
cb_params.net_outputs = (Mock(), Mock(), 2048, Mock()) # Has loss_scale

monitor._save_megatron_ckpt_file_format(cb_params)

# Should pass async_save_manager to save_checkpoint
mock_save_ckpt.assert_called_once()
call_kwargs = mock_save_ckpt.call_args[1]
assert call_kwargs.get('async_save_manager') is not None
# Verify loss_scale was extracted
assert monitor.common_info.loss_scale == 2048.0

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.get_all_sharded_tensor')
@patch('mindformers.core.callback.callback.save_checkpoint')
def test_save_megatron_ckpt_file_format_filter_func(
self, mock_save_ckpt, mock_get_sharded, mock_rank):
"""Test _save_megatron_ckpt_file_format uses filter_func when save_optimizer=False"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt',
use_legacy_format=False,
save_optimizer=False
)

monitor._last_triggered_step = 0
monitor._append_step_num = 0

# Track what filter_func was passed to get_all_sharded_tensor and test it
filter_func_tested = [False]

def capture_and_test_filter_func(*args, **kwargs):
# Get the filter_func and test it immediately
filter_func = kwargs.get('filter_func')
assert filter_func is not None, "filter_func should be passed"
assert callable(filter_func), "filter_func should be callable"
# Filter should only allow params in parameters_dict
assert filter_func('layer.weight'), "filter_func should allow 'layer.weight'"
# 'optimizer.state' is not in parameters_dict, so filter returns False
assert not filter_func('optimizer.state'), "filter_func should reject 'optimizer.state'"
filter_func_tested[0] = True
return {}

mock_get_sharded.side_effect = capture_and_test_filter_func

cb_params = Mock()
cb_params.cur_step_num = 10
cb_params.cur_epoch_num = 1
cb_params.network = Mock()
cb_params.network.network = Mock()
cb_params.network.optimizer = Mock()
cb_params.network.optimizer.global_step = 10
# Only include network params, not optimizer state
cb_params.network.network.parameters_dict.return_value = {'layer.weight': Mock()}
cb_params.net_outputs = ()

monitor._save_megatron_ckpt_file_format(cb_params)

# Verify that the filter_func was actually tested
assert filter_func_tested[0], "filter_func should have been tested"


class TestCheckpointMonitorGetCurDpEdgeCases:
"""Test CheckpointMonitor._get_cur_dp edge cases"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
def test_get_cur_dp_no_matching_rank(self, mock_rank):
"""Test _get_cur_dp when rank not in any group"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt'
)

# Rank 0 not in any group - returns empty tuple (initial min_value)
param_redundancy_dict = {
'layer.weight': [(1, 2, 3, 4)],
}

cur_dp = monitor._get_cur_dp(0, param_redundancy_dict)

# When rank is not in any group, returns empty tuple (initial min_value)
assert cur_dp == ()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
def test_get_cur_dp_skip_accu_grads(self, mock_rank):
"""Test _get_cur_dp skips accu_grads parameters"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt'
)

param_redundancy_dict = {
'accu_grads.layer.weight': [(0, 1, 2, 3)],
'inputs.data': [(0, 1)],
'layer.weight': [(0, 1)],
}

cur_dp = monitor._get_cur_dp(0, param_redundancy_dict)

# Should skip accu_grads and inputs, use layer.weight group
assert 0 in cur_dp
assert 1 in cur_dp

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
def test_get_cur_dp_conflicting_groups(self, mock_rank):
"""Test _get_cur_dp with conflicting groups"""

monitor = CheckpointMonitor(
prefix='TEST',
directory='./test_ckpt'
)

# Conflicting groups: rank 0 is in (0, 1) but also in (0, 2, 3)
# where (0, 2, 3) is not a subset of (0, 1)
param_redundancy_dict = {
'layer1.weight': [(0, 1)],
'layer2.weight': [(0, 2, 3)],
}

cur_dp = monitor._get_cur_dp(0, param_redundancy_dict)

# Should return single rank when conflicts exist
assert cur_dp == (0,)

if __name__ == '__main__':
unittest.main()
pytest.main([__file__, '-v'])

+ 303
- 0
tests/st/test_ut/test_core/test_callback/test_cold_hot_expert_monitor.py View File

@@ -0,0 +1,303 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test callback.py using pytest framework."""
from unittest.mock import Mock, patch

import numpy as np
import pytest

from mindformers.core.callback.callback import ColdHotExpertMonitor

# pylint: disable=unused-argument # for mock logic


class TestColdHotExpertMonitorExtended:
"""Extended tests for ColdHotExpertMonitor"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('os.getenv')
def test_get_attribute_by_path(self, mock_getenv):
"""Test get_attribute_by_path method"""

def getenv_side_effect(x, default=None):
if x == "RANK_ID":
return "0"
if x == "RANK_SIZE":
return "8"
return default

mock_getenv.side_effect = getenv_side_effect

moe_config = Mock()
moe_config.update_step = 10
moe_config.expert_num = 8
moe_config.hot_expert_num = 1
moe_config.moe_module_name = "model.layers"

monitor = ColdHotExpertMonitor(
moe_config=moe_config,
hidden_size=128,
ffn_hidden_size=512,
expert_parallel=1,
model_parallel=1,
save_checkpoint_steps=100
)

# Create mock object with nested attributes
obj = Mock()
obj.model.layers = [Mock(), Mock()]

result = monitor.get_attribute_by_path(obj, "model.layers")
assert len(result) == 2


class TestColdHotExpertMonitorBasic:
"""Test ColdHotExpertMonitor basic functionality"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('os.getenv')
def test_init_basic(self, mock_getenv):
"""Test ColdHotExpertMonitor initialization"""

def getenv_side_effect(x, default=None):
if x == "RANK_ID":
return "0"
if x == "RANK_SIZE":
return "8"
return default

mock_getenv.side_effect = getenv_side_effect

moe_config = Mock()
moe_config.update_step = 10
moe_config.expert_num = 8
moe_config.hot_expert_num = 2
moe_config.moe_module_name = "model.layers"

monitor = ColdHotExpertMonitor(
moe_config=moe_config,
hidden_size=128,
ffn_hidden_size=512,
expert_parallel=2,
model_parallel=2,
save_checkpoint_steps=100
)

assert monitor.update_step == 10
assert monitor.expert_num == 8
assert monitor.hot_expert_num == 2
assert monitor.local_expert_num == 4 # 8 / 2


class TestColdHotExpertMonitorStepEnd:
"""Test ColdHotExpertMonitor.on_train_step_end and expert switching"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_rank', return_value=0)
@patch('mindformers.core.callback.callback.os.getenv')
@patch('mindformers.core.callback.callback.time.time')
def test_on_train_step_end_switch_experts(self, mock_time, mock_getenv, mock_get_rank):
"""Test on_train_step_end triggers expert switching"""

def getenv_side_effect(key, default=None):
if key == "RANK_ID":
return "0"
if key == "RANK_SIZE":
return "8"
return default

mock_getenv.side_effect = getenv_side_effect

# Mock time.time() to return incrementing values
time_counter = [100.0]

def time_side_effect():
result = time_counter[0]
time_counter[0] += 1.0
return result

mock_time.side_effect = time_side_effect

# Use Mock object instead of dict to support attribute access
moe_config = Mock()
moe_config.expert_num = 8
moe_config.hot_expert_num = 1
moe_config.moe_module_name = 'network.blocks'
moe_config.update_step = 10

monitor = ColdHotExpertMonitor(
moe_config=moe_config,
hidden_size=4096,
ffn_hidden_size=16384,
expert_parallel=1,
model_parallel=1,
save_checkpoint_steps=10
)

# Mock the train_network and blocks
run_context = Mock()
cb_params = Mock()
cb_params.cur_step_num = 10
cb_params.train_network = Mock()

# Mock blocks structure
mock_block = Mock()
mock_block.output.hot_expert_index.value.return_value = [np.array([0])]

monitor.get_attribute_by_path = Mock(return_value=[mock_block])
monitor.return_back_hot_expert = Mock()
monitor.switch_hot_expert = Mock()

run_context.original_args.return_value = cb_params

monitor.on_train_step_end(run_context)

monitor.switch_hot_expert.assert_called()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_rank', return_value=0)
@patch('mindformers.core.callback.callback.os.getenv')
def test_return_back_hot_expert_single(self, mock_getenv, mock_get_rank):
"""Test return_back_hot_expert with single hot expert"""

def getenv_side_effect(key, default=None):
if key == "RANK_ID":
return "0"
if key == "RANK_SIZE":
return "8"
return default

mock_getenv.side_effect = getenv_side_effect

# Use Mock object instead of dict to support attribute access
moe_config = Mock()
moe_config.expert_num = 8
moe_config.hot_expert_num = 1
moe_config.moe_module_name = 'network.blocks'
moe_config.update_step = 10

monitor = ColdHotExpertMonitor(
moe_config=moe_config,
hidden_size=4096,
ffn_hidden_size=16384,
expert_parallel=1,
model_parallel=1,
save_checkpoint_steps=10
)

# Mock block with hot expert - need to support subscript access
mock_block = Mock()

# old_hot_expert_index[0] needs to be subscriptable
# value()[0] should return an array-like object that supports indexing
mock_hot_expert_index = np.array([0]) # Use numpy array for proper indexing support
mock_block.output.hot_expert_index.value.return_value = [mock_hot_expert_index]

# Create mock arrays that support subscript assignment
# For weight arrays - simple list
mock_weight_array = [Mock() for _ in range(8)]

# For bias arrays - need nested structure that supports bias[0][ffn_index][0] = value
# Create a list of lists where each inner list contains Mock objects
mock_bias_inner = [[Mock()] for _ in range(8)]
mock_bias_array = [mock_bias_inner]

mock_block.output.ffn.mapping.weight = mock_weight_array
mock_block.output.ffn.mapping.bias = mock_bias_array
mock_block.output.ffn.projection.weight = [Mock() for _ in range(8)]
mock_block.output.ffn.projection.bias = [[[Mock()] for _ in range(8)]]

mock_block.output.mlp.mapping.weight = Mock()
mock_block.output.mlp.mapping.bias = Mock()
mock_block.output.mlp.projection.weight = Mock()
mock_block.output.mlp.projection.bias = Mock()

# Should not raise error
monitor.return_back_hot_expert(mock_block)


class TestColdHotExpertMonitorSwitchExpert:
"""Test ColdHotExpertMonitor.switch_hot_expert method"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_rank', return_value=0)
@patch('mindformers.core.callback.callback.os.getenv')
def test_switch_hot_expert_single_no_change(self, mock_getenv, mock_get_rank):
"""Test switch_hot_expert when expert doesn't change"""

def getenv_side_effect(key, default=None):
if key == "RANK_ID":
return "0"
if key == "RANK_SIZE":
return "8"
return default

mock_getenv.side_effect = getenv_side_effect

moe_config = Mock()
moe_config.expert_num = 8
moe_config.hot_expert_num = 1
moe_config.moe_module_name = 'network.blocks'
moe_config.update_step = 10

monitor = ColdHotExpertMonitor(
moe_config=moe_config,
hidden_size=4096,
ffn_hidden_size=16384,
expert_parallel=1,
model_parallel=1,
save_checkpoint_steps=10
)

# Mock block where old and new expert are the same
mock_block = Mock()

# Old expert index - should be array-like supporting indexing
# value()[0] should return an array, and then [0] accesses first element
old_expert_array = np.array([0]) # Array that supports [0] indexing
mock_block.output.hot_expert_index.value.return_value = [old_expert_array]

# New expert index (same as old)
# cumsum_value.value() returns a tensor
mock_cumsum = Mock()

# Create a mock tensor that supports slicing and indexing
# new_expert_index[0:1] should return np.array([0])
# new_expert_index[1:8] should return an array
def mock_getitem(self, key):
# Need to accept self parameter since this is bound as a method
if isinstance(key, slice):
# Handle slicing
if key.start == 0 and key.stop == 1: # hot_expert_num = 1
return np.array([0])
return np.array(list(range(key.start or 0, key.stop or 8)))
# Handle single index
return 0

mock_expert_indices = Mock()
mock_expert_indices.__getitem__ = mock_getitem
mock_cumsum.topk.return_value = (Mock(), mock_expert_indices)
mock_block.output.router.router.cumsum_value.value.return_value = mock_cumsum

# Should return early without switching (since old and new expert are the same)
monitor.switch_hot_expert(mock_block, cur_step_num=2)

if __name__ == '__main__':
pytest.main([__file__, '-v'])

+ 96
- 0
tests/st/test_ut/test_core/test_callback/test_expert_migrate_callback.py View File

@@ -0,0 +1,96 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test callback.py using pytest framework."""
from unittest.mock import Mock, patch

import numpy as np
import pytest

from mindformers.core.callback.callback import ExpertMigrateCallback

# pylint: disable=unused-argument # for mock logic


class TestExpertMigrateCallback:
"""Test ExpertMigrateCallback"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_rank', return_value=0)
def test_init(self, mock_rank):
"""Test ExpertMigrateCallback initialization and on_train_step_end."""
config = Mock()
config.pipeline_model_parallel_size = 1
config.data_parallel_size = 1
config.tensor_model_parallel_size = 1
config.expert_model_parallel_size = 1
config.num_layers = 2
config.mtp_num_layers = 0
config.num_moe_experts = 8

callback = ExpertMigrateCallback(config=config, print_expert_load=True)

run_context = Mock()
cb_params = Mock()

real_network = Mock()
# Need to ensure loop terminates: while hasattr(network, 'network')
# We can just not give it a 'network' attribute
del real_network.network

layer = Mock()
layer.pipeline_stage = 0
layer.mlp.experts.num_tokens_per_expert = Mock()
layer.mlp.num_local_experts = 8
layer.mlp.expert_load_history.asnumpy.return_value = np.zeros(8)

real_network.model.decoder.layers = [layer, layer]

cb_params.train_network = real_network
cb_params.optimizer = Mock()
run_context.original_args.return_value = cb_params

ctx_patch = 'mindformers.core.callback.callback.get_auto_parallel_context'
with patch(ctx_patch, return_value="stand_alone"):
callback.on_train_step_end(run_context)

layer.mlp.update_expert_load_history.assert_called()


class TestExpertMigrateCallbackExtended:
"""Extended tests for ExpertMigrateCallback"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_rank', return_value=0)
def test_expert_migrate_with_mtp_layers(self, mock_rank):
"""Test ExpertMigrateCallback with MTP layers"""

config = Mock()
config.pipeline_model_parallel_size = 1
config.data_parallel_size = 1
config.tensor_model_parallel_size = 1
config.expert_model_parallel_size = 1
config.num_layers = 2
config.mtp_num_layers = 1
config.num_moe_experts = 4

callback = ExpertMigrateCallback(config=config, print_expert_load=False)

assert callback.mtp_num_layers == 1
assert callback.num_layers == 2

if __name__ == '__main__':
pytest.main([__file__, '-v'])

+ 422
- 0
tests/st/test_ut/test_core/test_callback/test_helper_functions.py View File

@@ -0,0 +1,422 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test callback.py using pytest framework."""
import inspect
from unittest.mock import Mock, patch

import numpy as np
import pytest

from mindformers.core.callback.callback import (
AllReduceNet,
_check_mspti_is_on,
_get_loss_output,
_get_max_eigenvalue,
_get_optimizer_state,
_get_separate_loss,
_get_stable_rank,
_get_weight_norm,
_log_grouped_lr_info,
get_embedding_info
)

# pylint: disable=unused-argument # for mock logic


class TestHelperFunctions:
"""Test helper functions in callback.py"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_get_loss_output(self):
"""Test _get_loss_output function."""
# Test case 1: Simple scalar output (not Tensor)
output = 0.5
loss, overflow, scaling_sens, _, _ = _get_loss_output(output)
assert loss == 0.5
assert not overflow
assert not scaling_sens

# Test case 2: Tuple with 3 elements
output = (0.5, False, 1024.0)
loss, overflow, scaling_sens, _, _ = _get_loss_output(output)
assert loss == 0.5
assert not overflow
assert scaling_sens == 1024.0

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.F')
def test_get_weight_norm(self, mock_f):
"""Test _get_weight_norm function."""
network = Mock()
param = Mock()
param.to.return_value.norm.return_value = 1.0
network.trainable_params.return_value = [param, param]

# Mock F.stack
mock_f.stack.return_value.norm.return_value.item.return_value = 1.414

norm = _get_weight_norm(network)
assert norm == pytest.approx(1.414)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_get_optimizer_state(self):
"""Test _get_optimizer_state function"""
param1 = Mock()
param1.name = "p1"
param1.to.return_value.norm.return_value.item.return_value = 0.1

param2 = Mock()
param2.name = "p2"
param2.to.return_value.norm.return_value.item.return_value = 0.2

optim_params = [param1, param2]

norms = _get_optimizer_state(optim_params)
assert norms['p1'] == 0.1
assert norms['p2'] == 0.2


class TestAllReduceNet:
"""Test AllReduceNet class"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_init_and_construct(self):
"""Test AllReduceNet initialization"""

# Mock P.AllReduce which is used in AllReduceNet.__init__
mock_allreduce_class = Mock()
mock_allreduce_instance = Mock()
mock_allreduce_class.return_value = mock_allreduce_instance

with patch('mindformers.core.callback.callback.P.AllReduce', mock_allreduce_class):
net = AllReduceNet('test_group')
mock_allreduce_class.assert_called_once()

# Test construct method
mock_tensor = Mock()
mock_allreduce_instance.return_value = mock_tensor
result = net.construct(mock_tensor)
assert result == mock_tensor


class TestCheckMsptiIsOn:
"""Test _check_mspti_is_on function"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('os.getenv')
def test_mspti_enabled(self, mock_getenv):
"""Test when libmspti.so is in LD_PRELOAD"""

mock_getenv.return_value = "/path/to/libmspti.so"
result = _check_mspti_is_on()
assert result

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('os.getenv')
def test_mspti_disabled(self, mock_getenv):
"""Test when libmspti.so is not in LD_PRELOAD"""

mock_getenv.return_value = "/path/to/other.so"
result = _check_mspti_is_on()
assert not result

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('os.getenv')
def test_mspti_no_ld_preload(self, mock_getenv):
"""Test when LD_PRELOAD is not set"""

mock_getenv.return_value = None
result = _check_mspti_is_on()
assert not result


class TestGetSeparateLoss:
"""Test _get_separate_loss function"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.parameter_register')
def test_get_separate_loss(self, mock_param_register):
"""Test _get_separate_loss retrieves and clears losses"""

# Mock parameter values
mock_aux_loss = Mock()
mock_aux_loss.asnumpy.return_value = np.array([0.1])
mock_mtp_loss = Mock()
mock_mtp_loss.asnumpy.return_value = np.array([0.2])
mock_lm_loss = Mock()
mock_lm_loss.asnumpy.return_value = np.array([0.3])

mock_param_register.get.side_effect = lambda x, default=None: {
'aux_loss': mock_aux_loss,
'mtp_loss': mock_mtp_loss,
'lm_loss': mock_lm_loss
}.get(x, default)

lm_loss, aux_loss, mtp_loss = _get_separate_loss()

assert lm_loss[0] == 0.3
assert aux_loss[0] == 0.1
assert mtp_loss[0] == 0.2

# Verify clear was called
assert mock_param_register.clear.call_count == 3


class TestLogGroupedLrInfo:
"""Test _log_grouped_lr_info function"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_log_grouped_lr_info_basic(self):
"""Test _log_grouped_lr_info basic functionality"""
# This test verifies the function can be called without errors
# when GROUPED_PARAMS is empty (default state in our mocks)

# Should return early without error when GROUPED_PARAMS is empty
# If this raises an exception, pytest will fail the test
_log_grouped_lr_info()


class TestGetLossOutputExtended:
"""Extended tests for _get_loss_output function"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_get_loss_output_tuple_4(self):
"""Test _get_loss_output with 4-element tuple"""

output = (0.5, False, 1024.0, 0.001)
loss, overflow, scaling_sens, learning_rate, global_norm = _get_loss_output(output)
assert loss == 0.5
assert not overflow
assert scaling_sens == 1024.0
assert learning_rate == 0.001
assert global_norm is None

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_get_loss_output_tuple_7(self):
"""Test _get_loss_output with 7-element tuple"""

output = (0.5, False, 1024.0, 0.001, 2.5, np.array([1.0, 2.0]), 2)
loss, overflow, scaling_sens, learning_rate, global_norm = _get_loss_output(output)
assert loss == 0.5
assert not overflow
assert scaling_sens == 1024.0
assert learning_rate == 0.001
assert global_norm == 2.5


class TestGetMaxEigenvalue:
"""Test _get_max_eigenvalue function"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_get_max_eigenvalue_basic(self):
"""Test _get_max_eigenvalue function - simplified test"""
# This function is complex and involves many MindSpore operations
# We'll just verify it exists and has the correct signature

# Verify the function exists
assert callable(_get_max_eigenvalue)

# Verify the function signature
sig = inspect.signature(_get_max_eigenvalue)
params = list(sig.parameters.keys())
assert 'input_tensor' in params
assert 'num_iter' in params

# Note: Full functional testing of this method would require actual MindSpore tensors
# which is beyond the scope of unit testing with mocks


class TestGetStableRankExtended:
"""Extended tests for _get_stable_rank function"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.ms.ops.square')
@patch('mindformers.core.callback.callback.ms.ops.norm')
@patch('mindformers.core.callback.callback._get_max_eigenvalue')
def test_get_stable_rank_zero_eigenvalue(self, mock_eigenvalue, mock_norm, mock_square):
"""Test _get_stable_rank when eigenvalue is zero"""

# Create a more complete mock weight object
weight = Mock()
weight.name = "test_weight"
weight.ndim = 2 # 添加 ndim 属性,避免 -ndim 操作
weight.shape = [3, 3] # 添加 shape 属性

mock_eigenvalue.return_value = np.array(0.0)

stable_rank, eig = _get_stable_rank(weight, num_iter=5)
assert stable_rank == 0.0
assert eig == 0.0

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.ms.ops.square')
@patch('mindformers.core.callback.callback.ms.ops.norm')
@patch('mindformers.core.callback.callback._get_max_eigenvalue')
def test_get_stable_rank_normal(self, mock_eigenvalue, mock_norm, mock_square):
"""Test _get_stable_rank with normal values"""

# Create a more complete mock weight object
weight = Mock()
weight.name = "test_weight"
weight.ndim = 2 # 添加 ndim 属性,避免 -ndim 操作
weight.shape = [3, 3] # 添加 shape 属性

mock_eigenvalue.return_value = np.array(2.0)

# Mock norm to return a Mock that can be squared
mock_norm_tensor = Mock()
mock_norm_tensor.ndim = 0 # 标量
mock_norm.return_value = mock_norm_tensor

# Mock square to return a Mock that has asnumpy method returning 16.0
mock_square_result = Mock()
mock_square_result.asnumpy.return_value = 16.0
mock_square.return_value = mock_square_result

stable_rank, eig = _get_stable_rank(weight, num_iter=5)
# stable_rank = f_norm^2 / eig = 16.0 / 2.0 = 8.0
assert stable_rank == 8.0
assert eig == 2.0


class TestGetEmbeddingInfo:
"""Test get_embedding_info function"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_rank', return_value=0)
@patch('mindspore.context.get_auto_parallel_context', return_value=2)
def test_get_embedding_info(self, *mocks):
"""Test get_embedding_info extracts embedding local norm"""

cb_params = Mock()
cb_params.net_outputs = [0.5, False, 1024.0, 0.001, 2.5,
[1.0, 2.0, 3.0], [128, 256, 128]]

embedding_size = 128
result = get_embedding_info(cb_params, embedding_size)

# Should return the first local_norm with matching size
assert result == 1.0


class TestGetMaxEigenvalueComprehensive:
"""Comprehensive tests for _get_max_eigenvalue function"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.ms.ops.matmul')
@patch('mindformers.core.callback.callback.ms.ops.unsqueeze')
@patch('mindformers.core.callback.callback.ms.ops.randn')
@patch('mindformers.core.callback.callback.logger')
def test_get_max_eigenvalue_2d_tensor(self, mock_logger, mock_randn,
mock_unsqueeze, mock_matmul):
"""Test _get_max_eigenvalue with 2D tensor"""

# Create mock input tensor
input_tensor = Mock()
input_tensor.ndim = 2
input_tensor.shape = [3, 3]
input_tensor.astype.return_value = input_tensor
input_tensor.transpose.return_value = input_tensor

# Mock randn to return a tensor with positive norm
mock_u_tensor = Mock()
mock_u_norm = Mock()
mock_u_norm.asnumpy.return_value = 1.0
mock_u_tensor.norm.return_value = mock_u_norm
mock_u_tensor.__truediv__ = Mock(return_value=mock_u_tensor)
mock_randn.return_value = mock_u_tensor

# Mock unsqueeze
mock_unsqueeze.return_value = mock_u_tensor

# Mock matmul operations
mock_input_seq = Mock()
mock_unsqueeze.return_value = mock_input_seq

mock_v_tensor = Mock()
mock_v_norm = Mock()
mock_v_norm.asnumpy.return_value = 1.0

# Mock (v_norm != 0).all() - need to return a tensor-like object with .all() method
mock_comparison_result = Mock()
mock_comparison_result.all.return_value = True
mock_v_norm.__ne__ = Mock(return_value=mock_comparison_result)

mock_v_tensor.norm.return_value = mock_v_norm
mock_v_tensor.transpose.return_value = mock_v_tensor
mock_v_tensor.__truediv__ = Mock(return_value=mock_v_tensor)

mock_eigenvalue = Mock()
mock_eigenvalue.asnumpy.return_value = 2.5
mock_eigenvalue.squeeze.return_value = mock_eigenvalue

# matmul is called:
# 1. Once for input_seq (line 211)
# 2. num_iter times for v_tensor (line 216)
# 3. num_iter times for eigenvalue (line 217)
# Total: 1 + 2 + 2 = 5 times for num_iter=2
mock_matmul.side_effect = [
mock_input_seq, # Line 211: input_seq calculation
mock_v_tensor, # Line 216: iteration 1, v_tensor
mock_eigenvalue, # Line 217: iteration 1, eigenvalue
mock_v_tensor, # Line 216: iteration 2, v_tensor
mock_eigenvalue # Line 217: iteration 2, eigenvalue
]

result = _get_max_eigenvalue(input_tensor, num_iter=2)
assert result == 2.5

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.ms.ops.randn')
@patch('mindformers.core.callback.callback.logger')
def test_get_max_eigenvalue_zero_norm(self, mock_logger, mock_randn):
"""Test _get_max_eigenvalue when random vector has zero norm"""

input_tensor = Mock()
input_tensor.ndim = 2
input_tensor.shape = [3, 3]
input_tensor.astype.return_value = input_tensor

# Mock randn to always return zero norm
mock_u_tensor = Mock()
mock_u_norm = Mock()
mock_u_norm.asnumpy.return_value = 0.0
mock_u_tensor.norm.return_value = mock_u_norm
mock_randn.return_value = mock_u_tensor

result = _get_max_eigenvalue(input_tensor, num_iter=2)
assert result == 0.0
mock_logger.warning.assert_called()

if __name__ == '__main__':
pytest.main([__file__, '-v'])

+ 829
- 5
tests/st/test_ut/test_core/test_callback/test_mfloss_monitor.py View File

@@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""Test test_mfloss_monitor.py"""
import builtins
from unittest.mock import Mock, patch
import unittest
import numpy as np
@@ -21,6 +22,9 @@ import pytest
from mindspore import Tensor
from mindformers.core.callback.callback import MFLossMonitor

# pylint: disable=protected-access
# pylint: disable=unused-argument # for mock logic


class TestMFLossMonitor(unittest.TestCase):
"""Test MFLossMonitor class"""
@@ -99,7 +103,6 @@ class TestMFLossMonitor(unittest.TestCase):
mock_get_context.return_value = 1 # pipeline_stages = 1

original_loss = 1.0
# pylint: disable=W0212
fixed_loss = self.monitor._fix_loss_for_parallel(original_loss)

self.assertEqual(fixed_loss, original_loss)
@@ -116,7 +119,6 @@ class TestMFLossMonitor(unittest.TestCase):
mock_get_context.return_value = 2 # pipeline_stages = 2

original_loss = 2.0
# pylint: disable=W0212
fixed_loss = self.monitor._fix_loss_for_parallel(original_loss)

# Should divide by micro_size
@@ -135,7 +137,6 @@ class TestMFLossMonitor(unittest.TestCase):
with patch('mindspore.get_context') as mock_get_context:
mock_get_context.return_value = ms.GRAPH_MODE

# pylint: disable=W0212
result = self.monitor._can_calculate_model_flops(mock_cb_params)

self.assertTrue(result)
@@ -149,7 +150,6 @@ class TestMFLossMonitor(unittest.TestCase):
mock_cb_params = Mock()
mock_cb_params.mode = 'invalid'

# pylint: disable=W0212
result = self.monitor._can_calculate_model_flops(mock_cb_params)

self.assertFalse(result)
@@ -203,5 +203,829 @@ class TestMFLossMonitorIntegration(unittest.TestCase):
self.assertEqual(monitor.loss_list[0], 0.5)


class TestMFLossMonitorBasic:
"""Test MFLossMonitor basic functionality"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=1)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.get_tensorboard_args', return_value={})
def test_init_defaults(self, *mocks):
"""Test MFLossMonitor default initialization"""

monitor = MFLossMonitor(per_print_times=1, global_batch_size=32, dataset_size=100)
assert monitor.per_print_times == 1
assert monitor.global_batch_size == 32
assert monitor.steps_per_epoch == 100

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=1)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.get_tensorboard_args', return_value={})
def test_init_custom_values(self, *mocks):
"""Test MFLossMonitor custom initialization"""

monitor = MFLossMonitor(
learning_rate=0.01,
per_print_times=10,
micro_batch_num=2,
micro_batch_interleave_num=2,
gradient_accumulation_steps=4
)
assert monitor.per_print_times == 10
assert monitor.mirco_size == 2
assert monitor.micro_batch_interleave_num == 2
assert monitor.gradient_accumulation_steps == 4

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_tensorboard_args', return_value={})
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=1)
@patch('time.time')
def test_on_train_epoch_begin(self, mock_time, *mocks):
"""Test on_train_epoch_begin callback"""

mock_time.return_value = 1000.0
monitor = MFLossMonitor()
mock_run_context = Mock()

monitor.on_train_epoch_begin(mock_run_context)

assert monitor.loss_list == []
assert monitor.epoch_time == 1000.0
assert monitor.run_context == mock_run_context

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_tensorboard_args', return_value={})
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=1)
@patch('time.time')
def test_on_train_step_begin(self, mock_time, *mocks):
"""Test on_train_step_begin callback"""

mock_time.return_value = 1000.0
monitor = MFLossMonitor()
mock_run_context = Mock()

monitor.on_train_step_begin(mock_run_context)

assert monitor.step_time == 1000.0
assert monitor.run_context == mock_run_context


class TestMFLossMonitorExtended:
"""Extended tests for MFLossMonitor"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=1)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.get_tensorboard_args', return_value={})
def test_fix_loss_for_parallel_pipeline(self, *mocks):
"""Test _fix_loss_for_parallel with pipeline stages"""

monitor = MFLossMonitor(
micro_batch_num=2,
gradient_accumulation_steps=2,
calculate_per_token_loss=False
)

# Mock both context.get_auto_parallel_context and get_auto_parallel_context
with patch('mindspore.context.get_auto_parallel_context', return_value=2), \
patch('mindspore.get_auto_parallel_context', return_value='not_zero_bubble_v'):
loss = 8.0
fixed_loss = monitor._fix_loss_for_parallel(loss, print_warning=False)

# When pipeline_stages=2: loss = 8.0 / mirco_size(2) = 4.0
# When gradient_accumulation_steps=2: loss = 4.0 / 2 = 2.0
assert fixed_loss == 2.0

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=1)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.get_tensorboard_args', return_value={})
def test_fix_loss_for_parallel_gradient_accumulation(self, *mocks):
"""Test _fix_loss_for_parallel with gradient accumulation only"""

monitor = MFLossMonitor(
micro_batch_num=1, # No pipeline division
gradient_accumulation_steps=2,
calculate_per_token_loss=False
)

# Mock pipeline_stages=1 (no pipeline)
with patch('mindspore.context.get_auto_parallel_context', return_value=1), \
patch('mindspore.get_auto_parallel_context', return_value='not_zero_bubble_v'):
loss = 8.0
fixed_loss = monitor._fix_loss_for_parallel(loss, print_warning=False)

# When pipeline_stages=1: no division by mirco_size
# When gradient_accumulation_steps=2: loss = 8.0 / 2 = 4.0
assert fixed_loss == 4.0

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=1)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.get_tensorboard_args', return_value={})
def test_fix_loss_for_parallel_no_pipeline(self, *mocks):
"""Test _fix_loss_for_parallel without pipeline stages"""

monitor = MFLossMonitor(
micro_batch_num=2,
gradient_accumulation_steps=2,
calculate_per_token_loss=False
)

with patch('mindspore.context.get_auto_parallel_context', return_value=1), \
patch('mindspore.get_auto_parallel_context', return_value='data_parallel'):
loss = 8.0
fixed_loss = monitor._fix_loss_for_parallel(loss)
# When pipeline_stages=1: no division by mirco_size
# When gradient_accumulation_steps=2: loss = 8.0 / 2 = 4.0
assert fixed_loss == 4.0

@patch('time.time', return_value=1000.0)
@patch('mindformers.core.callback.callback.get_tensorboard_args',
return_value={'log_loss_scale_to_tensorboard': True})
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=Mock())
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=1)
def test_print_output_info_with_tensorboard(self, *mocks):
"""Test print_output_info with tensorboard enabled"""

monitor = MFLossMonitor(learning_rate=0.001, global_batch_size=32)
monitor.tensor_writer = Mock()

cb_params = Mock()
cb_params.dataset_sink_mode = False
cb_params.optimizer = Mock()
cb_params.optimizer.global_step = 10
cb_params.train_network = Mock()
cb_params.train_network.phase = 'train'
cb_params.train_network.set_train = Mock()

monitor.print_output_info(
cb_params, 1, 10, 100.0, 1, 100, 0.5, 100.0,
False, 1024.0, 3600, 10.0, 2.5, None, None, None
)

# Verify tensorboard writer was called
assert monitor.tensor_writer.add_scalar.call_count > 0


class TestMFLossMonitorOnTrainStepEnd:
"""Test MFLossMonitor.on_train_step_end comprehensive coverage"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=1)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.get_tensorboard_args', return_value={})
@patch('mindformers.core.callback.callback.get_auto_parallel_context')
@patch('mindformers.core.callback.callback.set_auto_parallel_context')
@patch('time.time')
def test_on_train_step_end_with_separate_loss(
self, mock_time, mock_set_context, mock_get_context, *mocks):
"""Test on_train_step_end with print_separate_loss enabled"""

mock_time.return_value = 1000.0

def get_context_side_effect(x, *args):
return {
'parallel_mode': 'stand_alone',
'full_batch': False
}.get(x, None)

mock_get_context.side_effect = get_context_side_effect

monitor = MFLossMonitor(
origin_epochs=10,
dataset_size=100,
global_batch_size=32,
print_separate_loss=True,
is_moe_model=True
)
monitor.step_time = 999.0

run_context = Mock()
cb_params = Mock()
cb_params.cur_step_num = 1
cb_params.batch_num = 100
cb_params.cur_epoch_num = 1
cb_params.dataset_sink_mode = False
cb_params.net_outputs = (0.5, False, 1024.0, 0.001, 2.5)
cb_params.get.return_value = None
run_context.original_args.return_value = cb_params

separate_loss_mock = (np.array([0.3]), np.array([0.1]), np.array([0.1]))
loss_patch = 'mindformers.core.callback.callback._get_separate_loss'
with patch(loss_patch, return_value=separate_loss_mock):
monitor.on_train_step_end(run_context)

assert len(monitor.loss_list) == 1

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.get_tensorboard_args', return_value={})
@patch('mindformers.core.callback.callback.get_auto_parallel_context')
@patch('mindformers.core.callback.callback.set_auto_parallel_context')
@patch('time.time')
@patch('mindformers.core.callback.callback.check_arf_status', return_value=True)
def test_on_train_step_end_with_arf_status(
self, mock_arf, mock_time, mock_set_context, mock_get_context, *mocks):
"""Test on_train_step_end with ARF status check"""

mock_time.return_value = 1000.0

def get_context_side_effect(x, *args):
return {
'parallel_mode': 'stand_alone',
'full_batch': False
}.get(x, None)

mock_get_context.side_effect = get_context_side_effect

monitor = MFLossMonitor(
origin_epochs=10,
dataset_size=100,
global_batch_size=32
)
monitor.step_time = 999.0
monitor.mf_support = True
monitor.mf_calculated = False

run_context = Mock()
cb_params = Mock()
cb_params.cur_step_num = 1
cb_params.batch_num = 100
cb_params.cur_epoch_num = 1
cb_params.dataset_sink_mode = False
cb_params.net_outputs = 0.5
cb_params.get.return_value = None
cb_params.mode = 'train'
cb_params.train_network = Mock()
cb_params.train_network.current_phase = 'train_phase'
run_context.original_args.return_value = cb_params

with patch.object(monitor, '_calculate_model_flops'):
monitor.on_train_step_end(run_context)


class TestMFLossMonitorCalculateFlops:
"""Test MFLossMonitor._calculate_model_flops"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.get_tensorboard_args', return_value={})
@patch('mindformers.core.callback.callback.flops_collection')
@patch('mindformers.core.callback.callback.auto_parallel_context')
@patch('mindformers.core.callback.callback.get_group_size', return_value=8)
def test_calculate_model_flops_standalone(
self, mock_group_size, mock_auto_context, mock_flops, *mocks):
"""Test _calculate_model_flops in standalone mode"""

monitor = MFLossMonitor()
monitor.current_phase = 'train_phase'

mock_flops.return_value = (1000000.0, 0, 500000.0, 0, False)
mock_auto_context.return_value.get_pipeline_stages.return_value = 1
mock_auto_context.return_value.get_parallel_mode.return_value = 'stand_alone'

monitor._calculate_model_flops()

assert monitor.mf_calculated
assert monitor.full_model_flops == 1000000.0

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.get_tensorboard_args', return_value={})
@patch('mindformers.core.callback.callback.flops_collection')
def test_calculate_model_flops_runtime_error(self, mock_flops, *mocks):
"""Test _calculate_model_flops with RuntimeError"""

monitor = MFLossMonitor()
monitor.current_phase = 'train_phase'
monitor.mf_support = True

mock_flops.side_effect = RuntimeError("Flops calculation failed")

monitor._calculate_model_flops()

assert not monitor.mf_support


class TestMFLossMonitorPrintOutputInfo:
"""Test MFLossMonitor.print_output_info comprehensive coverage"""

@patch('mindformers.core.callback.callback.get_real_group_size', return_value=1)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=Mock())
@patch('mindformers.core.callback.callback.get_tensorboard_args', return_value={
'log_loss_scale_to_tensorboard': True,
'log_timers_to_tensorboard': True
})
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_group_size', return_value=8)
def test_print_output_info_with_all_tensorboard_options(self, mock_group_size, *mocks):
"""Test print_output_info with all tensorboard options enabled"""

monitor = MFLossMonitor(learning_rate=0.001, global_batch_size=32)
monitor.tensor_writer = Mock()
monitor.mf_calculated = True
monitor.full_model_flops = 1000000.0

cb_params = Mock()
cb_params.dataset_sink_mode = False
cb_params.optimizer = Mock()
cb_params.optimizer.global_step = 10
cb_params.train_network = Mock()
cb_params.train_network.phase = 'train'
cb_params.train_network.set_train = Mock()

monitor.print_output_info(
cb_params, 1, 10, 100.0, 1, 100, 0.5, 100.0,
False, 1024.0, 3600, 10.0, 2.5, None, None, None
)

# Verify tensorboard writer was called for various metrics
assert monitor.tensor_writer.add_scalar.call_count > 5

@patch('mindformers.core.callback.callback.get_real_group_size', return_value=1)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=Mock())
@patch('mindformers.core.callback.callback.get_tensorboard_args', return_value={
'log_timers_to_tensorboard': True
})
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_group_size', return_value=8)
@patch('mindformers.core.callback.callback.is_legacy_model', return_value=False)
def test_print_output_info_with_separate_loss(self, mock_legacy, mock_group_size, *mocks):
"""Test print_output_info with separate loss (MoE and MTP)"""

monitor = MFLossMonitor(
learning_rate=0.001,
global_batch_size=32,
print_separate_loss=True,
is_moe_model=True,
is_mtp_model=True
)
# Explicitly set print_separate_loss to True (in case it was reset during init)
monitor.print_separate_loss = True
# Ensure tensor_writer is a Mock and tensorboard config is set
monitor.tensor_writer = Mock()
monitor.tensorboard = {'log_timers_to_tensorboard': True}

cb_params = Mock()
cb_params.dataset_sink_mode = True
cb_params.optimizer = Mock()
cb_params.optimizer.global_step = 10
cb_params.train_network = Mock()
cb_params.train_network.phase = 'train'

# Test with separate losses
lm_loss = np.array([0.3])
aux_loss = np.array([0.1])
mtp_loss = np.array([0.05])

monitor.print_output_info(
cb_params, 1, 10, 100.0, 1, 100, 0.5, 100.0,
False, 1024.0, 3600, 10.0, 2.5, lm_loss, aux_loss, mtp_loss
)

# Verify separate loss was logged to tensorboard
# Check if add_scalar was called with the expected tags
call_args = monitor.tensor_writer.add_scalar.call_args_list
tags_called = [call[0][0] for call in call_args] # Extract first positional argument (tag)

assert 'lm-loss' in tags_called, f"'lm-loss' not found in {tags_called}"
assert 'mtp-loss' in tags_called, f"'mtp-loss' not found in {tags_called}"
assert 'load-balancing-loss' in tags_called, \
f"'load-balancing-loss' not found in {tags_called}"


class TestMFLossMonitorGetPipelineGroup:
"""Test MFLossMonitor._get_pipeline_group"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_rank', return_value=2)
@patch('mindformers.core.callback.callback.auto_parallel_context')
@patch('mindformers.core.callback.callback.get_group_size', return_value=8)
def test_get_pipeline_group(self, mock_group_size, mock_auto_context, mock_get_rank):
"""Test _get_pipeline_group calculation"""

mock_auto_context.return_value.get_pipeline_stages.return_value = 2

rank_list, rank_list_str = MFLossMonitor._get_pipeline_group()

# With rank=2, stage_nums=2, device_nums=8
# per_stage_device_nums = 8 // 2 = 4
# local_stage_rank_id = 2 % 4 = 2
# rank_list = [2 + 0*4, 2 + 1*4] = [2, 6]
assert rank_list == [2, 6]
assert rank_list_str == "2-6"


class TestMFLossMonitorCanCalculateFlops:
"""Test MFLossMonitor._can_calculate_model_flops"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=1)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.get_tensorboard_args', return_value={})
@patch('mindformers.core.callback.callback.ms.get_context', return_value=0) # GRAPH_MODE
@patch('mindformers.core.callback.callback.is_legacy_model', return_value=True)
def test_can_calculate_flops_train_mode(self, mock_legacy, mock_get_context, *mocks):
"""Test _can_calculate_model_flops in train mode"""

monitor = MFLossMonitor()
monitor.is_moe_model = False

cb_params = Mock()
cb_params.mode = 'train'
cb_params.train_network = Mock()
cb_params.train_network.current_phase = 'train_phase'

result = monitor._can_calculate_model_flops(cb_params)

assert result
assert monitor.current_phase == 'train_phase'

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=1)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.get_tensorboard_args', return_value={})
@patch('mindformers.core.callback.callback.ms.get_context', return_value=1) # PYNATIVE_MODE
@patch('mindformers.core.callback.callback.logger')
def test_can_calculate_flops_pynative_mode(self, mock_logger, mock_get_context, *mocks):
"""Test _can_calculate_model_flops in pynative mode (should fail)"""

monitor = MFLossMonitor()

cb_params = Mock()
cb_params.mode = 'train'
cb_params.train_network = Mock()

result = monitor._can_calculate_model_flops(cb_params)

assert not result
mock_logger.warning.assert_called()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=1)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.get_tensorboard_args', return_value={})
@patch('mindformers.core.callback.callback.logger')
def test_can_calculate_flops_invalid_mode(self, mock_logger, *mocks):
"""Test _can_calculate_model_flops with invalid mode"""

monitor = MFLossMonitor()

cb_params = Mock()
cb_params.mode = 'predict' # Invalid mode

result = monitor._can_calculate_model_flops(cb_params)

assert not result
mock_logger.warning.assert_called()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=1)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.get_tensorboard_args', return_value={})
@patch('mindformers.core.callback.callback.ms.get_context', return_value=0)
@patch('mindformers.core.callback.callback.logger')
def test_can_calculate_flops_no_current_phase(self, mock_logger, mock_get_context, *mocks):
"""Test _can_calculate_model_flops when network has no current_phase"""

monitor = MFLossMonitor()

cb_params = Mock()
cb_params.mode = 'train'
cb_params.train_network = Mock(spec=[]) # No current_phase attribute

result = monitor._can_calculate_model_flops(cb_params)

assert not result
mock_logger.warning.assert_called()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=1)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.get_tensorboard_args', return_value={})
@patch('mindformers.core.callback.callback.ms.get_context', return_value=0)
@patch('mindformers.core.callback.callback.is_legacy_model', return_value=False)
@patch('mindformers.core.callback.callback.logger')
def test_can_calculate_flops_moe_model_non_legacy(
self, mock_logger, mock_legacy, mock_get_context, *mocks):
"""Test _can_calculate_model_flops with MoE model in non-legacy mode"""

monitor = MFLossMonitor()
monitor.is_moe_model = True

cb_params = Mock()
cb_params.mode = 'train'
cb_params.train_network = Mock()
cb_params.train_network.current_phase = 'train_phase'

result = monitor._can_calculate_model_flops(cb_params)

assert not result
mock_logger.warning.assert_called()


class TestMFLossMonitorPrintOutputInfoLearningRate:
"""Test MFLossMonitor.print_output_info learning rate scenarios"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=1)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.get_tensorboard_args', return_value={})
@patch('mindformers.core.callback.callback.ms.context.get_context', return_value='CPU')
@patch('mindformers.core.callback.callback.logger')
def test_print_output_info_lr_schedule_cpu(self, mock_logger, mock_get_context, *mocks):
"""Test print_output_info with LearningRateSchedule on CPU"""

lr_schedule = Mock(spec=['__call__'])
monitor = MFLossMonitor(learning_rate=lr_schedule, global_batch_size=32)
monitor.print_warning_flag = True

cb_params = Mock()
cb_params.dataset_sink_mode = False
cb_params.optimizer = Mock()
cb_params.optimizer.global_step = 10

monitor.print_output_info(
cb_params, 1, 10, 100.0, 1, 100, 0.5, 100.0,
False, 1024.0, 3600, 10.0, 2.5, None, None, None
)

# Should log warning about CPU not supported
mock_logger.warning.assert_called()
assert not monitor.print_warning_flag

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=1)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.get_tensorboard_args', return_value={})
@patch('mindformers.core.callback.callback.ms.context.get_context', return_value='Ascend')
def test_print_output_info_lr_schedule_ascend(self, mock_get_context, *mocks):
"""Test print_output_info with LearningRateSchedule on Ascend"""

# Create a simple mock that can be called
lr_schedule = Mock()
lr_result = Mock()
lr_result.asnumpy.return_value = np.array(0.001)
lr_schedule.return_value = lr_result

monitor = MFLossMonitor(learning_rate=lr_schedule, global_batch_size=32)

# Manually set the learning_rate to be recognized as LearningRateSchedule
# by patching the isinstance check in print_output_info
with patch('mindformers.core.callback.callback.isinstance') as mock_isinstance:
# Default behavior: call the real isinstance
def isinstance_side_effect(obj, classinfo):
# Special handling for our lr_schedule object
if obj is monitor.learning_rate:
# Check if classinfo is a tuple (for the first isinstance check)
if isinstance(classinfo, tuple):
return False # Not (float, Tensor, np.ndarray)
return True # Is LearningRateSchedule
# For all other cases, use built-in isinstance
return builtins.isinstance(obj, classinfo)

mock_isinstance.side_effect = isinstance_side_effect

cb_params = Mock()
cb_params.dataset_sink_mode = False
cb_params.optimizer = Mock()
cb_params.optimizer.global_step = 10
cb_params.train_network = Mock()
cb_params.train_network.phase = 'train'
cb_params.train_network.set_train = Mock()

monitor.print_output_info(
cb_params, 1, 10, 100.0, 1, 100, 0.5, 100.0,
False, 1024.0, 3600, 10.0, 2.5, None, None, None
)

# Verify set_train was called to temporarily disable training
cb_params.train_network.set_train.assert_called()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=1)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.get_tensorboard_args', return_value={})
@patch('mindformers.core.callback.callback.logger')
def test_print_output_info_invalid_lr_type(self, mock_logger, *mocks):
"""Test print_output_info with invalid learning rate type"""

# Use a list as learning rate (invalid type)
monitor = MFLossMonitor(learning_rate=[0.01, 0.02], global_batch_size=32)
monitor.print_warning_flag = True

cb_params = Mock()
cb_params.dataset_sink_mode = False
cb_params.optimizer = Mock()
cb_params.optimizer.global_step = 10

monitor.print_output_info(
cb_params, 1, 10, 100.0, 1, 100, 0.5, 100.0,
False, 1024.0, 3600, 10.0, 2.5, None, None, None
)

# Should log warning about invalid type
mock_logger.warning.assert_called()
assert not monitor.print_warning_flag

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=1)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.get_tensorboard_args', return_value={})
@patch('mindformers.core.callback.callback.logger')
def test_print_output_info_no_lr(self, mock_logger, *mocks):
"""Test print_output_info without learning rate"""

monitor = MFLossMonitor(learning_rate=None, global_batch_size=32)
monitor.print_warning_flag = True

cb_params = Mock()
cb_params.dataset_sink_mode = False
cb_params.optimizer = Mock()
cb_params.optimizer.global_step = 10

monitor.print_output_info(
cb_params, 1, 10, 100.0, 1, 100, 0.5, 100.0,
False, 1024.0, 3600, 10.0, 2.5, None, None, None
)

# Should log warning about missing learning rate
mock_logger.warning.assert_called()
assert not monitor.print_warning_flag


class TestMFLossMonitorCalculateFlopsWithPipeline:
"""Test MFLossMonitor._calculate_model_flops with pipeline parallel"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.get_tensorboard_args', return_value={})
@patch('mindformers.core.callback.callback.flops_collection')
@patch('mindformers.core.callback.callback.auto_parallel_context')
@patch('mindformers.core.callback.callback.get_group_size', return_value=8)
@patch('mindformers.core.callback.callback.create_group')
@patch('mindformers.core.callback.callback.AllReduceNet')
@patch('mindformers.core.callback.callback.Tensor')
def test_calculate_flops_with_pipeline_dynamic_shape(self, mock_tensor, mock_allreduce_net,
mock_create_group, mock_group_size,
mock_auto_context, mock_flops, *mocks):
"""Test _calculate_model_flops with pipeline and dynamic shape"""

monitor = MFLossMonitor()
monitor.current_phase = 'train_phase'

mock_flops.return_value = (1000000.0, 0, 500000.0, 0, True) # is_dynamic_shape=True
mock_auto_context.return_value.get_pipeline_stages.return_value = 2

# Mock AllReduceNet to return is_dynamic_shape > 0
mock_allreduce_instance = Mock()
mock_result = Mock()
mock_result.asnumpy.return_value = [1] # is_dynamic_shape > 0
mock_allreduce_instance.return_value = mock_result
mock_allreduce_net.return_value = mock_allreduce_instance

monitor._calculate_model_flops()

# Should set mf_support to False due to dynamic shape
assert not monitor.mf_support

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.get_tensorboard_args', return_value={})
@patch('mindformers.core.callback.callback.flops_collection')
@patch('mindformers.core.callback.callback.auto_parallel_context')
@patch('mindformers.core.callback.callback.get_group_size', return_value=8)
@patch('mindformers.core.callback.callback.create_group')
@patch('mindformers.core.callback.callback.AllReduceNet')
@patch('mindformers.core.callback.callback.Tensor')
def test_calculate_flops_with_pipeline_success(self, mock_tensor, mock_allreduce_net,
mock_create_group, mock_group_size,
mock_auto_context, mock_flops, *mocks):
"""Test _calculate_model_flops with pipeline parallel success"""

monitor = MFLossMonitor()
monitor.current_phase = 'train_phase'

mock_flops.return_value = (1000000.0, 0, 500000.0, 0, False) # is_dynamic_shape=False
mock_auto_context.return_value.get_pipeline_stages.return_value = 2
mock_auto_context.return_value.get_parallel_mode.return_value = 'semi_auto_parallel'

# Mock AllReduceNet
mock_allreduce_instance = Mock()

# First call: is_dynamic_shape check
mock_is_dynamic_result = Mock()
mock_is_dynamic_result.asnumpy.return_value = [0]

# Second call: flops aggregation
mock_flops_result = Mock()
mock_flops_result.asnumpy.return_value = [2000000.0]

mock_allreduce_instance.side_effect = [mock_is_dynamic_result, mock_flops_result]
mock_allreduce_net.return_value = mock_allreduce_instance

monitor._calculate_model_flops()

# Should aggregate flops across pipeline stages and divide by group size
assert monitor.mf_calculated
# 2000000.0 / 8 = 250000.0
assert monitor.full_model_flops == 250000.0


class TestMFLossMonitorPrintOutputInfoDataSinkMode:
"""Test MFLossMonitor.print_output_info in dataset sink mode"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=1)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.get_tensorboard_args', return_value={})
def test_print_output_info_sink_mode(self, *mocks):
"""Test print_output_info in dataset sink mode"""

monitor = MFLossMonitor(learning_rate=0.001, global_batch_size=32)

cb_params = Mock()
cb_params.dataset_sink_mode = True # Sink mode
cb_params.optimizer = Mock()
cb_params.optimizer.global_step = 10

monitor.print_output_info(
cb_params, 1, 10, 100.0, 1, 100, 0.5, 100.0,
False, 1024.0, 3600, 10.0, 2.5, None, None, None
)

# In sink mode, loss_info format is different
# This test mainly ensures no errors occur


class TestMFLossMonitorMstxEnabled:
"""Test MFLossMonitor with mstx enabled"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=1)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.get_tensorboard_args', return_value={})
@patch('mindformers.core.callback.callback._check_mspti_is_on', return_value=True)
@patch('mindformers.core.callback.callback.ms.profiler.mstx')
@patch('mindformers.core.callback.callback.ms.runtime')
@patch('time.time')
def test_on_train_step_with_mstx(self, mock_time, mock_runtime, mock_mstx, mock_mspti, *mocks):
"""Test on_train_step_begin and on_train_step_end with mstx enabled"""

mock_time.return_value = 1000.0
mock_mstx.range_start.return_value = 12345
mock_runtime.current_stream.return_value = Mock()

monitor = MFLossMonitor(origin_epochs=10, dataset_size=100, global_batch_size=32)

run_context = Mock()
cb_params = Mock()
cb_params.cur_step_num = 5
run_context.original_args.return_value = cb_params

# Test on_train_step_begin
monitor.on_train_step_begin(run_context)

mock_mstx.range_start.assert_called_once()
assert monitor.mstx_range_id == 12345

if __name__ == '__main__':
unittest.main()
pytest.main([__file__, '-v'])

+ 302
- 0
tests/st/test_ut/test_core/test_callback/test_other_callbacks.py View File

@@ -0,0 +1,302 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test callback.py using pytest framework."""
from unittest.mock import Mock, patch

import pytest

from mindformers.core.callback.callback import (
EvalCallBack,
MaxLogitsMonitor,
MoEDropRateCallback,
StressDetectCallBack,
SummaryMonitor,
TopkBiasBalanceCallback,
TrainCallBack
)

# pylint: disable=unused-argument # for mock logic


class TestSummaryMonitor:
"""Test SummaryMonitor class"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.SummaryCollector')
@patch('mindformers.core.callback.callback.get_output_subpath')
@patch('mindformers.core.callback.callback.get_real_rank')
def test_init(self, mock_get_real_rank, mock_get_output_subpath, mock_summary_collector):
"""Test initialization"""
mock_get_real_rank.return_value = 0
mock_get_output_subpath.return_value = "/tmp/summary"

SummaryMonitor(summary_dir=None)

mock_summary_collector.assert_called_once()
_, kwargs = mock_summary_collector.call_args
assert kwargs['summary_dir'] == "/tmp/summary"


class TestEvalCallBack:
"""Test EvalCallBack class"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_on_train_epoch_end(self):
"""Test on_train_epoch_end callback"""
eval_func = Mock()
callback = EvalCallBack(eval_func, epoch_interval=2)

run_context = Mock()
cb_params = Mock()

# Epoch 1: no eval
cb_params.cur_epoch_num = 1
run_context.original_args.return_value = cb_params
callback.on_train_epoch_end(run_context)
eval_func.assert_not_called()

# Epoch 2: eval
cb_params.cur_epoch_num = 2
run_context.original_args.return_value = cb_params
callback.on_train_epoch_end(run_context)
eval_func.assert_called_once()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_on_train_step_end(self):
"""Test on_train_step_end callback"""
eval_func = Mock()
callback = EvalCallBack(eval_func, step_interval=10, epoch_interval=-1)

run_context = Mock()
cb_params = Mock()

# Step 5: no eval
cb_params.cur_step_num = 5
run_context.original_args.return_value = cb_params
callback.on_train_step_end(run_context)
eval_func.assert_not_called()

# Step 10: eval
cb_params.cur_step_num = 10
run_context.original_args.return_value = cb_params
callback.on_train_step_end(run_context)
eval_func.assert_called_once()


class TestTrainCallBack:
"""Test TrainCallBack class"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_stop_step(self):
"""Test stop_step functionality"""
callback = TrainCallBack(stop_step=10)
run_context = Mock()
cb_params = Mock()
run_context.original_args.return_value = cb_params

# Step 5
cb_params.cur_step_num = 5
callback.on_train_step_end(run_context)
run_context.request_stop.assert_not_called()

# Step 10
cb_params.cur_step_num = 10
callback.on_train_step_end(run_context)
run_context.request_stop.assert_called_once()


class TestStressDetectCallBack:
"""Test StressDetectCallBack class"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.stress_detect')
def test_stress_detect(self, mock_stress_detect):
"""Test stress detection functionality"""
callback = StressDetectCallBack(detection_interval=10, num_detections=2, dataset_size=100)
run_context = Mock()
cb_params = Mock()
run_context.original_args.return_value = cb_params

# Step 5: no detect
cb_params.cur_step_num = 5
callback.on_train_step_end(run_context)
mock_stress_detect.assert_not_called()

# Step 10: detect
cb_params.cur_step_num = 10
mock_stress_detect.return_value = 0
callback.on_train_step_end(run_context)
assert mock_stress_detect.call_count == 2


class TestMaxLogitsMonitor:
"""Test MaxLogitsMonitor"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_on_train_step_end(self):
"""Test on_train_step_end callback"""
callback = MaxLogitsMonitor()
run_context = Mock()
cb_params = Mock()

# Create a network structure where 'network' attribute chain terminates
# The leaf node MUST NOT have a 'network' attribute to break the loop.
leaf_network = Mock()
del leaf_network.network
leaf_network.reset_max_attention_logit = Mock()

# intermediate network
network = Mock()
network.network = leaf_network

cb_params.train_network = network

run_context.original_args.return_value = cb_params

with patch('mindformers.core.callback.callback.get_auto_parallel_context') \
as mock_get_parallel:
mock_get_parallel.return_value = "stand_alone"
callback.on_train_step_end(run_context)

leaf_network.reset_max_attention_logit.assert_called_once()


class TestTopkBiasBalanceCallback:
"""Test TopkBiasBalanceCallback"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindspore.context.get_auto_parallel_context')
@patch('mindformers.core.callback.callback.get_tensorboard_writer')
@patch('mindformers.core.callback.callback.get_tensorboard_args')
def test_update_topk_bias(self, mock_args, mock_writer, mock_get_parallel):
"""Test topk bias update functionality"""
mock_args.return_value = {'log_expert_load_to_tensorboard': False}
mock_get_parallel.return_value = 1 # pipeline stages

# We need to mock P.Assign etc, which are used in __init__
with patch('mindspore.ops.operations.Assign'), \
patch('mindspore.ops.operations.Sub'), \
patch('mindspore.ops.operations.Add'), \
patch('mindspore.ops.operations.Sign'), \
patch('mindspore.ops.operations.Mul'), \
patch('mindspore.ops.operations.Div'):
callback = TopkBiasBalanceCallback(balance_via_topk_bias=True, expert_num=2)

# Setup network structure for _update_topk_bias logic
# Ensure leaf_network does not have 'network' attribute to terminate loop
leaf_network = Mock()
del leaf_network.network

layer = Mock()
router_inner = Mock()
mock_expert_load = Mock()
mock_expert_load.sum.return_value = 2.0
router_inner.expert_load.value.return_value = mock_expert_load
router_inner.topk_bias.value.return_value = Mock()

router = Mock()
router.router = router_inner

routed_experts = Mock()
routed_experts.router = router

feed_forward = Mock()
feed_forward.routed_experts = routed_experts

layer.feed_forward = feed_forward
leaf_network.model.layers = [layer]

network = Mock()
network.network = leaf_network

run_context = Mock()
cb_params = Mock()
cb_params.train_network = network
run_context.original_args.return_value = cb_params

ctx_patch = 'mindformers.core.callback.callback.get_auto_parallel_context'
with patch(ctx_patch, return_value="stand_alone"):
callback.on_train_step_end(run_context)


class TestMoEDropRateCallback:
"""Test MoEDropRateCallback"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_callback_droprate(self):
"""Test MoEDropRateCallback - simplified version that just verifies initialization"""
# Test initialization
callback = MoEDropRateCallback(expert_num=8, capacity_factor=1.1, num_layers=1, mtp_depth=0)

# Verify basic attributes
assert callback.capacity_factor_over_expert_num == 1.1 / 8
assert callback.num_layers == 1

# Test with mock network that has no routed_experts (skip the callback logic)
leaf_network = Mock()
del leaf_network.network

layer = Mock()
# Make feed_forward not have routed_experts attribute
layer.feed_forward = Mock(spec=[])

leaf_network.model.layers = [layer]

network = Mock()
network.network = leaf_network

run_context = Mock()
cb_params = Mock()
cb_params.train_network = network
run_context.original_args.return_value = cb_params

# Mock to avoid entering the complex logic
ctx_patch = 'mindformers.core.callback.callback.get_auto_parallel_context'
with patch(ctx_patch, return_value="stand_alone"):
# This should not raise any errors
callback.on_train_step_end(run_context)


class TestTopkBiasBalanceCallbackExtended:
"""Extended tests for TopkBiasBalanceCallback"""

@patch('mindformers.core.callback.callback.get_tensorboard_args',
return_value={'log_expert_load_to_tensorboard': True})
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=Mock())
def test_log_expert_load_to_tensorboard(self, *mocks):
"""Test logging expert load to tensorboard"""

callback = TopkBiasBalanceCallback(
balance_via_topk_bias=False,
topk_bias_update_rate=0.01,
expert_num=8,
micro_batch_num=2,
gradient_accumulation_steps=4
)

assert callback.tensor_writer is not None

if __name__ == '__main__':
pytest.main([__file__, '-v'])

+ 137
- 1
tests/st/test_ut/test_core/test_callback/test_profile_monitor.py View File

@@ -20,6 +20,10 @@ import tempfile
import pytest
from mindformers.core.callback.callback import ProfileMonitor

# pylint: disable=protected-access
# pylint: disable=unused-argument # for mock logic


class TestProfileMonitor(unittest.TestCase):
"""Test cases for ProfileMonitor class"""

@@ -160,5 +164,137 @@ class TestProfileMonitor(unittest.TestCase):
self.assertIsNone(monitor.profile_rank_ids)


class TestProfileMonitorExtended:
"""Extended tests for ProfileMonitor"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_check_step_valid(self):
"""Test _check_step with valid inputs"""

start, stop = ProfileMonitor._check_step(5, 10)
assert start == 5
assert stop == 10

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_check_step_invalid(self):
"""Test _check_step with invalid inputs"""

# start > stop
start, stop = ProfileMonitor._check_step(15, 10)
assert start == 1
assert stop == 10

# negative values
start, stop = ProfileMonitor._check_step(-1, -5)
assert start == 1
assert stop == 10

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_check_start_profile(self):
"""Test _check_start_profile"""

# start_step != 1, should return False
result = ProfileMonitor._check_start_profile(True, 5)
assert not result

# start_step == 1, should keep original value
result = ProfileMonitor._check_start_profile(True, 1)
assert result


class TestProfileMonitorInit:
"""Test ProfileMonitor initialization and configuration"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.get_pipeline_rank_ids', return_value=[0, 1])
@patch('mindformers.core.callback.callback.get_output_subpath', return_value='/output/profile')
@patch('mindformers.core.callback.callback.ms.get_context', return_value='Ascend')
@patch('mindformers.core.callback.callback.is_version_ge', return_value=True)
@patch('mindformers.core.callback.callback._check_mspti_is_on', return_value=False)
def test_profile_monitor_init_with_pipeline(self, mock_mspti, mock_version, mock_context,
mock_output, mock_pipeline_ids, mock_real_rank):
"""Test ProfileMonitor initialization with pipeline profiling"""

# Mock the profile function from mindspore.profiler
with patch('mindspore.profiler.profile') as mock_profile:
mock_profiler_instance = Mock()
mock_profile.return_value = mock_profiler_instance

monitor = ProfileMonitor(
start_step=1,
stop_step=10,
profile_pipeline=True,
profile_communication=True,
profile_memory=True,
profiler_level=1
)

assert monitor.profiler is not None
assert monitor.start_step == 1
assert monitor.stop_step == 10

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=5)
@patch('mindformers.core.callback.callback.get_pipeline_rank_ids', return_value=[0, 1])
def test_profile_monitor_not_required_rank(self, mock_pipeline_ids, mock_real_rank):
"""Test ProfileMonitor when current rank doesn't need profiling"""

monitor = ProfileMonitor(
start_step=1,
stop_step=10,
profile_rank_ids=[0, 1, 2]
)

# Rank 5 is not in profile_rank_ids or pipeline_rank_ids
assert monitor.profiler is None

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.get_pipeline_rank_ids', return_value=[0])
@patch('mindformers.core.callback.callback.get_output_subpath', return_value='/output/profile')
@patch('mindformers.core.callback.callback.ms.get_context', return_value='Ascend')
@patch('mindformers.core.callback.callback.is_version_ge', return_value=True)
@patch('mindformers.core.callback.callback._check_mspti_is_on', return_value=False)
def test_on_train_step_begin_start_profiler(self, mock_mspti, mock_version, mock_context,
mock_output, mock_pipeline_ids, mock_real_rank):
"""Test on_train_step_begin starts profiler"""

# Create a mock profiler
mock_profiler = Mock()
mock_profiler.start = Mock()
mock_profiler.step = Mock()

# Create monitor - we'll manually set the profiler
monitor = ProfileMonitor(
start_step=1,
stop_step=10,
profile_rank_ids=[0] # Ensure rank 0 is profiled
)

# Manually set the profiler and is_profiler_start flag
monitor.profiler = mock_profiler
monitor.is_profiler_start = False

# Create run context
run_context = Mock()
cb_params = Mock()
cb_params.cur_step_num = 1
run_context.original_args.return_value = cb_params

# Call on_train_step_begin
monitor.on_train_step_begin(run_context)

# Verify profiler.start() and profiler.step() were called
mock_profiler.start.assert_called_once()
mock_profiler.step.assert_called_once()
assert monitor.is_profiler_start

if __name__ == '__main__':
unittest.main()
pytest.main([__file__, '-v'])

+ 653
- 0
tests/st/test_ut/test_core/test_callback/test_stress_test_monitor.py View File

@@ -0,0 +1,653 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test callback.py using pytest framework."""
import os
import tempfile
from unittest.mock import Mock, patch

import numpy as np
import pytest

from mindformers.core.callback.callback import StressTestModelMonitor

# pylint: disable=unused-argument # for mock logic


class TestStressTestModelMonitorBasic:
"""Test StressTestModelMonitor basic methods"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindspore.communication.get_local_rank_size', return_value=8)
@patch('os.getenv')
def test_get_value_from_line(self, mock_getenv, mock_rank_size):
"""Test get_value_from_line method"""

model_dir = tempfile.mkdtemp()
dataset_dir = tempfile.mkdtemp()

# Mock MS_SCHED_PORT environment variable
def getenv_side_effect(key, default=None):
if key == "MS_SCHED_PORT":
return "8118" # Return a valid port number as string
return default

mock_getenv.side_effect = getenv_side_effect

monitor = StressTestModelMonitor(
interval_steps=10,
stress_model_dir=model_dir,
stress_dataset_dir=dataset_dir
)

line = "loss: 0.5234, global_norm: [1.234]"
loss = monitor.get_value_from_line(line, r"loss: (\d+\.\d+)")
assert loss == 0.5234

global_norm = monitor.get_value_from_line(line, r"global_norm: \[(\d+\.\d+)\]")
assert global_norm == 1.234

# No match
result = monitor.get_value_from_line(line, r"notfound: (\d+\.\d+)")
assert result is None


class TestStressTestModelMonitorMethods:
"""Test StressTestModelMonitor methods"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_rank', return_value=0)
@patch('mindformers.core.callback.callback.os.getenv')
@patch('mindformers.core.callback.callback.os.path.exists', return_value=True)
@patch('mindformers.core.callback.callback.ms.communication.get_local_rank_size',
return_value=8)
def test_on_train_step_end_skip(self, mock_local_rank, mock_exists, mock_getenv, mock_get_rank):
"""Test on_train_step_end when interval not reached"""

def getenv_side_effect(key, default=None):
return "8118" if key == "MS_SCHED_PORT" else default

mock_getenv.side_effect = getenv_side_effect

monitor = StressTestModelMonitor(
interval_steps=100,
stress_model_dir='/path/to/model',
stress_dataset_dir='/path/to/dataset'
)

monitor.last_checked_step = 0
monitor.check_stress_test_model = Mock()

run_context = Mock()
cb_params = Mock()
cb_params.cur_step_num = 50 # Less than interval
run_context.original_args.return_value = cb_params

monitor.on_train_step_end(run_context)

# Should not call check_stress_test_model
monitor.check_stress_test_model.assert_not_called()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_rank', return_value=0)
@patch('mindformers.core.callback.callback.os.getenv')
@patch('mindformers.core.callback.callback.os.path.exists', return_value=True)
@patch('mindformers.core.callback.callback.ms.communication.get_local_rank_size',
return_value=8)
def test_extract_interval_step_results_empty(
self, mock_local_rank, mock_exists, mock_getenv, mock_get_rank):
"""Test extract_interval_step_results with no matching intervals"""

def getenv_side_effect(key, default=None):
return "8118" if key == "MS_SCHED_PORT" else default

mock_getenv.side_effect = getenv_side_effect

monitor = StressTestModelMonitor(
interval_steps=100,
stress_model_dir='/path/to/model',
stress_dataset_dir='/path/to/dataset',
compare_interval_steps=1000 # Very large interval
)

# Create a temporary log file with few steps
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.log') as f:
log_line1 = "{Epoch:[ 1], step:[ 10/ 100], loss: 2.5, global_norm: [1.2]}"
f.write(f"2024-01-01 10:00:00 - INFO - {log_line1}\n")
log_line2 = "{Epoch:[ 1], step:[ 20/ 100], loss: 2.3, global_norm: [1.1]}"
f.write(f"2024-01-01 10:01:00 - INFO - {log_line2}\n")
log_file = f.name

try:
results, global_step = monitor.extract_interval_step_results(log_file)
# Should return None when interval is too large
assert results is None
assert global_step == 20
finally:
os.remove(log_file)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_rank', return_value=0)
@patch('mindformers.core.callback.callback.os.getenv')
@patch('mindformers.core.callback.callback.os.path.exists', return_value=True)
@patch('mindformers.core.callback.callback.ms.communication.get_local_rank_size',
return_value=8)
def test_compare_gathered_results_consistent(
self, mock_local_rank, mock_exists, mock_getenv, mock_get_rank):
"""Test compare_gathered_results with consistent results"""

def getenv_side_effect(key, default=None):
return "8118" if key == "MS_SCHED_PORT" else default

mock_getenv.side_effect = getenv_side_effect

monitor = StressTestModelMonitor(
interval_steps=100,
stress_model_dir='/path/to/model',
stress_dataset_dir='/path/to/dataset'
)

# Create consistent results from multiple ranks
gathered_results = np.array([
[[1, 10, 2.5, 1.2]],
[[1, 10, 2.5, 1.2]],
[[1, 10, 2.5, 1.2]],
[[1, 10, 2.5, 1.2]]
])

result = monitor.compare_gathered_results(gathered_results)

# Should return True for consistent results
assert result

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_rank', return_value=0)
@patch('mindformers.core.callback.callback.os.getenv')
@patch('mindformers.core.callback.callback.os.path.exists', return_value=True)
@patch('mindformers.core.callback.callback.ms.communication.get_local_rank_size',
return_value=8)
def test_compare_gathered_results_inconsistent(
self, mock_local_rank, mock_exists, mock_getenv, mock_get_rank):
"""Test compare_gathered_results with inconsistent results"""

def getenv_side_effect(key, default=None):
return "8118" if key == "MS_SCHED_PORT" else default

mock_getenv.side_effect = getenv_side_effect

monitor = StressTestModelMonitor(
interval_steps=100,
stress_model_dir='/path/to/model',
stress_dataset_dir='/path/to/dataset'
)

# Create inconsistent results from multiple ranks
gathered_results = np.array([
[[1, 10, 2.5, 1.2]],
[[1, 10, 2.6, 1.3]], # Different values
[[1, 10, 2.5, 1.2]],
[[1, 10, 2.5, 1.2]]
])

result = monitor.compare_gathered_results(gathered_results)

# Should return False for inconsistent results
assert not result


class TestStressTestModelMonitorCheckStressTest:
"""Test StressTestModelMonitor.check_stress_test_model method"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_rank', return_value=0)
@patch('mindformers.core.callback.callback.os.getenv')
@patch('mindformers.core.callback.callback.os.path.exists', return_value=True)
@patch('mindformers.core.callback.callback.ms.communication.get_local_rank_size',
return_value=8)
@patch('mindformers.core.callback.callback.logger')
def test_check_stress_test_model_dataset_not_exists(self, mock_logger, mock_local_rank,
mock_exists, mock_getenv, mock_get_rank):
"""Test check_stress_test_model when dataset_dir doesn't exist"""

def getenv_side_effect(key, default=None):
return "8118" if key == "MS_SCHED_PORT" else default

mock_getenv.side_effect = getenv_side_effect

monitor = StressTestModelMonitor(
interval_steps=100,
stress_model_dir='/path/to/model',
stress_dataset_dir='/path/to/dataset'
)

# Make dataset_dir check return False
mock_exists.return_value = False
monitor.dataset_dir = '/nonexistent/path'

# Should return early without running stress test
monitor.check_stress_test_model(current_step=100)

# Should log error about dataset not found
mock_logger.error.assert_called()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_rank', return_value=0)
@patch('mindformers.core.callback.callback.os.getenv')
@patch('mindformers.core.callback.callback.os.path.exists', return_value=True)
@patch('mindformers.core.callback.callback.ms.communication.get_local_rank_size',
return_value=8)
@patch('mindformers.core.callback.callback.logger')
def test_check_stress_test_model_dataset_dir_none(self, mock_logger, mock_local_rank,
mock_exists, mock_getenv, mock_get_rank):
"""Test check_stress_test_model when dataset_dir is None (line 3000)"""

def getenv_side_effect(key, default=None):
return "8118" if key == "MS_SCHED_PORT" else default

mock_getenv.side_effect = getenv_side_effect

monitor = StressTestModelMonitor(
interval_steps=100,
stress_model_dir='/path/to/model',
stress_dataset_dir='/path/to/dataset'
)

# Set dataset_dir to None
monitor.dataset_dir = None

# Should return early without running stress test
monitor.check_stress_test_model(current_step=100)

# Should log error about dataset not found
mock_logger.error.assert_called()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_rank', return_value=0)
@patch('mindformers.core.callback.callback.os.getenv')
@patch('mindformers.core.callback.callback.os.path.exists', return_value=True)
@patch('mindformers.core.callback.callback.ms.communication.get_local_rank_size',
return_value=8)
@patch('mindformers.core.callback.callback.os.cpu_count', return_value=16)
@patch('mindformers.core.callback.callback.barrier')
@patch('mindformers.core.callback.callback.all_gather_into_tensor')
@patch('mindformers.core.callback.callback.subprocess.Popen')
@patch('mindformers.core.callback.callback.shlex.split', side_effect=lambda x: x.split())
@patch('mindformers.core.callback.callback.logger')
def test_check_stress_test_model_rank0_runs_subprocess(self, mock_logger, mock_shlex, mock_popen,
mock_all_gather, mock_barrier,
mock_cpu_count, mock_local_rank,
mock_exists, mock_getenv, mock_get_rank):
"""Test check_stress_test_model when rank_id % worker_num == 0 runs subprocess"""

def getenv_side_effect(key, default=None):
return "8118" if key == "MS_SCHED_PORT" else default

mock_getenv.side_effect = getenv_side_effect

monitor = StressTestModelMonitor(
interval_steps=100,
stress_model_dir='/path/to/model',
stress_dataset_dir='/path/to/dataset',
compare_interval_steps=None # Skip interval comparison
)

# Mock subprocess behavior
mock_process = Mock()
mock_process.poll.side_effect = [None, 0] # First call returns None, second returns 0
mock_process.returncode = 0
mock_process.__enter__ = Mock(return_value=mock_process)
mock_process.__exit__ = Mock(return_value=False)
mock_popen.return_value = mock_process

# Mock all_gather_into_tensor result
mock_tensor = Mock()
mock_tensor.asnumpy.return_value = np.array([[1.0, 2.0], [1.0, 2.0]])
mock_all_gather.return_value = (mock_tensor, None)

# Mock readlog and extract methods
monitor.readlog = Mock(return_value="Training step 10")
monitor.extract_last_step_result = Mock(return_value=Mock())

monitor.check_stress_test_model(current_step=100)

# Should call Popen to start subprocess
mock_popen.assert_called()
# Should call barrier for synchronization
mock_barrier.assert_called()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_rank', return_value=1)
@patch('mindformers.core.callback.callback.os.getenv')
@patch('mindformers.core.callback.callback.os.path.exists', return_value=True)
@patch('mindformers.core.callback.callback.ms.communication.get_local_rank_size',
return_value=8)
@patch('mindformers.core.callback.callback.barrier')
@patch('mindformers.core.callback.callback.all_gather_into_tensor')
@patch('mindformers.core.callback.callback.logger')
def test_check_stress_test_model_non_rank0_skips_subprocess(self, mock_logger, mock_all_gather,
mock_barrier, mock_local_rank,
mock_exists, mock_getenv, mock_get_rank):
"""Test check_stress_test_model when rank_id % worker_num != 0 skips subprocess"""

def getenv_side_effect(key, default=None):
return "8118" if key == "MS_SCHED_PORT" else default

mock_getenv.side_effect = getenv_side_effect

monitor = StressTestModelMonitor(
interval_steps=100,
stress_model_dir='/path/to/model',
stress_dataset_dir='/path/to/dataset',
compare_interval_steps=None
)

# Mock all_gather_into_tensor result
mock_tensor = Mock()
mock_tensor.asnumpy.return_value = np.array([[1.0, 2.0], [1.0, 2.0]])
mock_all_gather.return_value = (mock_tensor, None)

monitor.extract_last_step_result = Mock(return_value=Mock())

monitor.check_stress_test_model(current_step=100)

# Should call barrier (synchronization happens regardless of rank)
mock_barrier.assert_called()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_rank', return_value=0)
@patch('mindformers.core.callback.callback.os.getenv')
@patch('mindformers.core.callback.callback.os.path.exists', return_value=True)
@patch('mindformers.core.callback.callback.ms.communication.get_local_rank_size',
return_value=8)
@patch('mindformers.core.callback.callback.os.cpu_count', return_value=16)
@patch('mindformers.core.callback.callback.barrier')
@patch('mindformers.core.callback.callback.all_gather_into_tensor')
@patch('mindformers.core.callback.callback.subprocess.Popen')
@patch('mindformers.core.callback.callback.shlex.split', side_effect=lambda x: x.split())
@patch('mindformers.core.callback.callback.logger')
def test_check_stress_test_model_with_compare_interval_steps(self, mock_logger, mock_shlex, mock_popen,
mock_all_gather, mock_barrier,
mock_cpu_count, mock_local_rank,
mock_exists, mock_getenv, mock_get_rank):
"""Test check_stress_test_model with compare_interval_steps set"""

def getenv_side_effect(key, default=None):
return "8118" if key == "MS_SCHED_PORT" else default

mock_getenv.side_effect = getenv_side_effect

monitor = StressTestModelMonitor(
interval_steps=100,
stress_model_dir='/path/to/model',
stress_dataset_dir='/path/to/dataset',
compare_interval_steps=10 # Set interval comparison
)

# Mock subprocess behavior
mock_process = Mock()
mock_process.poll.side_effect = [None, 0]
mock_process.returncode = 0
mock_process.__enter__ = Mock(return_value=mock_process)
mock_process.__exit__ = Mock(return_value=False)
mock_popen.return_value = mock_process

# Mock all_gather_into_tensor result for interval comparison
mock_interval_tensor = Mock()
mock_interval_tensor.asnumpy.return_value = np.array([[[1, 10, 2.5, 1.2]], [[1, 10, 2.5, 1.2]]])
mock_last_tensor = Mock()
mock_last_tensor.asnumpy.return_value = np.array([[1.0, 2.0], [1.0, 2.0]])
mock_all_gather.side_effect = [(mock_interval_tensor, None), (mock_last_tensor, None)]

# Mock extract methods to return valid results
mock_interval_result = Mock()
monitor.extract_interval_step_results = Mock(return_value=(mock_interval_result, 100))
monitor.extract_last_step_result = Mock(return_value=Mock())
monitor.readlog = Mock(return_value="Training step 10")
monitor.compare_gathered_results = Mock(return_value=True)

monitor.check_stress_test_model(current_step=100)

# Should call compare_gathered_results for interval comparison
monitor.compare_gathered_results.assert_called()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_rank', return_value=0)
@patch('mindformers.core.callback.callback.os.getenv')
@patch('mindformers.core.callback.callback.os.path.exists', return_value=True)
@patch('mindformers.core.callback.callback.ms.communication.get_local_rank_size',
return_value=8)
@patch('mindformers.core.callback.callback.os.cpu_count', return_value=16)
@patch('mindformers.core.callback.callback.barrier')
@patch('mindformers.core.callback.callback.all_gather_into_tensor')
@patch('mindformers.core.callback.callback.subprocess.Popen')
@patch('mindformers.core.callback.callback.shlex.split', side_effect=lambda x: x.split())
@patch('mindformers.core.callback.callback.logger')
def test_check_stress_test_model_interval_results_none(self, mock_logger, mock_shlex, mock_popen,
mock_all_gather, mock_barrier,
mock_cpu_count, mock_local_rank,
mock_exists, mock_getenv, mock_get_rank):
"""Test check_stress_test_model when interval_results is None"""

def getenv_side_effect(key, default=None):
return "8118" if key == "MS_SCHED_PORT" else default

mock_getenv.side_effect = getenv_side_effect

monitor = StressTestModelMonitor(
interval_steps=100,
stress_model_dir='/path/to/model',
stress_dataset_dir='/path/to/dataset',
compare_interval_steps=1000 # Large interval
)

# Mock subprocess behavior
mock_process = Mock()
mock_process.poll.side_effect = [None, 0]
mock_process.returncode = 0
mock_process.__enter__ = Mock(return_value=mock_process)
mock_process.__exit__ = Mock(return_value=False)
mock_popen.return_value = mock_process

# Mock all_gather_into_tensor result
mock_tensor = Mock()
mock_tensor.asnumpy.return_value = np.array([[1.0, 2.0], [1.0, 2.0]])
mock_all_gather.return_value = (mock_tensor, None)

# Mock extract_interval_step_results to return None (interval too large)
monitor.extract_interval_step_results = Mock(return_value=(None, 50))
monitor.extract_last_step_result = Mock(return_value=Mock())
monitor.readlog = Mock(return_value="Training step 10")

monitor.check_stress_test_model(current_step=100)

# Should log warning about interval being larger than total steps
mock_logger.warning.assert_called()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_rank', return_value=0)
@patch('mindformers.core.callback.callback.os.getenv')
@patch('mindformers.core.callback.callback.os.path.exists', return_value=True)
@patch('mindformers.core.callback.callback.ms.communication.get_local_rank_size',
return_value=8)
@patch('mindformers.core.callback.callback.os.cpu_count', return_value=16)
@patch('mindformers.core.callback.callback.barrier')
@patch('mindformers.core.callback.callback.all_gather_into_tensor')
@patch('mindformers.core.callback.callback.subprocess.Popen')
@patch('mindformers.core.callback.callback.shlex.split', side_effect=lambda x: x.split())
@patch('mindformers.core.callback.callback.logger')
def test_check_stress_test_model_results_match(self, mock_logger, mock_shlex, mock_popen,
mock_all_gather, mock_barrier,
mock_cpu_count, mock_local_rank,
mock_exists, mock_getenv, mock_get_rank):
"""Test check_stress_test_model when all results match"""

def getenv_side_effect(key, default=None):
return "8118" if key == "MS_SCHED_PORT" else default

mock_getenv.side_effect = getenv_side_effect

monitor = StressTestModelMonitor(
interval_steps=100,
stress_model_dir='/path/to/model',
stress_dataset_dir='/path/to/dataset',
compare_interval_steps=None
)

# Mock subprocess behavior
mock_process = Mock()
mock_process.poll.side_effect = [None, 0]
mock_process.returncode = 0
mock_process.__enter__ = Mock(return_value=mock_process)
mock_process.__exit__ = Mock(return_value=False)
mock_popen.return_value = mock_process

# All results match
mock_tensor = Mock()
mock_tensor.asnumpy.return_value = np.array([[1.0, 2.0], [1.0, 2.0], [1.0, 2.0]])
mock_all_gather.return_value = (mock_tensor, None)

monitor.extract_last_step_result = Mock(return_value=Mock())
monitor.readlog = Mock(return_value="Training step 10")

monitor.check_stress_test_model(current_step=100)

# Should log STRESS TEST PASSED
info_calls = [str(call) for call in mock_logger.info.call_args_list]
passed_logged = any('STRESS TEST PASSED' in str(call) for call in info_calls)
assert passed_logged

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_rank', return_value=0)
@patch('mindformers.core.callback.callback.os.getenv')
@patch('mindformers.core.callback.callback.os.path.exists', return_value=True)
@patch('mindformers.core.callback.callback.ms.communication.get_local_rank_size',
return_value=8)
@patch('mindformers.core.callback.callback.os.cpu_count', return_value=16)
@patch('mindformers.core.callback.callback.barrier')
@patch('mindformers.core.callback.callback.all_gather_into_tensor')
@patch('mindformers.core.callback.callback.subprocess.Popen')
@patch('mindformers.core.callback.callback.shlex.split', side_effect=lambda x: x.split())
@patch('mindformers.core.callback.callback.logger')
def test_check_stress_test_model_results_mismatch(self, mock_logger, mock_shlex, mock_popen,
mock_all_gather, mock_barrier,
mock_cpu_count, mock_local_rank,
mock_exists, mock_getenv, mock_get_rank):
"""Test check_stress_test_model when results don't match"""

def getenv_side_effect(key, default=None):
return "8118" if key == "MS_SCHED_PORT" else default

mock_getenv.side_effect = getenv_side_effect

monitor = StressTestModelMonitor(
interval_steps=100,
stress_model_dir='/path/to/model',
stress_dataset_dir='/path/to/dataset',
compare_interval_steps=None
)

# Mock subprocess behavior
mock_process = Mock()
mock_process.poll.side_effect = [None, 0]
mock_process.returncode = 0
mock_process.__enter__ = Mock(return_value=mock_process)
mock_process.__exit__ = Mock(return_value=False)
mock_popen.return_value = mock_process

# Results don't match - different values
mock_tensor = Mock()
mock_tensor.asnumpy.return_value = np.array([[1.0, 2.0], [1.5, 2.5], [1.0, 2.0]])
mock_all_gather.return_value = (mock_tensor, None)

monitor.extract_last_step_result = Mock(return_value=Mock())
monitor.readlog = Mock(return_value="Training step 10")

monitor.check_stress_test_model(current_step=100)

# Should log STRESS TEST FAILED warning
warning_calls = [str(call) for call in mock_logger.warning.call_args_list]
failed_logged = any('STRESS TEST FAILED' in str(call) for call in warning_calls)
assert failed_logged

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_rank', return_value=0)
@patch('mindformers.core.callback.callback.os.getenv')
@patch('mindformers.core.callback.callback.os.path.exists', return_value=True)
@patch('mindformers.core.callback.callback.ms.communication.get_local_rank_size',
return_value=8)
@patch('mindformers.core.callback.callback.os.cpu_count', return_value=16)
@patch('mindformers.core.callback.callback.barrier')
@patch('mindformers.core.callback.callback.all_gather_into_tensor')
@patch('mindformers.core.callback.callback.subprocess.Popen')
@patch('mindformers.core.callback.callback.shlex.split', side_effect=lambda x: x.split())
@patch('mindformers.core.callback.callback.logger')
def test_check_stress_test_model_subprocess_error(self, mock_logger, mock_shlex, mock_popen,
mock_all_gather, mock_barrier,
mock_cpu_count, mock_local_rank,
mock_exists, mock_getenv, mock_get_rank):
"""Test check_stress_test_model when subprocess returns error"""

def getenv_side_effect(key, default=None):
return "8118" if key == "MS_SCHED_PORT" else default

mock_getenv.side_effect = getenv_side_effect

monitor = StressTestModelMonitor(
interval_steps=100,
stress_model_dir='/path/to/model',
stress_dataset_dir='/path/to/dataset',
compare_interval_steps=None
)

# Mock subprocess with error
mock_process = Mock()
mock_process.poll.side_effect = [None, 1] # Returns non-zero exit code
mock_process.returncode = 1 # Error
mock_process.stderr = Mock()
mock_process.stderr.read.return_value = b'Error message'
mock_process.__enter__ = Mock(return_value=mock_process)
mock_process.__exit__ = Mock(return_value=False)
mock_popen.return_value = mock_process

# Mock all_gather_into_tensor result
mock_tensor = Mock()
mock_tensor.asnumpy.return_value = np.array([[1.0, 2.0], [1.0, 2.0]])
mock_all_gather.return_value = (mock_tensor, None)

monitor.extract_last_step_result = Mock(return_value=Mock())
monitor.readlog = Mock(return_value="Training step 10")

monitor.check_stress_test_model(current_step=100)

# Should log warning about subprocess error
warning_calls = [str(call) for call in mock_logger.warning.call_args_list]
error_logged = any('error occurred' in str(call).lower() for call in warning_calls)
assert error_logged

if __name__ == '__main__':
pytest.main([__file__, '-v'])

+ 1050
- 0
tests/st/test_ut/test_core/test_callback/test_training_state_monitor.py View File

@@ -0,0 +1,1050 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test callback.py using pytest framework."""
from unittest.mock import Mock, patch

import numpy as np
import pytest

from mindformers.core.callback.callback import TrainingStateMonitor

# pylint: disable=protected-access
# pylint: disable=unused-argument # for mock logic


class TestTrainingStateMonitor:
"""Test TrainingStateMonitor class"""

def setup_method(self):
"""Set up test fixtures for each test method."""

# Mock context.get_auto_parallel_context to return appropriate values
def mock_get_context(key):
if key == "pipeline_stages":
return 1 # Return integer for pipeline_stages
if key == "dump_local_norm_path":
return None # No dump path
return None

with patch('mindformers.core.callback.callback.context.get_auto_parallel_context',
side_effect=mock_get_context), \
patch('mindformers.core.callback.callback.get_real_group_size', return_value=1), \
patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None):
self.monitor = TrainingStateMonitor(
origin_epochs=10,
dataset_size=100,
step_interval=1
)
# Initialize dump_path to None to avoid finish_pattern attribute error
self.monitor.dump_path = None

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_init(self):
"""Test initialization"""
assert self.monitor.origin_epochs == 10
assert self.monitor.steps_per_epoch == 100
assert self.monitor.step_interval == 1

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('time.time')
def test_on_train_epoch_begin(self, mock_time):
"""Test on_train_epoch_begin"""
mock_time.return_value = 12345.0
run_context = Mock()
self.monitor.on_train_epoch_begin(run_context)
assert self.monitor.epoch_time == 12345.0
assert self.monitor.run_context == run_context

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('time.time')
def test_on_train_step_begin(self, mock_time):
"""Test on_train_step_begin"""
mock_time.return_value = 67890.0
run_context = Mock()
run_context.original_args.return_value = Mock()
self.monitor.on_train_step_begin(run_context)
assert self.monitor.step_time == 67890.0
assert self.monitor.run_context == run_context

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('time.time')
@patch('mindformers.core.callback.callback.set_auto_parallel_context')
@patch('mindformers.core.callback.callback.get_auto_parallel_context')
def test_on_train_step_end_basic(self, mock_get_parallel, mock_set_parallel, mock_time):
"""Test on_train_step_end basic flow"""
mock_time.side_effect = [1000.0, 1000.1] # start, end

def get_context_side_effect(attr):
if attr == "parallel_mode":
return "stand_alone"
if attr == "full_batch":
return False
if attr == "dump_local_norm_path":
return None
return None

mock_get_parallel.side_effect = get_context_side_effect

run_context = Mock()
cb_params = Mock()
cb_params.cur_step_num = 1
cb_params.batch_num = 100
cb_params.cur_epoch_num = 1
cb_params.dataset_sink_mode = False
cb_params.net_outputs = Mock() # loss
cb_params.initial_step = 0

# Mock get method for cb_params to behave like dict for 'initial_step'
def mock_get(key, default=None):
if key == 'initial_step':
return cb_params.initial_step
return getattr(cb_params, key, default)

cb_params.get.side_effect = mock_get

run_context.original_args.return_value = cb_params

with patch.object(self.monitor, '_get_loss_output', return_value=(0.5, 1.0, 0.1)):
self.monitor.step_time = 1000.0
self.monitor.on_train_step_end(run_context)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.check_device_local_loss')
def test_boundary_check(self, mock_check_loss):
"""Test _boundary_check"""
self.monitor.check_for_nan_in_loss_and_grad = True
cb_params = Mock()

# Case 1: Normal
with patch.object(self.monitor, '_get_loss_output', return_value=(0.5, 1.0, 0.1)):
self.monitor._boundary_check(cb_params)

# Case 2: NaN loss
# We need to simulate np.isnan check.
# If _get_loss_output returns nan, _check_nan_or_inf checks it
# using np.any(np.isnan(indicator))
with patch.object(self.monitor, '_get_loss_output', return_value=(float('nan'), 1.0, 0.1)):
with pytest.raises(ValueError):
self.monitor._boundary_check(cb_params)


class TestTrainingStateMonitorExtended:
"""Extended tests for TrainingStateMonitor"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=1)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
def test_global_norm_spike_detection(self, *mocks):
"""Test global norm spike detection"""

config = {
'global_norm_spike_threshold': 10.0,
'global_norm_spike_count_threshold': 3
}

# Mock context.get_auto_parallel_context to return appropriate values
def mock_get_context(key):
if key == "pipeline_stages":
return 1 # Return integer for pipeline_stages
if key == "dump_local_norm_path":
return None # No dump path
return None

with patch('mindformers.core.callback.callback.context.get_auto_parallel_context',
side_effect=mock_get_context):
monitor = TrainingStateMonitor(
origin_epochs=10,
dataset_size=100,
config=config
)

assert monitor.global_norm_spike_threshold == 10.0
assert monitor.global_norm_spike_count_threshold == 3

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=1)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
def test_dump_path_initialization(self, *mocks):
"""Test dump_path initialization"""

# Mock context.get_auto_parallel_context to return appropriate values
def mock_get_context(key):
if key == "pipeline_stages":
return 1 # Return integer for pipeline_stages
if key == "dump_local_norm_path":
return "/tmp/test_path" # Return a path to trigger dump_path initialization
return None

with patch('mindformers.core.callback.callback.context.get_auto_parallel_context',
side_effect=mock_get_context), \
patch('mindformers.core.callback.callback.get_auto_parallel_context',
side_effect=mock_get_context):
monitor = TrainingStateMonitor(origin_epochs=10, dataset_size=100)
assert monitor.dump_path is not None
assert monitor.dump_path == "/tmp/test_path/rank_0"


class TestTrainingStateMonitorAbnormalGlobalNorm:
"""Test TrainingStateMonitor.abnormal_global_norm_check"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.get_rank', return_value=0)
@patch('mindformers.core.callback.callback.barrier_world')
@patch('mindformers.core.callback.callback.ms.runtime.synchronize')
@patch('mindformers.core.callback.callback.logger')
@patch('mindformers.core.callback.callback.os.path.exists', return_value=False)
@patch('mindformers.core.callback.callback.os.makedirs')
@patch('builtins.open', create=True)
@patch('mindformers.core.callback.callback.set_safe_mode_for_file_or_dir')
@patch('mindformers.core.callback.callback.context.get_auto_parallel_context', return_value=1)
def test_abnormal_global_norm_check_first_spike(
self, mock_context, mock_safe_mode, mock_open, mock_makedirs,
mock_exists, mock_logger, mock_sync,
mock_barrier, mock_rank, *mocks):
"""Test abnormal_global_norm_check when first spike occurs"""

config = {
'check_for_global_norm': True,
'global_norm_spike_threshold': 10.0,
'global_norm_spike_count_threshold': 3
}

monitor = TrainingStateMonitor(
origin_epochs=10,
dataset_size=100,
config=config
)

cb_params = Mock()
cb_params.cur_step_num = 5
cb_params.batch_num = 100
cb_params.dataset_sink_mode = False
cb_params.cur_epoch_num = 1
cb_params.get.return_value = None

# Mock net_outputs with high global_norm
# Need to make global_norm support >= comparison
mock_global_norm = Mock()
mock_global_norm.item.return_value = 15.0
# Mock __ge__ to support >= comparison
mock_global_norm.__ge__ = Mock(return_value=True)
cb_params.net_outputs = (Mock(), False, 1024.0, 0.001, mock_global_norm)

# Should raise RuntimeError on first spike
with pytest.raises(RuntimeError) as exc_info:
monitor.abnormal_global_norm_check(cb_params)

assert "TREError" in str(exc_info.value)
mock_barrier.assert_called_once()
mock_sync.assert_called_once()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.logger')
@patch('mindformers.core.callback.callback.context.get_auto_parallel_context', return_value=1)
def test_abnormal_global_norm_check_use_skip_data(self, mock_context, mock_logger, *mocks):
"""Test abnormal_global_norm_check with use_skip_data_by_global_norm"""

config = {
'check_for_global_norm': False,
'global_norm_spike_threshold': 10.0,
'global_norm_spike_count_threshold': 3
}

monitor = TrainingStateMonitor(
origin_epochs=10,
dataset_size=100,
config=config,
use_skip_data_by_global_norm=True
)

cb_params = Mock()
cb_params.cur_step_num = 5
cb_params.batch_num = 100
cb_params.dataset_sink_mode = False
cb_params.cur_epoch_num = 1
cb_params.get.return_value = None
cb_params.optimizer = Mock()
cb_params.optimizer.global_step = 5

# Mock net_outputs with high global_norm
mock_global_norm = Mock()
mock_global_norm.item.return_value = 15.0
# Mock __ge__ to support >= comparison (returns True for high norm)
mock_global_norm.__ge__ = Mock(return_value=True)
cb_params.net_outputs = (Mock(), False, 1024.0, 0.001, mock_global_norm)

# First spike - should log but not raise
monitor.abnormal_global_norm_check(cb_params)
assert monitor.global_norm_spike_count == 1

# Second spike
monitor.abnormal_global_norm_check(cb_params)
assert monitor.global_norm_spike_count == 2

# Third spike - should raise ValueError
with pytest.raises(ValueError) as exc_info:
monitor.abnormal_global_norm_check(cb_params)

assert "consecutive times greater than threshold" in str(exc_info.value)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.context.get_auto_parallel_context', return_value=1)
def test_abnormal_global_norm_check_reset_count(self, mock_context, *mocks):
"""Test that global_norm_spike_count resets when norm is normal"""

config = {
'check_for_global_norm': False,
'global_norm_spike_threshold': 10.0,
'global_norm_spike_count_threshold': 3
}

monitor = TrainingStateMonitor(
origin_epochs=10,
dataset_size=100,
config=config,
use_skip_data_by_global_norm=True
)

cb_params = Mock()
cb_params.cur_step_num = 5
cb_params.batch_num = 100
cb_params.dataset_sink_mode = False
cb_params.cur_epoch_num = 1
cb_params.get.return_value = None
cb_params.optimizer = Mock()
cb_params.optimizer.global_step = 5

# High global_norm
mock_high_norm = Mock()
mock_high_norm.item.return_value = 15.0
# Mock __ge__ to return True (high norm >= threshold)
mock_high_norm.__ge__ = Mock(return_value=True)
cb_params.net_outputs = (Mock(), False, 1024.0, 0.001, mock_high_norm)

monitor.abnormal_global_norm_check(cb_params)
assert monitor.global_norm_spike_count == 1

# Normal global_norm - should reset count
mock_normal_norm = Mock()
mock_normal_norm.item.return_value = 5.0
# Mock __ge__ to return False (normal norm < threshold)
mock_normal_norm.__ge__ = Mock(return_value=False)
cb_params.net_outputs = (Mock(), False, 1024.0, 0.001, mock_normal_norm)

monitor.abnormal_global_norm_check(cb_params)
assert monitor.global_norm_spike_count == 0


class TestTrainingStateMonitorCalcMethods:
"""Test TrainingStateMonitor calculation methods"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=Mock())
@patch('mindformers.core.callback.callback._get_weight_norm', return_value=2.5)
@patch('mindformers.core.callback.callback.context.get_auto_parallel_context', return_value=1)
def test_calc_weight_state(self, mock_context, mock_get_weight_norm, *mocks):
"""Test _calc_weight_state method"""

config = {
'weight_state_format': ['log', 'tensorboard']
}

monitor = TrainingStateMonitor(
origin_epochs=10,
dataset_size=100,
config=config
)

cb_params = Mock()
cb_params.cur_step_num = 5
cb_params.network = Mock()
cb_params.network.network = Mock()

monitor._calc_weight_state(cb_params)

mock_get_weight_norm.assert_called_once()
monitor.tensor_writer.add_scalar.assert_called()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=Mock())
@patch('mindformers.core.callback.callback.context.get_auto_parallel_context', return_value=1)
def test_calc_throughput_linearity(self, mock_context, *mocks):
"""Test _calc_throughput_linearity method"""

config = {
'throughput_baseline': 100.0
}

monitor = TrainingStateMonitor(
origin_epochs=10,
dataset_size=100,
config=config,
global_batch_size=32
)

cb_params = Mock()
cb_params.cur_step_num = 5

# per_step_seconds = 100ms
monitor._calc_throughput_linearity(cb_params, 100.0)

# throughput = 32 / 8 / (100/1000) = 4 / 0.1 = 40
# linearity = 40 / 100 = 0.4
monitor.tensor_writer.add_scalar.assert_called()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=Mock())
@patch('mindformers.core.callback.callback.get_device_local_loss')
@patch('mindformers.core.callback.callback.context.get_auto_parallel_context', return_value=1)
def test_calc_device_local_loss(self, mock_context, mock_get_loss, *mocks):
"""Test _calc_device_local_loss method"""

config = {
'device_local_loss_format': ['log', 'tensorboard']
}

monitor = TrainingStateMonitor(
origin_epochs=10,
dataset_size=100,
config=config
)

# Mock device local loss
mock_loss_tensor = Mock()
mock_loss_tensor.asnumpy.return_value = np.array(0.5)
mock_get_loss.return_value = {'lm': mock_loss_tensor}

cb_params = Mock()
cb_params.cur_step_num = 5

monitor._calc_device_local_loss(cb_params)

mock_get_loss.assert_called_once()
monitor.tensor_writer.add_scalar.assert_called()


class TestTrainingStateMonitorStableRank:
"""Test TrainingStateMonitor stable rank calculation"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=Mock())
@patch('mindformers.core.callback.callback.ms.runtime.empty_cache')
@patch('mindformers.core.callback.callback.context.get_auto_parallel_context', return_value=1)
def test_do_stable_rank(self, mock_context, mock_empty_cache, *mocks):
"""Test _do_stable_rank method"""

config = {
'stable_rank_config': {
'format': ['log'],
'step_interval': 10
}
}

monitor = TrainingStateMonitor(
origin_epochs=10,
dataset_size=100,
config=config
)

cb_params = Mock()
cb_params.cur_step_num = 10
cb_params.train_network = Mock()
cb_params.train_network.network = Mock()
cb_params.train_network.network.trainable_params.return_value = []

with patch.object(monitor, '_calc_stable_rank'):
monitor._do_stable_rank(cb_params)

mock_empty_cache.assert_called_once()
assert monitor.sr_last_print_time == 10

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=Mock())
@patch('mindformers.core.callback.callback.context.get_auto_parallel_context')
def test_calc_stable_rank_standalone(self, mock_get_context, *mocks):
"""Test _calc_stable_rank in standalone mode"""

# Mock context to return 1 for pipeline_stages during init,
# then 'stand_alone' for parallel_mode
def context_side_effect(key):
if key == "pipeline_stages":
return 1
if key == "parallel_mode":
return "stand_alone"
return None

mock_get_context.side_effect = context_side_effect

config = {
'stable_rank_config': {
'format': ['log'],
'step_interval': 10
}
}

monitor = TrainingStateMonitor(
origin_epochs=10,
dataset_size=100,
config=config
)

# Mock trainable params
mock_param = Mock()
mock_param.name = 'layer.weight'
mock_param.ndim = 2

cb_params = Mock()
cb_params.cur_step_num = 10
cb_params.train_network = Mock()
cb_params.train_network.network = Mock()
cb_params.train_network.network.trainable_params.return_value = [mock_param]

with patch.object(monitor, '_print_stable_rank'):
monitor._calc_stable_rank(cb_params)
monitor._print_stable_rank.assert_called_once()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=Mock())
@patch('mindformers.core.callback.callback.context.get_auto_parallel_context')
def test_calc_stable_rank_parallel_no_aggregation(self, mock_get_context, *mocks):
"""Test _calc_stable_rank in parallel mode without aggregation"""

# Mock context to return 1 for pipeline_stages during init,
# then 'semi_auto_parallel' for parallel_mode
def context_side_effect(key):
if key == "pipeline_stages":
return 1
if key == "parallel_mode":
return "semi_auto_parallel"
return None

mock_get_context.side_effect = context_side_effect

config = {
'stable_rank_config': {
'format': ['log'],
'step_interval': 10,
'do_aggregation': False
}
}

monitor = TrainingStateMonitor(
origin_epochs=10,
dataset_size=100,
config=config
)

# Mock trainable params
mock_param = Mock()
mock_param.name = 'layer.weight'
mock_param.ndim = 2

cb_params = Mock()
cb_params.cur_step_num = 10
cb_params.train_network = Mock()
cb_params.train_network.network = Mock()
cb_params.train_network.network.trainable_params.return_value = [mock_param]

with patch.object(monitor, '_get_remove_redundancy_param_names',
return_value=['layer.weight']):
with patch.object(monitor, '_print_stable_rank'):
monitor._calc_stable_rank(cb_params)
monitor._print_stable_rank.assert_called_once()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=Mock())
@patch('mindformers.core.callback.callback.context.get_auto_parallel_context')
@patch('mindformers.core.callback.callback.Tensor')
@patch('mindformers.core.callback.callback._get_merged_param_data')
def test_calc_stable_rank_parallel_with_aggregation(self, mock_merged_data, mock_tensor,
mock_get_context, *mocks):
"""Test _calc_stable_rank in parallel mode with aggregation"""

# Mock context to return 1 for pipeline_stages during init,
# then 'semi_auto_parallel' for parallel_mode
def context_side_effect(key):
if key == "pipeline_stages":
return 1
if key == "parallel_mode":
return "semi_auto_parallel"
return None

mock_get_context.side_effect = context_side_effect

config = {
'stable_rank_config': {
'format': ['log'],
'step_interval': 10,
'do_aggregation': True,
'target': ['layer.*']
}
}

monitor = TrainingStateMonitor(
origin_epochs=10,
dataset_size=100,
config=config
)

# Mock trainable params
mock_param = Mock()
mock_param.name = 'layer.weight'
mock_param.ndim = 2
mock_param.data = Mock()
mock_param.data.asnumpy.return_value = np.array([[1, 2], [3, 4]])

cb_params = Mock()
cb_params.cur_step_num = 10
cb_params.train_network = Mock()
cb_params.train_network.network = Mock()
cb_params.train_network.network.trainable_params.return_value = [mock_param]
cb_params.train_network.parameter_layout_dict = {'layer.weight': Mock()}

mock_merged_data.return_value = Mock()

with patch.object(monitor, '_get_remove_redundancy_param_names',
return_value=['layer.weight']):
with patch.object(monitor, '_get_single_params', return_value={0: ['layer.weight']}):
with patch.object(monitor, '_get_redundancy_removed_print', return_value=True):
with patch.object(monitor, '_print_stable_rank'):
monitor._calc_stable_rank(cb_params)
# Should call _print_stable_rank with merged data
monitor._print_stable_rank.assert_called()


class TestTrainingStateMonitorCheckSrTarget:
"""Test TrainingStateMonitor._check_sr_target"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=None)
@patch('mindformers.core.callback.callback.context.get_auto_parallel_context', return_value=1)
def test_check_sr_target_match(self, mock_context, *mocks):
"""Test _check_sr_target with matching pattern"""

config = {
'stable_rank_config': {
'format': ['log'],
'target': ['layer\\..*', 'attention\\..*']
}
}

monitor = TrainingStateMonitor(
origin_epochs=10,
dataset_size=100,
config=config
)

# Should match
assert monitor._check_sr_target('layer.weight')
assert monitor._check_sr_target('attention.query')

# Should not match
assert not monitor._check_sr_target('other.weight')

# Cache should work - second call should use cached result
assert monitor._check_sr_target('layer.weight')
assert 'layer.weight' in monitor.sr_target_cache


class TestTrainingStateMonitorPrintStableRank:
"""Test TrainingStateMonitor._print_stable_rank and related methods"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=Mock())
@patch('mindformers.core.callback.callback._get_stable_rank')
@patch('mindformers.core.callback.callback.context.get_auto_parallel_context', return_value=1)
def test_print_stable_rank_2d_tensor(
self, mock_context, mock_get_stable_rank,
mock_tensorboard, mock_group_size):
"""Test _print_stable_rank with 2D tensor"""

mock_get_stable_rank.return_value = (2.5, 3.0)

config = {
'stable_rank_config': {
'format': ['log', 'tensorboard'],
'step_interval': 10
}
}

monitor = TrainingStateMonitor(
origin_epochs=10,
dataset_size=100,
config=config
)

# Create a 2D mock parameter
mock_param = Mock()
mock_param.ndim = 2
mock_param.asnumpy.return_value = np.random.randn(10, 10)

monitor._print_stable_rank('test_layer', mock_param, 10)

mock_get_stable_rank.assert_called_once()

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=Mock())
@patch('mindformers.core.callback.callback._get_stable_rank')
@patch('mindformers.core.callback.callback.context.get_auto_parallel_context', return_value=1)
def test_print_stable_rank_3d_moe_all_mode(self, mock_context, mock_get_stable_rank,
mock_tensorboard, mock_group_size):
"""Test _print_stable_rank with 3D tensor in MoE 'all' mode"""

# Return arrays for multiple experts
mock_get_stable_rank.return_value = (
np.array([2.5, 2.6, 2.7]),
np.array([3.0, 3.1, 3.2])
)

config = {
'stable_rank_config': {
'format': ['log'],
'step_interval': 10,
'moe_show_mode': 'all'
}
}

monitor = TrainingStateMonitor(
origin_epochs=10,
dataset_size=100,
config=config
)

# Create a 3D mock parameter (for MoE)
mock_param = Mock()
mock_param.ndim = 3
mock_param.asnumpy.return_value = np.random.randn(3, 10, 10)

monitor._print_stable_rank('moe_layer', mock_param, 10)

mock_get_stable_rank.assert_called_once()


class TestTrainingStateMonitorDumpMethods:
"""Test TrainingStateMonitor dump methods"""

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=Mock())
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.glob.glob', return_value=[])
def test_parse_step_no_files(self, mock_glob, mock_rank, *mocks):
"""Test _parse_step with no dump files"""

config = {
'dump_path': '/tmp/dump',
'finish_pattern': 'finish_*'
}

# Mock both get_auto_parallel_context and context.get_auto_parallel_context
def get_context_side_effect(key):
if key == "dump_local_norm_path":
return '/tmp/dump'
if key == "pipeline_stages":
return 1
return None

with patch('mindformers.core.callback.callback.get_auto_parallel_context',
side_effect=get_context_side_effect), \
patch('mindformers.core.callback.callback.context.get_auto_parallel_context',
side_effect=get_context_side_effect):

monitor = TrainingStateMonitor(
origin_epochs=10,
dataset_size=100,
config=config
)

monitor.dump_key = {}
monitor._parse_step()

# Should not add any keys when no files found
assert len(monitor.dump_key) == 0

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=Mock())
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.os.listdir', return_value=[])
def test_dump_data_in_step_empty(self, mock_listdir, mock_rank, *mocks):
"""Test _dump_data_in_step with empty directory"""

config = {
'dump_path': '/tmp/dump'
}

# Mock both get_auto_parallel_context and context.get_auto_parallel_context
def get_context_side_effect(key):
if key == "dump_local_norm_path":
return '/tmp/dump'
if key == "pipeline_stages":
return 1
return None

with patch('mindformers.core.callback.callback.get_auto_parallel_context',
side_effect=get_context_side_effect), \
patch('mindformers.core.callback.callback.context.get_auto_parallel_context',
side_effect=get_context_side_effect):

monitor = TrainingStateMonitor(
origin_epochs=10,
dataset_size=100,
config=config
)

monitor.dump_key = {0: 0, 1: 10}
monitor.dump_step = 1
monitor._parse_step = Mock()

# Should not raise error with empty directory
monitor._dump_data_in_step(1)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=Mock())
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
def test_dump_local_loss(self, mock_rank, *mocks):
"""Test _dump_local_loss method"""

config = {
'local_loss_format': ['log', 'tensorboard']
}

# Mock both get_auto_parallel_context and context.get_auto_parallel_context
def get_context_side_effect(key):
if key == "dump_local_norm_path":
return None
if key == "pipeline_stages":
return 1
return None

with patch('mindformers.core.callback.callback.get_auto_parallel_context',
side_effect=get_context_side_effect), \
patch('mindformers.core.callback.callback.context.get_auto_parallel_context',
side_effect=get_context_side_effect):

monitor = TrainingStateMonitor(
origin_epochs=10,
dataset_size=100,
config=config
)

monitor.dump_step = 10
monitor._output = Mock()

local_losses = {
'main': [np.array(0.5), np.array(0.6)],
'aux': [np.array(0.1), np.array(0.2)]
}

monitor._dump_local_loss(local_losses)

# Should call _output for each loss
assert monitor._output.call_count > 0

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=Mock())
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.get_rank', return_value=0)
def test_dump_max_attention_logit(self, mock_get_rank, mock_real_rank, *mocks):
"""Test _dump_max_attention_logit method"""

config = {
'max_attention_logit_format': ['log']
}

# Mock both get_auto_parallel_context and context.get_auto_parallel_context
def get_context_side_effect(key):
if key == "dump_local_norm_path":
return None
if key == "pipeline_stages":
return 1
return None

with patch('mindformers.core.callback.callback.get_auto_parallel_context',
side_effect=get_context_side_effect), \
patch('mindformers.core.callback.callback.context.get_auto_parallel_context',
side_effect=get_context_side_effect):

monitor = TrainingStateMonitor(
origin_epochs=10,
dataset_size=100,
config=config,
micro_batch_num=2,
tensor_model_parallel_size=1
)

monitor._output = Mock()

# Mock cb_params with optimizer parameters
cb_params = Mock()
cb_params.cur_step_num = 10
cb_params.optimizer = Mock()

# Mock parameter with max_logits_val in name
mock_param = Mock()
mock_param.name = 'layer.max_logits_val'
mock_tensor = Mock()
mock_tensor.asnumpy.return_value = np.array([[1.5, 2.0, 2.5]])
mock_param.value.return_value = mock_tensor

cb_params.optimizer._parameters = [mock_param]

monitor._dump_max_attention_logit(cb_params)

# Should call _output
assert monitor._output.call_count > 0

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=Mock())
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback._get_optimizer_state',
return_value={'layer.weight': 2.5})
def test_dump_optimizer_state(self, mock_get_opt_state, mock_rank, *mocks):
"""Test _dump_optimizer_state method"""

config = {
'optimizer_state_format': ['log']
}

# Mock both get_auto_parallel_context and context.get_auto_parallel_context
def get_context_side_effect(key):
if key == "dump_local_norm_path":
return None
if key == "pipeline_stages":
return 1
return None

with patch('mindformers.core.callback.callback.get_auto_parallel_context',
side_effect=get_context_side_effect), \
patch('mindformers.core.callback.callback.context.get_auto_parallel_context',
side_effect=get_context_side_effect):

monitor = TrainingStateMonitor(
origin_epochs=10,
dataset_size=100,
config=config
)

monitor._output = Mock()
monitor._check_param_name = Mock(return_value=True)

# Mock cb_params with optimizer
cb_params = Mock()
cb_params.cur_step_num = 10
cb_params.optimizer = Mock()
cb_params.optimizer.moment1 = Mock()
cb_params.optimizer.moment2 = Mock()

monitor._dump_optimizer_state(cb_params)

# Should call _output for adam_m and adam_v
assert monitor._output.call_count > 0

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@patch('mindformers.core.callback.callback.get_real_group_size', return_value=8)
@patch('mindformers.core.callback.callback.get_tensorboard_writer', return_value=Mock())
@patch('mindformers.core.callback.callback.get_real_rank', return_value=0)
@patch('mindformers.core.callback.callback.get_output_root_path',
return_value='/tmp/test_output')
@patch('mindformers.core.callback.callback.os.listdir', return_value=['file1.txt', '.nfs123'])
@patch('mindformers.core.callback.callback.os.remove')
def test_clear_dump_path(self, mock_remove, mock_listdir, mock_output_path, mock_rank, *mocks):
"""Test _clear_dump_path method"""

config = {
'dump_path': '/tmp/dump'
}

# Mock both get_auto_parallel_context and context.get_auto_parallel_context
def get_context_side_effect(key):
if key == "dump_local_norm_path":
return '/tmp/dump'
if key == "pipeline_stages":
return 1
return None

# Mock os.path.exists to return appropriate values for different paths
def exists_side_effect(path):
# Return True only for dump_path, False for global_norm_record_path
if '/tmp/dump' in path:
return True
return False

with patch('mindformers.core.callback.callback.get_auto_parallel_context',
side_effect=get_context_side_effect), \
patch('mindformers.core.callback.callback.context.get_auto_parallel_context',
side_effect=get_context_side_effect), \
patch('mindformers.core.callback.callback.os.path.exists',
side_effect=exists_side_effect):

monitor = TrainingStateMonitor(
origin_epochs=10,
dataset_size=100,
config=config
)

monitor._clear_dump_path()

# Should remove file1.txt but not .nfs123
mock_remove.assert_called_once()

if __name__ == '__main__':
pytest.main([__file__, '-v'])

+ 178
- 0
tests/st/test_ut/test_core/test_optim/test_get_op_group.py View File

@@ -0,0 +1,178 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test get op groups info for GPT model."""

from unittest.mock import patch

import mindspore as ms
import pytest

from mindformers import build_context
from mindformers.checkpoint.sharded_tensor import build_sharded_tensor
from mindformers.parallel_core.training_graph.base_models.gpt import gpt_model
from mindformers.parallel_core.training_graph.base_models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec, \
get_gpt_mtp_block_spec
from mindformers.parallel_core.training_graph.base_models.gpt.gpt_model import GPTModel, \
compute_repeat_num_and_model_parallel_size, get_op_group_name
from mindformers.parallel_core.transformer_config import TransformerConfig


def build_transformer_config() -> TransformerConfig:
"""Create a minimal transformer config for tensor-parallel unit tests."""
return TransformerConfig(
data_parallel_size=1,
pipeline_model_parallel_size=1,
tensor_model_parallel_size=1,
# model architecture
vocab_size=1024,
position_embedding_type="rope",
num_attention_heads=2,
num_layers=2,
hidden_size=128,
ffn_hidden_size=512,
# moe architecture
num_moe_experts=4,
first_k_dense_replace=1,
mtp_num_layers=1,
add_bias_linear=False,
moe_grouped_gemm=True
)


def build_gpt_model():
"""Construct a GPTModel instance with the default test configuration."""
config = build_transformer_config()
transformer_layer_spec = get_gpt_decoder_block_spec(config)
mtp_block_spec = None
if config.mtp_num_layers is not None:
mtp_block_spec = get_gpt_mtp_block_spec(config, transformer_layer_spec)
model = GPTModel(
config=config,
transformer_layer_spec=transformer_layer_spec,
vocab_size=config.vocab_size,
max_sequence_length=config.max_position_embeddings,
position_embedding_type=config.position_embedding_type,
rotary_percent=1.0,
rotary_base=config.rotary_base,
rope_scaling=False,
mtp_block_spec=mtp_block_spec
)
return model


def build_sharded_info(local_shape, axis_fragmentations):
"""Helper to create a simple ShardedTensor descriptor."""
return build_sharded_tensor(
param_name="test",
param_dtype=ms.float32,
local_shape=local_shape,
global_shape=local_shape,
axis_fragmentations=axis_fragmentations,
global_offset=(0,) * len(local_shape),
)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_gpt_model_sharded_state_dict():
"""
Feature: GPTModel
Description: Test the sharded state dict of GPT model.
Expectation: The sharded state dict has all the trainable parameters and the shape is correct.
"""
build_context({"use_legacy": False})
model = build_gpt_model()
sharded_state_dict = model.sharded_state_dict()

params = model.trainable_params()
for param in params:
assert param.name in sharded_state_dict
assert param.shape == sharded_state_dict[param.name].global_shape


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize(
"axis_fragmentations, world_size, pipeline_parallel, opt_group_size, local_shape, expected",
[
# case 0: real_op_size == opt_group_size
((1, 1), 12, 2, 4, (12, 4), (4, 1)),
# case 1: real_op_size < opt_group_size
((2, 1), 16, 2, 8, (12, 4), (4, 2)),
# case 2: real_op_size = 1 due to local shape not divisible by real_op_size
((4, 1), 32, 2, 4, (10, 4), (1, 4)),
],
)
def test_compute_repeat_num_and_model_parallel_size(axis_fragmentations, world_size, pipeline_parallel,
opt_group_size, local_shape, expected):
"""
Feature: compute_repeat_num_and_model_parallel_size()
Description: Test the compute repeat num and model parallel size.
Expectation: The compute repeat num and model parallel size should be correct.
"""
sharded_info = build_sharded_info(local_shape, axis_fragmentations)
assert compute_repeat_num_and_model_parallel_size(
sharded_info,
world_size=world_size,
pp=pipeline_parallel,
op=opt_group_size,
) == expected


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_compute_repeat_num_and_model_parallel_size_multiple_axis_error():
"""
Feature: compute_repeat_num_and_model_parallel_size()
Description: Test the error of compute repeat num and model parallel size.
Expectation: The ValueError should be raised.
"""
sharded_info = build_sharded_info((8, 8), (2, 2))
with pytest.raises(ValueError):
compute_repeat_num_and_model_parallel_size(sharded_info, world_size=16, pp=1, op=2)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@patch("mindformers.parallel_core.training_graph.base_models.gpt.gpt_model.create_communication_group")
def test_get_op_group_name_with_mock(mock_create_group):
"""
Feature: get_op_group_name()
Description: Test the get op group name with mock.
Expectation: The get op group name with mock should be correct.
"""
mock_create_group.return_value = "mock_group"
gpt_model.OP_GROUP_NAME.clear()

# case 0: model_parallel_size > 1
result = get_op_group_name(rank_id=3, real_op_size=2, model_parallel_size=2)
assert result == ("mock_group", [1, 3])
mock_create_group.assert_called_once_with([1, 3])

second_result = get_op_group_name(rank_id=3, real_op_size=2, model_parallel_size=2)
assert second_result == result
mock_create_group.assert_called_once()

# case 1: model_parallel_size = 1
result = get_op_group_name(rank_id=3, real_op_size=2, model_parallel_size=1)
assert result == ("mock_group", [2, 3])

# case 2: model_parallel_size = 4
result = get_op_group_name(rank_id=3, real_op_size=2, model_parallel_size=4)
assert result == ("mock_group", [3, 7])

+ 2
- 2
tests/st/test_ut/test_model_mixin.py View File

@@ -480,7 +480,7 @@ class TestTrainModelMixin:
# Create a mock model with get_op_groups_info method
class MockModel:
# pylint: disable=W0613
def get_op_groups_info(self, parameters, op_size, tp_group, op_group):
def get_op_groups_info(self, parameters, op_size):
return f"info_{op_size}"

class TestModel(TrainModelMixin):
@@ -489,7 +489,7 @@ class TestTrainModelMixin:
self.model = MockModel()

mixin = TestModel()
assert mixin.get_op_groups_info(None, 2, None, None) == "info_2"
assert mixin.get_op_groups_info(None, 2) == "info_2"

@pytest.mark.level0
@pytest.mark.platform_x86_cpu


+ 57
- 0
tests/st/test_ut/test_models/test_base_model/test_base_model.py View File

@@ -0,0 +1,57 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test base model
"""
import os
import tempfile
import pytest

from mindformers import MindFormerConfig
from mindformers.models.base_config import BaseConfig
from mindformers.models.base_model import BaseModel

NUM_LAYERS = 1


class TestBaseModel:
"""A test class for testing model.save_pretrained() method."""

def setup_method(self):
"""init test class."""
with tempfile.TemporaryDirectory() as temp_dir_path:
self.path = temp_dir_path

@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_base_model(self):
"""
Feature: Base_model save_pretrained()
Description: Test llama save pretrained
Expectation: Run successfully.
"""
config = BaseConfig(num_layers=NUM_LAYERS)
model = BaseModel(config)
model.save_pretrained(self.path, save_name="mindspore_model")
yaml_path = self.path + "/" + "mindspore_model.yaml"
model_path = self.path + "/" + "mindspore_model.ckpt"
assert os.path.exists(yaml_path)
assert os.path.exists(model_path)

mf_config = MindFormerConfig(yaml_path)
assert mf_config.model.model_config.num_layers == NUM_LAYERS
# pylint: disable=W0212
model._get_config_args(pretrained_model_name_or_dir=self.path)

+ 79
- 0
tests/st/test_ut/test_models/test_build_models/test_build_model.py View File

@@ -0,0 +1,79 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test module for testing build_model for mindformers.
"""
import pytest

from mindformers.models.build_model import build_encoder, build_head
from mindformers.tools.register import MindFormerModuleType, MindFormerRegister


class DummyEncoder:
def __init__(self, **kwargs):
self.kwargs = kwargs

class DummyHead:
def __init__(self, num_classes=10):
self.num_classes = num_classes

MindFormerRegister.register(MindFormerModuleType.ENCODER, "dummy_enc")(DummyEncoder)
MindFormerRegister.register(MindFormerModuleType.HEAD, "dummy_head")(DummyHead)


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_build_encoder():
"""
Feature: build_encoder()
Description: Test build_encoder().
Expectation: Run successfully.
"""
encoder_config = None
class_name = None
encoder = build_encoder(encoder_config, class_name=class_name)
assert encoder is None
encoder = build_encoder(class_name=DummyEncoder)
assert encoder is not None
encoder_config = {"type": DummyEncoder}
encoder = build_encoder(encoder_config)
assert encoder is not None
encoder_config = [{"type": DummyEncoder}, {"type": DummyEncoder}]
encoder = build_encoder(encoder_config)
assert encoder is not None


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_build_head():
"""
Feature: build_head()
Description: Test build_head().
Expectation: Run successfully.
"""
head_config = None
class_name = None
head = build_head(head_config, class_name=class_name)
assert head is None
head = build_head(class_name=DummyHead)
assert head is not None
head_config = {"type": DummyHead}
head = build_head(head_config)
assert head is not None
head_config = [{"type": DummyHead}, {"type": DummyHead}]
head = build_head(head_config)
assert head is not None

+ 147
- 0
tests/st/test_ut/test_models/test_build_models/test_utils.py View File

@@ -0,0 +1,147 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test module for testing models utils for mindformers.
"""
import unittest
from unittest.mock import MagicMock
import pytest

from mindformers.models.utils import check_use_3d_tensor_parallel_valid


# Mock helper functions and constants
class ParallelMode:
AUTO_PARALLEL = "auto_parallel"


def check_fine_grain_interleave_valid(fine_grain):
return fine_grain is not None and fine_grain > 1


class TestCheckUse3DTensorParallelValid(unittest.TestCase):
"""A class for testing CheckUse3DTensorParallelValid."""
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_disabled_use_3d_tp(self):
"""Branch: use_3d_tensor_parallel = False → return False"""
config = MagicMock()
config.use_3d_tensor_parallel = False
result = check_use_3d_tensor_parallel_valid(config)
self.assertFalse(result)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_config_none(self):
"""Branch: config is None → return False"""
result = check_use_3d_tensor_parallel_valid(None)
self.assertFalse(result)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_parallel_config_none(self):
"""Branch: config.parallel_config is None → return False"""
config = MagicMock()
config.use_3d_tensor_parallel = True
config.parallel_config = None
result = check_use_3d_tensor_parallel_valid(config)
self.assertFalse(result)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_use_flash_attention_false(self):
"""Raise: use_flash_attention must be True"""
config = self._create_valid_config()
config.use_flash_attention = False
with self.assertRaises(ValueError, msg="use_flash_attention must be True"):
check_use_3d_tensor_parallel_valid(config)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_ulysses_cp_num_gt_1(self):
"""Raise: ulysses cp > 1 not supported"""
config = self._create_valid_config()
config.parallel_config.get_ulysses_cp_num.return_value = 2
with self.assertRaises(ValueError, msg="ulysses cp must be 1"):
check_use_3d_tensor_parallel_valid(config)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_moe_enabled(self):
"""Raise: MoE not supported"""
config = self._create_valid_config()
moe_mock = MagicMock()
moe_mock.expert_num = 8
config.moe_config = moe_mock
with self.assertRaises(ValueError, msg="MoE not supported"):
check_use_3d_tensor_parallel_valid(config)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_seq_parallel_false(self):
"""Raise: use_seq_parallel must be True"""
config = self._create_valid_config()
config.parallel_config.use_seq_parallel = False
with self.assertRaises(ValueError, msg="use_seq_parallel must be True"):
check_use_3d_tensor_parallel_valid(config)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_fine_grain_interleave_invalid(self):
"""Raise: fine_grain_interleave not supported"""
config = self._create_valid_config()
config.fine_grain_interleave = 2 # triggers True in check_fine_grain_interleave_valid
with self.assertRaises(ValueError, msg="fine_grain_interleave not supported"):
check_use_3d_tensor_parallel_valid(config)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_tp_product_mismatch(self):
"""Raise: tp_x * tp_y * tp_z != model_parallel"""
config = self._create_valid_config()
config.tp_x = 2
config.tp_y = 2
config.tp_z = 2
config.parallel_config.model_parallel = 7 # 2*2*2=8 ≠ 7
with self.assertRaises(ValueError, msg="tp product mismatch"):
check_use_3d_tensor_parallel_valid(config)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def _create_valid_config(self):
"""Helper to create a config that passes initial checks"""
config = MagicMock()
config.use_3d_tensor_parallel = True
config.use_flash_attention = True
config.fine_grain_interleave = None # valid
config.moe_config = None

parallel_config = MagicMock()
parallel_config.get_ulysses_cp_num.return_value = 1
parallel_config.use_seq_parallel = True
parallel_config.model_parallel = 4
config.parallel_config = parallel_config

return config

+ 116
- 0
tests/st/test_ut/test_tools/test_check_rules.py View File

@@ -0,0 +1,116 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test module for testing tools check_rules for mindformers.
"""
import pytest
from mindformers.tools.check_rules import (
_restore_net_type,
_rule_fa_only_for_train,
_check_keyword_gen_dataset,
_check_context_parallel_algo_valid,
_check_recompute
)
from mindformers import MindFormerConfig


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_restore_net_type():
"""
Feature: Check rules
Description: Test check_rules.
Expectation: Run successfully.
"""
config = MindFormerConfig()
config.set_value('model.model_config.compute_dtype', 'bfloat16')
config.set_value('model.model_config.param_init_type', 'float32')
_restore_net_type(config=config)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_rule_fa_only_train():
"""
Feature: Check rules
Description: Test check_rules.
Expectation: Run successfully.
"""
config = MindFormerConfig()
config.set_value('model.model_config.use_flash_attention', True)
_rule_fa_only_for_train(config=config, mode="train")

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_context_parallel_algo_valid():
"""
Feature: Check rules
Description: Test check_rules.
Expectation: Run successfully.
"""
config = MindFormerConfig()
config.set_value('model.model_config.n_kv_heads', None)
config.set_value('model.model_config.multi_query_group_num', 2)
config.set_value('model.model_config.num_heads', None)
config.set_value('model.model_config.num_attention_heads', 32)
config.set_value('parallel_config.context_parallel_algo.value', "ulysses_cp")
with pytest.raises(ValueError, match=r"cp \* mp <= attention head"):
_check_context_parallel_algo_valid(config=config, cp=8, mp=8)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_keyword_gen_dataset():
"""
Feature: Check rules
Description: Test check_rules.
Expectation: Run successfully.
"""
config = MindFormerConfig()
config.set_value('model.model_config.seq_length', 101)
config.set_value('do_eval', False)
config.set_value('metric', [{"type": "ADGENMetric"}, {"type": "PerplexityMetric"}])

# train dataset
config.set_value('train_dataset.data_loader.type', "ADGenDataLoader")
config.set_value('train_dataset.max_source_length', 50)
config.set_value('train_dataset.max_target_length', 50)

# eval dataset
config.set_value('eval_dataset.data_loader.type', "ADGenDataLoader")
config.set_value('eval_dataset.data_loader.phase', "eval")
config.set_value('eval_dataset.max_source_length', 101)
config.set_value('eval_dataset.max_target_length', 20)
config.set_value('eval_dataset_task.dataset_config.data_loader.phase', "eval")

_check_keyword_gen_dataset(config=config, mode='train')

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_recompute():
"""
Feature: Check rules
Description: Test check_rules.
Expectation: Run successfully.
"""
config = MindFormerConfig()
config.set_value("swap_config.swap", True)
config.set_value("recompute_config.recompute", True)
config.set_value("recompute_config.select_recompute", True)
config.set_value("recompute_config.select_comm_recompute", True)
_check_recompute(config=config)

+ 482
- 0
tests/st/test_ut/test_tools/test_utils/test_utils.py View File

@@ -0,0 +1,482 @@
# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test utils"""
import os
import stat
import tempfile
from pathlib import Path
from unittest import mock
import pytest

from mindspore import context

from mindformers.tools.utils import (
check_in_modelarts,
get_output_root_path,
is_version_le,
is_version_ge,
get_epoch_and_step_from_ckpt_name,
str2bool,
parse_value,
set_safe_mode_for_file_or_dir,
PARALLEL_MODE,
MODE,
Validator,
check_obs_url,
get_rank_id_from_ckpt_name,
replace_rank_id_in_ckpt_name,
get_ascend_log_path,
calculate_pipeline_stage,
divide,
is_pynative,
create_and_write_info_to_txt,
check_ckpt_file_name,
get_times_epoch_and_step_from_ckpt_name,
is_last_pipeline_stage,
get_dp_from_dataset_strategy,
check_shared_disk,
replace_tk_to_mindpet,
get_num_nodes_devices,
)


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_in_modelarts_true():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
with mock.patch.dict(os.environ, {"MA_LOG_DIR": "/tmp"}):
assert check_in_modelarts() is True

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_in_modelarts_false():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
with mock.patch.dict(os.environ, clear=True):
assert check_in_modelarts() is False

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_output_root_path_default():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
with mock.patch.dict(os.environ, {}, clear=True):
path = get_output_root_path()
assert path == os.path.realpath("./output")

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_output_root_path_env():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
with mock.patch.dict(os.environ, {"LOCAL_DEFAULT_PATH": "/custom/output"}):
path = get_output_root_path()
assert path == "/custom/output"

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_is_version_le():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
assert is_version_le("1.8.1", "1.11.0") is True
assert is_version_le("1.11.0", "1.11.0") is True
assert is_version_le("2.0.0", "1.11.0") is False

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_is_version_ge():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
assert is_version_ge("1.11.0", "1.8.1") is True
assert is_version_ge("1.11.0", "1.11.0") is True
assert is_version_ge("1.8.1", "1.11.0") is False

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_epoch_and_step_from_ckpt_name():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
epoch, step = get_epoch_and_step_from_ckpt_name("model-5_100.ckpt")
assert epoch == 5
assert step == 100

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_epoch_and_step_invalid():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
with pytest.raises(ValueError, match="Can't match epoch and step"):
get_epoch_and_step_from_ckpt_name("invalid_name.txt")

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_str2bool():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
assert str2bool("True") is True
assert str2bool("False") is False

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_str2bool_invalid():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
with pytest.raises(Exception, match="Invalid Bool Value"):
str2bool("maybe")

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_parse_value():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
assert parse_value("123") == 123
assert parse_value("3.14") == 3.14
assert parse_value("True") is True
assert parse_value('{"a": 1}') == {"a": 1}
assert parse_value("hello") == "hello"

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_set_safe_mode_for_file_or_dir():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
with tempfile.TemporaryDirectory() as tmpdir:
file_path = Path(tmpdir) / "test.txt"
dir_path = Path(tmpdir) / "subdir"
dir_path.mkdir()

file_path.write_text("test")
set_safe_mode_for_file_or_dir([str(file_path), str(dir_path)])

assert (file_path.stat().st_mode & stat.S_IRUSR) != 0
assert (file_path.stat().st_mode & stat.S_IWUSR) != 0
assert (dir_path.stat().st_mode & stat.S_IXUSR) != 0

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_parallel_mode_mapping():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
assert PARALLEL_MODE["DATA_PARALLEL"] == context.ParallelMode.DATA_PARALLEL
assert PARALLEL_MODE[0] == context.ParallelMode.DATA_PARALLEL

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_mode_mapping():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
assert MODE["GRAPH_MODE"] == context.GRAPH_MODE
assert MODE[0] == context.GRAPH_MODE

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_validator_check_type():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
Validator.check_type(42, int)
with pytest.raises(TypeError):
Validator.check_type("42", int)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_obs_url_valid():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
check_obs_url("obs://bucket/path")
check_obs_url("s3://bucket/path")

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_obs_url_invalid():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
with pytest.raises(TypeError, match="should be start with obs:// or s3://"):
check_obs_url("/local/path")

with pytest.raises(TypeError, match="type should be a str"):
check_obs_url(123)

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_rank_id_from_ckpt_name():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
rank = get_rank_id_from_ckpt_name("llama_7b_rank_3-5_100.ckpt")
assert rank == 3

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_replace_rank_id_in_ckpt_name():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
new_name = replace_rank_id_in_ckpt_name("model_rank_2-1_50.ckpt", 5)
assert new_name == "model_rank_5-1_50.ckpt"

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_ascend_log_path():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
os.environ['ASCEND_PROCESS_LOG_PATH'] = '/home/log'
assert get_ascend_log_path() == '/home/log'

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_calculate_pipeline_stage():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
layers_per_stage = [4, 4]
model_layers = [6]
input_layers_per_stage = layers_per_stage.copy()
result = calculate_pipeline_stage(input_layers_per_stage, model_layers)
expected = [
{
"offset": [1, -1], # [4-3, 2-3]
"start_stage": 0,
"stage_num": 2
}
]
assert result == expected

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_divide():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
res = divide(10, 2)
assert res == 5

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_is_pynative():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
os.environ['ENFORCE_EAGER'] = 'true'
res = is_pynative()
assert res

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_create_and_write_info_to_txt():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
with tempfile.TemporaryDirectory() as tmpdir:
txt_path = os.path.join(tmpdir, "output.txt")
info = "Hello, world!"

create_and_write_info_to_txt(txt_path, info)

assert os.path.exists(txt_path)
with open(txt_path, 'r', encoding='utf-8') as f:
content = f.read()
assert content == "Hello, world!"

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_ckpt_file_name():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
ckpt_name = "llama_0-3_1.ckpt"
res = check_ckpt_file_name(ckpt_name)
assert res
ckpt_name = "dsadsdsadasd"
res = check_ckpt_file_name(ckpt_name)
assert not res

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_times_epoch_and_step_from_ckpt_name():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
ckpt_name = "llama_0-3_1.ckpt"
res = check_ckpt_file_name(ckpt_name)
if res:
times, epcoh, step = get_times_epoch_and_step_from_ckpt_name(ckpt_name)
assert times == 0
assert epcoh == 3
assert step == 1

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_is_last_pipeline_stage():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
with mock.patch("mindformers.tools.utils.get_real_group_size", return_value=8), \
mock.patch("mindformers.tools.utils.get_real_rank", return_value=6), \
mock.patch("mindspore.get_auto_parallel_context", return_value=2):
assert is_last_pipeline_stage() is True

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_dp_from_dataset_strategy():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
with mock.patch("mindspore.get_auto_parallel_context", return_value=[[2, 1]]):
dp = get_dp_from_dataset_strategy()
assert dp == 2

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_shared_disk():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
disk_path = "/home/workspace"
res = check_shared_disk(disk_path=disk_path)
assert not res

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_replace_tk_to_mindpet():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
ckpt_dict = {"tk_delta": 1}
new_ckpt = replace_tk_to_mindpet(ckpt_dict)
assert new_ckpt['mindpet_delta'] == 1

@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_num_nodes_devices():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
rank_size = 7
with mock.patch("mindformers.tools.utils.get_device_num_per_node", return_value=8):
num_nodes, num_devices = get_num_nodes_devices(rank_size=rank_size)
assert num_nodes == 1
assert num_devices == rank_size

+ 3
- 17
tests/st/test_ut/test_transformer_apis.py View File

@@ -22,14 +22,14 @@ from mindspore.ops import operations as ops
from mindspore.common.api import _cell_graph_executor

from mindformers.core import CrossEntropyLoss
from mindformers.modules import FeedForward, FixedSparseAttention, LowerTriangularMaskWithDynamic
from mindformers.modules import FixedSparseAttention, LowerTriangularMaskWithDynamic


class MyActivation(mindspore.nn.Cell):
"""An example of custom activation"""

def __init__(self):
super(MyActivation, self).__init__()
super().__init__()
self.add = ops.Add()

def construct(self, x):
@@ -43,27 +43,13 @@ class MyActivationNoShard(mindspore.nn.Cell):
"""An example of custom activation without shard"""

def __init__(self):
super(MyActivationNoShard, self).__init__()
super().__init__()
self.add = ops.Add()

def construct(self, x):
return self.add(x, 0.1)


def test_feedforward():
"""
Feature: Feedforward
Description: Test Feedforward module
Expectation: No exception
"""
model = FeedForward(hidden_size=15,
ffn_hidden_size=30,
dropout_rate=0.1,
hidden_act='relu')
tensor = Tensor(np.ones((2, 20, 15)), dtype.float32)
_cell_graph_executor.compile(model, tensor)


def test_cross_entropy_loss():
"""
Feature: CrossEntropyLoss


+ 42
- 0
tests/st/test_ut/test_utils/test_import_utils.py View File

@@ -0,0 +1,42 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test module for testing import utils for mindformers.
"""
import os
import tempfile
import pytest

from mindformers.utils.import_utils import direct_mindformers_import


@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_direct_mindformers_import_success():
"""
Feature: Import utils
Description: Test direct_mindformers_import.
Expectation: Run successfully.
"""
with tempfile.TemporaryDirectory() as tmpdir:
init_file = os.path.join(tmpdir, "__init__.py")
with open(init_file, "w", encoding='utf-8') as f:
f.write('''
def hello():
return "Hello from mocked mindformers!"
''')
module = direct_mindformers_import(tmpdir)
assert module is not None

Loading…
Cancel
Save
Baidu
map