CYF200127 commited on
Commit
1f516b6
·
verified ·
1 Parent(s): 768c438

Upload 162 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. __init__.py +3 -0
  3. __pycache__/get_molecular_agent.cpython-310.pyc +0 -0
  4. __pycache__/get_reaction_agent.cpython-310.pyc +0 -0
  5. __pycache__/main.cpython-310.pyc +0 -0
  6. app.ipynb +295 -0
  7. app.py +239 -0
  8. chemiener/__init__.py +1 -0
  9. chemiener/__pycache__/__init__.cpython-310.pyc +0 -0
  10. chemiener/__pycache__/__init__.cpython-38.pyc +0 -0
  11. chemiener/__pycache__/dataset.cpython-310.pyc +0 -0
  12. chemiener/__pycache__/dataset.cpython-38.pyc +0 -0
  13. chemiener/__pycache__/interface.cpython-310.pyc +0 -0
  14. chemiener/__pycache__/interface.cpython-38.pyc +0 -0
  15. chemiener/__pycache__/model.cpython-310.pyc +0 -0
  16. chemiener/__pycache__/model.cpython-38.pyc +0 -0
  17. chemiener/__pycache__/utils.cpython-310.pyc +0 -0
  18. chemiener/__pycache__/utils.cpython-38.pyc +0 -0
  19. chemiener/dataset.py +172 -0
  20. chemiener/interface.py +124 -0
  21. chemiener/main.py +345 -0
  22. chemiener/model.py +14 -0
  23. chemiener/utils.py +23 -0
  24. chemietoolkit/__init__.py +1 -0
  25. chemietoolkit/__pycache__/__init__.cpython-310.pyc +0 -0
  26. chemietoolkit/__pycache__/__init__.cpython-38.pyc +0 -0
  27. chemietoolkit/__pycache__/chemrxnextractor.cpython-310.pyc +0 -0
  28. chemietoolkit/__pycache__/chemrxnextractor.cpython-38.pyc +0 -0
  29. chemietoolkit/__pycache__/interface.cpython-310.pyc +0 -0
  30. chemietoolkit/__pycache__/interface.cpython-38.pyc +0 -0
  31. chemietoolkit/__pycache__/tableextractor.cpython-310.pyc +0 -0
  32. chemietoolkit/__pycache__/utils.cpython-310.pyc +0 -0
  33. chemietoolkit/chemrxnextractor.py +107 -0
  34. chemietoolkit/interface.py +749 -0
  35. chemietoolkit/tableextractor.py +340 -0
  36. chemietoolkit/utils.py +1018 -0
  37. examples/exp.png +3 -0
  38. examples/image.webp +0 -0
  39. examples/rdkit.png +0 -0
  40. examples/reaction1.jpg +0 -0
  41. examples/reaction2.png +0 -0
  42. examples/reaction3.png +0 -0
  43. examples/reaction4.png +3 -0
  44. get_molecular_agent.py +599 -0
  45. get_reaction_agent.py +507 -0
  46. main.py +546 -0
  47. main_Rgroup_debug.ipynb +993 -0
  48. molscribe/__init__.py +1 -0
  49. molscribe/__pycache__/__init__.cpython-310.pyc +0 -0
  50. molscribe/__pycache__/augment.cpython-310.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/exp.png filter=lfs diff=lfs merge=lfs -text
37
+ examples/reaction4.png filter=lfs diff=lfs merge=lfs -text
38
+ molscribe/indigo/lib/Linux/x64/libbingo.so filter=lfs diff=lfs merge=lfs -text
39
+ molscribe/indigo/lib/Linux/x64/libindigo-renderer.so filter=lfs diff=lfs merge=lfs -text
40
+ molscribe/indigo/lib/Linux/x64/libindigo.so filter=lfs diff=lfs merge=lfs -text
__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __version__ = "0.1.0"
2
+ __author__ = 'Alex Wang'
3
+ __credits__ = 'CSAIL'
__pycache__/get_molecular_agent.cpython-310.pyc ADDED
Binary file (9.08 kB). View file
 
__pycache__/get_reaction_agent.cpython-310.pyc ADDED
Binary file (6.94 kB). View file
 
__pycache__/main.cpython-310.pyc ADDED
Binary file (8.97 kB). View file
 
app.ipynb ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "d13d3631",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "* Running on local URL: http://127.0.0.1:7866\n",
14
+ "\n",
15
+ "To create a public link, set `share=True` in `launch()`.\n"
16
+ ]
17
+ },
18
+ {
19
+ "data": {
20
+ "text/html": [
21
+ "<div><iframe src=\"http://127.0.0.1:7866/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
22
+ ],
23
+ "text/plain": [
24
+ "<IPython.core.display.HTML object>"
25
+ ]
26
+ },
27
+ "metadata": {},
28
+ "output_type": "display_data"
29
+ }
30
+ ],
31
+ "source": [
32
+ "import os\n",
33
+ "import gradio as gr\n",
34
+ "import json\n",
35
+ "from main import ChemEagle # 支持 API key 通过环境变量\n",
36
+ "from rdkit import Chem\n",
37
+ "from rdkit.Chem import rdChemReactions\n",
38
+ "from rdkit.Chem import Draw\n",
39
+ "from rdkit.Chem import AllChem\n",
40
+ "from rdkit.Chem.Draw import rdMolDraw2D\n",
41
+ "import cairosvg\n",
42
+ "import re\n",
43
+ "import torch\n",
44
+ "\n",
45
+ "example_diagram = \"examples/exp.png\"\n",
46
+ "rdkit_image = \"examples/rdkit.png\"\n",
47
+ "# 解析 ChemEagle 返回的结构化数据\n",
48
+ "def parse_reactions(output_json):\n",
49
+ " \"\"\"\n",
50
+ " 解析 JSON 格式的反应数据并格式化输出,包含颜色定制。\n",
51
+ " \"\"\"\n",
52
+ " if isinstance(output_json, str):\n",
53
+ " reactions_data = json.loads(output_json)\n",
54
+ " elif isinstance(output_json, dict):\n",
55
+ " reactions_data = output_json # 转换 JSON 字符串为字典\n",
56
+ " reactions_list = reactions_data.get(\"reactions\", [])\n",
57
+ " detailed_output = []\n",
58
+ " smiles_output = [] \n",
59
+ "\n",
60
+ " for reaction in reactions_list:\n",
61
+ " reaction_id = reaction.get(\"reaction_id\", \"Unknown ID\")\n",
62
+ " reactants = [r.get(\"smiles\", \"Unknown\") for r in reaction.get(\"reactants\", [])]\n",
63
+ " conditions = [\n",
64
+ " f\"<span style='color:red'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>\"\n",
65
+ " for c in reaction.get(\"condition\", [])\n",
66
+ " ]\n",
67
+ " conditions_1 = [\n",
68
+ " f\"<span style='color:black'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>\"\n",
69
+ " for c in reaction.get(\"condition\", [])\n",
70
+ " ]\n",
71
+ " products = [f\"<span style='color:orange'>{p.get('smiles', 'Unknown')}</span>\" for p in reaction.get(\"products\", [])]\n",
72
+ " products_1 = [f\"<span style='color:black'>{p.get('smiles', 'Unknown')}</span>\" for p in reaction.get(\"products\", [])]\n",
73
+ " products_2 = [r.get(\"smiles\", \"Unknown\") for r in reaction.get(\"products\", [])]\n",
74
+ " \n",
75
+ " additional = reaction.get(\"additional_info\", [])\n",
76
+ " additional_str = [str(x) for x in additional if x is not None]\n",
77
+ "\n",
78
+ " tail = conditions_1 + additional_str\n",
79
+ " tail_str = \", \".join(tail)\n",
80
+ "\n",
81
+ " # 构造反应的完整字符串,定制字体颜色\n",
82
+ " full_reaction = f\"{'.'.join(reactants)}>>{'.'.join(products_1)} | {tail_str}\"\n",
83
+ " full_reaction = f\"<span style='color:black'>{full_reaction}</span>\"\n",
84
+ " \n",
85
+ " # 详细反应格式化输出\n",
86
+ " reaction_output = f\"<b>Reaction: </b> {reaction_id}<br>\"\n",
87
+ " reaction_output += f\" Reactants: <span style='color:blue'>{', '.join(reactants)}</span><br>\"\n",
88
+ " reaction_output += f\" Conditions: {', '.join(conditions)}<br>\"\n",
89
+ " reaction_output += f\" Products: {', '.join(products)}<br>\"\n",
90
+ " reaction_output += f\" additional_info: {', '.join(additional_str)}<br>\"\n",
91
+ " reaction_output += f\" <b>Full Reaction:</b> {full_reaction}<br>\"\n",
92
+ " reaction_output += \"<br>\"\n",
93
+ " detailed_output.append(reaction_output)\n",
94
+ "\n",
95
+ " reaction_smiles = f\"{'.'.join(reactants)}>>{'.'.join(products_2)}\"\n",
96
+ " smiles_output.append(reaction_smiles)\n",
97
+ " return detailed_output, smiles_output\n",
98
+ "\n",
99
+ "\n",
100
+ "# 核心处理函数,仅使用 API Key 和图像\n",
101
+ "def process_chem_image(api_key, image):\n",
102
+ " # 设置 API Key 环境变量,供 ChemEagle 使用\n",
103
+ " os.environ[\"CHEMEAGLE_API_KEY\"] = api_key\n",
104
+ "\n",
105
+ " # 保存上传图片\n",
106
+ " image_path = \"temp_image.png\"\n",
107
+ " image.save(image_path)\n",
108
+ "\n",
109
+ " # 调用 ChemEagle(实现内部读取 os.getenv)\n",
110
+ " chemeagle_result = ChemEagle(image_path)\n",
111
+ "\n",
112
+ " # 解析输出\n",
113
+ " detailed, smiles = parse_reactions(chemeagle_result)\n",
114
+ "\n",
115
+ " # 写出 JSON\n",
116
+ " json_path = \"output.json\"\n",
117
+ " with open(json_path, 'w') as jf:\n",
118
+ " json.dump(chemeagle_result, jf, indent=2)\n",
119
+ "\n",
120
+ " # 返回 HTML、SMILES 合并文本、示意图、JSON 下载\n",
121
+ " return \"\\n\\n\".join(detailed), smiles, example_diagram, json_path\n",
122
+ "\n",
123
+ "# 构建 Gradio 界面\n",
124
+ "with gr.Blocks() as demo:\n",
125
+ " gr.Markdown(\n",
126
+ " \"\"\"\n",
127
+ " <center><h1>ChemEagle: A Multi-Agent System for Multimodal Chemical Information Extraction</h1></center>\n",
128
+ " Upload a multimodal reaction image and type your OpenAI API key to extract multimodal chemical information.\n",
129
+ " \"\"\"\n",
130
+ " )\n",
131
+ "\n",
132
+ " with gr.Row():\n",
133
+ " # ———— 左侧:上传 + API Key + 按钮 ————\n",
134
+ " with gr.Column(scale=1):\n",
135
+ " image_input = gr.Image(type=\"pil\", label=\"Upload a multimodal reaction image\")\n",
136
+ " api_key_input = gr.Textbox(\n",
137
+ " label=\"Your API-Key\",\n",
138
+ " placeholder=\"Type your OpenAI_API_KEY\",\n",
139
+ " type=\"password\"\n",
140
+ " )\n",
141
+ " with gr.Row():\n",
142
+ " clear_btn = gr.Button(\"Clear\")\n",
143
+ " run_btn = gr.Button(\"Run\", elem_id=\"submit-btn\")\n",
144
+ "\n",
145
+ " # ———— 中间:解析结果 + 示意图 ————\n",
146
+ " with gr.Column(scale=1):\n",
147
+ " gr.Markdown(\"### Parsed Reactions\")\n",
148
+ " reaction_output = gr.HTML(label=\"Detailed Reaction Output\")\n",
149
+ " gr.Markdown(\"### Schematic Diagram\")\n",
150
+ " schematic_diagram = gr.Image(value=example_diagram, label=\"示意图\")\n",
151
+ "\n",
152
+ " # ———— 右侧:SMILES 拆分 & RDKit 渲染 + JSON 下载 ————\n",
153
+ " with gr.Column(scale=1):\n",
154
+ " gr.Markdown(\"### Machine-readable Output\")\n",
155
+ " smiles_output = gr.Textbox(\n",
156
+ " label=\"Reaction SMILES\",\n",
157
+ " show_copy_button=True,\n",
158
+ " interactive=False,\n",
159
+ " visible=False\n",
160
+ " )\n",
161
+ "\n",
162
+ " @gr.render(inputs = smiles_output) # 使用gr.render修饰器绑定输入和渲染逻辑\n",
163
+ " def show_split(inputs): # 定义处理和展示分割文本的函数\n",
164
+ " if not inputs or isinstance(inputs, str) and inputs.strip() == \"\": # 检查输入文本是否为空\n",
165
+ " return gr.Textbox(label= \"SMILES of Reaction i\"), gr.Image(value=rdkit_image, label= \"RDKit Image of Reaction i\",height=100)\n",
166
+ " else:\n",
167
+ " # 假设输入是逗号分隔的 SMILES 字符串\n",
168
+ " smiles_list = inputs.split(\",\")\n",
169
+ " smiles_list = [re.sub(r\"^\\s*\\[?'?|'\\]?\\s*$\", \"\", item) for item in smiles_list]\n",
170
+ " components = [] # 初始化一个组件列表,用于存放每个 SMILES 对应的 Textbox 组件\n",
171
+ " for i, smiles in enumerate(smiles_list): \n",
172
+ " smiles.replace('\"', '').replace(\"'\", \"\").replace(\"[\", \"\").replace(\"]\", \"\")\n",
173
+ " rxn = rdChemReactions.ReactionFromSmarts(smiles, useSmiles=True)\n",
174
+ " \n",
175
+ " if rxn:\n",
176
+ "\n",
177
+ " new_rxn = AllChem.ChemicalReaction()\t\n",
178
+ " for mol in rxn.GetReactants():\n",
179
+ " mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol))\n",
180
+ " new_rxn.AddReactantTemplate(mol)\n",
181
+ " for mol in rxn.GetProducts():\n",
182
+ " mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol))\n",
183
+ " new_rxn.AddProductTemplate(mol)\n",
184
+ "\n",
185
+ " rxn = new_rxn\n",
186
+ "\n",
187
+ " def atom_mapping_remover(rxn):\n",
188
+ " for reactant in rxn.GetReactants():\n",
189
+ " for atom in reactant.GetAtoms():\n",
190
+ " atom.SetAtomMapNum(0)\n",
191
+ " for product in rxn.GetProducts():\n",
192
+ " for atom in product.GetAtoms():\n",
193
+ " atom.SetAtomMapNum(0)\n",
194
+ " return rxn\n",
195
+ " \n",
196
+ " atom_mapping_remover(rxn)\n",
197
+ "\n",
198
+ " reactant1 = rxn.GetReactantTemplate(0)\n",
199
+ " print(reactant1.GetNumBonds)\n",
200
+ " reactant2 = rxn.GetReactantTemplate(1) if rxn.GetNumReactantTemplates() > 1 else None\n",
201
+ "\n",
202
+ " if reactant1.GetNumBonds() > 0:\n",
203
+ " bond_length_reference = Draw.MeanBondLength(reactant1)\n",
204
+ " elif reactant2 and reactant2.GetNumBonds() > 0:\n",
205
+ " bond_length_reference = Draw.MeanBondLength(reactant2)\n",
206
+ " else:\n",
207
+ " bond_length_reference = 1.0 \n",
208
+ "\n",
209
+ "\n",
210
+ " drawer = rdMolDraw2D.MolDraw2DSVG(-1, -1)\n",
211
+ " dopts = drawer.drawOptions()\n",
212
+ " dopts.padding = 0.1 \n",
213
+ " dopts.includeRadicals = True\n",
214
+ " Draw.SetACS1996Mode(dopts, bond_length_reference*0.55)\n",
215
+ " dopts.bondLineWidth = 1.5\n",
216
+ " drawer.DrawReaction(rxn)\n",
217
+ " drawer.FinishDrawing()\n",
218
+ " svg_content = drawer.GetDrawingText()\n",
219
+ " svg_file = f\"reaction{i+1}.svg\"\n",
220
+ " with open(svg_file, \"w\") as f:\n",
221
+ " f.write(svg_content)\n",
222
+ " png_file = f\"reaction_{i+1}.png\"\n",
223
+ " cairosvg.svg2png(url=svg_file, write_to=png_file)\n",
224
+ "\n",
225
+ "\n",
226
+ " \n",
227
+ " components.append(gr.Textbox(value=smiles,label= f\"SMILES of Reaction {i}\", show_copy_button=True, interactive=False))\n",
228
+ " components.append(gr.Image(value=png_file,label= f\"RDKit Image of Reaction {i}\")) \n",
229
+ " return components # 返回包含所有 SMILES Textbox 组件的列表\n",
230
+ "\n",
231
+ " download_json = gr.File(label=\"Download JSON File\")\n",
232
+ "\n",
233
+ "\n",
234
+ " gr.Examples(\n",
235
+ " examples=[\n",
236
+ " [\"examples/reaction1.jpg\", \"\"],\n",
237
+ " [\"examples/reaction2.png\", \"\"],\n",
238
+ " [\"examples/reaction3.png\", \"\"],\n",
239
+ " [\"examples/reaction4.png\", \"\"],\n",
240
+ " \n",
241
+ " \n",
242
+ " ],\n",
243
+ " inputs=[image_input, api_key_input],\n",
244
+ " outputs=[reaction_output, smiles_output, schematic_diagram, download_json],\n",
245
+ " cache_examples=False,\n",
246
+ " examples_per_page=4,\n",
247
+ " )\n",
248
+ "\n",
249
+ " # ———— 清空与运行 绑定 ————\n",
250
+ " clear_btn.click(\n",
251
+ " lambda: (None, None, None, None, None),\n",
252
+ " inputs=[],\n",
253
+ " outputs=[image_input, api_key_input, reaction_output, smiles_output, download_json]\n",
254
+ " )\n",
255
+ " run_btn.click(\n",
256
+ " process_chem_image,\n",
257
+ " inputs=[api_key_input, image_input],\n",
258
+ " outputs=[reaction_output, smiles_output, schematic_diagram, download_json]\n",
259
+ " )\n",
260
+ "\n",
261
+ " # 自定义按钮样式\n",
262
+ " demo.css = \"\"\"\n",
263
+ " #submit-btn {\n",
264
+ " background-color: #FF914D;\n",
265
+ " color: white;\n",
266
+ " font-weight: bold;\n",
267
+ " }\n",
268
+ " \"\"\"\n",
269
+ "\n",
270
+ " demo.launch()"
271
+ ]
272
+ }
273
+ ],
274
+ "metadata": {
275
+ "kernelspec": {
276
+ "display_name": "openchemie",
277
+ "language": "python",
278
+ "name": "python3"
279
+ },
280
+ "language_info": {
281
+ "codemirror_mode": {
282
+ "name": "ipython",
283
+ "version": 3
284
+ },
285
+ "file_extension": ".py",
286
+ "mimetype": "text/x-python",
287
+ "name": "python",
288
+ "nbconvert_exporter": "python",
289
+ "pygments_lexer": "ipython3",
290
+ "version": "3.10.14"
291
+ }
292
+ },
293
+ "nbformat": 4,
294
+ "nbformat_minor": 5
295
+ }
app.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import json
4
+ from main import ChemEagle # 支持 API key 通过环境变量
5
+ from rdkit import Chem
6
+ from rdkit.Chem import rdChemReactions
7
+ from rdkit.Chem import Draw
8
+ from rdkit.Chem import AllChem
9
+ from rdkit.Chem.Draw import rdMolDraw2D
10
+ import cairosvg
11
+ import re
12
+ import torch
13
+
14
+ example_diagram = "examples/exp.png"
15
+ rdkit_image = "examples/rdkit.png"
16
+ # 解析 ChemEagle 返回的结构化数据
17
+ def parse_reactions(output_json):
18
+ """
19
+ 解析 JSON 格式的反应数据并格式化输出,包含颜色定制。
20
+ """
21
+ if isinstance(output_json, str):
22
+ reactions_data = json.loads(output_json)
23
+ elif isinstance(output_json, dict):
24
+ reactions_data = output_json # 转换 JSON 字符串为字典
25
+ reactions_list = reactions_data.get("reactions", [])
26
+ detailed_output = []
27
+ smiles_output = []
28
+
29
+ for reaction in reactions_list:
30
+ reaction_id = reaction.get("reaction_id", "Unknown ID")
31
+ reactants = [r.get("smiles", "Unknown") for r in reaction.get("reactants", [])]
32
+ conditions = [
33
+ f"<span style='color:red'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>"
34
+ for c in reaction.get("condition", [])
35
+ ]
36
+ conditions_1 = [
37
+ f"<span style='color:black'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>"
38
+ for c in reaction.get("condition", [])
39
+ ]
40
+ products = [f"<span style='color:orange'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])]
41
+ products_1 = [f"<span style='color:black'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])]
42
+ products_2 = [r.get("smiles", "Unknown") for r in reaction.get("products", [])]
43
+
44
+ additional = reaction.get("additional_info", [])
45
+ additional_str = [str(x) for x in additional if x is not None]
46
+
47
+ tail = conditions_1 + additional_str
48
+ tail_str = ", ".join(tail)
49
+
50
+ # 构造反应的完整字符串,定制字体颜色
51
+ full_reaction = f"{'.'.join(reactants)}>>{'.'.join(products_1)} | {tail_str}"
52
+ full_reaction = f"<span style='color:black'>{full_reaction}</span>"
53
+
54
+ # 详细反应格式化输出
55
+ reaction_output = f"<b>Reaction: </b> {reaction_id}<br>"
56
+ reaction_output += f" Reactants: <span style='color:blue'>{', '.join(reactants)}</span><br>"
57
+ reaction_output += f" Conditions: {', '.join(conditions)}<br>"
58
+ reaction_output += f" Products: {', '.join(products)}<br>"
59
+ reaction_output += f" additional_info: {', '.join(additional_str)}<br>"
60
+ reaction_output += f" <b>Full Reaction:</b> {full_reaction}<br>"
61
+ reaction_output += "<br>"
62
+ detailed_output.append(reaction_output)
63
+
64
+ reaction_smiles = f"{'.'.join(reactants)}>>{'.'.join(products_2)}"
65
+ smiles_output.append(reaction_smiles)
66
+ return detailed_output, smiles_output
67
+
68
+
69
+ # 核心处理函数,仅使用 API Key 和图像
70
+ def process_chem_image(api_key, image):
71
+ # 设置 API Key 环境变量,供 ChemEagle 使用
72
+ os.environ["CHEMEAGLE_API_KEY"] = api_key
73
+
74
+ # 保存上传图片
75
+ image_path = "temp_image.png"
76
+ image.save(image_path)
77
+
78
+ # 调用 ChemEagle(实现内部读取 os.getenv)
79
+ chemeagle_result = ChemEagle(image_path)
80
+
81
+ # 解析输出
82
+ detailed, smiles = parse_reactions(chemeagle_result)
83
+
84
+ # 写出 JSON
85
+ json_path = "output.json"
86
+ with open(json_path, 'w') as jf:
87
+ json.dump(chemeagle_result, jf, indent=2)
88
+
89
+ # 返回 HTML、SMILES 合并文本、示意图、JSON 下载
90
+ return "\n\n".join(detailed), smiles, example_diagram, json_path
91
+
92
+ # 构建 Gradio 界面
93
+ with gr.Blocks() as demo:
94
+ gr.Markdown(
95
+ """
96
+ <center><h1>ChemEagle: A Multi-Agent System for Multimodal Chemical Information Extraction</h1></center>
97
+ Upload a multimodal reaction image and type your OpenAI API key to extract multimodal chemical information.
98
+ """
99
+ )
100
+
101
+ with gr.Row():
102
+ # ———— 左侧:上传 + API Key + 按钮 ————
103
+ with gr.Column(scale=1):
104
+ image_input = gr.Image(type="pil", label="Upload a multimodal reaction image")
105
+ api_key_input = gr.Textbox(
106
+ label="Your API-Key",
107
+ placeholder="Type your OpenAI_API_KEY",
108
+ type="password"
109
+ )
110
+ with gr.Row():
111
+ clear_btn = gr.Button("Clear")
112
+ run_btn = gr.Button("Run", elem_id="submit-btn")
113
+
114
+ # ———— 中间:解析结果 + 示意图 ————
115
+ with gr.Column(scale=1):
116
+ gr.Markdown("### Parsed Reactions")
117
+ reaction_output = gr.HTML(label="Detailed Reaction Output")
118
+ gr.Markdown("### Schematic Diagram")
119
+ schematic_diagram = gr.Image(value=example_diagram, label="示意图")
120
+
121
+ # ———— 右侧:SMILES 拆分 & RDKit 渲染 + JSON 下载 ————
122
+ with gr.Column(scale=1):
123
+ gr.Markdown("### Machine-readable Output")
124
+ smiles_output = gr.Textbox(
125
+ label="Reaction SMILES",
126
+ show_copy_button=True,
127
+ interactive=False,
128
+ visible=False
129
+ )
130
+
131
+ @gr.render(inputs = smiles_output) # 使用gr.render修饰器绑定输入和渲染逻辑
132
+ def show_split(inputs): # 定义处理和展示分割文本的函数
133
+ if not inputs or isinstance(inputs, str) and inputs.strip() == "": # 检查输入文本是否为空
134
+ return gr.Textbox(label= "SMILES of Reaction i"), gr.Image(value=rdkit_image, label= "RDKit Image of Reaction i",height=100)
135
+ else:
136
+ # 假设输入是逗号分隔的 SMILES 字符串
137
+ smiles_list = inputs.split(",")
138
+ smiles_list = [re.sub(r"^\s*\[?'?|'\]?\s*$", "", item) for item in smiles_list]
139
+ components = [] # 初始化一个组件列表,用于存放每个 SMILES 对应的 Textbox 组件
140
+ for i, smiles in enumerate(smiles_list):
141
+ smiles.replace('"', '').replace("'", "").replace("[", "").replace("]", "")
142
+ rxn = rdChemReactions.ReactionFromSmarts(smiles, useSmiles=True)
143
+
144
+ if rxn:
145
+
146
+ new_rxn = AllChem.ChemicalReaction()
147
+ for mol in rxn.GetReactants():
148
+ mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol))
149
+ new_rxn.AddReactantTemplate(mol)
150
+ for mol in rxn.GetProducts():
151
+ mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol))
152
+ new_rxn.AddProductTemplate(mol)
153
+
154
+ rxn = new_rxn
155
+
156
+ def atom_mapping_remover(rxn):
157
+ for reactant in rxn.GetReactants():
158
+ for atom in reactant.GetAtoms():
159
+ atom.SetAtomMapNum(0)
160
+ for product in rxn.GetProducts():
161
+ for atom in product.GetAtoms():
162
+ atom.SetAtomMapNum(0)
163
+ return rxn
164
+
165
+ atom_mapping_remover(rxn)
166
+
167
+ reactant1 = rxn.GetReactantTemplate(0)
168
+ print(reactant1.GetNumBonds)
169
+ reactant2 = rxn.GetReactantTemplate(1) if rxn.GetNumReactantTemplates() > 1 else None
170
+
171
+ if reactant1.GetNumBonds() > 0:
172
+ bond_length_reference = Draw.MeanBondLength(reactant1)
173
+ elif reactant2 and reactant2.GetNumBonds() > 0:
174
+ bond_length_reference = Draw.MeanBondLength(reactant2)
175
+ else:
176
+ bond_length_reference = 1.0
177
+
178
+
179
+ drawer = rdMolDraw2D.MolDraw2DSVG(-1, -1)
180
+ dopts = drawer.drawOptions()
181
+ dopts.padding = 0.1
182
+ dopts.includeRadicals = True
183
+ Draw.SetACS1996Mode(dopts, bond_length_reference*0.55)
184
+ dopts.bondLineWidth = 1.5
185
+ drawer.DrawReaction(rxn)
186
+ drawer.FinishDrawing()
187
+ svg_content = drawer.GetDrawingText()
188
+ svg_file = f"reaction{i+1}.svg"
189
+ with open(svg_file, "w") as f:
190
+ f.write(svg_content)
191
+ png_file = f"reaction_{i+1}.png"
192
+ cairosvg.svg2png(url=svg_file, write_to=png_file)
193
+
194
+
195
+
196
+ components.append(gr.Textbox(value=smiles,label= f"SMILES of Reaction {i}", show_copy_button=True, interactive=False))
197
+ components.append(gr.Image(value=png_file,label= f"RDKit Image of Reaction {i}"))
198
+ return components # 返回包含所有 SMILES Textbox 组件的列表
199
+
200
+ download_json = gr.File(label="Download JSON File")
201
+
202
+
203
+ gr.Examples(
204
+ examples=[
205
+ ["examples/reaction1.jpg", ""],
206
+ ["examples/reaction2.png", ""],
207
+ ["examples/reaction3.png", ""],
208
+ ["examples/reaction4.png", ""],
209
+
210
+
211
+ ],
212
+ inputs=[image_input, api_key_input],
213
+ outputs=[reaction_output, smiles_output, schematic_diagram, download_json],
214
+ cache_examples=False,
215
+ examples_per_page=4,
216
+ )
217
+
218
+ # ———— 清空与运行 绑定 ————
219
+ clear_btn.click(
220
+ lambda: (None, None, None, None, None),
221
+ inputs=[],
222
+ outputs=[image_input, api_key_input, reaction_output, smiles_output, download_json]
223
+ )
224
+ run_btn.click(
225
+ process_chem_image,
226
+ inputs=[api_key_input, image_input],
227
+ outputs=[reaction_output, smiles_output, schematic_diagram, download_json]
228
+ )
229
+
230
+ # 自定义按钮样式
231
+ demo.css = """
232
+ #submit-btn {
233
+ background-color: #FF914D;
234
+ color: white;
235
+ font-weight: bold;
236
+ }
237
+ """
238
+
239
+ demo.launch()
chemiener/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .interface import ChemNER
chemiener/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (187 Bytes). View file
 
chemiener/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (185 Bytes). View file
 
chemiener/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (5.37 kB). View file
 
chemiener/__pycache__/dataset.cpython-38.pyc ADDED
Binary file (5.35 kB). View file
 
chemiener/__pycache__/interface.cpython-310.pyc ADDED
Binary file (4.46 kB). View file
 
chemiener/__pycache__/interface.cpython-38.pyc ADDED
Binary file (4.47 kB). View file
 
chemiener/__pycache__/model.cpython-310.pyc ADDED
Binary file (684 Bytes). View file
 
chemiener/__pycache__/model.cpython-38.pyc ADDED
Binary file (680 Bytes). View file
 
chemiener/__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.67 kB). View file
 
chemiener/__pycache__/utils.cpython-38.pyc ADDED
Binary file (1.53 kB). View file
 
chemiener/dataset.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import copy
4
+ import random
5
+ import json
6
+ import contextlib
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch.utils.data import DataLoader, Dataset
12
+ from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
13
+
14
+ from transformers import BertTokenizerFast, AutoTokenizer, RobertaTokenizerFast
15
+
16
+ from .utils import get_class_to_index
17
+
18
+
19
+
20
+ class NERDataset(Dataset):
21
+ def __init__(self, args, data_file, split='train'):
22
+ super().__init__()
23
+ self.args = args
24
+ if data_file:
25
+ data_path = os.path.join(args.data_path, data_file)
26
+ with open(data_path) as f:
27
+ self.data = json.load(f)
28
+ self.name = os.path.basename(data_file).split('.')[0]
29
+ self.split = split
30
+ self.is_train = (split == 'train')
31
+ self.tokenizer = AutoTokenizer.from_pretrained(self.args.roberta_checkpoint, cache_dir = self.args.cache_dir)#BertTokenizerFast.from_pretrained('allenai/scibert_scivocab_uncased')
32
+ self.class_to_index = get_class_to_index(self.args.corpus)
33
+ self.index_to_class = {self.class_to_index[key]: key for key in self.class_to_index}
34
+
35
+ #commment
36
+ def __len__(self):
37
+ return len(self.data)
38
+
39
+ def __getitem__(self, idx):
40
+
41
+ text_tokenized = self.tokenizer(self.data[str(idx)]['text'], truncation = True, max_length = self.args.max_seq_length)
42
+ if len(text_tokenized['input_ids']) > 512: print(len(text_tokenized['input_ids']))
43
+ text_tokenized_untruncated = self.tokenizer(self.data[str(idx)]['text'])
44
+ return text_tokenized, self.align_labels(text_tokenized, self.data[str(idx)]['entities'], len(self.data[str(idx)]['text'])), self.align_labels(text_tokenized_untruncated, self.data[str(idx)]['entities'], len(self.data[str(idx)]['text']))
45
+
46
+ def align_labels(self, text_tokenized, entities, length):
47
+ char_to_class = {}
48
+
49
+ for entity in entities:
50
+ for span in entities[entity]["span"]:
51
+ for i in range(span[0], span[1]):
52
+ char_to_class[i] = self.class_to_index[('B-' if i == span[0] else 'I-')+str(entities[entity]["type"])]
53
+
54
+ for i in range(length):
55
+ if i not in char_to_class:
56
+ char_to_class[i] = 0
57
+
58
+ classes = []
59
+ for i in range(len(text_tokenized[0])):
60
+ span = text_tokenized.token_to_chars(i)
61
+ if span is not None:
62
+ classes.append(char_to_class[span.start])
63
+ else:
64
+ classes.append(-100)
65
+
66
+ return torch.LongTensor(classes)
67
+
68
+ def make_html(word_tokens, predictions):
69
+
70
+ toreturn = '''<!DOCTYPE html>
71
+ <html>
72
+ <head>
73
+ <title>Named Entity Recognition Visualization</title>
74
+ <style>
75
+ .EXAMPLE_LABEL {
76
+ color: red;
77
+ text-decoration: underline red;
78
+ }
79
+ .REACTION_PRODUCT {
80
+ color: orange;
81
+ text-decoration: underline orange;
82
+ }
83
+ .STARTING_MATERIAL {
84
+ color: gold;
85
+ text-decoration: underline gold;
86
+ }
87
+ .REAGENT_CATALYST {
88
+ color: green;
89
+ text-decoration: underline green;
90
+ }
91
+ .SOLVENT {
92
+ color: cyan;
93
+ text-decoration: underline cyan;
94
+ }
95
+ .OTHER_COMPOUND {
96
+ color: blue;
97
+ text-decoration: underline blue;
98
+ }
99
+ .TIME {
100
+ color: purple;
101
+ text-decoration: underline purple;
102
+ }
103
+ .TEMPERATURE {
104
+ color: magenta;
105
+ text-decoration: underline magenta;
106
+ }
107
+ .YIELD_OTHER {
108
+ color: palegreen;
109
+ text-decoration: underline palegreen;
110
+ }
111
+ .YIELD_PERCENT {
112
+ color: pink;
113
+ text-decoration: underline pink;
114
+ }
115
+ </style>
116
+ </head>
117
+ <body>
118
+ <p>'''
119
+ last_label = None
120
+ for idx, item in enumerate(word_tokens):
121
+ decoded = self.tokenizer.decode(item, skip_special_tokens = True)
122
+ if len(decoded)>0:
123
+ if idx!=0 and decoded[0]!='#':
124
+ toreturn+=" "
125
+ label = predictions[idx]
126
+ if label == last_label:
127
+
128
+ toreturn+=decoded if decoded[0]!="#" else decoded[2:]
129
+ else:
130
+ if last_label is not None and last_label>0:
131
+ toreturn+="</u>"
132
+ if label >0:
133
+ toreturn+="<u class=\""
134
+ toreturn+=self.index_to_class[label]
135
+ toreturn+="\">"
136
+ toreturn+=decoded if decoded[0]!="#" else decoded[2:]
137
+ if label == 0:
138
+ toreturn+=decoded if decoded[0]!="#" else decoded[2:]
139
+ if idx==len(word_tokens) and label>0:
140
+ toreturn+="</u>"
141
+ last_label = label
142
+
143
+ toreturn += ''' </p>
144
+ </body>
145
+ </html>'''
146
+ return toreturn
147
+
148
+
149
+ def get_collate_fn():
150
+ def collate(batch):
151
+
152
+
153
+
154
+ sentences = []
155
+ masks = []
156
+ refs = []
157
+
158
+
159
+ for ex in batch:
160
+ sentences.append(torch.LongTensor(ex[0]['input_ids']))
161
+ masks.append(torch.Tensor(ex[0]['attention_mask']))
162
+ refs.append(ex[1])
163
+
164
+ sentences = pad_sequence(sentences, batch_first = True, padding_value = 0)
165
+ masks = pad_sequence(masks, batch_first = True, padding_value = 0)
166
+ refs = pad_sequence(refs, batch_first = True, padding_value = -100)
167
+ return sentences, masks, refs
168
+
169
+ return collate
170
+
171
+
172
+
chemiener/interface.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from typing import List
4
+ import torch
5
+ import numpy as np
6
+
7
+ from .model import build_model
8
+
9
+ from .dataset import NERDataset, get_collate_fn
10
+
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ from .utils import get_class_to_index
14
+
15
+ class ChemNER:
16
+
17
+ def __init__(self, model_path, device = None, cache_dir = None):
18
+
19
+ self.args = self._get_args(cache_dir)
20
+
21
+ states = torch.load(model_path, map_location = torch.device('cpu'))
22
+
23
+ if device is None:
24
+ device = torch.device('cpu')
25
+
26
+ self.device = device
27
+
28
+ self.model = self.get_model(self.args, device, states['state_dict'])
29
+
30
+ self.collate = get_collate_fn()
31
+
32
+ self.dataset = NERDataset(self.args, data_file = None)
33
+
34
+ self.class_to_index = get_class_to_index(self.args.corpus)
35
+
36
+ self.index_to_class = {self.class_to_index[key]: key for key in self.class_to_index}
37
+
38
+ def _get_args(self, cache_dir):
39
+ parser = argparse.ArgumentParser()
40
+
41
+ parser.add_argument('--roberta_checkpoint', default = 'dmis-lab/biobert-large-cased-v1.1', type=str, help='which roberta config to use')
42
+
43
+ parser.add_argument('--corpus', default = "chemdner", type=str, help="which corpus should the tags be from")
44
+
45
+ args = parser.parse_args([])
46
+
47
+ args.cache_dir = cache_dir
48
+
49
+ return args
50
+
51
+ def get_model(self, args, device, model_states):
52
+ model = build_model(args)
53
+
54
+ def remove_prefix(state_dict):
55
+ return {k.replace('model.', ''): v for k, v in state_dict.items()}
56
+
57
+ model.load_state_dict(remove_prefix(model_states), strict = False)
58
+
59
+ model.to(device)
60
+
61
+ model.eval()
62
+
63
+ return model
64
+
65
+ def predict_strings(self, strings: List, batch_size = 8):
66
+ device = self.device
67
+
68
+ predictions = []
69
+
70
+ def prepare_output(char_span, prediction):
71
+ toreturn = []
72
+
73
+
74
+ i = 0
75
+
76
+ while i < len(char_span):
77
+ if prediction[i][0] == 'B':
78
+ toreturn.append((prediction[i][2:], [char_span[i].start, char_span[i].end]))
79
+
80
+
81
+
82
+
83
+ elif len(toreturn) > 0 and prediction[i][2:] == toreturn[-1][0]:
84
+ toreturn[-1] = (toreturn[-1][0], [toreturn[-1][1][0], char_span[i].end])
85
+
86
+
87
+
88
+ i += 1
89
+
90
+
91
+ return toreturn
92
+
93
+ output = []
94
+ for idx in range(0, len(strings), batch_size):
95
+ batch_strings = strings[idx:idx+batch_size]
96
+ batch_strings_tokenized = [(self.dataset.tokenizer(s, truncation = True, max_length = 512), torch.Tensor([-1]), torch.Tensor([-1]) ) for s in batch_strings]
97
+
98
+
99
+ sentences, masks, refs = self.collate(batch_strings_tokenized)
100
+
101
+ predictions = self.model(input_ids = sentences.to(device), attention_mask = masks.to(device))[0].argmax(dim = 2).to('cpu')
102
+
103
+ sentences_list = list(sentences)
104
+
105
+ predictions_list = list(predictions)
106
+
107
+
108
+ char_spans = []
109
+ for j, sentence in enumerate(sentences_list):
110
+ to_add = [batch_strings_tokenized[j][0].token_to_chars(i) for i, word in enumerate(sentence) if len(self.dataset.tokenizer.decode(int(word.item()), skip_special_tokens = True)) > 0 ]
111
+ char_spans.append(to_add)
112
+
113
+ class_predictions = [[self.index_to_class[int(pred.item())] for (pred, word) in zip(sentence_p, sentence_w) if len(self.dataset.tokenizer.decode(int(word.item()), skip_special_tokens = True)) > 0] for (sentence_p, sentence_w) in zip(predictions_list, sentences_list)]
114
+
115
+
116
+
117
+ output+=[prepare_output(char_span, prediction) for char_span, prediction in zip(char_spans, class_predictions)]
118
+
119
+ return output
120
+
121
+
122
+
123
+
124
+
chemiener/main.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import json
4
+ import random
5
+ import argparse
6
+ import numpy as np
7
+
8
+ import time
9
+
10
+ import torch
11
+ from torch.profiler import profile, record_function, ProfilerActivity
12
+ import torch.distributed as dist
13
+ import pytorch_lightning as pl
14
+ from pytorch_lightning import LightningModule, LightningDataModule
15
+ from pytorch_lightning.callbacks import LearningRateMonitor
16
+ from pytorch_lightning.strategies.ddp import DDPStrategy
17
+ from transformers import get_scheduler
18
+ import transformers
19
+
20
+ from dataset import NERDataset, get_collate_fn
21
+
22
+ from model import build_model
23
+
24
+ from utils import get_class_to_index
25
+
26
+ import evaluate
27
+
28
+ from seqeval.metrics import accuracy_score
29
+ from seqeval.metrics import classification_report
30
+ from seqeval.metrics import f1_score
31
+ from seqeval.scheme import IOB2
32
+
33
+
34
+
35
+ def get_args(notebook=False):
36
+ parser = argparse.ArgumentParser()
37
+ parser.add_argument('--do_train', action='store_true')
38
+ parser.add_argument('--do_valid', action='store_true')
39
+ parser.add_argument('--do_test', action='store_true')
40
+ parser.add_argument('--fp16', action='store_true')
41
+ parser.add_argument('--seed', type=int, default=42)
42
+ parser.add_argument('--gpus', type=int, default=1)
43
+ parser.add_argument('--print_freq', type=int, default=200)
44
+ parser.add_argument('--debug', action='store_true')
45
+ parser.add_argument('--no_eval', action='store_true')
46
+
47
+
48
+ # Data
49
+ parser.add_argument('--data_path', type=str, default=None)
50
+ parser.add_argument('--image_path', type=str, default=None)
51
+ parser.add_argument('--train_file', type=str, default=None)
52
+ parser.add_argument('--valid_file', type=str, default=None)
53
+ parser.add_argument('--test_file', type=str, default=None)
54
+ parser.add_argument('--vocab_file', type=str, default=None)
55
+ parser.add_argument('--format', type=str, default='reaction')
56
+ parser.add_argument('--num_workers', type=int, default=8)
57
+ parser.add_argument('--input_size', type=int, default=224)
58
+
59
+ # Training
60
+ parser.add_argument('--epochs', type=int, default=8)
61
+ parser.add_argument('--batch_size', type=int, default=256)
62
+ parser.add_argument('--lr', type=float, default=1e-4)
63
+ parser.add_argument('--weight_decay', type=float, default=0.05)
64
+ parser.add_argument('--max_grad_norm', type=float, default=5.)
65
+ parser.add_argument('--scheduler', type=str, choices=['cosine', 'constant'], default='cosine')
66
+ parser.add_argument('--warmup_ratio', type=float, default=0)
67
+ parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
68
+ parser.add_argument('--load_path', type=str, default=None)
69
+ parser.add_argument('--load_encoder_only', action='store_true')
70
+ parser.add_argument('--train_steps_per_epoch', type=int, default=-1)
71
+ parser.add_argument('--eval_per_epoch', type=int, default=10)
72
+ parser.add_argument('--save_path', type=str, default='output/')
73
+ parser.add_argument('--save_mode', type=str, default='best', choices=['best', 'all', 'last'])
74
+ parser.add_argument('--load_ckpt', type=str, default='best')
75
+ parser.add_argument('--resume', action='store_true')
76
+ parser.add_argument('--num_train_example', type=int, default=None)
77
+
78
+ parser.add_argument('--roberta_checkpoint', type=str, default = "roberta-base")
79
+
80
+ parser.add_argument('--corpus', type=str, default = "chemu")
81
+
82
+ parser.add_argument('--cache_dir')
83
+
84
+ parser.add_argument('--eval_truncated', action='store_true')
85
+
86
+ parser.add_argument('--max_seq_length', type = int, default=512)
87
+
88
+ args = parser.parse_args([]) if notebook else parser.parse_args()
89
+
90
+
91
+
92
+
93
+
94
+ return args
95
+
96
+
97
+ class ChemIENERecognizer(LightningModule):
98
+
99
+ def __init__(self, args):
100
+ super().__init__()
101
+
102
+ self.args = args
103
+
104
+ self.model = build_model(args)
105
+
106
+ self.validation_step_outputs = []
107
+
108
+ def training_step(self, batch, batch_idx):
109
+
110
+
111
+
112
+
113
+ sentences, masks, refs,_ = batch
114
+ '''
115
+ print("sentences " + str(sentences))
116
+ print("sentence shape " + str(sentences.shape))
117
+ print("masks " + str(masks))
118
+ print("masks shape " + str(masks.shape))
119
+ print("refs " + str(refs))
120
+ print("refs shape " + str(refs.shape))
121
+ '''
122
+
123
+
124
+
125
+ loss, logits = self.model(input_ids=sentences, attention_mask=masks, labels=refs)
126
+ self.log('train/loss', loss)
127
+ self.log('lr', self.lr_schedulers().get_lr()[0], prog_bar=True, logger=False)
128
+ return loss
129
+
130
+ def validation_step(self, batch, batch_idx):
131
+
132
+ sentences, masks, refs, untruncated = batch
133
+ '''
134
+ print("sentences " + str(sentences))
135
+ print("sentence shape " + str(sentences.shape))
136
+ print("masks " + str(masks))
137
+ print("masks shape " + str(masks.shape))
138
+ print("refs " + str(refs))
139
+ print("refs shape " + str(refs.shape))
140
+ '''
141
+
142
+ logits = self.model(input_ids = sentences, attention_mask=masks)[0]
143
+ '''
144
+ print("logits " + str(logits))
145
+ print(sentences.shape)
146
+ print(logits.shape)
147
+ print(torch.eq(logits.argmax(dim = 2), refs).sum())
148
+ '''
149
+ self.validation_step_outputs.append((sentences.to("cpu"), logits.argmax(dim = 2).to("cpu"), refs.to('cpu'), untruncated.to("cpu")))
150
+
151
+
152
+ def on_validation_epoch_end(self):
153
+ if self.trainer.num_devices > 1:
154
+ gathered_outputs = [None for i in range(self.trainer.num_devices)]
155
+ dist.all_gather_object(gathered_outputs, self.validation_step_outputs)
156
+ gathered_outputs = sum(gathered_outputs, [])
157
+ else:
158
+ gathered_outputs = self.validation_step_outputs
159
+
160
+ sentences = [list(output[0]) for output in gathered_outputs]
161
+
162
+ class_to_index = get_class_to_index(self.args.corpus)
163
+
164
+
165
+
166
+ index_to_class = {class_to_index[key]: key for key in class_to_index}
167
+ predictions = [list(output[1]) for output in gathered_outputs]
168
+ labels = [list(output[2]) for output in gathered_outputs]
169
+
170
+ untruncateds = [list(output[3]) for output in gathered_outputs]
171
+
172
+ untruncateds = [[index_to_class[int(label.item())] for label in sentence if int(label.item()) != -100] for batched in untruncateds for sentence in batched]
173
+
174
+
175
+ output = {"sentences": [[int(word.item()) for (word, label) in zip(sentence_w, sentence_l) if label != -100] for (batched_w, batched_l) in zip(sentences, labels) for (sentence_w, sentence_l) in zip(batched_w, batched_l) ],
176
+ "predictions": [[index_to_class[int(pred.item())] for (pred, label) in zip(sentence_p, sentence_l) if label!=-100] for (batched_p, batched_l) in zip(predictions, labels) for (sentence_p, sentence_l) in zip(batched_p, batched_l) ],
177
+ "groundtruth": [[index_to_class[int(label.item())] for label in sentence if label != -100] for batched in labels for sentence in batched]}
178
+
179
+
180
+ #true_labels = [str(label.item()) for batched in labels for sentence in batched for label in sentence if label != -100]
181
+ #true_predictions = [str(pred.item()) for (batched_p, batched_l) in zip(predictions, labels) for (sentence_p, sentence_l) in zip(batched_p, batched_l) for (pred, label) in zip(sentence_p, sentence_l) if label!=-100 ]
182
+
183
+
184
+
185
+ #print("true_label " + str(len(true_labels)) + " true_predictions "+str(len(true_predictions)))
186
+
187
+
188
+ #predictions = utils.merge_predictions(gathered_outputs)
189
+ name = self.eval_dataset.name
190
+ scores = [0]
191
+
192
+ #print(predictions)
193
+ #print(predictions[0].shape)
194
+
195
+ if self.trainer.is_global_zero:
196
+ if not self.args.no_eval:
197
+ epoch = self.trainer.current_epoch
198
+
199
+ metric = evaluate.load("seqeval", cache_dir = self.args.cache_dir)
200
+
201
+ predictions = [ preds + ['O'] * (len(full_groundtruth) - len(preds)) for (preds, full_groundtruth) in zip(output['predictions'], untruncateds)]
202
+ all_metrics = metric.compute(predictions = predictions, references = untruncateds)
203
+
204
+ #accuracy = sum([1 if p == l else 0 for (p, l) in zip(true_predictions, true_labels)])/len(true_labels)
205
+
206
+ #precision = torch.eq(self.eval_dataset.data, predictions.argmax(dim = 1)).sum().float()/self.eval_dataset.data.numel()
207
+ #self.print("Epoch: "+str(epoch)+" accuracy: "+str(accuracy))
208
+ if self.args.eval_truncated:
209
+ report = classification_report(output['groundtruth'], output['predictions'], mode = 'strict', scheme = IOB2, output_dict = True)
210
+ else:
211
+ #report = classification_report(predictions, untruncateds, output_dict = True)#, mode = 'strict', scheme = IOB2, output_dict = True)
212
+ report = classification_report(predictions, untruncateds, mode = 'strict', scheme = IOB2, output_dict = True)
213
+ self.print(report)
214
+ #self.print("______________________________________________")
215
+ #self.print(report_strict)
216
+ scores = [report['micro avg']['f1-score']]
217
+ with open(os.path.join(self.trainer.default_root_dir, f'prediction_{name}.json'), 'w') as f:
218
+ json.dump(output, f)
219
+
220
+ dist.broadcast_object_list(scores)
221
+
222
+ self.log('val/score', scores[0], prog_bar=True, rank_zero_only=True)
223
+ self.validation_step_outputs.clear()
224
+
225
+
226
+
227
+ self.validation_step_outputs.clear()
228
+
229
+ def configure_optimizers(self):
230
+ num_training_steps = self.trainer.num_training_steps
231
+
232
+ self.print(f'Num training steps: {num_training_steps}')
233
+ num_warmup_steps = int(num_training_steps * self.args.warmup_ratio)
234
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
235
+ scheduler = get_scheduler(self.args.scheduler, optimizer, num_warmup_steps, num_training_steps)
236
+ return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'}}
237
+
238
+ class NERDataModule(LightningDataModule):
239
+
240
+ def __init__(self, args):
241
+ super().__init__()
242
+ self.args = args
243
+ self.collate_fn = get_collate_fn()
244
+
245
+ def prepare_data(self):
246
+ args = self.args
247
+ if args.do_train:
248
+ self.train_dataset = NERDataset(args, args.train_file, split='train')
249
+ if self.args.do_train or self.args.do_valid:
250
+ self.val_dataset = NERDataset(args, args.valid_file, split='valid')
251
+ if self.args.do_test:
252
+ self.test_dataset = NERDataset(args, args.test_file, split='valid')
253
+
254
+ def print_stats(self):
255
+ if self.args.do_train:
256
+ print(f'Train dataset: {len(self.train_dataset)}')
257
+ if self.args.do_train or self.args.do_valid:
258
+ print(f'Valid dataset: {len(self.val_dataset)}')
259
+ if self.args.do_test:
260
+ print(f'Test dataset: {len(self.test_dataset)}')
261
+
262
+
263
+ def train_dataloader(self):
264
+ return torch.utils.data.DataLoader(
265
+ self.train_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers,
266
+ collate_fn=self.collate_fn)
267
+
268
+ def val_dataloader(self):
269
+ return torch.utils.data.DataLoader(
270
+ self.val_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers,
271
+ collate_fn=self.collate_fn)
272
+
273
+
274
+ def test_dataloader(self):
275
+ return torch.utils.data.DataLoader(
276
+ self.test_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers,
277
+ collate_fn=self.collate_fn)
278
+
279
+
280
+
281
+ class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
282
+ def _get_metric_interpolated_filepath_name(self, monitor_candidates, trainer, del_filepath=None) -> str:
283
+ filepath = self.format_checkpoint_name(monitor_candidates)
284
+ return filepath
285
+
286
+ def main():
287
+ transformers.utils.logging.set_verbosity_error()
288
+ args = get_args()
289
+
290
+ pl.seed_everything(args.seed, workers = True)
291
+
292
+ if args.do_train:
293
+ model = ChemIENERecognizer(args)
294
+ else:
295
+ model = ChemIENERecognizer.load_from_checkpoint(os.path.join(args.save_path, 'checkpoints/best.ckpt'), strict=False,
296
+ args=args)
297
+
298
+ dm = NERDataModule(args)
299
+ dm.prepare_data()
300
+ dm.print_stats()
301
+
302
+ checkpoint = ModelCheckpoint(monitor='val/score', mode='max', save_top_k=1, filename='best', save_last=True)
303
+ # checkpoint = ModelCheckpoint(monitor=None, save_top_k=0, save_last=True)
304
+ lr_monitor = LearningRateMonitor(logging_interval='step')
305
+ logger = pl.loggers.TensorBoardLogger(args.save_path, name='', version='')
306
+
307
+ trainer = pl.Trainer(
308
+ strategy=DDPStrategy(find_unused_parameters=False),
309
+ accelerator='gpu',
310
+ precision = 16,
311
+ devices=args.gpus,
312
+ logger=logger,
313
+ default_root_dir=args.save_path,
314
+ callbacks=[checkpoint, lr_monitor],
315
+ max_epochs=args.epochs,
316
+ gradient_clip_val=args.max_grad_norm,
317
+ accumulate_grad_batches=args.gradient_accumulation_steps,
318
+ check_val_every_n_epoch=args.eval_per_epoch,
319
+ log_every_n_steps=10,
320
+ deterministic='warn')
321
+
322
+ if args.do_train:
323
+ trainer.num_training_steps = math.ceil(
324
+ len(dm.train_dataset) / (args.batch_size * args.gpus * args.gradient_accumulation_steps)) * args.epochs
325
+ model.eval_dataset = dm.val_dataset
326
+ ckpt_path = os.path.join(args.save_path, 'checkpoints/last.ckpt') if args.resume else None
327
+ trainer.fit(model, datamodule=dm, ckpt_path=ckpt_path)
328
+ model = ChemIENERecognizer.load_from_checkpoint(checkpoint.best_model_path, args=args)
329
+
330
+ if args.do_valid:
331
+
332
+ model.eval_dataset = dm.val_dataset
333
+
334
+ trainer.validate(model, datamodule=dm)
335
+
336
+ if args.do_test:
337
+
338
+ model.test_dataset = dm.test_dataset
339
+
340
+ trainer.test(model, datamodule=dm)
341
+
342
+
343
+ if __name__ == "__main__":
344
+ main()
345
+
chemiener/model.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ from transformers import BertForTokenClassification, RobertaForTokenClassification, AutoModelForTokenClassification
6
+
7
+
8
+ def build_model(args):
9
+ if args.corpus == "chemu":
10
+ return AutoModelForTokenClassification.from_pretrained(args.roberta_checkpoint, num_labels = 21, cache_dir = args.cache_dir, return_dict = False)
11
+ elif args.corpus == "chemdner":
12
+ return AutoModelForTokenClassification.from_pretrained(args.roberta_checkpoint, num_labels = 17, cache_dir = args.cache_dir, return_dict = False)
13
+ elif args.corpus == "chemdner-mol":
14
+ return AutoModelForTokenClassification.from_pretrained(args.roberta_checkpoint, num_labels = 3, cache_dir = args.cache_dir, return_dict = False)
chemiener/utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ def merge_predictions(results):
3
+ if len(results) == 0:
4
+ return []
5
+ predictions = {}
6
+ for batch_preds in results:
7
+ for idx, preds in enumerate(batch_preds):
8
+ predictions[idx] = preds
9
+ predictions = [predictions[i] for i in range(len(predictions))]
10
+
11
+ return predictions
12
+
13
+ def get_class_to_index(corpus):
14
+ if corpus == "chemu":
15
+ return {'B-EXAMPLE_LABEL': 1, 'B-REACTION_PRODUCT': 2, 'B-STARTING_MATERIAL': 3, 'B-REAGENT_CATALYST': 4, 'B-SOLVENT': 5, 'B-OTHER_COMPOUND': 6, 'B-TIME': 7, 'B-TEMPERATURE': 8, 'B-YIELD_OTHER': 9, 'B-YIELD_PERCENT': 10, 'O': 0,
16
+ 'I-EXAMPLE_LABEL': 11, 'I-REACTION_PRODUCT': 12, 'I-STARTING_MATERIAL': 13, 'I-REAGENT_CATALYST': 14, 'I-SOLVENT': 15, 'I-OTHER_COMPOUND': 16, 'I-TIME': 17, 'I-TEMPERATURE': 18, 'I-YIELD_OTHER': 19, 'I-YIELD_PERCENT': 20}
17
+ elif corpus == "chemdner":
18
+ return {'O': 0, 'B-ABBREVIATION': 1, 'B-FAMILY': 2, 'B-FORMULA': 3, 'B-IDENTIFIER': 4, 'B-MULTIPLE': 5, 'B-SYSTEMATIC': 6, 'B-TRIVIAL': 7, 'B-NO CLASS': 8, 'I-ABBREVIATION': 9, 'I-FAMILY': 10, 'I-FORMULA': 11, 'I-IDENTIFIER': 12, 'I-MULTIPLE': 13, 'I-SYSTEMATIC': 14, 'I-TRIVIAL': 15, 'I-NO CLASS': 16}
19
+ elif corpus == "chemdner-mol":
20
+ return {'O': 0, 'B-MOL': 1, 'I-MOL': 2}
21
+
22
+
23
+
chemietoolkit/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .interface import ChemIEToolkit
chemietoolkit/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (198 Bytes). View file
 
chemietoolkit/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (189 Bytes). View file
 
chemietoolkit/__pycache__/chemrxnextractor.cpython-310.pyc ADDED
Binary file (3.66 kB). View file
 
chemietoolkit/__pycache__/chemrxnextractor.cpython-38.pyc ADDED
Binary file (3.62 kB). View file
 
chemietoolkit/__pycache__/interface.cpython-310.pyc ADDED
Binary file (29.3 kB). View file
 
chemietoolkit/__pycache__/interface.cpython-38.pyc ADDED
Binary file (30 kB). View file
 
chemietoolkit/__pycache__/tableextractor.cpython-310.pyc ADDED
Binary file (10.3 kB). View file
 
chemietoolkit/__pycache__/utils.cpython-310.pyc ADDED
Binary file (25 kB). View file
 
chemietoolkit/chemrxnextractor.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PyPDF2 import PdfReader, PdfWriter
2
+ import pdfminer.high_level
3
+ import pdfminer.layout
4
+ from operator import itemgetter
5
+ import os
6
+ import pdftotext
7
+ from chemrxnextractor import RxnExtractor
8
+
9
+ class ChemRxnExtractor(object):
10
+ def __init__(self, pdf, pn, model_dir, device):
11
+ self.pdf_file = pdf
12
+ self.pages = pn
13
+ self.model_dir = os.path.join(model_dir, "cre_models_v0.1") # directory saving both prod and role models
14
+ use_cuda = (device == 'cuda')
15
+ self.rxn_extractor = RxnExtractor(self.model_dir, use_cuda=use_cuda)
16
+ self.text_file = "info.txt"
17
+ self.pdf_text = ""
18
+ if len(self.pdf_file) > 0:
19
+ with open(self.pdf_file, "rb") as f:
20
+ self.pdf_text = pdftotext.PDF(f)
21
+
22
+ def set_pdf_file(self, pdf):
23
+ self.pdf_file = pdf
24
+ with open(self.pdf_file, "rb") as f:
25
+ self.pdf_text = pdftotext.PDF(f)
26
+
27
+ def set_pages(self, pn):
28
+ self.pages = pn
29
+
30
+ def set_model_dir(self, md):
31
+ self.model_dir = md
32
+ self.rxn_extractor = RxnExtractor(self.model_dir)
33
+
34
+ def set_text_file(self, tf):
35
+ self.text_file = tf
36
+
37
+ def extract_reactions_from_text(self):
38
+ if self.pages is None:
39
+ return self.extract_all(len(self.pdf_text))
40
+ else:
41
+ return self.extract_all(self.pages)
42
+
43
+ def extract_all(self, pages):
44
+ ans = []
45
+ text = self.get_paragraphs_from_pdf(pages)
46
+ for data in text:
47
+ L = [sent for paragraph in data['paragraphs'] for sent in paragraph]
48
+ reactions = self.get_reactions(L, page_number=data['page'])
49
+ ans.append(reactions)
50
+ return ans
51
+
52
+ def get_reactions(self, sents, page_number=None):
53
+ rxns = self.rxn_extractor.get_reactions(sents)
54
+
55
+ ret = []
56
+ for r in rxns:
57
+ if len(r['reactions']) != 0: ret.append(r)
58
+ ans = {}
59
+ ans.update({'page' : page_number})
60
+ ans.update({'reactions' : ret})
61
+ return ans
62
+
63
+
64
+ def get_paragraphs_from_pdf(self, pages):
65
+ current_page_num = 1
66
+ if pages is None:
67
+ pages = len(self.pdf_text)
68
+ result = []
69
+ for page in range(pages):
70
+ content = self.pdf_text[page]
71
+ pg = content.split("\n\n")
72
+ L = []
73
+ for line in pg:
74
+ paragraph = []
75
+ if '\x0c' in line:
76
+ continue
77
+ text = line
78
+ text = text.replace("\n", " ")
79
+ text = text.replace("- ", "-")
80
+ curind = 0
81
+ i = 0
82
+ while i < len(text):
83
+ if text[i] == '.':
84
+ if i != 0 and not text[i-1].isdigit() or i != len(text) - 1 and (text[i+1] == " " or text[i+1] == "\n"):
85
+ paragraph.append(text[curind:i+1] + "\n")
86
+ while(i < len(text) and text[i] != " "):
87
+ i += 1
88
+ curind = i + 1
89
+ i += 1
90
+ if curind != i:
91
+ if text[i - 1] == " ":
92
+ if i != 1:
93
+ i -= 1
94
+ else:
95
+ break
96
+ if text[i - 1] != '.':
97
+ paragraph.append(text[curind:i] + ".\n")
98
+ else:
99
+ paragraph.append(text[curind:i] + "\n")
100
+ L.append(paragraph)
101
+
102
+ result.append({
103
+ 'paragraphs': L,
104
+ 'page': current_page_num
105
+ })
106
+ current_page_num += 1
107
+ return result
chemietoolkit/interface.py ADDED
@@ -0,0 +1,749 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import re
3
+ from functools import lru_cache
4
+ import layoutparser as lp
5
+ import pdf2image
6
+ from PIL import Image
7
+ from huggingface_hub import hf_hub_download, snapshot_download
8
+ from molscribe import MolScribe
9
+ from rxnscribe import RxnScribe, MolDetect
10
+ from chemiener import ChemNER
11
+ from .chemrxnextractor import ChemRxnExtractor
12
+ from .tableextractor import TableExtractor
13
+ from .utils import *
14
+
15
+ class ChemIEToolkit:
16
+ def __init__(self, device=None):
17
+ if device is None:
18
+ self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
19
+ else:
20
+ self.device = torch.device(device)
21
+
22
+ self._molscribe = None
23
+ self._rxnscribe = None
24
+ self._pdfparser = None
25
+ self._moldet = None
26
+ self._chemrxnextractor = None
27
+ self._chemner = None
28
+ self._coref = None
29
+
30
+ @property
31
+ def molscribe(self):
32
+ if self._molscribe is None:
33
+ self.init_molscribe()
34
+ return self._molscribe
35
+
36
+ @lru_cache(maxsize=None)
37
+ def init_molscribe(self, ckpt_path=None):
38
+ """
39
+ Set model to custom checkpoint
40
+ Parameters:
41
+ ckpt_path: path to checkpoint to use, if None then will use default
42
+ """
43
+ if ckpt_path is None:
44
+ ckpt_path = hf_hub_download("yujieq/MolScribe", "swin_base_char_aux_1m.pth")
45
+ self._molscribe = MolScribe(ckpt_path, device=self.device)
46
+
47
+
48
+ @property
49
+ def rxnscribe(self):
50
+ if self._rxnscribe is None:
51
+ self.init_rxnscribe()
52
+ return self._rxnscribe
53
+
54
+ @lru_cache(maxsize=None)
55
+ def init_rxnscribe(self, ckpt_path=None):
56
+ """
57
+ Set model to custom checkpoint
58
+ Parameters:
59
+ ckpt_path: path to checkpoint to use, if None then will use default
60
+ """
61
+ if ckpt_path is None:
62
+ ckpt_path = hf_hub_download("yujieq/RxnScribe", "pix2seq_reaction_full.ckpt")
63
+ self._rxnscribe = RxnScribe(ckpt_path, device=self.device)
64
+
65
+
66
+ @property
67
+ def pdfparser(self):
68
+ if self._pdfparser is None:
69
+ self.init_pdfparser()
70
+ return self._pdfparser
71
+
72
+ @lru_cache(maxsize=None)
73
+ def init_pdfparser(self, ckpt_path=None):
74
+ """
75
+ Set model to custom checkpoint
76
+ Parameters:
77
+ ckpt_path: path to checkpoint to use, if None then will use default
78
+ """
79
+ config_path = "lp://efficientdet/PubLayNet/tf_efficientdet_d1"
80
+ self._pdfparser = lp.AutoLayoutModel(config_path, model_path=ckpt_path, device=self.device.type)
81
+
82
+
83
+ @property
84
+ def moldet(self):
85
+ if self._moldet is None:
86
+ self.init_moldet()
87
+ return self._moldet
88
+
89
+ @lru_cache(maxsize=None)
90
+ def init_moldet(self, ckpt_path=None):
91
+ """
92
+ Set model to custom checkpoint
93
+ Parameters:
94
+ ckpt_path: path to checkpoint to use, if None then will use default
95
+ """
96
+ if ckpt_path is None:
97
+ ckpt_path = hf_hub_download("Ozymandias314/MolDetectCkpt", "best_hf.ckpt")
98
+ self._moldet = MolDetect(ckpt_path, device=self.device)
99
+
100
+
101
+ @property
102
+ def coref(self):
103
+ if self._coref is None:
104
+ self.init_coref()
105
+ return self._coref
106
+
107
+ @lru_cache(maxsize=None)
108
+ def init_coref(self, ckpt_path=None):
109
+ """
110
+ Set model to custom checkpoint
111
+ Parameters:
112
+ ckpt_path: path to checkpoint to use, if None then will use default
113
+ """
114
+ if ckpt_path is None:
115
+ ckpt_path = hf_hub_download("Ozymandias314/MolDetectCkpt", "coref_best_hf.ckpt")
116
+ self._coref = MolDetect(ckpt_path, device=self.device, coref=True)
117
+
118
+
119
+ @property
120
+ def chemrxnextractor(self):
121
+ if self._chemrxnextractor is None:
122
+ self.init_chemrxnextractor()
123
+ return self._chemrxnextractor
124
+
125
+ @lru_cache(maxsize=None)
126
+ def init_chemrxnextractor(self, ckpt_path=None):
127
+ """
128
+ Set model to custom checkpoint
129
+ Parameters:
130
+ ckpt_path: path to checkpoint to use, if None then will use default
131
+ """
132
+ if ckpt_path is None:
133
+ ckpt_path = snapshot_download(repo_id="amberwang/chemrxnextractor-training-modules")
134
+ self._chemrxnextractor = ChemRxnExtractor("", None, ckpt_path, self.device.type)
135
+
136
+
137
+ @property
138
+ def chemner(self):
139
+ if self._chemner is None:
140
+ self.init_chemner()
141
+ return self._chemner
142
+
143
+ @lru_cache(maxsize=None)
144
+ def init_chemner(self, ckpt_path=None):
145
+ """
146
+ Set model to custom checkpoint
147
+ Parameters:
148
+ ckpt_path: path to checkpoint to use, if None then will use default
149
+ """
150
+ if ckpt_path is None:
151
+ ckpt_path = hf_hub_download("Ozymandias314/ChemNERckpt", "best.ckpt")
152
+ self._chemner = ChemNER(ckpt_path, device=self.device)
153
+
154
+
155
+ @property
156
+ def tableextractor(self):
157
+ return TableExtractor()
158
+
159
+
160
+ def extract_figures_from_pdf(self, pdf, num_pages=None, output_bbox=False, output_image=True):
161
+ """
162
+ Find and return all figures from a pdf page
163
+ Parameters:
164
+ pdf: path to pdf
165
+ num_pages: process only first `num_pages` pages, if `None` then process all
166
+ output_bbox: whether to output bounding boxes for each individual entry of a table
167
+ output_image: whether to include PIL image for figures. default is True
168
+ Returns:
169
+ list of content in the following format
170
+ [
171
+ { # first figure
172
+ 'title': str,
173
+ 'figure': {
174
+ 'image': PIL image or None,
175
+ 'bbox': list in form [x1, y1, x2, y2],
176
+ }
177
+ 'table': {
178
+ 'bbox': list in form [x1, y1, x2, y2] or empty list,
179
+ 'content': {
180
+ 'columns': list of column headers,
181
+ 'rows': list of list of row content,
182
+ } or None
183
+ }
184
+ 'footnote': str or empty,
185
+ 'page': int
186
+ }
187
+ # more figures
188
+ ]
189
+ """
190
+ pages = pdf2image.convert_from_path(pdf, last_page=num_pages)
191
+
192
+ table_ext = self.tableextractor
193
+ table_ext.set_pdf_file(pdf)
194
+ table_ext.set_output_image(output_image)
195
+
196
+ table_ext.set_output_bbox(output_bbox)
197
+
198
+ return table_ext.extract_all_tables_and_figures(pages, self.pdfparser, content='figures')
199
+
200
+ def extract_tables_from_pdf(self, pdf, num_pages=None, output_bbox=False, output_image=True):
201
+ """
202
+ Find and return all tables from a pdf page
203
+ Parameters:
204
+ pdf: path to pdf
205
+ num_pages: process only first `num_pages` pages, if `None` then process all
206
+ output_bbox: whether to include bboxes for individual entries of the table
207
+ output_image: whether to include PIL image for figures. default is True
208
+ Returns:
209
+ list of content in the following format
210
+ [
211
+ { # first table
212
+ 'title': str,
213
+ 'figure': {
214
+ 'image': PIL image or None,
215
+ 'bbox': list in form [x1, y1, x2, y2] or empty list,
216
+ }
217
+ 'table': {
218
+ 'bbox': list in form [x1, y1, x2, y2] or empty list,
219
+ 'content': {
220
+ 'columns': list of column headers,
221
+ 'rows': list of list of row content,
222
+ }
223
+ }
224
+ 'footnote': str or empty,
225
+ 'page': int
226
+ }
227
+ # more tables
228
+ ]
229
+ """
230
+ pages = pdf2image.convert_from_path(pdf, last_page=num_pages)
231
+
232
+ table_ext = self.tableextractor
233
+ table_ext.set_pdf_file(pdf)
234
+ table_ext.set_output_image(output_image)
235
+
236
+ table_ext.set_output_bbox(output_bbox)
237
+
238
+ return table_ext.extract_all_tables_and_figures(pages, self.pdfparser, content='tables')
239
+
240
+ def extract_molecules_from_figures_in_pdf(self, pdf, batch_size=16, num_pages=None):
241
+ """
242
+ Get all molecules and their information from a pdf
243
+ Parameters:
244
+ pdf: path to pdf, or byte file
245
+ batch_size: batch size for inference in all models
246
+ num_pages: process only first `num_pages` pages, if `None` then process all
247
+ Returns:
248
+ list of figures and corresponding molecule info in the following format
249
+ [
250
+ { # first figure
251
+ 'image': ndarray of the figure image,
252
+ 'molecules': [
253
+ { # first molecule
254
+ 'bbox': tuple in the form (x1, y1, x2, y2),
255
+ 'score': float,
256
+ 'image': ndarray of cropped molecule image,
257
+ 'smiles': str,
258
+ 'molfile': str
259
+ },
260
+ # more molecules
261
+ ],
262
+ 'page': int
263
+ },
264
+ # more figures
265
+ ]
266
+ """
267
+ figures = self.extract_figures_from_pdf(pdf, num_pages=num_pages, output_bbox=True)
268
+ images = [figure['figure']['image'] for figure in figures]
269
+ results = self.extract_molecules_from_figures(images, batch_size=batch_size)
270
+ for figure, result in zip(figures, results):
271
+ result['page'] = figure['page']
272
+ return results
273
+
274
+ def extract_molecule_bboxes_from_figures(self, figures, batch_size=16):
275
+ """
276
+ Return bounding boxes of molecules in images
277
+ Parameters:
278
+ figures: list of PIL or ndarray images
279
+ batch_size: batch size for inference
280
+ Returns:
281
+ list of results for each figure in the following format
282
+ [
283
+ [ # first figure
284
+ { # first bounding box
285
+ 'category': str,
286
+ 'bbox': tuple in the form (x1, y1, x2, y2),
287
+ 'category_id': int,
288
+ 'score': float
289
+ },
290
+ # more bounding boxes
291
+ ],
292
+ # more figures
293
+ ]
294
+ """
295
+ figures = [convert_to_pil(figure) for figure in figures]
296
+ return self.moldet.predict_images(figures, batch_size=batch_size)
297
+
298
+ def extract_molecules_from_figures(self, figures, batch_size=16):
299
+ """
300
+ Get all molecules and their information from list of figures
301
+ Parameters:
302
+ figures: list of PIL or ndarray images
303
+ batch_size: batch size for inference
304
+ Returns:
305
+ list of results for each figure in the following format
306
+ [
307
+ { # first figure
308
+ 'image': ndarray of the figure image,
309
+ 'molecules': [
310
+ { # first molecule
311
+ 'bbox': tuple in the form (x1, y1, x2, y2),
312
+ 'score': float,
313
+ 'image': ndarray of cropped molecule image,
314
+ 'smiles': str,
315
+ 'molfile': str
316
+ },
317
+ # more molecules
318
+ ],
319
+ },
320
+ # more figures
321
+ ]
322
+ """
323
+ bboxes = self.extract_molecule_bboxes_from_figures(figures, batch_size=batch_size)
324
+ figures = [convert_to_cv2(figure) for figure in figures]
325
+ results, cropped_images, refs = clean_bbox_output(figures, bboxes)
326
+ mol_info = self.molscribe.predict_images(cropped_images, batch_size=batch_size)
327
+ for info, ref in zip(mol_info, refs):
328
+ ref.update(info)
329
+ return results
330
+
331
+ def extract_molecule_corefs_from_figures_in_pdf(self, pdf, batch_size=16, num_pages=None, molscribe = True, ocr = True):
332
+ """
333
+ Get all molecule bboxes and corefs from figures in pdf
334
+ Parameters:
335
+ pdf: path to pdf, or byte file
336
+ batch_size: batch size for inference in all models
337
+ num_pages: process only first `num_pages` pages, if `None` then process all
338
+ Returns:
339
+ list of results for each figure in the following format:
340
+ [
341
+ {
342
+ 'bboxes': [
343
+ { # first bbox
344
+ 'category': '[Sup]',
345
+ 'bbox': (0.0050025012506253125, 0.38273870663142223, 0.9934967483741871, 0.9450094869920168),
346
+ 'category_id': 4,
347
+ 'score': -0.07593922317028046
348
+ },
349
+ # More bounding boxes
350
+ ],
351
+ 'corefs': [
352
+ [0, 1], # molecule bbox index, identifier bbox index
353
+ [3, 4],
354
+ # More coref pairs
355
+ ],
356
+ 'page': int
357
+ },
358
+ # More figures
359
+ ]
360
+ """
361
+ figures = self.extract_figures_from_pdf(pdf, num_pages=num_pages, output_bbox=True)
362
+ images = [figure['figure']['image'] for figure in figures]
363
+ results = self.extract_molecule_corefs_from_figures(images, batch_size=batch_size, molscribe=molscribe, ocr=ocr)
364
+ for figure, result in zip(figures, results):
365
+ result['page'] = figure['page']
366
+ return results
367
+
368
+ def extract_molecule_corefs_from_figures(self, figures, batch_size=16, molscribe=True, ocr=True):
369
+ """
370
+ Get all molecule bboxes and corefs from list of figures
371
+ Parameters:
372
+ figures: list of PIL or ndarray images
373
+ batch_size: batch size for inference
374
+ Returns:
375
+ list of results for each figure in the following format:
376
+ [
377
+ {
378
+ 'bboxes': [
379
+ { # first bbox
380
+ 'category': '[Sup]',
381
+ 'bbox': (0.0050025012506253125, 0.38273870663142223, 0.9934967483741871, 0.9450094869920168),
382
+ 'category_id': 4,
383
+ 'score': -0.07593922317028046
384
+ },
385
+ # More bounding boxes
386
+ ],
387
+ 'corefs': [
388
+ [0, 1], # molecule bbox index, identifier bbox index
389
+ [3, 4],
390
+ # More coref pairs
391
+ ],
392
+ },
393
+ # More figures
394
+ ]
395
+ """
396
+ figures = [convert_to_pil(figure) for figure in figures]
397
+ return self.coref.predict_images(figures, batch_size=batch_size, coref=True, molscribe = molscribe, ocr = ocr)
398
+
399
+ def extract_reactions_from_figures_in_pdf(self, pdf, batch_size=16, num_pages=None, molscribe=True, ocr=True):
400
+ """
401
+ Get reaction information from figures in pdf
402
+ Parameters:
403
+ pdf: path to pdf, or byte file
404
+ batch_size: batch size for inference in all models
405
+ num_pages: process only first `num_pages` pages, if `None` then process all
406
+ molscribe: whether to predict and return smiles and molfile info
407
+ ocr: whether to predict and return text of conditions
408
+ Returns:
409
+ list of figures and corresponding molecule info in the following format
410
+ [
411
+ {
412
+ 'figure': PIL image
413
+ 'reactions': [
414
+ {
415
+ 'reactants': [
416
+ {
417
+ 'category': str,
418
+ 'bbox': tuple (x1,x2,y1,y2),
419
+ 'category_id': int,
420
+ 'smiles': str,
421
+ 'molfile': str,
422
+ },
423
+ # more reactants
424
+ ],
425
+ 'conditions': [
426
+ {
427
+ 'category': str,
428
+ 'bbox': tuple (x1,x2,y1,y2),
429
+ 'category_id': int,
430
+ 'text': list of str,
431
+ },
432
+ # more conditions
433
+ ],
434
+ 'products': [
435
+ # same structure as reactants
436
+ ]
437
+ },
438
+ # more reactions
439
+ ],
440
+ 'page': int
441
+ },
442
+ # more figures
443
+ ]
444
+ """
445
+ figures = self.extract_figures_from_pdf(pdf, num_pages=num_pages, output_bbox=True)
446
+ images = [figure['figure']['image'] for figure in figures]
447
+ results = self.extract_reactions_from_figures(images, batch_size=batch_size, molscribe=molscribe, ocr=ocr)
448
+ for figure, result in zip(figures, results):
449
+ result['page'] = figure['page']
450
+ return results
451
+
452
+ def extract_reactions_from_figures(self, figures, batch_size=16, molscribe=True, ocr=True):
453
+ """
454
+ Get reaction information from list of figures
455
+ Parameters:
456
+ figures: list of PIL or ndarray images
457
+ batch_size: batch size for inference in all models
458
+ molscribe: whether to predict and return smiles and molfile info
459
+ ocr: whether to predict and return text of conditions
460
+ Returns:
461
+ list of figures and corresponding molecule info in the following format
462
+ [
463
+ {
464
+ 'figure': PIL image
465
+ 'reactions': [
466
+ {
467
+ 'reactants': [
468
+ {
469
+ 'category': str,
470
+ 'bbox': tuple (x1,x2,y1,y2),
471
+ 'category_id': int,
472
+ 'smiles': str,
473
+ 'molfile': str,
474
+ },
475
+ # more reactants
476
+ ],
477
+ 'conditions': [
478
+ {
479
+ 'category': str,
480
+ 'bbox': tuple (x1,x2,y1,y2),
481
+ 'category_id': int,
482
+ 'text': list of str,
483
+ },
484
+ # more conditions
485
+ ],
486
+ 'products': [
487
+ # same structure as reactants
488
+ ]
489
+ },
490
+ # more reactions
491
+ ],
492
+ },
493
+ # more figures
494
+ ]
495
+
496
+ """
497
+ pil_figures = [convert_to_pil(figure) for figure in figures]
498
+ results = []
499
+ reactions = self.rxnscribe.predict_images(pil_figures, batch_size=batch_size, molscribe=molscribe, ocr=ocr)
500
+ for figure, rxn in zip(figures, reactions):
501
+ data = {
502
+ 'figure': figure,
503
+ 'reactions': rxn,
504
+ }
505
+ results.append(data)
506
+ return results
507
+
508
+ def extract_molecules_from_text_in_pdf(self, pdf, batch_size=16, num_pages=None):
509
+ """
510
+ Get molecules in text of given pdf
511
+
512
+ Parameters:
513
+ pdf: path to pdf, or byte file
514
+ batch_size: batch size for inference in all models
515
+ num_pages: process only first `num_pages` pages, if `None` then process all
516
+ Returns:
517
+ list of sentences and found molecules in the following format
518
+ [
519
+ {
520
+ 'molecules': [
521
+ { # first paragraph
522
+ 'text': str,
523
+ 'labels': [
524
+ (str, int, int), # tuple of label, range start (inclusive), range end (exclusive)
525
+ # more labels
526
+ ]
527
+ },
528
+ # more paragraphs
529
+ ]
530
+ 'page': int
531
+ },
532
+ # more pages
533
+ ]
534
+ """
535
+ self.chemrxnextractor.set_pdf_file(pdf)
536
+ self.chemrxnextractor.set_pages(num_pages)
537
+ text = self.chemrxnextractor.get_paragraphs_from_pdf(num_pages)
538
+ result = []
539
+ for data in text:
540
+ model_inp = []
541
+ for paragraph in data['paragraphs']:
542
+ model_inp.append(' '.join(paragraph).replace('\n', ''))
543
+ output = self.chemner.predict_strings(model_inp, batch_size=batch_size)
544
+ to_add = {
545
+ 'molecules': [{
546
+ 'text': t,
547
+ 'labels': labels,
548
+ } for t, labels in zip(model_inp, output)],
549
+ 'page': data['page']
550
+ }
551
+ result.append(to_add)
552
+ return result
553
+
554
+
555
+ def extract_reactions_from_text_in_pdf(self, pdf, num_pages=None):
556
+ """
557
+ Get reaction information from text in pdf
558
+ Parameters:
559
+ pdf: path to pdf
560
+ num_pages: process only first `num_pages` pages, if `None` then process all
561
+ Returns:
562
+ list of pages and corresponding reaction info in the following format
563
+ [
564
+ {
565
+ 'page': page number
566
+ 'reactions': [
567
+ {
568
+ 'tokens': list of words in relevant sentence,
569
+ 'reactions' : [
570
+ {
571
+ # key, value pairs where key is the label and value is a tuple
572
+ # or list of tuples of the form (tokens, start index, end index)
573
+ # where indices are for the corresponding token list and start and end are inclusive
574
+ }
575
+ # more reactions
576
+ ]
577
+ }
578
+ # more reactions in other sentences
579
+ ]
580
+ },
581
+ # more pages
582
+ ]
583
+ """
584
+ self.chemrxnextractor.set_pdf_file(pdf)
585
+ self.chemrxnextractor.set_pages(num_pages)
586
+ return self.chemrxnextractor.extract_reactions_from_text()
587
+
588
+ def extract_reactions_from_text_in_pdf_combined(self, pdf, num_pages=None):
589
+ """
590
+ Get reaction information from text in pdf and combined with corefs from figures
591
+ Parameters:
592
+ pdf: path to pdf
593
+ num_pages: process only first `num_pages` pages, if `None` then process all
594
+ Returns:
595
+ list of pages and corresponding reaction info in the following format
596
+ [
597
+ {
598
+ 'page': page number
599
+ 'reactions': [
600
+ {
601
+ 'tokens': list of words in relevant sentence,
602
+ 'reactions' : [
603
+ {
604
+ # key, value pairs where key is the label and value is a tuple
605
+ # or list of tuples of the form (tokens, start index, end index)
606
+ # where indices are for the corresponding token list and start and end are inclusive
607
+ }
608
+ # more reactions
609
+ ]
610
+ }
611
+ # more reactions in other sentences
612
+ ]
613
+ },
614
+ # more pages
615
+ ]
616
+ """
617
+ results = self.extract_reactions_from_text_in_pdf(pdf, num_pages=num_pages)
618
+ results_coref = self.extract_molecule_corefs_from_figures_in_pdf(pdf, num_pages=num_pages)
619
+ return associate_corefs(results, results_coref)
620
+
621
+ def extract_reactions_from_figures_and_tables_in_pdf(self, pdf, num_pages=None, batch_size=16, molscribe=True, ocr=True):
622
+ """
623
+ Get reaction information from figures and combine with table information in pdf
624
+ Parameters:
625
+ pdf: path to pdf, or byte file
626
+ batch_size: batch size for inference in all models
627
+ num_pages: process only first `num_pages` pages, if `None` then process all
628
+ molscribe: whether to predict and return smiles and molfile info
629
+ ocr: whether to predict and return text of conditions
630
+ Returns:
631
+ list of figures and corresponding molecule info in the following format
632
+ [
633
+ {
634
+ 'figure': PIL image
635
+ 'reactions': [
636
+ {
637
+ 'reactants': [
638
+ {
639
+ 'category': str,
640
+ 'bbox': tuple (x1,x2,y1,y2),
641
+ 'category_id': int,
642
+ 'smiles': str,
643
+ 'molfile': str,
644
+ },
645
+ # more reactants
646
+ ],
647
+ 'conditions': [
648
+ {
649
+ 'category': str,
650
+ 'text': list of str,
651
+ },
652
+ # more conditions
653
+ ],
654
+ 'products': [
655
+ # same structure as reactants
656
+ ]
657
+ },
658
+ # more reactions
659
+ ],
660
+ 'page': int
661
+ },
662
+ # more figures
663
+ ]
664
+ """
665
+ figures = self.extract_figures_from_pdf(pdf, num_pages=num_pages, output_bbox=True)
666
+ images = [figure['figure']['image'] for figure in figures]
667
+ results = self.extract_reactions_from_figures(images, batch_size=batch_size, molscribe=molscribe, ocr=ocr)
668
+ results = process_tables(figures, results, self.molscribe, batch_size=batch_size)
669
+ results_coref = self.extract_molecule_corefs_from_figures_in_pdf(pdf, num_pages=num_pages)
670
+ results = replace_rgroups_in_figure(figures, results, results_coref, self.molscribe, batch_size=batch_size)
671
+ results = expand_reactions_with_backout(results, results_coref, self.molscribe)
672
+ return results
673
+
674
+ def extract_reactions_from_pdf(self, pdf, num_pages=None, batch_size=16):
675
+ """
676
+ Returns:
677
+ dictionary of reactions from multimodal sources
678
+ {
679
+ 'figures': [
680
+ {
681
+ 'figure': PIL image
682
+ 'reactions': [
683
+ {
684
+ 'reactants': [
685
+ {
686
+ 'category': str,
687
+ 'bbox': tuple (x1,x2,y1,y2),
688
+ 'category_id': int,
689
+ 'smiles': str,
690
+ 'molfile': str,
691
+ },
692
+ # more reactants
693
+ ],
694
+ 'conditions': [
695
+ {
696
+ 'category': str,
697
+ 'text': list of str,
698
+ },
699
+ # more conditions
700
+ ],
701
+ 'products': [
702
+ # same structure as reactants
703
+ ]
704
+ },
705
+ # more reactions
706
+ ],
707
+ 'page': int
708
+ },
709
+ # more figures
710
+ ]
711
+ 'text': [
712
+ {
713
+ 'page': page number
714
+ 'reactions': [
715
+ {
716
+ 'tokens': list of words in relevant sentence,
717
+ 'reactions' : [
718
+ {
719
+ # key, value pairs where key is the label and value is a tuple
720
+ # or list of tuples of the form (tokens, start index, end index)
721
+ # where indices are for the corresponding token list and start and end are inclusive
722
+ }
723
+ # more reactions
724
+ ]
725
+ }
726
+ # more reactions in other sentences
727
+ ]
728
+ },
729
+ # more pages
730
+ ]
731
+ }
732
+
733
+ """
734
+ figures = self.extract_figures_from_pdf(pdf, num_pages=num_pages, output_bbox=True)
735
+ images = [figure['figure']['image'] for figure in figures]
736
+ results = self.extract_reactions_from_figures(images, batch_size=batch_size, molscribe=True, ocr=True)
737
+ table_expanded_results = process_tables(figures, results, self.molscribe, batch_size=batch_size)
738
+ text_results = self.extract_reactions_from_text_in_pdf(pdf, num_pages=num_pages)
739
+ results_coref = self.extract_molecule_corefs_from_figures_in_pdf(pdf, num_pages=num_pages)
740
+ figure_results = replace_rgroups_in_figure(figures, table_expanded_results, results_coref, self.molscribe, batch_size=batch_size)
741
+ table_expanded_results = expand_reactions_with_backout(figure_results, results_coref, self.molscribe)
742
+ coref_expanded_results = associate_corefs(text_results, results_coref)
743
+ return {
744
+ 'figures': table_expanded_results,
745
+ 'text': coref_expanded_results,
746
+ }
747
+
748
+ if __name__=="__main__":
749
+ model = OpenChemIE()
chemietoolkit/tableextractor.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdf2image
2
+ import numpy as np
3
+ from PIL import Image
4
+ import matplotlib.pyplot as plt
5
+ import layoutparser as lp
6
+ import cv2
7
+
8
+ from PyPDF2 import PdfReader, PdfWriter
9
+ import pandas as pd
10
+
11
+ import pdfminer.high_level
12
+ import pdfminer.layout
13
+ from operator import itemgetter
14
+
15
+ # inputs: pdf_file, page #, bounding box (optional) (llur or ullr), output_bbox
16
+ class TableExtractor(object):
17
+ def __init__(self, output_bbox=True):
18
+ self.pdf_file = ""
19
+ self.page = ""
20
+ self.image_dpi = 200
21
+ self.pdf_dpi = 72
22
+ self.output_bbox = output_bbox
23
+ self.blocks = {}
24
+ self.title_y = 0
25
+ self.column_header_y = 0
26
+ self.model = None
27
+ self.img = None
28
+ self.output_image = True
29
+ self.tagging = {
30
+ 'substance': ['compound', 'salt', 'base', 'solvent', 'CBr4', 'collidine', 'InX3', 'substrate', 'ligand', 'PPh3', 'PdL2', 'Cu', 'compd', 'reagent', 'reagant', 'acid', 'aldehyde', 'amine', 'Ln', 'H2O', 'enzyme', 'cofactor', 'oxidant', 'Pt(COD)Cl2', 'CuBr2', 'additive'],
31
+ 'ratio': [':'],
32
+ 'measurement': ['μM', 'nM', 'IC50', 'CI', 'excitation', 'emission', 'Φ', 'φ', 'shift', 'ee', 'ΔG', 'ΔH', 'TΔS', 'Δ', 'distance', 'trajectory', 'V', 'eV'],
33
+ 'temperature': ['temp', 'temperature', 'T', '°C'],
34
+ 'time': ['time', 't(', 't ('],
35
+ 'result': ['yield', 'aa', 'result', 'product', 'conversion', '(%)'],
36
+ 'alkyl group': ['R', 'Ar', 'X', 'Y'],
37
+ 'solvent': ['solvent'],
38
+ 'counter': ['entry', 'no.'],
39
+ 'catalyst': ['catalyst', 'cat.'],
40
+ 'conditions': ['condition'],
41
+ 'reactant': ['reactant'],
42
+ }
43
+
44
+ def set_output_image(self, oi):
45
+ self.output_image = oi
46
+
47
+ def set_pdf_file(self, pdf):
48
+ self.pdf_file = pdf
49
+
50
+ def set_page_num(self, pn):
51
+ self.page = pn
52
+
53
+ def set_output_bbox(self, ob):
54
+ self.output_bbox = ob
55
+
56
+ def run_model(self, page_info):
57
+ #img = np.asarray(pdf2image.convert_from_path(self.pdf_file, dpi=self.image_dpi)[self.page])
58
+
59
+ #model = lp.Detectron2LayoutModel('lp://PubLayNet/mask_rcnn_X_101_32x8d_FPN_3x/config', extra_config=["MODEL.ROI_HEADS.SCORE_THRESH_TEST", 0.5], label_map={0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"})
60
+
61
+ img = np.asarray(page_info)
62
+ self.img = img
63
+
64
+ layout_result = self.model.detect(img)
65
+
66
+ text_blocks = lp.Layout([b for b in layout_result if b.type == 'Text'])
67
+ title_blocks = lp.Layout([b for b in layout_result if b.type == 'Title'])
68
+ list_blocks = lp.Layout([b for b in layout_result if b.type == 'List'])
69
+ table_blocks = lp.Layout([b for b in layout_result if b.type == 'Table'])
70
+ figure_blocks = lp.Layout([b for b in layout_result if b.type == 'Figure'])
71
+
72
+ self.blocks.update({'text': text_blocks})
73
+ self.blocks.update({'title': title_blocks})
74
+ self.blocks.update({'list': list_blocks})
75
+ self.blocks.update({'table': table_blocks})
76
+ self.blocks.update({'figure': figure_blocks})
77
+
78
+ # type is what coordinates you want to get. it comes in text, title, list, table, and figure
79
+ def convert_to_pdf_coordinates(self, type):
80
+ # scale coordinates
81
+
82
+ blocks = self.blocks[type]
83
+ coordinates = [blocks[a].scale(self.pdf_dpi/self.image_dpi) for a in range(len(blocks))]
84
+
85
+ reader = PdfReader(self.pdf_file)
86
+
87
+ writer = PdfWriter()
88
+ p = reader.pages[self.page]
89
+ a = p.mediabox.upper_left
90
+ new_coords = []
91
+ for new_block in coordinates:
92
+ new_coords.append((new_block.block.x_1, pd.to_numeric(a[1]) - new_block.block.y_2, new_block.block.x_2, pd.to_numeric(a[1]) - new_block.block.y_1))
93
+
94
+ return new_coords
95
+ # output: list of bounding boxes for tables but in pdf coordinates
96
+
97
+ # input: new_coords is singular table bounding box in pdf coordinates
98
+ def extract_singular_table(self, new_coords):
99
+ for page_layout in pdfminer.high_level.extract_pages(self.pdf_file, page_numbers=[self.page]):
100
+ elements = []
101
+ for element in page_layout:
102
+ if isinstance(element, pdfminer.layout.LTTextBox):
103
+ for e in element._objs:
104
+ temp = e.bbox
105
+ if temp[0] > min(new_coords[0], new_coords[2]) and temp[0] < max(new_coords[0], new_coords[2]) and temp[1] > min(new_coords[1], new_coords[3]) and temp[1] < max(new_coords[1], new_coords[3]) and temp[2] > min(new_coords[0], new_coords[2]) and temp[2] < max(new_coords[0], new_coords[2]) and temp[3] > min(new_coords[1], new_coords[3]) and temp[3] < max(new_coords[1], new_coords[3]) and isinstance(e, pdfminer.layout.LTTextLineHorizontal):
106
+ elements.append([e.bbox[0], e.bbox[1], e.bbox[2], e.bbox[3], e.get_text()])
107
+
108
+ elements = sorted(elements, key=itemgetter(0))
109
+ w = sorted(elements, key=itemgetter(3), reverse=True)
110
+ if len(w) <= 1:
111
+ continue
112
+
113
+ ret = {}
114
+ i = 1
115
+ g = [w[0]]
116
+
117
+ while i < len(w) and w[i][3] > w[i-1][1]:
118
+ g.append(w[i])
119
+ i += 1
120
+ g = sorted(g, key=itemgetter(0))
121
+ # check for overlaps
122
+ for a in range(len(g)-1, 0, -1):
123
+ if g[a][0] < g[a-1][2]:
124
+ g[a-1][0] = min(g[a][0], g[a-1][0])
125
+ g[a-1][1] = min(g[a][1], g[a-1][1])
126
+ g[a-1][2] = max(g[a][2], g[a-1][2])
127
+ g[a-1][3] = max(g[a][3], g[a-1][3])
128
+ g[a-1][4] = g[a-1][4].strip() + " " + g[a][4]
129
+ g.pop(a)
130
+
131
+
132
+ ret.update({"columns":[]})
133
+ for t in g:
134
+ temp_bbox = t[:4]
135
+
136
+ column_text = t[4].strip()
137
+ tag = 'unknown'
138
+ tagged = False
139
+ for key in self.tagging.keys():
140
+ for word in self.tagging[key]:
141
+ if word in column_text:
142
+ tag = key
143
+ tagged = True
144
+ break
145
+ if tagged:
146
+ break
147
+
148
+ if self.output_bbox:
149
+ ret["columns"].append({'text':column_text,'tag': tag, 'bbox':temp_bbox})
150
+ else:
151
+ ret["columns"].append({'text':column_text,'tag': tag})
152
+ self.column_header_y = max(t[1], t[3])
153
+ ret.update({"rows":[]})
154
+
155
+ g.insert(0, [0, 0, new_coords[0], 0, ''])
156
+ g.append([new_coords[2], 0, 0, 0, ''])
157
+ while i < len(w):
158
+ group = [w[i]]
159
+ i += 1
160
+ while i < len(w) and w[i][3] > w[i-1][1]:
161
+ group.append(w[i])
162
+ i += 1
163
+ group = sorted(group, key=itemgetter(0))
164
+
165
+ for a in range(len(group)-1, 0, -1):
166
+ if group[a][0] < group[a-1][2]:
167
+ group[a-1][0] = min(group[a][0], group[a-1][0])
168
+ group[a-1][1] = min(group[a][1], group[a-1][1])
169
+ group[a-1][2] = max(group[a][2], group[a-1][2])
170
+ group[a-1][3] = max(group[a][3], group[a-1][3])
171
+ group[a-1][4] = group[a-1][4].strip() + " " + group[a][4]
172
+ group.pop(a)
173
+
174
+ a = 1
175
+ while a < len(g) - 1:
176
+ if a > len(group):
177
+ group.append([0, 0, 0, 0, '\n'])
178
+ a += 1
179
+ continue
180
+ if group[a-1][0] >= g[a-1][2] and group[a-1][2] <= g[a+1][0]:
181
+ pass
182
+ """
183
+ if a < len(group) and group[a][0] >= g[a-1][2] and group[a][2] <= g[a+1][0]:
184
+ g.insert(1, [g[0][2], 0, group[a-1][2], 0, ''])
185
+ #ret["columns"].insert(0, '')
186
+ else:
187
+ a += 1
188
+ continue
189
+ """
190
+ else: group.insert(a-1, [0, 0, 0, 0, '\n'])
191
+ a += 1
192
+
193
+
194
+ added_row = []
195
+ for t in group:
196
+ temp_bbox = t[:4]
197
+ if self.output_bbox:
198
+ added_row.append({'text':t[4].strip(), 'bbox':temp_bbox})
199
+ else:
200
+ added_row.append(t[4].strip())
201
+ ret["rows"].append(added_row)
202
+ if ret["rows"] and len(ret["rows"][0]) != len(ret["columns"]):
203
+ ret["columns"] = ret["rows"][0]
204
+ ret["rows"] = ret["rows"][1:]
205
+ for col in ret['columns']:
206
+ tag = 'unknown'
207
+ tagged = False
208
+ for key in self.tagging.keys():
209
+ for word in self.tagging[key]:
210
+ if word in col['text']:
211
+ tag = key
212
+ tagged = True
213
+ break
214
+ if tagged:
215
+ break
216
+ col['tag'] = tag
217
+
218
+ return ret
219
+
220
+ def get_title_and_footnotes(self, tb_coords):
221
+
222
+ for page_layout in pdfminer.high_level.extract_pages(self.pdf_file, page_numbers=[self.page]):
223
+ title = (0, 0, 0, 0, '')
224
+ footnote = (0, 0, 0, 0, '')
225
+ title_gap = 30
226
+ footnote_gap = 30
227
+ for element in page_layout:
228
+ if isinstance(element, pdfminer.layout.LTTextBoxHorizontal):
229
+ if (element.bbox[0] >= tb_coords[0] and element.bbox[0] <= tb_coords[2]) or (element.bbox[2] >= tb_coords[0] and element.bbox[2] <= tb_coords[2]) or (tb_coords[0] >= element.bbox[0] and tb_coords[0] <= element.bbox[2]) or (tb_coords[2] >= element.bbox[0] and tb_coords[2] <= element.bbox[2]):
230
+ #print(element)
231
+ if 'Table' in element.get_text():
232
+ if abs(element.bbox[1] - tb_coords[3]) < title_gap:
233
+ title = tuple(element.bbox) + (element.get_text()[element.get_text().index('Table'):].replace('\n', ' '),)
234
+ title_gap = abs(element.bbox[1] - tb_coords[3])
235
+ if 'Scheme' in element.get_text():
236
+ if abs(element.bbox[1] - tb_coords[3]) < title_gap:
237
+ title = tuple(element.bbox) + (element.get_text()[element.get_text().index('Scheme'):].replace('\n', ' '),)
238
+ title_gap = abs(element.bbox[1] - tb_coords[3])
239
+ if element.bbox[1] >= tb_coords[1] and element.bbox[3] <= tb_coords[3]: continue
240
+ #print(element)
241
+ temp = ['aA', 'aB', 'aC', 'aD', 'aE', 'aF', 'aG', 'aH', 'aI', 'aJ', 'aK', 'aL', 'aM', 'aN', 'aO', 'aP', 'aQ', 'aR', 'aS', 'aT', 'aU', 'aV', 'aW', 'aX', 'aY', 'aZ', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6', 'a7', 'a8', 'a9', 'a0']
242
+ for segment in temp:
243
+ if segment in element.get_text():
244
+ if abs(element.bbox[3] - tb_coords[1]) < footnote_gap:
245
+ footnote = tuple(element.bbox) + (element.get_text()[element.get_text().index(segment):].replace('\n', ' '),)
246
+ footnote_gap = abs(element.bbox[3] - tb_coords[1])
247
+ break
248
+ self.title_y = min(title[1], title[3])
249
+ if self.output_bbox:
250
+ return ({'text': title[4], 'bbox': list(title[:4])}, {'text': footnote[4], 'bbox': list(footnote[:4])})
251
+ else:
252
+ return (title[4], footnote[4])
253
+
254
+ def extract_table_information(self):
255
+ #self.run_model(page_info) # changed
256
+ table_coordinates = self.blocks['table'] #should return a list of layout objects
257
+ table_coordinates_in_pdf = self.convert_to_pdf_coordinates('table') #should return a list of lists
258
+
259
+ ans = []
260
+ i = 0
261
+ for coordinate in table_coordinates_in_pdf:
262
+ ret = {}
263
+ pad = 20
264
+ coordinate = [coordinate[0] - pad, coordinate[1], coordinate[2] + pad, coordinate[3]]
265
+ ullr_coord = [coordinate[0], coordinate[3], coordinate[2], coordinate[1]]
266
+
267
+ table_results = self.extract_singular_table(coordinate)
268
+ tf = self.get_title_and_footnotes(coordinate)
269
+ figure = Image.fromarray(table_coordinates[i].crop_image(self.img))
270
+ ret.update({'title': tf[0]})
271
+ ret.update({'figure': {
272
+ 'image': None,
273
+ 'bbox': []
274
+ }})
275
+ if self.output_image:
276
+ ret['figure']['image'] = figure
277
+ ret.update({'table': {'bbox': list(coordinate), 'content': table_results}})
278
+ ret.update({'footnote': tf[1]})
279
+ if abs(self.title_y - self.column_header_y) > 50:
280
+ ret['figure']['bbox'] = list(coordinate)
281
+
282
+ ret.update({'page':self.page})
283
+
284
+ ans.append(ret)
285
+ i += 1
286
+
287
+ return ans
288
+
289
+ def extract_figure_information(self):
290
+ figure_coordinates = self.blocks['figure']
291
+ figure_coordinates_in_pdf = self.convert_to_pdf_coordinates('figure')
292
+
293
+ ans = []
294
+ for i in range(len(figure_coordinates)):
295
+ ret = {}
296
+ coordinate = figure_coordinates_in_pdf[i]
297
+ ullr_coord = [coordinate[0], coordinate[3], coordinate[2], coordinate[1]]
298
+
299
+ tf = self.get_title_and_footnotes(coordinate)
300
+ figure = Image.fromarray(figure_coordinates[i].crop_image(self.img))
301
+ ret.update({'title':tf[0]})
302
+ ret.update({'figure': {
303
+ 'image': None,
304
+ 'bbox': []
305
+ }})
306
+ if self.output_image:
307
+ ret['figure']['image'] = figure
308
+ ret.update({'table': {
309
+ 'bbox': [],
310
+ 'content': None
311
+ }})
312
+ ret.update({'footnote': tf[1]})
313
+ ret['figure']['bbox'] = list(coordinate)
314
+
315
+ ret.update({'page':self.page})
316
+
317
+ ans.append(ret)
318
+
319
+ return ans
320
+
321
+
322
+ def extract_all_tables_and_figures(self, pages, pdfparser, content=None):
323
+ self.model = pdfparser
324
+ ret = []
325
+ for i in range(len(pages)):
326
+ self.set_page_num(i)
327
+ self.run_model(pages[i])
328
+ table_info = self.extract_table_information()
329
+ figure_info = self.extract_figure_information()
330
+ if content == 'tables':
331
+ ret += table_info
332
+ elif content == 'figures':
333
+ ret += figure_info
334
+ for table in table_info:
335
+ if table['figure']['bbox'] != []:
336
+ ret.append(table)
337
+ else:
338
+ ret += table_info
339
+ ret += figure_info
340
+ return ret
chemietoolkit/utils.py ADDED
@@ -0,0 +1,1018 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import cv2
4
+ import layoutparser as lp
5
+ from rdkit import Chem
6
+ from rdkit.Chem import Draw
7
+ from rdkit.Chem import rdDepictor
8
+ rdDepictor.SetPreferCoordGen(True)
9
+ from rdkit.Chem.Draw import IPythonConsole
10
+ from rdkit.Chem import AllChem
11
+ import re
12
+ import copy
13
+
14
+ BOND_TO_INT = {
15
+ "": 0,
16
+ "single": 1,
17
+ "double": 2,
18
+ "triple": 3,
19
+ "aromatic": 4,
20
+ "solid wedge": 5,
21
+ "dashed wedge": 6
22
+ }
23
+
24
+ RGROUP_SYMBOLS = ['R', 'R1', 'R2', 'R3', 'R4', 'R5', 'R6', 'R7', 'R8', 'R9', 'R10', 'R11', 'R12',
25
+ 'Ra', 'Rb', 'Rc', 'Rd', 'Rf', 'X', 'Y', 'Z', 'Q', 'A', 'E', 'Ar', 'Ar1', 'Ar2', 'Ari', "R'",
26
+ '1*', '2*','3*', '4*','5*', '6*','7*', '8*','9*', '10*','11*', '12*','[a*]', '[b*]','[c*]', '[d*]']
27
+
28
+ RGROUP_SYMBOLS = RGROUP_SYMBOLS + [f'[{i}]' for i in RGROUP_SYMBOLS]
29
+
30
+ RGROUP_SMILES = ['[1*]', '[2*]','[3*]', '[4*]','[5*]', '[6*]','[7*]', '[8*]','[9*]', '[10*]','[11*]', '[12*]','[a*]', '[b*]','[c*]', '[d*]','*', '[Rf]']
31
+
32
+ def get_figures_from_pages(pages, pdfparser):
33
+ figures = []
34
+ for i in range(len(pages)):
35
+ img = np.asarray(pages[i])
36
+ layout = pdfparser.detect(img)
37
+ blocks = lp.Layout([b for b in layout if b.type == "Figure"])
38
+ for block in blocks:
39
+ figure = Image.fromarray(block.crop_image(img))
40
+ figures.append({
41
+ 'image': figure,
42
+ 'page': i
43
+ })
44
+ return figures
45
+
46
+ def clean_bbox_output(figures, bboxes):
47
+ results = []
48
+ cropped = []
49
+ references = []
50
+ for i, output in enumerate(bboxes):
51
+ mol_bboxes = [elt['bbox'] for elt in output if elt['category'] == '[Mol]']
52
+ mol_scores = [elt['score'] for elt in output if elt['category'] == '[Mol]']
53
+ data = {}
54
+ results.append(data)
55
+ data['image'] = figures[i]
56
+ data['molecules'] = []
57
+ for bbox, score in zip(mol_bboxes, mol_scores):
58
+ x1, y1, x2, y2 = bbox
59
+ height, width, _ = figures[i].shape
60
+ cropped_img = figures[i][int(y1*height):int(y2*height),int(x1*width):int(x2*width)]
61
+ cur_mol = {
62
+ 'bbox': bbox,
63
+ 'score': score,
64
+ 'image': cropped_img,
65
+ #'info': None,
66
+ }
67
+ cropped.append(cropped_img)
68
+ data['molecules'].append(cur_mol)
69
+ references.append(cur_mol)
70
+ return results, cropped, references
71
+
72
+ def convert_to_pil(image):
73
+ if type(image) == np.ndarray:
74
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
75
+ image = Image.fromarray(image)
76
+ return image
77
+
78
+ def convert_to_cv2(image):
79
+ if type(image) != np.ndarray:
80
+ image = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)
81
+ return image
82
+
83
+ def replace_rgroups_in_figure(figures, results, coref_results, molscribe, batch_size=16):
84
+ pattern = re.compile('(?P<name>[RXY]\d?)[ ]*=[ ]*(?P<group>\w+)')
85
+ for figure, result, corefs in zip(figures, results, coref_results):
86
+ r_groups = []
87
+ seen_r_groups = set()
88
+ for bbox in corefs['bboxes']:
89
+ if bbox['category'] == '[Idt]':
90
+ for text in bbox['text']:
91
+ res = pattern.search(text)
92
+ if res is None:
93
+ continue
94
+ name = res.group('name')
95
+ group = res.group('group')
96
+ if (name, group) in seen_r_groups:
97
+ continue
98
+ seen_r_groups.add((name, group))
99
+ r_groups.append({name: res.group('group')})
100
+ if r_groups and result['reactions']:
101
+ seen_r_groups = set([pair[0] for pair in seen_r_groups])
102
+ orig_reaction = result['reactions'][0]
103
+ graphs = get_atoms_and_bonds(figure['figure']['image'], orig_reaction, molscribe, batch_size=batch_size)
104
+ relevant_locs = {}
105
+ for i, graph in enumerate(graphs):
106
+ to_add = []
107
+ for j, atom in enumerate(graph['chartok_coords']['symbols']):
108
+ if atom[1:-1] in seen_r_groups:
109
+ to_add.append((atom[1:-1], j))
110
+ relevant_locs[i] = to_add
111
+
112
+ for r_group in r_groups:
113
+ reaction = get_replaced_reaction(orig_reaction, graphs, relevant_locs, r_group, molscribe)
114
+ to_add ={
115
+ 'reactants': reaction['reactants'][:],
116
+ 'conditions': orig_reaction['conditions'][:],
117
+ 'products': reaction['products'][:]
118
+ }
119
+ result['reactions'].append(to_add)
120
+ return results
121
+
122
+ def process_tables(figures, results, molscribe, batch_size=16):
123
+ r_group_pattern = re.compile(r'^(\w+-)?(?P<group>[\w-]+)( \(\w+\))?$')
124
+ for figure, result in zip(figures, results):
125
+ result['page'] = figure['page']
126
+ if figure['table']['content'] is not None:
127
+ content = figure['table']['content']
128
+ if len(result['reactions']) > 1:
129
+ print("Warning: multiple reactions detected for table")
130
+ elif len(result['reactions']) == 0:
131
+ continue
132
+ orig_reaction = result['reactions'][0]
133
+ graphs = get_atoms_and_bonds(figure['figure']['image'], orig_reaction, molscribe, batch_size=batch_size)
134
+ relevant_locs = find_relevant_groups(graphs, content['columns'])
135
+ conditions_to_extend = []
136
+ for row in content['rows']:
137
+ r_groups = {}
138
+ expanded_conditions = orig_reaction['conditions'][:]
139
+ replaced = False
140
+ for col, entry in zip(content['columns'], row):
141
+ if col['tag'] != 'alkyl group':
142
+ expanded_conditions.append({
143
+ 'category': '[Table]',
144
+ 'text': entry['text'],
145
+ 'tag': col['tag'],
146
+ 'header': col['text'],
147
+ })
148
+ else:
149
+ found = r_group_pattern.match(entry['text'])
150
+ if found is not None:
151
+ r_groups[col['text']] = found.group('group')
152
+ replaced = True
153
+ reaction = get_replaced_reaction(orig_reaction, graphs, relevant_locs, r_groups, molscribe)
154
+ if replaced:
155
+ to_add = {
156
+ 'reactants': reaction['reactants'][:],
157
+ 'conditions': expanded_conditions,
158
+ 'products': reaction['products'][:]
159
+ }
160
+ result['reactions'].append(to_add)
161
+ else:
162
+ conditions_to_extend.append(expanded_conditions)
163
+ orig_reaction['conditions'] = [orig_reaction['conditions']]
164
+ orig_reaction['conditions'].extend(conditions_to_extend)
165
+ return results
166
+
167
+
168
+ def get_atoms_and_bonds(image, reaction, molscribe, batch_size=16):
169
+ image = convert_to_cv2(image)
170
+ cropped_images = []
171
+ results = []
172
+ for key, molecules in reaction.items():
173
+ for i, elt in enumerate(molecules):
174
+ if type(elt) != dict or elt['category'] != '[Mol]':
175
+ continue
176
+ x1, y1, x2, y2 = elt['bbox']
177
+ height, width, _ = image.shape
178
+ cropped_images.append(image[int(y1*height):int(y2*height),int(x1*width):int(x2*width)])
179
+ to_add = {
180
+ 'image': cropped_images[-1],
181
+ 'chartok_coords': {
182
+ 'coords': [],
183
+ 'symbols': [],
184
+ },
185
+ 'edges': [],
186
+ 'key': (key, i)
187
+ }
188
+ results.append(to_add)
189
+ outputs = molscribe.predict_images(cropped_images, return_atoms_bonds=True, batch_size=batch_size)
190
+ for mol, result in zip(outputs, results):
191
+ for atom in mol['atoms']:
192
+ result['chartok_coords']['coords'].append((atom['x'], atom['y']))
193
+ result['chartok_coords']['symbols'].append(atom['atom_symbol'])
194
+ result['edges'] = [[0] * len(mol['atoms']) for _ in range(len(mol['atoms']))]
195
+ for bond in mol['bonds']:
196
+ i, j = bond['endpoint_atoms']
197
+ result['edges'][i][j] = BOND_TO_INT[bond['bond_type']]
198
+ result['edges'][j][i] = BOND_TO_INT[bond['bond_type']]
199
+ return results
200
+
201
+ def find_relevant_groups(graphs, columns):
202
+ results = {}
203
+ r_groups = set([f"[{col['text']}]" for col in columns if col['tag'] == 'alkyl group'])
204
+ for i, graph in enumerate(graphs):
205
+ to_add = []
206
+ for j, atom in enumerate(graph['chartok_coords']['symbols']):
207
+ if atom in r_groups:
208
+ to_add.append((atom[1:-1], j))
209
+ results[i] = to_add
210
+ return results
211
+
212
+ def get_replaced_reaction(orig_reaction, graphs, relevant_locs, mappings, molscribe):
213
+ graph_copy = []
214
+ for graph in graphs:
215
+ graph_copy.append({
216
+ 'image': graph['image'],
217
+ 'chartok_coords': {
218
+ 'coords': graph['chartok_coords']['coords'][:],
219
+ 'symbols': graph['chartok_coords']['symbols'][:],
220
+ },
221
+ 'edges': graph['edges'][:],
222
+ 'key': graph['key'],
223
+ })
224
+ for graph_idx, atoms in relevant_locs.items():
225
+ for atom, atom_idx in atoms:
226
+ if atom in mappings:
227
+ graph_copy[graph_idx]['chartok_coords']['symbols'][atom_idx] = mappings[atom]
228
+ reaction_copy = {}
229
+ def append_copy(copy_list, entity):
230
+ if entity['category'] == '[Mol]':
231
+ copy_list.append({
232
+ k1: v1 for k1, v1 in entity.items()
233
+ })
234
+ else:
235
+ copy_list.append(entity)
236
+
237
+ for k, v in orig_reaction.items():
238
+ reaction_copy[k] = []
239
+ for entity in v:
240
+ if type(entity) == list:
241
+ sub_list = []
242
+ for e in entity:
243
+ append_copy(sub_list, e)
244
+ reaction_copy[k].append(sub_list)
245
+ else:
246
+ append_copy(reaction_copy[k], entity)
247
+
248
+ for graph in graph_copy:
249
+ output = molscribe.convert_graph_to_output([graph], [graph['image']])
250
+ molecule = reaction_copy[graph['key'][0]][graph['key'][1]]
251
+ molecule['smiles'] = output[0]['smiles']
252
+ molecule['molfile'] = output[0]['molfile']
253
+ return reaction_copy
254
+
255
+ def get_sites(tar, ref, ref_site = False):
256
+ rdDepictor.Compute2DCoords(ref)
257
+ rdDepictor.Compute2DCoords(tar)
258
+ idx_pair = rdDepictor.GenerateDepictionMatching2DStructure(tar, ref)
259
+
260
+ in_template = [i[1] for i in idx_pair]
261
+ sites = []
262
+ for i in range(tar.GetNumAtoms()):
263
+ if i not in in_template:
264
+ for j in tar.GetAtomWithIdx(i).GetNeighbors():
265
+ if j.GetIdx() in in_template and j.GetIdx() not in sites:
266
+
267
+ if ref_site: sites.append(idx_pair[in_template.index(j.GetIdx())][0])
268
+ else: sites.append(idx_pair[in_template.index(j.GetIdx())][0])
269
+ return sites
270
+
271
+ def get_atom_mapping(prod_mol, prod_smiles, r_sites_reversed = None):
272
+ # returns prod_mol_to_query which is the mapping of atom indices in prod_mol to the atom indices of the molecule represented by prod_smiles
273
+ prod_template_intermediate = Chem.MolToSmiles(prod_mol)
274
+ prod_template = prod_smiles
275
+
276
+ for r in RGROUP_SMILES:
277
+ if r!='*' and r!='(*)':
278
+ prod_template = prod_template.replace(r, '*')
279
+ prod_template_intermediate = prod_template_intermediate.replace(r, '*')
280
+
281
+ prod_template_intermediate_mol = Chem.MolFromSmiles(prod_template_intermediate)
282
+ prod_template_mol = Chem.MolFromSmiles(prod_template)
283
+
284
+ p = Chem.AdjustQueryParameters.NoAdjustments()
285
+ p.makeDummiesQueries = True
286
+
287
+ prod_template_mol_query = Chem.AdjustQueryProperties(prod_template_mol, p)
288
+ prod_template_intermediate_mol_query = Chem.AdjustQueryProperties(prod_template_intermediate_mol, p)
289
+ rdDepictor.Compute2DCoords(prod_mol)
290
+ rdDepictor.Compute2DCoords(prod_template_mol_query)
291
+ rdDepictor.Compute2DCoords(prod_template_intermediate_mol_query)
292
+ idx_pair = rdDepictor.GenerateDepictionMatching2DStructure(prod_mol, prod_template_intermediate_mol_query)
293
+
294
+ intermdiate_to_prod_mol = {a:b for a,b in idx_pair}
295
+ prod_mol_to_intermediate = {b:a for a,b in idx_pair}
296
+
297
+
298
+ #idx_pair_2 = rdDepictor.GenerateDepictionMatching2DStructure(prod_template_mol_query, prod_template_intermediate_mol_query)
299
+
300
+ #intermediate_to_query = {a:b for a,b in idx_pair_2}
301
+ #query_to_intermediate = {b:a for a,b in idx_pair_2}
302
+
303
+ #prod_mol_to_query = {a:intermediate_to_query[prod_mol_to_intermediate[a]] for a in prod_mol_to_intermediate}
304
+
305
+
306
+ substructs = prod_template_mol_query.GetSubstructMatches(prod_template_intermediate_mol_query, uniquify = False)
307
+
308
+ #idx_pair_2 = rdDepictor.GenerateDepictionMatching2DStructure(prod_template_mol_query, prod_template_intermediate_mol_query)
309
+ for substruct in substructs:
310
+
311
+
312
+ intermediate_to_query = {a:b for a, b in enumerate(substruct)}
313
+ query_to_intermediate = {intermediate_to_query[i]: i for i in intermediate_to_query}
314
+
315
+ prod_mol_to_query = {a:intermediate_to_query[prod_mol_to_intermediate[a]] for a in prod_mol_to_intermediate}
316
+
317
+ good_map = True
318
+ for i in r_sites_reversed:
319
+ if prod_template_mol_query.GetAtomWithIdx(prod_mol_to_query[i]).GetSymbol() not in RGROUP_SMILES:
320
+ good_map = False
321
+ if good_map:
322
+ break
323
+
324
+ return prod_mol_to_query, prod_template_mol_query
325
+
326
+ def clean_corefs(coref_results_dict, idx):
327
+ label_pattern = rf'{re.escape(idx)}[a-zA-Z]+'
328
+ #unclean_pattern = re.escape(idx) + r'\d(?![\d% ])'
329
+ toreturn = {}
330
+ for prod in coref_results_dict:
331
+ has_good_label = False
332
+ for parsed in coref_results_dict[prod]:
333
+ if re.search(label_pattern, parsed):
334
+ has_good_label = True
335
+ if not has_good_label:
336
+ for parsed in coref_results_dict[prod]:
337
+ if idx+'1' in parsed:
338
+ coref_results_dict[prod].append(idx+'l')
339
+ elif idx+'0' in parsed:
340
+ coref_results_dict[prod].append(idx+'o')
341
+ elif idx+'5' in parsed:
342
+ coref_results_dict[prod].append(idx+'s')
343
+ elif idx+'9' in parsed:
344
+ coref_results_dict[prod].append(idx+'g')
345
+
346
+
347
+
348
+ def expand_r_group_label_helper(res, coref_smiles_to_graphs, other_prod, molscribe):
349
+ name = res.group('name')
350
+ group = res.group('group')
351
+ #print(other_prod)
352
+ atoms = coref_smiles_to_graphs[other_prod]['atoms']
353
+ bonds = coref_smiles_to_graphs[other_prod]['bonds']
354
+
355
+ #print(atoms, bonds)
356
+
357
+ graph = {
358
+ 'image': None,
359
+ 'chartok_coords': {
360
+ 'coords': [],
361
+ 'symbols': [],
362
+ },
363
+ 'edges': [],
364
+ 'key': None
365
+ }
366
+ for atom in atoms:
367
+ graph['chartok_coords']['coords'].append((atom['x'], atom['y']))
368
+ graph['chartok_coords']['symbols'].append(atom['atom_symbol'])
369
+ graph['edges'] = [[0] * len(atoms) for _ in range(len(atoms))]
370
+ for bond in bonds:
371
+ i, j = bond['endpoint_atoms']
372
+ graph['edges'][i][j] = BOND_TO_INT[bond['bond_type']]
373
+ graph['edges'][j][i] = BOND_TO_INT[bond['bond_type']]
374
+ for i, symbol in enumerate(graph['chartok_coords']['symbols']):
375
+ if symbol[1:-1] == name:
376
+ graph['chartok_coords']['symbols'][i] = group
377
+
378
+ #print(graph)
379
+ o = molscribe.convert_graph_to_output([graph], [graph['image']])
380
+ return Chem.MolFromSmiles(o[0]['smiles'])
381
+
382
+ def get_r_group_frags_and_substitute(other_prod_mol, query, reactant_mols, reactant_information, parsed, toreturn):
383
+ prod_template_mol_query, r_sites_reversed_new, h_sites, num_r_groups = query
384
+ # we get the substruct matches. note that we set uniquify to false since the order matters for our method
385
+ substructs = other_prod_mol.GetSubstructMatches(prod_template_mol_query, uniquify = False)
386
+
387
+
388
+ #for r in r_sites_reversed:
389
+ # print(prod_template_mol_query.GetAtomWithIdx(prod_mol_to_query[r]).GetSymbol())
390
+
391
+ # for each substruct we create the mapping of the substruct onto the other_mol
392
+ # delete all the molecules in other_mol correspond to the substruct
393
+ # and check if they number of mol frags is equal to number of r groups
394
+ # we do this to make sure we have the correct substruct
395
+ if len(substructs) >= 1:
396
+ for substruct in substructs:
397
+
398
+ query_to_other = {a:b for a,b in enumerate(substruct)}
399
+ other_to_query = {query_to_other[i]:i for i in query_to_other}
400
+
401
+ editable = Chem.EditableMol(other_prod_mol)
402
+ r_site_correspondence = []
403
+ for r in r_sites_reversed_new:
404
+ #get its id in substruct
405
+ substruct_id = query_to_other[r]
406
+ r_site_correspondence.append([substruct_id, r_sites_reversed_new[r]])
407
+
408
+ for idx in tuple(sorted(substruct, reverse = True)):
409
+ if idx not in [query_to_other[i] for i in r_sites_reversed_new]:
410
+ editable.RemoveAtom(idx)
411
+ for r_site in r_site_correspondence:
412
+ if idx < r_site[0]:
413
+ r_site[0]-=1
414
+ other_prod_removed = editable.GetMol()
415
+
416
+ if len(Chem.GetMolFrags(other_prod_removed, asMols = False)) == num_r_groups:
417
+ break
418
+
419
+ # need to compute the sites at which correspond to each r_site_reversed
420
+
421
+ r_site_correspondence.sort(key = lambda x: x[0])
422
+
423
+
424
+ f = []
425
+ ff = []
426
+ frags = Chem.GetMolFrags(other_prod_removed, asMols = True, frags = f, fragsMolAtomMapping = ff)
427
+
428
+ # r_group_information maps r group name --> the fragment/molcule corresponding to the r group and the atom index it should be connected at
429
+ r_group_information = {}
430
+ #tosubtract = 0
431
+ for idx, r_site in enumerate(r_site_correspondence):
432
+
433
+ r_group_information[r_site[1]]= (frags[f[r_site[0]]], ff[f[r_site[0]]].index(r_site[0]))
434
+ #tosubtract += len(ff[idx])
435
+ for r_site in h_sites:
436
+ r_group_information[r_site] = (Chem.MolFromSmiles('[H]'), 0)
437
+
438
+ # now we modify all of the reactants according to the R groups we have found
439
+ # for every reactant we disconnect its r group symbol, and connect it to the r group
440
+ modify_reactants = copy.deepcopy(reactant_mols)
441
+ modified_reactant_smiles = []
442
+ for reactant_idx in reactant_information:
443
+ if len(reactant_information[reactant_idx]) == 0:
444
+ modified_reactant_smiles.append(Chem.MolToSmiles(modify_reactants[reactant_idx]))
445
+ else:
446
+ combined = reactant_mols[reactant_idx]
447
+ if combined.GetNumAtoms() == 1:
448
+ r_group, _, _ = reactant_information[reactant_idx][0]
449
+ modified_reactant_smiles.append(Chem.MolToSmiles(r_group_information[r_group][0]))
450
+ else:
451
+ for r_group, r_index, connect_index in reactant_information[reactant_idx]:
452
+ combined = Chem.CombineMols(combined, r_group_information[r_group][0])
453
+
454
+ editable = Chem.EditableMol(combined)
455
+ atomIdxAdder = reactant_mols[reactant_idx].GetNumAtoms()
456
+ for r_group, r_index, connect_index in reactant_information[reactant_idx]:
457
+ Chem.EditableMol.RemoveBond(editable, r_index, connect_index)
458
+ Chem.EditableMol.AddBond(editable, connect_index, atomIdxAdder + r_group_information[r_group][1], Chem.BondType.SINGLE)
459
+ atomIdxAdder += r_group_information[r_group][0].GetNumAtoms()
460
+ r_indices = [i[1] for i in reactant_information[reactant_idx]]
461
+
462
+ r_indices.sort(reverse = True)
463
+
464
+ for r_index in r_indices:
465
+ Chem.EditableMol.RemoveAtom(editable, r_index)
466
+
467
+ modified_reactant_smiles.append(Chem.MolToSmiles(Chem.MolFromSmiles(Chem.MolToSmiles(editable.GetMol()))))
468
+
469
+ toreturn.append((modified_reactant_smiles, [Chem.MolToSmiles(other_prod_mol)], parsed))
470
+ return True
471
+ else:
472
+ return False
473
+
474
+ def query_enumeration(prod_template_mol_query, r_sites_reversed_new, num_r_groups):
475
+ subsets = generate_subsets(num_r_groups)
476
+
477
+ toreturn = []
478
+
479
+ for subset in subsets:
480
+ r_sites_list = [[i, r_sites_reversed_new[i]] for i in r_sites_reversed_new]
481
+ r_sites_list.sort(key = lambda x: x[0])
482
+ to_edit = Chem.EditableMol(prod_template_mol_query)
483
+
484
+ for entry in subset:
485
+ pos = r_sites_list[entry][0]
486
+ Chem.EditableMol.RemoveBond(to_edit, r_sites_list[entry][0], prod_template_mol_query.GetAtomWithIdx(r_sites_list[entry][0]).GetNeighbors()[0].GetIdx())
487
+ for entry in subset:
488
+ pos = r_sites_list[entry][0]
489
+ Chem.EditableMol.RemoveAtom(to_edit, pos)
490
+
491
+ edited = to_edit.GetMol()
492
+ for entry in subset:
493
+ for i in range(entry + 1, num_r_groups):
494
+ r_sites_list[i][0]-=1
495
+
496
+ new_r_sites = {}
497
+ new_h_sites = set()
498
+ for i in range(num_r_groups):
499
+ if i not in subset:
500
+ new_r_sites[r_sites_list[i][0]] = r_sites_list[i][1]
501
+ else:
502
+ new_h_sites.add(r_sites_list[i][1])
503
+ toreturn.append((edited, new_r_sites, new_h_sites, num_r_groups - len(subset)))
504
+ return toreturn
505
+
506
+ def generate_subsets(n):
507
+ def backtrack(start, subset):
508
+ result.append(subset[:])
509
+ for i in range(start, -1, -1): # Iterate in reverse order
510
+ subset.append(i)
511
+ backtrack(i - 1, subset)
512
+ subset.pop()
513
+
514
+ result = []
515
+ backtrack(n - 1, [])
516
+ return sorted(result, key=lambda x: (-len(x), x), reverse=True)
517
+
518
+ def backout(results, coref_results, molscribe):
519
+
520
+ toreturn = []
521
+
522
+ if not results or not results[0]['reactions'] or not coref_results:
523
+ return toreturn
524
+
525
+ try:
526
+ reactants = results[0]['reactions'][0]['reactants']
527
+ products = [i['smiles'] for i in results[0]['reactions'][0]['products']]
528
+ coref_results_dict = {coref_results[0]['bboxes'][coref[0]]['smiles']: coref_results[0]['bboxes'][coref[1]]['text'] for coref in coref_results[0]['corefs']}
529
+ coref_smiles_to_graphs = {coref_results[0]['bboxes'][coref[0]]['smiles']: coref_results[0]['bboxes'][coref[0]] for coref in coref_results[0]['corefs']}
530
+
531
+
532
+ if len(products) == 1:
533
+ if products[0] not in coref_results_dict:
534
+ print("Warning: No Label Parsed")
535
+ return
536
+ product_labels = coref_results_dict[products[0]]
537
+ prod = products[0]
538
+ label_idx = product_labels[0]
539
+ '''
540
+ if len(product_labels) == 1:
541
+ # get the coreference label of the product molecule
542
+ label_idx = product_labels[0]
543
+ else:
544
+ print("Warning: Malformed Label Parsed.")
545
+ return
546
+ '''
547
+ else:
548
+ print("Warning: More than one product detected")
549
+ return
550
+
551
+ # format the regular expression for labels that correspond to the product label
552
+ numbers = re.findall(r'\d+', label_idx)
553
+ label_idx = numbers[0] if len(numbers) > 0 else ""
554
+ label_pattern = rf'{re.escape(label_idx)}[a-zA-Z]+'
555
+
556
+
557
+ prod_smiles = prod
558
+ prod_mol = Chem.MolFromMolBlock(results[0]['reactions'][0]['products'][0]['molfile'])
559
+
560
+ # identify the atom indices of the R groups in the product tempalte
561
+ h_counter = 0
562
+ r_sites = {}
563
+ for idx, atom in enumerate(results[0]['reactions'][0]['products'][0]['atoms']):
564
+ sym = atom['atom_symbol']
565
+ if sym == '[H]':
566
+ h_counter += 1
567
+ if sym[0] == '[':
568
+ sym = sym[1:-1]
569
+ if sym[0] == 'R' and sym[1:].isdigit():
570
+ sym = sym[1:]+"*"
571
+ sym = f'[{sym}]'
572
+ if sym in RGROUP_SYMBOLS:
573
+ if sym not in r_sites:
574
+ r_sites[sym] = [idx-h_counter]
575
+ else:
576
+ r_sites[sym].append(idx-h_counter)
577
+
578
+ r_sites_reversed = {}
579
+ for sym in r_sites:
580
+ for pos in r_sites[sym]:
581
+ r_sites_reversed[pos] = sym
582
+
583
+ num_r_groups = len(r_sites_reversed)
584
+
585
+ #prepare the product template and get the associated mapping
586
+
587
+ prod_mol_to_query, prod_template_mol_query = get_atom_mapping(prod_mol, prod_smiles, r_sites_reversed = r_sites_reversed)
588
+
589
+ reactant_mols = []
590
+
591
+
592
+ #--------------process the reactants-----------------
593
+
594
+ reactant_information = {} #index of relevant reaction --> [[R group name, atom index of R group, atom index of R group connection], ...]
595
+
596
+ for idx, reactant in enumerate(reactants):
597
+ reactant_information[idx] = []
598
+ reactant_mols.append(Chem.MolFromSmiles(reactant['smiles']))
599
+ has_r = False
600
+
601
+ r_sites_reactant = {}
602
+
603
+ h_counter = 0
604
+
605
+ for a_idx, atom in enumerate(reactant['atoms']):
606
+
607
+ #go through all atoms and check if they are an R group, if so add it to reactant information
608
+ sym = atom['atom_symbol']
609
+ if sym == '[H]':
610
+ h_counter += 1
611
+ if sym[0] == '[':
612
+ sym = sym[1:-1]
613
+ if sym[0] == 'R' and sym[1:].isdigit():
614
+ sym = sym[1:]+"*"
615
+ sym = f'[{sym}]'
616
+ if sym in r_sites:
617
+ if reactant_mols[-1].GetNumAtoms()==1:
618
+ reactant_information[idx].append([sym, -1, -1])
619
+ else:
620
+ has_r = True
621
+ reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
622
+ reactant_information[idx].append([sym, a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
623
+ r_sites_reactant[sym] = a_idx-h_counter
624
+ elif sym == '[1*]' and '[7*]' in r_sites:
625
+ if reactant_mols[-1].GetNumAtoms()==1:
626
+ reactant_information[idx].append(['[7*]', -1, -1])
627
+ else:
628
+ has_r = True
629
+ reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
630
+ reactant_information[idx].append(['[7*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
631
+ r_sites_reactant['[7*]'] = a_idx-h_counter
632
+ elif sym == '[7*]' and '[1*]' in r_sites:
633
+ if reactant_mols[-1].GetNumAtoms()==1:
634
+ reactant_information[idx].append(['[1*]', -1, -1])
635
+ else:
636
+ has_r = True
637
+ reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
638
+ reactant_information[idx].append(['[1*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
639
+ r_sites_reactant['[1*]'] = a_idx-h_counter
640
+
641
+
642
+
643
+ elif sym == '[1*]' and '[Rf]' in r_sites:
644
+ if reactant_mols[-1].GetNumAtoms()==1:
645
+ reactant_information[idx].append(['[Rf]', -1, -1])
646
+ else:
647
+ has_r = True
648
+ reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
649
+ reactant_information[idx].append(['[Rf]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
650
+ r_sites_reactant['[Rf]'] = a_idx-h_counter
651
+
652
+ elif sym == '[Rf]' and '[1*]' in r_sites:
653
+ if reactant_mols[-1].GetNumAtoms()==1:
654
+ reactant_information[idx].append(['[1*]', -1, -1])
655
+ else:
656
+ has_r = True
657
+ reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
658
+ reactant_information[idx].append(['[1*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
659
+ r_sites_reactant['[1*]'] = a_idx-h_counter
660
+
661
+
662
+ r_sites_reversed_reactant = {r_sites_reactant[i]: i for i in r_sites_reactant}
663
+ # if the reactant had r groups, we had to use the molecule generated from the MolBlock.
664
+ # but the molblock may have unexpanded elemeents that are not R groups
665
+ # so we have to map back the r group indices in the molblock version to the full molecule generated by the smiles
666
+ # and adjust the indices of the r groups accordingly
667
+ if has_r:
668
+ #get the mapping
669
+ reactant_mol_to_query, _ = get_atom_mapping(reactant_mols[-1], reactant['smiles'], r_sites_reversed = r_sites_reversed_reactant)
670
+
671
+ #make the adjustment
672
+ for info in reactant_information[idx]:
673
+ info[1] = reactant_mol_to_query[info[1]]
674
+ info[2] = reactant_mol_to_query[info[2]]
675
+ reactant_mols[-1] = Chem.MolFromSmiles(reactant['smiles'])
676
+
677
+ #go through all the molecules in the coreference
678
+
679
+ clean_corefs(coref_results_dict, label_idx)
680
+
681
+ for other_prod in coref_results_dict:
682
+
683
+ #check if they match the product label regex
684
+ found_good_label = False
685
+ for parsed in coref_results_dict[other_prod]:
686
+ if re.search(label_pattern, parsed) and not found_good_label:
687
+ found_good_label = True
688
+ other_prod_mol = Chem.MolFromSmiles(other_prod)
689
+
690
+ if other_prod != prod_smiles and other_prod_mol is not None:
691
+
692
+ #check if there are R groups to be resolved in the target product
693
+
694
+ all_other_prod_mols = []
695
+
696
+ r_group_sub_pattern = re.compile('(?P<name>[RXY]\d?)[ ]*=[ ]*(?P<group>\w+)')
697
+
698
+ for parsed_labels in coref_results_dict[other_prod]:
699
+ res = r_group_sub_pattern.search(parsed_labels)
700
+
701
+ if res is not None:
702
+ all_other_prod_mols.append((expand_r_group_label_helper(res, coref_smiles_to_graphs, other_prod, molscribe), parsed + parsed_labels))
703
+
704
+ if len(all_other_prod_mols) == 0:
705
+ if other_prod_mol is not None:
706
+ all_other_prod_mols.append((other_prod_mol, parsed))
707
+
708
+
709
+
710
+
711
+ for other_prod_mol, parsed in all_other_prod_mols:
712
+
713
+ other_prod_frags = Chem.GetMolFrags(other_prod_mol, asMols = True)
714
+
715
+ for other_prod_frag in other_prod_frags:
716
+ substructs = other_prod_frag.GetSubstructMatches(prod_template_mol_query, uniquify = False)
717
+
718
+ if len(substructs)>0:
719
+ other_prod_mol = other_prod_frag
720
+ break
721
+ r_sites_reversed_new = {prod_mol_to_query[r]: r_sites_reversed[r] for r in r_sites_reversed}
722
+
723
+ queries = query_enumeration(prod_template_mol_query, r_sites_reversed_new, num_r_groups)
724
+
725
+ matched = False
726
+
727
+ for query in queries:
728
+ if not matched:
729
+ try:
730
+ matched = get_r_group_frags_and_substitute(other_prod_mol, query, reactant_mols, reactant_information, parsed, toreturn)
731
+ except:
732
+ pass
733
+
734
+ except:
735
+ pass
736
+
737
+
738
+ return toreturn
739
+
740
+
741
+ def backout_without_coref(results, coref_results, coref_results_dict, coref_smiles_to_graphs, molscribe):
742
+
743
+ toreturn = []
744
+
745
+ if not results or not results[0]['reactions'] or not coref_results:
746
+ return toreturn
747
+
748
+ try:
749
+ reactants = results[0]['reactions'][0]['reactants']
750
+ products = [i['smiles'] for i in results[0]['reactions'][0]['products']]
751
+ coref_results_dict = coref_results_dict
752
+ coref_smiles_to_graphs = coref_smiles_to_graphs
753
+
754
+
755
+ if len(products) == 1:
756
+ if products[0] not in coref_results_dict:
757
+ print("Warning: No Label Parsed")
758
+ return
759
+ product_labels = coref_results_dict[products[0]]
760
+ prod = products[0]
761
+ label_idx = product_labels[0]
762
+ '''
763
+ if len(product_labels) == 1:
764
+ # get the coreference label of the product molecule
765
+ label_idx = product_labels[0]
766
+ else:
767
+ print("Warning: Malformed Label Parsed.")
768
+ return
769
+ '''
770
+ else:
771
+ print("Warning: More than one product detected")
772
+ return
773
+
774
+ # format the regular expression for labels that correspond to the product label
775
+ numbers = re.findall(r'\d+', label_idx)
776
+ label_idx = numbers[0] if len(numbers) > 0 else ""
777
+ label_pattern = rf'{re.escape(label_idx)}[a-zA-Z]+'
778
+
779
+
780
+ prod_smiles = prod
781
+ prod_mol = Chem.MolFromMolBlock(results[0]['reactions'][0]['products'][0]['molfile'])
782
+
783
+ # identify the atom indices of the R groups in the product tempalte
784
+ h_counter = 0
785
+ r_sites = {}
786
+ for idx, atom in enumerate(results[0]['reactions'][0]['products'][0]['atoms']):
787
+ sym = atom['atom_symbol']
788
+ if sym == '[H]':
789
+ h_counter += 1
790
+ if sym[0] == '[':
791
+ sym = sym[1:-1]
792
+ if sym[0] == 'R' and sym[1:].isdigit():
793
+ sym = sym[1:]+"*"
794
+ sym = f'[{sym}]'
795
+ if sym in RGROUP_SYMBOLS:
796
+ if sym not in r_sites:
797
+ r_sites[sym] = [idx-h_counter]
798
+ else:
799
+ r_sites[sym].append(idx-h_counter)
800
+
801
+ r_sites_reversed = {}
802
+ for sym in r_sites:
803
+ for pos in r_sites[sym]:
804
+ r_sites_reversed[pos] = sym
805
+
806
+ num_r_groups = len(r_sites_reversed)
807
+
808
+ #prepare the product template and get the associated mapping
809
+
810
+ prod_mol_to_query, prod_template_mol_query = get_atom_mapping(prod_mol, prod_smiles, r_sites_reversed = r_sites_reversed)
811
+
812
+ reactant_mols = []
813
+
814
+
815
+ #--------------process the reactants-----------------
816
+
817
+ reactant_information = {} #index of relevant reaction --> [[R group name, atom index of R group, atom index of R group connection], ...]
818
+
819
+ for idx, reactant in enumerate(reactants):
820
+ reactant_information[idx] = []
821
+ reactant_mols.append(Chem.MolFromSmiles(reactant['smiles']))
822
+ has_r = False
823
+
824
+ r_sites_reactant = {}
825
+
826
+ h_counter = 0
827
+
828
+ for a_idx, atom in enumerate(reactant['atoms']):
829
+
830
+ #go through all atoms and check if they are an R group, if so add it to reactant information
831
+ sym = atom['atom_symbol']
832
+ if sym == '[H]':
833
+ h_counter += 1
834
+ if sym[0] == '[':
835
+ sym = sym[1:-1]
836
+ if sym[0] == 'R' and sym[1:].isdigit():
837
+ sym = sym[1:]+"*"
838
+ sym = f'[{sym}]'
839
+ if sym in r_sites:
840
+ if reactant_mols[-1].GetNumAtoms()==1:
841
+ reactant_information[idx].append([sym, -1, -1])
842
+ else:
843
+ has_r = True
844
+ reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
845
+ reactant_information[idx].append([sym, a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
846
+ r_sites_reactant[sym] = a_idx-h_counter
847
+ elif sym == '[1*]' and '[7*]' in r_sites:
848
+ if reactant_mols[-1].GetNumAtoms()==1:
849
+ reactant_information[idx].append(['[7*]', -1, -1])
850
+ else:
851
+ has_r = True
852
+ reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
853
+ reactant_information[idx].append(['[7*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
854
+ r_sites_reactant['[7*]'] = a_idx-h_counter
855
+ elif sym == '[7*]' and '[1*]' in r_sites:
856
+ if reactant_mols[-1].GetNumAtoms()==1:
857
+ reactant_information[idx].append(['[1*]', -1, -1])
858
+ else:
859
+ has_r = True
860
+ reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
861
+ reactant_information[idx].append(['[1*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
862
+ r_sites_reactant['[1*]'] = a_idx-h_counter
863
+
864
+ elif sym == '[1*]' and '[Rf]' in r_sites:
865
+ if reactant_mols[-1].GetNumAtoms()==1:
866
+ reactant_information[idx].append(['[Rf]', -1, -1])
867
+ else:
868
+ has_r = True
869
+ reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
870
+ reactant_information[idx].append(['[Rf]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
871
+ r_sites_reactant['[Rf]'] = a_idx-h_counter
872
+
873
+ elif sym == '[Rf]' and '[1*]' in r_sites:
874
+ if reactant_mols[-1].GetNumAtoms()==1:
875
+ reactant_information[idx].append(['[1*]', -1, -1])
876
+ else:
877
+ has_r = True
878
+ reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
879
+ reactant_information[idx].append(['[1*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
880
+ r_sites_reactant['[1*]'] = a_idx-h_counter
881
+
882
+ r_sites_reversed_reactant = {r_sites_reactant[i]: i for i in r_sites_reactant}
883
+ # if the reactant had r groups, we had to use the molecule generated from the MolBlock.
884
+ # but the molblock may have unexpanded elemeents that are not R groups
885
+ # so we have to map back the r group indices in the molblock version to the full molecule generated by the smiles
886
+ # and adjust the indices of the r groups accordingly
887
+ if has_r:
888
+ #get the mapping
889
+ reactant_mol_to_query, _ = get_atom_mapping(reactant_mols[-1], reactant['smiles'], r_sites_reversed = r_sites_reversed_reactant)
890
+
891
+ #make the adjustment
892
+ for info in reactant_information[idx]:
893
+ info[1] = reactant_mol_to_query[info[1]]
894
+ info[2] = reactant_mol_to_query[info[2]]
895
+ reactant_mols[-1] = Chem.MolFromSmiles(reactant['smiles'])
896
+
897
+ #go through all the molecules in the coreference
898
+
899
+ clean_corefs(coref_results_dict, label_idx)
900
+
901
+ for other_prod in coref_results_dict:
902
+
903
+ #check if they match the product label regex
904
+ found_good_label = False
905
+ for parsed in coref_results_dict[other_prod]:
906
+ if re.search(label_pattern, parsed) and not found_good_label:
907
+ found_good_label = True
908
+ other_prod_mol = Chem.MolFromSmiles(other_prod)
909
+
910
+ if other_prod != prod_smiles and other_prod_mol is not None:
911
+
912
+ #check if there are R groups to be resolved in the target product
913
+
914
+ all_other_prod_mols = []
915
+
916
+ r_group_sub_pattern = re.compile('(?P<name>[RXY]\d?)[ ]*=[ ]*(?P<group>\w+)')
917
+
918
+ for parsed_labels in coref_results_dict[other_prod]:
919
+ res = r_group_sub_pattern.search(parsed_labels)
920
+
921
+ if res is not None:
922
+ all_other_prod_mols.append((expand_r_group_label_helper(res, coref_smiles_to_graphs, other_prod, molscribe), parsed + parsed_labels))
923
+
924
+ if len(all_other_prod_mols) == 0:
925
+ if other_prod_mol is not None:
926
+ all_other_prod_mols.append((other_prod_mol, parsed))
927
+
928
+
929
+
930
+
931
+ for other_prod_mol, parsed in all_other_prod_mols:
932
+
933
+ other_prod_frags = Chem.GetMolFrags(other_prod_mol, asMols = True)
934
+
935
+ for other_prod_frag in other_prod_frags:
936
+ substructs = other_prod_frag.GetSubstructMatches(prod_template_mol_query, uniquify = False)
937
+
938
+ if len(substructs)>0:
939
+ other_prod_mol = other_prod_frag
940
+ break
941
+ r_sites_reversed_new = {prod_mol_to_query[r]: r_sites_reversed[r] for r in r_sites_reversed}
942
+
943
+ queries = query_enumeration(prod_template_mol_query, r_sites_reversed_new, num_r_groups)
944
+
945
+ matched = False
946
+
947
+ for query in queries:
948
+ if not matched:
949
+ try:
950
+ matched = get_r_group_frags_and_substitute(other_prod_mol, query, reactant_mols, reactant_information, parsed, toreturn)
951
+ except:
952
+ pass
953
+
954
+ except:
955
+ pass
956
+
957
+
958
+ return toreturn
959
+
960
+
961
+
962
+ def associate_corefs(results, results_coref):
963
+ coref_smiles = {}
964
+ idx_pattern = r'\b\d+[a-zA-Z]{0,2}\b'
965
+ for result_coref in results_coref:
966
+ bboxes, corefs = result_coref['bboxes'], result_coref['corefs']
967
+ for coref in corefs:
968
+ mol, idt = coref[0], coref[1]
969
+ if len(bboxes[idt]['text']) > 0:
970
+ for text in bboxes[idt]['text']:
971
+ matches = re.findall(idx_pattern, text)
972
+ for match in matches:
973
+ coref_smiles[match] = bboxes[mol]['smiles']
974
+
975
+ for page in results:
976
+ for reactions in page['reactions']:
977
+ for reaction in reactions['reactions']:
978
+ if 'Reactants' in reaction:
979
+ if isinstance(reaction['Reactants'], tuple):
980
+ if reaction['Reactants'][0] in coref_smiles:
981
+ reaction['Reactants'] = (f'{reaction["Reactants"][0]} ({coref_smiles[reaction["Reactants"][0]]})', reaction['Reactants'][1], reaction['Reactants'][2])
982
+ else:
983
+ for idx, compound in enumerate(reaction['Reactants']):
984
+ if compound[0] in coref_smiles:
985
+ reaction['Reactants'][idx] = (f'{compound[0]} ({coref_smiles[compound[0]]})', compound[1], compound[2])
986
+ if 'Product' in reaction:
987
+ if isinstance(reaction['Product'], tuple):
988
+ if reaction['Product'][0] in coref_smiles:
989
+ reaction['Product'] = (f'{reaction["Product"][0]} ({coref_smiles[reaction["Product"][0]]})', reaction['Product'][1], reaction['Product'][2])
990
+ else:
991
+ for idx, compound in enumerate(reaction['Product']):
992
+ if compound[0] in coref_smiles:
993
+ reaction['Product'][idx] = (f'{compound[0]} ({coref_smiles[compound[0]]})', compound[1], compound[2])
994
+ return results
995
+
996
+
997
+ def expand_reactions_with_backout(initial_results, results_coref, molscribe):
998
+ idx_pattern = r'^\d+[a-zA-Z]{0,2}$'
999
+ for reactions, result_coref in zip(initial_results, results_coref):
1000
+ if not reactions['reactions']:
1001
+ continue
1002
+ try:
1003
+ backout_results = backout([reactions], [result_coref], molscribe)
1004
+ except Exception:
1005
+ continue
1006
+ conditions = reactions['reactions'][0]['conditions']
1007
+ idt_to_smiles = {}
1008
+ if not backout_results:
1009
+ continue
1010
+
1011
+ for reactants, products, idt in backout_results:
1012
+ reactions['reactions'].append({
1013
+ 'reactants': [{'category': '[Mol]', 'molfile': None, 'smiles': reactant} for reactant in reactants],
1014
+ 'conditions': conditions[:],
1015
+ 'products': [{'category': '[Mol]', 'molfile': None, 'smiles': product} for product in products]
1016
+ })
1017
+ return initial_results
1018
+
examples/exp.png ADDED

Git LFS Details

  • SHA256: 3ce344ed33ff77f45d6e87a29e91426c3444ee9b58a8b10086ce3483a1ad2a2e
  • Pointer size: 131 Bytes
  • Size of remote file: 696 kB
examples/image.webp ADDED
examples/rdkit.png ADDED
examples/reaction1.jpg ADDED
examples/reaction2.png ADDED
examples/reaction3.png ADDED
examples/reaction4.png ADDED

Git LFS Details

  • SHA256: 341a3b9f6b24b3fe3793186ec198cf1171ffb84bca6c0316052f25e17c0eeb55
  • Pointer size: 131 Bytes
  • Size of remote file: 232 kB
get_molecular_agent.py ADDED
@@ -0,0 +1,599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ import json
4
+ from chemietoolkit import ChemIEToolkit
5
+ import cv2
6
+ from PIL import Image
7
+ import json
8
+ import sys
9
+ #sys.path.append('./RxnScribe-main/')
10
+ import torch
11
+ from rxnscribe import RxnScribe
12
+ import json
13
+ import sys
14
+ import torch
15
+ import json
16
+ model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
17
+ from molscribe.chemistry import _convert_graph_to_smiles
18
+ import base64
19
+ import torch
20
+ import json
21
+ from PIL import Image
22
+ import numpy as np
23
+ from chemietoolkit import ChemIEToolkit, utils
24
+ from openai import AzureOpenAI
25
+ import os
26
+
27
+
28
+
29
+
30
+ ckpt_path = "./pix2seq_reaction_full.ckpt"
31
+ model1 = RxnScribe(ckpt_path, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
32
+ device = torch.device(('cuda' if torch.cuda.is_available() else 'cpu'))
33
+
34
+ model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
35
+
36
+ def get_multi_molecular(image_path: str) -> list:
37
+ '''Returns a list of reactions extracted from the image.'''
38
+ # 打开图像文件
39
+ image = Image.open(image_path).convert('RGB')
40
+
41
+ # 将图像作为输入传递给模型
42
+ coref_results = model.extract_molecule_corefs_from_figures([image])
43
+ for item in coref_results:
44
+ for bbox in item.get("bboxes", []):
45
+ for key in ["category", "molfile", "symbols", 'atoms', "bonds", 'category_id', 'score', 'corefs']: #'atoms'
46
+ bbox.pop(key, None) # 安全地移除键
47
+ print(json.dumps(coref_results))
48
+ # 返回反应列表,使用 json.dumps 进行格式化
49
+
50
+ return json.dumps(coref_results)
51
+
52
+ def get_multi_molecular_text_to_correct(image_path: str) -> list:
53
+ '''Returns a list of reactions extracted from the image.'''
54
+ # 打开图像文件
55
+ image = Image.open(image_path).convert('RGB')
56
+
57
+ # 将图像作为输入传递给模型
58
+ coref_results = model.extract_molecule_corefs_from_figures([image])
59
+ for item in coref_results:
60
+ for bbox in item.get("bboxes", []):
61
+ for key in ["category", "bbox", "molfile", "symbols", 'atoms', "bonds", 'category_id', 'score', 'corefs']: #'atoms'
62
+ bbox.pop(key, None) # 安全地移除键
63
+ print(json.dumps(coref_results))
64
+ # 返回反应列表,使用 json.dumps 进行格式化
65
+
66
+ return json.dumps(coref_results)
67
+
68
+ def get_multi_molecular_text_to_correct_withatoms(image_path: str) -> list:
69
+ '''Returns a list of reactions extracted from the image.'''
70
+ # 打开图像文件
71
+ image = Image.open(image_path).convert('RGB')
72
+
73
+ # 将图像作为输入传递给模型
74
+ coref_results = model.extract_molecule_corefs_from_figures([image])
75
+ for item in coref_results:
76
+ for bbox in item.get("bboxes", []):
77
+ for key in ["coords","edges","molfile", 'atoms', "bonds", 'category_id', 'score', 'corefs']: #'atoms'
78
+ bbox.pop(key, None) # 安全地移除键
79
+ print(json.dumps(coref_results))
80
+ # 返回反应列表,使用 json.dumps 进行格式化
81
+ return json.dumps(coref_results)
82
+
83
+
84
+
85
+
86
+
87
+
88
+ def process_reaction_image_with_multiple_products_and_text(image_path: str) -> dict:
89
+ """
90
+
91
+
92
+ Args:
93
+ image_path (str): 图像文件路径。
94
+
95
+ Returns:
96
+ dict: 整理后的反应数据,包括反应物、产物和反应模板。
97
+ """
98
+ # 配置 API Key 和 Azure Endpoint
99
+ api_key = os.getenv("CHEMEAGLE_API_KEY")
100
+ if not api_key:
101
+ raise RuntimeError("Missing CHEMEAGLE_API_KEY environment variable")
102
+ # 替换为实际的 API Key
103
+ azure_endpoint = "https://hkust.azure-api.net" # 替换为实际的 Azure Endpoint
104
+
105
+
106
+ model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
107
+ client = AzureOpenAI(
108
+ api_key=api_key,
109
+ api_version='2024-06-01',
110
+ azure_endpoint=azure_endpoint
111
+ )
112
+
113
+ # 加载图像并编码为 Base64
114
+ def encode_image(image_path: str):
115
+ with open(image_path, "rb") as image_file:
116
+ return base64.b64encode(image_file.read()).decode('utf-8')
117
+
118
+ base64_image = encode_image(image_path)
119
+
120
+ # GPT 工具调用配置
121
+ tools = [
122
+ {
123
+ 'type': 'function',
124
+ 'function': {
125
+ 'name': 'get_multi_molecular_text_to_correct_withatoms',
126
+ 'description': 'Extracts the SMILES string, the symbols set, and the text coref of all molecular images in a table-reaction image and ready to be correct.',
127
+ 'parameters': {
128
+ 'type': 'object',
129
+ 'properties': {
130
+ 'image_path': {
131
+ 'type': 'string',
132
+ 'description': 'The path to the reaction image.',
133
+ },
134
+ },
135
+ 'required': ['image_path'],
136
+ 'additionalProperties': False,
137
+ },
138
+ },
139
+ },
140
+
141
+ ]
142
+
143
+ # 提供给 GPT 的消息内容
144
+ with open('./prompt_getmolecular.txt', 'r') as prompt_file:
145
+ prompt = prompt_file.read()
146
+ messages = [
147
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
148
+ {
149
+ 'role': 'user',
150
+ 'content': [
151
+ {'type': 'text', 'text': prompt},
152
+ {'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}}
153
+ ]
154
+ }
155
+ ]
156
+
157
+ # 调用 GPT 接口
158
+ response = client.chat.completions.create(
159
+ model = 'gpt-4o',
160
+ temperature = 0,
161
+ response_format={ 'type': 'json_object' },
162
+ messages = [
163
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
164
+ {
165
+ 'role': 'user',
166
+ 'content': [
167
+ {
168
+ 'type': 'text',
169
+ 'text': prompt
170
+ },
171
+ {
172
+ 'type': 'image_url',
173
+ 'image_url': {
174
+ 'url': f'data:image/png;base64,{base64_image}'
175
+ }
176
+ }
177
+ ]},
178
+ ],
179
+ tools = tools)
180
+
181
+ # Step 1: 工具映射表
182
+ TOOL_MAP = {
183
+ 'get_multi_molecular_text_to_correct_withatoms': get_multi_molecular_text_to_correct_withatoms,
184
+ }
185
+
186
+ # Step 2: 处理多个工具调用
187
+ tool_calls = response.choices[0].message.tool_calls
188
+ results = []
189
+
190
+ # 遍历每个工具调用
191
+ for tool_call in tool_calls:
192
+ tool_name = tool_call.function.name
193
+ tool_arguments = tool_call.function.arguments
194
+ tool_call_id = tool_call.id
195
+
196
+ tool_args = json.loads(tool_arguments)
197
+
198
+ if tool_name in TOOL_MAP:
199
+ # 调用工具并获取结果
200
+ tool_result = TOOL_MAP[tool_name](image_path)
201
+ else:
202
+ raise ValueError(f"Unknown tool called: {tool_name}")
203
+
204
+ # 保存每个工具调用结果
205
+ results.append({
206
+ 'role': 'tool',
207
+ 'content': json.dumps({
208
+ 'image_path': image_path,
209
+ f'{tool_name}':(tool_result),
210
+ }),
211
+ 'tool_call_id': tool_call_id,
212
+ })
213
+
214
+
215
+ # Prepare the chat completion payload
216
+ completion_payload = {
217
+ 'model': 'gpt-4o',
218
+ 'messages': [
219
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
220
+ {
221
+ 'role': 'user',
222
+ 'content': [
223
+ {
224
+ 'type': 'text',
225
+ 'text': prompt
226
+ },
227
+ {
228
+ 'type': 'image_url',
229
+ 'image_url': {
230
+ 'url': f'data:image/png;base64,{base64_image}'
231
+ }
232
+ }
233
+ ]
234
+ },
235
+ response.choices[0].message,
236
+ *results
237
+ ],
238
+ }
239
+
240
+ # Generate new response
241
+ response = client.chat.completions.create(
242
+ model=completion_payload["model"],
243
+ messages=completion_payload["messages"],
244
+ response_format={ 'type': 'json_object' },
245
+ temperature=0
246
+ )
247
+
248
+
249
+
250
+ # 获取 GPT 生成的结果
251
+ gpt_output = [json.loads(response.choices[0].message.content)]
252
+
253
+
254
+ def get_multi_molecular(image_path: str) -> list:
255
+ '''Returns a list of reactions extracted from the image.'''
256
+ # 打开图像文件
257
+ image = Image.open(image_path).convert('RGB')
258
+
259
+ # 将图像作为输入传递给模型
260
+ coref_results = model.extract_molecule_corefs_from_figures([image])
261
+ return coref_results
262
+
263
+
264
+ coref_results = get_multi_molecular(image_path)
265
+
266
+
267
+ def update_symbols_in_atoms(input1, input2):
268
+ """
269
+ 用 input1 中更新后的 'symbols' 替换 input2 中对应 bboxes 的 'symbols',并同步更新 'atoms' 的 'atom_symbol'。
270
+ 假设 input1 和 input2 的结构一致。
271
+ """
272
+ for item1, item2 in zip(input1, input2):
273
+ bboxes1 = item1.get('bboxes', [])
274
+ bboxes2 = item2.get('bboxes', [])
275
+
276
+ if len(bboxes1) != len(bboxes2):
277
+ print("Warning: Mismatched number of bboxes!")
278
+ continue
279
+
280
+ for bbox1, bbox2 in zip(bboxes1, bboxes2):
281
+ # 更新 symbols
282
+ if 'symbols' in bbox1:
283
+ bbox2['symbols'] = bbox1['symbols'] # 更新 symbols
284
+
285
+ # 更新 atoms 的 atom_symbol
286
+ if 'symbols' in bbox1 and 'atoms' in bbox2:
287
+ symbols = bbox1['symbols']
288
+ atoms = bbox2.get('atoms', [])
289
+
290
+ # 确保 symbols 和 atoms 的长度一致
291
+ if len(symbols) != len(atoms):
292
+ print(f"Warning: Mismatched symbols and atoms in bbox {bbox1.get('bbox')}!")
293
+ continue
294
+
295
+ for atom, symbol in zip(atoms, symbols):
296
+ atom['atom_symbol'] = symbol # 更新 atom_symbol
297
+
298
+ return input2
299
+
300
+
301
+ input2_updated = update_symbols_in_atoms(gpt_output, coref_results)
302
+
303
+
304
+
305
+
306
+
307
+ def update_smiles_and_molfile(input_data, conversion_function):
308
+ """
309
+ 使用更新后的 'symbols'、'coords' 和 'edges' 调用 `conversion_function` 生成新的 'smiles' 和 'molfile',
310
+ 并替换到原数据结构中。
311
+
312
+ 参数:
313
+ - input_data: 包含 bboxes 的嵌套数据结构
314
+ - conversion_function: 函数,接受 'coords', 'symbols', 'edges' 并返回 (new_smiles, new_molfile, _)
315
+
316
+ 返回:
317
+ - 更新后的数据结构
318
+ """
319
+ for item in input_data:
320
+ for bbox in item.get('bboxes', []):
321
+ # 检查必需的键是否存在
322
+ if all(key in bbox for key in ['coords', 'symbols', 'edges']):
323
+ coords = bbox['coords']
324
+ symbols = bbox['symbols']
325
+ edges = bbox['edges']
326
+
327
+ # 调用转换函数生成新的 'smiles' 和 'molfile'
328
+ new_smiles, new_molfile, _ = conversion_function(coords, symbols, edges)
329
+ print(f" Generated 'smiles': {new_smiles}")
330
+
331
+ # 替换旧的 'smiles' 和 'molfile'
332
+ bbox['smiles'] = new_smiles
333
+ bbox['molfile'] = new_molfile
334
+
335
+ return input_data
336
+
337
+ updated_data = update_smiles_and_molfile(input2_updated, _convert_graph_to_smiles)
338
+
339
+ return updated_data
340
+
341
+
342
+
343
+
344
+
345
+
346
+
347
+
348
+
349
+ def process_reaction_image_with_multiple_products_and_text_correctR(image_path: str) -> dict:
350
+ """
351
+
352
+
353
+ Args:
354
+ image_path (str): 图像文件路径。
355
+
356
+ Returns:
357
+ dict: 整理后的反应数据,包括反应物、产物和反应模板。
358
+ """
359
+ # 配置 API Key 和 Azure Endpoint
360
+ api_key = os.getenv("CHEMEAGLE_API_KEY")
361
+ if not api_key:
362
+ raise RuntimeError("Missing CHEMEAGLE_API_KEY environment variable")
363
+ # 替换为实际的 API Key
364
+ azure_endpoint = "https://hkust.azure-api.net" # 替换为实际的 Azure Endpoint
365
+ model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
366
+ client = AzureOpenAI(
367
+ api_key=api_key,
368
+ api_version='2024-06-01',
369
+ azure_endpoint=azure_endpoint
370
+ )
371
+
372
+ # 加载图像并编码为 Base64
373
+ def encode_image(image_path: str):
374
+ with open(image_path, "rb") as image_file:
375
+ return base64.b64encode(image_file.read()).decode('utf-8')
376
+
377
+ base64_image = encode_image(image_path)
378
+
379
+ # GPT 工具调用配置
380
+ tools = [
381
+ {
382
+ 'type': 'function',
383
+ 'function': {
384
+ 'name': 'get_multi_molecular_text_to_correct_withatoms',
385
+ 'description': 'Extracts the SMILES string, the symbols set, and the text coref of all molecular images in a table-reaction image and ready to be correct.',
386
+ 'parameters': {
387
+ 'type': 'object',
388
+ 'properties': {
389
+ 'image_path': {
390
+ 'type': 'string',
391
+ 'description': 'The path to the reaction image.',
392
+ },
393
+ },
394
+ 'required': ['image_path'],
395
+ 'additionalProperties': False,
396
+ },
397
+ },
398
+ },
399
+
400
+ ]
401
+
402
+ # 提供给 GPT 的消息内容
403
+ with open('./prompt_getmolecular_correctR.txt', 'r') as prompt_file:
404
+ prompt = prompt_file.read()
405
+ messages = [
406
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
407
+ {
408
+ 'role': 'user',
409
+ 'content': [
410
+ {'type': 'text', 'text': prompt},
411
+ {'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}}
412
+ ]
413
+ }
414
+ ]
415
+
416
+ # 调用 GPT 接口
417
+ response = client.chat.completions.create(
418
+ model = 'gpt-4o',
419
+ temperature = 0,
420
+ response_format={ 'type': 'json_object' },
421
+ messages = [
422
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
423
+ {
424
+ 'role': 'user',
425
+ 'content': [
426
+ {
427
+ 'type': 'text',
428
+ 'text': prompt
429
+ },
430
+ {
431
+ 'type': 'image_url',
432
+ 'image_url': {
433
+ 'url': f'data:image/png;base64,{base64_image}'
434
+ }
435
+ }
436
+ ]},
437
+ ],
438
+ tools = tools)
439
+
440
+ # Step 1: 工具映射表
441
+ TOOL_MAP = {
442
+ 'get_multi_molecular_text_to_correct_withatoms': get_multi_molecular_text_to_correct_withatoms,
443
+ }
444
+
445
+ # Step 2: 处理多个工具调用
446
+ tool_calls = response.choices[0].message.tool_calls
447
+ results = []
448
+
449
+ # 遍历每个工具调用
450
+ for tool_call in tool_calls:
451
+ tool_name = tool_call.function.name
452
+ tool_arguments = tool_call.function.arguments
453
+ tool_call_id = tool_call.id
454
+
455
+ tool_args = json.loads(tool_arguments)
456
+
457
+ if tool_name in TOOL_MAP:
458
+ # 调用工具并获取结果
459
+ tool_result = TOOL_MAP[tool_name](image_path)
460
+ else:
461
+ raise ValueError(f"Unknown tool called: {tool_name}")
462
+
463
+ # 保存每个工具调用结果
464
+ results.append({
465
+ 'role': 'tool',
466
+ 'content': json.dumps({
467
+ 'image_path': image_path,
468
+ f'{tool_name}':(tool_result),
469
+ }),
470
+ 'tool_call_id': tool_call_id,
471
+ })
472
+
473
+
474
+ # Prepare the chat completion payload
475
+ completion_payload = {
476
+ 'model': 'gpt-4o',
477
+ 'messages': [
478
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
479
+ {
480
+ 'role': 'user',
481
+ 'content': [
482
+ {
483
+ 'type': 'text',
484
+ 'text': prompt
485
+ },
486
+ {
487
+ 'type': 'image_url',
488
+ 'image_url': {
489
+ 'url': f'data:image/png;base64,{base64_image}'
490
+ }
491
+ }
492
+ ]
493
+ },
494
+ response.choices[0].message,
495
+ *results
496
+ ],
497
+ }
498
+
499
+ # Generate new response
500
+ response = client.chat.completions.create(
501
+ model=completion_payload["model"],
502
+ messages=completion_payload["messages"],
503
+ response_format={ 'type': 'json_object' },
504
+ temperature=0
505
+ )
506
+
507
+
508
+
509
+ # 获取 GPT 生成的结果
510
+ gpt_output = [json.loads(response.choices[0].message.content)]
511
+
512
+
513
+ def get_multi_molecular(image_path: str) -> list:
514
+ '''Returns a list of reactions extracted from the image.'''
515
+ # 打开图像文件
516
+ image = Image.open(image_path).convert('RGB')
517
+
518
+ # 将图像作为输入传递给模型
519
+ coref_results = model.extract_molecule_corefs_from_figures([image])
520
+ return coref_results
521
+
522
+
523
+ coref_results = get_multi_molecular(image_path)
524
+
525
+
526
+ def update_symbols_in_atoms(input1, input2):
527
+ """
528
+ 用 input1 中更新后的 'symbols' 替换 input2 中对应 bboxes 的 'symbols',并同步更新 'atoms' 的 'atom_symbol'。
529
+ 假设 input1 和 input2 的结构一致。
530
+ """
531
+ for item1, item2 in zip(input1, input2):
532
+ bboxes1 = item1.get('bboxes', [])
533
+ bboxes2 = item2.get('bboxes', [])
534
+
535
+ if len(bboxes1) != len(bboxes2):
536
+ print("Warning: Mismatched number of bboxes!")
537
+ continue
538
+
539
+ for bbox1, bbox2 in zip(bboxes1, bboxes2):
540
+ # 更新 symbols
541
+ if 'symbols' in bbox1:
542
+ bbox2['symbols'] = bbox1['symbols'] # 更新 symbols
543
+
544
+ # 更新 atoms 的 atom_symbol
545
+ if 'symbols' in bbox1 and 'atoms' in bbox2:
546
+ symbols = bbox1['symbols']
547
+ atoms = bbox2.get('atoms', [])
548
+
549
+ # 确保 symbols 和 atoms 的长度一致
550
+ if len(symbols) != len(atoms):
551
+ print(f"Warning: Mismatched symbols and atoms in bbox {bbox1.get('bbox')}!")
552
+ continue
553
+
554
+ for atom, symbol in zip(atoms, symbols):
555
+ atom['atom_symbol'] = symbol # 更新 atom_symbol
556
+
557
+ return input2
558
+
559
+
560
+ input2_updated = update_symbols_in_atoms(gpt_output, coref_results)
561
+
562
+
563
+
564
+
565
+
566
+ def update_smiles_and_molfile(input_data, conversion_function):
567
+ """
568
+ 使用更新后的 'symbols'、'coords' 和 'edges' 调用 `conversion_function` 生成新的 'smiles' 和 'molfile',
569
+ 并替换到原数据结构中。
570
+
571
+ 参数:
572
+ - input_data: 包含 bboxes 的嵌套数据结构
573
+ - conversion_function: 函数,接受 'coords', 'symbols', 'edges' 并返回 (new_smiles, new_molfile, _)
574
+
575
+ 返回:
576
+ - 更新后的数据结构
577
+ """
578
+ for item in input_data:
579
+ for bbox in item.get('bboxes', []):
580
+ # 检查必需的键是否存在
581
+ if all(key in bbox for key in ['coords', 'symbols', 'edges']):
582
+ coords = bbox['coords']
583
+ symbols = bbox['symbols']
584
+ edges = bbox['edges']
585
+
586
+ # 调用转换函数生成新的 'smiles' 和 'molfile'
587
+ new_smiles, new_molfile, _ = conversion_function(coords, symbols, edges)
588
+ print(f" Generated 'smiles': {new_smiles}")
589
+
590
+ # 替换旧的 'smiles' 和 'molfile'
591
+ bbox['smiles'] = new_smiles
592
+ bbox['molfile'] = new_molfile
593
+
594
+ return input_data
595
+
596
+ updated_data = update_smiles_and_molfile(input2_updated, _convert_graph_to_smiles)
597
+ print(f"updated_mol_data:{updated_data}")
598
+
599
+ return updated_data
get_reaction_agent.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ import json
4
+ from chemietoolkit import ChemIEToolkit
5
+ import cv2
6
+ from PIL import Image
7
+ import json
8
+ import sys
9
+ #sys.path.append('./RxnScribe-main/')
10
+ import torch
11
+ from rxnscribe import RxnScribe
12
+ import json
13
+ from molscribe.chemistry import _convert_graph_to_smiles
14
+
15
+ from openai import AzureOpenAI
16
+ import base64
17
+ import numpy as np
18
+ from chemietoolkit import utils
19
+ from PIL import Image
20
+
21
+
22
+
23
+
24
+ ckpt_path = "./pix2seq_reaction_full.ckpt"
25
+ model1 = RxnScribe(ckpt_path, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
26
+ device = torch.device(('cuda' if torch.cuda.is_available() else 'cpu'))
27
+ model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
28
+
29
+ def get_reaction(image_path: str) -> dict:
30
+ '''
31
+ Returns a structured dictionary of reactions extracted from the image,
32
+ including reactants, conditions, and products, with their smiles, text, and bbox.
33
+ '''
34
+ image_file = image_path
35
+ raw_prediction = model1.predict_image_file(image_file, molscribe=True, ocr=True)
36
+
37
+ # Ensure raw_prediction is treated as a list directly
38
+ structured_output = {}
39
+ for section_key in ['reactants', 'conditions', 'products']:
40
+ if section_key in raw_prediction[0]:
41
+ structured_output[section_key] = []
42
+ for item in raw_prediction[0][section_key]:
43
+ if section_key in ['reactants', 'products']:
44
+ # Extract smiles and bbox for molecules
45
+ structured_output[section_key].append({
46
+ "smiles": item.get("smiles", ""),
47
+ "bbox": item.get("bbox", []),
48
+ "symbols": item.get("symbols", [])
49
+ })
50
+ elif section_key == 'conditions':
51
+ # Extract smiles, text, and bbox for conditions
52
+ condition_data = {"bbox": item.get("bbox", [])}
53
+ if "smiles" in item:
54
+ condition_data["smiles"] = item.get("smiles", "")
55
+ if "text" in item:
56
+ condition_data["text"] = item.get("text", [])
57
+ structured_output[section_key].append(condition_data)
58
+ print(structured_output)
59
+
60
+ return structured_output
61
+
62
+
63
+
64
+ def get_full_reaction(image_path: str) -> dict:
65
+ '''
66
+ Returns a structured dictionary of reactions extracted from the image,
67
+ including reactants, conditions, and products, with their smiles, text, and bbox.
68
+ '''
69
+ image_file = image_path
70
+ raw_prediction = model1.predict_image_file(image_file, molscribe=True, ocr=True)
71
+ return raw_prediction
72
+
73
+
74
+
75
+ def get_reaction_withatoms(image_path: str) -> dict:
76
+ """
77
+
78
+ Args:
79
+ image_path (str): 图像文件路径。
80
+
81
+ Returns:
82
+ dict: 整理后的反应数据,包括反应物、产物和反应模板。
83
+ """
84
+ # 配置 API Key 和 Azure Endpoint
85
+ api_key = "b038da96509b4009be931e035435e022" # 替换为实际的 API Key
86
+ azure_endpoint = "https://hkust.azure-api.net" # 替换为实际的 Azure Endpoint
87
+
88
+ model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
89
+ client = AzureOpenAI(
90
+ api_key=api_key,
91
+ api_version='2024-06-01',
92
+ azure_endpoint=azure_endpoint
93
+ )
94
+
95
+ # 加载图像并编码为 Base64
96
+ def encode_image(image_path: str):
97
+ with open(image_path, "rb") as image_file:
98
+ return base64.b64encode(image_file.read()).decode('utf-8')
99
+
100
+ base64_image = encode_image(image_path)
101
+
102
+ # GPT 工具调用配置
103
+ tools = [
104
+ {
105
+ 'type': 'function',
106
+ 'function': {
107
+ 'name': 'get_reaction',
108
+ 'description': 'Get a list of reactions from a reaction image. A reaction contains data of the reactants, conditions, and products.',
109
+ 'parameters': {
110
+ 'type': 'object',
111
+ 'properties': {
112
+ 'image_path': {
113
+ 'type': 'string',
114
+ 'description': 'The path to the reaction image.',
115
+ },
116
+ },
117
+ 'required': ['image_path'],
118
+ 'additionalProperties': False,
119
+ },
120
+ },
121
+ },
122
+ ]
123
+
124
+ # 提供给 GPT 的消息内容
125
+ with open('./prompt_getreaction.txt', 'r') as prompt_file:
126
+ prompt = prompt_file.read()
127
+ messages = [
128
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
129
+ {
130
+ 'role': 'user',
131
+ 'content': [
132
+ {'type': 'text', 'text': prompt},
133
+ {'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}}
134
+ ]
135
+ }
136
+ ]
137
+
138
+ # 调用 GPT 接口
139
+ response = client.chat.completions.create(
140
+ model = 'gpt-4o',
141
+ temperature = 0,
142
+ response_format={ 'type': 'json_object' },
143
+ messages = [
144
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
145
+ {
146
+ 'role': 'user',
147
+ 'content': [
148
+ {
149
+ 'type': 'text',
150
+ 'text': prompt
151
+ },
152
+ {
153
+ 'type': 'image_url',
154
+ 'image_url': {
155
+ 'url': f'data:image/png;base64,{base64_image}'
156
+ }
157
+ }
158
+ ]},
159
+ ],
160
+ tools = tools)
161
+
162
+ # Step 1: 工具映射表
163
+ TOOL_MAP = {
164
+ 'get_reaction': get_reaction,
165
+ }
166
+
167
+ # Step 2: 处理多个工具调用
168
+ tool_calls = response.choices[0].message.tool_calls
169
+ results = []
170
+
171
+ # 遍历每个工具调用
172
+ for tool_call in tool_calls:
173
+ tool_name = tool_call.function.name
174
+ tool_arguments = tool_call.function.arguments
175
+ tool_call_id = tool_call.id
176
+
177
+ tool_args = json.loads(tool_arguments)
178
+
179
+ if tool_name in TOOL_MAP:
180
+ # 调用工具并获取结果
181
+ tool_result = TOOL_MAP[tool_name](image_path)
182
+ else:
183
+ raise ValueError(f"Unknown tool called: {tool_name}")
184
+
185
+ # 保存每个工具调用结果
186
+ results.append({
187
+ 'role': 'tool',
188
+ 'content': json.dumps({
189
+ 'image_path': image_path,
190
+ f'{tool_name}':(tool_result),
191
+ }),
192
+ 'tool_call_id': tool_call_id,
193
+ })
194
+
195
+
196
+ # Prepare the chat completion payload
197
+ completion_payload = {
198
+ 'model': 'gpt-4o',
199
+ 'messages': [
200
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
201
+ {
202
+ 'role': 'user',
203
+ 'content': [
204
+ {
205
+ 'type': 'text',
206
+ 'text': prompt
207
+ },
208
+ {
209
+ 'type': 'image_url',
210
+ 'image_url': {
211
+ 'url': f'data:image/png;base64,{base64_image}'
212
+ }
213
+ }
214
+ ]
215
+ },
216
+ response.choices[0].message,
217
+ *results
218
+ ],
219
+ }
220
+
221
+ # Generate new response
222
+ response = client.chat.completions.create(
223
+ model=completion_payload["model"],
224
+ messages=completion_payload["messages"],
225
+ response_format={ 'type': 'json_object' },
226
+ temperature=0
227
+ )
228
+
229
+
230
+
231
+ # 获取 GPT 生成的结果
232
+ gpt_output = json.loads(response.choices[0].message.content)
233
+ print(f"gpt_output1:{gpt_output}")
234
+
235
+
236
+ def get_reaction_full(image_path: str) -> dict:
237
+ '''
238
+ Returns a structured dictionary of reactions extracted from the image,
239
+ including reactants, conditions, and products, with their smiles, text, and bbox.
240
+ '''
241
+ image_file = image_path
242
+ raw_prediction = model1.predict_image_file(image_file, molscribe=True, ocr=True)
243
+ return raw_prediction
244
+
245
+ input2 = get_reaction_full(image_path)
246
+
247
+
248
+
249
+ def update_input_with_symbols(input1, input2, conversion_function):
250
+ symbol_mapping = {}
251
+ for key in ['reactants', 'products']:
252
+ for item in input1.get(key, []):
253
+ bbox = tuple(item['bbox']) # 使用 bbox 作为唯一标识
254
+ symbol_mapping[bbox] = item['symbols']
255
+
256
+ for key in ['reactants', 'products']:
257
+ for item in input2.get(key, []):
258
+ bbox = tuple(item['bbox']) # 获取 bbox 作为匹配键
259
+
260
+ # 如果 bbox 存在于 input1 的映射中,则更新 symbols
261
+ if bbox in symbol_mapping:
262
+ updated_symbols = symbol_mapping[bbox]
263
+ item['symbols'] = updated_symbols
264
+
265
+ # 更新 atoms 的 atom_symbol
266
+ if 'atoms' in item:
267
+ atoms = item['atoms']
268
+ if len(atoms) != len(updated_symbols):
269
+ print(f"Warning: Mismatched symbols and atoms in bbox {bbox}")
270
+ else:
271
+ for atom, symbol in zip(atoms, updated_symbols):
272
+ atom['atom_symbol'] = symbol
273
+
274
+ # 如果 coords 和 edges 存在,调用转换函数生成新的 smiles 和 molfile
275
+ if 'coords' in item and 'edges' in item:
276
+ coords = item['coords']
277
+ edges = item['edges']
278
+ new_smiles, new_molfile, _ = conversion_function(coords, updated_symbols, edges)
279
+
280
+ # 替换旧的 smiles 和 molfile
281
+ item['smiles'] = new_smiles
282
+ item['molfile'] = new_molfile
283
+
284
+ return input2
285
+
286
+ updated_data = [update_input_with_symbols(gpt_output, input2[0], _convert_graph_to_smiles)]
287
+
288
+ return updated_data
289
+
290
+
291
+
292
+
293
+ def get_reaction_withatoms_correctR(image_path: str) -> dict:
294
+ """
295
+
296
+ Args:
297
+ image_path (str): 图像文件路径。
298
+
299
+ Returns:
300
+ dict: 整理后的反应数据,包括反应物、产物和反应模板。
301
+ """
302
+ # 配置 API Key 和 Azure Endpoint
303
+ api_key = "b038da96509b4009be931e035435e022" # 替换为实际的 API Key
304
+ azure_endpoint = "https://hkust.azure-api.net" # 替换为实际的 Azure Endpoint
305
+
306
+ model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
307
+ client = AzureOpenAI(
308
+ api_key=api_key,
309
+ api_version='2024-06-01',
310
+ azure_endpoint=azure_endpoint
311
+ )
312
+
313
+ # 加载图像并编码为 Base64
314
+ def encode_image(image_path: str):
315
+ with open(image_path, "rb") as image_file:
316
+ return base64.b64encode(image_file.read()).decode('utf-8')
317
+
318
+ base64_image = encode_image(image_path)
319
+
320
+ # GPT 工具调用配置
321
+ tools = [
322
+ {
323
+ 'type': 'function',
324
+ 'function': {
325
+ 'name': 'get_reaction',
326
+ 'description': 'Get a list of reactions from a reaction image. A reaction contains data of the reactants, conditions, and products.',
327
+ 'parameters': {
328
+ 'type': 'object',
329
+ 'properties': {
330
+ 'image_path': {
331
+ 'type': 'string',
332
+ 'description': 'The path to the reaction image.',
333
+ },
334
+ },
335
+ 'required': ['image_path'],
336
+ 'additionalProperties': False,
337
+ },
338
+ },
339
+ },
340
+ ]
341
+
342
+ # 提供给 GPT 的消息内容
343
+ with open('./prompt_getreaction_correctR.txt', 'r') as prompt_file:
344
+ prompt = prompt_file.read()
345
+ messages = [
346
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
347
+ {
348
+ 'role': 'user',
349
+ 'content': [
350
+ {'type': 'text', 'text': prompt},
351
+ {'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}}
352
+ ]
353
+ }
354
+ ]
355
+
356
+ # 调用 GPT 接口
357
+ response = client.chat.completions.create(
358
+ model = 'gpt-4o',
359
+ temperature = 0,
360
+ response_format={ 'type': 'json_object' },
361
+ messages = [
362
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
363
+ {
364
+ 'role': 'user',
365
+ 'content': [
366
+ {
367
+ 'type': 'text',
368
+ 'text': prompt
369
+ },
370
+ {
371
+ 'type': 'image_url',
372
+ 'image_url': {
373
+ 'url': f'data:image/png;base64,{base64_image}'
374
+ }
375
+ }
376
+ ]},
377
+ ],
378
+ tools = tools)
379
+
380
+ # Step 1: 工具映射表
381
+ TOOL_MAP = {
382
+ 'get_reaction': get_reaction,
383
+ }
384
+
385
+ # Step 2: 处理多个工具调用
386
+ tool_calls = response.choices[0].message.tool_calls
387
+ results = []
388
+
389
+ # 遍历每个工具调用
390
+ for tool_call in tool_calls:
391
+ tool_name = tool_call.function.name
392
+ tool_arguments = tool_call.function.arguments
393
+ tool_call_id = tool_call.id
394
+
395
+ tool_args = json.loads(tool_arguments)
396
+
397
+ if tool_name in TOOL_MAP:
398
+ # 调用工具并获取结果
399
+ tool_result = TOOL_MAP[tool_name](image_path)
400
+ else:
401
+ raise ValueError(f"Unknown tool called: {tool_name}")
402
+
403
+ # 保存每个工具调用结果
404
+ results.append({
405
+ 'role': 'tool',
406
+ 'content': json.dumps({
407
+ 'image_path': image_path,
408
+ f'{tool_name}':(tool_result),
409
+ }),
410
+ 'tool_call_id': tool_call_id,
411
+ })
412
+
413
+
414
+ # Prepare the chat completion payload
415
+ completion_payload = {
416
+ 'model': 'gpt-4o',
417
+ 'messages': [
418
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
419
+ {
420
+ 'role': 'user',
421
+ 'content': [
422
+ {
423
+ 'type': 'text',
424
+ 'text': prompt
425
+ },
426
+ {
427
+ 'type': 'image_url',
428
+ 'image_url': {
429
+ 'url': f'data:image/png;base64,{base64_image}'
430
+ }
431
+ }
432
+ ]
433
+ },
434
+ response.choices[0].message,
435
+ *results
436
+ ],
437
+ }
438
+
439
+ # Generate new response
440
+ response = client.chat.completions.create(
441
+ model=completion_payload["model"],
442
+ messages=completion_payload["messages"],
443
+ response_format={ 'type': 'json_object' },
444
+ temperature=0
445
+ )
446
+
447
+
448
+
449
+ # 获取 GPT 生成的结果
450
+ gpt_output = json.loads(response.choices[0].message.content)
451
+ print(f"gpt_output1:{gpt_output}")
452
+
453
+
454
+ def get_reaction_full(image_path: str) -> dict:
455
+ '''
456
+ Returns a structured dictionary of reactions extracted from the image,
457
+ including reactants, conditions, and products, with their smiles, text, and bbox.
458
+ '''
459
+ image_file = image_path
460
+ raw_prediction = model1.predict_image_file(image_file, molscribe=True, ocr=True)
461
+ return raw_prediction
462
+
463
+ input2 = get_reaction_full(image_path)
464
+
465
+
466
+
467
+ def update_input_with_symbols(input1, input2, conversion_function):
468
+ symbol_mapping = {}
469
+ for key in ['reactants', 'products']:
470
+ for item in input1.get(key, []):
471
+ bbox = tuple(item['bbox']) # 使用 bbox 作为唯一标识
472
+ symbol_mapping[bbox] = item['symbols']
473
+
474
+ for key in ['reactants', 'products']:
475
+ for item in input2.get(key, []):
476
+ bbox = tuple(item['bbox']) # 获取 bbox 作为匹配键
477
+
478
+ # 如果 bbox 存在于 input1 的映射中,则更新 symbols
479
+ if bbox in symbol_mapping:
480
+ updated_symbols = symbol_mapping[bbox]
481
+ item['symbols'] = updated_symbols
482
+
483
+ # 更新 atoms 的 atom_symbol
484
+ if 'atoms' in item:
485
+ atoms = item['atoms']
486
+ if len(atoms) != len(updated_symbols):
487
+ print(f"Warning: Mismatched symbols and atoms in bbox {bbox}")
488
+ else:
489
+ for atom, symbol in zip(atoms, updated_symbols):
490
+ atom['atom_symbol'] = symbol
491
+
492
+ # 如果 coords 和 edges 存在,调用转换函数生成新的 smiles 和 molfile
493
+ if 'coords' in item and 'edges' in item:
494
+ coords = item['coords']
495
+ edges = item['edges']
496
+ new_smiles, new_molfile, _ = conversion_function(coords, updated_symbols, edges)
497
+
498
+ # 替换旧的 smiles 和 molfile
499
+ item['smiles'] = new_smiles
500
+ item['molfile'] = new_molfile
501
+
502
+ return input2
503
+
504
+ updated_data = [update_input_with_symbols(gpt_output, input2[0], _convert_graph_to_smiles)]
505
+ print(f"updated_reaction_data:{updated_data}")
506
+
507
+ return updated_data
main.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ import json
4
+ from chemietoolkit import ChemIEToolkit,utils
5
+ import cv2
6
+ from openai import AzureOpenAI
7
+ import numpy as np
8
+ from PIL import Image
9
+ import json
10
+ from get_molecular_agent import process_reaction_image_with_multiple_products_and_text_correctR
11
+ from get_reaction_agent import get_reaction_withatoms_correctR
12
+ import sys
13
+ from rxnscribe import RxnScribe
14
+ import json
15
+ import base64
16
+ model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
17
+ ckpt_path = "./pix2seq_reaction_full.ckpt"
18
+ model1 = RxnScribe(ckpt_path, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
19
+ device = torch.device(('cuda' if torch.cuda.is_available() else 'cpu'))
20
+ import os
21
+
22
+ with open('api_key.txt', 'r') as api_key_file:
23
+ API_KEY = api_key_file.read()
24
+
25
+ def parse_coref_data_with_fallback(data):
26
+ bboxes = data["bboxes"]
27
+ corefs = data["corefs"]
28
+ paired_indices = set()
29
+
30
+ # 先处理有 coref 配对的
31
+ results = []
32
+ for idx1, idx2 in corefs:
33
+ smiles_entry = bboxes[idx1] if "smiles" in bboxes[idx1] else bboxes[idx2]
34
+ text_entry = bboxes[idx2] if "text" in bboxes[idx2] else bboxes[idx1]
35
+
36
+ smiles = smiles_entry.get("smiles", "")
37
+ texts = text_entry.get("text", [])
38
+
39
+ results.append({
40
+ "smiles": smiles,
41
+ "texts": texts
42
+ })
43
+
44
+ # 记录下哪些 SMILES 被配对过了
45
+ paired_indices.add(idx1)
46
+ paired_indices.add(idx2)
47
+
48
+ # 处理未配对的 SMILES(补充进来)
49
+ for idx, entry in enumerate(bboxes):
50
+ if "smiles" in entry and idx not in paired_indices:
51
+ results.append({
52
+ "smiles": entry["smiles"],
53
+ "texts": ["There is no label or failed to detect, please recheck the image again"]
54
+ })
55
+
56
+ return results
57
+
58
+
59
+ def get_multi_molecular_text_to_correct(image_path: str) -> list:
60
+ '''Returns a list of reactions extracted from the image.'''
61
+ # 打开图像文件
62
+ image = Image.open(image_path).convert('RGB')
63
+
64
+ # 将图像作为输入传递给模型
65
+ #coref_results = process_reaction_image_with_multiple_products_and_text_correctR(image_path)
66
+ coref_results = model.extract_molecule_corefs_from_figures([image])
67
+ for item in coref_results:
68
+ for bbox in item.get("bboxes", []):
69
+ for key in ["category", "bbox", "molfile", "symbols", 'atoms', "bonds", 'category_id', 'score', 'corefs',"coords","edges"]: #'atoms'
70
+ bbox.pop(key, None) # 安全地移除键
71
+
72
+ data = coref_results[0]
73
+ parsed = parse_coref_data_with_fallback(data)
74
+
75
+
76
+ print(f"coref_results:{json.dumps(parsed)}")
77
+ return json.dumps(parsed)
78
+
79
+
80
+
81
+
82
+
83
+
84
+
85
+
86
+ def get_reaction(image_path: str) -> dict:
87
+ '''
88
+ Returns a structured dictionary of reactions extracted from the image,
89
+ including only reactants, conditions, and products with their smiles, bbox, or text.
90
+ '''
91
+ image_file = image_path
92
+ #raw_prediction = model1.predict_image_file(image_file, molscribe=True, ocr=True)
93
+ raw_prediction = get_reaction_withatoms_correctR(image_path)
94
+
95
+
96
+ # Ensure raw_prediction is treated as a list directly
97
+ structured_output = {}
98
+ for section_key in ['reactants', 'conditions', 'products']:
99
+ if section_key in raw_prediction[0]:
100
+ structured_output[section_key] = []
101
+ for item in raw_prediction[0][section_key]:
102
+ if section_key in ['reactants', 'products']:
103
+ # Extract smiles and bbox for molecules
104
+ structured_output[section_key].append({
105
+ "smiles": item.get("smiles", ""),
106
+ "bbox": item.get("bbox", [])
107
+ })
108
+ elif section_key == 'conditions':
109
+ # Extract text and bbox for conditions
110
+ structured_output[section_key].append({
111
+ "text": item.get("text", []),
112
+ "bbox": item.get("bbox", []),
113
+ "smiles": item.get("smiles", []),
114
+ })
115
+
116
+ return structured_output
117
+
118
+
119
+
120
+ def process_reaction_image(image_path: str) -> dict:
121
+ """
122
+
123
+ Args:
124
+ image_path (str): 图像文件路径。
125
+
126
+ Returns:
127
+ dict: 整理后的反应数据,包括反应物、产物和反应模板。
128
+ """
129
+ # 配置 API Key 和 Azure Endpoint
130
+ api_key = os.getenv("CHEMEAGLE_API_KEY")
131
+ if not api_key:
132
+ raise RuntimeError("Missing CHEMEAGLE_API_KEY environment variable")
133
+
134
+ azure_endpoint = "https://hkust.azure-api.net" # 替换为实际的 Azure Endpoint
135
+
136
+ model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
137
+ client = AzureOpenAI(
138
+ api_key=api_key,
139
+ api_version='2024-06-01',
140
+ azure_endpoint=azure_endpoint
141
+ )
142
+
143
+ # 加载图像并编码为 Base64
144
+ def encode_image(image_path: str):
145
+ with open(image_path, "rb") as image_file:
146
+ return base64.b64encode(image_file.read()).decode('utf-8')
147
+
148
+ base64_image = encode_image(image_path)
149
+
150
+ # GPT 工具调用配置
151
+ tools = [
152
+ {
153
+ 'type': 'function',
154
+ 'function': {
155
+ 'name': 'get_multi_molecular_text_to_correct',
156
+ 'description': 'Extracts the SMILES string and text coref from molecular images.',
157
+ 'parameters': {
158
+ 'type': 'object',
159
+ 'properties': {
160
+ 'image_path': {
161
+ 'type': 'string',
162
+ 'description': 'Path to the reaction image.'
163
+ }
164
+ },
165
+ 'required': ['image_path'],
166
+ 'additionalProperties': False
167
+ }
168
+ }
169
+ },
170
+ {
171
+ 'type': 'function',
172
+ 'function': {
173
+ 'name': 'get_reaction',
174
+ 'description': 'Get a list of reactions from a reaction image. A reaction contains data of the reactants, conditions, and products.',
175
+ 'parameters': {
176
+ 'type': 'object',
177
+ 'properties': {
178
+ 'image_path': {
179
+ 'type': 'string',
180
+ 'description': 'The path to the reaction image.',
181
+ },
182
+ },
183
+ 'required': ['image_path'],
184
+ 'additionalProperties': False,
185
+ },
186
+ },
187
+ },
188
+ ]
189
+
190
+ # 提供给 GPT 的消息内容
191
+ with open('./prompt.txt', 'r') as prompt_file:
192
+ prompt = prompt_file.read()
193
+ messages = [
194
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
195
+ {
196
+ 'role': 'user',
197
+ 'content': [
198
+ {'type': 'text', 'text': prompt},
199
+ {'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}}
200
+ ]
201
+ }
202
+ ]
203
+
204
+ # 调用 GPT 接口
205
+ response = client.chat.completions.create(
206
+ model = 'gpt-4o',
207
+ temperature = 0,
208
+ response_format={ 'type': 'json_object' },
209
+ messages = [
210
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
211
+ {
212
+ 'role': 'user',
213
+ 'content': [
214
+ {
215
+ 'type': 'text',
216
+ 'text': prompt
217
+ },
218
+ {
219
+ 'type': 'image_url',
220
+ 'image_url': {
221
+ 'url': f'data:image/png;base64,{base64_image}'
222
+ }
223
+ }
224
+ ]},
225
+ ],
226
+ tools = tools)
227
+
228
+ # Step 1: 工具映射表
229
+ TOOL_MAP = {
230
+ 'get_multi_molecular_text_to_correct': get_multi_molecular_text_to_correct,
231
+ 'get_reaction': get_reaction
232
+ }
233
+
234
+ # Step 2: 处理多个工具调用
235
+ tool_calls = response.choices[0].message.tool_calls
236
+ results = []
237
+
238
+ # 遍历每个工具调用
239
+ for tool_call in tool_calls:
240
+ tool_name = tool_call.function.name
241
+ tool_arguments = tool_call.function.arguments
242
+ tool_call_id = tool_call.id
243
+
244
+ tool_args = json.loads(tool_arguments)
245
+
246
+ if tool_name in TOOL_MAP:
247
+ # 调用工具并获取结果
248
+ tool_result = TOOL_MAP[tool_name](image_path)
249
+ else:
250
+ raise ValueError(f"Unknown tool called: {tool_name}")
251
+
252
+ # 保存每个工具调用结果
253
+ results.append({
254
+ 'role': 'tool',
255
+ 'content': json.dumps({
256
+ 'image_path': image_path,
257
+ f'{tool_name}':(tool_result),
258
+ }),
259
+ 'tool_call_id': tool_call_id,
260
+ })
261
+
262
+
263
+ # Prepare the chat completion payload
264
+ completion_payload = {
265
+ 'model': 'gpt-4o',
266
+ 'messages': [
267
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
268
+ {
269
+ 'role': 'user',
270
+ 'content': [
271
+ {
272
+ 'type': 'text',
273
+ 'text': prompt
274
+ },
275
+ {
276
+ 'type': 'image_url',
277
+ 'image_url': {
278
+ 'url': f'data:image/png;base64,{base64_image}'
279
+ }
280
+ }
281
+ ]
282
+ },
283
+ response.choices[0].message,
284
+ *results
285
+ ],
286
+ }
287
+
288
+ # Generate new response
289
+ response = client.chat.completions.create(
290
+ model=completion_payload["model"],
291
+ messages=completion_payload["messages"],
292
+ response_format={ 'type': 'json_object' },
293
+ temperature=0
294
+ )
295
+
296
+
297
+
298
+ # 获取 GPT 生成的结果
299
+ gpt_output = json.loads(response.choices[0].message.content)
300
+ print(gpt_output)
301
+ image = Image.open(image_path).convert('RGB')
302
+ image_np = np.array(image)
303
+
304
+
305
+ # reaction_results = model.extract_reactions_from_figures([image_np])
306
+ coref_results = model.extract_molecule_corefs_from_figures([image_np])
307
+
308
+ reaction_results = get_reaction_withatoms_correctR(image_path)[0]
309
+ reaction = {
310
+ "reactants": reaction_results.get('reactants', []),
311
+ "conditions": reaction_results.get('conditions', []),
312
+ "products": reaction_results.get('products', [])
313
+ }
314
+ reaction_results = [{"reactions": [reaction]}]
315
+ print(reaction_results)
316
+ #coref_results = process_reaction_image_with_multiple_products_and_text_correctR(image_path)
317
+
318
+
319
+ # 定义更新工具输出的函数
320
+ def extract_smiles_details(smiles_data, raw_details):
321
+ smiles_details = {}
322
+ for smiles in smiles_data:
323
+ for detail in raw_details:
324
+ for bbox in detail.get('bboxes', []):
325
+ if bbox.get('smiles') == smiles:
326
+ smiles_details[smiles] = {
327
+ 'category': bbox.get('category'),
328
+ 'bbox': bbox.get('bbox'),
329
+ 'category_id': bbox.get('category_id'),
330
+ 'score': bbox.get('score'),
331
+ 'molfile': bbox.get('molfile'),
332
+ 'atoms': bbox.get('atoms'),
333
+ 'bonds': bbox.get('bonds'),
334
+ }
335
+ break
336
+ return smiles_details
337
+
338
+ # 获取结果
339
+ smiles_details = extract_smiles_details(gpt_output, coref_results)
340
+
341
+ reactants_array = []
342
+ products = []
343
+
344
+ for reactant in reaction_results[0]['reactions'][0]['reactants']:
345
+ if 'smiles' in reactant:
346
+ print(f"SMILES:{reactant['smiles']}")
347
+ #print(reactant)
348
+ reactants_array.append(reactant['smiles'])
349
+
350
+ for product in reaction_results[0]['reactions'][0]['products']:
351
+ #print(product['smiles'])
352
+ #print(product)
353
+ products.append(product['smiles'])
354
+ # 输出结果
355
+ #import pprint
356
+ #pprint.pprint(smiles_details)
357
+
358
+ # 整理反应数据
359
+ backed_out = utils.backout_without_coref(reaction_results, coref_results, gpt_output, smiles_details, model.molscribe)
360
+ backed_out.sort(key=lambda x: x[2])
361
+ extracted_rxns = {}
362
+ for reactants, products_, label in backed_out:
363
+ extracted_rxns[label] = {'reactants': reactants, 'products': products_}
364
+
365
+ toadd = {
366
+ "reaction_template": {
367
+ "reactants": reactants_array,
368
+ "products": products
369
+ },
370
+ "reactions": extracted_rxns,
371
+ "original_molecule_list": gpt_output
372
+ }
373
+
374
+ # 按标签排序
375
+ sorted_keys = sorted(toadd["reactions"].keys())
376
+ toadd["reactions"] = {i: toadd["reactions"][i] for i in sorted_keys}
377
+ print(toadd)
378
+ return toadd
379
+
380
+
381
+
382
+
383
+ def ChemEagle(image_path: str) -> dict:
384
+ """
385
+ 输入化学反应图像路径,通过 GPT 模型和 TOOLS 提取反应信息并返回整理后的反应数据。
386
+
387
+ Args:
388
+ image_path (str): 图像文件路径。
389
+
390
+ Returns:
391
+ dict: 整理后的反应数据,包括反应物、产物和反应模板。
392
+ """
393
+ # 配置 API Key 和 Azure Endpoint
394
+ api_key = os.getenv("CHEMEAGLE_API_KEY")
395
+ if not api_key:
396
+ raise RuntimeError("Missing CHEMEAGLE_API_KEY environment variable")
397
+
398
+ azure_endpoint = "https://hkust.azure-api.net" # 替换为实际的 Azure Endpoint
399
+
400
+ model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
401
+ client = AzureOpenAI(
402
+ api_key=api_key,
403
+ api_version='2024-06-01',
404
+ azure_endpoint=azure_endpoint
405
+ )
406
+
407
+ # 加载图像并编码为 Base64
408
+ def encode_image(image_path: str):
409
+ with open(image_path, "rb") as image_file:
410
+ return base64.b64encode(image_file.read()).decode('utf-8')
411
+
412
+ base64_image = encode_image(image_path)
413
+
414
+ # GPT 工具调用配置
415
+ tools = [
416
+ {
417
+ 'type': 'function',
418
+ 'function': {
419
+ 'name': 'process_reaction_image',
420
+ 'description': 'get the reaction data of the reaction diagram and get SMILES strings of every detailed reaction in reaction diagram and the table, and the original molecular list.',
421
+ 'parameters': {
422
+ 'type': 'object',
423
+ 'properties': {
424
+ 'image_path': {
425
+ 'type': 'string',
426
+ 'description': 'The path to the reaction image.',
427
+ },
428
+ },
429
+ 'required': ['image_path'],
430
+ 'additionalProperties': False,
431
+ },
432
+ },
433
+ },
434
+ ]
435
+
436
+ # 提供给 GPT 的消息内容
437
+ with open('./prompt_final_simple_version.txt', 'r') as prompt_file:
438
+ prompt = prompt_file.read()
439
+ messages = [
440
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
441
+ {
442
+ 'role': 'user',
443
+ 'content': [
444
+ {'type': 'text', 'text': prompt},
445
+ {'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}}
446
+ ]
447
+ }
448
+ ]
449
+
450
+ # 调用 GPT 接口
451
+ response = client.chat.completions.create(
452
+ model = 'gpt-4o',
453
+ temperature = 0,
454
+ response_format={ 'type': 'json_object' },
455
+ messages = [
456
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
457
+ {
458
+ 'role': 'user',
459
+ 'content': [
460
+ {
461
+ 'type': 'text',
462
+ 'text': prompt
463
+ },
464
+ {
465
+ 'type': 'image_url',
466
+ 'image_url': {
467
+ 'url': f'data:image/png;base64,{base64_image}'
468
+ }
469
+ }
470
+ ]},
471
+ ],
472
+ tools = tools)
473
+
474
+ # Step 1: 工具映射表
475
+ TOOL_MAP = {
476
+ 'process_reaction_image': process_reaction_image
477
+ }
478
+
479
+ # Step 2: 处理多个工具调用
480
+ tool_calls = response.choices[0].message.tool_calls
481
+ results = []
482
+
483
+ # 遍历每个工具调用
484
+ for tool_call in tool_calls:
485
+ tool_name = tool_call.function.name
486
+ tool_arguments = tool_call.function.arguments
487
+ tool_call_id = tool_call.id
488
+
489
+ tool_args = json.loads(tool_arguments)
490
+
491
+ if tool_name in TOOL_MAP:
492
+ # 调用工具并获取结果
493
+ tool_result = TOOL_MAP[tool_name](image_path)
494
+ else:
495
+ raise ValueError(f"Unknown tool called: {tool_name}")
496
+
497
+ # 保存每个工具调用结果
498
+ results.append({
499
+ 'role': 'tool',
500
+ 'content': json.dumps({
501
+ 'image_path': image_path,
502
+ f'{tool_name}':(tool_result),
503
+ }),
504
+ 'tool_call_id': tool_call_id,
505
+ })
506
+
507
+
508
+ # Prepare the chat completion payload
509
+ completion_payload = {
510
+ 'model': 'gpt-4o',
511
+ 'messages': [
512
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
513
+ {
514
+ 'role': 'user',
515
+ 'content': [
516
+ {
517
+ 'type': 'text',
518
+ 'text': prompt
519
+ },
520
+ {
521
+ 'type': 'image_url',
522
+ 'image_url': {
523
+ 'url': f'data:image/png;base64,{base64_image}'
524
+ }
525
+ }
526
+ ]
527
+ },
528
+ response.choices[0].message,
529
+ *results
530
+ ],
531
+ }
532
+
533
+ # Generate new response
534
+ response = client.chat.completions.create(
535
+ model=completion_payload["model"],
536
+ messages=completion_payload["messages"],
537
+ response_format={ 'type': 'json_object' },
538
+ temperature=0
539
+ )
540
+
541
+
542
+
543
+ # 获取 GPT 生成的结果
544
+ gpt_output = json.loads(response.choices[0].message.content)
545
+ print(gpt_output)
546
+ return gpt_output
main_Rgroup_debug.ipynb ADDED
@@ -0,0 +1,993 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import sys\n",
10
+ "import torch\n",
11
+ "import json\n",
12
+ "from chemietoolkit import ChemIEToolkit\n",
13
+ "import cv2\n",
14
+ "from PIL import Image\n",
15
+ "import json\n",
16
+ "model = ChemIEToolkit(device=torch.device('cpu')) \n",
17
+ "from get_molecular_agent import process_reaction_image_with_multiple_products_and_text\n",
18
+ "from get_reaction_agent import get_reaction_withatoms\n",
19
+ "from get_reaction_agent import get_full_reaction\n",
20
+ "\n",
21
+ "\n",
22
+ "# 定义函数,接受多个图像路径并返回反应列表\n",
23
+ "def get_multi_molecular(image_path: str) -> list:\n",
24
+ " '''Returns a list of reactions extracted from the image.'''\n",
25
+ " # 打开图像文件\n",
26
+ " image = Image.open(image_path).convert('RGB')\n",
27
+ " \n",
28
+ " # 将图像作为输入传递给模型\n",
29
+ " coref_results = model.extract_molecule_corefs_from_figures([image])\n",
30
+ " \n",
31
+ " for item in coref_results:\n",
32
+ " for bbox in item.get(\"bboxes\", []):\n",
33
+ " for key in [\"category\", \"molfile\", \"symbols\", 'atoms', \"bonds\", 'category_id', 'score', 'corefs',\"coords\",\"edges\"]: #'atoms'\n",
34
+ " bbox.pop(key, None) # 安全地移除键\n",
35
+ " print(json.dumps(coref_results))\n",
36
+ " # 返回反应列表,使用 json.dumps 进行格式化\n",
37
+ " \n",
38
+ " return json.dumps(coref_results)\n",
39
+ "\n",
40
+ "def get_multi_molecular_text_to_correct(image_path: str) -> list:\n",
41
+ " '''Returns a list of reactions extracted from the image.'''\n",
42
+ " # 打开图像文件\n",
43
+ " image = Image.open(image_path).convert('RGB')\n",
44
+ " \n",
45
+ " # 将图像作为输入传递给模型\n",
46
+ " coref_results = model.extract_molecule_corefs_from_figures([image])\n",
47
+ " #coref_results = process_reaction_image_with_multiple_products_and_text(image_path)\n",
48
+ " for item in coref_results:\n",
49
+ " for bbox in item.get(\"bboxes\", []):\n",
50
+ " for key in [\"category\", \"bbox\", \"molfile\", \"symbols\", 'atoms', \"bonds\", 'category_id', 'score', 'corefs',\"coords\",\"edges\"]: #'atoms'\n",
51
+ " bbox.pop(key, None) # 安全地移除键\n",
52
+ " print(json.dumps(coref_results))\n",
53
+ " # 返回反应列表,使用 json.dumps 进行格式化\n",
54
+ " \n",
55
+ " return json.dumps(coref_results)\n",
56
+ "\n",
57
+ "def get_multi_molecular_text_to_correct_withatoms(image_path: str) -> list:\n",
58
+ " '''Returns a list of reactions extracted from the image.'''\n",
59
+ " # 打开图像文件\n",
60
+ " image = Image.open(image_path).convert('RGB')\n",
61
+ " \n",
62
+ " # 将图像作为输入传递给模型\n",
63
+ " #coref_results = model.extract_molecule_corefs_from_figures([image])\n",
64
+ " coref_results = process_reaction_image_with_multiple_products_and_text(image_path)\n",
65
+ " for item in coref_results:\n",
66
+ " for bbox in item.get(\"bboxes\", []):\n",
67
+ " for key in [\"molfile\", 'atoms', \"bonds\", 'category_id', 'score', 'corefs',\"coords\",\"edges\"]: #'atoms'\n",
68
+ " bbox.pop(key, None) # 安全地移除键\n",
69
+ " print(json.dumps(coref_results))\n",
70
+ " # 返回反应列表,使用 json.dumps 进行格式化\n",
71
+ " return json.dumps(coref_results)\n",
72
+ "\n",
73
+ "#get_multi_molecular_text_to_correct('./acs.joc.2c00176 example 1.png')\n",
74
+ "\n",
75
+ "import sys\n",
76
+ "#sys.path.append('./RxnScribe-main/')\n",
77
+ "import torch\n",
78
+ "from rxnscribe import RxnScribe\n",
79
+ "import json\n",
80
+ "\n",
81
+ "ckpt_path = \"./pix2seq_reaction_full.ckpt\"\n",
82
+ "model1 = RxnScribe(ckpt_path, device=torch.device('cpu'))\n",
83
+ "device = torch.device('cpu')\n",
84
+ "\n",
85
+ "def get_reaction(image_path: str) -> dict:\n",
86
+ " '''\n",
87
+ " Returns a structured dictionary of reactions extracted from the image,\n",
88
+ " including reactants, conditions, and products, with their smiles, text, and bbox.\n",
89
+ " '''\n",
90
+ " image_file = image_path\n",
91
+ " #raw_prediction = model1.predict_image_file(image_file, molscribe=True, ocr=True)\n",
92
+ " raw_prediction = get_reaction_withatoms(image_path)\n",
93
+ "\n",
94
+ " # Ensure raw_prediction is treated as a list directly\n",
95
+ " structured_output = {}\n",
96
+ " for section_key in ['reactants', 'conditions', 'products']:\n",
97
+ " if section_key in raw_prediction[0]:\n",
98
+ " structured_output[section_key] = []\n",
99
+ " for item in raw_prediction[0][section_key]:\n",
100
+ " if section_key in ['reactants', 'products']:\n",
101
+ " # Extract smiles and bbox for molecules\n",
102
+ " structured_output[section_key].append({\n",
103
+ " \"smiles\": item.get(\"smiles\", \"\"),\n",
104
+ " \"bbox\": item.get(\"bbox\", [])\n",
105
+ " })\n",
106
+ " elif section_key == 'conditions':\n",
107
+ " # Extract smiles, text, and bbox for conditions\n",
108
+ " condition_data = {\"bbox\": item.get(\"bbox\", [])}\n",
109
+ " if \"smiles\" in item:\n",
110
+ " condition_data[\"smiles\"] = item.get(\"smiles\", \"\")\n",
111
+ " if \"text\" in item:\n",
112
+ " condition_data[\"text\"] = item.get(\"text\", [])\n",
113
+ " structured_output[section_key].append(condition_data)\n",
114
+ " print(f\"structured_output:{structured_output}\")\n",
115
+ "\n",
116
+ " return structured_output\n",
117
+ "\n",
118
+ "\n",
119
+ "\n",
120
+ "\n",
121
+ "import base64\n",
122
+ "import torch\n",
123
+ "import json\n",
124
+ "from PIL import Image\n",
125
+ "import numpy as np\n",
126
+ "from chemietoolkit import ChemIEToolkit, utils\n",
127
+ "from openai import AzureOpenAI\n",
128
+ "\n",
129
+ "def process_reaction_image_with_multiple_products(image_path: str) -> dict:\n",
130
+ " \"\"\"\n",
131
+ " Args:\n",
132
+ " image_path (str): 图像文件路径。\n",
133
+ "\n",
134
+ " Returns:\n",
135
+ " dict: 整理后的反应数据,包括反应物、产物和反应模板。\n",
136
+ " \"\"\"\n",
137
+ " # 配置 API Key 和 Azure Endpoint\n",
138
+ " api_key = \"b038da96509b4009be931e035435e022\" # 替换为实际的 API Key\n",
139
+ " azure_endpoint = \"https://hkust.azure-api.net\" # 替换为实际的 Azure Endpoint\n",
140
+ " \n",
141
+ "\n",
142
+ " model = ChemIEToolkit(device=torch.device('cpu'))\n",
143
+ " client = AzureOpenAI(\n",
144
+ " api_key=api_key,\n",
145
+ " api_version='2024-06-01',\n",
146
+ " azure_endpoint=azure_endpoint\n",
147
+ " )\n",
148
+ "\n",
149
+ " # 加载图像并编码为 Base64\n",
150
+ " def encode_image(image_path: str):\n",
151
+ " with open(image_path, \"rb\") as image_file:\n",
152
+ " return base64.b64encode(image_file.read()).decode('utf-8')\n",
153
+ "\n",
154
+ " base64_image = encode_image(image_path)\n",
155
+ "\n",
156
+ " # GPT 工具调用配置\n",
157
+ " tools = [\n",
158
+ " {\n",
159
+ " 'type': 'function',\n",
160
+ " 'function': {\n",
161
+ " 'name': 'get_multi_molecular_text_to_correct',\n",
162
+ " 'description': 'Extracts the SMILES string and text coref from molecular images.',\n",
163
+ " 'parameters': {\n",
164
+ " 'type': 'object',\n",
165
+ " 'properties': {\n",
166
+ " 'image_path': {\n",
167
+ " 'type': 'string',\n",
168
+ " 'description': 'Path to the reaction image.'\n",
169
+ " }\n",
170
+ " },\n",
171
+ " 'required': ['image_path'],\n",
172
+ " 'additionalProperties': False\n",
173
+ " }\n",
174
+ " }\n",
175
+ " },\n",
176
+ " {\n",
177
+ " 'type': 'function',\n",
178
+ " 'function': {\n",
179
+ " 'name': 'get_reaction',\n",
180
+ " 'description': 'Get a list of reactions from a reaction image. A reaction contains data of the reactants, conditions, and products.',\n",
181
+ " 'parameters': {\n",
182
+ " 'type': 'object',\n",
183
+ " 'properties': {\n",
184
+ " 'image_path': {\n",
185
+ " 'type': 'string',\n",
186
+ " 'description': 'The path to the reaction image.',\n",
187
+ " },\n",
188
+ " },\n",
189
+ " 'required': ['image_path'],\n",
190
+ " 'additionalProperties': False,\n",
191
+ " },\n",
192
+ " },\n",
193
+ " },\n",
194
+ " ]\n",
195
+ "\n",
196
+ " # 提供给 GPT 的消息内容\n",
197
+ " with open('./prompt.txt', 'r') as prompt_file:\n",
198
+ " prompt = prompt_file.read()\n",
199
+ " messages = [\n",
200
+ " {'role': 'system', 'content': 'You are a helpful assistant.'},\n",
201
+ " {\n",
202
+ " 'role': 'user',\n",
203
+ " 'content': [\n",
204
+ " {'type': 'text', 'text': prompt},\n",
205
+ " {'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}}\n",
206
+ " ]\n",
207
+ " }\n",
208
+ " ]\n",
209
+ "\n",
210
+ " # 调用 GPT 接口\n",
211
+ " response = client.chat.completions.create(\n",
212
+ " model = 'gpt-4o',\n",
213
+ " temperature = 0,\n",
214
+ " response_format={ 'type': 'json_object' },\n",
215
+ " messages = [\n",
216
+ " {'role': 'system', 'content': 'You are a helpful assistant.'},\n",
217
+ " {\n",
218
+ " 'role': 'user',\n",
219
+ " 'content': [\n",
220
+ " {\n",
221
+ " 'type': 'text',\n",
222
+ " 'text': prompt\n",
223
+ " },\n",
224
+ " {\n",
225
+ " 'type': 'image_url',\n",
226
+ " 'image_url': {\n",
227
+ " 'url': f'data:image/png;base64,{base64_image}'\n",
228
+ " }\n",
229
+ " }\n",
230
+ " ]},\n",
231
+ " ],\n",
232
+ " tools = tools)\n",
233
+ " \n",
234
+ "# Step 1: 工具映射表\n",
235
+ " TOOL_MAP = {\n",
236
+ " 'get_multi_molecular_text_to_correct': get_multi_molecular_text_to_correct,\n",
237
+ " 'get_reaction': get_reaction\n",
238
+ " }\n",
239
+ "\n",
240
+ " # Step 2: 处理多个工具调用\n",
241
+ " tool_calls = response.choices[0].message.tool_calls\n",
242
+ " results = []\n",
243
+ "\n",
244
+ " # 遍历每个工具调用\n",
245
+ " for tool_call in tool_calls:\n",
246
+ " tool_name = tool_call.function.name\n",
247
+ " tool_arguments = tool_call.function.arguments\n",
248
+ " tool_call_id = tool_call.id\n",
249
+ " \n",
250
+ " tool_args = json.loads(tool_arguments)\n",
251
+ " \n",
252
+ " if tool_name in TOOL_MAP:\n",
253
+ " # 调用工具并获取结果\n",
254
+ " tool_result = TOOL_MAP[tool_name](image_path)\n",
255
+ " else:\n",
256
+ " raise ValueError(f\"Unknown tool called: {tool_name}\")\n",
257
+ " \n",
258
+ " # 保存每个工具调用结果\n",
259
+ " results.append({\n",
260
+ " 'role': 'tool',\n",
261
+ " 'content': json.dumps({\n",
262
+ " 'image_path': image_path,\n",
263
+ " f'{tool_name}':(tool_result),\n",
264
+ " }),\n",
265
+ " 'tool_call_id': tool_call_id,\n",
266
+ " })\n",
267
+ "\n",
268
+ "\n",
269
+ "# Prepare the chat completion payload\n",
270
+ " completion_payload = {\n",
271
+ " 'model': 'gpt-4o',\n",
272
+ " 'messages': [\n",
273
+ " {'role': 'system', 'content': 'You are a helpful assistant.'},\n",
274
+ " {\n",
275
+ " 'role': 'user',\n",
276
+ " 'content': [\n",
277
+ " {\n",
278
+ " 'type': 'text',\n",
279
+ " 'text': prompt\n",
280
+ " },\n",
281
+ " {\n",
282
+ " 'type': 'image_url',\n",
283
+ " 'image_url': {\n",
284
+ " 'url': f'data:image/png;base64,{base64_image}'\n",
285
+ " }\n",
286
+ " }\n",
287
+ " ]\n",
288
+ " },\n",
289
+ " response.choices[0].message,\n",
290
+ " *results\n",
291
+ " ],\n",
292
+ " }\n",
293
+ "\n",
294
+ "# Generate new response\n",
295
+ " response = client.chat.completions.create(\n",
296
+ " model=completion_payload[\"model\"],\n",
297
+ " messages=completion_payload[\"messages\"],\n",
298
+ " response_format={ 'type': 'json_object' },\n",
299
+ " temperature=0\n",
300
+ " )\n",
301
+ "\n",
302
+ "\n",
303
+ " \n",
304
+ " # 获取 GPT 生成的结果\n",
305
+ " gpt_output = json.loads(response.choices[0].message.content)\n",
306
+ " print(f\"gptout:{gpt_output}\")\n",
307
+ "\n",
308
+ " image = Image.open(image_path).convert('RGB')\n",
309
+ " image_np = np.array(image)\n",
310
+ "\n",
311
+ " #########################\n",
312
+ " #reaction_results = model.extract_reactions_from_figures([image_np])\n",
313
+ " reaction_results = get_reaction_withatoms(image_path)[0]\n",
314
+ " reactions = []\n",
315
+ " \n",
316
+ " # 将 reactants 和 products 转换为 reactions\n",
317
+ " for reactants, conditions, products in zip(reaction_results.get('reactants', []), reaction_results.get('conditions', []), reaction_results.get('products', [])):\n",
318
+ " reaction = {\n",
319
+ " \"reactants\": [reactants],\n",
320
+ " \"conditions\": [conditions],\n",
321
+ " \"products\": [products]\n",
322
+ " }\n",
323
+ " reactions.append(reaction)\n",
324
+ " reaction_results = [{\"reactions\": reactions}]\n",
325
+ " #coref_results = model.extract_molecule_corefs_from_figures([image_np])\n",
326
+ " coref_results = process_reaction_image_with_multiple_products_and_text(image_path)\n",
327
+ " ########################\n",
328
+ "\n",
329
+ " # 定义更新工具输出的函数\n",
330
+ " def extract_smiles_details(smiles_data, raw_details):\n",
331
+ " smiles_details = {}\n",
332
+ " for smiles in smiles_data:\n",
333
+ " for detail in raw_details:\n",
334
+ " for bbox in detail.get('bboxes', []):\n",
335
+ " if bbox.get('smiles') == smiles:\n",
336
+ " smiles_details[smiles] = {\n",
337
+ " 'category': bbox.get('category'),\n",
338
+ " 'bbox': bbox.get('bbox'),\n",
339
+ " 'category_id': bbox.get('category_id'),\n",
340
+ " 'score': bbox.get('score'),\n",
341
+ " 'molfile': bbox.get('molfile'),\n",
342
+ " 'atoms': bbox.get('atoms'),\n",
343
+ " 'bonds': bbox.get('bonds')\n",
344
+ " }\n",
345
+ " break\n",
346
+ " return smiles_details\n",
347
+ "\n",
348
+ "# 获取结果\n",
349
+ " smiles_details = extract_smiles_details(gpt_output, coref_results)\n",
350
+ "\n",
351
+ " reactants_array = []\n",
352
+ " products = []\n",
353
+ "\n",
354
+ " for reactant in reaction_results[0]['reactions'][0]['reactants']:\n",
355
+ " #for reactant in reaction_results[0]['reactions'][0]['reactants']:\n",
356
+ " if 'smiles' in reactant:\n",
357
+ " #print(reactant['smiles'])\n",
358
+ " #print(reactant)\n",
359
+ " reactants_array.append(reactant['smiles'])\n",
360
+ "\n",
361
+ " for product in reaction_results[0]['reactions'][0]['products']:\n",
362
+ " #print(product['smiles'])\n",
363
+ " #print(product)\n",
364
+ " products.append(product['smiles'])\n",
365
+ " # 输出结果\n",
366
+ " #import pprint\n",
367
+ " #pprint.pprint(smiles_details)\n",
368
+ "\n",
369
+ " # 整理反应数据\n",
370
+ " try:\n",
371
+ " backed_out = utils.backout_without_coref(reaction_results, coref_results, gpt_output, smiles_details, model.molscribe)\n",
372
+ " backed_out.sort(key=lambda x: x[2])\n",
373
+ " extracted_rxns = {}\n",
374
+ " for reactants, products_, label in backed_out:\n",
375
+ " extracted_rxns[label] = {'reactants': reactants, 'products': products_}\n",
376
+ "\n",
377
+ " toadd = {\n",
378
+ " \"reaction_template\": {\n",
379
+ " \"reactants\": reactants_array,\n",
380
+ " \"products\": products\n",
381
+ " },\n",
382
+ " \"reactions\": extracted_rxns\n",
383
+ " }\n",
384
+ " \n",
385
+ "\n",
386
+ " # 按标签排序\n",
387
+ " sorted_keys = sorted(toadd[\"reactions\"].keys())\n",
388
+ " toadd[\"reactions\"] = {i: toadd[\"reactions\"][i] for i in sorted_keys}\n",
389
+ " original_molecular_list = {'Original molecular list': gpt_output}\n",
390
+ " final_data= toadd.copy()\n",
391
+ " final_data.update(original_molecular_list)\n",
392
+ " except:\n",
393
+ " #pass\n",
394
+ " final_data = {'Original molecular list': gpt_output}\n",
395
+ "\n",
396
+ " print(final_data)\n",
397
+ " return final_data\n",
398
+ " \n",
399
+ "\n",
400
+ "\n",
401
+ "\n"
402
+ ]
403
+ },
404
+ {
405
+ "cell_type": "code",
406
+ "execution_count": null,
407
+ "metadata": {},
408
+ "outputs": [],
409
+ "source": [
410
+ "# # image_path = './example/Replace/99.jpg'\n",
411
+ "# # result = process_reaction_image(image_path)\n",
412
+ "# # print(json.dumps(result, indent=4))\n",
413
+ "# image_path = './example/example1/replace/Nesting/283.jpg'\n",
414
+ "# image = Image.open(image_path).convert('RGB')\n",
415
+ "# image_np = np.array(image)\n",
416
+ "\n",
417
+ "# # input1 = get_multi_molecular_text_to_correct_withatoms('./example/example1/replace/Nesting/283.jpg')\n",
418
+ "# # input2 = get_reaction('./example/example1/replace/Nesting/283.jpg')\n",
419
+ "# # print(input1)\n",
420
+ "# # print(input2)\n",
421
+ "# #reaction_results = model.extract_reactions_from_figures([image_np])\n",
422
+ "# coorf = model.extract_molecule_corefs_from_figures([image_np])\n",
423
+ "# print(coorf)\n"
424
+ ]
425
+ },
426
+ {
427
+ "cell_type": "code",
428
+ "execution_count": null,
429
+ "metadata": {},
430
+ "outputs": [],
431
+ "source": [
432
+ "import base64\n",
433
+ "import torch\n",
434
+ "import json\n",
435
+ "from PIL import Image\n",
436
+ "import numpy as np\n",
437
+ "from openai import AzureOpenAI\n",
438
+ "\n",
439
+ "def process_reaction_image_final(image_path: str) -> dict:\n",
440
+ " \"\"\"\n",
441
+ "\n",
442
+ " Args:\n",
443
+ " image_path (str): 图像文件路径。\n",
444
+ "\n",
445
+ " Returns:\n",
446
+ " dict: 整理后的反应数据,包括反应物、产物和反应模板。\n",
447
+ " \"\"\"\n",
448
+ " # 配置 API Key 和 Azure Endpoint\n",
449
+ " api_key = \"b038da96509b4009be931e035435e022\" # 替换为实际的 API Key\n",
450
+ " azure_endpoint = \"https://hkust.azure-api.net\" # 替换为实际的 Azure Endpoint\n",
451
+ " \n",
452
+ "\n",
453
+ " model = ChemIEToolkit(device=torch.device('cpu'))\n",
454
+ " client = AzureOpenAI(\n",
455
+ " api_key=api_key,\n",
456
+ " api_version='2024-06-01',\n",
457
+ " azure_endpoint=azure_endpoint\n",
458
+ " )\n",
459
+ "\n",
460
+ " # 加载图像并编码为 Base64\n",
461
+ " def encode_image(image_path: str):\n",
462
+ " with open(image_path, \"rb\") as image_file:\n",
463
+ " return base64.b64encode(image_file.read()).decode('utf-8')\n",
464
+ "\n",
465
+ " base64_image = encode_image(image_path)\n",
466
+ "\n",
467
+ " # GPT 工具调用配置\n",
468
+ " tools = [\n",
469
+ " {\n",
470
+ " 'type': 'function',\n",
471
+ " 'function': {\n",
472
+ " 'name': 'get_multi_molecular_text_to_correct',\n",
473
+ " 'description': 'Extracts the SMILES string and text coref from molecular sub-images from a reaction image and ready for further process.',\n",
474
+ " 'parameters': {\n",
475
+ " 'type': 'object',\n",
476
+ " 'properties': {\n",
477
+ " 'image_path': {\n",
478
+ " 'type': 'string',\n",
479
+ " 'description': 'Path to the reaction image.'\n",
480
+ " }\n",
481
+ " },\n",
482
+ " 'required': ['image_path'],\n",
483
+ " 'additionalProperties': False\n",
484
+ " }\n",
485
+ " }\n",
486
+ " },\n",
487
+ " {\n",
488
+ " 'type': 'function',\n",
489
+ " 'function': {\n",
490
+ " 'name': 'get_reaction',\n",
491
+ " 'description': 'Get a list of reactions from a reaction image. A reaction contains data of the reactants, conditions, and products.',\n",
492
+ " 'parameters': {\n",
493
+ " 'type': 'object',\n",
494
+ " 'properties': {\n",
495
+ " 'image_path': {\n",
496
+ " 'type': 'string',\n",
497
+ " 'description': 'The path to the reaction image.',\n",
498
+ " },\n",
499
+ " },\n",
500
+ " 'required': ['image_path'],\n",
501
+ " 'additionalProperties': False,\n",
502
+ " },\n",
503
+ " },\n",
504
+ " },\n",
505
+ "\n",
506
+ " \n",
507
+ "\n",
508
+ " {\n",
509
+ " 'type': 'function',\n",
510
+ " 'function': {\n",
511
+ " 'name': 'process_reaction_image_with_multiple_products',\n",
512
+ " 'description': 'process the reaction image that contains a multiple products table. Get a list of reactions from the reaction image, Inculding the reaction template and detailed reaction with detailed R-group information.',\n",
513
+ " 'parameters': {\n",
514
+ " 'type': 'object',\n",
515
+ " 'properties': {\n",
516
+ " 'image_path': {\n",
517
+ " 'type': 'string',\n",
518
+ " 'description': 'The path to the reaction image.',\n",
519
+ " },\n",
520
+ " },\n",
521
+ " 'required': ['image_path'],\n",
522
+ " 'additionalProperties': False,\n",
523
+ " },\n",
524
+ " },\n",
525
+ " },\n",
526
+ "\n",
527
+ " {\n",
528
+ " 'type': 'function',\n",
529
+ " 'function': {\n",
530
+ " 'name': 'get_full_reaction',\n",
531
+ " 'description': 'Get a list of reactions from a reaction image without any tables. A reaction contains data of the reactants, conditions, and products.',\n",
532
+ " 'parameters': {\n",
533
+ " 'type': 'object',\n",
534
+ " 'properties': {\n",
535
+ " 'image_path': {\n",
536
+ " 'type': 'string',\n",
537
+ " 'description': 'The path to the reaction image.',\n",
538
+ " },\n",
539
+ " },\n",
540
+ " 'required': ['image_path'],\n",
541
+ " 'additionalProperties': False,\n",
542
+ " },\n",
543
+ " },\n",
544
+ " },\n",
545
+ "\n",
546
+ " {\n",
547
+ " 'type': 'function',\n",
548
+ " 'function': {\n",
549
+ " 'name': 'get_multi_molecular',\n",
550
+ " 'description': 'Extracts the SMILES string and text coref from a molecular image without any reactions',\n",
551
+ " 'parameters': {\n",
552
+ " 'type': 'object',\n",
553
+ " 'properties': {\n",
554
+ " 'image_path': {\n",
555
+ " 'type': 'string',\n",
556
+ " 'description': 'The path to the reaction image.',\n",
557
+ " },\n",
558
+ " },\n",
559
+ " 'required': ['image_path'],\n",
560
+ " 'additionalProperties': False,\n",
561
+ " },\n",
562
+ " },\n",
563
+ " },\n",
564
+ " ]\n",
565
+ "\n",
566
+ " # 提供给 GPT 的消息内容\n",
567
+ " with open('./prompt_final.txt', 'r') as prompt_file:\n",
568
+ " prompt = prompt_file.read()\n",
569
+ " messages = [\n",
570
+ " {'role': 'system', 'content': 'You are a helpful assistant.'},\n",
571
+ " {\n",
572
+ " 'role': 'user',\n",
573
+ " 'content': [\n",
574
+ " {'type': 'text', 'text': prompt},\n",
575
+ " {'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}}\n",
576
+ " ]\n",
577
+ " }\n",
578
+ " ]\n",
579
+ "\n",
580
+ " # 调用 GPT 接口\n",
581
+ " response = client.chat.completions.create(\n",
582
+ " model = 'gpt-4o',\n",
583
+ " temperature = 0,\n",
584
+ " response_format={ 'type': 'json_object' },\n",
585
+ " messages = [\n",
586
+ " {'role': 'system', 'content': 'You are a helpful assistant.'},\n",
587
+ " {\n",
588
+ " 'role': 'user',\n",
589
+ " 'content': [\n",
590
+ " {\n",
591
+ " 'type': 'text',\n",
592
+ " 'text': prompt\n",
593
+ " },\n",
594
+ " {\n",
595
+ " 'type': 'image_url',\n",
596
+ " 'image_url': {\n",
597
+ " 'url': f'data:image/png;base64,{base64_image}'\n",
598
+ " }\n",
599
+ " }\n",
600
+ " ]},\n",
601
+ " ],\n",
602
+ " tools = tools)\n",
603
+ " \n",
604
+ "# Step 1: 工具映射表\n",
605
+ " TOOL_MAP = {\n",
606
+ " 'get_multi_molecular_text_to_correct': get_multi_molecular_text_to_correct,\n",
607
+ " 'get_reaction': get_reaction,\n",
608
+ " 'process_reaction_image_with_multiple_products':process_reaction_image_with_multiple_products,\n",
609
+ "\n",
610
+ " 'get_full_reaction': get_full_reaction,\n",
611
+ " 'get_multi_molecular':get_multi_molecular,\n",
612
+ " }\n",
613
+ "\n",
614
+ " # Step 2: 处理多个工具调用\n",
615
+ " tool_calls = response.choices[0].message.tool_calls\n",
616
+ " results = []\n",
617
+ "\n",
618
+ " # 遍历每个工具调用\n",
619
+ " for tool_call in tool_calls:\n",
620
+ " tool_name = tool_call.function.name\n",
621
+ " tool_arguments = tool_call.function.arguments\n",
622
+ " tool_call_id = tool_call.id\n",
623
+ " \n",
624
+ " tool_args = json.loads(tool_arguments)\n",
625
+ " \n",
626
+ " if tool_name in TOOL_MAP:\n",
627
+ " # 调用工具并获取结果\n",
628
+ " tool_result = TOOL_MAP[tool_name](image_path)\n",
629
+ " else:\n",
630
+ " raise ValueError(f\"Unknown tool called: {tool_name}\")\n",
631
+ " \n",
632
+ " # 保存每个工具调用结果\n",
633
+ " results.append({\n",
634
+ " 'role': 'tool',\n",
635
+ " 'content': json.dumps({\n",
636
+ " 'image_path': image_path,\n",
637
+ " f'{tool_name}':(tool_result),\n",
638
+ " }),\n",
639
+ " 'tool_call_id': tool_call_id,\n",
640
+ " })\n",
641
+ "\n",
642
+ "\n",
643
+ "# Prepare the chat completion payload\n",
644
+ " completion_payload = {\n",
645
+ " 'model': 'gpt-4o',\n",
646
+ " 'messages': [\n",
647
+ " {'role': 'system', 'content': 'You are a helpful assistant.'},\n",
648
+ " {\n",
649
+ " 'role': 'user',\n",
650
+ " 'content': [\n",
651
+ " {\n",
652
+ " 'type': 'text',\n",
653
+ " 'text': prompt\n",
654
+ " },\n",
655
+ " {\n",
656
+ " 'type': 'image_url',\n",
657
+ " 'image_url': {\n",
658
+ " 'url': f'data:image/png;base64,{base64_image}'\n",
659
+ " }\n",
660
+ " }\n",
661
+ " ]\n",
662
+ " },\n",
663
+ " response.choices[0].message,\n",
664
+ " *results\n",
665
+ " ],\n",
666
+ " }\n",
667
+ "\n",
668
+ "# Generate new response\n",
669
+ " response = client.chat.completions.create(\n",
670
+ " model=completion_payload[\"model\"],\n",
671
+ " messages=completion_payload[\"messages\"],\n",
672
+ " response_format={ 'type': 'json_object' },\n",
673
+ " temperature=0\n",
674
+ " )\n",
675
+ "\n",
676
+ "\n",
677
+ " \n",
678
+ " # 获取 GPT 生成的结果\n",
679
+ " gpt_output = json.loads(response.choices[0].message.content)\n",
680
+ " print(gpt_output)\n",
681
+ " return gpt_output\n",
682
+ "\n",
683
+ "\n",
684
+ "\n"
685
+ ]
686
+ },
687
+ {
688
+ "cell_type": "code",
689
+ "execution_count": null,
690
+ "metadata": {},
691
+ "outputs": [],
692
+ "source": [
693
+ "image_path = './data/bowen-4/2.png'\n",
694
+ "result = process_reaction_image_final(image_path)\n",
695
+ "print(json.dumps(result, indent=4))"
696
+ ]
697
+ },
698
+ {
699
+ "cell_type": "code",
700
+ "execution_count": null,
701
+ "metadata": {},
702
+ "outputs": [],
703
+ "source": [
704
+ "# def get_reaction(image_path: str) -> list:\n",
705
+ "# '''Returns a list of reactions extracted from the image.'''\n",
706
+ "# image_file = image_path\n",
707
+ "# return json.dumps(model1.predict_image_file(image_file, molscribe=True, ocr=True))\n",
708
+ "\n",
709
+ "# reaction_output = get_reaction('./pdf/2/2_image_3_1.png')\n",
710
+ "# print(reaction_output)"
711
+ ]
712
+ },
713
+ {
714
+ "cell_type": "code",
715
+ "execution_count": null,
716
+ "metadata": {},
717
+ "outputs": [],
718
+ "source": [
719
+ "import os\n",
720
+ "import fitz # PyMuPDF\n",
721
+ "from core import run_visualheist\n",
722
+ "import base64\n",
723
+ "from openai import AzureOpenAI\n",
724
+ "\n",
725
+ "def full_pdf_extraction_pipeline_with_history(pdf_path,\n",
726
+ " output_dir,\n",
727
+ " api_key,\n",
728
+ " azure_endpoint,\n",
729
+ " model=\"gpt-4o\",\n",
730
+ " model_size=\"large\"):\n",
731
+ " \"\"\"\n",
732
+ " Full pipeline: from PDF to GPT-annotated related text.\n",
733
+ " Extracts markdown + figures + reaction data from a PDF and calls GPT-4o to annotate them.\n",
734
+ "\n",
735
+ " Args:\n",
736
+ " pdf_path (str): Path to input PDF file.\n",
737
+ " output_dir (str): Directory to save results.\n",
738
+ " api_key (str): Azure OpenAI API key.\n",
739
+ " azure_endpoint (str): Azure OpenAI endpoint.\n",
740
+ " model (str): GPT model name (default \"gpt-4o\").\n",
741
+ " model_size (str): VisualHeist model size (\"base\", \"large\", etc).\n",
742
+ "\n",
743
+ " Returns:\n",
744
+ " List of GPT-generated annotated related-text JSONs.\n",
745
+ " \"\"\"\n",
746
+ "\n",
747
+ "\n",
748
+ " os.makedirs(output_dir, exist_ok=True)\n",
749
+ "\n",
750
+ " # Step 1: Extract Markdown text\n",
751
+ " doc = fitz.open(pdf_path)\n",
752
+ " md_text = \"\"\n",
753
+ " for i, page in enumerate(doc, start=1):\n",
754
+ " md_text += f\"\\n\\n## = Page {i} =\\n\\n\" + page.get_text()\n",
755
+ " filename = os.path.splitext(os.path.basename(pdf_path))[0]\n",
756
+ " md_path = os.path.join(output_dir, f\"{filename}.md\")\n",
757
+ " with open(md_path, \"w\", encoding=\"utf-8\") as f:\n",
758
+ " f.write(md_text.strip())\n",
759
+ " print(f\"[✓] Markdown saved to: {md_path}\")\n",
760
+ "\n",
761
+ " # Step 2: Extract figures using VisualHeist\n",
762
+ " run_visualheist(pdf_dir=pdf_path, model_size=model_size, image_dir=output_dir)\n",
763
+ " print(f\"[✓] Figures extracted to: {output_dir}\")\n",
764
+ "\n",
765
+ " # Step 3: Parse figures to JSON\n",
766
+ " image_data = []\n",
767
+ " known_molecules = []\n",
768
+ "\n",
769
+ " for fname in sorted(os.listdir(output_dir)):\n",
770
+ " if fname.endswith(\".png\"):\n",
771
+ " img_path = os.path.join(output_dir, fname)\n",
772
+ " try:\n",
773
+ " result = process_reaction_image_final(img_path)\n",
774
+ " result[\"image_name\"] = fname\n",
775
+ " image_data.append(result)\n",
776
+ " except Exception as e:\n",
777
+ " print(f\"[!] Failed on {fname}: {e}\")\n",
778
+ " new_mols_json = get_multi_molecular_text_to_correct(img_path)\n",
779
+ " new_mols = json.loads(new_mols_json)\n",
780
+ " for m in new_mols:\n",
781
+ " if m[\"smiles\"] not in {km[\"smiles\"] for km in known_molecules}:\n",
782
+ " known_molecules.append(m)\n",
783
+ "\n",
784
+ "\n",
785
+ " json_path = os.path.join(output_dir, f\"{filename}_reaction_data.json\")\n",
786
+ " with open(json_path, \"w\", encoding=\"utf-8\") as f:\n",
787
+ " json.dump(image_data, f, indent=2, ensure_ascii=False)\n",
788
+ " print(f\"[✓] Reaction data saved to: {json_path}\")\n",
789
+ "\n",
790
+ " # Step 4: Call Azure GPT-4 for annotation\n",
791
+ " client = AzureOpenAI(\n",
792
+ " api_key=api_key,\n",
793
+ " api_version=\"2024-06-01\",\n",
794
+ " azure_endpoint=azure_endpoint\n",
795
+ " )\n",
796
+ "\n",
797
+ " prompt = \"\"\"\n",
798
+ "You are a text-mining assistant for chemistry papers. Your task is to find the most relevant 1–3 sentences in a research article that describe a given figure or scheme.\n",
799
+ "\n",
800
+ "You will be given:\n",
801
+ "- A block of text extracted from the article (in Markdown format).\n",
802
+ "- The extracted structured data from one image (including its title and list of molecules or reactions).\n",
803
+ "\n",
804
+ "Your task is:\n",
805
+ "1. Match the image with sentences that are most relevant to it. Use clues like the figure/scheme/table number in the title, or molecule/reaction labels (e.g., 1a, 2b, 3).\n",
806
+ "2. Extract up to 3 short sentences that best describe or mention the contents of the image.\n",
807
+ "3. In these sentences, label any molecule or reaction identifiers (like “1a”, “2b”) with their role based on context: [reactant], [product], etc.\n",
808
+ "4. Also label experimental conditions with their roles:\n",
809
+ " - Percent values like “85%” as [yield]\n",
810
+ " - Temperatures like “100 °C” as [temperature]\n",
811
+ " - Time durations like “24 h”, “20 min” as [time]\n",
812
+ "5. Do **not** label chemical position numbers (e.g., in \"3-trifluoromethyl\", \"1,2,4-triazole\").\n",
813
+ "6. Do not repeat any labels. Only label each item once per sentence.\n",
814
+ "\n",
815
+ "Output format:\n",
816
+ "{\n",
817
+ " \"title\": \"<title from image>\",\n",
818
+ " \"related-text\": [\n",
819
+ " \"Sentence with roles like 1a[reactant], 2c[product], 100[temperature] °C.\",\n",
820
+ " ...\n",
821
+ " ]\n",
822
+ "}\n",
823
+ "\"\"\"\n",
824
+ "\n",
825
+ " annotated_results = []\n",
826
+ " for item in image_data:\n",
827
+ " img_path = os.path.join(output_dir, item[\"image_name\"])\n",
828
+ " with open(img_path, \"rb\") as f:\n",
829
+ " base64_image = base64.b64encode(f.read()).decode(\"utf-8\")\n",
830
+ "\n",
831
+ " combined_input = f\"\"\"\n",
832
+ "## Image Structured Data:\n",
833
+ "{json.dumps(item, indent=2)}\n",
834
+ "\n",
835
+ "## Article Text:\n",
836
+ "{md_text}\n",
837
+ "\"\"\"\n",
838
+ "\n",
839
+ " response = client.chat.completions.create(\n",
840
+ " model=model,\n",
841
+ " temperature=0,\n",
842
+ " response_format=\"json\",\n",
843
+ " messages=[\n",
844
+ " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
845
+ " {\n",
846
+ " \"role\": \"user\",\n",
847
+ " \"content\": [\n",
848
+ " {\"type\": \"text\", \"text\": prompt + \"\\n\\n\" + combined_input},\n",
849
+ " {\n",
850
+ " \"type\": \"image_url\",\n",
851
+ " \"image_url\": {\n",
852
+ " \"url\": f\"data:image/png;base64,{base64_image}\"\n",
853
+ " }\n",
854
+ " }\n",
855
+ " ]\n",
856
+ " }\n",
857
+ " ]\n",
858
+ " )\n",
859
+ " annotated_results.append(json.loads(response.choices[0].message.content))\n",
860
+ "\n",
861
+ " # Optionally save output\n",
862
+ " with open(os.path.join(output_dir, f\"{filename}_annotated_related_text.json\"), \"w\", encoding=\"utf-8\") as f:\n",
863
+ " json.dump(annotated_results, f, indent=2, ensure_ascii=False)\n",
864
+ " print(f\"[✓] Annotated related-text saved.\")\n",
865
+ "\n",
866
+ " return annotated_results"
867
+ ]
868
+ },
869
+ {
870
+ "cell_type": "code",
871
+ "execution_count": null,
872
+ "metadata": {},
873
+ "outputs": [],
874
+ "source": [
875
+ "image_path = './data/example/example1/replace/Nesting/283.jpg'\n",
876
+ "#image_path = './pdf/2/2_image_1_1.png'\n",
877
+ "result = process_reaction_image_final(image_path)\n",
878
+ "print(json.dumps(result, indent=4))"
879
+ ]
880
+ },
881
+ {
882
+ "cell_type": "code",
883
+ "execution_count": null,
884
+ "metadata": {},
885
+ "outputs": [],
886
+ "source": []
887
+ },
888
+ {
889
+ "cell_type": "code",
890
+ "execution_count": null,
891
+ "metadata": {},
892
+ "outputs": [],
893
+ "source": [
894
+ "# import os\n",
895
+ "\n",
896
+ "# image_folder = './example/example1/replace/regular/' # 图片文件夹路径\n",
897
+ "# output_folder = './batches_final_repalce_regular/' # 保存每批结果的文件夹路径\n",
898
+ "# batch_size = 3 # 每批处理文件数量\n",
899
+ "\n",
900
+ "# # 创建保存批次结果的文件夹(如果不存在)\n",
901
+ "# os.makedirs(output_folder, exist_ok=True)\n",
902
+ "\n",
903
+ "# # 获取所有图片文件并按字母顺序排序\n",
904
+ "# all_files = sorted([f for f in os.listdir(image_folder) if f.endswith('.jpg')])\n",
905
+ "\n",
906
+ "# # 获取已完成的批次\n",
907
+ "# completed_batches = [\n",
908
+ "# int(f.split('_')[1].split('.')[0]) for f in os.listdir(output_folder) if f.startswith('batch_') and f.endswith('.json')\n",
909
+ "# ]\n",
910
+ "# completed_batches = sorted(completed_batches) # 确保按顺序排序\n",
911
+ "\n",
912
+ "# # 从指定批次开始(如果有未完成批次)\n",
913
+ "# start_batch = (completed_batches[-1] + 1) if completed_batches else 1\n",
914
+ "\n",
915
+ "# # 将文件分批并从指定批次开始\n",
916
+ "# for batch_index in range((start_batch - 1) * batch_size, len(all_files), batch_size):\n",
917
+ "# batch_files = all_files[batch_index:batch_index + batch_size]\n",
918
+ "# results = []\n",
919
+ "\n",
920
+ "# batch_number = batch_index // batch_size + 1\n",
921
+ "# print(f\"正在按字母顺序处理第 {batch_number} 批,共 {len(batch_files)} 张图片...\")\n",
922
+ " \n",
923
+ "# for file_name in batch_files:\n",
924
+ "# image_path = os.path.join(image_folder, file_name)\n",
925
+ "# print(f\"处理文件 {file_name}...\")\n",
926
+ " \n",
927
+ "# try:\n",
928
+ "# # 处理单个图片\n",
929
+ "# result = process_reaction_image_final(image_path)\n",
930
+ " \n",
931
+ "# # 确保结果是字典\n",
932
+ "# if isinstance(result, dict):\n",
933
+ "# # 添加文件名信息\n",
934
+ "# result_with_filename = {\n",
935
+ "# \"file_name\": file_name,\n",
936
+ "# **result\n",
937
+ "# }\n",
938
+ "# results.append(result_with_filename)\n",
939
+ "# print(result_with_filename)\n",
940
+ "# else:\n",
941
+ "# print(f\"文件 {file_name} 的处理结果不是字典,跳过。\")\n",
942
+ " \n",
943
+ "# except Exception as e:\n",
944
+ "# print(f\"处理文件 {file_name} 时出错: {e}\")\n",
945
+ "\n",
946
+ "# # 保存当前批次结果\n",
947
+ "# batch_output_path = os.path.join(output_folder, f'batch_{batch_number}.json')\n",
948
+ "# with open(batch_output_path, 'w', encoding='utf-8') as json_file:\n",
949
+ "# json.dump(results, json_file, ensure_ascii=False, indent=4)\n",
950
+ "\n",
951
+ "# print(f\"第 {batch_number} 批处理完成,结果保存到 {batch_output_path}\")\n",
952
+ "\n",
953
+ "# print(\"所有批次处理完成!\")\n",
954
+ "\n",
955
+ "\n"
956
+ ]
957
+ },
958
+ {
959
+ "cell_type": "code",
960
+ "execution_count": null,
961
+ "metadata": {},
962
+ "outputs": [],
963
+ "source": [
964
+ "import rdkit\n",
965
+ "from rdkit import Chem\n",
966
+ "from rdkit.Chem import Draw\n",
967
+ "\n",
968
+ "Draw.MolToImage(Chem.MolFromSmiles('[Si](C)(C)OC(c1ccccc1)(c1ccccc1)C1CCC2=NN(Cc3ccccc3)=CN21'))"
969
+ ]
970
+ }
971
+ ],
972
+ "metadata": {
973
+ "kernelspec": {
974
+ "display_name": "openchemie",
975
+ "language": "python",
976
+ "name": "python3"
977
+ },
978
+ "language_info": {
979
+ "codemirror_mode": {
980
+ "name": "ipython",
981
+ "version": 3
982
+ },
983
+ "file_extension": ".py",
984
+ "mimetype": "text/x-python",
985
+ "name": "python",
986
+ "nbconvert_exporter": "python",
987
+ "pygments_lexer": "ipython3",
988
+ "version": "3.10.14"
989
+ }
990
+ },
991
+ "nbformat": 4,
992
+ "nbformat_minor": 2
993
+ }
molscribe/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .interface import MolScribe
molscribe/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (189 Bytes). View file
 
molscribe/__pycache__/augment.cpython-310.pyc ADDED
Binary file (8.98 kB). View file