@@ -16,6 +16,9 @@ from mindspore.parallel._cost_model_context import _set_multi_subgraphs
from mindspore.parallel import set_algo_parameters
from pangu_alpha import PANGUALPHAPipeline, PANGUALPHA, EvalNet
from pangu_alpha_config import PANGUALPHAConfig
from tokenization_jieba import JIEBATokenizer
from generate import generate
from extra_utils import TimeSpan
def run_predict_pipeline(args_opt):
@@ -118,6 +121,17 @@ def run_predict_pipeline(args_opt):
model_predict.predict(inputs_np)
def generate_one_sample(sample, tokenizer, model_predict, config):
ts = TimeSpan.start('generate for {}'.format(sample))
tokenized_token = tokenizer.tokenize(sample)
start_sentence = tokenizer.convert_tokens_to_ids(tokenized_token)
input_ids = np.array(start_sentence).reshape(1, -1)
output_ids = generate(model_predict, input_ids, config.seq_length, 9)
output_samples = tokenizer.convert_ids_to_tokens(output_ids.tolist())
print('Output is:', output_samples, flush=True)
ts.end()
def run_predict_no_pipeline(args_opt):
device_id = int(os.getenv("DEVICE_ID"))
rank_id_str = os.getenv('RANK_ID', '0')
@@ -194,6 +208,7 @@ def run_predict_no_pipeline(args_opt):
inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32)
predict_layout = model_predict.infer_predict_layout(inputs_np)
print("======start load_distributed checkpoint", flush=True)
load_ckpt_ts = TimeSpan.start('load distributed checkpoint')
# For 2.6B and 13B models, the number of ckpt files is 512.
ckpt_name = 'filerted'
@@ -201,24 +216,34 @@ def run_predict_no_pipeline(args_opt):
print(f"Loading from path {ckpt_file_list[0]}", flush=True)
load_distributed_checkpoint(eval_net, ckpt_file_list, predict_layout)
print("================load param ok=================", flush=True)
load_ckpt_ts.end()
from tokenization_jieba import JIEBATokenizer
from generate import generate
tokenizer = JIEBATokenizer(os.path.join(args_opt.tokenizer_path, 'vocab.vocab'),
os.path.join(args_opt.tokenizer_path, 'vocab.model'))
sample = "今天是一个好天气"
tokenized_token = tokenizer.tokenize(sample)
start_sentence = tokenizer.convert_tokens_to_ids(tokenized_token)
input_ids = np.array(start_sentence).reshape(1, -1)
output_ids = generate(model_predict, input_ids, config.seq_length, 9)
output_samples = tokenizer.convert_ids_to_tokens(output_ids.tolist())
print('Output is:', output_samples, flush=True)
generate_one_sample(sample, tokenizer, model_predict, config)
samples = ['上联:瑞风播福泽,事业具昌盛千家乐',
'四川的省会是?',
'上联:春雨润人间,社会和谐万象新',
'''书生:羌笛何须怨杨柳,春风不度玉门关。
飞云:(这诗怎么这么耳熟?且过去跟他聊聊如何。)
书生:小兄弟,要不要一起喝一杯?
飞云:你请我呀?你若是请我,我便和你喝一杯;你若不请我,我便一个人去喝。
书生:小兄弟,看你年纪轻轻,不至于这么势利吧?
飞云:''',
'张无忌拿出屠龙宝刀,手起刀落,周芷若掉了一颗门牙,身旁的赵敏喜极而泣,',
'人工智能成为国际竞争的新焦点。人工智能是引领未来的战略性技术,世界主要发达国家把发展人工智能作为提升国家竞争力、维护国家安全的重大战略,加紧出台规划和政策,围绕核心技术、顶尖人才、标准规范等强化部署,力图在新一轮国际科技竞争中掌握主导权。当前,',
'中国和美国和日本和法国和加拿大和澳大利亚的首都分别是哪里?']
for sample in samples:
generate_one_sample(sample, tokenizer, model_predict, config)
def run_predict(args_opt):
ts = TimeSpan.start('run_predict')
if args_opt.stage_num > 1:
run_predict_pipeline(args_opt)
else:
run_predict_no_pipeline(args_opt)
ts.end()