2 Commits

Author SHA1 Message Date
  yangb de52001cca add conf 2 years ago
  yangb eb8a1aadd2 reconstruct infer 2 years ago
4 changed files with 36 additions and 10 deletions
Split View
  1. +27
    -8
      applications/Chat/infer.py
  2. +5
    -0
      applications/Chat/infer_conf/ckpt/bloom_7b1_sft.yaml
  3. +2
    -2
      applications/Chat/infer_conf/config.yaml
  4. +2
    -0
      applications/Chat/infer_conf/info/zh_vi_idea.yaml

+ 27
- 8
applications/Chat/infer.py View File

@@ -3,6 +3,8 @@ from pathlib import Path
from hydra import compose, initialize
from omegaconf import DictConfig, OmegaConf
from transformers import BloomTokenizerFast, BloomConfig
import sacrebleu

from coati.models.bloom import BLOOMRM, BLOOMLM


@@ -37,19 +39,19 @@ class Instructor:
OmegaConf.resolve(cfg)
print(OmegaConf.to_yaml(cfg))
self.cfg = cfg
self.infer_ckpt = Path(cfg.ckpt_dir) / cfg.ckpt.family / cfg.ckpt.infer
self.tokenizer = BloomTokenizerFast.from_pretrained(self.infer_ckpt)
self.infer_model = BLOOMLM(self.infer_ckpt)
self.ckpt_path = Path(cfg.ckpt_dir) / cfg.ckpt.family / cfg.ckpt.base
self.tokenizer = BloomTokenizerFast.from_pretrained(self.ckpt_path)
config = BloomConfig.from_json_file(Path(self.ckpt_path) / 'config.json')
self.infer_model = BLOOMLM(config=config)
self.infer_model.load_state_dict(torch.load(Path(self.ckpt_path) / self.cfg.ckpt.infer), strict=True)
self.infer_model.to('cuda')
# self.infer_model.configure_sharded_model()
# self.infer_model = AutoModelForCausalLM.from_pretrained(self.infer_ckpt, torch_dtype="auto", device_map="auto")
if self.cfg.ckpt.reward:
reward_ckpt = Path(cfg.ckpt_dir) / cfg.ckpt.family / cfg.ckpt.reward
config = BloomConfig.from_json_file(Path(reward_ckpt) / 'config.json')
self.reward_model = BLOOMRM(config=config)
self.reward_model.load_state_dict(torch.load(f'{reward_ckpt}/pytorch_model.bin'), strict=True)
self.reward_model.load_state_dict(torch.load(Path(self.ckpt_path) / self.cfg.ckpt.reward), strict=True)
self.reward_model.to('cuda')
Path(self.cfg.hyp_file).parent.mkdir(parents=True, exist_ok=True)
self.hyps = []
self.refs = []

def build_translation_prompts(self, inputs):
src_lang = iso2language[self.cfg.info.src]
@@ -91,6 +93,7 @@ class Instructor:
for line in f1.readlines():
line = line.strip()
prompt = self.build_translation_prompts(line)
prompt = PROMPT_DICT["prompt_no_input"].format_map({'instruction': prompt})
inputs = self.tokenizer.encode(prompt, return_tensors="pt").to("cuda")
num_input_tokens = inputs.size(dim=1)
outputs = self.infer_model.model.generate(inputs, max_new_tokens=200, num_beams=5, num_return_sequences=5)
@@ -110,6 +113,21 @@ class Instructor:
else:
func = getattr(self, 'bloomlm_translation', None)
func()
def compute_bleu(self):
with open(self.cfg.hyp_file, 'r', encoding='utf-8') as hyp_file:
for line in hyp_file.readlines():
self.hyps.append(line)
with open(self.cfg.info.src_file, 'r', encoding='utf-8') as f:
for line in f.readlines():
self.refs.append(line)
if self.cfg.info.tgt in ["lo","th","zh"]:
bleu = sacrebleu.corpus_bleu(self.hyps, [self.refs], q2b=self.cfg.info.q2b, strict_bp=self.cfg.info.strict_bp, tokenize=self.cfg.info.tgt).score
else:
bleu = sacrebleu.corpus_bleu(self.hyps, [self.refs], q2b=self.cfg.info.q2b, strict_bp=self.cfg.info.strict_bp, tokenize="13a").score
print(f"system : {self.cfg.ckpt.identity}, direction: {self.cfg.info.src}->{self.cfg.info.tgt}, bleu score: {bleu}")



if __name__ == '__main__':
@@ -117,3 +135,4 @@ if __name__ == '__main__':
cfg = compose(config_name='config')
ins = Instructor(cfg)
ins.inference()
ins.compute_bleu()

+ 5
- 0
applications/Chat/infer_conf/ckpt/bloom_7b1_sft.yaml View File

@@ -0,0 +1,5 @@
family: bloom
base: bloom_7b1
infer: bloom_7b1_zh2vi_4w_sft.bin
reward:
identity: bloom_7b1_zh2vi_4w_sft

+ 2
- 2
applications/Chat/infer_conf/config.yaml View File

@@ -1,7 +1,7 @@
defaults:
- info: zh_vi_idea
- ckpt: bloomz_1b1_sft
- ckpt: bloom_7b1_sft
- _self_
prompt: text
hyp_file: results/${info.testset}/${ckpt.identity}/prompt_${.prompt}_${info.src}2${info.tgt}.hyp.sft.new_prompt
hyp_file: results/${info.testset}/infer_{ckpt.infer}_reward_{ckpt.reward}_prompt_${.prompt}_${info.src}2${info.tgt}.hyp
ckpt_dir: /userhome/rl4mt/checkpoints

+ 2
- 0
applications/Chat/infer_conf/info/zh_vi_idea.yaml View File

@@ -1,4 +1,6 @@
src: zh
tgt: vi
q2b: True
strict-bp: False
testset: idea
src_file: "/userhome/rl4mt/data/${.testset}/${.src}.test"

Loading…
Cancel
Save
Baidu
map