3 Commits

Author SHA1 Message Date
  Myhs_phz 5c0213de89
[Fix] fix OpenAISDKRollout and dump-res-length (#2351) 1 day ago
  Myhs_phz 2667d4d0ec
[Dataset] add dataset SciReasoner (#2360) 1 day ago
  zhulinJulia24 d836b49fee
[ci] add v1.8 new datasets (#2358) 1 day ago
84 changed files with 13987 additions and 48 deletions
Split View
  1. +65
    -0
      .github/scripts/eval_regression_api_rollout.py
  2. +1
    -1
      .github/scripts/eval_regression_base_models.py
  3. +2
    -2
      .github/scripts/eval_regression_chat_models.py
  4. +91
    -0
      .github/scripts/eval_regression_chat_obj_fullbench_v8.py
  5. +14
    -0
      .github/scripts/oc_score_assert.py
  6. +76
    -0
      .github/scripts/oc_score_baseline_fullbench.yaml
  7. +3
    -3
      .github/scripts/oc_score_baseline_testrange.yaml
  8. +24
    -22
      .github/workflows/daily-run-test.yml
  9. +157
    -0
      examples/eval_scireasoner.py
  10. +77
    -0
      opencompass/configs/datasets/SciReasoner/GUE_gen.py
  11. +290
    -0
      opencompass/configs/datasets/SciReasoner/LLM4Mat_gen.py
  12. +49
    -0
      opencompass/configs/datasets/SciReasoner/UMG.py
  13. +67
    -0
      opencompass/configs/datasets/SciReasoner/UPG.py
  14. +75
    -0
      opencompass/configs/datasets/SciReasoner/bio_instruction_gen.py
  15. +52
    -0
      opencompass/configs/datasets/SciReasoner/bulk_modulus_material_gen.py
  16. +64
    -0
      opencompass/configs/datasets/SciReasoner/composition_material_gen.py
  17. +110
    -0
      opencompass/configs/datasets/SciReasoner/mol_biotext_gen.py
  18. +63
    -0
      opencompass/configs/datasets/SciReasoner/mol_molecule_gen.py
  19. +83
    -0
      opencompass/configs/datasets/SciReasoner/mol_protein_gen.py
  20. +83
    -0
      opencompass/configs/datasets/SciReasoner/opi_gen.py
  21. +99
    -0
      opencompass/configs/datasets/SciReasoner/peer_gen.py
  22. +74
    -0
      opencompass/configs/datasets/SciReasoner/retrosynthesis_USPTO_gen.py
  23. +55
    -0
      opencompass/configs/datasets/SciReasoner/scireasoner_gen.py
  24. +120
    -0
      opencompass/configs/datasets/SciReasoner/smol_gen.py
  25. +51
    -0
      opencompass/configs/datasets/SciReasoner/unconditional_RNA_gen.py
  26. +50
    -0
      opencompass/configs/datasets/SciReasoner/unconditional_material_gen.py
  27. +210
    -0
      opencompass/datasets/SciReasoner/GUE.py
  28. +12
    -0
      opencompass/datasets/SciReasoner/LLM4Chem/__init__.py
  29. +166
    -0
      opencompass/datasets/SciReasoner/LLM4Chem/config.py
  30. +228
    -0
      opencompass/datasets/SciReasoner/LLM4Chem/evaluator.py
  31. +449
    -0
      opencompass/datasets/SciReasoner/LLM4Chem/retrosynthesis_evaluator.py
  32. +1
    -0
      opencompass/datasets/SciReasoner/LLM4Chem/utils/__input__.py
  33. +12
    -0
      opencompass/datasets/SciReasoner/LLM4Chem/utils/chat_generation.py
  34. +195
    -0
      opencompass/datasets/SciReasoner/LLM4Chem/utils/core_tagger.py
  35. +35
    -0
      opencompass/datasets/SciReasoner/LLM4Chem/utils/general_prompter.py
  36. +685
    -0
      opencompass/datasets/SciReasoner/LLM4Chem/utils/metrics.py
  37. +189
    -0
      opencompass/datasets/SciReasoner/LLM4Chem/utils/smiles_canonicalization.py
  38. +217
    -0
      opencompass/datasets/SciReasoner/LLM4Mat.py
  39. +15
    -0
      opencompass/datasets/SciReasoner/Mol_Instructions/__init__.py
  40. +331
    -0
      opencompass/datasets/SciReasoner/Mol_Instructions/biotext.py
  41. +458
    -0
      opencompass/datasets/SciReasoner/Mol_Instructions/molecule.py
  42. +150
    -0
      opencompass/datasets/SciReasoner/Mol_Instructions/normalized_SW_score.py
  43. +155
    -0
      opencompass/datasets/SciReasoner/Mol_Instructions/protein.py
  44. +471
    -0
      opencompass/datasets/SciReasoner/PEER.py
  45. +13
    -0
      opencompass/datasets/SciReasoner/__init__.py
  46. +1440
    -0
      opencompass/datasets/SciReasoner/bio_instruction.py
  47. +173
    -0
      opencompass/datasets/SciReasoner/bulk_modulus_material.py
  48. +228
    -0
      opencompass/datasets/SciReasoner/composition_material.py
  49. +7
    -0
      opencompass/datasets/SciReasoner/opi/__init__.py
  50. +136
    -0
      opencompass/datasets/SciReasoner/opi/config.py
  51. +289
    -0
      opencompass/datasets/SciReasoner/opi/evaluator.py
  52. +62
    -0
      opencompass/datasets/SciReasoner/opi/process_ec_numbers.py
  53. +45
    -0
      opencompass/datasets/SciReasoner/opi/utils/accuracy4fold_type.py
  54. +143
    -0
      opencompass/datasets/SciReasoner/opi/utils/metrics4all.py
  55. +128
    -0
      opencompass/datasets/SciReasoner/uncond_RNA.py
  56. +80
    -0
      opencompass/datasets/SciReasoner/uncond_material.py
  57. +123
    -0
      opencompass/datasets/SciReasoner/unconditional_molecule_generation/UMG.py
  58. +1
    -0
      opencompass/datasets/SciReasoner/unconditional_molecule_generation/__init__.py
  59. +195
    -0
      opencompass/datasets/SciReasoner/unconditional_protein_generation/UPG.py
  60. +1
    -0
      opencompass/datasets/SciReasoner/unconditional_protein_generation/__init__.py
  61. +71
    -0
      opencompass/datasets/SciReasoner/unconditional_protein_generation/main.py
  62. +38
    -0
      opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/__init__.py
  63. +104
    -0
      opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/__main__.py
  64. +152
    -0
      opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/confidence.py
  65. +118
    -0
      opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/config.py
  66. +371
    -0
      opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/decode.py
  67. +384
    -0
      opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/embedders.py
  68. +174
    -0
      opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/geoformer.py
  69. +248
    -0
      opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/model.py
  70. +636
    -0
      opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/modules.py
  71. +233
    -0
      opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/omegaplm.py
  72. +424
    -0
      opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/pipeline.py
  73. +51
    -0
      opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/__init__.py
  74. +38
    -0
      opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/protein_utils/__init__.py
  75. +936
    -0
      opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/protein_utils/aaframe.py
  76. +149
    -0
      opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/protein_utils/functions.py
  77. +686
    -0
      opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/protein_utils/residue_constants.py
  78. +147
    -0
      opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/torch_utils.py
  79. +1
    -0
      opencompass/datasets/__init__.py
  80. +23
    -13
      opencompass/models/openai_api.py
  81. +7
    -4
      opencompass/openicl/icl_inferencer/icl_gen_inferencer.py
  82. +9
    -3
      opencompass/tasks/openicl_eval.py
  83. +42
    -0
      opencompass/utils/datasets_info.py
  84. +7
    -0
      requirements/extra.txt

+ 65
- 0
.github/scripts/eval_regression_api_rollout.py View File

@@ -0,0 +1,65 @@
from mmengine.config import read_base

from opencompass.models import OpenAISDKRollout, TurboMindModelwithChatTemplate
from opencompass.utils.text_postprocessors import extract_non_reasoning_content

with read_base():
from opencompass.configs.datasets.aime2025.aime2025_cascade_eval_gen_5e9f4f import \
aime2025_datasets # noqa: F401, E501

datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])

num_repeat = 2
for item in datasets:
item['abbr'] += f'_rollout_rep{num_repeat}'
item['n'] = num_repeat

api_meta_template = dict(
round=[
dict(role='HUMAN', api_role='HUMAN'),
dict(role='BOT', api_role='BOT', generate=True),
],
reserved_roles=[dict(role='SYSTEM', api_role='SYSTEM')],
)

models = [
dict(abbr='lmdeploy-api-test-rollout',
type=OpenAISDKRollout,
key='EMPTY',
openai_api_base='http://localhost:23333/v1',
path='Qwen/Qwen3-8B',
tokenizer_path='Qwen/Qwen3-8B',
rpm_verbose=True,
meta_template=api_meta_template,
query_per_second=128,
max_out_len=1024,
max_seq_len=4096,
temperature=0.01,
batch_size=128,
retry=20,
logprobs=True,
top_logprobs=5,
extra_body=dict(top_k=20),
openai_extra_kwargs=dict(top_p=0.95, ),
pred_postprocessor=dict(type=extract_non_reasoning_content)),
]

obj_judge_model = dict(type=TurboMindModelwithChatTemplate,
abbr='qwen-3-8b-fullbench',
path='Qwen/Qwen3-8B',
engine_config=dict(session_len=46000,
max_batch_size=1,
tp=1),
gen_config=dict(do_sample=False, enable_thinking=False),
max_seq_len=46000,
max_out_len=46000,
batch_size=1,
run_cfg=dict(num_gpus=1))

for d in datasets:
if 'judge_cfg' in d['eval_cfg']['evaluator']:
d['eval_cfg']['evaluator']['judge_cfg'] = obj_judge_model
if 'llm_evaluator' in d['eval_cfg']['evaluator'] and 'judge_cfg' in d[
'eval_cfg']['evaluator']['llm_evaluator']:
d['eval_cfg']['evaluator']['llm_evaluator'][
'judge_cfg'] = obj_judge_model

+ 1
- 1
.github/scripts/eval_regression_base_models.py View File

@@ -38,7 +38,7 @@ models = [
model_kwargs=dict(tensor_parallel_size=1, gpu_memory_utilization=0.6),
max_seq_len=8192,
max_out_len=2048,
batch_size=16,
batch_size=1,
generation_kwargs=dict(temperature=0),
run_cfg=dict(num_gpus=1),
),


+ 2
- 2
.github/scripts/eval_regression_chat_models.py View File

@@ -41,10 +41,10 @@ Qwen3_0_6B_FP8_vllm = dict(
abbr='qwen3-0_6b-fp8-vllm',
path='Qwen/Qwen3-0.6B-FP8',
model_kwargs=dict(tensor_parallel_size=1),
generation_kwargs=dict(do_sample=False), # greedy
generation_kwargs=dict(temperature=0), # greedy
max_seq_len=32768,
max_out_len=16384,
batch_size=16,
batch_size=1,
run_cfg=dict(num_gpus=1),
)



+ 91
- 0
.github/scripts/eval_regression_chat_obj_fullbench_v8.py View File

@@ -0,0 +1,91 @@
from mmengine.config import read_base

from opencompass.models import (HuggingFacewithChatTemplate,
TurboMindModelwithChatTemplate)
from opencompass.utils.text_postprocessors import extract_non_reasoning_content

with read_base():
# Datasets
from opencompass.configs.datasets.atlas.atlas_val_gen_b2d1b6 import \
atlas_datasets # noqa: F401, E501
from opencompass.configs.datasets.biodata.biodata_task_gen import \
biodata_task_datasets # noqa: F401, E501
from opencompass.configs.datasets.CMPhysBench.cmphysbench_gen import \
cmphysbench_datasets # noqa: F401, E501
from opencompass.configs.datasets.MolInstructions_chem.mol_instructions_chem_gen import \
mol_gen_selfies_datasets # noqa: F401, E501
from opencompass.configs.datasets.openswi.openswi_gen import \
openswi_datasets # noqa: F401, E501

from ...rjob import eval, infer # noqa: F401, E501

datasets = [
*atlas_datasets, *biodata_task_datasets, *cmphysbench_datasets,
*mol_gen_selfies_datasets, *openswi_datasets
]

for d in datasets:
if 'n' in d:
d['n'] = 1
if 'reader_cfg' in d:
d['reader_cfg']['test_range'] = '[0:16]'
else:
d['test_range'] = '[0:16]'
if 'eval_cfg' in d and 'dataset_cfg' in d['eval_cfg'][
'evaluator'] and 'reader_cfg' in d['eval_cfg']['evaluator'][
'dataset_cfg']:
d['eval_cfg']['evaluator']['dataset_cfg']['reader_cfg'][
'test_range'] = '[0:16]'
if 'eval_cfg' in d and 'llm_evaluator' in d['eval_cfg'][
'evaluator'] and 'dataset_cfg' in d['eval_cfg']['evaluator'][
'llm_evaluator']:
d['eval_cfg']['evaluator']['llm_evaluator']['dataset_cfg'][
'reader_cfg']['test_range'] = '[0:16]'

hf_model = dict(type=HuggingFacewithChatTemplate,
abbr='qwen-3-8b-hf-fullbench',
path='Qwen/Qwen3-8B',
max_out_len=8192,
batch_size=8,
run_cfg=dict(num_gpus=1),
pred_postprocessor=dict(type=extract_non_reasoning_content))

tm_model = dict(type=TurboMindModelwithChatTemplate,
abbr='qwen-3-8b-fullbench',
path='Qwen/Qwen3-8B',
engine_config=dict(session_len=32768, max_batch_size=1, tp=1),
gen_config=dict(do_sample=False, enable_thinking=True),
max_seq_len=32768,
max_out_len=32768,
batch_size=1,
run_cfg=dict(num_gpus=1),
pred_postprocessor=dict(type=extract_non_reasoning_content))

models = [hf_model, tm_model]

models = sorted(models, key=lambda x: x['run_cfg']['num_gpus'])

obj_judge_model = dict(
type=TurboMindModelwithChatTemplate,
abbr='qwen-3-8b-fullbench',
path='Qwen/Qwen3-8B',
engine_config=dict(session_len=46000, max_batch_size=1, tp=1),
gen_config=dict(do_sample=False, enable_thinking=True),
max_seq_len=46000,
max_out_len=46000,
batch_size=1,
run_cfg=dict(num_gpus=1),
pred_postprocessor=dict(type=extract_non_reasoning_content))

for d in datasets:
if 'eval_cfg' in d and 'evaluator' in d['eval_cfg']:
if 'atlas' in d['abbr'] and 'judge_cfg' in d['eval_cfg']['evaluator']:
d['eval_cfg']['evaluator']['judge_cfg'] = dict(
judgers=[obj_judge_model])
elif 'judge_cfg' in d['eval_cfg']['evaluator']:
d['eval_cfg']['evaluator']['judge_cfg'] = obj_judge_model
elif 'llm_evaluator' in d['eval_cfg'][
'evaluator'] and 'judge_cfg' in d[ # noqa
'eval_cfg']['evaluator']['llm_evaluator']: # noqa
d['eval_cfg']['evaluator']['llm_evaluator'][
'judge_cfg'] = obj_judge_model

+ 14
- 0
.github/scripts/oc_score_assert.py View File

@@ -139,6 +139,18 @@ class TestChatFullbench:
result_score = result_scores.get(model).get(dataset)
assert_score(model, result_score, base_score, dataset)

@pytest.mark.chat_obj_fullbench_v8
@pytest.mark.parametrize(
'model, dataset',
[(p1, p2) for p1 in ['qwen-3-8b-hf-fullbench', 'qwen-3-8b-fullbench']
for p2 in dataset_list('qwen-3-8b-hf-fullbench', 'objective_v8')])
def test_chat_obj_v8(self, baseline_scores_fullbench, result_scores, model,
dataset):
base_score = baseline_scores_fullbench.get(model).get(
'objective_v8').get(dataset, None)
result_score = result_scores.get(model).get(dataset, None)
assert_score(model, result_score, base_score, dataset)

@pytest.mark.chat_obj_fullbench_other
@pytest.mark.parametrize(
'model, dataset',
@@ -325,6 +337,8 @@ class TestCmdCase:


def assert_score(model_type, score, baseline, dataset: str = ''):
if baseline is None:
return
if score is None or score == '-':
assert False, 'value is none'



+ 76
- 0
.github/scripts/oc_score_baseline_fullbench.yaml View File

@@ -314,6 +314,44 @@ qwen-3-8b-hf-fullbench:
srbench_SymbolicMatch: 0.18
supergpqa_Electronic_Science_and_Technology_accuracy: 68.75
lcb_code_generation_v6_pass@1: 37.5
objective_v8:
atlas-val_accuracy: 0
DNA-cpd-sample_1k_mcc: -46.67
DNA-emp-sample_1k_mcc: -65.47
DNA-pd-sample_1k_mcc: -87.83
DNA-tf-h-sample_1k_mcc: -9.76
DNA-tf-m-sample_1k_mcc: -7.27
Multi_sequence-antibody_antigen-sample_1k_mcc: -15.29
Multi_sequence-promoter_enhancer_interaction-sample_1k_mcc: -38.3
Multi_sequence-rna_protein_interaction-sample_1k_mcc: -48.8
DNA-enhancer_activity-sample_1k_pcc: 7.27
RNA-CRISPROnTarget-sample_1k_spearman: null
Protein-Fluorescence-sample_1k_spearman: 20.00
Protein-Stability-sample_1k_spearman: 28.53
Protein-Thermostability-sample_1k_spearman: -18.45
RNA-Isoform-sample_1k_r^2: 6.64
RNA-MeanRibosomeLoading-sample_1k_r^2: 20.35
RNA-ProgrammableRNASwitches-sample_1k_r^2: 9.23
RNA-Modification-sample_1k_auc: 50.15
Protein-Solubility-sample_1k_acc: 31.25
RNA-NoncodingRNAFamily-sample_1k_acc: 18.75
Protein-FunctionEC-sample_1k_score: 3
Multi_sequence-sirnaEfficiency-sample_1k_Mixed: 33.13
CMPhysBench-fix_prompt_score: 30.16
RP-selfies_score: 3.2
RP-selfies_valid_score: 50
MG-selfies_score: 0.8
MG-selfies_valid_score: 25
FS-selfies_score: 30.13
FS-selfies_valid_score: 87.5
RS-selfies_score: 12.09
RS-selfies_valid_score: 50
PP-selfies_score: 3.64
MC-selfies_score: 0.14
OpenSWI-shallow-1k_score: 394.16
OpenSWI-shallow-1k_valid: 18.75
OpenSWI-deep-1k_score: 1069.07
OpenSWI-deep-1k_valid: 6.25
chat_subjective:
alignment_bench_v1_1_总分: 0.46
arenahard_score: 100
@@ -664,6 +702,44 @@ qwen-3-8b-fullbench:
srbench_SymbolicMatch: 0.14
supergpqa_Electronic_Science_and_Technology_accuracy: 56.25
lcb_code_generation_v6_pass@1: 50
objective_v8:
atlas-val_accuracy: 12.5
DNA-cpd-sample_1k_mcc: -28.94
DNA-emp-sample_1k_mcc: -42.86
DNA-pd-sample_1k_mcc: -51.64
DNA-tf-h-sample_1k_mcc: -46.67
DNA-tf-m-sample_1k_mcc: 5.1
Multi_sequence-antibody_antigen-sample_1k_mcc: -41.82
Multi_sequence-promoter_enhancer_interaction-sample_1k_mcc: -15.29
Multi_sequence-rna_protein_interaction-sample_1k_mcc: -33.33
DNA-enhancer_activity-sample_1k_pcc: 28.51
RNA-CRISPROnTarget-sample_1k_spearman: -22.28
Protein-Fluorescence-sample_1k_spearman: null
Protein-Stability-sample_1k_spearman: 11.09
Protein-Thermostability-sample_1k_spearman: 31.25
RNA-Isoform-sample_1k_r^2: 2.72
RNA-MeanRibosomeLoading-sample_1k_r^2: 30.35
RNA-ProgrammableRNASwitches-sample_1k_r^2:
RNA-Modification-sample_1k_auc: 46.15
Protein-Solubility-sample_1k_acc: 37.5
RNA-NoncodingRNAFamily-sample_1k_acc: 6.25
Protein-FunctionEC-sample_1k_score: 3
Multi_sequence-sirnaEfficiency-sample_1k_Mixed: 34.17
CMPhysBench-fix_prompt_score: 59.22
RP-selfies_score: 2.86
RP-selfies_valid_score: 68.75
MG-selfies_score: 1.31
MG-selfies_valid_score: 25
FS-selfies_score: 27.68
FS-selfies_valid_score: 93.75
RS-selfies_score: 16.55
RS-selfies_valid_score: 75
PP-selfies_score: 4.43
MC-selfies_score: 0.13
OpenSWI-shallow-1k_score: 284.37
OpenSWI-shallow-1k_valid: 6.25
OpenSWI-deep-1k_score: 861.96
OpenSWI-deep-1k_valid: 0
chat_longtext:
babilong_qa1_256k_score: 0.00
LongBench_2wikimqa_score: 5.43


+ 3
- 3
.github/scripts/oc_score_baseline_testrange.yaml View File

@@ -12,10 +12,10 @@ base:
race-high_accuracy: 62.50
winogrande_accuracy: 71.88
qwen3-8b-base-vllm:
gsm8k_accuracy: 50.00
GPQA_diamond_accuracy: 21.88
gsm8k_accuracy: 53.12
GPQA_diamond_accuracy: 18.75
race-high_accuracy: 59.38
winogrande_accuracy: 68.75
winogrande_accuracy: 71.88
qwen3-8b-base-hf:
gsm8k_accuracy: 50.00
GPQA_diamond_accuracy: 18.75


+ 24
- 22
.github/workflows/daily-run-test.yml View File

@@ -22,7 +22,7 @@ on:
required: true
description: 'regression functions'
type: string
default: "['chat_models','base_models','chat_obj_fullbench_v5', 'chat_obj_fullbench_v6', 'chat_obj_fullbench_v7', 'chat_obj_fullbench_other','chat_sub_fullbench','base_fullbench','base_longtext_fullbench','chat_longtext_fullbench']"
default: "['chat_models','base_models','chat_obj_fullbench_v5', 'chat_obj_fullbench_v6', 'chat_obj_fullbench_v7', 'chat_obj_fullbench_v8', 'chat_obj_fullbench_other','chat_sub_fullbench','base_fullbench','base_longtext_fullbench','chat_longtext_fullbench']"
baseline_result:
required: true
description: 'baseline result'
@@ -88,6 +88,8 @@ jobs:
runs-on: yidian_cu12_daily
timeout-minutes: 180 #3hours
steps:
- name: Clean workdir
run: sudo git clean -ffdx
- name: Clone repository
uses: actions/checkout@v5
with:
@@ -142,6 +144,8 @@ jobs:
runs-on: yidian_cu12_daily
timeout-minutes: 240 #4hours
steps:
- name: Clean workdir
run: sudo git clean -ffdx
- name: Clone repository
uses: actions/checkout@v5
with:
@@ -163,7 +167,7 @@ jobs:
. ${{env.CONDA_PATH}}/bin/activate
conda activate ${{env.CONDA_ENV}}
conda info --envs
opencompass .github/scripts/eval_regression_${{matrix.regression_func}}.py --work-dir ${{env.REPORT_ROOT}}/${{ github.run_id }}/${{matrix.regression_func}} --reuse
opencompass .github/scripts/eval_regression_${{matrix.regression_func}}.py --work-dir ${{env.REPORT_ROOT}}/${{ github.run_id }}/${{matrix.regression_func}} --reuse --dump-res-length
- name: Run test - other
if: matrix.regression_func == 'chat_obj_fullbench_other'
env:
@@ -173,7 +177,7 @@ jobs:
. ${{env.CONDA_PATH}}/bin/activate
conda activate ${{env.CONDA_ENV}}
conda info --envs
opencompass .github/scripts/eval_regression_${{matrix.regression_func}}.py --work-dir ${{env.REPORT_ROOT}}/${{ github.run_id }}/${{matrix.regression_func}} --reuse
opencompass .github/scripts/eval_regression_${{matrix.regression_func}}.py --work-dir ${{env.REPORT_ROOT}}/${{ github.run_id }}/${{matrix.regression_func}} --reuse --dump-res-length
- name: Run test - other
if: matrix.regression_func == 'chat_sub_fullbench'
env:
@@ -183,7 +187,7 @@ jobs:
. ${{env.CONDA_PATH}}/bin/activate
conda activate ${{env.CONDA_ENV}}
conda info --envs
opencompass .github/scripts/eval_regression_${{matrix.regression_func}}.py --work-dir ${{env.REPORT_ROOT}}/${{ github.run_id }}/${{matrix.regression_func}} --reuse
opencompass .github/scripts/eval_regression_${{matrix.regression_func}}.py --work-dir ${{env.REPORT_ROOT}}/${{ github.run_id }}/${{matrix.regression_func}} --reuse --dump-res-length
- name: Assert result
run: |
. ${{env.CONDA_PATH}}/bin/activate
@@ -192,10 +196,6 @@ jobs:
rm regression_result_daily -f && ln -s ${{env.REPORT_ROOT}}/${{ github.run_id }}/${{matrix.regression_func}}/*/summary regression_result_daily
python -m pytest -m ${{matrix.regression_func}} -s -v --color=yes .github/scripts/oc_score_assert.py || true
python .github/scripts/compare_results.py compare_results ${{env.REPORT_ROOT}}/${{ github.run_id }}/${{matrix.regression_func}} ${{env.REPORT_ROOT}}/${{env.BASELINE_DIR}}/${{matrix.regression_func}}
- name: Change code permission
if: always()
run: |
sudo chmod -R 777 .


daily_run_cmd:
@@ -204,6 +204,8 @@ jobs:
runs-on: yidian_cu12_daily
timeout-minutes: 240 #4hours
steps:
- name: Clean workdir
run: sudo git clean -ffdx
- name: Clone repository
uses: actions/checkout@v5
with:
@@ -220,7 +222,7 @@ jobs:
. ${{env.CONDA_PATH}}/bin/activate
conda activate ${{env.CONDA_ENV}}
conda info --envs
rjob submit --name=cmd-${{ env.JOB_NAME }} --charged-group=opencompass_gpu --private-machine=group --group=opencompass_gpu --gpu=2 --cpu=32 --memory=32568 --private-machine=group --image=registry.h.pjlab.org.cn/ailab-puyu/xpuyu:torch-2.6.0-45d96d5f-0607 --env=COMPASS_DATA_CACHE=/mnt/shared-storage-user/auto-eval-pipeline/opencompass/llmeval/compass_data_cache --env=TIKTOKEN_CACHE_DIR=/mnt/shared-storage-user/auto-eval-pipeline/opencompass/llmeval/share_tiktoken --env=HF_ENDPOINT=https://hf-mirror.com --env=HF_DATASETS_CACHE=/mnt/shared-storage-user/auto-eval-pipeline/qa-llm-cicd/hf_cache --env=HF_HUB_CACHE=/mnt/shared-storage-user/large-model-center-share-weights/hf_hub --env=CUDA_MODULE_LOADING=EAGER --env=HF_DATASETS_OFFLINE=1 --env=TRANSFORMERS_OFFLINE=1 --env=HF_EVALUATE_OFFLINE=1 --env=HF_HUB_OFFLINE=1 --env=VLLM_USE_MODELSCOPE=false --env=VLLM_WORKER_MULTIPROC_METHOD=spawn --mount=gpfs://gpfs1/opencompass-shared:/mnt/shared-storage-user/opencompass-shared --mount=gpfs://gpfs1/auto-eval-pipeline:/mnt/shared-storage-user/auto-eval-pipeline --mount=gpfs://gpfs1/large-model-center-share-weights:/mnt/shared-storage-user/large-model-center-share-weights --host-network=True -- bash -exc '/mnt/shared-storage-user/opencompass-shared/qa-llm-cicd/daily_cmd_test.sh ${{env.REPORT_ROOT}}/${{ github.run_id }}'
rjob submit --metadata-name=cmd-${{ env.JOB_NAME }} --charged-group=opencompass_gpu --private-machine=group --group=opencompass_gpu --gpu=2 --cpu=32 --memory=32568 --private-machine=group --image=registry.h.pjlab.org.cn/ailab-puyu/xpuyu:torch-2.6.0-45d96d5f-0607 --env=COMPASS_DATA_CACHE=/mnt/shared-storage-user/auto-eval-pipeline/opencompass/llmeval/compass_data_cache --env=TIKTOKEN_CACHE_DIR=/mnt/shared-storage-user/auto-eval-pipeline/opencompass/llmeval/share_tiktoken --env=HF_ENDPOINT=https://hf-mirror.com --env=HF_DATASETS_CACHE=/mnt/shared-storage-user/auto-eval-pipeline/qa-llm-cicd/hf_cache --env=HF_HUB_CACHE=/mnt/shared-storage-user/large-model-center-share-weights/hf_hub --env=CUDA_MODULE_LOADING=EAGER --env=HF_DATASETS_OFFLINE=1 --env=TRANSFORMERS_OFFLINE=1 --env=HF_EVALUATE_OFFLINE=1 --env=HF_HUB_OFFLINE=1 --env=VLLM_USE_MODELSCOPE=false --env=VLLM_WORKER_MULTIPROC_METHOD=spawn --mount=gpfs://gpfs1/opencompass-shared:/mnt/shared-storage-user/opencompass-shared --mount=gpfs://gpfs1/auto-eval-pipeline:/mnt/shared-storage-user/auto-eval-pipeline --mount=gpfs://gpfs1/large-model-center-share-weights:/mnt/shared-storage-user/large-model-center-share-weights --host-network=True -- bash -exc '/mnt/shared-storage-user/opencompass-shared/qa-llm-cicd/daily_cmd_test.sh ${{env.REPORT_ROOT}}/${{ github.run_id }}'

for i in {1..300}; do
current_status=$(rjob get cmd-${{ env.JOB_NAME }} | grep -oP 'rjob [^:]+: \K[^ ]+')
@@ -246,17 +248,19 @@ jobs:
python -m pytest -m case4 -s -v --color=yes .github/scripts/oc_score_assert.py
rm regression_result_daily -f && ln -s ${{env.REPORT_ROOT}}/${{ github.run_id }}/cmd5/*/summary regression_result_daily
python -m pytest -m case5 -s -v --color=yes .github/scripts/oc_score_assert.py
- name: Change code permission
if: always()
run: |
sudo chmod -R 777 .

daily_run_api:
if: ${{!cancelled() && contains(needs.prepare_env.result, 'success') && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_type), 'api'))}}
needs: prepare_env
runs-on: yidian_cu12_daily
strategy:
fail-fast: false
matrix:
regression_func: ["api","api-rollout"]
timeout-minutes: 240 #4hours
steps:
- name: Clean workdir
run: sudo git clean -ffdx
- name: Clone repository
uses: actions/checkout@v5
with:
@@ -273,13 +277,13 @@ jobs:
. ${{env.CONDA_PATH}}/bin/activate
conda activate ${{env.CONDA_ENV}}
conda info --envs
rjob submit --name=api-${{ env.JOB_NAME }} --charged-group=opencompass_gpu --private-machine=group --group=opencompass_gpu --gpu=2 --cpu=32 --memory=32568 --private-machine=group --image=registry.h.pjlab.org.cn/ailab-puyu/xpuyu:torch-2.6.0-45d96d5f-0607 --env=COMPASS_DATA_CACHE=/mnt/shared-storage-user/auto-eval-pipeline/opencompass/llmeval/compass_data_cache --env=TIKTOKEN_CACHE_DIR=/mnt/shared-storage-user/auto-eval-pipeline/opencompass/llmeval/share_tiktoken --env=HF_ENDPOINT=https://hf-mirror.com --env=HF_DATASETS_CACHE=/mnt/shared-storage-user/auto-eval-pipeline/qa-llm-cicd/hf_cache --env=HF_HUB_CACHE=/mnt/shared-storage-user/large-model-center-share-weights/hf_hub --env=CUDA_MODULE_LOADING=EAGER --env=HF_DATASETS_OFFLINE=1 --env=TRANSFORMERS_OFFLINE=1 --env=HF_EVALUATE_OFFLINE=1 --env=HF_HUB_OFFLINE=1 --env=VLLM_USE_MODELSCOPE=false --env=VLLM_WORKER_MULTIPROC_METHOD=spawn --mount=gpfs://gpfs1/opencompass-shared:/mnt/shared-storage-user/opencompass-shared --mount=gpfs://gpfs1/auto-eval-pipeline:/mnt/shared-storage-user/auto-eval-pipeline --mount=gpfs://gpfs1/large-model-center-share-weights:/mnt/shared-storage-user/large-model-center-share-weights --host-network=True -- bash -exc '/mnt/shared-storage-user/opencompass-shared/qa-llm-cicd/daily_api_test.sh ${{env.REPORT_ROOT}}/${{ github.run_id }} ${{env.WORK_PATH}}'
rjob submit --metadata-name=${{matrix.regression_func}}-${{ env.JOB_NAME }} --charged-group=opencompass_gpu --private-machine=group --group=opencompass_gpu --gpu=1 --cpu=32 --memory=32568 --private-machine=group --image=registry.h.pjlab.org.cn/ailab-puyu/xpuyu:torch-2.6.0-45d96d5f-0607 --env=COMPASS_DATA_CACHE=/mnt/shared-storage-user/auto-eval-pipeline/opencompass/llmeval/compass_data_cache --env=TIKTOKEN_CACHE_DIR=/mnt/shared-storage-user/auto-eval-pipeline/opencompass/llmeval/share_tiktoken --env=HF_ENDPOINT=https://hf-mirror.com --env=HF_DATASETS_CACHE=/mnt/shared-storage-user/auto-eval-pipeline/qa-llm-cicd/hf_cache --env=HF_HUB_CACHE=/mnt/shared-storage-user/large-model-center-share-weights/hf_hub --env=CUDA_MODULE_LOADING=EAGER --env=HF_DATASETS_OFFLINE=1 --env=TRANSFORMERS_OFFLINE=1 --env=HF_EVALUATE_OFFLINE=1 --env=HF_HUB_OFFLINE=1 --env=VLLM_USE_MODELSCOPE=false --env=VLLM_WORKER_MULTIPROC_METHOD=spawn --mount=gpfs://gpfs1/opencompass-shared:/mnt/shared-storage-user/opencompass-shared --mount=gpfs://gpfs1/auto-eval-pipeline:/mnt/shared-storage-user/auto-eval-pipeline --mount=gpfs://gpfs1/large-model-center-share-weights:/mnt/shared-storage-user/large-model-center-share-weights --host-network=True -- bash -exc '/mnt/shared-storage-user/opencompass-shared/qa-llm-cicd/daily_${{matrix.regression_func}}_test.sh ${{env.REPORT_ROOT}}/${{ github.run_id }} ${{env.WORK_PATH}}'

for i in {1..300}; do
current_status=$(rjob get api-${{ env.JOB_NAME }} | grep -oP 'rjob [^:]+: \K[^ ]+')
current_status=$(rjob get ${{matrix.regression_func}}-${{ env.JOB_NAME }} | grep -oP 'rjob [^:]+: \K[^ ]+')
if [[ $current_status == "Succeeded" || $current_status == "Failed" || $current_status == "Stopped" ]]; then
echo "Current status: $current_status, stop checking"
rjob logs job api-${{ env.JOB_NAME }}
rjob logs job ${{matrix.regression_func}}-${{ env.JOB_NAME }}
break
fi
sleep 6
@@ -289,9 +293,7 @@ jobs:
. ${{env.CONDA_PATH}}/bin/activate
conda activate ${{env.CONDA_ENV}}
conda info --envs
rm regression_result_daily -f && ln -s ${{env.REPORT_ROOT}}/${{ github.run_id }}/api/*/summary regression_result_daily
python -m pytest -m api -s -v --color=yes .github/scripts/oc_score_assert.py
- name: Change code permission
if: always()
run: |
sudo chmod -R 777 .
rm regression_result_daily -f && ln -s ${{env.REPORT_ROOT}}/${{ github.run_id }}/${{matrix.regression_func}}/*/summary regression_result_daily
python -m pytest -m ${{matrix.regression_func}} -s -v --color=yes .github/scripts/oc_score_assert.py
python .github/scripts/compare_results.py compare_results ${{env.REPORT_ROOT}}/${{ github.run_id }}/${{matrix.regression_func}} ${{env.REPORT_ROOT}}/${{env.BASELINE_DIR}}/${{matrix.regression_func}}


+ 157
- 0
examples/eval_scireasoner.py View File

@@ -0,0 +1,157 @@
from mmengine.config import read_base

with read_base():
# scireasoner
from opencompass.configs.datasets.SciReasoner.scireasoner_gen import scireasoner_datasets_full, scireasoner_datasets_mini

# # full set summarizer
# summarizer = dict(
# dataset_abbrs=[
# ['SciReasoner-bio_instruction-antibody_antigen', 'MCC'], ['SciReasoner-bio_instruction-rna_protein_interaction', 'MCC'], ['SciReasoner-bio_instruction-emp', 'MCC'],
# ['SciReasoner-bio_instruction-enhancer_activity', 'PCC'], ['SciReasoner-bio_instruction-tf_m', 'MCC'], ['SciReasoner-bio_instruction-Isoform', 'R2'], ['SciReasoner-bio_instruction-Modification', 'AUC'],
# ['SciReasoner-bio_instruction-MeanRibosomeLoading', 'R2'], ['SciReasoner-bio_instruction-ProgrammableRNASwitches', 'R2'], ['SciReasoner-bio_instruction-CRISPROnTarget', 'spearman'],
# ['SciReasoner-bio_instruction-promoter_enhancer_interaction', 'MCC'], ['SciReasoner-bio_instruction-sirnaEfficiency', 'mixed_score'], ['SciReasoner-bio_instruction-cpd', 'MCC'],
# ['SciReasoner-bio_instruction-pd', 'MCC'], ['SciReasoner-bio_instruction-tf_h', 'MCC'],
# ['SciReasoner-Gue_cpd-prom_core_all', 'matthews_correlation_all'],
# ['SciReasoner-Gue_cpd-prom_core_notata', 'matthews_correlation_all'],
# ['SciReasoner-Gue_cpd-prom_core_tata', 'matthews_correlation_all'], ['SciReasoner-Gue_pd-prom_300_all', 'matthews_correlation_all'],
# ['SciReasoner-Gue_pd-prom_300_notata', 'matthews_correlation_all'], ['SciReasoner-Gue_pd-prom_300_tata', 'matthews_correlation_all'],
# ['SciReasoner-Gue_tf-h-0', 'matthews_correlation_all'], ['SciReasoner-Gue_tf-h-1', 'matthews_correlation_all'],
# ['SciReasoner-Gue_tf-h-2', 'matthews_correlation_all'], ['SciReasoner-Gue_tf-h-3', 'matthews_correlation_all'],
# ['SciReasoner-Gue_tf-h-4', 'matthews_correlation_all'], ['SciReasoner-smol_forward_synthesis', 'top1_exact_match'],
# ['SciReasoner-smol_retrosynthesis', 'top1_exact_match'], ['SciReasoner-smol_molecule_captioning', 'meteor_score'],
# ['SciReasoner-smol_molecule_generation', 'top1_exact_match'], ['SciReasoner-smol_name_conversion-i2f', 'top1_ele_match'],
# ['SciReasoner-smol_name_conversion-i2s', 'top1_exact_match'], ['SciReasoner-smol_name_conversion-s2f', 'top1_ele_match'],
# ['SciReasoner-smol_name_conversion-s2i', 'top1_split_match'], ['SciReasoner-smol_property_prediction-esol', 'RMSE'],
# ['SciReasoner-smol_property_prediction-lipo', 'RMSE'], ['SciReasoner-smol_property_prediction-bbbp', 'accuracy'],
# ['SciReasoner-smol_property_prediction-clintox', 'accuracy'], ['SciReasoner-smol_property_prediction-hiv', 'accuracy'],
# ['SciReasoner-smol_property_prediction-sider', 'accuracy'], ['SciReasoner-retrosynthesis_USPTO_50K', 'Top-1 Accuracy'],
# ['SciReasoner-LLM4Mat_MP_IsStable', 'AUC'], ['SciReasoner-LLM4Mat_MP_IsGapDirect', 'AUC'], ['SciReasoner-LLM4Mat_SNUMAT_IsDirect', 'AUC'],
# ['SciReasoner-LLM4Mat_SNUMAT_IsDirect_HSE', 'AUC'], ['SciReasoner-LLM4Mat_SNUMAT_SOC', 'AUC'], ['SciReasoner-LLM4Mat_MP_FEPA', 'MAD/MAE'],
# ['SciReasoner-LLM4Mat_MP_Bandgap', 'MAD/MAE'], ['SciReasoner-LLM4Mat_MP_EPA', 'MAD/MAE'], ['SciReasoner-LLM4Mat_MP_Ehull', 'MAD/MAE'], ['SciReasoner-LLM4Mat_MP_Efermi', 'MAD/MAE'],
# ['SciReasoner-LLM4Mat_MP_Density', 'MAD/MAE'], ['SciReasoner-LLM4Mat_MP_DensityAtomic', 'MAD/MAE'], ['SciReasoner-LLM4Mat_MP_Volume', 'MAD/MAE'],
# ['SciReasoner-LLM4Mat_JARVISDFT_FEPA', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_Bandgap_OPT', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_TotEn', 'MAD/MAE'],
# ['SciReasoner-LLM4Mat_JARVISDFT_Ehull', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_Bandgap_MBJ', 'MAD/MAE'], ['JSciReasoner-LLM4Mat_ARVISDFT_Kv', 'MAD/MAE'],
# ['SciReasoner-LLM4Mat_JARVISDFT_Gv', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_SLME', 'MAD/MAE'], ['JSciReasoner-LLM4Mat_ARVISDFT_Spillage', 'MAD/MAE'],
# ['SciReasoner-LLM4Mat_JARVISDFT_Epsx_OPT', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_Dielectric_DFPT', 'MAD/MAE'],
# ['SciReasoner-LLM4Mat_JARVISDFT_Max_Piezo_dij', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_Max_Piezo_eij', 'MAD/MAE'],
# ['SciReasoner-LLM4Mat_JARVISDFT_MaxEFG', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_ExfEn', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_AvgMe', 'MAD/MAE'],
# ['SciReasoner-LLM4Mat_JARVISDFT_nSeebeck', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_nPF', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_pSeebeck', 'MAD/MAE'],
# ['SciReasoner-LLM4Mat_JARVISDFT_pPF', 'MAD/MAE'], ['SciReasoner-LLM4Mat_SNUMAT_Bandgap_GGA', 'MAD/MAE'], ['SciReasoner-LLM4Mat_SNUMAT_Bandgap_HSE', 'MAD/MAE'],
# ['SciReasoner-LLM4Mat_SNUMAT_Bandgap_GGA_Optical', 'MAD/MAE'], ['SciReasoner-LLM4Mat_SNUMAT_Bandgap_HSE_Optical', 'MAD/MAE'],
# ['SciReasoner-LLM4Mat_GNoME_FEPA', 'MAD/MAE'], ['SciReasoner-LLM4Mat_GNoME_DEPA', 'MAD/MAE'], ['SciReasoner-LLM4Mat_GNoME_Bandgap', 'MAD/MAE'],
# ['SciReasoner-LLM4Mat_GNoME_TotEn', 'MAD/MAE'], ['SciReasoner-LLM4Mat_GNoME_Volume', 'MAD/MAE'], ['SciReasoner-LLM4Mat_GNoME_Density', 'MAD/MAE'],
# ['SciReasoner-LLM4Mat_hMOF_MaxCO2', 'MAD/MAE'], ['SciReasoner-LLM4Mat_hMOF_MinCO2', 'MAD/MAE'], ['SciReasoner-LLM4Mat_hMOF_LCD', 'MAD/MAE'],
# ['SciReasoner-LLM4Mat_hMOF_PLD', 'MAD/MAE'], ['SciReasoner-LLM4Mat_hMOF_VoidFraction', 'MAD/MAE'], ['SciReasoner-LLM4Mat_hMOF_SA_m2g', 'MAD/MAE'],
# ['SciReasoner-LLM4Mat_hMOF_SA_m2cm3', 'MAD/MAE'], ['SciReasoner-LLM4Mat_SciReasoner-LLM4Mat_Cantor_HEA_FEPA', 'MAD/MAE'], ['SciReasoner-LLM4Mat_Cantor_HEA_EPA', 'MAD/MAE'],
# ['SciReasoner-LLM4Mat_Cantor_HEA_Ehull', 'MAD/MAE'], ['SciReasoner-LLM4Mat_Cantor_HEA_VPA', 'MAD/MAE'], ['SciReasoner-LLM4Mat_QMOF_TotEn', 'MAD/MAE'],
# ['SciReasoner-LLM4Mat_QMOF_Bandgap', 'MAD/MAE'], ['SciReasoner-LLM4Mat_QMOF_LCD', 'MAD/MAE'], ['SciReasoner-LLM4Mat_QMOF_PLD', 'MAD/MAE'],
# ['SciReasoner-LLM4Mat_JARVISQETB_EPA', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISQETB_IndirBandgap', 'MAD/MAE'],
# ['SciReasoner-LLM4Mat_JARVISQETB_FEPA', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISQETB_TotEn', 'MAD/MAE'], ['SciReasoner-LLM4Mat_OQMD_Bandgap', 'MAD/MAE'],
# ['SciReasoner-LLM4Mat_OQMD_FEPA', 'MAD/MAE'], ['SciReasoner-LLM4Mat_OMDB_Bandgap', 'MAD/MAE'],
# ['SciReasoner-composition_to_material_generation', 'smact_validity_ratio_in_all_%'],
# ['SciReasoner-bulk_modulus_to_material_generation', 'smact_validity_ratio_in_all_%'],
# ['SciReasoner-mol_instruction_chemical_disease_interaction_extraction', 'f1'],
# ['SciReasoner-mol_instruction_chemical_entity_recognition', 'f1'],
# ['SciReasoner-mol_instruction_chemical_protein_interaction_extraction', 'f1'],
# ['SciReasoner-mol_instruction_multi_choice_question', 'accuracy'],
# ['SciReasoner-mol_instruction_open_question', 'bert_score'],
# ['SciReasoner-mol_instruction_true_or_false_question', 'accuracy'],
# ['SciReasoner-mol_instruction_property_prediction_str', 'mae'],
# ['SciReasoner-mol_instruction_description_guided_molecule_design', 'exact_match_score'],
# ['SciReasoner-mol_instruction_forward_reaction_prediction', 'exact_match_score'],
# ['SciReasoner-mol_instruction_retrosynthesis', 'exact_match_score'],
# ['SciReasoner-mol_instruction_reagent_prediction', 'exact_match_score'],
# ['SciReasoner-mol_instruction_molecular_description_generation', 'rougeL'],
# ['SciReasoner-mol_instruction_catalytic_activity', 'rougeL'], ['SciReasoner-mol_instruction_domain_motif', 'rougeL'],
# ['SciReasoner-mol_instruction_general_function', 'rougeL'], ['SciReasoner-mol_instruction_protein_function', 'rougeL'],
# ['SciReasoner-mol_instruction_protein_design', 'Max SW score'], ['SciReasoner-Opi_EC_number_CLEAN_EC_number_new', 'Accuracy'],
# ['SciReasoner-Opi_EC_number_CLEAN_EC_number_price', 'Accuracy'], ['SciReasoner-Opi_Fold_type_fold_type', 'Accuracy'],
# ['SciReasoner-Opi_Function_CASPSimilarSeq_function', 'ROUGE-L'], ['SciReasoner-Opi_Function_IDFilterSeq_function', 'ROUGE-L'],
# ['SciReasoner-Opi_Function_UniProtSeq_function', 'ROUGE-L'], ['SciReasoner-Opi_gName2Cancer_gene_name_to_cancer', 'F1 Score'],
# ['SciReasoner-Opi_GO_CASPSimilarSeq_go', 'F1 Score'], ['SciReasoner-Opi_GO_IDFilterSeq_go', 'F1 Score'], ['GO_UniProtSeq_go', 'F1 Score'],
# ['SciReasoner-Opi_gSymbol2Cancer_gene_symbol_to_cancer', 'F1 Score'], ['SciReasoner-Opi_gSymbol2Tissue_gene_symbol_to_tissue', 'F1 Score'],
# ['SciReasoner-Opi_Keywords_CASPSimilarSeq_keywords', 'F1 Score'], ['SciReasoner-Opi_Keywords_IDFilterSeq_keywords', 'F1 Score'],
# ['SciReasoner-Opi_Keywords_UniProtSeq_keywords', 'F1 Score'], ['SciReasoner-Opi_Subcellular_localization_subcell_loc', 'Accuracy'],
# ['SciReasoner-PEER_solubility', 'accuracy'], ['SciReasoner-PEER_stability', 'accuracy'], ['SciReasoner-PEER_human_ppi', 'accuracy'], ['SciReasoner-PEER_yeast_ppi', 'accuracy'],
# ['SciReasoner-unconditional_material_generation', 'smact_validity_ratio_in_all'],
# ['SciReasoner-unconditional_RNA_generation', 'average_mfe'], ['SciReasoner-unconditional_protein_generation', 'valid_rate'],
# ['SciReasoner-unconditional_molecule_generation', 'validity']
# ]
# )

# mini set summarizer
summarizer = dict(
dataset_abbrs=[
['SciReasoner-bio_instruction-antibody_antigen-mini', 'MCC'], ['SciReasoner-bio_instruction-rna_protein_interaction-mini', 'MCC'], ['SciReasoner-bio_instruction-emp-mini', 'MCC'],
['SciReasoner-bio_instruction-enhancer_activity-mini', 'PCC'], ['SciReasoner-bio_instruction-tf_m-mini', 'MCC'], ['SciReasoner-bio_instruction-Isoform-mini', 'R2'], ['SciReasoner-bio_instruction-Modification-mini', 'AUC'],
['SciReasoner-bio_instruction-MeanRibosomeLoading-mini', 'R2'], ['SciReasoner-bio_instruction-ProgrammableRNASwitches-mini', 'R2'], ['SciReasoner-bio_instruction-CRISPROnTarget-mini', 'spearman'],
['SciReasoner-bio_instruction-promoter_enhancer_interaction-mini', 'MCC'], ['SciReasoner-bio_instruction-sirnaEfficiency-mini', 'mixed_score'], ['SciReasoner-bio_instruction-cpd-mini', 'MCC'],
['SciReasoner-bio_instruction-pd-mini', 'MCC'], ['SciReasoner-bio_instruction-tf_h-mini', 'MCC'],
['SciReasoner-Gue_cpd-prom_core_all-mini', 'matthews_correlation_all'],
['SciReasoner-Gue_cpd-prom_core_notata-mini', 'matthews_correlation_all'],
['SciReasoner-Gue_cpd-prom_core_tata-mini', 'matthews_correlation_all'], ['SciReasoner-Gue_pd-prom_300_all-mini', 'matthews_correlation_all'],
['SciReasoner-Gue_pd-prom_300_notata-mini', 'matthews_correlation_all'], ['SciReasoner-Gue_pd-prom_300_tata-mini', 'matthews_correlation_all'],
['SciReasoner-Gue_tf-h-0-mini', 'matthews_correlation_all'], ['SciReasoner-Gue_tf-h-1-mini', 'matthews_correlation_all'],
['SciReasoner-Gue_tf-h-2-mini', 'matthews_correlation_all'], ['SciReasoner-Gue_tf-h-3-mini', 'matthews_correlation_all'],
['SciReasoner-Gue_tf-h-4-mini', 'matthews_correlation_all'], ['SciReasoner-smol_forward_synthesis-mini', 'top1_exact_match'],
['SciReasoner-smol_retrosynthesis-mini', 'top1_exact_match'], ['SciReasoner-smol_molecule_captioning-mini', 'meteor_score'],
['SciReasoner-smol_molecule_generation-mini', 'top1_exact_match'], ['SciReasoner-smol_name_conversion-i2f-mini', 'top1_ele_match'],
['SciReasoner-smol_name_conversion-i2s-mini', 'top1_exact_match'], ['SciReasoner-smol_name_conversion-s2f-mini', 'top1_ele_match'],
['SciReasoner-smol_name_conversion-s2i-mini', 'top1_split_match'], ['SciReasoner-smol_property_prediction-esol-mini', 'RMSE'],
['SciReasoner-smol_property_prediction-lipo-mini', 'RMSE'], ['SciReasoner-smol_property_prediction-bbbp-mini', 'accuracy'],
['SciReasoner-smol_property_prediction-clintox-mini', 'accuracy'], ['SciReasoner-smol_property_prediction-hiv-mini', 'accuracy'],
['SciReasoner-smol_property_prediction-sider-mini', 'accuracy'], ['SciReasoner-retrosynthesis_USPTO_50K-mini', 'Top-1 Accuracy'],
['SciReasoner-LLM4Mat_MP_IsStable-mini', 'AUC'], ['SciReasoner-LLM4Mat_MP_IsGapDirect-mini', 'AUC'], ['SciReasoner-LLM4Mat_SNUMAT_IsDirect-mini', 'AUC'],
['SciReasoner-LLM4Mat_SNUMAT_IsDirect_HSE-mini', 'AUC'], ['SciReasoner-LLM4Mat_SNUMAT_SOC-mini', 'AUC'], ['SciReasoner-LLM4Mat_MP_FEPA-mini', 'MAD/MAE'],
['SciReasoner-LLM4Mat_MP_Bandgap-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_MP_EPA-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_MP_Ehull-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_MP_Efermi-mini', 'MAD/MAE'],
['SciReasoner-LLM4Mat_MP_Density-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_MP_DensityAtomic-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_MP_Volume-mini', 'MAD/MAE'],
['SciReasoner-LLM4Mat_JARVISDFT_FEPA-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_Bandgap_OPT-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_TotEn-mini', 'MAD/MAE'],
['SciReasoner-LLM4Mat_JARVISDFT_Ehull-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_Bandgap_MBJ-mini', 'MAD/MAE'], ['JSciReasoner-LLM4Mat_ARVISDFT_Kv-mini', 'MAD/MAE'],
['SciReasoner-LLM4Mat_JARVISDFT_Gv-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_SLME-mini', 'MAD/MAE'], ['JSciReasoner-LLM4Mat_ARVISDFT_Spillage-mini', 'MAD/MAE'],
['SciReasoner-LLM4Mat_JARVISDFT_Epsx_OPT-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_Dielectric_DFPT-mini', 'MAD/MAE'],
['SciReasoner-LLM4Mat_JARVISDFT_Max_Piezo_dij-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_Max_Piezo_eij-mini', 'MAD/MAE'],
['SciReasoner-LLM4Mat_JARVISDFT_MaxEFG-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_ExfEn-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_AvgMe-mini', 'MAD/MAE'],
['SciReasoner-LLM4Mat_JARVISDFT_nSeebeck-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_nPF-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISDFT_pSeebeck-mini', 'MAD/MAE'],
['SciReasoner-LLM4Mat_JARVISDFT_pPF-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_SNUMAT_Bandgap_GGA-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_SNUMAT_Bandgap_HSE-mini', 'MAD/MAE'],
['SciReasoner-LLM4Mat_SNUMAT_Bandgap_GGA_Optical-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_SNUMAT_Bandgap_HSE_Optical-mini', 'MAD/MAE'],
['SciReasoner-LLM4Mat_GNoME_FEPA-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_GNoME_DEPA-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_GNoME_Bandgap-mini', 'MAD/MAE'],
['SciReasoner-LLM4Mat_GNoME_TotEn-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_GNoME_Volume-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_GNoME_Density-mini', 'MAD/MAE'],
['SciReasoner-LLM4Mat_hMOF_MaxCO2-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_hMOF_MinCO2-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_hMOF_LCD-mini', 'MAD/MAE'],
['SciReasoner-LLM4Mat_hMOF_PLD-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_hMOF_VoidFraction-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_hMOF_SA_m2g-mini', 'MAD/MAE'],
['SciReasoner-LLM4Mat_hMOF_SA_m2cm3-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_SciReasoner-LLM4Mat_Cantor_HEA_FEPA-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_Cantor_HEA_EPA-mini', 'MAD/MAE'],
['SciReasoner-LLM4Mat_Cantor_HEA_Ehull-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_Cantor_HEA_VPA-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_QMOF_TotEn-mini', 'MAD/MAE'],
['SciReasoner-LLM4Mat_QMOF_Bandgap-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_QMOF_LCD-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_QMOF_PLD-mini', 'MAD/MAE'],
['SciReasoner-LLM4Mat_JARVISQETB_EPA-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISQETB_IndirBandgap-mini', 'MAD/MAE'],
['SciReasoner-LLM4Mat_JARVISQETB_FEPA-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_JARVISQETB_TotEn-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_OQMD_Bandgap-mini', 'MAD/MAE'],
['SciReasoner-LLM4Mat_OQMD_FEPA-mini', 'MAD/MAE'], ['SciReasoner-LLM4Mat_OMDB_Bandgap-mini', 'MAD/MAE'],
['SciReasoner-composition_to_material_generation-mini', 'smact_validity_ratio_in_all_%'],
['SciReasoner-bulk_modulus_to_material_generation-mini', 'smact_validity_ratio_in_all_%'],
['SciReasoner-mol_instruction_chemical_disease_interaction_extraction-mini', 'f1'],
['SciReasoner-mol_instruction_chemical_entity_recognition-mini', 'f1'],
['SciReasoner-mol_instruction_chemical_protein_interaction_extraction-mini', 'f1'],
['SciReasoner-mol_instruction_multi_choice_question-mini', 'accuracy'],
['SciReasoner-mol_instruction_open_question-mini', 'bert_score'],
['SciReasoner-mol_instruction_true_or_false_question-mini', 'accuracy'],
['SciReasoner-mol_instruction_property_prediction_str-mini', 'mae'],
['SciReasoner-mol_instruction_description_guided_molecule_design-mini', 'exact_match_score'],
['SciReasoner-mol_instruction_forward_reaction_prediction-mini', 'exact_match_score'],
['SciReasoner-mol_instruction_retrosynthesis-mini', 'exact_match_score'],
['SciReasoner-mol_instruction_reagent_prediction-mini', 'exact_match_score'],
['SciReasoner-mol_instruction_molecular_description_generation-mini', 'rougeL'],
['SciReasoner-mol_instruction_catalytic_activity-mini', 'rougeL'], ['SciReasoner-mol_instruction_domain_motif-mini', 'rougeL'],
['SciReasoner-mol_instruction_general_function-mini', 'rougeL'], ['SciReasoner-mol_instruction_protein_function-mini', 'rougeL'],
['SciReasoner-mol_instruction_protein_design-mini', 'Max SW score'], ['SciReasoner-Opi_EC_number_CLEAN_EC_number_new-mini', 'Accuracy'],
['SciReasoner-Opi_EC_number_CLEAN_EC_number_price-mini', 'Accuracy'], ['SciReasoner-Opi_Fold_type_fold_type-mini', 'Accuracy'],
['SciReasoner-Opi_Function_CASPSimilarSeq_function-mini', 'ROUGE-L'], ['SciReasoner-Opi_Function_IDFilterSeq_function-mini', 'ROUGE-L'],
['SciReasoner-Opi_Function_UniProtSeq_function-mini', 'ROUGE-L'], ['SciReasoner-Opi_gName2Cancer_gene_name_to_cancer-mini', 'F1 Score'],
['SciReasoner-Opi_GO_CASPSimilarSeq_go-mini', 'F1 Score'], ['SciReasoner-Opi_GO_IDFilterSeq_go-mini', 'F1 Score'], ['GO_UniProtSeq_go-mini', 'F1 Score'],
['SciReasoner-Opi_gSymbol2Cancer_gene_symbol_to_cancer-mini', 'F1 Score'], ['SciReasoner-Opi_gSymbol2Tissue_gene_symbol_to_tissue-mini', 'F1 Score'],
['SciReasoner-Opi_Keywords_CASPSimilarSeq_keywords-mini', 'F1 Score'], ['SciReasoner-Opi_Keywords_IDFilterSeq_keywords-mini', 'F1 Score'],
['SciReasoner-Opi_Keywords_UniProtSeq_keywords-mini', 'F1 Score'], ['SciReasoner-Opi_Subcellular_localization_subcell_loc-mini', 'Accuracy'],
['SciReasoner-PEER_solubility-mini', 'accuracy'], ['SciReasoner-PEER_stability-mini', 'accuracy'], ['SciReasoner-PEER_human_ppi-mini', 'accuracy'], ['SciReasoner-PEER_yeast_ppi-mini', 'accuracy'],
['SciReasoner-unconditional_material_generation-mini', 'smact_validity_ratio_in_all'],
['SciReasoner-unconditional_RNA_generation-mini', 'average_mfe'], ['SciReasoner-unconditional_protein_generation-mini', 'valid_rate'],
['SciReasoner-unconditional_molecule_generation-mini', 'validity']
]
)

+ 77
- 0
opencompass/configs/datasets/SciReasoner/GUE_gen.py View File

@@ -0,0 +1,77 @@
from opencompass.datasets import (
GUE_Dataset,
GUE_Evaluator,
GUE_postprocessor
)
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_retriever import ZeroRetriever

GUE_sub_tasks = [
'cpd-prom_core_all',
'cpd-prom_core_notata',
'cpd-prom_core_tata',
'pd-prom_300_all',
'pd-prom_300_notata',
'pd-prom_300_tata',
'tf-h-0',
'tf-h-1',
'tf-h-2',
'tf-h-3',
'tf-h-4',
]

GUE_reader_cfg = dict(input_columns=['input'], output_column='output')

GUE_datasets = []
mini_GUE_datasets = []

for name in GUE_sub_tasks:

GUE_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{input}'),
]
),
),
retriever=dict(
type=ZeroRetriever
),
inferencer=dict(type=GenInferencer),
)

GUE_eval_cfg = dict(
evaluator=dict(
type=GUE_Evaluator
),
pred_role='BOT',
pred_postprocessor=dict(type=GUE_postprocessor),
dataset_postprocessor=dict(type=GUE_postprocessor),
)

GUE_datasets.append(
dict(
abbr=f'SciReasoner-Gue_{name}',
type=GUE_Dataset,
path='opencompass/SciReasoner-GUE',
task=name,
reader_cfg=GUE_reader_cfg,
infer_cfg=GUE_infer_cfg,
eval_cfg=GUE_eval_cfg,
)
)
mini_GUE_datasets.append(
dict(
abbr=f'SciReasoner-Gue_{name}-mini',
type=GUE_Dataset,
path='opencompass/SciReasoner-GUE',
task=name,
mini_set=True,
reader_cfg=GUE_reader_cfg,
infer_cfg=GUE_infer_cfg,
eval_cfg=GUE_eval_cfg,
)
)

+ 290
- 0
opencompass/configs/datasets/SciReasoner/LLM4Mat_gen.py View File

@@ -0,0 +1,290 @@
from opencompass.datasets import (
LLM4MatDataset,
LLM4Mat_Evaluator,
LLM4Mat_postprocessor
)
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_retriever import ZeroRetriever

LLM4Mat_sub_tasks = \
{'MP_FEPA': {'property': 'formation_energy_per_atom',
'test_path': 'mp/test/data.json',
'train_path': 'mp/dev/data.json'},
'MP_Bandgap': {'property': 'band_gap',
'test_path': 'mp/test/data.json',
'train_path': 'mp/dev/data.json'},
'MP_EPA': {'property': 'GGA-PBE-based_energy_per_atom',
'test_path': 'mp/test/data.json',
'train_path': 'mp/dev/data.json'},
'MP_Ehull': {'property': 'energy_above_hull',
'test_path': 'mp/test/data.json',
'train_path': 'mp/dev/data.json'},
'MP_Efermi': {'property': 'efermi',
'test_path': 'mp/test/data.json',
'train_path': 'mp/dev/data.json'},
'MP_Density': {'property': 'density',
'test_path': 'mp/test/data.json',
'train_path': 'mp/dev/data.json'},
'MP_DensityAtomic': {'property': 'density_atomic',
'test_path': 'mp/test/data.json',
'train_path': 'mp/dev/data.json'},
'MP_Volume': {'property': 'volume',
'test_path': 'mp/test/data.json',
'train_path': 'mp/dev/data.json'},
'MP_IsStable': {'property': 'is_stable',
'test_path': 'mp/test/data.json',
'train_path': 'mp/dev/data.json'},
'MP_IsGapDirect': {'property': 'is_gap_direct',
'test_path': 'mp/test/data.json',
'train_path': 'mp/dev/data.json'},
'JARVISDFT_FEPA': {'property': 'formation_energy_peratom',
'test_path': 'jarvis_dft/test/data.json',
'train_path': 'jarvis_dft/dev/data.json'},
'JARVISDFT_Bandgap_OPT': {'property': 'optb88vdw_bandgap',
'test_path': 'jarvis_dft/test/data.json',
'train_path': 'jarvis_dft/dev/data.json'},
'JARVISDFT_TotEn': {'property': 'optb88vdw_total_energy',
'test_path': 'jarvis_dft/test/data.json',
'train_path': 'jarvis_dft/dev/data.json'},
'JARVISDFT_Ehull': {'property': 'ehull',
'test_path': 'jarvis_dft/test/data.json',
'train_path': 'jarvis_dft/dev/data.json'},
'JARVISDFT_Bandgap_MBJ': {'property': 'mbj_bandgap',
'test_path': 'jarvis_dft/test/data.json',
'train_path': 'jarvis_dft/dev/data.json'},
'JARVISDFT_Kv': {'property': 'bulk_modulus_kv',
'test_path': 'jarvis_dft/test/data.json',
'train_path': 'jarvis_dft/dev/data.json'},
'JARVISDFT_Gv': {'property': 'shear_modulus_gv',
'test_path': 'jarvis_dft/test/data.json',
'train_path': 'jarvis_dft/dev/data.json'},
'JARVISDFT_SLME': {'property': 'slme',
'test_path': 'jarvis_dft/test/data.json',
'train_path': 'jarvis_dft/dev/data.json'},
'JARVISDFT_Spillage': {'property': 'spillage',
'test_path': 'jarvis_dft/test/data.json',
'train_path': 'jarvis_dft/dev/data.json'},
'JARVISDFT_Epsx_OPT': {'property': 'mepsx',
'test_path': 'jarvis_dft/test/data.json',
'train_path': 'jarvis_dft/dev/data.json'},
'JARVISDFT_Dielectric_DFPT': {'property': 'dfpt_piezo_max_dielectric',
'test_path': 'jarvis_dft/test/data.json',
'train_path': 'jarvis_dft/dev/data.json'},
'JARVISDFT_Max_Piezo_dij': {'property': 'dfpt_piezo_max_dij',
'test_path': 'jarvis_dft/test/data.json',
'train_path': 'jarvis_dft/dev/data.json'},
'JARVISDFT_Max_Piezo_eij': {'property': 'dfpt_piezo_max_eij',
'test_path': 'jarvis_dft/test/data.json',
'train_path': 'jarvis_dft/dev/data.json'},
'JARVISDFT_MaxEFG': {'property': 'max_efg',
'test_path': 'jarvis_dft/test/data.json',
'train_path': 'jarvis_dft/dev/data.json'},
'JARVISDFT_ExfEn': {'property': 'exfoliation_energy',
'test_path': 'jarvis_dft/test/data.json',
'train_path': 'jarvis_dft/dev/data.json'},
'JARVISDFT_AvgMe': {'property': 'avg_elec_mass',
'test_path': 'jarvis_dft/test/data.json',
'train_path': 'jarvis_dft/dev/data.json'},
'JARVISDFT_nSeebeck': {'property': 'n-Seebeck',
'test_path': 'jarvis_dft/test/data.json',
'train_path': 'jarvis_dft/dev/data.json'},
'JARVISDFT_nPF': {'property': 'n-powerfact',
'test_path': 'jarvis_dft/test/data.json',
'train_path': 'jarvis_dft/dev/data.json'},
'JARVISDFT_pSeebeck': {'property': 'p-Seebeck',
'test_path': 'jarvis_dft/test/data.json',
'train_path': 'jarvis_dft/dev/data.json'},
'JARVISDFT_pPF': {'property': 'p-powerfact',
'test_path': 'jarvis_dft/test/data.json',
'train_path': 'jarvis_dft/dev/data.json'},
'SNUMAT_Bandgap_GGA': {'property': 'Band_gap_GGA',
'test_path': 'snumat/test/data.json',
'train_path': 'snumat/dev/data.json'},
'SNUMAT_Bandgap_HSE': {'property': 'Band_gap_HSE',
'test_path': 'snumat/test/data.json',
'train_path': 'snumat/dev/data.json'},
'SNUMAT_Bandgap_GGA_Optical': {'property': 'Band_gap_GGA_optical',
'test_path': 'snumat/test/data.json',
'train_path': 'snumat/dev/data.json'},
'SNUMAT_Bandgap_HSE_Optical': {'property': 'Band_gap_HSE_optical',
'test_path': 'snumat/test/data.json',
'train_path': 'snumat/dev/data.json'},
'SNUMAT_IsDirect': {'property': 'Direct_or_indirect',
'test_path': 'snumat/test/data.json',
'train_path': 'snumat/dev/data.json'},
'SNUMAT_IsDirect_HSE': {'property': 'Direct_or_indirect_HSE',
'test_path': 'snumat/test/data.json',
'train_path': 'snumat/dev/data.json'},
'SNUMAT_SOC': {'property': 'SOC',
'test_path': 'snumat/test/data.json',
'train_path': 'snumat/dev/data.json'},
'GNoME_FEPA': {'property': 'Formation_Energy_Per_Atom',
'test_path': 'gnome/test/data.json',
'train_path': 'gnome/dev/data.json'},
'GNoME_DEPA': {'property': 'Decomposition_Energy_Per_Atom',
'test_path': 'gnome/test/data.json',
'train_path': 'gnome/dev/data.json'},
'GNoME_Bandgap': {'property': 'Bandgap',
'test_path': 'gnome/test/data.json',
'train_path': 'gnome/dev/data.json'},
'GNoME_TotEn': {'property': 'Corrected_Energy',
'test_path': 'gnome/test/data.json',
'train_path': 'gnome/dev/data.json'},
'GNoME_Volume': {'property': 'Volume',
'test_path': 'gnome/test/data.json',
'train_path': 'gnome/dev/data.json'},
'GNoME_Density': {'property': 'Density',
'test_path': 'gnome/test/data.json',
'train_path': 'gnome/dev/data.json'},
'hMOF_MaxCO2': {'property': 'max_co2_adsp',
'test_path': 'hmof/test/data.json',
'train_path': 'hmof/dev/data.json'},
'hMOF_MinCO2': {'property': 'min_co2_adsp',
'test_path': 'hmof/test/data.json',
'train_path': 'hmof/dev/data.json'},
'hMOF_LCD': {'property': 'lcd',
'test_path': 'hmof/test/data.json',
'train_path': 'hmof/dev/data.json'},
'hMOF_PLD': {'property': 'pld',
'test_path': 'hmof/test/data.json',
'train_path': 'hmof/dev/data.json'},
'hMOF_VoidFraction': {'property': 'void_fraction',
'test_path': 'hmof/test/data.json',
'train_path': 'hmof/dev/data.json'},
'hMOF_SA_m2g': {'property': 'surface_area_m2g',
'test_path': 'hmof/test/data.json',
'train_path': 'hmof/dev/data.json'},
'hMOF_SA_m2cm3': {'property': 'surface_area_m2cm3',
'test_path': 'hmof/test/data.json',
'train_path': 'hmof/dev/data.json'},
'Cantor_HEA_FEPA': {'property': 'Ef_per_atom',
'test_path': 'cantor_hea/test/data.json',
'train_path': 'cantor_hea/dev/data.json'},
'Cantor_HEA_EPA': {'property': 'e_per_atom',
'test_path': 'cantor_hea/test/data.json',
'train_path': 'cantor_hea/dev/data.json'},
'Cantor_HEA_Ehull': {'property': 'e_above_hull',
'test_path': 'cantor_hea/test/data.json',
'train_path': 'cantor_hea/dev/data.json'},
'Cantor_HEA_VPA': {'property': 'volume_per_atom',
'test_path': 'cantor_hea/test/data.json',
'train_path': 'cantor_hea/dev/data.json'},
'QMOF_TotEn': {'property': 'energy_total',
'test_path': 'qmof/test/data.json',
'train_path': 'qmof/dev/data.json'},
'QMOF_Bandgap': {'property': 'bandgap',
'test_path': 'qmof/test/data.json',
'train_path': 'qmof/dev/data.json'},
'QMOF_LCD': {'property': 'lcd',
'test_path': 'qmof/test/data.json',
'train_path': 'qmof/dev/data.json'},
'QMOF_PLD': {'property': 'pld',
'test_path': 'qmof/test/data.json',
'train_path': 'qmof/dev/data.json'},
'JARVISQETB_EPA': {'property': 'TB-based_energy_per_atom',
'test_path': 'jarvis_qetb/test/data.json',
'train_path': 'jarvis_qetb/dev/data.json'},
'JARVISQETB_IndirBandgap': {'property': 'indir_gap',
'test_path': 'jarvis_qetb/test/data.json',
'train_path': 'jarvis_qetb/dev/data.json'},
'JARVISQETB_FEPA': {'property': 'f_enp',
'test_path': 'jarvis_qetb/test/data.json',
'train_path': 'jarvis_qetb/dev/data.json'},
'JARVISQETB_TotEn': {'property': 'final_energy',
'test_path': 'jarvis_qetb/test/data.json',
'train_path': 'jarvis_qetb/dev/data.json'},
'OQMD_Bandgap': {'property': 'bandgap',
'test_path': 'oqmd/test/data.json',
'train_path': 'oqmd/dev/data.json'},
'OQMD_FEPA': {'property': 'e_form',
'test_path': 'oqmd/test/data.json',
'train_path': 'oqmd/dev/data.json'},
'OMDB_Bandgap': {'property': 'bandgap',
'test_path': 'omdb/test/data.json',
'train_path': 'omdb/dev/data.json'}
}

non_numeric_props_options = {
'Direct_or_indirect': ['Direct', 'Indirect'],
'Direct_or_indirect_HSE': ['Direct', 'Indirect'],
'SOC': [True, False],
'is_gap_direct': [True, False],
'is_stable': [True, False],
}

LLM4Mat_reader_cfg = dict(input_columns=['input'], output_column='output')

LLM4Mat_datasets = []
mini_LLM4Mat_datasets = []


for name, info in LLM4Mat_sub_tasks.items():
prop = info['property']
test_path = info['test_path']
train_path = info['train_path']

if prop in non_numeric_props_options:
options = non_numeric_props_options[prop]
if all(isinstance(x, bool) for x in options):
options_str = 'True/False'
else:
options_str = '/'.join(str(x) for x in options)

prompt_template = dict(
round=[
dict(role='HUMAN', prompt=f'{{input}}'),
]
)
else:
prompt_template = dict(
round=[
dict(role='HUMAN', prompt='{input}'),
]
)

LLM4Mat_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=prompt_template,
),
retriever=dict(
type=ZeroRetriever
),
inferencer=dict(type=GenInferencer),
)

LLM4Mat_eval_cfg = dict(
evaluator=dict(type=LLM4Mat_Evaluator),
pred_role='BOT',
pred_postprocessor=dict(type=LLM4Mat_postprocessor, property=prop),
dataset_postprocessor=dict(type=LLM4Mat_postprocessor, property=prop),
)

LLM4Mat_datasets.append(
dict(
abbr=f'SciReasoner-LLM4Mat_{name}',
type=LLM4MatDataset,
path='opencompass/SciReasoner-LLM4Mat',
train_path=train_path,
test_path=test_path,
property=prop,
reader_cfg=LLM4Mat_reader_cfg,
infer_cfg=LLM4Mat_infer_cfg,
eval_cfg=LLM4Mat_eval_cfg,
)
)
mini_LLM4Mat_datasets.append(
dict(
abbr=f'SciReasoner-LLM4Mat_{name}-mini',
type=LLM4MatDataset,
path='opencompass/SciReasoner-LLM4Mat',
train_path=train_path,
test_path=test_path,
property=prop,
mini_set=True,
reader_cfg=LLM4Mat_reader_cfg,
infer_cfg=LLM4Mat_infer_cfg,
eval_cfg=LLM4Mat_eval_cfg,
)
)

+ 49
- 0
opencompass/configs/datasets/SciReasoner/UMG.py View File

@@ -0,0 +1,49 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import UMG_Dataset, UMG_Evaluator

INFER_TEMPLATE = '''Generate a molecule with <SMILES>'''

reader_cfg = dict(input_columns=['input'], output_column='output')

infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(
role='HUMAN',
prompt='{input}',
),
],
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)

eval_cfg = dict(
evaluator=dict(
type=UMG_Evaluator,
),
)

UMG_Datasets = [
dict(
abbr='SciReasoner-unconditional_molecule_generation',
type=UMG_Dataset,
# max_cut=20, # Optionally limit the maximum number of samples
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg)
]
mini_UMG_Datasets = [
dict(
abbr='SciReasoner-unconditional_molecule_generation-mini',
type=UMG_Dataset,
max_cut=150, # Optionally limit the maximum number of samples
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg)
]

+ 67
- 0
opencompass/configs/datasets/SciReasoner/UPG.py View File

@@ -0,0 +1,67 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import UPGDataset, UPG_postprocess, UPG_Evaluator

reader_cfg = dict(input_columns=['input'], output_column='output')

infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin=[
'</E>',
],
round=[
dict(role='HUMAN', prompt='{input}'),
]
),
ice_token='</E>',
),
ice_template=dict(
type=PromptTemplate,
template=dict(
round=[
# dict(role='HUMAN', prompt='{input} /no_think'), # for Qwen3
dict(role='HUMAN', prompt='{input}'),
dict(role='BOT', prompt='{output}'),
]
)
),
# The retriever is responsible for retrieving examples and formatting them using ice_template
retriever=dict(
# type=FixKRetriever,
# fix_id_list=[0, 1, 2, 3, 4], # Use the first 5 examples
type=ZeroRetriever, # For our trained models, use zero-shot
),
inferencer=dict(
type=GenInferencer,
),
)

eval_cfg = dict(
evaluator=dict(
type=UPG_Evaluator,
),
pred_postprocessor=dict(type=UPG_postprocess),
dataset_postprocessor=dict(type=UPG_postprocess),
)

UPG_datasets = [
dict(
abbr='SciReasoner-unconditional_protein_generation',
type=UPGDataset,
# max_cut=20, # Optionally limit the maximum number of samples
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg)
]
mini_UPG_datasets = [
dict(
abbr='SciReasoner-unconditional_protein_generation-mini',
type=UPGDataset,
max_cut=150, # Optionally limit the maximum number of samples
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg)
]

+ 75
- 0
opencompass/configs/datasets/SciReasoner/bio_instruction_gen.py View File

@@ -0,0 +1,75 @@
from mmengine.config import read_base
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import Bioinstruction_Dataset, bio_instruction_Evaluator

reader_cfg = dict(
input_columns=['input'],
output_column='output'
)

MODEL_NAME = r'model'

bio_instruction_datasets = []
mini_bio_instruction_datasets = []

path = ['antibody_antigen', 'rna_protein_interaction', 'emp', 'enhancer_activity', 'tf_m', 'Isoform', 'Modification',
'MeanRibosomeLoading', 'ProgrammableRNASwitches',
'CRISPROnTarget', 'promoter_enhancer_interaction', 'sirnaEfficiency', 'cpd', 'pd', 'tf_h']
extra_path = ['Fluorescence', 'FunctionEC', 'Stability', 'Solubility', 'Thermostability'] # protein的这几个

for task in path:
infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{input}'),
]),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)

eval_cfg = dict(
evaluator=dict(type=bio_instruction_Evaluator,
path='opencompass/SciReasoner-bio_instruction',
task=task,
model_name=MODEL_NAME),
pred_role='BOT',
# num_gpus=1
)
eval_mini_cfg = dict(
evaluator=dict(type=bio_instruction_Evaluator,
path='opencompass/SciReasoner-bio_instruction',
task=task,
mini_set=True,
model_name=MODEL_NAME),
pred_role='BOT',
# num_gpus=1
)

bio_instruction_datasets.append(
dict(
type=Bioinstruction_Dataset,
abbr=f'SciReasoner-bio_instruction-{task}',
path='opencompass/SciReasoner-bio_instruction',
task=task,
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg,
)
)
mini_bio_instruction_datasets.append(
dict(
type=Bioinstruction_Dataset,
abbr=f'SciReasoner-bio_instruction-{task}-mini',
path='opencompass/SciReasoner-bio_instruction',
task=task,
mini_set=True,
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_mini_cfg,
)
)

+ 52
- 0
opencompass/configs/datasets/SciReasoner/bulk_modulus_material_gen.py View File

@@ -0,0 +1,52 @@
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.datasets import Bulk_modulus_material_Dataset, material_Evaluator, material_postprocessor

modulus_material_reader = dict(input_columns=['input'], output_column='output')

modulus_material_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(
role='HUMAN',
prompt='{input}',
),
],
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)

modulus_material_eval_cfg = dict(
evaluator=dict(
type=material_Evaluator,
data_path='opencompass/SciReasoner-Conditional_generation',
),
pred_postprocessor=dict(type=material_postprocessor),
)

modulus_material_datasets = [
dict(
abbr='SciReasoner-bulk_modulus_to_material_generation',
type=Bulk_modulus_material_Dataset,
path='opencompass/SciReasoner-Conditional_generation',
reader_cfg=modulus_material_reader,
infer_cfg=modulus_material_infer_cfg,
eval_cfg=modulus_material_eval_cfg,
)
]
mini_modulus_material_datasets = [
dict(
abbr='SciReasoner-bulk_modulus_to_material_generation-mini',
type=Bulk_modulus_material_Dataset,
path='opencompass/SciReasoner-Conditional_generation',
mini_set=True,
reader_cfg=modulus_material_reader,
infer_cfg=modulus_material_infer_cfg,
eval_cfg=modulus_material_eval_cfg,
)
]

+ 64
- 0
opencompass/configs/datasets/SciReasoner/composition_material_gen.py View File

@@ -0,0 +1,64 @@
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.datasets import Composition_material_Dataset, composition_Evaluator, material_postprocessor

generation_kwargs = dict(
do_sample=True,
# top_p=0.8,
# min_p=0,
temperature=0.40,
# top_k=20,
# repetition_penalty=1,
# "<|endoftext|>": 151643 "<|im_end|>": 151645
# eos_token_id=[151643, 151645],
)

composition_material_reader = dict(input_columns=['input'], output_column='output')

composition_material_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(
role='HUMAN',
prompt='{input}',
),
],
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)

composition_material_eval_cfg = dict(
evaluator=dict(
type=composition_Evaluator,
data_path='opencompass/SciReasoner-Conditional_generation',
),
pred_postprocessor=dict(type=material_postprocessor),
)


composition_material_datasets = [
dict(
abbr='SciReasoner-composition_to_material_generation',
type=Composition_material_Dataset,
path='opencompass/SciReasoner-Conditional_generation',
reader_cfg=composition_material_reader,
infer_cfg=composition_material_infer_cfg,
eval_cfg=composition_material_eval_cfg,
)
]
mini_composition_material_datasets = [
dict(
abbr='SciReasoner-composition_to_material_generation-mini',
type=Composition_material_Dataset,
path='opencompass/SciReasoner-Conditional_generation',
mini_set=True,
reader_cfg=composition_material_reader,
infer_cfg=composition_material_infer_cfg,
eval_cfg=composition_material_eval_cfg,
)
]

+ 110
- 0
opencompass/configs/datasets/SciReasoner/mol_biotext_gen.py View File

@@ -0,0 +1,110 @@
# base config for LLM4Chem
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever, FixKRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import Mol_Instructions_postprocess_BioText, Mol_Instructions_Evaluator_BioText, \
Mol_Instructions_Dataset_BioText

TASKS = [
'chemical_disease_interaction_extraction',
'chemical_entity_recognition',
'chemical_protein_interaction_extraction',
'multi_choice_question',
'open_question',
'true_or_false_question'
]

reader_cfg = dict(input_columns=['input'], output_column='output')

mol_biotext_datasets = []
mini_mol_biotext_datasets = []

infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(role='HUMAN', prompt='{input}')
]),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer))

infer_cfg_true_or_false = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(role='HUMAN', prompt="{input}Your answer should start with 'Yes' or 'Maybe' or 'No'")
]),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer))

infer_cfg_CER = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin=[
# Optional but recommended: A system prompt for better instructions.
dict(role='SYSTEM', fallback_role='HUMAN',
prompt='There is a single choice question about chemistry. Answer the question directly.'),
# The placeholder is the ice_token string itself, used as a direct list element.
'</E>',
],
round=[
dict(role='HUMAN', prompt='Query: {input}'),
dict(role='BOT', prompt=''),
]
),
ice_token='</E>',
),
ice_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='Query: {input}'),
dict(role='BOT', prompt='{output}'),
]
)
),
retriever=dict(
type=FixKRetriever,
fix_id_list=[0, ], # 使用前1个示例
),
inferencer=dict(type=GenInferencer)
)

for task in TASKS:
eval_cfg = dict(
evaluator=dict(type=Mol_Instructions_Evaluator_BioText, task=task),
pred_postprocessor=dict(type=Mol_Instructions_postprocess_BioText, task=task),
dataset_postprocessor=dict(type=Mol_Instructions_postprocess_BioText, task=task),
)

if task == 'true_or_false_question':
apply_infer_cfg = infer_cfg_true_or_false
elif task == 'chemical_entity_recognition':
apply_infer_cfg = infer_cfg_CER
else:
apply_infer_cfg = infer_cfg

mol_biotext_datasets.append(
dict(
abbr=f'SciReasoner-mol_instruction_{task}',
type=Mol_Instructions_Dataset_BioText,
path='opencompass/SciReasoner-Mol_Instructions',
task=task,
reader_cfg=reader_cfg,
infer_cfg=apply_infer_cfg,
eval_cfg=eval_cfg)
)
mini_mol_biotext_datasets.append(
dict(
abbr=f'SciReasoner-mol_instruction_{task}-mini',
type=Mol_Instructions_Dataset_BioText,
path='opencompass/SciReasoner-Mol_Instructions',
task=task,
mini_set=True,
reader_cfg=reader_cfg,
infer_cfg=apply_infer_cfg,
eval_cfg=eval_cfg)
)

+ 63
- 0
opencompass/configs/datasets/SciReasoner/mol_molecule_gen.py View File

@@ -0,0 +1,63 @@
# base config for LLM4Chem
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import Mol_Instructions_postprocess_Mol, Mol_Instructions_Evaluator_Mol, Mol_Instructions_Dataset

TASKS = [
'property_prediction_str',
'description_guided_molecule_design',
'forward_reaction_prediction',
'retrosynthesis',
'reagent_prediction',
'molecular_description_generation'
]

reader_cfg = dict(input_columns=['input'], output_column='output')

mol_mol_datasets = []
mini_mol_mol_datasets = []

infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(role='HUMAN', prompt='{input}'),
]),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer))

for task in TASKS:
eval_cfg = dict(
evaluator=dict(type=Mol_Instructions_Evaluator_Mol, task=task),
pred_postprocessor=dict(type=Mol_Instructions_postprocess_Mol, task=task),
dataset_postprocessor=dict(type=Mol_Instructions_postprocess_Mol, task=task),
)

mol_mol_datasets.append(
dict(
abbr=f'SciReasoner-mol_instruction_{task}',
type=Mol_Instructions_Dataset,
path='opencompass/SciReasoner-Mol_Instructions',
task=task,
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg
)
)
mini_mol_mol_datasets.append(
dict(
abbr=f'SciReasoner-mol_instruction_{task}-mini',
type=Mol_Instructions_Dataset,
path='opencompass/SciReasoner-Mol_Instructions',
task=task,
mini_set=True,
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg
)
)




+ 83
- 0
opencompass/configs/datasets/SciReasoner/mol_protein_gen.py View File

@@ -0,0 +1,83 @@
# base config for LLM4Chem
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import (Mol_Instructions_postprocess_Protein, Mol_Instructions_Evaluator_Protein,
Mol_Instructions_Dataset, Mol_Instructions_postprocess_Protein_Design,
Mol_Instructions_Evaluator_Protein_Design, Mol_Instructions_Dataset_Protein_Design)

TASKS = [
'catalytic_activity',
'domain_motif',
'general_function',
'protein_function',
]

reader_cfg = dict(input_columns=['input'], output_column='output')

infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(role='HUMAN', prompt='{input}'),
]),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer))

eval_cfg = dict(
evaluator=dict(type=Mol_Instructions_Evaluator_Protein),
pred_postprocessor=dict(type=Mol_Instructions_postprocess_Protein),
dataset_postprocessor=dict(type=Mol_Instructions_postprocess_Protein),
)

eval_cfg_protein_design = dict(
evaluator=dict(type=Mol_Instructions_Evaluator_Protein_Design),
pred_postprocessor=dict(type=Mol_Instructions_postprocess_Protein_Design),
dataset_postprocessor=dict(type=Mol_Instructions_postprocess_Protein_Design),
)

mol_protein_datasets = []
mini_mol_protein_datasets = []

for task in TASKS:
mol_protein_datasets.append(
dict(
abbr=f'SciReasoner-mol_instruction_{task}',
type=Mol_Instructions_Dataset,
path='opencompass/SciReasoner-Mol_Instructions',
task=task,
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg))
mini_mol_protein_datasets.append(
dict(
abbr=f'SciReasoner-mol_instruction_{task}-mini',
type=Mol_Instructions_Dataset,
path='opencompass/SciReasoner-Mol_Instructions',
task=task,
mini_set=True,
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg))

task = 'protein_design'
mol_protein_datasets.append(
dict(
abbr='SciReasoner-mol_instruction_protein_design',
type=Mol_Instructions_Dataset_Protein_Design,
path='opencompass/SciReasoner-Mol_Instructions',
task=task,
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg_protein_design))
mini_mol_protein_datasets.append(
dict(
abbr='SciReasoner-mol_instruction_protein_design-mini',
type=Mol_Instructions_Dataset_Protein_Design,
path='opencompass/SciReasoner-Mol_Instructions',
task=task,
mini_set=True,
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg_protein_design))

+ 83
- 0
opencompass/configs/datasets/SciReasoner/opi_gen.py View File

@@ -0,0 +1,83 @@
# base config for opi
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import opi_postprocess, Opi_Evaluator, OpiDataset


all_datasets = []
mini_all_datasets = []

# Root directory where the datasets are located
root_dir = '/path/OPI_test'

subtask_dirs = [
'EC_number_CLEAN_EC_number_new',
'EC_number_CLEAN_EC_number_price',
'Fold_type_fold_type',
'Function_CASPSimilarSeq_function',
'Function_IDFilterSeq_function',
'Function_UniProtSeq_function',
'gName2Cancer_gene_name_to_cancer',
'GO_CASPSimilarSeq_go',
'GO_IDFilterSeq_go',
'GO_UniProtSeq_go',
'gSymbol2Cancer_gene_symbol_to_cancer',
'gSymbol2Tissue_gene_symbol_to_tissue',
'Keywords_CASPSimilarSeq_keywords',
'Keywords_IDFilterSeq_keywords',
'Keywords_UniProtSeq_keywords',
'Subcellular_localization_subcell_loc',
]

for subtask_name in subtask_dirs:
# Common configs for inference

reader_cfg = dict(input_columns=['input'], output_column='output')

infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(role='HUMAN', prompt='{input}'),
]),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(
type=GenInferencer,
)
)

# Extract high-level task from subdir name for the evaluator (e.g., 'EC_number')
task_type = subtask_name.split('_')[0]

eval_cfg = dict(
evaluator=dict(type=Opi_Evaluator, task=task_type),
pred_postprocessor=dict(type=opi_postprocess, task=task_type),
dataset_postprocessor=dict(type=opi_postprocess, task=task_type),
)

# Create the dataset dictionary for the current subtask
all_datasets.append(
dict(
abbr=f'SciReasoner-Opi_{subtask_name}',
type=OpiDataset,
path='opencompass/SciReasoner-OPI',
task=subtask_name,
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg
).copy()
)
mini_all_datasets.append(
dict(
abbr=f'SciReasoner-Opi_{subtask_name}-mini',
type=OpiDataset,
path='opencompass/SciReasoner-OPI',
task=subtask_name,
mini_set=True,
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg
).copy()
)

+ 99
- 0
opencompass/configs/datasets/SciReasoner/peer_gen.py View File

@@ -0,0 +1,99 @@
# base config for LLM4Chem
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import PEER_postprocess, PEER_Evaluator, PEER_Dataset, PEER_postprocess_float_compare, \
PEER_postprocess_default

TASKS = [
'solubility',
'stability',
'human_ppi',
'yeast_ppi',
]

reader_cfg = dict(input_columns=['input'], output_column='output')

infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(role='HUMAN', prompt='{input}.'),
]),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(
type=GenInferencer,
# max_out_len=2048,
)
)

eval_cfg = dict(
evaluator=dict(type=PEER_Evaluator),
pred_postprocessor=dict(type=PEER_postprocess),
dataset_postprocessor=dict(type=PEER_postprocess),
)

# use default postprocess to remain the original output for LLM judgement.
# PEER_postprocess will be used in the evaluation stage to compare the output with the ground truth as a fast comparison.
eval_llm_cfg = dict(
evaluator=dict(type=PEER_Evaluator,
openai_key='EMPTY', gpt_model='gpt-4.1-mini'),
pred_postprocessor=dict(type=PEER_postprocess_default),
dataset_postprocessor=dict(type=PEER_postprocess_default),
)

eval_stability_cfg = dict(
evaluator=dict(type=PEER_Evaluator, task='stability'),
pred_postprocessor=dict(type=PEER_postprocess_float_compare, compare_number=1),
dataset_postprocessor=dict(type=PEER_postprocess_float_compare, compare_number=1),
)

PEER_datasets = []
mini_PEER_datasets = []

for task in TASKS:
if task != 'stability':
PEER_datasets.append(
dict(
abbr=f'SciReasoner-PEER_{task}',
type=PEER_Dataset,
path='opencompass/SciReasoner-PEER',
task=task,
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_llm_cfg),
)
mini_PEER_datasets.append(
dict(
abbr=f'SciReasoner-PEER_{task}-mini',
type=PEER_Dataset,
path='opencompass/SciReasoner-PEER',
task=task,
mini_set=True,
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_llm_cfg),
)
else:
PEER_datasets.append(
dict(
abbr=f'SciReasoner-PEER_{task}',
type=PEER_Dataset,
path='opencompass/SciReasoner-PEER',
task=task,
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_stability_cfg),
)
mini_PEER_datasets.append(
dict(
abbr=f'SciReasoner-PEER_{task}-mini',
type=PEER_Dataset,
path='opencompass/SciReasoner-PEER',
task=task,
mini_set=True,
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_stability_cfg),
)

+ 74
- 0
opencompass/configs/datasets/SciReasoner/retrosynthesis_USPTO_gen.py View File

@@ -0,0 +1,74 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import RetrosynthesisEvaluator, Retrosynthesis_postprocess, LLM4ChemDataset

reader_cfg = dict(input_columns=['input'], output_column='output')



infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin=[
dict(role='SYSTEM', fallback_role='HUMAN', prompt=''),
'</E>',
],
round = [
# dict(role='HUMAN', prompt='Query: {input} /no_think'), # for Qwen3
dict(role='HUMAN', prompt='{input}'),
]
),
ice_token='</E>',
),
ice_template=dict(
type=PromptTemplate,
template = dict(
round = [
dict(role='HUMAN', prompt='{input}'),
]
)
),
# retriever: responsible for retrieving and formatting examples using ice_template
retriever=dict(
# type=FixKRetriever,
# fix_id_list=[0, 1, 2, 3, 4], # Use the first 5 examples
type=ZeroRetriever, # For our trained model, use zero-shot
),
inferencer=dict(
type=GenInferencer,
),
)

eval_cfg = dict(
evaluator=dict(type=RetrosynthesisEvaluator, beam_size=1, n_best=1),
pred_postprocessor=dict(type=Retrosynthesis_postprocess),
dataset_postprocessor=dict(type=Retrosynthesis_postprocess),
)

task = 'retrosynthesis_uspto50k'

Retrosynthesis_datasets = [
dict(
abbr='SciReasoner-retrosynthesis_USPTO_50K',
type=LLM4ChemDataset,
path='opencompass/SciReasoner-smol',
task=task,
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg= eval_cfg,
)
]
mini_Retrosynthesis_datasets = [
dict(
abbr='SciReasoner-retrosynthesis_USPTO_50K-mini',
type=LLM4ChemDataset,
path='opencompass/SciReasoner-smol',
task=task,
mini_set=True,
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg= eval_cfg,
)
]

+ 55
- 0
opencompass/configs/datasets/SciReasoner/scireasoner_gen.py View File

@@ -0,0 +1,55 @@
from mmengine.config import read_base

with read_base():
# scireasoner
from opencompass.configs.datasets.SciReasoner.bio_instruction_gen import bio_instruction_datasets, \
mini_bio_instruction_datasets
from opencompass.configs.datasets.SciReasoner.composition_material_gen import \
composition_material_datasets, mini_composition_material_datasets
from opencompass.configs.datasets.SciReasoner.GUE_gen import GUE_datasets, mini_GUE_datasets
from opencompass.configs.datasets.SciReasoner.smol_gen import all_datasets as smol_datasets, \
mini_all_datasets as mini_smol_datasets
from opencompass.configs.datasets.SciReasoner.retrosynthesis_USPTO_gen import \
Retrosynthesis_datasets as Retrosynthesis_uspto50k_datasets, \
mini_Retrosynthesis_datasets as mini_Retrosynthesis_uspto50k_datasets
from opencompass.configs.datasets.SciReasoner.LLM4Mat_gen import LLM4Mat_datasets, mini_LLM4Mat_datasets
from opencompass.configs.datasets.SciReasoner.bulk_modulus_material_gen import modulus_material_datasets, \
mini_modulus_material_datasets
from opencompass.configs.datasets.SciReasoner.mol_biotext_gen import mol_biotext_datasets, mini_mol_biotext_datasets
from opencompass.configs.datasets.SciReasoner.mol_molecule_gen import mol_mol_datasets, mini_mol_mol_datasets
from opencompass.configs.datasets.SciReasoner.mol_protein_gen import mol_protein_datasets, mini_mol_protein_datasets
from opencompass.configs.datasets.SciReasoner.opi_gen import all_datasets as opi_datasets, \
mini_all_datasets as mini_opi_datasets
from opencompass.configs.datasets.SciReasoner.peer_gen import PEER_datasets, mini_PEER_datasets
from opencompass.configs.datasets.SciReasoner.unconditional_material_gen import uncond_material_datasets, \
mini_uncond_material_datasets
from opencompass.configs.datasets.SciReasoner.unconditional_RNA_gen import uncond_RNA_datasets, \
mini_uncond_RNA_datasets
from opencompass.configs.datasets.SciReasoner.UPG import \
UPG_datasets as uncond_protein_datasets, mini_UPG_datasets as mini_uncond_protein_datasets
from opencompass.configs.datasets.SciReasoner.UMG import UMG_Datasets, mini_UMG_Datasets

# full eval set
scireasoner_datasets_full = bio_instruction_datasets + composition_material_datasets + GUE_datasets + smol_datasets + \
Retrosynthesis_uspto50k_datasets + LLM4Mat_datasets + modulus_material_datasets + \
mol_biotext_datasets + mol_mol_datasets + mol_protein_datasets + opi_datasets + PEER_datasets + \
uncond_material_datasets + uncond_RNA_datasets + uncond_protein_datasets + UMG_Datasets

# mini eval set
scireasoner_datasets_mini = mini_bio_instruction_datasets + mini_composition_material_datasets + mini_GUE_datasets + mini_smol_datasets + \
mini_Retrosynthesis_uspto50k_datasets + mini_LLM4Mat_datasets + mini_modulus_material_datasets + \
mini_mol_biotext_datasets + mini_mol_mol_datasets + mini_mol_protein_datasets + mini_opi_datasets + mini_PEER_datasets + \
mini_uncond_material_datasets + mini_uncond_RNA_datasets + mini_uncond_protein_datasets + mini_UMG_Datasets

# scireasoner_mini_datasets =\
# (
# # mini_bio_instruction_datasets +
# # mini_composition_material_datasets +
# # mini_modulus_material_datasets +
# # mini_GUE_datasets +
# # mini_LLM4Mat_datasets +
# # mini_mol_biotext_datasets + mini_mol_mol_datasets + mini_mol_protein_datasets + mini_opi_datasets + mini_Retrosynthesis_uspto50k_datasets + mini_smol_datasets
# # mini_UMG_Datasets + mini_uncond_material_datasets
# mini_uncond_RNA_datasets
# # mini_uncond_protein_datasets
# )

+ 120
- 0
opencompass/configs/datasets/SciReasoner/smol_gen.py View File

@@ -0,0 +1,120 @@
# base config for LLM4Chem
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import LLM4Chem_postprocess, LLM4Chem_Evaluator, LLM4ChemDataset

TASKS = (
'forward_synthesis',
'retrosynthesis',
'molecule_captioning',
'molecule_generation',
'name_conversion-i2f',
'name_conversion-i2s',
'name_conversion-s2f',
'name_conversion-s2i',
'property_prediction-esol',
'property_prediction-lipo',
'property_prediction-bbbp',
'property_prediction-clintox',
'property_prediction-hiv',
'property_prediction-sider',
)

TASKS_single = (
'property_prediction-esol',
'property_prediction-lipo',
'property_prediction-bbbp',
'property_prediction-clintox',
'property_prediction-hiv',
'property_prediction-sider',
)

TASK_TAGS = {
'forward_synthesis': ('<SMILES>', '</SMILES>'),
'retrosynthesis': ('<SMILES>', '</SMILES>'),
'molecule_generation': ('<SMILES>', '</SMILES>'),
'molecule_captioning': (None, None),
'name_conversion-i2f': ('<MOLFORMULA>', '</MOLFORMULA>'),
'name_conversion-i2s': ('<SMILES>', '</SMILES>'),
'name_conversion-s2f': ('<MOLFORMULA>', '</MOLFORMULA>'),
'name_conversion-s2i': ('<IUPAC>', '</IUPAC>'),
'property_prediction-esol': ('<NUMBER>', '</NUMBER>'),
'property_prediction-lipo': ('<NUMBER>', '</NUMBER>'),
'property_prediction-bbbp': ('<BOOLEAN>', '</BOOLEAN>'),
'property_prediction-clintox': ('<BOOLEAN>', '</BOOLEAN>'),
'property_prediction-hiv': ('<BOOLEAN>', '</BOOLEAN>'),
'property_prediction-sider': ('<BOOLEAN>', '</BOOLEAN>'),
}

all_datasets = []
mini_all_datasets = []

for task in TASKS:

reader_cfg = dict(input_columns=['input'], output_column='output')

infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin=[
# Optional but recommended: A system prompt for better instructions.
dict(role='SYSTEM', fallback_role='HUMAN', prompt=''),
# The placeholder is the ice_token string itself, used as a direct list element.
'</E>',
],
round=[
dict(role='HUMAN', prompt=f'{{input}}'),
]
),
ice_token='</E>',
),
ice_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{input}'),
dict(role='BOT', prompt='{output}'),
]
)
),
# retriever is responsible for retrieving examples and using ice_template to format them
retriever=dict(
# type=FixKRetriever,
# fix_id_list=[0, 1, 2, 3, 4], # Use the first 5 examples
type=ZeroRetriever, # For our trained model, use zero-shot
),
inferencer=dict(
type=GenInferencer,
))

eval_cfg = dict(
evaluator=dict(type=LLM4Chem_Evaluator, task=task),
pred_postprocessor=dict(type=LLM4Chem_postprocess, task=task),
dataset_postprocessor=dict(type=LLM4Chem_postprocess, task=task),
)

all_datasets.append(
dict(
abbr='SciReasoner-smol_' + task,
type=LLM4ChemDataset,
path='opencompass/SciReasoner-smol',
task=task,
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg
)
)
mini_all_datasets.append(
dict(
abbr='SciReasoner-smol_' + task + '-mini',
type=LLM4ChemDataset,
path='opencompass/SciReasoner-smol',
task=task,
mini_set=True,
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg
)
)

+ 51
- 0
opencompass/configs/datasets/SciReasoner/unconditional_RNA_gen.py View File

@@ -0,0 +1,51 @@
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.datasets import Uncond_RNA_Dataset, RNA_Evaluator, RNA_postprocessor

uncond_RNA_reader_cfg = dict(input_columns=['input'], output_column='output')


uncond_RNA_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(
role='HUMAN',
prompt='{input}',
),
],
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)

uncond_RNA_eval_cfg = dict(
evaluator=dict(type=RNA_Evaluator),
pred_postprocessor=dict(type=RNA_postprocessor),
)

uncond_RNA_datasets = [
dict(
abbr='SciReasoner-unconditional_RNA_generation',
type=Uncond_RNA_Dataset,
num=5000,
prompt='Please generate a novel RNA sequence. <rna>',
reader_cfg=uncond_RNA_reader_cfg,
infer_cfg=uncond_RNA_infer_cfg,
eval_cfg=uncond_RNA_eval_cfg,
)
]
mini_uncond_RNA_datasets = [
dict(
abbr='SciReasoner-unconditional_RNA_generation-mini',
type=Uncond_RNA_Dataset,
num=150,
prompt='Please generate a novel RNA sequence. <rna>',
reader_cfg=uncond_RNA_reader_cfg,
infer_cfg=uncond_RNA_infer_cfg,
eval_cfg=uncond_RNA_eval_cfg,
)
]

+ 50
- 0
opencompass/configs/datasets/SciReasoner/unconditional_material_gen.py View File

@@ -0,0 +1,50 @@
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.datasets import Uncond_material_Dataset, uncond_material_Evaluator, material_postprocessor

uncond_material_reader_cfg = dict(input_columns=['input'], output_column='output')

uncond_material_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(
role='HUMAN',
prompt='{input}',
),
],
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)

uncond_material_eval_cfg = dict(
evaluator=dict(type=uncond_material_Evaluator),
pred_postprocessor=dict(type=material_postprocessor),
)

uncond_material_datasets = [
dict(
abbr='SciReasoner-unconditional_material_generation',
type=Uncond_material_Dataset,
num=5000,
prompt='Produce a material that has any bulk modulus or composition',
reader_cfg=uncond_material_reader_cfg,
infer_cfg=uncond_material_infer_cfg,
eval_cfg=uncond_material_eval_cfg,
)
]
mini_uncond_material_datasets = [
dict(
abbr='SciReasoner-unconditional_material_generation-mini',
type=Uncond_material_Dataset,
num=150,
prompt='Produce a material that has any bulk modulus or composition',
reader_cfg=uncond_material_reader_cfg,
infer_cfg=uncond_material_infer_cfg,
eval_cfg=uncond_material_eval_cfg,
)
]

+ 210
- 0
opencompass/datasets/SciReasoner/GUE.py View File

@@ -0,0 +1,210 @@
# flake8: noqa

import json
import os
import re
from typing import Union

from datasets import Dataset, DatasetDict
from huggingface_hub import hf_hub_download
from sklearn.metrics import matthews_corrcoef

from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
from opencompass.utils import get_data_path


@LOAD_DATASET.register_module()
class GUE_Dataset(BaseDataset):

@staticmethod
def load(path, task, mini_set=False):

# if (hf_hub is True):
# repo_id = test_path.split('/')[0] + '/' + test_path.split('/')[1]
# train_path = train_path.split(repo_id + '/')[1]
# test_path = test_path.split(repo_id + '/')[1]
#
# train_path = hf_hub_download(repo_id,
# train_path,
# repo_type='dataset')
# test_path = hf_hub_download(repo_id,
# test_path,
# repo_type='dataset')

path = get_data_path(path)
train_path = os.path.join(path, f'{task}/dev/data.json')
test_path = os.path.join(path, f'{task}/test/data.json')

with open(train_path, 'r', encoding='utf-8') as f:
train_data = json.load(f)
with open(test_path, 'r', encoding='utf-8') as f:
test_data = json.load(f)

def augment_output(data):
for item in data:
label = item.get('meta_data', {}).get('label', '')
item['output'] += f' The prediction result is {label}.'
return data

train_data = augment_output(train_data[:5])
test_data = augment_output(test_data)
if mini_set:
import random
random.seed(1024)
test_data = random.sample(test_data, 150)
random.seed()

dataset = DatasetDict({
'train': Dataset.from_list(train_data),
'test': Dataset.from_list(test_data)
})
return dataset


def remove_think_tags(text: str) -> str:
if '<think>' not in text:
return text
if '</think>' not in text:
return ''
return re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)


@TEXT_POSTPROCESSORS.register_module()
def GUE_postprocessor(text: Union[str, None]) -> str:
if not isinstance(text, str):
return ''

text = text.strip()
text = remove_think_tags(text)

if text == '':
return ''

match = re.search(r'\bThe prediction result is\s+(positive|negative)\b',
text, re.IGNORECASE)
if match:
return match.group(1).lower()

positive_patterns = [
r'\bpositive\b',
r'\bpositively\b',
r'\bpresence\b',
r'\bdetected\b',
r'\bidentified\b',
r'\bidentifiable\b',
r'\bfound\b',
r'\byes\b',
r'\blocated\b',
r'\bdetectable\b',
r'\bobservable\b',
r'\bevident\b',
r'\babsolutely\b',
r'\baffirmative\b',
r'\bcan\b',
r'\baffirm\b',
r'\bconfirm\b',
r'\bconfirms\b',
r'\breveals\b',
r'\bexistence\b',
r'\bcertainly\b',
r'\bconsistent\b',
r'\brecognizable\b',
r'\bshows core\b',
r'\bshows promoter\b',
r'\bshows characteristic\b',
r'\bevidenced by\b',
r'\bseeing characteristic patterns\b',
r'\bincludes\b',
r'\bcontains sequences\b',
r'\bexhibits clear\b',
r'\bcontains transcription\b',
r'\bexhibits sequences\b',
r'\bclearly contains\b',
r'\brecognized\b',
r'\bexhibits features\b',
r'\bcontains regulatory\b',
r'\bshows clear\b',
r'\bdisplays\b',
r'\bdefinitely has\b',
r'\bexhibits patterns\b',
r'\bclear evidence\b',
r'\bcontains a\b',
r'\byep\b',
r'\bcontains sites\b',
r'\bshows sequences\b',
]

negative_patterns = [
r'\bnegative\b',
r'\bno\b',
r'\babsence\b',
r'\bnot\b',
r'\bcannot\b',
r'\bfails\b',
r'\babsent\b',
r'\blacks\b',
]

for pattern in negative_patterns:
if re.search(pattern, text, re.IGNORECASE):
return 'negative'

for pattern in positive_patterns:
if re.search(pattern, text, re.IGNORECASE):
return 'positive'

return ''


class GUE_Evaluator(BaseEvaluator):

def score(self, predictions, references):

def normalize(label):
label = label.strip().lower()
if label == 'positive':
return 1
elif label == 'negative':
return 0
else:
return None

total_count = len(predictions)

if isinstance(predictions[0], list):
predictions = [p[0] for p in predictions]

pred_bin_all = [
1 if p.strip().lower() == 'positive' else 0 for p in predictions
]
ref_bin_all = [
1 if r.strip().lower() == 'positive' else 0 for r in references
]
mcc_all = matthews_corrcoef(ref_bin_all, pred_bin_all)

filtered_pred = []
filtered_ref = []
skipped = 0

for p, r in zip(predictions, references):
p_norm = normalize(p)
r_norm = normalize(r)
if p_norm is None or r_norm is None:
skipped += 1
continue
filtered_pred.append(p_norm)
filtered_ref.append(r_norm)

if filtered_pred:
mcc_filtered = matthews_corrcoef(filtered_ref, filtered_pred)
else:
mcc_filtered = 0.0

return {
'matthews_correlation_all': mcc_all * 100,
'matthews_correlation_filtered': mcc_filtered * 100,
'non_pos_neg_count': skipped,
'total_count': total_count
}

+ 12
- 0
opencompass/datasets/SciReasoner/LLM4Chem/__init__.py View File

@@ -0,0 +1,12 @@
from .config import TASK_TAGS as LLM4Chem_TASK_TAGS # noqa: F401, F403
from .config import TASKS as LLM4Chem_TASKS # noqa: F401, F403
from .config import \
TASKS_GENERATION_SETTINGS as \
LLM4Chem_TASKS_GENERATION_SETTINGS # noqa: F401, F403
from .evaluator import LLM4Chem_Evaluator # noqa: F401
from .evaluator import LLM4Chem_postprocess # noqa: F401
from .evaluator import LLM4ChemDataset # noqa: F401, F403
from .retrosynthesis_evaluator import \
Retrosynthesis_postprocess # noqa: F401, F403
from .retrosynthesis_evaluator import \
RetrosynthesisEvaluator # noqa: F401, F403

+ 166
- 0
opencompass/datasets/SciReasoner/LLM4Chem/config.py View File

@@ -0,0 +1,166 @@
TASKS = (
'forward_synthesis',
'retrosynthesis',
'molecule_captioning',
'molecule_generation',
'name_conversion-i2f',
'name_conversion-i2s',
'name_conversion-s2f',
'name_conversion-s2i',
'property_prediction-esol',
'property_prediction-lipo',
'property_prediction-bbbp',
'property_prediction-clintox',
'property_prediction-hiv',
'property_prediction-sider',
)

DEFAULT_MAX_INPUT_TOKENS = 512
DEFAULT_MAX_NEW_TOKENS = 1024

TASKS_GENERATION_SETTINGS = {
'forward_synthesis': {
'generation_kargs': {
'num_return_sequences': 5,
'num_beams': 8
},
},
'retrosynthesis': {
'max_new_tokens': 960,
'generation_kargs': {
'num_return_sequences': 10,
'num_beams': 13
},
},
'molecule_captioning': {
'generation_kargs': {
'num_return_sequences': 1,
'num_beams': 4
},
},
'molecule_generation': {
'generation_kargs': {
'num_return_sequences': 5,
'num_beams': 8
},
},
'name_conversion-i2f': {
'max_new_tokens': 20,
'generation_kargs': {
'num_return_sequences': 3,
'num_beams': 6
},
},
'name_conversion-i2s': {
'generation_kargs': {
'num_return_sequences': 5,
'num_beams': 8
},
},
'name_conversion-s2f': {
'max_new_tokens': 20,
'generation_kargs': {
'num_return_sequences': 3,
'num_beams': 6
},
},
'name_conversion-s2i': {
'generation_kargs': {
'num_return_sequences': 5,
'num_beams': 8
},
},
'property_prediction-esol': {
'batch_size': 16,
'max_new_tokens': 20,
'generation_kargs': {
'num_return_sequences': 1,
'num_beams': 4,
},
},
'property_prediction-lipo': {
'batch_size': 16,
'max_new_tokens': 20,
'generation_kargs': {
'num_return_sequences': 1,
'num_beams': 4,
},
},
'property_prediction-bbbp': {
'batch_size': 16,
'max_new_tokens': 20,
'generation_kargs': {
'num_return_sequences': 1,
'num_beams': 4,
},
},
'property_prediction-clintox': {
'batch_size': 16,
'max_new_tokens': 20,
'generation_kargs': {
'num_return_sequences': 1,
'num_beams': 4,
},
},
'property_prediction-hiv': {
'batch_size': 16,
'max_new_tokens': 20,
'generation_kargs': {
'num_return_sequences': 1,
'num_beams': 4,
},
},
'property_prediction-sider': {
'batch_size': 16,
'max_new_tokens': 20,
'generation_kargs': {
'num_return_sequences': 1,
'num_beams': 4,
},
},
}

TASK_TAGS = {
'forward_synthesis': ('<SMILES>', '</SMILES>'),
'retrosynthesis': ('<SMILES>', '</SMILES>'),
'molecule_generation': ('<SMILES>', '</SMILES>'),
'molecule_captioning': (None, None),
'name_conversion-i2f': ('<MOLFORMULA>', '</MOLFORMULA>'),
'name_conversion-i2s': ('<SMILES>', '</SMILES>'),
'name_conversion-s2f': ('<MOLFORMULA>', '</MOLFORMULA>'),
'name_conversion-s2i': ('<IUPAC>', '</IUPAC>'),
'property_prediction-esol': ('<NUMBER>', '</NUMBER>'),
'property_prediction-lipo': ('<NUMBER>', '</NUMBER>'),
'property_prediction-bbbp': ('<BOOLEAN>', '</BOOLEAN>'),
'property_prediction-clintox': ('<BOOLEAN>', '</BOOLEAN>'),
'property_prediction-hiv': ('<BOOLEAN>', '</BOOLEAN>'),
'property_prediction-sider': ('<BOOLEAN>', '</BOOLEAN>'),
}

# These tasks output SMILES, where there may be semicolons
# that separate different parts. To facilitate evaluation,
# each semicolon is replaced by a dot.
TASKS_WITH_SEMICOLON_REPLACE = (
'forward_synthesis',
'retrosynthesis',
'molecule_generation',
'name_conversion-i2s',
)

# For these tasks, one input might have multiple gold answers,
# so the gold answer should be directly obtained from the dataset
# instead of directly using the gold domain of each sample.
TASKS_WITH_READING_GOLD_FROM_DATASET = ('forward_synthesis', 'retrosynthesis',
'molecule_generation',
'molecule_captioning',
'name_conversion-i2f',
'name_conversion-i2s',
'name_conversion-s2f',
'name_conversion-s2i')

BASE_MODELS = {
'osunlp/LlaSMol-Mistral-7B': 'mistralai/Mistral-7B-v0.1',
'osunlp/LlaSMol-Galactica-6.7B': 'facebook/galactica-6.7b',
'osunlp/LlaSMol-Llama2-7B': 'meta-llama/Llama-2-7b-hf',
'osunlp/LlaSMol-CodeLlama-7B': 'codellama/CodeLlama-7b-hf',
}

+ 228
- 0
opencompass/datasets/SciReasoner/LLM4Chem/evaluator.py View File

@@ -0,0 +1,228 @@
# flake8: noqa
# NC-I2S NC-S2I task
# https://github.com/OSU-NLP-Group/LLM4Chem

import json
import os
import re

from datasets import Dataset, DatasetDict
from huggingface_hub import hf_hub_download

from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
from opencompass.utils import get_data_path

from .config import TASK_TAGS, TASKS_WITH_SEMICOLON_REPLACE
from .utils.metrics import (calculate_boolean_metrics,
calculate_formula_metrics,
calculate_number_metrics, calculate_smiles_metrics,
calculate_text_metrics)


@LOAD_DATASET.register_module()
class LLM4ChemDataset(BaseDataset):

@staticmethod
def load(path, task, max_cut=-1, mini_set=False, hf_hub=False):

# if (hf_hub is True):
# # load from huggingface hub
# train_data = []
# repo_id = test_path.split('/')[0] + '/' + test_path.split('/')[1]
# train_path = train_path.split(repo_id + '/')[1]
# test_path = test_path.split(repo_id + '/')[1]
#
# train_path = hf_hub_download(repo_id,
# train_path,
# repo_type='dataset')
# test_path = hf_hub_download(repo_id,
# test_path,
# repo_type='dataset')

path = get_data_path(path)
train_path = os.path.join(path, f'{task}/dev/data.json')
test_path = os.path.join(path, f'{task}/test/data.json')

with open(train_path, 'r', encoding='utf-8') as f:
train_data = json.load(f)
with open(test_path, 'r', encoding='utf-8') as f:
test_data = json.load(f)

train_data = train_data[:5]
# Limit the dataset to 5 samples for testing purposes

if (max_cut != -1):
test_data = test_data[:max_cut]
if mini_set:
import random
random.seed(1024)
test_data = random.sample(test_data, 50)
random.seed()

dataset = DatasetDict({
'train': Dataset.from_list(train_data),
'test': Dataset.from_list(test_data)
})
return dataset


def extract_answer_part(outputs, left_tag, right_tag, mode='tag'):
assert mode in ('tag', 'direct')

assert isinstance(outputs, list)
answers = []
for text in outputs:
if mode == 'direct' or (left_tag is None and right_tag is None):
text = text.replace('<unk>', '').replace('</s>', '').strip()
answers.append(text.strip())
continue

left_tag_pos = text.find(left_tag)
if left_tag_pos == -1:
answers.append('')
continue
right_tag_pos = text.find(right_tag)
if right_tag_pos == -1:
answers.append('')
continue
text = text[left_tag_pos + len(left_tag):right_tag_pos].strip()
answers.append(text)
return answers


@TEXT_POSTPROCESSORS.register_module('LLM4Chem_postprocess')
def LLM4Chem_postprocess(text, task, *args, **kwargs):
# 删除 <think> </think> 里的内容
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
replace_semicolon = task in TASKS_WITH_SEMICOLON_REPLACE
pred = extract_answer_part([text], *(TASK_TAGS[task]), mode='tag')[0]
# task in TASKS_WITH_SEMICOLON_REPLACE needs semicolon
# replaced with a period
if replace_semicolon:
pred = pred.replace(';', '.')
# no matched tag
if pred == '':
tag = TASK_TAGS[task][0]

if (tag == '<BOOLEAN>'):
# 找到 text 的最后一个 yes/true/no/false,不区分大小写
ans = re.findall(r'\b(?:yes|true|no|false)\b', text, re.IGNORECASE)
if ans:
# if ans[-1] 是 yes/true
if ans[-1].lower() in ('yes', 'true'):
return 'Yes'
else:
return 'No'
else:
return ''

if (tag == '<NUMBER>'):
# 找到 text 的最后一个数字
# 去掉 text 里 <SMILES> </SMILES> 里的内容
text_2 = re.sub(r'<SMILES>.*?</SMILES>', '', text, flags=re.DOTALL)
ans = re.findall(r'-?\d*\.\d+|-?\d+', text_2)
if ans:
return ans[-1]
else:
return ''

if (tag == '<MOLFORMULA>'):
# 找到 text 的最后一个化学式
ans = re.findall(
r'[\[\(]?[A-Z][a-z]?\d*(?:\([A-Za-z0-9]+\)\d*)?[\]\)]?'
r'(?:[A-Z][a-z]?\d*|\([^\)]+\)\d*|\[[^\]]+\]\d*)'
r'*(?:[+-]{1,2})?(?:·\d*[A-Z][a-z]?\d*)*', text)
if ans:
return ans[-1]
else:
return ''

# print(f"prediction: {pred}")
return pred


class LLM4Chem_Evaluator(BaseEvaluator):

def __init__(self, task, *args, **kwargs):
super().__init__(*args, **kwargs)
self.task = task

def score(self, predictions, references):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
}
if not isinstance(predictions[0], list):
predictions = [[pred] for pred in predictions]
if not isinstance(references[0], list):
references = [[ref] for ref in references]

task = self.task
pred_list = predictions
gold_list = references

if task in ('property_prediction-esol', 'property_prediction-lipo',
'property_prediction-bbbp', 'property_prediction-clintox',
'property_prediction-hiv', 'property_prediction-sider'):
# set pred_list to [length * 1]
pred_list = [[pred[0]] for pred in pred_list]

if task in ('forward_synthesis', 'molecule_generation',
'name_conversion-i2s'):
r = calculate_smiles_metrics(pred_list, gold_list)
elif task in ('retrosynthesis', ):
r = calculate_smiles_metrics(pred_list,
gold_list,
metrics=('exact_match', 'fingerprint',
'multiple_match'))
elif task in ('molecule_captioning', ):
r = calculate_text_metrics(
pred_list,
gold_list,
text_model='allenai/scibert_scivocab_uncased',
text_trunc_length=2048,
)
elif task in ('name_conversion-i2f', 'name_conversion-s2f'):
r = calculate_formula_metrics(pred_list,
gold_list,
metrics=('element_match', ))
elif task in ('name_conversion-s2i', ):
r = calculate_formula_metrics(pred_list,
gold_list,
metrics=('split_match', ))
elif task in ('property_prediction-esol', 'property_prediction-lipo'):
r = calculate_number_metrics(pred_list, gold_list)
elif task in ('property_prediction-bbbp',
'property_prediction-clintox', 'property_prediction-hiv',
'property_prediction-sider'):
r = calculate_boolean_metrics(pred_list, gold_list)
else:
raise ValueError(task)

if 'num_t1_exact_match' in r and 'num_all' in r:
# 100%, 2 位小数
r['top1_exact_match'] = round(
r['num_t1_exact_match'] / r['num_all'] * 100, 2)
if 'num_t5_exact_match' in r and 'num_all' in r:
# 100%, 2 位小数
r['top5_exact_match'] = round(
r['num_t5_exact_match'] / r['num_all'] * 100, 2)
if 'num_t1_ele_match' in r and 'num_all' in r:
# 100%, 2 位小数
r['top1_ele_match'] = round(
r['num_t1_ele_match'] / r['num_all'] * 100, 2)
if 'num_correct' in r and 'num_all' in r:
r['accuracy'] = round(r['num_correct'] / r['num_all'] * 100, 2)
if 'num_t1_split_match' in r and 'num_all' in r:
# 100%, 2 位小数
r['top1_split_match'] = round(
r['num_t1_split_match'] / r['num_all'] * 100, 2)
if 'num_t5_split_match' in r and 'num_all' in r:
# 100%, 2 位小数
r['top5_split_match'] = round(
r['num_t5_split_match'] / r['num_all'] * 100, 2)

return r

+ 449
- 0
opencompass/datasets/SciReasoner/LLM4Chem/retrosynthesis_evaluator.py View File

@@ -0,0 +1,449 @@
# dataset: USPTO-50K
# https://github.com/otori-bird/retrosynthesis
# task : retrosynthesis prediction
import multiprocessing
import re
from functools import partial
from typing import Union

try:
from rdkit import Chem, RDLogger
except Exception:
Chem, RDLogger = None, None

from tqdm import tqdm

from opencompass.openicl import BaseEvaluator
from opencompass.registry import TEXT_POSTPROCESSORS

# 关闭 RDKit 的冗余日志输出
# lg = RDLogger.logger()
# lg.setLevel(RDLogger.CRITICAL)

# ----------------------------------------------------------------------
# 1. 复用原脚本的核心函数
# 我们将这些函数放在文件顶部,以便在 Evaluator 中调用
# ----------------------------------------------------------------------


def smi_tokenizer(smi):
"""
Tokenizes a SMILES string using a regular expression.
Note: This function was in the original script but is not directly used
in the evaluation logic. It's included for completeness.
"""
pattern = (r'(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)'
r'|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])')
regex = re.compile(pattern)
tokens = [token for token in regex.findall(smi)]
assert smi == ''.join(tokens), f'SMILES tokenization failed for: {smi}'
return ' '.join(tokens)


def canonicalize_smiles_clear_map(smiles, synthon=False, return_max_frag=True):
"""
Canonicalizes a SMILES string, clears atom map numbers, and optionally
returns the largest fragment.

Args:
smiles (str): The SMILES string to process.
synthon (bool): Whether to skip the sanitization step.
return_max_frag (bool): If True, returns a tuple of
(full_smiles, max_frag_smiles).
Otherwise, returns only the full SMILES.

Returns:
A tuple (str, str) or a single str depending on return_max_frag.
"""
mol = Chem.MolFromSmiles(smiles, sanitize=not synthon)
if mol is not None:
# Clear atom map numbers
for atom in mol.GetAtoms():
if atom.HasProp('molAtomMapNumber'):
atom.ClearProp('molAtomMapNumber')
try:
smi = Chem.MolToSmiles(mol, isomericSmiles=True)
except Exception:
# Handle cases where MolToSmiles fails
if return_max_frag:
return '', ''
else:
return ''

if return_max_frag:
sub_smi_list = smi.split('.')
if len(sub_smi_list) > 1:
# Find the largest fragment
sub_mols = [(s, Chem.MolFromSmiles(s, sanitize=not synthon))
for s in sub_smi_list]
sub_mol_sizes = [(smi, len(m.GetAtoms()))
for smi, m in sub_mols if m is not None]
if sub_mol_sizes:
# Sort fragments by size and return the largest one
max_frag_smi = sorted(sub_mol_sizes,
key=lambda x: x[1],
reverse=True)[0][0]
# Recursively canonicalize the largest fragment
return smi, canonicalize_smiles_clear_map(
max_frag_smi, synthon=synthon, return_max_frag=False)
else:
return smi, ''
else:
# If no fragments, the molecule is its own largest fragment
return smi, smi
else:
return smi
else:
# If the molecule is invalid from the start
if return_max_frag:
return '', ''
else:
return ''


def compute_rank(prediction_group,
beam_size,
n_best,
score_alpha=1.0,
raw=False):
"""
Ranks predictions for a single sample across multiple augmentations.

Args:
prediction_group (list): A 2D list of predictions for one sample,
shaped [augmentation, beam_size].
Each prediction is a tuple
(full_smi, max_frag_smi).
beam_size (int): The number of beams used in generation.
n_best (int): The number of top predictions to consider.
score_alpha (float): The scoring decay factor.
raw (bool): If True, assumes no test augmentation (augmentation=1).

Returns:
A tuple containing:
- A sorted list of ranked results: [(prediction_tuple, score), ...].
- A list of invalid rates for each beam position.
"""
rank = {}
highest_pos = {}
invalid_rates = [0] * beam_size

if raw:
# No test augmentation, len(prediction_group) is 1
assert len(prediction_group) == 1, 'Raw mode requires augmentation=1'
aug_predictions = prediction_group[0]
for k in range(len(aug_predictions)):
pred_tuple = aug_predictions[k]
if not pred_tuple or not pred_tuple[0]:
invalid_rates[k] += 1
continue
# Use rank as score for raw mode, lower is better
rank[pred_tuple] = 1 / (score_alpha * k + 1)
else:
# With test augmentation
for aug_predictions in prediction_group:
valid_k = [] # Store valid (prediction_tuple, original_beam_index)
for k, pred_tuple in enumerate(aug_predictions):
if pred_tuple and pred_tuple[0]:
valid_k.append((pred_tuple, k))
else:
invalid_rates[k] += 1

# Deduplicate predictions within this augmentation run
seen = set()
deduped_preds = []
for pred_tuple, k in valid_k:
if pred_tuple not in seen:
seen.add(pred_tuple)
deduped_preds.append((pred_tuple, k))

# Update ranks and highest positions
for k, (pred_tuple, _) in enumerate(deduped_preds):
score = 1 / (score_alpha * k + 1)
rank[pred_tuple] = rank.get(pred_tuple, 0) + score
highest_pos[pred_tuple] = min(
k, highest_pos.get(pred_tuple, float('inf')))

# Combine scores for final ranking
# The -1e8 term heavily penalizes lower ranks,
# ensuring highest position is prioritized
final_ranked_list = []
if not raw:
for key, score in rank.items():
final_ranked_list.append((key, score + highest_pos[key] * -1e8))
else:
for key, score in rank.items():
final_ranked_list.append((key, score))

final_ranked_list.sort(key=lambda x: x[1], reverse=True)
return final_ranked_list[:n_best], invalid_rates


# ----------------------------------------------------------------------
# 定义 Postprocessor (后处理器)
# ----------------------------------------------------------------------


@TEXT_POSTPROCESSORS.register_module()
def Retrosynthesis_postprocess(text: Union[str, None]) -> str:
"""
从模型的原始输出中提取SMILES字符串。

此函数会查找并返回被 <SMILES> 和 </SMILES> 标签包裹的内容。
"""
# 检查输入是否为字符串,如果不是则返回空字符串,以提高代码健壮性
if not isinstance(text, str):
return ''

# 删除 <think> </think> 标签及其内容
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)

# 使用正则表达式搜索SMILES标签内的内容
# re.search() 会查找字符串中首次出现该模式的位置
# (.*?) 是一个非贪婪捕获组,用于捕获两个标签之间的所有字符
# re.DOTALL 标志让 '.' 可以匹配包括换行符在内的任意字符
matches = re.findall(r'<SMILES>(.*?)</SMILES>', text, re.DOTALL)

if matches:
# 如果找到匹配项,group(1)会返回第一个捕获组的内容
# .strip() 用于去除捕获内容前后可能存在的多余空格或换行
return matches[-1].strip()
else:
# 如果没有找到匹配的模式,返回一个空字符串
return ''


# ----------------------------------------------------------------------
# 定义 Evaluator (评估器) - 这是修改的核心
# ----------------------------------------------------------------------


class RetrosynthesisEvaluator(BaseEvaluator):
"""
Evaluator for retrosynthesis models. It calculates Top-K accuracy and
Max-Fragment accuracy based on SMILES string comparisons.
"""

def __init__(self,
beam_size=10,
n_best=10,
augmentation=1,
score_alpha=1.0,
synthon=False,
process_number=None):
super().__init__()
self.beam_size = beam_size
self.n_best = n_best
self.augmentation = augmentation
self.score_alpha = score_alpha
self.synthon = synthon
self.process_number = process_number if process_number is not None \
else multiprocessing.cpu_count()
print(f'Evaluator initialized with: beam_size={beam_size},'
f' n_best={n_best}, augmentation={augmentation},'
f' processes={self.process_number}')

def score(self, predictions, references):
"""
Calculates retrosynthesis prediction accuracy.

Args:
predictions (list): A flat list of predicted SMILES strings.
Shape: [data_size * augmentation * beam_size].
references (list): A list of ground truth SMILES strings.
Shape: [data_size].

Returns:
dict: A dictionary containing evaluation metrics.
"""
# flat predictions -> 1D
print(f'len of predictions: {len(predictions)}')
print(f'predictions[0]: {predictions[0]}')
if isinstance(predictions, list):
# Ensure predictions are a flat list
if isinstance(predictions[0], list):
predictions = [x for y in predictions for x in y]
else:
pass

# print(f"predictions = {predictions} \nreferences = {references}")
data_size = len(references)
expected_preds_len = data_size * self.augmentation * self.beam_size
if len(predictions) != expected_preds_len:
return {
'error':
f'Length of predictions ({len(predictions)})'
f' does not match expected length ({expected_preds_len})'
}

print('Canonicalizing predictions and references...')
# Create a partial function for multiprocessing
map_func = partial(canonicalize_smiles_clear_map,
synthon=self.synthon,
return_max_frag=True)

with multiprocessing.Pool(self.process_number) as pool:
can_predictions = list(
tqdm(pool.imap(map_func, predictions),
total=len(predictions),
desc='Canonicalizing Predictions'))
can_references = list(
tqdm(pool.imap(map_func, references),
total=len(references),
desc='Canonicalizing References'))

# Reshape the flat predictions list into a 3D list:
# data_size x augmentation x beam_size
predictions_reshaped = [[] for _ in range(data_size)]
for i in range(data_size):
for j in range(self.augmentation):
start_idx = (i * self.augmentation + j) * self.beam_size
end_idx = start_idx + self.beam_size
predictions_reshaped[i].append(
can_predictions[start_idx:end_idx])

# Initialize metric counters
accuracy = [0] * self.n_best
max_frag_accuracy = [0] * self.n_best
total_invalid_rates = [0] * self.beam_size

print('Computing ranks and accuracy...')
is_raw_mode = (self.augmentation == 1)

for i in tqdm(range(data_size), desc='Evaluating Samples'):
prediction_group = predictions_reshaped[i]
target_smi_tuple = can_references[i]

# Skip evaluation for this sample if the ground truth is invalid
if not target_smi_tuple or not target_smi_tuple[0]:
continue

ranked_results, invalid_rate = compute_rank(
prediction_group,
beam_size=self.beam_size,
n_best=self.n_best,
score_alpha=self.score_alpha,
raw=is_raw_mode)

# Aggregate invalid rates
for j in range(len(invalid_rate)):
total_invalid_rates[j] += invalid_rate[j]

# Check for full molecule match
found_match = False
for j, (pred_tuple, _) in enumerate(ranked_results):
if not found_match and pred_tuple[0] == target_smi_tuple[0]:
for k in range(j, self.n_best):
accuracy[k] += 1
found_match = True # Ensure we only count the first match

# Check for max fragment match
found_frag_match = False
for j, (pred_tuple, _) in enumerate(ranked_results):
# Ensure max fragment is not empty before comparing
if not found_frag_match and pred_tuple[1] and pred_tuple[
1] == target_smi_tuple[1]:
for k in range(j, self.n_best):
max_frag_accuracy[k] += 1
found_frag_match = True

# Calculate final results
results = {}
# Usually, Top-1, 3, 5, 10 are reported
for i in [k - 1 for k in [1, 3, 5, 10] if k <= self.n_best]:
k = i + 1
results[f'Top-{k} Accuracy'] = accuracy[i] / data_size * 100
results[f'Top-{k} MaxFrag Accuracy'] = max_frag_accuracy[
i] / data_size * 100

# Report the invalid rate at the first beam position
if self.beam_size > 0:
total_predictions_at_beam1 = data_size * self.augmentation
results['Invalid SMILES Rate (at beam 1)'] = (
total_invalid_rates[0] / total_predictions_at_beam1 * 100) \
if total_predictions_at_beam1 > 0 else 0

return results


# Example Usage
if __name__ == '__main__':
# --- Mock Data Generation ---
# This simulates the kind of data the evaluator would receive.

# Configuration
BEAM_SIZE = 5
N_BEST = 5
AUGMENTATION = 3 # Use > 1 to test augmentation logic
DATA_SIZE = 100

# Ground truth molecules (references)
mock_references = [
'CCO.CN', # Correct: CCO is largest fragment
'c1ccccc1CC(=O)O', # Correct
'INVALID_SMILES', # An invalid reference SMILES
'CC(C)C(=O)N[C@@H](C)C(=O)O' # Chiral molecule
] * (DATA_SIZE // 4)

# Simulated model predictions (a flat list)
mock_predictions = []
for i in range(DATA_SIZE):
target = mock_references[i]
for _ in range(AUGMENTATION):
# For each augmentation, create a beam of predictions
beam = []
# Make the first beam prediction correct for 20% of cases
if i % 5 == 0:
beam.append(target)
else:
beam.append('CC(C)=O') # A common incorrect prediction

# Add some other variations and invalid SMILES
beam.append('c1cnccc1')
beam.append('completely_invalid') # Invalid
# Add a prediction that only matches the largest fragment
beam.append('CCO')
# Fill the rest of the beam
beam.extend(['C'] * (BEAM_SIZE - len(beam)))

mock_predictions.extend(beam)

print(f'Generated {len(mock_predictions)} '
f'predictions for {len(mock_references)} references.')

# --- Evaluation ---
evaluator = RetrosynthesisEvaluator(
beam_size=BEAM_SIZE,
n_best=N_BEST,
augmentation=AUGMENTATION,
process_number=4 # Use 4 cores for the example
)

results = evaluator.score(mock_predictions, mock_references)

# --- Print Results ---
print('\n--- Evaluation Results ---')
for key, value in results.items():
print(f'{key}: {value:.2f}%')
print('--------------------------\n')

# --- Test RAW mode (no augmentation) ---
print('Testing RAW mode (augmentation=1)...')
evaluator_raw = RetrosynthesisEvaluator(
beam_size=BEAM_SIZE,
n_best=N_BEST,
augmentation=1, # RAW mode
process_number=4)
# Select only the first "augmentation" set of predictions
mock_predictions_raw = []
for i in range(DATA_SIZE):
start_idx = i * AUGMENTATION * BEAM_SIZE
end_idx = start_idx + BEAM_SIZE
mock_predictions_raw.extend(mock_predictions[start_idx:end_idx])

results_raw = evaluator_raw.score(mock_predictions_raw, mock_references)

print('\n--- RAW Mode Evaluation Results ---')
for key, value in results_raw.items():
print(f'{key}: {value:.2f}%')
print('---------------------------------\n')

+ 1
- 0
opencompass/datasets/SciReasoner/LLM4Chem/utils/__input__.py View File

@@ -0,0 +1 @@
import smiles_canonicalization # noqa: F401, F403

+ 12
- 0
opencompass/datasets/SciReasoner/LLM4Chem/utils/chat_generation.py View File

@@ -0,0 +1,12 @@
def generate_chat(input_text, output_text=None, prefix_chat=None):
chat = [
{
'role': 'user',
'content': input_text
},
]
if output_text is not None:
chat.append({'role': 'assistant', 'content': output_text})
if prefix_chat is not None:
chat = prefix_chat + chat
return chat

+ 195
- 0
opencompass/datasets/SciReasoner/LLM4Chem/utils/core_tagger.py View File

@@ -0,0 +1,195 @@
def find_sub_sequence(whole, sub):
assert isinstance(whole, list)
assert isinstance(sub, list)
len_whole = len(whole)
len_sub = len(sub)
assert len_whole > 0
assert len_sub > 0

s = 0
while True:
s_whole = whole[s:]
try:
k_pos = s_whole.index(sub[0])
except ValueError:
return -1

fail = False
for i in range(1, len_sub):
try:
if s_whole[k_pos + i] != sub[i]:
fail = True
break
except IndexError:
return -1
if fail:
s = s + k_pos + 1
continue
else:
return s + k_pos


class CoreTagger(object):

def __init__(self,
tokenizer,
core_tags_as_special_tokens=False,
include_tags=True):
self.tokenizer = tokenizer
if core_tags_as_special_tokens:
raise NotImplementedError
self.core_tags_as_special_tokens = core_tags_as_special_tokens
if not include_tags:
raise NotImplementedError
self.include_tags = include_tags

self.left_tag_to_id = {}
self.right_tag_to_id = {}

def generate_mask(self, token_ids, output_begin, sample):
mask = [0] * len(token_ids)
left_tag, right_tag = sample['output_core_tag_left'], sample[
'output_core_tag_right']
if left_tag not in self.left_tag_to_id:
if left_tag is None:
left_token_ids = None
else:
left_token_ids = self.tokenizer(
left_tag,
add_special_tokens=False,
return_attention_mask=False)['input_ids']
self.left_tag_to_id[left_tag] = left_token_ids
else:
left_token_ids = self.left_tag_to_id[left_tag]
if right_tag not in self.right_tag_to_id:
if right_tag is None:
right_token_ids = None
else:
right_token_ids = self.tokenizer(
right_tag,
add_special_tokens=False,
return_attention_mask=False)['input_ids']
self.right_tag_to_id[right_tag] = right_token_ids
else:
right_token_ids = self.right_tag_to_id[right_tag]

output_token_ids = token_ids[output_begin:]
if left_token_ids is None:
left_position = output_begin
elif len(output_token_ids) == 0:
left_position = None
else:
left_position = find_sub_sequence(output_token_ids,
left_token_ids) + output_begin
if left_position == -1:
left_position = None

if left_position is None:
return mask

if right_token_ids is None:
right_position = len(token_ids)
if token_ids[-1] == self.tokenizer.eos_token_id:
right_position -= 1
else:
right_position = find_sub_sequence(output_token_ids,
right_token_ids) + output_begin
if right_position == -1:
right_position = len(token_ids)
if token_ids[-1] == self.tokenizer.eos_token_id:
right_position -= 1
else:
right_position = min(right_position + len(right_token_ids),
len(token_ids))

for idx in range(left_position, right_position):
mask[idx] = 1

return mask


class CoreTaggerGeneral(object):

def __init__(self,
tokenizer,
core_tags_as_special_tokens=False,
include_tags=True):
self.tokenizer = tokenizer
if core_tags_as_special_tokens:
raise NotImplementedError
self.core_tags_as_special_tokens = core_tags_as_special_tokens
if not include_tags:
raise NotImplementedError
self.include_tags = include_tags

self.left_tag_to_id = {}
self.right_tag_to_id = {}

def generate_mask(self, token_ids, prompt_mask, sample):
mask = [0] * len(token_ids)
left_tag, right_tag = sample['output_core_tag_left'], sample[
'output_core_tag_right']
if left_tag not in self.left_tag_to_id:
if left_tag is None:
left_token_ids = None
else:
left_token_ids = self.tokenizer(
left_tag,
add_special_tokens=False,
return_attention_mask=False)['input_ids']
self.left_tag_to_id[left_tag] = left_token_ids
else:
left_token_ids = self.left_tag_to_id[left_tag]
if right_tag not in self.right_tag_to_id:
if right_tag is None:
right_token_ids = None
else:
right_token_ids = self.tokenizer(
right_tag,
add_special_tokens=False,
return_attention_mask=False)['input_ids']
self.right_tag_to_id[right_tag] = right_token_ids
else:
right_token_ids = self.right_tag_to_id[right_tag]

cur_ = 0
for idx in range(len(token_ids)):
if prompt_mask[idx] == 1 or token_ids[
idx] == self.tokenizer.bos_token_id:
cur_ = 0
continue

if left_token_ids is None:
match_left = True
else:
match_left = True
try:
for offset in range(len(left_token_ids)):
if token_ids[idx + offset] != left_token_ids[offset]:
match_left = False
break
except IndexError:
match_left = False

if match_left:
cur_ = 1

mask[idx] = cur_

if right_token_ids is None:
continue

match_right = True
try:
for offset in range(len(right_token_ids)):
if token_ids[idx - len(right_token_ids) +
offset] != right_token_ids[offset]:
match_right = False
break
except IndexError:
match_right = False

if match_right:
cur_ = 0

return mask

+ 35
- 0
opencompass/datasets/SciReasoner/LLM4Chem/utils/general_prompter.py View File

@@ -0,0 +1,35 @@
def get_chat_content(conversation, tokenize=False):
if tokenize:
raise NotImplementedError
available_roles = ('user', 'assistant')
content = ''
for idx, item in enumerate(conversation):
role = item['role']
assert role in available_roles, role
if idx % 2 == 0:
assert role == 'user'
content += '<s>'
item_content = '[INST] %s [/INST]' % item['content']
content += item_content
else:
assert role == 'assistant'
item_content = ' %s</s>' % item['content']
content += item_content
return content


class GeneralPrompter(object):

def __init__(self, apply_chat_template_func, response_split='[/INST]'):
self.apply_chat_template_func = apply_chat_template_func
self.response_split = response_split

def generate_prompt(self, chat, tokenize=False, *args, **kargs) -> str:
res = self.apply_chat_template_func(chat,
tokenize=tokenize,
*args,
**kargs)
return res

def get_response(self, output: str) -> str:
return output.split(self.response_split)[-1].strip()

+ 685
- 0
opencompass/datasets/SciReasoner/LLM4Chem/utils/metrics.py View File

@@ -0,0 +1,685 @@
# flake8: noqa

import re
from collections import defaultdict

import numpy as np
from nltk.translate.bleu_score import corpus_bleu
from nltk.translate.meteor_score import meteor_score

try:
from rdkit import Chem, DataStructs, RDLogger
from rdkit.Chem import AllChem, MACCSkeys
except Exception:
Chem, DataStructs, RDLogger, AllChem, MACCSkeys = None, None, None, None, None

from rouge_score import rouge_scorer
from sklearn.metrics import (f1_score, matthews_corrcoef, precision_score,
recall_score, roc_auc_score)
from tqdm.auto import tqdm
from transformers import BertTokenizerFast

from .smiles_canonicalization import (canonicalize_molecule_smiles,
get_molecule_id)

# RDLogger.DisableLog('rdApp.*')


def convert_smiles_list_into_mol_list(smiles_list,
raise_error_when_error=False):
mol_list = []
no_answer_labels = []
invalid_labels = []
for smiles in smiles_list:
if smiles == '':
mol = 'NA'
no_answer_labels.append(True)
if raise_error_when_error:
raise ValueError('SMILES is empty.')
else:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
mol = 'INVALID'
invalid_labels.append(True)
if raise_error_when_error:
raise ValueError('SMILES is not valid: %s' % smiles)
mol_list.append(mol)

no_answer_labels = np.array(no_answer_labels)
invalid_labels = np.arange(invalid_labels)

return mol_list, no_answer_labels, invalid_labels


def judge_exact_match(pred_can_smiles_list, gold_can_smiles_list):
assert len(pred_can_smiles_list) == len(gold_can_smiles_list)
exact_match_labels = []
for pred_smiles, gold_smiles_list in zip(pred_can_smiles_list,
gold_can_smiles_list):
if pred_smiles is None or pred_smiles.strip() == '':
exact_match_labels.append(False)
continue
pred_smiles_inchi = get_molecule_id(pred_smiles)
sample_exact_match = False
for gold_smiles in gold_smiles_list:
assert gold_smiles is not None
gold_smiles_inchi = get_molecule_id(gold_smiles)
if pred_smiles_inchi == gold_smiles_inchi:
sample_exact_match = True
break
exact_match_labels.append(sample_exact_match)
return np.array(exact_match_labels)


def calculate_fingerprint_similarity(pred_mol_list,
gold_mols_list,
morgan_r=2):
assert len(pred_mol_list) == len(gold_mols_list)
MACCS_sims = []
morgan_sims = []
RDK_sims = []
for pred_mol, gold_mol_list in zip(pred_mol_list, gold_mols_list):
if pred_mol is None or type(pred_mol) == str:
raise ValueError(type(pred_mol))
tmp_MACCS, tmp_RDK, tmp_morgan = 0, 0, 0
for gold_mol in gold_mol_list:
tmp_MACCS = max(
tmp_MACCS,
DataStructs.FingerprintSimilarity(
MACCSkeys.GenMACCSKeys(gold_mol),
MACCSkeys.GenMACCSKeys(pred_mol),
metric=DataStructs.TanimotoSimilarity))
tmp_RDK = max(
tmp_RDK,
DataStructs.FingerprintSimilarity(
Chem.RDKFingerprint(gold_mol),
Chem.RDKFingerprint(pred_mol),
metric=DataStructs.TanimotoSimilarity))
tmp_morgan = max(
tmp_morgan,
DataStructs.TanimotoSimilarity(
AllChem.GetMorganFingerprint(gold_mol, morgan_r),
AllChem.GetMorganFingerprint(pred_mol, morgan_r)))
MACCS_sims.append(tmp_MACCS)
RDK_sims.append(tmp_RDK)
morgan_sims.append(tmp_morgan)
maccs_sims_score = np.mean(MACCS_sims)
rdk_sims_score = np.mean(RDK_sims)
morgan_sims_score = np.mean(morgan_sims)
return maccs_sims_score, rdk_sims_score, morgan_sims_score


def judge_multiple_match(pred_can_smiles_list, golds_can_smiles_list):
assert len(pred_can_smiles_list) == len(golds_can_smiles_list)
subset_labels = []
intersection_labels = []
for pred_smiles, gold_smiles_list in zip(pred_can_smiles_list,
golds_can_smiles_list):
if pred_smiles is None:
subset_labels.append(False)
intersection_labels.append(False)
continue

pred_ele_set = set()
for smiles in pred_smiles.split('.'):
pred_ele_set.add(get_molecule_id(smiles, remove_duplicate=False))

intersection_label = False
subset_label = False
for gold_smiles in gold_smiles_list:
assert gold_smiles is not None
gold_ele_set = set()
for smiles in gold_smiles.split('.'):
gold_ele_set.add(
get_molecule_id(smiles, remove_duplicate=False))

if len(pred_ele_set & gold_ele_set) > 0:
intersection_label = True
g_p = gold_ele_set - pred_ele_set
if len(g_p) >= 0 and len(pred_ele_set - gold_ele_set) == 0:
subset_label = True
break
intersection_labels.append(intersection_label)
subset_labels.append(subset_label)

return intersection_labels, subset_labels


def calculate_smiles_metrics(preds_smiles_list,
golds_smiles_list,
metrics=('exact_match', 'fingerprint')):
num_all = len(preds_smiles_list)
assert num_all > 0
assert num_all == len(golds_smiles_list)
k = len(preds_smiles_list[0])

dk_pred_smiles_list_dict = {}
dk_pred_no_answer_labels_dict = {}
dk_pred_invalid_labels_dict = {}
for dk in range(k):
dk_pred_smiles_list_dict[dk] = []
dk_pred_no_answer_labels_dict[dk] = []
dk_pred_invalid_labels_dict[dk] = []
for pred_smiles_list in tqdm(preds_smiles_list):
if pred_smiles_list is None:
for dk in range(k):
dk_pred_no_answer_labels_dict[dk].append(True)
dk_pred_invalid_labels_dict[dk].append(False)
dk_pred_smiles_list_dict[dk].append(None)
continue
assert len(pred_smiles_list) == k
for dk, item in enumerate(pred_smiles_list):
# item = item.strip()
if item == '' or item is None:
item = None
dk_pred_no_answer_labels_dict[dk].append(True)
dk_pred_invalid_labels_dict[dk].append(False)
else:
dk_pred_no_answer_labels_dict[dk].append(False)
item = canonicalize_molecule_smiles(item)
if item is None:
dk_pred_invalid_labels_dict[dk].append(True)
else:
dk_pred_invalid_labels_dict[dk].append(False)
dk_pred_smiles_list_dict[dk].append(item)

new_list = []
for gold_smiles_list in tqdm(golds_smiles_list):
sample_gold_smiles_list = []
for gold in gold_smiles_list:
item = gold.strip()
new_item = canonicalize_molecule_smiles(
item, return_none_for_error=False)
# if new_item is None:
# new_item = item #TODO
# assert new_item is not None, item
sample_gold_smiles_list.append(new_item)
new_list.append(sample_gold_smiles_list)
golds_smiles_list = new_list

metric_results = {'num_all': num_all}

tk_pred_no_answer_labels = np.array([True] * num_all)
tk_pred_invalid_labels = np.array([True] * num_all)
for dk in range(k):
dk_no_answer_labels = dk_pred_no_answer_labels_dict[dk]
dk_invalid_labels = dk_pred_invalid_labels_dict[dk]
tk_pred_no_answer_labels = tk_pred_no_answer_labels & \
dk_no_answer_labels
tk_pred_invalid_labels = tk_pred_invalid_labels & dk_invalid_labels
metric_results['num_t%d_no_answer' %
(dk + 1)] = tk_pred_no_answer_labels.sum().item()
metric_results['num_t%d_invalid' %
(dk + 1)] = tk_pred_invalid_labels.sum().item()

# d1_no_answer_labels = dk_pred_no_answer_labels_dict[0]
# # print(np.array(d1_no_answer_labels).sum().item())
# for label, item in zip(d1_no_answer_labels, preds_smiles_list):
# if label:
# print(item)

for metric in metrics:
if metric == 'exact_match':
tk_exact_match_labels = np.array([False] * num_all)
for dk in range(k):
dk_pred_smiles_list = dk_pred_smiles_list_dict[dk]
dk_exact_match_labels = judge_exact_match(
dk_pred_smiles_list, golds_smiles_list)
tk_exact_match_labels = tk_exact_match_labels | \
dk_exact_match_labels
metric_results['num_t%d_exact_match' %
(dk + 1)] = tk_exact_match_labels.sum().item()
elif metric == 'fingerprint':
d1_pred_mol_list = []
gold_mols_list = []
for pred_smiles, gold_smiles_list, no_answer, invalid in zip(
dk_pred_smiles_list_dict[0], golds_smiles_list,
dk_pred_no_answer_labels_dict[0],
dk_pred_invalid_labels_dict[0]):
if pred_smiles is None or pred_smiles.strip(
) == '' or no_answer is True or invalid is True:
continue
pred_mol = Chem.MolFromSmiles(pred_smiles)
if pred_mol is None: # TODO
continue
assert pred_mol is not None, pred_smiles
gold_mol_list = []
for gold_smiles in gold_smiles_list:
gold_mol = Chem.MolFromSmiles(gold_smiles)
# if gold_mol is None:
# continue # TODO
assert gold_mol is not None, gold_smiles
gold_mol_list.append(gold_mol)
# if len(gold_mol_list) == 0:
# continue # TODO
d1_pred_mol_list.append(pred_mol)
gold_mols_list.append(gold_mol_list)
maccs_sims_score, rdk_sims_score, morgan_sims_score = \
calculate_fingerprint_similarity(
d1_pred_mol_list, gold_mols_list)
metric_results['t1_maccs_fps'] = maccs_sims_score
metric_results['t1_rdk_fps'] = rdk_sims_score
metric_results['t1_morgan_fps'] = morgan_sims_score
elif metric == 'multiple_match':
tk_intersection_labels = np.array([False] * num_all)
tk_subset_labels = np.array([False] * num_all)
for dk in range(k):
dk_intersection_labels, dk_subset_labels = \
judge_multiple_match(
dk_pred_smiles_list_dict[dk], golds_smiles_list)
tk_intersection_labels = tk_intersection_labels | \
dk_intersection_labels
tk_subset_labels = tk_subset_labels | dk_subset_labels
metric_results['num_t%d_subset' %
(dk + 1)] = tk_intersection_labels.sum().item()
metric_results['num_t%d_intersection' %
(dk + 1)] = tk_intersection_labels.sum().item()
else:
raise ValueError(metric)

return metric_results


def judge_string_exact_match(pred_string_list, golds_string_list):
exact_match_labels = []
for pred_string, gold_string_list in zip(pred_string_list,
golds_string_list):
exact_match = False
for gold_string in gold_string_list:
if pred_string == gold_string:
exact_match = True
break
exact_match_labels.append(exact_match)
return np.array(exact_match_labels)


def judge_string_split_match(pred_string_list,
golds_string_list,
separator=';'):
exact_match_labels = []
for pred_string, gold_string_list in zip(pred_string_list,
golds_string_list):
pred_item = tuple(sorted(pred_string.split(separator)))
exact_match = False
for gold_string in gold_string_list:
gold_item = tuple(sorted(gold_string.split(separator)))
if pred_item == gold_item:
exact_match = True
break
exact_match_labels.append(exact_match)
return np.array(exact_match_labels)


def parse_molecule(molecular_formula):
valid = re.match(r'([A-Za-z]\d*)+([\+\-]\d*)*$', molecular_formula)
if valid is None:
raise ValueError("Molecular formula \"%s\" is not valid." %
molecular_formula)

stack = [defaultdict(int)]

def _parse_formula(formula, _stack):

# Set remainder equal to 'None'
r = None

# Regular expression matching for each of the three cases:
atom = re.match(r'([A-Z][a-z]?)(\d+)?', formula)
opening = re.match(r'[\(\[\{]', formula)
closing = re.match(r'[\)\]\}](\d+)?', formula)

# If atom is identified:
if atom:
r = formula[len(atom.group()):]
_stack[-1][atom.group(1)] += int(atom.group(2) or 1)

# If opening brackets encountered:
elif opening:
# this sets the remainder equal
# to everything after the opening brackets
r = formula[len(opening.group()):]
_stack.append(defaultdict(int))

# If closing brackets encountered:
elif closing:
r = formula[len(closing.group()):]
# this sets the remainder equal to
# everything after the closing brackets
for (k, v) in _stack.pop().items():
_stack[-1][k] += v * int(
closing.group(1)
# v times amount of molecule k,
# depending on nesting
or 1)

# If anything remains,
# process remainders recursively as nested formulas:
if r:
_parse_formula(r, _stack)

return dict(_stack[0])

result = _parse_formula(molecular_formula, stack)

charge = re.search(r'[\+\-]\d*', molecular_formula)
if charge is not None:
charge_str = charge.group()
charge_type = charge_str[0]
if len(charge_str) == 1:
charge_num = 1
else:
charge_num = int(charge_str[1:])
result[charge_type] = charge_num

return result


def count_element_match(pred_formula_list, golds_formula_list):
assert len(pred_formula_list) == len(golds_formula_list)
ele_match_labels = []
ele_invalid_labels = []
for pred_formula, gold_formula_list in zip(pred_formula_list,
golds_formula_list):
if pred_formula == '' or pred_formula is None:
ele_invalid_labels.append(False)
ele_match_labels.append(False)
continue
try:
pred_ele = parse_molecule(pred_formula)
except KeyboardInterrupt:
raise
except Exception:
# print(pred_formula)
# print('=====')
ele_invalid_labels.append(True)
ele_match_labels.append(False)
continue
ele_invalid_labels.append(False)
ele_match = False
for gold_formula in gold_formula_list:
gold_ele = parse_molecule(gold_formula)
if pred_ele == gold_ele:
ele_match = True
break
ele_match_labels.append(ele_match)
return ele_match_labels, ele_invalid_labels


def calculate_formula_metrics(preds_formula_list,
golds_formula_list,
metrics=('element_match', )):
"""
Calculate metrics for molecular formula.
Here we use element_match (equals to exact_match
used in our paper) by default,
which compares the atom numbers and ignore the orders.
For example, C5H8 == H8C5.
"""
num_all = len(preds_formula_list)
assert len(preds_formula_list) == len(golds_formula_list)
try:
k = len(preds_formula_list[0])
except IndexError:
print(preds_formula_list)
raise
dk_pred_formula_list_dict = dict()
for dk in range(k):
dk_pred_formula_list_dict[dk] = []
for sample_formula_list in preds_formula_list:
if sample_formula_list is None:
for dk in range(k):
dk_pred_formula_list_dict[dk].append('')
continue
assert len(sample_formula_list) == k
for dk in range(k):
item = sample_formula_list[dk]
dk_pred_formula_list_dict[dk].append(item)
golds_formula_list = [[small_item.strip() for small_item in item]
for item in golds_formula_list]
new_golds_formula_list = []
for item in golds_formula_list:
new_item = []
for small_item in item:
small_item = small_item.strip()
assert small_item != ''
new_item.append(small_item)
new_golds_formula_list.append(new_item)
golds_formula_list = new_golds_formula_list

metric_results = {'num_all': num_all}

tk_no_answer_labels = np.array([True] * num_all)
for dk in range(k):
dk_pred_formula_list = dk_pred_formula_list_dict[dk]
dk_no_answer_labels = []
for item in dk_pred_formula_list:
if item == '' or item is None:
dk_no_answer_labels.append(True)
else:
dk_no_answer_labels.append(False)
dk_no_answer_labels = np.array(dk_no_answer_labels)
tk_no_answer_labels = tk_no_answer_labels & dk_no_answer_labels
metric_results['num_t%d_no_answer' %
(dk + 1)] = tk_no_answer_labels.sum().item()

for metric in metrics:
if metric == 'exact_match':
tk_exact_match_labels = np.array([False] * num_all)
for dk in range(k):
dk_pred_formula_list = dk_pred_formula_list_dict[dk]
dk_exact_match_labels = judge_string_exact_match(
dk_pred_formula_list, golds_formula_list)
tk_exact_match_labels = tk_exact_match_labels | \
dk_exact_match_labels
metric_results['num_t%d_exact_match' %
(dk + 1)] = tk_exact_match_labels.sum().item()
elif metric == 'element_match':
tk_ele_match_labels = np.array([False] * num_all)
tk_formula_invalid_labels = np.array([True] * num_all)
for dk in range(k):
dk_pred_formula_list = dk_pred_formula_list_dict[dk]
dk_ele_match_labels, dk_formula_invalid_labels = \
count_element_match(
dk_pred_formula_list, golds_formula_list)
tk_ele_match_labels = tk_ele_match_labels | dk_ele_match_labels
tk_formula_invalid_labels = tk_formula_invalid_labels & \
dk_formula_invalid_labels
metric_results['num_t%d_ele_match' %
(dk + 1)] = tk_ele_match_labels.sum().item()
metric_results['num_t%d_formula_invalid' %
(dk +
1)] = tk_formula_invalid_labels.sum().item()
elif metric == 'split_match':
tk_exact_match_labels = np.array([False] * num_all)
for dk in range(k):
dk_pred_formula_list = dk_pred_formula_list_dict[dk]
dk_exact_match_labels = judge_string_split_match(
dk_pred_formula_list, golds_formula_list)
tk_exact_match_labels = tk_exact_match_labels | \
dk_exact_match_labels
metric_results['num_t%d_split_match' %
(dk + 1)] = tk_exact_match_labels.sum().item()
else:
raise ValueError(metric)

return metric_results


def calculate_text_metrics(pred_text_list,
gold_text_list,
text_model='allenai/scibert_scivocab_uncased',
text_trunc_length=512):
assert len(pred_text_list) == len(gold_text_list)
pred_text_list = [(item[0].strip() if item is not None else '')
for item in pred_text_list]
gold_text_list = [item[0].strip() for item in gold_text_list]

num_no_answer = 0
for pred_formula in pred_text_list:
if pred_formula == '':
num_no_answer += 1

text_tokenizer = BertTokenizerFast.from_pretrained(text_model)

meteor_scores = []

references = []
hypotheses = []

for i, (gt, out) in enumerate(zip(gold_text_list, pred_text_list)):
if out == '':
continue

gt_tokens = text_tokenizer.tokenize(gt,
truncation=True,
max_length=text_trunc_length,
padding='max_length')
gt_tokens = list(filter(('[PAD]').__ne__, gt_tokens))
gt_tokens = list(filter(('[CLS]').__ne__, gt_tokens))
gt_tokens = list(filter(('[SEP]').__ne__, gt_tokens))

out_tokens = text_tokenizer.tokenize(out,
truncation=True,
max_length=text_trunc_length,
padding='max_length')
out_tokens = list(filter(('[PAD]').__ne__, out_tokens))
out_tokens = list(filter(('[CLS]').__ne__, out_tokens))
out_tokens = list(filter(('[SEP]').__ne__, out_tokens))

references.append([gt_tokens])
hypotheses.append(out_tokens)

mscore = meteor_score([gt_tokens], out_tokens)
meteor_scores.append(mscore)

bleu2 = corpus_bleu(references, hypotheses, weights=(.5, .5))
bleu4 = corpus_bleu(references, hypotheses, weights=(.25, .25, .25, .25))

_meteor_score = np.mean(meteor_scores)

scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'])

rouge_scores = []

references = []
hypotheses = []

for i, (gt, out) in enumerate(zip(gold_text_list, pred_text_list)):
if out == '':
continue

rs = scorer.score(out, gt)
rouge_scores.append(rs)

rouge_1 = np.mean([rs['rouge1'].fmeasure for rs in rouge_scores])
rouge_2 = np.mean([rs['rouge2'].fmeasure for rs in rouge_scores])
rouge_l = np.mean([rs['rougeL'].fmeasure for rs in rouge_scores])

result = {
'num_all': len(pred_text_list),
'num_no_answer': num_no_answer,
'bleu2': bleu2,
'bleu4': bleu4,
'rouge_1': rouge_1,
'rouge_2': rouge_2,
'rouge_l': rouge_l,
'meteor_score': _meteor_score,
}

return result


def calculate_number_metrics(pred_text_list, gold_text_list):
assert len(pred_text_list) == len(gold_text_list)
num_all = len(pred_text_list)
metrics = {}
metrics['num_all'] = num_all
num_no_answer = 0
num_invalid = 0
new_pred_text_list, new_gold_text_list = [], []
for (pred_item, gold_item) in zip(pred_text_list, gold_text_list):
if pred_item is None:
num_no_answer += 1
continue
assert len(pred_item) == 1
assert len(gold_item) == 1
pred_item = pred_item[0]
gold_item = gold_item[0]
if pred_item == '':
num_no_answer += 1
continue
try:
pred_item = float(pred_item)
except (SyntaxError, ValueError):
# print("\"%s\"" % pred_item)
num_invalid += 1
continue
gold_item = float(gold_item)
new_pred_text_list.append(pred_item)
new_gold_text_list.append(gold_item)

new_pred_text_list = np.array(new_pred_text_list)
new_gold_text_list = np.array(new_gold_text_list)
score = np.sqrt(((new_pred_text_list - new_gold_text_list)**2).mean())

metrics['num_no_answer'] = num_no_answer
metrics['num_invalid'] = num_invalid
metrics['RMSE'] = score

return metrics


def calculate_boolean_metrics(pred_text_list, gold_text_list):
assert len(pred_text_list) == len(gold_text_list)
num_all = len(pred_text_list)
metrics = {}
metrics['num_all'] = num_all
num_no_answer = 0
num_invalid = 0
num_correct = 0
new_pred_text_list, new_gold_text_list = [], []
for (pred_item, gold_item) in zip(pred_text_list, gold_text_list):
if pred_item is None or pred_item == '':
num_no_answer += 1
continue
assert len(pred_item) == 1
assert len(gold_item) == 1
pred_item = pred_item[0].strip().lower()
gold_item = gold_item[0].strip().lower()
if pred_item == '':
num_no_answer += 1
continue
if pred_item not in ('yes', 'no'):
num_invalid += 1
continue
pred_item = 1 if pred_item == 'yes' else 0
gold_item = 1 if gold_item == 'yes' else 0
new_pred_text_list.append(pred_item)
new_gold_text_list.append(gold_item)
if gold_item == pred_item:
num_correct += 1

metrics['num_no_answer'] = num_no_answer
metrics['num_invalid'] = num_invalid
metrics['num_correct'] = num_correct

# return metrics

new_gold_text_list = np.array(new_gold_text_list)
new_pred_text_list = np.array(new_pred_text_list)

macro_roc_auc_score = roc_auc_score(new_gold_text_list, new_pred_text_list)
f1 = f1_score(new_gold_text_list, new_pred_text_list)
metrics['roc_auc_score'] = macro_roc_auc_score
metrics['precision'] = precision_score(new_gold_text_list,
new_pred_text_list)
metrics['recall'] = recall_score(new_gold_text_list, new_pred_text_list)
metrics['f1_score'] = f1

no_mask = (new_gold_text_list == 0)
new_gold_text_list[no_mask] = -1
no_mask = (new_pred_text_list == 0)
new_pred_text_list[no_mask] = -1
metrics['mcc'] = matthews_corrcoef(new_gold_text_list, new_pred_text_list)

return metrics

+ 189
- 0
opencompass/datasets/SciReasoner/LLM4Chem/utils/smiles_canonicalization.py View File

@@ -0,0 +1,189 @@
try:
from rdkit import Chem, RDLogger
from rdkit.Chem.AllChem import AssignStereochemistry
except Exception:
Chem, RDLogger, AssignStereochemistry = None, None, None

# RDLogger.DisableLog('rdApp.*')


def canonicalize(smiles, isomeric=False, canonical=True, kekulize=False):
# When canonicalizing a SMILES string, we typically want to
# run Chem.RemoveHs(mol), but this will try to kekulize the mol
# which is not required for canonical SMILES. Instead, we make a
# copy of the mol retaining only the information we desire
# (not explicit Hs)
# Then, we sanitize the mol without kekulization.
# copy_atom and copy_edit_mol
# Are used to create this clean copy of the mol.
def copy_atom(atom):
new_atom = Chem.Atom(atom.GetSymbol())
new_atom.SetFormalCharge(atom.GetFormalCharge())
if atom.GetIsAromatic() and atom.GetNoImplicit():
new_atom.SetNumExplicitHs(atom.GetNumExplicitHs())
# elif atom.GetSymbol() == 'N':
# print(atom.GetSymbol())
# print(atom.GetImplicitValence())
# new_atom.SetNumExplicitHs(-atom.GetImplicitValence())
# elif atom.GetSymbol() == 'S':
# print(atom.GetSymbol())
# print(atom.GetImplicitValence())
return new_atom

def copy_edit_mol(mol):
from rdchiral.chiral import copy_chirality

new_mol = Chem.RWMol(Chem.MolFromSmiles(''))
for atom in mol.GetAtoms():
new_atom = copy_atom(atom)
new_mol.AddAtom(new_atom)
for bond in mol.GetBonds():
a1 = bond.GetBeginAtom().GetIdx()
a2 = bond.GetEndAtom().GetIdx()
bt = bond.GetBondType()
new_mol.AddBond(a1, a2, bt)
new_bond = new_mol.GetBondBetweenAtoms(a1, a2)
new_bond.SetBondDir(bond.GetBondDir())
new_bond.SetStereo(bond.GetStereo())
for new_atom in new_mol.GetAtoms():
atom = mol.GetAtomWithIdx(new_atom.GetIdx())
copy_chirality(atom, new_atom)
return new_mol

smiles = smiles.replace(' ', '')
tmp = Chem.MolFromSmiles(smiles, sanitize=False)
tmp.UpdatePropertyCache()
new_mol = copy_edit_mol(tmp)
# Chem.SanitizeMol(new_mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL)
if not kekulize:
Chem.SanitizeMol(new_mol,
sanitizeOps=Chem.SanitizeFlags.SANITIZE_SETAROMATICITY
| Chem.SanitizeFlags.SANITIZE_PROPERTIES
| Chem.SanitizeFlags.SANITIZE_ADJUSTHS,
catchErrors=True)
else:
Chem.SanitizeMol(new_mol,
sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE
| Chem.SanitizeFlags.SANITIZE_PROPERTIES
| Chem.SanitizeFlags.SANITIZE_ADJUSTHS,
catchErrors=True)

AssignStereochemistry(new_mol,
cleanIt=False,
force=True,
flagPossibleStereoCenters=True)

new_smiles = Chem.MolToSmiles(new_mol,
isomericSmiles=isomeric,
canonical=canonical)
return new_smiles


def canonicalize_molecule_smiles(smiles,
return_none_for_error=True,
skip_mol=False,
sort_things=True,
isomeric=True,
kekulization=True,
allow_empty_part=False):
things = smiles.split('.')
if skip_mol:
new_things = things
else:
new_things = []
for thing in things:
try:
if thing == '' and not allow_empty_part:
raise ValueError('SMILES contains empty part.')

mol = Chem.MolFromSmiles(thing)
# print(f"smiles = {thing} mol = {mol}")
if mol is None:
return thing
assert mol is not None
for atom in mol.GetAtoms():
atom.SetAtomMapNum(0)
thing_smiles = Chem.MolToSmiles(mol,
kekuleSmiles=False,
isomericSmiles=isomeric)
thing_smiles = Chem.MolFromSmiles(thing_smiles)
thing_smiles = Chem.MolToSmiles(thing_smiles,
kekuleSmiles=False,
isomericSmiles=isomeric)
thing_smiles = Chem.MolFromSmiles(thing_smiles)
thing_smiles = Chem.MolToSmiles(thing_smiles,
kekuleSmiles=False,
isomericSmiles=isomeric)
assert thing_smiles is not None
can_in = thing_smiles
can_out = canonicalize(thing_smiles, isomeric=isomeric)
assert can_out is not None, can_in
thing_smiles = can_out
if kekulization:
thing_smiles = keku_mid = Chem.MolFromSmiles(thing_smiles)
assert keku_mid is not None, \
'Before can: %s\nAfter can: %s' % (
can_in, can_out)
thing_smiles = Chem.MolToSmiles(thing_smiles,
kekuleSmiles=True,
isomericSmiles=isomeric)
except KeyboardInterrupt:
raise
except Exception:
if return_none_for_error:
return None
else:
raise
new_things.append(thing_smiles)
if sort_things:
new_things = sorted(new_things)
new_things = '.'.join(new_things)
return new_things


def canonicalize_reaction_smiles(smiles,
return_none_for_error=True,
return_segs=False,
skip_mol=False,
sort_things=True,
isomeric=True,
kekulization=True):
segs = smiles.split('>')
assert len(segs) == 3
new_segs = []
for seg in segs:
if seg != '':
new_things = canonicalize_molecule_smiles(
seg,
return_none_for_error=return_none_for_error,
skip_mol=skip_mol,
sort_things=sort_things,
isomeric=isomeric,
kekulization=kekulization)
if return_none_for_error and new_things is None:
return None
new_segs.append(new_things)
else:
new_segs.append('')

if return_segs:
return tuple(new_segs)

smiles = '>'.join(new_segs)
return smiles


def get_molecule_id(smiles, remove_duplicate=True):
if remove_duplicate:
assert ';' not in smiles
all_inchi = set()
for part in smiles.split('.'):
inchi = get_molecule_id(part, remove_duplicate=False)
all_inchi.add(inchi)
all_inchi = tuple(sorted(all_inchi))
return all_inchi
else:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return ''
return Chem.MolToInchi(mol)

+ 217
- 0
opencompass/datasets/SciReasoner/LLM4Mat.py View File

@@ -0,0 +1,217 @@
# flake8: noqa

import json
import os
import re
from typing import Union

import numpy as np
from datasets import Dataset, DatasetDict
from huggingface_hub import hf_hub_download
from sklearn.metrics import (mean_absolute_error, mean_squared_error,
roc_auc_score)

from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
from opencompass.utils import get_data_path


@LOAD_DATASET.register_module()
class LLM4MatDataset(BaseDataset):

@staticmethod
def load(path,
property,
train_path,
test_path,
mini_set=False) -> DatasetDict:

def load_single_dataset(path, property, num=None):

# if (hf_hub is True):
# repo_id = path.split('/')[0] + '/' + path.split('/')[1]
# path = path.split(repo_id + '/')[1]
#
# path = hf_hub_download(repo_id, path, repo_type='dataset')

with open(path, 'r', encoding='utf-8') as f:
raw_data = json.load(f)
if isinstance(raw_data, dict):
raw_data = [raw_data]

processed = []
for i, item in enumerate(raw_data):
if not '{' + f'{property} :' in item['output']:
continue
new_item = {
'input': item['input'],
'output': item['output'],
}
processed.append(new_item)
if num:
dataset = Dataset.from_list(processed[:num])
else:
dataset = Dataset.from_list(processed)
return dataset

path = get_data_path(path)
train_path = os.path.join(path, train_path)
test_path = os.path.join(path, test_path)

if mini_set:
test_num = 150
else:
test_num = None
dataset = DatasetDict({
'train':
load_single_dataset(train_path, property, num=5),
'test':
load_single_dataset(test_path, property, num=test_num)
})
return dataset


non_numeric_props_options = {
'Direct_or_indirect': ['Indirect', 'Direct'],
'Direct_or_indirect_HSE': ['Indirect', 'Direct'],
'SOC': [True, False],
'is_gap_direct': [True, False],
'is_stable': [True, False],
}


def remove_think_tags(text: str) -> str:
if '<think>' not in text:
return text
if '</think>' not in text:
return ''
return re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)


def extract_strict_value(text: str, property: str) -> str:
text_clean = re.sub(r'^```(?:json)?\s*|\s*```$',
'',
text.strip(),
flags=re.IGNORECASE | re.MULTILINE)
try:
data = json.loads(text_clean)
if property in data:
raw_value = data[property]
if isinstance(raw_value, (int, float)):
return float(raw_value)
if property in non_numeric_props_options:
options = non_numeric_props_options[property]
for opt in options:
if isinstance(opt, bool):
if str(raw_value).lower() == str(opt).lower():
return str(opt)
elif str(raw_value).lower() == str(opt).lower():
return opt
return ''
return str(raw_value)
except Exception:
pass

pattern = rf'\{{[^{{}}]*"?{re.escape(property)}"?\s*:\s*(.*?)\s*\}}'
match = re.search(pattern, text_clean, flags=re.DOTALL | re.IGNORECASE)
if not match:
return ''
raw_value = match.group(1).strip().strip('"')
if property in non_numeric_props_options:
options = non_numeric_props_options[property]
for opt in options:
if isinstance(opt, bool):
if raw_value.lower() == str(opt).lower():
return str(opt)
elif raw_value.lower() == opt.lower():
return opt
return ''
try:
return float(raw_value)
except ValueError:
return ''


@TEXT_POSTPROCESSORS.register_module()
def LLM4Mat_postprocessor(text: Union[str, None], property):
if text is None or not isinstance(text, str):
return ''
text = text.strip()
text = remove_think_tags(text)
if text == '':
return ''
result = extract_strict_value(text, property)
return result


class LLM4Mat_Evaluator(BaseEvaluator):

def score(self, predictions, references):
is_regression = isinstance(
references[0],
(int, float)) and not isinstance(references[0], bool)

if is_regression:
y_true = []
y_pred = []
total = len(references)
for t, p in zip(references, predictions):
try:
t_val = float(t)
p_val = float(p)
if not (np.isfinite(t_val) and np.isfinite(p_val)):
continue
y_true.append(t_val)
y_pred.append(p_val)
except Exception:
continue
if len(y_true) == 0:
return {
'total': total,
'filtered': len(y_true),
'MAE': None,
'RMSE': None,
'MAD': None,
'MAD/MAE': None
}
mae = mean_absolute_error(y_true, y_pred)
rmse = mean_squared_error(y_true, y_pred, squared=False)
mean_value = np.mean(y_true)
baseline_pred = [mean_value] * len(y_true)
mad = mean_absolute_error(y_true, baseline_pred)
mad_mae_ratio = mad / mae if mae != 0 else None
return {
'total': total,
'filtered': len(y_true),
'MAE': mae,
'RMSE': rmse,
'MAD': mad,
'MAD/MAE': mad_mae_ratio
}
else:
y_true = []
y_pred = []
auc = None
try:
for t, p in zip(references, predictions):
if t in ['Null']:
continue
if t in ['Direct', 'True', True]:
y_true.append(1)
elif t in ['Indirect', 'False', False]:
y_true.append(0)
else:
continue

if p in ['Direct', 'True', True]:
y_pred.append(1)
elif p in ['Indirect', 'False', False]:
y_pred.append(0)
else:
y_true.pop()
continue
auc = roc_auc_score(y_true, y_pred)
except Exception:
pass
return {'AUC': auc}

+ 15
- 0
opencompass/datasets/SciReasoner/Mol_Instructions/__init__.py View File

@@ -0,0 +1,15 @@
from .biotext import Mol_Instructions_Dataset_BioText # noqa: F401, F403
from .biotext import Mol_Instructions_Evaluator_BioText # noqa: F401, F403
from .biotext import Mol_Instructions_postprocess_BioText # noqa: F401, F403
from .molecule import Mol_Instructions_Dataset # noqa: F401, F403
from .molecule import Mol_Instructions_Evaluator_Mol # noqa: F401, F403
from .molecule import Mol_Instructions_postprocess_Mol # noqa: F401, F403
from .normalized_SW_score import normalized_smith_waterman # noqa: F401, F403
from .protein import \
Mol_Instructions_Dataset_Protein_Design # noqa: F401, F403
from .protein import Mol_Instructions_Evaluator_Protein # noqa: F401, F403
from .protein import \
Mol_Instructions_Evaluator_Protein_Design # noqa: F401, F403
from .protein import Mol_Instructions_postprocess_Protein # noqa: F401, F403
from .protein import \
Mol_Instructions_postprocess_Protein_Design # noqa: F401, F403

+ 331
- 0
opencompass/datasets/SciReasoner/Mol_Instructions/biotext.py View File

@@ -0,0 +1,331 @@
# flake8: noqa
# molecule task
# https://github.com/zjunlp/Mol-Instructions/tree/main/evaluation/molecule

import json
import os
import re
from typing import List

from datasets import Dataset, DatasetDict
from huggingface_hub import hf_hub_download
from sklearn.metrics import precision_recall_fscore_support

from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
from opencompass.utils import get_data_path


def CER_calculate_f1_score(true_entities, predicted_entities):
true_entities = set(true_entities.split(', '))
predicted_entities = set(predicted_entities.split(', '))
true_positive = len(true_entities & predicted_entities)
precision = true_positive / len(predicted_entities) if len(
predicted_entities) > 0 else 0
recall = true_positive / len(true_entities) if len(
true_entities) > 0 else 0

f1_score = 2 * (precision * recall) / (precision + recall) if (
precision + recall) > 0 else 0
# print(true_entities,predicted_entities,f1_score)
return f1_score


def calculate_f1_score(true_entities, predicted_entities):
# import pdb;pdb.set_trace()
pattern = r'\(.*?\)'
true_entities = re.findall(pattern, true_entities)
predicted_entities_tmp = re.findall(pattern, predicted_entities)
if not predicted_entities_tmp:
# add () to predicted_entities if it is empty
predicted_entities = f'({predicted_entities})'
predicted_entities_tmp = re.findall(pattern, predicted_entities)

predicted_entities = [entity.strip() for entity in predicted_entities_tmp]

true_entities = set(true_entities)
predicted_entities = set(predicted_entities)
true_positive = len(true_entities & predicted_entities)
precision = true_positive / len(predicted_entities) if len(
predicted_entities) > 0 else 0
recall = true_positive / len(true_entities) if len(
true_entities) > 0 else 0

f1_score = 2 * (precision * recall) / (precision + recall) if (
precision + recall) > 0 else 0
return f1_score


def calculate_accuracy_(predictions, references):
correct_count = 0
total_count = len(references)
for i, (pred, ref) in enumerate(zip(predictions, references)):
pred = pred[0].lower()
ref = ref[0].lower()
f1_score = calculate_f1_score(ref, pred)
correct_count += f1_score

return correct_count / total_count


def CER_calculate_accuracy_(predictions, references):
correct_count = 0
total_count = len(references)
for i, (pred, ref) in enumerate(zip(predictions, references)):
pred = pred[0].lower()
ref = ref[0].lower()
f1_score = CER_calculate_f1_score(ref, pred)
# print(f1_score)
correct_count += f1_score

return correct_count / total_count


def ture_or_false_calculate_accuracy_(predictions, references):
x, y, z = 0, 0, 0
correct_count = 0
total_count = len(references)
other_answers = 0
for i, (pred, ref) in enumerate(zip(predictions, references)):
pred = pred[0].lower()
ref = ref[0].lower()
correct_first_word = ref.split(',')[0].strip().lower()
# my_first_word = pred.split(',')[0].strip().lower()
pred = pred.strip().lower()
if 'yes' in pred:
my_first_word = 'yes'
elif 'no' in pred:
my_first_word = 'no'
elif 'maybe' in pred or 'may be' in pred or 'might' in pred:
my_first_word = 'maybe'
else:
other_answers += 1
my_first_word = 'other'
print(f'Other answer: {pred}, reference: {ref}')

if correct_first_word == 'no' and my_first_word == 'no':
x += 1
if correct_first_word == 'no':
y += 1
if my_first_word == 'no':
z += 1
if correct_first_word == my_first_word:
correct_count += 1
accuracy = (correct_count / total_count) * 100
return accuracy, other_answers


def calculate_macro_f1_(predictions, references):
correct_answers = [
ref[0].split(',')[0].strip().lower() for ref in references
]
my_answers = [
pred[0].split(',')[0].strip().lower() for pred in predictions
]
# Compute precision, recall, and F1-score for each class
precision, recall, f1, _ = precision_recall_fscore_support(
correct_answers,
my_answers,
labels=['yes', 'no', 'maybe'],
average=None)
# Calculate macro F1 by averaging F1-scores for all classes
macro_f1 = sum(f1) / len(f1)

return macro_f1


def multi_choice_question_calculate_accuracy(question_data):
correct_count = 0
total_count = len(question_data)
for i, question in enumerate(question_data):
correct_answer = question['output'].split('(')[1].split(')')[0]
my_answer = question['my_output'][0]
if '(A' in question['my_output'] or 'A)' in question[
'my_output'] or ' A ' in question['my_output']:
my_answer = 'A'
elif '(B' in question['my_output'] or 'B)' in question[
'my_output'] or ' B ' in question['my_output']:
my_answer = 'B'
elif '(C' in question['my_output'] or 'C)' in question[
'my_output'] or ' C ' in question['my_output']:
my_answer = 'C'
elif '(D' in question['my_output'] or 'D)' in question[
'my_output'] or ' D ' in question['my_output']:
my_answer = 'D'
if correct_answer == my_answer:
correct_count += 1
accuracy = (correct_count / total_count) * 100

return accuracy


def multi_choice_question_calculate_accuracy_(predictions, references):
correct_count = 0
total_count = len(references)
for i, (pred, ref) in enumerate(zip(predictions, references)):
correct_answer = ref[0].split('(')[1].split(')')[0]
my_answer = pred[0]
if '(A' in pred[0] or 'A)' in pred[0] or ' A ' in pred[0]:
my_answer = 'A'
elif '(B' in pred[0] or 'B)' in pred[0] or ' B ' in pred[0]:
my_answer = 'B'
elif '(C' in pred[0] or 'C)' in pred[0] or ' C ' in pred[0]:
my_answer = 'C'
elif '(D' in pred[0] or 'D)' in pred[0] or ' D ' in pred[0]:
my_answer = 'D'
if correct_answer == my_answer:
correct_count += 1
accuracy = (correct_count / total_count) * 100

return accuracy


@LOAD_DATASET.register_module()
class Mol_Instructions_Dataset_BioText(BaseDataset):

@staticmethod
def load(path, task, max_cut=-1, mini_set=False, hf_hub=False):

# if (hf_hub is True):
# # load from huggingface hub
# train_data = []
# repo_id = test_path.split('/')[0] + '/' + test_path.split('/')[1]
# train_path = train_path.split(repo_id + '/')[1]
# test_path = test_path.split(repo_id + '/')[1]
#
# train_path = hf_hub_download(repo_id,
# train_path,
# repo_type='dataset')
# test_path = hf_hub_download(repo_id,
# test_path,
# repo_type='dataset')

path = get_data_path(path)
train_path = os.path.join(path, f'{task}/dev/data.json')
test_path = os.path.join(path, f'{task}/test/data.json')

with open(train_path, 'r', encoding='utf-8') as f:
train_data = json.load(f)
with open(test_path, 'r', encoding='utf-8') as f:
test_data = json.load(f)

train_data = train_data[:5]
# Limit the dataset to 5 samples for testing purposes

if (max_cut != -1):
test_data = test_data[:max_cut]
if mini_set:
import random
random.seed(1024)
test_data = random.sample(test_data, 50)
random.seed()

dataset = DatasetDict({
'train': Dataset.from_list(train_data),
'test': Dataset.from_list(test_data)
})
return dataset


@TEXT_POSTPROCESSORS.register_module('Mol_Instructions_postprocess_BioText')
def Mol_Instructions_postprocess_BioText(text, task, *args, **kwargs):
"""
Extract the protein str between <protein> and </protein> in the sentences
"""
text = text.strip()
if task in (
'chemical_disease_interaction_extraction',
'chemical_protein_interaction_extraction',
'chemical_entity_recognition',
'true_or_false_question',
'multi_choice_question',
'open_question',
):
# For property prediction, we only need the first line of the text
text = text.strip()
text = re.sub(r'<\|endoftext\|>', '', text)
text = re.sub(r'<\|im_end\|>', '', text)

# remove "Response: " or "Answer: " at the beginning for qwen3
text = re.sub(r'^(Response:|Answer:)\s*',
'',
text,
flags=re.IGNORECASE)
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
# remove the sentences before </think> for gpt-oss-120b
text = re.sub(r'.*?</think>\s*', '', text, flags=re.DOTALL)

# remove "I would say that" or
# "I would like to say that" at the beginning for qwen3
text = re.sub(r'^(I would say that|I would like to say that)\s*',
'',
text,
flags=re.IGNORECASE)
text = text.strip()
else:
pass
return text


class Mol_Instructions_Evaluator_BioText(BaseEvaluator):

def __init__(self, task='protein_design', *args, **kwargs):
super().__init__(*args, **kwargs)
self.task = task

def score(self, predictions: List[str], references: List[str]):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
}
if not isinstance(predictions[0], list):
predictions = [[pred] for pred in predictions]
if not isinstance(references[0], list):
references = [[ref] for ref in references]

if self.task in (
'chemical_disease_interaction_extraction',
'chemical_protein_interaction_extraction',
):
results = {
'f1': calculate_accuracy_(predictions, references),
}
elif self.task in ('chemical_entity_recognition', ):
results = {
'f1': CER_calculate_accuracy_(predictions, references),
}
elif self.task == 'true_or_false_question':
acc, other_answers = ture_or_false_calculate_accuracy_(
predictions, references)
results = {
'accuracy': acc,
'other_answers': other_answers,
}
elif self.task == 'multi_choice_question':
results = {
'accuracy':
multi_choice_question_calculate_accuracy_(
predictions, references),
}
elif self.task == 'open_question':
from bert_score import score
correct_answers = [ref[0] for ref in references]
my_answers = [pred[0] for pred in predictions]
P, R, F1 = score(my_answers,
correct_answers,
lang='en',
verbose=False,
num_layers=14,
model_type='FacebookAI/roberta-large')

results = {
# 'bleu': total_bleu/len(my_answers),
# 'rouge': total_rouge/len(my_answers),
'bert_score': sum(F1).item() / len(F1),
}
else:
raise ValueError(f'Unknown task: {self.task}')

return results

+ 458
- 0
opencompass/datasets/SciReasoner/Mol_Instructions/molecule.py View File

@@ -0,0 +1,458 @@
# flake8: noqa
# molecule task
# https://github.com/zjunlp/Mol-Instructions/tree/main/evaluation/molecule

import json
import re

import numpy as np
from datasets import Dataset, DatasetDict
from huggingface_hub import hf_hub_download
from Levenshtein import distance as lev
from nltk.translate.bleu_score import corpus_bleu
from nltk.translate.meteor_score import meteor_score

try:
from rdkit import Chem, DataStructs, RDLogger
from rdkit.Chem import AllChem, MACCSkeys
except Exception:
Chem, DataStructs, RDLogger, AllChem, MACCSkeys = None, None, None, None, None

try:
import selfies as sf
except Exception:
sf = None

import os

from rouge_score import rouge_scorer
from sklearn.metrics import mean_absolute_error
from transformers import BertTokenizerFast

from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
from opencompass.utils import get_data_path

# RDLogger.DisableLog('rdApp.*')


@LOAD_DATASET.register_module()
class Mol_Instructions_Dataset(BaseDataset):

@staticmethod
def load(path, task, max_cut=-1, mini_set=False, hf_hub=False):

# if (hf_hub is True):
# # load from huggingface hub
# train_data = []
# repo_id = test_path.split('/')[0] + '/' + test_path.split('/')[1]
# train_path = train_path.split(repo_id + '/')[1]
# test_path = test_path.split(repo_id + '/')[1]
#
# train_path = hf_hub_download(repo_id,
# train_path,
# repo_type='dataset')
# test_path = hf_hub_download(repo_id,
# test_path,
# repo_type='dataset')

path = get_data_path(path)
train_path = os.path.join(path, f'{task}/dev/data.json')
test_path = os.path.join(path, f'{task}/test/data.json')

with open(train_path, 'r', encoding='utf-8') as f:
train_data = json.load(f)
with open(test_path, 'r', encoding='utf-8') as f:
test_data = json.load(f)

train_data = train_data[:5]
# Limit the dataset to 5 samples for testing purposes

if (max_cut != -1):
test_data = test_data[:max_cut]
if mini_set:
import random
random.seed(1024)
test_data = random.sample(test_data, 150)
random.seed()

dataset = DatasetDict({
'train': Dataset.from_list(train_data),
'test': Dataset.from_list(test_data)
})
return dataset


def convert_to_canonical_smiles(smiles):
molecule = Chem.MolFromSmiles(smiles)
if molecule is not None:
canonical_smiles = Chem.MolToSmiles(molecule,
isomericSmiles=False,
canonical=True)
return canonical_smiles
else:
return None


@TEXT_POSTPROCESSORS.register_module()
def Mol_Instructions_postprocess_Mol(text, task, *args, **kwargs):
"""
Filter end tokens in the sentences: "<|endoftext|>","<|im_end|>"
"""
if task == 'property_prediction_str':
# For property prediction, we only need the first line of the text
text = text.strip()
text = re.sub(r'<\|endoftext\|>', '', text)
text = re.sub(r'<\|im_end\|>', '', text)
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
text = re.sub(r'.*?</think>\s*', '', text, flags=re.DOTALL)
text = re.sub(r'(?<=\d) +(?=\d)|(?<=\.) +(?=\d)', '', text)
num_match = re.search(r'[-+]?\d*\.\d+|\d+', text)
text = num_match.group(0) if num_match else 0
elif task in [
'description_guided_molecule_design',
'forward_reaction_prediction',
'retrosynthesis',
'reagent_prediction',
]:
text = text.strip()
text = re.sub(r'<\|endoftext\|>', '', text)
text = re.sub(r'<\|im_end\|>', '', text)
# first filter the <think></think> pattern

text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
text = re.sub(r'.*?</think>\s*', '', text, flags=re.DOTALL)

pattern = r'<SMILES>(.*?)</SMILES>'
match = re.search(pattern, text)
if match:
smiles = match.group(1).strip()
text = convert_to_canonical_smiles(smiles)
else:
# print('No SMILES found in the text. Using the original text.')
# print(text)
# import pdb; pdb.set_trace()
text = None # generate a false SMILES to avoid error in evaluation
elif task in [
'molecular_description_generation',
]:
text = text.strip()
text = re.sub(r'<\|endoftext\|>', '', text)
text = re.sub(r'<\|im_end\|>', '', text)
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
text = re.sub(r'.*?</think>\s*', '', text, flags=re.DOTALL)

return text


def compute_MAE_property_prediction_str(predictions, references):
y_pred = np.array([float(p[0]) for p in predictions])
y_true = np.array([float(r[0]) for r in references])
mae = mean_absolute_error(
y_true,
y_pred) * 1000 # scale to match the presentation of Opencompass
return {'mae': mae}


def compute_fingerprint_metricts(
predictions,
references,
morgan_r=2,
):
bad_mols = 0
outputs = []

for pred, refer in zip(predictions, references):
try:
if pred[0] is None:
bad_mols += 1
continue
pred_ = Chem.MolFromSmiles(pred[0])
refer_ = Chem.MolFromSmiles(refer[0])
if pred_ is None:
# print(pred)
bad_mols += 1
continue
outputs.append((refer_, pred_))
except Exception:
import pdb
pdb.set_trace()

validity_score = len(outputs) / (len(outputs) + bad_mols)

MACCS_sims = []
morgan_sims = []
RDK_sims = []

enum_list = outputs

for i, (gt_m, ot_m) in enumerate(enum_list):
# if i % 100 == 0:
# if verbose: print(i, 'processed.')

MACCS_sims.append(
DataStructs.FingerprintSimilarity(
MACCSkeys.GenMACCSKeys(gt_m),
MACCSkeys.GenMACCSKeys(ot_m),
metric=DataStructs.TanimotoSimilarity))
RDK_sims.append(
DataStructs.FingerprintSimilarity(
Chem.RDKFingerprint(gt_m),
Chem.RDKFingerprint(ot_m),
metric=DataStructs.TanimotoSimilarity))
morgan_sims.append(
DataStructs.TanimotoSimilarity(
AllChem.GetMorganFingerprint(gt_m, morgan_r),
AllChem.GetMorganFingerprint(ot_m, morgan_r)))

maccs_sims_score = np.mean(MACCS_sims)
rdk_sims_score = np.mean(RDK_sims)
morgan_sims_score = np.mean(morgan_sims)

return {
'validity_score': validity_score,
'maccs_sims_score': maccs_sims_score,
'rdk_sims_score': rdk_sims_score,
'morgan_sims_score': morgan_sims_score
}


def compute_mol_translation_selfies(predictions, references):
outputs = []
bad_mols = 0
print(f'predictions: {predictions}, references: {references}')
for pred, refer in zip(predictions, references):
if pred[0] is None:
bad_mols += 1
continue
pred_canonical_smiles = pred[0]
refer_canonical_smiles = refer[0]
try:
pred_sf = sf.encoder(pred_canonical_smiles)
refer_sf = sf.encoder(refer_canonical_smiles)
except Exception:
bad_mols += 1
continue

outputs.append(
(refer_sf, pred_sf, refer_canonical_smiles, pred_canonical_smiles))

references_self = []
hypotheses_self = []

references_smi = []
hypotheses_smi = []

for i, (gt_self, ot_self, gt_smi, ot_smi) in enumerate(outputs):
gt_self_tokens = [c for c in gt_self]
out_self_tokens = [c for c in ot_self]

references_self.append([gt_self_tokens])
hypotheses_self.append(out_self_tokens)

gt_smi_tokens = [c for c in gt_smi]
ot_smi_tokens = [c for c in ot_smi]

references_smi.append([gt_smi_tokens])
hypotheses_smi.append(ot_smi_tokens)

# BLEU score
if not references_self or not hypotheses_self:
bleu_score_self = 0.0
else:
bleu_score_self = corpus_bleu(references_self, hypotheses_self)

references_self = []
hypotheses_self = []

references_smi = []
hypotheses_smi = []

levs_self = []
levs_smi = []

num_exact = 0

i = 0
for i, (gt_self, ot_self, gt_smi, ot_smi) in enumerate(outputs):

hypotheses_self.append(ot_self)
references_self.append(gt_self)

hypotheses_smi.append(ot_smi)
references_smi.append(gt_smi)

try:
m_out = Chem.MolFromSmiles(ot_smi)
m_gt = Chem.MolFromSmiles(gt_smi)

if Chem.MolToInchi(m_out) == Chem.MolToInchi(m_gt):
num_exact += 1
# if gt == out: num_exact += 1
# old version that didn't standardize strings
except Exception:
bad_mols += 1

levs_self.append(lev(ot_self, gt_self))
levs_smi.append(lev(ot_smi, gt_smi))

# Exact matching score
exact_match_score = num_exact / (i + 1)
# if verbose:
# print('Exact Match:')
# print(exact_match_score)

# Levenshtein score
levenshtein_score_smi = np.mean(levs_smi)
# if verbose:
# print('SMILES Levenshtein:')
# print(levenshtein_score_smi)

return {
'bleu_self_scores': bleu_score_self,
'exact_match_score': exact_match_score,
'levenshtein_score_smi': levenshtein_score_smi,
}


def fix_smiles_brackets(smiles):
"""修复SMILES字符串中缺失的右括号"""
if not isinstance(smiles, str):
return smiles

left_count = smiles.count('(')
right_count = smiles.count(')')
missing = left_count - right_count

if missing > 0:
return smiles + ')' * missing
return smiles


class Mol_Instructions_Evaluator_Mol(BaseEvaluator):

def __init__(self, task, *args, **kwargs):
super().__init__(*args, **kwargs)
self.task = task

def score(self, predictions, references):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
}
if not isinstance(predictions[0], list):
predictions = [[pred] for pred in predictions]
if not isinstance(references[0], list):
references = [[ref] for ref in references]
# import pdb;pdb.set_trace()
task = self.task
pred_list = predictions
gold_list = references

if task in ('property_prediction_str', ):
results = compute_MAE_property_prediction_str(pred_list, gold_list)
elif task in ('description_guided_molecule_design',
'forward_reaction_prediction', 'retrosynthesis',
'reagent_prediction'):
fingerprint_metrics = compute_fingerprint_metricts(
pred_list, gold_list)
mol_translation_selfies = compute_mol_translation_selfies(
pred_list, gold_list)
# Combine the results from both computations
results = {**fingerprint_metrics, **mol_translation_selfies}
# change the order to
# 'exact', 'blue', 'levenshtein', 'RDK',
# 'MACCS', 'Morgan', 'validity'
results = {
'exact_match_score': results['exact_match_score'],
'bleu_self_scores': results['bleu_self_scores'],
'levenshtein_score_smi': results['levenshtein_score_smi'],
'rdk_sims_score': results['rdk_sims_score'],
'maccs_sims_score': results['maccs_sims_score'],
'morgan_sims_score': results['morgan_sims_score'],
'validity_score': results['validity_score']
}
elif task in ('molecular_description_generation', ):
results = compute_text_translation_metrics(pred_list, gold_list)
else:
raise ValueError(task)

return results


def compute_text_translation_metrics(
predictions,
references,
text_model='allenai/scibert_scivocab_uncased',
text_trunc_length=512):
outputs = []

for pred, refer in zip(predictions, references):
try:
pred_ = pred[0].rsplit('.', 1)[0] + '.' if isinstance(
pred[0], str) else pred[0]
outputs.append((refer[0], pred_))
except Exception:
import pdb
pdb.set_trace()

text_tokenizer = BertTokenizerFast.from_pretrained(text_model)

meteor_scores = []

references = []
hypotheses = []

for i, (gt, out) in enumerate(outputs):
gt_tokens = text_tokenizer.tokenize(gt,
truncation=True,
max_length=text_trunc_length,
padding='max_length')
gt_tokens = list(filter(('[PAD]').__ne__, gt_tokens))
gt_tokens = list(filter(('[CLS]').__ne__, gt_tokens))
gt_tokens = list(filter(('[SEP]').__ne__, gt_tokens))

out_tokens = text_tokenizer.tokenize(out,
truncation=True,
max_length=text_trunc_length,
padding='max_length')
out_tokens = list(filter(('[PAD]').__ne__, out_tokens))
out_tokens = list(filter(('[CLS]').__ne__, out_tokens))
out_tokens = list(filter(('[SEP]').__ne__, out_tokens))

references.append([gt_tokens])
hypotheses.append(out_tokens)

mscore = meteor_score([gt_tokens], out_tokens)
meteor_scores.append(mscore)

bleu2 = corpus_bleu(references, hypotheses, weights=(.5, .5))
bleu4 = corpus_bleu(references, hypotheses, weights=(.25, .25, .25, .25))

_meteor_score = np.mean(meteor_scores)

scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'])

rouge_scores = []

references = []
hypotheses = []

for i, (gt, out) in enumerate(outputs):
rs = scorer.score(out, gt)
rouge_scores.append(rs)

rouge_1 = np.mean([rs['rouge1'].fmeasure for rs in rouge_scores])
rouge_2 = np.mean([rs['rouge2'].fmeasure for rs in rouge_scores])
rouge_l = np.mean([rs['rougeL'].fmeasure for rs in rouge_scores])

return {
'bleu2': bleu2,
'bleu4': bleu4,
'meteor_score': _meteor_score,
'rouge1': rouge_1,
'rouge2': rouge_2,
'rougeL': rouge_l
}

+ 150
- 0
opencompass/datasets/SciReasoner/Mol_Instructions/normalized_SW_score.py View File

@@ -0,0 +1,150 @@
import math


def normalized_smith_waterman(seq1,
seq2,
matrix_name='BLOSUM45',
open_gap=-10,
extend_gap=-0.5):
"""
Compute normalized Smith-Waterman score for protein sequences.

Args:
seq1, seq2 (str): Protein sequences (uppercase letters)
matrix_name (str): Name of substitution matrix (default: BLOSUM62)
open_gap (float): Gap opening penalty
extend_gap (float): Gap extension penalty

Returns:
float: Normalized score between 0.0 and 1.0
"""

from Bio.Align import PairwiseAligner, substitution_matrices

# Initialize aligner
aligner = PairwiseAligner()
aligner.mode = 'local' # Smith-Waterman algorithm
aligner.open_gap_score = open_gap
aligner.extend_gap_score = extend_gap

# Load substitution matrix
try:
matrix = substitution_matrices.load(matrix_name)
except ValueError:
raise ValueError(f'Matrix {matrix_name} not available.'
f' Try: {substitution_matrices.load()}')

# Set substitution matrix
aligner.substitution_matrix = matrix

# Calculate raw alignment score
raw_score = aligner.score(seq1, seq2)
if raw_score <= 0:
return 0.0

# Calculate self-alignment scores
def calc_self_score(seq):
"""Calculate maximum possible self-alignment score"""
score = 0
for aa in seq:
try:
# Try direct lookup
score += matrix[aa, aa]
except KeyError:
# Try reverse lookup for symmetric matrices
score += matrix[aa, aa] # Same residue
return score

self_score1 = calc_self_score(seq1)
self_score2 = calc_self_score(seq2)

# Handle invalid self-scores
if self_score1 <= 0 or self_score2 <= 0:
return 0.0

# Compute normalization factor (geometric mean)
norm_factor = math.sqrt(self_score1 * self_score2)

return min(raw_score / norm_factor, 1.0)


# 示例用法
if __name__ == '__main__':
# 示例序列(可以替换为实际的蛋白质序列)
# target_sequence = "MGGKWSKSSIVGWPAVRERIRQTEPRTEPAA" # 目标序列
# generated_sequence = "MGGKWSKSSIVGWPAVRERIRRTEPAA" # 模型生成的序列
#
# # target_sequence = 'MSTNPKPQRKTKRNTNRRPQDVKFPGGG'
# # generated_sequence = 'MSTNPKPQRKTKRNTNRRPQDVK'
#
# # 计算归一化 SW 得分
# normalized_score = calculate_normalized_sw_score(
# target_sequence,
# generated_sequence,
# gap_open=-10,
# gap_extend=-0.5,
# match_score=2,
# mismatch_score=-1
# )
#
# print(f"归一化 SW 得分: {normalized_score:.3f}")
#
# # 计算归一化 Smith-Waterman 得分
# normalized_sw_score = normalized_smith_waterman(
# target_sequence,
# generated_sequence,
# )
# print(f"归一化 Smith-Waterman 得分: {normalized_sw_score:.4f}")
import json
import os
import re

def Mol_Instructions_postprocess_Protein_Design(text, *args, **kwargs):
"""
Extract the protein str between
<protein> and </protein> in the sentences
"""
text = text.strip()
pattern = r'<protein>(.*?)</protein>'
match = re.search(pattern, text)
if match:
text = match.group(1)
# filter to make sure letters are all in the alphabet
valid_letters = set('ACDEFGHIKLMNPQRSTVWY')
text = ''.join(filter(lambda x: x in valid_letters, text))
else:
text = ''
return text

pred_list = []
gt_list = []
scores = []
json_dir = (
'/root/code/opencompass-sci/outputs/protein/mol_instructions/'
'20250619_185027/predictions/qwen3-1.7B-sft-protein_0.7T_0.9p_50k')
for filename in os.listdir(json_dir):
if filename.endswith('.json'):
file_path = os.path.join(json_dir, filename)
with open(file_path, 'r') as f:
data = json.load(f)
for key, value in data.items():
pred = Mol_Instructions_postprocess_Protein_Design(
value['prediction'])
gt = Mol_Instructions_postprocess_Protein_Design(
value['gold'])
pred_list.append(pred)
gt_list.append(gt)
if not pred or not gt:
scores.append(0.0)
else:
# Calculate the normalized Smith-Waterman score
try:
score = normalized_smith_waterman(pred, gt)
scores.append(score)
except Exception:
import pdb

pdb.set_trace()
import pdb

pdb.set_trace()

+ 155
- 0
opencompass/datasets/SciReasoner/Mol_Instructions/protein.py View File

@@ -0,0 +1,155 @@
# flake8: noqa
# molecule task
# https://github.com/zjunlp/Mol-Instructions/tree/main/evaluation/molecule

import json
import os
import re
from typing import List, Optional

from datasets import Dataset, DatasetDict
from huggingface_hub import hf_hub_download
from mmengine.config import ConfigDict

from opencompass.datasets.base import BaseDataset
from opencompass.datasets.SciReasoner.Mol_Instructions.normalized_SW_score import \
normalized_smith_waterman
from opencompass.openicl import BaseEvaluator, RougeEvaluator
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
from opencompass.utils import get_data_path


@LOAD_DATASET.register_module()
class Mol_Instructions_Dataset_Protein_Design(BaseDataset):

@staticmethod
def load(path, task, max_cut=-1, mini_set=False, hf_hub=False):
# import pdb; pdb.set_trace()
# if (hf_hub is True):
# # load from huggingface hub
# train_data = []
# repo_id = test_path.split('/')[0] + '/' + test_path.split('/')[1]
# train_path = train_path.split(repo_id + '/')[1]
# test_path = test_path.split(repo_id + '/')[1]
#
# train_path = hf_hub_download(repo_id,
# train_path,
# repo_type='dataset')
# test_path = hf_hub_download(repo_id,
# test_path,
# repo_type='dataset')

path = get_data_path(path)
train_path = os.path.join(path, f'{task}/dev/data.json')
test_path = os.path.join(path, f'{task}/test/data.json')

with open(train_path, 'r', encoding='utf-8') as f:
train_data = json.load(f)
with open(test_path, 'r', encoding='utf-8') as f:
test_data = json.load(f)

train_data = train_data[:5]
# Limit the dataset to 5 samples for testing purposes

if (max_cut != -1):
test_data = test_data[:max_cut]
if mini_set:
import random
random.seed(1024)
test_data = random.sample(test_data, 150)
random.seed()

dataset = DatasetDict({
'train': Dataset.from_list(train_data),
'test': Dataset.from_list(test_data)
})
return dataset


@TEXT_POSTPROCESSORS.register_module('Mol_Instructions_postprocess_Protein')
def Mol_Instructions_postprocess_Protein(text, *args, **kwargs):
"""
Filter end tokens in the sentences: "<|endoftext|>","<|im_end|>"
"""
text = text.strip()
text = re.sub(r'<\|endoftext\|>', '', text)
text = re.sub(r'<\|im_end\|>', '', text)
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
text = re.sub(r'.*?</think>\s*', '', text, flags=re.DOTALL)
text = text.strip()

return text


class Mol_Instructions_Evaluator_Protein(RougeEvaluator):

def __init__(self,
task='catalytic_activity',
pred_postprocessor: Optional[ConfigDict] = None):
super().__init__(pred_postprocessor=pred_postprocessor, )
self.task = task


@TEXT_POSTPROCESSORS.register_module(
'Mol_Instructions_postprocess_Protein_Design')
def Mol_Instructions_postprocess_Protein_Design(text, *args, **kwargs):
"""
Extract the protein str between <protein> and </protein> in the sentences
"""
text = text.strip()
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
text = re.sub(r'.*?</think>\s*', '', text, flags=re.DOTALL)
pattern = r'<protein>(.*?)</protein>'
match = re.search(pattern, text)
if match:
text = match.group(1)
valid_letters = set('ACDEFGHIKLMNPQRSTVWY')
text = ''.join(filter(lambda x: x in valid_letters, text))
else:
text = ''
return text


class Mol_Instructions_Evaluator_Protein_Design(BaseEvaluator):

def __init__(self, task='protein_design', *args, **kwargs):
super().__init__(*args, **kwargs)
self.task = task

def score(self, predictions: List[str], references: List[str]):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
}
if not isinstance(predictions[0], list):
predictions = [[pred] for pred in predictions]
if not isinstance(references[0], list):
references = [[ref] for ref in references]

scores = []
for pred, refer in zip(predictions, references):
pred = pred[0].strip()
refer = refer[0].strip()
if not pred or not refer:
scores.append(0.0)
else:
# Calculate the normalized Smith-Waterman score
score = normalized_smith_waterman(
pred, refer) * 100 # Convert to percentage
scores.append(score)

averaged_valid_scores = [score for score in scores if score > 0]

results = {
'Max SW score':
max(scores),
'Min SW score':
min(scores),
'Average SW score':
sum(scores) / len(scores),
'valid average SW score':
sum(averaged_valid_scores) /
len(averaged_valid_scores) if averaged_valid_scores else 0.0,
}
return results

+ 471
- 0
opencompass/datasets/SciReasoner/PEER.py View File

@@ -0,0 +1,471 @@
# flake8: noqa
# dataset: PEER
# task : solubility prediction

import json
import os
import re
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Union

import numpy as np
from datasets import Dataset, DatasetDict
from huggingface_hub import hf_hub_download
from openai import OpenAI
from sklearn.metrics import f1_score, precision_score, recall_score

from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
from opencompass.utils import get_data_path


@LOAD_DATASET.register_module()
class PEER_Dataset(BaseDataset):

@staticmethod
def load(path, task, max_cut=-1, mini_set=False, hf_hub=False):

# if (hf_hub is True):
# # load from huggingface hub
# train_data = []
# repo_id = test_path.split('/')[0] + '/' + test_path.split('/')[1]
# train_path = train_path.split(repo_id + '/')[1]
# test_path = test_path.split(repo_id + '/')[1]
#
# train_path = hf_hub_download(repo_id,
# train_path,
# repo_type='dataset')
# test_path = hf_hub_download(repo_id,
# test_path,
# repo_type='dataset')

path = get_data_path(path)
train_path = os.path.join(path, f'{task}/dev/data.json')
test_path = os.path.join(path, f'{task}/test/data.json')

with open(train_path, 'r', encoding='utf-8') as f:
train_data = json.load(f)
with open(test_path, 'r', encoding='utf-8') as f:
test_data = json.load(f)

train_data = train_data[:5]
# Limit the dataset to 5 samples for testing purposes

if (max_cut != -1):
test_data = test_data[:max_cut]
if mini_set:
import random
random.seed(1024)
test_data = random.sample(test_data, 150)
random.seed()

dataset = DatasetDict({
'train': Dataset.from_list(train_data),
'test': Dataset.from_list(test_data)
})
return dataset


@TEXT_POSTPROCESSORS.register_module()
def PEER_postprocess_default(text: Union[str, None]) -> str:
text = text.strip()
text = re.sub(r'<\|endoftext\|>', '', text)
text = re.sub(r'<\|im_end\|>', '', text)
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
return text


@TEXT_POSTPROCESSORS.register_module()
def PEER_postprocess(text: Union[str, None]) -> str:
"""
从模型的原始输出中提取预测结果(Yes或No)。

此函数会查找并返回跟在The answer is后面的Yes或者No,
或从文本中识别常见的Yes/No表达方式。
"""
# 检查输入是否为字符串,提高代码健壮性
if not isinstance(text, str):
return ''
# 定义正则表达式模式,匹配常见的Yes/No表达方式
# 首先检查是否有明确的"The answer is Yes/No"模式
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
text = re.sub(r'.*?</think>\s*', '', text, flags=re.DOTALL)
match = re.search(r'The answer is\s+(Yes|No)', text, re.IGNORECASE)
if match:
return match.group(1)

# 检查常见的肯定表达方式
positive_patterns = [
r'will be soluble',
r'will dissolve',
r'is soluble',
r'can be predicted',
r'positive',
r'Yes',
r'correct',
r'valid',
r'accurate',
r'certainly',
r'indeed',
r'affirmative',
r'highly soluble',
r'easily soluble',
r'dissolves easily',
r'is assured',
# r'likely',
r'be soluble'
]

# 检查常见的否定表达方式
negative_patterns = [
r'will not be soluble',
r'is not soluble',
r'will not dissolve',
r'low solubility',
r'low',
r'cannot be predicted',
r'negative',
r'No',
r'incorrect',
r'invalid',
r'inaccurate',
r'impossible',
r'not possible',
r'denied',
r'be insoluble',
]

# 检查是否包含肯定表达
for pattern in positive_patterns:
if re.search(pattern, text, re.IGNORECASE):
return 'Yes'

# 检查是否包含否定表达
for pattern in negative_patterns:
if re.search(pattern, text, re.IGNORECASE):
return 'No'

# 若无法识别,返回空字符串
return ''


@TEXT_POSTPROCESSORS.register_module()
def PEER_postprocess_float_compare(text: Union[str, None],
compare_number: float) -> str:
# 从模型的输出中匹配预测的数值,与compare_number进行比较, 大于则返回"Yes",否则返回"No"
if not isinstance(text, str):
return ''
try:
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
text = re.sub(r'.*?</think>\s*', '', text, flags=re.DOTALL)
# 提取文本中的数字
match = re.search(r'[-+]?\d*\.\d+|\d+', text)
if match:
value = float(match.group(0))
# 比较数值
if value > compare_number:
return 'Yes'
else:
return 'No'
else:
# 如果没有找到数字,返回空字符串
return ''
except ValueError:
# 如果转换失败,返回空字符串
return ''


def calculate_accuracy(pred_text_list, gold_text_list):
assert len(pred_text_list) == len(gold_text_list)
num_all = len(pred_text_list)
metrics = {}
metrics['num_all'] = num_all
num_no_answer = 0
num_invalid = 0
num_correct = 0
new_pred_text_list, new_gold_text_list = [], []
for (pred_item, gold_item) in zip(pred_text_list, gold_text_list):
if pred_item is None or pred_item == '':
num_no_answer += 1
continue
assert len(pred_item) == 1
assert len(gold_item) == 1
pred_item = pred_item[0].strip().lower()
gold_item = gold_item[0].strip().lower()
if pred_item == '':
num_no_answer += 1
continue
if pred_item not in ('yes', 'no'):
num_invalid += 1
continue
pred_item = 1 if pred_item == 'yes' else 0
gold_item = 1 if gold_item == 'yes' else 0
new_pred_text_list.append(pred_item)
new_gold_text_list.append(gold_item)
if gold_item == pred_item:
num_correct += 1

metrics['num_no_answer'] = num_no_answer
metrics['num_invalid'] = num_invalid
metrics['num_correct'] = num_correct

# return metrics

new_gold_text_list = np.array(new_gold_text_list)
new_pred_text_list = np.array(new_pred_text_list)

# macro_roc_auc_score =
# roc_auc_score(new_gold_text_list, new_pred_text_list)
f1 = f1_score(new_gold_text_list, new_pred_text_list)
# metrics['roc_auc_score'] = macro_roc_auc_score
metrics['accuracy'] = num_correct / (num_all) * 100
metrics['acc_wo_no_answer_invalid'] = num_correct / (
num_all - num_no_answer - num_invalid) * 100 if (
num_all - num_no_answer - num_invalid) > 0 else 0
metrics['precision'] = precision_score(new_gold_text_list,
new_pred_text_list) * 100
metrics['recall'] = recall_score(new_gold_text_list,
new_pred_text_list) * 100
metrics['f1_score'] = f1 * 100

return metrics


# ----------------------------------------------------------------------
# 定义 Evaluator (评估器) - 这是修改的核心
# ----------------------------------------------------------------------

MAX_RETRIES = 3
BACKOFF_SEC = 2


class PEER_Evaluator(BaseEvaluator):

def __init__(self,
task='solubility',
gpt_model='gpt-4',
openai_key='xxx',
use_gpt=True,
max_workers=8,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.task = task
self.gpt_model = gpt_model
self.use_gpt = use_gpt
self.max_workers = max_workers

if task in [
'stability',
]:
self.use_gpt = False

if self.use_gpt:
if not openai_key:
raise ValueError('OpenAI API key is missing.')
self.client = OpenAI(base_url='url', api_key=openai_key)

def _retry_api(self, fn, *args, **kwargs):
last_exc = None
for attempt in range(1, MAX_RETRIES + 1):
try:
result = fn(*args, **kwargs)
if result is not None:
return result
raise ValueError('Received None')
except Exception as e:
last_exc = e
sleep_time = BACKOFF_SEC**attempt
print(f'[retry] attempt {attempt} failed ({e}),'
f' retrying in {sleep_time}s…')
time.sleep(sleep_time)
raise last_exc

def ask_gpt25(self, question, answer, prediction):

prompt = (
'Please determine whether this answer is correct. Definition:'
"'Correct': The core conclusion of the model's answer (if any) is "
'completely consistent with the reference answer (literal identity'
" is not required). 'Incorrect': The core conclusion of the"
" model's answer is consistent with the reference answer, or the"
' core conclusion is not clearly expressed. Reference answer'
f': {answer}'
f'Model answer: {prediction}'
"If correct, answer 'True'; if incorrect, answer 'False'."
"Please only answer 'True' or 'False'.")

def _call():
response = self.client.chat.completions.create(
model=self.gpt_model,
messages=[{
'role': 'user',
'content': prompt
}],
temperature=0)

result = response.choices[0].message.content.strip().upper()
print('=== GPT 判断结果 ===')
print(f'Prompt:\n{prompt}')
print(f'Output:\n{result}')
return result

try:
return self._retry_api(_call)
except Exception as e:
print(f'[GPT ERROR] Exception: {e}')
return ''

def ask_gpt25_batch(self, questions, answers, predictions):
results = [None] * len(questions)

def task(index, q, a, p):
try:
result = self.ask_gpt25(q, a, p)
results[index] = result
except Exception as e:
results[index] = ''
print(f'[GPT ERROR] 批次样本 {index} 出错: {e}')

with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = [
executor.submit(task, i, q, a, p)
for i, (q, a,
p) in enumerate(zip(questions, answers, predictions))
]
for future in as_completed(futures):
pass

return results

def score(self, predictions, references):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different length'
}

if not isinstance(predictions[0], list):
predictions = [[pred] for pred in predictions]
if not isinstance(references[0], list):
references = [[ref] for ref in references]

postprocessed_references = [[PEER_postprocess(r[0]).strip().lower()]
for r in references]
postprocessed_predictions = [[PEER_postprocess(p[0]).strip().lower()]
for p in predictions]

voted_prediction = []
for pred in postprocessed_predictions:
valid_pred = [p for p in pred if p in ['yes', 'no']]
cnt = valid_pred.count('yes')
if cnt > len(valid_pred) / 2:
voted = 'yes'
elif cnt < len(valid_pred) / 2:
voted = 'no'
else:
voted = ''
voted_prediction.append([voted])

num_all = len(voted_prediction)
num_correct, num_no_answer, num_invalid = 0, 0, 0
num_gpt_called = 0
new_pred, new_gold = [], []

to_recheck_indices = []
to_recheck_golds = []
to_recheck_preds = []

for i, (pred_item, gold_item) in enumerate(
zip(postprocessed_predictions, postprocessed_references)):
pred = pred_item[0]
gold = gold_item[0]

if pred not in ('yes', 'no'):
to_recheck_indices.append(i)
to_recheck_golds.append(references[i][0])
to_recheck_preds.append(predictions[i][0])
continue

if pred == 'yes':
pred_bin = 1
elif pred == 'no':
pred_bin = 0
else:
to_recheck_indices.append(i)
to_recheck_golds.append(references[i][0])
to_recheck_preds.append(predictions[i][0])
continue

if gold == 'yes':
gold_bin = 1
elif gold == 'no':
gold_bin = 0
else:
to_recheck_indices.append(i)
to_recheck_golds.append(references[i][0])
to_recheck_preds.append(predictions[i][0])
continue

if pred_bin == gold_bin:
num_correct += 1
# import pdb; pdb.set_trace()
print(references[i][0], '\n', predictions[i][0], '----')
new_pred.append(pred_bin)
new_gold.append(gold_bin)
else:
to_recheck_indices.append(i)
to_recheck_golds.append(references[i][0])
to_recheck_preds.append(predictions[i][0])

if to_recheck_indices and self.use_gpt:
rechecked_preds = self.ask_gpt25_batch(
['' for _ in to_recheck_indices], to_recheck_golds,
to_recheck_preds)
num_gpt_called += len(rechecked_preds)

for i, result in enumerate(rechecked_preds):
result = result.strip().lower()
if 'true' in result:
num_correct += 1
pred_bin = 1
gold_bin = 1
elif 'false' in result:
pred_bin = 0
gold_bin = 1
else:
pred_bin = 1
gold_bin = 0

new_pred.append(pred_bin)
new_gold.append(gold_bin)

new_pred = np.array(new_pred)
new_gold = np.array(new_gold)

metrics = {
'num_all':
num_all,
'num_correct':
num_correct,
'num_no_answer':
num_no_answer,
'num_invalid':
num_invalid,
'num_gpt_called':
num_gpt_called,
'accuracy':
num_correct / num_all * 100,
'acc_wo_no_answer_invalid':
num_correct / (num_all - num_no_answer - num_invalid) * 100 if
(num_all - num_no_answer - num_invalid) > 0 else 0,
'precision':
precision_score(new_gold, new_pred, zero_division=0) * 100,
'recall':
recall_score(new_gold, new_pred, zero_division=0) * 100,
'f1_score':
f1_score(new_gold, new_pred, zero_division=0) * 100,
}

return metrics

+ 13
- 0
opencompass/datasets/SciReasoner/__init__.py View File

@@ -0,0 +1,13 @@
from .bio_instruction import * # noqa: F401, F403
from .bulk_modulus_material import * # noqa: F401, F403
from .composition_material import * # noqa: F401, F403
from .GUE import * # noqa: F401, F403
from .LLM4Chem import * # noqa: F401, F403
from .LLM4Mat import * # noqa: F401, F403
from .Mol_Instructions import * # noqa: F401, F403
from .opi import * # noqa: F401, F403
from .PEER import * # noqa: F401, F403
from .uncond_material import * # noqa: F401, F403
from .uncond_RNA import * # noqa: F401, F403
from .unconditional_molecule_generation import * # noqa: F401, F403
from .unconditional_protein_generation import * # noqa: F401, F403

+ 1440
- 0
opencompass/datasets/SciReasoner/bio_instruction.py View File

@@ -0,0 +1,1440 @@
# flake8: noqa

import json
import os
import re
import sys
from collections import defaultdict

import numpy as np
import pandas as pd
import torch
from datasets import Dataset, DatasetDict
from huggingface_hub import hf_hub_download

try:
from scipy.stats import pearsonr, spearmanr
except Exception:
pearsonr, spermanr = None, None
from sklearn.metrics import (accuracy_score, matthews_corrcoef,
mean_absolute_error, mean_squared_error,
precision_score, recall_score, roc_auc_score)
from tqdm import tqdm
from transformers import pipeline

from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
from opencompass.utils import get_data_path

# current_working_directory = os.getcwd()
# path_bioinstruction = os.path.join(current_working_directory, 'OpenCompass_SciReasoner_extra_data',
# 'datasets', 'bioinstruction')


# @LOAD_DATASET.register_module()
class Bioinstruction_Dataset(BaseDataset):

@staticmethod
def load(path, task, mini_set=False, hf_hub=False):
# if (hf_hub is True):
# # load from huggingface hub
# train_data = []
# repo_id = test_path.split('/')[0] + '/' + test_path.split('/')[1]
# train_path = train_path.split(repo_id + '/')[1]
# test_path = test_path.split(repo_id + '/')[1]
# train_path = hf_hub_download(repo_id,
# train_path,
# repo_type='dataset')
# test_path = hf_hub_download(repo_id,
# test_path,
# repo_type='dataset')

path = get_data_path(path)
train_path = os.path.join(path, f'{task}/dev/data.json')
test_path = os.path.join(path, f'{task}/test/data.json')
with open(train_path, 'r', encoding='utf-8') as f:
train_data = json.load(f)
train_data = train_data[:5]
with open(test_path, 'r', encoding='utf-8') as f:
test_data = json.load(f)

selected_train_data = [{
'input': record['input'],
'output': record['output']
} for record in train_data]
selected_test_data = [{
'input': record['input'],
'output': record['output']
} for record in test_data]
# dataset=Dataset.from_list(selected_train_data)
if mini_set and len(selected_test_data) > 150:
import random
random.seed(1024)
selected_test_data = random.sample(selected_test_data, 150)
random.seed()

dataset = DatasetDict({
'train': Dataset.from_list(selected_train_data),
'test': Dataset.from_list(selected_test_data)
})
return dataset


def extract_answer_part(outputs, left_tag, right_tag, mode='tag'):
assert mode in ('tag', 'direct')

assert isinstance(outputs, list)
answers = []
for text in outputs:
if mode == 'direct' or (left_tag is None and right_tag is None):
text = text.replace('<unk>', '').replace('</s>', '').strip()
answers.append(text.strip())
continue

left_tag_pos = text.find(left_tag)
if left_tag_pos == -1:
answers.append('')
continue
right_tag_pos = text.find(right_tag)
if right_tag_pos == -1:
answers.append('')
continue
text = text[left_tag_pos + len(left_tag):right_tag_pos].strip()
answers.append(text)
return answers


def extract_numeric_values(text):
text = text.replace("5'", "five'")
text = text.replace("3'", 'three')

matches = re.findall(r'(?<![a-zA-Z])[-‑]?\d+\.?\d*', str(text))
# matches = re.findall(r"(?<![a-zA-Z])[-‑]?\d+\.?\d*", str(text))
# matches = \
# re.findall(r'(?<![a-zA-Z0-9])([-‑]?\d+\.?\d*)(?=\.|\s|$)', str(text))

# Convert to floats and ensure values are limited to 6 significant digits
numeric_values = []
for num in matches:
num = num.replace('‑', '-')
value = np.float64(num) # Convert to NumPy float64 for consistent

# Limit the value to 6 significant digits
if value.is_integer(
): # If it's an integer, format as an integer with 6 digits max
value = f'{int(value):.6g}'
else: # For floats, format with 6 significant digits
value = f'{value:.6g}'

numeric_values.append(
float(value)) # Convert back to float for numeric operations

return numeric_values


RNA_CLASSES = sorted([
'5S_rRNA', '5_8S_rRNA', 'tRNA', 'ribozyme', 'CD-box', 'miRNA',
'Intron_gpI', 'Intron_gpII', 'HACA-box', 'riboswitch', 'IRES', 'leader',
'scaRNA'
],
key=len,
reverse=True)

modification_classes = [
'AtoI', 'm6Am', 'm1A', 'm5C', 'm5U', 'm6A', 'm7G', 'Psi', 'Am', 'Cm', 'Gm',
'Um', 'none'
]


def generic_replace(m):
candidate = m.group(1)

if len(candidate) >= 4:
# print(candidate)
return f'<SMILES> {candidate} </SMILES>'
else:
return candidate


# Use the sentiment analysis model as fallback
# if classification by keywords fails
def classify_by_sentiment_model(text):
text = [
str(t).replace('</s>', '').replace('<pad>', '').strip() for t in text
]

candidate_labels = [
'Yes,I can positively identify', 'No,My answer is negative',
'This protein is expected to dissolve in water',
'This protein is not expected to dissolve in water'
]

classifier = pipeline('zero-shot-classification',
model='facebook/bart-large-mnli',
device=0)

outputs = classifier(text, candidate_labels, batch_size=64)
processed_results = []
for output in outputs:
# Hugging Face zero-shot pipeline默认按分数高低排序返回结果
top_label = output['labels'][0]
top_score = output['scores'][0]

if (top_label == 'Yes,I can positively identify' or top_label
== 'This protein is expected to dissolve in water'):
result_class = 1
else:
result_class = 0

processed_results.append((result_class, top_score))
return processed_results


def classify_by_keywords(text):
positive_keywords = [
'Yes', 'yes', 'positive', 'Positive', 'empirical', 'plausible',
'confirms', 'have detected', 'are discernible', 'are supported',
'is supported', 'display', 'detected the presence', 'shows evidence',
'has been identified', 'shows', 'has identified', 'contains ',
'exhibits evidence', 'is plausible', 'contains identifiable', 'Indeed',
'reveals the presence', 'include', 'are present', 'definitely has',
'soluble', 'displays regions', 'has a high solubility',
'dissolves easily', 'Solubility is expected',
'is expected to dissolve', 'is predicted', 'is likely', 'is expected',
'is expected to dissolve', 'will dissolve', 'dissolves easily'
]

negative_keywords = [
'No', 'no', 'negative', 'Negative', 'insoluble', 'does not',
'unlikely', 'absence', 'not found', 'not detected', 'not associated',
'not inferred', 'not linked', 'does not indicate', 'no evidence',
'not predicted', 'absent', 'not present', 'no indicators',
'not exhibit', 'are absent', 'found none', 'did not reveal', 'lacks',
'exhibits no', 'insolubility', 'low solubility', 'not soluble',
'not be soluble', 'does not display regions', 'cannot confirm'
]

dont_know_keywords = [
'don\'t know', 'unknown', 'unsure', 'uncertain', 'not applicable',
'cannot confirm'
]

text_lower = text.lower()

# 为了安全,转义关键词中的特殊字符,并用'|'(或)连接
# \b确保匹配的是整个单词
negative_pattern = r'\b(' + '|'.join(
re.escape(kw) for kw in negative_keywords) + r')\b'
positive_pattern = r'\b(' + '|'.join(
re.escape(kw) for kw in positive_keywords) + r')\b'
dont_know_pattern = r'\b(' + '|'.join(
re.escape(kw) for kw in dont_know_keywords) + r')\b'

# 1. 检查负面关键词
if re.search(negative_pattern, text_lower):
return 0
# 2. 检查正面关键词
elif re.search(positive_pattern, text_lower):
return 1
# 3. 检查 "不知道" 关键词
elif re.search(dont_know_pattern, text_lower):
return 'dont_know'
else:
return None


# Save the processed data for each task in a separate file
# def save_processed_data(model_name, task_name, task_processed_data):
#
# dir_path = path_bioinstruction + f'/processed_data/{model_name}'
# file_path = f'{dir_path}/{task_name}_processed_data.json'
# os.makedirs(dir_path, exist_ok=True)
# with open(file_path, 'w') as outfile:
# json.dump(task_processed_data, outfile, indent=4)
#
# print(f'Task {task_name} procssed data saved in {file_path}')


# Process regression task
def process_regression_task(task_name, task_entries, model_name):
result_values = []
label_values = []
task_processed_data = []
over_len = 0
miss_len = 0
for index, entry in enumerate(task_entries):
# print(entry)
if '<summary>' in entry['model_output']:
entry['model_output'] = entry['model_output'].split(
'<summary>')[-1]
if '</think>' in entry['model_output']:
entry['model_output'] = entry['model_output'].split('</think>')[-1]
extracted_result = extract_numeric_values(entry['model_output'])
else:
if '<think>' in entry['model_output']:
over_len += 1
extracted_result = []
else:
miss_len += 1
extracted_result = extract_numeric_values(
entry['model_output'])

label = float(entry['label'])
print('label', label)
print('extracted_result', extracted_result)

if len(extracted_result
) != 0 and extracted_result[0] > 80 and task_name == 'Isoform':
print(entry['model_output'])
extracted_result = []

if len(extracted_result) != 1:
print('not one:', entry['model_output'])
extracted_result = []

if len(extracted_result) == 0:
result_values.append(
np.inf) # Assign infinity if no valid result is extracted
else:
result_values.append(
extracted_result[0]) # Take the first valid extracted result

label_values.append(label)

task_processed_data.append({
'input':
entry['input'],
'label':
entry['label'],
'processed_model_ouput':
extracted_result[0] if len(extracted_result) > 0 else np.inf,
'original_model_output':
entry['model_output'],
})

# save_processed_data(model_name, task_name, task_processed_data)
print('over_len: ', over_len)
print('miss_len: ', miss_len)
return label_values, result_values


# Compute spearman correlation
def compute_spearman(label_values, result_values):
if len(result_values) == 0:
return {'spearman': 'Error: Empty data'}
elif len(result_values) != len(label_values):
return {
'spearman':
'Error: Mismatch in the number of extracted numeric values'
}

# Convert the label and result values to numpy arrays
result_values = np.array(result_values).flatten()
label_values = np.array(label_values).flatten()

# Identify explicitly assigned infinity values
near_infinity_mask = np.isinf(result_values)

# Exclude near-infinity pairs from the main calculation
valid_mask = ~near_infinity_mask & np.isfinite(
result_values) & np.isfinite(label_values)
valid_result_values = result_values[valid_mask]
valid_label_values = label_values[valid_mask]

outlier_mask = valid_result_values <= 300

valid_result_values = valid_result_values[outlier_mask]
valid_label_values = valid_label_values[outlier_mask]

# 初始化指标
spearman = 0.0
rmse = 0.0

# Compute Spearman correlation for valid values
if len(valid_result_values) > 0:
spearman, _ = spearmanr(valid_label_values, valid_result_values)
mse = mean_squared_error(valid_label_values, valid_result_values)
# 然后开方得到 RMSE
rmse = np.sqrt(mse)

else:
spearman = 0 # Fallback if no valid pairs

total_data_points = len(result_values)
total_valid_points = valid_mask.sum()
num_infinity_values = near_infinity_mask.sum()

if num_infinity_values > 0:
final_spearman_score = (spearman * total_valid_points +
0 * num_infinity_values) / total_data_points
else:
final_spearman_score = spearman # Edge case: no near-infinity values
print('rmse:', rmse)

return {'spearman': final_spearman_score}


# Compute R2
def compute_R2(label_values, result_values):
# from sklearn.metrics import r2_score

# y_true = np.asarray(label_values, dtype=float).flatten()
# y_pred = np.asarray(result_values, dtype=float).flatten()

# Check for empty data
if len(result_values) == 0:
return {'R2': 'Error: Empty data.'}

# Check for equal length of arrays
elif len(result_values) != len(label_values):
return {
'R2': 'Error: Mismatch in the number of extracted numeric values.'
}

# Convert the label and result values to numpy arrays
result_values = np.array(result_values).flatten()
label_values = np.array(label_values).flatten()

# Identify explicitly assigned infinity values
near_infinity_mask = np.isinf(result_values)

# Exclude near-infinity pairs from the main calculation
valid_mask = ~near_infinity_mask & np.isfinite(
result_values) & np.isfinite(label_values)
valid_result_values = result_values[valid_mask]
valid_label_values = label_values[valid_mask]

# Compute Pearson correlation coefficient for valid values
if len(valid_result_values) > 0:
try:
pcc, _ = pearsonr(valid_label_values, valid_result_values)
R2 = pcc**2
# mse = mean_squared_error(valid_label_values, valid_result_values)
# 然后开方得到 RMSE
# rmse = np.sqrt(mse)
except Exception:
R2 = np.inf # Fallback to inf if computation fails
else:
R2 = 0 # Fallback if no valid pairs

# Combine R2 score for valid and infinity values
total_data_points = len(result_values)
total_valid_points = valid_mask.sum()
num_infinity_values = near_infinity_mask.sum()

if num_infinity_values > 0:
final_R2_score = (R2 * total_valid_points +
0 * num_infinity_values) / total_data_points
else:
final_R2_score = R2 # Edge case: no near-infinity values
# print("RMSE:",rmse)
return {'R2': final_R2_score}


# Compute mixed score
def compute_mixed_score(label_values,
result_values,
threshold=30,
max_value=1e3):
rmse = 0.0
if len(result_values) == 0:
return {'mixed_score': 'Error: Empty data.'}
elif len(result_values) != len(label_values):
return {
'mixed_score':
'Error: Mismatch in the number of extracted numeric values'
}

# Convert the label and result values to numeric arrays
# using pandas to handle non-numeric entries
result_values = pd.to_numeric(result_values, errors='coerce').flatten()
label_values = pd.to_numeric(label_values, errors='coerce').flatten()

# Identify near-infinity values
near_infinity_mask = np.abs(result_values) > max_value
if near_infinity_mask.any():
print(
f'Warning: Found {sum(near_infinity_mask)} result values too large'
' will be assigned a mixed score of 0. '
f'Large result values: {result_values[near_infinity_mask]} ')

# Exclude near-infinity pairs from the main calculation
valid_mask = ~near_infinity_mask & np.isfinite(
result_values) & np.isfinite(label_values)
valid_result_values = result_values[valid_mask]
valid_label_values = label_values[valid_mask]

# Assign a mixed score of 0 to near-infinity pairs
num_infinity_values = near_infinity_mask.sum()
if num_infinity_values > 0:
mixed_score_infinity = 0

# Convert to binary based on the threshold for valid values
label_binary = (valid_label_values < threshold).astype(int)
result_binary = (valid_result_values < threshold).astype(int)

# Compute precision, recall, F1 score for valid values
precision = precision_score(label_binary, result_binary, average='binary')
recall = recall_score(label_binary, result_binary, average='binary')
f1 = 2 * precision * recall / (precision + recall) if (precision +
recall) != 0 else 0

try:
# Compute mean absolute error (MAE) for valid values
mae = mean_absolute_error(valid_label_values, valid_result_values)
mse = mean_squared_error(valid_label_values, valid_result_values)
rmse = np.sqrt(mse)

except ValueError:
mae = np.inf # Fallback to infinity if error occurs

# Mask to keep only values in the range [0, threshold] for valid values
mask = (valid_result_values >= 0) & (valid_result_values <= threshold)
if mask.sum() > 0:
range_mae = mean_absolute_error(valid_label_values[mask],
valid_result_values[mask])
else:
range_mae = 100 # Fallback if no values within the range

# Ensure MAE and range_mae are within reasonable bounds to avoid overflow
mae = min(mae, 100)
range_mae = min(range_mae, 100)

# Compute mixed score for valid values
mixed_score_valid = (1 - mae / 100) * 0.5 + (1 -
range_mae / 100) * f1 * 0.5
print(
f'(1 - mae / 100) * 0.5={(1 - mae / 100) * 0.5}\n '
f'(1 - range_mae / 100)={(1 - range_mae / 100)}\n '
f'(1 - range_mae / 100) * f1 * 0.5={(1 - range_mae / 100) * f1 * 0.5}')

# Compute the final mixed score,
# averaging in the score for the near-infinity pairs
total_data_points = len(result_values)
total_valid_points = valid_mask.sum()

if num_infinity_values > 0:
final_mixed_score = (
mixed_score_valid * total_valid_points +
mixed_score_infinity * num_infinity_values) / total_data_points
else:
# Edge case: no near-infinity values
final_mixed_score = mixed_score_valid
print('RMSE', rmse)

return {'mixed_score': final_mixed_score}


# Programmable Switch task:
# multilabel regression output one average correlation
def compute_R2_for_ProgrammableRNASwitches_task(task_name, task_entries,
model_name):
on_result_values = []
off_result_values = []
on_off_result_values = []

on_label_values = []
off_label_values = []
on_off_label_values = []

task_processed_data = []
over_len = 0
miss_len = 0
# Loop through each entry in the task
for entry in task_entries:
label = entry['label']
# label = ast.literal_eval(label)
on_label = float(label['ON'])
off_label = float(label['OFF'])
on_off_label = float(label['ON_OFF'])

# Extract numeric values from the model output
if '</think>' in entry['model_output']:
entry['model_output'] = entry['model_output'].split('</think>')[-1]
else:
if '<think>' in entry['model_output']:
over_len += 1
else:
miss_len += 1
extracted_result = extract_numeric_values(entry['model_output'])
print('extracted_result', extracted_result)

# Handle missing or invalid data by assigning np.nan
if len(extracted_result) != 3:
on_result_values.append(np.nan)
off_result_values.append(np.nan)
on_off_result_values.append(np.nan)
else:
on_result = extracted_result[0]
off_result = extracted_result[1]
on_off_result = extracted_result[2]
on_result_values.append(on_result)
off_result_values.append(off_result)
on_off_result_values.append(on_off_result)

# Append the label values
on_label_values.append(on_label)
off_label_values.append(off_label)
on_off_label_values.append(on_off_label)

# Save processed task data for this entry
task_processed_data.append({
'input':
entry['input'],
'label':
entry['label'],
'processed_model_output': {
'ON': on_result if len(extracted_result) == 3 else np.nan,
'OFF': off_result if len(extracted_result) == 3 else np.nan,
'ON_Off':
on_off_result if len(extracted_result) == 3 else np.nan
},
'original_model_output':
entry['model_output']
})

# Save the processed task data
# save_processed_data(model_name, task_name, task_processed_data)

# Convert to numpy arrays for easier manipulation
on_result_values = np.array(on_result_values)
off_result_values = np.array(off_result_values)
on_off_result_values = np.array(on_off_result_values)

on_label_values = np.array(on_label_values)
off_label_values = np.array(off_label_values)
on_off_label_values = np.array(on_off_label_values)

# Filter out NaN values in ON, OFF, and ON/OFF result/label pairs
on_valid_mask = np.isfinite(on_result_values) & np.isfinite(
on_label_values)
off_valid_mask = np.isfinite(off_result_values) & np.isfinite(
off_label_values)
on_off_valid_mask = np.isfinite(on_off_result_values) & np.isfinite(
on_off_label_values)

# Filter the valid ON, OFF, and ON/OFF values
on_result_values = on_result_values[on_valid_mask]
off_result_values = off_result_values[off_valid_mask]
on_off_result_values = on_off_result_values[on_off_valid_mask]

on_label_values = on_label_values[on_valid_mask]
off_label_values = off_label_values[off_valid_mask]
on_off_label_values = on_off_label_values[on_off_valid_mask]

try:
on_R2 = compute_R2(
on_result_values,
on_label_values)['R2'] if len(on_result_values) > 0 else 0
except Exception:
on_R2 = 0 # Assign 0 in case of error

try:
off_R2 = compute_R2(
off_result_values,
off_label_values)['R2'] if len(off_result_values) > 0 else 0
except Exception:
off_R2 = 0 # Assign 0 in case of error

try:
on_off_R2 = compute_R2(
on_off_result_values,
on_off_label_values)['R2'] if len(on_off_result_values) > 0 else 0
except Exception:
on_off_R2 = 0 # Assign 0 in case of error

# Combine R2 scores for ON, OFF, and ON/OFF values
total_on_points = max(len(on_result_values) + np.sum(~on_valid_mask), 1)
total_off_points = max(len(off_result_values) + np.sum(~off_valid_mask), 1)
total_on_off_points = max(
len(on_off_result_values) + np.sum(~on_off_valid_mask), 1)

# Assign average R2 with 0 for invalid entries
final_on_R2 = (on_R2 * len(on_result_values)) / total_on_points if len(
on_result_values) > 0 else 0
final_off_R2 = (off_R2 * len(off_result_values)) / total_off_points if len(
off_result_values) > 0 else 0
final_on_off_R2 = (on_off_R2 *
len(on_off_result_values)) / total_on_off_points if len(
on_off_result_values) > 0 else 0

avg_R2 = (final_on_R2 + final_off_R2 + final_on_off_R2) / 3
print('over_len: ', over_len)
print('miss_len: ', miss_len)
print('123', final_on_R2, final_off_R2, final_on_off_R2)
return {'R2': avg_R2}


# Enhancer Activity Task:
# multilabel regression output two individual correlation
def compute_PCC_for_enhancer_activity_task(task_name, task_entries,
model_name):
hk_result_values = []
dev_result_values = []

hk_label_values = []
dev_label_values = []

task_processed_data = []
over_len = 0
miss_len = 0
# Loop through each entry in the task
for entry in task_entries:
label = entry['label']
# label = ast.literal_eval(label)
if '</think>' in entry['model_output']:
entry['model_output'] = entry['model_output'].split('</think>')[-1]
else:
if '<think>' in entry['model_output']:
over_len += 1
else:
miss_len += 1
model_output = entry['model_output']
print('model_output', model_output)
hk_label = float(label['hk'])
dev_label = float(label['dev'])

# Extract model output values for HK and Dev enhancer activity
extracted_result = extract_numeric_values(model_output)

# Handle missing or invalid data by assigning np.inf
if len(extracted_result) != 2:

hk_result_values.append(np.inf)
dev_result_values.append(np.inf)
else:
hk_result = extracted_result[0]
dev_result = extracted_result[1]
hk_result_values.append(hk_result)
dev_result_values.append(dev_result)

# Append the label values
hk_label_values.append(hk_label)
dev_label_values.append(dev_label)

# Save processed task data for this entry
task_processed_data.append({
'input':
entry['input'],
'label':
entry['label'],
'processed_model_output': {
'hk': hk_result if len(extracted_result) == 2 else np.inf,
'dev': dev_result if len(extracted_result) == 2 else np.inf
},
'original_model_output':
entry['model_output']
})

# Save the processed task data
# save_processed_data(model_name, task_name, task_processed_data)

# Convert to numpy arrays for easier manipulation
hk_result_values = np.array(hk_result_values)
dev_result_values = np.array(dev_result_values)
hk_label_values = np.array(hk_label_values)
dev_label_values = np.array(dev_label_values)

# Filter out NaN or inf values in both HK and Dev result/label pairs
hk_valid_mask = np.isfinite(hk_result_values) & np.isfinite(
hk_label_values)
dev_valid_mask = np.isfinite(dev_result_values) & np.isfinite(
dev_label_values)

# Filter the valid HK and Dev values
hk_result_values = hk_result_values[hk_valid_mask]
hk_label_values = hk_label_values[hk_valid_mask]
dev_result_values = dev_result_values[dev_valid_mask]
dev_label_values = dev_label_values[dev_valid_mask]

# Compute Pearson correlation for valid HK and Dev enhancer activities
if len(hk_result_values) > 0:
try:
hk_pcc, _ = pearsonr(hk_result_values, hk_label_values)
except Exception:
hk_pcc = np.inf # Set to inf in case of errors
else:
return {
'PCC':
'Error: HK has insufficient valid data '
'after removing NaNs and infs.'
}
if len(dev_result_values) > 0:
try:
dev_pcc, _ = pearsonr(dev_result_values, dev_label_values)
except Exception:
dev_pcc = np.inf # Set to inf in case of errors
else:
return {
'PCC':
'Error: Dev has insufficient valid data '
'after removing NaNs and infs.'
}

# Combine results with NaN/inf values consideration
total_hk_points = len(hk_result_values) + np.sum(~hk_valid_mask)
total_dev_points = len(dev_result_values) + np.sum(~dev_valid_mask)

# Assign mixed score with 0 for invalid entries
final_hk_pcc = (hk_pcc * len(hk_result_values) + 0 * np.sum(~hk_valid_mask)
) / total_hk_points if len(hk_result_values) > 0 else 0
final_dev_pcc = (dev_pcc * len(dev_result_values) +
0 * np.sum(~dev_valid_mask)) / total_dev_points if len(
dev_result_values) > 0 else 0
print('over_len:', over_len)
print('miss_len: ', miss_len)
return {
'PCC': (final_hk_pcc + final_dev_pcc) / 2,
'hk_PCC': final_hk_pcc,
'dev_PCC': final_dev_pcc
}


# Process binary classification task
def process_binary_classification_task(task_name, task_entries, model_name):
label_classes = []
result_classes = []
task_processed_data = []
entries_for_model = []
over_len = 0
miss_len = 0
for index, entry in enumerate(tqdm(task_entries)):
if '<summary>' in entry['model_output']:
entry['model_output'] = entry['model_output'].split(
'<summary>')[-1]

if '</think>' in entry['model_output']:
entry['model_output'] = entry['model_output'].split('</think>')[-1]
else:
if '<think>' in entry['model_output']:
over_len += 1
else:
miss_len += 1

label_class = 1 if entry['label'] == 'positive' else 0
model_output = entry['model_output']
model_output = str(entry['model_output'])
result_class = None
score = 0

if model_output is None:
result_class = 1 - label_class
else:
keyword_result = classify_by_keywords(model_output)
if keyword_result == 'dont_know':
result_class = 1 - label_class
elif keyword_result is not None:
result_class = keyword_result
else:
if model_output and model_output.strip():
entries_for_model.append({
'index': index,
'text': model_output
})
else:
result_class = 1 - label_class

# 将已经处理完的条目先存起来,留出空位给模型处理结果
task_processed_data.append({
'input': entry['input'],
'original_label': entry['label'],
'processed_label': label_class,
'original_model_output': model_output,
'processed_model_output': result_class, # 可能为None,后面会填充
'score': 'N/A' # 默认为N/A
})
print(len(entries_for_model))

if entries_for_model:

texts_to_classify = [item['text'] for item in entries_for_model]

# 一次性将所有文本传给模型
model_results = classify_by_sentiment_model(texts_to_classify)

for i, model_item in enumerate(tqdm(entries_for_model)):
original_index = model_item['index']
result_class, score = model_results[i]

# (可选逻辑) 如果置信度低,则判错
# if score < 0.5:
# result_class =
# 1 - task_processed_data[original_index]['processed_label']

# 将模型处理的结果填回到最终数据列表的正确位置
task_processed_data[original_index][
'processed_model_output'] = result_class
task_processed_data[original_index]['score'] = str(score)

result_classes = [d['processed_model_output'] for d in task_processed_data]
label_classes = [d['processed_label'] for d in task_processed_data]
print('miss_len:', miss_len)
print('over_len:', over_len)

# save_processed_data(model_name, task_name, task_processed_data)

return label_classes, result_classes


# Compute matthews correlation coefficient (MCC)
def compute_MCC(label_classes, result_classes):
if len(result_classes) == 0:
return {'MCC': 'Error: Empty data.'}
elif len(result_classes) != len(label_classes):
return {
'MCC': 'Error: Mismatch in the number of extracted numeric values.'
}
else:
mcc = matthews_corrcoef(label_classes, result_classes)
return {'MCC': mcc}


# Compute accuracy score (Acc)
def compute_Acc(label_classes, result_classes):
if len(result_classes) == 0:
return {
'Acc':
'Error: Insufficient data for classification. '
'Number of model outputs is 0.'
}
elif len(result_classes) != len(label_classes):
return {
'Acc':
'Error: Mismatched labels. '
'The number of model outputs does not match the number of labels.'
}
else:
acc = accuracy_score(label_classes, result_classes)
return {'Acc': acc}


# Extract RNA family from the text
def extract_rna_family(text):
for rna_class in RNA_CLASSES:
if rna_class in text:
return rna_class
return None


# Compute ACC metric for NoncodingRNAFamily multiclass classification task
def compute_Acc_for_NoncodingRNAFamily_task(task_name, task_entries,
model_name):
correct_count = 0
total_count = 0
task_processed_data = []
over_len = 0
miss_len = 0
for entry in task_entries:
if '</think>' in entry['model_output']:
entry['model_output'] = entry['model_output'].split('</think>')[-1]
result_family = extract_rna_family(entry['model_output'])
else:
if '<think>' in entry['model_output']:
over_len += 1
else:
miss_len += 1
# result_family = "None"
result_family = extract_rna_family(entry['model_output'])

label_family = entry['label']
# result_family = extract_rna_family(entry["model_output"])
# Compare extracted family with the ground truth label
if result_family == label_family:
correct_count += 1

total_count += 1

# Store original and processed data
task_processed_data.append({
'input':
entry['input'],
'label':
entry['label'],
'processed_model_output':
result_family,
'original_model_output':
entry['model_output']
})

# save_processed_data(model_name, task_name, task_processed_data)
print('over_len:', over_len)
print('miss_len:', miss_len)
# Calculate accuracy
accuracy = correct_count / total_count if total_count > 0 else 0
if (total_count - over_len) != 0:
print('true_acc:', correct_count / (total_count - over_len))

return {'Acc': accuracy}


# Extract RNA modification labels from the output text
def extract_modifications(text):
extracted_modifications = []
for mod_class in modification_classes:
# Use word boundaries to ensure whole-word match
if re.search(rf'\b{mod_class}\b', text):
extracted_modifications.append(mod_class)
return extracted_modifications


# Convert modification labels to a binary multihot vector
def convert_to_binary_vector(modifications, classes=modification_classes):
binary_vector = []

# Handle case where modifications is None
if modifications is None:
modifications = [] # Treat None as an empty list

for mod in classes:
if mod in modifications:
binary_vector.append(1)
else:
binary_vector.append(0)
return binary_vector


# Compute AUC metrics for Modification task
def compute_AUC_for_Modification_task(task_name, task_entries, model_name):
y_true = []
y_pred = []
task_processed_data = []
over_len = 0
miss_len = 0
for entry in task_entries:
# MARK:gaile
if '<summary>' in entry['model_output']:
entry['model_output'] = entry['model_output'].split(
'<summary>')[-1]
if '</think>' in entry['model_output']:
entry['model_output'] = entry['model_output'].split('</think>')[-1]
else:
if '<think>' in entry['model_output']:
over_len += 1
else:
miss_len += 1
predicted_modifications = extract_modifications(entry['model_output'])
# print(predicted_modifications)
true_modifications = entry['label'].split(',')

# Handle case where result is empty and label is "none"
if not predicted_modifications:
# Classify by keyword
predicted_modifications = classify_by_keywords(
entry['model_output'])

# If keyword negative,
# assigned to prediction to be the "none" class
if predicted_modifications == 0:
predicted_modifications = ['none']

elif predicted_modifications == 1:
predicted_modifications = []

# If the result cannot be classified, use the sentiment model
elif predicted_modifications is None:

sentiment_result, sentiment_score = \
classify_by_sentiment_model(
[entry['model_output']])[0]

# If classified as negative, manually label as 'none'
if sentiment_result == 0:
predicted_modifications = ['none']

else:
predicted_modifications = []

# Convert the predicted and true modifications to binary vectors
y_true.append(convert_to_binary_vector(true_modifications))
y_pred.append(convert_to_binary_vector(predicted_modifications))

# Store the processed data
task_processed_data.append({
'input':
entry['input'],
'label':
entry['label'],
'processed_model_ouput':
predicted_modifications,
'original_model_output':
entry['model_output']
})
print('label', entry['label'])
print('predication', predicted_modifications)

# save_processed_data(model_name, task_name, task_processed_data)
print('over_len:', over_len)
print('miss_len: ', miss_len)
# Compute the AUC for each class, then average the AUC across all classes
try:
auc = roc_auc_score(y_true, y_pred, average='macro')
print('auc', auc)
except ValueError:
auc = None

return {'AUC': auc}


# FunctionEC Task
# Modified from
# SaProt https://github.com/westlake-repl/SaProt/blob/main/utils/metrics.py
def count_f1_max(pred, target):
"""
F1 score with the optimal threshold.
Handles cases where either predictions or targets are empty.

Parameters:
pred (Tensor): predictions of shape :math:`(B, N)`
target (Tensor): binary targets of shape :math:`(B, N)`

Returns:
float: The maximum F1 score or 0.0 if inputs are empty.
"""
# Check if either pred or target is empty
if pred.numel() == 0 or target.numel() == 0:
return 0.0

# Proceed with the original logic if inputs are not empty
order = pred.argsort(descending=True, dim=1, stable=True)
# print(f"order: {order}")
target = target.gather(1, order)
precision = target.cumsum(1) / torch.ones_like(target).cumsum(1)
recall = target.cumsum(1) / (target.sum(1, keepdim=True) + 1e-10)

is_start = torch.zeros_like(target).bool()
is_start[:, 0] = 1
is_start = torch.scatter(is_start, 1, order, is_start)
all_order = pred.flatten().argsort(descending=True, stable=True)
order = order + torch.arange(
order.shape[0], device=order.device).unsqueeze(1) * order.shape[1]
order = order.flatten()
inv_order = torch.zeros_like(order)
inv_order[order] = torch.arange(order.shape[0], device=order.device)
is_start = is_start.flatten()[all_order]
all_order = inv_order[all_order]

precision = precision.flatten()
recall = recall.flatten()

all_precision = precision[all_order] - \
torch.where(
is_start, torch.zeros_like(precision),
precision[all_order - 1])
all_precision = all_precision.cumsum(0) / is_start.cumsum(0)
all_recall = recall[all_order] - \
torch.where(
is_start, torch.zeros_like(recall),
recall[all_order - 1])
all_recall = all_recall.cumsum(0) / pred.shape[0]
all_f1 = 2 * all_precision * all_recall / (all_precision + all_recall +
1e-10)

if torch.isnan(all_f1).any():
return 0.0

return all_f1.max()


def round_and_scale_results(data, decimal_places=3, scale_factor=100):
for key, value in data.items():
if isinstance(value, dict):
# Recursive call if the value is a dictionary
round_and_scale_results(value, decimal_places, scale_factor)
elif isinstance(value, (float, int)):
# Round and scale numeric values
data[key] = float(round(value * scale_factor, decimal_places))


# Convert EC number to binary multihot vectors
def ec_to_multihot(ec_list, ec_labels):
multihot = torch.zeros(len(ec_labels))
if not ec_list: # Check if ec_list is empty
return multihot
multihot = torch.zeros(len(ec_labels))
for ec in ec_list:
if ec in ec_labels:
idx = ec_labels.index(ec)
multihot[idx] = 1
return multihot


# Compute Fmax metric for FunctionEC task
def compute_Fmax_for_FunctionEC_task(task_name, task_entries, ec_labels,
model_name):
all_preds = []
all_labels = []
task_processed_data = []
over_len = 0
miss_len = 0
for entry in task_entries:
if '</think>' in entry['model_output']:
entry['model_output'] = entry['model_output'].split('</think>')[-1]
else:
if '<think>' in entry['model_output']:
over_len += 1
else:
miss_len += 1
if '<summary>' in entry['model_output']:
entry['model_output'] = entry['model_output'].split(
'<summary>')[-1]
# Parse the EC numbers from 'output' and 'label'
label_ec = re.findall(r'\d+\.\d+\.\d+\.\-?\d*', entry['label'])
result_ec = re.findall(r'\d+\.\d+\.\d+\.\-?\d*',
str(entry['model_output']))

# Convert EC numbers to multi-hot vectors
pred_multihot = ec_to_multihot(result_ec, ec_labels)
label_multihot = ec_to_multihot(label_ec, ec_labels)

# Store the results
all_preds.append(pred_multihot)
all_labels.append(label_multihot)

# Save processed task data
task_processed_data.append({
'input':
entry['input'],
'label':
entry['label'],
'processed_label':
label_ec,
'original_model_output':
entry['model_output'],
'processed_model_output':
result_ec,
})
print('label_ec', label_ec)
print('result_ec', result_ec)

# save_processed_data(model_name, task_name, task_processed_data)

# # Stack the predictions and targets for batch processing
all_preds = torch.stack(all_preds)
all_labels = torch.stack(all_labels)
print('miss_len: ', miss_len)
print('over_len: ', over_len)
# Compute the Fmax score
try:
fmax_score = count_f1_max(all_preds, all_labels)
except ValueError:
fmax_score = None

return {'Fmax': fmax_score.item()}


def preprocess_input_data(input_file_path, prediction, mini_set=False):
data = []
# Open the input file and process each line

with open(input_file_path, 'r') as f:
data_in = json.load(f)
if mini_set and len(data_in) > 150:
import random
random.seed(1024)
data_in = random.sample(data_in, 150)
random.seed()

if len(prediction) == len(data_in):
for index in range(len(data_in)):
try:
data_list = {}
data_list['input'] = data_in[index]['input']
data_list['output'] = data_in[index]['output']
# Try to load the line as a JSON object

data_list['model_output'] = prediction[index]
data_list['label'] = data_in[index]['label']
# data_list['label']=data_in[index]['label']

data_list['task'] = data_in[index]['task']
# data_list['task']=data_in[index]['task']
data.append(data_list)
# Ensure the parsed data is a dictionary
except json.JSONDecodeError:
print(f'Skipping invalid line: {data_in[index]}')
else:
print('len(prediction)!=len(data_in) !!!')

df = pd.DataFrame(data) # Convert to a DataFrame
# df = pd.read_json(input_file_path, lines=True, encoding_errors="ignore")
print(f'Number of data samples: {len(df)}')
df.rename(columns={'result': 'model_output'}, inplace=True)
print(df['task'])
df['task'] = df['task'].replace('rna_protein_interaction',
'ncRNAProteinInter')
df['task'] = df['task'].replace('antibody_antigen', 'AntibodyAntigen')
# Process entries with null labels
# null_label_df = df[df['label'].isna()]
# # null_label_df.to_json(f"{model_name}_result_label_null.json",
# orient='records', lines=True)

# Remove data for _all task
# df = df[~df['task'].str.endswith('_all')]

# Replace 'tf-h' with 'tf_h' and 'tf-m' with 'tf_m' in the 'task' column
df['task'] = df['task'].str.replace('tf-h', 'tf_h')
df['task'] = df['task'].str.replace('tf-m', 'tf_m')

# Keep data if label is not null
df = df[df['label'].notna()]
df.reset_index(inplace=True, drop=True)

# Convert to dictionary format for grouping
data = df.to_dict(orient='records')

# Group the data by 'task'
grouped_data = defaultdict(list)
for entry in data:
task_name = entry['task'].split('-')[0]
grouped_data[task_name].append(entry)

return grouped_data


class bio_instruction_Evaluator(BaseEvaluator):

def __init__(self,
path,
task,
model_name,
mini_set=False,
*args,
**kwargs):
super().__init__(*args, **kwargs)

path = get_data_path(path)
self.dataset_path = os.path.join(path, f'{task}/test/data.json')
self.model_name = model_name
self.mini_set = mini_set

def score(self, predictions):
test_path = self.dataset_path
repo_id = '/'.join(test_path.split('/')[:-3])
ec_path = 'ec_labels.json'
ec_file_path = os.path.join(repo_id, ec_path)
# ec_file_path = hf_hub_download(repo_id, ec_path, repo_type="dataset")

with open(ec_file_path, 'r') as f:
ec_labels = json.load(f)

test_path = test_path.split(repo_id + '/')[1]
input_file_path = self.dataset_path
# input_file_path =
# hf_hub_download(repo_id, test_path, repo_type="dataset")

grouped_data = preprocess_input_data(input_file_path,
predictions,
mini_set=self.mini_set)

print(f'Grouped data for tasks: {list(grouped_data.keys())}')

register_tasks_path = 'register_tasks.json'
register_tasks_file_path = os.path.join(repo_id, register_tasks_path)
# register_tasks_file_path =
# hf_hub_download(repo_id, register_tasks_path, repo_type="dataset")
with open(register_tasks_file_path, 'r') as f:
task_type_data = json.load(f)

metrics = {}

# Loop over tasks
for task_name, task_entries in grouped_data.items():
task_type = task_type_data[task_name]['type']
task_metrics = task_type_data[task_name]['metrics']
print(f'Prosessing {task_name} task...')
print(task_type)
sys.stdout.flush()

if task_type == 'regression':
# task_processed_data, label_values, result_values
# = process_regression_task(task_name, task_entries)
label_values, result_values = process_regression_task(
task_name, task_entries, self.model_name)
if task_metrics == 'spearman':
metrics[task_name] = compute_spearman(
label_values, result_values)

elif task_metrics == 'R2':
metrics[task_name] = compute_R2(label_values,
result_values)
# print(metrics[task_name])

elif task_metrics == 'mixed_score':
metrics[task_name] = compute_mixed_score(label_values,
result_values,
threshold=30)

elif task_type == 'binary classification':
# task_processed_data, label_classes, result_classes
# = process_binary_classification_task(task_name, task_entries)
label_classes, result_classes = \
process_binary_classification_task(
task_name, task_entries, self.model_name)
print(f'label_classes: {label_classes}')
print(f'result_classes: {result_classes}')
if task_metrics == 'MCC':
metrics[task_name] = compute_MCC(label_classes,
result_classes)

elif task_metrics == 'Acc':
metrics[task_name] = compute_Acc(label_classes,
result_classes)

elif task_type == 'multilabel regression':

if task_name == 'ProgrammableRNASwitches':
metrics[task_name] = \
compute_R2_for_ProgrammableRNASwitches_task(
task_name, task_entries, self.model_name)

elif task_name == 'enhancer_activity':
metrics[
task_name] = compute_PCC_for_enhancer_activity_task(
task_name, task_entries, self.model_name)

elif task_type == 'multiclass classification':

if task_name == 'NoncodingRNAFamily':
metrics[
task_name] = compute_Acc_for_NoncodingRNAFamily_task(
task_name, task_entries, self.model_name)

elif task_type == 'multilabel classification':
if task_name == 'FunctionEC':
metrics[task_name] = compute_Fmax_for_FunctionEC_task(
task_name, task_entries, ec_labels, self.model_name)

elif task_name == 'Modification':
metrics[task_name] = compute_AUC_for_Modification_task(
task_name, task_entries, self.model_name)

print(f'The metrics {task_metrics} for task {task_name}'
f' is {str(metrics[task_name][task_metrics])}')
sys.stdout.flush()

metrics_grouped_by_omics = defaultdict(dict)

for task_name, task_metrics in metrics.items():
# Get the omics type from task_type_data
omics = task_type_data[task_name]['omics']

# Scale the metrics
scaled_metrics = task_metrics.copy(
) # Make a copy to avoid modifying the original
round_and_scale_results(
scaled_metrics) # Apply scaling to the metrics

# Add the scaled metrics to the grouped dictionary
metrics_grouped_by_omics[omics][task_name] = scaled_metrics

# Save the metrics (results) to a new JSON file
# metrics_file_path = (
# path_bioinstruction + f'/metrics_result/{omics}/' +
# f'metrics_result_{self.model_name}_{task_name}.json')
# output_directory = os.path.dirname(metrics_file_path)
# os.makedirs(output_directory, exist_ok=True)
# with open(metrics_file_path, 'w') as outfile:
# json.dump(metrics_grouped_by_omics[omics], outfile, indent=4)
# print(f'Metrics saved to {metrics_file_path}')

return metrics_grouped_by_omics[omics][task_name]

+ 173
- 0
opencompass/datasets/SciReasoner/bulk_modulus_material.py View File

@@ -0,0 +1,173 @@
# flake8: noqa

import json
import os
import re
from collections import Counter
from typing import List, Union

from datasets import Dataset, DatasetDict
from huggingface_hub import hf_hub_download

from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
from opencompass.utils import get_data_path


@LOAD_DATASET.register_module()
class Bulk_modulus_material_Dataset(BaseDataset):

@staticmethod
def load(path, mini_set=False):
# if (hf_hub is True):
# # load from huggingface hub
# train_data = []
# repo_id = test_path.split('/')[0] + '/' + test_path.split('/')[1]
# train_path = train_path.split(repo_id + '/')[1]
# test_path = test_path.split(repo_id + '/')[1]
#
# train_path = hf_hub_download(repo_id,
# train_path,
# repo_type='dataset')
# test_path = hf_hub_download(repo_id,
# test_path,
# repo_type='dataset')

path = get_data_path(path)
train_path = os.path.join(path, f'bulk_modulus_material/dev/data.json')
test_path = os.path.join(path, f'bulk_modulus_material/test/data.json')

# load from local json file
with open(train_path, 'r', encoding='utf-8') as f:
train_data = json.load(f)
with open(test_path, 'r', encoding='utf-8') as f:
test_data = json.load(f)

train_data = train_data[:5]
# Limit the dataset to 5 samples for testing purposes
if mini_set:
import random
random.seed(1024)
test_data = random.sample(test_data, 150)
random.seed()

dataset = DatasetDict({
'train': Dataset.from_list(train_data),
'test': Dataset.from_list(test_data)
})
return dataset


@TEXT_POSTPROCESSORS.register_module()
def material_postprocessor(text: Union[str, None]) -> str:
"""提取 <material> 标签内容"""
if not text:
return ''
match = re.search(r'<material>(.*?)</material>', text,
re.DOTALL | re.IGNORECASE)
if match:
return match.group(1).strip()
return ''


class material_Evaluator(BaseEvaluator):
"""
Evaluator for:
- SMAct validity
- Composition precision (based on output-extracted elements)
- Exact match (between prediction and reference <material> block)
"""

def __init__(self, data_path=None, **kwargs):
super().__init__()
self.data_path = os.path.join(get_data_path(data_path),
'bulk_modulus_material/test/data.json')
self.prompt_elements_list = [] # 从 gt 提取的元素
self.reference_materials = [] # exact match 的参考答案

if self.data_path:
self._load_ground_truths()

def _load_ground_truths(self):
"""加载 ground truth 元素和材料"""
with open(self.data_path, 'r', encoding='utf-8') as f:
data = json.load(f)

for item in data:
output = item.get('output', '')
# 提取组成元素
elements = re.findall(r'\b[A-Z][a-z]?\b',
material_postprocessor(output))
self.prompt_elements_list.append(elements)
# 提取完整材料块用于 exact match
self.reference_materials.append(material_postprocessor(output))

def _normalize(self, formula: str) -> str:
"""标准化化学式(字母排序+数量)"""
tokens = re.findall(r'([A-Z][a-z]?)(\d*)', formula)
tokens.sort(key=lambda x: x[0])
return ''.join(f"{el}{cnt or ''}" for el, cnt in tokens)

def score(self, predictions: List[dict]):

from smact.screening import smact_validity

total = len(predictions)
format_valid = 0
smact_valid = 0
precision_sum = 0.0
exact_match_count = 0

for i, item in enumerate(predictions):
if isinstance(item, str):
item = {'prediction': item}
text = item.get('prediction', '').strip()

# --- SMAct validity ---
match = re.match(
r'([A-Z][a-z]?(?: [A-Z][a-z]?)*?)\s*(?:<sg>\s*)?<sg(\d+)>',
text)
if match:
elements_str, _ = match.groups()
elements = elements_str.split()
counter = Counter(elements)
formula = ''.join(f"{el}{cnt or ''}"
for el, cnt in sorted(counter.items()))
try:
if smact_validity(formula):
smact_valid += 1
format_valid += 1
except Exception:
pass

# --- Composition precision ---
if i < len(self.prompt_elements_list):
gt_elements = set(self.prompt_elements_list[i])
pred_elements = set(re.findall(r'\b[A-Z][a-z]?\b', text))
correct = len(gt_elements & pred_elements)
if gt_elements:
precision_sum += correct / len(gt_elements)

# --- Exact Match ---
if i < len(self.reference_materials):
pred_mat = material_postprocessor(text)
gt_mat = self.reference_materials[i]
if pred_mat == gt_mat:
exact_match_count += 1

avg_precision = (precision_sum / total * 100) if total else 0.0
smact_in_format = (smact_valid / format_valid *
100) if format_valid else 0.0
smact_in_all = (smact_valid / total * 100) if total else 0.0
exact_match_ratio = (exact_match_count / total * 100) if total else 0.0

return {
'total_samples': total,
'format_valid_count': format_valid,
'smact_valid_count': smact_valid,
'smact_validity_ratio_in_format_valid_%': smact_in_format,
'smact_validity_ratio_in_all_%': smact_in_all,
'average_precision_%': avg_precision,
'exact_match_ratio_%': exact_match_ratio
}

+ 228
- 0
opencompass/datasets/SciReasoner/composition_material.py View File

@@ -0,0 +1,228 @@
# flake8: noqa

import json
import os
import re
from collections import Counter
from typing import Union

from datasets import Dataset, DatasetDict
from huggingface_hub import hf_hub_download

from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
from opencompass.utils import get_data_path


def extract_elements_from_prompt(prompt: str) -> list:
"""
Extract element symbols from diverse prompt instructions.
Supported patterns include:
- composed of
- that has
- characterized by
- with the composition
- based on
- featuring
- whose makeup is
"""
patterns = [
r'composed of', r'that has', r'characterized by',
r'with the composition', r'based on', r'featuring', r'whose makeup is'
]

joined = '|'.join(patterns)
match = re.search(rf'(?:{joined})\s+(.*?)(?:[\.。\n]|$)', prompt,
re.IGNORECASE)

if match:
elements_str = match.group(1)
elements = [
el.strip() for el in re.split(r'[,\s]+', elements_str)
if re.fullmatch(r'[A-Z][a-z]?', el.strip())
]
return elements

# fallback: 尝试提取所有可能的元素符号
fallback = re.findall(r'\b[A-Z][a-z]?\b', prompt)
return fallback


def composition_precision(elements: list[str], prediction: str) -> float:
"""计算元素命中率"""
E_pi = set(elements)
clean = re.sub(r'<[^>]+>', ' ', prediction)
E_gi = set(re.findall(r'\b[A-Z][a-z]?\b', clean))
if not E_pi:
return 0.0
return len(E_pi & E_gi) / len(E_pi)


@LOAD_DATASET.register_module()
class Composition_material_Dataset(BaseDataset):

@staticmethod
def load(path, mini_set=False):

# if (hf_hub is True):
# # load from huggingface hub
# train_data = []
# repo_id = test_path.split('/')[0] + '/' + test_path.split('/')[1]
# train_path = train_path.split(repo_id + '/')[1]
# test_path = test_path.split(repo_id + '/')[1]
#
# train_path = hf_hub_download(repo_id,
# train_path,
# repo_type='dataset')
# test_path = hf_hub_download(repo_id,
# test_path,
# repo_type='dataset')

path = get_data_path(path)
train_path = os.path.join(
path, f'conditional_generation/composition_material/dev/data.json')
test_path = os.path.join(
path,
f'conditional_generation/composition_material/test/data.json')

# load from local json file
with open(train_path, 'r', encoding='utf-8') as f:
train_data = json.load(f)
with open(test_path, 'r', encoding='utf-8') as f:
test_data = json.load(f)

train_data = train_data[:5]
# Limit the dataset to 5 samples for testing purposes
if mini_set:
import random
random.seed(1024)
test_data = random.sample(test_data, 150)
random.seed()

dataset = DatasetDict({
'train': Dataset.from_list(train_data),
'test': Dataset.from_list(test_data)
})
return dataset


@TEXT_POSTPROCESSORS.register_module()
def material_postprocessor(text: Union[str, None]) -> str:
"""提取 <material> 标签内容"""
if not text:
return ''
match = re.search(r'<material>(.*?)</material>', text,
re.DOTALL | re.IGNORECASE)
if match:
return match.group(1).strip()
return ''


class composition_Evaluator(BaseEvaluator):

def __init__(self, data_path, tuning_data=None, **kwargs):
super().__init__()
self.data_path = os.path.join(
get_data_path(data_path),
'conditional_generation/composition_material/test/data.json')
self.prompts = []
self.gt_materials = set()

if self.data_path:
self._load_original_inputs()

def _load_original_inputs(self):
with open(self.data_path, 'r', encoding='utf-8') as f:
data = json.load(f)

self.prompts = [item.get('input', '') for item in data]

for item in data:
output = item.get('output', '')
mat = material_postprocessor(output)
if mat:
self.gt_materials.add(mat.strip())

def _normalize(self, formula):
tokens = re.findall(r'([A-Z][a-z]?)(\d*)', formula)
tokens.sort(key=lambda x: x[0])
return ''.join(f"{el}{cnt or ''}" for el, cnt in tokens)

def score(self, predictions):
from smact.screening import smact_validity

total = len(predictions)
format_valid = 0
smact_valid = 0
precision_sum = 0.0
novel_count = 0

for i, item in enumerate(predictions):
if isinstance(item, str):
item = {'prediction': item}

text = item.get('prediction', '').strip()
prompt = item.get('input', '').strip()
if not prompt and i < len(self.prompts):
prompt = self.prompts[i]

prompt_elements = extract_elements_from_prompt(prompt)

print('== Sample ==')
print('Prompt:', prompt)
print('Prompt Elements:', prompt_elements)
print('Prediction Text:', text[:200])

# --- SMAct validity ---
match = re.match(
r'([A-Z][a-z]?(?: [A-Z][a-z]?)*?)\s*(?:<sg>\s*)?<sg(\d+)>',
text)
if match:
elements_str, _ = match.groups()
elements = elements_str.split()
counter = Counter(elements)
formula = ''.join(f"{el}{cnt or ''}"
for el, cnt in sorted(counter.items()))
try:
if smact_validity(formula):
smact_valid += 1
format_valid += 1
except Exception:
pass

# --- Composition precision ---
if prompt_elements:
precision_sum += composition_precision(prompt_elements, text)

# --- Novelty ---
predicted_material = material_postprocessor(text)
if not predicted_material:
predicted_material = text.strip()

if predicted_material:
print(f'[Novelty Check] GT materials: {self.gt_materials}')
print(
f'[Novelty Check] Predicted material: {predicted_material}'
)
if predicted_material not in self.gt_materials:
novel_count += 1
print('[Novelty] Novel')
else:
print('[Novelty] Seen before')

avg_precision = (precision_sum / total * 100) if total else 0.0
smact_in_format = (smact_valid / format_valid *
100) if format_valid else 0.0
smact_in_all = (smact_valid / total * 100) if total else 0.0
novelty_ratio = (novel_count / total * 100) if total else 0.0

return {
'total_samples': total,
'format_valid_count': format_valid,
'smact_valid_count': smact_valid,
'smact_validity_ratio_in_format_valid_%': smact_in_format,
'smact_validity_ratio_in_all_%': smact_in_all,
'average_precision_%': avg_precision,
'novel_material_ratio_%': novelty_ratio,
}

+ 7
- 0
opencompass/datasets/SciReasoner/opi/__init__.py View File

@@ -0,0 +1,7 @@
# flake8: noqa
from .config import TASKS as opi_TASKS # noqa: F401, F403
from .config import \
TASKS_GENERATION_SETTINGS as \
opi_TASKS_GENERATION_SETTINGS # noqa: F401, F403
from .evaluator import opi_postprocess # noqa: F401, F403
from .evaluator import Opi_Evaluator, OpiDataset

+ 136
- 0
opencompass/datasets/SciReasoner/opi/config.py View File

@@ -0,0 +1,136 @@
TASKS = (
'EC_number',
'Subcellular_localization',
'Fold_type',
'Keywords',
'GO',
'Function',
'gSymbol2Tissue',
'gSymbol2Cancer',
'gName2Cancer',
)

DEFAULT_MAX_INPUT_TOKENS = 512
DEFAULT_MAX_NEW_TOKENS = 400

TASKS_GENERATION_SETTINGS = {
'EC_number': {
'generation_kargs': {
'num_return_sequences': 1,
'num_beams': 1,
'temperature': 0.2,
'top_k': 50,
'top_p': 0.75,
'do_sample': True,
},
},
'Subcellular_localization': {
'generation_kargs': {
'num_return_sequences': 1,
'num_beams': 1,
'temperature': 0.2,
'top_k': 50,
'top_p': 0.75,
'do_sample': True,
},
},
'Fold_type': {
'generation_kargs': {
'num_return_sequences': 1,
'num_beams': 1,
'temperature': 0.2,
'top_k': 50,
'top_p': 0.75,
'do_sample': True,
},
},
'Keywords': {
'generation_kargs': {
'num_return_sequences': 1,
'num_beams': 1,
'temperature': 0.2,
'top_k': 50,
'top_p': 0.75,
'do_sample': True,
},
},
'GO': {
'generation_kargs': {
'num_return_sequences': 1,
'num_beams': 1,
'temperature': 0.2,
'top_k': 50,
'top_p': 0.75,
'do_sample': True,
},
},
'Function': {
'generation_kargs': {
'num_return_sequences': 1,
'num_beams': 1,
'temperature': 0.2,
'top_k': 50,
'top_p': 0.75,
'do_sample': True,
},
},
'gSymbol2Tissue': {
'generation_kargs': {
'num_return_sequences': 1,
'num_beams': 1,
'temperature': 0.2,
'top_k': 50,
'top_p': 0.75,
'do_sample': True,
},
},
'gSymbol2Cancer': {
'generation_kargs': {
'num_return_sequences': 1,
'num_beams': 1,
'temperature': 0.2,
'top_k': 50,
'top_p': 0.75,
'do_sample': True,
},
},
'gName2Cancer': {
'generation_kargs': {
'num_return_sequences': 1,
'num_beams': 1,
'temperature': 0.2,
'top_k': 50,
'top_p': 0.75,
'do_sample': True,
},
},
}

TASK_TAGS = {
'EC_number': ('<EC_NUMBER>', '</EC_NUMBER>'),
'Subcellular_localization': ('<LOCATION>', '</LOCATION>'),
'Fold_type': ('<FOLD>', '</FOLD>'),
'Keywords': ('<KEYWORDS>', '</KEYWORDS>'),
'GO': ('<GO_TERMS>', '</GO_TERMS>'),
'Function': ('<FUNCTION>', '</FUNCTION>'),
'gSymbol2Tissue': ('<TISSUE>', '</TISSUE>'),
'gSymbol2Cancer': ('<CANCER>', '</CANCER>'),
'gName2Cancer': ('<CANCER>', '</CANCER>'),
}

# These tasks output SMILES, where there may be semicolons
# that separate different parts.
# To facilitate evaluation, each semicolon is replaced by a dot.
TASKS_WITH_SEMICOLON_REPLACE = ('Keywords', 'GO')

# For these tasks, one input might have multiple gold answers,
# so the gold answer should be directly obtained from the dataset
# instead of directly using the gold domain of each sample.
TASKS_WITH_READING_GOLD_FROM_DATASET = TASKS

BASE_MODELS = {
'osunlp/LlaSMol-Mistral-7B': 'mistralai/Mistral-7B-v0.1',
'osunlp/LlaSMol-Galactica-6.7B': 'facebook/galactica-6.7b',
'osunlp/LlaSMol-Llama2-7B': 'meta-llama/Llama-2-7b-hf',
'osunlp/LlaSMol-CodeLlama-7B': 'codellama/CodeLlama-7b-hf',
}

+ 289
- 0
opencompass/datasets/SciReasoner/opi/evaluator.py View File

@@ -0,0 +1,289 @@
# flake8: noqa
# opencompass/datasets/opi/evaluator.py

import json
import os
import re

from datasets import Dataset, DatasetDict
from huggingface_hub import hf_hub_download

from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
from opencompass.utils import get_data_path

from .utils.metrics4all import calculate_metrics, calculate_rouge_l


@LOAD_DATASET.register_module()
class OpiDataset(BaseDataset):

@staticmethod
def load(path, task, max_cut=-1, mini_set=False, hf_hub=False):

# if (hf_hub is True):
# # load from huggingface hub
# train_data = []
# repo_id = test_path.split('/')[0] + '/' + test_path.split('/')[1]
# train_path = train_path.split(repo_id + '/')[1]
# test_path = test_path.split(repo_id + '/')[1]
#
# train_path = hf_hub_download(repo_id,
# train_path,
# repo_type='dataset')
# test_path = hf_hub_download(repo_id,
# test_path,
# repo_type='dataset')

path = get_data_path(path)
train_path = os.path.join(path, f'{task}/dev/data.json')
test_path = os.path.join(path, f'{task}/test/data.json')

with open(train_path, 'r', encoding='utf-8') as f:
train_data = json.load(f)
with open(test_path, 'r', encoding='utf-8') as f:
test_data = json.load(f)

# train_data = train_data[:10]
# # Limit the dataset to 10 samples for testing purposes
# test_data = test_data[:10]
if mini_set:
import random
random.seed(1024)
test_data = random.sample(test_data, 50)
random.seed()

dataset = DatasetDict({
'train': Dataset.from_list(train_data),
'test': Dataset.from_list(test_data)
})
return dataset


def extract_answer_part(outputs, left_tag, right_tag, mode='tag'):
assert mode in ('tag', 'direct')
assert isinstance(outputs, list)

answers = []
for text in outputs:
if mode == 'direct' or (left_tag is None and right_tag is None):
text = text.replace('<unk>', '').replace('</s>', '').strip()
answers.append(text.strip())
continue

left_tag_pos = text.find(left_tag)
if left_tag_pos == -1:
answers.append('')
continue
right_tag_pos = text.find(right_tag)
if right_tag_pos == -1:
answers.append('')
continue
text = text[left_tag_pos + len(left_tag):right_tag_pos].strip()
answers.append(text)
return answers


@TEXT_POSTPROCESSORS.register_module('opi_postprocess')
def opi_postprocess(text, task, *args, **kwargs):
print(f'task: {task}, text: {text}')
text = text.strip()
text = re.sub(r'<\|endoftext\|>', '', text)
text = re.sub(r'<\|im_end\|>', '', text)
return text


class Opi_Evaluator(BaseEvaluator):

def __init__(self, task, *args, **kwargs):
super().__init__(*args, **kwargs)
self.task = task

def score(self, predictions, references):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different length'
}

if not isinstance(predictions[0], list):
predictions = [[pred] for pred in predictions]
if not isinstance(references[0], list):
references = [[ref] for ref in references]

if self.task == 'Function':
return self._evaluate_function(predictions, references)
elif self.task == 'Subcellular_localization':
return self._evaluate_subcellular_localization(
predictions, references)
elif self.task == 'Fold_type':
return self._evaluate_fold_type(predictions, references)
elif self.task in ('EC_number', 'GO', 'Keywords', 'gSymbol2Tissue',
'gSymbol2Cancer', 'gName2Cancer'):
return self._evaluate_multilabel(predictions, references)
else:
return self._evaluate_general(predictions, references)

def _evaluate_function(self, predictions, references):
"""评估功能描述任务,使用 ROUGE-L"""
# if not METRICS_AVAILABLE:
# return self._evaluate_text_similarity(predictions, references)

rouge_ls = []
for pred_list, ref_list in zip(predictions, references):
pred = pred_list[0].strip()
ref = ref_list[0].strip()

# 确保输出和目标是列表格式
if isinstance(pred, str):
pred = [pred]
if isinstance(ref, str):
ref = [ref]

rouge_l = calculate_rouge_l(pred, ref)
rouge_ls.append(rouge_l)

mean_rouge_l = sum(rouge_ls) / len(rouge_ls) if rouge_ls else 0
return {
'ROUGE-L': round(mean_rouge_l, 4),
# 'total': len(predictions)
}

def _evaluate_subcellular_localization(self, predictions, references):
"""评估亚细胞定位任务,使用准确率"""
# if not METRICS_AVAILABLE:
# return self._evaluate_general(predictions, references)

accuracies = []
for pred_list, ref_list in zip(predictions, references):
pred = pred_list[0].strip()
ref = ref_list[0].strip()

# 确保输出和目标是列表格式
if isinstance(pred, str):
pred = [pred]
if isinstance(ref, str):
ref = [ref]

accuracy, _, _, _ = calculate_metrics(pred, ref)
accuracies.append(accuracy)

mean_accuracy = sum(accuracies) / len(accuracies) if accuracies else 0
return {
'Accuracy': round(mean_accuracy, 4),
# 'total': len(predictions)
}

def _evaluate_fold_type(self, predictions, references):
"""评估折叠类型任务,使用与 accuracy4fold_type.py 相同的计算方式"""
# 初始化计数器
correct_predictions = 0
total_predictions = 0

# 评估每个预测结果
for pred_list, ref_list in zip(predictions, references):
pred = pred_list[0].strip()
ref = ref_list[0].strip()

# 直接比较预测值和真实值
if pred == ref:
correct_predictions += 1
total_predictions += 1

# 计算准确率
accuracy = correct_predictions / total_predictions \
if total_predictions > 0 else 0

return {
'Accuracy': round(accuracy, 4),
# 'correct': correct_predictions,
# 'total': total_predictions
}

def _evaluate_multilabel(self, predictions, references):
"""评估多标签任务(EC_number, GO, Keywords)"""
# if not METRICS_AVAILABLE:
# return self._evaluate_general(predictions, references)

precisions = []
recalls = []
f1_scores = []

for pred_list, ref_list in zip(predictions, references):
pred = pred_list[0].strip()
ref = ref_list[0].strip()

# if isinstance(pred, str):
# pred = re.split(r'[;,,;]\s*', pred)
# if isinstance(ref, str):
# ref = re.split(r'[;,,;]\s*', ref)
if isinstance(pred, str):
pred = [
p.strip() for p in re.split(r'[;,,;]\s*', pred)
if p.strip()
]
if isinstance(ref, str):
ref = [
r.strip() for r in re.split(r'[;,,;]\s*', ref)
if r.strip()
]

# 过滤空字符串
# pred = [p for p in pred if p.strip()]
# ref = [r for r in ref if r.strip()]
# import pdb; pdb.set_trace()
if ref: # 只有当参考标签不为空时才计算
_, precision, recall, f1 = calculate_metrics(pred, ref)
precisions.append(precision)
recalls.append(recall)
f1_scores.append(f1)

mean_precision = sum(precisions) / len(precisions) if precisions else 0
mean_recall = sum(recalls) / len(recalls) if recalls else 0
mean_f1 = sum(f1_scores) / len(f1_scores) if f1_scores else 0

return {
'Precision': round(mean_precision, 4),
'Recall': round(mean_recall, 4),
'F1 Score': round(mean_f1, 4),
# 'total': len(predictions)
}

def _evaluate_text_similarity(self, predictions, references):
"""简单的文本相似度评估(当 ROUGE 不可用时)"""
correct = 0
total = len(predictions)

for pred_list, ref_list in zip(predictions, references):
pred = pred_list[0].lower().strip()
ref = ref_list[0].lower().strip()

# 简单的包含关系检查
if pred == ref or pred in ref or ref in pred:
correct += 1

accuracy = correct / total if total > 0 else 0
return {
'Text_Similarity': round(accuracy, 4),
# 'correct': correct,
# 'total': total
}

def _evaluate_general(self, predictions, references):
"""通用评估方法"""
correct = 0
total = len(predictions)

for pred_list, ref_list in zip(predictions, references):
pred = pred_list[0].lower().strip()
ref = ref_list[0].lower().strip()

if pred == ref:
correct += 1

accuracy = correct / total if total > 0 else 0
return {
'Accuracy': round(accuracy, 4),
# 'correct': correct,
# 'total': total
}

+ 62
- 0
opencompass/datasets/SciReasoner/opi/process_ec_numbers.py View File

@@ -0,0 +1,62 @@
import json
import re
from typing import Any


def add_spaces_to_ec_number(text: str) -> str:
"""
在EC号码中添加空格,格式从 2.7.10.2 变为 2 . 7 . 10 . 2
"""
# 匹配EC号码格式:数字.数字.数字.数字
pattern = r'\b(\d+)\.(\d+)\.(\d+)\.(\d+)\b'

def replace_ec(match):
return (f'{match.group(1)} . {match.group(2)} .',
f' {match.group(3)} . {match.group(4)}')

return re.sub(pattern, replace_ec, text)


def process_json_value(value: Any) -> Any:
"""
递归处理JSON值,在字符串中添加EC号码空格
"""
if isinstance(value, str):
return add_spaces_to_ec_number(value)
elif isinstance(value, dict):
return {k: process_json_value(v) for k, v in value.items()}
elif isinstance(value, list):
return [process_json_value(item) for item in value]
else:
return value


def process_ec_json_file(input_file: str, output_file: str) -> None:
"""
处理JSON文件,将所有EC号码格式化为带空格的格式
"""
try:
# 读取JSON文件
with open(input_file, 'r', encoding='utf-8') as f:
data = json.load(f)

# 处理数据
processed_data = process_json_value(data)

# 写入新文件
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(processed_data, f, ensure_ascii=False, indent=2)

print(f'处理完成!已保存到 {output_file}')

except Exception as e:
print(f'处理文件时出错: {e}')


if __name__ == '__main__':
input_file = \
'cot_data/EC_number_train_CLEAN_EC_number_train_train.json_final.json'
output_file = \
'cot_data/EC_number_train_CLEAN_EC_number_train_train_spaced.json'

process_ec_json_file(input_file, output_file)

+ 45
- 0
opencompass/datasets/SciReasoner/opi/utils/accuracy4fold_type.py View File

@@ -0,0 +1,45 @@
import json
import os

import tqdm


def load_json(file_path):
"""Load JSON data from a file."""
with open(file_path, 'r') as f:
return json.load(f)


def compute_accuracy4fold_type(eval_file, test_files):
"""Compute accuracy for predictions against test datasets."""
# Load evaluation data
eval_data = load_json(eval_file)
acc_dict = {}
# Iterate over each test file
for test_file in test_files:
# Load test data
test_data = load_json(test_file)

# Create a set of test sequences
test_seq_set = {item['primary'] for item in test_data}

# Initialize counters
correct_predictions = 0
total_predictions = 0

# Evaluate each item in the evaluation data
for item in tqdm.tqdm(eval_data):
if item['input'] not in test_seq_set:
continue
predict = item.get('output', item.get('predict', []))
label = item['target']
if predict == label:
correct_predictions += 1
total_predictions += 1

# Calculate and print accuracy
accuracy = correct_predictions / total_predictions \
if total_predictions > 0 else 0
acc_dict[os.path.basename(test_file).split('.')[0][21:]] = round(
accuracy, 4)
return acc_dict

+ 143
- 0
opencompass/datasets/SciReasoner/opi/utils/metrics4all.py View File

@@ -0,0 +1,143 @@
import argparse
import json
import os

import tqdm
from rouge_score import rouge_scorer
from sklearn.metrics import (accuracy_score, f1_score, precision_score,
recall_score)
from sklearn.preprocessing import MultiLabelBinarizer

from .accuracy4fold_type import compute_accuracy4fold_type


def calculate_metrics(output, target):
# Convert to binary format
mlb = MultiLabelBinarizer(classes=sorted(set(output + target)))
y_true = mlb.fit_transform([target])
y_pred = mlb.transform([output])

# Calculate metrics
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true,
y_pred,
average='micro',
zero_division=0)
recall = recall_score(y_true, y_pred, average='micro', zero_division=0)
f1 = f1_score(y_true, y_pred, average='micro', zero_division=0)

return accuracy, precision, recall, f1


def calculate_rouge_l(output, target):
scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
scores = scorer.score(' '.join(target), ' '.join(output))
return scores['rougeL'].fmeasure


def process_json_file(json_file_path):
accuracies = []
precisions = []
recalls = []
f1_scores = []
rouge_ls = []

with open(json_file_path, 'r') as file:
data = json.load(file)

for entry in tqdm.tqdm(data):
output = entry.get('output', entry.get('predict', []))
target = entry.get('target', [])

# Ensure both output and target are lists
if isinstance(output, str):
if any(keyword in json_file_path for keyword in
['EC_number', 'go_terms', 'keywords', 'gene', 'domain']):
output = output.split('; ')
elif any(keyword in json_file_path
for keyword in ['function', 'subcell_loc', 'ss']):
output = [output]
if isinstance(target, str):
if any(keyword in json_file_path for keyword in
['EC_number', 'go_terms', 'keywords', 'gene', 'domain']):
target = target.split('; ')
elif any(keyword in json_file_path
for keyword in ['function', 'subcell_loc', 'ss']):
target = [target]

if 'function' in json_file_path:
rouge_l = calculate_rouge_l(output, target)
rouge_ls.append(rouge_l)
elif 'subcell_loc' in json_file_path:
accuracy, _, _, _ = calculate_metrics(output, target)
accuracies.append(accuracy)
else:
_, precision, recall, f1 = calculate_metrics(output, target)
# accuracies.append(accuracy)
precisions.append(precision)
recalls.append(recall)
f1_scores.append(f1)

if 'function' in json_file_path:
mean_rouge_l = sum(rouge_ls) / len(rouge_ls) if rouge_ls else 0
return {'ROUGE-L': round(mean_rouge_l, 4)}, None
elif 'subcell_loc' in json_file_path:
mean_accuracy = sum(accuracies) / len(accuracies) if accuracies else 0
return {'Accuracy': round(mean_accuracy, 4)}, None
else:
mean_precision = sum(precisions) / len(precisions) if precisions else 0
mean_recall = sum(recalls) / len(recalls) if recalls else 0
mean_f1 = sum(f1_scores) / len(f1_scores) if f1_scores else 0
return {
'Precision': round(mean_precision, 4),
'Recall': round(mean_recall, 4),
'F1 Score': round(mean_f1, 4)
}, None


def main(eval_res_path):
results = {}

# List all JSON files in the directory
for file_name in sorted(os.listdir(eval_res_path)):
if file_name.endswith('.json') and 'metrics_result' not in file_name:
print(f'Processing {file_name}')
file_path = os.path.join(eval_res_path, file_name)
if 'function' in file_path:
metrics, _ = process_json_file(file_path)
results[file_name] = {'ROUGE-L': metrics['ROUGE-L']}
elif 'subcell' in file_path:
metrics, _ = process_json_file(file_path)
results[file_name] = {'Accuracy': metrics['Accuracy']}
elif 'fold_type' in file_path:
test_files = [
'compute_scores/remote_homology_test_fold_holdout.json',
('compute_scores/'
'remote_homology_test_superfamily_holdout.json'),
'compute_scores/remote_homology_test_family_holdout.json'
]
acc_dict = compute_accuracy4fold_type(file_path, test_files)
results[file_name] = acc_dict
else:
metrics, _ = process_json_file(file_path)
results[file_name] = {
# 'Accuracy': metrics['Accuracy'],
'Precision': metrics['Precision'],
'Recall': metrics['Recall'],
'F1 Score': metrics['F1 Score']
}
print(results[file_name])
with open(f'{eval_res_path}/metrics_result.json', 'w') as result_file:
json.dump(results, result_file, indent=4)

print(f'Results saved to: {eval_res_path}/metrics_result.json')


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--indir',
required=True,
help='Path to the result file dir')
args = parser.parse_args()

main(args.indir)

+ 128
- 0
opencompass/datasets/SciReasoner/uncond_RNA.py View File

@@ -0,0 +1,128 @@
import os
import re
import subprocess
from tempfile import TemporaryDirectory
from typing import Union

from datasets import Dataset

from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS


@LOAD_DATASET.register_module()
class Uncond_RNA_Dataset(BaseDataset):

@staticmethod
def load(num, prompt):
dataset = [{'input': prompt, 'output': ''} for _ in range(num)]
return Dataset.from_list(dataset)


@TEXT_POSTPROCESSORS.register_module()
def RNA_postprocessor(text: Union[str, None]) -> str:
if not text:
return ''

text = text.replace('T', 'U').replace('t', 'u')

match = re.search(r'<rna>(.*?)</rna>', text, re.DOTALL | re.IGNORECASE)
if match:
return match.group(1).strip()

return ''


class RNA_Evaluator(BaseEvaluator):

def score(self, predictions, references):
invalid_count = 0
overlength_count = 0
valid_rnas = []
valid_bases = set('AUCG')
avg_mfe = None
rfam_families = []

for idx, seq in enumerate(predictions):
seq = seq.strip().upper()
if not seq or any(base not in valid_bases for base in seq):
invalid_count += 1
else:
valid_rnas.append((f'seq{idx}', seq))
if len(seq) > 1024:
overlength_count += 1

with TemporaryDirectory() as tmpdir:
tmpdir = 'tmp'
fasta_path = os.path.join(tmpdir, 'valid_sequences.fasta')
with open(fasta_path, 'w') as f:
for seq_id, seq in valid_rnas:
f.write(f'>{seq_id}\n{seq}\n')

mfe_file = self.run_rnafold(fasta_path, tmpdir)
mfe_values = self.parse_mfe(mfe_file)
avg_mfe = sum(mfe_values) / len(mfe_values) if mfe_values else None

rfam_cm = 'Rfam/Rfam.cm'
rfam_clanin = 'Rfam/Rfam.clanin'
rfam_tblout = self.run_cmscan(fasta_path, tmpdir, rfam_cm,
rfam_clanin)
rfam_families = self.parse_unique_families(rfam_tblout)

return {
'total_samples': len(predictions),
'invalid_prediction_count': invalid_count,
'overlength_prediction_count': overlength_count,
'valid_sequence_count': len(valid_rnas),
'average_mfe': avg_mfe,
'retrieved_rfam_family_count': len(rfam_families),
}

def run_rnafold(self, input_fasta, output_dir):
output_file = os.path.join(output_dir, 'mfe_results.txt')
cmd = (
f'cd {output_dir} && RNAfold < '
f'{os.path.abspath(input_fasta)} > {os.path.basename(output_file)}'
)
ret = subprocess.run(cmd, shell=True)
if ret.returncode != 0:
print(ret)
raise RuntimeError('RNAfold execution failed!')
return output_file

def parse_mfe(self, output_file):
mfe_values = []
with open(output_file) as f:
for line in f:
match = re.search(r'\s\(([-\d\.]+)\)\s*$', line.strip())
if match:
mfe = float(match.group(1))
mfe_values.append(mfe)
return mfe_values

def run_cmscan(self, fasta_file, output_dir, rfam_cm, rfam_clanin):
tblout_path = os.path.join(output_dir, 'cmscan_results.tblout')
cmscan_cmd = [
'cmscan', '--rfam', '--cut_ga', '--nohmmonly', '--tblout',
tblout_path, '--fmt', '2', '--clanin', rfam_clanin, rfam_cm,
fasta_file
]
result = subprocess.run(cmscan_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
if result.returncode != 0:
raise RuntimeError(f'cmscan failed:\n{result.stderr.decode()}')
return tblout_path

def parse_unique_families(self, tblout_file):
families = set()
with open(tblout_file, 'r') as f:
for line in f:
if line.startswith('#'):
continue
parts = line.strip().split()
if len(parts) > 0:
family_id = parts[0]
families.add(family_id)
return families

+ 80
- 0
opencompass/datasets/SciReasoner/uncond_material.py View File

@@ -0,0 +1,80 @@
import re
from typing import Union

from datasets import Dataset

from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS


@LOAD_DATASET.register_module()
class Uncond_material_Dataset(BaseDataset):

@staticmethod
def load(num, prompt):
dataset = [{'input': prompt, 'output': ''} for _ in range(num)]
return Dataset.from_list(dataset)


@TEXT_POSTPROCESSORS.register_module()
def material_postprocessor(text: Union[str, None]) -> str:
if not text:
return ''

match = re.search(r'<material>(.*?)</material>', text,
re.DOTALL | re.IGNORECASE)
if match:
return match.group(1).strip()

return ''


class uncond_material_Evaluator(BaseEvaluator):

def score(self, predictions):
total = len(predictions)
format_valid = 0
smact_valid = 0
from collections import Counter

from smact.screening import smact_validity
for text in predictions:

match = re.match(
r'([A-Z][a-z]?(?: [A-Z][a-z]?)*?)'
r'\s*(?:<|⟨)sg(?:>|⟩)\s*(?:<|⟨)sg(\d+)(?:>|⟩)', text.strip())
if not match:
continue

elements_str, sg_num = match.groups()
elements = elements_str.split()
counter = Counter(elements)
formula = ''
for el, cnt in sorted(counter.items()):
formula += el
if cnt > 1:
formula += str(cnt)
try:
if smact_validity(formula):
smact_valid += 1
format_valid += 1
except Exception:
continue

smact_validity_ratio_in_format_valid = smact_valid / format_valid \
if format_valid else 0
smact_validity_ratio_in_all = smact_valid / total if total else 0

return {
'total_samples':
total,
'format_valid_count':
format_valid,
'smact_valid_count':
smact_valid,
'smact_validity_ratio_in_format_valid':
smact_validity_ratio_in_format_valid * 100,
'smact_validity_ratio_in_all':
smact_validity_ratio_in_all * 100,
}

+ 123
- 0
opencompass/datasets/SciReasoner/unconditional_molecule_generation/UMG.py View File

@@ -0,0 +1,123 @@
import re

from datasets import Dataset, DatasetDict

try:
from rdkit import Chem
except Exception:
Chem = None

from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET


@LOAD_DATASET.register_module()
class UMG_Dataset(BaseDataset):

@staticmethod
def load(max_cut=-1):
gen_inst = 'Generate a molecule with <SMILES> '

output_samples = [
'<SMILES>CN1C=NC2=C1C(=O)N(C)C(=O)N2C</SMILES>',
'<SMILES>c1ccccc1C(=O)O</SMILES>', '<SMILES>CCO</SMILES>',
'<SMILES>CC(=O)Oc1ccccc1C(=O)O</SMILES>', '<SMILES>CCO</SMILES>'
]

train_data = [{
'input': gen_inst,
'output': output,
} for output in output_samples]

len_test_data = 800

if (max_cut != -1):
len_test_data = min(len_test_data, max_cut)

test_data = [{
'input': gen_inst,
'output': ''
} for i in range(len_test_data)]

dataset = DatasetDict({
'train': Dataset.from_list(train_data),
'test': Dataset.from_list(test_data)
})
return dataset


class UMG_Evaluator(BaseEvaluator):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def is_valid_smiles_rdkit(self, s):
"""使用 RDKit 验证 SMILES 字符串"""
if not isinstance(s, str) or not s:
return False
# 如果字符串中已经包含HTML标签样的结构,则认为它不是一个纯SMILES串
# 这是为了避免重复处理已经被脚本标记过的SMILES
if '<' in s or '>' in s:
return False
mol = Chem.MolFromSmiles(
s, sanitize=False) # sanitize=False 允许解析但可能化学上无效的SMILES
return mol is not None

def extract_smiles_simple(self, text: str) -> str | None:
# match = re.search(r"⟨mol⟩([A-Za-z0-9()=#+@\\/\.-]+)⟨/mol⟩", text)
if '<SMILES>' not in text:
generic_pat = re.compile(r'(?<!\*)([A-Za-z0-9\u2080-\u2089'
r'\(\)\.\+\-\=\#\$\:\@\*/%\\]{2,})(?!\*)')

def generic_replace(m):
candidate = m.group(1)

if len(candidate) >= 4 and self.is_valid_smiles_rdkit(
candidate):
print('candidate', candidate)
return f'<SMILES> {candidate} </SMILES>'
else:
return candidate

text = generic_pat.sub(generic_replace, text)
match = re.search(r'<SMILES> ([A-Za-z0-9()=#+@\\/\.-]+) </SMILES>',
text)
if match:
# 提取并打印出干净的结果
clean_smiles = match.group(1)
return clean_smiles
else:
return text

def score(self, predictions):
if not predictions:
return {'validity': 0.0, 'uniqueness': 0.0, 'valid_smiles': []}
valid_smiles = []
for smiles in predictions:
# RDKit有时会收到None或者空字符串,这里做一下防护
if not smiles or not isinstance(smiles, str):
continue
smiles = self.extract_smiles_simple(smiles)
# 核心步骤:使用RDKit检查SMILES是否有效
mol = Chem.MolFromSmiles(smiles.strip()) # .strip()去除首尾空白
if mol is not None:
valid_smiles.append(smiles)

total_generated = len(predictions)
total_valid = len(valid_smiles)

# 计算有效率 Validity = (有效SMILES数量 / 总生成SMILES数量)
validity = float(total_valid) / float(
total_generated) if total_generated > 0 else 0.0

# 计算独特性 Uniqueness = (独特的有效SMILES数量 / 总有效SMILES数量)
if total_valid > 0:
unique_valid_smiles = set(valid_smiles)

uniqueness = float(len(unique_valid_smiles)) / float(total_valid)
else:
uniqueness = 0.0
print('validity', validity)
print('uniquness', uniqueness)
return {'validity': validity, 'uniquness': uniqueness}

+ 1
- 0
opencompass/datasets/SciReasoner/unconditional_molecule_generation/__init__.py View File

@@ -0,0 +1 @@
from .UMG import UMG_Dataset, UMG_Evaluator # noqa: F401, F403

+ 195
- 0
opencompass/datasets/SciReasoner/unconditional_protein_generation/UPG.py View File

@@ -0,0 +1,195 @@
import re

from datasets import Dataset, DatasetDict

from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS


@LOAD_DATASET.register_module()
class UPGDataset(BaseDataset):

@staticmethod
def load(tag_bool=True, max_cut=-1):
if tag_bool:
gen_inst = 'Generate a protein sequence with <protein> </protein>.'
else:
gen_inst = 'Generate a protein sequence.'
output_samples = [
'<protein>MGDVEKGKKIFIMKCSQCHTVEKGGKHKTGPNLHGLFGRKTGQAPGYSYTAANKNK'
'GIIWGEDTLMEYLENPKKYIPGTKMIFVGIKKKEERADLIAYLKKATNE</protein>',
'<protein>MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKL'
'PVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFE'
'GDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIED'
'GSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAG'
'ITLGMDELYK</protein>',
'<protein>MKALIVLGLVLLSVTVQGKVFERCELARTLKRLGMDGYRGISLANWMCLAKWESGY'
'NTRATNYNAGDRSTDYGIFQINSRYWCNDGKTPGAVNACHLSCSALLQDNIADAVACAKRVVRD'
'PQGIRAWVAWRNRCQNRDVRQYVQGCGV</protein>',
'<protein>MLEVKERIAQAKAEIPAPVELAPEEIERLLWKLGWRPVAYGSEEKARELDELYGHP'
'FAQEHPKEGAAGPVLAAARGGLEEYGAVEWGWGLGEREWAAAGRVAADVVRRLDGEAREGTLPA'
'EAEAFPALAAALEHHHHHH</protein>',
'<protein>MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRR'
'EAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN</protein>',
]

train_data = [{
'input': gen_inst,
'output': output,
} for output in output_samples]

len_test_data = 1000
# len_test_data = 10

if (max_cut != -1):
len_test_data = min(len_test_data, max_cut)

test_data = [{
'input': gen_inst,
'output': ''
} for i in range(len_test_data)]

dataset = DatasetDict({
'train': Dataset.from_list(train_data),
'test': Dataset.from_list(test_data)
})
return dataset


@TEXT_POSTPROCESSORS.register_module('UPG_postprocess')
def UPG_postprocess(text):
# Check if the input is a string;
# if not, return an empty string to improve robustness
if not isinstance(text, str):
return ''

# re.findall() searches for all occurrences of the pattern in the string
# (.*?) is a non-greedy capture group,
# capturing everything between the two tags
# re.DOTALL flag makes '.' match any character, including newlines
matches = re.findall(r'<protein>(.*?)</protein>', text, re.DOTALL)

if matches:
# If a match is found, take the last one
# and strip leading/trailing whitespace
s = matches[-1].strip()
# Remove ';'
s = s.replace(';', '')
# Remove spaces
s = s.replace(' ', '')

def clean_prediction(seq: str) -> str:
valid = set('ACDEFGHIKLMNPQRSTVWY-'
) # Valid uppercase amino acid characters
return ''.join([aa for aa in seq.upper() if aa in valid])

s = clean_prediction(s)
return s
else:
# If no match is found, return an empty string
return ''


class UPG_Evaluator(BaseEvaluator):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def _calculate_sequence_identity(self, seq1, seq2):
"""
Calculate sequence identity between two sequences.
This is a simplified implementation for sequences of equal length,
computed by direct position-wise comparison.
More accurate methods may require alignment algorithms
(e.g., Smith-Waterman).
"""
if len(seq1) != len(seq2) or not seq1:
# For unequal-length or empty sequences, treat identity as 0
# or adopt a more complex alignment strategy if needed.
# Here we return 0 for simplicity.
return 0
matches = sum(1 for a, b in zip(seq1, seq2) if a == b)
return matches / len(seq1)

def score(self, predictions, references=None):
"""
Evaluate the generated protein sequences.

Args:
predictions (list[str]): List of model-generated protein sequences.
references (list[str], optional):
Reference sequences; ignored here.

Returns:
dict: Dictionary containing evaluation metrics.
"""
if not predictions:
return {
'average_length': 0,
'diversity': 0,
'average_plddt': 0,
'info': 'Input predictions list is empty.'
}

ori_len = len(predictions)

predictions = [pred for pred in predictions if len(pred) > 0]
predictions = [
pred for pred in predictions if not (pred.strip() == '')
]
valid_rate = len(predictions) / ori_len

# --- 1. Compute Average Length ---
total_length = sum(len(seq) for seq in predictions)
avg_length = total_length / len(predictions)

# --- 2. Compute Diversity ---
# Use a greedy clustering algorithm with 99%
# sequence identity threshold
clusters_representatives = []
for seq in predictions:
is_in_existing_cluster = False
for representative in clusters_representatives:
# Note: This uses simplified equal-length identity calculation.
# For sequences of different lengths,
# use sequence alignment tools.
# As a simple strategy, compare only if lengths are close.
if abs(
len(seq) - len(representative)
) < 20: # Only compare sequences with small length differences
if self._calculate_sequence_identity(
seq,
representative) >= 0.99: # 99% sequence identity
is_in_existing_cluster = True
break
if not is_in_existing_cluster:
clusters_representatives.append(seq)

num_clusters = len(clusters_representatives)
diversity = num_clusters / len(predictions)

# --- 3. Compute Average pLDDT ---
# Only compute for sequences shorter than 100 residues
plddt_scores = []
sequences_for_plddt = [
seq for seq in predictions if (len(seq) < 100 and len(seq) > 0)
]

for s in sequences_for_plddt:
print(s)

if sequences_for_plddt:
from .omegafold.__main__ import main as plddt_main
plddt_scores = plddt_main(sequences_for_plddt)
avg_plddt = sum(plddt_scores) / len(plddt_scores)
else:
avg_plddt = 0.0 # If no sequences shorter than 100, set to 0

return {
'num_length_less_100': len(sequences_for_plddt),
'valid_rate': round(valid_rate, 4),
'average_length': round(avg_length, 2),
'diversity': round(diversity, 4),
'average_plddt': round(avg_plddt, 2)
}

+ 1
- 0
opencompass/datasets/SciReasoner/unconditional_protein_generation/__init__.py View File

@@ -0,0 +1 @@
from .UPG import UPG_Evaluator, UPG_postprocess, UPGDataset # noqa: F401, F403

+ 71
- 0
opencompass/datasets/SciReasoner/unconditional_protein_generation/main.py View File

@@ -0,0 +1,71 @@
from omegafold.__main__ import main

if __name__ == '__main__':
protein_list = [
'MSS',
'MKTIIL',
'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQ'
'CQDFTRSAAKTGTATVGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDF'
'GGYVGSIYQVWGRGTVGIVG',
'MSS',
'MSS',
'MSS',
'MSS',
'MSS',
'MSS',
'MSS',
'MSS',
'MSS',
'MSS',
'MSS',
'MSS',
'MSS',
'MSS',
'MSS',
'MKTIIL',
'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFT'
'RSAAKTGTATVGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG',
'MKTIIL',
'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAA'
'KTGTATVGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG',
'MKTIIL',
'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAAKT'
'GTATVGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG',
'MKTIIL',
'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAAKTG'
'TATVGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG',
'MKTIIL',
'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAAKTGTA'
'TVGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG',
'MKTIIL',
'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAAKTGTA'
'TVGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG',
'MKTIIL',
'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAAKTGTAT'
'VGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG',
'MKTIIL',
'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAAKTGTAT'
'VGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG',
'MKTIIL',
'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAAKTGTAT'
'VGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG',
'MKTIIL',
'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAAKTGTAT'
'VGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG',
'MKTIIL',
'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAAKTGTA'
'TVGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG',
'MKTIIL',
'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAAKTGTAT'
'VGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG',
'MKTIIL',
'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAAKTGTAT'
'VGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG',
'MKTIIL',
'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAAKTGTAT'
'VGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG',
'MKTIIL',
'MKTIIALSYIFCLVFADYKDDDDKIVGGYTCAEDEKGTYTLVGDEKPYNGTQCQDFTRSAAKTGTAT'
'VGVNQVRDGIVVGIVSWGSIAGSSENRIVGPLGILGDFGGYVGSIYQVWGRGTVGIVG',
]
print(main(protein_list))

+ 38
- 0
opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/__init__.py View File

@@ -0,0 +1,38 @@
# =============================================================================
# Copyright 2022 HeliXon Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""

"""
# =============================================================================
# Imports
# =============================================================================
from .config import make_config # noqa: F401, F403
from .model import OmegaFold # noqa: F401, F403

# =============================================================================
# Constants
# =============================================================================
# =============================================================================
# Functions
# =============================================================================
# =============================================================================
# Classes
# =============================================================================
# =============================================================================
# Tests
# =============================================================================
if __name__ == '__main__':
pass

+ 104
- 0
opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/__main__.py View File

@@ -0,0 +1,104 @@
# =============================================================================
# Copyright 2022 HeliXon Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""
The main function to run the prediction
"""
# =============================================================================
# Imports
# =============================================================================
import gc
import logging
import sys
import time

import torch

from . import OmegaFold, make_config, pipeline

# =============================================================================
# Functions
# =============================================================================


@torch.no_grad()
def main(protein_list):
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
args, state_dict, forward_config = pipeline.get_args()
# create the output directory
# os.makedirs(args.output_dir, exist_ok=True)
# get the model
logging.info('Constructing OmegaFold')
model = OmegaFold(make_config(args.model))
if state_dict is None:
logging.warning('Inferencing without loading weight')
else:
if 'model' in state_dict:
state_dict = state_dict.pop('model')
model.load_state_dict(state_dict)
model.eval()
model.to(args.device)

# logging.info(f"Reading {args.input_file}")

pLDDT_list = []

for i, (input_data, save_path) in enumerate(
pipeline.list2inputs(
protein_list,
num_pseudo_msa=args.num_pseudo_msa,
output_dir='./',
device=args.device,
mask_rate=args.pseudo_msa_mask_rate,
num_cycle=args.num_cycle,
)):
logging.info(f'Predicting {i + 1}th chain')
logging.info(
f"{len(input_data[0]['p_msa'][0])} residues in this chain.")
ts = time.time()
try:
output = model(input_data,
predict_with_confidence=True,
fwd_cfg=forward_config)
except Exception as e:
logging.error(
f'Failed to generate {save_path} due to an exception: {e}',
exc_info=True)
logging.info('Skipping...')
# 即使这里捕获了,output 仍然是 None,下面的检查会处理
continue
logging.info(f'Finished prediction in {time.time() - ts:.2f} seconds.')

# logging.info(f"Saving prediction to {save_path}")

# print(output['confidence'] * 100)

pLDDT_list.append(output['confidence'].mean().item() * 100)

# pipeline.save_pdb(
# pos14=output["final_atom_positions"],
# b_factors=output["confidence"] * 100,
# sequence=input_data[0]["p_msa"][0],
# mask=input_data[0]["p_msa_mask"][0],
# save_path=save_path,
# model=0
# )
# logging.info(f"Saved")
del output
torch.cuda.empty_cache()
gc.collect()
logging.info('Done!')

return pLDDT_list

+ 152
- 0
opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/confidence.py View File

@@ -0,0 +1,152 @@
# =============================================================================
# Copyright 2022 HeliXon Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""
Code for confidence-relevant things
"""

# =============================================================================
# Imports
# =============================================================================
import argparse

import torch
from torch import nn

from . import modules, utils

# =============================================================================
# Constants
# =============================================================================
# =============================================================================
# Functions
# =============================================================================


def get_all_confidence(
lddt_per_residue: torch.Tensor,
ca_coordinates: torch.Tensor,
ca_mask: torch.Tensor,
cutoff=15.,
) -> float:
"""
Compute an approximate LDDT score for the entire sequence

lDDT reference:
Mariani, V., Biasini, M., Barbato, A. & Schwede, T. lDDT: A local
superposition-free score for comparing protein structures and models using
distance difference tests. Bioinformatics 29, 2722–2728 (2013).

Code below adopted from
https://github.com/deepmind/alphafold/blob/1109480e6f38d71b3b265a4a25039e51e2343368/alphafold/model/lddt.py#L19

Args:
lddt_per_residue: the lddt score for each of the residues,
of shape [num_res]
ca_coordinates: the c-a coordinates of the residues,
of shape [num_res, 3]
ca_mask: mask of the c-a atoms,
of shape [num_res]
cutoff: The cutoff for each residue pair to be included

Returns:
The overall confidence for the entire prediction

"""

assert ca_coordinates.ndim == 2
assert lddt_per_residue.ndim == 1

# Compute true and predicted distance matrices.
dmat_true = torch.sqrt(
torch.sum((ca_coordinates[:, None] - ca_coordinates[None, :])**2,
dim=-1).add(1e-10))

dists_to_score = (
torch.lt(dmat_true, cutoff) * ca_mask[..., :, None] *
ca_mask[..., None, :] *
(1. - torch.eye(dmat_true.shape[1], device=ca_mask.device))
# Exclude self-interaction.
)

# Normalize over the appropriate axes.

score = ((lddt_per_residue *
(torch.sum(dists_to_score, dim=(-1, )).add(1e-10))).sum(-1) /
(1e-10 + torch.sum(dists_to_score, dim=(-1, -2))))

return score.item()


def _compute_confidence(logits: torch.Tensor) -> torch.Tensor:
"""
Computes per-residue pLDDT from logits.

Code below adopted from
https://github.com/deepmind/alphafold/blob/0be2b30b98f0da7aecb973bde04758fae67eb913/alphafold/common/confidence.py#L22

Args:
logits: the logits into the softmax, of shape [num_res, num_bins]

Returns:
predicted_lddt_ca: the predicted CA lddt, of shape [num_res]

"""
num_bins = logits.shape[-1]
bin_width = 1.0 / num_bins
bin_centers = torch.arange(start=0.5 * bin_width,
end=1.0,
step=bin_width,
device=logits.device)
probs = torch.softmax(logits, dim=-1)
confidence = torch.mv(probs, bin_centers)
return confidence


# =============================================================================
# Classes
# =============================================================================


class ConfidenceHead(modules.OFModule):
"""
This is the same pLDDT head from AF2, which provides a confidence measure
of the model's prediction

"""

def __init__(self, cfg: argparse.Namespace):
super().__init__(cfg)
self.network = nn.Sequential(
nn.Linear(cfg.node_dim, cfg.hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(cfg.hidden_dim, cfg.hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(cfg.hidden_dim, cfg.num_bins),
)

def forward(self, node_repr: torch.Tensor) -> torch.Tensor:
node_repr = utils.normalize(node_repr)
logits = self.network(node_repr)
logits = _compute_confidence(logits)

return logits


# =============================================================================
# Tests
# =============================================================================
if __name__ == '__main__':
pass

+ 118
- 0
opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/config.py View File

@@ -0,0 +1,118 @@
# =============================================================================
# Copyright 2022 HeliXon Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""
Static configuration reside in this file
"""
# =============================================================================
# Imports
# =============================================================================
import argparse


# =============================================================================
# Constants
# =============================================================================
# =============================================================================
# Functions
# =============================================================================
def _make_config(input_dict: dict) -> argparse.Namespace:
"""Recursively go through dictionary"""
new_dict = {}
for k, v in input_dict.items():
if type(v) == dict:
new_dict[k] = _make_config(v)
else:
new_dict[k] = v
return argparse.Namespace(**new_dict)


def make_config(model_idx: int = 1) -> argparse.Namespace:
if model_idx not in [1, 2]:
raise ValueError('model_idx must be 1 or 2')
cfg = dict(alphabet_size=21,
plm=dict(
alphabet_size=23,
node=1280,
padding_idx=21,
edge=66,
proj_dim=1280 * 2,
attn_dim=256,
num_head=1,
num_relpos=129,
masked_ratio=0.12,
),
node_dim=256,
edge_dim=128,
relpos_len=32,
prev_pos=dict(
first_break=3.25,
last_break=20.75,
num_bins=16,
ignore_index=0,
),
rough_dist_bin=dict(
x_min=3.25,
x_max=20.75,
x_bins=16,
),
dist_bin=dict(
x_bins=64,
x_min=2,
x_max=65,
),
pos_bin=dict(
x_bins=64,
x_min=-32,
x_max=32,
),
c=16,
geo_num_blocks=50,
gating=True,
attn_c=32,
attn_n_head=8,
transition_multiplier=4,
activation='ReLU',
opm_dim=32,
geom_count=2,
geom_c=32,
geom_head=4,
struct=dict(
node_dim=384,
edge_dim=128,
num_cycle=8,
num_transition=3,
num_head=12,
num_point_qk=4,
num_point_v=8,
num_scalar_qk=16,
num_scalar_v=16,
num_channel=128,
num_residual_block=2,
hidden_dim=128,
num_bins=50,
))
cfg['struct_embedder'] = model_idx == 2
return _make_config(cfg)


# =============================================================================
# Classes
# =============================================================================
# =============================================================================
# Tests
# =============================================================================
if __name__ == '__main__':
pass

+ 371
- 0
opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/decode.py View File

@@ -0,0 +1,371 @@
# =============================================================================
# Copyright 2022 HeliXon Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""
For generating the final coordinates of the amino acids of the predicted
"""
# =============================================================================
# Imports
# =============================================================================
import argparse
import math
import typing

import torch
from torch import nn

from . import modules, utils

# =============================================================================
# Constants
# =============================================================================
# =============================================================================
# Functions
# =============================================================================
# =============================================================================
# Classes
# =============================================================================


class InvariantPointAttention(modules.OFModule):
"""
This is the Invariant Point Attention from Jumper et al. (2021) that
performs transformer-like operation on frames

"""

def __init__(self, cfg: argparse.Namespace) -> None:
super(InvariantPointAttention, self).__init__(cfg)
node_dim = cfg.node_dim
edge_dim = cfg.edge_dim
num_head = cfg.num_head
num_scalar_qk = cfg.num_scalar_qk
num_point_qk = cfg.num_point_qk
num_scalar_v = cfg.num_scalar_v
num_point_v = cfg.num_point_v

# For scalar parts
self.q_scalar = nn.Linear(node_dim, num_head * num_scalar_qk)
self.k_scalar = nn.Linear(node_dim, num_head * num_scalar_qk)
self.v_scalar = nn.Linear(node_dim, num_head * num_scalar_v)

# to reason about the spatial relationships
self.q_point = nn.Linear(node_dim, num_head * 3 * num_point_qk)
self.k_point = nn.Linear(node_dim, num_head * 3 * num_point_qk)
self.v_point = nn.Linear(node_dim, num_head * 3 * num_point_v)

# trainable weights for edge bias
self.trainable_point_weights = nn.Parameter(
torch.full([cfg.num_head],
fill_value=math.log(math.exp(1.) - 1)), )
self.bias_2d = nn.Linear(edge_dim, num_head)

final_input_dim = edge_dim + num_scalar_v + num_point_v * 4
final_input_dim *= num_head
# output projection
self.output_projection = nn.Linear(final_input_dim, node_dim)
self.softplus = torch.nn.Softplus()

# weighting of each component
num_logit_terms = 3
scalar_variance = max(num_scalar_qk, 1) * 1.
point_variance = max(num_point_qk, 1) * 9. / 2
self.scalar_weight = math.sqrt(1 / (num_logit_terms * scalar_variance))
self.point_weight = math.sqrt(1 / (num_logit_terms * point_variance))
self.edge_logits_weight = math.sqrt(1 / num_logit_terms)

def forward(self, node_repr: torch.Tensor, edge_repr: torch.Tensor,
frames: utils.AAFrame) -> torch.Tensor:
"""
From Jumper et al. (2021), Invariant Point Attention

Args:
node_repr: the node representation,
of shape [num_res, dim_node]
edge_repr: the edge representation,
of shape [num_res, num_res, dim_edge]
frames: the backbone frames of the amino acids,
of shape [num_res]

Returns:
the node representation update of shape [num_res, dim_node]

"""
n_head = self.cfg.num_head

# acquire the scalar part of the attention logits
_q_scalar = self._get_scalar(self.q_scalar, node_repr, n_head)
_k_scalar = self._get_scalar(self.k_scalar, node_repr, n_head)
_v_scalar = self._get_scalar(self.v_scalar, node_repr, n_head)
scalar_logits = torch.einsum('qhc,khc->qkh', _q_scalar, _k_scalar)
scalar_logits *= self.scalar_weight

# acquire the 2-dimensional bias from the edge representation
edge_logits = self.bias_2d(edge_repr) * self.edge_logits_weight

# acquire the spatial part of the logits from the frames
_q_point = self._get_point(self.q_point, node_repr, n_head, frames)
_k_point = self._get_point(self.k_point, node_repr, n_head, frames)
_v_point = self._get_point(self.v_point, node_repr, n_head, frames)
dist = ((_q_point[:, None, ...] - _k_point[None, ...])**2)
point_logits = dist.sum([-1, -2]) * self.point_weight
point_logits *= self.softplus(self.trainable_point_weights) / 2

# Combine them and take the softmax
logits = scalar_logits + edge_logits - point_logits
logits += utils.mask2bias(frames.mask[None, ..., None])
attn_w = modules.softmax(logits, dim=-2, in_place=True)

# get the output
ret_edge = torch.einsum('...qkh,...qkc->...qhc', attn_w, edge_repr)
ret_scalar = torch.einsum('...qkh,...khc->...qhc', attn_w, _v_scalar)
ret_point = torch.einsum('...qkh,...khpc->...qhpc', attn_w, _v_point)
ret_point = frames.position_in_frame(ret_point)

feat = torch.cat([
ret_scalar.flatten(start_dim=-2),
ret_point.flatten(start_dim=-3),
utils.get_norm(ret_point).flatten(start_dim=-2),
ret_edge.flatten(start_dim=-2),
],
dim=-1)

return self.output_projection(feat)

@staticmethod
def _get_scalar(linear: nn.Linear, inputs: torch.Tensor,
num_head: int) -> torch.Tensor:
"""
Pass the input through linear and then perform reshaping for the
multi-headed attention

Args:
linear: the linear module to pass the input into
inputs: the input tensor to the linear module
num_head: the number of heads

Returns:
key, query, or value for the multi-headed attention,
[num_res, num_head, dim]

"""
return linear(inputs).unflatten(dim=-1, sizes=[num_head, -1])

@staticmethod
def _get_point(linear: nn.Linear, inputs: torch.Tensor, n_head: int,
transformation: utils.AAFrame) -> torch.Tensor:
"""
Pass the input through the linear and perform reshaping for the
multi-headed attention, then transform the points by the transformation

Args:
linear: the linear module to compute the local points
inputs: the inputs into the linear module, of shape
n_head: the number of head
transformation: the transformation to make local global

Returns:
points in global frame, [num_res, n_head, -1, 3]

"""
local_points = linear(inputs).unflatten(dim=-1, sizes=[n_head, -1, 3])
global_points = transformation.transform(local_points)
return global_points


class TorsionAngleHead(modules.OFModule):
"""
Predict the torsion angles of each of the amino acids from
node representation following Jumper et al. (2021)
"""

def __init__(self, cfg: argparse.Namespace):
super(TorsionAngleHead, self).__init__(cfg)

self.input_projection = nn.ModuleList(
[nn.Linear(cfg.node_dim, cfg.num_channel) for _ in range(2)])

self.resblock1 = nn.ModuleList([
nn.Linear(cfg.num_channel, cfg.num_channel)
for _ in range(cfg.num_residual_block)
])
self.resblock2 = nn.ModuleList([
nn.Linear(cfg.num_channel, cfg.num_channel)
for _ in range(cfg.num_residual_block)
])

self.unnormalized_angles = nn.Linear(cfg.num_channel, 14)

def forward(
self, representations_list: typing.Sequence[torch.Tensor]
) -> torch.Tensor:
"""
Predict side chains using multi-rigid representations.

Args:
representations_list: A list of activations to
predict side chains from.
Returns:
The normalized sin-cos representation of the torsion angles
"""
act = 0.
for (x, layer) in zip(representations_list, self.input_projection):
act = layer(torch.relu(x)) + act

for layer1, layer2 in zip(self.resblock1, self.resblock2):
old_act = act
act = layer1(torch.relu(act))
act = layer2(torch.relu(act))
act = old_act + act

sin_cos_raw = self.unnormalized_angles(torch.relu(act))

sin_cos_raw = sin_cos_raw.unflatten(dim=-1, sizes=[7, 2])
sin_cos_normalized = utils.robust_normalize(sin_cos_raw)

return sin_cos_normalized


class StructureCycle(modules.OFModule):
"""
Each of the cycles from
Jumper et al. (2021) Suppl. Alg. 20 "StructureModule"

"""

def __init__(self, cfg: argparse.Namespace) -> None:
super(StructureCycle, self).__init__(cfg)
self.ipa = InvariantPointAttention(cfg)
self.input_norm = nn.LayerNorm(cfg.node_dim)
self.transition = nn.ModuleList([
nn.Linear(cfg.node_dim, cfg.node_dim)
for _ in range(cfg.num_transition)
])
self.update_norm = nn.LayerNorm(cfg.node_dim)

self.affine_update = nn.Linear(cfg.node_dim, 6)

def forward(
self, node_repr: torch.Tensor, edge_repr: torch.Tensor,
backbone_frames: utils.AAFrame
) -> typing.Tuple[torch.Tensor, utils.AAFrame]:
"""
Perform one backbone update and node representation update

Args:
node_repr: the node representation,
of shape [num_res, dim_node]
edge_repr: the edge representation,
of shape [num_res, dim_edge]
backbone_frames: the backbone frames of the amino acids,
of shape [num_res]

Returns:

"""
node_repr += self.ipa(node_repr, edge_repr, backbone_frames)
node_repr = self.input_norm(node_repr)
# Transition
input_repr = node_repr
for layer in self.transition:
node_repr = layer(node_repr)
if layer is not self.transition[-1]:
node_repr = torch.relu(node_repr)

node_repr += input_repr # Shortcut residual connection
node_repr = self.update_norm(node_repr)
backbone_update = self.affine_update(node_repr)
frame_update = utils.AAFrame.from_tensor(backbone_update, unit='nano')
backbone_frames = backbone_frames * frame_update

return node_repr, backbone_frames


class StructureModule(modules.OFModule):
"""Jumper et al. (2021) Suppl. Alg. 20 'StructureModule'"""

def __init__(self, cfg: argparse.Namespace):
super(StructureModule, self).__init__(cfg)
self.node_norm = nn.LayerNorm(cfg.node_dim)
self.edge_norm = nn.LayerNorm(cfg.edge_dim)
self.init_proj = nn.Linear(cfg.node_dim, cfg.node_dim)

self.cycles = nn.ModuleList(
[StructureCycle(cfg) for _ in range(cfg.num_cycle)])

self.torsion_angle_pred = TorsionAngleHead(cfg)

def forward(
self, node_repr: torch.Tensor, edge_repr: torch.Tensor,
fasta: torch.Tensor, mask: torch.Tensor
) -> typing.Tuple[torch.Tensor, typing.Dict[str, typing.Union[
utils.AAFrame, torch.Tensor]]]:
"""
Jumper et al. (2021) Suppl. Alg. 20 "StructureModule"

Args:
node_repr: node representation tensor of shape [num_res, dim_node]
edge_repr: edge representation tensor of shape [num_res, dim_edge]
fasta: the tokenized sequence of the input protein sequence
mask

Returns:
node_repr: The current node representation tensor for confidence
of shape [num_res, dim_node]
dictionary containing:
final_atom_positions: the final atom14 positions,
of shape [num_res, 14, 3]
final_atom_mask: the final atom14 mask,
of shape [num_res, 14]

"""
node_repr = self.node_norm(node_repr)
edge_repr = self.edge_norm(edge_repr)

init_node_repr = node_repr
node_repr = self.init_proj(node_repr)
# Initialize the initial frames with Black-hole Jumper et al. (2021)
backbone_frames = utils.AAFrame.default_init(*node_repr.shape[0:1],
unit='nano',
device=self.device,
mask=mask.bool())

for layer in self.cycles:
node_repr, backbone_frames = layer(node_repr, edge_repr,
backbone_frames)

torsion_angles_sin_cos = self.torsion_angle_pred(
representations_list=[node_repr, init_node_repr], )

torsion_angles_mask = torch.ones_like(torsion_angles_sin_cos[..., 0],
dtype=torch.bool)
backbone_frames = backbone_frames.to_angstrom(in_place=False)
frames8 = backbone_frames.expand_w_torsion(
torsion_angles=torsion_angles_sin_cos,
torsion_angles_mask=torsion_angles_mask,
fasta=fasta)
pos14, mask14 = frames8.expanded_to_pos(fasta)
return node_repr, {
'final_frames': frames8,
'final_atom_positions': pos14,
'final_atom_mask': mask14
}


# =============================================================================
# Tests
# =============================================================================
if __name__ == '__main__':
pass

+ 384
- 0
opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/embedders.py View File

@@ -0,0 +1,384 @@
# =============================================================================
# Copyright 2022 HeliXon Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""

"""
# =============================================================================
# Imports
# =============================================================================
import argparse
import typing

import torch
from torch import nn

from . import modules, utils
from .utils import residue_constants as rc


# =============================================================================
# Constants
# =============================================================================
# =============================================================================
# Functions
# =============================================================================
def _get_pos(shape: torch.Size, device: torch.device, dtype: torch.dtype,
seq_dim: typing.Tuple[int, ...]) -> torch.Tensor:
"""Get the position of the tokens given

Args:
shape: the shape of the tensor to be applied with RoPE
device: the device on which the tensor reside
dtype: the datatype of the tensor
seq_dim: dimensions of the tensor that reference the sequence length

Returns:
The position tensor of the shape from ~shape indexed by seq_dim

"""
spatial_shape = [shape[i] for i in seq_dim]
total_len = 1
for i in spatial_shape:
total_len *= i
position = torch.arange(total_len, dtype=dtype, device=device)
position = position.reshape(*spatial_shape)

return position


def _apply_embed(inputs: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor,
seq_dim: typing.Tuple[int, ...]) -> torch.Tensor:
"""Applies RoPE to ~inputs

Args:
inputs: the tensor to which RoPE is applied, the dimensions indexed by
~seq_dim indicates the spatial dimensions
sin: the sine tensor that constitutes parts of the RoPE,
of spatial shape + vector dimension
cos: the cosine tensor that constitutes parts of the RoPE,
of spatial shape + vector dimension
seq_dim: the dimensions indicating the spatial dimensions,
must be consecutive

Returns:
tensor with RoPE applied.

"""
gaps = [(seq_dim[i + 1] - seq_dim[i]) == 1
for i in range(len(seq_dim) - 1)]
if len(gaps) > 0:
if not all(gaps):
raise ValueError(f'seq_dim must be consecutive, but got {seq_dim}')

# Align dimensions of sine and cosine
seq_dim = sorted(seq_dim)
end = seq_dim[-1]
for _ in range(seq_dim[0]):
sin = sin.unsqueeze(0)
cos = cos.unsqueeze(0)
end += 1

for _ in range(end, inputs.ndim - 1):
sin = sin.unsqueeze(_)
cos = cos.unsqueeze(_)

# Apply RoPE
x1, x2 = torch.split(inputs, inputs.shape[-1] // 2, dim=-1)
return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)


# =============================================================================
# Classes
# =============================================================================
class EdgeEmbedder(modules.OFModule):
"""
Embed the input into node and edge representations

"""

def __init__(self, cfg: argparse.Namespace) -> None:
super(EdgeEmbedder, self).__init__(cfg)

self.proj_i = nn.Embedding(cfg.alphabet_size, cfg.edge_dim)
self.proj_j = nn.Embedding(cfg.alphabet_size, cfg.edge_dim)
self.relpos = RelPosEmbedder(cfg.relpos_len * 2 + 1, cfg.edge_dim)

def forward(self, fasta_sequence: torch.Tensor,
out: torch.Tensor) -> torch.Tensor:
out += self.proj_i(fasta_sequence).unsqueeze(-2)
out += self.proj_j(fasta_sequence).unsqueeze(-3)
out += self.relpos(fasta_sequence.size(-1))

return out


class RoPE(nn.Module):
"""The RoPE module

Attributes:
input_dim: the dimension of the input vectors.

"""

def __init__(self, input_dim: int) -> None:
super(RoPE, self).__init__()
if input_dim % 2 != 0:
raise ValueError(
f'Input dimension for RoPE must be a multiple of 2,'
f' but got {input_dim}')
self.input_dim = input_dim
self.half_size = input_dim // 2
freq_seq = torch.arange(self.half_size, dtype=torch.float32)
freq_seq = -freq_seq.div(float(self.half_size))

self.register_buffer('inv_freq',
torch.pow(10000., freq_seq),
persistent=False)

def forward(self, tensor: torch.Tensor,
seq_dim: typing.Union[int, tuple]) -> torch.Tensor:
"""

Args:
tensor: the tensor to apply rope onto
seq_dim: the dimension that represents the sequence dimension

Returns:

"""
if isinstance(seq_dim, int):
seq_dim = [
seq_dim,
]
sin, cos = self._compute_sin_cos(tensor, seq_dim)

return _apply_embed(tensor, sin, cos, seq_dim)

def _compute_sin_cos(
self, tensor: torch.Tensor, seq_dim: typing.Tuple[int]
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
"""Compute sine and cosine tensors

Args:
tensor: the tensors to apply RoPE to
seq_dim: the dimension indices of the spatial dimensions

Returns:
A tuple of tensors where the first one is the sine tensor
and the second one is the cosine tensor

"""
position = _get_pos(tensor.shape, tensor.device, tensor.dtype, seq_dim)
sinusoid = torch.einsum('..., d->...d', position, self.inv_freq)
sin, cos = torch.sin(sinusoid), torch.cos(sinusoid)
return sin, cos


class RelPosEmbedder(nn.Embedding):
"""
Compute the relative positional embedding,
this is the same algorithm in
Jumper et al. (2021) Suppl. Alg. 4 "relpos"
"""

def forward(self, num_res: int) -> torch.Tensor:
"""

Args:
num_res: number of residues in input sequence.

Returns:

"""
idx = torch.arange(num_res, device=next(self.parameters()).device)
one_side = self.num_embeddings // 2
idx = (idx[None, :] - idx[:, None]).clamp(-one_side, one_side)
idx = idx + one_side
return super(RelPosEmbedder, self).forward(idx) # [num_res, dim]


class StructEmbedder(modules.OFModule):
"""
Encoder for pair wise atom distance without distance clamp
but a sublinear-function with ord encoder.
"""

def __init__(self, cfg: argparse.Namespace):
super(StructEmbedder, self).__init__(cfg)
self.rough_dist_bin = modules.Val2ContBins(cfg.rough_dist_bin)
self.dist_bin = modules.Val2ContBins(cfg.dist_bin)
self.pos_bin = modules.Val2ContBins(cfg.pos_bin)

self.aa_embedding = nn.Embedding(21 * 21, embedding_dim=cfg.c)

frame_num = 8
atom_num = 14

self.dist_bin_embedding = nn.Linear(cfg.dist_bin.x_bins, cfg.c)
self.rough_dist_bin_embedding = nn.Linear(cfg.rough_dist_bin.x_bins,
cfg.c)

self.dist_bin_linear = nn.Linear(atom_num * atom_num * cfg.c, cfg.c)
self.rough_dist_bin_linear = nn.Linear(atom_num * atom_num * cfg.c,
cfg.c)

self.pos_bin_embedding = nn.Linear(cfg.pos_bin.x_bins, cfg.c)
self.pos_linear = nn.Linear(frame_num * atom_num * 3 * cfg.c, cfg.c)

self.linear_z_weights = nn.Parameter(
torch.zeros([cfg.c, cfg.c, cfg.edge_dim]))
self.linear_z_bias = nn.Parameter(torch.zeros([cfg.edge_dim]))

def forward(
self,
fasta1: torch.Tensor,
fasta2: torch.Tensor,
pos14_a: torch.Tensor,
mask14_a: torch.Tensor,
pos14_b: torch.Tensor,
mask14_b: torch.Tensor,
frame8: utils.AAFrame,
):
pairwise_fasta = fasta1.unsqueeze(-1) * 21 + fasta2.unsqueeze(-2)
d = torch.norm(pos14_b[None, :, None] - pos14_a[:, None, :, None],
p=2,
dim=-1,
keepdim=False)
d_mask = mask14_b[None, :, None] * mask14_a[:, None, :, None]
d_mask = d_mask.unsqueeze(-1)
local_mask = torch.mul(mask14_b[None, :, None], frame8.mask[:, None, :,
None])
local_mask = local_mask.unsqueeze(-1)

local_vec = frame8.unsqueeze(1).unsqueeze(-1).position_in_frame(
pos14_b[None, :, None, :])

return self._sharded_compute(pairwise_fasta, d, local_vec, d_mask,
local_mask)

def _sharded_compute(self, pairwise_fasta: torch.Tensor, d: torch.Tensor,
local_vec: torch.Tensor, d_mask: torch.Tensor,
local_mask: torch.Tensor) -> torch.Tensor:
pairwise_fasta = self.aa_embedding(pairwise_fasta)
d1 = self.rough_dist_bin(d)
d2 = self.dist_bin(d)
d3 = self.pos_bin(local_vec)

d1 = self.rough_dist_bin_embedding(d1)
d1 = d1 * d_mask
d1 = self.rough_dist_bin_linear(d1.flatten(start_dim=-3))

d2 = self.dist_bin_embedding(d2)
d2 = d2 * d_mask
d2 = self.dist_bin_linear(d2.flatten(start_dim=-3))

d3 = self.pos_bin_embedding(d3)
d3 = d3 * (local_mask.unsqueeze(-1))
d3 = self.pos_linear(d3.flatten(start_dim=-4))

final_d = d1 + d2 + d3 # + d4
O_ = torch.einsum('...sdi,...sdj->...sdij', pairwise_fasta, final_d)
Z = torch.einsum('...sdij,ijh->...sdh', O_,
self.linear_z_weights) + self.linear_z_bias
return Z


class PairStructEmbedder(StructEmbedder):

def forward(
self,
fasta: torch.Tensor,
pos14: torch.Tensor,
pos14_mask: torch.Tensor,
frame8: utils.AAFrame,
):
return super(PairStructEmbedder, self).forward(fasta1=fasta,
fasta2=fasta,
pos14_a=pos14,
pos14_b=pos14,
mask14_a=pos14_mask,
mask14_b=pos14_mask,
frame8=frame8)


class RecycleEmbedder(modules.OFModule):
"""
The recycle embedder from Jumper et al. (2021)

"""

def __init__(self, cfg: argparse.Namespace):
super(RecycleEmbedder, self).__init__(cfg)

self.layernorm_node = nn.LayerNorm(cfg.node_dim)
self.layernorm_edge = nn.LayerNorm(cfg.edge_dim)
self.dgram = modules.Val2Bins(cfg.prev_pos)
self.prev_pos_embed = nn.Embedding(
cfg.prev_pos.num_bins,
cfg.edge_dim,
)
if cfg.struct_embedder:
self.embed_struct = PairStructEmbedder(cfg)

def forward(
self,
fasta: torch.Tensor,
prev_node: torch.Tensor,
prev_edge: torch.Tensor,
prev_x: torch.Tensor,
node_repr: torch.Tensor,
edge_repr: torch.Tensor,
atom14_mask: torch.Tensor,
prev_frames: utils.AAFrame,
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
"""Recycle the last run

Args:
fasta:
prev_node: node representations from the previous cycle
of shape [num_res, node_repr_dim]
prev_edge: edge representations from the previous cycle
of shape [num_res, num_res, edge_repr_dim]
prev_x: pseudo beta coordinates from the previous cycle.
of shape [num_res, 3]
node_repr: the node representation to put stuff in
edge_repr: the edge representation to put stuff in
atom14_mask: the mask for the 14 atoms
prev_frames: the frames from the previous cycle

Returns:

"""
atom_mask = rc.restype2atom_mask.to(self.device)[fasta]
prev_beta = utils.create_pseudo_beta(prev_x, atom_mask)
d = utils.get_norm(prev_beta.unsqueeze(-2) - prev_beta.unsqueeze(-3))
d = self.dgram(d)
node_repr[..., 0, :, :] = (node_repr[..., 0, :, :] +
self.layernorm_node(prev_node))
edge_repr += self.prev_pos_embed(d)
edge_repr += self.layernorm_edge(prev_edge)
if self.cfg.struct_embedder:
edge_repr += self.embed_struct(fasta, prev_x, atom14_mask,
prev_frames)

return node_repr, edge_repr


# =============================================================================
# Tests
# =============================================================================
if __name__ == '__main__':
pass

+ 174
- 0
opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/geoformer.py View File

@@ -0,0 +1,174 @@
# =============================================================================
# Copyright 2022 HeliXon Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""
The code for GeoFormer, the main trunk
"""
# =============================================================================
# Imports
# =============================================================================
import argparse
import typing

import torch
from torch import nn

from . import modules, utils

# =============================================================================
# Constants
# =============================================================================
# =============================================================================
# Functions
# =============================================================================
# =============================================================================
# Classes
# =============================================================================


class GeoFormerBlock(modules.OFModule):
"""
One iteration of GeoFormer

"""

def __init__(self, cfg: argparse.Namespace) -> None:
super(GeoFormerBlock, self).__init__(cfg)
self.attention_w_edge_bias = modules.AttentionWEdgeBias(
d_node=cfg.node_dim,
d_edge=cfg.edge_dim,
n_head=cfg.attn_n_head,
attn_gating=cfg.gating,
attn_c=cfg.attn_c)
self.column_attention = modules.Attention(q_dim=cfg.node_dim,
kv_dim=cfg.node_dim,
gating=cfg.gating,
n_head=cfg.attn_n_head,
c=cfg.attn_c,
out_dim=cfg.node_dim,
n_axis=1)
self.node_transition = modules.Transition(d=cfg.node_dim,
n=cfg.transition_multiplier,
activation=cfg.activation)
self.out_product = modules.Node2Edge(in_dim=cfg.node_dim,
out_dim=cfg.edge_dim,
proj_dim=cfg.opm_dim)
self.geometric_attention = nn.ModuleList([
modules.GeometricAttention(d_edge=cfg.edge_dim,
n_axis=2,
c=cfg.geom_c,
n_head=cfg.geom_head)
for _ in range(cfg.geom_count)
])
self.edge_transition = modules.Transition(d=cfg.edge_dim,
n=cfg.transition_multiplier,
activation=cfg.activation)

def forward(
self,
node_repr: torch.Tensor,
edge_repr: torch.Tensor,
mask: torch.Tensor,
*,
fwd_cfg: typing.Optional[argparse.Namespace] = None
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
"""

Args:
node_repr:
edge_repr:
mask
fwd_cfg:

Returns:

"""
node_repr += self.attention_w_edge_bias(node_repr,
edge_repr,
mask,
fwd_cfg=fwd_cfg)
node_repr = self._column_attention(node_repr, mask, fwd_cfg=fwd_cfg)
node_repr += self.node_transition(node_repr,
subbatch_size=fwd_cfg.subbatch_size)

edge_repr += self.out_product(node_repr, mask)
for layer in self.geometric_attention:
edge_repr += layer(edge_repr, mask[..., 0, :], fwd_cfg=fwd_cfg)

edge_repr += self.edge_transition(edge_repr, fwd_cfg.subbatch_size)

return node_repr, edge_repr

def _column_attention(self, node_repr, mask, fwd_cfg):
node_repr_col = utils.normalize(
node_repr.transpose(-2, -3).contiguous())
node_repr_col = self.column_attention(node_repr_col,
node_repr_col,
bias=utils.mask2bias(
mask.T[..., None, None, :]),
fwd_cfg=fwd_cfg)
node_repr += node_repr_col.transpose(-2, -3)
return node_repr


class GeoFormer(modules.OFModule):

def __init__(self, cfg: argparse.Namespace):
super(GeoFormer, self).__init__(cfg)
self.blocks = nn.ModuleList(
[GeoFormerBlock(cfg) for _ in range(cfg.geo_num_blocks)])
self.node_final_proj = nn.Linear(cfg.node_dim, cfg.struct.node_dim)

def forward(
self,
node_repr: torch.Tensor,
edge_repr: torch.Tensor,
mask: torch.Tensor,
*,
fwd_cfg: typing.Optional[argparse.Namespace] = None
) -> typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""

Args:
node_repr: the node representation from the
pretrained language model, of shape[num_res, dim]
edge_repr: the edge representation from the
pretrained language model, of shape[num_res, num_res, dim]
mask: the mask indicating the validity of the amino acid,
of [num_res].
fwd_cfg

Returns:
edge_repr: the edge representation used for recycling
node_repr: the node representation used for recycling
final_node: the node representation used for structure generation

"""

for block in self.blocks:
node_repr, edge_repr = block(node_repr,
edge_repr,
mask,
fwd_cfg=fwd_cfg)

final_node = self.node_final_proj(node_repr)
return node_repr, edge_repr, final_node


# =============================================================================
# Tests
# =============================================================================
if __name__ == '__main__':
pass

+ 248
- 0
opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/model.py View File

@@ -0,0 +1,248 @@
# =============================================================================
# Copyright 2022 HeliXon Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""

"""
# =============================================================================
# Imports
# =============================================================================
import argparse
import typing

import torch
from torch import nn

from . import (confidence, decode, embedders, geoformer, modules, omegaplm,
utils)
from .utils import residue_constants as rc

# =============================================================================
# Constants
# =============================================================================
# =============================================================================
# Functions
# =============================================================================
# =============================================================================
# Classes
# =============================================================================


class OmegaFoldCycle(modules.OFModule):

def __init__(self, cfg: argparse.Namespace) -> None:
super(OmegaFoldCycle, self).__init__(cfg)

self.geoformer = geoformer.GeoFormer(cfg)
self.structure_module = decode.StructureModule(cfg.struct)
self.confidence_head = confidence.ConfidenceHead(cfg.struct)

def forward(
self,
fasta: torch.Tensor,
mask: torch.Tensor,
node_repr: torch.Tensor,
edge_repr: torch.Tensor,
fwd_cfg: typing.Optional[argparse.Namespace],
) -> typing.Tuple[typing.Dict[str, torch.Tensor], typing.Dict[
str, typing.Union[torch.Tensor, utils.AAFrame]]]:
"""
The forward method for one iteration of OmegaFold

Args:
fasta: the tokenized sequence of the protein, of shape,
of shape [num_res]
mask: If to ignore, of shape,
of shape [num_res]
node_repr:
of shape [num_res, node_repr_dim]
edge_repr:
of shape [num_res, node_repr, edge_repr_dim]
fwd_cfg:

Returns:
ret: A dictionary containing:
confidence: the confidence score of the output protein structure

"""

prev_node, edge_repr, node_repr = self.geoformer(node_repr=node_repr,
edge_repr=edge_repr,
mask=mask,
fwd_cfg=fwd_cfg)

node_repr, ret = self.structure_module(
node_repr=node_repr[..., 0, :, :],
edge_repr=edge_repr,
fasta=fasta,
mask=mask[..., 0, :],
)

ret['confidence'] = self.confidence_head(node_repr)

prev_dict = {
'prev_node': prev_node[..., 0, :, :],
'prev_edge': edge_repr,
'prev_x': ret['final_atom_positions'],
'prev_frames': ret['final_frames'],
}
return ret, prev_dict


_INPUTS = typing.List[typing.Dict[typing.Union[str, int], typing.Any]]


class OmegaFold(modules.OFModule):
"""
The Entire OmegaFold model that comprises a pretrained Protein Language
Model, an encoder of the primary sequence, as well as a structure module
for decoding

"""

def __init__(self, cfg: argparse.Namespace) -> None:
super(OmegaFold, self).__init__(cfg)
self.omega_plm = omegaplm.OmegaPLM(cfg.plm)
self.plm_node_embedder = nn.Linear(cfg.plm.node, cfg.node_dim)
self.plm_edge_embedder = nn.Linear(cfg.plm.edge, cfg.edge_dim)
self.input_embedder = embedders.EdgeEmbedder(cfg)
self.recycle_embedder = embedders.RecycleEmbedder(cfg)
self.omega_fold_cycle = OmegaFoldCycle(cfg)

def forward(
self,
inputs: _INPUTS,
predict_with_confidence: typing.Optional[bool] = True,
*,
fwd_cfg: typing.Optional[argparse.Namespace] = None
) -> typing.Dict[str, typing.Union[torch.Tensor, float]]:
"""
The forward implementation of OmegaFold

Args:
inputs:
predict_with_confidence: if to choose with confidence
fwd_cfg: forward configuration

Returns:

"""
# Preparation before entering the cycles
primary_sequence = inputs[0]['p_msa'][..., 0, :]
max_confidence = 0
prev_dict = self.create_initial_prev_dict(len(primary_sequence))
final_result = None

# Start cycling
residx_atom14_mask = rc.restype_atom14_mask.to(
device=primary_sequence.device)[primary_sequence]
for cycle_data in inputs:
p_msa, p_msa_mask = cycle_data['p_msa'], cycle_data['p_msa_mask']
fasta, mask = p_msa[..., 0, :], p_msa_mask[..., 0, :]
node_repr, edge_repr = self.deep_sequence_embed(
p_msa, p_msa_mask, fwd_cfg)
node_recycle, edge_repr = self.recycle_embedder(
fasta=fasta,
prev_node=prev_dict.pop('prev_node'),
prev_edge=prev_dict.pop('prev_edge'),
prev_x=prev_dict.pop('prev_x'),
node_repr=node_repr,
edge_repr=edge_repr,
atom14_mask=residx_atom14_mask,
prev_frames=prev_dict.pop('prev_frames'))

result, prev_dict = self.omega_fold_cycle(fasta=fasta,
mask=p_msa_mask,
node_repr=node_repr,
edge_repr=edge_repr,
fwd_cfg=fwd_cfg)

confidence_overall = confidence.get_all_confidence(
result['confidence'],
result['final_atom_positions'][..., 1, :], mask)
result['confidence_overall'] = confidence_overall
if predict_with_confidence:
if confidence_overall > max_confidence:
max_confidence = confidence_overall
final_result = result
else:
final_result = result

return final_result

def deep_sequence_embed(
self,
fasta: torch.Tensor,
mask: torch.Tensor,
fwd_cfg: typing.Optional[argparse.Namespace],
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
"""
Run the forward method of the pretrained-language model

Args:
fasta: the fasta sequence
mask: the mask indicating the validity of the token

Returns:

"""
node_repr, edge_repr = self.omega_plm(fasta, mask, fwd_cfg=fwd_cfg)
# return node_plm, edge_plm
node_repr = self.plm_node_embedder(
utils.normalize(node_repr, in_place=True))
edge_repr = edge_repr.permute(1, 2, 0)
edge_repr = self.plm_edge_embedder(
utils.normalize(edge_repr, in_place=True))
edge_repr = self.input_embedder(fasta[..., 0, :], out=edge_repr)

return node_repr, edge_repr

def create_initial_prev_dict(
self, num_res: int) -> typing.Dict[str, torch.Tensor]:
"""
Generate 'previous' (filling with 0's) features for the model

Args:
num_res: the number of residues

Returns:

"""
return {
'prev_node':
torch.zeros([num_res, self.cfg.node_dim],
device=self.device,
dtype=torch.float),
'prev_edge':
torch.zeros([num_res, num_res, self.cfg.edge_dim],
device=self.device,
dtype=torch.float),
'prev_x':
torch.zeros([num_res, 14, 3],
device=self.device,
dtype=torch.float),
'prev_frames':
utils.AAFrame.default_init(num_res,
8,
unit='Angstrom',
device=self.device)
}


# =============================================================================
# Tests
# =============================================================================
if __name__ == '__main__':
pass

+ 636
- 0
opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/modules.py View File

@@ -0,0 +1,636 @@
# =============================================================================
# Copyright 2022 HeliXon Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""

"""
# =============================================================================
# Imports
# =============================================================================
import argparse
import numbers
import typing

import torch
from torch import nn

from . import utils


# =============================================================================
# Constants
# =============================================================================
# =============================================================================
# Functions
# =============================================================================
def softmax(x: torch.Tensor,
dim: int,
*,
dtype: typing.Optional[torch.dtype] = None,
in_place: bool = False) -> torch.Tensor:
"""
In-place or normal softmax

Args:
x: the input tensor
dim: the dimension along which to perform the softmax
dtype: the data type
in_place: if to perform inplace

Returns:

"""
if in_place:
max_val = torch.max(x, dim=dim, keepdim=True)[0]
torch.sub(x, max_val, out=x)
torch.exp(x, out=x)
summed = torch.sum(x, dim=dim, keepdim=True)
x /= summed
return x
else:
return torch.softmax(input=x, dim=dim, dtype=dtype)


def _attention(
query: torch.Tensor, key: torch.Tensor, scale: torch.Tensor,
value: torch.Tensor, bias: torch.Tensor, return_edge: bool,
edge_reduction: str, edge_reduction_dim: int
) -> typing.Tuple[torch.Tensor, typing.Optional[torch.Tensor]]:
"""Normal attention

Args:
query: positive tensor of shape (*_q, dim_qk)
key: positive tensor of shape (*_k, dim_qk)
scale: the scaling of logits
value: tensor of shape (*_k, dim_v)
bias: the bias acting as either mask or relative positional encoding
return_edge: if to return the logits of attention

Returns:
The aggregated tensor of shape (*_q, dim_v)

"""
logits = torch.einsum('...id, ...jd -> ...ij', query * scale, key)
logits.add_(bias)
attn = softmax(logits, dim=-1, in_place=not return_edge)
out = torch.einsum('...ij, ...jd -> ...id', attn, value)
if return_edge:
attn = getattr(attn, edge_reduction)(dim=edge_reduction_dim)
return out, attn
else:
return out, None


def attention(
query: torch.Tensor,
key: torch.Tensor,
scale: typing.Union[torch.Tensor, float],
value: torch.Tensor,
bias: torch.Tensor,
subbatch_size: typing.Optional[int] = None,
*,
return_edge: bool = False,
edge_reduction: str = 'sum',
edge_reduction_dim: int = 0,
) -> typing.Tuple[torch.Tensor, typing.Tuple[torch.Tensor]]:
"""Computes attention with q, k , v

Args:
query: positive tensor of shape (*_q, dim_qk)
key: positive tensor of shape (*_k, dim_qk)
scale: the scaling of logits
value: tensor of shape (*_k, dim_v)
bias: the bias acting as either mask or relative positional encoding
subbatch_size: the subbatch size to split the computation into
return_edge: if to return the logits
edge_reduction:
edge_reduction_dim:

Returns:
The aggregated tensor of shape (*_q, dim_v)

"""
q_length, k_length, v_dim = query.shape[-2], key.shape[-2], value.shape[-1]
subbatch_size = subbatch_size or q_length

batch_shape = list(query.shape[:-2])
factory_kwargs = nn.factory_kwargs({
'device': query.device,
'dtype': query.dtype
})
output = torch.empty(*batch_shape, q_length, v_dim, **factory_kwargs)
if return_edge:
batch_shape.pop(edge_reduction_dim + 2)
attns = torch.empty(*batch_shape, q_length, k_length, **factory_kwargs)
else:
attns = None

for i, q_i in enumerate(query.split(subbatch_size, dim=-2)):
start, end = i * subbatch_size, (i + 1) * subbatch_size,
if bias.shape[-2] != q_length:
b_i = bias
else:
b_i = bias[..., start:end, :]

res, attn = _attention(q_i, key, scale, value, b_i, return_edge,
edge_reduction, edge_reduction_dim)
output[..., start:end, :] = res
if return_edge:
attns[..., start:end, :] = attn

return output, attns


# =============================================================================
# Classes
# =============================================================================


class OFModule(nn.Module):
"""
The OmegaFold modules
args: The arguments used for each of the modules
"""

def __init__(self, cfg: typing.Optional[argparse.Namespace]) -> None:
super(OFModule, self).__init__()
self.cfg = cfg

@property
def device(self) -> torch.device:
return next(self.parameters()).device

@property
def dtype(self) -> torch.dtype:
return next(self.parameters()).dtype


class Transition(OFModule):

def __init__(self, d: int, n: int, activation: str) -> None:
super(Transition, self).__init__(None)
fc1 = nn.Linear(d, n * d)
fc2 = nn.Linear(n * d, d)
try:
act = getattr(nn, activation)(inplace=True)
except TypeError:
act = getattr(nn, activation)()
self.network = nn.Sequential(fc1, act, fc2)

def forward(self, x: torch.Tensor,
subbatch_size: typing.Optional[int]) -> torch.Tensor:
subbatch_size = subbatch_size or x.shape[-2]

out = torch.empty_like(x)
for i, x_i in enumerate(x.split(subbatch_size, dim=0)):
start, end = i * subbatch_size, (i + 1) * subbatch_size
x_i = utils.normalize(x_i)
out[start:end] = self.network(x_i)
return out


class MultiHeadedScaling(OFModule):
"""
Perform an element wise scale shift

"""

def __init__(
self,
shape: typing.Union[int, typing.List[int], torch.Size],
num_heads: int,
on_out_ready: typing.Optional[typing.Callable[[torch.Tensor],
torch.Tensor]],
dtype: typing.Optional[torch.dtype] = None,
) -> None:
"""

Args:
shape: the shape of the input dimensions
num_heads: the number of dimensions to squeeze to
dtype: the dtype of the parameters at generation
on_out_ready: the function called on exit
"""
super(MultiHeadedScaling, self).__init__(None)
factory_kwargs = nn.factory_kwargs({'dtype': dtype})
if isinstance(shape, numbers.Integral):
shape = (shape, )
shape = list(tuple(shape))
self.unsqueeze_dim = -(len(shape) + 1)
shape.insert(0, num_heads)
self.shape = shape
self.split_dims = [1] * num_heads
self.weight = nn.Parameter(torch.empty(self.shape, **factory_kwargs))
self.bias = nn.Parameter(torch.empty(self.shape, **factory_kwargs))
self.call_on_out_ready = on_out_ready

self.reset_parameters()

def forward(self, x: torch.Tensor) -> typing.List[torch.Tensor]:
"""
Element wise multiplication followed by addition

Args:
x: the input tensor with the trailing dimensions following
~self.shape

Returns:
A output tensor of the same shape

"""
x = x.unsqueeze(self.unsqueeze_dim) * self.weight + self.bias
positive_index = x.ndim + self.unsqueeze_dim
if self.call_on_out_ready is not None:
x = self.call_on_out_ready(x)

x = x.split(self.split_dims, dim=positive_index)

return [x_i.squeeze(positive_index) for x_i in x]

def reset_parameters(self):
nn.init.normal_(self.weight, std=0.02)
nn.init.zeros_(self.bias)


class Val2ContBins(OFModule):

def __init__(
self,
cfg: argparse.Namespace,
):
super(Val2ContBins, self).__init__(cfg)

x_bin_size = (cfg.x_max - cfg.x_min) / (cfg.x_bins - 2)

self.register_buffer('x_offset',
torch.linspace(cfg.x_min - x_bin_size / 2,
cfg.x_max + x_bin_size / 2,
cfg.x_bins),
persistent=False)
self.coeff = -0.5 / ((x_bin_size * 0.2)**2)
# `*0.5`: makes it not too blurred

def forward(self, dist_x): # (*)
x_offset_shape = [1] * len(dist_x.size()) + [len(self.x_offset)]
x = dist_x.unsqueeze(-1) - self.x_offset.view(*x_offset_shape)
x_norm = self.coeff * torch.pow(x, 2)
x_norm = x_norm - x_norm.max(-1, keepdim=True)[0]
logits = torch.softmax(x_norm, dim=-1)

return logits


class Val2Bins(OFModule):
"""
Convert continuous values to bins

Attributes:
breaks: the line space break
"""

def __init__(self, cfg: argparse.Namespace) -> None:
super(Val2Bins, self).__init__(cfg)
self.register_buffer('breaks',
torch.linspace(cfg.first_break, cfg.last_break,
cfg.num_bins - 1),
persistent=False)

def forward(self, dist: torch.Tensor) -> torch.Tensor:
"""

Args:
dist: distances in the euclidean space.

Returns:

"""
dist = dist.unsqueeze(-1)
dist_bin = torch.sum(torch.gt(dist, self.breaks),
dim=-1,
dtype=torch.long)
return dist_bin


class Node2Edge(OFModule):
"""Communicate between tracks

faster than OutProductMean mostly due to a better implementation
"""

def __init__(self, in_dim: int, proj_dim: int, out_dim: int) -> None:
super(Node2Edge, self).__init__(None)
self.input_proj = nn.Linear(in_dim, proj_dim * 2)
self.proj_dim = proj_dim
self.out_weights = nn.Parameter(
torch.empty(proj_dim, proj_dim, out_dim))
self.out_bias = nn.Parameter(torch.empty(out_dim))

def forward(self, node_repr: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
node_repr = utils.normalize(node_repr)
act = self.input_proj(node_repr)
mask = mask[..., None]
act = act * mask
norm = torch.einsum('...sid, ...sjd->...ijd', mask, mask)

l, r = act.split(self.proj_dim, dim=-1)
# We found this implementation to work significantly faster
out = torch.einsum('...sid, def, ...sje-> ...ijf', l, self.out_weights,
r) + self.out_bias
out = out / (norm + 1e-3)

return out


class Attention(OFModule):
"""
Widely used attention mechanism

Attributes:
qg_weights (nn.Parameter): weight matrices for queries and gates
qg_bias (nn.Parameter): biases for queries and gates
kv_weights (nn.Parameter): weight matrices for queries and gates
kv_bias (nn.Linear): biases for keys and values

o_weights (nn.Linear): the output weight matrix
o_bias (nn.Linear): the output bias
"""

def __init__(self, q_dim: int, kv_dim: int, n_head: int, gating: bool,
c: int, out_dim: int, n_axis: int) -> None:
super(Attention, self).__init__(None)
self.c = c
self.n_head = n_head
self.gating = gating
self.q_dim = q_dim
self.n_axis = n_axis

self.qg_weights = nn.Parameter(
torch.empty(q_dim, n_axis, n_head, (gating + 1) * c))
self.kv_weights = nn.Parameter(
torch.empty(kv_dim, n_axis, n_head, 2 * c))
self.qg_bias = nn.Parameter(
torch.empty(n_axis, n_head, 1, c * (1 + gating)))
self.kv_bias = nn.Parameter(torch.empty(n_axis, n_head, 1, c * 2))

self.o_weights = nn.Parameter(torch.empty(n_axis, n_head, c, out_dim))
self.o_bias = nn.Parameter(torch.empty([out_dim, n_axis]))

def forward(
self,
q_inputs: torch.Tensor,
kv_inputs: torch.Tensor,
bias: torch.Tensor,
*,
fwd_cfg: typing.Optional[argparse.Namespace] = None
) -> typing.Union[typing.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
"""
Perform the standard multi-headed attention with added gating with some
biases

Args:
q_inputs: the inputs to generate query vectors,
of shape (*, q_len, q_dim, (n_axis))
kv_inputs: the inputs to generate key and value vectors,
of shape (*, kv_len, kv_dim, (n_axis))
bias: the bias for the logits
of shape (*, n_head, q_len, kv_len)
fwd_cfg: if return logits

Return:
output tensor (*, seq_len, o_dim, (n_axis))
attention logits (Optional) (q_len, kv_len, num_head)
"""

# Acquire the q, k, v tensors
to_unsqueeze = (q_inputs.shape[-1] != self.n_axis
and q_inputs.shape[-1] == self.q_dim)
if to_unsqueeze:
q_inputs = q_inputs.unsqueeze(-1)
kv_inputs = kv_inputs.unsqueeze(-1)
if bias is not None:
bias = bias.unsqueeze(-4)

attn_out = self._get_attn_out(q_inputs, kv_inputs, fwd_cfg, bias)

output = torch.einsum('...rhqc,rhco->...qor', attn_out, self.o_weights)
output += self.o_bias

if to_unsqueeze:
output = output.squeeze(-1)
return output

def _get_attn_out(self, q_inputs, kv_inputs, fwd_cfg, bias):

qg = torch.einsum('...qar,arhc->...rhqc', q_inputs, self.qg_weights)
qg += self.qg_bias
q_out = qg.split(self.c, dim=-1)
q = q_out[0]

kv = torch.einsum('...kar,arhc->...rhkc', kv_inputs, self.kv_weights)
kv += self.kv_bias
k, v = kv.split([self.c, self.c], dim=-1)

# Attention
subbatch_size = (q.shape[-4]
if fwd_cfg is None else fwd_cfg.subbatch_size)
attn_out, _ = attention(query=q,
key=k,
value=v,
subbatch_size=subbatch_size,
bias=bias,
scale=self.c**(-0.5))
# get the gating
if self.gating:
g = torch.sigmoid(q_out[1])
attn_out *= g

return attn_out


class AttentionWEdgeBias(OFModule):

def __init__(self, d_node: int, d_edge: int, n_head: int,
attn_gating: bool, attn_c: int) -> None:
super(AttentionWEdgeBias, self).__init__(None)
self.proj_edge_bias = nn.Linear(
in_features=d_edge,
out_features=n_head # , bias=False
)
self.attention = Attention(q_dim=d_node,
kv_dim=d_node,
n_head=n_head,
gating=attn_gating,
c=attn_c,
out_dim=d_node,
n_axis=1)

def forward(
self,
node_repr: torch.Tensor,
edge_repr: torch.Tensor,
mask: torch.Tensor,
*,
fwd_cfg: typing.Optional[argparse.Namespace] = None
) -> typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, torch.Tensor]]:
"""

Args:
node_repr:
edge_repr:
mask:
fwd_cfg:

Returns:

"""
node_repr = utils.normalize(node_repr)
edge_repr = utils.normalize(edge_repr)
# check dim
edge_bias = self.proj_edge_bias(edge_repr).permute(2, 0, 1)

edge_bias = edge_bias + utils.mask2bias(mask[..., None, None, :])
attn_out = self.attention(node_repr,
node_repr,
bias=edge_bias,
fwd_cfg=fwd_cfg)
return attn_out


def _get_sharded_stacked(edge_repr: torch.Tensor, subbatch_size: int):
subbatch_size = subbatch_size or edge_repr.shape[-2]
idx = 0
start, end = 0, subbatch_size
while start < edge_repr.shape[-2]:
yield start, end, torch.stack(
[edge_repr[start:end],
edge_repr.transpose(-2, -3)[start:end]],
dim=-1)
idx += 1
start, end = idx * subbatch_size, (idx + 1) * subbatch_size


class GeometricAttention(OFModule):
"""We have a lot of stuff here for GRAM reduction

"""

def __init__(self, d_edge: int, c: int, n_head: int, n_axis: int) -> None:
super(GeometricAttention, self).__init__(None)
self.d_edge = d_edge
self.n_axis = n_axis
self.n_head = n_head
self.linear_b_weights = nn.Parameter(
torch.empty([d_edge, n_axis, n_head]))
self.linear_b_bias = nn.Parameter(torch.empty([n_axis, n_head, 1, 1]))

self.act_w = nn.Parameter(torch.empty([d_edge, n_axis, d_edge * 5]))
self.act_b = nn.Parameter(torch.empty([n_axis, d_edge * 5]))

self.out_proj_w = nn.Parameter(torch.empty([n_axis, d_edge, d_edge]))
self.out_proj_b = nn.Parameter(torch.empty([n_axis, d_edge]))
self.glu = nn.GLU()

self.attention = Attention(q_dim=d_edge,
kv_dim=d_edge,
n_head=n_head,
c=c,
gating=True,
out_dim=d_edge,
n_axis=n_axis)

def _get_attended(self, edge_repr: torch.Tensor, mask: torch.Tensor,
fwd_cfg) -> torch.Tensor:
attended = torch.empty(*edge_repr.shape,
self.n_axis,
dtype=edge_repr.dtype,
device=edge_repr.device)
b = torch.zeros(self.n_axis,
self.n_head,
*edge_repr.shape[:2],
dtype=edge_repr.dtype,
device=edge_repr.device)
b += utils.mask2bias(mask)
for s, e, edge_r in _get_sharded_stacked(
edge_repr, subbatch_size=fwd_cfg.subbatch_size):
b[..., s:e, :] = torch.einsum(
'...qkcr,crh->...rhqk', edge_r,
self.linear_b_weights) + self.linear_b_bias
for s, e, edge_r in _get_sharded_stacked(
edge_repr, subbatch_size=fwd_cfg.subbatch_size):
attended[s:e] = self.attention(edge_r, edge_r, b, fwd_cfg=fwd_cfg)
return attended[..., 0] + attended[..., 1].transpose(-2, -3)

def _get_gated(self, edge_repr: torch.Tensor, mask: torch.Tensor, fwd_cfg):
gated = torch.empty(*edge_repr.shape[:2],
self.n_axis,
self.d_edge,
device=edge_repr.device,
dtype=edge_repr.dtype)
for s_row, e_row, edge_row in _get_sharded_stacked(
edge_repr, subbatch_size=fwd_cfg.subbatch_size):
act_row = self._get_act_row(edge_row, mask[s_row:e_row])
act_g = torch.sigmoid(
torch.einsum('...dr,drc->...rc', edge_row, self.act_w[
..., -self.d_edge:]) + self.act_b[..., -self.d_edge:])
for s_col, e_col, edge_col, in _get_sharded_stacked(
edge_repr, subbatch_size=fwd_cfg.subbatch_size):
act_col = self._get_act_col(edge_col, mask[s_col:e_col])
ab = torch.einsum('...ikrd,...jkrd->...ijrd', act_row, act_col)
ab = utils.normalize(ab.contiguous())
gated[s_row:e_row,
s_col:e_col] = torch.einsum('...rd,rdc->...rc', ab,
self.out_proj_w)
gated[s_row:e_row, s_col:e_col].add_(self.out_proj_b)
gated[s_row:e_row, s_col:e_col] *= act_g[:, s_col:e_col]

return gated.sum(-2)

def _get_sliced_weight(self, weight: torch.Tensor, shift=0):
w = weight[..., :-self.d_edge].unflatten(-1, sizes=(4, -1))
w = w[..., shift::2, :]
w = w.flatten(start_dim=-2)
return w

def _get_act_row(self, edge_row: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
w = self._get_sliced_weight(self.act_w)
b = self._get_sliced_weight(self.act_b)
act = torch.einsum('...dr,drc->...rc', edge_row, w) + b
act = self.glu(act) * mask[..., None, None, None]
return act

def _get_act_col(self, edge_row: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
w = self._get_sliced_weight(self.act_w, shift=1)
b = self._get_sliced_weight(self.act_b, shift=1)
act = torch.einsum('...dr,drc->...rc', edge_row, w) + b
act = self.glu(act) * mask[..., None, None, None]
return act

def forward(self, edge_repr: torch.Tensor, mask: torch.Tensor,
fwd_cfg) -> torch.Tensor:
edge_repr = utils.normalize(edge_repr)
out = self._get_attended(edge_repr, mask, fwd_cfg)
out += self._get_gated(edge_repr, mask, fwd_cfg)

return out


# =============================================================================
# Tests
# =============================================================================
if __name__ == '__main__':
pass

+ 233
- 0
opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/omegaplm.py View File

@@ -0,0 +1,233 @@
# =============================================================================
# Copyright 2022 HeliXon Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""

"""
# =============================================================================
# Imports
# =============================================================================
import argparse
import math
import typing

import torch
from torch import nn

from . import embedders, modules, utils


# =============================================================================
# Constants
# =============================================================================
# =============================================================================
# Functions
# =============================================================================
def _get_qk_scaling(num_res: torch.Tensor, attn_dim: int) -> torch.Tensor:
"""
https://kexue.fm/archives/8823

Args:
num_res: [num_chunks]
attn_dim

Returns:

"""
return num_res.clamp(min=4e-5).log() / (math.log(512) * attn_dim**0.5)


# =============================================================================
# Classes
# =============================================================================
class GatedAttentionUnit(modules.OFModule):
"""

"""

def __init__(self, cfg: argparse.Namespace):
super(GatedAttentionUnit, self).__init__(cfg)
self.cfg = cfg
self.gva_proj = nn.Sequential(
nn.Linear(cfg.node, cfg.proj_dim * 2 + cfg.attn_dim), nn.SiLU())
self.multi_headed_scaling = modules.MultiHeadedScaling(
cfg.attn_dim,
num_heads=2,
on_out_ready=lambda x: self.rope(x, x.ndim - 3))
self.rope = embedders.RoPE(cfg.attn_dim)
self.relpos = embedders.RelPosEmbedder(cfg.num_relpos, embedding_dim=1)
self.output_proj = nn.Linear(cfg.proj_dim, cfg.node)

def forward(
self, node: torch.Tensor, scaling: torch.Tensor, bias: torch.Tensor,
fwd_cfg: typing.Optional[argparse.Namespace]
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
"""
The forward method of this class

Args:
node: the node representation
scaling: logits scaling
bias:
fwd_cfg:

Returns:

"""
cfg = self.cfg
# initial projection
gates, values, base = self.gva_proj(node).split(
[cfg.proj_dim, cfg.proj_dim, cfg.attn_dim], dim=-1)
queries, keys = self.multi_headed_scaling(base)

node, edge = modules.attention(
query=queries,
key=keys,
scale=scaling,
value=values,
bias=bias + self.relpos(base.shape[-2])[..., 0],
subbatch_size=fwd_cfg.subbatch_size,
return_edge=True,
edge_reduction='sum',
edge_reduction_dim=-3,
)

# unflatten the values, base will be unflattened in self._forward
node = node * gates
node = self.output_proj(node)
return node, edge


class OmegaPLMLayer(modules.OFModule):
"""One OmegaPLM Layer

This layer baked the pre-layernorm configuration into the model

Attributes:
gau: the underlying GAU layer containing most of the computations

"""

def __init__(self, cfg: argparse.Namespace) -> None:
super(OmegaPLMLayer, self).__init__(cfg)
self.gau = GatedAttentionUnit(cfg)

def forward(
self, node: torch.Tensor, qk_scaling: torch.Tensor, bias: torch.Tensor,
fwd_cfg: typing.Optional[argparse.Namespace]
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
"""Forward method for pre-layernorm

One layer of OmegaPLM

Args:
node: the node representation
qk_scaling: the scaling of logits before attention
bias: the bias for logits before attention
fwd_cfg

Returns:
node and edge representation

"""
shortcut, node = node, utils.normalize(node)
node, edge = self.gau(node, qk_scaling, bias, fwd_cfg)
node = node + shortcut
return node, edge


class OmegaPLM(modules.OFModule):
"""Encoder GAU model

This is the OmegaPLM model in Wu et al. 2022.

Attributes:
input_embedding: This is an embedding layer
layers: the trunk of the network containing modified GAU layers
output_norm: an output normalization layer

"""

def __init__(self, cfg: argparse.Namespace) -> None:
super(OmegaPLM, self).__init__(cfg)
self.input_embedding = nn.Embedding(cfg.alphabet_size,
cfg.node,
padding_idx=cfg.padding_idx)
self.layers = nn.ModuleList(
[OmegaPLMLayer(cfg) for _ in range(cfg.edge)])
self.output_norm = nn.LayerNorm(cfg.node)

def forward(
self, tokens: torch.Tensor, mask: torch.Tensor,
fwd_cfg: typing.Optional[argparse.Namespace]
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
"""Forward method

Args:
tokens: A tensor of input tokens,
of shape [*, seq_len]
mask: mask indicating the validity of the tokens,
of shape [*, seq_len]
fwd_cfg

Returns:

"""
qk_scaling = _get_qk_scaling(mask.sum(-1), self.cfg.attn_dim)
qk_scaling = qk_scaling[..., None, None]
bias = utils.mask2bias(mask[..., None, :])

node = self.input_embedding(tokens)
node *= self._get_finetuning_scale(mask, tokens)
edges = torch.empty(len(self.layers),
mask.shape[-1],
mask.shape[-1],
dtype=node.dtype,
device=node.device)
for i, layer in enumerate(self.layers):
node, edges[i] = layer(node, qk_scaling, bias, fwd_cfg)
node = self.output_norm(node)

# Taking the average
edges /= (mask.any(-1).sum() + 1e-5)

return node, edges

def _get_finetuning_scale(self, mask: torch.Tensor,
tokens: torch.Tensor) -> torch.Tensor:
"""Token dropout scaling

This computes the scaling from Rives et al. 2021

Args:
mask: the mask indicating the validity of the input sequence

Returns:

"""
un_masked_ratio_train = 1 - self.cfg.masked_ratio
src_lengths = mask.sum(-1)
mask_ratio_observed = tokens.eq(21).sum(-1).float() / src_lengths
mask_ratio_observed = torch.where(
mask_ratio_observed == 1.,
torch.full_like(mask_ratio_observed, 0.99), mask_ratio_observed)
return un_masked_ratio_train / (1 - mask_ratio_observed)[:, None, None]


# =============================================================================
# Tests
# =============================================================================
if __name__ == '__main__':
pass

+ 424
- 0
opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/pipeline.py View File

@@ -0,0 +1,424 @@
# =============================================================================
# Copyright 2022 HeliXon Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""
This file contains the utilities that we use for the entire inference pipeline
"""
# =============================================================================
# Imports
# =============================================================================
from __future__ import annotations

import collections
import logging
import ntpath
import os
import os.path
import pathlib
import types
import typing

import torch
from Bio import PDB as PDB
from Bio.PDB import StructureBuilder
from huggingface_hub import hf_hub_download
from torch import hub
from torch.backends import cuda, cudnn

from . import utils
from .utils.protein_utils import residue_constants as rc

try:
from torch.backends import mps # Compatibility with earlier versions

_mps_is_available = mps.is_available
except ImportError:

def _mps_is_available():
return False


# =============================================================================
# Constants
# =============================================================================
# =============================================================================
# Functions
# =============================================================================
def _set_precision(allow_tf32: bool) -> None:
"""Set precision (mostly to do with tensorfloat32)

This allows user to go to fp32

Args:
allow_tf32: if allowing

Returns:

"""
if int(torch.__version__.split('.')[1]) < 12:
cuda.matmul.allow_tf32 = allow_tf32
cudnn.allow_tf32 = allow_tf32
else:
precision = 'high' if allow_tf32 else 'highest'
torch.set_float32_matmul_precision(precision)


def path_leaf(path: str) -> str:
"""
Get the filename from the path

Args:
path: the absolute or relative path to the file

Returns:
the filename

"""
head, tail = ntpath.split(path)
return tail or ntpath.basename(head)


def fasta2inputs(
fasta_path: str,
output_dir: typing.Optional[str] = None,
num_pseudo_msa: int = 15,
device: typing.Optional[torch.device] = torch.device('cpu'),
mask_rate: float = 0.12,
num_cycle: int = 10,
deterministic: bool = True
) -> typing.Generator[typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
str], None, None]:
"""
Load a fasta file and

Args:
fasta_path: the path to the fasta files
output_dir: the path to the output directory
num_pseudo_msa:
device: the device to move
mask_rate:
num_cycle:
deterministic:

Returns:

"""
chain_ids: list[str] = []
aastr: list[str] = []
with open(fasta_path, 'r') as file:
lines = file.readlines()
name = False
for line in lines:
if len(line) == 0:
continue
if line.startswith('>') or line.startswith(':'):
name = True
chain_ids.append(line[1:].strip('\n'))
else:
if name:
aastr.append(line.strip('\n').upper())
name = False
else:
aastr[-1] = aastr[-1] + line.strip('\n').upper()

combined = sorted(list(zip(chain_ids, aastr)), key=lambda x: len(x[1]))
if output_dir is None:
parent = pathlib.Path(fasta_path).parent
folder_name = path_leaf(fasta_path).split('.')[0]
output_dir = os.path.join(parent, folder_name)
os.makedirs(output_dir, exist_ok=True)
try:
name_max = os.pathconf(output_dir, 'PC_NAME_MAX') - 4
except AttributeError:
# os.pathconf is UNIX specific. Set to 32 for now.
name_max = 32

for i, (ch, fas) in enumerate(combined):
fas = fas.replace('Z', 'E').replace('B', 'D').replace('U', 'C')
aatype = torch.LongTensor(
[rc.restypes_with_x.index(aa) if aa != '-' else 21 for aa in fas])
mask = torch.ones_like(aatype).float()
assert torch.all(aatype.ge(0)) and torch.all(aatype.le(21)), \
'Only take 0-20 amino acids as inputs with unknown amino acid ' \
'indexed as 20'
if len(ch) < name_max:
out_fname = ch.replace(os.path.sep, '-')
else:
out_fname = f'{i}th chain'
out_fname = os.path.join(output_dir, out_fname + '.pdb')

num_res = len(aatype)
data = list()
g = None
if deterministic:
g = torch.Generator()
g.manual_seed(num_res)
for _ in range(num_cycle):
p_msa = aatype[None, :].repeat(num_pseudo_msa, 1)
p_msa_mask = torch.rand([num_pseudo_msa, num_res],
generator=g).gt(mask_rate)
p_msa_mask = torch.cat((mask[None, :], p_msa_mask), dim=0)
p_msa = torch.cat((aatype[None, :], p_msa), dim=0)
p_msa[~p_msa_mask.bool()] = 21
data.append({'p_msa': p_msa, 'p_msa_mask': p_msa_mask})

yield utils.recursive_to(data, device=device), out_fname


# modify fasta2inputs to list2inputs
def list2inputs(
protein_list: typing.List[str],
output_dir: typing.Optional[str] = None,
num_pseudo_msa: int = 15,
device: typing.Optional[torch.device] = torch.device('cpu'),
mask_rate: float = 0.12,
num_cycle: int = 10,
deterministic: bool = True
) -> typing.Generator[typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
str], None, None]:
"""
Load a fasta file and

Args:
fasta_path: the path to the fasta files
output_dir: the path to the output directory
num_pseudo_msa:
device: the device to move
mask_rate:
num_cycle:
deterministic:

Returns:

"""
chain_ids: list[str] = []
aastr: list[str] = []

chain_ids = [f'chain_{i}' for i in range(len(protein_list))]
aastr = protein_list

combined = sorted(list(zip(chain_ids, aastr)), key=lambda x: len(x[1]))
name_max = 32

for i, (ch, fas) in enumerate(combined):
fas = fas.replace('Z', 'E').replace('B', 'D').replace('U', 'C')
aatype = torch.LongTensor(
[rc.restypes_with_x.index(aa) if aa != '-' else 21 for aa in fas])
mask = torch.ones_like(aatype).float()
assert torch.all(aatype.ge(0)) and torch.all(aatype.le(21)), \
'Only take 0-20 amino acids as inputs with unknown amino acid ' \
'indexed as 20'
if len(ch) < name_max:
out_fname = ch.replace(os.path.sep, '-')
else:
out_fname = f'{i}th chain'
out_fname = os.path.join(output_dir, out_fname + '.pdb')

num_res = len(aatype)
data = list()
g = None
if deterministic:
g = torch.Generator()
g.manual_seed(num_res)
for _ in range(num_cycle):
p_msa = aatype[None, :].repeat(num_pseudo_msa, 1)
p_msa_mask = torch.rand([num_pseudo_msa, num_res],
generator=g).gt(mask_rate)
p_msa_mask = torch.cat((mask[None, :], p_msa_mask), dim=0)
p_msa = torch.cat((aatype[None, :], p_msa), dim=0)
p_msa[~p_msa_mask.bool()] = 21
data.append({'p_msa': p_msa, 'p_msa_mask': p_msa_mask})

yield utils.recursive_to(data, device=device), out_fname


def save_pdb(pos14: torch.Tensor,
b_factors: torch.Tensor,
sequence: torch.Tensor,
mask: torch.Tensor,
save_path: str,
model: int = 0,
init_chain: str = 'A') -> None:
"""
saves the pos14 as a pdb file

Args:
pos14: the atom14 representation of the coordinates
b_factors: the b_factors of the amino acids
sequence: the amino acid of the pos14
mask: the validity of the atoms
save_path: the path to save the pdb file
model: the model id of the pdb file
init_chain

return:
the structure saved to ~save_path

"""
builder = StructureBuilder.StructureBuilder()
builder.init_structure(0)
builder.init_model(model)
builder.init_chain(init_chain)
builder.init_seg(' ')
for i, (aa_idx, p_res, b,
m_res) in enumerate(zip(sequence, pos14, b_factors, mask.bool())):
if not m_res:
continue
aa_idx = aa_idx.item()
p_res = p_res.clone().detach().cpu()
if aa_idx == 21:
continue
try:
three = rc.residx_to_3(aa_idx)
except IndexError:
continue
builder.init_residue(three, ' ', int(i), icode=' ')
for j, (atom_name, ) in enumerate(
zip(rc.restype_name_to_atom14_names[three])):
if len(atom_name) > 0:
builder.init_atom(atom_name,
p_res[j].tolist(),
b.item(),
1.0,
' ',
atom_name.join([' ', ' ']),
element=atom_name[0])
structure = builder.get_structure()
io = PDB.PDBIO()
io.set_structure(structure)
os.makedirs(pathlib.Path(save_path).parent, exist_ok=True)
io.save(save_path)


def _load_weights(
weights_url: str,
weights_file: str,
) -> collections.OrderedDict:
"""
Loads the weights from either a url or a local file. If from url,

Args:
weights_url: a url for the weights
weights_file: a local file

Returns:
state_dict: the state dict for the model

"""

weights_file = os.path.expanduser(weights_file)
use_cache = os.path.exists(weights_file)
if weights_file and weights_url and not use_cache:
logging.info(
f'Downloading weights from {weights_url} to {weights_file}')
os.makedirs(os.path.dirname(weights_file), exist_ok=True)
hub.download_url_to_file(weights_url, weights_file)
else:
logging.info(f'Loading weights from {weights_file}')

return torch.load(weights_file, map_location='cpu')


def _get_device(device) -> str:
"""
Infer the accelerator

Args:
device: the device type

Returns:

"""
if device is None:
if torch.cuda.is_available():
return 'cuda'
elif _mps_is_available():
return 'mps'
else:
return 'cpu'
elif device == 'cpu':
return device
elif device.startswith('cuda'):
if torch.cuda.is_available():
return device
else:
raise ValueError('Device cuda is not available')
elif device == 'mps':
if _mps_is_available():
return device
else:
raise ValueError('Device mps is not available')
else:
raise ValueError(f'Device type {device} is not available')


def get_args() -> typing.Tuple[types.SimpleNamespace, collections.OrderedDict,
types.SimpleNamespace]:

# 直接构造 args 对象,代替 argparse.Namespace
args = types.SimpleNamespace()
args.num_cycle = 10
args.subbatch_size = 448
args.device = None
# TODO: Modify the path of weights_file
args.weights_file = hf_hub_download('SciReason/OmegaFold-release',
'release2.pt',
repo_type='dataset')
args.weights = 'https://helixon.s3.amazonaws.com/release1.pt'
args.model = 2
args.pseudo_msa_mask_rate = 0.12
args.num_pseudo_msa = 15
args.allow_tf32 = True

_set_precision(args.allow_tf32)

if args.model == 1:
weights_url = 'https://helixon.s3.amazonaws.com/release1.pt'
if args.weights_file is None:
args.weights_file = os.path.expanduser(
'~/.cache/omegafold_ckpt/model.pt')
elif args.model == 2:
weights_url = 'https://helixon.s3.amazonaws.com/release2.pt'
if args.weights_file is None:
args.weights_file = os.path.expanduser(
'~/.cache/omegafold_ckpt/model2.pt')
else:
raise ValueError(
f'Model {args.model} is not available, only 1 or 2 supported.')

# 加载权重
weights = _load_weights(weights_url, args.weights_file)
weights = weights.pop('model', weights)

# 构造 forward_config
forward_config = types.SimpleNamespace(
subbatch_size=args.subbatch_size,
num_recycle=args.num_cycle,
)

# 自动设置设备
args.device = _get_device(args.device)

return args, weights, forward_config


# =============================================================================
# Classes
# =============================================================================
# =============================================================================
# Tests
# =============================================================================
if __name__ == '__main__':
pass

+ 51
- 0
opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/__init__.py View File

@@ -0,0 +1,51 @@
# =============================================================================
# Copyright 2022 HeliXon Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
# flake8: noqa
"""

"""
# =============================================================================
# Imports
# =============================================================================
from typing import Dict, Union # noqa: F401, F403

import torch # noqa: F401, F403

from ..utils.protein_utils import residue_constants # noqa: F401, F403
from ..utils.protein_utils.aaframe import AAFrame # noqa: F401, F403
from ..utils.protein_utils.functions import bit_wise_not # noqa: F401, F403
from ..utils.protein_utils.functions import \
robust_normalize # noqa: F401, F403
from ..utils.protein_utils.functions import create_pseudo_beta, get_norm
from ..utils.torch_utils import masked_mean # noqa: F401, F403
from ..utils.torch_utils import normalize # noqa: F401, F403
from ..utils.torch_utils import mask2bias, recursive_to

# =============================================================================
# Constants
# =============================================================================
DATA = Dict[str, Union[str, bool, torch.Tensor, AAFrame]]
# =============================================================================
# Functions
# =============================================================================
# =============================================================================
# Classes
# =============================================================================
# =============================================================================
# Tests
# =============================================================================
if __name__ == '__main__':
pass

+ 38
- 0
opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/protein_utils/__init__.py View File

@@ -0,0 +1,38 @@
# =============================================================================
# Copyright 2022 HeliXon Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""

"""
from ...utils.protein_utils import residue_constants # noqa: F401, F403
# =============================================================================
# Imports
# =============================================================================
from ...utils.protein_utils.aaframe import AAFrame # noqa: F401, F403

# =============================================================================
# Constants
# =============================================================================
# =============================================================================
# Functions
# =============================================================================
# =============================================================================
# Classes
# =============================================================================
# =============================================================================
# Tests
# =============================================================================
if __name__ == '__main__':
pass

+ 936
- 0
opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/protein_utils/aaframe.py View File

@@ -0,0 +1,936 @@
# =============================================================================
# Copyright 2022 HeliXon Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""
This script contains the Frame object, that acts as an essential part to
convert to full atom coordinates for amino acids.
This is inspired by Jumper et al. (2021), where the authors refer to this
object as rigid group/affine update, and we unify the two notions here.

Some codes adopted from
https://github.com/deepmind/alphafold/blob/main/alphafold/model/all_atom.py
"""
# =============================================================================
# Imports
# =============================================================================
from typing import List, Tuple, Union

import torch
from torch.nn import functional as F

from ...utils.protein_utils import functions as f
from ...utils.protein_utils import residue_constants as rc

# =============================================================================
# Functions
# =============================================================================
# =============================================================================
# Constant
# =============================================================================
_BACKBONE_ROTATE = torch.tensor([
[-1, 0., 0.],
[0., 1., 0.],
[0., 0., -1],
])


# =============================================================================
# Classes
# =============================================================================
class AAFrame(object):
"""
The transformation object that holds translation and rotation
"""

def __init__(self,
translation: torch.Tensor = None,
rotation: torch.Tensor = None,
mask: Union[torch.Tensor, torch.Tensor] = None,
safe: bool = True,
unit: str = 'Angstrom',
*,
expanded: bool = False) -> None:
"""
Initialize the transformation

Args:
translation (): the translation vector of shape (*, 3)
rotation (): the rotation vector of shape (*, 3, 3)
mask (): the torsion_angles_mask tensor indicating the presence of
the frame
safe (): if to use safe initialization, if unsafe, it"s faster
expanded (): if this frame is expanded to per-residue frames
"""
super(AAFrame, self).__init__()
self.orig = None
if safe:
self.mask = mask
self.translation = translation
self.rotation = rotation
else:
self._mask = mask
self._translation = translation
self._rotation = rotation

self.expanded_ = expanded
self._unit = unit

@property
def unit(self) -> str:
"""
Get the unit of the frame

Returns:
the current unit of this frame

"""
return self._unit

def _assign(self, translation: torch.Tensor, rotation: torch.Tensor,
unit: str, mask: torch.Tensor, in_place: bool,
orig: str) -> 'AAFrame':
"""
Create a new one or in-place assignment

Args:
translation: the translation (center) of the frame
rotation: the rotation of the frame
unit: the unit in which the frame operates
mask: the mask of the frames indicating which components are valid
in_place: if to perform the operation in-place
orig: the info of the origin of the new frame

Returns:
A new frame, if not in-place, or the original frame with the
attributes

"""
if in_place:
self._translation, self._rotation, = translation, rotation
self._unit, self._mask = unit, mask
return self
else:
return self._construct_frame(translation,
rotation,
mask,
orig=orig,
safe=True,
unit=unit)

def to_nanometers(self, in_place: bool = True) -> 'AAFrame':
"""
Move the nanometers

Args:
in_place: if to perform the operation in place.

Returns:

"""
if self._unit == 'Angstrom':
_translation = self._translation / 10
else:
_translation = self._translation
_unit = 'nano'
_rotation = self._rotation
_mask = self._mask
return self._assign(translation=_translation,
rotation=_rotation,
unit=_unit,
mask=_mask,
orig=f'To nano from {self}',
in_place=in_place)

def to_angstrom(self, in_place: bool) -> 'AAFrame':
"""
move to angstrom

Args:
in_place: if to use in_place operation

Returns:

"""
if self._unit == 'nano':
_translation = self._translation * 10
else:
_translation = self._translation
_unit = 'Angstrom'
_rotation = self._rotation
_mask = self._mask
return self._assign(translation=_translation,
rotation=_rotation,
unit=_unit,
mask=_mask,
orig=f'To nano from {self}',
in_place=in_place)

@property
def translation(self) -> torch.Tensor:
"""
Mask the ~self._translation by self.mask

Returns:

"""
return self._translation

@translation.setter
def translation(self, value: torch.Tensor) -> None:
"""
Assign the translation in the frame with masked values set to 0"s.

Args:
value: the translation value

"""
m = f.bit_wise_not(self.mask.unsqueeze(-1).expand_as(value))
self._translation = value.masked_fill(m, 0)

@property
def rotation(self) -> torch.Tensor:
"""
The rotation matrix

Returns:

"""
return self._rotation

@rotation.setter
def rotation(self, value: torch.Tensor) -> None:
"""
Assign the rotation in the frame with masked values set to identity
matrices.

Args:
value: the rotational matrices

"""
mask = f.bit_wise_not(self.mask[..., None, None].expand_as(value))
value = value.masked_fill(mask, 0.)
value = value.masked_fill(
mask * torch.eye(3, dtype=torch.bool).to(mask.device), 1)
self._rotation = value

@property
def mask(self) -> torch.Tensor:
"""
Hope this protects the attribute

Returns:

"""
return self._mask

@mask.setter
def mask(self, value: torch.Tensor):
self._mask = value.bool()

@classmethod
def default_init(
cls,
*shape,
unit: str = 'Angstrom',
safe: bool = True,
device: torch.device = torch.device('cpu'),
mask: Union[torch.Tensor, torch.Tensor] = None,
) -> 'AAFrame':
"""
partially initialize a bunch of frames, for now only supports one
dimensional

Args:
shape (): the shape of the frames,
mask (): the mask, if not provided, will be all true
device (): on which will the frame reside
safe (): if to safe init
unit (): the unit

Returns:

"""
if mask is not None:
assert tuple(mask.shape) == shape
translation = torch.zeros(list(shape) + [3], device=device)
rotation = torch.eye(3, dtype=translation.dtype,
device=device) * torch.ones(list(shape) + [1, 1],
device=device)
if mask is None:
mask = torch.ones_like(translation[..., 0], dtype=torch.bool)

return cls._construct_frame(trans=translation,
rots=rotation,
mask=mask,
orig='partially initialized',
safe=safe,
unit=unit)

@classmethod
def _neg_dim(cls, dim: int) -> Tuple[int, int, int]:
if dim < 0:
return dim, dim - 1, dim - 2
else:
return dim, dim, dim

def unsqueeze(self, dim: int) -> 'AAFrame':
"""
see torch.squeeze

Args:
dim ():

Returns:

"""
return self.dim_apply(torch.unsqueeze, dim=dim)

def sum(self, dim: int, keepdim: bool = False) -> 'AAFrame':
"""
see torch.sum

Args:
dim ():
keepdim ():

Returns:

"""
dim0, dim1, dim2 = self._neg_dim(dim)
m = torch.sum(self.mask, dim=dim0, keepdim=keepdim)
t = torch.sum(self.translation, dim=dim1, keepdim=keepdim)
r = torch.sum(self.rotation, dim=dim2, keepdim=keepdim)
return self._construct_frame(t,
r,
m,
f'Created by {torch.sum} at dim {dim}',
safe=False,
unit=self.unit) # from self

def dim_apply(self, func: callable, dim: int) -> 'AAFrame':
"""
Apply torch functionals to the translation and rotations

Args:
func (): the functional to apply to
dim (): the dimension to which the function will be applied

Returns:

"""
dim0, dim1, dim2 = self._neg_dim(dim)
m = func(self.mask, dim0)
t = func(self.translation, dim1)
r = func(self.rotation, dim2)
u = self.unit
return self._construct_frame(t,
r,
m,
f'Created by {func} at dim {dim}',
safe=False,
unit=u) # from self

@classmethod
def _construct_frame(
cls,
trans: torch.Tensor,
rots: torch.Tensor,
mask: Union[torch.Tensor, torch.Tensor],
orig: str,
safe: bool,
unit: str,
) -> 'AAFrame':
"""
Construct a frame

Args:
trans: the absolute position in the bigger frame
rots: the rotation of the frame
mask: the mask indicating the validity of the frame
orig: the message information about the origin of the frame
unit: the unit for initialize
safe: if use safe init

Returns:

"""
# assert t.shape[:-1] == r.shape[:-2] == m.shape
transformation = AAFrame(translation=trans,
rotation=rots,
mask=mask,
safe=safe,
unit=unit)
transformation.orig = orig

return transformation

@classmethod
def from_4x4(cls, m: torch.Tensor, mask: torch.Tensor,
unit: str) -> 'AAFrame':
"""
get the frames from 4x4 matrix

Args:
m (): the transformation in homogeneous coordinates
should be of shape (*, 4, 4)
mask (): the masking tensor
unit ():

Returns:
A transformation

"""

return cls._construct_frame(m[..., 0:3, 3],
m[..., 0:3, 0:3],
mask=mask,
orig='from matrix',
safe=True,
unit=unit)

def transform(self, pos: torch.Tensor) -> torch.Tensor:
"""
Apply the transformation on the input coordinates

Args:
pos (): the 3-D coordinates to transforms,
of shape (*, 3)

Note:
if we are using batched dims, we simply assume that the
dimensions of pos can be split into three parts
1. the batched_dims
2. the ones to do the outer-product-like expansion
3. the 3 xyz coordinate value

Returns:
transformed coordinates of the same shape as the coordinates,
of shape (N, 3)

Examples:
>>> frames = AAFrame(
... translation=torch.zeros(10,3),
... rotation=torch.eye(3)[None, ...].repeat(10, 1, 1),
... mask=torch.ones(10, dtype=torch.bool)
... )
>>> frames.shape
torch.Size([10])
>>> frames.transform(torch.randn(10, 3)).shape
torch.Size([10, 3]) # one-to-one
>>> frames.transform(torch.randn(10, 1, 3)).shape
torch.Size([10, 1, 3]) # it is still one-to-one
>>> frames.transform(torch.randn(1, 4, 3)).shape
torch.Size([10, 4, 3]) # this broadcasts to every pair,
# with the first dimension being the
# frames
>>> frames.transform(torch.randn(4, 1, 3)).shape
torch.Size([10, 1, 3])
>>> frames = AAFrame(
... translation=torch.zeros(10, 9, 3),
... rotation=torch.eye(3)[None, ...].repeat(10, 9, 1, 1),
... mask=torch.ones(10, 9, dtype=torch.bool)
... )
>>> frames.shape
torch.Size([10, 9])
>>> frames.transform(torch.randn(10, 9, 3)).shape
torch.Size([10, 9, 3])
>>> frames.transform(torch.randn(10, 1, 3)).shape
torch.Size([10, 9, 3]) # this broadcasts to 9, but does not
# work with shape (1, 9, 3)
>>> frames.transform(torch.randn(1, 1, 3)).shape
torch.Size([10, 9, 3]) #
>>> frames.transform(torch.randn(10, 9, 4, 3)).shape
torch.Size([10, 9, 4, 3])
>>> frames.transform(torch.randn(10, 1, 9, 4, 3)).shape
torch.Size([10, 9, 9, 4, 3]) # the 1st, 2nd dim are from frames
>>> frames.transform(torch.randn(10, 1, 1, 3)).shape
torch.Size([10, 9, 1, 3])
"""
batched_dims = len(self.shape)
shape1 = self.shape[:batched_dims]
shape2 = pos.shape[batched_dims:-1] # the ones to cross
self_shape2 = self.shape[batched_dims:]
out = self.view(*shape1, *[1 for _ in range(len(shape2))],
*self_shape2)
return f.batch_matrix_vector(out.rotation, pos) + out.translation

@classmethod
def from_torsion(
cls,
unit: str,
torsion_angles: torch.Tensor,
mask: Union[torch.Tensor, torch.Tensor],
translation: torch.Tensor = None,
) -> 'AAFrame':
"""
Create a transformation that rotates around the x-axis

Args:
unit ():
torsion_angles (): the torsion angle to create the axis with,
should be of shape (*, 2)
mask (): the masking tensor
translation (): optional, if provided will be passed in to the
transformation

Returns:
A rotation matrix around the x axis

"""
device = torsion_angles.device
_make_rot_mat = torch.tensor(
[
[0., 0., 0., 0., 0., -1, 0., 1., 0.], # sin
[0., 0., 0., 0., 1., 0., 0., 0., 1.], # cos
],
dtype=torsion_angles.dtype,
device=device)
rot_mat = torch.matmul(torsion_angles, _make_rot_mat)

rot_mat = rot_mat.unflatten(dim=-1, sizes=[3, 3])
rot_mat[..., 0, 0] = 1

if translation is None:
shape = list(torsion_angles.shape)
shape[-1] = 3
translation = torch.zeros(*shape, device=device)

return cls._construct_frame(translation,
rot_mat,
mask,
'from torsion',
safe=True,
unit=unit)

def __getitem__(self, idx: Union[slice, int, torch.Tensor]) -> 'AAFrame':
"""
Select the frame

Args:
idx (): the index of the selection

Returns:
selected transformation

"""
if isinstance(idx, (slice, int)):
return self._construct_frame(self.translation[..., idx, :],
self.rotation[..., idx, :, :],
self.mask[..., idx],
f'selected from {self} at {idx}',
unit=self.unit,
safe=False)
elif isinstance(idx, torch.Tensor):
return self._construct_frame(self.translation[idx, :],
self.rotation[idx, :, :],
self.mask[idx],
f'selected from {self} by tensor',
unit=self.unit,
safe=False)
else:
raise IndexError(f'Type {type(idx)} not supported for indexing')

def __setitem__(self, key: Union[int, torch.Tensor, List[int]],
value: Union[torch.Tensor, 'AAFrame']) -> None:
if isinstance(value, AAFrame):
t = value.translation.to(self._translation.dtype)
r = value.rotation.to(self._rotation.dtype)
m = value.mask.to(self._mask.dtype)
else:
t = r = value
m = bool(value)
mask = self.mask.clone()
translation = self.translation.clone()
rotation = self.rotation.clone()

if isinstance(key, int):
mask[..., key] = m
translation[..., key, :] = t
rotation[..., key, :, :] = r
elif isinstance(key, (torch.Tensor, list)):
# this because it cannot use in-place operations for gradients
mask[key] = m
translation[key, :] = t
rotation[key, :, :] = r

self.mask = mask
self.translation = translation
self.rotation = rotation

@property
def device(self) -> torch.device:
"""

Returns:

"""
assert (self._mask.device == self._translation.device ==
self._rotation.device)
return self._mask.device

@property
def shape(self) -> torch.Size:
"""

Returns: the shape of the tensor

"""
return self.mask.shape

def __mul__(self, other) -> 'AAFrame':
if isinstance(other, AAFrame):
return self._combine_transformation(other)
else:
return self._tensor_multiplication(other)

def _tensor_multiplication(self, other: torch.Tensor) -> 'AAFrame':
"""
Multiply everything by the tensor

Args:
other:

Returns:

"""
if torch.logical_or(torch.eq(other, 0), torch.eq(other, 1)).all():
m = self.mask * other
t = self.translation * other[..., None]
r = self.rotation * other[..., None, None]
else:
t = self.translation * other
m = self.mask
r = self.rotation

return self._construct_frame(t,
r,
m,
f'Created by multiplication from {self}',
safe=False,
unit=self.unit)

def _combine_transformation(self, other: 'AAFrame') -> 'AAFrame':
"""
Combine two frames

Args:
The following two arguments all have the transition of shape
(N, 3) at the first place and rotation matrix of shape (N, 3, 3) at
the second

other (): frame 1

Returns:
the end frame

"""
# the rotation
if self.shape != other.shape:
t_1 = self.translation[..., None, :].expand_as(other.translation)
r_1 = self.rotation[..., None, :, :].expand_as(other.rotation)
m_1 = self.mask[..., None].expand_as(other.mask).reshape(-1)
t_1, r_1 = t_1.reshape(-1, 3), r_1.reshape(-1, 3, 3)
else:
t_1, r_1, m_1 = self.translation, self.rotation, self.mask
t_1, r_1, m_1 = t_1.view(-1, 3), r_1.view(-1, 3, 3), m_1.view(-1)

if self.unit == 'Angstrom':
other.to_angstrom(in_place=True)
else:
other.to_nanometers(in_place=True)
t_2, r_2 = other.translation.view(-1, 3), other.rotation.view(-1, 3, 3)
m_2 = other.mask.view(-1)

r_out = torch.bmm(r_1, r_2)
# the transition
t_out = t_1 + f.batch_matrix_vector(r_1, t_2)
# the torsion_angles_mask
m_out = m_1 * m_2

return self._construct_frame(t_out.view(*other.shape, 3),
r_out.view(*other.shape, 3, 3),
m_out.view(*other.shape),
f'Combination of {self} and {other}',
safe=False,
unit=self.unit)

def __repr__(self) -> str:
return f'Frame {id(self)}'

def view(self, *args) -> 'AAFrame':
"""
See Tensor.view

Args:
*args ():

Returns:

"""
mask = self.mask
translation = self.translation
rotation = self.rotation
return self._construct_frame(translation.view(*args, 3),
rotation.view(*args, 3, 3),
mask.view(*args),
f'view from {self}',
safe=False,
unit=self.unit)

@property
def dtype(self):
return self.translation.dtype

def expand_w_torsion(self, torsion_angles: torch.Tensor,
torsion_angles_mask: torch.Tensor,
fasta: torch.Tensor) -> 'AAFrame':
r"""
Compute the global frame

Lines 2-10
Algorithm 24, Page 31 of the AlphaFold 2 supplementary material

Args:
self (): the transformation from backbone to global
bb_coor (): the transition of the backbone transformation, or
the coordinates of the CA atom,
should be of shape (N, 3)
bb_rot (): the rotation of the backbone transformation,
should be of shape (N, 3, 3)
torsion_angles (): the torsion angles
(\omega, \phi, \psi, \chi_1, \chi_2, \chi_3, \chi_4)
should be of shape (N, 7, 2)
torsion_angles_mask (): the torsion angle masks indicating presence
(\omega, \phi, \psi, \chi_1, \chi_2, \chi_3, \chi_4)
should be of shape (N, 7)
fasta (): input sequence where each place is an index indicating
which amino acid is in each position, following ~restypes

Returns:
Frame

"""
assert self.unit == 'Angstrom'
if torsion_angles.shape[-2] == 5:
torsion_angles = torch.cat((torch.zeros_like(
torsion_angles[..., 0:2, :]), torsion_angles),
dim=-2)
torsion_angles_mask = torch.cat((torch.zeros_like(
torsion_angles_mask[..., 0:2]), torsion_angles_mask),
dim=-1)

# append an identity for backbone2backbone
shape = list(torsion_angles.shape)
shape[-2] = 1
angle = torch.tensor([[0, 1]], dtype=self.dtype,
device=self.device).expand(shape) # (*, 1, 2)
angle_mask = torch.tensor([True], dtype=torch.bool,
device=self.device).expand(shape[:-1])
torsion_angles = torch.cat((angle, torsion_angles), -2) # (*, 8, 2)
torsion_angles_mask = torch.cat((angle_mask, torsion_angles_mask), -1)

# prepare the angles
torsion_angles = f.robust_normalize(torsion_angles)
rot_x = AAFrame.from_torsion(torsion_angles=torsion_angles,
mask=torsion_angles_mask,
unit='Angstrom')

# make extra backbone frames
# This follows the order of ~restypes
m = rc.restype_aa_default_frame.to(self.device)[fasta]
default_frames = AAFrame.from_4x4(m,
torsion_angles_mask,
unit='Angstrom')
all_frames = default_frames * rot_x
# make side chain frames (chain them up along the side chain)
chi2_frame_to_frame = all_frames[5]
chi3_frame_to_frame = all_frames[6]
chi4_frame_to_frame = all_frames[7]
# chains
chi1_frame_to_backb = all_frames[4]
chi2_frame_to_backb = chi1_frame_to_backb * chi2_frame_to_frame
chi3_frame_to_backb = chi2_frame_to_backb * chi3_frame_to_frame
chi4_frame_to_backb = chi3_frame_to_backb * chi4_frame_to_frame

# all_frames[4] = chi1_f2bb
all_frames[5] = chi2_frame_to_backb
all_frames[6] = chi3_frame_to_backb
all_frames[7] = chi4_frame_to_backb
# get all
# map atom literature positions to the global frame
all_f2global = self * all_frames
all_f2global.expanded_ = True

return all_f2global

def rotate(self, rotation: torch.Tensor):
"""
Rotate with a rotation matrix

Note:
batched rotated not yet supported,
for now just use ~Frame._construct_transformation

Args:
rotation (): the rotation matrix of shape (d, d)

Returns:
Rotated frame

"""
if len(rotation.shape) == 2:
t = self.translation
r = torch.matmul(self.rotation, rotation)
return self._construct_frame(t,
r,
self.mask,
f'Rotated from {self}',
safe=False,
unit=self.unit)
else:
raise NotImplementedError('Not yet implemented')

def expanded_to_pos(
self,
fasta: torch.Tensor,
full: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute the full atom representation

Args:
fasta: the sequence to compute the atoms
full: if to use safe initialization

Returns:
the atom14 representation and the mask indicating the presence
of the atoms

"""
if full:
assert self.expanded_
num_classes = 8
frame = self
pos_counts = 14
else:
num_classes = 1
frame = self.unsqueeze(-1)
pos_counts = 5

assert self._unit == 'Angstrom'

fasta = fasta.cpu()
residx2group = rc.restype_atom14_to_aa
residx2group = residx2group[..., :pos_counts]
residx2group = residx2group[fasta].to(self.device)
group_mask = F.one_hot(residx2group, num_classes=8)
group_mask = group_mask[..., :num_classes]
group_mask = group_mask * frame.mask[..., None, :]
to_mask = frame.unsqueeze(-2) * group_mask
map_atoms_to_global = to_mask.sum(-1)
lit_pos = rc.restype_atom14_aa_positions
lit_pos = lit_pos[..., :pos_counts, :]
lit_pos = lit_pos[fasta].to(self.device)
pred_pos = map_atoms_to_global.transform(lit_pos)
# mask = c.restype_atom14_mask[sequence] # (N, 14)
# mask |= self.mask[..., None]
pred_pos = pred_pos * map_atoms_to_global.mask[..., None]

return pred_pos, torsion_mask_to_atom14_mask(frame.mask,
group_mask,
fasta=fasta)

def __len__(self):
return len(self.mask)

@property
def inverse(self) -> 'AAFrame':
"""
The inverse of the transformation

Returns:

"""
r = self.rotation.transpose(-1, -2)
t = f.batch_matrix_vector(r, self.translation)
return self._construct_frame(-t,
r,
self.mask,
f'inversed from {self}',
safe=False,
unit=self.unit)

def position_in_frame(self, pos: torch.Tensor) -> torch.Tensor:
"""
Get the frame-based position of the given global position

Args:
pos (): the global position of shape (*, 3)

Returns:
the result

"""
return self.inverse.transform(pos)

@classmethod
def from_tensor(cls, tensor, unit: str) -> 'AAFrame':
"""
Args:
tensor: (*, 7)
unit:
"""
q_dim = 4 if tensor.shape[-1] == 7 else 3
quaternion, tx, ty, tz = torch.split(tensor, [q_dim, 1, 1, 1], dim=-1)
rotation = f.quaternion_to_matrix(quaternion)
translation = torch.stack([tx[..., 0], ty[..., 0], tz[..., 0]], dim=-1)

return cls._construct_frame(trans=translation,
rots=rotation,
mask=torch.ones_like(translation[..., 0]),
orig='from tensor',
safe=True,
unit=unit)


def torsion_mask_to_atom14_mask(torsion_mask: torch.Tensor,
group_mask: torch.Tensor,
fasta: torch.Tensor) -> torch.Tensor:
"""
expand the mask of torsion angles into atom14 masks

Args:
torsion_mask (): the mask for torsion angles, of shape (*, 8)
group_mask (): the group mask to add on, of shape (*, 14, 8)
fasta (): the sequence for this operation

Returns:
Expanded mask of shape (*, 14)

"""
atom14_exist_mask = group_mask[..., 1:].sum(-1)
atom14_exist_mask[..., 4] = fasta != 7
atom14_exist_mask[..., 0:3] = torsion_mask[..., 0:1]
return atom14_exist_mask.bool()


# =============================================================================
# Functions
# =============================================================================

# =============================================================================
# Tests
# =============================================================================
if __name__ == '__main__':
pass

+ 149
- 0
opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/protein_utils/functions.py View File

@@ -0,0 +1,149 @@
# =============================================================================
# Copyright 2022 HeliXon Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""
This script contains some functions that may be handy somewhere
"""
# =============================================================================
# Constants
# =============================================================================
# =============================================================================
# Imports
# =============================================================================
import typing

import torch


# =============================================================================
# Functions
# =============================================================================
def get_norm(x: torch.Tensor) -> torch.Tensor:
"""
Replacement for LA.norm since MPS does not support it yet.

Args:
x:

Returns:

"""
return x.norm(p=2, dim=-1)


def robust_normalize(x: torch.Tensor,
dim: int = -1,
p: typing.Union[int, str] = 2) -> torch.Tensor:
"""
Normalization with a constant small term on the denominator

Args:
x (): tensor to normalize
dim (): the dimension along which to perform the normalization
p (): the p in l-p

Returns:
the normalized result

"""
return x / (x.norm(p=p, dim=dim, keepdim=True).clamp(4e-5))


def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as quaternions to rotation matrices.

# The following from PyTorch3d
Args:
quaternions: quaternions with real part first,
as tensor of shape (..., 4) or (..., 3).

Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
if quaternions.shape[-1] == 3:
quaternions = torch.cat(
(torch.ones_like(quaternions[..., 0:1]), quaternions), dim=-1)
r, i, j, k = torch.unbind(quaternions, -1)
two_s = 2.0 / (quaternions * quaternions).sum(-1)

o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))


def batch_matrix_vector(matrix: torch.Tensor,
vector: torch.Tensor) -> torch.Tensor:
"""
Perform batched matrix vector product on the last dimension

Args:
matrix (): of shape (*, d, d)
vector (): of shape (*, d)

Returns:
the product of the two

"""
assert len(matrix.shape[:-2]) == len(vector.shape[:-1])

return torch.einsum('...cd, ...d -> ...c', matrix, vector)


def create_pseudo_beta(atom_pos: torch.Tensor,
atom_mask: torch.Tensor) -> torch.Tensor:
"""

Args:
atom_pos: the atom position in atom14 format,
of shape [*, num_res, 14, 3]
atom_mask: the atom mask in atom14 format,
of shape [*, num_res, 14]

Returns:
CB coordinate (when available) and CA coordinate (when not available)

"""
if not (atom_mask.shape[-1] == atom_pos.shape[-2] == 14):
raise ValueError('Only supports atom 14')
pseudo_beta = torch.where(
atom_mask[..., 4:5].expand(list(atom_mask.shape[:-1]) + [3]).bool(),
atom_pos[..., 4, :], atom_pos[..., 1, :])
return pseudo_beta


def bit_wise_not(boolean_tensor: torch.Tensor) -> torch.Tensor:
"""For MPS devices that have no support for yet bit-wise not"""
boolean_tensor = 1 - boolean_tensor.float()
return boolean_tensor.bool()


# =============================================================================
# Tests
# =============================================================================
if __name__ == '__main__':
pass

+ 686
- 0
opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/protein_utils/residue_constants.py View File

@@ -0,0 +1,686 @@
# =============================================================================
# Copyright 2022 HeliXon Limited
# This file is adopted from DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Constants used in OmegaFold."""
import Bio.PDB
import torch
# Internal import (35fd).
# Distance from one CA to next CA [trans configuration: omega = 180].
from Bio.Data import PDBData

ca_ca = 3.80209737096

# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in
# this order (or a relevant subset from chi1 onwards). ALA and GLY don"t have
# chi angles so their chi angle lists are empty.
chi_angles_atoms = {
'ALA': [],
# Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
'ARG': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
['CB', 'CG', 'CD', 'NE'], ['CG', 'CD', 'NE', 'CZ']],
'ASN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']],
'ASP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']],
'CYS': [['N', 'CA', 'CB', 'SG']],
'GLN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
['CB', 'CG', 'CD', 'OE1']],
'GLU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
['CB', 'CG', 'CD', 'OE1']],
'GLY': [],
'HIS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'ND1']],
'ILE': [['N', 'CA', 'CB', 'CG1'], ['CA', 'CB', 'CG1', 'CD1']],
'LEU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
'LYS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
['CB', 'CG', 'CD', 'CE'], ['CG', 'CD', 'CE', 'NZ']],
'MET': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'SD'],
['CB', 'CG', 'SD', 'CE']],
'PHE': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
'PRO': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD']],
'SER': [['N', 'CA', 'CB', 'OG']],
'THR': [['N', 'CA', 'CB', 'OG1']],
'TRP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
'TYR': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
'VAL': [['N', 'CA', 'CB', 'CG1']],
}

# If chi angles given in fixed-length array, this matrix determines how to
# torsion_angles_mask them for each AA type. The order is as per
# restype_order (see below).
chi_angles_mask = torch.tensor([
[0.0, 0.0, 0.0, 0.0], # ALA
[1.0, 1.0, 1.0, 1.0], # ARG
[1.0, 1.0, 0.0, 0.0], # ASN
[1.0, 1.0, 0.0, 0.0], # ASP
[1.0, 0.0, 0.0, 0.0], # CYS
[1.0, 1.0, 1.0, 0.0], # GLN
[1.0, 1.0, 1.0, 0.0], # GLU
[0.0, 0.0, 0.0, 0.0], # GLY
[1.0, 1.0, 0.0, 0.0], # HIS
[1.0, 1.0, 0.0, 0.0], # ILE
[1.0, 1.0, 0.0, 0.0], # LEU
[1.0, 1.0, 1.0, 1.0], # LYS
[1.0, 1.0, 1.0, 0.0], # MET
[1.0, 1.0, 0.0, 0.0], # PHE
[1.0, 1.0, 0.0, 0.0], # PRO
[1.0, 0.0, 0.0, 0.0], # SET
[1.0, 0.0, 0.0, 0.0], # THR
[1.0, 1.0, 0.0, 0.0], # TRP
[1.0, 1.0, 0.0, 0.0], # TYR
[1.0, 0.0, 0.0, 0.0], # VAL
[0.0, 0.0, 0.0, 0.0], # UNK
])

# Atoms positions relative to the 8 rigid groups, defined by the pre-omega,
# phi, psi and chi angles:
# 0: "backbone group",
# 1: "pre-omega-group", (empty)
# 2: "phi-group", (currently empty, because it defines only hydrogens)
# 3: "psi-group",
# 4,5,6,7: "chi1,2,3,4-group"
# The atom positions are relative to the axis-end-atom of the corresponding
# rotation axis. The x-axis is in direction of the rotation axis, and the
# y-axis is defined such that the dihedral-angle-definiting atom (the last
# entry in chi_angles_atoms above) is in the xy-plane (with a positive
# y-coordinate). format: [atomname, group_idx, rel_position]
aa_atom_positions = {
'ALA': [
['N', 0, (-0.525, 1.363, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.526, -0.000, -0.000)],
['CB', 0, (-0.529, -0.774, -1.205)],
['O', 3, (0.627, 1.062, 0.000)],
],
'ARG': [
['N', 0, (-0.524, 1.362, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.525, -0.000, -0.000)],
['CB', 0, (-0.524, -0.778, -1.209)],
['O', 3, (0.626, 1.062, 0.000)],
['CG', 4, (0.616, 1.390, -0.000)],
['CD', 5, (0.564, 1.414, 0.000)],
['NE', 6, (0.539, 1.357, -0.000)],
['NH1', 7, (0.206, 2.301, 0.000)],
['NH2', 7, (2.078, 0.978, -0.000)],
['CZ', 7, (0.758, 1.093, -0.000)],
],
'ASN': [
['N', 0, (-0.536, 1.357, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.526, -0.000, -0.000)],
['CB', 0, (-0.531, -0.787, -1.200)],
['O', 3, (0.625, 1.062, 0.000)],
['CG', 4, (0.584, 1.399, 0.000)],
['ND2', 5, (0.593, -1.188, 0.001)],
['OD1', 5, (0.633, 1.059, 0.000)],
],
'ASP': [
['N', 0, (-0.525, 1.362, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.527, 0.000, -0.000)],
['CB', 0, (-0.526, -0.778, -1.208)],
['O', 3, (0.626, 1.062, -0.000)],
['CG', 4, (0.593, 1.398, -0.000)],
['OD1', 5, (0.610, 1.091, 0.000)],
['OD2', 5, (0.592, -1.101, -0.003)],
],
'CYS': [
['N', 0, (-0.522, 1.362, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.524, 0.000, 0.000)],
['CB', 0, (-0.519, -0.773, -1.212)],
['O', 3, (0.625, 1.062, -0.000)],
['SG', 4, (0.728, 1.653, 0.000)],
],
'GLN': [
['N', 0, (-0.526, 1.361, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.526, 0.000, 0.000)],
['CB', 0, (-0.525, -0.779, -1.207)],
['O', 3, (0.626, 1.062, -0.000)],
['CG', 4, (0.615, 1.393, 0.000)],
['CD', 5, (0.587, 1.399, -0.000)],
['NE2', 6, (0.593, -1.189, -0.001)],
['OE1', 6, (0.634, 1.060, 0.000)],
],
'GLU': [
['N', 0, (-0.528, 1.361, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.526, -0.000, -0.000)],
['CB', 0, (-0.526, -0.781, -1.207)],
['O', 3, (0.626, 1.062, 0.000)],
['CG', 4, (0.615, 1.392, 0.000)],
['CD', 5, (0.600, 1.397, 0.000)],
['OE1', 6, (0.607, 1.095, -0.000)],
['OE2', 6, (0.589, -1.104, -0.001)],
],
'GLY': [
['N', 0, (-0.572, 1.337, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.517, -0.000, -0.000)],
['O', 3, (0.626, 1.062, -0.000)],
],
'HIS': [
['N', 0, (-0.527, 1.360, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.525, 0.000, 0.000)],
['CB', 0, (-0.525, -0.778, -1.208)],
['O', 3, (0.625, 1.063, 0.000)],
['CG', 4, (0.600, 1.370, -0.000)],
['CD2', 5, (0.889, -1.021, 0.003)],
['ND1', 5, (0.744, 1.160, -0.000)],
['CE1', 5, (2.030, 0.851, 0.002)],
['NE2', 5, (2.145, -0.466, 0.004)],
],
'ILE': [
['N', 0, (-0.493, 1.373, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.527, -0.000, -0.000)],
['CB', 0, (-0.536, -0.793, -1.213)],
['O', 3, (0.627, 1.062, -0.000)],
['CG1', 4, (0.534, 1.437, -0.000)],
['CG2', 4, (0.540, -0.785, -1.199)],
['CD1', 5, (0.619, 1.391, 0.000)],
],
'LEU': [
['N', 0, (-0.520, 1.363, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.525, -0.000, -0.000)],
['CB', 0, (-0.522, -0.773, -1.214)],
['O', 3, (0.625, 1.063, -0.000)],
['CG', 4, (0.678, 1.371, 0.000)],
['CD1', 5, (0.530, 1.430, -0.000)],
['CD2', 5, (0.535, -0.774, 1.200)],
],
'LYS': [
['N', 0, (-0.526, 1.362, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.526, 0.000, 0.000)],
['CB', 0, (-0.524, -0.778, -1.208)],
['O', 3, (0.626, 1.062, -0.000)],
['CG', 4, (0.619, 1.390, 0.000)],
['CD', 5, (0.559, 1.417, 0.000)],
['CE', 6, (0.560, 1.416, 0.000)],
['NZ', 7, (0.554, 1.387, 0.000)],
],
'MET': [
['N', 0, (-0.521, 1.364, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.525, 0.000, 0.000)],
['CB', 0, (-0.523, -0.776, -1.210)],
['O', 3, (0.625, 1.062, -0.000)],
['CG', 4, (0.613, 1.391, -0.000)],
['SD', 5, (0.703, 1.695, 0.000)],
['CE', 6, (0.320, 1.786, -0.000)],
],
'PHE': [
['N', 0, (-0.518, 1.363, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.524, 0.000, -0.000)],
['CB', 0, (-0.525, -0.776, -1.212)],
['O', 3, (0.626, 1.062, -0.000)],
['CG', 4, (0.607, 1.377, 0.000)],
['CD1', 5, (0.709, 1.195, -0.000)],
['CD2', 5, (0.706, -1.196, 0.000)],
['CE1', 5, (2.102, 1.198, -0.000)],
['CE2', 5, (2.098, -1.201, -0.000)],
['CZ', 5, (2.794, -0.003, -0.001)],
],
'PRO': [
['N', 0, (-0.566, 1.351, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.527, -0.000, 0.000)],
['CB', 0, (-0.546, -0.611, -1.293)],
['O', 3, (0.621, 1.066, 0.000)],
['CG', 4, (0.382, 1.445, 0.0)],
# ["CD", 5, (0.427, 1.440, 0.0)],
['CD', 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger
],
'SER': [
['N', 0, (-0.529, 1.360, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.525, -0.000, -0.000)],
['CB', 0, (-0.518, -0.777, -1.211)],
['O', 3, (0.626, 1.062, -0.000)],
['OG', 4, (0.503, 1.325, 0.000)],
],
'THR': [
['N', 0, (-0.517, 1.364, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.526, 0.000, -0.000)],
['CB', 0, (-0.516, -0.793, -1.215)],
['O', 3, (0.626, 1.062, 0.000)],
['CG2', 4, (0.550, -0.718, -1.228)],
['OG1', 4, (0.472, 1.353, 0.000)],
],
'TRP': [
['N', 0, (-0.521, 1.363, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.525, -0.000, 0.000)],
['CB', 0, (-0.523, -0.776, -1.212)],
['O', 3, (0.627, 1.062, 0.000)],
['CG', 4, (0.609, 1.370, -0.000)],
['CD1', 5, (0.824, 1.091, 0.000)],
['CD2', 5, (0.854, -1.148, -0.005)],
['CE2', 5, (2.186, -0.678, -0.007)],
['CE3', 5, (0.622, -2.530, -0.007)],
['NE1', 5, (2.140, 0.690, -0.004)],
['CH2', 5, (3.028, -2.890, -0.013)],
['CZ2', 5, (3.283, -1.543, -0.011)],
['CZ3', 5, (1.715, -3.389, -0.011)],
],
'TYR': [
['N', 0, (-0.522, 1.362, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.524, -0.000, -0.000)],
['CB', 0, (-0.522, -0.776, -1.213)],
['O', 3, (0.627, 1.062, -0.000)],
['CG', 4, (0.607, 1.382, -0.000)],
['CD1', 5, (0.716, 1.195, -0.000)],
['CD2', 5, (0.713, -1.194, -0.001)],
['CE1', 5, (2.107, 1.200, -0.002)],
['CE2', 5, (2.104, -1.201, -0.003)],
['OH', 5, (4.168, -0.002, -0.005)],
['CZ', 5, (2.791, -0.001, -0.003)],
],
'VAL': [
['N', 0, (-0.494, 1.373, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.527, -0.000, -0.000)],
['CB', 0, (-0.533, -0.795, -1.213)],
['O', 3, (0.627, 1.062, -0.000)],
['CG1', 4, (0.540, 1.429, -0.000)],
['CG2', 4, (0.533, -0.776, 1.203)],
],
}

for aa_k, aa_dict, in aa_atom_positions.items():
for i, v in enumerate(aa_dict):
aa_dict[i][-1] = torch.tensor(v[-1])
aa_atom_positions[aa_k] = aa_dict

# This mapping is used when we need to store atom data in a format that
# requires fixed atom data size for every residue (e.g. a numpy array).
atom_types = [
'N', 'CA', 'C', 'CB', 'O', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD',
'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3',
'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2',
'CZ3', 'NZ', 'OXT'
]
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
atom_type_num = len(atom_types) # := 37.

# A compact atom encoding with 14 columns
# pylint: disable=line-too-long
# pylint: disable=bad-whitespace
restype_name_to_atom14_names = {
'ALA': ['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', ''],
'ARG': [
'N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2', '',
'', ''
],
'ASN':
['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', ''],
'ASP':
['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', ''],
'CYS': ['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', ''],
'GLN':
['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', ''],
'GLU':
['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', ''],
'GLY': ['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', ''],
'HIS': [
'N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2', '', '',
'', ''
],
'ILE':
['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', ''],
'LEU':
['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', ''],
'LYS':
['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', ''],
'MET':
['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', ''],
'PHE': [
'N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', '',
'', ''
],
'PRO': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', ''],
'SER': ['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', ''],
'THR':
['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', ''],
'TRP': [
'N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3',
'CZ2', 'CZ3', 'CH2'
],
'TYR': [
'N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ',
'OH', '', ''
],
'VAL':
['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', ''],
'UNK': ['', '', '', '', '', '', '', '', '', '', '', '', '', ''],
}
# pylint: enable=line-too-long
# pylint: enable=bad-whitespace

# This is the standard residue order when coding AA type as a number.
# Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
restypes = [
'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P',
'S', 'T', 'W', 'Y', 'V'
]
restype_order = {restype: i for i, restype in enumerate(restypes)}
restype_num = len(restypes) # := 20.
unk_restype_index = restype_num # Catch-all index for unknown restypes.

restypes_with_x = restypes + ['X', '-']
restype_order_with_x = {
restype: i
for i, restype in enumerate(restypes_with_x)
}
restype_1to3 = {
'A': 'ALA',
'R': 'ARG',
'N': 'ASN',
'D': 'ASP',
'C': 'CYS',
'Q': 'GLN',
'E': 'GLU',
'G': 'GLY',
'H': 'HIS',
'I': 'ILE',
'L': 'LEU',
'K': 'LYS',
'M': 'MET',
'F': 'PHE',
'P': 'PRO',
'S': 'SER',
'T': 'THR',
'W': 'TRP',
'Y': 'TYR',
'V': 'VAL',
'X': 'UNK'
}

# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple
# 1-to-1 mapping of 3 letter names to one letter names. The latter contains
# many more, and less common, three letter names as keys and maps many of these
# to the same one letter name (including "X" and "U" which we don"t use here).
restype_3to1 = {v: k for k, v in restype_1to3.items()}

restype2atom_mask = torch.zeros([len(restypes_with_x), 14])
for k, v in restype_name_to_atom14_names.items():
for i, atom in enumerate(v):
restype2atom_mask[restype_order_with_x[
restype_3to1[k]]][i] = len(atom) > 0

restype_rigidgroup_mask = torch.zeros([21, 8], dtype=torch.float)
restype_rigidgroup_mask[:, 0] = 1
restype_rigidgroup_mask[:, 3] = 1
restype_rigidgroup_mask[:, 4:] = chi_angles_mask


# Compute a mask whether the group exists.
# (N, 8)
def residx_to_3(idx):
return restype_1to3[restypes[idx]]


# Define a restype name for all unknown residues.
unk_restype = 'UNK'

resnames = [restype_1to3[r] for r in restypes] + [unk_restype]
resname_to_idx = {resname: i for i, resname in enumerate(resnames)}


def get_chi_angle_atom_indices():
"""Returns atom indices needed to compute chi angles for all residue types.

Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue
types are in the order specified in residue_constants.restypes +
unknown residue type at the end. For chi angles which are not defined
on the residue, the positions indices are by default set to 0.
"""
chi_atom_indices = []
for residue_name in restypes:
residue_name = restype_1to3[residue_name]
residue_chi_angles = chi_angles_atoms[residue_name]
atom_indices = []
for chi_angle in residue_chi_angles:
atom_indices.append([atom_order[_atom] for _atom in chi_angle])
for _ in range(4 - len(atom_indices)):
atom_indices.append([0, 0, 0, 0]) # For those not defined on AA.
chi_atom_indices.append(atom_indices)

chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.

return torch.tensor(chi_atom_indices)


chi_angle_atom_indices = get_chi_angle_atom_indices()


def _make_rigid_transformation_4x4(ex: torch.Tensor, ey: torch.Tensor,
translation: torch.Tensor) -> torch.Tensor:
"""Create a rigid 4x4 transformation matrix from two axes and transl."""
# Normalize ex.
ex_normalized = ex / torch.linalg.norm(ex)

# make ey perpendicular to ex
ey_normalized = ey - torch.dot(ey, ex_normalized) * ex_normalized
ey_normalized /= torch.linalg.norm(ey_normalized)

# compute ez as cross product
eznorm = torch.cross(ex_normalized, ey_normalized)
m = torch.stack([ex_normalized, ey_normalized, eznorm, translation]).T
m = torch.cat([m, torch.tensor([[0., 0., 0., 1.]])], dim=0)
return m


# create an array with (restype, atomtype) --> aa_idx
# and an array with (restype, atomtype, coord) for the atom positions
# and compute affine transformation matrices (4,4) from one rigid group to the
# previous group
restype_atom37_to_aa = torch.zeros([21, 37], dtype=torch.long)
restype_atom37_mask = torch.zeros([21, 37], dtype=torch.float32)
restype_atom37_aa_positions = torch.zeros([21, 37, 3], dtype=torch.float32)
restype_atom14_to_aa = torch.zeros([21, 14], dtype=torch.long)
restype_atom14_mask = torch.zeros([21, 14], dtype=torch.float32)
restype_atom14_aa_positions = torch.zeros([21, 14, 3], dtype=torch.float32)
restype_aa_default_frame = torch.zeros([21, 8, 4, 4], dtype=torch.float32)


def _make_aa_constants():
"""Fill the arrays above."""
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
for atomname, group_idx, atom_pos in aa_atom_positions[resname]:
atomtype = atom_order[atomname]
restype_atom37_to_aa[restype, atomtype] = group_idx
restype_atom37_mask[restype, atomtype] = 1
restype_atom37_aa_positions[restype, atomtype, :] = atom_pos

atom14idx = restype_name_to_atom14_names[resname].index(atomname)
restype_atom14_to_aa[restype, atom14idx] = group_idx
restype_atom14_mask[restype, atom14idx] = 1
restype_atom14_aa_positions[restype, atom14idx, :] = atom_pos

for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
atom_positions = {
name: pos
for name, _, pos in aa_atom_positions[resname]
}

# backbone to backbone is the identity transforms
restype_aa_default_frame[restype, 0, :, :] = torch.eye(4)

# pre-omega-frame to backbone (currently dummy identity matrix)
restype_aa_default_frame[restype, 1, :, :] = torch.eye(4)

# phi-frame to backbone
mat = _make_rigid_transformation_4x4(ex=atom_positions['N'] -
atom_positions['CA'],
ey=torch.tensor([1., 0., 0.]),
translation=atom_positions['N'])
restype_aa_default_frame[restype, 2, :, :] = mat

# psi-frame to backbone
mat = _make_rigid_transformation_4x4(
ex=atom_positions['C'] - atom_positions['CA'],
ey=atom_positions['CA'] - atom_positions['N'],
translation=atom_positions['C'])
restype_aa_default_frame[restype, 3, :, :] = mat

# chi1-frame to backbone
if chi_angles_mask[restype][0]:
base_atom_names = chi_angles_atoms[resname][0]
base_atom_positions = [
atom_positions[name] for name in base_atom_names
]
mat = _make_rigid_transformation_4x4(
ex=base_atom_positions[2] - base_atom_positions[1],
ey=base_atom_positions[0] - base_atom_positions[1],
translation=base_atom_positions[2])
restype_aa_default_frame[restype, 4, :, :] = mat

# chi2-frame to chi1-frame
# chi3-frame to chi2-frame
# chi4-frame to chi3-frame
# luckily all rotation axes for the next frame start at (0,0,0) of the
# previous frame
for chi_idx in range(1, 4):
if chi_angles_mask[restype][chi_idx]:
axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]
axis_end_atom_position = atom_positions[axis_end_atom_name]
mat = _make_rigid_transformation_4x4(
ex=axis_end_atom_position,
ey=torch.tensor([-1., 0., 0.]),
translation=axis_end_atom_position)
restype_aa_default_frame[restype, 4 + chi_idx, :, :] = mat


_make_aa_constants()
"""Construct denser atom positions (14 dimensions instead of 37)."""
restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37
restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14

for rt in restypes:
atom_names = restype_name_to_atom14_names[restype_1to3[rt]]

restype_atom14_to_atom37.append([(atom_order[name] if name else 0)
for name in atom_names])

atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append([
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in atom_types
])

# Add dummy mapping for restype "UNK"
restype_atom14_to_atom37.append([0] * 14)
restype_atom37_to_atom14.append([0] * 37)

restype_atom14_to_atom37 = torch.tensor(restype_atom14_to_atom37,
dtype=torch.long)
restype_atom37_to_atom14 = torch.tensor(restype_atom37_to_atom14,
dtype=torch.long)
chi_pi_periodic = torch.tensor([
[0.0, 0.0, 0.0, 0.0], # ALA
[0.0, 0.0, 0.0, 0.0], # ARG
[0.0, 0.0, 0.0, 0.0], # ASN
[0.0, 1.0, 0.0, 0.0], # ASP
[0.0, 0.0, 0.0, 0.0], # CYS
[0.0, 0.0, 0.0, 0.0], # GLN
[0.0, 0.0, 1.0, 0.0], # GLU
[0.0, 0.0, 0.0, 0.0], # GLY
[0.0, 0.0, 0.0, 0.0], # HIS
[0.0, 0.0, 0.0, 0.0], # ILE
[0.0, 0.0, 0.0, 0.0], # LEU
[0.0, 0.0, 0.0, 0.0], # LYS
[0.0, 0.0, 0.0, 0.0], # MET
[0.0, 1.0, 0.0, 0.0], # PHE
[0.0, 0.0, 0.0, 0.0], # PRO
[0.0, 0.0, 0.0, 0.0], # SET
[0.0, 0.0, 0.0, 0.0], # THR
[0.0, 0.0, 0.0, 0.0], # TRP
[0.0, 1.0, 0.0, 0.0], # TYR
[0.0, 0.0, 0.0, 0.0], # VAL
[0.0, 0.0, 0.0, 0.0], # UNK
])

residue_atom_renaming_swaps = {
'ASP': {
'OD1': 'OD2'
},
'GLU': {
'OE1': 'OE2'
},
'PHE': {
'CD1': 'CD2',
'CE1': 'CE2'
},
'TYR': {
'CD1': 'CD2',
'CE1': 'CE2'
},
}

# Create an ambiguous atoms mask. shape: (21, 14).
mask_ambiguous = torch.zeros((21, 14), dtype=torch.bool)
for resname, swap in residue_atom_renaming_swaps.items():
for atom_name1, atom_name2 in swap.items():
restype = restype_order[restype_3to1[resname]]
atom_idx1 = restype_name_to_atom14_names[resname].index(atom_name1)
atom_idx2 = restype_name_to_atom14_names[resname].index(atom_name2)
mask_ambiguous[restype, atom_idx1] = 1
mask_ambiguous[restype, atom_idx2] = 1

restype_3 = [restype_1to3[res] for res in restypes]
restype_3 += ['UNK']

all_matrices = {res: torch.eye(14, dtype=torch.float32) for res in restype_3}
for resname, swap in residue_atom_renaming_swaps.items():
correspondences = torch.arange(14)
renaming_matrix = None
for source_atom_swap, target_atom_swap in swap.items():
source_index = restype_name_to_atom14_names[resname].index(
source_atom_swap)
target_index = restype_name_to_atom14_names[resname].index(
target_atom_swap)
correspondences[source_index] = target_index
correspondences[target_index] = source_index
renaming_matrix = torch.zeros((14, 14), dtype=torch.float32)
for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.
all_matrices[resname] = renaming_matrix.to(torch.float32)
renaming_matrices = torch.stack(
[all_matrices[restype] for restype in restype_3], dim=0)


def substitute(res: str):
if Bio.PDB.is_aa(res):
if res in resnames:
return res
else:
res = PDBData.protein_letters_3to1[res]
if res in restype_1to3.keys():
return restype_1to3[res]
elif res == 'X':
return 'UNK'
else:
# did not get anything that works
return None

+ 147
- 0
opencompass/datasets/SciReasoner/unconditional_protein_generation/omegafold/utils/torch_utils.py View File

@@ -0,0 +1,147 @@
# =============================================================================
# Copyright 2022 HeliXon Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""
PyTorch utilities
"""
# =============================================================================
# Imports
# =============================================================================
import numbers
import typing

import torch
from torch.nn import functional as F

# =============================================================================
# Constants
# =============================================================================

T = typing.TypeVar('T')


# =============================================================================
# Functions
# =============================================================================
def mask2bias(mask: torch.Tensor, *, inf: float = 1e9) -> torch.Tensor:
"""Convert mask to attention bias

Args:
mask: the mask to convert to bias representation
inf: the floating point number to represent infinity

Returns:
bias representation for masking in attention

"""
return mask.float().sub(1).mul(inf)


def normalize(inputs: torch.Tensor,
normalized_shape: typing.Optional[typing.Union[
int, typing.List[int], torch.Size]] = None,
in_place: bool = False) -> torch.Tensor:
"""Layer normalization without a module (and weight)

Args:
inputs: the input tensor to be normalized
normalized_shape: the normalized_shape for normalization
in_place: if to perform the operations in-place

Returns:
normalized tensor

"""
if normalized_shape is None:
normalized_shape = inputs.shape[-1]
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape, )

if in_place:
# This seems to create small discrepancy in result
dim = list(range(len(inputs.shape))[-len(normalized_shape):])
inputs -= inputs.mean(dim=dim, keepdim=True)
inputs *= torch.rsqrt(inputs.var(dim=dim, keepdim=True) + 1e-5)
return inputs
else:
# F.layer_norm seems a bit faster
return F.layer_norm(inputs, normalized_shape, None, None, 1e-5)


def masked_mean(values: torch.Tensor,
mask: torch.Tensor,
dim: typing.Union[int, typing.Sequence[int], None],
keepdim: typing.Optional[bool] = False,
eps: typing.Optional[float] = 4e-5) -> torch.Tensor:
"""Mean operation with mask

Args:
values: the values to take the mean for
mask: the mask to take the mean with
dim: the dimension along which to take the mean
keepdim: to keep the dimension
eps: the epsilon to compute mean for

Returns:
mean result

"""
values = values.masked_fill(~mask.bool(), 0).sum(dim, keepdim=keepdim)
norm = mask.sum(dim, keepdim=keepdim, dtype=values.dtype) + eps
return values / norm


def recursive_to(obj: typing.Any, **kwargs) -> typing.Any:
r"""
Just to move things to space
*args is removed because it brings problems in using .cpu()

Args:
obj (): the object to move
kwargs (): different keyword arguments

Returns:
cuda tensors in its original construct

"""
if isinstance(obj, torch.Tensor):
try:
return obj.to(**kwargs)
except RuntimeError:
kwargs.pop('non_blocking')
return obj.to(**kwargs)
elif isinstance(obj, list):
return [recursive_to(o, **kwargs) for o in obj]
elif isinstance(obj, tuple):
return tuple(recursive_to(o, **kwargs) for o in obj)
elif isinstance(obj, set):
return set(recursive_to(o, **kwargs) for o in obj)
elif isinstance(obj, dict):
return {k: recursive_to(v, **kwargs) for k, v in obj.items()}
elif hasattr(obj, 'to'):
# this takes care of classes that implements the ~to method
return obj.to(**kwargs)
else:
return obj


# =============================================================================
# Classes
# =============================================================================
# =============================================================================
# Tests
# =============================================================================
if __name__ == '__main__':
pass

+ 1
- 0
opencompass/datasets/__init__.py View File

@@ -156,6 +156,7 @@ from .scibench import ScibenchDataset, scibench_postprocess # noqa: F401, F403
from .scicode import * # noqa: F401, F403
from .SciEval import SciEvalDataset # noqa: F401
from .SciKnowEval import * # noqa: F401, F403
from .SciReasoner import * # noqa: F401, F403
from .SeedBench import * # noqa: F401, F403
from .simpleqa import * # noqa: F401, F403
from .siqa import * # noqa: F401, F403


+ 23
- 13
opencompass/models/openai_api.py View File

@@ -821,7 +821,6 @@ class OpenAISDKRollout(OpenAI):
think_tag: str = '</think>',
max_workers: Optional[int] = None,
openai_extra_kwargs: Dict | None = None,
dump_rollout_inf: bool = False,
):
super().__init__(
path,
@@ -931,18 +930,28 @@ class OpenAISDKRollout(OpenAI):
self.logger.info('Start calling OpenAI API')

responses = self.openai_client.chat.completions.create(
**query_data, timeout=timeout,
logprobs=True) # timeout in seconds
**query_data,
timeout=timeout,
logprobs=True,
top_logprobs=self.top_logprobs) # timeout in seconds

if not responses.choices[0].logprobs or not responses.choices[
0].logprobs.content:
token_logprobs = None
token_logprobs = [
c.logprob for c in responses.choices[0].logprobs.content
]
sum_neg_logprob = -float(sum(token_logprobs))
num_tokens = len(token_logprobs)
finish_reason = responses.choices[0].finish_reason
sum_neg_logprob = 0.0
num_tokens = 0
else:
token_logprobs = [
c.logprob
for c in responses.choices[0].logprobs.content
]
sum_neg_logprob = -float(sum(token_logprobs))
num_tokens = len(token_logprobs)

if not responses.choices[0].finish_reason:
finish_reason = 'error'
else:
finish_reason = responses.choices[0].finish_reason
rollout = dict(
token_logprobs=token_logprobs,
sum_neg_logprob=sum_neg_logprob,
@@ -1006,14 +1015,15 @@ class OpenAISDKRollout(OpenAI):
content,
)
if content:
return dict(output=reasoning_content + self.think_tag +
content,
return dict(prediction=reasoning_content +
self.think_tag + content,
rollout=rollout)
else:
return dict(output=reasoning_content, rollout=rollout)
return dict(prediction=reasoning_content,
rollout=rollout)

else:
return dict(output=content, rollout=rollout)
return dict(prediction=content, rollout=rollout)

except (BadRequestError, APIStatusError) as e:
# Handle BadRequest status


+ 7
- 4
opencompass/openicl/icl_inferencer/icl_gen_inferencer.py View File

@@ -1,5 +1,5 @@
"""Direct Generation Inferencer."""
import copy
import inspect
import json
import os
@@ -175,12 +175,15 @@ class GenInferencer(BaseInferencer):
prompt[i]['prompt'])
input_length += prompt[i]['input_length']

pred_str = copy.deepcopy(prediction)
if isinstance(pred_str, dict):
pred_str = pred_str['prediction']

if num_return_sequences == 1:
res_length = self.model.get_token_len(prediction)
res_length = self.model.get_token_len(pred_str)
else:
res_length = [
self.model.get_token_len(pred)
for pred in prediction
self.model.get_token_len(pred) for pred in pred_str
]
output_handler.save_results(prompt,
prediction,


+ 9
- 3
opencompass/tasks/openicl_eval.py View File

@@ -47,7 +47,12 @@ class OpenICLEvalTask(BaseTask):
'judge_cfg', {}).get('run_cfg', {}).get('num_gpus', 0),
c.get('eval_cfg', {}).get('evaluator', {}).get(
'llm_evaluator', {}).get('judge_cfg', {}).get(
'run_cfg', {}).get('num_gpus', 0))
'run_cfg', {}).get('num_gpus', 0),
next(
iter(
c.get('eval_cfg', {}).get('evaluator', {}).get(
'judge_cfg', {}).get('judgers', [])), {}).get(
'run_cfg', {}).get('num_gpus', 0))
for c in sum(self.dataset_cfgs, []))
self.num_procs = max(
c.get('eval_cfg', {}).get('evaluator', {}).get(
@@ -130,13 +135,14 @@ class OpenICLEvalTask(BaseTask):
test_set = build_dataset_from_cfg(self.dataset_cfg).test
# Postprocess dataset if necessary
if 'dataset_postprocessor' in self.eval_cfg:
proc = self.eval_cfg['dataset_postprocessor']['type']
kwargs = copy.deepcopy(self.eval_cfg['dataset_postprocessor'])
proc = kwargs.pop('type')
if isinstance(proc, str):
proc = TEXT_POSTPROCESSORS.get(proc)

def postprocess(sample):
s = sample[self.output_column]
sample[self.output_column] = proc(s)
sample[self.output_column] = proc(s, **kwargs)
return sample

test_set = test_set.map(postprocess)


+ 42
- 0
opencompass/utils/datasets_info.py View File

@@ -542,6 +542,48 @@ DATASETS_MAPPING = {
"hf_id": "",
"local": "./data/phybench",
},

# SciReasoner
"opencompass/SciReasoner-bio_instruction":{
"ms_id": "",
"hf_id": "",
"local": "./data/SciReasoner/bio_instruction",
},
"opencompass/SciReasoner-Conditional_generation":{
"ms_id": "",
"hf_id": "",
"local": "./data/SciReasoner/Conditional_generation",
},
"opencompass/SciReasoner-GUE":{
"ms_id": "",
"hf_id": "",
"local": "./data/SciReasoner/GUE-test",
},
"opencompass/SciReasoner-LLM4Mat":{
"ms_id": "",
"hf_id": "",
"local": "./data/SciReasoner/LLM4Mat-test",
},
"opencompass/SciReasoner-Mol_Instructions":{
"ms_id": "",
"hf_id": "",
"local": "./data/SciReasoner/Mol-Instructions-test",
},
"opencompass/SciReasoner-OPI":{
"ms_id": "",
"hf_id": "",
"local": "./data/SciReasoner/OPI_test",
},
"opencompass/SciReasoner-PEER":{
"ms_id": "",
"hf_id": "",
"local": "./data/SciReasoner/PEER-test",
},
"opencompass/SciReasoner-smol":{
"ms_id": "",
"hf_id": "",
"local": "./data/SciReasoner/smol-test",
},
}

DATASETS_URL = {


+ 7
- 0
requirements/extra.txt View File

@@ -2,6 +2,9 @@
alpaca-eval==0.6
# OlympiadBench
antlr4-python3-runtime==4.11
# UPG
Bio
# OlympiadBench
cn2an
# Dingo
dingo-python==1.5.0
@@ -23,10 +26,14 @@ pint
pyext
# Law Bench
pypinyin
# LLM4Chem
rdchiral
# Smolinstruct
rdkit
# Molinstructions
selfies
# Scireasoner Composition Material
smact
# IFBench
syllapy
# RULER


Loading…
Cancel
Save
Baidu
map