2 Commits

9 changed files with 91 additions and 85 deletions
Split View
  1. +4
    -5
      mindformers/parallel_core/inference/transformer/attention.py
  2. +2
    -2
      mindformers/parallel_core/inference/transformer/mlp.py
  3. +2
    -2
      mindformers/parallel_core/inference/transformer/moe/experts.py
  4. +11
    -8
      mindformers/parallel_core/inference/transformer/multi_latent_attention.py
  5. +31
    -29
      research/deepseek3/deepseek3_model_infer.py
  6. +3
    -3
      research/deepseek3/infer/transformer.py
  7. +13
    -13
      research/deepseek3/moe.py
  8. +12
    -12
      research/llama3_1/infer/transformer.py
  9. +13
    -11
      research/qwen2_5/infer/transformer.py

+ 4
- 5
mindformers/parallel_core/inference/transformer/attention.py View File

@@ -24,7 +24,7 @@ from dataclasses import dataclass
import math
from typing import Union, Optional

from mindspore import mint, nn, ops
from mindspore import nn, ops

from mindformers.parallel_core.inference.quantization import QuantizationConfig
from mindformers.parallel_core.inference.transformer.identity_op import IdentityOp
@@ -146,13 +146,12 @@ class Attention(nn.Cell):
self.tp_group_size = self.tp.size

self.num_attention_heads_per_partition = divide(self.num_heads, self.tp_group_size)
self.use_gqa = (self.num_heads != self.num_query_groups)
self.use_gqa = self.num_heads != self.num_query_groups

if self.use_gqa:
self._check_gqa_valid()
# Note: Special handling when kv heads is less than tp size
if self.num_query_groups < self.tp_group_size:
self.num_query_groups = self.tp_group_size
self.num_query_groups = max(self.num_query_groups, self.tp_group_size)
self.num_query_groups_per_partition = divide(self.num_query_groups, self.tp_group_size)
self.repeat_num = divide(self.num_heads, self.num_query_groups)
else:
@@ -370,7 +369,7 @@ class SelfAttention(Attention):

def get_query_key_value_tensors(self, hidden_states):
qkv = self.cast(self.linear_qkv(hidden_states), self.compute_dtype)
query, key, value = mint.split(qkv,
query, key, value = ops.function.array_func.split_ext(qkv,
(self.hidden_size_per_partition,
self.kv_hidden_size_per_partition,
self.kv_hidden_size_per_partition), -1)


+ 2
- 2
mindformers/parallel_core/inference/transformer/mlp.py View File

@@ -21,7 +21,7 @@ __all__ = [
from dataclasses import dataclass
from typing import Union, Optional

from mindspore import nn, mint
from mindspore import nn, mint, ops

from mindformers.parallel_core.inference.quantization import QuantizationConfig
from mindformers.parallel_core.transformer_config import TransformerConfig
@@ -157,7 +157,7 @@ class MLP(nn.Cell):
intermediate_parallel = self.linear_fc1(hidden_states)

if self.config.gated_linear_unit:
gate, hidden = mint.split(intermediate_parallel,
gate, hidden = ops.function.array_func.split_ext(intermediate_parallel,
(self.ffn_hidden_size_per_partition,
self.ffn_hidden_size_per_partition), -1)
gate = self.activation_func(gate) if self.activation_type else gate


+ 2
- 2
mindformers/parallel_core/inference/transformer/moe/experts.py View File

@@ -18,7 +18,7 @@ __all__ = ["GroupedMLP"]

from typing import Optional

from mindspore import mint, nn
from mindspore import mint, nn, ops

from mindformers.parallel_core.transformer_config import TransformerConfig
from mindformers.parallel_core.utils.spec_utils import build_module
@@ -125,7 +125,7 @@ class GroupedMLP(nn.Cell):
intermediate_parallel = self.linear_fc1(hidden_states, group_list=group_list)

if self.config.gated_linear_unit:
gate, hidden = mint.split(intermediate_parallel,
gate, hidden = ops.function.array_func.split_ext(intermediate_parallel,
(self.ffn_hidden_size_per_partition,
self.ffn_hidden_size_per_partition), -1)
gate = self.activation_func(gate) if self.activation_type else gate


+ 11
- 8
mindformers/parallel_core/inference/transformer/multi_latent_attention.py View File

@@ -363,7 +363,7 @@ class MLASelfAttention(MultiLatentAttention):
Process the weight after loading.
This can be used for example, to transpose weights for computation.
"""
q_absorb, out_absorb = mint.split(self.linear_kv_up_proj.weight,
q_absorb, out_absorb = ops.function.array_func.split_ext(self.linear_kv_up_proj.weight,
[self.num_attention_heads_per_partition * self.config.qk_head_dim,
self.num_attention_heads_per_partition * self.config.v_head_dim], -2)
self.q_absorb = q_absorb.reshape(self.num_attention_heads_per_partition,
@@ -384,7 +384,7 @@ class MLASelfAttention(MultiLatentAttention):
hidden_states = self.input_layernorm(hidden_states)
if self.config.q_lora_rank is not None:
qkv = self.linear_qkv_down_proj(hidden_states)
kv_compressed, k_pos_emb, q_compressed = mint.split(qkv,
kv_compressed, k_pos_emb, q_compressed = ops.function.array_func.split_ext(qkv,
[self.config.kv_lora_rank,
self.config.qk_pos_emb_head_dim,
self.config.q_lora_rank],
@@ -404,7 +404,7 @@ class MLASelfAttention(MultiLatentAttention):
if kv_combined.shape[-1] != self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim:
# the shape of kv_combined is [s, b, (kv_lora_rank + qk_pos_emb_head_dim)]
kv_combined = gather_from_model_parallel_region(q_compressed, self.tp)
kv_compressed, k_pos_emb = mint.split(
kv_compressed, k_pos_emb = ops.function.array_func.split_ext(
kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1
)
# the shape of q is [num_tokens, n * (qk_head_dim + qk_pos_emb_head_dim)]
@@ -413,7 +413,8 @@ class MLASelfAttention(MultiLatentAttention):
# the shape of q is [num_tokens, n, q_head_dim]
q = q.reshape(*q.shape[:-1], self.num_attention_heads_per_partition, self.q_head_dim)
# the shape of q_no_pe is [num_tokens, n, qk_head_dim], q_pos_emb: [num_tokens, n, qk_pos_emb_head_dim]
q_no_pe, q_pos_emb = mint.split(q, [self.config.qk_head_dim, self.config.qk_pos_emb_head_dim], dim=-1)
q_no_pe, q_pos_emb = ops.function.array_func.split_ext(
q, [self.config.qk_head_dim, self.config.qk_pos_emb_head_dim], dim=-1)
# the shape of kv_compressed is [num_tokens, kv_lora_rank]
kv_compressed = self.kv_layernorm(kv_compressed)

@@ -443,8 +444,9 @@ class MLASelfAttention(MultiLatentAttention):

# the shape of k_no_pe is [num_tokens, qk_head_dim * self.kv_num_heads_per_partition],
# the shape of value is [num_tokens, v_head_dim * self.kv_num_heads_per_partition]
k_no_pe, value = mint.split(kv, [self.config.qk_head_dim * self.kv_num_heads_per_partition,
self.config.v_head_dim * self.kv_num_heads_per_partition], dim=-1)
k_no_pe, value = ops.function.array_func.split_ext(
kv, [self.config.qk_head_dim * self.kv_num_heads_per_partition,
self.config.v_head_dim * self.kv_num_heads_per_partition], dim=-1)
k_no_pe = k_no_pe.reshape(-1, self.kv_num_heads_per_partition, self.config.qk_head_dim)

# the shape of value_states is [num_tokens, n, v_head_dim]
@@ -531,7 +533,7 @@ class FusedMLASelfAttention(MLASelfAttention):
self.is_modelslim = quant_config.is_modelslim
self.fa3_quant = quant_config.fa3_quant
self.fa3_quant_layer = quant_config.fa3_quant_layer
self.is_fa3_quant_layer = (layer_number - 1) in self.fa3_quant_layer # layer_number start from 1
self.is_fa3_quant_layer = layer_number - 1 in self.fa3_quant_layer # layer_number start from 1
self.input_layernorm_weight = None
self.qkv_down_proj_input_scale = None
self.q_layernorm_weight = None
@@ -542,6 +544,7 @@ class FusedMLASelfAttention(MLASelfAttention):
self.q_up_proj_input_offset = None
self.input_format = 1 if self.fa3_quant else 0
self.use_ringmla = use_ms_custom_ops() and get_tensor_model_parallel_world_size() < 16
# pylint: disable=C0415
import ms_custom_ops
self.ms_custom_ops = ms_custom_ops
self.scale_value = 1 / math.sqrt(self.config.kv_lora_rank + self.config.qk_head_dim) \
@@ -793,7 +796,7 @@ class FusedMLASelfAttention(MLASelfAttention):
k_cache = self.transpose(key_cache.reshape(-1, self.config.kv_lora_rank // 32, \
self.config.block_size, 32), (0, 2, 1, 3)).reshape( \
-1, self.config.block_size, self.config.kv_lora_rank)
k_cache = (self.cast(k_cache, dtype.bfloat16) / self.quant_ctkv_scale)
k_cache = self.cast(k_cache, dtype.bfloat16) / self.quant_ctkv_scale
else:
k_cache = self.ms_custom_ops.trans_data(key_cache, transdata_type=0)
v_cache = self.ms_custom_ops.trans_data(value_cache, transdata_type=0)


+ 31
- 29
research/deepseek3/deepseek3_model_infer.py View File

@@ -37,6 +37,13 @@ try:
from mindspore._checkparam import Validator
except ImportError:
import mindspore._checkparam as Validator

from research.deepseek3.deepseek3_config import DeepseekV3Config
from research.deepseek3.moe import ExpertParallelMoE, ParallelMoEV2, RoutedParallelMLP, SharedMLP, SharedParallelMLP
from research.deepseek3.utils import convert_model_config
from research.deepseek3.infer.norm import RMSNorm
from research.deepseek3.infer.transformer import ParallelMLP, VocabEmbedding
from research.deepseek3.infer.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
from mindformers.models.modeling_utils import PreTrainedModel
from mindformers.models.utils import lazy_inline, check_fine_grain_interleave_valid, predict_lazy_inline,\
jit
@@ -59,13 +66,8 @@ from mindformers.parallel_core.inference.tensor_parallel.mappings import (gather
reduce_scatter_to_model_parallel_region,
scatter_to_model_parallel_region)
from mindformers.version_control import is_910b
from mindformers.parallel_core.inference.parallel_state import get_data_parallel_group

from research.deepseek3.deepseek3_config import DeepseekV3Config
from research.deepseek3.moe import ExpertParallelMoE, ParallelMoEV2, RoutedParallelMLP, SharedMLP, SharedParallelMLP
from research.deepseek3.utils import convert_model_config
from research.deepseek3.infer.norm import RMSNorm
from research.deepseek3.infer.transformer import ParallelMLP, VocabEmbedding
from research.deepseek3.infer.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding

__all__ = ['InferenceDeepseekV3ForCausalLM', 'DeepseekV3Model']

@@ -249,7 +251,7 @@ class MLAInferAttention(nn.Cell):
prefill_head_dim=None,
config: DeepseekV3Config = None
):
super(MLAInferAttention, self).__init__()
super().__init__()
self.n_head = n_head
self.head_dim = head_dim
self.n_kv_head = n_kv_head
@@ -438,14 +440,13 @@ class DeepseekV3Attention(nn.Cell):
raise ValueError("For 'DeepseekV3Attention', the use_flash_attention must be enabled.")

if self.hidden_size % self.n_head != 0:
raise ValueError("For 'MultiHeadAttention', the class variable 'hidden_size' must be a multiple "
"of 'n_head', but got the hidden_size is {} and the n_head is {}."
.format(self.hidden_size, self.n_head))
raise ValueError(f"For 'MultiHeadAttention', the class variable 'hidden_size' must be a multiple "
f"of 'n_head', but got the hidden_size is {self.hidden_size} and "
f"the n_head is {self.n_head}.")
if self.n_kv_head % parallel_config.model_parallel != 0:
raise ValueError("For 'MultiHeadAttention', the class variable 'n_kv_head' must be a multiple of "
"'parallel_config.model_parallel', but got the n_kv_head is {} "
"and the parallel_config.model_parallel is {}."
.format(self.n_kv_head, parallel_config.model_parallel))
raise ValueError(f"For 'MultiHeadAttention', the class variable 'n_kv_head' must be a multiple of "
f"'parallel_config.model_parallel', but got the n_kv_head is {self.n_kv_head} "
f"and the parallel_config.model_parallel is {parallel_config.model_parallel}.")
self.shape = P.Shape()
self.cast = P.Cast()
if self.q_lora_rank == 0:
@@ -572,12 +573,13 @@ class DeepseekV3Attention(nn.Cell):
if self.q_lora_rank == 0:
q = self.q_proj(x)
latent_kv_all = self.kv2l(x)
latent_kv, k_pe = mint.split(latent_kv_all, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_kv, k_pe = ops.function.array_func.split_ext(
latent_kv_all, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
else:
if self.qkv_concat:
qkv2l = self.qkv2l(x)
q, latent_kv, k_pe = mint.split(qkv2l, [self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim],
dim=-1)
q, latent_kv, k_pe = ops.function.array_func.split_ext(
qkv2l, [self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
norm_q = self.lq_norm(q)
q = self.l2q_proj(norm_q)
else:
@@ -585,10 +587,11 @@ class DeepseekV3Attention(nn.Cell):
norm_q = self.lq_norm(q)
q = self.l2q_proj(norm_q)
latent_kv_all = self.kv2l(x)
latent_kv, k_pe = mint.split(latent_kv_all, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_kv, k_pe = ops.function.array_func.split_ext(
latent_kv_all, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)

q = self.reshape(q, (-1, self.n_local_heads, self.q_head_dim))
q_nope, q_pe = mint.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
q_nope, q_pe = ops.function.array_func.split_ext(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# (T, kv_lora_rank)
i_kv = self.lkv_norm(latent_kv)
q_pe = self.reshape(q_pe, (-1, self.n_local_heads * self.qk_rope_head_dim))
@@ -663,7 +666,7 @@ class DeepseekV3ParallelMLP(ParallelMLP):
# [B, S, H] -> [B, S, ffn_H]
if self.ffn_concat:
gate_hidden_out = self.w_gate_hidden(x) # dp,1 -> dp, mp # dp,1 -> dp, mp
gate, hidden = mint.split(gate_hidden_out,
gate, hidden = ops.function.array_func.split_ext(gate_hidden_out,
(self.ffn_hidden_size_per_partition, self.ffn_hidden_size_per_partition), -1)
else:
gate = self.w1(x) # dp,1 -> dp, mp
@@ -692,7 +695,7 @@ class DeepseekV3MoE(Cell):
"""

def __init__(self, config):
super(DeepseekV3MoE, self).__init__()
super().__init__()
self.config = config
self.parallel_config = config.parallel_config
self.moe_config = config.moe_config
@@ -766,7 +769,7 @@ class DeepseekV3MoEWithMicroBatch(DeepseekV3MoE):
"""

def __init__(self, config):
super(DeepseekV3MoEWithMicroBatch, self).__init__(config=config)
super().__init__(config=config)
self.moe_tp_size = get_moe_tp_world_size()
self.moe_ep_size = get_moe_ep_world_size()
self.ep_rank_id = get_rank() // self.moe_tp_size
@@ -846,7 +849,7 @@ class AttentionReduceScatter(Cell):
"""

def __init__(self, config):
super(AttentionReduceScatter, self).__init__()
super().__init__()
self.config = config
self.compute_dtype = config.compute_dtype
self.hidden_size = config.hidden_size
@@ -1439,7 +1442,7 @@ class InferenceDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):

@lazy_inline
def __init__(self, config: DeepseekV3Config = None):
super(InferenceDeepseekV3ForCausalLM, self).__init__(config, auto_prefix=True)
super().__init__(config, auto_prefix=True)
_check_config(config.parallel_config)

self.config = convert_model_config(config)
@@ -1499,7 +1502,7 @@ class InferenceDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):

self.load_checkpoint(config)
self.predict_run_mode = get_predict_run_mode()
logger.info("Predict run mode:{}".format(self.predict_run_mode))
logger.info(f"Predict run mode:{self.predict_run_mode}")
self.return_hidden_states = config.return_hidden_states

# pylint: disable=W0613
@@ -1602,7 +1605,6 @@ class InferenceDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
if dp_size == 1 or q_seq_len is None:
return model_inputs

from mindformers.parallel_core.inference.parallel_state import get_data_parallel_group
tokens_len_per_dp = q_seq_len.sum().reshape(-1)
tokens_len_per_dp = ops.AllGather(group=get_data_parallel_group().group)(tokens_len_per_dp)
tokens_len_per_dp = tokens_len_per_dp.asnumpy()
@@ -1749,7 +1751,7 @@ class DeepseekV3MTPLayer(nn.Cell):
"""

def __init__(self, config: DeepseekV3Config = None):
super(DeepseekV3MTPLayer, self).__init__()
super().__init__()
self.enorm = RMSNorm(config.hidden_size, config.rms_norm_eps,
compute_type=config.layernorm_compute_type)
self.hnorm = RMSNorm(config.hidden_size, config.rms_norm_eps,
@@ -1826,7 +1828,7 @@ class DeepseekV3MTPModel(DeepseekV3PreTrainedModel):
"""

def __init__(self, config: DeepseekV3Config = None):
super(DeepseekV3MTPModel, self).__init__(config, auto_prefix=True)
super().__init__(config, auto_prefix=True)
self.dtype = config.compute_dtype
self.use_past = config.use_past
self.is_first_iteration = True
@@ -1925,7 +1927,7 @@ class InferenceDeepseekV3MTPForCausalLM(DeepseekV3PreTrainedModel):
"""

def __init__(self, config: DeepseekV3Config = None):
super(InferenceDeepseekV3MTPForCausalLM, self).__init__(config, auto_prefix=True)
super().__init__(config, auto_prefix=True)
self.dtype = config.compute_dtype
self.config = convert_model_config(config)
self.parallel_config = self.config.parallel_config


+ 3
- 3
research/deepseek3/infer/transformer.py View File

@@ -17,11 +17,11 @@ import mindspore.common.dtype as mstype
from mindspore import Parameter, mint, nn, ops
from mindspore.common.initializer import initializer

from research.deepseek3.infer.activation import SiLU
from research.deepseek3.infer.layers import ColumnParallelLinear, RowParallelLinear
from mindformers.parallel_core.inference.utils import get_tp_world_size
from mindformers.parallel_core.inference.parallel_state import get_tensor_model_parallel_group
from mindformers.tools.utils import divide
from research.deepseek3.infer.activation import SiLU
from research.deepseek3.infer.layers import ColumnParallelLinear, RowParallelLinear


class VocabEmbedding(nn.Cell):
@@ -174,7 +174,7 @@ class ParallelMLP(nn.Cell):
gate_hidden_out_shape = gate_hidden_out.shape
reshape_out = self.reshape(gate_hidden_out,
(*gate_hidden_out_shape[:-1], self.ffn_hidden_size_per_partition, 2))
gate, hidden = mint.split(reshape_out,
gate, hidden = ops.function.array_func.split_ext(reshape_out,
(1, 1), -1)
gate = self.reshape(gate, (*gate_hidden_out_shape[:-1], self.ffn_hidden_size_per_partition))
hidden = self.reshape(hidden, (*gate_hidden_out_shape[:-1], self.ffn_hidden_size_per_partition))


+ 13
- 13
research/deepseek3/moe.py View File

@@ -37,6 +37,8 @@ try:
except ImportError:
MOE_FUSED_OP_VALID = False

from research.deepseek3.infer.activation import SiLU
from research.deepseek3.infer.layers import ColumnParallelLinear, RowParallelLinear
from mindformers.modules.layers import Linear
from mindformers.parallel_core.inference.parallel_state import (default_pgs, get_moe_expert_parallel_group,
get_moe_expert_parallel_world_size,
@@ -45,8 +47,6 @@ from mindformers.parallel_core.inference.parallel_state import (default_pgs, get
from mindformers.version_control import is_910b
from mindformers.tools.utils import divide

from research.deepseek3.infer.activation import SiLU
from research.deepseek3.infer.layers import ColumnParallelLinear, RowParallelLinear

dtype_map = {
'float16': mstype.float32,
@@ -60,7 +60,7 @@ class TopkRouter(nn.Cell):
A router implementation which maps each tokens to the topk expert.
"""
def __init__(self, expert_num):
super(TopkRouter, self).__init__()
super().__init__()
self.topk_bias = Parameter(initializer('zeros', (expert_num), mstype.float32),
requires_grad=False, parallel_optimizer=False)

@@ -73,7 +73,7 @@ class Router(nn.Cell):
def __init__(self,
hidden_size,
moe_config):
super(Router, self).__init__()
super().__init__()
self.expert_num = moe_config.expert_num
self.dense = nn.Dense(in_channels=hidden_size, out_channels=self.expert_num,
has_bias=False, dtype=dtype_map.get(moe_config.router_dense_type))
@@ -103,7 +103,7 @@ class ParallelMoE(nn.Cell):
hidden_size,
moe_config,
use_fused_op=True):
super(ParallelMoE, self).__init__()
super().__init__()
self.hidden_size = hidden_size
self.moe_config = moe_config
self.expert_dim = moe_config.expert_num
@@ -290,7 +290,7 @@ class SharedMLP(nn.Cell):
""" Construct function of mlp block. """
if self.ffn_concat:
gate_hidden_out = self.w_gate_hidden(x) # dp,1 -> dp, mp # dp,1 -> dp, mp
gate, hidden = mint.split(gate_hidden_out,
gate, hidden = ops.function.array_func.split_ext(gate_hidden_out,
(self.ffn_hidden_size, self.ffn_hidden_size), -1)
else:
gate = self.w1(x)
@@ -387,7 +387,7 @@ class SharedParallelMLP(nn.Cell):
""" Construct function of mlp block. """
if self.ffn_concat:
gate_hidden_out = self.w_gate_hidden(x) # dp,1 -> dp, mp # dp,1 -> dp, mp
gate, hidden = mint.split(gate_hidden_out,
gate, hidden = ops.function.array_func.split_ext(gate_hidden_out,
(self.ffn_hidden_size_per_partition, self.ffn_hidden_size_per_partition), -1)
else:
gate = self.w1(x)
@@ -455,7 +455,7 @@ class ColumnParallelGroupLinear(ColumnParallelLinear):
tp_group=default_pgs,
**kwargs
):
super(ColumnParallelGroupLinear, self).__init__(
super().__init__(
input_size=input_size,
output_size=output_size,
config=config,
@@ -541,7 +541,7 @@ class RowParallelGroupLinear(RowParallelLinear):
tp_group=default_pgs,
**kwargs
):
super(RowParallelGroupLinear, self).__init__(
super().__init__(
input_size=input_size,
output_size=output_size,
config=config,
@@ -661,7 +661,7 @@ class RoutedParallelMLP(nn.Cell):
"""Forward process of the FeedForward"""
if self.ffn_concat:
gate_hidden_out = self.w_gate_hidden(x, group_list=group_list) # dp,1 -> dp, mp # dp,1 -> dp, mp
gate, hidden = mint.split(gate_hidden_out,
gate, hidden = ops.function.array_func.split_ext(gate_hidden_out,
(self.ffn_hidden_size_per_partition, self.ffn_hidden_size_per_partition), -1)
else:
gate = self.w1(x, group_list=group_list)
@@ -711,7 +711,7 @@ class ParallelMoEV2(nn.Cell):
hidden_size,
moe_config,
is_reduce_moe_output=True):
super(ParallelMoEV2, self).__init__()
super().__init__()
self.hidden_size = hidden_size
self.moe_config = moe_config
self.is_reduce_moe_output = is_reduce_moe_output
@@ -821,7 +821,7 @@ class ExpertParallelMoE(nn.Cell):
moe_config,
use_alltoall,
compute_dtype):
super(ExpertParallelMoE, self).__init__()
super().__init__()
self.compute_dtype = compute_dtype
self.hidden_size = hidden_size
self.moe_config = moe_config
@@ -860,7 +860,7 @@ class ExpertParallelMoE(nn.Cell):
self.group_list_index = Tensor([0,], mstype.int32)

if self.moe_ep_size > 1 and not self.use_alltoall:
bias_idx = [idx for idx in range(self.expert_num)]
bias_idx = list(range(self.expert_num))
self.bias_idx = bias_idx[self.in_start_expert_idx:] + bias_idx[:self.in_start_expert_idx]
self.router.e_score_correction_bias.init_data()
self.router.e_score_correction_bias = self.router.e_score_correction_bias[self.bias_idx]


+ 12
- 12
research/llama3_1/infer/transformer.py View File

@@ -27,6 +27,10 @@ import mindspore.common.dtype as mstype
from mindspore import Parameter, Tensor, mint, nn, ops
from mindspore.common.initializer import initializer

from research.llama3_1.infer.norm import RMSNorm
from research.llama3_1.infer.parallel_paged_attention_mgr import ParallelPagedAttentionMgr
from research.llama3_1.infer.scale_mask_softmax import ScaleMaskSoftmax
from research.llama3_1.infer.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
from mindformers.parallel_core.inference.utils import divide, get_attn_mask_func
from mindformers.parallel_core.inference.transformer.activation import get_act_func
from mindformers.parallel_core.process_group_config import default_model_comm_pgs
@@ -36,11 +40,6 @@ from mindformers.modules.layers import FreqsMgr, RotaryEmbedding
from mindformers.modules.transformer import LowerTriangularMaskWithDynamic
from mindformers.version_control import need_nz

from research.llama3_1.infer.norm import RMSNorm
from research.llama3_1.infer.parallel_paged_attention_mgr import ParallelPagedAttentionMgr
from research.llama3_1.infer.scale_mask_softmax import ScaleMaskSoftmax
from research.llama3_1.infer.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding


class VocabEmbedding(nn.Cell):
"""
@@ -194,7 +193,7 @@ class ParallelMLP(nn.Cell):
gate_hidden_out_shape = gate_hidden_out.shape
reshape_out = self.reshape(gate_hidden_out,
(*gate_hidden_out_shape[:-1], self.ffn_hidden_size_per_partition, 2))
gate, hidden = mint.split(reshape_out,
gate, hidden = ops.function.array_func.split_ext(reshape_out,
(1, 1), -1)
gate = self.reshape(gate, (*gate_hidden_out_shape[:-1], self.ffn_hidden_size_per_partition))
hidden = self.reshape(hidden, (*gate_hidden_out_shape[:-1], self.ffn_hidden_size_per_partition))
@@ -236,7 +235,7 @@ class CoreAttention(nn.Cell):
"""

def __init__(self, layer_number, config, attn_mask_type=None):
super(CoreAttention, self).__init__()
super().__init__()
if attn_mask_type:
raise NotImplementedError("For CoreAttention, `attn_mask_type` is not supported for now.")
self.config = config
@@ -352,7 +351,7 @@ class ParallelAttention(nn.Cell):
self.tp_group_size = self.tp.size
self.num_heads_per_partition = divide(self.num_heads, self.tp_group_size)

self.use_gqa = (self.num_heads != self.kv_num_heads)
self.use_gqa = self.num_heads != self.kv_num_heads

if self.use_gqa:
self._check_gqa_valid()
@@ -423,7 +422,7 @@ class ParallelAttention(nn.Cell):
(-1,
self.kv_num_heads_per_partition,
(self.n_rep + 2) * self.head_dim))
query, key, value = mint.split(reshape_qkv,
query, key, value = ops.function.array_func.split_ext(reshape_qkv,
(self.head_dim * self.n_rep,
self.head_dim,
self.head_dim), -1)
@@ -445,7 +444,8 @@ class ParallelAttention(nn.Cell):
query = self.cast(self.wq(x), self.compute_dtype)
if self.qkv_concat:
kv = self.cast(self.w_kv(encoder_output), self.compute_dtype)
key, value = mint.split(kv, (self.kv_hidden_size_per_partition, self.kv_hidden_size_per_partition), -1)
key, value = ops.function.array_func.split_ext(
kv, (self.kv_hidden_size_per_partition, self.kv_hidden_size_per_partition), -1)
else:
key = self.cast(self.wk(encoder_output), self.compute_dtype)
value = self.cast(self.wv(encoder_output), self.compute_dtype)
@@ -687,8 +687,8 @@ class ParallelTransformerLayer(nn.Cell):
raise NotImplementedError("For ParallelTransformerLayer, `self_attn_mask_type` is not supported for now.")
if drop_path_rate > 0.0:
raise NotImplementedError(
"For ParallelTransformerLayer, `drop_path_rate > 0` is not supported for now, "
"but got `drop_path_rate={}`".format(drop_path_rate)
f"For ParallelTransformerLayer, `drop_path_rate > 0` is not supported for now, "
f"but got `drop_path_rate={drop_path_rate}`"
)
self.config = config
self.apply_residual_connection_post_norm = self.config.apply_residual_connection_post_norm


+ 13
- 11
research/qwen2_5/infer/transformer.py View File

@@ -21,6 +21,10 @@ import mindspore.common.dtype as mstype
from mindspore import Parameter, Tensor, mint, nn, ops
from mindspore.common.initializer import initializer

from research.qwen2_5.infer.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
from research.qwen2_5.infer.norm import get_norm
from research.qwen2_5.infer.parallel_paged_attention_mgr import ParallelPagedAttentionMgr
from research.qwen2_5.infer.scale_mask_softmax import ScaleMaskSoftmax
from mindformers.modules.flash_attention import FlashAttention
from mindformers.modules.infer_attention import InferRotaryEmbedding
from mindformers.modules.layers import FreqsMgr, RotaryEmbedding
@@ -30,10 +34,7 @@ from mindformers.parallel_core.inference.utils import divide
from mindformers.parallel_core.inference.utils import get_attn_mask_func
from mindformers.parallel_core.process_group_config import default_model_comm_pgs
from mindformers.version_control import need_nz
from research.qwen2_5.infer.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
from research.qwen2_5.infer.norm import get_norm
from research.qwen2_5.infer.parallel_paged_attention_mgr import ParallelPagedAttentionMgr
from research.qwen2_5.infer.scale_mask_softmax import ScaleMaskSoftmax


__all__ = [
"ParallelMLP",
@@ -195,7 +196,7 @@ class ParallelMLP(nn.Cell):
gate_hidden_out_shape = gate_hidden_out.shape
reshape_out = self.reshape(gate_hidden_out,
(*gate_hidden_out_shape[:-1], self.ffn_hidden_size_per_partition, 2))
gate, hidden = mint.split(reshape_out,
gate, hidden = ops.function.array_func.split_ext(reshape_out,
(1, 1), -1)
gate = self.reshape(gate, (*gate_hidden_out_shape[:-1], self.ffn_hidden_size_per_partition))
hidden = self.reshape(hidden, (*gate_hidden_out_shape[:-1], self.ffn_hidden_size_per_partition))
@@ -237,7 +238,7 @@ class CoreAttention(nn.Cell):
"""

def __init__(self, layer_number, config, attn_mask_type=None):
super(CoreAttention, self).__init__()
super().__init__()
if attn_mask_type:
raise NotImplementedError("For CoreAttention, `attn_mask_type` is not supported for now.")
self.config = config
@@ -353,7 +354,7 @@ class ParallelAttention(nn.Cell):
self.tp_group_size = self.tp.size
self.num_heads_per_partition = divide(self.num_heads, self.tp_group_size)

self.use_gqa = (self.num_heads != self.kv_num_heads)
self.use_gqa = self.num_heads != self.kv_num_heads

if self.use_gqa:
self._check_gqa_valid()
@@ -424,7 +425,7 @@ class ParallelAttention(nn.Cell):
(-1,
self.kv_num_heads_per_partition,
(self.n_rep + 2) * self.head_dim))
query, key, value = mint.split(reshape_qkv,
query, key, value = ops.function.array_func.split_ext(reshape_qkv,
(self.head_dim * self.n_rep,
self.head_dim,
self.head_dim), -1)
@@ -446,7 +447,8 @@ class ParallelAttention(nn.Cell):
query = self.cast(self.wq(x), self.compute_dtype)
if self.qkv_concat:
kv = self.cast(self.w_kv(encoder_output), self.compute_dtype)
key, value = mint.split(kv, (self.kv_hidden_size_per_partition, self.kv_hidden_size_per_partition), -1)
key, value = ops.function.array_func.split_ext(
kv, (self.kv_hidden_size_per_partition, self.kv_hidden_size_per_partition), -1)
else:
key = self.cast(self.wk(encoder_output), self.compute_dtype)
value = self.cast(self.wv(encoder_output), self.compute_dtype)
@@ -688,8 +690,8 @@ class ParallelTransformerLayer(nn.Cell):
raise NotImplementedError("For ParallelTransformerLayer, `self_attn_mask_type` is not supported for now.")
if drop_path_rate > 0.0:
raise NotImplementedError(
"For ParallelTransformerLayer, `drop_path_rate > 0` is not supported for now, "
"but got `drop_path_rate={}`".format(drop_path_rate)
f"For ParallelTransformerLayer, `drop_path_rate > 0` is not supported for now, "
f"but got `drop_path_rate={drop_path_rate}`"
)
self.config = config
self.apply_residual_connection_post_norm = self.config.apply_residual_connection_post_norm


Loading…
Cancel
Save
Baidu
map