@@ -3,6 +3,8 @@ from pathlib import Path
from hydra import compose, initialize
from omegaconf import DictConfig, OmegaConf
from transformers import BloomTokenizerFast, BloomConfig
import sacrebleu
from coati.models.bloom import BLOOMRM, BLOOMLM
@@ -37,19 +39,19 @@ class Instructor:
OmegaConf.resolve(cfg)
print(OmegaConf.to_yaml(cfg))
self.cfg = cfg
self.infer_ckpt = Path(cfg.ckpt_dir) / cfg.ckpt.family / cfg.ckpt.infer
self.tokenizer = BloomTokenizerFast.from_pretrained(self.infer_ckpt)
self.infer_model = BLOOMLM(self.infer_ckpt)
self.ckpt_path = Path(cfg.ckpt_dir) / cfg.ckpt.family / cfg.ckpt.base
self.tokenizer = BloomTokenizerFast.from_pretrained(self.ckpt_path)
config = BloomConfig.from_json_file(Path(self.ckpt_path) / 'config.json')
self.infer_model = BLOOMLM(config=config)
self.infer_model.load_state_dict(torch.load(Path(self.ckpt_path) / self.cfg.ckpt.infer), strict=True)
self.infer_model.to('cuda')
# self.infer_model.configure_sharded_model()
# self.infer_model = AutoModelForCausalLM.from_pretrained(self.infer_ckpt, torch_dtype="auto", device_map="auto")
if self.cfg.ckpt.reward:
reward_ckpt = Path(cfg.ckpt_dir) / cfg.ckpt.family / cfg.ckpt.reward
config = BloomConfig.from_json_file(Path(reward_ckpt) / 'config.json')
self.reward_model = BLOOMRM(config=config)
self.reward_model.load_state_dict(torch.load(f'{reward_ckpt}/pytorch_model.bin' ), strict=True)
self.reward_model.load_state_dict(torch.load(Path(self.ckpt_path) / self.cfg.ckpt.reward), strict=True)
self.reward_model.to('cuda')
Path(self.cfg.hyp_file).parent.mkdir(parents=True, exist_ok=True)
self.hyps = []
self.refs = []
def build_translation_prompts(self, inputs):
src_lang = iso2language[self.cfg.info.src]
@@ -91,6 +93,7 @@ class Instructor:
for line in f1.readlines():
line = line.strip()
prompt = self.build_translation_prompts(line)
prompt = PROMPT_DICT["prompt_no_input"].format_map({'instruction': prompt})
inputs = self.tokenizer.encode(prompt, return_tensors="pt").to("cuda")
num_input_tokens = inputs.size(dim=1)
outputs = self.infer_model.model.generate(inputs, max_new_tokens=200, num_beams=5, num_return_sequences=5)
@@ -110,6 +113,21 @@ class Instructor:
else:
func = getattr(self, 'bloomlm_translation', None)
func()
def compute_bleu(self):
with open(self.cfg.hyp_file, 'r', encoding='utf-8') as hyp_file:
for line in hyp_file.readlines():
self.hyps.append(line)
with open(self.cfg.info.src_file, 'r', encoding='utf-8') as f:
for line in f.readlines():
self.refs.append(line)
if self.cfg.info.tgt in ["lo","th","zh"]:
bleu = sacrebleu.corpus_bleu(self.hyps, [self.refs], q2b=self.cfg.info.q2b, strict_bp=self.cfg.info.strict_bp, tokenize=self.cfg.info.tgt).score
else:
bleu = sacrebleu.corpus_bleu(self.hyps, [self.refs], q2b=self.cfg.info.q2b, strict_bp=self.cfg.info.strict_bp, tokenize="13a").score
print(f"system : {self.cfg.ckpt.identity}, direction: {self.cfg.info.src}->{self.cfg.info.tgt}, bleu score: {bleu}")
if __name__ == '__main__':
@@ -117,3 +135,4 @@ if __name__ == '__main__':
cfg = compose(config_name='config')
ins = Instructor(cfg)
ins.inference()
ins.compute_bleu()