2 Commits

3 changed files with 32 additions and 30 deletions
Split View
  1. +8
    -8
      research/telechat2/infer/moe.py
  2. +12
    -10
      research/telechat2/infer/telechat_transformers.py
  3. +12
    -12
      research/telechat2/infer/transformer.py

+ 8
- 8
research/telechat2/infer/moe.py View File

@@ -16,17 +16,17 @@
Note: Mixture of Expert (MoE) structure. This is an experimental interface that is subject to change or deletion.
"""
import mindspore.common.dtype as mstype
from mindspore import Tensor, nn, Parameter, mint
from mindspore import Tensor, nn, Parameter, mint, ops
from mindspore.ops import operations as P
from mindspore.common.initializer import initializer

from research.telechat2.infer.layers import ColumnParallelLinear, RowParallelLinear
from mindformers.parallel_core.inference.transformer.activation import get_act_func
from mindformers.parallel_core.inference.parallel_state import default_pgs, get_moe_tensor_parallel_group
# pylint: disable=C0412
from mindformers.parallel_core.inference.utils import get_moe_ep_world_size, get_moe_tp_world_size
from mindformers.tools.utils import divide

from research.telechat2.infer.layers import ColumnParallelLinear, RowParallelLinear

MOE_FUSED_OP_VALID = True

@@ -51,7 +51,7 @@ class TopkRouter(nn.Cell):
A router implementation which maps each tokens to the topk expert.
"""
def __init__(self, expert_num):
super(TopkRouter, self).__init__()
super().__init__()
self.topk_bias = Parameter(initializer('zeros', (expert_num), mstype.float32),
requires_grad=False, parallel_optimizer=False)

@@ -64,7 +64,7 @@ class Router(nn.Cell):
def __init__(self,
hidden_size,
moe_config):
super(Router, self).__init__()
super().__init__()
self.expert_num = moe_config.expert_num
self.dense = nn.Dense(in_channels=hidden_size, out_channels=self.expert_num,
has_bias=False, dtype=dtype_map.get(moe_config.router_dense_type))
@@ -94,7 +94,7 @@ class ParallelMoE(nn.Cell):
hidden_size,
moe_config,
use_fused_op=True):
super(ParallelMoE, self).__init__()
super().__init__()
self.hidden_size = hidden_size
self.moe_config = moe_config
self.expert_dim = moe_config.expert_num
@@ -272,7 +272,7 @@ class ColumnParallelGroupLinear(ColumnParallelLinear):
tp_group=default_pgs,
**kwargs
):
super(ColumnParallelGroupLinear, self).__init__(
super().__init__(
input_size=input_size,
output_size=output_size,
config=config,
@@ -357,7 +357,7 @@ class RowParallelGroupLinear(RowParallelLinear):
tp_group=default_pgs,
**kwargs
):
super(RowParallelGroupLinear, self).__init__(
super().__init__(
input_size=input_size,
output_size=output_size,
config=config,
@@ -476,7 +476,7 @@ class RoutedParallelMLP(nn.Cell):
"""Forward process of the FeedForward"""
if self.ffn_concat:
gate_hidden_out = self.w_gate_hidden(x, group_list=group_list) # dp,1 -> dp, mp # dp,1 -> dp, mp
gate, hidden = mint.split(gate_hidden_out,
gate, hidden = ops.function.array_func.split_ext(gate_hidden_out,
(self.ffn_hidden_size_per_partition, self.ffn_hidden_size_per_partition), -1)
else:
gate = self.w1(x, group_list=group_list)


+ 12
- 10
research/telechat2/infer/telechat_transformers.py View File

@@ -19,16 +19,16 @@ import numpy as np
import mindspore.common.dtype as mstype
from mindspore import nn, ops, mint, Tensor

from research.telechat2.infer.moe import ParallelMoE, RoutedParallelMLP
from research.telechat2.infer.layers import ColumnParallelLinear, RowParallelLinear
from research.telechat2.infer.transformer import \
ParallelAttention, ParallelTransformerLayer, ParallelTransformer, ParallelMLP
from mindformers.parallel_core.inference.utils import divide
from mindformers.parallel_core.inference.tensor_parallel.mappings import reduce_from_model_parallel_region
from mindformers.parallel_core.inference.parallel_state import get_tensor_model_parallel_group
from mindformers.parallel_core.process_group_config import default_model_comm_pgs
from mindformers.modules.layers import FreqsMgrDynamicNTK
from mindformers.tools.logger import logger
from research.telechat2.infer.moe import ParallelMoE, RoutedParallelMLP
from research.telechat2.infer.layers import ColumnParallelLinear, RowParallelLinear
from research.telechat2.infer.transformer import \
ParallelAttention, ParallelTransformerLayer, ParallelTransformer, ParallelMLP


class TelechatParallelMoE(ParallelMoE):
@@ -52,7 +52,7 @@ class TelechatParallelMoE(ParallelMoE):
hidden_size,
moe_config,
use_fused_op=True):
super(TelechatParallelMoE, self).__init__(
super().__init__(
ffn=ffn,
hidden_size=hidden_size,
moe_config=moe_config,
@@ -167,8 +167,9 @@ class TelechatParallelMLP(ParallelMLP):
if self.mlp_has_gate:
if self.ffn_concat:
gate_hidden_out = self.w_gate_hidden(x) # dp,1 -> dp, mp # dp,1 -> dp, mp
gate, hidden = mint.split(gate_hidden_out,
(self.ffn_hidden_size_per_partition, self.ffn_hidden_size_per_partition), -1)
gate, hidden = ops.function.array_func.split_ext(
gate_hidden_out,
(self.ffn_hidden_size_per_partition, self.ffn_hidden_size_per_partition), -1)
else:
gate = self.w1(x) # dp,1 -> dp, mp
hidden = self.w3(x) # dp,1 -> dp, mp
@@ -219,7 +220,7 @@ class TelechatParallelAttention(ParallelAttention):
# [B, S, H] --> [B, S, H + 2 * kv_H]
if self.qkv_concat:
qkv = self.cast(self.w_qkv(x), self.compute_dtype)
query, key, value = mint.split(qkv,
query, key, value = ops.function.array_func.split_ext(qkv,
(self.hidden_size_per_partition,
self.kv_hidden_size_per_partition,
self.kv_hidden_size_per_partition), -1)
@@ -227,14 +228,15 @@ class TelechatParallelAttention(ParallelAttention):
query = self.cast(self.wq(x), self.compute_dtype)
key_value = self.cast(self.wk_v(x), self.compute_dtype)
key_value = key_value.reshape(-1, self.kv_num_heads_per_partition, self.head_dim * 2)
key, value = mint.split(key_value, (self.head_dim, self.head_dim), -1)
key, value = ops.function.array_func.split_ext(key_value, (self.head_dim, self.head_dim), -1)
key = key.reshape(bs, seq_len, -1)
value = value.reshape(bs, seq_len, -1)
else:
query = self.cast(self.wq(x), self.compute_dtype)
if self.qkv_concat:
kv = self.cast(self.w_kv(encoder_output), self.compute_dtype)
key, value = mint.split(kv, (self.kv_hidden_size_per_partition, self.kv_hidden_size_per_partition), -1)
key, value = ops.function.array_func.split_ext(
kv, (self.kv_hidden_size_per_partition, self.kv_hidden_size_per_partition), -1)
else:
key = self.cast(self.wk(encoder_output), self.compute_dtype)
value = self.cast(self.wv(encoder_output), self.compute_dtype)


+ 12
- 12
research/telechat2/infer/transformer.py View File

@@ -22,6 +22,10 @@ import mindspore.common.dtype as mstype
from mindspore import Parameter, Tensor, mint, nn, ops
from mindspore.common.initializer import initializer

from research.telechat2.infer.norm import RMSNorm
from research.telechat2.infer.parallel_paged_attention_mgr import ParallelPagedAttentionMgr
from research.telechat2.infer.scale_mask_softmax import ScaleMaskSoftmax
from research.telechat2.infer.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
from mindformers.parallel_core.inference.utils import divide, get_attn_mask_func
from mindformers.parallel_core.inference.transformer.activation import get_act_func
from mindformers.parallel_core.process_group_config import default_model_comm_pgs
@@ -31,11 +35,6 @@ from mindformers.modules.layers import FreqsMgr, RotaryEmbedding
from mindformers.modules.transformer import LowerTriangularMaskWithDynamic
from mindformers.version_control import need_nz

from research.telechat2.infer.norm import RMSNorm
from research.telechat2.infer.parallel_paged_attention_mgr import ParallelPagedAttentionMgr
from research.telechat2.infer.scale_mask_softmax import ScaleMaskSoftmax
from research.telechat2.infer.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding


__all__ = [
"ParallelMLP",
@@ -197,7 +196,7 @@ class ParallelMLP(nn.Cell):
gate_hidden_out_shape = gate_hidden_out.shape
reshape_out = self.reshape(gate_hidden_out,
(*gate_hidden_out_shape[:-1], self.ffn_hidden_size_per_partition, 2))
gate, hidden = mint.split(reshape_out,
gate, hidden = ops.function.array_func.split_ext(reshape_out,
(1, 1), -1)
gate = self.reshape(gate, (*gate_hidden_out_shape[:-1], self.ffn_hidden_size_per_partition))
hidden = self.reshape(hidden, (*gate_hidden_out_shape[:-1], self.ffn_hidden_size_per_partition))
@@ -239,7 +238,7 @@ class CoreAttention(nn.Cell):
"""

def __init__(self, layer_number, config, attn_mask_type=None):
super(CoreAttention, self).__init__()
super().__init__()
if attn_mask_type:
raise NotImplementedError("For CoreAttention, `attn_mask_type` is not supported for now.")
self.config = config
@@ -355,7 +354,7 @@ class ParallelAttention(nn.Cell):
self.tp_group_size = self.tp.size
self.num_heads_per_partition = divide(self.num_heads, self.tp_group_size)

self.use_gqa = (self.num_heads != self.kv_num_heads)
self.use_gqa = self.num_heads != self.kv_num_heads

if self.use_gqa:
self._check_gqa_valid()
@@ -426,7 +425,7 @@ class ParallelAttention(nn.Cell):
(-1,
self.kv_num_heads_per_partition,
(self.n_rep + 2) * self.head_dim))
query, key, value = mint.split(reshape_qkv,
query, key, value = ops.function.array_func.split_ext(reshape_qkv,
(self.head_dim * self.n_rep,
self.head_dim,
self.head_dim), -1)
@@ -448,7 +447,8 @@ class ParallelAttention(nn.Cell):
query = self.cast(self.wq(x), self.compute_dtype)
if self.qkv_concat:
kv = self.cast(self.w_kv(encoder_output), self.compute_dtype)
key, value = mint.split(kv, (self.kv_hidden_size_per_partition, self.kv_hidden_size_per_partition), -1)
key, value = ops.function.array_func.split_ext(
kv, (self.kv_hidden_size_per_partition, self.kv_hidden_size_per_partition), -1)
else:
key = self.cast(self.wk(encoder_output), self.compute_dtype)
value = self.cast(self.wv(encoder_output), self.compute_dtype)
@@ -690,8 +690,8 @@ class ParallelTransformerLayer(nn.Cell):
raise NotImplementedError("For ParallelTransformerLayer, `self_attn_mask_type` is not supported for now.")
if drop_path_rate > 0.0:
raise NotImplementedError(
"For ParallelTransformerLayer, `drop_path_rate > 0` is not supported for now, "
"but got `drop_path_rate={}`".format(drop_path_rate)
f"For ParallelTransformerLayer, `drop_path_rate > 0` is not supported for now, "
f"but got `drop_path_rate={drop_path_rate}`"
)
self.config = config
self.apply_residual_connection_post_norm = self.config.apply_residual_connection_post_norm


Loading…
Cancel
Save
Baidu
map