Spaces:
Sleeping
Sleeping
Upload 162 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +5 -0
- __init__.py +3 -0
- __pycache__/get_molecular_agent.cpython-310.pyc +0 -0
- __pycache__/get_reaction_agent.cpython-310.pyc +0 -0
- __pycache__/main.cpython-310.pyc +0 -0
- app.ipynb +295 -0
- app.py +239 -0
- chemiener/__init__.py +1 -0
- chemiener/__pycache__/__init__.cpython-310.pyc +0 -0
- chemiener/__pycache__/__init__.cpython-38.pyc +0 -0
- chemiener/__pycache__/dataset.cpython-310.pyc +0 -0
- chemiener/__pycache__/dataset.cpython-38.pyc +0 -0
- chemiener/__pycache__/interface.cpython-310.pyc +0 -0
- chemiener/__pycache__/interface.cpython-38.pyc +0 -0
- chemiener/__pycache__/model.cpython-310.pyc +0 -0
- chemiener/__pycache__/model.cpython-38.pyc +0 -0
- chemiener/__pycache__/utils.cpython-310.pyc +0 -0
- chemiener/__pycache__/utils.cpython-38.pyc +0 -0
- chemiener/dataset.py +172 -0
- chemiener/interface.py +124 -0
- chemiener/main.py +345 -0
- chemiener/model.py +14 -0
- chemiener/utils.py +23 -0
- chemietoolkit/__init__.py +1 -0
- chemietoolkit/__pycache__/__init__.cpython-310.pyc +0 -0
- chemietoolkit/__pycache__/__init__.cpython-38.pyc +0 -0
- chemietoolkit/__pycache__/chemrxnextractor.cpython-310.pyc +0 -0
- chemietoolkit/__pycache__/chemrxnextractor.cpython-38.pyc +0 -0
- chemietoolkit/__pycache__/interface.cpython-310.pyc +0 -0
- chemietoolkit/__pycache__/interface.cpython-38.pyc +0 -0
- chemietoolkit/__pycache__/tableextractor.cpython-310.pyc +0 -0
- chemietoolkit/__pycache__/utils.cpython-310.pyc +0 -0
- chemietoolkit/chemrxnextractor.py +107 -0
- chemietoolkit/interface.py +749 -0
- chemietoolkit/tableextractor.py +340 -0
- chemietoolkit/utils.py +1018 -0
- examples/exp.png +3 -0
- examples/image.webp +0 -0
- examples/rdkit.png +0 -0
- examples/reaction1.jpg +0 -0
- examples/reaction2.png +0 -0
- examples/reaction3.png +0 -0
- examples/reaction4.png +3 -0
- get_molecular_agent.py +599 -0
- get_reaction_agent.py +507 -0
- main.py +546 -0
- main_Rgroup_debug.ipynb +993 -0
- molscribe/__init__.py +1 -0
- molscribe/__pycache__/__init__.cpython-310.pyc +0 -0
- molscribe/__pycache__/augment.cpython-310.pyc +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
examples/exp.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
examples/reaction4.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
molscribe/indigo/lib/Linux/x64/libbingo.so filter=lfs diff=lfs merge=lfs -text
|
39 |
+
molscribe/indigo/lib/Linux/x64/libindigo-renderer.so filter=lfs diff=lfs merge=lfs -text
|
40 |
+
molscribe/indigo/lib/Linux/x64/libindigo.so filter=lfs diff=lfs merge=lfs -text
|
__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = "0.1.0"
|
2 |
+
__author__ = 'Alex Wang'
|
3 |
+
__credits__ = 'CSAIL'
|
__pycache__/get_molecular_agent.cpython-310.pyc
ADDED
Binary file (9.08 kB). View file
|
|
__pycache__/get_reaction_agent.cpython-310.pyc
ADDED
Binary file (6.94 kB). View file
|
|
__pycache__/main.cpython-310.pyc
ADDED
Binary file (8.97 kB). View file
|
|
app.ipynb
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "d13d3631",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"* Running on local URL: http://127.0.0.1:7866\n",
|
14 |
+
"\n",
|
15 |
+
"To create a public link, set `share=True` in `launch()`.\n"
|
16 |
+
]
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"data": {
|
20 |
+
"text/html": [
|
21 |
+
"<div><iframe src=\"http://127.0.0.1:7866/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
22 |
+
],
|
23 |
+
"text/plain": [
|
24 |
+
"<IPython.core.display.HTML object>"
|
25 |
+
]
|
26 |
+
},
|
27 |
+
"metadata": {},
|
28 |
+
"output_type": "display_data"
|
29 |
+
}
|
30 |
+
],
|
31 |
+
"source": [
|
32 |
+
"import os\n",
|
33 |
+
"import gradio as gr\n",
|
34 |
+
"import json\n",
|
35 |
+
"from main import ChemEagle # 支持 API key 通过环境变量\n",
|
36 |
+
"from rdkit import Chem\n",
|
37 |
+
"from rdkit.Chem import rdChemReactions\n",
|
38 |
+
"from rdkit.Chem import Draw\n",
|
39 |
+
"from rdkit.Chem import AllChem\n",
|
40 |
+
"from rdkit.Chem.Draw import rdMolDraw2D\n",
|
41 |
+
"import cairosvg\n",
|
42 |
+
"import re\n",
|
43 |
+
"import torch\n",
|
44 |
+
"\n",
|
45 |
+
"example_diagram = \"examples/exp.png\"\n",
|
46 |
+
"rdkit_image = \"examples/rdkit.png\"\n",
|
47 |
+
"# 解析 ChemEagle 返回的结构化数据\n",
|
48 |
+
"def parse_reactions(output_json):\n",
|
49 |
+
" \"\"\"\n",
|
50 |
+
" 解析 JSON 格式的反应数据并格式化输出,包含颜色定制。\n",
|
51 |
+
" \"\"\"\n",
|
52 |
+
" if isinstance(output_json, str):\n",
|
53 |
+
" reactions_data = json.loads(output_json)\n",
|
54 |
+
" elif isinstance(output_json, dict):\n",
|
55 |
+
" reactions_data = output_json # 转换 JSON 字符串为字典\n",
|
56 |
+
" reactions_list = reactions_data.get(\"reactions\", [])\n",
|
57 |
+
" detailed_output = []\n",
|
58 |
+
" smiles_output = [] \n",
|
59 |
+
"\n",
|
60 |
+
" for reaction in reactions_list:\n",
|
61 |
+
" reaction_id = reaction.get(\"reaction_id\", \"Unknown ID\")\n",
|
62 |
+
" reactants = [r.get(\"smiles\", \"Unknown\") for r in reaction.get(\"reactants\", [])]\n",
|
63 |
+
" conditions = [\n",
|
64 |
+
" f\"<span style='color:red'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>\"\n",
|
65 |
+
" for c in reaction.get(\"condition\", [])\n",
|
66 |
+
" ]\n",
|
67 |
+
" conditions_1 = [\n",
|
68 |
+
" f\"<span style='color:black'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>\"\n",
|
69 |
+
" for c in reaction.get(\"condition\", [])\n",
|
70 |
+
" ]\n",
|
71 |
+
" products = [f\"<span style='color:orange'>{p.get('smiles', 'Unknown')}</span>\" for p in reaction.get(\"products\", [])]\n",
|
72 |
+
" products_1 = [f\"<span style='color:black'>{p.get('smiles', 'Unknown')}</span>\" for p in reaction.get(\"products\", [])]\n",
|
73 |
+
" products_2 = [r.get(\"smiles\", \"Unknown\") for r in reaction.get(\"products\", [])]\n",
|
74 |
+
" \n",
|
75 |
+
" additional = reaction.get(\"additional_info\", [])\n",
|
76 |
+
" additional_str = [str(x) for x in additional if x is not None]\n",
|
77 |
+
"\n",
|
78 |
+
" tail = conditions_1 + additional_str\n",
|
79 |
+
" tail_str = \", \".join(tail)\n",
|
80 |
+
"\n",
|
81 |
+
" # 构造反应的完整字符串,定制字体颜色\n",
|
82 |
+
" full_reaction = f\"{'.'.join(reactants)}>>{'.'.join(products_1)} | {tail_str}\"\n",
|
83 |
+
" full_reaction = f\"<span style='color:black'>{full_reaction}</span>\"\n",
|
84 |
+
" \n",
|
85 |
+
" # 详细反应格式化输出\n",
|
86 |
+
" reaction_output = f\"<b>Reaction: </b> {reaction_id}<br>\"\n",
|
87 |
+
" reaction_output += f\" Reactants: <span style='color:blue'>{', '.join(reactants)}</span><br>\"\n",
|
88 |
+
" reaction_output += f\" Conditions: {', '.join(conditions)}<br>\"\n",
|
89 |
+
" reaction_output += f\" Products: {', '.join(products)}<br>\"\n",
|
90 |
+
" reaction_output += f\" additional_info: {', '.join(additional_str)}<br>\"\n",
|
91 |
+
" reaction_output += f\" <b>Full Reaction:</b> {full_reaction}<br>\"\n",
|
92 |
+
" reaction_output += \"<br>\"\n",
|
93 |
+
" detailed_output.append(reaction_output)\n",
|
94 |
+
"\n",
|
95 |
+
" reaction_smiles = f\"{'.'.join(reactants)}>>{'.'.join(products_2)}\"\n",
|
96 |
+
" smiles_output.append(reaction_smiles)\n",
|
97 |
+
" return detailed_output, smiles_output\n",
|
98 |
+
"\n",
|
99 |
+
"\n",
|
100 |
+
"# 核心处理函数,仅使用 API Key 和图像\n",
|
101 |
+
"def process_chem_image(api_key, image):\n",
|
102 |
+
" # 设置 API Key 环境变量,供 ChemEagle 使用\n",
|
103 |
+
" os.environ[\"CHEMEAGLE_API_KEY\"] = api_key\n",
|
104 |
+
"\n",
|
105 |
+
" # 保存上传图片\n",
|
106 |
+
" image_path = \"temp_image.png\"\n",
|
107 |
+
" image.save(image_path)\n",
|
108 |
+
"\n",
|
109 |
+
" # 调用 ChemEagle(实现内部读取 os.getenv)\n",
|
110 |
+
" chemeagle_result = ChemEagle(image_path)\n",
|
111 |
+
"\n",
|
112 |
+
" # 解析输出\n",
|
113 |
+
" detailed, smiles = parse_reactions(chemeagle_result)\n",
|
114 |
+
"\n",
|
115 |
+
" # 写出 JSON\n",
|
116 |
+
" json_path = \"output.json\"\n",
|
117 |
+
" with open(json_path, 'w') as jf:\n",
|
118 |
+
" json.dump(chemeagle_result, jf, indent=2)\n",
|
119 |
+
"\n",
|
120 |
+
" # 返回 HTML、SMILES 合并文本、示意图、JSON 下载\n",
|
121 |
+
" return \"\\n\\n\".join(detailed), smiles, example_diagram, json_path\n",
|
122 |
+
"\n",
|
123 |
+
"# 构建 Gradio 界面\n",
|
124 |
+
"with gr.Blocks() as demo:\n",
|
125 |
+
" gr.Markdown(\n",
|
126 |
+
" \"\"\"\n",
|
127 |
+
" <center><h1>ChemEagle: A Multi-Agent System for Multimodal Chemical Information Extraction</h1></center>\n",
|
128 |
+
" Upload a multimodal reaction image and type your OpenAI API key to extract multimodal chemical information.\n",
|
129 |
+
" \"\"\"\n",
|
130 |
+
" )\n",
|
131 |
+
"\n",
|
132 |
+
" with gr.Row():\n",
|
133 |
+
" # ———— 左侧:上传 + API Key + 按钮 ————\n",
|
134 |
+
" with gr.Column(scale=1):\n",
|
135 |
+
" image_input = gr.Image(type=\"pil\", label=\"Upload a multimodal reaction image\")\n",
|
136 |
+
" api_key_input = gr.Textbox(\n",
|
137 |
+
" label=\"Your API-Key\",\n",
|
138 |
+
" placeholder=\"Type your OpenAI_API_KEY\",\n",
|
139 |
+
" type=\"password\"\n",
|
140 |
+
" )\n",
|
141 |
+
" with gr.Row():\n",
|
142 |
+
" clear_btn = gr.Button(\"Clear\")\n",
|
143 |
+
" run_btn = gr.Button(\"Run\", elem_id=\"submit-btn\")\n",
|
144 |
+
"\n",
|
145 |
+
" # ———— 中间:解析结果 + 示意图 ————\n",
|
146 |
+
" with gr.Column(scale=1):\n",
|
147 |
+
" gr.Markdown(\"### Parsed Reactions\")\n",
|
148 |
+
" reaction_output = gr.HTML(label=\"Detailed Reaction Output\")\n",
|
149 |
+
" gr.Markdown(\"### Schematic Diagram\")\n",
|
150 |
+
" schematic_diagram = gr.Image(value=example_diagram, label=\"示意图\")\n",
|
151 |
+
"\n",
|
152 |
+
" # ———— 右侧:SMILES 拆分 & RDKit 渲染 + JSON 下载 ————\n",
|
153 |
+
" with gr.Column(scale=1):\n",
|
154 |
+
" gr.Markdown(\"### Machine-readable Output\")\n",
|
155 |
+
" smiles_output = gr.Textbox(\n",
|
156 |
+
" label=\"Reaction SMILES\",\n",
|
157 |
+
" show_copy_button=True,\n",
|
158 |
+
" interactive=False,\n",
|
159 |
+
" visible=False\n",
|
160 |
+
" )\n",
|
161 |
+
"\n",
|
162 |
+
" @gr.render(inputs = smiles_output) # 使用gr.render修饰器绑定输入和渲染逻辑\n",
|
163 |
+
" def show_split(inputs): # 定义处理和展示分割文本的函数\n",
|
164 |
+
" if not inputs or isinstance(inputs, str) and inputs.strip() == \"\": # 检查输入文本是否为空\n",
|
165 |
+
" return gr.Textbox(label= \"SMILES of Reaction i\"), gr.Image(value=rdkit_image, label= \"RDKit Image of Reaction i\",height=100)\n",
|
166 |
+
" else:\n",
|
167 |
+
" # 假设输入是逗号分隔的 SMILES 字符串\n",
|
168 |
+
" smiles_list = inputs.split(\",\")\n",
|
169 |
+
" smiles_list = [re.sub(r\"^\\s*\\[?'?|'\\]?\\s*$\", \"\", item) for item in smiles_list]\n",
|
170 |
+
" components = [] # 初始化一个组件列表,用于存放每个 SMILES 对应的 Textbox 组件\n",
|
171 |
+
" for i, smiles in enumerate(smiles_list): \n",
|
172 |
+
" smiles.replace('\"', '').replace(\"'\", \"\").replace(\"[\", \"\").replace(\"]\", \"\")\n",
|
173 |
+
" rxn = rdChemReactions.ReactionFromSmarts(smiles, useSmiles=True)\n",
|
174 |
+
" \n",
|
175 |
+
" if rxn:\n",
|
176 |
+
"\n",
|
177 |
+
" new_rxn = AllChem.ChemicalReaction()\t\n",
|
178 |
+
" for mol in rxn.GetReactants():\n",
|
179 |
+
" mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol))\n",
|
180 |
+
" new_rxn.AddReactantTemplate(mol)\n",
|
181 |
+
" for mol in rxn.GetProducts():\n",
|
182 |
+
" mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol))\n",
|
183 |
+
" new_rxn.AddProductTemplate(mol)\n",
|
184 |
+
"\n",
|
185 |
+
" rxn = new_rxn\n",
|
186 |
+
"\n",
|
187 |
+
" def atom_mapping_remover(rxn):\n",
|
188 |
+
" for reactant in rxn.GetReactants():\n",
|
189 |
+
" for atom in reactant.GetAtoms():\n",
|
190 |
+
" atom.SetAtomMapNum(0)\n",
|
191 |
+
" for product in rxn.GetProducts():\n",
|
192 |
+
" for atom in product.GetAtoms():\n",
|
193 |
+
" atom.SetAtomMapNum(0)\n",
|
194 |
+
" return rxn\n",
|
195 |
+
" \n",
|
196 |
+
" atom_mapping_remover(rxn)\n",
|
197 |
+
"\n",
|
198 |
+
" reactant1 = rxn.GetReactantTemplate(0)\n",
|
199 |
+
" print(reactant1.GetNumBonds)\n",
|
200 |
+
" reactant2 = rxn.GetReactantTemplate(1) if rxn.GetNumReactantTemplates() > 1 else None\n",
|
201 |
+
"\n",
|
202 |
+
" if reactant1.GetNumBonds() > 0:\n",
|
203 |
+
" bond_length_reference = Draw.MeanBondLength(reactant1)\n",
|
204 |
+
" elif reactant2 and reactant2.GetNumBonds() > 0:\n",
|
205 |
+
" bond_length_reference = Draw.MeanBondLength(reactant2)\n",
|
206 |
+
" else:\n",
|
207 |
+
" bond_length_reference = 1.0 \n",
|
208 |
+
"\n",
|
209 |
+
"\n",
|
210 |
+
" drawer = rdMolDraw2D.MolDraw2DSVG(-1, -1)\n",
|
211 |
+
" dopts = drawer.drawOptions()\n",
|
212 |
+
" dopts.padding = 0.1 \n",
|
213 |
+
" dopts.includeRadicals = True\n",
|
214 |
+
" Draw.SetACS1996Mode(dopts, bond_length_reference*0.55)\n",
|
215 |
+
" dopts.bondLineWidth = 1.5\n",
|
216 |
+
" drawer.DrawReaction(rxn)\n",
|
217 |
+
" drawer.FinishDrawing()\n",
|
218 |
+
" svg_content = drawer.GetDrawingText()\n",
|
219 |
+
" svg_file = f\"reaction{i+1}.svg\"\n",
|
220 |
+
" with open(svg_file, \"w\") as f:\n",
|
221 |
+
" f.write(svg_content)\n",
|
222 |
+
" png_file = f\"reaction_{i+1}.png\"\n",
|
223 |
+
" cairosvg.svg2png(url=svg_file, write_to=png_file)\n",
|
224 |
+
"\n",
|
225 |
+
"\n",
|
226 |
+
" \n",
|
227 |
+
" components.append(gr.Textbox(value=smiles,label= f\"SMILES of Reaction {i}\", show_copy_button=True, interactive=False))\n",
|
228 |
+
" components.append(gr.Image(value=png_file,label= f\"RDKit Image of Reaction {i}\")) \n",
|
229 |
+
" return components # 返回包含所有 SMILES Textbox 组件的列表\n",
|
230 |
+
"\n",
|
231 |
+
" download_json = gr.File(label=\"Download JSON File\")\n",
|
232 |
+
"\n",
|
233 |
+
"\n",
|
234 |
+
" gr.Examples(\n",
|
235 |
+
" examples=[\n",
|
236 |
+
" [\"examples/reaction1.jpg\", \"\"],\n",
|
237 |
+
" [\"examples/reaction2.png\", \"\"],\n",
|
238 |
+
" [\"examples/reaction3.png\", \"\"],\n",
|
239 |
+
" [\"examples/reaction4.png\", \"\"],\n",
|
240 |
+
" \n",
|
241 |
+
" \n",
|
242 |
+
" ],\n",
|
243 |
+
" inputs=[image_input, api_key_input],\n",
|
244 |
+
" outputs=[reaction_output, smiles_output, schematic_diagram, download_json],\n",
|
245 |
+
" cache_examples=False,\n",
|
246 |
+
" examples_per_page=4,\n",
|
247 |
+
" )\n",
|
248 |
+
"\n",
|
249 |
+
" # ———— 清空与运行 绑定 ————\n",
|
250 |
+
" clear_btn.click(\n",
|
251 |
+
" lambda: (None, None, None, None, None),\n",
|
252 |
+
" inputs=[],\n",
|
253 |
+
" outputs=[image_input, api_key_input, reaction_output, smiles_output, download_json]\n",
|
254 |
+
" )\n",
|
255 |
+
" run_btn.click(\n",
|
256 |
+
" process_chem_image,\n",
|
257 |
+
" inputs=[api_key_input, image_input],\n",
|
258 |
+
" outputs=[reaction_output, smiles_output, schematic_diagram, download_json]\n",
|
259 |
+
" )\n",
|
260 |
+
"\n",
|
261 |
+
" # 自定义按钮样式\n",
|
262 |
+
" demo.css = \"\"\"\n",
|
263 |
+
" #submit-btn {\n",
|
264 |
+
" background-color: #FF914D;\n",
|
265 |
+
" color: white;\n",
|
266 |
+
" font-weight: bold;\n",
|
267 |
+
" }\n",
|
268 |
+
" \"\"\"\n",
|
269 |
+
"\n",
|
270 |
+
" demo.launch()"
|
271 |
+
]
|
272 |
+
}
|
273 |
+
],
|
274 |
+
"metadata": {
|
275 |
+
"kernelspec": {
|
276 |
+
"display_name": "openchemie",
|
277 |
+
"language": "python",
|
278 |
+
"name": "python3"
|
279 |
+
},
|
280 |
+
"language_info": {
|
281 |
+
"codemirror_mode": {
|
282 |
+
"name": "ipython",
|
283 |
+
"version": 3
|
284 |
+
},
|
285 |
+
"file_extension": ".py",
|
286 |
+
"mimetype": "text/x-python",
|
287 |
+
"name": "python",
|
288 |
+
"nbconvert_exporter": "python",
|
289 |
+
"pygments_lexer": "ipython3",
|
290 |
+
"version": "3.10.14"
|
291 |
+
}
|
292 |
+
},
|
293 |
+
"nbformat": 4,
|
294 |
+
"nbformat_minor": 5
|
295 |
+
}
|
app.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import json
|
4 |
+
from main import ChemEagle # 支持 API key 通过环境变量
|
5 |
+
from rdkit import Chem
|
6 |
+
from rdkit.Chem import rdChemReactions
|
7 |
+
from rdkit.Chem import Draw
|
8 |
+
from rdkit.Chem import AllChem
|
9 |
+
from rdkit.Chem.Draw import rdMolDraw2D
|
10 |
+
import cairosvg
|
11 |
+
import re
|
12 |
+
import torch
|
13 |
+
|
14 |
+
example_diagram = "examples/exp.png"
|
15 |
+
rdkit_image = "examples/rdkit.png"
|
16 |
+
# 解析 ChemEagle 返回的结构化数据
|
17 |
+
def parse_reactions(output_json):
|
18 |
+
"""
|
19 |
+
解析 JSON 格式的反应数据并格式化输出,包含颜色定制。
|
20 |
+
"""
|
21 |
+
if isinstance(output_json, str):
|
22 |
+
reactions_data = json.loads(output_json)
|
23 |
+
elif isinstance(output_json, dict):
|
24 |
+
reactions_data = output_json # 转换 JSON 字符串为字典
|
25 |
+
reactions_list = reactions_data.get("reactions", [])
|
26 |
+
detailed_output = []
|
27 |
+
smiles_output = []
|
28 |
+
|
29 |
+
for reaction in reactions_list:
|
30 |
+
reaction_id = reaction.get("reaction_id", "Unknown ID")
|
31 |
+
reactants = [r.get("smiles", "Unknown") for r in reaction.get("reactants", [])]
|
32 |
+
conditions = [
|
33 |
+
f"<span style='color:red'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>"
|
34 |
+
for c in reaction.get("condition", [])
|
35 |
+
]
|
36 |
+
conditions_1 = [
|
37 |
+
f"<span style='color:black'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>"
|
38 |
+
for c in reaction.get("condition", [])
|
39 |
+
]
|
40 |
+
products = [f"<span style='color:orange'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])]
|
41 |
+
products_1 = [f"<span style='color:black'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])]
|
42 |
+
products_2 = [r.get("smiles", "Unknown") for r in reaction.get("products", [])]
|
43 |
+
|
44 |
+
additional = reaction.get("additional_info", [])
|
45 |
+
additional_str = [str(x) for x in additional if x is not None]
|
46 |
+
|
47 |
+
tail = conditions_1 + additional_str
|
48 |
+
tail_str = ", ".join(tail)
|
49 |
+
|
50 |
+
# 构造反应的完整字符串,定制字体颜色
|
51 |
+
full_reaction = f"{'.'.join(reactants)}>>{'.'.join(products_1)} | {tail_str}"
|
52 |
+
full_reaction = f"<span style='color:black'>{full_reaction}</span>"
|
53 |
+
|
54 |
+
# 详细反应格式化输出
|
55 |
+
reaction_output = f"<b>Reaction: </b> {reaction_id}<br>"
|
56 |
+
reaction_output += f" Reactants: <span style='color:blue'>{', '.join(reactants)}</span><br>"
|
57 |
+
reaction_output += f" Conditions: {', '.join(conditions)}<br>"
|
58 |
+
reaction_output += f" Products: {', '.join(products)}<br>"
|
59 |
+
reaction_output += f" additional_info: {', '.join(additional_str)}<br>"
|
60 |
+
reaction_output += f" <b>Full Reaction:</b> {full_reaction}<br>"
|
61 |
+
reaction_output += "<br>"
|
62 |
+
detailed_output.append(reaction_output)
|
63 |
+
|
64 |
+
reaction_smiles = f"{'.'.join(reactants)}>>{'.'.join(products_2)}"
|
65 |
+
smiles_output.append(reaction_smiles)
|
66 |
+
return detailed_output, smiles_output
|
67 |
+
|
68 |
+
|
69 |
+
# 核心处理函数,仅使用 API Key 和图像
|
70 |
+
def process_chem_image(api_key, image):
|
71 |
+
# 设置 API Key 环境变量,供 ChemEagle 使用
|
72 |
+
os.environ["CHEMEAGLE_API_KEY"] = api_key
|
73 |
+
|
74 |
+
# 保存上传图片
|
75 |
+
image_path = "temp_image.png"
|
76 |
+
image.save(image_path)
|
77 |
+
|
78 |
+
# 调用 ChemEagle(实现内部读取 os.getenv)
|
79 |
+
chemeagle_result = ChemEagle(image_path)
|
80 |
+
|
81 |
+
# 解析输出
|
82 |
+
detailed, smiles = parse_reactions(chemeagle_result)
|
83 |
+
|
84 |
+
# 写出 JSON
|
85 |
+
json_path = "output.json"
|
86 |
+
with open(json_path, 'w') as jf:
|
87 |
+
json.dump(chemeagle_result, jf, indent=2)
|
88 |
+
|
89 |
+
# 返回 HTML、SMILES 合并文本、示意图、JSON 下载
|
90 |
+
return "\n\n".join(detailed), smiles, example_diagram, json_path
|
91 |
+
|
92 |
+
# 构建 Gradio 界面
|
93 |
+
with gr.Blocks() as demo:
|
94 |
+
gr.Markdown(
|
95 |
+
"""
|
96 |
+
<center><h1>ChemEagle: A Multi-Agent System for Multimodal Chemical Information Extraction</h1></center>
|
97 |
+
Upload a multimodal reaction image and type your OpenAI API key to extract multimodal chemical information.
|
98 |
+
"""
|
99 |
+
)
|
100 |
+
|
101 |
+
with gr.Row():
|
102 |
+
# ———— 左侧:上传 + API Key + 按钮 ————
|
103 |
+
with gr.Column(scale=1):
|
104 |
+
image_input = gr.Image(type="pil", label="Upload a multimodal reaction image")
|
105 |
+
api_key_input = gr.Textbox(
|
106 |
+
label="Your API-Key",
|
107 |
+
placeholder="Type your OpenAI_API_KEY",
|
108 |
+
type="password"
|
109 |
+
)
|
110 |
+
with gr.Row():
|
111 |
+
clear_btn = gr.Button("Clear")
|
112 |
+
run_btn = gr.Button("Run", elem_id="submit-btn")
|
113 |
+
|
114 |
+
# ———— 中间:解析结果 + 示意图 ————
|
115 |
+
with gr.Column(scale=1):
|
116 |
+
gr.Markdown("### Parsed Reactions")
|
117 |
+
reaction_output = gr.HTML(label="Detailed Reaction Output")
|
118 |
+
gr.Markdown("### Schematic Diagram")
|
119 |
+
schematic_diagram = gr.Image(value=example_diagram, label="示意图")
|
120 |
+
|
121 |
+
# ———— 右侧:SMILES 拆分 & RDKit 渲染 + JSON 下载 ————
|
122 |
+
with gr.Column(scale=1):
|
123 |
+
gr.Markdown("### Machine-readable Output")
|
124 |
+
smiles_output = gr.Textbox(
|
125 |
+
label="Reaction SMILES",
|
126 |
+
show_copy_button=True,
|
127 |
+
interactive=False,
|
128 |
+
visible=False
|
129 |
+
)
|
130 |
+
|
131 |
+
@gr.render(inputs = smiles_output) # 使用gr.render修饰器绑定输入和渲染逻辑
|
132 |
+
def show_split(inputs): # 定义处理和展示分割文本的函数
|
133 |
+
if not inputs or isinstance(inputs, str) and inputs.strip() == "": # 检查输入文本是否为空
|
134 |
+
return gr.Textbox(label= "SMILES of Reaction i"), gr.Image(value=rdkit_image, label= "RDKit Image of Reaction i",height=100)
|
135 |
+
else:
|
136 |
+
# 假设输入是逗号分隔的 SMILES 字符串
|
137 |
+
smiles_list = inputs.split(",")
|
138 |
+
smiles_list = [re.sub(r"^\s*\[?'?|'\]?\s*$", "", item) for item in smiles_list]
|
139 |
+
components = [] # 初始化一个组件列表,用于存放每个 SMILES 对应的 Textbox 组件
|
140 |
+
for i, smiles in enumerate(smiles_list):
|
141 |
+
smiles.replace('"', '').replace("'", "").replace("[", "").replace("]", "")
|
142 |
+
rxn = rdChemReactions.ReactionFromSmarts(smiles, useSmiles=True)
|
143 |
+
|
144 |
+
if rxn:
|
145 |
+
|
146 |
+
new_rxn = AllChem.ChemicalReaction()
|
147 |
+
for mol in rxn.GetReactants():
|
148 |
+
mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol))
|
149 |
+
new_rxn.AddReactantTemplate(mol)
|
150 |
+
for mol in rxn.GetProducts():
|
151 |
+
mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol))
|
152 |
+
new_rxn.AddProductTemplate(mol)
|
153 |
+
|
154 |
+
rxn = new_rxn
|
155 |
+
|
156 |
+
def atom_mapping_remover(rxn):
|
157 |
+
for reactant in rxn.GetReactants():
|
158 |
+
for atom in reactant.GetAtoms():
|
159 |
+
atom.SetAtomMapNum(0)
|
160 |
+
for product in rxn.GetProducts():
|
161 |
+
for atom in product.GetAtoms():
|
162 |
+
atom.SetAtomMapNum(0)
|
163 |
+
return rxn
|
164 |
+
|
165 |
+
atom_mapping_remover(rxn)
|
166 |
+
|
167 |
+
reactant1 = rxn.GetReactantTemplate(0)
|
168 |
+
print(reactant1.GetNumBonds)
|
169 |
+
reactant2 = rxn.GetReactantTemplate(1) if rxn.GetNumReactantTemplates() > 1 else None
|
170 |
+
|
171 |
+
if reactant1.GetNumBonds() > 0:
|
172 |
+
bond_length_reference = Draw.MeanBondLength(reactant1)
|
173 |
+
elif reactant2 and reactant2.GetNumBonds() > 0:
|
174 |
+
bond_length_reference = Draw.MeanBondLength(reactant2)
|
175 |
+
else:
|
176 |
+
bond_length_reference = 1.0
|
177 |
+
|
178 |
+
|
179 |
+
drawer = rdMolDraw2D.MolDraw2DSVG(-1, -1)
|
180 |
+
dopts = drawer.drawOptions()
|
181 |
+
dopts.padding = 0.1
|
182 |
+
dopts.includeRadicals = True
|
183 |
+
Draw.SetACS1996Mode(dopts, bond_length_reference*0.55)
|
184 |
+
dopts.bondLineWidth = 1.5
|
185 |
+
drawer.DrawReaction(rxn)
|
186 |
+
drawer.FinishDrawing()
|
187 |
+
svg_content = drawer.GetDrawingText()
|
188 |
+
svg_file = f"reaction{i+1}.svg"
|
189 |
+
with open(svg_file, "w") as f:
|
190 |
+
f.write(svg_content)
|
191 |
+
png_file = f"reaction_{i+1}.png"
|
192 |
+
cairosvg.svg2png(url=svg_file, write_to=png_file)
|
193 |
+
|
194 |
+
|
195 |
+
|
196 |
+
components.append(gr.Textbox(value=smiles,label= f"SMILES of Reaction {i}", show_copy_button=True, interactive=False))
|
197 |
+
components.append(gr.Image(value=png_file,label= f"RDKit Image of Reaction {i}"))
|
198 |
+
return components # 返回包含所有 SMILES Textbox 组件的列表
|
199 |
+
|
200 |
+
download_json = gr.File(label="Download JSON File")
|
201 |
+
|
202 |
+
|
203 |
+
gr.Examples(
|
204 |
+
examples=[
|
205 |
+
["examples/reaction1.jpg", ""],
|
206 |
+
["examples/reaction2.png", ""],
|
207 |
+
["examples/reaction3.png", ""],
|
208 |
+
["examples/reaction4.png", ""],
|
209 |
+
|
210 |
+
|
211 |
+
],
|
212 |
+
inputs=[image_input, api_key_input],
|
213 |
+
outputs=[reaction_output, smiles_output, schematic_diagram, download_json],
|
214 |
+
cache_examples=False,
|
215 |
+
examples_per_page=4,
|
216 |
+
)
|
217 |
+
|
218 |
+
# ———— 清空与运行 绑定 ————
|
219 |
+
clear_btn.click(
|
220 |
+
lambda: (None, None, None, None, None),
|
221 |
+
inputs=[],
|
222 |
+
outputs=[image_input, api_key_input, reaction_output, smiles_output, download_json]
|
223 |
+
)
|
224 |
+
run_btn.click(
|
225 |
+
process_chem_image,
|
226 |
+
inputs=[api_key_input, image_input],
|
227 |
+
outputs=[reaction_output, smiles_output, schematic_diagram, download_json]
|
228 |
+
)
|
229 |
+
|
230 |
+
# 自定义按钮样式
|
231 |
+
demo.css = """
|
232 |
+
#submit-btn {
|
233 |
+
background-color: #FF914D;
|
234 |
+
color: white;
|
235 |
+
font-weight: bold;
|
236 |
+
}
|
237 |
+
"""
|
238 |
+
|
239 |
+
demo.launch()
|
chemiener/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .interface import ChemNER
|
chemiener/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (187 Bytes). View file
|
|
chemiener/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (185 Bytes). View file
|
|
chemiener/__pycache__/dataset.cpython-310.pyc
ADDED
Binary file (5.37 kB). View file
|
|
chemiener/__pycache__/dataset.cpython-38.pyc
ADDED
Binary file (5.35 kB). View file
|
|
chemiener/__pycache__/interface.cpython-310.pyc
ADDED
Binary file (4.46 kB). View file
|
|
chemiener/__pycache__/interface.cpython-38.pyc
ADDED
Binary file (4.47 kB). View file
|
|
chemiener/__pycache__/model.cpython-310.pyc
ADDED
Binary file (684 Bytes). View file
|
|
chemiener/__pycache__/model.cpython-38.pyc
ADDED
Binary file (680 Bytes). View file
|
|
chemiener/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (1.67 kB). View file
|
|
chemiener/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (1.53 kB). View file
|
|
chemiener/dataset.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import copy
|
4 |
+
import random
|
5 |
+
import json
|
6 |
+
import contextlib
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch.utils.data import DataLoader, Dataset
|
12 |
+
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
|
13 |
+
|
14 |
+
from transformers import BertTokenizerFast, AutoTokenizer, RobertaTokenizerFast
|
15 |
+
|
16 |
+
from .utils import get_class_to_index
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
class NERDataset(Dataset):
|
21 |
+
def __init__(self, args, data_file, split='train'):
|
22 |
+
super().__init__()
|
23 |
+
self.args = args
|
24 |
+
if data_file:
|
25 |
+
data_path = os.path.join(args.data_path, data_file)
|
26 |
+
with open(data_path) as f:
|
27 |
+
self.data = json.load(f)
|
28 |
+
self.name = os.path.basename(data_file).split('.')[0]
|
29 |
+
self.split = split
|
30 |
+
self.is_train = (split == 'train')
|
31 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.args.roberta_checkpoint, cache_dir = self.args.cache_dir)#BertTokenizerFast.from_pretrained('allenai/scibert_scivocab_uncased')
|
32 |
+
self.class_to_index = get_class_to_index(self.args.corpus)
|
33 |
+
self.index_to_class = {self.class_to_index[key]: key for key in self.class_to_index}
|
34 |
+
|
35 |
+
#commment
|
36 |
+
def __len__(self):
|
37 |
+
return len(self.data)
|
38 |
+
|
39 |
+
def __getitem__(self, idx):
|
40 |
+
|
41 |
+
text_tokenized = self.tokenizer(self.data[str(idx)]['text'], truncation = True, max_length = self.args.max_seq_length)
|
42 |
+
if len(text_tokenized['input_ids']) > 512: print(len(text_tokenized['input_ids']))
|
43 |
+
text_tokenized_untruncated = self.tokenizer(self.data[str(idx)]['text'])
|
44 |
+
return text_tokenized, self.align_labels(text_tokenized, self.data[str(idx)]['entities'], len(self.data[str(idx)]['text'])), self.align_labels(text_tokenized_untruncated, self.data[str(idx)]['entities'], len(self.data[str(idx)]['text']))
|
45 |
+
|
46 |
+
def align_labels(self, text_tokenized, entities, length):
|
47 |
+
char_to_class = {}
|
48 |
+
|
49 |
+
for entity in entities:
|
50 |
+
for span in entities[entity]["span"]:
|
51 |
+
for i in range(span[0], span[1]):
|
52 |
+
char_to_class[i] = self.class_to_index[('B-' if i == span[0] else 'I-')+str(entities[entity]["type"])]
|
53 |
+
|
54 |
+
for i in range(length):
|
55 |
+
if i not in char_to_class:
|
56 |
+
char_to_class[i] = 0
|
57 |
+
|
58 |
+
classes = []
|
59 |
+
for i in range(len(text_tokenized[0])):
|
60 |
+
span = text_tokenized.token_to_chars(i)
|
61 |
+
if span is not None:
|
62 |
+
classes.append(char_to_class[span.start])
|
63 |
+
else:
|
64 |
+
classes.append(-100)
|
65 |
+
|
66 |
+
return torch.LongTensor(classes)
|
67 |
+
|
68 |
+
def make_html(word_tokens, predictions):
|
69 |
+
|
70 |
+
toreturn = '''<!DOCTYPE html>
|
71 |
+
<html>
|
72 |
+
<head>
|
73 |
+
<title>Named Entity Recognition Visualization</title>
|
74 |
+
<style>
|
75 |
+
.EXAMPLE_LABEL {
|
76 |
+
color: red;
|
77 |
+
text-decoration: underline red;
|
78 |
+
}
|
79 |
+
.REACTION_PRODUCT {
|
80 |
+
color: orange;
|
81 |
+
text-decoration: underline orange;
|
82 |
+
}
|
83 |
+
.STARTING_MATERIAL {
|
84 |
+
color: gold;
|
85 |
+
text-decoration: underline gold;
|
86 |
+
}
|
87 |
+
.REAGENT_CATALYST {
|
88 |
+
color: green;
|
89 |
+
text-decoration: underline green;
|
90 |
+
}
|
91 |
+
.SOLVENT {
|
92 |
+
color: cyan;
|
93 |
+
text-decoration: underline cyan;
|
94 |
+
}
|
95 |
+
.OTHER_COMPOUND {
|
96 |
+
color: blue;
|
97 |
+
text-decoration: underline blue;
|
98 |
+
}
|
99 |
+
.TIME {
|
100 |
+
color: purple;
|
101 |
+
text-decoration: underline purple;
|
102 |
+
}
|
103 |
+
.TEMPERATURE {
|
104 |
+
color: magenta;
|
105 |
+
text-decoration: underline magenta;
|
106 |
+
}
|
107 |
+
.YIELD_OTHER {
|
108 |
+
color: palegreen;
|
109 |
+
text-decoration: underline palegreen;
|
110 |
+
}
|
111 |
+
.YIELD_PERCENT {
|
112 |
+
color: pink;
|
113 |
+
text-decoration: underline pink;
|
114 |
+
}
|
115 |
+
</style>
|
116 |
+
</head>
|
117 |
+
<body>
|
118 |
+
<p>'''
|
119 |
+
last_label = None
|
120 |
+
for idx, item in enumerate(word_tokens):
|
121 |
+
decoded = self.tokenizer.decode(item, skip_special_tokens = True)
|
122 |
+
if len(decoded)>0:
|
123 |
+
if idx!=0 and decoded[0]!='#':
|
124 |
+
toreturn+=" "
|
125 |
+
label = predictions[idx]
|
126 |
+
if label == last_label:
|
127 |
+
|
128 |
+
toreturn+=decoded if decoded[0]!="#" else decoded[2:]
|
129 |
+
else:
|
130 |
+
if last_label is not None and last_label>0:
|
131 |
+
toreturn+="</u>"
|
132 |
+
if label >0:
|
133 |
+
toreturn+="<u class=\""
|
134 |
+
toreturn+=self.index_to_class[label]
|
135 |
+
toreturn+="\">"
|
136 |
+
toreturn+=decoded if decoded[0]!="#" else decoded[2:]
|
137 |
+
if label == 0:
|
138 |
+
toreturn+=decoded if decoded[0]!="#" else decoded[2:]
|
139 |
+
if idx==len(word_tokens) and label>0:
|
140 |
+
toreturn+="</u>"
|
141 |
+
last_label = label
|
142 |
+
|
143 |
+
toreturn += ''' </p>
|
144 |
+
</body>
|
145 |
+
</html>'''
|
146 |
+
return toreturn
|
147 |
+
|
148 |
+
|
149 |
+
def get_collate_fn():
|
150 |
+
def collate(batch):
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
sentences = []
|
155 |
+
masks = []
|
156 |
+
refs = []
|
157 |
+
|
158 |
+
|
159 |
+
for ex in batch:
|
160 |
+
sentences.append(torch.LongTensor(ex[0]['input_ids']))
|
161 |
+
masks.append(torch.Tensor(ex[0]['attention_mask']))
|
162 |
+
refs.append(ex[1])
|
163 |
+
|
164 |
+
sentences = pad_sequence(sentences, batch_first = True, padding_value = 0)
|
165 |
+
masks = pad_sequence(masks, batch_first = True, padding_value = 0)
|
166 |
+
refs = pad_sequence(refs, batch_first = True, padding_value = -100)
|
167 |
+
return sentences, masks, refs
|
168 |
+
|
169 |
+
return collate
|
170 |
+
|
171 |
+
|
172 |
+
|
chemiener/interface.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
from typing import List
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from .model import build_model
|
8 |
+
|
9 |
+
from .dataset import NERDataset, get_collate_fn
|
10 |
+
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
+
|
13 |
+
from .utils import get_class_to_index
|
14 |
+
|
15 |
+
class ChemNER:
|
16 |
+
|
17 |
+
def __init__(self, model_path, device = None, cache_dir = None):
|
18 |
+
|
19 |
+
self.args = self._get_args(cache_dir)
|
20 |
+
|
21 |
+
states = torch.load(model_path, map_location = torch.device('cpu'))
|
22 |
+
|
23 |
+
if device is None:
|
24 |
+
device = torch.device('cpu')
|
25 |
+
|
26 |
+
self.device = device
|
27 |
+
|
28 |
+
self.model = self.get_model(self.args, device, states['state_dict'])
|
29 |
+
|
30 |
+
self.collate = get_collate_fn()
|
31 |
+
|
32 |
+
self.dataset = NERDataset(self.args, data_file = None)
|
33 |
+
|
34 |
+
self.class_to_index = get_class_to_index(self.args.corpus)
|
35 |
+
|
36 |
+
self.index_to_class = {self.class_to_index[key]: key for key in self.class_to_index}
|
37 |
+
|
38 |
+
def _get_args(self, cache_dir):
|
39 |
+
parser = argparse.ArgumentParser()
|
40 |
+
|
41 |
+
parser.add_argument('--roberta_checkpoint', default = 'dmis-lab/biobert-large-cased-v1.1', type=str, help='which roberta config to use')
|
42 |
+
|
43 |
+
parser.add_argument('--corpus', default = "chemdner", type=str, help="which corpus should the tags be from")
|
44 |
+
|
45 |
+
args = parser.parse_args([])
|
46 |
+
|
47 |
+
args.cache_dir = cache_dir
|
48 |
+
|
49 |
+
return args
|
50 |
+
|
51 |
+
def get_model(self, args, device, model_states):
|
52 |
+
model = build_model(args)
|
53 |
+
|
54 |
+
def remove_prefix(state_dict):
|
55 |
+
return {k.replace('model.', ''): v for k, v in state_dict.items()}
|
56 |
+
|
57 |
+
model.load_state_dict(remove_prefix(model_states), strict = False)
|
58 |
+
|
59 |
+
model.to(device)
|
60 |
+
|
61 |
+
model.eval()
|
62 |
+
|
63 |
+
return model
|
64 |
+
|
65 |
+
def predict_strings(self, strings: List, batch_size = 8):
|
66 |
+
device = self.device
|
67 |
+
|
68 |
+
predictions = []
|
69 |
+
|
70 |
+
def prepare_output(char_span, prediction):
|
71 |
+
toreturn = []
|
72 |
+
|
73 |
+
|
74 |
+
i = 0
|
75 |
+
|
76 |
+
while i < len(char_span):
|
77 |
+
if prediction[i][0] == 'B':
|
78 |
+
toreturn.append((prediction[i][2:], [char_span[i].start, char_span[i].end]))
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
elif len(toreturn) > 0 and prediction[i][2:] == toreturn[-1][0]:
|
84 |
+
toreturn[-1] = (toreturn[-1][0], [toreturn[-1][1][0], char_span[i].end])
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
i += 1
|
89 |
+
|
90 |
+
|
91 |
+
return toreturn
|
92 |
+
|
93 |
+
output = []
|
94 |
+
for idx in range(0, len(strings), batch_size):
|
95 |
+
batch_strings = strings[idx:idx+batch_size]
|
96 |
+
batch_strings_tokenized = [(self.dataset.tokenizer(s, truncation = True, max_length = 512), torch.Tensor([-1]), torch.Tensor([-1]) ) for s in batch_strings]
|
97 |
+
|
98 |
+
|
99 |
+
sentences, masks, refs = self.collate(batch_strings_tokenized)
|
100 |
+
|
101 |
+
predictions = self.model(input_ids = sentences.to(device), attention_mask = masks.to(device))[0].argmax(dim = 2).to('cpu')
|
102 |
+
|
103 |
+
sentences_list = list(sentences)
|
104 |
+
|
105 |
+
predictions_list = list(predictions)
|
106 |
+
|
107 |
+
|
108 |
+
char_spans = []
|
109 |
+
for j, sentence in enumerate(sentences_list):
|
110 |
+
to_add = [batch_strings_tokenized[j][0].token_to_chars(i) for i, word in enumerate(sentence) if len(self.dataset.tokenizer.decode(int(word.item()), skip_special_tokens = True)) > 0 ]
|
111 |
+
char_spans.append(to_add)
|
112 |
+
|
113 |
+
class_predictions = [[self.index_to_class[int(pred.item())] for (pred, word) in zip(sentence_p, sentence_w) if len(self.dataset.tokenizer.decode(int(word.item()), skip_special_tokens = True)) > 0] for (sentence_p, sentence_w) in zip(predictions_list, sentences_list)]
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
output+=[prepare_output(char_span, prediction) for char_span, prediction in zip(char_spans, class_predictions)]
|
118 |
+
|
119 |
+
return output
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
|
chemiener/main.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import json
|
4 |
+
import random
|
5 |
+
import argparse
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
import time
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch.profiler import profile, record_function, ProfilerActivity
|
12 |
+
import torch.distributed as dist
|
13 |
+
import pytorch_lightning as pl
|
14 |
+
from pytorch_lightning import LightningModule, LightningDataModule
|
15 |
+
from pytorch_lightning.callbacks import LearningRateMonitor
|
16 |
+
from pytorch_lightning.strategies.ddp import DDPStrategy
|
17 |
+
from transformers import get_scheduler
|
18 |
+
import transformers
|
19 |
+
|
20 |
+
from dataset import NERDataset, get_collate_fn
|
21 |
+
|
22 |
+
from model import build_model
|
23 |
+
|
24 |
+
from utils import get_class_to_index
|
25 |
+
|
26 |
+
import evaluate
|
27 |
+
|
28 |
+
from seqeval.metrics import accuracy_score
|
29 |
+
from seqeval.metrics import classification_report
|
30 |
+
from seqeval.metrics import f1_score
|
31 |
+
from seqeval.scheme import IOB2
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
def get_args(notebook=False):
|
36 |
+
parser = argparse.ArgumentParser()
|
37 |
+
parser.add_argument('--do_train', action='store_true')
|
38 |
+
parser.add_argument('--do_valid', action='store_true')
|
39 |
+
parser.add_argument('--do_test', action='store_true')
|
40 |
+
parser.add_argument('--fp16', action='store_true')
|
41 |
+
parser.add_argument('--seed', type=int, default=42)
|
42 |
+
parser.add_argument('--gpus', type=int, default=1)
|
43 |
+
parser.add_argument('--print_freq', type=int, default=200)
|
44 |
+
parser.add_argument('--debug', action='store_true')
|
45 |
+
parser.add_argument('--no_eval', action='store_true')
|
46 |
+
|
47 |
+
|
48 |
+
# Data
|
49 |
+
parser.add_argument('--data_path', type=str, default=None)
|
50 |
+
parser.add_argument('--image_path', type=str, default=None)
|
51 |
+
parser.add_argument('--train_file', type=str, default=None)
|
52 |
+
parser.add_argument('--valid_file', type=str, default=None)
|
53 |
+
parser.add_argument('--test_file', type=str, default=None)
|
54 |
+
parser.add_argument('--vocab_file', type=str, default=None)
|
55 |
+
parser.add_argument('--format', type=str, default='reaction')
|
56 |
+
parser.add_argument('--num_workers', type=int, default=8)
|
57 |
+
parser.add_argument('--input_size', type=int, default=224)
|
58 |
+
|
59 |
+
# Training
|
60 |
+
parser.add_argument('--epochs', type=int, default=8)
|
61 |
+
parser.add_argument('--batch_size', type=int, default=256)
|
62 |
+
parser.add_argument('--lr', type=float, default=1e-4)
|
63 |
+
parser.add_argument('--weight_decay', type=float, default=0.05)
|
64 |
+
parser.add_argument('--max_grad_norm', type=float, default=5.)
|
65 |
+
parser.add_argument('--scheduler', type=str, choices=['cosine', 'constant'], default='cosine')
|
66 |
+
parser.add_argument('--warmup_ratio', type=float, default=0)
|
67 |
+
parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
|
68 |
+
parser.add_argument('--load_path', type=str, default=None)
|
69 |
+
parser.add_argument('--load_encoder_only', action='store_true')
|
70 |
+
parser.add_argument('--train_steps_per_epoch', type=int, default=-1)
|
71 |
+
parser.add_argument('--eval_per_epoch', type=int, default=10)
|
72 |
+
parser.add_argument('--save_path', type=str, default='output/')
|
73 |
+
parser.add_argument('--save_mode', type=str, default='best', choices=['best', 'all', 'last'])
|
74 |
+
parser.add_argument('--load_ckpt', type=str, default='best')
|
75 |
+
parser.add_argument('--resume', action='store_true')
|
76 |
+
parser.add_argument('--num_train_example', type=int, default=None)
|
77 |
+
|
78 |
+
parser.add_argument('--roberta_checkpoint', type=str, default = "roberta-base")
|
79 |
+
|
80 |
+
parser.add_argument('--corpus', type=str, default = "chemu")
|
81 |
+
|
82 |
+
parser.add_argument('--cache_dir')
|
83 |
+
|
84 |
+
parser.add_argument('--eval_truncated', action='store_true')
|
85 |
+
|
86 |
+
parser.add_argument('--max_seq_length', type = int, default=512)
|
87 |
+
|
88 |
+
args = parser.parse_args([]) if notebook else parser.parse_args()
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
return args
|
95 |
+
|
96 |
+
|
97 |
+
class ChemIENERecognizer(LightningModule):
|
98 |
+
|
99 |
+
def __init__(self, args):
|
100 |
+
super().__init__()
|
101 |
+
|
102 |
+
self.args = args
|
103 |
+
|
104 |
+
self.model = build_model(args)
|
105 |
+
|
106 |
+
self.validation_step_outputs = []
|
107 |
+
|
108 |
+
def training_step(self, batch, batch_idx):
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
sentences, masks, refs,_ = batch
|
114 |
+
'''
|
115 |
+
print("sentences " + str(sentences))
|
116 |
+
print("sentence shape " + str(sentences.shape))
|
117 |
+
print("masks " + str(masks))
|
118 |
+
print("masks shape " + str(masks.shape))
|
119 |
+
print("refs " + str(refs))
|
120 |
+
print("refs shape " + str(refs.shape))
|
121 |
+
'''
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
loss, logits = self.model(input_ids=sentences, attention_mask=masks, labels=refs)
|
126 |
+
self.log('train/loss', loss)
|
127 |
+
self.log('lr', self.lr_schedulers().get_lr()[0], prog_bar=True, logger=False)
|
128 |
+
return loss
|
129 |
+
|
130 |
+
def validation_step(self, batch, batch_idx):
|
131 |
+
|
132 |
+
sentences, masks, refs, untruncated = batch
|
133 |
+
'''
|
134 |
+
print("sentences " + str(sentences))
|
135 |
+
print("sentence shape " + str(sentences.shape))
|
136 |
+
print("masks " + str(masks))
|
137 |
+
print("masks shape " + str(masks.shape))
|
138 |
+
print("refs " + str(refs))
|
139 |
+
print("refs shape " + str(refs.shape))
|
140 |
+
'''
|
141 |
+
|
142 |
+
logits = self.model(input_ids = sentences, attention_mask=masks)[0]
|
143 |
+
'''
|
144 |
+
print("logits " + str(logits))
|
145 |
+
print(sentences.shape)
|
146 |
+
print(logits.shape)
|
147 |
+
print(torch.eq(logits.argmax(dim = 2), refs).sum())
|
148 |
+
'''
|
149 |
+
self.validation_step_outputs.append((sentences.to("cpu"), logits.argmax(dim = 2).to("cpu"), refs.to('cpu'), untruncated.to("cpu")))
|
150 |
+
|
151 |
+
|
152 |
+
def on_validation_epoch_end(self):
|
153 |
+
if self.trainer.num_devices > 1:
|
154 |
+
gathered_outputs = [None for i in range(self.trainer.num_devices)]
|
155 |
+
dist.all_gather_object(gathered_outputs, self.validation_step_outputs)
|
156 |
+
gathered_outputs = sum(gathered_outputs, [])
|
157 |
+
else:
|
158 |
+
gathered_outputs = self.validation_step_outputs
|
159 |
+
|
160 |
+
sentences = [list(output[0]) for output in gathered_outputs]
|
161 |
+
|
162 |
+
class_to_index = get_class_to_index(self.args.corpus)
|
163 |
+
|
164 |
+
|
165 |
+
|
166 |
+
index_to_class = {class_to_index[key]: key for key in class_to_index}
|
167 |
+
predictions = [list(output[1]) for output in gathered_outputs]
|
168 |
+
labels = [list(output[2]) for output in gathered_outputs]
|
169 |
+
|
170 |
+
untruncateds = [list(output[3]) for output in gathered_outputs]
|
171 |
+
|
172 |
+
untruncateds = [[index_to_class[int(label.item())] for label in sentence if int(label.item()) != -100] for batched in untruncateds for sentence in batched]
|
173 |
+
|
174 |
+
|
175 |
+
output = {"sentences": [[int(word.item()) for (word, label) in zip(sentence_w, sentence_l) if label != -100] for (batched_w, batched_l) in zip(sentences, labels) for (sentence_w, sentence_l) in zip(batched_w, batched_l) ],
|
176 |
+
"predictions": [[index_to_class[int(pred.item())] for (pred, label) in zip(sentence_p, sentence_l) if label!=-100] for (batched_p, batched_l) in zip(predictions, labels) for (sentence_p, sentence_l) in zip(batched_p, batched_l) ],
|
177 |
+
"groundtruth": [[index_to_class[int(label.item())] for label in sentence if label != -100] for batched in labels for sentence in batched]}
|
178 |
+
|
179 |
+
|
180 |
+
#true_labels = [str(label.item()) for batched in labels for sentence in batched for label in sentence if label != -100]
|
181 |
+
#true_predictions = [str(pred.item()) for (batched_p, batched_l) in zip(predictions, labels) for (sentence_p, sentence_l) in zip(batched_p, batched_l) for (pred, label) in zip(sentence_p, sentence_l) if label!=-100 ]
|
182 |
+
|
183 |
+
|
184 |
+
|
185 |
+
#print("true_label " + str(len(true_labels)) + " true_predictions "+str(len(true_predictions)))
|
186 |
+
|
187 |
+
|
188 |
+
#predictions = utils.merge_predictions(gathered_outputs)
|
189 |
+
name = self.eval_dataset.name
|
190 |
+
scores = [0]
|
191 |
+
|
192 |
+
#print(predictions)
|
193 |
+
#print(predictions[0].shape)
|
194 |
+
|
195 |
+
if self.trainer.is_global_zero:
|
196 |
+
if not self.args.no_eval:
|
197 |
+
epoch = self.trainer.current_epoch
|
198 |
+
|
199 |
+
metric = evaluate.load("seqeval", cache_dir = self.args.cache_dir)
|
200 |
+
|
201 |
+
predictions = [ preds + ['O'] * (len(full_groundtruth) - len(preds)) for (preds, full_groundtruth) in zip(output['predictions'], untruncateds)]
|
202 |
+
all_metrics = metric.compute(predictions = predictions, references = untruncateds)
|
203 |
+
|
204 |
+
#accuracy = sum([1 if p == l else 0 for (p, l) in zip(true_predictions, true_labels)])/len(true_labels)
|
205 |
+
|
206 |
+
#precision = torch.eq(self.eval_dataset.data, predictions.argmax(dim = 1)).sum().float()/self.eval_dataset.data.numel()
|
207 |
+
#self.print("Epoch: "+str(epoch)+" accuracy: "+str(accuracy))
|
208 |
+
if self.args.eval_truncated:
|
209 |
+
report = classification_report(output['groundtruth'], output['predictions'], mode = 'strict', scheme = IOB2, output_dict = True)
|
210 |
+
else:
|
211 |
+
#report = classification_report(predictions, untruncateds, output_dict = True)#, mode = 'strict', scheme = IOB2, output_dict = True)
|
212 |
+
report = classification_report(predictions, untruncateds, mode = 'strict', scheme = IOB2, output_dict = True)
|
213 |
+
self.print(report)
|
214 |
+
#self.print("______________________________________________")
|
215 |
+
#self.print(report_strict)
|
216 |
+
scores = [report['micro avg']['f1-score']]
|
217 |
+
with open(os.path.join(self.trainer.default_root_dir, f'prediction_{name}.json'), 'w') as f:
|
218 |
+
json.dump(output, f)
|
219 |
+
|
220 |
+
dist.broadcast_object_list(scores)
|
221 |
+
|
222 |
+
self.log('val/score', scores[0], prog_bar=True, rank_zero_only=True)
|
223 |
+
self.validation_step_outputs.clear()
|
224 |
+
|
225 |
+
|
226 |
+
|
227 |
+
self.validation_step_outputs.clear()
|
228 |
+
|
229 |
+
def configure_optimizers(self):
|
230 |
+
num_training_steps = self.trainer.num_training_steps
|
231 |
+
|
232 |
+
self.print(f'Num training steps: {num_training_steps}')
|
233 |
+
num_warmup_steps = int(num_training_steps * self.args.warmup_ratio)
|
234 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
|
235 |
+
scheduler = get_scheduler(self.args.scheduler, optimizer, num_warmup_steps, num_training_steps)
|
236 |
+
return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'}}
|
237 |
+
|
238 |
+
class NERDataModule(LightningDataModule):
|
239 |
+
|
240 |
+
def __init__(self, args):
|
241 |
+
super().__init__()
|
242 |
+
self.args = args
|
243 |
+
self.collate_fn = get_collate_fn()
|
244 |
+
|
245 |
+
def prepare_data(self):
|
246 |
+
args = self.args
|
247 |
+
if args.do_train:
|
248 |
+
self.train_dataset = NERDataset(args, args.train_file, split='train')
|
249 |
+
if self.args.do_train or self.args.do_valid:
|
250 |
+
self.val_dataset = NERDataset(args, args.valid_file, split='valid')
|
251 |
+
if self.args.do_test:
|
252 |
+
self.test_dataset = NERDataset(args, args.test_file, split='valid')
|
253 |
+
|
254 |
+
def print_stats(self):
|
255 |
+
if self.args.do_train:
|
256 |
+
print(f'Train dataset: {len(self.train_dataset)}')
|
257 |
+
if self.args.do_train or self.args.do_valid:
|
258 |
+
print(f'Valid dataset: {len(self.val_dataset)}')
|
259 |
+
if self.args.do_test:
|
260 |
+
print(f'Test dataset: {len(self.test_dataset)}')
|
261 |
+
|
262 |
+
|
263 |
+
def train_dataloader(self):
|
264 |
+
return torch.utils.data.DataLoader(
|
265 |
+
self.train_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers,
|
266 |
+
collate_fn=self.collate_fn)
|
267 |
+
|
268 |
+
def val_dataloader(self):
|
269 |
+
return torch.utils.data.DataLoader(
|
270 |
+
self.val_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers,
|
271 |
+
collate_fn=self.collate_fn)
|
272 |
+
|
273 |
+
|
274 |
+
def test_dataloader(self):
|
275 |
+
return torch.utils.data.DataLoader(
|
276 |
+
self.test_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers,
|
277 |
+
collate_fn=self.collate_fn)
|
278 |
+
|
279 |
+
|
280 |
+
|
281 |
+
class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
|
282 |
+
def _get_metric_interpolated_filepath_name(self, monitor_candidates, trainer, del_filepath=None) -> str:
|
283 |
+
filepath = self.format_checkpoint_name(monitor_candidates)
|
284 |
+
return filepath
|
285 |
+
|
286 |
+
def main():
|
287 |
+
transformers.utils.logging.set_verbosity_error()
|
288 |
+
args = get_args()
|
289 |
+
|
290 |
+
pl.seed_everything(args.seed, workers = True)
|
291 |
+
|
292 |
+
if args.do_train:
|
293 |
+
model = ChemIENERecognizer(args)
|
294 |
+
else:
|
295 |
+
model = ChemIENERecognizer.load_from_checkpoint(os.path.join(args.save_path, 'checkpoints/best.ckpt'), strict=False,
|
296 |
+
args=args)
|
297 |
+
|
298 |
+
dm = NERDataModule(args)
|
299 |
+
dm.prepare_data()
|
300 |
+
dm.print_stats()
|
301 |
+
|
302 |
+
checkpoint = ModelCheckpoint(monitor='val/score', mode='max', save_top_k=1, filename='best', save_last=True)
|
303 |
+
# checkpoint = ModelCheckpoint(monitor=None, save_top_k=0, save_last=True)
|
304 |
+
lr_monitor = LearningRateMonitor(logging_interval='step')
|
305 |
+
logger = pl.loggers.TensorBoardLogger(args.save_path, name='', version='')
|
306 |
+
|
307 |
+
trainer = pl.Trainer(
|
308 |
+
strategy=DDPStrategy(find_unused_parameters=False),
|
309 |
+
accelerator='gpu',
|
310 |
+
precision = 16,
|
311 |
+
devices=args.gpus,
|
312 |
+
logger=logger,
|
313 |
+
default_root_dir=args.save_path,
|
314 |
+
callbacks=[checkpoint, lr_monitor],
|
315 |
+
max_epochs=args.epochs,
|
316 |
+
gradient_clip_val=args.max_grad_norm,
|
317 |
+
accumulate_grad_batches=args.gradient_accumulation_steps,
|
318 |
+
check_val_every_n_epoch=args.eval_per_epoch,
|
319 |
+
log_every_n_steps=10,
|
320 |
+
deterministic='warn')
|
321 |
+
|
322 |
+
if args.do_train:
|
323 |
+
trainer.num_training_steps = math.ceil(
|
324 |
+
len(dm.train_dataset) / (args.batch_size * args.gpus * args.gradient_accumulation_steps)) * args.epochs
|
325 |
+
model.eval_dataset = dm.val_dataset
|
326 |
+
ckpt_path = os.path.join(args.save_path, 'checkpoints/last.ckpt') if args.resume else None
|
327 |
+
trainer.fit(model, datamodule=dm, ckpt_path=ckpt_path)
|
328 |
+
model = ChemIENERecognizer.load_from_checkpoint(checkpoint.best_model_path, args=args)
|
329 |
+
|
330 |
+
if args.do_valid:
|
331 |
+
|
332 |
+
model.eval_dataset = dm.val_dataset
|
333 |
+
|
334 |
+
trainer.validate(model, datamodule=dm)
|
335 |
+
|
336 |
+
if args.do_test:
|
337 |
+
|
338 |
+
model.test_dataset = dm.test_dataset
|
339 |
+
|
340 |
+
trainer.test(model, datamodule=dm)
|
341 |
+
|
342 |
+
|
343 |
+
if __name__ == "__main__":
|
344 |
+
main()
|
345 |
+
|
chemiener/model.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
from transformers import BertForTokenClassification, RobertaForTokenClassification, AutoModelForTokenClassification
|
6 |
+
|
7 |
+
|
8 |
+
def build_model(args):
|
9 |
+
if args.corpus == "chemu":
|
10 |
+
return AutoModelForTokenClassification.from_pretrained(args.roberta_checkpoint, num_labels = 21, cache_dir = args.cache_dir, return_dict = False)
|
11 |
+
elif args.corpus == "chemdner":
|
12 |
+
return AutoModelForTokenClassification.from_pretrained(args.roberta_checkpoint, num_labels = 17, cache_dir = args.cache_dir, return_dict = False)
|
13 |
+
elif args.corpus == "chemdner-mol":
|
14 |
+
return AutoModelForTokenClassification.from_pretrained(args.roberta_checkpoint, num_labels = 3, cache_dir = args.cache_dir, return_dict = False)
|
chemiener/utils.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
def merge_predictions(results):
|
3 |
+
if len(results) == 0:
|
4 |
+
return []
|
5 |
+
predictions = {}
|
6 |
+
for batch_preds in results:
|
7 |
+
for idx, preds in enumerate(batch_preds):
|
8 |
+
predictions[idx] = preds
|
9 |
+
predictions = [predictions[i] for i in range(len(predictions))]
|
10 |
+
|
11 |
+
return predictions
|
12 |
+
|
13 |
+
def get_class_to_index(corpus):
|
14 |
+
if corpus == "chemu":
|
15 |
+
return {'B-EXAMPLE_LABEL': 1, 'B-REACTION_PRODUCT': 2, 'B-STARTING_MATERIAL': 3, 'B-REAGENT_CATALYST': 4, 'B-SOLVENT': 5, 'B-OTHER_COMPOUND': 6, 'B-TIME': 7, 'B-TEMPERATURE': 8, 'B-YIELD_OTHER': 9, 'B-YIELD_PERCENT': 10, 'O': 0,
|
16 |
+
'I-EXAMPLE_LABEL': 11, 'I-REACTION_PRODUCT': 12, 'I-STARTING_MATERIAL': 13, 'I-REAGENT_CATALYST': 14, 'I-SOLVENT': 15, 'I-OTHER_COMPOUND': 16, 'I-TIME': 17, 'I-TEMPERATURE': 18, 'I-YIELD_OTHER': 19, 'I-YIELD_PERCENT': 20}
|
17 |
+
elif corpus == "chemdner":
|
18 |
+
return {'O': 0, 'B-ABBREVIATION': 1, 'B-FAMILY': 2, 'B-FORMULA': 3, 'B-IDENTIFIER': 4, 'B-MULTIPLE': 5, 'B-SYSTEMATIC': 6, 'B-TRIVIAL': 7, 'B-NO CLASS': 8, 'I-ABBREVIATION': 9, 'I-FAMILY': 10, 'I-FORMULA': 11, 'I-IDENTIFIER': 12, 'I-MULTIPLE': 13, 'I-SYSTEMATIC': 14, 'I-TRIVIAL': 15, 'I-NO CLASS': 16}
|
19 |
+
elif corpus == "chemdner-mol":
|
20 |
+
return {'O': 0, 'B-MOL': 1, 'I-MOL': 2}
|
21 |
+
|
22 |
+
|
23 |
+
|
chemietoolkit/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .interface import ChemIEToolkit
|
chemietoolkit/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (198 Bytes). View file
|
|
chemietoolkit/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (189 Bytes). View file
|
|
chemietoolkit/__pycache__/chemrxnextractor.cpython-310.pyc
ADDED
Binary file (3.66 kB). View file
|
|
chemietoolkit/__pycache__/chemrxnextractor.cpython-38.pyc
ADDED
Binary file (3.62 kB). View file
|
|
chemietoolkit/__pycache__/interface.cpython-310.pyc
ADDED
Binary file (29.3 kB). View file
|
|
chemietoolkit/__pycache__/interface.cpython-38.pyc
ADDED
Binary file (30 kB). View file
|
|
chemietoolkit/__pycache__/tableextractor.cpython-310.pyc
ADDED
Binary file (10.3 kB). View file
|
|
chemietoolkit/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (25 kB). View file
|
|
chemietoolkit/chemrxnextractor.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PyPDF2 import PdfReader, PdfWriter
|
2 |
+
import pdfminer.high_level
|
3 |
+
import pdfminer.layout
|
4 |
+
from operator import itemgetter
|
5 |
+
import os
|
6 |
+
import pdftotext
|
7 |
+
from chemrxnextractor import RxnExtractor
|
8 |
+
|
9 |
+
class ChemRxnExtractor(object):
|
10 |
+
def __init__(self, pdf, pn, model_dir, device):
|
11 |
+
self.pdf_file = pdf
|
12 |
+
self.pages = pn
|
13 |
+
self.model_dir = os.path.join(model_dir, "cre_models_v0.1") # directory saving both prod and role models
|
14 |
+
use_cuda = (device == 'cuda')
|
15 |
+
self.rxn_extractor = RxnExtractor(self.model_dir, use_cuda=use_cuda)
|
16 |
+
self.text_file = "info.txt"
|
17 |
+
self.pdf_text = ""
|
18 |
+
if len(self.pdf_file) > 0:
|
19 |
+
with open(self.pdf_file, "rb") as f:
|
20 |
+
self.pdf_text = pdftotext.PDF(f)
|
21 |
+
|
22 |
+
def set_pdf_file(self, pdf):
|
23 |
+
self.pdf_file = pdf
|
24 |
+
with open(self.pdf_file, "rb") as f:
|
25 |
+
self.pdf_text = pdftotext.PDF(f)
|
26 |
+
|
27 |
+
def set_pages(self, pn):
|
28 |
+
self.pages = pn
|
29 |
+
|
30 |
+
def set_model_dir(self, md):
|
31 |
+
self.model_dir = md
|
32 |
+
self.rxn_extractor = RxnExtractor(self.model_dir)
|
33 |
+
|
34 |
+
def set_text_file(self, tf):
|
35 |
+
self.text_file = tf
|
36 |
+
|
37 |
+
def extract_reactions_from_text(self):
|
38 |
+
if self.pages is None:
|
39 |
+
return self.extract_all(len(self.pdf_text))
|
40 |
+
else:
|
41 |
+
return self.extract_all(self.pages)
|
42 |
+
|
43 |
+
def extract_all(self, pages):
|
44 |
+
ans = []
|
45 |
+
text = self.get_paragraphs_from_pdf(pages)
|
46 |
+
for data in text:
|
47 |
+
L = [sent for paragraph in data['paragraphs'] for sent in paragraph]
|
48 |
+
reactions = self.get_reactions(L, page_number=data['page'])
|
49 |
+
ans.append(reactions)
|
50 |
+
return ans
|
51 |
+
|
52 |
+
def get_reactions(self, sents, page_number=None):
|
53 |
+
rxns = self.rxn_extractor.get_reactions(sents)
|
54 |
+
|
55 |
+
ret = []
|
56 |
+
for r in rxns:
|
57 |
+
if len(r['reactions']) != 0: ret.append(r)
|
58 |
+
ans = {}
|
59 |
+
ans.update({'page' : page_number})
|
60 |
+
ans.update({'reactions' : ret})
|
61 |
+
return ans
|
62 |
+
|
63 |
+
|
64 |
+
def get_paragraphs_from_pdf(self, pages):
|
65 |
+
current_page_num = 1
|
66 |
+
if pages is None:
|
67 |
+
pages = len(self.pdf_text)
|
68 |
+
result = []
|
69 |
+
for page in range(pages):
|
70 |
+
content = self.pdf_text[page]
|
71 |
+
pg = content.split("\n\n")
|
72 |
+
L = []
|
73 |
+
for line in pg:
|
74 |
+
paragraph = []
|
75 |
+
if '\x0c' in line:
|
76 |
+
continue
|
77 |
+
text = line
|
78 |
+
text = text.replace("\n", " ")
|
79 |
+
text = text.replace("- ", "-")
|
80 |
+
curind = 0
|
81 |
+
i = 0
|
82 |
+
while i < len(text):
|
83 |
+
if text[i] == '.':
|
84 |
+
if i != 0 and not text[i-1].isdigit() or i != len(text) - 1 and (text[i+1] == " " or text[i+1] == "\n"):
|
85 |
+
paragraph.append(text[curind:i+1] + "\n")
|
86 |
+
while(i < len(text) and text[i] != " "):
|
87 |
+
i += 1
|
88 |
+
curind = i + 1
|
89 |
+
i += 1
|
90 |
+
if curind != i:
|
91 |
+
if text[i - 1] == " ":
|
92 |
+
if i != 1:
|
93 |
+
i -= 1
|
94 |
+
else:
|
95 |
+
break
|
96 |
+
if text[i - 1] != '.':
|
97 |
+
paragraph.append(text[curind:i] + ".\n")
|
98 |
+
else:
|
99 |
+
paragraph.append(text[curind:i] + "\n")
|
100 |
+
L.append(paragraph)
|
101 |
+
|
102 |
+
result.append({
|
103 |
+
'paragraphs': L,
|
104 |
+
'page': current_page_num
|
105 |
+
})
|
106 |
+
current_page_num += 1
|
107 |
+
return result
|
chemietoolkit/interface.py
ADDED
@@ -0,0 +1,749 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import re
|
3 |
+
from functools import lru_cache
|
4 |
+
import layoutparser as lp
|
5 |
+
import pdf2image
|
6 |
+
from PIL import Image
|
7 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
8 |
+
from molscribe import MolScribe
|
9 |
+
from rxnscribe import RxnScribe, MolDetect
|
10 |
+
from chemiener import ChemNER
|
11 |
+
from .chemrxnextractor import ChemRxnExtractor
|
12 |
+
from .tableextractor import TableExtractor
|
13 |
+
from .utils import *
|
14 |
+
|
15 |
+
class ChemIEToolkit:
|
16 |
+
def __init__(self, device=None):
|
17 |
+
if device is None:
|
18 |
+
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
19 |
+
else:
|
20 |
+
self.device = torch.device(device)
|
21 |
+
|
22 |
+
self._molscribe = None
|
23 |
+
self._rxnscribe = None
|
24 |
+
self._pdfparser = None
|
25 |
+
self._moldet = None
|
26 |
+
self._chemrxnextractor = None
|
27 |
+
self._chemner = None
|
28 |
+
self._coref = None
|
29 |
+
|
30 |
+
@property
|
31 |
+
def molscribe(self):
|
32 |
+
if self._molscribe is None:
|
33 |
+
self.init_molscribe()
|
34 |
+
return self._molscribe
|
35 |
+
|
36 |
+
@lru_cache(maxsize=None)
|
37 |
+
def init_molscribe(self, ckpt_path=None):
|
38 |
+
"""
|
39 |
+
Set model to custom checkpoint
|
40 |
+
Parameters:
|
41 |
+
ckpt_path: path to checkpoint to use, if None then will use default
|
42 |
+
"""
|
43 |
+
if ckpt_path is None:
|
44 |
+
ckpt_path = hf_hub_download("yujieq/MolScribe", "swin_base_char_aux_1m.pth")
|
45 |
+
self._molscribe = MolScribe(ckpt_path, device=self.device)
|
46 |
+
|
47 |
+
|
48 |
+
@property
|
49 |
+
def rxnscribe(self):
|
50 |
+
if self._rxnscribe is None:
|
51 |
+
self.init_rxnscribe()
|
52 |
+
return self._rxnscribe
|
53 |
+
|
54 |
+
@lru_cache(maxsize=None)
|
55 |
+
def init_rxnscribe(self, ckpt_path=None):
|
56 |
+
"""
|
57 |
+
Set model to custom checkpoint
|
58 |
+
Parameters:
|
59 |
+
ckpt_path: path to checkpoint to use, if None then will use default
|
60 |
+
"""
|
61 |
+
if ckpt_path is None:
|
62 |
+
ckpt_path = hf_hub_download("yujieq/RxnScribe", "pix2seq_reaction_full.ckpt")
|
63 |
+
self._rxnscribe = RxnScribe(ckpt_path, device=self.device)
|
64 |
+
|
65 |
+
|
66 |
+
@property
|
67 |
+
def pdfparser(self):
|
68 |
+
if self._pdfparser is None:
|
69 |
+
self.init_pdfparser()
|
70 |
+
return self._pdfparser
|
71 |
+
|
72 |
+
@lru_cache(maxsize=None)
|
73 |
+
def init_pdfparser(self, ckpt_path=None):
|
74 |
+
"""
|
75 |
+
Set model to custom checkpoint
|
76 |
+
Parameters:
|
77 |
+
ckpt_path: path to checkpoint to use, if None then will use default
|
78 |
+
"""
|
79 |
+
config_path = "lp://efficientdet/PubLayNet/tf_efficientdet_d1"
|
80 |
+
self._pdfparser = lp.AutoLayoutModel(config_path, model_path=ckpt_path, device=self.device.type)
|
81 |
+
|
82 |
+
|
83 |
+
@property
|
84 |
+
def moldet(self):
|
85 |
+
if self._moldet is None:
|
86 |
+
self.init_moldet()
|
87 |
+
return self._moldet
|
88 |
+
|
89 |
+
@lru_cache(maxsize=None)
|
90 |
+
def init_moldet(self, ckpt_path=None):
|
91 |
+
"""
|
92 |
+
Set model to custom checkpoint
|
93 |
+
Parameters:
|
94 |
+
ckpt_path: path to checkpoint to use, if None then will use default
|
95 |
+
"""
|
96 |
+
if ckpt_path is None:
|
97 |
+
ckpt_path = hf_hub_download("Ozymandias314/MolDetectCkpt", "best_hf.ckpt")
|
98 |
+
self._moldet = MolDetect(ckpt_path, device=self.device)
|
99 |
+
|
100 |
+
|
101 |
+
@property
|
102 |
+
def coref(self):
|
103 |
+
if self._coref is None:
|
104 |
+
self.init_coref()
|
105 |
+
return self._coref
|
106 |
+
|
107 |
+
@lru_cache(maxsize=None)
|
108 |
+
def init_coref(self, ckpt_path=None):
|
109 |
+
"""
|
110 |
+
Set model to custom checkpoint
|
111 |
+
Parameters:
|
112 |
+
ckpt_path: path to checkpoint to use, if None then will use default
|
113 |
+
"""
|
114 |
+
if ckpt_path is None:
|
115 |
+
ckpt_path = hf_hub_download("Ozymandias314/MolDetectCkpt", "coref_best_hf.ckpt")
|
116 |
+
self._coref = MolDetect(ckpt_path, device=self.device, coref=True)
|
117 |
+
|
118 |
+
|
119 |
+
@property
|
120 |
+
def chemrxnextractor(self):
|
121 |
+
if self._chemrxnextractor is None:
|
122 |
+
self.init_chemrxnextractor()
|
123 |
+
return self._chemrxnextractor
|
124 |
+
|
125 |
+
@lru_cache(maxsize=None)
|
126 |
+
def init_chemrxnextractor(self, ckpt_path=None):
|
127 |
+
"""
|
128 |
+
Set model to custom checkpoint
|
129 |
+
Parameters:
|
130 |
+
ckpt_path: path to checkpoint to use, if None then will use default
|
131 |
+
"""
|
132 |
+
if ckpt_path is None:
|
133 |
+
ckpt_path = snapshot_download(repo_id="amberwang/chemrxnextractor-training-modules")
|
134 |
+
self._chemrxnextractor = ChemRxnExtractor("", None, ckpt_path, self.device.type)
|
135 |
+
|
136 |
+
|
137 |
+
@property
|
138 |
+
def chemner(self):
|
139 |
+
if self._chemner is None:
|
140 |
+
self.init_chemner()
|
141 |
+
return self._chemner
|
142 |
+
|
143 |
+
@lru_cache(maxsize=None)
|
144 |
+
def init_chemner(self, ckpt_path=None):
|
145 |
+
"""
|
146 |
+
Set model to custom checkpoint
|
147 |
+
Parameters:
|
148 |
+
ckpt_path: path to checkpoint to use, if None then will use default
|
149 |
+
"""
|
150 |
+
if ckpt_path is None:
|
151 |
+
ckpt_path = hf_hub_download("Ozymandias314/ChemNERckpt", "best.ckpt")
|
152 |
+
self._chemner = ChemNER(ckpt_path, device=self.device)
|
153 |
+
|
154 |
+
|
155 |
+
@property
|
156 |
+
def tableextractor(self):
|
157 |
+
return TableExtractor()
|
158 |
+
|
159 |
+
|
160 |
+
def extract_figures_from_pdf(self, pdf, num_pages=None, output_bbox=False, output_image=True):
|
161 |
+
"""
|
162 |
+
Find and return all figures from a pdf page
|
163 |
+
Parameters:
|
164 |
+
pdf: path to pdf
|
165 |
+
num_pages: process only first `num_pages` pages, if `None` then process all
|
166 |
+
output_bbox: whether to output bounding boxes for each individual entry of a table
|
167 |
+
output_image: whether to include PIL image for figures. default is True
|
168 |
+
Returns:
|
169 |
+
list of content in the following format
|
170 |
+
[
|
171 |
+
{ # first figure
|
172 |
+
'title': str,
|
173 |
+
'figure': {
|
174 |
+
'image': PIL image or None,
|
175 |
+
'bbox': list in form [x1, y1, x2, y2],
|
176 |
+
}
|
177 |
+
'table': {
|
178 |
+
'bbox': list in form [x1, y1, x2, y2] or empty list,
|
179 |
+
'content': {
|
180 |
+
'columns': list of column headers,
|
181 |
+
'rows': list of list of row content,
|
182 |
+
} or None
|
183 |
+
}
|
184 |
+
'footnote': str or empty,
|
185 |
+
'page': int
|
186 |
+
}
|
187 |
+
# more figures
|
188 |
+
]
|
189 |
+
"""
|
190 |
+
pages = pdf2image.convert_from_path(pdf, last_page=num_pages)
|
191 |
+
|
192 |
+
table_ext = self.tableextractor
|
193 |
+
table_ext.set_pdf_file(pdf)
|
194 |
+
table_ext.set_output_image(output_image)
|
195 |
+
|
196 |
+
table_ext.set_output_bbox(output_bbox)
|
197 |
+
|
198 |
+
return table_ext.extract_all_tables_and_figures(pages, self.pdfparser, content='figures')
|
199 |
+
|
200 |
+
def extract_tables_from_pdf(self, pdf, num_pages=None, output_bbox=False, output_image=True):
|
201 |
+
"""
|
202 |
+
Find and return all tables from a pdf page
|
203 |
+
Parameters:
|
204 |
+
pdf: path to pdf
|
205 |
+
num_pages: process only first `num_pages` pages, if `None` then process all
|
206 |
+
output_bbox: whether to include bboxes for individual entries of the table
|
207 |
+
output_image: whether to include PIL image for figures. default is True
|
208 |
+
Returns:
|
209 |
+
list of content in the following format
|
210 |
+
[
|
211 |
+
{ # first table
|
212 |
+
'title': str,
|
213 |
+
'figure': {
|
214 |
+
'image': PIL image or None,
|
215 |
+
'bbox': list in form [x1, y1, x2, y2] or empty list,
|
216 |
+
}
|
217 |
+
'table': {
|
218 |
+
'bbox': list in form [x1, y1, x2, y2] or empty list,
|
219 |
+
'content': {
|
220 |
+
'columns': list of column headers,
|
221 |
+
'rows': list of list of row content,
|
222 |
+
}
|
223 |
+
}
|
224 |
+
'footnote': str or empty,
|
225 |
+
'page': int
|
226 |
+
}
|
227 |
+
# more tables
|
228 |
+
]
|
229 |
+
"""
|
230 |
+
pages = pdf2image.convert_from_path(pdf, last_page=num_pages)
|
231 |
+
|
232 |
+
table_ext = self.tableextractor
|
233 |
+
table_ext.set_pdf_file(pdf)
|
234 |
+
table_ext.set_output_image(output_image)
|
235 |
+
|
236 |
+
table_ext.set_output_bbox(output_bbox)
|
237 |
+
|
238 |
+
return table_ext.extract_all_tables_and_figures(pages, self.pdfparser, content='tables')
|
239 |
+
|
240 |
+
def extract_molecules_from_figures_in_pdf(self, pdf, batch_size=16, num_pages=None):
|
241 |
+
"""
|
242 |
+
Get all molecules and their information from a pdf
|
243 |
+
Parameters:
|
244 |
+
pdf: path to pdf, or byte file
|
245 |
+
batch_size: batch size for inference in all models
|
246 |
+
num_pages: process only first `num_pages` pages, if `None` then process all
|
247 |
+
Returns:
|
248 |
+
list of figures and corresponding molecule info in the following format
|
249 |
+
[
|
250 |
+
{ # first figure
|
251 |
+
'image': ndarray of the figure image,
|
252 |
+
'molecules': [
|
253 |
+
{ # first molecule
|
254 |
+
'bbox': tuple in the form (x1, y1, x2, y2),
|
255 |
+
'score': float,
|
256 |
+
'image': ndarray of cropped molecule image,
|
257 |
+
'smiles': str,
|
258 |
+
'molfile': str
|
259 |
+
},
|
260 |
+
# more molecules
|
261 |
+
],
|
262 |
+
'page': int
|
263 |
+
},
|
264 |
+
# more figures
|
265 |
+
]
|
266 |
+
"""
|
267 |
+
figures = self.extract_figures_from_pdf(pdf, num_pages=num_pages, output_bbox=True)
|
268 |
+
images = [figure['figure']['image'] for figure in figures]
|
269 |
+
results = self.extract_molecules_from_figures(images, batch_size=batch_size)
|
270 |
+
for figure, result in zip(figures, results):
|
271 |
+
result['page'] = figure['page']
|
272 |
+
return results
|
273 |
+
|
274 |
+
def extract_molecule_bboxes_from_figures(self, figures, batch_size=16):
|
275 |
+
"""
|
276 |
+
Return bounding boxes of molecules in images
|
277 |
+
Parameters:
|
278 |
+
figures: list of PIL or ndarray images
|
279 |
+
batch_size: batch size for inference
|
280 |
+
Returns:
|
281 |
+
list of results for each figure in the following format
|
282 |
+
[
|
283 |
+
[ # first figure
|
284 |
+
{ # first bounding box
|
285 |
+
'category': str,
|
286 |
+
'bbox': tuple in the form (x1, y1, x2, y2),
|
287 |
+
'category_id': int,
|
288 |
+
'score': float
|
289 |
+
},
|
290 |
+
# more bounding boxes
|
291 |
+
],
|
292 |
+
# more figures
|
293 |
+
]
|
294 |
+
"""
|
295 |
+
figures = [convert_to_pil(figure) for figure in figures]
|
296 |
+
return self.moldet.predict_images(figures, batch_size=batch_size)
|
297 |
+
|
298 |
+
def extract_molecules_from_figures(self, figures, batch_size=16):
|
299 |
+
"""
|
300 |
+
Get all molecules and their information from list of figures
|
301 |
+
Parameters:
|
302 |
+
figures: list of PIL or ndarray images
|
303 |
+
batch_size: batch size for inference
|
304 |
+
Returns:
|
305 |
+
list of results for each figure in the following format
|
306 |
+
[
|
307 |
+
{ # first figure
|
308 |
+
'image': ndarray of the figure image,
|
309 |
+
'molecules': [
|
310 |
+
{ # first molecule
|
311 |
+
'bbox': tuple in the form (x1, y1, x2, y2),
|
312 |
+
'score': float,
|
313 |
+
'image': ndarray of cropped molecule image,
|
314 |
+
'smiles': str,
|
315 |
+
'molfile': str
|
316 |
+
},
|
317 |
+
# more molecules
|
318 |
+
],
|
319 |
+
},
|
320 |
+
# more figures
|
321 |
+
]
|
322 |
+
"""
|
323 |
+
bboxes = self.extract_molecule_bboxes_from_figures(figures, batch_size=batch_size)
|
324 |
+
figures = [convert_to_cv2(figure) for figure in figures]
|
325 |
+
results, cropped_images, refs = clean_bbox_output(figures, bboxes)
|
326 |
+
mol_info = self.molscribe.predict_images(cropped_images, batch_size=batch_size)
|
327 |
+
for info, ref in zip(mol_info, refs):
|
328 |
+
ref.update(info)
|
329 |
+
return results
|
330 |
+
|
331 |
+
def extract_molecule_corefs_from_figures_in_pdf(self, pdf, batch_size=16, num_pages=None, molscribe = True, ocr = True):
|
332 |
+
"""
|
333 |
+
Get all molecule bboxes and corefs from figures in pdf
|
334 |
+
Parameters:
|
335 |
+
pdf: path to pdf, or byte file
|
336 |
+
batch_size: batch size for inference in all models
|
337 |
+
num_pages: process only first `num_pages` pages, if `None` then process all
|
338 |
+
Returns:
|
339 |
+
list of results for each figure in the following format:
|
340 |
+
[
|
341 |
+
{
|
342 |
+
'bboxes': [
|
343 |
+
{ # first bbox
|
344 |
+
'category': '[Sup]',
|
345 |
+
'bbox': (0.0050025012506253125, 0.38273870663142223, 0.9934967483741871, 0.9450094869920168),
|
346 |
+
'category_id': 4,
|
347 |
+
'score': -0.07593922317028046
|
348 |
+
},
|
349 |
+
# More bounding boxes
|
350 |
+
],
|
351 |
+
'corefs': [
|
352 |
+
[0, 1], # molecule bbox index, identifier bbox index
|
353 |
+
[3, 4],
|
354 |
+
# More coref pairs
|
355 |
+
],
|
356 |
+
'page': int
|
357 |
+
},
|
358 |
+
# More figures
|
359 |
+
]
|
360 |
+
"""
|
361 |
+
figures = self.extract_figures_from_pdf(pdf, num_pages=num_pages, output_bbox=True)
|
362 |
+
images = [figure['figure']['image'] for figure in figures]
|
363 |
+
results = self.extract_molecule_corefs_from_figures(images, batch_size=batch_size, molscribe=molscribe, ocr=ocr)
|
364 |
+
for figure, result in zip(figures, results):
|
365 |
+
result['page'] = figure['page']
|
366 |
+
return results
|
367 |
+
|
368 |
+
def extract_molecule_corefs_from_figures(self, figures, batch_size=16, molscribe=True, ocr=True):
|
369 |
+
"""
|
370 |
+
Get all molecule bboxes and corefs from list of figures
|
371 |
+
Parameters:
|
372 |
+
figures: list of PIL or ndarray images
|
373 |
+
batch_size: batch size for inference
|
374 |
+
Returns:
|
375 |
+
list of results for each figure in the following format:
|
376 |
+
[
|
377 |
+
{
|
378 |
+
'bboxes': [
|
379 |
+
{ # first bbox
|
380 |
+
'category': '[Sup]',
|
381 |
+
'bbox': (0.0050025012506253125, 0.38273870663142223, 0.9934967483741871, 0.9450094869920168),
|
382 |
+
'category_id': 4,
|
383 |
+
'score': -0.07593922317028046
|
384 |
+
},
|
385 |
+
# More bounding boxes
|
386 |
+
],
|
387 |
+
'corefs': [
|
388 |
+
[0, 1], # molecule bbox index, identifier bbox index
|
389 |
+
[3, 4],
|
390 |
+
# More coref pairs
|
391 |
+
],
|
392 |
+
},
|
393 |
+
# More figures
|
394 |
+
]
|
395 |
+
"""
|
396 |
+
figures = [convert_to_pil(figure) for figure in figures]
|
397 |
+
return self.coref.predict_images(figures, batch_size=batch_size, coref=True, molscribe = molscribe, ocr = ocr)
|
398 |
+
|
399 |
+
def extract_reactions_from_figures_in_pdf(self, pdf, batch_size=16, num_pages=None, molscribe=True, ocr=True):
|
400 |
+
"""
|
401 |
+
Get reaction information from figures in pdf
|
402 |
+
Parameters:
|
403 |
+
pdf: path to pdf, or byte file
|
404 |
+
batch_size: batch size for inference in all models
|
405 |
+
num_pages: process only first `num_pages` pages, if `None` then process all
|
406 |
+
molscribe: whether to predict and return smiles and molfile info
|
407 |
+
ocr: whether to predict and return text of conditions
|
408 |
+
Returns:
|
409 |
+
list of figures and corresponding molecule info in the following format
|
410 |
+
[
|
411 |
+
{
|
412 |
+
'figure': PIL image
|
413 |
+
'reactions': [
|
414 |
+
{
|
415 |
+
'reactants': [
|
416 |
+
{
|
417 |
+
'category': str,
|
418 |
+
'bbox': tuple (x1,x2,y1,y2),
|
419 |
+
'category_id': int,
|
420 |
+
'smiles': str,
|
421 |
+
'molfile': str,
|
422 |
+
},
|
423 |
+
# more reactants
|
424 |
+
],
|
425 |
+
'conditions': [
|
426 |
+
{
|
427 |
+
'category': str,
|
428 |
+
'bbox': tuple (x1,x2,y1,y2),
|
429 |
+
'category_id': int,
|
430 |
+
'text': list of str,
|
431 |
+
},
|
432 |
+
# more conditions
|
433 |
+
],
|
434 |
+
'products': [
|
435 |
+
# same structure as reactants
|
436 |
+
]
|
437 |
+
},
|
438 |
+
# more reactions
|
439 |
+
],
|
440 |
+
'page': int
|
441 |
+
},
|
442 |
+
# more figures
|
443 |
+
]
|
444 |
+
"""
|
445 |
+
figures = self.extract_figures_from_pdf(pdf, num_pages=num_pages, output_bbox=True)
|
446 |
+
images = [figure['figure']['image'] for figure in figures]
|
447 |
+
results = self.extract_reactions_from_figures(images, batch_size=batch_size, molscribe=molscribe, ocr=ocr)
|
448 |
+
for figure, result in zip(figures, results):
|
449 |
+
result['page'] = figure['page']
|
450 |
+
return results
|
451 |
+
|
452 |
+
def extract_reactions_from_figures(self, figures, batch_size=16, molscribe=True, ocr=True):
|
453 |
+
"""
|
454 |
+
Get reaction information from list of figures
|
455 |
+
Parameters:
|
456 |
+
figures: list of PIL or ndarray images
|
457 |
+
batch_size: batch size for inference in all models
|
458 |
+
molscribe: whether to predict and return smiles and molfile info
|
459 |
+
ocr: whether to predict and return text of conditions
|
460 |
+
Returns:
|
461 |
+
list of figures and corresponding molecule info in the following format
|
462 |
+
[
|
463 |
+
{
|
464 |
+
'figure': PIL image
|
465 |
+
'reactions': [
|
466 |
+
{
|
467 |
+
'reactants': [
|
468 |
+
{
|
469 |
+
'category': str,
|
470 |
+
'bbox': tuple (x1,x2,y1,y2),
|
471 |
+
'category_id': int,
|
472 |
+
'smiles': str,
|
473 |
+
'molfile': str,
|
474 |
+
},
|
475 |
+
# more reactants
|
476 |
+
],
|
477 |
+
'conditions': [
|
478 |
+
{
|
479 |
+
'category': str,
|
480 |
+
'bbox': tuple (x1,x2,y1,y2),
|
481 |
+
'category_id': int,
|
482 |
+
'text': list of str,
|
483 |
+
},
|
484 |
+
# more conditions
|
485 |
+
],
|
486 |
+
'products': [
|
487 |
+
# same structure as reactants
|
488 |
+
]
|
489 |
+
},
|
490 |
+
# more reactions
|
491 |
+
],
|
492 |
+
},
|
493 |
+
# more figures
|
494 |
+
]
|
495 |
+
|
496 |
+
"""
|
497 |
+
pil_figures = [convert_to_pil(figure) for figure in figures]
|
498 |
+
results = []
|
499 |
+
reactions = self.rxnscribe.predict_images(pil_figures, batch_size=batch_size, molscribe=molscribe, ocr=ocr)
|
500 |
+
for figure, rxn in zip(figures, reactions):
|
501 |
+
data = {
|
502 |
+
'figure': figure,
|
503 |
+
'reactions': rxn,
|
504 |
+
}
|
505 |
+
results.append(data)
|
506 |
+
return results
|
507 |
+
|
508 |
+
def extract_molecules_from_text_in_pdf(self, pdf, batch_size=16, num_pages=None):
|
509 |
+
"""
|
510 |
+
Get molecules in text of given pdf
|
511 |
+
|
512 |
+
Parameters:
|
513 |
+
pdf: path to pdf, or byte file
|
514 |
+
batch_size: batch size for inference in all models
|
515 |
+
num_pages: process only first `num_pages` pages, if `None` then process all
|
516 |
+
Returns:
|
517 |
+
list of sentences and found molecules in the following format
|
518 |
+
[
|
519 |
+
{
|
520 |
+
'molecules': [
|
521 |
+
{ # first paragraph
|
522 |
+
'text': str,
|
523 |
+
'labels': [
|
524 |
+
(str, int, int), # tuple of label, range start (inclusive), range end (exclusive)
|
525 |
+
# more labels
|
526 |
+
]
|
527 |
+
},
|
528 |
+
# more paragraphs
|
529 |
+
]
|
530 |
+
'page': int
|
531 |
+
},
|
532 |
+
# more pages
|
533 |
+
]
|
534 |
+
"""
|
535 |
+
self.chemrxnextractor.set_pdf_file(pdf)
|
536 |
+
self.chemrxnextractor.set_pages(num_pages)
|
537 |
+
text = self.chemrxnextractor.get_paragraphs_from_pdf(num_pages)
|
538 |
+
result = []
|
539 |
+
for data in text:
|
540 |
+
model_inp = []
|
541 |
+
for paragraph in data['paragraphs']:
|
542 |
+
model_inp.append(' '.join(paragraph).replace('\n', ''))
|
543 |
+
output = self.chemner.predict_strings(model_inp, batch_size=batch_size)
|
544 |
+
to_add = {
|
545 |
+
'molecules': [{
|
546 |
+
'text': t,
|
547 |
+
'labels': labels,
|
548 |
+
} for t, labels in zip(model_inp, output)],
|
549 |
+
'page': data['page']
|
550 |
+
}
|
551 |
+
result.append(to_add)
|
552 |
+
return result
|
553 |
+
|
554 |
+
|
555 |
+
def extract_reactions_from_text_in_pdf(self, pdf, num_pages=None):
|
556 |
+
"""
|
557 |
+
Get reaction information from text in pdf
|
558 |
+
Parameters:
|
559 |
+
pdf: path to pdf
|
560 |
+
num_pages: process only first `num_pages` pages, if `None` then process all
|
561 |
+
Returns:
|
562 |
+
list of pages and corresponding reaction info in the following format
|
563 |
+
[
|
564 |
+
{
|
565 |
+
'page': page number
|
566 |
+
'reactions': [
|
567 |
+
{
|
568 |
+
'tokens': list of words in relevant sentence,
|
569 |
+
'reactions' : [
|
570 |
+
{
|
571 |
+
# key, value pairs where key is the label and value is a tuple
|
572 |
+
# or list of tuples of the form (tokens, start index, end index)
|
573 |
+
# where indices are for the corresponding token list and start and end are inclusive
|
574 |
+
}
|
575 |
+
# more reactions
|
576 |
+
]
|
577 |
+
}
|
578 |
+
# more reactions in other sentences
|
579 |
+
]
|
580 |
+
},
|
581 |
+
# more pages
|
582 |
+
]
|
583 |
+
"""
|
584 |
+
self.chemrxnextractor.set_pdf_file(pdf)
|
585 |
+
self.chemrxnextractor.set_pages(num_pages)
|
586 |
+
return self.chemrxnextractor.extract_reactions_from_text()
|
587 |
+
|
588 |
+
def extract_reactions_from_text_in_pdf_combined(self, pdf, num_pages=None):
|
589 |
+
"""
|
590 |
+
Get reaction information from text in pdf and combined with corefs from figures
|
591 |
+
Parameters:
|
592 |
+
pdf: path to pdf
|
593 |
+
num_pages: process only first `num_pages` pages, if `None` then process all
|
594 |
+
Returns:
|
595 |
+
list of pages and corresponding reaction info in the following format
|
596 |
+
[
|
597 |
+
{
|
598 |
+
'page': page number
|
599 |
+
'reactions': [
|
600 |
+
{
|
601 |
+
'tokens': list of words in relevant sentence,
|
602 |
+
'reactions' : [
|
603 |
+
{
|
604 |
+
# key, value pairs where key is the label and value is a tuple
|
605 |
+
# or list of tuples of the form (tokens, start index, end index)
|
606 |
+
# where indices are for the corresponding token list and start and end are inclusive
|
607 |
+
}
|
608 |
+
# more reactions
|
609 |
+
]
|
610 |
+
}
|
611 |
+
# more reactions in other sentences
|
612 |
+
]
|
613 |
+
},
|
614 |
+
# more pages
|
615 |
+
]
|
616 |
+
"""
|
617 |
+
results = self.extract_reactions_from_text_in_pdf(pdf, num_pages=num_pages)
|
618 |
+
results_coref = self.extract_molecule_corefs_from_figures_in_pdf(pdf, num_pages=num_pages)
|
619 |
+
return associate_corefs(results, results_coref)
|
620 |
+
|
621 |
+
def extract_reactions_from_figures_and_tables_in_pdf(self, pdf, num_pages=None, batch_size=16, molscribe=True, ocr=True):
|
622 |
+
"""
|
623 |
+
Get reaction information from figures and combine with table information in pdf
|
624 |
+
Parameters:
|
625 |
+
pdf: path to pdf, or byte file
|
626 |
+
batch_size: batch size for inference in all models
|
627 |
+
num_pages: process only first `num_pages` pages, if `None` then process all
|
628 |
+
molscribe: whether to predict and return smiles and molfile info
|
629 |
+
ocr: whether to predict and return text of conditions
|
630 |
+
Returns:
|
631 |
+
list of figures and corresponding molecule info in the following format
|
632 |
+
[
|
633 |
+
{
|
634 |
+
'figure': PIL image
|
635 |
+
'reactions': [
|
636 |
+
{
|
637 |
+
'reactants': [
|
638 |
+
{
|
639 |
+
'category': str,
|
640 |
+
'bbox': tuple (x1,x2,y1,y2),
|
641 |
+
'category_id': int,
|
642 |
+
'smiles': str,
|
643 |
+
'molfile': str,
|
644 |
+
},
|
645 |
+
# more reactants
|
646 |
+
],
|
647 |
+
'conditions': [
|
648 |
+
{
|
649 |
+
'category': str,
|
650 |
+
'text': list of str,
|
651 |
+
},
|
652 |
+
# more conditions
|
653 |
+
],
|
654 |
+
'products': [
|
655 |
+
# same structure as reactants
|
656 |
+
]
|
657 |
+
},
|
658 |
+
# more reactions
|
659 |
+
],
|
660 |
+
'page': int
|
661 |
+
},
|
662 |
+
# more figures
|
663 |
+
]
|
664 |
+
"""
|
665 |
+
figures = self.extract_figures_from_pdf(pdf, num_pages=num_pages, output_bbox=True)
|
666 |
+
images = [figure['figure']['image'] for figure in figures]
|
667 |
+
results = self.extract_reactions_from_figures(images, batch_size=batch_size, molscribe=molscribe, ocr=ocr)
|
668 |
+
results = process_tables(figures, results, self.molscribe, batch_size=batch_size)
|
669 |
+
results_coref = self.extract_molecule_corefs_from_figures_in_pdf(pdf, num_pages=num_pages)
|
670 |
+
results = replace_rgroups_in_figure(figures, results, results_coref, self.molscribe, batch_size=batch_size)
|
671 |
+
results = expand_reactions_with_backout(results, results_coref, self.molscribe)
|
672 |
+
return results
|
673 |
+
|
674 |
+
def extract_reactions_from_pdf(self, pdf, num_pages=None, batch_size=16):
|
675 |
+
"""
|
676 |
+
Returns:
|
677 |
+
dictionary of reactions from multimodal sources
|
678 |
+
{
|
679 |
+
'figures': [
|
680 |
+
{
|
681 |
+
'figure': PIL image
|
682 |
+
'reactions': [
|
683 |
+
{
|
684 |
+
'reactants': [
|
685 |
+
{
|
686 |
+
'category': str,
|
687 |
+
'bbox': tuple (x1,x2,y1,y2),
|
688 |
+
'category_id': int,
|
689 |
+
'smiles': str,
|
690 |
+
'molfile': str,
|
691 |
+
},
|
692 |
+
# more reactants
|
693 |
+
],
|
694 |
+
'conditions': [
|
695 |
+
{
|
696 |
+
'category': str,
|
697 |
+
'text': list of str,
|
698 |
+
},
|
699 |
+
# more conditions
|
700 |
+
],
|
701 |
+
'products': [
|
702 |
+
# same structure as reactants
|
703 |
+
]
|
704 |
+
},
|
705 |
+
# more reactions
|
706 |
+
],
|
707 |
+
'page': int
|
708 |
+
},
|
709 |
+
# more figures
|
710 |
+
]
|
711 |
+
'text': [
|
712 |
+
{
|
713 |
+
'page': page number
|
714 |
+
'reactions': [
|
715 |
+
{
|
716 |
+
'tokens': list of words in relevant sentence,
|
717 |
+
'reactions' : [
|
718 |
+
{
|
719 |
+
# key, value pairs where key is the label and value is a tuple
|
720 |
+
# or list of tuples of the form (tokens, start index, end index)
|
721 |
+
# where indices are for the corresponding token list and start and end are inclusive
|
722 |
+
}
|
723 |
+
# more reactions
|
724 |
+
]
|
725 |
+
}
|
726 |
+
# more reactions in other sentences
|
727 |
+
]
|
728 |
+
},
|
729 |
+
# more pages
|
730 |
+
]
|
731 |
+
}
|
732 |
+
|
733 |
+
"""
|
734 |
+
figures = self.extract_figures_from_pdf(pdf, num_pages=num_pages, output_bbox=True)
|
735 |
+
images = [figure['figure']['image'] for figure in figures]
|
736 |
+
results = self.extract_reactions_from_figures(images, batch_size=batch_size, molscribe=True, ocr=True)
|
737 |
+
table_expanded_results = process_tables(figures, results, self.molscribe, batch_size=batch_size)
|
738 |
+
text_results = self.extract_reactions_from_text_in_pdf(pdf, num_pages=num_pages)
|
739 |
+
results_coref = self.extract_molecule_corefs_from_figures_in_pdf(pdf, num_pages=num_pages)
|
740 |
+
figure_results = replace_rgroups_in_figure(figures, table_expanded_results, results_coref, self.molscribe, batch_size=batch_size)
|
741 |
+
table_expanded_results = expand_reactions_with_backout(figure_results, results_coref, self.molscribe)
|
742 |
+
coref_expanded_results = associate_corefs(text_results, results_coref)
|
743 |
+
return {
|
744 |
+
'figures': table_expanded_results,
|
745 |
+
'text': coref_expanded_results,
|
746 |
+
}
|
747 |
+
|
748 |
+
if __name__=="__main__":
|
749 |
+
model = OpenChemIE()
|
chemietoolkit/tableextractor.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdf2image
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import layoutparser as lp
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
from PyPDF2 import PdfReader, PdfWriter
|
9 |
+
import pandas as pd
|
10 |
+
|
11 |
+
import pdfminer.high_level
|
12 |
+
import pdfminer.layout
|
13 |
+
from operator import itemgetter
|
14 |
+
|
15 |
+
# inputs: pdf_file, page #, bounding box (optional) (llur or ullr), output_bbox
|
16 |
+
class TableExtractor(object):
|
17 |
+
def __init__(self, output_bbox=True):
|
18 |
+
self.pdf_file = ""
|
19 |
+
self.page = ""
|
20 |
+
self.image_dpi = 200
|
21 |
+
self.pdf_dpi = 72
|
22 |
+
self.output_bbox = output_bbox
|
23 |
+
self.blocks = {}
|
24 |
+
self.title_y = 0
|
25 |
+
self.column_header_y = 0
|
26 |
+
self.model = None
|
27 |
+
self.img = None
|
28 |
+
self.output_image = True
|
29 |
+
self.tagging = {
|
30 |
+
'substance': ['compound', 'salt', 'base', 'solvent', 'CBr4', 'collidine', 'InX3', 'substrate', 'ligand', 'PPh3', 'PdL2', 'Cu', 'compd', 'reagent', 'reagant', 'acid', 'aldehyde', 'amine', 'Ln', 'H2O', 'enzyme', 'cofactor', 'oxidant', 'Pt(COD)Cl2', 'CuBr2', 'additive'],
|
31 |
+
'ratio': [':'],
|
32 |
+
'measurement': ['μM', 'nM', 'IC50', 'CI', 'excitation', 'emission', 'Φ', 'φ', 'shift', 'ee', 'ΔG', 'ΔH', 'TΔS', 'Δ', 'distance', 'trajectory', 'V', 'eV'],
|
33 |
+
'temperature': ['temp', 'temperature', 'T', '°C'],
|
34 |
+
'time': ['time', 't(', 't ('],
|
35 |
+
'result': ['yield', 'aa', 'result', 'product', 'conversion', '(%)'],
|
36 |
+
'alkyl group': ['R', 'Ar', 'X', 'Y'],
|
37 |
+
'solvent': ['solvent'],
|
38 |
+
'counter': ['entry', 'no.'],
|
39 |
+
'catalyst': ['catalyst', 'cat.'],
|
40 |
+
'conditions': ['condition'],
|
41 |
+
'reactant': ['reactant'],
|
42 |
+
}
|
43 |
+
|
44 |
+
def set_output_image(self, oi):
|
45 |
+
self.output_image = oi
|
46 |
+
|
47 |
+
def set_pdf_file(self, pdf):
|
48 |
+
self.pdf_file = pdf
|
49 |
+
|
50 |
+
def set_page_num(self, pn):
|
51 |
+
self.page = pn
|
52 |
+
|
53 |
+
def set_output_bbox(self, ob):
|
54 |
+
self.output_bbox = ob
|
55 |
+
|
56 |
+
def run_model(self, page_info):
|
57 |
+
#img = np.asarray(pdf2image.convert_from_path(self.pdf_file, dpi=self.image_dpi)[self.page])
|
58 |
+
|
59 |
+
#model = lp.Detectron2LayoutModel('lp://PubLayNet/mask_rcnn_X_101_32x8d_FPN_3x/config', extra_config=["MODEL.ROI_HEADS.SCORE_THRESH_TEST", 0.5], label_map={0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"})
|
60 |
+
|
61 |
+
img = np.asarray(page_info)
|
62 |
+
self.img = img
|
63 |
+
|
64 |
+
layout_result = self.model.detect(img)
|
65 |
+
|
66 |
+
text_blocks = lp.Layout([b for b in layout_result if b.type == 'Text'])
|
67 |
+
title_blocks = lp.Layout([b for b in layout_result if b.type == 'Title'])
|
68 |
+
list_blocks = lp.Layout([b for b in layout_result if b.type == 'List'])
|
69 |
+
table_blocks = lp.Layout([b for b in layout_result if b.type == 'Table'])
|
70 |
+
figure_blocks = lp.Layout([b for b in layout_result if b.type == 'Figure'])
|
71 |
+
|
72 |
+
self.blocks.update({'text': text_blocks})
|
73 |
+
self.blocks.update({'title': title_blocks})
|
74 |
+
self.blocks.update({'list': list_blocks})
|
75 |
+
self.blocks.update({'table': table_blocks})
|
76 |
+
self.blocks.update({'figure': figure_blocks})
|
77 |
+
|
78 |
+
# type is what coordinates you want to get. it comes in text, title, list, table, and figure
|
79 |
+
def convert_to_pdf_coordinates(self, type):
|
80 |
+
# scale coordinates
|
81 |
+
|
82 |
+
blocks = self.blocks[type]
|
83 |
+
coordinates = [blocks[a].scale(self.pdf_dpi/self.image_dpi) for a in range(len(blocks))]
|
84 |
+
|
85 |
+
reader = PdfReader(self.pdf_file)
|
86 |
+
|
87 |
+
writer = PdfWriter()
|
88 |
+
p = reader.pages[self.page]
|
89 |
+
a = p.mediabox.upper_left
|
90 |
+
new_coords = []
|
91 |
+
for new_block in coordinates:
|
92 |
+
new_coords.append((new_block.block.x_1, pd.to_numeric(a[1]) - new_block.block.y_2, new_block.block.x_2, pd.to_numeric(a[1]) - new_block.block.y_1))
|
93 |
+
|
94 |
+
return new_coords
|
95 |
+
# output: list of bounding boxes for tables but in pdf coordinates
|
96 |
+
|
97 |
+
# input: new_coords is singular table bounding box in pdf coordinates
|
98 |
+
def extract_singular_table(self, new_coords):
|
99 |
+
for page_layout in pdfminer.high_level.extract_pages(self.pdf_file, page_numbers=[self.page]):
|
100 |
+
elements = []
|
101 |
+
for element in page_layout:
|
102 |
+
if isinstance(element, pdfminer.layout.LTTextBox):
|
103 |
+
for e in element._objs:
|
104 |
+
temp = e.bbox
|
105 |
+
if temp[0] > min(new_coords[0], new_coords[2]) and temp[0] < max(new_coords[0], new_coords[2]) and temp[1] > min(new_coords[1], new_coords[3]) and temp[1] < max(new_coords[1], new_coords[3]) and temp[2] > min(new_coords[0], new_coords[2]) and temp[2] < max(new_coords[0], new_coords[2]) and temp[3] > min(new_coords[1], new_coords[3]) and temp[3] < max(new_coords[1], new_coords[3]) and isinstance(e, pdfminer.layout.LTTextLineHorizontal):
|
106 |
+
elements.append([e.bbox[0], e.bbox[1], e.bbox[2], e.bbox[3], e.get_text()])
|
107 |
+
|
108 |
+
elements = sorted(elements, key=itemgetter(0))
|
109 |
+
w = sorted(elements, key=itemgetter(3), reverse=True)
|
110 |
+
if len(w) <= 1:
|
111 |
+
continue
|
112 |
+
|
113 |
+
ret = {}
|
114 |
+
i = 1
|
115 |
+
g = [w[0]]
|
116 |
+
|
117 |
+
while i < len(w) and w[i][3] > w[i-1][1]:
|
118 |
+
g.append(w[i])
|
119 |
+
i += 1
|
120 |
+
g = sorted(g, key=itemgetter(0))
|
121 |
+
# check for overlaps
|
122 |
+
for a in range(len(g)-1, 0, -1):
|
123 |
+
if g[a][0] < g[a-1][2]:
|
124 |
+
g[a-1][0] = min(g[a][0], g[a-1][0])
|
125 |
+
g[a-1][1] = min(g[a][1], g[a-1][1])
|
126 |
+
g[a-1][2] = max(g[a][2], g[a-1][2])
|
127 |
+
g[a-1][3] = max(g[a][3], g[a-1][3])
|
128 |
+
g[a-1][4] = g[a-1][4].strip() + " " + g[a][4]
|
129 |
+
g.pop(a)
|
130 |
+
|
131 |
+
|
132 |
+
ret.update({"columns":[]})
|
133 |
+
for t in g:
|
134 |
+
temp_bbox = t[:4]
|
135 |
+
|
136 |
+
column_text = t[4].strip()
|
137 |
+
tag = 'unknown'
|
138 |
+
tagged = False
|
139 |
+
for key in self.tagging.keys():
|
140 |
+
for word in self.tagging[key]:
|
141 |
+
if word in column_text:
|
142 |
+
tag = key
|
143 |
+
tagged = True
|
144 |
+
break
|
145 |
+
if tagged:
|
146 |
+
break
|
147 |
+
|
148 |
+
if self.output_bbox:
|
149 |
+
ret["columns"].append({'text':column_text,'tag': tag, 'bbox':temp_bbox})
|
150 |
+
else:
|
151 |
+
ret["columns"].append({'text':column_text,'tag': tag})
|
152 |
+
self.column_header_y = max(t[1], t[3])
|
153 |
+
ret.update({"rows":[]})
|
154 |
+
|
155 |
+
g.insert(0, [0, 0, new_coords[0], 0, ''])
|
156 |
+
g.append([new_coords[2], 0, 0, 0, ''])
|
157 |
+
while i < len(w):
|
158 |
+
group = [w[i]]
|
159 |
+
i += 1
|
160 |
+
while i < len(w) and w[i][3] > w[i-1][1]:
|
161 |
+
group.append(w[i])
|
162 |
+
i += 1
|
163 |
+
group = sorted(group, key=itemgetter(0))
|
164 |
+
|
165 |
+
for a in range(len(group)-1, 0, -1):
|
166 |
+
if group[a][0] < group[a-1][2]:
|
167 |
+
group[a-1][0] = min(group[a][0], group[a-1][0])
|
168 |
+
group[a-1][1] = min(group[a][1], group[a-1][1])
|
169 |
+
group[a-1][2] = max(group[a][2], group[a-1][2])
|
170 |
+
group[a-1][3] = max(group[a][3], group[a-1][3])
|
171 |
+
group[a-1][4] = group[a-1][4].strip() + " " + group[a][4]
|
172 |
+
group.pop(a)
|
173 |
+
|
174 |
+
a = 1
|
175 |
+
while a < len(g) - 1:
|
176 |
+
if a > len(group):
|
177 |
+
group.append([0, 0, 0, 0, '\n'])
|
178 |
+
a += 1
|
179 |
+
continue
|
180 |
+
if group[a-1][0] >= g[a-1][2] and group[a-1][2] <= g[a+1][0]:
|
181 |
+
pass
|
182 |
+
"""
|
183 |
+
if a < len(group) and group[a][0] >= g[a-1][2] and group[a][2] <= g[a+1][0]:
|
184 |
+
g.insert(1, [g[0][2], 0, group[a-1][2], 0, ''])
|
185 |
+
#ret["columns"].insert(0, '')
|
186 |
+
else:
|
187 |
+
a += 1
|
188 |
+
continue
|
189 |
+
"""
|
190 |
+
else: group.insert(a-1, [0, 0, 0, 0, '\n'])
|
191 |
+
a += 1
|
192 |
+
|
193 |
+
|
194 |
+
added_row = []
|
195 |
+
for t in group:
|
196 |
+
temp_bbox = t[:4]
|
197 |
+
if self.output_bbox:
|
198 |
+
added_row.append({'text':t[4].strip(), 'bbox':temp_bbox})
|
199 |
+
else:
|
200 |
+
added_row.append(t[4].strip())
|
201 |
+
ret["rows"].append(added_row)
|
202 |
+
if ret["rows"] and len(ret["rows"][0]) != len(ret["columns"]):
|
203 |
+
ret["columns"] = ret["rows"][0]
|
204 |
+
ret["rows"] = ret["rows"][1:]
|
205 |
+
for col in ret['columns']:
|
206 |
+
tag = 'unknown'
|
207 |
+
tagged = False
|
208 |
+
for key in self.tagging.keys():
|
209 |
+
for word in self.tagging[key]:
|
210 |
+
if word in col['text']:
|
211 |
+
tag = key
|
212 |
+
tagged = True
|
213 |
+
break
|
214 |
+
if tagged:
|
215 |
+
break
|
216 |
+
col['tag'] = tag
|
217 |
+
|
218 |
+
return ret
|
219 |
+
|
220 |
+
def get_title_and_footnotes(self, tb_coords):
|
221 |
+
|
222 |
+
for page_layout in pdfminer.high_level.extract_pages(self.pdf_file, page_numbers=[self.page]):
|
223 |
+
title = (0, 0, 0, 0, '')
|
224 |
+
footnote = (0, 0, 0, 0, '')
|
225 |
+
title_gap = 30
|
226 |
+
footnote_gap = 30
|
227 |
+
for element in page_layout:
|
228 |
+
if isinstance(element, pdfminer.layout.LTTextBoxHorizontal):
|
229 |
+
if (element.bbox[0] >= tb_coords[0] and element.bbox[0] <= tb_coords[2]) or (element.bbox[2] >= tb_coords[0] and element.bbox[2] <= tb_coords[2]) or (tb_coords[0] >= element.bbox[0] and tb_coords[0] <= element.bbox[2]) or (tb_coords[2] >= element.bbox[0] and tb_coords[2] <= element.bbox[2]):
|
230 |
+
#print(element)
|
231 |
+
if 'Table' in element.get_text():
|
232 |
+
if abs(element.bbox[1] - tb_coords[3]) < title_gap:
|
233 |
+
title = tuple(element.bbox) + (element.get_text()[element.get_text().index('Table'):].replace('\n', ' '),)
|
234 |
+
title_gap = abs(element.bbox[1] - tb_coords[3])
|
235 |
+
if 'Scheme' in element.get_text():
|
236 |
+
if abs(element.bbox[1] - tb_coords[3]) < title_gap:
|
237 |
+
title = tuple(element.bbox) + (element.get_text()[element.get_text().index('Scheme'):].replace('\n', ' '),)
|
238 |
+
title_gap = abs(element.bbox[1] - tb_coords[3])
|
239 |
+
if element.bbox[1] >= tb_coords[1] and element.bbox[3] <= tb_coords[3]: continue
|
240 |
+
#print(element)
|
241 |
+
temp = ['aA', 'aB', 'aC', 'aD', 'aE', 'aF', 'aG', 'aH', 'aI', 'aJ', 'aK', 'aL', 'aM', 'aN', 'aO', 'aP', 'aQ', 'aR', 'aS', 'aT', 'aU', 'aV', 'aW', 'aX', 'aY', 'aZ', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6', 'a7', 'a8', 'a9', 'a0']
|
242 |
+
for segment in temp:
|
243 |
+
if segment in element.get_text():
|
244 |
+
if abs(element.bbox[3] - tb_coords[1]) < footnote_gap:
|
245 |
+
footnote = tuple(element.bbox) + (element.get_text()[element.get_text().index(segment):].replace('\n', ' '),)
|
246 |
+
footnote_gap = abs(element.bbox[3] - tb_coords[1])
|
247 |
+
break
|
248 |
+
self.title_y = min(title[1], title[3])
|
249 |
+
if self.output_bbox:
|
250 |
+
return ({'text': title[4], 'bbox': list(title[:4])}, {'text': footnote[4], 'bbox': list(footnote[:4])})
|
251 |
+
else:
|
252 |
+
return (title[4], footnote[4])
|
253 |
+
|
254 |
+
def extract_table_information(self):
|
255 |
+
#self.run_model(page_info) # changed
|
256 |
+
table_coordinates = self.blocks['table'] #should return a list of layout objects
|
257 |
+
table_coordinates_in_pdf = self.convert_to_pdf_coordinates('table') #should return a list of lists
|
258 |
+
|
259 |
+
ans = []
|
260 |
+
i = 0
|
261 |
+
for coordinate in table_coordinates_in_pdf:
|
262 |
+
ret = {}
|
263 |
+
pad = 20
|
264 |
+
coordinate = [coordinate[0] - pad, coordinate[1], coordinate[2] + pad, coordinate[3]]
|
265 |
+
ullr_coord = [coordinate[0], coordinate[3], coordinate[2], coordinate[1]]
|
266 |
+
|
267 |
+
table_results = self.extract_singular_table(coordinate)
|
268 |
+
tf = self.get_title_and_footnotes(coordinate)
|
269 |
+
figure = Image.fromarray(table_coordinates[i].crop_image(self.img))
|
270 |
+
ret.update({'title': tf[0]})
|
271 |
+
ret.update({'figure': {
|
272 |
+
'image': None,
|
273 |
+
'bbox': []
|
274 |
+
}})
|
275 |
+
if self.output_image:
|
276 |
+
ret['figure']['image'] = figure
|
277 |
+
ret.update({'table': {'bbox': list(coordinate), 'content': table_results}})
|
278 |
+
ret.update({'footnote': tf[1]})
|
279 |
+
if abs(self.title_y - self.column_header_y) > 50:
|
280 |
+
ret['figure']['bbox'] = list(coordinate)
|
281 |
+
|
282 |
+
ret.update({'page':self.page})
|
283 |
+
|
284 |
+
ans.append(ret)
|
285 |
+
i += 1
|
286 |
+
|
287 |
+
return ans
|
288 |
+
|
289 |
+
def extract_figure_information(self):
|
290 |
+
figure_coordinates = self.blocks['figure']
|
291 |
+
figure_coordinates_in_pdf = self.convert_to_pdf_coordinates('figure')
|
292 |
+
|
293 |
+
ans = []
|
294 |
+
for i in range(len(figure_coordinates)):
|
295 |
+
ret = {}
|
296 |
+
coordinate = figure_coordinates_in_pdf[i]
|
297 |
+
ullr_coord = [coordinate[0], coordinate[3], coordinate[2], coordinate[1]]
|
298 |
+
|
299 |
+
tf = self.get_title_and_footnotes(coordinate)
|
300 |
+
figure = Image.fromarray(figure_coordinates[i].crop_image(self.img))
|
301 |
+
ret.update({'title':tf[0]})
|
302 |
+
ret.update({'figure': {
|
303 |
+
'image': None,
|
304 |
+
'bbox': []
|
305 |
+
}})
|
306 |
+
if self.output_image:
|
307 |
+
ret['figure']['image'] = figure
|
308 |
+
ret.update({'table': {
|
309 |
+
'bbox': [],
|
310 |
+
'content': None
|
311 |
+
}})
|
312 |
+
ret.update({'footnote': tf[1]})
|
313 |
+
ret['figure']['bbox'] = list(coordinate)
|
314 |
+
|
315 |
+
ret.update({'page':self.page})
|
316 |
+
|
317 |
+
ans.append(ret)
|
318 |
+
|
319 |
+
return ans
|
320 |
+
|
321 |
+
|
322 |
+
def extract_all_tables_and_figures(self, pages, pdfparser, content=None):
|
323 |
+
self.model = pdfparser
|
324 |
+
ret = []
|
325 |
+
for i in range(len(pages)):
|
326 |
+
self.set_page_num(i)
|
327 |
+
self.run_model(pages[i])
|
328 |
+
table_info = self.extract_table_information()
|
329 |
+
figure_info = self.extract_figure_information()
|
330 |
+
if content == 'tables':
|
331 |
+
ret += table_info
|
332 |
+
elif content == 'figures':
|
333 |
+
ret += figure_info
|
334 |
+
for table in table_info:
|
335 |
+
if table['figure']['bbox'] != []:
|
336 |
+
ret.append(table)
|
337 |
+
else:
|
338 |
+
ret += table_info
|
339 |
+
ret += figure_info
|
340 |
+
return ret
|
chemietoolkit/utils.py
ADDED
@@ -0,0 +1,1018 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
import cv2
|
4 |
+
import layoutparser as lp
|
5 |
+
from rdkit import Chem
|
6 |
+
from rdkit.Chem import Draw
|
7 |
+
from rdkit.Chem import rdDepictor
|
8 |
+
rdDepictor.SetPreferCoordGen(True)
|
9 |
+
from rdkit.Chem.Draw import IPythonConsole
|
10 |
+
from rdkit.Chem import AllChem
|
11 |
+
import re
|
12 |
+
import copy
|
13 |
+
|
14 |
+
BOND_TO_INT = {
|
15 |
+
"": 0,
|
16 |
+
"single": 1,
|
17 |
+
"double": 2,
|
18 |
+
"triple": 3,
|
19 |
+
"aromatic": 4,
|
20 |
+
"solid wedge": 5,
|
21 |
+
"dashed wedge": 6
|
22 |
+
}
|
23 |
+
|
24 |
+
RGROUP_SYMBOLS = ['R', 'R1', 'R2', 'R3', 'R4', 'R5', 'R6', 'R7', 'R8', 'R9', 'R10', 'R11', 'R12',
|
25 |
+
'Ra', 'Rb', 'Rc', 'Rd', 'Rf', 'X', 'Y', 'Z', 'Q', 'A', 'E', 'Ar', 'Ar1', 'Ar2', 'Ari', "R'",
|
26 |
+
'1*', '2*','3*', '4*','5*', '6*','7*', '8*','9*', '10*','11*', '12*','[a*]', '[b*]','[c*]', '[d*]']
|
27 |
+
|
28 |
+
RGROUP_SYMBOLS = RGROUP_SYMBOLS + [f'[{i}]' for i in RGROUP_SYMBOLS]
|
29 |
+
|
30 |
+
RGROUP_SMILES = ['[1*]', '[2*]','[3*]', '[4*]','[5*]', '[6*]','[7*]', '[8*]','[9*]', '[10*]','[11*]', '[12*]','[a*]', '[b*]','[c*]', '[d*]','*', '[Rf]']
|
31 |
+
|
32 |
+
def get_figures_from_pages(pages, pdfparser):
|
33 |
+
figures = []
|
34 |
+
for i in range(len(pages)):
|
35 |
+
img = np.asarray(pages[i])
|
36 |
+
layout = pdfparser.detect(img)
|
37 |
+
blocks = lp.Layout([b for b in layout if b.type == "Figure"])
|
38 |
+
for block in blocks:
|
39 |
+
figure = Image.fromarray(block.crop_image(img))
|
40 |
+
figures.append({
|
41 |
+
'image': figure,
|
42 |
+
'page': i
|
43 |
+
})
|
44 |
+
return figures
|
45 |
+
|
46 |
+
def clean_bbox_output(figures, bboxes):
|
47 |
+
results = []
|
48 |
+
cropped = []
|
49 |
+
references = []
|
50 |
+
for i, output in enumerate(bboxes):
|
51 |
+
mol_bboxes = [elt['bbox'] for elt in output if elt['category'] == '[Mol]']
|
52 |
+
mol_scores = [elt['score'] for elt in output if elt['category'] == '[Mol]']
|
53 |
+
data = {}
|
54 |
+
results.append(data)
|
55 |
+
data['image'] = figures[i]
|
56 |
+
data['molecules'] = []
|
57 |
+
for bbox, score in zip(mol_bboxes, mol_scores):
|
58 |
+
x1, y1, x2, y2 = bbox
|
59 |
+
height, width, _ = figures[i].shape
|
60 |
+
cropped_img = figures[i][int(y1*height):int(y2*height),int(x1*width):int(x2*width)]
|
61 |
+
cur_mol = {
|
62 |
+
'bbox': bbox,
|
63 |
+
'score': score,
|
64 |
+
'image': cropped_img,
|
65 |
+
#'info': None,
|
66 |
+
}
|
67 |
+
cropped.append(cropped_img)
|
68 |
+
data['molecules'].append(cur_mol)
|
69 |
+
references.append(cur_mol)
|
70 |
+
return results, cropped, references
|
71 |
+
|
72 |
+
def convert_to_pil(image):
|
73 |
+
if type(image) == np.ndarray:
|
74 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
75 |
+
image = Image.fromarray(image)
|
76 |
+
return image
|
77 |
+
|
78 |
+
def convert_to_cv2(image):
|
79 |
+
if type(image) != np.ndarray:
|
80 |
+
image = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)
|
81 |
+
return image
|
82 |
+
|
83 |
+
def replace_rgroups_in_figure(figures, results, coref_results, molscribe, batch_size=16):
|
84 |
+
pattern = re.compile('(?P<name>[RXY]\d?)[ ]*=[ ]*(?P<group>\w+)')
|
85 |
+
for figure, result, corefs in zip(figures, results, coref_results):
|
86 |
+
r_groups = []
|
87 |
+
seen_r_groups = set()
|
88 |
+
for bbox in corefs['bboxes']:
|
89 |
+
if bbox['category'] == '[Idt]':
|
90 |
+
for text in bbox['text']:
|
91 |
+
res = pattern.search(text)
|
92 |
+
if res is None:
|
93 |
+
continue
|
94 |
+
name = res.group('name')
|
95 |
+
group = res.group('group')
|
96 |
+
if (name, group) in seen_r_groups:
|
97 |
+
continue
|
98 |
+
seen_r_groups.add((name, group))
|
99 |
+
r_groups.append({name: res.group('group')})
|
100 |
+
if r_groups and result['reactions']:
|
101 |
+
seen_r_groups = set([pair[0] for pair in seen_r_groups])
|
102 |
+
orig_reaction = result['reactions'][0]
|
103 |
+
graphs = get_atoms_and_bonds(figure['figure']['image'], orig_reaction, molscribe, batch_size=batch_size)
|
104 |
+
relevant_locs = {}
|
105 |
+
for i, graph in enumerate(graphs):
|
106 |
+
to_add = []
|
107 |
+
for j, atom in enumerate(graph['chartok_coords']['symbols']):
|
108 |
+
if atom[1:-1] in seen_r_groups:
|
109 |
+
to_add.append((atom[1:-1], j))
|
110 |
+
relevant_locs[i] = to_add
|
111 |
+
|
112 |
+
for r_group in r_groups:
|
113 |
+
reaction = get_replaced_reaction(orig_reaction, graphs, relevant_locs, r_group, molscribe)
|
114 |
+
to_add ={
|
115 |
+
'reactants': reaction['reactants'][:],
|
116 |
+
'conditions': orig_reaction['conditions'][:],
|
117 |
+
'products': reaction['products'][:]
|
118 |
+
}
|
119 |
+
result['reactions'].append(to_add)
|
120 |
+
return results
|
121 |
+
|
122 |
+
def process_tables(figures, results, molscribe, batch_size=16):
|
123 |
+
r_group_pattern = re.compile(r'^(\w+-)?(?P<group>[\w-]+)( \(\w+\))?$')
|
124 |
+
for figure, result in zip(figures, results):
|
125 |
+
result['page'] = figure['page']
|
126 |
+
if figure['table']['content'] is not None:
|
127 |
+
content = figure['table']['content']
|
128 |
+
if len(result['reactions']) > 1:
|
129 |
+
print("Warning: multiple reactions detected for table")
|
130 |
+
elif len(result['reactions']) == 0:
|
131 |
+
continue
|
132 |
+
orig_reaction = result['reactions'][0]
|
133 |
+
graphs = get_atoms_and_bonds(figure['figure']['image'], orig_reaction, molscribe, batch_size=batch_size)
|
134 |
+
relevant_locs = find_relevant_groups(graphs, content['columns'])
|
135 |
+
conditions_to_extend = []
|
136 |
+
for row in content['rows']:
|
137 |
+
r_groups = {}
|
138 |
+
expanded_conditions = orig_reaction['conditions'][:]
|
139 |
+
replaced = False
|
140 |
+
for col, entry in zip(content['columns'], row):
|
141 |
+
if col['tag'] != 'alkyl group':
|
142 |
+
expanded_conditions.append({
|
143 |
+
'category': '[Table]',
|
144 |
+
'text': entry['text'],
|
145 |
+
'tag': col['tag'],
|
146 |
+
'header': col['text'],
|
147 |
+
})
|
148 |
+
else:
|
149 |
+
found = r_group_pattern.match(entry['text'])
|
150 |
+
if found is not None:
|
151 |
+
r_groups[col['text']] = found.group('group')
|
152 |
+
replaced = True
|
153 |
+
reaction = get_replaced_reaction(orig_reaction, graphs, relevant_locs, r_groups, molscribe)
|
154 |
+
if replaced:
|
155 |
+
to_add = {
|
156 |
+
'reactants': reaction['reactants'][:],
|
157 |
+
'conditions': expanded_conditions,
|
158 |
+
'products': reaction['products'][:]
|
159 |
+
}
|
160 |
+
result['reactions'].append(to_add)
|
161 |
+
else:
|
162 |
+
conditions_to_extend.append(expanded_conditions)
|
163 |
+
orig_reaction['conditions'] = [orig_reaction['conditions']]
|
164 |
+
orig_reaction['conditions'].extend(conditions_to_extend)
|
165 |
+
return results
|
166 |
+
|
167 |
+
|
168 |
+
def get_atoms_and_bonds(image, reaction, molscribe, batch_size=16):
|
169 |
+
image = convert_to_cv2(image)
|
170 |
+
cropped_images = []
|
171 |
+
results = []
|
172 |
+
for key, molecules in reaction.items():
|
173 |
+
for i, elt in enumerate(molecules):
|
174 |
+
if type(elt) != dict or elt['category'] != '[Mol]':
|
175 |
+
continue
|
176 |
+
x1, y1, x2, y2 = elt['bbox']
|
177 |
+
height, width, _ = image.shape
|
178 |
+
cropped_images.append(image[int(y1*height):int(y2*height),int(x1*width):int(x2*width)])
|
179 |
+
to_add = {
|
180 |
+
'image': cropped_images[-1],
|
181 |
+
'chartok_coords': {
|
182 |
+
'coords': [],
|
183 |
+
'symbols': [],
|
184 |
+
},
|
185 |
+
'edges': [],
|
186 |
+
'key': (key, i)
|
187 |
+
}
|
188 |
+
results.append(to_add)
|
189 |
+
outputs = molscribe.predict_images(cropped_images, return_atoms_bonds=True, batch_size=batch_size)
|
190 |
+
for mol, result in zip(outputs, results):
|
191 |
+
for atom in mol['atoms']:
|
192 |
+
result['chartok_coords']['coords'].append((atom['x'], atom['y']))
|
193 |
+
result['chartok_coords']['symbols'].append(atom['atom_symbol'])
|
194 |
+
result['edges'] = [[0] * len(mol['atoms']) for _ in range(len(mol['atoms']))]
|
195 |
+
for bond in mol['bonds']:
|
196 |
+
i, j = bond['endpoint_atoms']
|
197 |
+
result['edges'][i][j] = BOND_TO_INT[bond['bond_type']]
|
198 |
+
result['edges'][j][i] = BOND_TO_INT[bond['bond_type']]
|
199 |
+
return results
|
200 |
+
|
201 |
+
def find_relevant_groups(graphs, columns):
|
202 |
+
results = {}
|
203 |
+
r_groups = set([f"[{col['text']}]" for col in columns if col['tag'] == 'alkyl group'])
|
204 |
+
for i, graph in enumerate(graphs):
|
205 |
+
to_add = []
|
206 |
+
for j, atom in enumerate(graph['chartok_coords']['symbols']):
|
207 |
+
if atom in r_groups:
|
208 |
+
to_add.append((atom[1:-1], j))
|
209 |
+
results[i] = to_add
|
210 |
+
return results
|
211 |
+
|
212 |
+
def get_replaced_reaction(orig_reaction, graphs, relevant_locs, mappings, molscribe):
|
213 |
+
graph_copy = []
|
214 |
+
for graph in graphs:
|
215 |
+
graph_copy.append({
|
216 |
+
'image': graph['image'],
|
217 |
+
'chartok_coords': {
|
218 |
+
'coords': graph['chartok_coords']['coords'][:],
|
219 |
+
'symbols': graph['chartok_coords']['symbols'][:],
|
220 |
+
},
|
221 |
+
'edges': graph['edges'][:],
|
222 |
+
'key': graph['key'],
|
223 |
+
})
|
224 |
+
for graph_idx, atoms in relevant_locs.items():
|
225 |
+
for atom, atom_idx in atoms:
|
226 |
+
if atom in mappings:
|
227 |
+
graph_copy[graph_idx]['chartok_coords']['symbols'][atom_idx] = mappings[atom]
|
228 |
+
reaction_copy = {}
|
229 |
+
def append_copy(copy_list, entity):
|
230 |
+
if entity['category'] == '[Mol]':
|
231 |
+
copy_list.append({
|
232 |
+
k1: v1 for k1, v1 in entity.items()
|
233 |
+
})
|
234 |
+
else:
|
235 |
+
copy_list.append(entity)
|
236 |
+
|
237 |
+
for k, v in orig_reaction.items():
|
238 |
+
reaction_copy[k] = []
|
239 |
+
for entity in v:
|
240 |
+
if type(entity) == list:
|
241 |
+
sub_list = []
|
242 |
+
for e in entity:
|
243 |
+
append_copy(sub_list, e)
|
244 |
+
reaction_copy[k].append(sub_list)
|
245 |
+
else:
|
246 |
+
append_copy(reaction_copy[k], entity)
|
247 |
+
|
248 |
+
for graph in graph_copy:
|
249 |
+
output = molscribe.convert_graph_to_output([graph], [graph['image']])
|
250 |
+
molecule = reaction_copy[graph['key'][0]][graph['key'][1]]
|
251 |
+
molecule['smiles'] = output[0]['smiles']
|
252 |
+
molecule['molfile'] = output[0]['molfile']
|
253 |
+
return reaction_copy
|
254 |
+
|
255 |
+
def get_sites(tar, ref, ref_site = False):
|
256 |
+
rdDepictor.Compute2DCoords(ref)
|
257 |
+
rdDepictor.Compute2DCoords(tar)
|
258 |
+
idx_pair = rdDepictor.GenerateDepictionMatching2DStructure(tar, ref)
|
259 |
+
|
260 |
+
in_template = [i[1] for i in idx_pair]
|
261 |
+
sites = []
|
262 |
+
for i in range(tar.GetNumAtoms()):
|
263 |
+
if i not in in_template:
|
264 |
+
for j in tar.GetAtomWithIdx(i).GetNeighbors():
|
265 |
+
if j.GetIdx() in in_template and j.GetIdx() not in sites:
|
266 |
+
|
267 |
+
if ref_site: sites.append(idx_pair[in_template.index(j.GetIdx())][0])
|
268 |
+
else: sites.append(idx_pair[in_template.index(j.GetIdx())][0])
|
269 |
+
return sites
|
270 |
+
|
271 |
+
def get_atom_mapping(prod_mol, prod_smiles, r_sites_reversed = None):
|
272 |
+
# returns prod_mol_to_query which is the mapping of atom indices in prod_mol to the atom indices of the molecule represented by prod_smiles
|
273 |
+
prod_template_intermediate = Chem.MolToSmiles(prod_mol)
|
274 |
+
prod_template = prod_smiles
|
275 |
+
|
276 |
+
for r in RGROUP_SMILES:
|
277 |
+
if r!='*' and r!='(*)':
|
278 |
+
prod_template = prod_template.replace(r, '*')
|
279 |
+
prod_template_intermediate = prod_template_intermediate.replace(r, '*')
|
280 |
+
|
281 |
+
prod_template_intermediate_mol = Chem.MolFromSmiles(prod_template_intermediate)
|
282 |
+
prod_template_mol = Chem.MolFromSmiles(prod_template)
|
283 |
+
|
284 |
+
p = Chem.AdjustQueryParameters.NoAdjustments()
|
285 |
+
p.makeDummiesQueries = True
|
286 |
+
|
287 |
+
prod_template_mol_query = Chem.AdjustQueryProperties(prod_template_mol, p)
|
288 |
+
prod_template_intermediate_mol_query = Chem.AdjustQueryProperties(prod_template_intermediate_mol, p)
|
289 |
+
rdDepictor.Compute2DCoords(prod_mol)
|
290 |
+
rdDepictor.Compute2DCoords(prod_template_mol_query)
|
291 |
+
rdDepictor.Compute2DCoords(prod_template_intermediate_mol_query)
|
292 |
+
idx_pair = rdDepictor.GenerateDepictionMatching2DStructure(prod_mol, prod_template_intermediate_mol_query)
|
293 |
+
|
294 |
+
intermdiate_to_prod_mol = {a:b for a,b in idx_pair}
|
295 |
+
prod_mol_to_intermediate = {b:a for a,b in idx_pair}
|
296 |
+
|
297 |
+
|
298 |
+
#idx_pair_2 = rdDepictor.GenerateDepictionMatching2DStructure(prod_template_mol_query, prod_template_intermediate_mol_query)
|
299 |
+
|
300 |
+
#intermediate_to_query = {a:b for a,b in idx_pair_2}
|
301 |
+
#query_to_intermediate = {b:a for a,b in idx_pair_2}
|
302 |
+
|
303 |
+
#prod_mol_to_query = {a:intermediate_to_query[prod_mol_to_intermediate[a]] for a in prod_mol_to_intermediate}
|
304 |
+
|
305 |
+
|
306 |
+
substructs = prod_template_mol_query.GetSubstructMatches(prod_template_intermediate_mol_query, uniquify = False)
|
307 |
+
|
308 |
+
#idx_pair_2 = rdDepictor.GenerateDepictionMatching2DStructure(prod_template_mol_query, prod_template_intermediate_mol_query)
|
309 |
+
for substruct in substructs:
|
310 |
+
|
311 |
+
|
312 |
+
intermediate_to_query = {a:b for a, b in enumerate(substruct)}
|
313 |
+
query_to_intermediate = {intermediate_to_query[i]: i for i in intermediate_to_query}
|
314 |
+
|
315 |
+
prod_mol_to_query = {a:intermediate_to_query[prod_mol_to_intermediate[a]] for a in prod_mol_to_intermediate}
|
316 |
+
|
317 |
+
good_map = True
|
318 |
+
for i in r_sites_reversed:
|
319 |
+
if prod_template_mol_query.GetAtomWithIdx(prod_mol_to_query[i]).GetSymbol() not in RGROUP_SMILES:
|
320 |
+
good_map = False
|
321 |
+
if good_map:
|
322 |
+
break
|
323 |
+
|
324 |
+
return prod_mol_to_query, prod_template_mol_query
|
325 |
+
|
326 |
+
def clean_corefs(coref_results_dict, idx):
|
327 |
+
label_pattern = rf'{re.escape(idx)}[a-zA-Z]+'
|
328 |
+
#unclean_pattern = re.escape(idx) + r'\d(?![\d% ])'
|
329 |
+
toreturn = {}
|
330 |
+
for prod in coref_results_dict:
|
331 |
+
has_good_label = False
|
332 |
+
for parsed in coref_results_dict[prod]:
|
333 |
+
if re.search(label_pattern, parsed):
|
334 |
+
has_good_label = True
|
335 |
+
if not has_good_label:
|
336 |
+
for parsed in coref_results_dict[prod]:
|
337 |
+
if idx+'1' in parsed:
|
338 |
+
coref_results_dict[prod].append(idx+'l')
|
339 |
+
elif idx+'0' in parsed:
|
340 |
+
coref_results_dict[prod].append(idx+'o')
|
341 |
+
elif idx+'5' in parsed:
|
342 |
+
coref_results_dict[prod].append(idx+'s')
|
343 |
+
elif idx+'9' in parsed:
|
344 |
+
coref_results_dict[prod].append(idx+'g')
|
345 |
+
|
346 |
+
|
347 |
+
|
348 |
+
def expand_r_group_label_helper(res, coref_smiles_to_graphs, other_prod, molscribe):
|
349 |
+
name = res.group('name')
|
350 |
+
group = res.group('group')
|
351 |
+
#print(other_prod)
|
352 |
+
atoms = coref_smiles_to_graphs[other_prod]['atoms']
|
353 |
+
bonds = coref_smiles_to_graphs[other_prod]['bonds']
|
354 |
+
|
355 |
+
#print(atoms, bonds)
|
356 |
+
|
357 |
+
graph = {
|
358 |
+
'image': None,
|
359 |
+
'chartok_coords': {
|
360 |
+
'coords': [],
|
361 |
+
'symbols': [],
|
362 |
+
},
|
363 |
+
'edges': [],
|
364 |
+
'key': None
|
365 |
+
}
|
366 |
+
for atom in atoms:
|
367 |
+
graph['chartok_coords']['coords'].append((atom['x'], atom['y']))
|
368 |
+
graph['chartok_coords']['symbols'].append(atom['atom_symbol'])
|
369 |
+
graph['edges'] = [[0] * len(atoms) for _ in range(len(atoms))]
|
370 |
+
for bond in bonds:
|
371 |
+
i, j = bond['endpoint_atoms']
|
372 |
+
graph['edges'][i][j] = BOND_TO_INT[bond['bond_type']]
|
373 |
+
graph['edges'][j][i] = BOND_TO_INT[bond['bond_type']]
|
374 |
+
for i, symbol in enumerate(graph['chartok_coords']['symbols']):
|
375 |
+
if symbol[1:-1] == name:
|
376 |
+
graph['chartok_coords']['symbols'][i] = group
|
377 |
+
|
378 |
+
#print(graph)
|
379 |
+
o = molscribe.convert_graph_to_output([graph], [graph['image']])
|
380 |
+
return Chem.MolFromSmiles(o[0]['smiles'])
|
381 |
+
|
382 |
+
def get_r_group_frags_and_substitute(other_prod_mol, query, reactant_mols, reactant_information, parsed, toreturn):
|
383 |
+
prod_template_mol_query, r_sites_reversed_new, h_sites, num_r_groups = query
|
384 |
+
# we get the substruct matches. note that we set uniquify to false since the order matters for our method
|
385 |
+
substructs = other_prod_mol.GetSubstructMatches(prod_template_mol_query, uniquify = False)
|
386 |
+
|
387 |
+
|
388 |
+
#for r in r_sites_reversed:
|
389 |
+
# print(prod_template_mol_query.GetAtomWithIdx(prod_mol_to_query[r]).GetSymbol())
|
390 |
+
|
391 |
+
# for each substruct we create the mapping of the substruct onto the other_mol
|
392 |
+
# delete all the molecules in other_mol correspond to the substruct
|
393 |
+
# and check if they number of mol frags is equal to number of r groups
|
394 |
+
# we do this to make sure we have the correct substruct
|
395 |
+
if len(substructs) >= 1:
|
396 |
+
for substruct in substructs:
|
397 |
+
|
398 |
+
query_to_other = {a:b for a,b in enumerate(substruct)}
|
399 |
+
other_to_query = {query_to_other[i]:i for i in query_to_other}
|
400 |
+
|
401 |
+
editable = Chem.EditableMol(other_prod_mol)
|
402 |
+
r_site_correspondence = []
|
403 |
+
for r in r_sites_reversed_new:
|
404 |
+
#get its id in substruct
|
405 |
+
substruct_id = query_to_other[r]
|
406 |
+
r_site_correspondence.append([substruct_id, r_sites_reversed_new[r]])
|
407 |
+
|
408 |
+
for idx in tuple(sorted(substruct, reverse = True)):
|
409 |
+
if idx not in [query_to_other[i] for i in r_sites_reversed_new]:
|
410 |
+
editable.RemoveAtom(idx)
|
411 |
+
for r_site in r_site_correspondence:
|
412 |
+
if idx < r_site[0]:
|
413 |
+
r_site[0]-=1
|
414 |
+
other_prod_removed = editable.GetMol()
|
415 |
+
|
416 |
+
if len(Chem.GetMolFrags(other_prod_removed, asMols = False)) == num_r_groups:
|
417 |
+
break
|
418 |
+
|
419 |
+
# need to compute the sites at which correspond to each r_site_reversed
|
420 |
+
|
421 |
+
r_site_correspondence.sort(key = lambda x: x[0])
|
422 |
+
|
423 |
+
|
424 |
+
f = []
|
425 |
+
ff = []
|
426 |
+
frags = Chem.GetMolFrags(other_prod_removed, asMols = True, frags = f, fragsMolAtomMapping = ff)
|
427 |
+
|
428 |
+
# r_group_information maps r group name --> the fragment/molcule corresponding to the r group and the atom index it should be connected at
|
429 |
+
r_group_information = {}
|
430 |
+
#tosubtract = 0
|
431 |
+
for idx, r_site in enumerate(r_site_correspondence):
|
432 |
+
|
433 |
+
r_group_information[r_site[1]]= (frags[f[r_site[0]]], ff[f[r_site[0]]].index(r_site[0]))
|
434 |
+
#tosubtract += len(ff[idx])
|
435 |
+
for r_site in h_sites:
|
436 |
+
r_group_information[r_site] = (Chem.MolFromSmiles('[H]'), 0)
|
437 |
+
|
438 |
+
# now we modify all of the reactants according to the R groups we have found
|
439 |
+
# for every reactant we disconnect its r group symbol, and connect it to the r group
|
440 |
+
modify_reactants = copy.deepcopy(reactant_mols)
|
441 |
+
modified_reactant_smiles = []
|
442 |
+
for reactant_idx in reactant_information:
|
443 |
+
if len(reactant_information[reactant_idx]) == 0:
|
444 |
+
modified_reactant_smiles.append(Chem.MolToSmiles(modify_reactants[reactant_idx]))
|
445 |
+
else:
|
446 |
+
combined = reactant_mols[reactant_idx]
|
447 |
+
if combined.GetNumAtoms() == 1:
|
448 |
+
r_group, _, _ = reactant_information[reactant_idx][0]
|
449 |
+
modified_reactant_smiles.append(Chem.MolToSmiles(r_group_information[r_group][0]))
|
450 |
+
else:
|
451 |
+
for r_group, r_index, connect_index in reactant_information[reactant_idx]:
|
452 |
+
combined = Chem.CombineMols(combined, r_group_information[r_group][0])
|
453 |
+
|
454 |
+
editable = Chem.EditableMol(combined)
|
455 |
+
atomIdxAdder = reactant_mols[reactant_idx].GetNumAtoms()
|
456 |
+
for r_group, r_index, connect_index in reactant_information[reactant_idx]:
|
457 |
+
Chem.EditableMol.RemoveBond(editable, r_index, connect_index)
|
458 |
+
Chem.EditableMol.AddBond(editable, connect_index, atomIdxAdder + r_group_information[r_group][1], Chem.BondType.SINGLE)
|
459 |
+
atomIdxAdder += r_group_information[r_group][0].GetNumAtoms()
|
460 |
+
r_indices = [i[1] for i in reactant_information[reactant_idx]]
|
461 |
+
|
462 |
+
r_indices.sort(reverse = True)
|
463 |
+
|
464 |
+
for r_index in r_indices:
|
465 |
+
Chem.EditableMol.RemoveAtom(editable, r_index)
|
466 |
+
|
467 |
+
modified_reactant_smiles.append(Chem.MolToSmiles(Chem.MolFromSmiles(Chem.MolToSmiles(editable.GetMol()))))
|
468 |
+
|
469 |
+
toreturn.append((modified_reactant_smiles, [Chem.MolToSmiles(other_prod_mol)], parsed))
|
470 |
+
return True
|
471 |
+
else:
|
472 |
+
return False
|
473 |
+
|
474 |
+
def query_enumeration(prod_template_mol_query, r_sites_reversed_new, num_r_groups):
|
475 |
+
subsets = generate_subsets(num_r_groups)
|
476 |
+
|
477 |
+
toreturn = []
|
478 |
+
|
479 |
+
for subset in subsets:
|
480 |
+
r_sites_list = [[i, r_sites_reversed_new[i]] for i in r_sites_reversed_new]
|
481 |
+
r_sites_list.sort(key = lambda x: x[0])
|
482 |
+
to_edit = Chem.EditableMol(prod_template_mol_query)
|
483 |
+
|
484 |
+
for entry in subset:
|
485 |
+
pos = r_sites_list[entry][0]
|
486 |
+
Chem.EditableMol.RemoveBond(to_edit, r_sites_list[entry][0], prod_template_mol_query.GetAtomWithIdx(r_sites_list[entry][0]).GetNeighbors()[0].GetIdx())
|
487 |
+
for entry in subset:
|
488 |
+
pos = r_sites_list[entry][0]
|
489 |
+
Chem.EditableMol.RemoveAtom(to_edit, pos)
|
490 |
+
|
491 |
+
edited = to_edit.GetMol()
|
492 |
+
for entry in subset:
|
493 |
+
for i in range(entry + 1, num_r_groups):
|
494 |
+
r_sites_list[i][0]-=1
|
495 |
+
|
496 |
+
new_r_sites = {}
|
497 |
+
new_h_sites = set()
|
498 |
+
for i in range(num_r_groups):
|
499 |
+
if i not in subset:
|
500 |
+
new_r_sites[r_sites_list[i][0]] = r_sites_list[i][1]
|
501 |
+
else:
|
502 |
+
new_h_sites.add(r_sites_list[i][1])
|
503 |
+
toreturn.append((edited, new_r_sites, new_h_sites, num_r_groups - len(subset)))
|
504 |
+
return toreturn
|
505 |
+
|
506 |
+
def generate_subsets(n):
|
507 |
+
def backtrack(start, subset):
|
508 |
+
result.append(subset[:])
|
509 |
+
for i in range(start, -1, -1): # Iterate in reverse order
|
510 |
+
subset.append(i)
|
511 |
+
backtrack(i - 1, subset)
|
512 |
+
subset.pop()
|
513 |
+
|
514 |
+
result = []
|
515 |
+
backtrack(n - 1, [])
|
516 |
+
return sorted(result, key=lambda x: (-len(x), x), reverse=True)
|
517 |
+
|
518 |
+
def backout(results, coref_results, molscribe):
|
519 |
+
|
520 |
+
toreturn = []
|
521 |
+
|
522 |
+
if not results or not results[0]['reactions'] or not coref_results:
|
523 |
+
return toreturn
|
524 |
+
|
525 |
+
try:
|
526 |
+
reactants = results[0]['reactions'][0]['reactants']
|
527 |
+
products = [i['smiles'] for i in results[0]['reactions'][0]['products']]
|
528 |
+
coref_results_dict = {coref_results[0]['bboxes'][coref[0]]['smiles']: coref_results[0]['bboxes'][coref[1]]['text'] for coref in coref_results[0]['corefs']}
|
529 |
+
coref_smiles_to_graphs = {coref_results[0]['bboxes'][coref[0]]['smiles']: coref_results[0]['bboxes'][coref[0]] for coref in coref_results[0]['corefs']}
|
530 |
+
|
531 |
+
|
532 |
+
if len(products) == 1:
|
533 |
+
if products[0] not in coref_results_dict:
|
534 |
+
print("Warning: No Label Parsed")
|
535 |
+
return
|
536 |
+
product_labels = coref_results_dict[products[0]]
|
537 |
+
prod = products[0]
|
538 |
+
label_idx = product_labels[0]
|
539 |
+
'''
|
540 |
+
if len(product_labels) == 1:
|
541 |
+
# get the coreference label of the product molecule
|
542 |
+
label_idx = product_labels[0]
|
543 |
+
else:
|
544 |
+
print("Warning: Malformed Label Parsed.")
|
545 |
+
return
|
546 |
+
'''
|
547 |
+
else:
|
548 |
+
print("Warning: More than one product detected")
|
549 |
+
return
|
550 |
+
|
551 |
+
# format the regular expression for labels that correspond to the product label
|
552 |
+
numbers = re.findall(r'\d+', label_idx)
|
553 |
+
label_idx = numbers[0] if len(numbers) > 0 else ""
|
554 |
+
label_pattern = rf'{re.escape(label_idx)}[a-zA-Z]+'
|
555 |
+
|
556 |
+
|
557 |
+
prod_smiles = prod
|
558 |
+
prod_mol = Chem.MolFromMolBlock(results[0]['reactions'][0]['products'][0]['molfile'])
|
559 |
+
|
560 |
+
# identify the atom indices of the R groups in the product tempalte
|
561 |
+
h_counter = 0
|
562 |
+
r_sites = {}
|
563 |
+
for idx, atom in enumerate(results[0]['reactions'][0]['products'][0]['atoms']):
|
564 |
+
sym = atom['atom_symbol']
|
565 |
+
if sym == '[H]':
|
566 |
+
h_counter += 1
|
567 |
+
if sym[0] == '[':
|
568 |
+
sym = sym[1:-1]
|
569 |
+
if sym[0] == 'R' and sym[1:].isdigit():
|
570 |
+
sym = sym[1:]+"*"
|
571 |
+
sym = f'[{sym}]'
|
572 |
+
if sym in RGROUP_SYMBOLS:
|
573 |
+
if sym not in r_sites:
|
574 |
+
r_sites[sym] = [idx-h_counter]
|
575 |
+
else:
|
576 |
+
r_sites[sym].append(idx-h_counter)
|
577 |
+
|
578 |
+
r_sites_reversed = {}
|
579 |
+
for sym in r_sites:
|
580 |
+
for pos in r_sites[sym]:
|
581 |
+
r_sites_reversed[pos] = sym
|
582 |
+
|
583 |
+
num_r_groups = len(r_sites_reversed)
|
584 |
+
|
585 |
+
#prepare the product template and get the associated mapping
|
586 |
+
|
587 |
+
prod_mol_to_query, prod_template_mol_query = get_atom_mapping(prod_mol, prod_smiles, r_sites_reversed = r_sites_reversed)
|
588 |
+
|
589 |
+
reactant_mols = []
|
590 |
+
|
591 |
+
|
592 |
+
#--------------process the reactants-----------------
|
593 |
+
|
594 |
+
reactant_information = {} #index of relevant reaction --> [[R group name, atom index of R group, atom index of R group connection], ...]
|
595 |
+
|
596 |
+
for idx, reactant in enumerate(reactants):
|
597 |
+
reactant_information[idx] = []
|
598 |
+
reactant_mols.append(Chem.MolFromSmiles(reactant['smiles']))
|
599 |
+
has_r = False
|
600 |
+
|
601 |
+
r_sites_reactant = {}
|
602 |
+
|
603 |
+
h_counter = 0
|
604 |
+
|
605 |
+
for a_idx, atom in enumerate(reactant['atoms']):
|
606 |
+
|
607 |
+
#go through all atoms and check if they are an R group, if so add it to reactant information
|
608 |
+
sym = atom['atom_symbol']
|
609 |
+
if sym == '[H]':
|
610 |
+
h_counter += 1
|
611 |
+
if sym[0] == '[':
|
612 |
+
sym = sym[1:-1]
|
613 |
+
if sym[0] == 'R' and sym[1:].isdigit():
|
614 |
+
sym = sym[1:]+"*"
|
615 |
+
sym = f'[{sym}]'
|
616 |
+
if sym in r_sites:
|
617 |
+
if reactant_mols[-1].GetNumAtoms()==1:
|
618 |
+
reactant_information[idx].append([sym, -1, -1])
|
619 |
+
else:
|
620 |
+
has_r = True
|
621 |
+
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
|
622 |
+
reactant_information[idx].append([sym, a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
|
623 |
+
r_sites_reactant[sym] = a_idx-h_counter
|
624 |
+
elif sym == '[1*]' and '[7*]' in r_sites:
|
625 |
+
if reactant_mols[-1].GetNumAtoms()==1:
|
626 |
+
reactant_information[idx].append(['[7*]', -1, -1])
|
627 |
+
else:
|
628 |
+
has_r = True
|
629 |
+
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
|
630 |
+
reactant_information[idx].append(['[7*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
|
631 |
+
r_sites_reactant['[7*]'] = a_idx-h_counter
|
632 |
+
elif sym == '[7*]' and '[1*]' in r_sites:
|
633 |
+
if reactant_mols[-1].GetNumAtoms()==1:
|
634 |
+
reactant_information[idx].append(['[1*]', -1, -1])
|
635 |
+
else:
|
636 |
+
has_r = True
|
637 |
+
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
|
638 |
+
reactant_information[idx].append(['[1*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
|
639 |
+
r_sites_reactant['[1*]'] = a_idx-h_counter
|
640 |
+
|
641 |
+
|
642 |
+
|
643 |
+
elif sym == '[1*]' and '[Rf]' in r_sites:
|
644 |
+
if reactant_mols[-1].GetNumAtoms()==1:
|
645 |
+
reactant_information[idx].append(['[Rf]', -1, -1])
|
646 |
+
else:
|
647 |
+
has_r = True
|
648 |
+
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
|
649 |
+
reactant_information[idx].append(['[Rf]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
|
650 |
+
r_sites_reactant['[Rf]'] = a_idx-h_counter
|
651 |
+
|
652 |
+
elif sym == '[Rf]' and '[1*]' in r_sites:
|
653 |
+
if reactant_mols[-1].GetNumAtoms()==1:
|
654 |
+
reactant_information[idx].append(['[1*]', -1, -1])
|
655 |
+
else:
|
656 |
+
has_r = True
|
657 |
+
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
|
658 |
+
reactant_information[idx].append(['[1*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
|
659 |
+
r_sites_reactant['[1*]'] = a_idx-h_counter
|
660 |
+
|
661 |
+
|
662 |
+
r_sites_reversed_reactant = {r_sites_reactant[i]: i for i in r_sites_reactant}
|
663 |
+
# if the reactant had r groups, we had to use the molecule generated from the MolBlock.
|
664 |
+
# but the molblock may have unexpanded elemeents that are not R groups
|
665 |
+
# so we have to map back the r group indices in the molblock version to the full molecule generated by the smiles
|
666 |
+
# and adjust the indices of the r groups accordingly
|
667 |
+
if has_r:
|
668 |
+
#get the mapping
|
669 |
+
reactant_mol_to_query, _ = get_atom_mapping(reactant_mols[-1], reactant['smiles'], r_sites_reversed = r_sites_reversed_reactant)
|
670 |
+
|
671 |
+
#make the adjustment
|
672 |
+
for info in reactant_information[idx]:
|
673 |
+
info[1] = reactant_mol_to_query[info[1]]
|
674 |
+
info[2] = reactant_mol_to_query[info[2]]
|
675 |
+
reactant_mols[-1] = Chem.MolFromSmiles(reactant['smiles'])
|
676 |
+
|
677 |
+
#go through all the molecules in the coreference
|
678 |
+
|
679 |
+
clean_corefs(coref_results_dict, label_idx)
|
680 |
+
|
681 |
+
for other_prod in coref_results_dict:
|
682 |
+
|
683 |
+
#check if they match the product label regex
|
684 |
+
found_good_label = False
|
685 |
+
for parsed in coref_results_dict[other_prod]:
|
686 |
+
if re.search(label_pattern, parsed) and not found_good_label:
|
687 |
+
found_good_label = True
|
688 |
+
other_prod_mol = Chem.MolFromSmiles(other_prod)
|
689 |
+
|
690 |
+
if other_prod != prod_smiles and other_prod_mol is not None:
|
691 |
+
|
692 |
+
#check if there are R groups to be resolved in the target product
|
693 |
+
|
694 |
+
all_other_prod_mols = []
|
695 |
+
|
696 |
+
r_group_sub_pattern = re.compile('(?P<name>[RXY]\d?)[ ]*=[ ]*(?P<group>\w+)')
|
697 |
+
|
698 |
+
for parsed_labels in coref_results_dict[other_prod]:
|
699 |
+
res = r_group_sub_pattern.search(parsed_labels)
|
700 |
+
|
701 |
+
if res is not None:
|
702 |
+
all_other_prod_mols.append((expand_r_group_label_helper(res, coref_smiles_to_graphs, other_prod, molscribe), parsed + parsed_labels))
|
703 |
+
|
704 |
+
if len(all_other_prod_mols) == 0:
|
705 |
+
if other_prod_mol is not None:
|
706 |
+
all_other_prod_mols.append((other_prod_mol, parsed))
|
707 |
+
|
708 |
+
|
709 |
+
|
710 |
+
|
711 |
+
for other_prod_mol, parsed in all_other_prod_mols:
|
712 |
+
|
713 |
+
other_prod_frags = Chem.GetMolFrags(other_prod_mol, asMols = True)
|
714 |
+
|
715 |
+
for other_prod_frag in other_prod_frags:
|
716 |
+
substructs = other_prod_frag.GetSubstructMatches(prod_template_mol_query, uniquify = False)
|
717 |
+
|
718 |
+
if len(substructs)>0:
|
719 |
+
other_prod_mol = other_prod_frag
|
720 |
+
break
|
721 |
+
r_sites_reversed_new = {prod_mol_to_query[r]: r_sites_reversed[r] for r in r_sites_reversed}
|
722 |
+
|
723 |
+
queries = query_enumeration(prod_template_mol_query, r_sites_reversed_new, num_r_groups)
|
724 |
+
|
725 |
+
matched = False
|
726 |
+
|
727 |
+
for query in queries:
|
728 |
+
if not matched:
|
729 |
+
try:
|
730 |
+
matched = get_r_group_frags_and_substitute(other_prod_mol, query, reactant_mols, reactant_information, parsed, toreturn)
|
731 |
+
except:
|
732 |
+
pass
|
733 |
+
|
734 |
+
except:
|
735 |
+
pass
|
736 |
+
|
737 |
+
|
738 |
+
return toreturn
|
739 |
+
|
740 |
+
|
741 |
+
def backout_without_coref(results, coref_results, coref_results_dict, coref_smiles_to_graphs, molscribe):
|
742 |
+
|
743 |
+
toreturn = []
|
744 |
+
|
745 |
+
if not results or not results[0]['reactions'] or not coref_results:
|
746 |
+
return toreturn
|
747 |
+
|
748 |
+
try:
|
749 |
+
reactants = results[0]['reactions'][0]['reactants']
|
750 |
+
products = [i['smiles'] for i in results[0]['reactions'][0]['products']]
|
751 |
+
coref_results_dict = coref_results_dict
|
752 |
+
coref_smiles_to_graphs = coref_smiles_to_graphs
|
753 |
+
|
754 |
+
|
755 |
+
if len(products) == 1:
|
756 |
+
if products[0] not in coref_results_dict:
|
757 |
+
print("Warning: No Label Parsed")
|
758 |
+
return
|
759 |
+
product_labels = coref_results_dict[products[0]]
|
760 |
+
prod = products[0]
|
761 |
+
label_idx = product_labels[0]
|
762 |
+
'''
|
763 |
+
if len(product_labels) == 1:
|
764 |
+
# get the coreference label of the product molecule
|
765 |
+
label_idx = product_labels[0]
|
766 |
+
else:
|
767 |
+
print("Warning: Malformed Label Parsed.")
|
768 |
+
return
|
769 |
+
'''
|
770 |
+
else:
|
771 |
+
print("Warning: More than one product detected")
|
772 |
+
return
|
773 |
+
|
774 |
+
# format the regular expression for labels that correspond to the product label
|
775 |
+
numbers = re.findall(r'\d+', label_idx)
|
776 |
+
label_idx = numbers[0] if len(numbers) > 0 else ""
|
777 |
+
label_pattern = rf'{re.escape(label_idx)}[a-zA-Z]+'
|
778 |
+
|
779 |
+
|
780 |
+
prod_smiles = prod
|
781 |
+
prod_mol = Chem.MolFromMolBlock(results[0]['reactions'][0]['products'][0]['molfile'])
|
782 |
+
|
783 |
+
# identify the atom indices of the R groups in the product tempalte
|
784 |
+
h_counter = 0
|
785 |
+
r_sites = {}
|
786 |
+
for idx, atom in enumerate(results[0]['reactions'][0]['products'][0]['atoms']):
|
787 |
+
sym = atom['atom_symbol']
|
788 |
+
if sym == '[H]':
|
789 |
+
h_counter += 1
|
790 |
+
if sym[0] == '[':
|
791 |
+
sym = sym[1:-1]
|
792 |
+
if sym[0] == 'R' and sym[1:].isdigit():
|
793 |
+
sym = sym[1:]+"*"
|
794 |
+
sym = f'[{sym}]'
|
795 |
+
if sym in RGROUP_SYMBOLS:
|
796 |
+
if sym not in r_sites:
|
797 |
+
r_sites[sym] = [idx-h_counter]
|
798 |
+
else:
|
799 |
+
r_sites[sym].append(idx-h_counter)
|
800 |
+
|
801 |
+
r_sites_reversed = {}
|
802 |
+
for sym in r_sites:
|
803 |
+
for pos in r_sites[sym]:
|
804 |
+
r_sites_reversed[pos] = sym
|
805 |
+
|
806 |
+
num_r_groups = len(r_sites_reversed)
|
807 |
+
|
808 |
+
#prepare the product template and get the associated mapping
|
809 |
+
|
810 |
+
prod_mol_to_query, prod_template_mol_query = get_atom_mapping(prod_mol, prod_smiles, r_sites_reversed = r_sites_reversed)
|
811 |
+
|
812 |
+
reactant_mols = []
|
813 |
+
|
814 |
+
|
815 |
+
#--------------process the reactants-----------------
|
816 |
+
|
817 |
+
reactant_information = {} #index of relevant reaction --> [[R group name, atom index of R group, atom index of R group connection], ...]
|
818 |
+
|
819 |
+
for idx, reactant in enumerate(reactants):
|
820 |
+
reactant_information[idx] = []
|
821 |
+
reactant_mols.append(Chem.MolFromSmiles(reactant['smiles']))
|
822 |
+
has_r = False
|
823 |
+
|
824 |
+
r_sites_reactant = {}
|
825 |
+
|
826 |
+
h_counter = 0
|
827 |
+
|
828 |
+
for a_idx, atom in enumerate(reactant['atoms']):
|
829 |
+
|
830 |
+
#go through all atoms and check if they are an R group, if so add it to reactant information
|
831 |
+
sym = atom['atom_symbol']
|
832 |
+
if sym == '[H]':
|
833 |
+
h_counter += 1
|
834 |
+
if sym[0] == '[':
|
835 |
+
sym = sym[1:-1]
|
836 |
+
if sym[0] == 'R' and sym[1:].isdigit():
|
837 |
+
sym = sym[1:]+"*"
|
838 |
+
sym = f'[{sym}]'
|
839 |
+
if sym in r_sites:
|
840 |
+
if reactant_mols[-1].GetNumAtoms()==1:
|
841 |
+
reactant_information[idx].append([sym, -1, -1])
|
842 |
+
else:
|
843 |
+
has_r = True
|
844 |
+
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
|
845 |
+
reactant_information[idx].append([sym, a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
|
846 |
+
r_sites_reactant[sym] = a_idx-h_counter
|
847 |
+
elif sym == '[1*]' and '[7*]' in r_sites:
|
848 |
+
if reactant_mols[-1].GetNumAtoms()==1:
|
849 |
+
reactant_information[idx].append(['[7*]', -1, -1])
|
850 |
+
else:
|
851 |
+
has_r = True
|
852 |
+
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
|
853 |
+
reactant_information[idx].append(['[7*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
|
854 |
+
r_sites_reactant['[7*]'] = a_idx-h_counter
|
855 |
+
elif sym == '[7*]' and '[1*]' in r_sites:
|
856 |
+
if reactant_mols[-1].GetNumAtoms()==1:
|
857 |
+
reactant_information[idx].append(['[1*]', -1, -1])
|
858 |
+
else:
|
859 |
+
has_r = True
|
860 |
+
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
|
861 |
+
reactant_information[idx].append(['[1*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
|
862 |
+
r_sites_reactant['[1*]'] = a_idx-h_counter
|
863 |
+
|
864 |
+
elif sym == '[1*]' and '[Rf]' in r_sites:
|
865 |
+
if reactant_mols[-1].GetNumAtoms()==1:
|
866 |
+
reactant_information[idx].append(['[Rf]', -1, -1])
|
867 |
+
else:
|
868 |
+
has_r = True
|
869 |
+
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
|
870 |
+
reactant_information[idx].append(['[Rf]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
|
871 |
+
r_sites_reactant['[Rf]'] = a_idx-h_counter
|
872 |
+
|
873 |
+
elif sym == '[Rf]' and '[1*]' in r_sites:
|
874 |
+
if reactant_mols[-1].GetNumAtoms()==1:
|
875 |
+
reactant_information[idx].append(['[1*]', -1, -1])
|
876 |
+
else:
|
877 |
+
has_r = True
|
878 |
+
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
|
879 |
+
reactant_information[idx].append(['[1*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
|
880 |
+
r_sites_reactant['[1*]'] = a_idx-h_counter
|
881 |
+
|
882 |
+
r_sites_reversed_reactant = {r_sites_reactant[i]: i for i in r_sites_reactant}
|
883 |
+
# if the reactant had r groups, we had to use the molecule generated from the MolBlock.
|
884 |
+
# but the molblock may have unexpanded elemeents that are not R groups
|
885 |
+
# so we have to map back the r group indices in the molblock version to the full molecule generated by the smiles
|
886 |
+
# and adjust the indices of the r groups accordingly
|
887 |
+
if has_r:
|
888 |
+
#get the mapping
|
889 |
+
reactant_mol_to_query, _ = get_atom_mapping(reactant_mols[-1], reactant['smiles'], r_sites_reversed = r_sites_reversed_reactant)
|
890 |
+
|
891 |
+
#make the adjustment
|
892 |
+
for info in reactant_information[idx]:
|
893 |
+
info[1] = reactant_mol_to_query[info[1]]
|
894 |
+
info[2] = reactant_mol_to_query[info[2]]
|
895 |
+
reactant_mols[-1] = Chem.MolFromSmiles(reactant['smiles'])
|
896 |
+
|
897 |
+
#go through all the molecules in the coreference
|
898 |
+
|
899 |
+
clean_corefs(coref_results_dict, label_idx)
|
900 |
+
|
901 |
+
for other_prod in coref_results_dict:
|
902 |
+
|
903 |
+
#check if they match the product label regex
|
904 |
+
found_good_label = False
|
905 |
+
for parsed in coref_results_dict[other_prod]:
|
906 |
+
if re.search(label_pattern, parsed) and not found_good_label:
|
907 |
+
found_good_label = True
|
908 |
+
other_prod_mol = Chem.MolFromSmiles(other_prod)
|
909 |
+
|
910 |
+
if other_prod != prod_smiles and other_prod_mol is not None:
|
911 |
+
|
912 |
+
#check if there are R groups to be resolved in the target product
|
913 |
+
|
914 |
+
all_other_prod_mols = []
|
915 |
+
|
916 |
+
r_group_sub_pattern = re.compile('(?P<name>[RXY]\d?)[ ]*=[ ]*(?P<group>\w+)')
|
917 |
+
|
918 |
+
for parsed_labels in coref_results_dict[other_prod]:
|
919 |
+
res = r_group_sub_pattern.search(parsed_labels)
|
920 |
+
|
921 |
+
if res is not None:
|
922 |
+
all_other_prod_mols.append((expand_r_group_label_helper(res, coref_smiles_to_graphs, other_prod, molscribe), parsed + parsed_labels))
|
923 |
+
|
924 |
+
if len(all_other_prod_mols) == 0:
|
925 |
+
if other_prod_mol is not None:
|
926 |
+
all_other_prod_mols.append((other_prod_mol, parsed))
|
927 |
+
|
928 |
+
|
929 |
+
|
930 |
+
|
931 |
+
for other_prod_mol, parsed in all_other_prod_mols:
|
932 |
+
|
933 |
+
other_prod_frags = Chem.GetMolFrags(other_prod_mol, asMols = True)
|
934 |
+
|
935 |
+
for other_prod_frag in other_prod_frags:
|
936 |
+
substructs = other_prod_frag.GetSubstructMatches(prod_template_mol_query, uniquify = False)
|
937 |
+
|
938 |
+
if len(substructs)>0:
|
939 |
+
other_prod_mol = other_prod_frag
|
940 |
+
break
|
941 |
+
r_sites_reversed_new = {prod_mol_to_query[r]: r_sites_reversed[r] for r in r_sites_reversed}
|
942 |
+
|
943 |
+
queries = query_enumeration(prod_template_mol_query, r_sites_reversed_new, num_r_groups)
|
944 |
+
|
945 |
+
matched = False
|
946 |
+
|
947 |
+
for query in queries:
|
948 |
+
if not matched:
|
949 |
+
try:
|
950 |
+
matched = get_r_group_frags_and_substitute(other_prod_mol, query, reactant_mols, reactant_information, parsed, toreturn)
|
951 |
+
except:
|
952 |
+
pass
|
953 |
+
|
954 |
+
except:
|
955 |
+
pass
|
956 |
+
|
957 |
+
|
958 |
+
return toreturn
|
959 |
+
|
960 |
+
|
961 |
+
|
962 |
+
def associate_corefs(results, results_coref):
|
963 |
+
coref_smiles = {}
|
964 |
+
idx_pattern = r'\b\d+[a-zA-Z]{0,2}\b'
|
965 |
+
for result_coref in results_coref:
|
966 |
+
bboxes, corefs = result_coref['bboxes'], result_coref['corefs']
|
967 |
+
for coref in corefs:
|
968 |
+
mol, idt = coref[0], coref[1]
|
969 |
+
if len(bboxes[idt]['text']) > 0:
|
970 |
+
for text in bboxes[idt]['text']:
|
971 |
+
matches = re.findall(idx_pattern, text)
|
972 |
+
for match in matches:
|
973 |
+
coref_smiles[match] = bboxes[mol]['smiles']
|
974 |
+
|
975 |
+
for page in results:
|
976 |
+
for reactions in page['reactions']:
|
977 |
+
for reaction in reactions['reactions']:
|
978 |
+
if 'Reactants' in reaction:
|
979 |
+
if isinstance(reaction['Reactants'], tuple):
|
980 |
+
if reaction['Reactants'][0] in coref_smiles:
|
981 |
+
reaction['Reactants'] = (f'{reaction["Reactants"][0]} ({coref_smiles[reaction["Reactants"][0]]})', reaction['Reactants'][1], reaction['Reactants'][2])
|
982 |
+
else:
|
983 |
+
for idx, compound in enumerate(reaction['Reactants']):
|
984 |
+
if compound[0] in coref_smiles:
|
985 |
+
reaction['Reactants'][idx] = (f'{compound[0]} ({coref_smiles[compound[0]]})', compound[1], compound[2])
|
986 |
+
if 'Product' in reaction:
|
987 |
+
if isinstance(reaction['Product'], tuple):
|
988 |
+
if reaction['Product'][0] in coref_smiles:
|
989 |
+
reaction['Product'] = (f'{reaction["Product"][0]} ({coref_smiles[reaction["Product"][0]]})', reaction['Product'][1], reaction['Product'][2])
|
990 |
+
else:
|
991 |
+
for idx, compound in enumerate(reaction['Product']):
|
992 |
+
if compound[0] in coref_smiles:
|
993 |
+
reaction['Product'][idx] = (f'{compound[0]} ({coref_smiles[compound[0]]})', compound[1], compound[2])
|
994 |
+
return results
|
995 |
+
|
996 |
+
|
997 |
+
def expand_reactions_with_backout(initial_results, results_coref, molscribe):
|
998 |
+
idx_pattern = r'^\d+[a-zA-Z]{0,2}$'
|
999 |
+
for reactions, result_coref in zip(initial_results, results_coref):
|
1000 |
+
if not reactions['reactions']:
|
1001 |
+
continue
|
1002 |
+
try:
|
1003 |
+
backout_results = backout([reactions], [result_coref], molscribe)
|
1004 |
+
except Exception:
|
1005 |
+
continue
|
1006 |
+
conditions = reactions['reactions'][0]['conditions']
|
1007 |
+
idt_to_smiles = {}
|
1008 |
+
if not backout_results:
|
1009 |
+
continue
|
1010 |
+
|
1011 |
+
for reactants, products, idt in backout_results:
|
1012 |
+
reactions['reactions'].append({
|
1013 |
+
'reactants': [{'category': '[Mol]', 'molfile': None, 'smiles': reactant} for reactant in reactants],
|
1014 |
+
'conditions': conditions[:],
|
1015 |
+
'products': [{'category': '[Mol]', 'molfile': None, 'smiles': product} for product in products]
|
1016 |
+
})
|
1017 |
+
return initial_results
|
1018 |
+
|
examples/exp.png
ADDED
![]() |
Git LFS Details
|
examples/image.webp
ADDED
![]() |
examples/rdkit.png
ADDED
![]() |
examples/reaction1.jpg
ADDED
![]() |
examples/reaction2.png
ADDED
![]() |
examples/reaction3.png
ADDED
![]() |
examples/reaction4.png
ADDED
![]() |
Git LFS Details
|
get_molecular_agent.py
ADDED
@@ -0,0 +1,599 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import torch
|
3 |
+
import json
|
4 |
+
from chemietoolkit import ChemIEToolkit
|
5 |
+
import cv2
|
6 |
+
from PIL import Image
|
7 |
+
import json
|
8 |
+
import sys
|
9 |
+
#sys.path.append('./RxnScribe-main/')
|
10 |
+
import torch
|
11 |
+
from rxnscribe import RxnScribe
|
12 |
+
import json
|
13 |
+
import sys
|
14 |
+
import torch
|
15 |
+
import json
|
16 |
+
model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
17 |
+
from molscribe.chemistry import _convert_graph_to_smiles
|
18 |
+
import base64
|
19 |
+
import torch
|
20 |
+
import json
|
21 |
+
from PIL import Image
|
22 |
+
import numpy as np
|
23 |
+
from chemietoolkit import ChemIEToolkit, utils
|
24 |
+
from openai import AzureOpenAI
|
25 |
+
import os
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
ckpt_path = "./pix2seq_reaction_full.ckpt"
|
31 |
+
model1 = RxnScribe(ckpt_path, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
32 |
+
device = torch.device(('cuda' if torch.cuda.is_available() else 'cpu'))
|
33 |
+
|
34 |
+
model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
35 |
+
|
36 |
+
def get_multi_molecular(image_path: str) -> list:
|
37 |
+
'''Returns a list of reactions extracted from the image.'''
|
38 |
+
# 打开图像文件
|
39 |
+
image = Image.open(image_path).convert('RGB')
|
40 |
+
|
41 |
+
# 将图像作为输入传递给模型
|
42 |
+
coref_results = model.extract_molecule_corefs_from_figures([image])
|
43 |
+
for item in coref_results:
|
44 |
+
for bbox in item.get("bboxes", []):
|
45 |
+
for key in ["category", "molfile", "symbols", 'atoms', "bonds", 'category_id', 'score', 'corefs']: #'atoms'
|
46 |
+
bbox.pop(key, None) # 安全地移除键
|
47 |
+
print(json.dumps(coref_results))
|
48 |
+
# 返回反应列表,使用 json.dumps 进行格式化
|
49 |
+
|
50 |
+
return json.dumps(coref_results)
|
51 |
+
|
52 |
+
def get_multi_molecular_text_to_correct(image_path: str) -> list:
|
53 |
+
'''Returns a list of reactions extracted from the image.'''
|
54 |
+
# 打开图像文件
|
55 |
+
image = Image.open(image_path).convert('RGB')
|
56 |
+
|
57 |
+
# 将图像作为输入传递给模型
|
58 |
+
coref_results = model.extract_molecule_corefs_from_figures([image])
|
59 |
+
for item in coref_results:
|
60 |
+
for bbox in item.get("bboxes", []):
|
61 |
+
for key in ["category", "bbox", "molfile", "symbols", 'atoms', "bonds", 'category_id', 'score', 'corefs']: #'atoms'
|
62 |
+
bbox.pop(key, None) # 安全地移除键
|
63 |
+
print(json.dumps(coref_results))
|
64 |
+
# 返回反应列表,使用 json.dumps 进行格式化
|
65 |
+
|
66 |
+
return json.dumps(coref_results)
|
67 |
+
|
68 |
+
def get_multi_molecular_text_to_correct_withatoms(image_path: str) -> list:
|
69 |
+
'''Returns a list of reactions extracted from the image.'''
|
70 |
+
# 打开图像文件
|
71 |
+
image = Image.open(image_path).convert('RGB')
|
72 |
+
|
73 |
+
# 将图像作为输入传递给模型
|
74 |
+
coref_results = model.extract_molecule_corefs_from_figures([image])
|
75 |
+
for item in coref_results:
|
76 |
+
for bbox in item.get("bboxes", []):
|
77 |
+
for key in ["coords","edges","molfile", 'atoms', "bonds", 'category_id', 'score', 'corefs']: #'atoms'
|
78 |
+
bbox.pop(key, None) # 安全地移除键
|
79 |
+
print(json.dumps(coref_results))
|
80 |
+
# 返回反应列表,使用 json.dumps 进行格式化
|
81 |
+
return json.dumps(coref_results)
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
def process_reaction_image_with_multiple_products_and_text(image_path: str) -> dict:
|
89 |
+
"""
|
90 |
+
|
91 |
+
|
92 |
+
Args:
|
93 |
+
image_path (str): 图像文件路径。
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
dict: 整理后的反应数据,包括反应物、产物和反应模板。
|
97 |
+
"""
|
98 |
+
# 配置 API Key 和 Azure Endpoint
|
99 |
+
api_key = os.getenv("CHEMEAGLE_API_KEY")
|
100 |
+
if not api_key:
|
101 |
+
raise RuntimeError("Missing CHEMEAGLE_API_KEY environment variable")
|
102 |
+
# 替换为实际的 API Key
|
103 |
+
azure_endpoint = "https://hkust.azure-api.net" # 替换为实际的 Azure Endpoint
|
104 |
+
|
105 |
+
|
106 |
+
model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
107 |
+
client = AzureOpenAI(
|
108 |
+
api_key=api_key,
|
109 |
+
api_version='2024-06-01',
|
110 |
+
azure_endpoint=azure_endpoint
|
111 |
+
)
|
112 |
+
|
113 |
+
# 加载图像并编码为 Base64
|
114 |
+
def encode_image(image_path: str):
|
115 |
+
with open(image_path, "rb") as image_file:
|
116 |
+
return base64.b64encode(image_file.read()).decode('utf-8')
|
117 |
+
|
118 |
+
base64_image = encode_image(image_path)
|
119 |
+
|
120 |
+
# GPT 工具调用配置
|
121 |
+
tools = [
|
122 |
+
{
|
123 |
+
'type': 'function',
|
124 |
+
'function': {
|
125 |
+
'name': 'get_multi_molecular_text_to_correct_withatoms',
|
126 |
+
'description': 'Extracts the SMILES string, the symbols set, and the text coref of all molecular images in a table-reaction image and ready to be correct.',
|
127 |
+
'parameters': {
|
128 |
+
'type': 'object',
|
129 |
+
'properties': {
|
130 |
+
'image_path': {
|
131 |
+
'type': 'string',
|
132 |
+
'description': 'The path to the reaction image.',
|
133 |
+
},
|
134 |
+
},
|
135 |
+
'required': ['image_path'],
|
136 |
+
'additionalProperties': False,
|
137 |
+
},
|
138 |
+
},
|
139 |
+
},
|
140 |
+
|
141 |
+
]
|
142 |
+
|
143 |
+
# 提供给 GPT 的消息内容
|
144 |
+
with open('./prompt_getmolecular.txt', 'r') as prompt_file:
|
145 |
+
prompt = prompt_file.read()
|
146 |
+
messages = [
|
147 |
+
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
148 |
+
{
|
149 |
+
'role': 'user',
|
150 |
+
'content': [
|
151 |
+
{'type': 'text', 'text': prompt},
|
152 |
+
{'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}}
|
153 |
+
]
|
154 |
+
}
|
155 |
+
]
|
156 |
+
|
157 |
+
# 调用 GPT 接口
|
158 |
+
response = client.chat.completions.create(
|
159 |
+
model = 'gpt-4o',
|
160 |
+
temperature = 0,
|
161 |
+
response_format={ 'type': 'json_object' },
|
162 |
+
messages = [
|
163 |
+
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
164 |
+
{
|
165 |
+
'role': 'user',
|
166 |
+
'content': [
|
167 |
+
{
|
168 |
+
'type': 'text',
|
169 |
+
'text': prompt
|
170 |
+
},
|
171 |
+
{
|
172 |
+
'type': 'image_url',
|
173 |
+
'image_url': {
|
174 |
+
'url': f'data:image/png;base64,{base64_image}'
|
175 |
+
}
|
176 |
+
}
|
177 |
+
]},
|
178 |
+
],
|
179 |
+
tools = tools)
|
180 |
+
|
181 |
+
# Step 1: 工具映射表
|
182 |
+
TOOL_MAP = {
|
183 |
+
'get_multi_molecular_text_to_correct_withatoms': get_multi_molecular_text_to_correct_withatoms,
|
184 |
+
}
|
185 |
+
|
186 |
+
# Step 2: 处理多个工具调用
|
187 |
+
tool_calls = response.choices[0].message.tool_calls
|
188 |
+
results = []
|
189 |
+
|
190 |
+
# 遍历每个工具调用
|
191 |
+
for tool_call in tool_calls:
|
192 |
+
tool_name = tool_call.function.name
|
193 |
+
tool_arguments = tool_call.function.arguments
|
194 |
+
tool_call_id = tool_call.id
|
195 |
+
|
196 |
+
tool_args = json.loads(tool_arguments)
|
197 |
+
|
198 |
+
if tool_name in TOOL_MAP:
|
199 |
+
# 调用工具并获取结果
|
200 |
+
tool_result = TOOL_MAP[tool_name](image_path)
|
201 |
+
else:
|
202 |
+
raise ValueError(f"Unknown tool called: {tool_name}")
|
203 |
+
|
204 |
+
# 保存每个工具调用结果
|
205 |
+
results.append({
|
206 |
+
'role': 'tool',
|
207 |
+
'content': json.dumps({
|
208 |
+
'image_path': image_path,
|
209 |
+
f'{tool_name}':(tool_result),
|
210 |
+
}),
|
211 |
+
'tool_call_id': tool_call_id,
|
212 |
+
})
|
213 |
+
|
214 |
+
|
215 |
+
# Prepare the chat completion payload
|
216 |
+
completion_payload = {
|
217 |
+
'model': 'gpt-4o',
|
218 |
+
'messages': [
|
219 |
+
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
220 |
+
{
|
221 |
+
'role': 'user',
|
222 |
+
'content': [
|
223 |
+
{
|
224 |
+
'type': 'text',
|
225 |
+
'text': prompt
|
226 |
+
},
|
227 |
+
{
|
228 |
+
'type': 'image_url',
|
229 |
+
'image_url': {
|
230 |
+
'url': f'data:image/png;base64,{base64_image}'
|
231 |
+
}
|
232 |
+
}
|
233 |
+
]
|
234 |
+
},
|
235 |
+
response.choices[0].message,
|
236 |
+
*results
|
237 |
+
],
|
238 |
+
}
|
239 |
+
|
240 |
+
# Generate new response
|
241 |
+
response = client.chat.completions.create(
|
242 |
+
model=completion_payload["model"],
|
243 |
+
messages=completion_payload["messages"],
|
244 |
+
response_format={ 'type': 'json_object' },
|
245 |
+
temperature=0
|
246 |
+
)
|
247 |
+
|
248 |
+
|
249 |
+
|
250 |
+
# 获取 GPT 生成的结果
|
251 |
+
gpt_output = [json.loads(response.choices[0].message.content)]
|
252 |
+
|
253 |
+
|
254 |
+
def get_multi_molecular(image_path: str) -> list:
|
255 |
+
'''Returns a list of reactions extracted from the image.'''
|
256 |
+
# 打开图像文件
|
257 |
+
image = Image.open(image_path).convert('RGB')
|
258 |
+
|
259 |
+
# 将图像作为输入传递给模型
|
260 |
+
coref_results = model.extract_molecule_corefs_from_figures([image])
|
261 |
+
return coref_results
|
262 |
+
|
263 |
+
|
264 |
+
coref_results = get_multi_molecular(image_path)
|
265 |
+
|
266 |
+
|
267 |
+
def update_symbols_in_atoms(input1, input2):
|
268 |
+
"""
|
269 |
+
用 input1 中更新后的 'symbols' 替换 input2 中对应 bboxes 的 'symbols',并同步更新 'atoms' 的 'atom_symbol'。
|
270 |
+
假设 input1 和 input2 的结构一致。
|
271 |
+
"""
|
272 |
+
for item1, item2 in zip(input1, input2):
|
273 |
+
bboxes1 = item1.get('bboxes', [])
|
274 |
+
bboxes2 = item2.get('bboxes', [])
|
275 |
+
|
276 |
+
if len(bboxes1) != len(bboxes2):
|
277 |
+
print("Warning: Mismatched number of bboxes!")
|
278 |
+
continue
|
279 |
+
|
280 |
+
for bbox1, bbox2 in zip(bboxes1, bboxes2):
|
281 |
+
# 更新 symbols
|
282 |
+
if 'symbols' in bbox1:
|
283 |
+
bbox2['symbols'] = bbox1['symbols'] # 更新 symbols
|
284 |
+
|
285 |
+
# 更新 atoms 的 atom_symbol
|
286 |
+
if 'symbols' in bbox1 and 'atoms' in bbox2:
|
287 |
+
symbols = bbox1['symbols']
|
288 |
+
atoms = bbox2.get('atoms', [])
|
289 |
+
|
290 |
+
# 确保 symbols 和 atoms 的长度一致
|
291 |
+
if len(symbols) != len(atoms):
|
292 |
+
print(f"Warning: Mismatched symbols and atoms in bbox {bbox1.get('bbox')}!")
|
293 |
+
continue
|
294 |
+
|
295 |
+
for atom, symbol in zip(atoms, symbols):
|
296 |
+
atom['atom_symbol'] = symbol # 更新 atom_symbol
|
297 |
+
|
298 |
+
return input2
|
299 |
+
|
300 |
+
|
301 |
+
input2_updated = update_symbols_in_atoms(gpt_output, coref_results)
|
302 |
+
|
303 |
+
|
304 |
+
|
305 |
+
|
306 |
+
|
307 |
+
def update_smiles_and_molfile(input_data, conversion_function):
|
308 |
+
"""
|
309 |
+
使用更新后的 'symbols'、'coords' 和 'edges' 调用 `conversion_function` 生成新的 'smiles' 和 'molfile',
|
310 |
+
并替换到原数据结构中。
|
311 |
+
|
312 |
+
参数:
|
313 |
+
- input_data: 包含 bboxes 的嵌套数据结构
|
314 |
+
- conversion_function: 函数,接受 'coords', 'symbols', 'edges' 并返回 (new_smiles, new_molfile, _)
|
315 |
+
|
316 |
+
返回:
|
317 |
+
- 更新后的数据结构
|
318 |
+
"""
|
319 |
+
for item in input_data:
|
320 |
+
for bbox in item.get('bboxes', []):
|
321 |
+
# 检查必需的键是否存在
|
322 |
+
if all(key in bbox for key in ['coords', 'symbols', 'edges']):
|
323 |
+
coords = bbox['coords']
|
324 |
+
symbols = bbox['symbols']
|
325 |
+
edges = bbox['edges']
|
326 |
+
|
327 |
+
# 调用转换函数生成新的 'smiles' 和 'molfile'
|
328 |
+
new_smiles, new_molfile, _ = conversion_function(coords, symbols, edges)
|
329 |
+
print(f" Generated 'smiles': {new_smiles}")
|
330 |
+
|
331 |
+
# 替换旧的 'smiles' 和 'molfile'
|
332 |
+
bbox['smiles'] = new_smiles
|
333 |
+
bbox['molfile'] = new_molfile
|
334 |
+
|
335 |
+
return input_data
|
336 |
+
|
337 |
+
updated_data = update_smiles_and_molfile(input2_updated, _convert_graph_to_smiles)
|
338 |
+
|
339 |
+
return updated_data
|
340 |
+
|
341 |
+
|
342 |
+
|
343 |
+
|
344 |
+
|
345 |
+
|
346 |
+
|
347 |
+
|
348 |
+
|
349 |
+
def process_reaction_image_with_multiple_products_and_text_correctR(image_path: str) -> dict:
|
350 |
+
"""
|
351 |
+
|
352 |
+
|
353 |
+
Args:
|
354 |
+
image_path (str): 图像文件路径。
|
355 |
+
|
356 |
+
Returns:
|
357 |
+
dict: 整理后的反应数据,包括反应物、产物和反应模板。
|
358 |
+
"""
|
359 |
+
# 配置 API Key 和 Azure Endpoint
|
360 |
+
api_key = os.getenv("CHEMEAGLE_API_KEY")
|
361 |
+
if not api_key:
|
362 |
+
raise RuntimeError("Missing CHEMEAGLE_API_KEY environment variable")
|
363 |
+
# 替换为实际的 API Key
|
364 |
+
azure_endpoint = "https://hkust.azure-api.net" # 替换为实际的 Azure Endpoint
|
365 |
+
model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
366 |
+
client = AzureOpenAI(
|
367 |
+
api_key=api_key,
|
368 |
+
api_version='2024-06-01',
|
369 |
+
azure_endpoint=azure_endpoint
|
370 |
+
)
|
371 |
+
|
372 |
+
# 加载图像并编码为 Base64
|
373 |
+
def encode_image(image_path: str):
|
374 |
+
with open(image_path, "rb") as image_file:
|
375 |
+
return base64.b64encode(image_file.read()).decode('utf-8')
|
376 |
+
|
377 |
+
base64_image = encode_image(image_path)
|
378 |
+
|
379 |
+
# GPT 工具调用配置
|
380 |
+
tools = [
|
381 |
+
{
|
382 |
+
'type': 'function',
|
383 |
+
'function': {
|
384 |
+
'name': 'get_multi_molecular_text_to_correct_withatoms',
|
385 |
+
'description': 'Extracts the SMILES string, the symbols set, and the text coref of all molecular images in a table-reaction image and ready to be correct.',
|
386 |
+
'parameters': {
|
387 |
+
'type': 'object',
|
388 |
+
'properties': {
|
389 |
+
'image_path': {
|
390 |
+
'type': 'string',
|
391 |
+
'description': 'The path to the reaction image.',
|
392 |
+
},
|
393 |
+
},
|
394 |
+
'required': ['image_path'],
|
395 |
+
'additionalProperties': False,
|
396 |
+
},
|
397 |
+
},
|
398 |
+
},
|
399 |
+
|
400 |
+
]
|
401 |
+
|
402 |
+
# 提供给 GPT 的消息内容
|
403 |
+
with open('./prompt_getmolecular_correctR.txt', 'r') as prompt_file:
|
404 |
+
prompt = prompt_file.read()
|
405 |
+
messages = [
|
406 |
+
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
407 |
+
{
|
408 |
+
'role': 'user',
|
409 |
+
'content': [
|
410 |
+
{'type': 'text', 'text': prompt},
|
411 |
+
{'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}}
|
412 |
+
]
|
413 |
+
}
|
414 |
+
]
|
415 |
+
|
416 |
+
# 调用 GPT 接口
|
417 |
+
response = client.chat.completions.create(
|
418 |
+
model = 'gpt-4o',
|
419 |
+
temperature = 0,
|
420 |
+
response_format={ 'type': 'json_object' },
|
421 |
+
messages = [
|
422 |
+
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
423 |
+
{
|
424 |
+
'role': 'user',
|
425 |
+
'content': [
|
426 |
+
{
|
427 |
+
'type': 'text',
|
428 |
+
'text': prompt
|
429 |
+
},
|
430 |
+
{
|
431 |
+
'type': 'image_url',
|
432 |
+
'image_url': {
|
433 |
+
'url': f'data:image/png;base64,{base64_image}'
|
434 |
+
}
|
435 |
+
}
|
436 |
+
]},
|
437 |
+
],
|
438 |
+
tools = tools)
|
439 |
+
|
440 |
+
# Step 1: 工具映射表
|
441 |
+
TOOL_MAP = {
|
442 |
+
'get_multi_molecular_text_to_correct_withatoms': get_multi_molecular_text_to_correct_withatoms,
|
443 |
+
}
|
444 |
+
|
445 |
+
# Step 2: 处理多个工具调用
|
446 |
+
tool_calls = response.choices[0].message.tool_calls
|
447 |
+
results = []
|
448 |
+
|
449 |
+
# 遍历每个工具调用
|
450 |
+
for tool_call in tool_calls:
|
451 |
+
tool_name = tool_call.function.name
|
452 |
+
tool_arguments = tool_call.function.arguments
|
453 |
+
tool_call_id = tool_call.id
|
454 |
+
|
455 |
+
tool_args = json.loads(tool_arguments)
|
456 |
+
|
457 |
+
if tool_name in TOOL_MAP:
|
458 |
+
# 调用工具并获取结果
|
459 |
+
tool_result = TOOL_MAP[tool_name](image_path)
|
460 |
+
else:
|
461 |
+
raise ValueError(f"Unknown tool called: {tool_name}")
|
462 |
+
|
463 |
+
# 保存每个工具调用结果
|
464 |
+
results.append({
|
465 |
+
'role': 'tool',
|
466 |
+
'content': json.dumps({
|
467 |
+
'image_path': image_path,
|
468 |
+
f'{tool_name}':(tool_result),
|
469 |
+
}),
|
470 |
+
'tool_call_id': tool_call_id,
|
471 |
+
})
|
472 |
+
|
473 |
+
|
474 |
+
# Prepare the chat completion payload
|
475 |
+
completion_payload = {
|
476 |
+
'model': 'gpt-4o',
|
477 |
+
'messages': [
|
478 |
+
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
479 |
+
{
|
480 |
+
'role': 'user',
|
481 |
+
'content': [
|
482 |
+
{
|
483 |
+
'type': 'text',
|
484 |
+
'text': prompt
|
485 |
+
},
|
486 |
+
{
|
487 |
+
'type': 'image_url',
|
488 |
+
'image_url': {
|
489 |
+
'url': f'data:image/png;base64,{base64_image}'
|
490 |
+
}
|
491 |
+
}
|
492 |
+
]
|
493 |
+
},
|
494 |
+
response.choices[0].message,
|
495 |
+
*results
|
496 |
+
],
|
497 |
+
}
|
498 |
+
|
499 |
+
# Generate new response
|
500 |
+
response = client.chat.completions.create(
|
501 |
+
model=completion_payload["model"],
|
502 |
+
messages=completion_payload["messages"],
|
503 |
+
response_format={ 'type': 'json_object' },
|
504 |
+
temperature=0
|
505 |
+
)
|
506 |
+
|
507 |
+
|
508 |
+
|
509 |
+
# 获取 GPT 生成的结果
|
510 |
+
gpt_output = [json.loads(response.choices[0].message.content)]
|
511 |
+
|
512 |
+
|
513 |
+
def get_multi_molecular(image_path: str) -> list:
|
514 |
+
'''Returns a list of reactions extracted from the image.'''
|
515 |
+
# 打开图像文件
|
516 |
+
image = Image.open(image_path).convert('RGB')
|
517 |
+
|
518 |
+
# 将图像作为输入传递给模型
|
519 |
+
coref_results = model.extract_molecule_corefs_from_figures([image])
|
520 |
+
return coref_results
|
521 |
+
|
522 |
+
|
523 |
+
coref_results = get_multi_molecular(image_path)
|
524 |
+
|
525 |
+
|
526 |
+
def update_symbols_in_atoms(input1, input2):
|
527 |
+
"""
|
528 |
+
用 input1 中更新后的 'symbols' 替换 input2 中对应 bboxes 的 'symbols',并同步更新 'atoms' 的 'atom_symbol'。
|
529 |
+
假设 input1 和 input2 的结构一致。
|
530 |
+
"""
|
531 |
+
for item1, item2 in zip(input1, input2):
|
532 |
+
bboxes1 = item1.get('bboxes', [])
|
533 |
+
bboxes2 = item2.get('bboxes', [])
|
534 |
+
|
535 |
+
if len(bboxes1) != len(bboxes2):
|
536 |
+
print("Warning: Mismatched number of bboxes!")
|
537 |
+
continue
|
538 |
+
|
539 |
+
for bbox1, bbox2 in zip(bboxes1, bboxes2):
|
540 |
+
# 更新 symbols
|
541 |
+
if 'symbols' in bbox1:
|
542 |
+
bbox2['symbols'] = bbox1['symbols'] # 更新 symbols
|
543 |
+
|
544 |
+
# 更新 atoms 的 atom_symbol
|
545 |
+
if 'symbols' in bbox1 and 'atoms' in bbox2:
|
546 |
+
symbols = bbox1['symbols']
|
547 |
+
atoms = bbox2.get('atoms', [])
|
548 |
+
|
549 |
+
# 确保 symbols 和 atoms 的长度一致
|
550 |
+
if len(symbols) != len(atoms):
|
551 |
+
print(f"Warning: Mismatched symbols and atoms in bbox {bbox1.get('bbox')}!")
|
552 |
+
continue
|
553 |
+
|
554 |
+
for atom, symbol in zip(atoms, symbols):
|
555 |
+
atom['atom_symbol'] = symbol # 更新 atom_symbol
|
556 |
+
|
557 |
+
return input2
|
558 |
+
|
559 |
+
|
560 |
+
input2_updated = update_symbols_in_atoms(gpt_output, coref_results)
|
561 |
+
|
562 |
+
|
563 |
+
|
564 |
+
|
565 |
+
|
566 |
+
def update_smiles_and_molfile(input_data, conversion_function):
|
567 |
+
"""
|
568 |
+
使用更新后的 'symbols'、'coords' 和 'edges' 调用 `conversion_function` 生成新的 'smiles' 和 'molfile',
|
569 |
+
并替换到原数据结构中。
|
570 |
+
|
571 |
+
参数:
|
572 |
+
- input_data: 包含 bboxes 的嵌套数据结构
|
573 |
+
- conversion_function: 函数,接受 'coords', 'symbols', 'edges' 并返回 (new_smiles, new_molfile, _)
|
574 |
+
|
575 |
+
返回:
|
576 |
+
- 更新后的数据结构
|
577 |
+
"""
|
578 |
+
for item in input_data:
|
579 |
+
for bbox in item.get('bboxes', []):
|
580 |
+
# 检查必需的键是否存在
|
581 |
+
if all(key in bbox for key in ['coords', 'symbols', 'edges']):
|
582 |
+
coords = bbox['coords']
|
583 |
+
symbols = bbox['symbols']
|
584 |
+
edges = bbox['edges']
|
585 |
+
|
586 |
+
# 调用转换函数生成新的 'smiles' 和 'molfile'
|
587 |
+
new_smiles, new_molfile, _ = conversion_function(coords, symbols, edges)
|
588 |
+
print(f" Generated 'smiles': {new_smiles}")
|
589 |
+
|
590 |
+
# 替换旧的 'smiles' 和 'molfile'
|
591 |
+
bbox['smiles'] = new_smiles
|
592 |
+
bbox['molfile'] = new_molfile
|
593 |
+
|
594 |
+
return input_data
|
595 |
+
|
596 |
+
updated_data = update_smiles_and_molfile(input2_updated, _convert_graph_to_smiles)
|
597 |
+
print(f"updated_mol_data:{updated_data}")
|
598 |
+
|
599 |
+
return updated_data
|
get_reaction_agent.py
ADDED
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import torch
|
3 |
+
import json
|
4 |
+
from chemietoolkit import ChemIEToolkit
|
5 |
+
import cv2
|
6 |
+
from PIL import Image
|
7 |
+
import json
|
8 |
+
import sys
|
9 |
+
#sys.path.append('./RxnScribe-main/')
|
10 |
+
import torch
|
11 |
+
from rxnscribe import RxnScribe
|
12 |
+
import json
|
13 |
+
from molscribe.chemistry import _convert_graph_to_smiles
|
14 |
+
|
15 |
+
from openai import AzureOpenAI
|
16 |
+
import base64
|
17 |
+
import numpy as np
|
18 |
+
from chemietoolkit import utils
|
19 |
+
from PIL import Image
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
ckpt_path = "./pix2seq_reaction_full.ckpt"
|
25 |
+
model1 = RxnScribe(ckpt_path, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
26 |
+
device = torch.device(('cuda' if torch.cuda.is_available() else 'cpu'))
|
27 |
+
model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
28 |
+
|
29 |
+
def get_reaction(image_path: str) -> dict:
|
30 |
+
'''
|
31 |
+
Returns a structured dictionary of reactions extracted from the image,
|
32 |
+
including reactants, conditions, and products, with their smiles, text, and bbox.
|
33 |
+
'''
|
34 |
+
image_file = image_path
|
35 |
+
raw_prediction = model1.predict_image_file(image_file, molscribe=True, ocr=True)
|
36 |
+
|
37 |
+
# Ensure raw_prediction is treated as a list directly
|
38 |
+
structured_output = {}
|
39 |
+
for section_key in ['reactants', 'conditions', 'products']:
|
40 |
+
if section_key in raw_prediction[0]:
|
41 |
+
structured_output[section_key] = []
|
42 |
+
for item in raw_prediction[0][section_key]:
|
43 |
+
if section_key in ['reactants', 'products']:
|
44 |
+
# Extract smiles and bbox for molecules
|
45 |
+
structured_output[section_key].append({
|
46 |
+
"smiles": item.get("smiles", ""),
|
47 |
+
"bbox": item.get("bbox", []),
|
48 |
+
"symbols": item.get("symbols", [])
|
49 |
+
})
|
50 |
+
elif section_key == 'conditions':
|
51 |
+
# Extract smiles, text, and bbox for conditions
|
52 |
+
condition_data = {"bbox": item.get("bbox", [])}
|
53 |
+
if "smiles" in item:
|
54 |
+
condition_data["smiles"] = item.get("smiles", "")
|
55 |
+
if "text" in item:
|
56 |
+
condition_data["text"] = item.get("text", [])
|
57 |
+
structured_output[section_key].append(condition_data)
|
58 |
+
print(structured_output)
|
59 |
+
|
60 |
+
return structured_output
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
def get_full_reaction(image_path: str) -> dict:
|
65 |
+
'''
|
66 |
+
Returns a structured dictionary of reactions extracted from the image,
|
67 |
+
including reactants, conditions, and products, with their smiles, text, and bbox.
|
68 |
+
'''
|
69 |
+
image_file = image_path
|
70 |
+
raw_prediction = model1.predict_image_file(image_file, molscribe=True, ocr=True)
|
71 |
+
return raw_prediction
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
def get_reaction_withatoms(image_path: str) -> dict:
|
76 |
+
"""
|
77 |
+
|
78 |
+
Args:
|
79 |
+
image_path (str): 图像文件路径。
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
dict: 整理后的反应数据,包括反应物、产物和反应模板。
|
83 |
+
"""
|
84 |
+
# 配置 API Key 和 Azure Endpoint
|
85 |
+
api_key = "b038da96509b4009be931e035435e022" # 替换为实际的 API Key
|
86 |
+
azure_endpoint = "https://hkust.azure-api.net" # 替换为实际的 Azure Endpoint
|
87 |
+
|
88 |
+
model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
89 |
+
client = AzureOpenAI(
|
90 |
+
api_key=api_key,
|
91 |
+
api_version='2024-06-01',
|
92 |
+
azure_endpoint=azure_endpoint
|
93 |
+
)
|
94 |
+
|
95 |
+
# 加载图像并编码为 Base64
|
96 |
+
def encode_image(image_path: str):
|
97 |
+
with open(image_path, "rb") as image_file:
|
98 |
+
return base64.b64encode(image_file.read()).decode('utf-8')
|
99 |
+
|
100 |
+
base64_image = encode_image(image_path)
|
101 |
+
|
102 |
+
# GPT 工具调用配置
|
103 |
+
tools = [
|
104 |
+
{
|
105 |
+
'type': 'function',
|
106 |
+
'function': {
|
107 |
+
'name': 'get_reaction',
|
108 |
+
'description': 'Get a list of reactions from a reaction image. A reaction contains data of the reactants, conditions, and products.',
|
109 |
+
'parameters': {
|
110 |
+
'type': 'object',
|
111 |
+
'properties': {
|
112 |
+
'image_path': {
|
113 |
+
'type': 'string',
|
114 |
+
'description': 'The path to the reaction image.',
|
115 |
+
},
|
116 |
+
},
|
117 |
+
'required': ['image_path'],
|
118 |
+
'additionalProperties': False,
|
119 |
+
},
|
120 |
+
},
|
121 |
+
},
|
122 |
+
]
|
123 |
+
|
124 |
+
# 提供给 GPT 的消息内容
|
125 |
+
with open('./prompt_getreaction.txt', 'r') as prompt_file:
|
126 |
+
prompt = prompt_file.read()
|
127 |
+
messages = [
|
128 |
+
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
129 |
+
{
|
130 |
+
'role': 'user',
|
131 |
+
'content': [
|
132 |
+
{'type': 'text', 'text': prompt},
|
133 |
+
{'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}}
|
134 |
+
]
|
135 |
+
}
|
136 |
+
]
|
137 |
+
|
138 |
+
# 调用 GPT 接口
|
139 |
+
response = client.chat.completions.create(
|
140 |
+
model = 'gpt-4o',
|
141 |
+
temperature = 0,
|
142 |
+
response_format={ 'type': 'json_object' },
|
143 |
+
messages = [
|
144 |
+
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
145 |
+
{
|
146 |
+
'role': 'user',
|
147 |
+
'content': [
|
148 |
+
{
|
149 |
+
'type': 'text',
|
150 |
+
'text': prompt
|
151 |
+
},
|
152 |
+
{
|
153 |
+
'type': 'image_url',
|
154 |
+
'image_url': {
|
155 |
+
'url': f'data:image/png;base64,{base64_image}'
|
156 |
+
}
|
157 |
+
}
|
158 |
+
]},
|
159 |
+
],
|
160 |
+
tools = tools)
|
161 |
+
|
162 |
+
# Step 1: 工具映射表
|
163 |
+
TOOL_MAP = {
|
164 |
+
'get_reaction': get_reaction,
|
165 |
+
}
|
166 |
+
|
167 |
+
# Step 2: 处理多个工具调用
|
168 |
+
tool_calls = response.choices[0].message.tool_calls
|
169 |
+
results = []
|
170 |
+
|
171 |
+
# 遍历每个工具调用
|
172 |
+
for tool_call in tool_calls:
|
173 |
+
tool_name = tool_call.function.name
|
174 |
+
tool_arguments = tool_call.function.arguments
|
175 |
+
tool_call_id = tool_call.id
|
176 |
+
|
177 |
+
tool_args = json.loads(tool_arguments)
|
178 |
+
|
179 |
+
if tool_name in TOOL_MAP:
|
180 |
+
# 调用工具并获取结果
|
181 |
+
tool_result = TOOL_MAP[tool_name](image_path)
|
182 |
+
else:
|
183 |
+
raise ValueError(f"Unknown tool called: {tool_name}")
|
184 |
+
|
185 |
+
# 保存每个工具调用结果
|
186 |
+
results.append({
|
187 |
+
'role': 'tool',
|
188 |
+
'content': json.dumps({
|
189 |
+
'image_path': image_path,
|
190 |
+
f'{tool_name}':(tool_result),
|
191 |
+
}),
|
192 |
+
'tool_call_id': tool_call_id,
|
193 |
+
})
|
194 |
+
|
195 |
+
|
196 |
+
# Prepare the chat completion payload
|
197 |
+
completion_payload = {
|
198 |
+
'model': 'gpt-4o',
|
199 |
+
'messages': [
|
200 |
+
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
201 |
+
{
|
202 |
+
'role': 'user',
|
203 |
+
'content': [
|
204 |
+
{
|
205 |
+
'type': 'text',
|
206 |
+
'text': prompt
|
207 |
+
},
|
208 |
+
{
|
209 |
+
'type': 'image_url',
|
210 |
+
'image_url': {
|
211 |
+
'url': f'data:image/png;base64,{base64_image}'
|
212 |
+
}
|
213 |
+
}
|
214 |
+
]
|
215 |
+
},
|
216 |
+
response.choices[0].message,
|
217 |
+
*results
|
218 |
+
],
|
219 |
+
}
|
220 |
+
|
221 |
+
# Generate new response
|
222 |
+
response = client.chat.completions.create(
|
223 |
+
model=completion_payload["model"],
|
224 |
+
messages=completion_payload["messages"],
|
225 |
+
response_format={ 'type': 'json_object' },
|
226 |
+
temperature=0
|
227 |
+
)
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
# 获取 GPT 生成的结果
|
232 |
+
gpt_output = json.loads(response.choices[0].message.content)
|
233 |
+
print(f"gpt_output1:{gpt_output}")
|
234 |
+
|
235 |
+
|
236 |
+
def get_reaction_full(image_path: str) -> dict:
|
237 |
+
'''
|
238 |
+
Returns a structured dictionary of reactions extracted from the image,
|
239 |
+
including reactants, conditions, and products, with their smiles, text, and bbox.
|
240 |
+
'''
|
241 |
+
image_file = image_path
|
242 |
+
raw_prediction = model1.predict_image_file(image_file, molscribe=True, ocr=True)
|
243 |
+
return raw_prediction
|
244 |
+
|
245 |
+
input2 = get_reaction_full(image_path)
|
246 |
+
|
247 |
+
|
248 |
+
|
249 |
+
def update_input_with_symbols(input1, input2, conversion_function):
|
250 |
+
symbol_mapping = {}
|
251 |
+
for key in ['reactants', 'products']:
|
252 |
+
for item in input1.get(key, []):
|
253 |
+
bbox = tuple(item['bbox']) # 使用 bbox 作为唯一标识
|
254 |
+
symbol_mapping[bbox] = item['symbols']
|
255 |
+
|
256 |
+
for key in ['reactants', 'products']:
|
257 |
+
for item in input2.get(key, []):
|
258 |
+
bbox = tuple(item['bbox']) # 获取 bbox 作为匹配键
|
259 |
+
|
260 |
+
# 如果 bbox 存在于 input1 的映射中,则更新 symbols
|
261 |
+
if bbox in symbol_mapping:
|
262 |
+
updated_symbols = symbol_mapping[bbox]
|
263 |
+
item['symbols'] = updated_symbols
|
264 |
+
|
265 |
+
# 更新 atoms 的 atom_symbol
|
266 |
+
if 'atoms' in item:
|
267 |
+
atoms = item['atoms']
|
268 |
+
if len(atoms) != len(updated_symbols):
|
269 |
+
print(f"Warning: Mismatched symbols and atoms in bbox {bbox}")
|
270 |
+
else:
|
271 |
+
for atom, symbol in zip(atoms, updated_symbols):
|
272 |
+
atom['atom_symbol'] = symbol
|
273 |
+
|
274 |
+
# 如果 coords 和 edges 存在,调用转换函数生成新的 smiles 和 molfile
|
275 |
+
if 'coords' in item and 'edges' in item:
|
276 |
+
coords = item['coords']
|
277 |
+
edges = item['edges']
|
278 |
+
new_smiles, new_molfile, _ = conversion_function(coords, updated_symbols, edges)
|
279 |
+
|
280 |
+
# 替换旧的 smiles 和 molfile
|
281 |
+
item['smiles'] = new_smiles
|
282 |
+
item['molfile'] = new_molfile
|
283 |
+
|
284 |
+
return input2
|
285 |
+
|
286 |
+
updated_data = [update_input_with_symbols(gpt_output, input2[0], _convert_graph_to_smiles)]
|
287 |
+
|
288 |
+
return updated_data
|
289 |
+
|
290 |
+
|
291 |
+
|
292 |
+
|
293 |
+
def get_reaction_withatoms_correctR(image_path: str) -> dict:
|
294 |
+
"""
|
295 |
+
|
296 |
+
Args:
|
297 |
+
image_path (str): 图像文件路径。
|
298 |
+
|
299 |
+
Returns:
|
300 |
+
dict: 整理后的反应数据,包括反应物、产物和反应模板。
|
301 |
+
"""
|
302 |
+
# 配置 API Key 和 Azure Endpoint
|
303 |
+
api_key = "b038da96509b4009be931e035435e022" # 替换为实际的 API Key
|
304 |
+
azure_endpoint = "https://hkust.azure-api.net" # 替换为实际的 Azure Endpoint
|
305 |
+
|
306 |
+
model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
307 |
+
client = AzureOpenAI(
|
308 |
+
api_key=api_key,
|
309 |
+
api_version='2024-06-01',
|
310 |
+
azure_endpoint=azure_endpoint
|
311 |
+
)
|
312 |
+
|
313 |
+
# 加载图像并编码为 Base64
|
314 |
+
def encode_image(image_path: str):
|
315 |
+
with open(image_path, "rb") as image_file:
|
316 |
+
return base64.b64encode(image_file.read()).decode('utf-8')
|
317 |
+
|
318 |
+
base64_image = encode_image(image_path)
|
319 |
+
|
320 |
+
# GPT 工具调用配置
|
321 |
+
tools = [
|
322 |
+
{
|
323 |
+
'type': 'function',
|
324 |
+
'function': {
|
325 |
+
'name': 'get_reaction',
|
326 |
+
'description': 'Get a list of reactions from a reaction image. A reaction contains data of the reactants, conditions, and products.',
|
327 |
+
'parameters': {
|
328 |
+
'type': 'object',
|
329 |
+
'properties': {
|
330 |
+
'image_path': {
|
331 |
+
'type': 'string',
|
332 |
+
'description': 'The path to the reaction image.',
|
333 |
+
},
|
334 |
+
},
|
335 |
+
'required': ['image_path'],
|
336 |
+
'additionalProperties': False,
|
337 |
+
},
|
338 |
+
},
|
339 |
+
},
|
340 |
+
]
|
341 |
+
|
342 |
+
# 提供给 GPT 的消息内容
|
343 |
+
with open('./prompt_getreaction_correctR.txt', 'r') as prompt_file:
|
344 |
+
prompt = prompt_file.read()
|
345 |
+
messages = [
|
346 |
+
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
347 |
+
{
|
348 |
+
'role': 'user',
|
349 |
+
'content': [
|
350 |
+
{'type': 'text', 'text': prompt},
|
351 |
+
{'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}}
|
352 |
+
]
|
353 |
+
}
|
354 |
+
]
|
355 |
+
|
356 |
+
# 调用 GPT 接口
|
357 |
+
response = client.chat.completions.create(
|
358 |
+
model = 'gpt-4o',
|
359 |
+
temperature = 0,
|
360 |
+
response_format={ 'type': 'json_object' },
|
361 |
+
messages = [
|
362 |
+
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
363 |
+
{
|
364 |
+
'role': 'user',
|
365 |
+
'content': [
|
366 |
+
{
|
367 |
+
'type': 'text',
|
368 |
+
'text': prompt
|
369 |
+
},
|
370 |
+
{
|
371 |
+
'type': 'image_url',
|
372 |
+
'image_url': {
|
373 |
+
'url': f'data:image/png;base64,{base64_image}'
|
374 |
+
}
|
375 |
+
}
|
376 |
+
]},
|
377 |
+
],
|
378 |
+
tools = tools)
|
379 |
+
|
380 |
+
# Step 1: 工具映射表
|
381 |
+
TOOL_MAP = {
|
382 |
+
'get_reaction': get_reaction,
|
383 |
+
}
|
384 |
+
|
385 |
+
# Step 2: 处理多个工具调用
|
386 |
+
tool_calls = response.choices[0].message.tool_calls
|
387 |
+
results = []
|
388 |
+
|
389 |
+
# 遍历每个工具调用
|
390 |
+
for tool_call in tool_calls:
|
391 |
+
tool_name = tool_call.function.name
|
392 |
+
tool_arguments = tool_call.function.arguments
|
393 |
+
tool_call_id = tool_call.id
|
394 |
+
|
395 |
+
tool_args = json.loads(tool_arguments)
|
396 |
+
|
397 |
+
if tool_name in TOOL_MAP:
|
398 |
+
# 调用工具并获取结果
|
399 |
+
tool_result = TOOL_MAP[tool_name](image_path)
|
400 |
+
else:
|
401 |
+
raise ValueError(f"Unknown tool called: {tool_name}")
|
402 |
+
|
403 |
+
# 保存每个工具调用结果
|
404 |
+
results.append({
|
405 |
+
'role': 'tool',
|
406 |
+
'content': json.dumps({
|
407 |
+
'image_path': image_path,
|
408 |
+
f'{tool_name}':(tool_result),
|
409 |
+
}),
|
410 |
+
'tool_call_id': tool_call_id,
|
411 |
+
})
|
412 |
+
|
413 |
+
|
414 |
+
# Prepare the chat completion payload
|
415 |
+
completion_payload = {
|
416 |
+
'model': 'gpt-4o',
|
417 |
+
'messages': [
|
418 |
+
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
419 |
+
{
|
420 |
+
'role': 'user',
|
421 |
+
'content': [
|
422 |
+
{
|
423 |
+
'type': 'text',
|
424 |
+
'text': prompt
|
425 |
+
},
|
426 |
+
{
|
427 |
+
'type': 'image_url',
|
428 |
+
'image_url': {
|
429 |
+
'url': f'data:image/png;base64,{base64_image}'
|
430 |
+
}
|
431 |
+
}
|
432 |
+
]
|
433 |
+
},
|
434 |
+
response.choices[0].message,
|
435 |
+
*results
|
436 |
+
],
|
437 |
+
}
|
438 |
+
|
439 |
+
# Generate new response
|
440 |
+
response = client.chat.completions.create(
|
441 |
+
model=completion_payload["model"],
|
442 |
+
messages=completion_payload["messages"],
|
443 |
+
response_format={ 'type': 'json_object' },
|
444 |
+
temperature=0
|
445 |
+
)
|
446 |
+
|
447 |
+
|
448 |
+
|
449 |
+
# 获取 GPT 生成的结果
|
450 |
+
gpt_output = json.loads(response.choices[0].message.content)
|
451 |
+
print(f"gpt_output1:{gpt_output}")
|
452 |
+
|
453 |
+
|
454 |
+
def get_reaction_full(image_path: str) -> dict:
|
455 |
+
'''
|
456 |
+
Returns a structured dictionary of reactions extracted from the image,
|
457 |
+
including reactants, conditions, and products, with their smiles, text, and bbox.
|
458 |
+
'''
|
459 |
+
image_file = image_path
|
460 |
+
raw_prediction = model1.predict_image_file(image_file, molscribe=True, ocr=True)
|
461 |
+
return raw_prediction
|
462 |
+
|
463 |
+
input2 = get_reaction_full(image_path)
|
464 |
+
|
465 |
+
|
466 |
+
|
467 |
+
def update_input_with_symbols(input1, input2, conversion_function):
|
468 |
+
symbol_mapping = {}
|
469 |
+
for key in ['reactants', 'products']:
|
470 |
+
for item in input1.get(key, []):
|
471 |
+
bbox = tuple(item['bbox']) # 使用 bbox 作为唯一标识
|
472 |
+
symbol_mapping[bbox] = item['symbols']
|
473 |
+
|
474 |
+
for key in ['reactants', 'products']:
|
475 |
+
for item in input2.get(key, []):
|
476 |
+
bbox = tuple(item['bbox']) # 获取 bbox 作为匹配键
|
477 |
+
|
478 |
+
# 如果 bbox 存在于 input1 的映射中,则更新 symbols
|
479 |
+
if bbox in symbol_mapping:
|
480 |
+
updated_symbols = symbol_mapping[bbox]
|
481 |
+
item['symbols'] = updated_symbols
|
482 |
+
|
483 |
+
# 更新 atoms 的 atom_symbol
|
484 |
+
if 'atoms' in item:
|
485 |
+
atoms = item['atoms']
|
486 |
+
if len(atoms) != len(updated_symbols):
|
487 |
+
print(f"Warning: Mismatched symbols and atoms in bbox {bbox}")
|
488 |
+
else:
|
489 |
+
for atom, symbol in zip(atoms, updated_symbols):
|
490 |
+
atom['atom_symbol'] = symbol
|
491 |
+
|
492 |
+
# 如果 coords 和 edges 存在,调用转换函数生成新的 smiles 和 molfile
|
493 |
+
if 'coords' in item and 'edges' in item:
|
494 |
+
coords = item['coords']
|
495 |
+
edges = item['edges']
|
496 |
+
new_smiles, new_molfile, _ = conversion_function(coords, updated_symbols, edges)
|
497 |
+
|
498 |
+
# 替换旧的 smiles 和 molfile
|
499 |
+
item['smiles'] = new_smiles
|
500 |
+
item['molfile'] = new_molfile
|
501 |
+
|
502 |
+
return input2
|
503 |
+
|
504 |
+
updated_data = [update_input_with_symbols(gpt_output, input2[0], _convert_graph_to_smiles)]
|
505 |
+
print(f"updated_reaction_data:{updated_data}")
|
506 |
+
|
507 |
+
return updated_data
|
main.py
ADDED
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import torch
|
3 |
+
import json
|
4 |
+
from chemietoolkit import ChemIEToolkit,utils
|
5 |
+
import cv2
|
6 |
+
from openai import AzureOpenAI
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
import json
|
10 |
+
from get_molecular_agent import process_reaction_image_with_multiple_products_and_text_correctR
|
11 |
+
from get_reaction_agent import get_reaction_withatoms_correctR
|
12 |
+
import sys
|
13 |
+
from rxnscribe import RxnScribe
|
14 |
+
import json
|
15 |
+
import base64
|
16 |
+
model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
17 |
+
ckpt_path = "./pix2seq_reaction_full.ckpt"
|
18 |
+
model1 = RxnScribe(ckpt_path, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
19 |
+
device = torch.device(('cuda' if torch.cuda.is_available() else 'cpu'))
|
20 |
+
import os
|
21 |
+
|
22 |
+
with open('api_key.txt', 'r') as api_key_file:
|
23 |
+
API_KEY = api_key_file.read()
|
24 |
+
|
25 |
+
def parse_coref_data_with_fallback(data):
|
26 |
+
bboxes = data["bboxes"]
|
27 |
+
corefs = data["corefs"]
|
28 |
+
paired_indices = set()
|
29 |
+
|
30 |
+
# 先处理有 coref 配对的
|
31 |
+
results = []
|
32 |
+
for idx1, idx2 in corefs:
|
33 |
+
smiles_entry = bboxes[idx1] if "smiles" in bboxes[idx1] else bboxes[idx2]
|
34 |
+
text_entry = bboxes[idx2] if "text" in bboxes[idx2] else bboxes[idx1]
|
35 |
+
|
36 |
+
smiles = smiles_entry.get("smiles", "")
|
37 |
+
texts = text_entry.get("text", [])
|
38 |
+
|
39 |
+
results.append({
|
40 |
+
"smiles": smiles,
|
41 |
+
"texts": texts
|
42 |
+
})
|
43 |
+
|
44 |
+
# 记录下哪些 SMILES 被配对过了
|
45 |
+
paired_indices.add(idx1)
|
46 |
+
paired_indices.add(idx2)
|
47 |
+
|
48 |
+
# 处理未配对的 SMILES(补充进来)
|
49 |
+
for idx, entry in enumerate(bboxes):
|
50 |
+
if "smiles" in entry and idx not in paired_indices:
|
51 |
+
results.append({
|
52 |
+
"smiles": entry["smiles"],
|
53 |
+
"texts": ["There is no label or failed to detect, please recheck the image again"]
|
54 |
+
})
|
55 |
+
|
56 |
+
return results
|
57 |
+
|
58 |
+
|
59 |
+
def get_multi_molecular_text_to_correct(image_path: str) -> list:
|
60 |
+
'''Returns a list of reactions extracted from the image.'''
|
61 |
+
# 打开图像文件
|
62 |
+
image = Image.open(image_path).convert('RGB')
|
63 |
+
|
64 |
+
# 将图像作为输入传递给模型
|
65 |
+
#coref_results = process_reaction_image_with_multiple_products_and_text_correctR(image_path)
|
66 |
+
coref_results = model.extract_molecule_corefs_from_figures([image])
|
67 |
+
for item in coref_results:
|
68 |
+
for bbox in item.get("bboxes", []):
|
69 |
+
for key in ["category", "bbox", "molfile", "symbols", 'atoms', "bonds", 'category_id', 'score', 'corefs',"coords","edges"]: #'atoms'
|
70 |
+
bbox.pop(key, None) # 安全地移除键
|
71 |
+
|
72 |
+
data = coref_results[0]
|
73 |
+
parsed = parse_coref_data_with_fallback(data)
|
74 |
+
|
75 |
+
|
76 |
+
print(f"coref_results:{json.dumps(parsed)}")
|
77 |
+
return json.dumps(parsed)
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
def get_reaction(image_path: str) -> dict:
|
87 |
+
'''
|
88 |
+
Returns a structured dictionary of reactions extracted from the image,
|
89 |
+
including only reactants, conditions, and products with their smiles, bbox, or text.
|
90 |
+
'''
|
91 |
+
image_file = image_path
|
92 |
+
#raw_prediction = model1.predict_image_file(image_file, molscribe=True, ocr=True)
|
93 |
+
raw_prediction = get_reaction_withatoms_correctR(image_path)
|
94 |
+
|
95 |
+
|
96 |
+
# Ensure raw_prediction is treated as a list directly
|
97 |
+
structured_output = {}
|
98 |
+
for section_key in ['reactants', 'conditions', 'products']:
|
99 |
+
if section_key in raw_prediction[0]:
|
100 |
+
structured_output[section_key] = []
|
101 |
+
for item in raw_prediction[0][section_key]:
|
102 |
+
if section_key in ['reactants', 'products']:
|
103 |
+
# Extract smiles and bbox for molecules
|
104 |
+
structured_output[section_key].append({
|
105 |
+
"smiles": item.get("smiles", ""),
|
106 |
+
"bbox": item.get("bbox", [])
|
107 |
+
})
|
108 |
+
elif section_key == 'conditions':
|
109 |
+
# Extract text and bbox for conditions
|
110 |
+
structured_output[section_key].append({
|
111 |
+
"text": item.get("text", []),
|
112 |
+
"bbox": item.get("bbox", []),
|
113 |
+
"smiles": item.get("smiles", []),
|
114 |
+
})
|
115 |
+
|
116 |
+
return structured_output
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
def process_reaction_image(image_path: str) -> dict:
|
121 |
+
"""
|
122 |
+
|
123 |
+
Args:
|
124 |
+
image_path (str): 图像文件路径。
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
dict: 整理后的反应数据,包括反应物、产物和反应模板。
|
128 |
+
"""
|
129 |
+
# 配置 API Key 和 Azure Endpoint
|
130 |
+
api_key = os.getenv("CHEMEAGLE_API_KEY")
|
131 |
+
if not api_key:
|
132 |
+
raise RuntimeError("Missing CHEMEAGLE_API_KEY environment variable")
|
133 |
+
|
134 |
+
azure_endpoint = "https://hkust.azure-api.net" # 替换为实际的 Azure Endpoint
|
135 |
+
|
136 |
+
model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
137 |
+
client = AzureOpenAI(
|
138 |
+
api_key=api_key,
|
139 |
+
api_version='2024-06-01',
|
140 |
+
azure_endpoint=azure_endpoint
|
141 |
+
)
|
142 |
+
|
143 |
+
# 加载图像并编码为 Base64
|
144 |
+
def encode_image(image_path: str):
|
145 |
+
with open(image_path, "rb") as image_file:
|
146 |
+
return base64.b64encode(image_file.read()).decode('utf-8')
|
147 |
+
|
148 |
+
base64_image = encode_image(image_path)
|
149 |
+
|
150 |
+
# GPT 工具调用配置
|
151 |
+
tools = [
|
152 |
+
{
|
153 |
+
'type': 'function',
|
154 |
+
'function': {
|
155 |
+
'name': 'get_multi_molecular_text_to_correct',
|
156 |
+
'description': 'Extracts the SMILES string and text coref from molecular images.',
|
157 |
+
'parameters': {
|
158 |
+
'type': 'object',
|
159 |
+
'properties': {
|
160 |
+
'image_path': {
|
161 |
+
'type': 'string',
|
162 |
+
'description': 'Path to the reaction image.'
|
163 |
+
}
|
164 |
+
},
|
165 |
+
'required': ['image_path'],
|
166 |
+
'additionalProperties': False
|
167 |
+
}
|
168 |
+
}
|
169 |
+
},
|
170 |
+
{
|
171 |
+
'type': 'function',
|
172 |
+
'function': {
|
173 |
+
'name': 'get_reaction',
|
174 |
+
'description': 'Get a list of reactions from a reaction image. A reaction contains data of the reactants, conditions, and products.',
|
175 |
+
'parameters': {
|
176 |
+
'type': 'object',
|
177 |
+
'properties': {
|
178 |
+
'image_path': {
|
179 |
+
'type': 'string',
|
180 |
+
'description': 'The path to the reaction image.',
|
181 |
+
},
|
182 |
+
},
|
183 |
+
'required': ['image_path'],
|
184 |
+
'additionalProperties': False,
|
185 |
+
},
|
186 |
+
},
|
187 |
+
},
|
188 |
+
]
|
189 |
+
|
190 |
+
# 提供给 GPT 的消息内容
|
191 |
+
with open('./prompt.txt', 'r') as prompt_file:
|
192 |
+
prompt = prompt_file.read()
|
193 |
+
messages = [
|
194 |
+
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
195 |
+
{
|
196 |
+
'role': 'user',
|
197 |
+
'content': [
|
198 |
+
{'type': 'text', 'text': prompt},
|
199 |
+
{'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}}
|
200 |
+
]
|
201 |
+
}
|
202 |
+
]
|
203 |
+
|
204 |
+
# 调用 GPT 接口
|
205 |
+
response = client.chat.completions.create(
|
206 |
+
model = 'gpt-4o',
|
207 |
+
temperature = 0,
|
208 |
+
response_format={ 'type': 'json_object' },
|
209 |
+
messages = [
|
210 |
+
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
211 |
+
{
|
212 |
+
'role': 'user',
|
213 |
+
'content': [
|
214 |
+
{
|
215 |
+
'type': 'text',
|
216 |
+
'text': prompt
|
217 |
+
},
|
218 |
+
{
|
219 |
+
'type': 'image_url',
|
220 |
+
'image_url': {
|
221 |
+
'url': f'data:image/png;base64,{base64_image}'
|
222 |
+
}
|
223 |
+
}
|
224 |
+
]},
|
225 |
+
],
|
226 |
+
tools = tools)
|
227 |
+
|
228 |
+
# Step 1: 工具映射表
|
229 |
+
TOOL_MAP = {
|
230 |
+
'get_multi_molecular_text_to_correct': get_multi_molecular_text_to_correct,
|
231 |
+
'get_reaction': get_reaction
|
232 |
+
}
|
233 |
+
|
234 |
+
# Step 2: 处理多个工具调用
|
235 |
+
tool_calls = response.choices[0].message.tool_calls
|
236 |
+
results = []
|
237 |
+
|
238 |
+
# 遍历每个工具调用
|
239 |
+
for tool_call in tool_calls:
|
240 |
+
tool_name = tool_call.function.name
|
241 |
+
tool_arguments = tool_call.function.arguments
|
242 |
+
tool_call_id = tool_call.id
|
243 |
+
|
244 |
+
tool_args = json.loads(tool_arguments)
|
245 |
+
|
246 |
+
if tool_name in TOOL_MAP:
|
247 |
+
# 调用工具并获取结果
|
248 |
+
tool_result = TOOL_MAP[tool_name](image_path)
|
249 |
+
else:
|
250 |
+
raise ValueError(f"Unknown tool called: {tool_name}")
|
251 |
+
|
252 |
+
# 保存每个工具调用结果
|
253 |
+
results.append({
|
254 |
+
'role': 'tool',
|
255 |
+
'content': json.dumps({
|
256 |
+
'image_path': image_path,
|
257 |
+
f'{tool_name}':(tool_result),
|
258 |
+
}),
|
259 |
+
'tool_call_id': tool_call_id,
|
260 |
+
})
|
261 |
+
|
262 |
+
|
263 |
+
# Prepare the chat completion payload
|
264 |
+
completion_payload = {
|
265 |
+
'model': 'gpt-4o',
|
266 |
+
'messages': [
|
267 |
+
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
268 |
+
{
|
269 |
+
'role': 'user',
|
270 |
+
'content': [
|
271 |
+
{
|
272 |
+
'type': 'text',
|
273 |
+
'text': prompt
|
274 |
+
},
|
275 |
+
{
|
276 |
+
'type': 'image_url',
|
277 |
+
'image_url': {
|
278 |
+
'url': f'data:image/png;base64,{base64_image}'
|
279 |
+
}
|
280 |
+
}
|
281 |
+
]
|
282 |
+
},
|
283 |
+
response.choices[0].message,
|
284 |
+
*results
|
285 |
+
],
|
286 |
+
}
|
287 |
+
|
288 |
+
# Generate new response
|
289 |
+
response = client.chat.completions.create(
|
290 |
+
model=completion_payload["model"],
|
291 |
+
messages=completion_payload["messages"],
|
292 |
+
response_format={ 'type': 'json_object' },
|
293 |
+
temperature=0
|
294 |
+
)
|
295 |
+
|
296 |
+
|
297 |
+
|
298 |
+
# 获取 GPT 生成的结果
|
299 |
+
gpt_output = json.loads(response.choices[0].message.content)
|
300 |
+
print(gpt_output)
|
301 |
+
image = Image.open(image_path).convert('RGB')
|
302 |
+
image_np = np.array(image)
|
303 |
+
|
304 |
+
|
305 |
+
# reaction_results = model.extract_reactions_from_figures([image_np])
|
306 |
+
coref_results = model.extract_molecule_corefs_from_figures([image_np])
|
307 |
+
|
308 |
+
reaction_results = get_reaction_withatoms_correctR(image_path)[0]
|
309 |
+
reaction = {
|
310 |
+
"reactants": reaction_results.get('reactants', []),
|
311 |
+
"conditions": reaction_results.get('conditions', []),
|
312 |
+
"products": reaction_results.get('products', [])
|
313 |
+
}
|
314 |
+
reaction_results = [{"reactions": [reaction]}]
|
315 |
+
print(reaction_results)
|
316 |
+
#coref_results = process_reaction_image_with_multiple_products_and_text_correctR(image_path)
|
317 |
+
|
318 |
+
|
319 |
+
# 定义更新工具输出的函数
|
320 |
+
def extract_smiles_details(smiles_data, raw_details):
|
321 |
+
smiles_details = {}
|
322 |
+
for smiles in smiles_data:
|
323 |
+
for detail in raw_details:
|
324 |
+
for bbox in detail.get('bboxes', []):
|
325 |
+
if bbox.get('smiles') == smiles:
|
326 |
+
smiles_details[smiles] = {
|
327 |
+
'category': bbox.get('category'),
|
328 |
+
'bbox': bbox.get('bbox'),
|
329 |
+
'category_id': bbox.get('category_id'),
|
330 |
+
'score': bbox.get('score'),
|
331 |
+
'molfile': bbox.get('molfile'),
|
332 |
+
'atoms': bbox.get('atoms'),
|
333 |
+
'bonds': bbox.get('bonds'),
|
334 |
+
}
|
335 |
+
break
|
336 |
+
return smiles_details
|
337 |
+
|
338 |
+
# 获取结果
|
339 |
+
smiles_details = extract_smiles_details(gpt_output, coref_results)
|
340 |
+
|
341 |
+
reactants_array = []
|
342 |
+
products = []
|
343 |
+
|
344 |
+
for reactant in reaction_results[0]['reactions'][0]['reactants']:
|
345 |
+
if 'smiles' in reactant:
|
346 |
+
print(f"SMILES:{reactant['smiles']}")
|
347 |
+
#print(reactant)
|
348 |
+
reactants_array.append(reactant['smiles'])
|
349 |
+
|
350 |
+
for product in reaction_results[0]['reactions'][0]['products']:
|
351 |
+
#print(product['smiles'])
|
352 |
+
#print(product)
|
353 |
+
products.append(product['smiles'])
|
354 |
+
# 输出结果
|
355 |
+
#import pprint
|
356 |
+
#pprint.pprint(smiles_details)
|
357 |
+
|
358 |
+
# 整理反应数据
|
359 |
+
backed_out = utils.backout_without_coref(reaction_results, coref_results, gpt_output, smiles_details, model.molscribe)
|
360 |
+
backed_out.sort(key=lambda x: x[2])
|
361 |
+
extracted_rxns = {}
|
362 |
+
for reactants, products_, label in backed_out:
|
363 |
+
extracted_rxns[label] = {'reactants': reactants, 'products': products_}
|
364 |
+
|
365 |
+
toadd = {
|
366 |
+
"reaction_template": {
|
367 |
+
"reactants": reactants_array,
|
368 |
+
"products": products
|
369 |
+
},
|
370 |
+
"reactions": extracted_rxns,
|
371 |
+
"original_molecule_list": gpt_output
|
372 |
+
}
|
373 |
+
|
374 |
+
# 按标签排序
|
375 |
+
sorted_keys = sorted(toadd["reactions"].keys())
|
376 |
+
toadd["reactions"] = {i: toadd["reactions"][i] for i in sorted_keys}
|
377 |
+
print(toadd)
|
378 |
+
return toadd
|
379 |
+
|
380 |
+
|
381 |
+
|
382 |
+
|
383 |
+
def ChemEagle(image_path: str) -> dict:
|
384 |
+
"""
|
385 |
+
输入化学反应图像路径,通过 GPT 模型和 TOOLS 提取反应信息并返回整理后的反应数据。
|
386 |
+
|
387 |
+
Args:
|
388 |
+
image_path (str): 图像文件路径。
|
389 |
+
|
390 |
+
Returns:
|
391 |
+
dict: 整理后的反应数据,包括反应物、产物和反应模板。
|
392 |
+
"""
|
393 |
+
# 配置 API Key 和 Azure Endpoint
|
394 |
+
api_key = os.getenv("CHEMEAGLE_API_KEY")
|
395 |
+
if not api_key:
|
396 |
+
raise RuntimeError("Missing CHEMEAGLE_API_KEY environment variable")
|
397 |
+
|
398 |
+
azure_endpoint = "https://hkust.azure-api.net" # 替换为实际的 Azure Endpoint
|
399 |
+
|
400 |
+
model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
401 |
+
client = AzureOpenAI(
|
402 |
+
api_key=api_key,
|
403 |
+
api_version='2024-06-01',
|
404 |
+
azure_endpoint=azure_endpoint
|
405 |
+
)
|
406 |
+
|
407 |
+
# 加载图像并编码为 Base64
|
408 |
+
def encode_image(image_path: str):
|
409 |
+
with open(image_path, "rb") as image_file:
|
410 |
+
return base64.b64encode(image_file.read()).decode('utf-8')
|
411 |
+
|
412 |
+
base64_image = encode_image(image_path)
|
413 |
+
|
414 |
+
# GPT 工具调用配置
|
415 |
+
tools = [
|
416 |
+
{
|
417 |
+
'type': 'function',
|
418 |
+
'function': {
|
419 |
+
'name': 'process_reaction_image',
|
420 |
+
'description': 'get the reaction data of the reaction diagram and get SMILES strings of every detailed reaction in reaction diagram and the table, and the original molecular list.',
|
421 |
+
'parameters': {
|
422 |
+
'type': 'object',
|
423 |
+
'properties': {
|
424 |
+
'image_path': {
|
425 |
+
'type': 'string',
|
426 |
+
'description': 'The path to the reaction image.',
|
427 |
+
},
|
428 |
+
},
|
429 |
+
'required': ['image_path'],
|
430 |
+
'additionalProperties': False,
|
431 |
+
},
|
432 |
+
},
|
433 |
+
},
|
434 |
+
]
|
435 |
+
|
436 |
+
# 提供给 GPT 的消息内容
|
437 |
+
with open('./prompt_final_simple_version.txt', 'r') as prompt_file:
|
438 |
+
prompt = prompt_file.read()
|
439 |
+
messages = [
|
440 |
+
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
441 |
+
{
|
442 |
+
'role': 'user',
|
443 |
+
'content': [
|
444 |
+
{'type': 'text', 'text': prompt},
|
445 |
+
{'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}}
|
446 |
+
]
|
447 |
+
}
|
448 |
+
]
|
449 |
+
|
450 |
+
# 调用 GPT 接口
|
451 |
+
response = client.chat.completions.create(
|
452 |
+
model = 'gpt-4o',
|
453 |
+
temperature = 0,
|
454 |
+
response_format={ 'type': 'json_object' },
|
455 |
+
messages = [
|
456 |
+
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
457 |
+
{
|
458 |
+
'role': 'user',
|
459 |
+
'content': [
|
460 |
+
{
|
461 |
+
'type': 'text',
|
462 |
+
'text': prompt
|
463 |
+
},
|
464 |
+
{
|
465 |
+
'type': 'image_url',
|
466 |
+
'image_url': {
|
467 |
+
'url': f'data:image/png;base64,{base64_image}'
|
468 |
+
}
|
469 |
+
}
|
470 |
+
]},
|
471 |
+
],
|
472 |
+
tools = tools)
|
473 |
+
|
474 |
+
# Step 1: 工具映射表
|
475 |
+
TOOL_MAP = {
|
476 |
+
'process_reaction_image': process_reaction_image
|
477 |
+
}
|
478 |
+
|
479 |
+
# Step 2: 处理多个工具调用
|
480 |
+
tool_calls = response.choices[0].message.tool_calls
|
481 |
+
results = []
|
482 |
+
|
483 |
+
# 遍历每个工具调用
|
484 |
+
for tool_call in tool_calls:
|
485 |
+
tool_name = tool_call.function.name
|
486 |
+
tool_arguments = tool_call.function.arguments
|
487 |
+
tool_call_id = tool_call.id
|
488 |
+
|
489 |
+
tool_args = json.loads(tool_arguments)
|
490 |
+
|
491 |
+
if tool_name in TOOL_MAP:
|
492 |
+
# 调用工具并获取结果
|
493 |
+
tool_result = TOOL_MAP[tool_name](image_path)
|
494 |
+
else:
|
495 |
+
raise ValueError(f"Unknown tool called: {tool_name}")
|
496 |
+
|
497 |
+
# 保存每个工具调用结果
|
498 |
+
results.append({
|
499 |
+
'role': 'tool',
|
500 |
+
'content': json.dumps({
|
501 |
+
'image_path': image_path,
|
502 |
+
f'{tool_name}':(tool_result),
|
503 |
+
}),
|
504 |
+
'tool_call_id': tool_call_id,
|
505 |
+
})
|
506 |
+
|
507 |
+
|
508 |
+
# Prepare the chat completion payload
|
509 |
+
completion_payload = {
|
510 |
+
'model': 'gpt-4o',
|
511 |
+
'messages': [
|
512 |
+
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
513 |
+
{
|
514 |
+
'role': 'user',
|
515 |
+
'content': [
|
516 |
+
{
|
517 |
+
'type': 'text',
|
518 |
+
'text': prompt
|
519 |
+
},
|
520 |
+
{
|
521 |
+
'type': 'image_url',
|
522 |
+
'image_url': {
|
523 |
+
'url': f'data:image/png;base64,{base64_image}'
|
524 |
+
}
|
525 |
+
}
|
526 |
+
]
|
527 |
+
},
|
528 |
+
response.choices[0].message,
|
529 |
+
*results
|
530 |
+
],
|
531 |
+
}
|
532 |
+
|
533 |
+
# Generate new response
|
534 |
+
response = client.chat.completions.create(
|
535 |
+
model=completion_payload["model"],
|
536 |
+
messages=completion_payload["messages"],
|
537 |
+
response_format={ 'type': 'json_object' },
|
538 |
+
temperature=0
|
539 |
+
)
|
540 |
+
|
541 |
+
|
542 |
+
|
543 |
+
# 获取 GPT 生成的结果
|
544 |
+
gpt_output = json.loads(response.choices[0].message.content)
|
545 |
+
print(gpt_output)
|
546 |
+
return gpt_output
|
main_Rgroup_debug.ipynb
ADDED
@@ -0,0 +1,993 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import sys\n",
|
10 |
+
"import torch\n",
|
11 |
+
"import json\n",
|
12 |
+
"from chemietoolkit import ChemIEToolkit\n",
|
13 |
+
"import cv2\n",
|
14 |
+
"from PIL import Image\n",
|
15 |
+
"import json\n",
|
16 |
+
"model = ChemIEToolkit(device=torch.device('cpu')) \n",
|
17 |
+
"from get_molecular_agent import process_reaction_image_with_multiple_products_and_text\n",
|
18 |
+
"from get_reaction_agent import get_reaction_withatoms\n",
|
19 |
+
"from get_reaction_agent import get_full_reaction\n",
|
20 |
+
"\n",
|
21 |
+
"\n",
|
22 |
+
"# 定义函数,接受多个图像路径并返回反应列表\n",
|
23 |
+
"def get_multi_molecular(image_path: str) -> list:\n",
|
24 |
+
" '''Returns a list of reactions extracted from the image.'''\n",
|
25 |
+
" # 打开图像文件\n",
|
26 |
+
" image = Image.open(image_path).convert('RGB')\n",
|
27 |
+
" \n",
|
28 |
+
" # 将图像作为输入传递给模型\n",
|
29 |
+
" coref_results = model.extract_molecule_corefs_from_figures([image])\n",
|
30 |
+
" \n",
|
31 |
+
" for item in coref_results:\n",
|
32 |
+
" for bbox in item.get(\"bboxes\", []):\n",
|
33 |
+
" for key in [\"category\", \"molfile\", \"symbols\", 'atoms', \"bonds\", 'category_id', 'score', 'corefs',\"coords\",\"edges\"]: #'atoms'\n",
|
34 |
+
" bbox.pop(key, None) # 安全地移除键\n",
|
35 |
+
" print(json.dumps(coref_results))\n",
|
36 |
+
" # 返回反应列表,使用 json.dumps 进行格式化\n",
|
37 |
+
" \n",
|
38 |
+
" return json.dumps(coref_results)\n",
|
39 |
+
"\n",
|
40 |
+
"def get_multi_molecular_text_to_correct(image_path: str) -> list:\n",
|
41 |
+
" '''Returns a list of reactions extracted from the image.'''\n",
|
42 |
+
" # 打开图像文件\n",
|
43 |
+
" image = Image.open(image_path).convert('RGB')\n",
|
44 |
+
" \n",
|
45 |
+
" # 将图像作为输入传递给模型\n",
|
46 |
+
" coref_results = model.extract_molecule_corefs_from_figures([image])\n",
|
47 |
+
" #coref_results = process_reaction_image_with_multiple_products_and_text(image_path)\n",
|
48 |
+
" for item in coref_results:\n",
|
49 |
+
" for bbox in item.get(\"bboxes\", []):\n",
|
50 |
+
" for key in [\"category\", \"bbox\", \"molfile\", \"symbols\", 'atoms', \"bonds\", 'category_id', 'score', 'corefs',\"coords\",\"edges\"]: #'atoms'\n",
|
51 |
+
" bbox.pop(key, None) # 安全地移除键\n",
|
52 |
+
" print(json.dumps(coref_results))\n",
|
53 |
+
" # 返回反应列表,使用 json.dumps 进行格式化\n",
|
54 |
+
" \n",
|
55 |
+
" return json.dumps(coref_results)\n",
|
56 |
+
"\n",
|
57 |
+
"def get_multi_molecular_text_to_correct_withatoms(image_path: str) -> list:\n",
|
58 |
+
" '''Returns a list of reactions extracted from the image.'''\n",
|
59 |
+
" # 打开图像文件\n",
|
60 |
+
" image = Image.open(image_path).convert('RGB')\n",
|
61 |
+
" \n",
|
62 |
+
" # 将图像作为输入传递给模型\n",
|
63 |
+
" #coref_results = model.extract_molecule_corefs_from_figures([image])\n",
|
64 |
+
" coref_results = process_reaction_image_with_multiple_products_and_text(image_path)\n",
|
65 |
+
" for item in coref_results:\n",
|
66 |
+
" for bbox in item.get(\"bboxes\", []):\n",
|
67 |
+
" for key in [\"molfile\", 'atoms', \"bonds\", 'category_id', 'score', 'corefs',\"coords\",\"edges\"]: #'atoms'\n",
|
68 |
+
" bbox.pop(key, None) # 安全地移除键\n",
|
69 |
+
" print(json.dumps(coref_results))\n",
|
70 |
+
" # 返回反应列表,使用 json.dumps 进行格式化\n",
|
71 |
+
" return json.dumps(coref_results)\n",
|
72 |
+
"\n",
|
73 |
+
"#get_multi_molecular_text_to_correct('./acs.joc.2c00176 example 1.png')\n",
|
74 |
+
"\n",
|
75 |
+
"import sys\n",
|
76 |
+
"#sys.path.append('./RxnScribe-main/')\n",
|
77 |
+
"import torch\n",
|
78 |
+
"from rxnscribe import RxnScribe\n",
|
79 |
+
"import json\n",
|
80 |
+
"\n",
|
81 |
+
"ckpt_path = \"./pix2seq_reaction_full.ckpt\"\n",
|
82 |
+
"model1 = RxnScribe(ckpt_path, device=torch.device('cpu'))\n",
|
83 |
+
"device = torch.device('cpu')\n",
|
84 |
+
"\n",
|
85 |
+
"def get_reaction(image_path: str) -> dict:\n",
|
86 |
+
" '''\n",
|
87 |
+
" Returns a structured dictionary of reactions extracted from the image,\n",
|
88 |
+
" including reactants, conditions, and products, with their smiles, text, and bbox.\n",
|
89 |
+
" '''\n",
|
90 |
+
" image_file = image_path\n",
|
91 |
+
" #raw_prediction = model1.predict_image_file(image_file, molscribe=True, ocr=True)\n",
|
92 |
+
" raw_prediction = get_reaction_withatoms(image_path)\n",
|
93 |
+
"\n",
|
94 |
+
" # Ensure raw_prediction is treated as a list directly\n",
|
95 |
+
" structured_output = {}\n",
|
96 |
+
" for section_key in ['reactants', 'conditions', 'products']:\n",
|
97 |
+
" if section_key in raw_prediction[0]:\n",
|
98 |
+
" structured_output[section_key] = []\n",
|
99 |
+
" for item in raw_prediction[0][section_key]:\n",
|
100 |
+
" if section_key in ['reactants', 'products']:\n",
|
101 |
+
" # Extract smiles and bbox for molecules\n",
|
102 |
+
" structured_output[section_key].append({\n",
|
103 |
+
" \"smiles\": item.get(\"smiles\", \"\"),\n",
|
104 |
+
" \"bbox\": item.get(\"bbox\", [])\n",
|
105 |
+
" })\n",
|
106 |
+
" elif section_key == 'conditions':\n",
|
107 |
+
" # Extract smiles, text, and bbox for conditions\n",
|
108 |
+
" condition_data = {\"bbox\": item.get(\"bbox\", [])}\n",
|
109 |
+
" if \"smiles\" in item:\n",
|
110 |
+
" condition_data[\"smiles\"] = item.get(\"smiles\", \"\")\n",
|
111 |
+
" if \"text\" in item:\n",
|
112 |
+
" condition_data[\"text\"] = item.get(\"text\", [])\n",
|
113 |
+
" structured_output[section_key].append(condition_data)\n",
|
114 |
+
" print(f\"structured_output:{structured_output}\")\n",
|
115 |
+
"\n",
|
116 |
+
" return structured_output\n",
|
117 |
+
"\n",
|
118 |
+
"\n",
|
119 |
+
"\n",
|
120 |
+
"\n",
|
121 |
+
"import base64\n",
|
122 |
+
"import torch\n",
|
123 |
+
"import json\n",
|
124 |
+
"from PIL import Image\n",
|
125 |
+
"import numpy as np\n",
|
126 |
+
"from chemietoolkit import ChemIEToolkit, utils\n",
|
127 |
+
"from openai import AzureOpenAI\n",
|
128 |
+
"\n",
|
129 |
+
"def process_reaction_image_with_multiple_products(image_path: str) -> dict:\n",
|
130 |
+
" \"\"\"\n",
|
131 |
+
" Args:\n",
|
132 |
+
" image_path (str): 图像文件路径。\n",
|
133 |
+
"\n",
|
134 |
+
" Returns:\n",
|
135 |
+
" dict: 整理后的反应数据,包括反应物、产物和反应模板。\n",
|
136 |
+
" \"\"\"\n",
|
137 |
+
" # 配置 API Key 和 Azure Endpoint\n",
|
138 |
+
" api_key = \"b038da96509b4009be931e035435e022\" # 替换为实际的 API Key\n",
|
139 |
+
" azure_endpoint = \"https://hkust.azure-api.net\" # 替换为实际的 Azure Endpoint\n",
|
140 |
+
" \n",
|
141 |
+
"\n",
|
142 |
+
" model = ChemIEToolkit(device=torch.device('cpu'))\n",
|
143 |
+
" client = AzureOpenAI(\n",
|
144 |
+
" api_key=api_key,\n",
|
145 |
+
" api_version='2024-06-01',\n",
|
146 |
+
" azure_endpoint=azure_endpoint\n",
|
147 |
+
" )\n",
|
148 |
+
"\n",
|
149 |
+
" # 加载图像并编码为 Base64\n",
|
150 |
+
" def encode_image(image_path: str):\n",
|
151 |
+
" with open(image_path, \"rb\") as image_file:\n",
|
152 |
+
" return base64.b64encode(image_file.read()).decode('utf-8')\n",
|
153 |
+
"\n",
|
154 |
+
" base64_image = encode_image(image_path)\n",
|
155 |
+
"\n",
|
156 |
+
" # GPT 工具调用配置\n",
|
157 |
+
" tools = [\n",
|
158 |
+
" {\n",
|
159 |
+
" 'type': 'function',\n",
|
160 |
+
" 'function': {\n",
|
161 |
+
" 'name': 'get_multi_molecular_text_to_correct',\n",
|
162 |
+
" 'description': 'Extracts the SMILES string and text coref from molecular images.',\n",
|
163 |
+
" 'parameters': {\n",
|
164 |
+
" 'type': 'object',\n",
|
165 |
+
" 'properties': {\n",
|
166 |
+
" 'image_path': {\n",
|
167 |
+
" 'type': 'string',\n",
|
168 |
+
" 'description': 'Path to the reaction image.'\n",
|
169 |
+
" }\n",
|
170 |
+
" },\n",
|
171 |
+
" 'required': ['image_path'],\n",
|
172 |
+
" 'additionalProperties': False\n",
|
173 |
+
" }\n",
|
174 |
+
" }\n",
|
175 |
+
" },\n",
|
176 |
+
" {\n",
|
177 |
+
" 'type': 'function',\n",
|
178 |
+
" 'function': {\n",
|
179 |
+
" 'name': 'get_reaction',\n",
|
180 |
+
" 'description': 'Get a list of reactions from a reaction image. A reaction contains data of the reactants, conditions, and products.',\n",
|
181 |
+
" 'parameters': {\n",
|
182 |
+
" 'type': 'object',\n",
|
183 |
+
" 'properties': {\n",
|
184 |
+
" 'image_path': {\n",
|
185 |
+
" 'type': 'string',\n",
|
186 |
+
" 'description': 'The path to the reaction image.',\n",
|
187 |
+
" },\n",
|
188 |
+
" },\n",
|
189 |
+
" 'required': ['image_path'],\n",
|
190 |
+
" 'additionalProperties': False,\n",
|
191 |
+
" },\n",
|
192 |
+
" },\n",
|
193 |
+
" },\n",
|
194 |
+
" ]\n",
|
195 |
+
"\n",
|
196 |
+
" # 提供给 GPT 的消息内容\n",
|
197 |
+
" with open('./prompt.txt', 'r') as prompt_file:\n",
|
198 |
+
" prompt = prompt_file.read()\n",
|
199 |
+
" messages = [\n",
|
200 |
+
" {'role': 'system', 'content': 'You are a helpful assistant.'},\n",
|
201 |
+
" {\n",
|
202 |
+
" 'role': 'user',\n",
|
203 |
+
" 'content': [\n",
|
204 |
+
" {'type': 'text', 'text': prompt},\n",
|
205 |
+
" {'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}}\n",
|
206 |
+
" ]\n",
|
207 |
+
" }\n",
|
208 |
+
" ]\n",
|
209 |
+
"\n",
|
210 |
+
" # 调用 GPT 接口\n",
|
211 |
+
" response = client.chat.completions.create(\n",
|
212 |
+
" model = 'gpt-4o',\n",
|
213 |
+
" temperature = 0,\n",
|
214 |
+
" response_format={ 'type': 'json_object' },\n",
|
215 |
+
" messages = [\n",
|
216 |
+
" {'role': 'system', 'content': 'You are a helpful assistant.'},\n",
|
217 |
+
" {\n",
|
218 |
+
" 'role': 'user',\n",
|
219 |
+
" 'content': [\n",
|
220 |
+
" {\n",
|
221 |
+
" 'type': 'text',\n",
|
222 |
+
" 'text': prompt\n",
|
223 |
+
" },\n",
|
224 |
+
" {\n",
|
225 |
+
" 'type': 'image_url',\n",
|
226 |
+
" 'image_url': {\n",
|
227 |
+
" 'url': f'data:image/png;base64,{base64_image}'\n",
|
228 |
+
" }\n",
|
229 |
+
" }\n",
|
230 |
+
" ]},\n",
|
231 |
+
" ],\n",
|
232 |
+
" tools = tools)\n",
|
233 |
+
" \n",
|
234 |
+
"# Step 1: 工具映射表\n",
|
235 |
+
" TOOL_MAP = {\n",
|
236 |
+
" 'get_multi_molecular_text_to_correct': get_multi_molecular_text_to_correct,\n",
|
237 |
+
" 'get_reaction': get_reaction\n",
|
238 |
+
" }\n",
|
239 |
+
"\n",
|
240 |
+
" # Step 2: 处理多个工具调用\n",
|
241 |
+
" tool_calls = response.choices[0].message.tool_calls\n",
|
242 |
+
" results = []\n",
|
243 |
+
"\n",
|
244 |
+
" # 遍历每个工具调用\n",
|
245 |
+
" for tool_call in tool_calls:\n",
|
246 |
+
" tool_name = tool_call.function.name\n",
|
247 |
+
" tool_arguments = tool_call.function.arguments\n",
|
248 |
+
" tool_call_id = tool_call.id\n",
|
249 |
+
" \n",
|
250 |
+
" tool_args = json.loads(tool_arguments)\n",
|
251 |
+
" \n",
|
252 |
+
" if tool_name in TOOL_MAP:\n",
|
253 |
+
" # 调用工具并获取结果\n",
|
254 |
+
" tool_result = TOOL_MAP[tool_name](image_path)\n",
|
255 |
+
" else:\n",
|
256 |
+
" raise ValueError(f\"Unknown tool called: {tool_name}\")\n",
|
257 |
+
" \n",
|
258 |
+
" # 保存每个工具调用结果\n",
|
259 |
+
" results.append({\n",
|
260 |
+
" 'role': 'tool',\n",
|
261 |
+
" 'content': json.dumps({\n",
|
262 |
+
" 'image_path': image_path,\n",
|
263 |
+
" f'{tool_name}':(tool_result),\n",
|
264 |
+
" }),\n",
|
265 |
+
" 'tool_call_id': tool_call_id,\n",
|
266 |
+
" })\n",
|
267 |
+
"\n",
|
268 |
+
"\n",
|
269 |
+
"# Prepare the chat completion payload\n",
|
270 |
+
" completion_payload = {\n",
|
271 |
+
" 'model': 'gpt-4o',\n",
|
272 |
+
" 'messages': [\n",
|
273 |
+
" {'role': 'system', 'content': 'You are a helpful assistant.'},\n",
|
274 |
+
" {\n",
|
275 |
+
" 'role': 'user',\n",
|
276 |
+
" 'content': [\n",
|
277 |
+
" {\n",
|
278 |
+
" 'type': 'text',\n",
|
279 |
+
" 'text': prompt\n",
|
280 |
+
" },\n",
|
281 |
+
" {\n",
|
282 |
+
" 'type': 'image_url',\n",
|
283 |
+
" 'image_url': {\n",
|
284 |
+
" 'url': f'data:image/png;base64,{base64_image}'\n",
|
285 |
+
" }\n",
|
286 |
+
" }\n",
|
287 |
+
" ]\n",
|
288 |
+
" },\n",
|
289 |
+
" response.choices[0].message,\n",
|
290 |
+
" *results\n",
|
291 |
+
" ],\n",
|
292 |
+
" }\n",
|
293 |
+
"\n",
|
294 |
+
"# Generate new response\n",
|
295 |
+
" response = client.chat.completions.create(\n",
|
296 |
+
" model=completion_payload[\"model\"],\n",
|
297 |
+
" messages=completion_payload[\"messages\"],\n",
|
298 |
+
" response_format={ 'type': 'json_object' },\n",
|
299 |
+
" temperature=0\n",
|
300 |
+
" )\n",
|
301 |
+
"\n",
|
302 |
+
"\n",
|
303 |
+
" \n",
|
304 |
+
" # 获取 GPT 生成的结果\n",
|
305 |
+
" gpt_output = json.loads(response.choices[0].message.content)\n",
|
306 |
+
" print(f\"gptout:{gpt_output}\")\n",
|
307 |
+
"\n",
|
308 |
+
" image = Image.open(image_path).convert('RGB')\n",
|
309 |
+
" image_np = np.array(image)\n",
|
310 |
+
"\n",
|
311 |
+
" #########################\n",
|
312 |
+
" #reaction_results = model.extract_reactions_from_figures([image_np])\n",
|
313 |
+
" reaction_results = get_reaction_withatoms(image_path)[0]\n",
|
314 |
+
" reactions = []\n",
|
315 |
+
" \n",
|
316 |
+
" # 将 reactants 和 products 转换为 reactions\n",
|
317 |
+
" for reactants, conditions, products in zip(reaction_results.get('reactants', []), reaction_results.get('conditions', []), reaction_results.get('products', [])):\n",
|
318 |
+
" reaction = {\n",
|
319 |
+
" \"reactants\": [reactants],\n",
|
320 |
+
" \"conditions\": [conditions],\n",
|
321 |
+
" \"products\": [products]\n",
|
322 |
+
" }\n",
|
323 |
+
" reactions.append(reaction)\n",
|
324 |
+
" reaction_results = [{\"reactions\": reactions}]\n",
|
325 |
+
" #coref_results = model.extract_molecule_corefs_from_figures([image_np])\n",
|
326 |
+
" coref_results = process_reaction_image_with_multiple_products_and_text(image_path)\n",
|
327 |
+
" ########################\n",
|
328 |
+
"\n",
|
329 |
+
" # 定义更新工具输出的函数\n",
|
330 |
+
" def extract_smiles_details(smiles_data, raw_details):\n",
|
331 |
+
" smiles_details = {}\n",
|
332 |
+
" for smiles in smiles_data:\n",
|
333 |
+
" for detail in raw_details:\n",
|
334 |
+
" for bbox in detail.get('bboxes', []):\n",
|
335 |
+
" if bbox.get('smiles') == smiles:\n",
|
336 |
+
" smiles_details[smiles] = {\n",
|
337 |
+
" 'category': bbox.get('category'),\n",
|
338 |
+
" 'bbox': bbox.get('bbox'),\n",
|
339 |
+
" 'category_id': bbox.get('category_id'),\n",
|
340 |
+
" 'score': bbox.get('score'),\n",
|
341 |
+
" 'molfile': bbox.get('molfile'),\n",
|
342 |
+
" 'atoms': bbox.get('atoms'),\n",
|
343 |
+
" 'bonds': bbox.get('bonds')\n",
|
344 |
+
" }\n",
|
345 |
+
" break\n",
|
346 |
+
" return smiles_details\n",
|
347 |
+
"\n",
|
348 |
+
"# 获取结果\n",
|
349 |
+
" smiles_details = extract_smiles_details(gpt_output, coref_results)\n",
|
350 |
+
"\n",
|
351 |
+
" reactants_array = []\n",
|
352 |
+
" products = []\n",
|
353 |
+
"\n",
|
354 |
+
" for reactant in reaction_results[0]['reactions'][0]['reactants']:\n",
|
355 |
+
" #for reactant in reaction_results[0]['reactions'][0]['reactants']:\n",
|
356 |
+
" if 'smiles' in reactant:\n",
|
357 |
+
" #print(reactant['smiles'])\n",
|
358 |
+
" #print(reactant)\n",
|
359 |
+
" reactants_array.append(reactant['smiles'])\n",
|
360 |
+
"\n",
|
361 |
+
" for product in reaction_results[0]['reactions'][0]['products']:\n",
|
362 |
+
" #print(product['smiles'])\n",
|
363 |
+
" #print(product)\n",
|
364 |
+
" products.append(product['smiles'])\n",
|
365 |
+
" # 输出结果\n",
|
366 |
+
" #import pprint\n",
|
367 |
+
" #pprint.pprint(smiles_details)\n",
|
368 |
+
"\n",
|
369 |
+
" # 整理反应数据\n",
|
370 |
+
" try:\n",
|
371 |
+
" backed_out = utils.backout_without_coref(reaction_results, coref_results, gpt_output, smiles_details, model.molscribe)\n",
|
372 |
+
" backed_out.sort(key=lambda x: x[2])\n",
|
373 |
+
" extracted_rxns = {}\n",
|
374 |
+
" for reactants, products_, label in backed_out:\n",
|
375 |
+
" extracted_rxns[label] = {'reactants': reactants, 'products': products_}\n",
|
376 |
+
"\n",
|
377 |
+
" toadd = {\n",
|
378 |
+
" \"reaction_template\": {\n",
|
379 |
+
" \"reactants\": reactants_array,\n",
|
380 |
+
" \"products\": products\n",
|
381 |
+
" },\n",
|
382 |
+
" \"reactions\": extracted_rxns\n",
|
383 |
+
" }\n",
|
384 |
+
" \n",
|
385 |
+
"\n",
|
386 |
+
" # 按标签排序\n",
|
387 |
+
" sorted_keys = sorted(toadd[\"reactions\"].keys())\n",
|
388 |
+
" toadd[\"reactions\"] = {i: toadd[\"reactions\"][i] for i in sorted_keys}\n",
|
389 |
+
" original_molecular_list = {'Original molecular list': gpt_output}\n",
|
390 |
+
" final_data= toadd.copy()\n",
|
391 |
+
" final_data.update(original_molecular_list)\n",
|
392 |
+
" except:\n",
|
393 |
+
" #pass\n",
|
394 |
+
" final_data = {'Original molecular list': gpt_output}\n",
|
395 |
+
"\n",
|
396 |
+
" print(final_data)\n",
|
397 |
+
" return final_data\n",
|
398 |
+
" \n",
|
399 |
+
"\n",
|
400 |
+
"\n",
|
401 |
+
"\n"
|
402 |
+
]
|
403 |
+
},
|
404 |
+
{
|
405 |
+
"cell_type": "code",
|
406 |
+
"execution_count": null,
|
407 |
+
"metadata": {},
|
408 |
+
"outputs": [],
|
409 |
+
"source": [
|
410 |
+
"# # image_path = './example/Replace/99.jpg'\n",
|
411 |
+
"# # result = process_reaction_image(image_path)\n",
|
412 |
+
"# # print(json.dumps(result, indent=4))\n",
|
413 |
+
"# image_path = './example/example1/replace/Nesting/283.jpg'\n",
|
414 |
+
"# image = Image.open(image_path).convert('RGB')\n",
|
415 |
+
"# image_np = np.array(image)\n",
|
416 |
+
"\n",
|
417 |
+
"# # input1 = get_multi_molecular_text_to_correct_withatoms('./example/example1/replace/Nesting/283.jpg')\n",
|
418 |
+
"# # input2 = get_reaction('./example/example1/replace/Nesting/283.jpg')\n",
|
419 |
+
"# # print(input1)\n",
|
420 |
+
"# # print(input2)\n",
|
421 |
+
"# #reaction_results = model.extract_reactions_from_figures([image_np])\n",
|
422 |
+
"# coorf = model.extract_molecule_corefs_from_figures([image_np])\n",
|
423 |
+
"# print(coorf)\n"
|
424 |
+
]
|
425 |
+
},
|
426 |
+
{
|
427 |
+
"cell_type": "code",
|
428 |
+
"execution_count": null,
|
429 |
+
"metadata": {},
|
430 |
+
"outputs": [],
|
431 |
+
"source": [
|
432 |
+
"import base64\n",
|
433 |
+
"import torch\n",
|
434 |
+
"import json\n",
|
435 |
+
"from PIL import Image\n",
|
436 |
+
"import numpy as np\n",
|
437 |
+
"from openai import AzureOpenAI\n",
|
438 |
+
"\n",
|
439 |
+
"def process_reaction_image_final(image_path: str) -> dict:\n",
|
440 |
+
" \"\"\"\n",
|
441 |
+
"\n",
|
442 |
+
" Args:\n",
|
443 |
+
" image_path (str): 图像文件路径。\n",
|
444 |
+
"\n",
|
445 |
+
" Returns:\n",
|
446 |
+
" dict: 整理后的反应数据,包括反应物、产物和反应模板。\n",
|
447 |
+
" \"\"\"\n",
|
448 |
+
" # 配置 API Key 和 Azure Endpoint\n",
|
449 |
+
" api_key = \"b038da96509b4009be931e035435e022\" # 替换为实际的 API Key\n",
|
450 |
+
" azure_endpoint = \"https://hkust.azure-api.net\" # 替换为实际的 Azure Endpoint\n",
|
451 |
+
" \n",
|
452 |
+
"\n",
|
453 |
+
" model = ChemIEToolkit(device=torch.device('cpu'))\n",
|
454 |
+
" client = AzureOpenAI(\n",
|
455 |
+
" api_key=api_key,\n",
|
456 |
+
" api_version='2024-06-01',\n",
|
457 |
+
" azure_endpoint=azure_endpoint\n",
|
458 |
+
" )\n",
|
459 |
+
"\n",
|
460 |
+
" # 加载图像并编码为 Base64\n",
|
461 |
+
" def encode_image(image_path: str):\n",
|
462 |
+
" with open(image_path, \"rb\") as image_file:\n",
|
463 |
+
" return base64.b64encode(image_file.read()).decode('utf-8')\n",
|
464 |
+
"\n",
|
465 |
+
" base64_image = encode_image(image_path)\n",
|
466 |
+
"\n",
|
467 |
+
" # GPT 工具调用配置\n",
|
468 |
+
" tools = [\n",
|
469 |
+
" {\n",
|
470 |
+
" 'type': 'function',\n",
|
471 |
+
" 'function': {\n",
|
472 |
+
" 'name': 'get_multi_molecular_text_to_correct',\n",
|
473 |
+
" 'description': 'Extracts the SMILES string and text coref from molecular sub-images from a reaction image and ready for further process.',\n",
|
474 |
+
" 'parameters': {\n",
|
475 |
+
" 'type': 'object',\n",
|
476 |
+
" 'properties': {\n",
|
477 |
+
" 'image_path': {\n",
|
478 |
+
" 'type': 'string',\n",
|
479 |
+
" 'description': 'Path to the reaction image.'\n",
|
480 |
+
" }\n",
|
481 |
+
" },\n",
|
482 |
+
" 'required': ['image_path'],\n",
|
483 |
+
" 'additionalProperties': False\n",
|
484 |
+
" }\n",
|
485 |
+
" }\n",
|
486 |
+
" },\n",
|
487 |
+
" {\n",
|
488 |
+
" 'type': 'function',\n",
|
489 |
+
" 'function': {\n",
|
490 |
+
" 'name': 'get_reaction',\n",
|
491 |
+
" 'description': 'Get a list of reactions from a reaction image. A reaction contains data of the reactants, conditions, and products.',\n",
|
492 |
+
" 'parameters': {\n",
|
493 |
+
" 'type': 'object',\n",
|
494 |
+
" 'properties': {\n",
|
495 |
+
" 'image_path': {\n",
|
496 |
+
" 'type': 'string',\n",
|
497 |
+
" 'description': 'The path to the reaction image.',\n",
|
498 |
+
" },\n",
|
499 |
+
" },\n",
|
500 |
+
" 'required': ['image_path'],\n",
|
501 |
+
" 'additionalProperties': False,\n",
|
502 |
+
" },\n",
|
503 |
+
" },\n",
|
504 |
+
" },\n",
|
505 |
+
"\n",
|
506 |
+
" \n",
|
507 |
+
"\n",
|
508 |
+
" {\n",
|
509 |
+
" 'type': 'function',\n",
|
510 |
+
" 'function': {\n",
|
511 |
+
" 'name': 'process_reaction_image_with_multiple_products',\n",
|
512 |
+
" 'description': 'process the reaction image that contains a multiple products table. Get a list of reactions from the reaction image, Inculding the reaction template and detailed reaction with detailed R-group information.',\n",
|
513 |
+
" 'parameters': {\n",
|
514 |
+
" 'type': 'object',\n",
|
515 |
+
" 'properties': {\n",
|
516 |
+
" 'image_path': {\n",
|
517 |
+
" 'type': 'string',\n",
|
518 |
+
" 'description': 'The path to the reaction image.',\n",
|
519 |
+
" },\n",
|
520 |
+
" },\n",
|
521 |
+
" 'required': ['image_path'],\n",
|
522 |
+
" 'additionalProperties': False,\n",
|
523 |
+
" },\n",
|
524 |
+
" },\n",
|
525 |
+
" },\n",
|
526 |
+
"\n",
|
527 |
+
" {\n",
|
528 |
+
" 'type': 'function',\n",
|
529 |
+
" 'function': {\n",
|
530 |
+
" 'name': 'get_full_reaction',\n",
|
531 |
+
" 'description': 'Get a list of reactions from a reaction image without any tables. A reaction contains data of the reactants, conditions, and products.',\n",
|
532 |
+
" 'parameters': {\n",
|
533 |
+
" 'type': 'object',\n",
|
534 |
+
" 'properties': {\n",
|
535 |
+
" 'image_path': {\n",
|
536 |
+
" 'type': 'string',\n",
|
537 |
+
" 'description': 'The path to the reaction image.',\n",
|
538 |
+
" },\n",
|
539 |
+
" },\n",
|
540 |
+
" 'required': ['image_path'],\n",
|
541 |
+
" 'additionalProperties': False,\n",
|
542 |
+
" },\n",
|
543 |
+
" },\n",
|
544 |
+
" },\n",
|
545 |
+
"\n",
|
546 |
+
" {\n",
|
547 |
+
" 'type': 'function',\n",
|
548 |
+
" 'function': {\n",
|
549 |
+
" 'name': 'get_multi_molecular',\n",
|
550 |
+
" 'description': 'Extracts the SMILES string and text coref from a molecular image without any reactions',\n",
|
551 |
+
" 'parameters': {\n",
|
552 |
+
" 'type': 'object',\n",
|
553 |
+
" 'properties': {\n",
|
554 |
+
" 'image_path': {\n",
|
555 |
+
" 'type': 'string',\n",
|
556 |
+
" 'description': 'The path to the reaction image.',\n",
|
557 |
+
" },\n",
|
558 |
+
" },\n",
|
559 |
+
" 'required': ['image_path'],\n",
|
560 |
+
" 'additionalProperties': False,\n",
|
561 |
+
" },\n",
|
562 |
+
" },\n",
|
563 |
+
" },\n",
|
564 |
+
" ]\n",
|
565 |
+
"\n",
|
566 |
+
" # 提供给 GPT 的消息内容\n",
|
567 |
+
" with open('./prompt_final.txt', 'r') as prompt_file:\n",
|
568 |
+
" prompt = prompt_file.read()\n",
|
569 |
+
" messages = [\n",
|
570 |
+
" {'role': 'system', 'content': 'You are a helpful assistant.'},\n",
|
571 |
+
" {\n",
|
572 |
+
" 'role': 'user',\n",
|
573 |
+
" 'content': [\n",
|
574 |
+
" {'type': 'text', 'text': prompt},\n",
|
575 |
+
" {'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}}\n",
|
576 |
+
" ]\n",
|
577 |
+
" }\n",
|
578 |
+
" ]\n",
|
579 |
+
"\n",
|
580 |
+
" # 调用 GPT 接口\n",
|
581 |
+
" response = client.chat.completions.create(\n",
|
582 |
+
" model = 'gpt-4o',\n",
|
583 |
+
" temperature = 0,\n",
|
584 |
+
" response_format={ 'type': 'json_object' },\n",
|
585 |
+
" messages = [\n",
|
586 |
+
" {'role': 'system', 'content': 'You are a helpful assistant.'},\n",
|
587 |
+
" {\n",
|
588 |
+
" 'role': 'user',\n",
|
589 |
+
" 'content': [\n",
|
590 |
+
" {\n",
|
591 |
+
" 'type': 'text',\n",
|
592 |
+
" 'text': prompt\n",
|
593 |
+
" },\n",
|
594 |
+
" {\n",
|
595 |
+
" 'type': 'image_url',\n",
|
596 |
+
" 'image_url': {\n",
|
597 |
+
" 'url': f'data:image/png;base64,{base64_image}'\n",
|
598 |
+
" }\n",
|
599 |
+
" }\n",
|
600 |
+
" ]},\n",
|
601 |
+
" ],\n",
|
602 |
+
" tools = tools)\n",
|
603 |
+
" \n",
|
604 |
+
"# Step 1: 工具映射表\n",
|
605 |
+
" TOOL_MAP = {\n",
|
606 |
+
" 'get_multi_molecular_text_to_correct': get_multi_molecular_text_to_correct,\n",
|
607 |
+
" 'get_reaction': get_reaction,\n",
|
608 |
+
" 'process_reaction_image_with_multiple_products':process_reaction_image_with_multiple_products,\n",
|
609 |
+
"\n",
|
610 |
+
" 'get_full_reaction': get_full_reaction,\n",
|
611 |
+
" 'get_multi_molecular':get_multi_molecular,\n",
|
612 |
+
" }\n",
|
613 |
+
"\n",
|
614 |
+
" # Step 2: 处理多个工具调用\n",
|
615 |
+
" tool_calls = response.choices[0].message.tool_calls\n",
|
616 |
+
" results = []\n",
|
617 |
+
"\n",
|
618 |
+
" # 遍历每个工具调用\n",
|
619 |
+
" for tool_call in tool_calls:\n",
|
620 |
+
" tool_name = tool_call.function.name\n",
|
621 |
+
" tool_arguments = tool_call.function.arguments\n",
|
622 |
+
" tool_call_id = tool_call.id\n",
|
623 |
+
" \n",
|
624 |
+
" tool_args = json.loads(tool_arguments)\n",
|
625 |
+
" \n",
|
626 |
+
" if tool_name in TOOL_MAP:\n",
|
627 |
+
" # 调用工具并获取结果\n",
|
628 |
+
" tool_result = TOOL_MAP[tool_name](image_path)\n",
|
629 |
+
" else:\n",
|
630 |
+
" raise ValueError(f\"Unknown tool called: {tool_name}\")\n",
|
631 |
+
" \n",
|
632 |
+
" # 保存每个工具调用结果\n",
|
633 |
+
" results.append({\n",
|
634 |
+
" 'role': 'tool',\n",
|
635 |
+
" 'content': json.dumps({\n",
|
636 |
+
" 'image_path': image_path,\n",
|
637 |
+
" f'{tool_name}':(tool_result),\n",
|
638 |
+
" }),\n",
|
639 |
+
" 'tool_call_id': tool_call_id,\n",
|
640 |
+
" })\n",
|
641 |
+
"\n",
|
642 |
+
"\n",
|
643 |
+
"# Prepare the chat completion payload\n",
|
644 |
+
" completion_payload = {\n",
|
645 |
+
" 'model': 'gpt-4o',\n",
|
646 |
+
" 'messages': [\n",
|
647 |
+
" {'role': 'system', 'content': 'You are a helpful assistant.'},\n",
|
648 |
+
" {\n",
|
649 |
+
" 'role': 'user',\n",
|
650 |
+
" 'content': [\n",
|
651 |
+
" {\n",
|
652 |
+
" 'type': 'text',\n",
|
653 |
+
" 'text': prompt\n",
|
654 |
+
" },\n",
|
655 |
+
" {\n",
|
656 |
+
" 'type': 'image_url',\n",
|
657 |
+
" 'image_url': {\n",
|
658 |
+
" 'url': f'data:image/png;base64,{base64_image}'\n",
|
659 |
+
" }\n",
|
660 |
+
" }\n",
|
661 |
+
" ]\n",
|
662 |
+
" },\n",
|
663 |
+
" response.choices[0].message,\n",
|
664 |
+
" *results\n",
|
665 |
+
" ],\n",
|
666 |
+
" }\n",
|
667 |
+
"\n",
|
668 |
+
"# Generate new response\n",
|
669 |
+
" response = client.chat.completions.create(\n",
|
670 |
+
" model=completion_payload[\"model\"],\n",
|
671 |
+
" messages=completion_payload[\"messages\"],\n",
|
672 |
+
" response_format={ 'type': 'json_object' },\n",
|
673 |
+
" temperature=0\n",
|
674 |
+
" )\n",
|
675 |
+
"\n",
|
676 |
+
"\n",
|
677 |
+
" \n",
|
678 |
+
" # 获取 GPT 生成的结果\n",
|
679 |
+
" gpt_output = json.loads(response.choices[0].message.content)\n",
|
680 |
+
" print(gpt_output)\n",
|
681 |
+
" return gpt_output\n",
|
682 |
+
"\n",
|
683 |
+
"\n",
|
684 |
+
"\n"
|
685 |
+
]
|
686 |
+
},
|
687 |
+
{
|
688 |
+
"cell_type": "code",
|
689 |
+
"execution_count": null,
|
690 |
+
"metadata": {},
|
691 |
+
"outputs": [],
|
692 |
+
"source": [
|
693 |
+
"image_path = './data/bowen-4/2.png'\n",
|
694 |
+
"result = process_reaction_image_final(image_path)\n",
|
695 |
+
"print(json.dumps(result, indent=4))"
|
696 |
+
]
|
697 |
+
},
|
698 |
+
{
|
699 |
+
"cell_type": "code",
|
700 |
+
"execution_count": null,
|
701 |
+
"metadata": {},
|
702 |
+
"outputs": [],
|
703 |
+
"source": [
|
704 |
+
"# def get_reaction(image_path: str) -> list:\n",
|
705 |
+
"# '''Returns a list of reactions extracted from the image.'''\n",
|
706 |
+
"# image_file = image_path\n",
|
707 |
+
"# return json.dumps(model1.predict_image_file(image_file, molscribe=True, ocr=True))\n",
|
708 |
+
"\n",
|
709 |
+
"# reaction_output = get_reaction('./pdf/2/2_image_3_1.png')\n",
|
710 |
+
"# print(reaction_output)"
|
711 |
+
]
|
712 |
+
},
|
713 |
+
{
|
714 |
+
"cell_type": "code",
|
715 |
+
"execution_count": null,
|
716 |
+
"metadata": {},
|
717 |
+
"outputs": [],
|
718 |
+
"source": [
|
719 |
+
"import os\n",
|
720 |
+
"import fitz # PyMuPDF\n",
|
721 |
+
"from core import run_visualheist\n",
|
722 |
+
"import base64\n",
|
723 |
+
"from openai import AzureOpenAI\n",
|
724 |
+
"\n",
|
725 |
+
"def full_pdf_extraction_pipeline_with_history(pdf_path,\n",
|
726 |
+
" output_dir,\n",
|
727 |
+
" api_key,\n",
|
728 |
+
" azure_endpoint,\n",
|
729 |
+
" model=\"gpt-4o\",\n",
|
730 |
+
" model_size=\"large\"):\n",
|
731 |
+
" \"\"\"\n",
|
732 |
+
" Full pipeline: from PDF to GPT-annotated related text.\n",
|
733 |
+
" Extracts markdown + figures + reaction data from a PDF and calls GPT-4o to annotate them.\n",
|
734 |
+
"\n",
|
735 |
+
" Args:\n",
|
736 |
+
" pdf_path (str): Path to input PDF file.\n",
|
737 |
+
" output_dir (str): Directory to save results.\n",
|
738 |
+
" api_key (str): Azure OpenAI API key.\n",
|
739 |
+
" azure_endpoint (str): Azure OpenAI endpoint.\n",
|
740 |
+
" model (str): GPT model name (default \"gpt-4o\").\n",
|
741 |
+
" model_size (str): VisualHeist model size (\"base\", \"large\", etc).\n",
|
742 |
+
"\n",
|
743 |
+
" Returns:\n",
|
744 |
+
" List of GPT-generated annotated related-text JSONs.\n",
|
745 |
+
" \"\"\"\n",
|
746 |
+
"\n",
|
747 |
+
"\n",
|
748 |
+
" os.makedirs(output_dir, exist_ok=True)\n",
|
749 |
+
"\n",
|
750 |
+
" # Step 1: Extract Markdown text\n",
|
751 |
+
" doc = fitz.open(pdf_path)\n",
|
752 |
+
" md_text = \"\"\n",
|
753 |
+
" for i, page in enumerate(doc, start=1):\n",
|
754 |
+
" md_text += f\"\\n\\n## = Page {i} =\\n\\n\" + page.get_text()\n",
|
755 |
+
" filename = os.path.splitext(os.path.basename(pdf_path))[0]\n",
|
756 |
+
" md_path = os.path.join(output_dir, f\"{filename}.md\")\n",
|
757 |
+
" with open(md_path, \"w\", encoding=\"utf-8\") as f:\n",
|
758 |
+
" f.write(md_text.strip())\n",
|
759 |
+
" print(f\"[✓] Markdown saved to: {md_path}\")\n",
|
760 |
+
"\n",
|
761 |
+
" # Step 2: Extract figures using VisualHeist\n",
|
762 |
+
" run_visualheist(pdf_dir=pdf_path, model_size=model_size, image_dir=output_dir)\n",
|
763 |
+
" print(f\"[✓] Figures extracted to: {output_dir}\")\n",
|
764 |
+
"\n",
|
765 |
+
" # Step 3: Parse figures to JSON\n",
|
766 |
+
" image_data = []\n",
|
767 |
+
" known_molecules = []\n",
|
768 |
+
"\n",
|
769 |
+
" for fname in sorted(os.listdir(output_dir)):\n",
|
770 |
+
" if fname.endswith(\".png\"):\n",
|
771 |
+
" img_path = os.path.join(output_dir, fname)\n",
|
772 |
+
" try:\n",
|
773 |
+
" result = process_reaction_image_final(img_path)\n",
|
774 |
+
" result[\"image_name\"] = fname\n",
|
775 |
+
" image_data.append(result)\n",
|
776 |
+
" except Exception as e:\n",
|
777 |
+
" print(f\"[!] Failed on {fname}: {e}\")\n",
|
778 |
+
" new_mols_json = get_multi_molecular_text_to_correct(img_path)\n",
|
779 |
+
" new_mols = json.loads(new_mols_json)\n",
|
780 |
+
" for m in new_mols:\n",
|
781 |
+
" if m[\"smiles\"] not in {km[\"smiles\"] for km in known_molecules}:\n",
|
782 |
+
" known_molecules.append(m)\n",
|
783 |
+
"\n",
|
784 |
+
"\n",
|
785 |
+
" json_path = os.path.join(output_dir, f\"{filename}_reaction_data.json\")\n",
|
786 |
+
" with open(json_path, \"w\", encoding=\"utf-8\") as f:\n",
|
787 |
+
" json.dump(image_data, f, indent=2, ensure_ascii=False)\n",
|
788 |
+
" print(f\"[✓] Reaction data saved to: {json_path}\")\n",
|
789 |
+
"\n",
|
790 |
+
" # Step 4: Call Azure GPT-4 for annotation\n",
|
791 |
+
" client = AzureOpenAI(\n",
|
792 |
+
" api_key=api_key,\n",
|
793 |
+
" api_version=\"2024-06-01\",\n",
|
794 |
+
" azure_endpoint=azure_endpoint\n",
|
795 |
+
" )\n",
|
796 |
+
"\n",
|
797 |
+
" prompt = \"\"\"\n",
|
798 |
+
"You are a text-mining assistant for chemistry papers. Your task is to find the most relevant 1–3 sentences in a research article that describe a given figure or scheme.\n",
|
799 |
+
"\n",
|
800 |
+
"You will be given:\n",
|
801 |
+
"- A block of text extracted from the article (in Markdown format).\n",
|
802 |
+
"- The extracted structured data from one image (including its title and list of molecules or reactions).\n",
|
803 |
+
"\n",
|
804 |
+
"Your task is:\n",
|
805 |
+
"1. Match the image with sentences that are most relevant to it. Use clues like the figure/scheme/table number in the title, or molecule/reaction labels (e.g., 1a, 2b, 3).\n",
|
806 |
+
"2. Extract up to 3 short sentences that best describe or mention the contents of the image.\n",
|
807 |
+
"3. In these sentences, label any molecule or reaction identifiers (like “1a”, “2b”) with their role based on context: [reactant], [product], etc.\n",
|
808 |
+
"4. Also label experimental conditions with their roles:\n",
|
809 |
+
" - Percent values like “85%” as [yield]\n",
|
810 |
+
" - Temperatures like “100 °C” as [temperature]\n",
|
811 |
+
" - Time durations like “24 h”, “20 min” as [time]\n",
|
812 |
+
"5. Do **not** label chemical position numbers (e.g., in \"3-trifluoromethyl\", \"1,2,4-triazole\").\n",
|
813 |
+
"6. Do not repeat any labels. Only label each item once per sentence.\n",
|
814 |
+
"\n",
|
815 |
+
"Output format:\n",
|
816 |
+
"{\n",
|
817 |
+
" \"title\": \"<title from image>\",\n",
|
818 |
+
" \"related-text\": [\n",
|
819 |
+
" \"Sentence with roles like 1a[reactant], 2c[product], 100[temperature] °C.\",\n",
|
820 |
+
" ...\n",
|
821 |
+
" ]\n",
|
822 |
+
"}\n",
|
823 |
+
"\"\"\"\n",
|
824 |
+
"\n",
|
825 |
+
" annotated_results = []\n",
|
826 |
+
" for item in image_data:\n",
|
827 |
+
" img_path = os.path.join(output_dir, item[\"image_name\"])\n",
|
828 |
+
" with open(img_path, \"rb\") as f:\n",
|
829 |
+
" base64_image = base64.b64encode(f.read()).decode(\"utf-8\")\n",
|
830 |
+
"\n",
|
831 |
+
" combined_input = f\"\"\"\n",
|
832 |
+
"## Image Structured Data:\n",
|
833 |
+
"{json.dumps(item, indent=2)}\n",
|
834 |
+
"\n",
|
835 |
+
"## Article Text:\n",
|
836 |
+
"{md_text}\n",
|
837 |
+
"\"\"\"\n",
|
838 |
+
"\n",
|
839 |
+
" response = client.chat.completions.create(\n",
|
840 |
+
" model=model,\n",
|
841 |
+
" temperature=0,\n",
|
842 |
+
" response_format=\"json\",\n",
|
843 |
+
" messages=[\n",
|
844 |
+
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
|
845 |
+
" {\n",
|
846 |
+
" \"role\": \"user\",\n",
|
847 |
+
" \"content\": [\n",
|
848 |
+
" {\"type\": \"text\", \"text\": prompt + \"\\n\\n\" + combined_input},\n",
|
849 |
+
" {\n",
|
850 |
+
" \"type\": \"image_url\",\n",
|
851 |
+
" \"image_url\": {\n",
|
852 |
+
" \"url\": f\"data:image/png;base64,{base64_image}\"\n",
|
853 |
+
" }\n",
|
854 |
+
" }\n",
|
855 |
+
" ]\n",
|
856 |
+
" }\n",
|
857 |
+
" ]\n",
|
858 |
+
" )\n",
|
859 |
+
" annotated_results.append(json.loads(response.choices[0].message.content))\n",
|
860 |
+
"\n",
|
861 |
+
" # Optionally save output\n",
|
862 |
+
" with open(os.path.join(output_dir, f\"{filename}_annotated_related_text.json\"), \"w\", encoding=\"utf-8\") as f:\n",
|
863 |
+
" json.dump(annotated_results, f, indent=2, ensure_ascii=False)\n",
|
864 |
+
" print(f\"[✓] Annotated related-text saved.\")\n",
|
865 |
+
"\n",
|
866 |
+
" return annotated_results"
|
867 |
+
]
|
868 |
+
},
|
869 |
+
{
|
870 |
+
"cell_type": "code",
|
871 |
+
"execution_count": null,
|
872 |
+
"metadata": {},
|
873 |
+
"outputs": [],
|
874 |
+
"source": [
|
875 |
+
"image_path = './data/example/example1/replace/Nesting/283.jpg'\n",
|
876 |
+
"#image_path = './pdf/2/2_image_1_1.png'\n",
|
877 |
+
"result = process_reaction_image_final(image_path)\n",
|
878 |
+
"print(json.dumps(result, indent=4))"
|
879 |
+
]
|
880 |
+
},
|
881 |
+
{
|
882 |
+
"cell_type": "code",
|
883 |
+
"execution_count": null,
|
884 |
+
"metadata": {},
|
885 |
+
"outputs": [],
|
886 |
+
"source": []
|
887 |
+
},
|
888 |
+
{
|
889 |
+
"cell_type": "code",
|
890 |
+
"execution_count": null,
|
891 |
+
"metadata": {},
|
892 |
+
"outputs": [],
|
893 |
+
"source": [
|
894 |
+
"# import os\n",
|
895 |
+
"\n",
|
896 |
+
"# image_folder = './example/example1/replace/regular/' # 图片文件夹路径\n",
|
897 |
+
"# output_folder = './batches_final_repalce_regular/' # 保存每批结果的文件夹路径\n",
|
898 |
+
"# batch_size = 3 # 每批处理文件数量\n",
|
899 |
+
"\n",
|
900 |
+
"# # 创建保存批次结果的文件夹(如果不存在)\n",
|
901 |
+
"# os.makedirs(output_folder, exist_ok=True)\n",
|
902 |
+
"\n",
|
903 |
+
"# # 获取所有图片文件并按字母顺序排序\n",
|
904 |
+
"# all_files = sorted([f for f in os.listdir(image_folder) if f.endswith('.jpg')])\n",
|
905 |
+
"\n",
|
906 |
+
"# # 获取已完成的批次\n",
|
907 |
+
"# completed_batches = [\n",
|
908 |
+
"# int(f.split('_')[1].split('.')[0]) for f in os.listdir(output_folder) if f.startswith('batch_') and f.endswith('.json')\n",
|
909 |
+
"# ]\n",
|
910 |
+
"# completed_batches = sorted(completed_batches) # 确保按顺序排序\n",
|
911 |
+
"\n",
|
912 |
+
"# # 从指定批次开始(如果有未完成批次)\n",
|
913 |
+
"# start_batch = (completed_batches[-1] + 1) if completed_batches else 1\n",
|
914 |
+
"\n",
|
915 |
+
"# # 将文件分批并从指定批次开始\n",
|
916 |
+
"# for batch_index in range((start_batch - 1) * batch_size, len(all_files), batch_size):\n",
|
917 |
+
"# batch_files = all_files[batch_index:batch_index + batch_size]\n",
|
918 |
+
"# results = []\n",
|
919 |
+
"\n",
|
920 |
+
"# batch_number = batch_index // batch_size + 1\n",
|
921 |
+
"# print(f\"正在按字母顺序处理第 {batch_number} 批,共 {len(batch_files)} 张图片...\")\n",
|
922 |
+
" \n",
|
923 |
+
"# for file_name in batch_files:\n",
|
924 |
+
"# image_path = os.path.join(image_folder, file_name)\n",
|
925 |
+
"# print(f\"处理文件 {file_name}...\")\n",
|
926 |
+
" \n",
|
927 |
+
"# try:\n",
|
928 |
+
"# # 处理单个图片\n",
|
929 |
+
"# result = process_reaction_image_final(image_path)\n",
|
930 |
+
" \n",
|
931 |
+
"# # 确保结果是字典\n",
|
932 |
+
"# if isinstance(result, dict):\n",
|
933 |
+
"# # 添加文件名信息\n",
|
934 |
+
"# result_with_filename = {\n",
|
935 |
+
"# \"file_name\": file_name,\n",
|
936 |
+
"# **result\n",
|
937 |
+
"# }\n",
|
938 |
+
"# results.append(result_with_filename)\n",
|
939 |
+
"# print(result_with_filename)\n",
|
940 |
+
"# else:\n",
|
941 |
+
"# print(f\"文件 {file_name} 的处理结果不是字典,跳过。\")\n",
|
942 |
+
" \n",
|
943 |
+
"# except Exception as e:\n",
|
944 |
+
"# print(f\"处理文件 {file_name} 时出错: {e}\")\n",
|
945 |
+
"\n",
|
946 |
+
"# # 保存当前批次结果\n",
|
947 |
+
"# batch_output_path = os.path.join(output_folder, f'batch_{batch_number}.json')\n",
|
948 |
+
"# with open(batch_output_path, 'w', encoding='utf-8') as json_file:\n",
|
949 |
+
"# json.dump(results, json_file, ensure_ascii=False, indent=4)\n",
|
950 |
+
"\n",
|
951 |
+
"# print(f\"第 {batch_number} 批处理完成,结果保存到 {batch_output_path}\")\n",
|
952 |
+
"\n",
|
953 |
+
"# print(\"所有批次处理完成!\")\n",
|
954 |
+
"\n",
|
955 |
+
"\n"
|
956 |
+
]
|
957 |
+
},
|
958 |
+
{
|
959 |
+
"cell_type": "code",
|
960 |
+
"execution_count": null,
|
961 |
+
"metadata": {},
|
962 |
+
"outputs": [],
|
963 |
+
"source": [
|
964 |
+
"import rdkit\n",
|
965 |
+
"from rdkit import Chem\n",
|
966 |
+
"from rdkit.Chem import Draw\n",
|
967 |
+
"\n",
|
968 |
+
"Draw.MolToImage(Chem.MolFromSmiles('[Si](C)(C)OC(c1ccccc1)(c1ccccc1)C1CCC2=NN(Cc3ccccc3)=CN21'))"
|
969 |
+
]
|
970 |
+
}
|
971 |
+
],
|
972 |
+
"metadata": {
|
973 |
+
"kernelspec": {
|
974 |
+
"display_name": "openchemie",
|
975 |
+
"language": "python",
|
976 |
+
"name": "python3"
|
977 |
+
},
|
978 |
+
"language_info": {
|
979 |
+
"codemirror_mode": {
|
980 |
+
"name": "ipython",
|
981 |
+
"version": 3
|
982 |
+
},
|
983 |
+
"file_extension": ".py",
|
984 |
+
"mimetype": "text/x-python",
|
985 |
+
"name": "python",
|
986 |
+
"nbconvert_exporter": "python",
|
987 |
+
"pygments_lexer": "ipython3",
|
988 |
+
"version": "3.10.14"
|
989 |
+
}
|
990 |
+
},
|
991 |
+
"nbformat": 4,
|
992 |
+
"nbformat_minor": 2
|
993 |
+
}
|
molscribe/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .interface import MolScribe
|
molscribe/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (189 Bytes). View file
|
|
molscribe/__pycache__/augment.cpython-310.pyc
ADDED
Binary file (8.98 kB). View file
|
|