@@ -37,6 +37,13 @@ try:
from mindspore._checkparam import Validator
except ImportError:
import mindspore._checkparam as Validator
from research.deepseek3.deepseek3_config import DeepseekV3Config
from research.deepseek3.moe import ExpertParallelMoE, ParallelMoEV2, RoutedParallelMLP, SharedMLP, SharedParallelMLP
from research.deepseek3.utils import convert_model_config
from research.deepseek3.infer.norm import RMSNorm
from research.deepseek3.infer.transformer import ParallelMLP, VocabEmbedding
from research.deepseek3.infer.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
from mindformers.models.modeling_utils import PreTrainedModel
from mindformers.models.utils import lazy_inline, check_fine_grain_interleave_valid, predict_lazy_inline,\
jit
@@ -59,13 +66,8 @@ from mindformers.parallel_core.inference.tensor_parallel.mappings import (gather
reduce_scatter_to_model_parallel_region,
scatter_to_model_parallel_region)
from mindformers.version_control import is_910b
from mindformers.parallel_core.inference.parallel_state import get_data_parallel_group
from research.deepseek3.deepseek3_config import DeepseekV3Config
from research.deepseek3.moe import ExpertParallelMoE, ParallelMoEV2, RoutedParallelMLP, SharedMLP, SharedParallelMLP
from research.deepseek3.utils import convert_model_config
from research.deepseek3.infer.norm import RMSNorm
from research.deepseek3.infer.transformer import ParallelMLP, VocabEmbedding
from research.deepseek3.infer.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
__all__ = ['InferenceDeepseekV3ForCausalLM', 'DeepseekV3Model']
@@ -249,7 +251,7 @@ class MLAInferAttention(nn.Cell):
prefill_head_dim=None,
config: DeepseekV3Config = None
):
super(MLAInferAttention, self ).__init__()
super().__init__()
self.n_head = n_head
self.head_dim = head_dim
self.n_kv_head = n_kv_head
@@ -438,14 +440,13 @@ class DeepseekV3Attention(nn.Cell):
raise ValueError("For 'DeepseekV3Attention', the use_flash_attention must be enabled.")
if self.hidden_size % self.n_head != 0:
raise ValueError("For 'MultiHeadAttention', the class variable 'hidden_size' must be a multiple "
"of 'n_head', but got the hidden_size is {} and the n_head is {}. "
.format(self.hidden_size, self.n_head) )
raise ValueError(f "For 'MultiHeadAttention', the class variable 'hidden_size' must be a multiple "
f"of 'n_head', but got the hidden_size is {self.hidden_size} and "
f"the n_head is {self.n_head}." )
if self.n_kv_head % parallel_config.model_parallel != 0:
raise ValueError("For 'MultiHeadAttention', the class variable 'n_kv_head' must be a multiple of "
"'parallel_config.model_parallel', but got the n_kv_head is {} "
"and the parallel_config.model_parallel is {}."
.format(self.n_kv_head, parallel_config.model_parallel))
raise ValueError(f"For 'MultiHeadAttention', the class variable 'n_kv_head' must be a multiple of "
f"'parallel_config.model_parallel', but got the n_kv_head is {self.n_kv_head} "
f"and the parallel_config.model_parallel is {parallel_config.model_parallel}.")
self.shape = P.Shape()
self.cast = P.Cast()
if self.q_lora_rank == 0:
@@ -572,12 +573,13 @@ class DeepseekV3Attention(nn.Cell):
if self.q_lora_rank == 0:
q = self.q_proj(x)
latent_kv_all = self.kv2l(x)
latent_kv, k_pe = mint.split(latent_kv_all, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_kv, k_pe = ops.function.array_func.split_ext(
latent_kv_all, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
else:
if self.qkv_concat:
qkv2l = self.qkv2l(x)
q, latent_kv, k_pe = mint.split(qkv2l, [self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim],
dim=-1)
q, latent_kv, k_pe = ops.function.array_func.split_ext(
qkv2l, [self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
norm_q = self.lq_norm(q)
q = self.l2q_proj(norm_q)
else:
@@ -585,10 +587,11 @@ class DeepseekV3Attention(nn.Cell):
norm_q = self.lq_norm(q)
q = self.l2q_proj(norm_q)
latent_kv_all = self.kv2l(x)
latent_kv, k_pe = mint.split(latent_kv_all, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_kv, k_pe = ops.function.array_func.split_ext(
latent_kv_all, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
q = self.reshape(q, (-1, self.n_local_heads, self.q_head_dim))
q_nope, q_pe = mint.spli t(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
q_nope, q_pe = ops.function.array_func.split_ex t(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# (T, kv_lora_rank)
i_kv = self.lkv_norm(latent_kv)
q_pe = self.reshape(q_pe, (-1, self.n_local_heads * self.qk_rope_head_dim))
@@ -663,7 +666,7 @@ class DeepseekV3ParallelMLP(ParallelMLP):
# [B, S, H] -> [B, S, ffn_H]
if self.ffn_concat:
gate_hidden_out = self.w_gate_hidden(x) # dp,1 -> dp, mp # dp,1 -> dp, mp
gate, hidden = mint.spli t(gate_hidden_out,
gate, hidden = ops.function.array_func.split_ex t(gate_hidden_out,
(self.ffn_hidden_size_per_partition, self.ffn_hidden_size_per_partition), -1)
else:
gate = self.w1(x) # dp,1 -> dp, mp
@@ -692,7 +695,7 @@ class DeepseekV3MoE(Cell):
"""
def __init__(self, config):
super(DeepseekV3MoE, self ).__init__()
super().__init__()
self.config = config
self.parallel_config = config.parallel_config
self.moe_config = config.moe_config
@@ -766,7 +769,7 @@ class DeepseekV3MoEWithMicroBatch(DeepseekV3MoE):
"""
def __init__(self, config):
super(DeepseekV3MoEWithMicroBatch, self ).__init__(config=config)
super().__init__(config=config)
self.moe_tp_size = get_moe_tp_world_size()
self.moe_ep_size = get_moe_ep_world_size()
self.ep_rank_id = get_rank() // self.moe_tp_size
@@ -846,7 +849,7 @@ class AttentionReduceScatter(Cell):
"""
def __init__(self, config):
super(AttentionReduceScatter, self ).__init__()
super().__init__()
self.config = config
self.compute_dtype = config.compute_dtype
self.hidden_size = config.hidden_size
@@ -1439,7 +1442,7 @@ class InferenceDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
@lazy_inline
def __init__(self, config: DeepseekV3Config = None):
super(InferenceDeepseekV3ForCausalLM, self ).__init__(config, auto_prefix=True)
super().__init__(config, auto_prefix=True)
_check_config(config.parallel_config)
self.config = convert_model_config(config)
@@ -1499,7 +1502,7 @@ class InferenceDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
self.load_checkpoint(config)
self.predict_run_mode = get_predict_run_mode()
logger.info("Predict run mode:{}".format(self.predict_run_mode) )
logger.info(f"Predict run mode:{self.predict_run_mode}" )
self.return_hidden_states = config.return_hidden_states
# pylint: disable=W0613
@@ -1602,7 +1605,6 @@ class InferenceDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
if dp_size == 1 or q_seq_len is None:
return model_inputs
from mindformers.parallel_core.inference.parallel_state import get_data_parallel_group
tokens_len_per_dp = q_seq_len.sum().reshape(-1)
tokens_len_per_dp = ops.AllGather(group=get_data_parallel_group().group)(tokens_len_per_dp)
tokens_len_per_dp = tokens_len_per_dp.asnumpy()
@@ -1749,7 +1751,7 @@ class DeepseekV3MTPLayer(nn.Cell):
"""
def __init__(self, config: DeepseekV3Config = None):
super(DeepseekV3MTPLayer, self ).__init__()
super().__init__()
self.enorm = RMSNorm(config.hidden_size, config.rms_norm_eps,
compute_type=config.layernorm_compute_type)
self.hnorm = RMSNorm(config.hidden_size, config.rms_norm_eps,
@@ -1826,7 +1828,7 @@ class DeepseekV3MTPModel(DeepseekV3PreTrainedModel):
"""
def __init__(self, config: DeepseekV3Config = None):
super(DeepseekV3MTPModel, self ).__init__(config, auto_prefix=True)
super().__init__(config, auto_prefix=True)
self.dtype = config.compute_dtype
self.use_past = config.use_past
self.is_first_iteration = True
@@ -1925,7 +1927,7 @@ class InferenceDeepseekV3MTPForCausalLM(DeepseekV3PreTrainedModel):
"""
def __init__(self, config: DeepseekV3Config = None):
super(InferenceDeepseekV3MTPForCausalLM, self ).__init__(config, auto_prefix=True)
super().__init__(config, auto_prefix=True)
self.dtype = config.compute_dtype
self.config = convert_model_config(config)
self.parallel_config = self.config.parallel_config