@@ -22,6 +22,10 @@ import mindspore.common.dtype as mstype
from mindspore import Parameter, Tensor, mint, nn, ops
from mindspore.common.initializer import initializer
from research.telechat2.infer.norm import RMSNorm
from research.telechat2.infer.parallel_paged_attention_mgr import ParallelPagedAttentionMgr
from research.telechat2.infer.scale_mask_softmax import ScaleMaskSoftmax
from research.telechat2.infer.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
from mindformers.parallel_core.inference.utils import divide, get_attn_mask_func
from mindformers.parallel_core.inference.transformer.activation import get_act_func
from mindformers.parallel_core.process_group_config import default_model_comm_pgs
@@ -31,11 +35,6 @@ from mindformers.modules.layers import FreqsMgr, RotaryEmbedding
from mindformers.modules.transformer import LowerTriangularMaskWithDynamic
from mindformers.version_control import need_nz
from research.telechat2.infer.norm import RMSNorm
from research.telechat2.infer.parallel_paged_attention_mgr import ParallelPagedAttentionMgr
from research.telechat2.infer.scale_mask_softmax import ScaleMaskSoftmax
from research.telechat2.infer.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
__all__ = [
"ParallelMLP",
@@ -197,7 +196,7 @@ class ParallelMLP(nn.Cell):
gate_hidden_out_shape = gate_hidden_out.shape
reshape_out = self.reshape(gate_hidden_out,
(*gate_hidden_out_shape[:-1], self.ffn_hidden_size_per_partition, 2))
gate, hidden = mint.spli t(reshape_out,
gate, hidden = ops.function.array_func.split_ex t(reshape_out,
(1, 1), -1)
gate = self.reshape(gate, (*gate_hidden_out_shape[:-1], self.ffn_hidden_size_per_partition))
hidden = self.reshape(hidden, (*gate_hidden_out_shape[:-1], self.ffn_hidden_size_per_partition))
@@ -239,7 +238,7 @@ class CoreAttention(nn.Cell):
"""
def __init__(self, layer_number, config, attn_mask_type=None):
super(CoreAttention, self ).__init__()
super().__init__()
if attn_mask_type:
raise NotImplementedError("For CoreAttention, `attn_mask_type` is not supported for now.")
self.config = config
@@ -355,7 +354,7 @@ class ParallelAttention(nn.Cell):
self.tp_group_size = self.tp.size
self.num_heads_per_partition = divide(self.num_heads, self.tp_group_size)
self.use_gqa = ( self.num_heads != self.kv_num_heads)
self.use_gqa = self.num_heads != self.kv_num_heads
if self.use_gqa:
self._check_gqa_valid()
@@ -426,7 +425,7 @@ class ParallelAttention(nn.Cell):
(-1,
self.kv_num_heads_per_partition,
(self.n_rep + 2) * self.head_dim))
query, key, value = mint.spli t(reshape_qkv,
query, key, value = ops.function.array_func.split_ex t(reshape_qkv,
(self.head_dim * self.n_rep,
self.head_dim,
self.head_dim), -1)
@@ -448,7 +447,8 @@ class ParallelAttention(nn.Cell):
query = self.cast(self.wq(x), self.compute_dtype)
if self.qkv_concat:
kv = self.cast(self.w_kv(encoder_output), self.compute_dtype)
key, value = mint.split(kv, (self.kv_hidden_size_per_partition, self.kv_hidden_size_per_partition), -1)
key, value = ops.function.array_func.split_ext(
kv, (self.kv_hidden_size_per_partition, self.kv_hidden_size_per_partition), -1)
else:
key = self.cast(self.wk(encoder_output), self.compute_dtype)
value = self.cast(self.wv(encoder_output), self.compute_dtype)
@@ -690,8 +690,8 @@ class ParallelTransformerLayer(nn.Cell):
raise NotImplementedError("For ParallelTransformerLayer, `self_attn_mask_type` is not supported for now.")
if drop_path_rate > 0.0:
raise NotImplementedError(
"For ParallelTransformerLayer, `drop_path_rate > 0` is not supported for now, "
"but got `drop_path_rate={}`".format(drop_path_rate)
f "For ParallelTransformerLayer, `drop_path_rate > 0` is not supported for now, "
f"but got `drop_path_rate={drop_path_rate}`"
)
self.config = config
self.apply_residual_connection_post_norm = self.config.apply_residual_connection_post_norm