diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..ad3146ed8be77347f79bcbb84afc704b87d38288 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +examples/exp.png filter=lfs diff=lfs merge=lfs -text +examples/reaction4.png filter=lfs diff=lfs merge=lfs -text +molscribe/indigo/lib/Linux/x64/libbingo.so filter=lfs diff=lfs merge=lfs -text +molscribe/indigo/lib/Linux/x64/libindigo-renderer.so filter=lfs diff=lfs merge=lfs -text +molscribe/indigo/lib/Linux/x64/libindigo.so filter=lfs diff=lfs merge=lfs -text diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e751a3e13c1d7b28266bd9cab8f2a635e4652124 --- /dev/null +++ b/__init__.py @@ -0,0 +1,3 @@ +__version__ = "0.1.0" +__author__ = 'Alex Wang' +__credits__ = 'CSAIL' \ No newline at end of file diff --git a/__pycache__/get_molecular_agent.cpython-310.pyc b/__pycache__/get_molecular_agent.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa67ae20a47cb65e83bb4ebf9f22c2ef53cf1db9 Binary files /dev/null and b/__pycache__/get_molecular_agent.cpython-310.pyc differ diff --git a/__pycache__/get_reaction_agent.cpython-310.pyc b/__pycache__/get_reaction_agent.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95d7e400ba6deee92bdbd2bfb98f192149e3df13 Binary files /dev/null and b/__pycache__/get_reaction_agent.cpython-310.pyc differ diff --git a/__pycache__/main.cpython-310.pyc b/__pycache__/main.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e5ba2f6aa51b1fd2981abbe4e7b6c15764dbc2d Binary files /dev/null and b/__pycache__/main.cpython-310.pyc differ diff --git a/app.ipynb b/app.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..789b36b67493b3197cd52832d7e502fc9098d9a8 --- /dev/null +++ b/app.ipynb @@ -0,0 +1,295 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "d13d3631", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Running on local URL: http://127.0.0.1:7866\n", + "\n", + "To create a public link, set `share=True` in `launch()`.\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import os\n", + "import gradio as gr\n", + "import json\n", + "from main import ChemEagle # 支持 API key 通过环境变量\n", + "from rdkit import Chem\n", + "from rdkit.Chem import rdChemReactions\n", + "from rdkit.Chem import Draw\n", + "from rdkit.Chem import AllChem\n", + "from rdkit.Chem.Draw import rdMolDraw2D\n", + "import cairosvg\n", + "import re\n", + "import torch\n", + "\n", + "example_diagram = \"examples/exp.png\"\n", + "rdkit_image = \"examples/rdkit.png\"\n", + "# 解析 ChemEagle 返回的结构化数据\n", + "def parse_reactions(output_json):\n", + " \"\"\"\n", + " 解析 JSON 格式的反应数据并格式化输出,包含颜色定制。\n", + " \"\"\"\n", + " if isinstance(output_json, str):\n", + " reactions_data = json.loads(output_json)\n", + " elif isinstance(output_json, dict):\n", + " reactions_data = output_json # 转换 JSON 字符串为字典\n", + " reactions_list = reactions_data.get(\"reactions\", [])\n", + " detailed_output = []\n", + " smiles_output = [] \n", + "\n", + " for reaction in reactions_list:\n", + " reaction_id = reaction.get(\"reaction_id\", \"Unknown ID\")\n", + " reactants = [r.get(\"smiles\", \"Unknown\") for r in reaction.get(\"reactants\", [])]\n", + " conditions = [\n", + " f\"{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]\"\n", + " for c in reaction.get(\"condition\", [])\n", + " ]\n", + " conditions_1 = [\n", + " f\"{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]\"\n", + " for c in reaction.get(\"condition\", [])\n", + " ]\n", + " products = [f\"{p.get('smiles', 'Unknown')}\" for p in reaction.get(\"products\", [])]\n", + " products_1 = [f\"{p.get('smiles', 'Unknown')}\" for p in reaction.get(\"products\", [])]\n", + " products_2 = [r.get(\"smiles\", \"Unknown\") for r in reaction.get(\"products\", [])]\n", + " \n", + " additional = reaction.get(\"additional_info\", [])\n", + " additional_str = [str(x) for x in additional if x is not None]\n", + "\n", + " tail = conditions_1 + additional_str\n", + " tail_str = \", \".join(tail)\n", + "\n", + " # 构造反应的完整字符串,定制字体颜色\n", + " full_reaction = f\"{'.'.join(reactants)}>>{'.'.join(products_1)} | {tail_str}\"\n", + " full_reaction = f\"{full_reaction}\"\n", + " \n", + " # 详细反应格式化输出\n", + " reaction_output = f\"Reaction: {reaction_id}
\"\n", + " reaction_output += f\" Reactants: {', '.join(reactants)}
\"\n", + " reaction_output += f\" Conditions: {', '.join(conditions)}
\"\n", + " reaction_output += f\" Products: {', '.join(products)}
\"\n", + " reaction_output += f\" additional_info: {', '.join(additional_str)}
\"\n", + " reaction_output += f\" Full Reaction: {full_reaction}
\"\n", + " reaction_output += \"
\"\n", + " detailed_output.append(reaction_output)\n", + "\n", + " reaction_smiles = f\"{'.'.join(reactants)}>>{'.'.join(products_2)}\"\n", + " smiles_output.append(reaction_smiles)\n", + " return detailed_output, smiles_output\n", + "\n", + "\n", + "# 核心处理函数,仅使用 API Key 和图像\n", + "def process_chem_image(api_key, image):\n", + " # 设置 API Key 环境变量,供 ChemEagle 使用\n", + " os.environ[\"CHEMEAGLE_API_KEY\"] = api_key\n", + "\n", + " # 保存上传图片\n", + " image_path = \"temp_image.png\"\n", + " image.save(image_path)\n", + "\n", + " # 调用 ChemEagle(实现内部读取 os.getenv)\n", + " chemeagle_result = ChemEagle(image_path)\n", + "\n", + " # 解析输出\n", + " detailed, smiles = parse_reactions(chemeagle_result)\n", + "\n", + " # 写出 JSON\n", + " json_path = \"output.json\"\n", + " with open(json_path, 'w') as jf:\n", + " json.dump(chemeagle_result, jf, indent=2)\n", + "\n", + " # 返回 HTML、SMILES 合并文本、示意图、JSON 下载\n", + " return \"\\n\\n\".join(detailed), smiles, example_diagram, json_path\n", + "\n", + "# 构建 Gradio 界面\n", + "with gr.Blocks() as demo:\n", + " gr.Markdown(\n", + " \"\"\"\n", + "

ChemEagle: A Multi-Agent System for Multimodal Chemical Information Extraction

\n", + " Upload a multimodal reaction image and type your OpenAI API key to extract multimodal chemical information.\n", + " \"\"\"\n", + " )\n", + "\n", + " with gr.Row():\n", + " # ———— 左侧:上传 + API Key + 按钮 ————\n", + " with gr.Column(scale=1):\n", + " image_input = gr.Image(type=\"pil\", label=\"Upload a multimodal reaction image\")\n", + " api_key_input = gr.Textbox(\n", + " label=\"Your API-Key\",\n", + " placeholder=\"Type your OpenAI_API_KEY\",\n", + " type=\"password\"\n", + " )\n", + " with gr.Row():\n", + " clear_btn = gr.Button(\"Clear\")\n", + " run_btn = gr.Button(\"Run\", elem_id=\"submit-btn\")\n", + "\n", + " # ———— 中间:解析结果 + 示意图 ————\n", + " with gr.Column(scale=1):\n", + " gr.Markdown(\"### Parsed Reactions\")\n", + " reaction_output = gr.HTML(label=\"Detailed Reaction Output\")\n", + " gr.Markdown(\"### Schematic Diagram\")\n", + " schematic_diagram = gr.Image(value=example_diagram, label=\"示意图\")\n", + "\n", + " # ———— 右侧:SMILES 拆分 & RDKit 渲染 + JSON 下载 ————\n", + " with gr.Column(scale=1):\n", + " gr.Markdown(\"### Machine-readable Output\")\n", + " smiles_output = gr.Textbox(\n", + " label=\"Reaction SMILES\",\n", + " show_copy_button=True,\n", + " interactive=False,\n", + " visible=False\n", + " )\n", + "\n", + " @gr.render(inputs = smiles_output) # 使用gr.render修饰器绑定输入和渲染逻辑\n", + " def show_split(inputs): # 定义处理和展示分割文本的函数\n", + " if not inputs or isinstance(inputs, str) and inputs.strip() == \"\": # 检查输入文本是否为空\n", + " return gr.Textbox(label= \"SMILES of Reaction i\"), gr.Image(value=rdkit_image, label= \"RDKit Image of Reaction i\",height=100)\n", + " else:\n", + " # 假设输入是逗号分隔的 SMILES 字符串\n", + " smiles_list = inputs.split(\",\")\n", + " smiles_list = [re.sub(r\"^\\s*\\[?'?|'\\]?\\s*$\", \"\", item) for item in smiles_list]\n", + " components = [] # 初始化一个组件列表,用于存放每个 SMILES 对应的 Textbox 组件\n", + " for i, smiles in enumerate(smiles_list): \n", + " smiles.replace('\"', '').replace(\"'\", \"\").replace(\"[\", \"\").replace(\"]\", \"\")\n", + " rxn = rdChemReactions.ReactionFromSmarts(smiles, useSmiles=True)\n", + " \n", + " if rxn:\n", + "\n", + " new_rxn = AllChem.ChemicalReaction()\t\n", + " for mol in rxn.GetReactants():\n", + " mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol))\n", + " new_rxn.AddReactantTemplate(mol)\n", + " for mol in rxn.GetProducts():\n", + " mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol))\n", + " new_rxn.AddProductTemplate(mol)\n", + "\n", + " rxn = new_rxn\n", + "\n", + " def atom_mapping_remover(rxn):\n", + " for reactant in rxn.GetReactants():\n", + " for atom in reactant.GetAtoms():\n", + " atom.SetAtomMapNum(0)\n", + " for product in rxn.GetProducts():\n", + " for atom in product.GetAtoms():\n", + " atom.SetAtomMapNum(0)\n", + " return rxn\n", + " \n", + " atom_mapping_remover(rxn)\n", + "\n", + " reactant1 = rxn.GetReactantTemplate(0)\n", + " print(reactant1.GetNumBonds)\n", + " reactant2 = rxn.GetReactantTemplate(1) if rxn.GetNumReactantTemplates() > 1 else None\n", + "\n", + " if reactant1.GetNumBonds() > 0:\n", + " bond_length_reference = Draw.MeanBondLength(reactant1)\n", + " elif reactant2 and reactant2.GetNumBonds() > 0:\n", + " bond_length_reference = Draw.MeanBondLength(reactant2)\n", + " else:\n", + " bond_length_reference = 1.0 \n", + "\n", + "\n", + " drawer = rdMolDraw2D.MolDraw2DSVG(-1, -1)\n", + " dopts = drawer.drawOptions()\n", + " dopts.padding = 0.1 \n", + " dopts.includeRadicals = True\n", + " Draw.SetACS1996Mode(dopts, bond_length_reference*0.55)\n", + " dopts.bondLineWidth = 1.5\n", + " drawer.DrawReaction(rxn)\n", + " drawer.FinishDrawing()\n", + " svg_content = drawer.GetDrawingText()\n", + " svg_file = f\"reaction{i+1}.svg\"\n", + " with open(svg_file, \"w\") as f:\n", + " f.write(svg_content)\n", + " png_file = f\"reaction_{i+1}.png\"\n", + " cairosvg.svg2png(url=svg_file, write_to=png_file)\n", + "\n", + "\n", + " \n", + " components.append(gr.Textbox(value=smiles,label= f\"SMILES of Reaction {i}\", show_copy_button=True, interactive=False))\n", + " components.append(gr.Image(value=png_file,label= f\"RDKit Image of Reaction {i}\")) \n", + " return components # 返回包含所有 SMILES Textbox 组件的列表\n", + "\n", + " download_json = gr.File(label=\"Download JSON File\")\n", + "\n", + "\n", + " gr.Examples(\n", + " examples=[\n", + " [\"examples/reaction1.jpg\", \"\"],\n", + " [\"examples/reaction2.png\", \"\"],\n", + " [\"examples/reaction3.png\", \"\"],\n", + " [\"examples/reaction4.png\", \"\"],\n", + " \n", + " \n", + " ],\n", + " inputs=[image_input, api_key_input],\n", + " outputs=[reaction_output, smiles_output, schematic_diagram, download_json],\n", + " cache_examples=False,\n", + " examples_per_page=4,\n", + " )\n", + "\n", + " # ———— 清空与运行 绑定 ————\n", + " clear_btn.click(\n", + " lambda: (None, None, None, None, None),\n", + " inputs=[],\n", + " outputs=[image_input, api_key_input, reaction_output, smiles_output, download_json]\n", + " )\n", + " run_btn.click(\n", + " process_chem_image,\n", + " inputs=[api_key_input, image_input],\n", + " outputs=[reaction_output, smiles_output, schematic_diagram, download_json]\n", + " )\n", + "\n", + " # 自定义按钮样式\n", + " demo.css = \"\"\"\n", + " #submit-btn {\n", + " background-color: #FF914D;\n", + " color: white;\n", + " font-weight: bold;\n", + " }\n", + " \"\"\"\n", + "\n", + " demo.launch()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "openchemie", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..bd5d10c150889dc102607a40cad3c5225e986613 --- /dev/null +++ b/app.py @@ -0,0 +1,239 @@ +import os +import gradio as gr +import json +from main import ChemEagle # 支持 API key 通过环境变量 +from rdkit import Chem +from rdkit.Chem import rdChemReactions +from rdkit.Chem import Draw +from rdkit.Chem import AllChem +from rdkit.Chem.Draw import rdMolDraw2D +import cairosvg +import re +import torch + +example_diagram = "examples/exp.png" +rdkit_image = "examples/rdkit.png" +# 解析 ChemEagle 返回的结构化数据 +def parse_reactions(output_json): + """ + 解析 JSON 格式的反应数据并格式化输出,包含颜色定制。 + """ + if isinstance(output_json, str): + reactions_data = json.loads(output_json) + elif isinstance(output_json, dict): + reactions_data = output_json # 转换 JSON 字符串为字典 + reactions_list = reactions_data.get("reactions", []) + detailed_output = [] + smiles_output = [] + + for reaction in reactions_list: + reaction_id = reaction.get("reaction_id", "Unknown ID") + reactants = [r.get("smiles", "Unknown") for r in reaction.get("reactants", [])] + conditions = [ + f"{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]" + for c in reaction.get("condition", []) + ] + conditions_1 = [ + f"{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]" + for c in reaction.get("condition", []) + ] + products = [f"{p.get('smiles', 'Unknown')}" for p in reaction.get("products", [])] + products_1 = [f"{p.get('smiles', 'Unknown')}" for p in reaction.get("products", [])] + products_2 = [r.get("smiles", "Unknown") for r in reaction.get("products", [])] + + additional = reaction.get("additional_info", []) + additional_str = [str(x) for x in additional if x is not None] + + tail = conditions_1 + additional_str + tail_str = ", ".join(tail) + + # 构造反应的完整字符串,定制字体颜色 + full_reaction = f"{'.'.join(reactants)}>>{'.'.join(products_1)} | {tail_str}" + full_reaction = f"{full_reaction}" + + # 详细反应格式化输出 + reaction_output = f"Reaction: {reaction_id}
" + reaction_output += f" Reactants: {', '.join(reactants)}
" + reaction_output += f" Conditions: {', '.join(conditions)}
" + reaction_output += f" Products: {', '.join(products)}
" + reaction_output += f" additional_info: {', '.join(additional_str)}
" + reaction_output += f" Full Reaction: {full_reaction}
" + reaction_output += "
" + detailed_output.append(reaction_output) + + reaction_smiles = f"{'.'.join(reactants)}>>{'.'.join(products_2)}" + smiles_output.append(reaction_smiles) + return detailed_output, smiles_output + + +# 核心处理函数,仅使用 API Key 和图像 +def process_chem_image(api_key, image): + # 设置 API Key 环境变量,供 ChemEagle 使用 + os.environ["CHEMEAGLE_API_KEY"] = api_key + + # 保存上传图片 + image_path = "temp_image.png" + image.save(image_path) + + # 调用 ChemEagle(实现内部读取 os.getenv) + chemeagle_result = ChemEagle(image_path) + + # 解析输出 + detailed, smiles = parse_reactions(chemeagle_result) + + # 写出 JSON + json_path = "output.json" + with open(json_path, 'w') as jf: + json.dump(chemeagle_result, jf, indent=2) + + # 返回 HTML、SMILES 合并文本、示意图、JSON 下载 + return "\n\n".join(detailed), smiles, example_diagram, json_path + +# 构建 Gradio 界面 +with gr.Blocks() as demo: + gr.Markdown( + """ +

ChemEagle: A Multi-Agent System for Multimodal Chemical Information Extraction

+ Upload a multimodal reaction image and type your OpenAI API key to extract multimodal chemical information. + """ + ) + + with gr.Row(): + # ———— 左侧:上传 + API Key + 按钮 ———— + with gr.Column(scale=1): + image_input = gr.Image(type="pil", label="Upload a multimodal reaction image") + api_key_input = gr.Textbox( + label="Your API-Key", + placeholder="Type your OpenAI_API_KEY", + type="password" + ) + with gr.Row(): + clear_btn = gr.Button("Clear") + run_btn = gr.Button("Run", elem_id="submit-btn") + + # ———— 中间:解析结果 + 示意图 ———— + with gr.Column(scale=1): + gr.Markdown("### Parsed Reactions") + reaction_output = gr.HTML(label="Detailed Reaction Output") + gr.Markdown("### Schematic Diagram") + schematic_diagram = gr.Image(value=example_diagram, label="示意图") + + # ———— 右侧:SMILES 拆分 & RDKit 渲染 + JSON 下载 ———— + with gr.Column(scale=1): + gr.Markdown("### Machine-readable Output") + smiles_output = gr.Textbox( + label="Reaction SMILES", + show_copy_button=True, + interactive=False, + visible=False + ) + + @gr.render(inputs = smiles_output) # 使用gr.render修饰器绑定输入和渲染逻辑 + def show_split(inputs): # 定义处理和展示分割文本的函数 + if not inputs or isinstance(inputs, str) and inputs.strip() == "": # 检查输入文本是否为空 + return gr.Textbox(label= "SMILES of Reaction i"), gr.Image(value=rdkit_image, label= "RDKit Image of Reaction i",height=100) + else: + # 假设输入是逗号分隔的 SMILES 字符串 + smiles_list = inputs.split(",") + smiles_list = [re.sub(r"^\s*\[?'?|'\]?\s*$", "", item) for item in smiles_list] + components = [] # 初始化一个组件列表,用于存放每个 SMILES 对应的 Textbox 组件 + for i, smiles in enumerate(smiles_list): + smiles.replace('"', '').replace("'", "").replace("[", "").replace("]", "") + rxn = rdChemReactions.ReactionFromSmarts(smiles, useSmiles=True) + + if rxn: + + new_rxn = AllChem.ChemicalReaction() + for mol in rxn.GetReactants(): + mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol)) + new_rxn.AddReactantTemplate(mol) + for mol in rxn.GetProducts(): + mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol)) + new_rxn.AddProductTemplate(mol) + + rxn = new_rxn + + def atom_mapping_remover(rxn): + for reactant in rxn.GetReactants(): + for atom in reactant.GetAtoms(): + atom.SetAtomMapNum(0) + for product in rxn.GetProducts(): + for atom in product.GetAtoms(): + atom.SetAtomMapNum(0) + return rxn + + atom_mapping_remover(rxn) + + reactant1 = rxn.GetReactantTemplate(0) + print(reactant1.GetNumBonds) + reactant2 = rxn.GetReactantTemplate(1) if rxn.GetNumReactantTemplates() > 1 else None + + if reactant1.GetNumBonds() > 0: + bond_length_reference = Draw.MeanBondLength(reactant1) + elif reactant2 and reactant2.GetNumBonds() > 0: + bond_length_reference = Draw.MeanBondLength(reactant2) + else: + bond_length_reference = 1.0 + + + drawer = rdMolDraw2D.MolDraw2DSVG(-1, -1) + dopts = drawer.drawOptions() + dopts.padding = 0.1 + dopts.includeRadicals = True + Draw.SetACS1996Mode(dopts, bond_length_reference*0.55) + dopts.bondLineWidth = 1.5 + drawer.DrawReaction(rxn) + drawer.FinishDrawing() + svg_content = drawer.GetDrawingText() + svg_file = f"reaction{i+1}.svg" + with open(svg_file, "w") as f: + f.write(svg_content) + png_file = f"reaction_{i+1}.png" + cairosvg.svg2png(url=svg_file, write_to=png_file) + + + + components.append(gr.Textbox(value=smiles,label= f"SMILES of Reaction {i}", show_copy_button=True, interactive=False)) + components.append(gr.Image(value=png_file,label= f"RDKit Image of Reaction {i}")) + return components # 返回包含所有 SMILES Textbox 组件的列表 + + download_json = gr.File(label="Download JSON File") + + + gr.Examples( + examples=[ + ["examples/reaction1.jpg", ""], + ["examples/reaction2.png", ""], + ["examples/reaction3.png", ""], + ["examples/reaction4.png", ""], + + + ], + inputs=[image_input, api_key_input], + outputs=[reaction_output, smiles_output, schematic_diagram, download_json], + cache_examples=False, + examples_per_page=4, + ) + + # ———— 清空与运行 绑定 ———— + clear_btn.click( + lambda: (None, None, None, None, None), + inputs=[], + outputs=[image_input, api_key_input, reaction_output, smiles_output, download_json] + ) + run_btn.click( + process_chem_image, + inputs=[api_key_input, image_input], + outputs=[reaction_output, smiles_output, schematic_diagram, download_json] + ) + + # 自定义按钮样式 + demo.css = """ + #submit-btn { + background-color: #FF914D; + color: white; + font-weight: bold; + } + """ + + demo.launch() \ No newline at end of file diff --git a/chemiener/__init__.py b/chemiener/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e85e549b1caf17e89fe52280c4cfca85b4b4b0d3 --- /dev/null +++ b/chemiener/__init__.py @@ -0,0 +1 @@ +from .interface import ChemNER diff --git a/chemiener/__pycache__/__init__.cpython-310.pyc b/chemiener/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a27295c7c9c96f1a00e14f85dc8bc869c09f9642 Binary files /dev/null and b/chemiener/__pycache__/__init__.cpython-310.pyc differ diff --git a/chemiener/__pycache__/__init__.cpython-38.pyc b/chemiener/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9886584b237f8c0a3169225198e2cad76995e65d Binary files /dev/null and b/chemiener/__pycache__/__init__.cpython-38.pyc differ diff --git a/chemiener/__pycache__/dataset.cpython-310.pyc b/chemiener/__pycache__/dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1a59538451e048b3e936f92f32333331b95aa84 Binary files /dev/null and b/chemiener/__pycache__/dataset.cpython-310.pyc differ diff --git a/chemiener/__pycache__/dataset.cpython-38.pyc b/chemiener/__pycache__/dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d559fc0a2651f8c505016896bddcae423839cac Binary files /dev/null and b/chemiener/__pycache__/dataset.cpython-38.pyc differ diff --git a/chemiener/__pycache__/interface.cpython-310.pyc b/chemiener/__pycache__/interface.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc47309d4d94ad597179fe3345b02259d7df29b4 Binary files /dev/null and b/chemiener/__pycache__/interface.cpython-310.pyc differ diff --git a/chemiener/__pycache__/interface.cpython-38.pyc b/chemiener/__pycache__/interface.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2808d7e57accf5eb9dc8935bc8adc0326a4eda80 Binary files /dev/null and b/chemiener/__pycache__/interface.cpython-38.pyc differ diff --git a/chemiener/__pycache__/model.cpython-310.pyc b/chemiener/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c358e38733d2305da4bfec62c019ca0e69fcab93 Binary files /dev/null and b/chemiener/__pycache__/model.cpython-310.pyc differ diff --git a/chemiener/__pycache__/model.cpython-38.pyc b/chemiener/__pycache__/model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee855688c15662e094f84eb1c4b7b6eb0f83cb2e Binary files /dev/null and b/chemiener/__pycache__/model.cpython-38.pyc differ diff --git a/chemiener/__pycache__/utils.cpython-310.pyc b/chemiener/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee7dd9093dbbb0a7e1657949397e34833cbb8bed Binary files /dev/null and b/chemiener/__pycache__/utils.cpython-310.pyc differ diff --git a/chemiener/__pycache__/utils.cpython-38.pyc b/chemiener/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aca3842eb5dd8222e4fd52830ca78b4061bffe90 Binary files /dev/null and b/chemiener/__pycache__/utils.cpython-38.pyc differ diff --git a/chemiener/dataset.py b/chemiener/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ebc1a9619bf618547e715a0a7e49f14d28ecbcb7 --- /dev/null +++ b/chemiener/dataset.py @@ -0,0 +1,172 @@ +import os +import cv2 +import copy +import random +import json +import contextlib +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset +from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence + +from transformers import BertTokenizerFast, AutoTokenizer, RobertaTokenizerFast + +from .utils import get_class_to_index + + + +class NERDataset(Dataset): + def __init__(self, args, data_file, split='train'): + super().__init__() + self.args = args + if data_file: + data_path = os.path.join(args.data_path, data_file) + with open(data_path) as f: + self.data = json.load(f) + self.name = os.path.basename(data_file).split('.')[0] + self.split = split + self.is_train = (split == 'train') + self.tokenizer = AutoTokenizer.from_pretrained(self.args.roberta_checkpoint, cache_dir = self.args.cache_dir)#BertTokenizerFast.from_pretrained('allenai/scibert_scivocab_uncased') + self.class_to_index = get_class_to_index(self.args.corpus) + self.index_to_class = {self.class_to_index[key]: key for key in self.class_to_index} + + #commment + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + + text_tokenized = self.tokenizer(self.data[str(idx)]['text'], truncation = True, max_length = self.args.max_seq_length) + if len(text_tokenized['input_ids']) > 512: print(len(text_tokenized['input_ids'])) + text_tokenized_untruncated = self.tokenizer(self.data[str(idx)]['text']) + return text_tokenized, self.align_labels(text_tokenized, self.data[str(idx)]['entities'], len(self.data[str(idx)]['text'])), self.align_labels(text_tokenized_untruncated, self.data[str(idx)]['entities'], len(self.data[str(idx)]['text'])) + + def align_labels(self, text_tokenized, entities, length): + char_to_class = {} + + for entity in entities: + for span in entities[entity]["span"]: + for i in range(span[0], span[1]): + char_to_class[i] = self.class_to_index[('B-' if i == span[0] else 'I-')+str(entities[entity]["type"])] + + for i in range(length): + if i not in char_to_class: + char_to_class[i] = 0 + + classes = [] + for i in range(len(text_tokenized[0])): + span = text_tokenized.token_to_chars(i) + if span is not None: + classes.append(char_to_class[span.start]) + else: + classes.append(-100) + + return torch.LongTensor(classes) + + def make_html(word_tokens, predictions): + + toreturn = ''' + + + Named Entity Recognition Visualization + + + +

''' + last_label = None + for idx, item in enumerate(word_tokens): + decoded = self.tokenizer.decode(item, skip_special_tokens = True) + if len(decoded)>0: + if idx!=0 and decoded[0]!='#': + toreturn+=" " + label = predictions[idx] + if label == last_label: + + toreturn+=decoded if decoded[0]!="#" else decoded[2:] + else: + if last_label is not None and last_label>0: + toreturn+="" + if label >0: + toreturn+="" + toreturn+=decoded if decoded[0]!="#" else decoded[2:] + if label == 0: + toreturn+=decoded if decoded[0]!="#" else decoded[2:] + if idx==len(word_tokens) and label>0: + toreturn+="" + last_label = label + + toreturn += '''

+ + ''' + return toreturn + + +def get_collate_fn(): + def collate(batch): + + + + sentences = [] + masks = [] + refs = [] + + + for ex in batch: + sentences.append(torch.LongTensor(ex[0]['input_ids'])) + masks.append(torch.Tensor(ex[0]['attention_mask'])) + refs.append(ex[1]) + + sentences = pad_sequence(sentences, batch_first = True, padding_value = 0) + masks = pad_sequence(masks, batch_first = True, padding_value = 0) + refs = pad_sequence(refs, batch_first = True, padding_value = -100) + return sentences, masks, refs + + return collate + + + diff --git a/chemiener/interface.py b/chemiener/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..deb5094e3f9598ebe7e72e4b77ee9ed82b6a8fc8 --- /dev/null +++ b/chemiener/interface.py @@ -0,0 +1,124 @@ +import os +import argparse +from typing import List +import torch +import numpy as np + +from .model import build_model + +from .dataset import NERDataset, get_collate_fn + +from huggingface_hub import hf_hub_download + +from .utils import get_class_to_index + +class ChemNER: + + def __init__(self, model_path, device = None, cache_dir = None): + + self.args = self._get_args(cache_dir) + + states = torch.load(model_path, map_location = torch.device('cpu')) + + if device is None: + device = torch.device('cpu') + + self.device = device + + self.model = self.get_model(self.args, device, states['state_dict']) + + self.collate = get_collate_fn() + + self.dataset = NERDataset(self.args, data_file = None) + + self.class_to_index = get_class_to_index(self.args.corpus) + + self.index_to_class = {self.class_to_index[key]: key for key in self.class_to_index} + + def _get_args(self, cache_dir): + parser = argparse.ArgumentParser() + + parser.add_argument('--roberta_checkpoint', default = 'dmis-lab/biobert-large-cased-v1.1', type=str, help='which roberta config to use') + + parser.add_argument('--corpus', default = "chemdner", type=str, help="which corpus should the tags be from") + + args = parser.parse_args([]) + + args.cache_dir = cache_dir + + return args + + def get_model(self, args, device, model_states): + model = build_model(args) + + def remove_prefix(state_dict): + return {k.replace('model.', ''): v for k, v in state_dict.items()} + + model.load_state_dict(remove_prefix(model_states), strict = False) + + model.to(device) + + model.eval() + + return model + + def predict_strings(self, strings: List, batch_size = 8): + device = self.device + + predictions = [] + + def prepare_output(char_span, prediction): + toreturn = [] + + + i = 0 + + while i < len(char_span): + if prediction[i][0] == 'B': + toreturn.append((prediction[i][2:], [char_span[i].start, char_span[i].end])) + + + + + elif len(toreturn) > 0 and prediction[i][2:] == toreturn[-1][0]: + toreturn[-1] = (toreturn[-1][0], [toreturn[-1][1][0], char_span[i].end]) + + + + i += 1 + + + return toreturn + + output = [] + for idx in range(0, len(strings), batch_size): + batch_strings = strings[idx:idx+batch_size] + batch_strings_tokenized = [(self.dataset.tokenizer(s, truncation = True, max_length = 512), torch.Tensor([-1]), torch.Tensor([-1]) ) for s in batch_strings] + + + sentences, masks, refs = self.collate(batch_strings_tokenized) + + predictions = self.model(input_ids = sentences.to(device), attention_mask = masks.to(device))[0].argmax(dim = 2).to('cpu') + + sentences_list = list(sentences) + + predictions_list = list(predictions) + + + char_spans = [] + for j, sentence in enumerate(sentences_list): + to_add = [batch_strings_tokenized[j][0].token_to_chars(i) for i, word in enumerate(sentence) if len(self.dataset.tokenizer.decode(int(word.item()), skip_special_tokens = True)) > 0 ] + char_spans.append(to_add) + + class_predictions = [[self.index_to_class[int(pred.item())] for (pred, word) in zip(sentence_p, sentence_w) if len(self.dataset.tokenizer.decode(int(word.item()), skip_special_tokens = True)) > 0] for (sentence_p, sentence_w) in zip(predictions_list, sentences_list)] + + + + output+=[prepare_output(char_span, prediction) for char_span, prediction in zip(char_spans, class_predictions)] + + return output + + + + + diff --git a/chemiener/main.py b/chemiener/main.py new file mode 100644 index 0000000000000000000000000000000000000000..85bb16bcb1470872830cc1afa03042dc391b442f --- /dev/null +++ b/chemiener/main.py @@ -0,0 +1,345 @@ +import os +import math +import json +import random +import argparse +import numpy as np + +import time + +import torch +from torch.profiler import profile, record_function, ProfilerActivity +import torch.distributed as dist +import pytorch_lightning as pl +from pytorch_lightning import LightningModule, LightningDataModule +from pytorch_lightning.callbacks import LearningRateMonitor +from pytorch_lightning.strategies.ddp import DDPStrategy +from transformers import get_scheduler +import transformers + +from dataset import NERDataset, get_collate_fn + +from model import build_model + +from utils import get_class_to_index + +import evaluate + +from seqeval.metrics import accuracy_score +from seqeval.metrics import classification_report +from seqeval.metrics import f1_score +from seqeval.scheme import IOB2 + + + +def get_args(notebook=False): + parser = argparse.ArgumentParser() + parser.add_argument('--do_train', action='store_true') + parser.add_argument('--do_valid', action='store_true') + parser.add_argument('--do_test', action='store_true') + parser.add_argument('--fp16', action='store_true') + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--gpus', type=int, default=1) + parser.add_argument('--print_freq', type=int, default=200) + parser.add_argument('--debug', action='store_true') + parser.add_argument('--no_eval', action='store_true') + + + # Data + parser.add_argument('--data_path', type=str, default=None) + parser.add_argument('--image_path', type=str, default=None) + parser.add_argument('--train_file', type=str, default=None) + parser.add_argument('--valid_file', type=str, default=None) + parser.add_argument('--test_file', type=str, default=None) + parser.add_argument('--vocab_file', type=str, default=None) + parser.add_argument('--format', type=str, default='reaction') + parser.add_argument('--num_workers', type=int, default=8) + parser.add_argument('--input_size', type=int, default=224) + + # Training + parser.add_argument('--epochs', type=int, default=8) + parser.add_argument('--batch_size', type=int, default=256) + parser.add_argument('--lr', type=float, default=1e-4) + parser.add_argument('--weight_decay', type=float, default=0.05) + parser.add_argument('--max_grad_norm', type=float, default=5.) + parser.add_argument('--scheduler', type=str, choices=['cosine', 'constant'], default='cosine') + parser.add_argument('--warmup_ratio', type=float, default=0) + parser.add_argument('--gradient_accumulation_steps', type=int, default=1) + parser.add_argument('--load_path', type=str, default=None) + parser.add_argument('--load_encoder_only', action='store_true') + parser.add_argument('--train_steps_per_epoch', type=int, default=-1) + parser.add_argument('--eval_per_epoch', type=int, default=10) + parser.add_argument('--save_path', type=str, default='output/') + parser.add_argument('--save_mode', type=str, default='best', choices=['best', 'all', 'last']) + parser.add_argument('--load_ckpt', type=str, default='best') + parser.add_argument('--resume', action='store_true') + parser.add_argument('--num_train_example', type=int, default=None) + + parser.add_argument('--roberta_checkpoint', type=str, default = "roberta-base") + + parser.add_argument('--corpus', type=str, default = "chemu") + + parser.add_argument('--cache_dir') + + parser.add_argument('--eval_truncated', action='store_true') + + parser.add_argument('--max_seq_length', type = int, default=512) + + args = parser.parse_args([]) if notebook else parser.parse_args() + + + + + + return args + + +class ChemIENERecognizer(LightningModule): + + def __init__(self, args): + super().__init__() + + self.args = args + + self.model = build_model(args) + + self.validation_step_outputs = [] + + def training_step(self, batch, batch_idx): + + + + + sentences, masks, refs,_ = batch + ''' + print("sentences " + str(sentences)) + print("sentence shape " + str(sentences.shape)) + print("masks " + str(masks)) + print("masks shape " + str(masks.shape)) + print("refs " + str(refs)) + print("refs shape " + str(refs.shape)) + ''' + + + + loss, logits = self.model(input_ids=sentences, attention_mask=masks, labels=refs) + self.log('train/loss', loss) + self.log('lr', self.lr_schedulers().get_lr()[0], prog_bar=True, logger=False) + return loss + + def validation_step(self, batch, batch_idx): + + sentences, masks, refs, untruncated = batch + ''' + print("sentences " + str(sentences)) + print("sentence shape " + str(sentences.shape)) + print("masks " + str(masks)) + print("masks shape " + str(masks.shape)) + print("refs " + str(refs)) + print("refs shape " + str(refs.shape)) + ''' + + logits = self.model(input_ids = sentences, attention_mask=masks)[0] + ''' + print("logits " + str(logits)) + print(sentences.shape) + print(logits.shape) + print(torch.eq(logits.argmax(dim = 2), refs).sum()) + ''' + self.validation_step_outputs.append((sentences.to("cpu"), logits.argmax(dim = 2).to("cpu"), refs.to('cpu'), untruncated.to("cpu"))) + + + def on_validation_epoch_end(self): + if self.trainer.num_devices > 1: + gathered_outputs = [None for i in range(self.trainer.num_devices)] + dist.all_gather_object(gathered_outputs, self.validation_step_outputs) + gathered_outputs = sum(gathered_outputs, []) + else: + gathered_outputs = self.validation_step_outputs + + sentences = [list(output[0]) for output in gathered_outputs] + + class_to_index = get_class_to_index(self.args.corpus) + + + + index_to_class = {class_to_index[key]: key for key in class_to_index} + predictions = [list(output[1]) for output in gathered_outputs] + labels = [list(output[2]) for output in gathered_outputs] + + untruncateds = [list(output[3]) for output in gathered_outputs] + + untruncateds = [[index_to_class[int(label.item())] for label in sentence if int(label.item()) != -100] for batched in untruncateds for sentence in batched] + + + output = {"sentences": [[int(word.item()) for (word, label) in zip(sentence_w, sentence_l) if label != -100] for (batched_w, batched_l) in zip(sentences, labels) for (sentence_w, sentence_l) in zip(batched_w, batched_l) ], + "predictions": [[index_to_class[int(pred.item())] for (pred, label) in zip(sentence_p, sentence_l) if label!=-100] for (batched_p, batched_l) in zip(predictions, labels) for (sentence_p, sentence_l) in zip(batched_p, batched_l) ], + "groundtruth": [[index_to_class[int(label.item())] for label in sentence if label != -100] for batched in labels for sentence in batched]} + + + #true_labels = [str(label.item()) for batched in labels for sentence in batched for label in sentence if label != -100] + #true_predictions = [str(pred.item()) for (batched_p, batched_l) in zip(predictions, labels) for (sentence_p, sentence_l) in zip(batched_p, batched_l) for (pred, label) in zip(sentence_p, sentence_l) if label!=-100 ] + + + + #print("true_label " + str(len(true_labels)) + " true_predictions "+str(len(true_predictions))) + + + #predictions = utils.merge_predictions(gathered_outputs) + name = self.eval_dataset.name + scores = [0] + + #print(predictions) + #print(predictions[0].shape) + + if self.trainer.is_global_zero: + if not self.args.no_eval: + epoch = self.trainer.current_epoch + + metric = evaluate.load("seqeval", cache_dir = self.args.cache_dir) + + predictions = [ preds + ['O'] * (len(full_groundtruth) - len(preds)) for (preds, full_groundtruth) in zip(output['predictions'], untruncateds)] + all_metrics = metric.compute(predictions = predictions, references = untruncateds) + + #accuracy = sum([1 if p == l else 0 for (p, l) in zip(true_predictions, true_labels)])/len(true_labels) + + #precision = torch.eq(self.eval_dataset.data, predictions.argmax(dim = 1)).sum().float()/self.eval_dataset.data.numel() + #self.print("Epoch: "+str(epoch)+" accuracy: "+str(accuracy)) + if self.args.eval_truncated: + report = classification_report(output['groundtruth'], output['predictions'], mode = 'strict', scheme = IOB2, output_dict = True) + else: + #report = classification_report(predictions, untruncateds, output_dict = True)#, mode = 'strict', scheme = IOB2, output_dict = True) + report = classification_report(predictions, untruncateds, mode = 'strict', scheme = IOB2, output_dict = True) + self.print(report) + #self.print("______________________________________________") + #self.print(report_strict) + scores = [report['micro avg']['f1-score']] + with open(os.path.join(self.trainer.default_root_dir, f'prediction_{name}.json'), 'w') as f: + json.dump(output, f) + + dist.broadcast_object_list(scores) + + self.log('val/score', scores[0], prog_bar=True, rank_zero_only=True) + self.validation_step_outputs.clear() + + + + self.validation_step_outputs.clear() + + def configure_optimizers(self): + num_training_steps = self.trainer.num_training_steps + + self.print(f'Num training steps: {num_training_steps}') + num_warmup_steps = int(num_training_steps * self.args.warmup_ratio) + optimizer = torch.optim.AdamW(self.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay) + scheduler = get_scheduler(self.args.scheduler, optimizer, num_warmup_steps, num_training_steps) + return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'}} + +class NERDataModule(LightningDataModule): + + def __init__(self, args): + super().__init__() + self.args = args + self.collate_fn = get_collate_fn() + + def prepare_data(self): + args = self.args + if args.do_train: + self.train_dataset = NERDataset(args, args.train_file, split='train') + if self.args.do_train or self.args.do_valid: + self.val_dataset = NERDataset(args, args.valid_file, split='valid') + if self.args.do_test: + self.test_dataset = NERDataset(args, args.test_file, split='valid') + + def print_stats(self): + if self.args.do_train: + print(f'Train dataset: {len(self.train_dataset)}') + if self.args.do_train or self.args.do_valid: + print(f'Valid dataset: {len(self.val_dataset)}') + if self.args.do_test: + print(f'Test dataset: {len(self.test_dataset)}') + + + def train_dataloader(self): + return torch.utils.data.DataLoader( + self.train_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers, + collate_fn=self.collate_fn) + + def val_dataloader(self): + return torch.utils.data.DataLoader( + self.val_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers, + collate_fn=self.collate_fn) + + + def test_dataloader(self): + return torch.utils.data.DataLoader( + self.test_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers, + collate_fn=self.collate_fn) + + + +class ModelCheckpoint(pl.callbacks.ModelCheckpoint): + def _get_metric_interpolated_filepath_name(self, monitor_candidates, trainer, del_filepath=None) -> str: + filepath = self.format_checkpoint_name(monitor_candidates) + return filepath + +def main(): + transformers.utils.logging.set_verbosity_error() + args = get_args() + + pl.seed_everything(args.seed, workers = True) + + if args.do_train: + model = ChemIENERecognizer(args) + else: + model = ChemIENERecognizer.load_from_checkpoint(os.path.join(args.save_path, 'checkpoints/best.ckpt'), strict=False, + args=args) + + dm = NERDataModule(args) + dm.prepare_data() + dm.print_stats() + + checkpoint = ModelCheckpoint(monitor='val/score', mode='max', save_top_k=1, filename='best', save_last=True) + # checkpoint = ModelCheckpoint(monitor=None, save_top_k=0, save_last=True) + lr_monitor = LearningRateMonitor(logging_interval='step') + logger = pl.loggers.TensorBoardLogger(args.save_path, name='', version='') + + trainer = pl.Trainer( + strategy=DDPStrategy(find_unused_parameters=False), + accelerator='gpu', + precision = 16, + devices=args.gpus, + logger=logger, + default_root_dir=args.save_path, + callbacks=[checkpoint, lr_monitor], + max_epochs=args.epochs, + gradient_clip_val=args.max_grad_norm, + accumulate_grad_batches=args.gradient_accumulation_steps, + check_val_every_n_epoch=args.eval_per_epoch, + log_every_n_steps=10, + deterministic='warn') + + if args.do_train: + trainer.num_training_steps = math.ceil( + len(dm.train_dataset) / (args.batch_size * args.gpus * args.gradient_accumulation_steps)) * args.epochs + model.eval_dataset = dm.val_dataset + ckpt_path = os.path.join(args.save_path, 'checkpoints/last.ckpt') if args.resume else None + trainer.fit(model, datamodule=dm, ckpt_path=ckpt_path) + model = ChemIENERecognizer.load_from_checkpoint(checkpoint.best_model_path, args=args) + + if args.do_valid: + + model.eval_dataset = dm.val_dataset + + trainer.validate(model, datamodule=dm) + + if args.do_test: + + model.test_dataset = dm.test_dataset + + trainer.test(model, datamodule=dm) + + +if __name__ == "__main__": + main() + diff --git a/chemiener/model.py b/chemiener/model.py new file mode 100644 index 0000000000000000000000000000000000000000..6a45a35fbbd4374f5e9ddd30ce2a72c66eef7578 --- /dev/null +++ b/chemiener/model.py @@ -0,0 +1,14 @@ +import torch +from torch import nn + + +from transformers import BertForTokenClassification, RobertaForTokenClassification, AutoModelForTokenClassification + + +def build_model(args): + if args.corpus == "chemu": + return AutoModelForTokenClassification.from_pretrained(args.roberta_checkpoint, num_labels = 21, cache_dir = args.cache_dir, return_dict = False) + elif args.corpus == "chemdner": + return AutoModelForTokenClassification.from_pretrained(args.roberta_checkpoint, num_labels = 17, cache_dir = args.cache_dir, return_dict = False) + elif args.corpus == "chemdner-mol": + return AutoModelForTokenClassification.from_pretrained(args.roberta_checkpoint, num_labels = 3, cache_dir = args.cache_dir, return_dict = False) diff --git a/chemiener/utils.py b/chemiener/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5ab386cec33ed2ea1f2f54e1dbb77a6518dde199 --- /dev/null +++ b/chemiener/utils.py @@ -0,0 +1,23 @@ + +def merge_predictions(results): + if len(results) == 0: + return [] + predictions = {} + for batch_preds in results: + for idx, preds in enumerate(batch_preds): + predictions[idx] = preds + predictions = [predictions[i] for i in range(len(predictions))] + + return predictions + +def get_class_to_index(corpus): + if corpus == "chemu": + return {'B-EXAMPLE_LABEL': 1, 'B-REACTION_PRODUCT': 2, 'B-STARTING_MATERIAL': 3, 'B-REAGENT_CATALYST': 4, 'B-SOLVENT': 5, 'B-OTHER_COMPOUND': 6, 'B-TIME': 7, 'B-TEMPERATURE': 8, 'B-YIELD_OTHER': 9, 'B-YIELD_PERCENT': 10, 'O': 0, + 'I-EXAMPLE_LABEL': 11, 'I-REACTION_PRODUCT': 12, 'I-STARTING_MATERIAL': 13, 'I-REAGENT_CATALYST': 14, 'I-SOLVENT': 15, 'I-OTHER_COMPOUND': 16, 'I-TIME': 17, 'I-TEMPERATURE': 18, 'I-YIELD_OTHER': 19, 'I-YIELD_PERCENT': 20} + elif corpus == "chemdner": + return {'O': 0, 'B-ABBREVIATION': 1, 'B-FAMILY': 2, 'B-FORMULA': 3, 'B-IDENTIFIER': 4, 'B-MULTIPLE': 5, 'B-SYSTEMATIC': 6, 'B-TRIVIAL': 7, 'B-NO CLASS': 8, 'I-ABBREVIATION': 9, 'I-FAMILY': 10, 'I-FORMULA': 11, 'I-IDENTIFIER': 12, 'I-MULTIPLE': 13, 'I-SYSTEMATIC': 14, 'I-TRIVIAL': 15, 'I-NO CLASS': 16} + elif corpus == "chemdner-mol": + return {'O': 0, 'B-MOL': 1, 'I-MOL': 2} + + + diff --git a/chemietoolkit/__init__.py b/chemietoolkit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0fdfd65736d19fa860228cc8d8fae8cabd55160e --- /dev/null +++ b/chemietoolkit/__init__.py @@ -0,0 +1 @@ +from .interface import ChemIEToolkit diff --git a/chemietoolkit/__pycache__/__init__.cpython-310.pyc b/chemietoolkit/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..897127a31f243b41f8074d88f8f03e0c55e8d986 Binary files /dev/null and b/chemietoolkit/__pycache__/__init__.cpython-310.pyc differ diff --git a/chemietoolkit/__pycache__/__init__.cpython-38.pyc b/chemietoolkit/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..361fdf785ff3f59c4682e9fe2deea392e6ef978e Binary files /dev/null and b/chemietoolkit/__pycache__/__init__.cpython-38.pyc differ diff --git a/chemietoolkit/__pycache__/chemrxnextractor.cpython-310.pyc b/chemietoolkit/__pycache__/chemrxnextractor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7b968a7d806befa1a61b29e85dd0df0cab8da9a Binary files /dev/null and b/chemietoolkit/__pycache__/chemrxnextractor.cpython-310.pyc differ diff --git a/chemietoolkit/__pycache__/chemrxnextractor.cpython-38.pyc b/chemietoolkit/__pycache__/chemrxnextractor.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffbbe9cadd2014752bfa908c3bff5b7658a61545 Binary files /dev/null and b/chemietoolkit/__pycache__/chemrxnextractor.cpython-38.pyc differ diff --git a/chemietoolkit/__pycache__/interface.cpython-310.pyc b/chemietoolkit/__pycache__/interface.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..111fcfdbc703b164bb6c3e2d5e149e0f8d8e7492 Binary files /dev/null and b/chemietoolkit/__pycache__/interface.cpython-310.pyc differ diff --git a/chemietoolkit/__pycache__/interface.cpython-38.pyc b/chemietoolkit/__pycache__/interface.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f28255f11e0e64d438c52eeeee85e44d52faeaf Binary files /dev/null and b/chemietoolkit/__pycache__/interface.cpython-38.pyc differ diff --git a/chemietoolkit/__pycache__/tableextractor.cpython-310.pyc b/chemietoolkit/__pycache__/tableextractor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0dd4b3edfcbe675a863d7232317b57b5f18c7740 Binary files /dev/null and b/chemietoolkit/__pycache__/tableextractor.cpython-310.pyc differ diff --git a/chemietoolkit/__pycache__/utils.cpython-310.pyc b/chemietoolkit/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b31890bd3f8723ca84c75a6c684a893a790ae287 Binary files /dev/null and b/chemietoolkit/__pycache__/utils.cpython-310.pyc differ diff --git a/chemietoolkit/chemrxnextractor.py b/chemietoolkit/chemrxnextractor.py new file mode 100644 index 0000000000000000000000000000000000000000..76c121b9cd3d5bdeb3dcdd38bc563d0d1c570733 --- /dev/null +++ b/chemietoolkit/chemrxnextractor.py @@ -0,0 +1,107 @@ +from PyPDF2 import PdfReader, PdfWriter +import pdfminer.high_level +import pdfminer.layout +from operator import itemgetter +import os +import pdftotext +from chemrxnextractor import RxnExtractor + +class ChemRxnExtractor(object): + def __init__(self, pdf, pn, model_dir, device): + self.pdf_file = pdf + self.pages = pn + self.model_dir = os.path.join(model_dir, "cre_models_v0.1") # directory saving both prod and role models + use_cuda = (device == 'cuda') + self.rxn_extractor = RxnExtractor(self.model_dir, use_cuda=use_cuda) + self.text_file = "info.txt" + self.pdf_text = "" + if len(self.pdf_file) > 0: + with open(self.pdf_file, "rb") as f: + self.pdf_text = pdftotext.PDF(f) + + def set_pdf_file(self, pdf): + self.pdf_file = pdf + with open(self.pdf_file, "rb") as f: + self.pdf_text = pdftotext.PDF(f) + + def set_pages(self, pn): + self.pages = pn + + def set_model_dir(self, md): + self.model_dir = md + self.rxn_extractor = RxnExtractor(self.model_dir) + + def set_text_file(self, tf): + self.text_file = tf + + def extract_reactions_from_text(self): + if self.pages is None: + return self.extract_all(len(self.pdf_text)) + else: + return self.extract_all(self.pages) + + def extract_all(self, pages): + ans = [] + text = self.get_paragraphs_from_pdf(pages) + for data in text: + L = [sent for paragraph in data['paragraphs'] for sent in paragraph] + reactions = self.get_reactions(L, page_number=data['page']) + ans.append(reactions) + return ans + + def get_reactions(self, sents, page_number=None): + rxns = self.rxn_extractor.get_reactions(sents) + + ret = [] + for r in rxns: + if len(r['reactions']) != 0: ret.append(r) + ans = {} + ans.update({'page' : page_number}) + ans.update({'reactions' : ret}) + return ans + + + def get_paragraphs_from_pdf(self, pages): + current_page_num = 1 + if pages is None: + pages = len(self.pdf_text) + result = [] + for page in range(pages): + content = self.pdf_text[page] + pg = content.split("\n\n") + L = [] + for line in pg: + paragraph = [] + if '\x0c' in line: + continue + text = line + text = text.replace("\n", " ") + text = text.replace("- ", "-") + curind = 0 + i = 0 + while i < len(text): + if text[i] == '.': + if i != 0 and not text[i-1].isdigit() or i != len(text) - 1 and (text[i+1] == " " or text[i+1] == "\n"): + paragraph.append(text[curind:i+1] + "\n") + while(i < len(text) and text[i] != " "): + i += 1 + curind = i + 1 + i += 1 + if curind != i: + if text[i - 1] == " ": + if i != 1: + i -= 1 + else: + break + if text[i - 1] != '.': + paragraph.append(text[curind:i] + ".\n") + else: + paragraph.append(text[curind:i] + "\n") + L.append(paragraph) + + result.append({ + 'paragraphs': L, + 'page': current_page_num + }) + current_page_num += 1 + return result diff --git a/chemietoolkit/interface.py b/chemietoolkit/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..a44eea700f76147ba9eb2620981458b8bed2910c --- /dev/null +++ b/chemietoolkit/interface.py @@ -0,0 +1,749 @@ +import torch +import re +from functools import lru_cache +import layoutparser as lp +import pdf2image +from PIL import Image +from huggingface_hub import hf_hub_download, snapshot_download +from molscribe import MolScribe +from rxnscribe import RxnScribe, MolDetect +from chemiener import ChemNER +from .chemrxnextractor import ChemRxnExtractor +from .tableextractor import TableExtractor +from .utils import * + +class ChemIEToolkit: + def __init__(self, device=None): + if device is None: + self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + else: + self.device = torch.device(device) + + self._molscribe = None + self._rxnscribe = None + self._pdfparser = None + self._moldet = None + self._chemrxnextractor = None + self._chemner = None + self._coref = None + + @property + def molscribe(self): + if self._molscribe is None: + self.init_molscribe() + return self._molscribe + + @lru_cache(maxsize=None) + def init_molscribe(self, ckpt_path=None): + """ + Set model to custom checkpoint + Parameters: + ckpt_path: path to checkpoint to use, if None then will use default + """ + if ckpt_path is None: + ckpt_path = hf_hub_download("yujieq/MolScribe", "swin_base_char_aux_1m.pth") + self._molscribe = MolScribe(ckpt_path, device=self.device) + + + @property + def rxnscribe(self): + if self._rxnscribe is None: + self.init_rxnscribe() + return self._rxnscribe + + @lru_cache(maxsize=None) + def init_rxnscribe(self, ckpt_path=None): + """ + Set model to custom checkpoint + Parameters: + ckpt_path: path to checkpoint to use, if None then will use default + """ + if ckpt_path is None: + ckpt_path = hf_hub_download("yujieq/RxnScribe", "pix2seq_reaction_full.ckpt") + self._rxnscribe = RxnScribe(ckpt_path, device=self.device) + + + @property + def pdfparser(self): + if self._pdfparser is None: + self.init_pdfparser() + return self._pdfparser + + @lru_cache(maxsize=None) + def init_pdfparser(self, ckpt_path=None): + """ + Set model to custom checkpoint + Parameters: + ckpt_path: path to checkpoint to use, if None then will use default + """ + config_path = "lp://efficientdet/PubLayNet/tf_efficientdet_d1" + self._pdfparser = lp.AutoLayoutModel(config_path, model_path=ckpt_path, device=self.device.type) + + + @property + def moldet(self): + if self._moldet is None: + self.init_moldet() + return self._moldet + + @lru_cache(maxsize=None) + def init_moldet(self, ckpt_path=None): + """ + Set model to custom checkpoint + Parameters: + ckpt_path: path to checkpoint to use, if None then will use default + """ + if ckpt_path is None: + ckpt_path = hf_hub_download("Ozymandias314/MolDetectCkpt", "best_hf.ckpt") + self._moldet = MolDetect(ckpt_path, device=self.device) + + + @property + def coref(self): + if self._coref is None: + self.init_coref() + return self._coref + + @lru_cache(maxsize=None) + def init_coref(self, ckpt_path=None): + """ + Set model to custom checkpoint + Parameters: + ckpt_path: path to checkpoint to use, if None then will use default + """ + if ckpt_path is None: + ckpt_path = hf_hub_download("Ozymandias314/MolDetectCkpt", "coref_best_hf.ckpt") + self._coref = MolDetect(ckpt_path, device=self.device, coref=True) + + + @property + def chemrxnextractor(self): + if self._chemrxnextractor is None: + self.init_chemrxnextractor() + return self._chemrxnextractor + + @lru_cache(maxsize=None) + def init_chemrxnextractor(self, ckpt_path=None): + """ + Set model to custom checkpoint + Parameters: + ckpt_path: path to checkpoint to use, if None then will use default + """ + if ckpt_path is None: + ckpt_path = snapshot_download(repo_id="amberwang/chemrxnextractor-training-modules") + self._chemrxnextractor = ChemRxnExtractor("", None, ckpt_path, self.device.type) + + + @property + def chemner(self): + if self._chemner is None: + self.init_chemner() + return self._chemner + + @lru_cache(maxsize=None) + def init_chemner(self, ckpt_path=None): + """ + Set model to custom checkpoint + Parameters: + ckpt_path: path to checkpoint to use, if None then will use default + """ + if ckpt_path is None: + ckpt_path = hf_hub_download("Ozymandias314/ChemNERckpt", "best.ckpt") + self._chemner = ChemNER(ckpt_path, device=self.device) + + + @property + def tableextractor(self): + return TableExtractor() + + + def extract_figures_from_pdf(self, pdf, num_pages=None, output_bbox=False, output_image=True): + """ + Find and return all figures from a pdf page + Parameters: + pdf: path to pdf + num_pages: process only first `num_pages` pages, if `None` then process all + output_bbox: whether to output bounding boxes for each individual entry of a table + output_image: whether to include PIL image for figures. default is True + Returns: + list of content in the following format + [ + { # first figure + 'title': str, + 'figure': { + 'image': PIL image or None, + 'bbox': list in form [x1, y1, x2, y2], + } + 'table': { + 'bbox': list in form [x1, y1, x2, y2] or empty list, + 'content': { + 'columns': list of column headers, + 'rows': list of list of row content, + } or None + } + 'footnote': str or empty, + 'page': int + } + # more figures + ] + """ + pages = pdf2image.convert_from_path(pdf, last_page=num_pages) + + table_ext = self.tableextractor + table_ext.set_pdf_file(pdf) + table_ext.set_output_image(output_image) + + table_ext.set_output_bbox(output_bbox) + + return table_ext.extract_all_tables_and_figures(pages, self.pdfparser, content='figures') + + def extract_tables_from_pdf(self, pdf, num_pages=None, output_bbox=False, output_image=True): + """ + Find and return all tables from a pdf page + Parameters: + pdf: path to pdf + num_pages: process only first `num_pages` pages, if `None` then process all + output_bbox: whether to include bboxes for individual entries of the table + output_image: whether to include PIL image for figures. default is True + Returns: + list of content in the following format + [ + { # first table + 'title': str, + 'figure': { + 'image': PIL image or None, + 'bbox': list in form [x1, y1, x2, y2] or empty list, + } + 'table': { + 'bbox': list in form [x1, y1, x2, y2] or empty list, + 'content': { + 'columns': list of column headers, + 'rows': list of list of row content, + } + } + 'footnote': str or empty, + 'page': int + } + # more tables + ] + """ + pages = pdf2image.convert_from_path(pdf, last_page=num_pages) + + table_ext = self.tableextractor + table_ext.set_pdf_file(pdf) + table_ext.set_output_image(output_image) + + table_ext.set_output_bbox(output_bbox) + + return table_ext.extract_all_tables_and_figures(pages, self.pdfparser, content='tables') + + def extract_molecules_from_figures_in_pdf(self, pdf, batch_size=16, num_pages=None): + """ + Get all molecules and their information from a pdf + Parameters: + pdf: path to pdf, or byte file + batch_size: batch size for inference in all models + num_pages: process only first `num_pages` pages, if `None` then process all + Returns: + list of figures and corresponding molecule info in the following format + [ + { # first figure + 'image': ndarray of the figure image, + 'molecules': [ + { # first molecule + 'bbox': tuple in the form (x1, y1, x2, y2), + 'score': float, + 'image': ndarray of cropped molecule image, + 'smiles': str, + 'molfile': str + }, + # more molecules + ], + 'page': int + }, + # more figures + ] + """ + figures = self.extract_figures_from_pdf(pdf, num_pages=num_pages, output_bbox=True) + images = [figure['figure']['image'] for figure in figures] + results = self.extract_molecules_from_figures(images, batch_size=batch_size) + for figure, result in zip(figures, results): + result['page'] = figure['page'] + return results + + def extract_molecule_bboxes_from_figures(self, figures, batch_size=16): + """ + Return bounding boxes of molecules in images + Parameters: + figures: list of PIL or ndarray images + batch_size: batch size for inference + Returns: + list of results for each figure in the following format + [ + [ # first figure + { # first bounding box + 'category': str, + 'bbox': tuple in the form (x1, y1, x2, y2), + 'category_id': int, + 'score': float + }, + # more bounding boxes + ], + # more figures + ] + """ + figures = [convert_to_pil(figure) for figure in figures] + return self.moldet.predict_images(figures, batch_size=batch_size) + + def extract_molecules_from_figures(self, figures, batch_size=16): + """ + Get all molecules and their information from list of figures + Parameters: + figures: list of PIL or ndarray images + batch_size: batch size for inference + Returns: + list of results for each figure in the following format + [ + { # first figure + 'image': ndarray of the figure image, + 'molecules': [ + { # first molecule + 'bbox': tuple in the form (x1, y1, x2, y2), + 'score': float, + 'image': ndarray of cropped molecule image, + 'smiles': str, + 'molfile': str + }, + # more molecules + ], + }, + # more figures + ] + """ + bboxes = self.extract_molecule_bboxes_from_figures(figures, batch_size=batch_size) + figures = [convert_to_cv2(figure) for figure in figures] + results, cropped_images, refs = clean_bbox_output(figures, bboxes) + mol_info = self.molscribe.predict_images(cropped_images, batch_size=batch_size) + for info, ref in zip(mol_info, refs): + ref.update(info) + return results + + def extract_molecule_corefs_from_figures_in_pdf(self, pdf, batch_size=16, num_pages=None, molscribe = True, ocr = True): + """ + Get all molecule bboxes and corefs from figures in pdf + Parameters: + pdf: path to pdf, or byte file + batch_size: batch size for inference in all models + num_pages: process only first `num_pages` pages, if `None` then process all + Returns: + list of results for each figure in the following format: + [ + { + 'bboxes': [ + { # first bbox + 'category': '[Sup]', + 'bbox': (0.0050025012506253125, 0.38273870663142223, 0.9934967483741871, 0.9450094869920168), + 'category_id': 4, + 'score': -0.07593922317028046 + }, + # More bounding boxes + ], + 'corefs': [ + [0, 1], # molecule bbox index, identifier bbox index + [3, 4], + # More coref pairs + ], + 'page': int + }, + # More figures + ] + """ + figures = self.extract_figures_from_pdf(pdf, num_pages=num_pages, output_bbox=True) + images = [figure['figure']['image'] for figure in figures] + results = self.extract_molecule_corefs_from_figures(images, batch_size=batch_size, molscribe=molscribe, ocr=ocr) + for figure, result in zip(figures, results): + result['page'] = figure['page'] + return results + + def extract_molecule_corefs_from_figures(self, figures, batch_size=16, molscribe=True, ocr=True): + """ + Get all molecule bboxes and corefs from list of figures + Parameters: + figures: list of PIL or ndarray images + batch_size: batch size for inference + Returns: + list of results for each figure in the following format: + [ + { + 'bboxes': [ + { # first bbox + 'category': '[Sup]', + 'bbox': (0.0050025012506253125, 0.38273870663142223, 0.9934967483741871, 0.9450094869920168), + 'category_id': 4, + 'score': -0.07593922317028046 + }, + # More bounding boxes + ], + 'corefs': [ + [0, 1], # molecule bbox index, identifier bbox index + [3, 4], + # More coref pairs + ], + }, + # More figures + ] + """ + figures = [convert_to_pil(figure) for figure in figures] + return self.coref.predict_images(figures, batch_size=batch_size, coref=True, molscribe = molscribe, ocr = ocr) + + def extract_reactions_from_figures_in_pdf(self, pdf, batch_size=16, num_pages=None, molscribe=True, ocr=True): + """ + Get reaction information from figures in pdf + Parameters: + pdf: path to pdf, or byte file + batch_size: batch size for inference in all models + num_pages: process only first `num_pages` pages, if `None` then process all + molscribe: whether to predict and return smiles and molfile info + ocr: whether to predict and return text of conditions + Returns: + list of figures and corresponding molecule info in the following format + [ + { + 'figure': PIL image + 'reactions': [ + { + 'reactants': [ + { + 'category': str, + 'bbox': tuple (x1,x2,y1,y2), + 'category_id': int, + 'smiles': str, + 'molfile': str, + }, + # more reactants + ], + 'conditions': [ + { + 'category': str, + 'bbox': tuple (x1,x2,y1,y2), + 'category_id': int, + 'text': list of str, + }, + # more conditions + ], + 'products': [ + # same structure as reactants + ] + }, + # more reactions + ], + 'page': int + }, + # more figures + ] + """ + figures = self.extract_figures_from_pdf(pdf, num_pages=num_pages, output_bbox=True) + images = [figure['figure']['image'] for figure in figures] + results = self.extract_reactions_from_figures(images, batch_size=batch_size, molscribe=molscribe, ocr=ocr) + for figure, result in zip(figures, results): + result['page'] = figure['page'] + return results + + def extract_reactions_from_figures(self, figures, batch_size=16, molscribe=True, ocr=True): + """ + Get reaction information from list of figures + Parameters: + figures: list of PIL or ndarray images + batch_size: batch size for inference in all models + molscribe: whether to predict and return smiles and molfile info + ocr: whether to predict and return text of conditions + Returns: + list of figures and corresponding molecule info in the following format + [ + { + 'figure': PIL image + 'reactions': [ + { + 'reactants': [ + { + 'category': str, + 'bbox': tuple (x1,x2,y1,y2), + 'category_id': int, + 'smiles': str, + 'molfile': str, + }, + # more reactants + ], + 'conditions': [ + { + 'category': str, + 'bbox': tuple (x1,x2,y1,y2), + 'category_id': int, + 'text': list of str, + }, + # more conditions + ], + 'products': [ + # same structure as reactants + ] + }, + # more reactions + ], + }, + # more figures + ] + + """ + pil_figures = [convert_to_pil(figure) for figure in figures] + results = [] + reactions = self.rxnscribe.predict_images(pil_figures, batch_size=batch_size, molscribe=molscribe, ocr=ocr) + for figure, rxn in zip(figures, reactions): + data = { + 'figure': figure, + 'reactions': rxn, + } + results.append(data) + return results + + def extract_molecules_from_text_in_pdf(self, pdf, batch_size=16, num_pages=None): + """ + Get molecules in text of given pdf + + Parameters: + pdf: path to pdf, or byte file + batch_size: batch size for inference in all models + num_pages: process only first `num_pages` pages, if `None` then process all + Returns: + list of sentences and found molecules in the following format + [ + { + 'molecules': [ + { # first paragraph + 'text': str, + 'labels': [ + (str, int, int), # tuple of label, range start (inclusive), range end (exclusive) + # more labels + ] + }, + # more paragraphs + ] + 'page': int + }, + # more pages + ] + """ + self.chemrxnextractor.set_pdf_file(pdf) + self.chemrxnextractor.set_pages(num_pages) + text = self.chemrxnextractor.get_paragraphs_from_pdf(num_pages) + result = [] + for data in text: + model_inp = [] + for paragraph in data['paragraphs']: + model_inp.append(' '.join(paragraph).replace('\n', '')) + output = self.chemner.predict_strings(model_inp, batch_size=batch_size) + to_add = { + 'molecules': [{ + 'text': t, + 'labels': labels, + } for t, labels in zip(model_inp, output)], + 'page': data['page'] + } + result.append(to_add) + return result + + + def extract_reactions_from_text_in_pdf(self, pdf, num_pages=None): + """ + Get reaction information from text in pdf + Parameters: + pdf: path to pdf + num_pages: process only first `num_pages` pages, if `None` then process all + Returns: + list of pages and corresponding reaction info in the following format + [ + { + 'page': page number + 'reactions': [ + { + 'tokens': list of words in relevant sentence, + 'reactions' : [ + { + # key, value pairs where key is the label and value is a tuple + # or list of tuples of the form (tokens, start index, end index) + # where indices are for the corresponding token list and start and end are inclusive + } + # more reactions + ] + } + # more reactions in other sentences + ] + }, + # more pages + ] + """ + self.chemrxnextractor.set_pdf_file(pdf) + self.chemrxnextractor.set_pages(num_pages) + return self.chemrxnextractor.extract_reactions_from_text() + + def extract_reactions_from_text_in_pdf_combined(self, pdf, num_pages=None): + """ + Get reaction information from text in pdf and combined with corefs from figures + Parameters: + pdf: path to pdf + num_pages: process only first `num_pages` pages, if `None` then process all + Returns: + list of pages and corresponding reaction info in the following format + [ + { + 'page': page number + 'reactions': [ + { + 'tokens': list of words in relevant sentence, + 'reactions' : [ + { + # key, value pairs where key is the label and value is a tuple + # or list of tuples of the form (tokens, start index, end index) + # where indices are for the corresponding token list and start and end are inclusive + } + # more reactions + ] + } + # more reactions in other sentences + ] + }, + # more pages + ] + """ + results = self.extract_reactions_from_text_in_pdf(pdf, num_pages=num_pages) + results_coref = self.extract_molecule_corefs_from_figures_in_pdf(pdf, num_pages=num_pages) + return associate_corefs(results, results_coref) + + def extract_reactions_from_figures_and_tables_in_pdf(self, pdf, num_pages=None, batch_size=16, molscribe=True, ocr=True): + """ + Get reaction information from figures and combine with table information in pdf + Parameters: + pdf: path to pdf, or byte file + batch_size: batch size for inference in all models + num_pages: process only first `num_pages` pages, if `None` then process all + molscribe: whether to predict and return smiles and molfile info + ocr: whether to predict and return text of conditions + Returns: + list of figures and corresponding molecule info in the following format + [ + { + 'figure': PIL image + 'reactions': [ + { + 'reactants': [ + { + 'category': str, + 'bbox': tuple (x1,x2,y1,y2), + 'category_id': int, + 'smiles': str, + 'molfile': str, + }, + # more reactants + ], + 'conditions': [ + { + 'category': str, + 'text': list of str, + }, + # more conditions + ], + 'products': [ + # same structure as reactants + ] + }, + # more reactions + ], + 'page': int + }, + # more figures + ] + """ + figures = self.extract_figures_from_pdf(pdf, num_pages=num_pages, output_bbox=True) + images = [figure['figure']['image'] for figure in figures] + results = self.extract_reactions_from_figures(images, batch_size=batch_size, molscribe=molscribe, ocr=ocr) + results = process_tables(figures, results, self.molscribe, batch_size=batch_size) + results_coref = self.extract_molecule_corefs_from_figures_in_pdf(pdf, num_pages=num_pages) + results = replace_rgroups_in_figure(figures, results, results_coref, self.molscribe, batch_size=batch_size) + results = expand_reactions_with_backout(results, results_coref, self.molscribe) + return results + + def extract_reactions_from_pdf(self, pdf, num_pages=None, batch_size=16): + """ + Returns: + dictionary of reactions from multimodal sources + { + 'figures': [ + { + 'figure': PIL image + 'reactions': [ + { + 'reactants': [ + { + 'category': str, + 'bbox': tuple (x1,x2,y1,y2), + 'category_id': int, + 'smiles': str, + 'molfile': str, + }, + # more reactants + ], + 'conditions': [ + { + 'category': str, + 'text': list of str, + }, + # more conditions + ], + 'products': [ + # same structure as reactants + ] + }, + # more reactions + ], + 'page': int + }, + # more figures + ] + 'text': [ + { + 'page': page number + 'reactions': [ + { + 'tokens': list of words in relevant sentence, + 'reactions' : [ + { + # key, value pairs where key is the label and value is a tuple + # or list of tuples of the form (tokens, start index, end index) + # where indices are for the corresponding token list and start and end are inclusive + } + # more reactions + ] + } + # more reactions in other sentences + ] + }, + # more pages + ] + } + + """ + figures = self.extract_figures_from_pdf(pdf, num_pages=num_pages, output_bbox=True) + images = [figure['figure']['image'] for figure in figures] + results = self.extract_reactions_from_figures(images, batch_size=batch_size, molscribe=True, ocr=True) + table_expanded_results = process_tables(figures, results, self.molscribe, batch_size=batch_size) + text_results = self.extract_reactions_from_text_in_pdf(pdf, num_pages=num_pages) + results_coref = self.extract_molecule_corefs_from_figures_in_pdf(pdf, num_pages=num_pages) + figure_results = replace_rgroups_in_figure(figures, table_expanded_results, results_coref, self.molscribe, batch_size=batch_size) + table_expanded_results = expand_reactions_with_backout(figure_results, results_coref, self.molscribe) + coref_expanded_results = associate_corefs(text_results, results_coref) + return { + 'figures': table_expanded_results, + 'text': coref_expanded_results, + } + +if __name__=="__main__": + model = OpenChemIE() diff --git a/chemietoolkit/tableextractor.py b/chemietoolkit/tableextractor.py new file mode 100644 index 0000000000000000000000000000000000000000..6ce2bf1a2da5c721c87b1cd388fb085c2a86135a --- /dev/null +++ b/chemietoolkit/tableextractor.py @@ -0,0 +1,340 @@ +import pdf2image +import numpy as np +from PIL import Image +import matplotlib.pyplot as plt +import layoutparser as lp +import cv2 + +from PyPDF2 import PdfReader, PdfWriter +import pandas as pd + +import pdfminer.high_level +import pdfminer.layout +from operator import itemgetter + +# inputs: pdf_file, page #, bounding box (optional) (llur or ullr), output_bbox +class TableExtractor(object): + def __init__(self, output_bbox=True): + self.pdf_file = "" + self.page = "" + self.image_dpi = 200 + self.pdf_dpi = 72 + self.output_bbox = output_bbox + self.blocks = {} + self.title_y = 0 + self.column_header_y = 0 + self.model = None + self.img = None + self.output_image = True + self.tagging = { + 'substance': ['compound', 'salt', 'base', 'solvent', 'CBr4', 'collidine', 'InX3', 'substrate', 'ligand', 'PPh3', 'PdL2', 'Cu', 'compd', 'reagent', 'reagant', 'acid', 'aldehyde', 'amine', 'Ln', 'H2O', 'enzyme', 'cofactor', 'oxidant', 'Pt(COD)Cl2', 'CuBr2', 'additive'], + 'ratio': [':'], + 'measurement': ['μM', 'nM', 'IC50', 'CI', 'excitation', 'emission', 'Φ', 'φ', 'shift', 'ee', 'ΔG', 'ΔH', 'TΔS', 'Δ', 'distance', 'trajectory', 'V', 'eV'], + 'temperature': ['temp', 'temperature', 'T', '°C'], + 'time': ['time', 't(', 't ('], + 'result': ['yield', 'aa', 'result', 'product', 'conversion', '(%)'], + 'alkyl group': ['R', 'Ar', 'X', 'Y'], + 'solvent': ['solvent'], + 'counter': ['entry', 'no.'], + 'catalyst': ['catalyst', 'cat.'], + 'conditions': ['condition'], + 'reactant': ['reactant'], + } + + def set_output_image(self, oi): + self.output_image = oi + + def set_pdf_file(self, pdf): + self.pdf_file = pdf + + def set_page_num(self, pn): + self.page = pn + + def set_output_bbox(self, ob): + self.output_bbox = ob + + def run_model(self, page_info): + #img = np.asarray(pdf2image.convert_from_path(self.pdf_file, dpi=self.image_dpi)[self.page]) + + #model = lp.Detectron2LayoutModel('lp://PubLayNet/mask_rcnn_X_101_32x8d_FPN_3x/config', extra_config=["MODEL.ROI_HEADS.SCORE_THRESH_TEST", 0.5], label_map={0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"}) + + img = np.asarray(page_info) + self.img = img + + layout_result = self.model.detect(img) + + text_blocks = lp.Layout([b for b in layout_result if b.type == 'Text']) + title_blocks = lp.Layout([b for b in layout_result if b.type == 'Title']) + list_blocks = lp.Layout([b for b in layout_result if b.type == 'List']) + table_blocks = lp.Layout([b for b in layout_result if b.type == 'Table']) + figure_blocks = lp.Layout([b for b in layout_result if b.type == 'Figure']) + + self.blocks.update({'text': text_blocks}) + self.blocks.update({'title': title_blocks}) + self.blocks.update({'list': list_blocks}) + self.blocks.update({'table': table_blocks}) + self.blocks.update({'figure': figure_blocks}) + + # type is what coordinates you want to get. it comes in text, title, list, table, and figure + def convert_to_pdf_coordinates(self, type): + # scale coordinates + + blocks = self.blocks[type] + coordinates = [blocks[a].scale(self.pdf_dpi/self.image_dpi) for a in range(len(blocks))] + + reader = PdfReader(self.pdf_file) + + writer = PdfWriter() + p = reader.pages[self.page] + a = p.mediabox.upper_left + new_coords = [] + for new_block in coordinates: + new_coords.append((new_block.block.x_1, pd.to_numeric(a[1]) - new_block.block.y_2, new_block.block.x_2, pd.to_numeric(a[1]) - new_block.block.y_1)) + + return new_coords + # output: list of bounding boxes for tables but in pdf coordinates + + # input: new_coords is singular table bounding box in pdf coordinates + def extract_singular_table(self, new_coords): + for page_layout in pdfminer.high_level.extract_pages(self.pdf_file, page_numbers=[self.page]): + elements = [] + for element in page_layout: + if isinstance(element, pdfminer.layout.LTTextBox): + for e in element._objs: + temp = e.bbox + if temp[0] > min(new_coords[0], new_coords[2]) and temp[0] < max(new_coords[0], new_coords[2]) and temp[1] > min(new_coords[1], new_coords[3]) and temp[1] < max(new_coords[1], new_coords[3]) and temp[2] > min(new_coords[0], new_coords[2]) and temp[2] < max(new_coords[0], new_coords[2]) and temp[3] > min(new_coords[1], new_coords[3]) and temp[3] < max(new_coords[1], new_coords[3]) and isinstance(e, pdfminer.layout.LTTextLineHorizontal): + elements.append([e.bbox[0], e.bbox[1], e.bbox[2], e.bbox[3], e.get_text()]) + + elements = sorted(elements, key=itemgetter(0)) + w = sorted(elements, key=itemgetter(3), reverse=True) + if len(w) <= 1: + continue + + ret = {} + i = 1 + g = [w[0]] + + while i < len(w) and w[i][3] > w[i-1][1]: + g.append(w[i]) + i += 1 + g = sorted(g, key=itemgetter(0)) + # check for overlaps + for a in range(len(g)-1, 0, -1): + if g[a][0] < g[a-1][2]: + g[a-1][0] = min(g[a][0], g[a-1][0]) + g[a-1][1] = min(g[a][1], g[a-1][1]) + g[a-1][2] = max(g[a][2], g[a-1][2]) + g[a-1][3] = max(g[a][3], g[a-1][3]) + g[a-1][4] = g[a-1][4].strip() + " " + g[a][4] + g.pop(a) + + + ret.update({"columns":[]}) + for t in g: + temp_bbox = t[:4] + + column_text = t[4].strip() + tag = 'unknown' + tagged = False + for key in self.tagging.keys(): + for word in self.tagging[key]: + if word in column_text: + tag = key + tagged = True + break + if tagged: + break + + if self.output_bbox: + ret["columns"].append({'text':column_text,'tag': tag, 'bbox':temp_bbox}) + else: + ret["columns"].append({'text':column_text,'tag': tag}) + self.column_header_y = max(t[1], t[3]) + ret.update({"rows":[]}) + + g.insert(0, [0, 0, new_coords[0], 0, '']) + g.append([new_coords[2], 0, 0, 0, '']) + while i < len(w): + group = [w[i]] + i += 1 + while i < len(w) and w[i][3] > w[i-1][1]: + group.append(w[i]) + i += 1 + group = sorted(group, key=itemgetter(0)) + + for a in range(len(group)-1, 0, -1): + if group[a][0] < group[a-1][2]: + group[a-1][0] = min(group[a][0], group[a-1][0]) + group[a-1][1] = min(group[a][1], group[a-1][1]) + group[a-1][2] = max(group[a][2], group[a-1][2]) + group[a-1][3] = max(group[a][3], group[a-1][3]) + group[a-1][4] = group[a-1][4].strip() + " " + group[a][4] + group.pop(a) + + a = 1 + while a < len(g) - 1: + if a > len(group): + group.append([0, 0, 0, 0, '\n']) + a += 1 + continue + if group[a-1][0] >= g[a-1][2] and group[a-1][2] <= g[a+1][0]: + pass + """ + if a < len(group) and group[a][0] >= g[a-1][2] and group[a][2] <= g[a+1][0]: + g.insert(1, [g[0][2], 0, group[a-1][2], 0, '']) + #ret["columns"].insert(0, '') + else: + a += 1 + continue + """ + else: group.insert(a-1, [0, 0, 0, 0, '\n']) + a += 1 + + + added_row = [] + for t in group: + temp_bbox = t[:4] + if self.output_bbox: + added_row.append({'text':t[4].strip(), 'bbox':temp_bbox}) + else: + added_row.append(t[4].strip()) + ret["rows"].append(added_row) + if ret["rows"] and len(ret["rows"][0]) != len(ret["columns"]): + ret["columns"] = ret["rows"][0] + ret["rows"] = ret["rows"][1:] + for col in ret['columns']: + tag = 'unknown' + tagged = False + for key in self.tagging.keys(): + for word in self.tagging[key]: + if word in col['text']: + tag = key + tagged = True + break + if tagged: + break + col['tag'] = tag + + return ret + + def get_title_and_footnotes(self, tb_coords): + + for page_layout in pdfminer.high_level.extract_pages(self.pdf_file, page_numbers=[self.page]): + title = (0, 0, 0, 0, '') + footnote = (0, 0, 0, 0, '') + title_gap = 30 + footnote_gap = 30 + for element in page_layout: + if isinstance(element, pdfminer.layout.LTTextBoxHorizontal): + if (element.bbox[0] >= tb_coords[0] and element.bbox[0] <= tb_coords[2]) or (element.bbox[2] >= tb_coords[0] and element.bbox[2] <= tb_coords[2]) or (tb_coords[0] >= element.bbox[0] and tb_coords[0] <= element.bbox[2]) or (tb_coords[2] >= element.bbox[0] and tb_coords[2] <= element.bbox[2]): + #print(element) + if 'Table' in element.get_text(): + if abs(element.bbox[1] - tb_coords[3]) < title_gap: + title = tuple(element.bbox) + (element.get_text()[element.get_text().index('Table'):].replace('\n', ' '),) + title_gap = abs(element.bbox[1] - tb_coords[3]) + if 'Scheme' in element.get_text(): + if abs(element.bbox[1] - tb_coords[3]) < title_gap: + title = tuple(element.bbox) + (element.get_text()[element.get_text().index('Scheme'):].replace('\n', ' '),) + title_gap = abs(element.bbox[1] - tb_coords[3]) + if element.bbox[1] >= tb_coords[1] and element.bbox[3] <= tb_coords[3]: continue + #print(element) + temp = ['aA', 'aB', 'aC', 'aD', 'aE', 'aF', 'aG', 'aH', 'aI', 'aJ', 'aK', 'aL', 'aM', 'aN', 'aO', 'aP', 'aQ', 'aR', 'aS', 'aT', 'aU', 'aV', 'aW', 'aX', 'aY', 'aZ', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6', 'a7', 'a8', 'a9', 'a0'] + for segment in temp: + if segment in element.get_text(): + if abs(element.bbox[3] - tb_coords[1]) < footnote_gap: + footnote = tuple(element.bbox) + (element.get_text()[element.get_text().index(segment):].replace('\n', ' '),) + footnote_gap = abs(element.bbox[3] - tb_coords[1]) + break + self.title_y = min(title[1], title[3]) + if self.output_bbox: + return ({'text': title[4], 'bbox': list(title[:4])}, {'text': footnote[4], 'bbox': list(footnote[:4])}) + else: + return (title[4], footnote[4]) + + def extract_table_information(self): + #self.run_model(page_info) # changed + table_coordinates = self.blocks['table'] #should return a list of layout objects + table_coordinates_in_pdf = self.convert_to_pdf_coordinates('table') #should return a list of lists + + ans = [] + i = 0 + for coordinate in table_coordinates_in_pdf: + ret = {} + pad = 20 + coordinate = [coordinate[0] - pad, coordinate[1], coordinate[2] + pad, coordinate[3]] + ullr_coord = [coordinate[0], coordinate[3], coordinate[2], coordinate[1]] + + table_results = self.extract_singular_table(coordinate) + tf = self.get_title_and_footnotes(coordinate) + figure = Image.fromarray(table_coordinates[i].crop_image(self.img)) + ret.update({'title': tf[0]}) + ret.update({'figure': { + 'image': None, + 'bbox': [] + }}) + if self.output_image: + ret['figure']['image'] = figure + ret.update({'table': {'bbox': list(coordinate), 'content': table_results}}) + ret.update({'footnote': tf[1]}) + if abs(self.title_y - self.column_header_y) > 50: + ret['figure']['bbox'] = list(coordinate) + + ret.update({'page':self.page}) + + ans.append(ret) + i += 1 + + return ans + + def extract_figure_information(self): + figure_coordinates = self.blocks['figure'] + figure_coordinates_in_pdf = self.convert_to_pdf_coordinates('figure') + + ans = [] + for i in range(len(figure_coordinates)): + ret = {} + coordinate = figure_coordinates_in_pdf[i] + ullr_coord = [coordinate[0], coordinate[3], coordinate[2], coordinate[1]] + + tf = self.get_title_and_footnotes(coordinate) + figure = Image.fromarray(figure_coordinates[i].crop_image(self.img)) + ret.update({'title':tf[0]}) + ret.update({'figure': { + 'image': None, + 'bbox': [] + }}) + if self.output_image: + ret['figure']['image'] = figure + ret.update({'table': { + 'bbox': [], + 'content': None + }}) + ret.update({'footnote': tf[1]}) + ret['figure']['bbox'] = list(coordinate) + + ret.update({'page':self.page}) + + ans.append(ret) + + return ans + + + def extract_all_tables_and_figures(self, pages, pdfparser, content=None): + self.model = pdfparser + ret = [] + for i in range(len(pages)): + self.set_page_num(i) + self.run_model(pages[i]) + table_info = self.extract_table_information() + figure_info = self.extract_figure_information() + if content == 'tables': + ret += table_info + elif content == 'figures': + ret += figure_info + for table in table_info: + if table['figure']['bbox'] != []: + ret.append(table) + else: + ret += table_info + ret += figure_info + return ret diff --git a/chemietoolkit/utils.py b/chemietoolkit/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..aedb65c0b4c1cb9f08498e60d434250ab6fe7ec4 --- /dev/null +++ b/chemietoolkit/utils.py @@ -0,0 +1,1018 @@ +import numpy as np +from PIL import Image +import cv2 +import layoutparser as lp +from rdkit import Chem +from rdkit.Chem import Draw +from rdkit.Chem import rdDepictor +rdDepictor.SetPreferCoordGen(True) +from rdkit.Chem.Draw import IPythonConsole +from rdkit.Chem import AllChem +import re +import copy + +BOND_TO_INT = { + "": 0, + "single": 1, + "double": 2, + "triple": 3, + "aromatic": 4, + "solid wedge": 5, + "dashed wedge": 6 +} + +RGROUP_SYMBOLS = ['R', 'R1', 'R2', 'R3', 'R4', 'R5', 'R6', 'R7', 'R8', 'R9', 'R10', 'R11', 'R12', + 'Ra', 'Rb', 'Rc', 'Rd', 'Rf', 'X', 'Y', 'Z', 'Q', 'A', 'E', 'Ar', 'Ar1', 'Ar2', 'Ari', "R'", + '1*', '2*','3*', '4*','5*', '6*','7*', '8*','9*', '10*','11*', '12*','[a*]', '[b*]','[c*]', '[d*]'] + +RGROUP_SYMBOLS = RGROUP_SYMBOLS + [f'[{i}]' for i in RGROUP_SYMBOLS] + +RGROUP_SMILES = ['[1*]', '[2*]','[3*]', '[4*]','[5*]', '[6*]','[7*]', '[8*]','[9*]', '[10*]','[11*]', '[12*]','[a*]', '[b*]','[c*]', '[d*]','*', '[Rf]'] + +def get_figures_from_pages(pages, pdfparser): + figures = [] + for i in range(len(pages)): + img = np.asarray(pages[i]) + layout = pdfparser.detect(img) + blocks = lp.Layout([b for b in layout if b.type == "Figure"]) + for block in blocks: + figure = Image.fromarray(block.crop_image(img)) + figures.append({ + 'image': figure, + 'page': i + }) + return figures + +def clean_bbox_output(figures, bboxes): + results = [] + cropped = [] + references = [] + for i, output in enumerate(bboxes): + mol_bboxes = [elt['bbox'] for elt in output if elt['category'] == '[Mol]'] + mol_scores = [elt['score'] for elt in output if elt['category'] == '[Mol]'] + data = {} + results.append(data) + data['image'] = figures[i] + data['molecules'] = [] + for bbox, score in zip(mol_bboxes, mol_scores): + x1, y1, x2, y2 = bbox + height, width, _ = figures[i].shape + cropped_img = figures[i][int(y1*height):int(y2*height),int(x1*width):int(x2*width)] + cur_mol = { + 'bbox': bbox, + 'score': score, + 'image': cropped_img, + #'info': None, + } + cropped.append(cropped_img) + data['molecules'].append(cur_mol) + references.append(cur_mol) + return results, cropped, references + +def convert_to_pil(image): + if type(image) == np.ndarray: + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = Image.fromarray(image) + return image + +def convert_to_cv2(image): + if type(image) != np.ndarray: + image = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB) + return image + +def replace_rgroups_in_figure(figures, results, coref_results, molscribe, batch_size=16): + pattern = re.compile('(?P[RXY]\d?)[ ]*=[ ]*(?P\w+)') + for figure, result, corefs in zip(figures, results, coref_results): + r_groups = [] + seen_r_groups = set() + for bbox in corefs['bboxes']: + if bbox['category'] == '[Idt]': + for text in bbox['text']: + res = pattern.search(text) + if res is None: + continue + name = res.group('name') + group = res.group('group') + if (name, group) in seen_r_groups: + continue + seen_r_groups.add((name, group)) + r_groups.append({name: res.group('group')}) + if r_groups and result['reactions']: + seen_r_groups = set([pair[0] for pair in seen_r_groups]) + orig_reaction = result['reactions'][0] + graphs = get_atoms_and_bonds(figure['figure']['image'], orig_reaction, molscribe, batch_size=batch_size) + relevant_locs = {} + for i, graph in enumerate(graphs): + to_add = [] + for j, atom in enumerate(graph['chartok_coords']['symbols']): + if atom[1:-1] in seen_r_groups: + to_add.append((atom[1:-1], j)) + relevant_locs[i] = to_add + + for r_group in r_groups: + reaction = get_replaced_reaction(orig_reaction, graphs, relevant_locs, r_group, molscribe) + to_add ={ + 'reactants': reaction['reactants'][:], + 'conditions': orig_reaction['conditions'][:], + 'products': reaction['products'][:] + } + result['reactions'].append(to_add) + return results + +def process_tables(figures, results, molscribe, batch_size=16): + r_group_pattern = re.compile(r'^(\w+-)?(?P[\w-]+)( \(\w+\))?$') + for figure, result in zip(figures, results): + result['page'] = figure['page'] + if figure['table']['content'] is not None: + content = figure['table']['content'] + if len(result['reactions']) > 1: + print("Warning: multiple reactions detected for table") + elif len(result['reactions']) == 0: + continue + orig_reaction = result['reactions'][0] + graphs = get_atoms_and_bonds(figure['figure']['image'], orig_reaction, molscribe, batch_size=batch_size) + relevant_locs = find_relevant_groups(graphs, content['columns']) + conditions_to_extend = [] + for row in content['rows']: + r_groups = {} + expanded_conditions = orig_reaction['conditions'][:] + replaced = False + for col, entry in zip(content['columns'], row): + if col['tag'] != 'alkyl group': + expanded_conditions.append({ + 'category': '[Table]', + 'text': entry['text'], + 'tag': col['tag'], + 'header': col['text'], + }) + else: + found = r_group_pattern.match(entry['text']) + if found is not None: + r_groups[col['text']] = found.group('group') + replaced = True + reaction = get_replaced_reaction(orig_reaction, graphs, relevant_locs, r_groups, molscribe) + if replaced: + to_add = { + 'reactants': reaction['reactants'][:], + 'conditions': expanded_conditions, + 'products': reaction['products'][:] + } + result['reactions'].append(to_add) + else: + conditions_to_extend.append(expanded_conditions) + orig_reaction['conditions'] = [orig_reaction['conditions']] + orig_reaction['conditions'].extend(conditions_to_extend) + return results + + +def get_atoms_and_bonds(image, reaction, molscribe, batch_size=16): + image = convert_to_cv2(image) + cropped_images = [] + results = [] + for key, molecules in reaction.items(): + for i, elt in enumerate(molecules): + if type(elt) != dict or elt['category'] != '[Mol]': + continue + x1, y1, x2, y2 = elt['bbox'] + height, width, _ = image.shape + cropped_images.append(image[int(y1*height):int(y2*height),int(x1*width):int(x2*width)]) + to_add = { + 'image': cropped_images[-1], + 'chartok_coords': { + 'coords': [], + 'symbols': [], + }, + 'edges': [], + 'key': (key, i) + } + results.append(to_add) + outputs = molscribe.predict_images(cropped_images, return_atoms_bonds=True, batch_size=batch_size) + for mol, result in zip(outputs, results): + for atom in mol['atoms']: + result['chartok_coords']['coords'].append((atom['x'], atom['y'])) + result['chartok_coords']['symbols'].append(atom['atom_symbol']) + result['edges'] = [[0] * len(mol['atoms']) for _ in range(len(mol['atoms']))] + for bond in mol['bonds']: + i, j = bond['endpoint_atoms'] + result['edges'][i][j] = BOND_TO_INT[bond['bond_type']] + result['edges'][j][i] = BOND_TO_INT[bond['bond_type']] + return results + +def find_relevant_groups(graphs, columns): + results = {} + r_groups = set([f"[{col['text']}]" for col in columns if col['tag'] == 'alkyl group']) + for i, graph in enumerate(graphs): + to_add = [] + for j, atom in enumerate(graph['chartok_coords']['symbols']): + if atom in r_groups: + to_add.append((atom[1:-1], j)) + results[i] = to_add + return results + +def get_replaced_reaction(orig_reaction, graphs, relevant_locs, mappings, molscribe): + graph_copy = [] + for graph in graphs: + graph_copy.append({ + 'image': graph['image'], + 'chartok_coords': { + 'coords': graph['chartok_coords']['coords'][:], + 'symbols': graph['chartok_coords']['symbols'][:], + }, + 'edges': graph['edges'][:], + 'key': graph['key'], + }) + for graph_idx, atoms in relevant_locs.items(): + for atom, atom_idx in atoms: + if atom in mappings: + graph_copy[graph_idx]['chartok_coords']['symbols'][atom_idx] = mappings[atom] + reaction_copy = {} + def append_copy(copy_list, entity): + if entity['category'] == '[Mol]': + copy_list.append({ + k1: v1 for k1, v1 in entity.items() + }) + else: + copy_list.append(entity) + + for k, v in orig_reaction.items(): + reaction_copy[k] = [] + for entity in v: + if type(entity) == list: + sub_list = [] + for e in entity: + append_copy(sub_list, e) + reaction_copy[k].append(sub_list) + else: + append_copy(reaction_copy[k], entity) + + for graph in graph_copy: + output = molscribe.convert_graph_to_output([graph], [graph['image']]) + molecule = reaction_copy[graph['key'][0]][graph['key'][1]] + molecule['smiles'] = output[0]['smiles'] + molecule['molfile'] = output[0]['molfile'] + return reaction_copy + +def get_sites(tar, ref, ref_site = False): + rdDepictor.Compute2DCoords(ref) + rdDepictor.Compute2DCoords(tar) + idx_pair = rdDepictor.GenerateDepictionMatching2DStructure(tar, ref) + + in_template = [i[1] for i in idx_pair] + sites = [] + for i in range(tar.GetNumAtoms()): + if i not in in_template: + for j in tar.GetAtomWithIdx(i).GetNeighbors(): + if j.GetIdx() in in_template and j.GetIdx() not in sites: + + if ref_site: sites.append(idx_pair[in_template.index(j.GetIdx())][0]) + else: sites.append(idx_pair[in_template.index(j.GetIdx())][0]) + return sites + +def get_atom_mapping(prod_mol, prod_smiles, r_sites_reversed = None): + # returns prod_mol_to_query which is the mapping of atom indices in prod_mol to the atom indices of the molecule represented by prod_smiles + prod_template_intermediate = Chem.MolToSmiles(prod_mol) + prod_template = prod_smiles + + for r in RGROUP_SMILES: + if r!='*' and r!='(*)': + prod_template = prod_template.replace(r, '*') + prod_template_intermediate = prod_template_intermediate.replace(r, '*') + + prod_template_intermediate_mol = Chem.MolFromSmiles(prod_template_intermediate) + prod_template_mol = Chem.MolFromSmiles(prod_template) + + p = Chem.AdjustQueryParameters.NoAdjustments() + p.makeDummiesQueries = True + + prod_template_mol_query = Chem.AdjustQueryProperties(prod_template_mol, p) + prod_template_intermediate_mol_query = Chem.AdjustQueryProperties(prod_template_intermediate_mol, p) + rdDepictor.Compute2DCoords(prod_mol) + rdDepictor.Compute2DCoords(prod_template_mol_query) + rdDepictor.Compute2DCoords(prod_template_intermediate_mol_query) + idx_pair = rdDepictor.GenerateDepictionMatching2DStructure(prod_mol, prod_template_intermediate_mol_query) + + intermdiate_to_prod_mol = {a:b for a,b in idx_pair} + prod_mol_to_intermediate = {b:a for a,b in idx_pair} + + + #idx_pair_2 = rdDepictor.GenerateDepictionMatching2DStructure(prod_template_mol_query, prod_template_intermediate_mol_query) + + #intermediate_to_query = {a:b for a,b in idx_pair_2} + #query_to_intermediate = {b:a for a,b in idx_pair_2} + + #prod_mol_to_query = {a:intermediate_to_query[prod_mol_to_intermediate[a]] for a in prod_mol_to_intermediate} + + + substructs = prod_template_mol_query.GetSubstructMatches(prod_template_intermediate_mol_query, uniquify = False) + + #idx_pair_2 = rdDepictor.GenerateDepictionMatching2DStructure(prod_template_mol_query, prod_template_intermediate_mol_query) + for substruct in substructs: + + + intermediate_to_query = {a:b for a, b in enumerate(substruct)} + query_to_intermediate = {intermediate_to_query[i]: i for i in intermediate_to_query} + + prod_mol_to_query = {a:intermediate_to_query[prod_mol_to_intermediate[a]] for a in prod_mol_to_intermediate} + + good_map = True + for i in r_sites_reversed: + if prod_template_mol_query.GetAtomWithIdx(prod_mol_to_query[i]).GetSymbol() not in RGROUP_SMILES: + good_map = False + if good_map: + break + + return prod_mol_to_query, prod_template_mol_query + +def clean_corefs(coref_results_dict, idx): + label_pattern = rf'{re.escape(idx)}[a-zA-Z]+' + #unclean_pattern = re.escape(idx) + r'\d(?![\d% ])' + toreturn = {} + for prod in coref_results_dict: + has_good_label = False + for parsed in coref_results_dict[prod]: + if re.search(label_pattern, parsed): + has_good_label = True + if not has_good_label: + for parsed in coref_results_dict[prod]: + if idx+'1' in parsed: + coref_results_dict[prod].append(idx+'l') + elif idx+'0' in parsed: + coref_results_dict[prod].append(idx+'o') + elif idx+'5' in parsed: + coref_results_dict[prod].append(idx+'s') + elif idx+'9' in parsed: + coref_results_dict[prod].append(idx+'g') + + + +def expand_r_group_label_helper(res, coref_smiles_to_graphs, other_prod, molscribe): + name = res.group('name') + group = res.group('group') + #print(other_prod) + atoms = coref_smiles_to_graphs[other_prod]['atoms'] + bonds = coref_smiles_to_graphs[other_prod]['bonds'] + + #print(atoms, bonds) + + graph = { + 'image': None, + 'chartok_coords': { + 'coords': [], + 'symbols': [], + }, + 'edges': [], + 'key': None + } + for atom in atoms: + graph['chartok_coords']['coords'].append((atom['x'], atom['y'])) + graph['chartok_coords']['symbols'].append(atom['atom_symbol']) + graph['edges'] = [[0] * len(atoms) for _ in range(len(atoms))] + for bond in bonds: + i, j = bond['endpoint_atoms'] + graph['edges'][i][j] = BOND_TO_INT[bond['bond_type']] + graph['edges'][j][i] = BOND_TO_INT[bond['bond_type']] + for i, symbol in enumerate(graph['chartok_coords']['symbols']): + if symbol[1:-1] == name: + graph['chartok_coords']['symbols'][i] = group + + #print(graph) + o = molscribe.convert_graph_to_output([graph], [graph['image']]) + return Chem.MolFromSmiles(o[0]['smiles']) + +def get_r_group_frags_and_substitute(other_prod_mol, query, reactant_mols, reactant_information, parsed, toreturn): + prod_template_mol_query, r_sites_reversed_new, h_sites, num_r_groups = query + # we get the substruct matches. note that we set uniquify to false since the order matters for our method + substructs = other_prod_mol.GetSubstructMatches(prod_template_mol_query, uniquify = False) + + + #for r in r_sites_reversed: + # print(prod_template_mol_query.GetAtomWithIdx(prod_mol_to_query[r]).GetSymbol()) + + # for each substruct we create the mapping of the substruct onto the other_mol + # delete all the molecules in other_mol correspond to the substruct + # and check if they number of mol frags is equal to number of r groups + # we do this to make sure we have the correct substruct + if len(substructs) >= 1: + for substruct in substructs: + + query_to_other = {a:b for a,b in enumerate(substruct)} + other_to_query = {query_to_other[i]:i for i in query_to_other} + + editable = Chem.EditableMol(other_prod_mol) + r_site_correspondence = [] + for r in r_sites_reversed_new: + #get its id in substruct + substruct_id = query_to_other[r] + r_site_correspondence.append([substruct_id, r_sites_reversed_new[r]]) + + for idx in tuple(sorted(substruct, reverse = True)): + if idx not in [query_to_other[i] for i in r_sites_reversed_new]: + editable.RemoveAtom(idx) + for r_site in r_site_correspondence: + if idx < r_site[0]: + r_site[0]-=1 + other_prod_removed = editable.GetMol() + + if len(Chem.GetMolFrags(other_prod_removed, asMols = False)) == num_r_groups: + break + + # need to compute the sites at which correspond to each r_site_reversed + + r_site_correspondence.sort(key = lambda x: x[0]) + + + f = [] + ff = [] + frags = Chem.GetMolFrags(other_prod_removed, asMols = True, frags = f, fragsMolAtomMapping = ff) + + # r_group_information maps r group name --> the fragment/molcule corresponding to the r group and the atom index it should be connected at + r_group_information = {} + #tosubtract = 0 + for idx, r_site in enumerate(r_site_correspondence): + + r_group_information[r_site[1]]= (frags[f[r_site[0]]], ff[f[r_site[0]]].index(r_site[0])) + #tosubtract += len(ff[idx]) + for r_site in h_sites: + r_group_information[r_site] = (Chem.MolFromSmiles('[H]'), 0) + + # now we modify all of the reactants according to the R groups we have found + # for every reactant we disconnect its r group symbol, and connect it to the r group + modify_reactants = copy.deepcopy(reactant_mols) + modified_reactant_smiles = [] + for reactant_idx in reactant_information: + if len(reactant_information[reactant_idx]) == 0: + modified_reactant_smiles.append(Chem.MolToSmiles(modify_reactants[reactant_idx])) + else: + combined = reactant_mols[reactant_idx] + if combined.GetNumAtoms() == 1: + r_group, _, _ = reactant_information[reactant_idx][0] + modified_reactant_smiles.append(Chem.MolToSmiles(r_group_information[r_group][0])) + else: + for r_group, r_index, connect_index in reactant_information[reactant_idx]: + combined = Chem.CombineMols(combined, r_group_information[r_group][0]) + + editable = Chem.EditableMol(combined) + atomIdxAdder = reactant_mols[reactant_idx].GetNumAtoms() + for r_group, r_index, connect_index in reactant_information[reactant_idx]: + Chem.EditableMol.RemoveBond(editable, r_index, connect_index) + Chem.EditableMol.AddBond(editable, connect_index, atomIdxAdder + r_group_information[r_group][1], Chem.BondType.SINGLE) + atomIdxAdder += r_group_information[r_group][0].GetNumAtoms() + r_indices = [i[1] for i in reactant_information[reactant_idx]] + + r_indices.sort(reverse = True) + + for r_index in r_indices: + Chem.EditableMol.RemoveAtom(editable, r_index) + + modified_reactant_smiles.append(Chem.MolToSmiles(Chem.MolFromSmiles(Chem.MolToSmiles(editable.GetMol())))) + + toreturn.append((modified_reactant_smiles, [Chem.MolToSmiles(other_prod_mol)], parsed)) + return True + else: + return False + +def query_enumeration(prod_template_mol_query, r_sites_reversed_new, num_r_groups): + subsets = generate_subsets(num_r_groups) + + toreturn = [] + + for subset in subsets: + r_sites_list = [[i, r_sites_reversed_new[i]] for i in r_sites_reversed_new] + r_sites_list.sort(key = lambda x: x[0]) + to_edit = Chem.EditableMol(prod_template_mol_query) + + for entry in subset: + pos = r_sites_list[entry][0] + Chem.EditableMol.RemoveBond(to_edit, r_sites_list[entry][0], prod_template_mol_query.GetAtomWithIdx(r_sites_list[entry][0]).GetNeighbors()[0].GetIdx()) + for entry in subset: + pos = r_sites_list[entry][0] + Chem.EditableMol.RemoveAtom(to_edit, pos) + + edited = to_edit.GetMol() + for entry in subset: + for i in range(entry + 1, num_r_groups): + r_sites_list[i][0]-=1 + + new_r_sites = {} + new_h_sites = set() + for i in range(num_r_groups): + if i not in subset: + new_r_sites[r_sites_list[i][0]] = r_sites_list[i][1] + else: + new_h_sites.add(r_sites_list[i][1]) + toreturn.append((edited, new_r_sites, new_h_sites, num_r_groups - len(subset))) + return toreturn + +def generate_subsets(n): + def backtrack(start, subset): + result.append(subset[:]) + for i in range(start, -1, -1): # Iterate in reverse order + subset.append(i) + backtrack(i - 1, subset) + subset.pop() + + result = [] + backtrack(n - 1, []) + return sorted(result, key=lambda x: (-len(x), x), reverse=True) + +def backout(results, coref_results, molscribe): + + toreturn = [] + + if not results or not results[0]['reactions'] or not coref_results: + return toreturn + + try: + reactants = results[0]['reactions'][0]['reactants'] + products = [i['smiles'] for i in results[0]['reactions'][0]['products']] + coref_results_dict = {coref_results[0]['bboxes'][coref[0]]['smiles']: coref_results[0]['bboxes'][coref[1]]['text'] for coref in coref_results[0]['corefs']} + coref_smiles_to_graphs = {coref_results[0]['bboxes'][coref[0]]['smiles']: coref_results[0]['bboxes'][coref[0]] for coref in coref_results[0]['corefs']} + + + if len(products) == 1: + if products[0] not in coref_results_dict: + print("Warning: No Label Parsed") + return + product_labels = coref_results_dict[products[0]] + prod = products[0] + label_idx = product_labels[0] + ''' + if len(product_labels) == 1: + # get the coreference label of the product molecule + label_idx = product_labels[0] + else: + print("Warning: Malformed Label Parsed.") + return + ''' + else: + print("Warning: More than one product detected") + return + + # format the regular expression for labels that correspond to the product label + numbers = re.findall(r'\d+', label_idx) + label_idx = numbers[0] if len(numbers) > 0 else "" + label_pattern = rf'{re.escape(label_idx)}[a-zA-Z]+' + + + prod_smiles = prod + prod_mol = Chem.MolFromMolBlock(results[0]['reactions'][0]['products'][0]['molfile']) + + # identify the atom indices of the R groups in the product tempalte + h_counter = 0 + r_sites = {} + for idx, atom in enumerate(results[0]['reactions'][0]['products'][0]['atoms']): + sym = atom['atom_symbol'] + if sym == '[H]': + h_counter += 1 + if sym[0] == '[': + sym = sym[1:-1] + if sym[0] == 'R' and sym[1:].isdigit(): + sym = sym[1:]+"*" + sym = f'[{sym}]' + if sym in RGROUP_SYMBOLS: + if sym not in r_sites: + r_sites[sym] = [idx-h_counter] + else: + r_sites[sym].append(idx-h_counter) + + r_sites_reversed = {} + for sym in r_sites: + for pos in r_sites[sym]: + r_sites_reversed[pos] = sym + + num_r_groups = len(r_sites_reversed) + + #prepare the product template and get the associated mapping + + prod_mol_to_query, prod_template_mol_query = get_atom_mapping(prod_mol, prod_smiles, r_sites_reversed = r_sites_reversed) + + reactant_mols = [] + + + #--------------process the reactants----------------- + + reactant_information = {} #index of relevant reaction --> [[R group name, atom index of R group, atom index of R group connection], ...] + + for idx, reactant in enumerate(reactants): + reactant_information[idx] = [] + reactant_mols.append(Chem.MolFromSmiles(reactant['smiles'])) + has_r = False + + r_sites_reactant = {} + + h_counter = 0 + + for a_idx, atom in enumerate(reactant['atoms']): + + #go through all atoms and check if they are an R group, if so add it to reactant information + sym = atom['atom_symbol'] + if sym == '[H]': + h_counter += 1 + if sym[0] == '[': + sym = sym[1:-1] + if sym[0] == 'R' and sym[1:].isdigit(): + sym = sym[1:]+"*" + sym = f'[{sym}]' + if sym in r_sites: + if reactant_mols[-1].GetNumAtoms()==1: + reactant_information[idx].append([sym, -1, -1]) + else: + has_r = True + reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile']) + reactant_information[idx].append([sym, a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]]) + r_sites_reactant[sym] = a_idx-h_counter + elif sym == '[1*]' and '[7*]' in r_sites: + if reactant_mols[-1].GetNumAtoms()==1: + reactant_information[idx].append(['[7*]', -1, -1]) + else: + has_r = True + reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile']) + reactant_information[idx].append(['[7*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]]) + r_sites_reactant['[7*]'] = a_idx-h_counter + elif sym == '[7*]' and '[1*]' in r_sites: + if reactant_mols[-1].GetNumAtoms()==1: + reactant_information[idx].append(['[1*]', -1, -1]) + else: + has_r = True + reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile']) + reactant_information[idx].append(['[1*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]]) + r_sites_reactant['[1*]'] = a_idx-h_counter + + + + elif sym == '[1*]' and '[Rf]' in r_sites: + if reactant_mols[-1].GetNumAtoms()==1: + reactant_information[idx].append(['[Rf]', -1, -1]) + else: + has_r = True + reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile']) + reactant_information[idx].append(['[Rf]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]]) + r_sites_reactant['[Rf]'] = a_idx-h_counter + + elif sym == '[Rf]' and '[1*]' in r_sites: + if reactant_mols[-1].GetNumAtoms()==1: + reactant_information[idx].append(['[1*]', -1, -1]) + else: + has_r = True + reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile']) + reactant_information[idx].append(['[1*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]]) + r_sites_reactant['[1*]'] = a_idx-h_counter + + + r_sites_reversed_reactant = {r_sites_reactant[i]: i for i in r_sites_reactant} + # if the reactant had r groups, we had to use the molecule generated from the MolBlock. + # but the molblock may have unexpanded elemeents that are not R groups + # so we have to map back the r group indices in the molblock version to the full molecule generated by the smiles + # and adjust the indices of the r groups accordingly + if has_r: + #get the mapping + reactant_mol_to_query, _ = get_atom_mapping(reactant_mols[-1], reactant['smiles'], r_sites_reversed = r_sites_reversed_reactant) + + #make the adjustment + for info in reactant_information[idx]: + info[1] = reactant_mol_to_query[info[1]] + info[2] = reactant_mol_to_query[info[2]] + reactant_mols[-1] = Chem.MolFromSmiles(reactant['smiles']) + + #go through all the molecules in the coreference + + clean_corefs(coref_results_dict, label_idx) + + for other_prod in coref_results_dict: + + #check if they match the product label regex + found_good_label = False + for parsed in coref_results_dict[other_prod]: + if re.search(label_pattern, parsed) and not found_good_label: + found_good_label = True + other_prod_mol = Chem.MolFromSmiles(other_prod) + + if other_prod != prod_smiles and other_prod_mol is not None: + + #check if there are R groups to be resolved in the target product + + all_other_prod_mols = [] + + r_group_sub_pattern = re.compile('(?P[RXY]\d?)[ ]*=[ ]*(?P\w+)') + + for parsed_labels in coref_results_dict[other_prod]: + res = r_group_sub_pattern.search(parsed_labels) + + if res is not None: + all_other_prod_mols.append((expand_r_group_label_helper(res, coref_smiles_to_graphs, other_prod, molscribe), parsed + parsed_labels)) + + if len(all_other_prod_mols) == 0: + if other_prod_mol is not None: + all_other_prod_mols.append((other_prod_mol, parsed)) + + + + + for other_prod_mol, parsed in all_other_prod_mols: + + other_prod_frags = Chem.GetMolFrags(other_prod_mol, asMols = True) + + for other_prod_frag in other_prod_frags: + substructs = other_prod_frag.GetSubstructMatches(prod_template_mol_query, uniquify = False) + + if len(substructs)>0: + other_prod_mol = other_prod_frag + break + r_sites_reversed_new = {prod_mol_to_query[r]: r_sites_reversed[r] for r in r_sites_reversed} + + queries = query_enumeration(prod_template_mol_query, r_sites_reversed_new, num_r_groups) + + matched = False + + for query in queries: + if not matched: + try: + matched = get_r_group_frags_and_substitute(other_prod_mol, query, reactant_mols, reactant_information, parsed, toreturn) + except: + pass + + except: + pass + + + return toreturn + + +def backout_without_coref(results, coref_results, coref_results_dict, coref_smiles_to_graphs, molscribe): + + toreturn = [] + + if not results or not results[0]['reactions'] or not coref_results: + return toreturn + + try: + reactants = results[0]['reactions'][0]['reactants'] + products = [i['smiles'] for i in results[0]['reactions'][0]['products']] + coref_results_dict = coref_results_dict + coref_smiles_to_graphs = coref_smiles_to_graphs + + + if len(products) == 1: + if products[0] not in coref_results_dict: + print("Warning: No Label Parsed") + return + product_labels = coref_results_dict[products[0]] + prod = products[0] + label_idx = product_labels[0] + ''' + if len(product_labels) == 1: + # get the coreference label of the product molecule + label_idx = product_labels[0] + else: + print("Warning: Malformed Label Parsed.") + return + ''' + else: + print("Warning: More than one product detected") + return + + # format the regular expression for labels that correspond to the product label + numbers = re.findall(r'\d+', label_idx) + label_idx = numbers[0] if len(numbers) > 0 else "" + label_pattern = rf'{re.escape(label_idx)}[a-zA-Z]+' + + + prod_smiles = prod + prod_mol = Chem.MolFromMolBlock(results[0]['reactions'][0]['products'][0]['molfile']) + + # identify the atom indices of the R groups in the product tempalte + h_counter = 0 + r_sites = {} + for idx, atom in enumerate(results[0]['reactions'][0]['products'][0]['atoms']): + sym = atom['atom_symbol'] + if sym == '[H]': + h_counter += 1 + if sym[0] == '[': + sym = sym[1:-1] + if sym[0] == 'R' and sym[1:].isdigit(): + sym = sym[1:]+"*" + sym = f'[{sym}]' + if sym in RGROUP_SYMBOLS: + if sym not in r_sites: + r_sites[sym] = [idx-h_counter] + else: + r_sites[sym].append(idx-h_counter) + + r_sites_reversed = {} + for sym in r_sites: + for pos in r_sites[sym]: + r_sites_reversed[pos] = sym + + num_r_groups = len(r_sites_reversed) + + #prepare the product template and get the associated mapping + + prod_mol_to_query, prod_template_mol_query = get_atom_mapping(prod_mol, prod_smiles, r_sites_reversed = r_sites_reversed) + + reactant_mols = [] + + + #--------------process the reactants----------------- + + reactant_information = {} #index of relevant reaction --> [[R group name, atom index of R group, atom index of R group connection], ...] + + for idx, reactant in enumerate(reactants): + reactant_information[idx] = [] + reactant_mols.append(Chem.MolFromSmiles(reactant['smiles'])) + has_r = False + + r_sites_reactant = {} + + h_counter = 0 + + for a_idx, atom in enumerate(reactant['atoms']): + + #go through all atoms and check if they are an R group, if so add it to reactant information + sym = atom['atom_symbol'] + if sym == '[H]': + h_counter += 1 + if sym[0] == '[': + sym = sym[1:-1] + if sym[0] == 'R' and sym[1:].isdigit(): + sym = sym[1:]+"*" + sym = f'[{sym}]' + if sym in r_sites: + if reactant_mols[-1].GetNumAtoms()==1: + reactant_information[idx].append([sym, -1, -1]) + else: + has_r = True + reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile']) + reactant_information[idx].append([sym, a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]]) + r_sites_reactant[sym] = a_idx-h_counter + elif sym == '[1*]' and '[7*]' in r_sites: + if reactant_mols[-1].GetNumAtoms()==1: + reactant_information[idx].append(['[7*]', -1, -1]) + else: + has_r = True + reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile']) + reactant_information[idx].append(['[7*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]]) + r_sites_reactant['[7*]'] = a_idx-h_counter + elif sym == '[7*]' and '[1*]' in r_sites: + if reactant_mols[-1].GetNumAtoms()==1: + reactant_information[idx].append(['[1*]', -1, -1]) + else: + has_r = True + reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile']) + reactant_information[idx].append(['[1*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]]) + r_sites_reactant['[1*]'] = a_idx-h_counter + + elif sym == '[1*]' and '[Rf]' in r_sites: + if reactant_mols[-1].GetNumAtoms()==1: + reactant_information[idx].append(['[Rf]', -1, -1]) + else: + has_r = True + reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile']) + reactant_information[idx].append(['[Rf]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]]) + r_sites_reactant['[Rf]'] = a_idx-h_counter + + elif sym == '[Rf]' and '[1*]' in r_sites: + if reactant_mols[-1].GetNumAtoms()==1: + reactant_information[idx].append(['[1*]', -1, -1]) + else: + has_r = True + reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile']) + reactant_information[idx].append(['[1*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]]) + r_sites_reactant['[1*]'] = a_idx-h_counter + + r_sites_reversed_reactant = {r_sites_reactant[i]: i for i in r_sites_reactant} + # if the reactant had r groups, we had to use the molecule generated from the MolBlock. + # but the molblock may have unexpanded elemeents that are not R groups + # so we have to map back the r group indices in the molblock version to the full molecule generated by the smiles + # and adjust the indices of the r groups accordingly + if has_r: + #get the mapping + reactant_mol_to_query, _ = get_atom_mapping(reactant_mols[-1], reactant['smiles'], r_sites_reversed = r_sites_reversed_reactant) + + #make the adjustment + for info in reactant_information[idx]: + info[1] = reactant_mol_to_query[info[1]] + info[2] = reactant_mol_to_query[info[2]] + reactant_mols[-1] = Chem.MolFromSmiles(reactant['smiles']) + + #go through all the molecules in the coreference + + clean_corefs(coref_results_dict, label_idx) + + for other_prod in coref_results_dict: + + #check if they match the product label regex + found_good_label = False + for parsed in coref_results_dict[other_prod]: + if re.search(label_pattern, parsed) and not found_good_label: + found_good_label = True + other_prod_mol = Chem.MolFromSmiles(other_prod) + + if other_prod != prod_smiles and other_prod_mol is not None: + + #check if there are R groups to be resolved in the target product + + all_other_prod_mols = [] + + r_group_sub_pattern = re.compile('(?P[RXY]\d?)[ ]*=[ ]*(?P\w+)') + + for parsed_labels in coref_results_dict[other_prod]: + res = r_group_sub_pattern.search(parsed_labels) + + if res is not None: + all_other_prod_mols.append((expand_r_group_label_helper(res, coref_smiles_to_graphs, other_prod, molscribe), parsed + parsed_labels)) + + if len(all_other_prod_mols) == 0: + if other_prod_mol is not None: + all_other_prod_mols.append((other_prod_mol, parsed)) + + + + + for other_prod_mol, parsed in all_other_prod_mols: + + other_prod_frags = Chem.GetMolFrags(other_prod_mol, asMols = True) + + for other_prod_frag in other_prod_frags: + substructs = other_prod_frag.GetSubstructMatches(prod_template_mol_query, uniquify = False) + + if len(substructs)>0: + other_prod_mol = other_prod_frag + break + r_sites_reversed_new = {prod_mol_to_query[r]: r_sites_reversed[r] for r in r_sites_reversed} + + queries = query_enumeration(prod_template_mol_query, r_sites_reversed_new, num_r_groups) + + matched = False + + for query in queries: + if not matched: + try: + matched = get_r_group_frags_and_substitute(other_prod_mol, query, reactant_mols, reactant_information, parsed, toreturn) + except: + pass + + except: + pass + + + return toreturn + + + +def associate_corefs(results, results_coref): + coref_smiles = {} + idx_pattern = r'\b\d+[a-zA-Z]{0,2}\b' + for result_coref in results_coref: + bboxes, corefs = result_coref['bboxes'], result_coref['corefs'] + for coref in corefs: + mol, idt = coref[0], coref[1] + if len(bboxes[idt]['text']) > 0: + for text in bboxes[idt]['text']: + matches = re.findall(idx_pattern, text) + for match in matches: + coref_smiles[match] = bboxes[mol]['smiles'] + + for page in results: + for reactions in page['reactions']: + for reaction in reactions['reactions']: + if 'Reactants' in reaction: + if isinstance(reaction['Reactants'], tuple): + if reaction['Reactants'][0] in coref_smiles: + reaction['Reactants'] = (f'{reaction["Reactants"][0]} ({coref_smiles[reaction["Reactants"][0]]})', reaction['Reactants'][1], reaction['Reactants'][2]) + else: + for idx, compound in enumerate(reaction['Reactants']): + if compound[0] in coref_smiles: + reaction['Reactants'][idx] = (f'{compound[0]} ({coref_smiles[compound[0]]})', compound[1], compound[2]) + if 'Product' in reaction: + if isinstance(reaction['Product'], tuple): + if reaction['Product'][0] in coref_smiles: + reaction['Product'] = (f'{reaction["Product"][0]} ({coref_smiles[reaction["Product"][0]]})', reaction['Product'][1], reaction['Product'][2]) + else: + for idx, compound in enumerate(reaction['Product']): + if compound[0] in coref_smiles: + reaction['Product'][idx] = (f'{compound[0]} ({coref_smiles[compound[0]]})', compound[1], compound[2]) + return results + + +def expand_reactions_with_backout(initial_results, results_coref, molscribe): + idx_pattern = r'^\d+[a-zA-Z]{0,2}$' + for reactions, result_coref in zip(initial_results, results_coref): + if not reactions['reactions']: + continue + try: + backout_results = backout([reactions], [result_coref], molscribe) + except Exception: + continue + conditions = reactions['reactions'][0]['conditions'] + idt_to_smiles = {} + if not backout_results: + continue + + for reactants, products, idt in backout_results: + reactions['reactions'].append({ + 'reactants': [{'category': '[Mol]', 'molfile': None, 'smiles': reactant} for reactant in reactants], + 'conditions': conditions[:], + 'products': [{'category': '[Mol]', 'molfile': None, 'smiles': product} for product in products] + }) + return initial_results + diff --git a/examples/exp.png b/examples/exp.png new file mode 100644 index 0000000000000000000000000000000000000000..f9055caa4429f1e1bc2e4ff6208808d1e84cd5ad --- /dev/null +++ b/examples/exp.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ce344ed33ff77f45d6e87a29e91426c3444ee9b58a8b10086ce3483a1ad2a2e +size 695688 diff --git a/examples/image.webp b/examples/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..25f53d472965dc1ca87c64a4248fa4e8fe4c7b03 Binary files /dev/null and b/examples/image.webp differ diff --git a/examples/rdkit.png b/examples/rdkit.png new file mode 100644 index 0000000000000000000000000000000000000000..3210df0f65f1b405854fb03cb2cf35b2e11a1f85 Binary files /dev/null and b/examples/rdkit.png differ diff --git a/examples/reaction1.jpg b/examples/reaction1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7cc7ef8a9638b6c6f31f3c83c0531f041f994ead Binary files /dev/null and b/examples/reaction1.jpg differ diff --git a/examples/reaction2.png b/examples/reaction2.png new file mode 100644 index 0000000000000000000000000000000000000000..4b1c9854a9e75dec5428b9e5e8372659b30ccd5a Binary files /dev/null and b/examples/reaction2.png differ diff --git a/examples/reaction3.png b/examples/reaction3.png new file mode 100644 index 0000000000000000000000000000000000000000..4f0c1e120fe65c6e4754854dd3d3a1e0031ce2d9 Binary files /dev/null and b/examples/reaction3.png differ diff --git a/examples/reaction4.png b/examples/reaction4.png new file mode 100644 index 0000000000000000000000000000000000000000..d21fb5a523855726590bbba281a71388aedd4f0b --- /dev/null +++ b/examples/reaction4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:341a3b9f6b24b3fe3793186ec198cf1171ffb84bca6c0316052f25e17c0eeb55 +size 231953 diff --git a/get_molecular_agent.py b/get_molecular_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..a931d179a62c3b4ccb46840334f752361e972b8a --- /dev/null +++ b/get_molecular_agent.py @@ -0,0 +1,599 @@ +import sys +import torch +import json +from chemietoolkit import ChemIEToolkit +import cv2 +from PIL import Image +import json +import sys +#sys.path.append('./RxnScribe-main/') +import torch +from rxnscribe import RxnScribe +import json +import sys +import torch +import json +model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) +from molscribe.chemistry import _convert_graph_to_smiles +import base64 +import torch +import json +from PIL import Image +import numpy as np +from chemietoolkit import ChemIEToolkit, utils +from openai import AzureOpenAI +import os + + + + +ckpt_path = "./pix2seq_reaction_full.ckpt" +model1 = RxnScribe(ckpt_path, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) +device = torch.device(('cuda' if torch.cuda.is_available() else 'cpu')) + +model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) + +def get_multi_molecular(image_path: str) -> list: + '''Returns a list of reactions extracted from the image.''' + # 打开图像文件 + image = Image.open(image_path).convert('RGB') + + # 将图像作为输入传递给模型 + coref_results = model.extract_molecule_corefs_from_figures([image]) + for item in coref_results: + for bbox in item.get("bboxes", []): + for key in ["category", "molfile", "symbols", 'atoms', "bonds", 'category_id', 'score', 'corefs']: #'atoms' + bbox.pop(key, None) # 安全地移除键 + print(json.dumps(coref_results)) + # 返回反应列表,使用 json.dumps 进行格式化 + + return json.dumps(coref_results) + +def get_multi_molecular_text_to_correct(image_path: str) -> list: + '''Returns a list of reactions extracted from the image.''' + # 打开图像文件 + image = Image.open(image_path).convert('RGB') + + # 将图像作为输入传递给模型 + coref_results = model.extract_molecule_corefs_from_figures([image]) + for item in coref_results: + for bbox in item.get("bboxes", []): + for key in ["category", "bbox", "molfile", "symbols", 'atoms', "bonds", 'category_id', 'score', 'corefs']: #'atoms' + bbox.pop(key, None) # 安全地移除键 + print(json.dumps(coref_results)) + # 返回反应列表,使用 json.dumps 进行格式化 + + return json.dumps(coref_results) + +def get_multi_molecular_text_to_correct_withatoms(image_path: str) -> list: + '''Returns a list of reactions extracted from the image.''' + # 打开图像文件 + image = Image.open(image_path).convert('RGB') + + # 将图像作为输入传递给模型 + coref_results = model.extract_molecule_corefs_from_figures([image]) + for item in coref_results: + for bbox in item.get("bboxes", []): + for key in ["coords","edges","molfile", 'atoms', "bonds", 'category_id', 'score', 'corefs']: #'atoms' + bbox.pop(key, None) # 安全地移除键 + print(json.dumps(coref_results)) + # 返回反应列表,使用 json.dumps 进行格式化 + return json.dumps(coref_results) + + + + + + +def process_reaction_image_with_multiple_products_and_text(image_path: str) -> dict: + """ + + + Args: + image_path (str): 图像文件路径。 + + Returns: + dict: 整理后的反应数据,包括反应物、产物和反应模板。 + """ + # 配置 API Key 和 Azure Endpoint + api_key = os.getenv("CHEMEAGLE_API_KEY") + if not api_key: + raise RuntimeError("Missing CHEMEAGLE_API_KEY environment variable") + # 替换为实际的 API Key + azure_endpoint = "https://hkust.azure-api.net" # 替换为实际的 Azure Endpoint + + + model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) + client = AzureOpenAI( + api_key=api_key, + api_version='2024-06-01', + azure_endpoint=azure_endpoint + ) + + # 加载图像并编码为 Base64 + def encode_image(image_path: str): + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode('utf-8') + + base64_image = encode_image(image_path) + + # GPT 工具调用配置 + tools = [ + { + 'type': 'function', + 'function': { + 'name': 'get_multi_molecular_text_to_correct_withatoms', + 'description': 'Extracts the SMILES string, the symbols set, and the text coref of all molecular images in a table-reaction image and ready to be correct.', + 'parameters': { + 'type': 'object', + 'properties': { + 'image_path': { + 'type': 'string', + 'description': 'The path to the reaction image.', + }, + }, + 'required': ['image_path'], + 'additionalProperties': False, + }, + }, + }, + + ] + + # 提供给 GPT 的消息内容 + with open('./prompt_getmolecular.txt', 'r') as prompt_file: + prompt = prompt_file.read() + messages = [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + { + 'role': 'user', + 'content': [ + {'type': 'text', 'text': prompt}, + {'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}} + ] + } + ] + + # 调用 GPT 接口 + response = client.chat.completions.create( + model = 'gpt-4o', + temperature = 0, + response_format={ 'type': 'json_object' }, + messages = [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': prompt + }, + { + 'type': 'image_url', + 'image_url': { + 'url': f'data:image/png;base64,{base64_image}' + } + } + ]}, + ], + tools = tools) + +# Step 1: 工具映射表 + TOOL_MAP = { + 'get_multi_molecular_text_to_correct_withatoms': get_multi_molecular_text_to_correct_withatoms, + } + + # Step 2: 处理多个工具调用 + tool_calls = response.choices[0].message.tool_calls + results = [] + + # 遍历每个工具调用 + for tool_call in tool_calls: + tool_name = tool_call.function.name + tool_arguments = tool_call.function.arguments + tool_call_id = tool_call.id + + tool_args = json.loads(tool_arguments) + + if tool_name in TOOL_MAP: + # 调用工具并获取结果 + tool_result = TOOL_MAP[tool_name](image_path) + else: + raise ValueError(f"Unknown tool called: {tool_name}") + + # 保存每个工具调用结果 + results.append({ + 'role': 'tool', + 'content': json.dumps({ + 'image_path': image_path, + f'{tool_name}':(tool_result), + }), + 'tool_call_id': tool_call_id, + }) + + +# Prepare the chat completion payload + completion_payload = { + 'model': 'gpt-4o', + 'messages': [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': prompt + }, + { + 'type': 'image_url', + 'image_url': { + 'url': f'data:image/png;base64,{base64_image}' + } + } + ] + }, + response.choices[0].message, + *results + ], + } + +# Generate new response + response = client.chat.completions.create( + model=completion_payload["model"], + messages=completion_payload["messages"], + response_format={ 'type': 'json_object' }, + temperature=0 + ) + + + + # 获取 GPT 生成的结果 + gpt_output = [json.loads(response.choices[0].message.content)] + + + def get_multi_molecular(image_path: str) -> list: + '''Returns a list of reactions extracted from the image.''' + # 打开图像文件 + image = Image.open(image_path).convert('RGB') + + # 将图像作为输入传递给模型 + coref_results = model.extract_molecule_corefs_from_figures([image]) + return coref_results + + + coref_results = get_multi_molecular(image_path) + + + def update_symbols_in_atoms(input1, input2): + """ + 用 input1 中更新后的 'symbols' 替换 input2 中对应 bboxes 的 'symbols',并同步更新 'atoms' 的 'atom_symbol'。 + 假设 input1 和 input2 的结构一致。 + """ + for item1, item2 in zip(input1, input2): + bboxes1 = item1.get('bboxes', []) + bboxes2 = item2.get('bboxes', []) + + if len(bboxes1) != len(bboxes2): + print("Warning: Mismatched number of bboxes!") + continue + + for bbox1, bbox2 in zip(bboxes1, bboxes2): + # 更新 symbols + if 'symbols' in bbox1: + bbox2['symbols'] = bbox1['symbols'] # 更新 symbols + + # 更新 atoms 的 atom_symbol + if 'symbols' in bbox1 and 'atoms' in bbox2: + symbols = bbox1['symbols'] + atoms = bbox2.get('atoms', []) + + # 确保 symbols 和 atoms 的长度一致 + if len(symbols) != len(atoms): + print(f"Warning: Mismatched symbols and atoms in bbox {bbox1.get('bbox')}!") + continue + + for atom, symbol in zip(atoms, symbols): + atom['atom_symbol'] = symbol # 更新 atom_symbol + + return input2 + + + input2_updated = update_symbols_in_atoms(gpt_output, coref_results) + + + + + + def update_smiles_and_molfile(input_data, conversion_function): + """ + 使用更新后的 'symbols'、'coords' 和 'edges' 调用 `conversion_function` 生成新的 'smiles' 和 'molfile', + 并替换到原数据结构中。 + + 参数: + - input_data: 包含 bboxes 的嵌套数据结构 + - conversion_function: 函数,接受 'coords', 'symbols', 'edges' 并返回 (new_smiles, new_molfile, _) + + 返回: + - 更新后的数据结构 + """ + for item in input_data: + for bbox in item.get('bboxes', []): + # 检查必需的键是否存在 + if all(key in bbox for key in ['coords', 'symbols', 'edges']): + coords = bbox['coords'] + symbols = bbox['symbols'] + edges = bbox['edges'] + + # 调用转换函数生成新的 'smiles' 和 'molfile' + new_smiles, new_molfile, _ = conversion_function(coords, symbols, edges) + print(f" Generated 'smiles': {new_smiles}") + + # 替换旧的 'smiles' 和 'molfile' + bbox['smiles'] = new_smiles + bbox['molfile'] = new_molfile + + return input_data + + updated_data = update_smiles_and_molfile(input2_updated, _convert_graph_to_smiles) + + return updated_data + + + + + + + + + +def process_reaction_image_with_multiple_products_and_text_correctR(image_path: str) -> dict: + """ + + + Args: + image_path (str): 图像文件路径。 + + Returns: + dict: 整理后的反应数据,包括反应物、产物和反应模板。 + """ + # 配置 API Key 和 Azure Endpoint + api_key = os.getenv("CHEMEAGLE_API_KEY") + if not api_key: + raise RuntimeError("Missing CHEMEAGLE_API_KEY environment variable") + # 替换为实际的 API Key + azure_endpoint = "https://hkust.azure-api.net" # 替换为实际的 Azure Endpoint + model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) + client = AzureOpenAI( + api_key=api_key, + api_version='2024-06-01', + azure_endpoint=azure_endpoint + ) + + # 加载图像并编码为 Base64 + def encode_image(image_path: str): + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode('utf-8') + + base64_image = encode_image(image_path) + + # GPT 工具调用配置 + tools = [ + { + 'type': 'function', + 'function': { + 'name': 'get_multi_molecular_text_to_correct_withatoms', + 'description': 'Extracts the SMILES string, the symbols set, and the text coref of all molecular images in a table-reaction image and ready to be correct.', + 'parameters': { + 'type': 'object', + 'properties': { + 'image_path': { + 'type': 'string', + 'description': 'The path to the reaction image.', + }, + }, + 'required': ['image_path'], + 'additionalProperties': False, + }, + }, + }, + + ] + + # 提供给 GPT 的消息内容 + with open('./prompt_getmolecular_correctR.txt', 'r') as prompt_file: + prompt = prompt_file.read() + messages = [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + { + 'role': 'user', + 'content': [ + {'type': 'text', 'text': prompt}, + {'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}} + ] + } + ] + + # 调用 GPT 接口 + response = client.chat.completions.create( + model = 'gpt-4o', + temperature = 0, + response_format={ 'type': 'json_object' }, + messages = [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': prompt + }, + { + 'type': 'image_url', + 'image_url': { + 'url': f'data:image/png;base64,{base64_image}' + } + } + ]}, + ], + tools = tools) + +# Step 1: 工具映射表 + TOOL_MAP = { + 'get_multi_molecular_text_to_correct_withatoms': get_multi_molecular_text_to_correct_withatoms, + } + + # Step 2: 处理多个工具调用 + tool_calls = response.choices[0].message.tool_calls + results = [] + + # 遍历每个工具调用 + for tool_call in tool_calls: + tool_name = tool_call.function.name + tool_arguments = tool_call.function.arguments + tool_call_id = tool_call.id + + tool_args = json.loads(tool_arguments) + + if tool_name in TOOL_MAP: + # 调用工具并获取结果 + tool_result = TOOL_MAP[tool_name](image_path) + else: + raise ValueError(f"Unknown tool called: {tool_name}") + + # 保存每个工具调用结果 + results.append({ + 'role': 'tool', + 'content': json.dumps({ + 'image_path': image_path, + f'{tool_name}':(tool_result), + }), + 'tool_call_id': tool_call_id, + }) + + +# Prepare the chat completion payload + completion_payload = { + 'model': 'gpt-4o', + 'messages': [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': prompt + }, + { + 'type': 'image_url', + 'image_url': { + 'url': f'data:image/png;base64,{base64_image}' + } + } + ] + }, + response.choices[0].message, + *results + ], + } + +# Generate new response + response = client.chat.completions.create( + model=completion_payload["model"], + messages=completion_payload["messages"], + response_format={ 'type': 'json_object' }, + temperature=0 + ) + + + + # 获取 GPT 生成的结果 + gpt_output = [json.loads(response.choices[0].message.content)] + + + def get_multi_molecular(image_path: str) -> list: + '''Returns a list of reactions extracted from the image.''' + # 打开图像文件 + image = Image.open(image_path).convert('RGB') + + # 将图像作为输入传递给模型 + coref_results = model.extract_molecule_corefs_from_figures([image]) + return coref_results + + + coref_results = get_multi_molecular(image_path) + + + def update_symbols_in_atoms(input1, input2): + """ + 用 input1 中更新后的 'symbols' 替换 input2 中对应 bboxes 的 'symbols',并同步更新 'atoms' 的 'atom_symbol'。 + 假设 input1 和 input2 的结构一致。 + """ + for item1, item2 in zip(input1, input2): + bboxes1 = item1.get('bboxes', []) + bboxes2 = item2.get('bboxes', []) + + if len(bboxes1) != len(bboxes2): + print("Warning: Mismatched number of bboxes!") + continue + + for bbox1, bbox2 in zip(bboxes1, bboxes2): + # 更新 symbols + if 'symbols' in bbox1: + bbox2['symbols'] = bbox1['symbols'] # 更新 symbols + + # 更新 atoms 的 atom_symbol + if 'symbols' in bbox1 and 'atoms' in bbox2: + symbols = bbox1['symbols'] + atoms = bbox2.get('atoms', []) + + # 确保 symbols 和 atoms 的长度一致 + if len(symbols) != len(atoms): + print(f"Warning: Mismatched symbols and atoms in bbox {bbox1.get('bbox')}!") + continue + + for atom, symbol in zip(atoms, symbols): + atom['atom_symbol'] = symbol # 更新 atom_symbol + + return input2 + + + input2_updated = update_symbols_in_atoms(gpt_output, coref_results) + + + + + + def update_smiles_and_molfile(input_data, conversion_function): + """ + 使用更新后的 'symbols'、'coords' 和 'edges' 调用 `conversion_function` 生成新的 'smiles' 和 'molfile', + 并替换到原数据结构中。 + + 参数: + - input_data: 包含 bboxes 的嵌套数据结构 + - conversion_function: 函数,接受 'coords', 'symbols', 'edges' 并返回 (new_smiles, new_molfile, _) + + 返回: + - 更新后的数据结构 + """ + for item in input_data: + for bbox in item.get('bboxes', []): + # 检查必需的键是否存在 + if all(key in bbox for key in ['coords', 'symbols', 'edges']): + coords = bbox['coords'] + symbols = bbox['symbols'] + edges = bbox['edges'] + + # 调用转换函数生成新的 'smiles' 和 'molfile' + new_smiles, new_molfile, _ = conversion_function(coords, symbols, edges) + print(f" Generated 'smiles': {new_smiles}") + + # 替换旧的 'smiles' 和 'molfile' + bbox['smiles'] = new_smiles + bbox['molfile'] = new_molfile + + return input_data + + updated_data = update_smiles_and_molfile(input2_updated, _convert_graph_to_smiles) + print(f"updated_mol_data:{updated_data}") + + return updated_data diff --git a/get_reaction_agent.py b/get_reaction_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..649c80cebc8a7d9cdf3e8ec4385e0a1e79266cfa --- /dev/null +++ b/get_reaction_agent.py @@ -0,0 +1,507 @@ +import sys +import torch +import json +from chemietoolkit import ChemIEToolkit +import cv2 +from PIL import Image +import json +import sys +#sys.path.append('./RxnScribe-main/') +import torch +from rxnscribe import RxnScribe +import json +from molscribe.chemistry import _convert_graph_to_smiles + +from openai import AzureOpenAI +import base64 +import numpy as np +from chemietoolkit import utils +from PIL import Image + + + + +ckpt_path = "./pix2seq_reaction_full.ckpt" +model1 = RxnScribe(ckpt_path, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) +device = torch.device(('cuda' if torch.cuda.is_available() else 'cpu')) +model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) + +def get_reaction(image_path: str) -> dict: + ''' + Returns a structured dictionary of reactions extracted from the image, + including reactants, conditions, and products, with their smiles, text, and bbox. + ''' + image_file = image_path + raw_prediction = model1.predict_image_file(image_file, molscribe=True, ocr=True) + + # Ensure raw_prediction is treated as a list directly + structured_output = {} + for section_key in ['reactants', 'conditions', 'products']: + if section_key in raw_prediction[0]: + structured_output[section_key] = [] + for item in raw_prediction[0][section_key]: + if section_key in ['reactants', 'products']: + # Extract smiles and bbox for molecules + structured_output[section_key].append({ + "smiles": item.get("smiles", ""), + "bbox": item.get("bbox", []), + "symbols": item.get("symbols", []) + }) + elif section_key == 'conditions': + # Extract smiles, text, and bbox for conditions + condition_data = {"bbox": item.get("bbox", [])} + if "smiles" in item: + condition_data["smiles"] = item.get("smiles", "") + if "text" in item: + condition_data["text"] = item.get("text", []) + structured_output[section_key].append(condition_data) + print(structured_output) + + return structured_output + + + +def get_full_reaction(image_path: str) -> dict: + ''' + Returns a structured dictionary of reactions extracted from the image, + including reactants, conditions, and products, with their smiles, text, and bbox. + ''' + image_file = image_path + raw_prediction = model1.predict_image_file(image_file, molscribe=True, ocr=True) + return raw_prediction + + + +def get_reaction_withatoms(image_path: str) -> dict: + """ + + Args: + image_path (str): 图像文件路径。 + + Returns: + dict: 整理后的反应数据,包括反应物、产物和反应模板。 + """ + # 配置 API Key 和 Azure Endpoint + api_key = "b038da96509b4009be931e035435e022" # 替换为实际的 API Key + azure_endpoint = "https://hkust.azure-api.net" # 替换为实际的 Azure Endpoint + + model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) + client = AzureOpenAI( + api_key=api_key, + api_version='2024-06-01', + azure_endpoint=azure_endpoint + ) + + # 加载图像并编码为 Base64 + def encode_image(image_path: str): + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode('utf-8') + + base64_image = encode_image(image_path) + + # GPT 工具调用配置 + tools = [ + { + 'type': 'function', + 'function': { + 'name': 'get_reaction', + 'description': 'Get a list of reactions from a reaction image. A reaction contains data of the reactants, conditions, and products.', + 'parameters': { + 'type': 'object', + 'properties': { + 'image_path': { + 'type': 'string', + 'description': 'The path to the reaction image.', + }, + }, + 'required': ['image_path'], + 'additionalProperties': False, + }, + }, + }, + ] + + # 提供给 GPT 的消息内容 + with open('./prompt_getreaction.txt', 'r') as prompt_file: + prompt = prompt_file.read() + messages = [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + { + 'role': 'user', + 'content': [ + {'type': 'text', 'text': prompt}, + {'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}} + ] + } + ] + + # 调用 GPT 接口 + response = client.chat.completions.create( + model = 'gpt-4o', + temperature = 0, + response_format={ 'type': 'json_object' }, + messages = [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': prompt + }, + { + 'type': 'image_url', + 'image_url': { + 'url': f'data:image/png;base64,{base64_image}' + } + } + ]}, + ], + tools = tools) + +# Step 1: 工具映射表 + TOOL_MAP = { + 'get_reaction': get_reaction, + } + + # Step 2: 处理多个工具调用 + tool_calls = response.choices[0].message.tool_calls + results = [] + + # 遍历每个工具调用 + for tool_call in tool_calls: + tool_name = tool_call.function.name + tool_arguments = tool_call.function.arguments + tool_call_id = tool_call.id + + tool_args = json.loads(tool_arguments) + + if tool_name in TOOL_MAP: + # 调用工具并获取结果 + tool_result = TOOL_MAP[tool_name](image_path) + else: + raise ValueError(f"Unknown tool called: {tool_name}") + + # 保存每个工具调用结果 + results.append({ + 'role': 'tool', + 'content': json.dumps({ + 'image_path': image_path, + f'{tool_name}':(tool_result), + }), + 'tool_call_id': tool_call_id, + }) + + +# Prepare the chat completion payload + completion_payload = { + 'model': 'gpt-4o', + 'messages': [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': prompt + }, + { + 'type': 'image_url', + 'image_url': { + 'url': f'data:image/png;base64,{base64_image}' + } + } + ] + }, + response.choices[0].message, + *results + ], + } + +# Generate new response + response = client.chat.completions.create( + model=completion_payload["model"], + messages=completion_payload["messages"], + response_format={ 'type': 'json_object' }, + temperature=0 + ) + + + + # 获取 GPT 生成的结果 + gpt_output = json.loads(response.choices[0].message.content) + print(f"gpt_output1:{gpt_output}") + + + def get_reaction_full(image_path: str) -> dict: + ''' + Returns a structured dictionary of reactions extracted from the image, + including reactants, conditions, and products, with their smiles, text, and bbox. + ''' + image_file = image_path + raw_prediction = model1.predict_image_file(image_file, molscribe=True, ocr=True) + return raw_prediction + + input2 = get_reaction_full(image_path) + + + + def update_input_with_symbols(input1, input2, conversion_function): + symbol_mapping = {} + for key in ['reactants', 'products']: + for item in input1.get(key, []): + bbox = tuple(item['bbox']) # 使用 bbox 作为唯一标识 + symbol_mapping[bbox] = item['symbols'] + + for key in ['reactants', 'products']: + for item in input2.get(key, []): + bbox = tuple(item['bbox']) # 获取 bbox 作为匹配键 + + # 如果 bbox 存在于 input1 的映射中,则更新 symbols + if bbox in symbol_mapping: + updated_symbols = symbol_mapping[bbox] + item['symbols'] = updated_symbols + + # 更新 atoms 的 atom_symbol + if 'atoms' in item: + atoms = item['atoms'] + if len(atoms) != len(updated_symbols): + print(f"Warning: Mismatched symbols and atoms in bbox {bbox}") + else: + for atom, symbol in zip(atoms, updated_symbols): + atom['atom_symbol'] = symbol + + # 如果 coords 和 edges 存在,调用转换函数生成新的 smiles 和 molfile + if 'coords' in item and 'edges' in item: + coords = item['coords'] + edges = item['edges'] + new_smiles, new_molfile, _ = conversion_function(coords, updated_symbols, edges) + + # 替换旧的 smiles 和 molfile + item['smiles'] = new_smiles + item['molfile'] = new_molfile + + return input2 + + updated_data = [update_input_with_symbols(gpt_output, input2[0], _convert_graph_to_smiles)] + + return updated_data + + + + +def get_reaction_withatoms_correctR(image_path: str) -> dict: + """ + + Args: + image_path (str): 图像文件路径。 + + Returns: + dict: 整理后的反应数据,包括反应物、产物和反应模板。 + """ + # 配置 API Key 和 Azure Endpoint + api_key = "b038da96509b4009be931e035435e022" # 替换为实际的 API Key + azure_endpoint = "https://hkust.azure-api.net" # 替换为实际的 Azure Endpoint + + model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) + client = AzureOpenAI( + api_key=api_key, + api_version='2024-06-01', + azure_endpoint=azure_endpoint + ) + + # 加载图像并编码为 Base64 + def encode_image(image_path: str): + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode('utf-8') + + base64_image = encode_image(image_path) + + # GPT 工具调用配置 + tools = [ + { + 'type': 'function', + 'function': { + 'name': 'get_reaction', + 'description': 'Get a list of reactions from a reaction image. A reaction contains data of the reactants, conditions, and products.', + 'parameters': { + 'type': 'object', + 'properties': { + 'image_path': { + 'type': 'string', + 'description': 'The path to the reaction image.', + }, + }, + 'required': ['image_path'], + 'additionalProperties': False, + }, + }, + }, + ] + + # 提供给 GPT 的消息内容 + with open('./prompt_getreaction_correctR.txt', 'r') as prompt_file: + prompt = prompt_file.read() + messages = [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + { + 'role': 'user', + 'content': [ + {'type': 'text', 'text': prompt}, + {'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}} + ] + } + ] + + # 调用 GPT 接口 + response = client.chat.completions.create( + model = 'gpt-4o', + temperature = 0, + response_format={ 'type': 'json_object' }, + messages = [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': prompt + }, + { + 'type': 'image_url', + 'image_url': { + 'url': f'data:image/png;base64,{base64_image}' + } + } + ]}, + ], + tools = tools) + +# Step 1: 工具映射表 + TOOL_MAP = { + 'get_reaction': get_reaction, + } + + # Step 2: 处理多个工具调用 + tool_calls = response.choices[0].message.tool_calls + results = [] + + # 遍历每个工具调用 + for tool_call in tool_calls: + tool_name = tool_call.function.name + tool_arguments = tool_call.function.arguments + tool_call_id = tool_call.id + + tool_args = json.loads(tool_arguments) + + if tool_name in TOOL_MAP: + # 调用工具并获取结果 + tool_result = TOOL_MAP[tool_name](image_path) + else: + raise ValueError(f"Unknown tool called: {tool_name}") + + # 保存每个工具调用结果 + results.append({ + 'role': 'tool', + 'content': json.dumps({ + 'image_path': image_path, + f'{tool_name}':(tool_result), + }), + 'tool_call_id': tool_call_id, + }) + + +# Prepare the chat completion payload + completion_payload = { + 'model': 'gpt-4o', + 'messages': [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': prompt + }, + { + 'type': 'image_url', + 'image_url': { + 'url': f'data:image/png;base64,{base64_image}' + } + } + ] + }, + response.choices[0].message, + *results + ], + } + +# Generate new response + response = client.chat.completions.create( + model=completion_payload["model"], + messages=completion_payload["messages"], + response_format={ 'type': 'json_object' }, + temperature=0 + ) + + + + # 获取 GPT 生成的结果 + gpt_output = json.loads(response.choices[0].message.content) + print(f"gpt_output1:{gpt_output}") + + + def get_reaction_full(image_path: str) -> dict: + ''' + Returns a structured dictionary of reactions extracted from the image, + including reactants, conditions, and products, with their smiles, text, and bbox. + ''' + image_file = image_path + raw_prediction = model1.predict_image_file(image_file, molscribe=True, ocr=True) + return raw_prediction + + input2 = get_reaction_full(image_path) + + + + def update_input_with_symbols(input1, input2, conversion_function): + symbol_mapping = {} + for key in ['reactants', 'products']: + for item in input1.get(key, []): + bbox = tuple(item['bbox']) # 使用 bbox 作为唯一标识 + symbol_mapping[bbox] = item['symbols'] + + for key in ['reactants', 'products']: + for item in input2.get(key, []): + bbox = tuple(item['bbox']) # 获取 bbox 作为匹配键 + + # 如果 bbox 存在于 input1 的映射中,则更新 symbols + if bbox in symbol_mapping: + updated_symbols = symbol_mapping[bbox] + item['symbols'] = updated_symbols + + # 更新 atoms 的 atom_symbol + if 'atoms' in item: + atoms = item['atoms'] + if len(atoms) != len(updated_symbols): + print(f"Warning: Mismatched symbols and atoms in bbox {bbox}") + else: + for atom, symbol in zip(atoms, updated_symbols): + atom['atom_symbol'] = symbol + + # 如果 coords 和 edges 存在,调用转换函数生成新的 smiles 和 molfile + if 'coords' in item and 'edges' in item: + coords = item['coords'] + edges = item['edges'] + new_smiles, new_molfile, _ = conversion_function(coords, updated_symbols, edges) + + # 替换旧的 smiles 和 molfile + item['smiles'] = new_smiles + item['molfile'] = new_molfile + + return input2 + + updated_data = [update_input_with_symbols(gpt_output, input2[0], _convert_graph_to_smiles)] + print(f"updated_reaction_data:{updated_data}") + + return updated_data \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..9a8a68254d92480f01a8c503d7f24299ae446963 --- /dev/null +++ b/main.py @@ -0,0 +1,546 @@ +import sys +import torch +import json +from chemietoolkit import ChemIEToolkit,utils +import cv2 +from openai import AzureOpenAI +import numpy as np +from PIL import Image +import json +from get_molecular_agent import process_reaction_image_with_multiple_products_and_text_correctR +from get_reaction_agent import get_reaction_withatoms_correctR +import sys +from rxnscribe import RxnScribe +import json +import base64 +model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) +ckpt_path = "./pix2seq_reaction_full.ckpt" +model1 = RxnScribe(ckpt_path, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) +device = torch.device(('cuda' if torch.cuda.is_available() else 'cpu')) +import os + +with open('api_key.txt', 'r') as api_key_file: + API_KEY = api_key_file.read() + +def parse_coref_data_with_fallback(data): + bboxes = data["bboxes"] + corefs = data["corefs"] + paired_indices = set() + + # 先处理有 coref 配对的 + results = [] + for idx1, idx2 in corefs: + smiles_entry = bboxes[idx1] if "smiles" in bboxes[idx1] else bboxes[idx2] + text_entry = bboxes[idx2] if "text" in bboxes[idx2] else bboxes[idx1] + + smiles = smiles_entry.get("smiles", "") + texts = text_entry.get("text", []) + + results.append({ + "smiles": smiles, + "texts": texts + }) + + # 记录下哪些 SMILES 被配对过了 + paired_indices.add(idx1) + paired_indices.add(idx2) + + # 处理未配对的 SMILES(补充进来) + for idx, entry in enumerate(bboxes): + if "smiles" in entry and idx not in paired_indices: + results.append({ + "smiles": entry["smiles"], + "texts": ["There is no label or failed to detect, please recheck the image again"] + }) + + return results + + +def get_multi_molecular_text_to_correct(image_path: str) -> list: + '''Returns a list of reactions extracted from the image.''' + # 打开图像文件 + image = Image.open(image_path).convert('RGB') + + # 将图像作为输入传递给模型 + #coref_results = process_reaction_image_with_multiple_products_and_text_correctR(image_path) + coref_results = model.extract_molecule_corefs_from_figures([image]) + for item in coref_results: + for bbox in item.get("bboxes", []): + for key in ["category", "bbox", "molfile", "symbols", 'atoms', "bonds", 'category_id', 'score', 'corefs',"coords","edges"]: #'atoms' + bbox.pop(key, None) # 安全地移除键 + + data = coref_results[0] + parsed = parse_coref_data_with_fallback(data) + + + print(f"coref_results:{json.dumps(parsed)}") + return json.dumps(parsed) + + + + + + + + +def get_reaction(image_path: str) -> dict: + ''' + Returns a structured dictionary of reactions extracted from the image, + including only reactants, conditions, and products with their smiles, bbox, or text. + ''' + image_file = image_path + #raw_prediction = model1.predict_image_file(image_file, molscribe=True, ocr=True) + raw_prediction = get_reaction_withatoms_correctR(image_path) + + + # Ensure raw_prediction is treated as a list directly + structured_output = {} + for section_key in ['reactants', 'conditions', 'products']: + if section_key in raw_prediction[0]: + structured_output[section_key] = [] + for item in raw_prediction[0][section_key]: + if section_key in ['reactants', 'products']: + # Extract smiles and bbox for molecules + structured_output[section_key].append({ + "smiles": item.get("smiles", ""), + "bbox": item.get("bbox", []) + }) + elif section_key == 'conditions': + # Extract text and bbox for conditions + structured_output[section_key].append({ + "text": item.get("text", []), + "bbox": item.get("bbox", []), + "smiles": item.get("smiles", []), + }) + + return structured_output + + + +def process_reaction_image(image_path: str) -> dict: + """ + + Args: + image_path (str): 图像文件路径。 + + Returns: + dict: 整理后的反应数据,包括反应物、产物和反应模板。 + """ + # 配置 API Key 和 Azure Endpoint + api_key = os.getenv("CHEMEAGLE_API_KEY") + if not api_key: + raise RuntimeError("Missing CHEMEAGLE_API_KEY environment variable") + + azure_endpoint = "https://hkust.azure-api.net" # 替换为实际的 Azure Endpoint + + model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) + client = AzureOpenAI( + api_key=api_key, + api_version='2024-06-01', + azure_endpoint=azure_endpoint + ) + + # 加载图像并编码为 Base64 + def encode_image(image_path: str): + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode('utf-8') + + base64_image = encode_image(image_path) + + # GPT 工具调用配置 + tools = [ + { + 'type': 'function', + 'function': { + 'name': 'get_multi_molecular_text_to_correct', + 'description': 'Extracts the SMILES string and text coref from molecular images.', + 'parameters': { + 'type': 'object', + 'properties': { + 'image_path': { + 'type': 'string', + 'description': 'Path to the reaction image.' + } + }, + 'required': ['image_path'], + 'additionalProperties': False + } + } + }, + { + 'type': 'function', + 'function': { + 'name': 'get_reaction', + 'description': 'Get a list of reactions from a reaction image. A reaction contains data of the reactants, conditions, and products.', + 'parameters': { + 'type': 'object', + 'properties': { + 'image_path': { + 'type': 'string', + 'description': 'The path to the reaction image.', + }, + }, + 'required': ['image_path'], + 'additionalProperties': False, + }, + }, + }, + ] + + # 提供给 GPT 的消息内容 + with open('./prompt.txt', 'r') as prompt_file: + prompt = prompt_file.read() + messages = [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + { + 'role': 'user', + 'content': [ + {'type': 'text', 'text': prompt}, + {'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}} + ] + } + ] + + # 调用 GPT 接口 + response = client.chat.completions.create( + model = 'gpt-4o', + temperature = 0, + response_format={ 'type': 'json_object' }, + messages = [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': prompt + }, + { + 'type': 'image_url', + 'image_url': { + 'url': f'data:image/png;base64,{base64_image}' + } + } + ]}, + ], + tools = tools) + +# Step 1: 工具映射表 + TOOL_MAP = { + 'get_multi_molecular_text_to_correct': get_multi_molecular_text_to_correct, + 'get_reaction': get_reaction + } + + # Step 2: 处理多个工具调用 + tool_calls = response.choices[0].message.tool_calls + results = [] + + # 遍历每个工具调用 + for tool_call in tool_calls: + tool_name = tool_call.function.name + tool_arguments = tool_call.function.arguments + tool_call_id = tool_call.id + + tool_args = json.loads(tool_arguments) + + if tool_name in TOOL_MAP: + # 调用工具并获取结果 + tool_result = TOOL_MAP[tool_name](image_path) + else: + raise ValueError(f"Unknown tool called: {tool_name}") + + # 保存每个工具调用结果 + results.append({ + 'role': 'tool', + 'content': json.dumps({ + 'image_path': image_path, + f'{tool_name}':(tool_result), + }), + 'tool_call_id': tool_call_id, + }) + + +# Prepare the chat completion payload + completion_payload = { + 'model': 'gpt-4o', + 'messages': [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': prompt + }, + { + 'type': 'image_url', + 'image_url': { + 'url': f'data:image/png;base64,{base64_image}' + } + } + ] + }, + response.choices[0].message, + *results + ], + } + +# Generate new response + response = client.chat.completions.create( + model=completion_payload["model"], + messages=completion_payload["messages"], + response_format={ 'type': 'json_object' }, + temperature=0 + ) + + + + # 获取 GPT 生成的结果 + gpt_output = json.loads(response.choices[0].message.content) + print(gpt_output) + image = Image.open(image_path).convert('RGB') + image_np = np.array(image) + + + # reaction_results = model.extract_reactions_from_figures([image_np]) + coref_results = model.extract_molecule_corefs_from_figures([image_np]) + + reaction_results = get_reaction_withatoms_correctR(image_path)[0] + reaction = { + "reactants": reaction_results.get('reactants', []), + "conditions": reaction_results.get('conditions', []), + "products": reaction_results.get('products', []) + } + reaction_results = [{"reactions": [reaction]}] + print(reaction_results) + #coref_results = process_reaction_image_with_multiple_products_and_text_correctR(image_path) + + + # 定义更新工具输出的函数 + def extract_smiles_details(smiles_data, raw_details): + smiles_details = {} + for smiles in smiles_data: + for detail in raw_details: + for bbox in detail.get('bboxes', []): + if bbox.get('smiles') == smiles: + smiles_details[smiles] = { + 'category': bbox.get('category'), + 'bbox': bbox.get('bbox'), + 'category_id': bbox.get('category_id'), + 'score': bbox.get('score'), + 'molfile': bbox.get('molfile'), + 'atoms': bbox.get('atoms'), + 'bonds': bbox.get('bonds'), + } + break + return smiles_details + +# 获取结果 + smiles_details = extract_smiles_details(gpt_output, coref_results) + + reactants_array = [] + products = [] + + for reactant in reaction_results[0]['reactions'][0]['reactants']: + if 'smiles' in reactant: + print(f"SMILES:{reactant['smiles']}") + #print(reactant) + reactants_array.append(reactant['smiles']) + + for product in reaction_results[0]['reactions'][0]['products']: + #print(product['smiles']) + #print(product) + products.append(product['smiles']) + # 输出结果 + #import pprint + #pprint.pprint(smiles_details) + + # 整理反应数据 + backed_out = utils.backout_without_coref(reaction_results, coref_results, gpt_output, smiles_details, model.molscribe) + backed_out.sort(key=lambda x: x[2]) + extracted_rxns = {} + for reactants, products_, label in backed_out: + extracted_rxns[label] = {'reactants': reactants, 'products': products_} + + toadd = { + "reaction_template": { + "reactants": reactants_array, + "products": products + }, + "reactions": extracted_rxns, + "original_molecule_list": gpt_output + } + +# 按标签排序 + sorted_keys = sorted(toadd["reactions"].keys()) + toadd["reactions"] = {i: toadd["reactions"][i] for i in sorted_keys} + print(toadd) + return toadd + + + + +def ChemEagle(image_path: str) -> dict: + """ + 输入化学反应图像路径,通过 GPT 模型和 TOOLS 提取反应信息并返回整理后的反应数据。 + + Args: + image_path (str): 图像文件路径。 + + Returns: + dict: 整理后的反应数据,包括反应物、产物和反应模板。 + """ + # 配置 API Key 和 Azure Endpoint + api_key = os.getenv("CHEMEAGLE_API_KEY") + if not api_key: + raise RuntimeError("Missing CHEMEAGLE_API_KEY environment variable") + + azure_endpoint = "https://hkust.azure-api.net" # 替换为实际的 Azure Endpoint + + model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) + client = AzureOpenAI( + api_key=api_key, + api_version='2024-06-01', + azure_endpoint=azure_endpoint + ) + + # 加载图像并编码为 Base64 + def encode_image(image_path: str): + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode('utf-8') + + base64_image = encode_image(image_path) + + # GPT 工具调用配置 + tools = [ + { + 'type': 'function', + 'function': { + 'name': 'process_reaction_image', + 'description': 'get the reaction data of the reaction diagram and get SMILES strings of every detailed reaction in reaction diagram and the table, and the original molecular list.', + 'parameters': { + 'type': 'object', + 'properties': { + 'image_path': { + 'type': 'string', + 'description': 'The path to the reaction image.', + }, + }, + 'required': ['image_path'], + 'additionalProperties': False, + }, + }, + }, + ] + + # 提供给 GPT 的消息内容 + with open('./prompt_final_simple_version.txt', 'r') as prompt_file: + prompt = prompt_file.read() + messages = [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + { + 'role': 'user', + 'content': [ + {'type': 'text', 'text': prompt}, + {'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}} + ] + } + ] + + # 调用 GPT 接口 + response = client.chat.completions.create( + model = 'gpt-4o', + temperature = 0, + response_format={ 'type': 'json_object' }, + messages = [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': prompt + }, + { + 'type': 'image_url', + 'image_url': { + 'url': f'data:image/png;base64,{base64_image}' + } + } + ]}, + ], + tools = tools) + +# Step 1: 工具映射表 + TOOL_MAP = { + 'process_reaction_image': process_reaction_image + } + + # Step 2: 处理多个工具调用 + tool_calls = response.choices[0].message.tool_calls + results = [] + + # 遍历每个工具调用 + for tool_call in tool_calls: + tool_name = tool_call.function.name + tool_arguments = tool_call.function.arguments + tool_call_id = tool_call.id + + tool_args = json.loads(tool_arguments) + + if tool_name in TOOL_MAP: + # 调用工具并获取结果 + tool_result = TOOL_MAP[tool_name](image_path) + else: + raise ValueError(f"Unknown tool called: {tool_name}") + + # 保存每个工具调用结果 + results.append({ + 'role': 'tool', + 'content': json.dumps({ + 'image_path': image_path, + f'{tool_name}':(tool_result), + }), + 'tool_call_id': tool_call_id, + }) + + +# Prepare the chat completion payload + completion_payload = { + 'model': 'gpt-4o', + 'messages': [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': prompt + }, + { + 'type': 'image_url', + 'image_url': { + 'url': f'data:image/png;base64,{base64_image}' + } + } + ] + }, + response.choices[0].message, + *results + ], + } + +# Generate new response + response = client.chat.completions.create( + model=completion_payload["model"], + messages=completion_payload["messages"], + response_format={ 'type': 'json_object' }, + temperature=0 + ) + + + + # 获取 GPT 生成的结果 + gpt_output = json.loads(response.choices[0].message.content) + print(gpt_output) + return gpt_output \ No newline at end of file diff --git a/main_Rgroup_debug.ipynb b/main_Rgroup_debug.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..99768b1c4508a2c592220ce4a6d6e3805280c61f --- /dev/null +++ b/main_Rgroup_debug.ipynb @@ -0,0 +1,993 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "import torch\n", + "import json\n", + "from chemietoolkit import ChemIEToolkit\n", + "import cv2\n", + "from PIL import Image\n", + "import json\n", + "model = ChemIEToolkit(device=torch.device('cpu')) \n", + "from get_molecular_agent import process_reaction_image_with_multiple_products_and_text\n", + "from get_reaction_agent import get_reaction_withatoms\n", + "from get_reaction_agent import get_full_reaction\n", + "\n", + "\n", + "# 定义函数,接受多个图像路径并返回反应列表\n", + "def get_multi_molecular(image_path: str) -> list:\n", + " '''Returns a list of reactions extracted from the image.'''\n", + " # 打开图像文件\n", + " image = Image.open(image_path).convert('RGB')\n", + " \n", + " # 将图像作为输入传递给模型\n", + " coref_results = model.extract_molecule_corefs_from_figures([image])\n", + " \n", + " for item in coref_results:\n", + " for bbox in item.get(\"bboxes\", []):\n", + " for key in [\"category\", \"molfile\", \"symbols\", 'atoms', \"bonds\", 'category_id', 'score', 'corefs',\"coords\",\"edges\"]: #'atoms'\n", + " bbox.pop(key, None) # 安全地移除键\n", + " print(json.dumps(coref_results))\n", + " # 返回反应列表,使用 json.dumps 进行格式化\n", + " \n", + " return json.dumps(coref_results)\n", + "\n", + "def get_multi_molecular_text_to_correct(image_path: str) -> list:\n", + " '''Returns a list of reactions extracted from the image.'''\n", + " # 打开图像文件\n", + " image = Image.open(image_path).convert('RGB')\n", + " \n", + " # 将图像作为输入传递给模型\n", + " coref_results = model.extract_molecule_corefs_from_figures([image])\n", + " #coref_results = process_reaction_image_with_multiple_products_and_text(image_path)\n", + " for item in coref_results:\n", + " for bbox in item.get(\"bboxes\", []):\n", + " for key in [\"category\", \"bbox\", \"molfile\", \"symbols\", 'atoms', \"bonds\", 'category_id', 'score', 'corefs',\"coords\",\"edges\"]: #'atoms'\n", + " bbox.pop(key, None) # 安全地移除键\n", + " print(json.dumps(coref_results))\n", + " # 返回反应列表,使用 json.dumps 进行格式化\n", + " \n", + " return json.dumps(coref_results)\n", + "\n", + "def get_multi_molecular_text_to_correct_withatoms(image_path: str) -> list:\n", + " '''Returns a list of reactions extracted from the image.'''\n", + " # 打开图像文件\n", + " image = Image.open(image_path).convert('RGB')\n", + " \n", + " # 将图像作为输入传递给模型\n", + " #coref_results = model.extract_molecule_corefs_from_figures([image])\n", + " coref_results = process_reaction_image_with_multiple_products_and_text(image_path)\n", + " for item in coref_results:\n", + " for bbox in item.get(\"bboxes\", []):\n", + " for key in [\"molfile\", 'atoms', \"bonds\", 'category_id', 'score', 'corefs',\"coords\",\"edges\"]: #'atoms'\n", + " bbox.pop(key, None) # 安全地移除键\n", + " print(json.dumps(coref_results))\n", + " # 返回反应列表,使用 json.dumps 进行格式化\n", + " return json.dumps(coref_results)\n", + "\n", + "#get_multi_molecular_text_to_correct('./acs.joc.2c00176 example 1.png')\n", + "\n", + "import sys\n", + "#sys.path.append('./RxnScribe-main/')\n", + "import torch\n", + "from rxnscribe import RxnScribe\n", + "import json\n", + "\n", + "ckpt_path = \"./pix2seq_reaction_full.ckpt\"\n", + "model1 = RxnScribe(ckpt_path, device=torch.device('cpu'))\n", + "device = torch.device('cpu')\n", + "\n", + "def get_reaction(image_path: str) -> dict:\n", + " '''\n", + " Returns a structured dictionary of reactions extracted from the image,\n", + " including reactants, conditions, and products, with their smiles, text, and bbox.\n", + " '''\n", + " image_file = image_path\n", + " #raw_prediction = model1.predict_image_file(image_file, molscribe=True, ocr=True)\n", + " raw_prediction = get_reaction_withatoms(image_path)\n", + "\n", + " # Ensure raw_prediction is treated as a list directly\n", + " structured_output = {}\n", + " for section_key in ['reactants', 'conditions', 'products']:\n", + " if section_key in raw_prediction[0]:\n", + " structured_output[section_key] = []\n", + " for item in raw_prediction[0][section_key]:\n", + " if section_key in ['reactants', 'products']:\n", + " # Extract smiles and bbox for molecules\n", + " structured_output[section_key].append({\n", + " \"smiles\": item.get(\"smiles\", \"\"),\n", + " \"bbox\": item.get(\"bbox\", [])\n", + " })\n", + " elif section_key == 'conditions':\n", + " # Extract smiles, text, and bbox for conditions\n", + " condition_data = {\"bbox\": item.get(\"bbox\", [])}\n", + " if \"smiles\" in item:\n", + " condition_data[\"smiles\"] = item.get(\"smiles\", \"\")\n", + " if \"text\" in item:\n", + " condition_data[\"text\"] = item.get(\"text\", [])\n", + " structured_output[section_key].append(condition_data)\n", + " print(f\"structured_output:{structured_output}\")\n", + "\n", + " return structured_output\n", + "\n", + "\n", + "\n", + "\n", + "import base64\n", + "import torch\n", + "import json\n", + "from PIL import Image\n", + "import numpy as np\n", + "from chemietoolkit import ChemIEToolkit, utils\n", + "from openai import AzureOpenAI\n", + "\n", + "def process_reaction_image_with_multiple_products(image_path: str) -> dict:\n", + " \"\"\"\n", + " Args:\n", + " image_path (str): 图像文件路径。\n", + "\n", + " Returns:\n", + " dict: 整理后的反应数据,包括反应物、产物和反应模板。\n", + " \"\"\"\n", + " # 配置 API Key 和 Azure Endpoint\n", + " api_key = \"b038da96509b4009be931e035435e022\" # 替换为实际的 API Key\n", + " azure_endpoint = \"https://hkust.azure-api.net\" # 替换为实际的 Azure Endpoint\n", + " \n", + "\n", + " model = ChemIEToolkit(device=torch.device('cpu'))\n", + " client = AzureOpenAI(\n", + " api_key=api_key,\n", + " api_version='2024-06-01',\n", + " azure_endpoint=azure_endpoint\n", + " )\n", + "\n", + " # 加载图像并编码为 Base64\n", + " def encode_image(image_path: str):\n", + " with open(image_path, \"rb\") as image_file:\n", + " return base64.b64encode(image_file.read()).decode('utf-8')\n", + "\n", + " base64_image = encode_image(image_path)\n", + "\n", + " # GPT 工具调用配置\n", + " tools = [\n", + " {\n", + " 'type': 'function',\n", + " 'function': {\n", + " 'name': 'get_multi_molecular_text_to_correct',\n", + " 'description': 'Extracts the SMILES string and text coref from molecular images.',\n", + " 'parameters': {\n", + " 'type': 'object',\n", + " 'properties': {\n", + " 'image_path': {\n", + " 'type': 'string',\n", + " 'description': 'Path to the reaction image.'\n", + " }\n", + " },\n", + " 'required': ['image_path'],\n", + " 'additionalProperties': False\n", + " }\n", + " }\n", + " },\n", + " {\n", + " 'type': 'function',\n", + " 'function': {\n", + " 'name': 'get_reaction',\n", + " 'description': 'Get a list of reactions from a reaction image. A reaction contains data of the reactants, conditions, and products.',\n", + " 'parameters': {\n", + " 'type': 'object',\n", + " 'properties': {\n", + " 'image_path': {\n", + " 'type': 'string',\n", + " 'description': 'The path to the reaction image.',\n", + " },\n", + " },\n", + " 'required': ['image_path'],\n", + " 'additionalProperties': False,\n", + " },\n", + " },\n", + " },\n", + " ]\n", + "\n", + " # 提供给 GPT 的消息内容\n", + " with open('./prompt.txt', 'r') as prompt_file:\n", + " prompt = prompt_file.read()\n", + " messages = [\n", + " {'role': 'system', 'content': 'You are a helpful assistant.'},\n", + " {\n", + " 'role': 'user',\n", + " 'content': [\n", + " {'type': 'text', 'text': prompt},\n", + " {'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}}\n", + " ]\n", + " }\n", + " ]\n", + "\n", + " # 调用 GPT 接口\n", + " response = client.chat.completions.create(\n", + " model = 'gpt-4o',\n", + " temperature = 0,\n", + " response_format={ 'type': 'json_object' },\n", + " messages = [\n", + " {'role': 'system', 'content': 'You are a helpful assistant.'},\n", + " {\n", + " 'role': 'user',\n", + " 'content': [\n", + " {\n", + " 'type': 'text',\n", + " 'text': prompt\n", + " },\n", + " {\n", + " 'type': 'image_url',\n", + " 'image_url': {\n", + " 'url': f'data:image/png;base64,{base64_image}'\n", + " }\n", + " }\n", + " ]},\n", + " ],\n", + " tools = tools)\n", + " \n", + "# Step 1: 工具映射表\n", + " TOOL_MAP = {\n", + " 'get_multi_molecular_text_to_correct': get_multi_molecular_text_to_correct,\n", + " 'get_reaction': get_reaction\n", + " }\n", + "\n", + " # Step 2: 处理多个工具调用\n", + " tool_calls = response.choices[0].message.tool_calls\n", + " results = []\n", + "\n", + " # 遍历每个工具调用\n", + " for tool_call in tool_calls:\n", + " tool_name = tool_call.function.name\n", + " tool_arguments = tool_call.function.arguments\n", + " tool_call_id = tool_call.id\n", + " \n", + " tool_args = json.loads(tool_arguments)\n", + " \n", + " if tool_name in TOOL_MAP:\n", + " # 调用工具并获取结果\n", + " tool_result = TOOL_MAP[tool_name](image_path)\n", + " else:\n", + " raise ValueError(f\"Unknown tool called: {tool_name}\")\n", + " \n", + " # 保存每个工具调用结果\n", + " results.append({\n", + " 'role': 'tool',\n", + " 'content': json.dumps({\n", + " 'image_path': image_path,\n", + " f'{tool_name}':(tool_result),\n", + " }),\n", + " 'tool_call_id': tool_call_id,\n", + " })\n", + "\n", + "\n", + "# Prepare the chat completion payload\n", + " completion_payload = {\n", + " 'model': 'gpt-4o',\n", + " 'messages': [\n", + " {'role': 'system', 'content': 'You are a helpful assistant.'},\n", + " {\n", + " 'role': 'user',\n", + " 'content': [\n", + " {\n", + " 'type': 'text',\n", + " 'text': prompt\n", + " },\n", + " {\n", + " 'type': 'image_url',\n", + " 'image_url': {\n", + " 'url': f'data:image/png;base64,{base64_image}'\n", + " }\n", + " }\n", + " ]\n", + " },\n", + " response.choices[0].message,\n", + " *results\n", + " ],\n", + " }\n", + "\n", + "# Generate new response\n", + " response = client.chat.completions.create(\n", + " model=completion_payload[\"model\"],\n", + " messages=completion_payload[\"messages\"],\n", + " response_format={ 'type': 'json_object' },\n", + " temperature=0\n", + " )\n", + "\n", + "\n", + " \n", + " # 获取 GPT 生成的结果\n", + " gpt_output = json.loads(response.choices[0].message.content)\n", + " print(f\"gptout:{gpt_output}\")\n", + "\n", + " image = Image.open(image_path).convert('RGB')\n", + " image_np = np.array(image)\n", + "\n", + " #########################\n", + " #reaction_results = model.extract_reactions_from_figures([image_np])\n", + " reaction_results = get_reaction_withatoms(image_path)[0]\n", + " reactions = []\n", + " \n", + " # 将 reactants 和 products 转换为 reactions\n", + " for reactants, conditions, products in zip(reaction_results.get('reactants', []), reaction_results.get('conditions', []), reaction_results.get('products', [])):\n", + " reaction = {\n", + " \"reactants\": [reactants],\n", + " \"conditions\": [conditions],\n", + " \"products\": [products]\n", + " }\n", + " reactions.append(reaction)\n", + " reaction_results = [{\"reactions\": reactions}]\n", + " #coref_results = model.extract_molecule_corefs_from_figures([image_np])\n", + " coref_results = process_reaction_image_with_multiple_products_and_text(image_path)\n", + " ########################\n", + "\n", + " # 定义更新工具输出的函数\n", + " def extract_smiles_details(smiles_data, raw_details):\n", + " smiles_details = {}\n", + " for smiles in smiles_data:\n", + " for detail in raw_details:\n", + " for bbox in detail.get('bboxes', []):\n", + " if bbox.get('smiles') == smiles:\n", + " smiles_details[smiles] = {\n", + " 'category': bbox.get('category'),\n", + " 'bbox': bbox.get('bbox'),\n", + " 'category_id': bbox.get('category_id'),\n", + " 'score': bbox.get('score'),\n", + " 'molfile': bbox.get('molfile'),\n", + " 'atoms': bbox.get('atoms'),\n", + " 'bonds': bbox.get('bonds')\n", + " }\n", + " break\n", + " return smiles_details\n", + "\n", + "# 获取结果\n", + " smiles_details = extract_smiles_details(gpt_output, coref_results)\n", + "\n", + " reactants_array = []\n", + " products = []\n", + "\n", + " for reactant in reaction_results[0]['reactions'][0]['reactants']:\n", + " #for reactant in reaction_results[0]['reactions'][0]['reactants']:\n", + " if 'smiles' in reactant:\n", + " #print(reactant['smiles'])\n", + " #print(reactant)\n", + " reactants_array.append(reactant['smiles'])\n", + "\n", + " for product in reaction_results[0]['reactions'][0]['products']:\n", + " #print(product['smiles'])\n", + " #print(product)\n", + " products.append(product['smiles'])\n", + " # 输出结果\n", + " #import pprint\n", + " #pprint.pprint(smiles_details)\n", + "\n", + " # 整理反应数据\n", + " try:\n", + " backed_out = utils.backout_without_coref(reaction_results, coref_results, gpt_output, smiles_details, model.molscribe)\n", + " backed_out.sort(key=lambda x: x[2])\n", + " extracted_rxns = {}\n", + " for reactants, products_, label in backed_out:\n", + " extracted_rxns[label] = {'reactants': reactants, 'products': products_}\n", + "\n", + " toadd = {\n", + " \"reaction_template\": {\n", + " \"reactants\": reactants_array,\n", + " \"products\": products\n", + " },\n", + " \"reactions\": extracted_rxns\n", + " }\n", + " \n", + "\n", + " # 按标签排序\n", + " sorted_keys = sorted(toadd[\"reactions\"].keys())\n", + " toadd[\"reactions\"] = {i: toadd[\"reactions\"][i] for i in sorted_keys}\n", + " original_molecular_list = {'Original molecular list': gpt_output}\n", + " final_data= toadd.copy()\n", + " final_data.update(original_molecular_list)\n", + " except:\n", + " #pass\n", + " final_data = {'Original molecular list': gpt_output}\n", + "\n", + " print(final_data)\n", + " return final_data\n", + " \n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# # image_path = './example/Replace/99.jpg'\n", + "# # result = process_reaction_image(image_path)\n", + "# # print(json.dumps(result, indent=4))\n", + "# image_path = './example/example1/replace/Nesting/283.jpg'\n", + "# image = Image.open(image_path).convert('RGB')\n", + "# image_np = np.array(image)\n", + "\n", + "# # input1 = get_multi_molecular_text_to_correct_withatoms('./example/example1/replace/Nesting/283.jpg')\n", + "# # input2 = get_reaction('./example/example1/replace/Nesting/283.jpg')\n", + "# # print(input1)\n", + "# # print(input2)\n", + "# #reaction_results = model.extract_reactions_from_figures([image_np])\n", + "# coorf = model.extract_molecule_corefs_from_figures([image_np])\n", + "# print(coorf)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import base64\n", + "import torch\n", + "import json\n", + "from PIL import Image\n", + "import numpy as np\n", + "from openai import AzureOpenAI\n", + "\n", + "def process_reaction_image_final(image_path: str) -> dict:\n", + " \"\"\"\n", + "\n", + " Args:\n", + " image_path (str): 图像文件路径。\n", + "\n", + " Returns:\n", + " dict: 整理后的反应数据,包括反应物、产物和反应模板。\n", + " \"\"\"\n", + " # 配置 API Key 和 Azure Endpoint\n", + " api_key = \"b038da96509b4009be931e035435e022\" # 替换为实际的 API Key\n", + " azure_endpoint = \"https://hkust.azure-api.net\" # 替换为实际的 Azure Endpoint\n", + " \n", + "\n", + " model = ChemIEToolkit(device=torch.device('cpu'))\n", + " client = AzureOpenAI(\n", + " api_key=api_key,\n", + " api_version='2024-06-01',\n", + " azure_endpoint=azure_endpoint\n", + " )\n", + "\n", + " # 加载图像并编码为 Base64\n", + " def encode_image(image_path: str):\n", + " with open(image_path, \"rb\") as image_file:\n", + " return base64.b64encode(image_file.read()).decode('utf-8')\n", + "\n", + " base64_image = encode_image(image_path)\n", + "\n", + " # GPT 工具调用配置\n", + " tools = [\n", + " {\n", + " 'type': 'function',\n", + " 'function': {\n", + " 'name': 'get_multi_molecular_text_to_correct',\n", + " 'description': 'Extracts the SMILES string and text coref from molecular sub-images from a reaction image and ready for further process.',\n", + " 'parameters': {\n", + " 'type': 'object',\n", + " 'properties': {\n", + " 'image_path': {\n", + " 'type': 'string',\n", + " 'description': 'Path to the reaction image.'\n", + " }\n", + " },\n", + " 'required': ['image_path'],\n", + " 'additionalProperties': False\n", + " }\n", + " }\n", + " },\n", + " {\n", + " 'type': 'function',\n", + " 'function': {\n", + " 'name': 'get_reaction',\n", + " 'description': 'Get a list of reactions from a reaction image. A reaction contains data of the reactants, conditions, and products.',\n", + " 'parameters': {\n", + " 'type': 'object',\n", + " 'properties': {\n", + " 'image_path': {\n", + " 'type': 'string',\n", + " 'description': 'The path to the reaction image.',\n", + " },\n", + " },\n", + " 'required': ['image_path'],\n", + " 'additionalProperties': False,\n", + " },\n", + " },\n", + " },\n", + "\n", + " \n", + "\n", + " {\n", + " 'type': 'function',\n", + " 'function': {\n", + " 'name': 'process_reaction_image_with_multiple_products',\n", + " 'description': 'process the reaction image that contains a multiple products table. Get a list of reactions from the reaction image, Inculding the reaction template and detailed reaction with detailed R-group information.',\n", + " 'parameters': {\n", + " 'type': 'object',\n", + " 'properties': {\n", + " 'image_path': {\n", + " 'type': 'string',\n", + " 'description': 'The path to the reaction image.',\n", + " },\n", + " },\n", + " 'required': ['image_path'],\n", + " 'additionalProperties': False,\n", + " },\n", + " },\n", + " },\n", + "\n", + " {\n", + " 'type': 'function',\n", + " 'function': {\n", + " 'name': 'get_full_reaction',\n", + " 'description': 'Get a list of reactions from a reaction image without any tables. A reaction contains data of the reactants, conditions, and products.',\n", + " 'parameters': {\n", + " 'type': 'object',\n", + " 'properties': {\n", + " 'image_path': {\n", + " 'type': 'string',\n", + " 'description': 'The path to the reaction image.',\n", + " },\n", + " },\n", + " 'required': ['image_path'],\n", + " 'additionalProperties': False,\n", + " },\n", + " },\n", + " },\n", + "\n", + " {\n", + " 'type': 'function',\n", + " 'function': {\n", + " 'name': 'get_multi_molecular',\n", + " 'description': 'Extracts the SMILES string and text coref from a molecular image without any reactions',\n", + " 'parameters': {\n", + " 'type': 'object',\n", + " 'properties': {\n", + " 'image_path': {\n", + " 'type': 'string',\n", + " 'description': 'The path to the reaction image.',\n", + " },\n", + " },\n", + " 'required': ['image_path'],\n", + " 'additionalProperties': False,\n", + " },\n", + " },\n", + " },\n", + " ]\n", + "\n", + " # 提供给 GPT 的消息内容\n", + " with open('./prompt_final.txt', 'r') as prompt_file:\n", + " prompt = prompt_file.read()\n", + " messages = [\n", + " {'role': 'system', 'content': 'You are a helpful assistant.'},\n", + " {\n", + " 'role': 'user',\n", + " 'content': [\n", + " {'type': 'text', 'text': prompt},\n", + " {'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}}\n", + " ]\n", + " }\n", + " ]\n", + "\n", + " # 调用 GPT 接口\n", + " response = client.chat.completions.create(\n", + " model = 'gpt-4o',\n", + " temperature = 0,\n", + " response_format={ 'type': 'json_object' },\n", + " messages = [\n", + " {'role': 'system', 'content': 'You are a helpful assistant.'},\n", + " {\n", + " 'role': 'user',\n", + " 'content': [\n", + " {\n", + " 'type': 'text',\n", + " 'text': prompt\n", + " },\n", + " {\n", + " 'type': 'image_url',\n", + " 'image_url': {\n", + " 'url': f'data:image/png;base64,{base64_image}'\n", + " }\n", + " }\n", + " ]},\n", + " ],\n", + " tools = tools)\n", + " \n", + "# Step 1: 工具映射表\n", + " TOOL_MAP = {\n", + " 'get_multi_molecular_text_to_correct': get_multi_molecular_text_to_correct,\n", + " 'get_reaction': get_reaction,\n", + " 'process_reaction_image_with_multiple_products':process_reaction_image_with_multiple_products,\n", + "\n", + " 'get_full_reaction': get_full_reaction,\n", + " 'get_multi_molecular':get_multi_molecular,\n", + " }\n", + "\n", + " # Step 2: 处理多个工具调用\n", + " tool_calls = response.choices[0].message.tool_calls\n", + " results = []\n", + "\n", + " # 遍历每个工具调用\n", + " for tool_call in tool_calls:\n", + " tool_name = tool_call.function.name\n", + " tool_arguments = tool_call.function.arguments\n", + " tool_call_id = tool_call.id\n", + " \n", + " tool_args = json.loads(tool_arguments)\n", + " \n", + " if tool_name in TOOL_MAP:\n", + " # 调用工具并获取结果\n", + " tool_result = TOOL_MAP[tool_name](image_path)\n", + " else:\n", + " raise ValueError(f\"Unknown tool called: {tool_name}\")\n", + " \n", + " # 保存每个工具调用结果\n", + " results.append({\n", + " 'role': 'tool',\n", + " 'content': json.dumps({\n", + " 'image_path': image_path,\n", + " f'{tool_name}':(tool_result),\n", + " }),\n", + " 'tool_call_id': tool_call_id,\n", + " })\n", + "\n", + "\n", + "# Prepare the chat completion payload\n", + " completion_payload = {\n", + " 'model': 'gpt-4o',\n", + " 'messages': [\n", + " {'role': 'system', 'content': 'You are a helpful assistant.'},\n", + " {\n", + " 'role': 'user',\n", + " 'content': [\n", + " {\n", + " 'type': 'text',\n", + " 'text': prompt\n", + " },\n", + " {\n", + " 'type': 'image_url',\n", + " 'image_url': {\n", + " 'url': f'data:image/png;base64,{base64_image}'\n", + " }\n", + " }\n", + " ]\n", + " },\n", + " response.choices[0].message,\n", + " *results\n", + " ],\n", + " }\n", + "\n", + "# Generate new response\n", + " response = client.chat.completions.create(\n", + " model=completion_payload[\"model\"],\n", + " messages=completion_payload[\"messages\"],\n", + " response_format={ 'type': 'json_object' },\n", + " temperature=0\n", + " )\n", + "\n", + "\n", + " \n", + " # 获取 GPT 生成的结果\n", + " gpt_output = json.loads(response.choices[0].message.content)\n", + " print(gpt_output)\n", + " return gpt_output\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image_path = './data/bowen-4/2.png'\n", + "result = process_reaction_image_final(image_path)\n", + "print(json.dumps(result, indent=4))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# def get_reaction(image_path: str) -> list:\n", + "# '''Returns a list of reactions extracted from the image.'''\n", + "# image_file = image_path\n", + "# return json.dumps(model1.predict_image_file(image_file, molscribe=True, ocr=True))\n", + "\n", + "# reaction_output = get_reaction('./pdf/2/2_image_3_1.png')\n", + "# print(reaction_output)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import fitz # PyMuPDF\n", + "from core import run_visualheist\n", + "import base64\n", + "from openai import AzureOpenAI\n", + "\n", + "def full_pdf_extraction_pipeline_with_history(pdf_path,\n", + " output_dir,\n", + " api_key,\n", + " azure_endpoint,\n", + " model=\"gpt-4o\",\n", + " model_size=\"large\"):\n", + " \"\"\"\n", + " Full pipeline: from PDF to GPT-annotated related text.\n", + " Extracts markdown + figures + reaction data from a PDF and calls GPT-4o to annotate them.\n", + "\n", + " Args:\n", + " pdf_path (str): Path to input PDF file.\n", + " output_dir (str): Directory to save results.\n", + " api_key (str): Azure OpenAI API key.\n", + " azure_endpoint (str): Azure OpenAI endpoint.\n", + " model (str): GPT model name (default \"gpt-4o\").\n", + " model_size (str): VisualHeist model size (\"base\", \"large\", etc).\n", + "\n", + " Returns:\n", + " List of GPT-generated annotated related-text JSONs.\n", + " \"\"\"\n", + "\n", + "\n", + " os.makedirs(output_dir, exist_ok=True)\n", + "\n", + " # Step 1: Extract Markdown text\n", + " doc = fitz.open(pdf_path)\n", + " md_text = \"\"\n", + " for i, page in enumerate(doc, start=1):\n", + " md_text += f\"\\n\\n## = Page {i} =\\n\\n\" + page.get_text()\n", + " filename = os.path.splitext(os.path.basename(pdf_path))[0]\n", + " md_path = os.path.join(output_dir, f\"{filename}.md\")\n", + " with open(md_path, \"w\", encoding=\"utf-8\") as f:\n", + " f.write(md_text.strip())\n", + " print(f\"[✓] Markdown saved to: {md_path}\")\n", + "\n", + " # Step 2: Extract figures using VisualHeist\n", + " run_visualheist(pdf_dir=pdf_path, model_size=model_size, image_dir=output_dir)\n", + " print(f\"[✓] Figures extracted to: {output_dir}\")\n", + "\n", + " # Step 3: Parse figures to JSON\n", + " image_data = []\n", + " known_molecules = []\n", + "\n", + " for fname in sorted(os.listdir(output_dir)):\n", + " if fname.endswith(\".png\"):\n", + " img_path = os.path.join(output_dir, fname)\n", + " try:\n", + " result = process_reaction_image_final(img_path)\n", + " result[\"image_name\"] = fname\n", + " image_data.append(result)\n", + " except Exception as e:\n", + " print(f\"[!] Failed on {fname}: {e}\")\n", + " new_mols_json = get_multi_molecular_text_to_correct(img_path)\n", + " new_mols = json.loads(new_mols_json)\n", + " for m in new_mols:\n", + " if m[\"smiles\"] not in {km[\"smiles\"] for km in known_molecules}:\n", + " known_molecules.append(m)\n", + "\n", + "\n", + " json_path = os.path.join(output_dir, f\"{filename}_reaction_data.json\")\n", + " with open(json_path, \"w\", encoding=\"utf-8\") as f:\n", + " json.dump(image_data, f, indent=2, ensure_ascii=False)\n", + " print(f\"[✓] Reaction data saved to: {json_path}\")\n", + "\n", + " # Step 4: Call Azure GPT-4 for annotation\n", + " client = AzureOpenAI(\n", + " api_key=api_key,\n", + " api_version=\"2024-06-01\",\n", + " azure_endpoint=azure_endpoint\n", + " )\n", + "\n", + " prompt = \"\"\"\n", + "You are a text-mining assistant for chemistry papers. Your task is to find the most relevant 1–3 sentences in a research article that describe a given figure or scheme.\n", + "\n", + "You will be given:\n", + "- A block of text extracted from the article (in Markdown format).\n", + "- The extracted structured data from one image (including its title and list of molecules or reactions).\n", + "\n", + "Your task is:\n", + "1. Match the image with sentences that are most relevant to it. Use clues like the figure/scheme/table number in the title, or molecule/reaction labels (e.g., 1a, 2b, 3).\n", + "2. Extract up to 3 short sentences that best describe or mention the contents of the image.\n", + "3. In these sentences, label any molecule or reaction identifiers (like “1a”, “2b”) with their role based on context: [reactant], [product], etc.\n", + "4. Also label experimental conditions with their roles:\n", + " - Percent values like “85%” as [yield]\n", + " - Temperatures like “100 °C” as [temperature]\n", + " - Time durations like “24 h”, “20 min” as [time]\n", + "5. Do **not** label chemical position numbers (e.g., in \"3-trifluoromethyl\", \"1,2,4-triazole\").\n", + "6. Do not repeat any labels. Only label each item once per sentence.\n", + "\n", + "Output format:\n", + "{\n", + " \"title\": \"\",\n", + " \"related-text\": [\n", + " \"Sentence with roles like 1a[reactant], 2c[product], 100[temperature] °C.\",\n", + " ...\n", + " ]\n", + "}\n", + "\"\"\"\n", + "\n", + " annotated_results = []\n", + " for item in image_data:\n", + " img_path = os.path.join(output_dir, item[\"image_name\"])\n", + " with open(img_path, \"rb\") as f:\n", + " base64_image = base64.b64encode(f.read()).decode(\"utf-8\")\n", + "\n", + " combined_input = f\"\"\"\n", + "## Image Structured Data:\n", + "{json.dumps(item, indent=2)}\n", + "\n", + "## Article Text:\n", + "{md_text}\n", + "\"\"\"\n", + "\n", + " response = client.chat.completions.create(\n", + " model=model,\n", + " temperature=0,\n", + " response_format=\"json\",\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": [\n", + " {\"type\": \"text\", \"text\": prompt + \"\\n\\n\" + combined_input},\n", + " {\n", + " \"type\": \"image_url\",\n", + " \"image_url\": {\n", + " \"url\": f\"data:image/png;base64,{base64_image}\"\n", + " }\n", + " }\n", + " ]\n", + " }\n", + " ]\n", + " )\n", + " annotated_results.append(json.loads(response.choices[0].message.content))\n", + "\n", + " # Optionally save output\n", + " with open(os.path.join(output_dir, f\"{filename}_annotated_related_text.json\"), \"w\", encoding=\"utf-8\") as f:\n", + " json.dump(annotated_results, f, indent=2, ensure_ascii=False)\n", + " print(f\"[✓] Annotated related-text saved.\")\n", + "\n", + " return annotated_results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image_path = './data/example/example1/replace/Nesting/283.jpg'\n", + "#image_path = './pdf/2/2_image_1_1.png'\n", + "result = process_reaction_image_final(image_path)\n", + "print(json.dumps(result, indent=4))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# import os\n", + "\n", + "# image_folder = './example/example1/replace/regular/' # 图片文件夹路径\n", + "# output_folder = './batches_final_repalce_regular/' # 保存每批结果的文件夹路径\n", + "# batch_size = 3 # 每批处理文件数量\n", + "\n", + "# # 创建保存批次结果的文件夹(如果不存在)\n", + "# os.makedirs(output_folder, exist_ok=True)\n", + "\n", + "# # 获取所有图片文件并按字母顺序排序\n", + "# all_files = sorted([f for f in os.listdir(image_folder) if f.endswith('.jpg')])\n", + "\n", + "# # 获取已完成的批次\n", + "# completed_batches = [\n", + "# int(f.split('_')[1].split('.')[0]) for f in os.listdir(output_folder) if f.startswith('batch_') and f.endswith('.json')\n", + "# ]\n", + "# completed_batches = sorted(completed_batches) # 确保按顺序排序\n", + "\n", + "# # 从指定批次开始(如果有未完成批次)\n", + "# start_batch = (completed_batches[-1] + 1) if completed_batches else 1\n", + "\n", + "# # 将文件分批并从指定批次开始\n", + "# for batch_index in range((start_batch - 1) * batch_size, len(all_files), batch_size):\n", + "# batch_files = all_files[batch_index:batch_index + batch_size]\n", + "# results = []\n", + "\n", + "# batch_number = batch_index // batch_size + 1\n", + "# print(f\"正在按字母顺序处理第 {batch_number} 批,共 {len(batch_files)} 张图片...\")\n", + " \n", + "# for file_name in batch_files:\n", + "# image_path = os.path.join(image_folder, file_name)\n", + "# print(f\"处理文件 {file_name}...\")\n", + " \n", + "# try:\n", + "# # 处理单个图片\n", + "# result = process_reaction_image_final(image_path)\n", + " \n", + "# # 确保结果是字典\n", + "# if isinstance(result, dict):\n", + "# # 添加文件名信息\n", + "# result_with_filename = {\n", + "# \"file_name\": file_name,\n", + "# **result\n", + "# }\n", + "# results.append(result_with_filename)\n", + "# print(result_with_filename)\n", + "# else:\n", + "# print(f\"文件 {file_name} 的处理结果不是字典,跳过。\")\n", + " \n", + "# except Exception as e:\n", + "# print(f\"处理文件 {file_name} 时出错: {e}\")\n", + "\n", + "# # 保存当前批次结果\n", + "# batch_output_path = os.path.join(output_folder, f'batch_{batch_number}.json')\n", + "# with open(batch_output_path, 'w', encoding='utf-8') as json_file:\n", + "# json.dump(results, json_file, ensure_ascii=False, indent=4)\n", + "\n", + "# print(f\"第 {batch_number} 批处理完成,结果保存到 {batch_output_path}\")\n", + "\n", + "# print(\"所有批次处理完成!\")\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import rdkit\n", + "from rdkit import Chem\n", + "from rdkit.Chem import Draw\n", + "\n", + "Draw.MolToImage(Chem.MolFromSmiles('[Si](C)(C)OC(c1ccccc1)(c1ccccc1)C1CCC2=NN(Cc3ccccc3)=CN21'))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "openchemie", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/molscribe/__init__.py b/molscribe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4473879606f65474f5fe4d5ba1be97c695649852 --- /dev/null +++ b/molscribe/__init__.py @@ -0,0 +1 @@ +from .interface import MolScribe diff --git a/molscribe/__pycache__/__init__.cpython-310.pyc b/molscribe/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73b34e8c2d7fa7cd5895e58fc505b40e39107b98 Binary files /dev/null and b/molscribe/__pycache__/__init__.cpython-310.pyc differ diff --git a/molscribe/__pycache__/augment.cpython-310.pyc b/molscribe/__pycache__/augment.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee220a878064d6837bd04c03d9ddd6ebaa47e993 Binary files /dev/null and b/molscribe/__pycache__/augment.cpython-310.pyc differ diff --git a/molscribe/__pycache__/chemistry.cpython-310.pyc b/molscribe/__pycache__/chemistry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07a0e9583b0bba4b725c940e44ed76c1f4e0fd51 Binary files /dev/null and b/molscribe/__pycache__/chemistry.cpython-310.pyc differ diff --git a/molscribe/__pycache__/constants.cpython-310.pyc b/molscribe/__pycache__/constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e220fe63d488e3afaf34f7b826e6fbde44126bf7 Binary files /dev/null and b/molscribe/__pycache__/constants.cpython-310.pyc differ diff --git a/molscribe/__pycache__/dataset.cpython-310.pyc b/molscribe/__pycache__/dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad56ea018f7182e1c2f4631f3fb81afec47a4128 Binary files /dev/null and b/molscribe/__pycache__/dataset.cpython-310.pyc differ diff --git a/molscribe/__pycache__/interface.cpython-310.pyc b/molscribe/__pycache__/interface.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1d52ffac16b105df14eff2837fd4b4245e0bb5f Binary files /dev/null and b/molscribe/__pycache__/interface.cpython-310.pyc differ diff --git a/molscribe/__pycache__/model.cpython-310.pyc b/molscribe/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88101f543e84996b343631bad1dcfff17ff1313b Binary files /dev/null and b/molscribe/__pycache__/model.cpython-310.pyc differ diff --git a/molscribe/__pycache__/tokenizer.cpython-310.pyc b/molscribe/__pycache__/tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c27bad89a0e87d4fab2afec1165cd0917e06084 Binary files /dev/null and b/molscribe/__pycache__/tokenizer.cpython-310.pyc differ diff --git a/molscribe/__pycache__/utils.cpython-310.pyc b/molscribe/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64b69c0c372ae3288f9f28bf48a3160c423aed8f Binary files /dev/null and b/molscribe/__pycache__/utils.cpython-310.pyc differ diff --git a/molscribe/augment.py b/molscribe/augment.py new file mode 100644 index 0000000000000000000000000000000000000000..a80ebc505c4698238386902affd3377b7a7ea885 --- /dev/null +++ b/molscribe/augment.py @@ -0,0 +1,282 @@ +import albumentations as A +from albumentations.augmentations.geometric.functional import safe_rotate_enlarged_img_size, _maybe_process_in_chunks, \ + keypoint_rotate +import cv2 +import math +import random +import numpy as np + + +def safe_rotate( + img: np.ndarray, + angle: int = 0, + interpolation: int = cv2.INTER_LINEAR, + value: int = None, + border_mode: int = cv2.BORDER_REFLECT_101, +): + + old_rows, old_cols = img.shape[:2] + + # getRotationMatrix2D needs coordinates in reverse order (width, height) compared to shape + image_center = (old_cols / 2, old_rows / 2) + + # Rows and columns of the rotated image (not cropped) + new_rows, new_cols = safe_rotate_enlarged_img_size(angle=angle, rows=old_rows, cols=old_cols) + + # Rotation Matrix + rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0) + + # Shift the image to create padding + rotation_mat[0, 2] += new_cols / 2 - image_center[0] + rotation_mat[1, 2] += new_rows / 2 - image_center[1] + + # CV2 Transformation function + warp_affine_fn = _maybe_process_in_chunks( + cv2.warpAffine, + M=rotation_mat, + dsize=(new_cols, new_rows), + flags=interpolation, + borderMode=border_mode, + borderValue=value, + ) + + # rotate image with the new bounds + rotated_img = warp_affine_fn(img) + + return rotated_img + + +def keypoint_safe_rotate(keypoint, angle, rows, cols): + old_rows = rows + old_cols = cols + + # Rows and columns of the rotated image (not cropped) + new_rows, new_cols = safe_rotate_enlarged_img_size(angle=angle, rows=old_rows, cols=old_cols) + + col_diff = (new_cols - old_cols) / 2 + row_diff = (new_rows - old_rows) / 2 + + # Shift keypoint + shifted_keypoint = (int(keypoint[0] + col_diff), int(keypoint[1] + row_diff), keypoint[2], keypoint[3]) + + # Rotate keypoint + rotated_keypoint = keypoint_rotate(shifted_keypoint, angle, rows=new_rows, cols=new_cols) + + return rotated_keypoint + + +class SafeRotate(A.SafeRotate): + + def __init__( + self, + limit=90, + interpolation=cv2.INTER_LINEAR, + border_mode=cv2.BORDER_REFLECT_101, + value=None, + mask_value=None, + always_apply=False, + p=0.5, + ): + super(SafeRotate, self).__init__( + limit=limit, + interpolation=interpolation, + border_mode=border_mode, + value=value, + mask_value=mask_value, + always_apply=always_apply, + p=p) + + def apply(self, img, angle=0, interpolation=cv2.INTER_LINEAR, **params): + return safe_rotate( + img=img, value=self.value, angle=angle, interpolation=interpolation, border_mode=self.border_mode) + + def apply_to_keypoint(self, keypoint, angle=0, **params): + return keypoint_safe_rotate(keypoint, angle=angle, rows=params["rows"], cols=params["cols"]) + + +class CropWhite(A.DualTransform): + + def __init__(self, value=(255, 255, 255), pad=0, p=1.0): + super(CropWhite, self).__init__(p=p) + self.value = value + self.pad = pad + assert pad >= 0 + + def update_params(self, params, **kwargs): + super().update_params(params, **kwargs) + assert "image" in kwargs + img = kwargs["image"] + height, width, _ = img.shape + x = (img != self.value).sum(axis=2) + if x.sum() == 0: + return params + row_sum = x.sum(axis=1) + top = 0 + while row_sum[top] == 0 and top+1 < height: + top += 1 + bottom = height + while row_sum[bottom-1] == 0 and bottom-1 > top: + bottom -= 1 + col_sum = x.sum(axis=0) + left = 0 + while col_sum[left] == 0 and left+1 < width: + left += 1 + right = width + while col_sum[right-1] == 0 and right-1 > left: + right -= 1 + # crop_top = max(0, top - self.pad) + # crop_bottom = max(0, height - bottom - self.pad) + # crop_left = max(0, left - self.pad) + # crop_right = max(0, width - right - self.pad) + # params.update({"crop_top": crop_top, "crop_bottom": crop_bottom, + # "crop_left": crop_left, "crop_right": crop_right}) + params.update({"crop_top": top, "crop_bottom": height - bottom, + "crop_left": left, "crop_right": width - right}) + return params + + def apply(self, img, crop_top=0, crop_bottom=0, crop_left=0, crop_right=0, **params): + height, width, _ = img.shape + img = img[crop_top:height - crop_bottom, crop_left:width - crop_right] + img = A.augmentations.pad_with_params( + img, self.pad, self.pad, self.pad, self.pad, border_mode=cv2.BORDER_CONSTANT, value=self.value) + return img + + def apply_to_keypoint(self, keypoint, crop_top=0, crop_bottom=0, crop_left=0, crop_right=0, **params): + x, y, angle, scale = keypoint[:4] + return x - crop_left + self.pad, y - crop_top + self.pad, angle, scale + + def get_transform_init_args_names(self): + return ('value', 'pad') + + +class PadWhite(A.DualTransform): + + def __init__(self, pad_ratio=0.2, p=0.5, value=(255, 255, 255)): + super(PadWhite, self).__init__(p=p) + self.pad_ratio = pad_ratio + self.value = value + + def update_params(self, params, **kwargs): + super().update_params(params, **kwargs) + assert "image" in kwargs + img = kwargs["image"] + height, width, _ = img.shape + side = random.randrange(4) + if side == 0: + params['pad_top'] = int(height * self.pad_ratio * random.random()) + elif side == 1: + params['pad_bottom'] = int(height * self.pad_ratio * random.random()) + elif side == 2: + params['pad_left'] = int(width * self.pad_ratio * random.random()) + elif side == 3: + params['pad_right'] = int(width * self.pad_ratio * random.random()) + return params + + def apply(self, img, pad_top=0, pad_bottom=0, pad_left=0, pad_right=0, **params): + height, width, _ = img.shape + img = A.augmentations.pad_with_params( + img, pad_top, pad_bottom, pad_left, pad_right, border_mode=cv2.BORDER_CONSTANT, value=self.value) + return img + + def apply_to_keypoint(self, keypoint, pad_top=0, pad_bottom=0, pad_left=0, pad_right=0, **params): + x, y, angle, scale = keypoint[:4] + return x + pad_left, y + pad_top, angle, scale + + def get_transform_init_args_names(self): + return ('value', 'pad_ratio') + + +class SaltAndPepperNoise(A.DualTransform): + + def __init__(self, num_dots, value=(0, 0, 0), p=0.5): + super().__init__(p) + self.num_dots = num_dots + self.value = value + + def apply(self, img, **params): + height, width, _ = img.shape + num_dots = random.randrange(self.num_dots + 1) + for i in range(num_dots): + x = random.randrange(height) + y = random.randrange(width) + img[x, y] = self.value + return img + + def apply_to_keypoint(self, keypoint, **params): + return keypoint + + def get_transform_init_args_names(self): + return ('value', 'num_dots') + +class ResizePad(A.DualTransform): + + def __init__(self, height, width, interpolation=cv2.INTER_LINEAR, value=(255, 255, 255)): + super(ResizePad, self).__init__(always_apply=True) + self.height = height + self.width = width + self.interpolation = interpolation + self.value = value + + def apply(self, img, interpolation=cv2.INTER_LINEAR, **params): + h, w, _ = img.shape + img = A.augmentations.geometric.functional.resize( + img, + height=min(h, self.height), + width=min(w, self.width), + interpolation=interpolation + ) + h, w, _ = img.shape + pad_top = (self.height - h) // 2 + pad_bottom = (self.height - h) - pad_top + pad_left = (self.width - w) // 2 + pad_right = (self.width - w) - pad_left + img = A.augmentations.pad_with_params( + img, + pad_top, + pad_bottom, + pad_left, + pad_right, + border_mode=cv2.BORDER_CONSTANT, + value=self.value, + ) + return img + + +def normalized_grid_distortion( + img, + num_steps=10, + xsteps=(), + ysteps=(), + *args, + **kwargs +): + height, width = img.shape[:2] + + # compensate for smaller last steps in source image. + x_step = width // num_steps + last_x_step = min(width, ((num_steps + 1) * x_step)) - (num_steps * x_step) + xsteps[-1] *= last_x_step / x_step + + y_step = height // num_steps + last_y_step = min(height, ((num_steps + 1) * y_step)) - (num_steps * y_step) + ysteps[-1] *= last_y_step / y_step + + # now normalize such that distortion never leaves image bounds. + tx = width / math.floor(width / num_steps) + ty = height / math.floor(height / num_steps) + xsteps = np.array(xsteps) * (tx / np.sum(xsteps)) + ysteps = np.array(ysteps) * (ty / np.sum(ysteps)) + + # do actual distortion. + return A.augmentations.functional.grid_distortion(img, num_steps, xsteps, ysteps, *args, **kwargs) + + +class NormalizedGridDistortion(A.augmentations.transforms.GridDistortion): + def apply(self, img, stepsx=(), stepsy=(), interpolation=cv2.INTER_LINEAR, **params): + return normalized_grid_distortion(img, self.num_steps, stepsx, stepsy, interpolation, self.border_mode, + self.value) + + def apply_to_mask(self, img, stepsx=(), stepsy=(), **params): + return normalized_grid_distortion( + img, self.num_steps, stepsx, stepsy, cv2.INTER_NEAREST, self.border_mode, self.mask_value) + diff --git a/molscribe/chemistry.py b/molscribe/chemistry.py new file mode 100644 index 0000000000000000000000000000000000000000..b77f5065061b02754b3bf2a6df95f4d62f8d512b --- /dev/null +++ b/molscribe/chemistry.py @@ -0,0 +1,641 @@ +import copy +import traceback +import numpy as np +import multiprocessing +import itertools + +import rdkit +import rdkit.Chem as Chem + +rdkit.RDLogger.DisableLog('rdApp.*') + +from SmilesPE.pretokenizer import atomwise_tokenizer + +from .constants import RGROUP_SYMBOLS, ABBREVIATIONS, VALENCES, FORMULA_REGEX + + +def is_valid_mol(s, format_='atomtok'): + if format_ == 'atomtok': + mol = Chem.MolFromSmiles(s) + elif format_ == 'inchi': + if not s.startswith('InChI=1S'): + s = f"InChI=1S/{s}" + mol = Chem.MolFromInchi(s) + else: + raise NotImplemented + return mol is not None + + +def _convert_smiles_to_inchi(smiles): + try: + mol = Chem.MolFromSmiles(smiles) + inchi = Chem.MolToInchi(mol) + except: + inchi = None + return inchi + + +def convert_smiles_to_inchi(smiles_list, num_workers=16): + with multiprocessing.Pool(num_workers) as p: + inchi_list = p.map(_convert_smiles_to_inchi, smiles_list, chunksize=128) + n_success = sum([x is not None for x in inchi_list]) + r_success = n_success / len(inchi_list) + inchi_list = [x if x else 'InChI=1S/H2O/h1H2' for x in inchi_list] + return inchi_list, r_success + + +def merge_inchi(inchi1, inchi2): + replaced = 0 + inchi1 = copy.deepcopy(inchi1) + for i in range(len(inchi1)): + if inchi1[i] == 'InChI=1S/H2O/h1H2': + inchi1[i] = inchi2[i] + replaced += 1 + return inchi1, replaced + + +def _get_num_atoms(smiles): + try: + return Chem.MolFromSmiles(smiles).GetNumAtoms() + except: + return 0 + + +def get_num_atoms(smiles, num_workers=16): + if type(smiles) is str: + return _get_num_atoms(smiles) + with multiprocessing.Pool(num_workers) as p: + num_atoms = p.map(_get_num_atoms, smiles) + return num_atoms + + +def normalize_nodes(nodes, flip_y=True): + x, y = nodes[:, 0], nodes[:, 1] + minx, maxx = min(x), max(x) + miny, maxy = min(y), max(y) + x = (x - minx) / max(maxx - minx, 1e-6) + if flip_y: + y = (maxy - y) / max(maxy - miny, 1e-6) + else: + y = (y - miny) / max(maxy - miny, 1e-6) + return np.stack([x, y], axis=1) + + +def _verify_chirality(mol, coords, symbols, edges, debug=False): + try: + n = mol.GetNumAtoms() + # Make a temp mol to find chiral centers + mol_tmp = mol.GetMol() + Chem.SanitizeMol(mol_tmp) + + chiral_centers = Chem.FindMolChiralCenters( + mol_tmp, includeUnassigned=True, includeCIP=False, useLegacyImplementation=False) + chiral_center_ids = [idx for idx, _ in chiral_centers] # List[Tuple[int, any]] -> List[int] + + # correction to clear pre-condition violation (for some corner cases) + for bond in mol.GetBonds(): + if bond.GetBondType() == Chem.BondType.SINGLE: + bond.SetBondDir(Chem.BondDir.NONE) + + # Create conformer from 2D coordinate + conf = Chem.Conformer(n) + conf.Set3D(True) + for i, (x, y) in enumerate(coords): + conf.SetAtomPosition(i, (x, 1 - y, 0)) + mol.AddConformer(conf) + Chem.SanitizeMol(mol) + Chem.AssignStereochemistryFrom3D(mol) + # NOTE: seems that only AssignStereochemistryFrom3D can handle double bond E/Z + # So we do this first, remove the conformer and add back the 2D conformer for chiral correction + + mol.RemoveAllConformers() + conf = Chem.Conformer(n) + conf.Set3D(False) + for i, (x, y) in enumerate(coords): + conf.SetAtomPosition(i, (x, 1 - y, 0)) + mol.AddConformer(conf) + + # Magic, inferring chirality from coordinates and BondDir. DO NOT CHANGE. + Chem.SanitizeMol(mol) + Chem.AssignChiralTypesFromBondDirs(mol) + Chem.AssignStereochemistry(mol, force=True) + + # Second loop to reset any wedge/dash bond to be starting from the chiral center) + for i in chiral_center_ids: + for j in range(n): + if edges[i][j] == 5: + # assert edges[j][i] == 6 + mol.RemoveBond(i, j) + mol.AddBond(i, j, Chem.BondType.SINGLE) + mol.GetBondBetweenAtoms(i, j).SetBondDir(Chem.BondDir.BEGINWEDGE) + elif edges[i][j] == 6: + # assert edges[j][i] == 5 + mol.RemoveBond(i, j) + mol.AddBond(i, j, Chem.BondType.SINGLE) + mol.GetBondBetweenAtoms(i, j).SetBondDir(Chem.BondDir.BEGINDASH) + Chem.AssignChiralTypesFromBondDirs(mol) + Chem.AssignStereochemistry(mol, force=True) + + # reset chiral tags for non-carbon atom + for atom in mol.GetAtoms(): + if atom.GetSymbol() != "C": + atom.SetChiralTag(Chem.rdchem.ChiralType.CHI_UNSPECIFIED) + mol = mol.GetMol() + + except Exception as e: + if debug: + raise e + pass + return mol + + +def _parse_tokens(tokens: list): + """ + Parse tokens of condensed formula into list of pairs `(elt, num)` + where `num` is the multiplicity of the atom (or nested condensed formula) `elt` + Used by `_parse_formula`, which does the same thing but takes a formula in string form as input + """ + elements = [] + i = 0 + j = 0 + while i < len(tokens): + if tokens[i] == '(': + while j < len(tokens) and tokens[j] != ')': + j += 1 + elt = _parse_tokens(tokens[i + 1:j]) + else: + elt = tokens[i] + j += 1 + if j < len(tokens) and tokens[j].isnumeric(): + num = int(tokens[j]) + j += 1 + else: + num = 1 + elements.append((elt, num)) + i = j + return elements + + +def _parse_formula(formula: str): + """ + Parse condensed formula into list of pairs `(elt, num)` + where `num` is the subscript to the atom (or nested condensed formula) `elt` + Example: "C2H4O" -> [('C', 2), ('H', 4), ('O', 1)] + """ + tokens = FORMULA_REGEX.findall(formula) + # if ''.join(tokens) != formula: + # tokens = FORMULA_REGEX_BACKUP.findall(formula) + return _parse_tokens(tokens) + + +def _expand_carbon(elements: list): + """ + Given list of pairs `(elt, num)`, output single list of all atoms in order, + expanding carbon sequences (CaXb where a > 1 and X is halogen) if necessary + Example: [('C', 2), ('H', 4), ('O', 1)] -> ['C', 'H', 'H', 'C', 'H', 'H', 'O']) + """ + expanded = [] + i = 0 + while i < len(elements): + elt, num = elements[i] + # expand carbon sequence + if elt == 'C' and num > 1 and i + 1 < len(elements): + next_elt, next_num = elements[i + 1] + quotient, remainder = next_num // num, next_num % num + for _ in range(num): + expanded.append('C') + for _ in range(quotient): + expanded.append(next_elt) + for _ in range(remainder): + expanded.append(next_elt) + i += 2 + # recurse if `elt` itself is a list (nested formula) + elif isinstance(elt, list): + new_elt = _expand_carbon(elt) + for _ in range(num): + expanded.append(new_elt) + i += 1 + # simplest case: simply append `elt` `num` times + else: + for _ in range(num): + expanded.append(elt) + i += 1 + return expanded + + +def _expand_abbreviation(abbrev): + """ + Expand abbreviation into its SMILES; also converts [Rn] to [n*] + Used in `_condensed_formula_list_to_smiles` when encountering abbrev. in condensed formula + """ + if abbrev in ABBREVIATIONS: + return ABBREVIATIONS[abbrev].smiles + if abbrev in RGROUP_SYMBOLS or (abbrev[0] == 'R' and abbrev[1:].isdigit()): + if abbrev[1:].isdigit(): + return f'[{abbrev[1:]}*]' + return '*' + return f'[{abbrev}]' + + +def _get_bond_symb(bond_num): + """ + Get SMILES symbol for a bond given bond order + Used in `_condensed_formula_list_to_smiles` while writing the SMILES string + """ + if bond_num == 0: + return '.' + if bond_num == 1: + return '' + if bond_num == 2: + return '=' + if bond_num == 3: + return '#' + return '' + + +def _condensed_formula_list_to_smiles(formula_list, start_bond, end_bond=None, direction=None): + """ + Converts condensed formula (in the form of a list of symbols) to smiles + Input: + `formula_list`: e.g. ['C', 'H', 'H', 'N', ['C', 'H', 'H', 'H'], ['C', 'H', 'H', 'H']] for CH2N(CH3)2 + `start_bond`: # bonds attached to beginning of formula + `end_bond`: # bonds attached to end of formula (deduce automatically if None) + `direction` (1, -1, or None): direction in which to process the list (1: left to right; -1: right to left; None: deduce automatically) + Returns: + `smiles`: smiles corresponding to input condensed formula + `bonds_left`: bonds remaining at the end of the formula (for connecting back to main molecule); should equal `end_bond` if specified + `num_trials`: number of trials + `success` (bool): whether conversion was successful + """ + # `direction` not specified: try left to right; if fails, try right to left + if direction is None: + num_trials = 1 + for dir_choice in [1, -1]: + smiles, bonds_left, trials, success = _condensed_formula_list_to_smiles(formula_list, start_bond, end_bond, dir_choice) + num_trials += trials + if success: + return smiles, bonds_left, num_trials, success + return None, None, num_trials, False + assert direction == 1 or direction == -1 + + def dfs(smiles, bonds_left, cur_idx, add_idx): + """ + `smiles`: SMILES string so far + `cur_idx`: index (in list `formula`) of current atom (i.e. atom to which subsequent atoms are being attached) + `cur_flat_idx`: index of current atom in list of atom tokens of SMILES so far + `bonds_left`: bonds remaining on current atom for subsequent atoms to be attached to + `add_idx`: index (in list `formula`) of atom to be attached to current atom + `add_flat_idx`: index of atom to be added in list of atom tokens of SMILES so far + Note: "atom" could refer to nested condensed formula (e.g. CH3 in CH2N(CH3)2) + """ + num_trials = 1 + # end of formula: return result + if (direction == 1 and add_idx == len(formula_list)) or (direction == -1 and add_idx == -1): + if end_bond is not None and end_bond != bonds_left: + return smiles, bonds_left, num_trials, False + return smiles, bonds_left, num_trials, True + + # no more bonds but there are atoms remaining: conversion failed + if bonds_left <= 0: + return smiles, bonds_left, num_trials, False + to_add = formula_list[add_idx] # atom to be added to current atom + + if isinstance(to_add, list): # "atom" added is a list (i.e. nested condensed formula): assume valence of 1 + if bonds_left > 1: + # "atom" added does not use up remaining bonds of current atom + # get smiles of "atom" (which is itself a condensed formula) + add_str, val, trials, success = _condensed_formula_list_to_smiles(to_add, 1, None, direction) + if val > 0: + add_str = _get_bond_symb(val + 1) + add_str + num_trials += trials + if not success: + return smiles, bonds_left, num_trials, False + # put smiles of "atom" in parentheses and append to smiles; go to next atom to add to current atom + result = dfs(smiles + f'({add_str})', bonds_left - 1, cur_idx, add_idx + direction) + else: + # "atom" added uses up remaining bonds of current atom + # get smiles of "atom" and bonds left on it + add_str, bonds_left, trials, success = _condensed_formula_list_to_smiles(to_add, 1, None, direction) + num_trials += trials + if not success: + return smiles, bonds_left, num_trials, False + # append smiles of "atom" (without parentheses) to smiles; it becomes new current atom + result = dfs(smiles + add_str, bonds_left, add_idx, add_idx + direction) + smiles, bonds_left, trials, success = result + num_trials += trials + return smiles, bonds_left, num_trials, success + + # atom added is a single symbol (as opposed to nested condensed formula) + for val in VALENCES.get(to_add, [1]): # try all possible valences of atom added + add_str = _expand_abbreviation(to_add) # expand to smiles if symbol is abbreviation + if bonds_left > val: # atom added does not use up remaining bonds of current atom; go to next atom to add to current atom + if cur_idx >= 0: + add_str = _get_bond_symb(val) + add_str + result = dfs(smiles + f'({add_str})', bonds_left - val, cur_idx, add_idx + direction) + else: # atom added uses up remaining bonds of current atom; it becomes new current atom + if cur_idx >= 0: + add_str = _get_bond_symb(bonds_left) + add_str + result = dfs(smiles + add_str, val - bonds_left, add_idx, add_idx + direction) + trials, success = result[2:] + num_trials += trials + if success: + return result[0], result[1], num_trials, success + if num_trials > 10000: + break + return smiles, bonds_left, num_trials, False + + cur_idx = -1 if direction == 1 else len(formula_list) + add_idx = 0 if direction == 1 else len(formula_list) - 1 + return dfs('', start_bond, cur_idx, add_idx) + + +def get_smiles_from_symbol(symbol, mol, atom, bonds): + """ + Convert symbol (abbrev. or condensed formula) to smiles + If condensed formula, determine parsing direction and num. bonds on each side using coordinates + """ + if symbol in ABBREVIATIONS: + return ABBREVIATIONS[symbol].smiles + if len(symbol) > 20: + return None + + total_bonds = int(sum([bond.GetBondTypeAsDouble() for bond in bonds])) + formula_list = _expand_carbon(_parse_formula(symbol)) + smiles, bonds_left, num_trails, success = _condensed_formula_list_to_smiles(formula_list, total_bonds, None) + if success: + return smiles + return None + + +def _replace_functional_group(smiles): + smiles = smiles.replace('<unk>', 'C') + for i, r in enumerate(RGROUP_SYMBOLS): + symbol = f'[{r}]' + if symbol in smiles: + if r[0] == 'R' and r[1:].isdigit(): + smiles = smiles.replace(symbol, f'[{int(r[1:])}*]') + else: + smiles = smiles.replace(symbol, '*') + # For unknown tokens (i.e. rdkit cannot parse), replace them with [{isotope}*], where isotope is an identifier. + tokens = atomwise_tokenizer(smiles) + new_tokens = [] + mappings = {} # isotope : symbol + isotope = 50 + for token in tokens: + if token[0] == '[': + if token[1:-1] in ABBREVIATIONS or Chem.AtomFromSmiles(token) is None: + while f'[{isotope}*]' in smiles or f'[{isotope}*]' in new_tokens: + isotope += 1 + placeholder = f'[{isotope}*]' + mappings[isotope] = token[1:-1] + new_tokens.append(placeholder) + continue + new_tokens.append(token) + smiles = ''.join(new_tokens) + return smiles, mappings + + +def convert_smiles_to_mol(smiles): + if smiles is None or smiles == '': + return None + try: + mol = Chem.MolFromSmiles(smiles) + except: + return None + return mol + + +BOND_TYPES = {1: Chem.rdchem.BondType.SINGLE, 2: Chem.rdchem.BondType.DOUBLE, 3: Chem.rdchem.BondType.TRIPLE} + + +def _expand_functional_group(mol, mappings, debug=False): + def _need_expand(mol, mappings): + return any([len(Chem.GetAtomAlias(atom)) > 0 for atom in mol.GetAtoms()]) or len(mappings) > 0 + + if _need_expand(mol, mappings): + mol_w = Chem.RWMol(mol) + num_atoms = mol_w.GetNumAtoms() + for i, atom in enumerate(mol_w.GetAtoms()): # reset radical electrons + atom.SetNumRadicalElectrons(0) + + atoms_to_remove = [] + for i in range(num_atoms): + atom = mol_w.GetAtomWithIdx(i) + if atom.GetSymbol() == '*': + symbol = Chem.GetAtomAlias(atom) + isotope = atom.GetIsotope() + if isotope > 0 and isotope in mappings: + symbol = mappings[isotope] + if not (isinstance(symbol, str) and len(symbol) > 0): + continue + # rgroups do not need to be expanded + if symbol in RGROUP_SYMBOLS: + continue + + bonds = atom.GetBonds() + sub_smiles = get_smiles_from_symbol(symbol, mol_w, atom, bonds) + + # create mol object for abbreviation/condensed formula from its SMILES + mol_r = convert_smiles_to_mol(sub_smiles) + + if mol_r is None: + # atom.SetAtomicNum(6) + atom.SetIsotope(0) + continue + + # remove bonds connected to abbreviation/condensed formula + adjacent_indices = [bond.GetOtherAtomIdx(i) for bond in bonds] + for adjacent_idx in adjacent_indices: + mol_w.RemoveBond(i, adjacent_idx) + + adjacent_atoms = [mol_w.GetAtomWithIdx(adjacent_idx) for adjacent_idx in adjacent_indices] + for adjacent_atom, bond in zip(adjacent_atoms, bonds): + adjacent_atom.SetNumRadicalElectrons(int(bond.GetBondTypeAsDouble())) + + # get indices of atoms of main body that connect to substituent + bonding_atoms_w = adjacent_indices + # assume indices are concated after combine mol_w and mol_r + bonding_atoms_r = [mol_w.GetNumAtoms()] + for atm in mol_r.GetAtoms(): + if atm.GetNumRadicalElectrons() and atm.GetIdx() > 0: + bonding_atoms_r.append(mol_w.GetNumAtoms() + atm.GetIdx()) + + # combine main body and substituent into a single molecule object + combo = Chem.CombineMols(mol_w, mol_r) + + # connect substituent to main body with bonds + mol_w = Chem.RWMol(combo) + # if len(bonding_atoms_r) == 1: # substituent uses one atom to bond to main body + for atm in bonding_atoms_w: + bond_order = mol_w.GetAtomWithIdx(atm).GetNumRadicalElectrons() + mol_w.AddBond(atm, bonding_atoms_r[0], order=BOND_TYPES[bond_order]) + + # reset radical electrons + for atm in bonding_atoms_w: + mol_w.GetAtomWithIdx(atm).SetNumRadicalElectrons(0) + for atm in bonding_atoms_r: + mol_w.GetAtomWithIdx(atm).SetNumRadicalElectrons(0) + atoms_to_remove.append(i) + + # Remove atom in the end, otherwise the id will change + # Reverse the order and remove atoms with larger id first + atoms_to_remove.sort(reverse=True) + for i in atoms_to_remove: + mol_w.RemoveAtom(i) + smiles = Chem.MolToSmiles(mol_w) + mol = mol_w.GetMol() + else: + smiles = Chem.MolToSmiles(mol) + return smiles, mol + + +def _convert_graph_to_smiles(coords, symbols, edges, image=None, debug=False): + mol = Chem.RWMol() + n = len(symbols) + ids = [] + for i in range(n): + symbol = symbols[i] + if symbol[0] == '[': + symbol = symbol[1:-1] + if symbol in RGROUP_SYMBOLS: + atom = Chem.Atom("*") + if symbol[0] == 'R' and symbol[1:].isdigit(): + atom.SetIsotope(int(symbol[1:])) + Chem.SetAtomAlias(atom, symbol) + elif symbol in ABBREVIATIONS: + atom = Chem.Atom("*") + Chem.SetAtomAlias(atom, symbol) + else: + try: # try to get SMILES of atom + atom = Chem.AtomFromSmiles(symbols[i]) + atom.SetChiralTag(Chem.rdchem.ChiralType.CHI_UNSPECIFIED) + except: # otherwise, abbreviation or condensed formula + atom = Chem.Atom("*") + Chem.SetAtomAlias(atom, symbol) + + if atom.GetSymbol() == '*': + atom.SetProp('molFileAlias', symbol) + + idx = mol.AddAtom(atom) + assert idx == i + ids.append(idx) + + for i in range(n): + for j in range(i + 1, n): + if edges[i][j] == 1: + mol.AddBond(ids[i], ids[j], Chem.BondType.SINGLE) + elif edges[i][j] == 2: + mol.AddBond(ids[i], ids[j], Chem.BondType.DOUBLE) + elif edges[i][j] == 3: + mol.AddBond(ids[i], ids[j], Chem.BondType.TRIPLE) + elif edges[i][j] == 4: + mol.AddBond(ids[i], ids[j], Chem.BondType.AROMATIC) + elif edges[i][j] == 5: + mol.AddBond(ids[i], ids[j], Chem.BondType.SINGLE) + mol.GetBondBetweenAtoms(ids[i], ids[j]).SetBondDir(Chem.BondDir.BEGINWEDGE) + elif edges[i][j] == 6: + mol.AddBond(ids[i], ids[j], Chem.BondType.SINGLE) + mol.GetBondBetweenAtoms(ids[i], ids[j]).SetBondDir(Chem.BondDir.BEGINDASH) + + pred_smiles = '<invalid>' + + try: + # TODO: move to an util function + if image is not None: + height, width, _ = image.shape + ratio = width / height + coords = [[x * ratio * 10, y * 10] for x, y in coords] + mol = _verify_chirality(mol, coords, symbols, edges, debug) + # molblock is obtained before expanding func groups, otherwise the expanded group won't have coordinates. + # TODO: make sure molblock has the abbreviation information + pred_molblock = Chem.MolToMolBlock(mol) + pred_smiles, mol = _expand_functional_group(mol, {}, debug) + success = True + except Exception as e: + if debug: + print(traceback.format_exc()) + pred_molblock = '' + success = False + + if debug: + return pred_smiles, pred_molblock, mol, success + return pred_smiles, pred_molblock, success + + +def convert_graph_to_smiles(coords, symbols, edges, images=None, num_workers=16): + if images is None: + args_zip = zip(coords, symbols, edges) + else: + args_zip = zip(coords, symbols, edges, images) + + if num_workers <= 1: + results = itertools.starmap(_convert_graph_to_smiles, args_zip) + results = list(results) + else: + with multiprocessing.Pool(num_workers) as p: + results = p.starmap(_convert_graph_to_smiles, args_zip, chunksize=128) + + smiles_list, molblock_list, success = zip(*results) + r_success = np.mean(success) + return smiles_list, molblock_list, r_success + + +def _postprocess_smiles(smiles, coords=None, symbols=None, edges=None, molblock=False, debug=False): + if type(smiles) is not str or smiles == '': + return '', False + mol = None + pred_molblock = '' + try: + pred_smiles = smiles + pred_smiles, mappings = _replace_functional_group(pred_smiles) + if coords is not None and symbols is not None and edges is not None: + pred_smiles = pred_smiles.replace('@', '').replace('/', '').replace('\\', '') + mol = Chem.RWMol(Chem.MolFromSmiles(pred_smiles, sanitize=False)) + mol = _verify_chirality(mol, coords, symbols, edges, debug) + else: + mol = Chem.MolFromSmiles(pred_smiles, sanitize=False) + # pred_smiles = Chem.MolToSmiles(mol, isomericSmiles=True, canonical=True) + if molblock: + pred_molblock = Chem.MolToMolBlock(mol) + pred_smiles, mol = _expand_functional_group(mol, mappings) + success = True + except Exception as e: + if debug: + print(traceback.format_exc()) + pred_smiles = smiles + pred_molblock = '' + success = False + if debug: + return pred_smiles, pred_molblock, mol, success + return pred_smiles, pred_molblock, success + + +def postprocess_smiles(smiles, coords=None, symbols=None, edges=None, molblock=False, num_workers=16): + with multiprocessing.Pool(num_workers) as p: + if coords is not None and symbols is not None and edges is not None: + results = p.starmap(_postprocess_smiles, zip(smiles, coords, symbols, edges), chunksize=128) + else: + results = p.map(_postprocess_smiles, smiles, chunksize=128) + smiles_list, molblock_list, success = zip(*results) + r_success = np.mean(success) + return smiles_list, molblock_list, r_success + + +def _keep_main_molecule(smiles, debug=False): + try: + mol = Chem.MolFromSmiles(smiles) + frags = Chem.GetMolFrags(mol, asMols=True) + if len(frags) > 1: + num_atoms = [m.GetNumAtoms() for m in frags] + main_mol = frags[np.argmax(num_atoms)] + smiles = Chem.MolToSmiles(main_mol) + except Exception as e: + if debug: + print(traceback.format_exc()) + return smiles + + +def keep_main_molecule(smiles, num_workers=16): + with multiprocessing.Pool(num_workers) as p: + results = p.map(_keep_main_molecule, smiles, chunksize=128) + return results diff --git a/molscribe/constants.py b/molscribe/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..feb3ea3f91794ae8d839b86a480bae101e530193 --- /dev/null +++ b/molscribe/constants.py @@ -0,0 +1,126 @@ +from typing import List +import re + +ORGANIC_SET = {'B', 'C', 'N', 'O', 'P', 'S', 'F', 'Cl', 'Br', 'I'} + +RGROUP_SYMBOLS = ['R', 'R1', 'R2', 'R3', 'R4', 'R5', 'R6', 'R7', 'R8', 'R9', 'R10', 'R11', 'R12', + 'Ra', 'Rb', 'Rc', 'Rd', 'X', 'Y', 'Z', 'Q', 'A', 'E', 'Ar'] + +PLACEHOLDER_ATOMS = ["Lv", "Lu", "Nd", "Yb", "At", "Fm", "Er"] + + +class Substitution(object): + '''Define common substitutions for chemical shorthand''' + def __init__(self, abbrvs, smarts, smiles, probability): + assert type(abbrvs) is list + self.abbrvs = abbrvs + self.smarts = smarts + self.smiles = smiles + self.probability = probability + + +SUBSTITUTIONS: List[Substitution] = [ + Substitution(['NO2', 'O2N'], '[N+](=O)[O-]', "[N+](=O)[O-]", 0.5), + Substitution(['CHO', 'OHC'], '[CH1](=O)', "[CH1](=O)", 0.5), + Substitution(['CO2Et', 'COOEt'], 'C(=O)[OH0;D2][CH2;D2][CH3]', "[C](=O)OCC", 0.5), + + Substitution(['OAc'], '[OH0;X2]C(=O)[CH3]', "[O]C(=O)C", 0.7), + Substitution(['NHAc'], '[NH1;D2]C(=O)[CH3]', "[NH]C(=O)C", 0.7), + Substitution(['Ac'], 'C(=O)[CH3]', "[C](=O)C", 0.1), + + Substitution(['OBz'], '[OH0;D2]C(=O)[cH0]1[cH][cH][cH][cH][cH]1', "[O]C(=O)c1ccccc1", 0.7), # Benzoyl + Substitution(['Bz'], 'C(=O)[cH0]1[cH][cH][cH][cH][cH]1', "[C](=O)c1ccccc1", 0.2), # Benzoyl + + Substitution(['OBn'], '[OH0;D2][CH2;D2][cH0]1[cH][cH][cH][cH][cH]1', "[O]Cc1ccccc1", 0.7), # Benzyl + Substitution(['Bn'], '[CH2;D2][cH0]1[cH][cH][cH][cH][cH]1', "[CH2]c1ccccc1", 0.2), # Benzyl + + Substitution(['NHBoc'], '[NH1;D2]C(=O)OC([CH3])([CH3])[CH3]', "[NH1]C(=O)OC(C)(C)C", 0.6), + Substitution(['NBoc'], '[NH0;D3]C(=O)OC([CH3])([CH3])[CH3]', "[NH1]C(=O)OC(C)(C)C", 0.6), + Substitution(['Boc'], 'C(=O)OC([CH3])([CH3])[CH3]', "[C](=O)OC(C)(C)C", 0.2), + + Substitution(['ClC6H4', '2-ClC6H4'], 'c2ccccc2Cl', "c2ccccc2Cl", 0.4), + Substitution(['[CF3]2C6H3', '3,5-[CF3]2C6H3'], 'C1=CC(=CC(=C1C(F)(F)F)C(F)(F)F)', "C1=CC(=CC(=C1C(F)(F)F)C(F)(F)F)", 0.4), + + + + Substitution(['Cbm'], 'C(=O)[NH2;D1]', "[C](=O)N", 0.2), + Substitution(['Cbz'], 'C(=O)OC[cH]1[cH][cH][cH1][cH][cH]1', "[C](=O)OCc1ccccc1", 0.4), + Substitution(['Cy'], '[CH1;X3]1[CH2][CH2][CH2][CH2][CH2]1', "[CH1]1CCCCC1", 0.3), + Substitution(['Fmoc'], 'C(=O)O[CH2][CH1]1c([cH1][cH1][cH1][cH1]2)c2c3c1[cH1][cH1][cH1][cH1]3', + "[C](=O)OCC1c(cccc2)c2c3c1cccc3", 0.6), + Substitution(['Mes'], '[cH0]1c([CH3])cc([CH3])cc([CH3])1', "[c]1c(C)cc(C)cc(C)1", 0.5), + Substitution(['OMs'], '[OH0;D2]S(=O)(=O)[CH3]', "[O]S(=O)(=O)C", 0.7), + Substitution(['Ms'], 'S(=O)(=O)[CH3]', "[S](=O)(=O)C", 0.2), + Substitution(['Ph'], '[cH0]1[cH][cH][cH1][cH][cH]1', "[c]1ccccc1", 0.5), + Substitution(['PMB'], '[CH2;D2][cH0]1[cH1][cH1][cH0](O[CH3])[cH1][cH1]1', "[CH2]c1ccc(OC)cc1", 0.2), + Substitution(['Py'], '[cH0]1[n;+0][cH1][cH1][cH1][cH1]1', "[c]1ncccc1", 0.1), + Substitution(['SEM'], '[CH2;D2][CH2][Si]([CH3])([CH3])[CH3]', "[CH2]CSi(C)(C)C", 0.2), + Substitution(['Suc'], 'C(=O)[CH2][CH2]C(=O)[OH]', "[C](=O)CCC(=O)O", 0.2), + Substitution(['TBS'], '[Si]([CH3])([CH3])C([CH3])([CH3])[CH3]', "[Si](C)(C)C(C)(C)C", 0.5), + Substitution(['TBZ'], 'C(=S)[cH]1[cH][cH][cH1][cH][cH]1', "[C](=S)c1ccccc1", 0.2), + Substitution(['OTf'], '[OH0;D2]S(=O)(=O)C(F)(F)F', "[O]S(=O)(=O)C(F)(F)F", 0.7), + Substitution(['Tf'], 'S(=O)(=O)C(F)(F)F', "[S](=O)(=O)C(F)(F)F", 0.2), + Substitution(['TFA'], 'C(=O)C(F)(F)F', "[C](=O)C(F)(F)F", 0.3), + Substitution(['TMS'], '[Si]([CH3])([CH3])[CH3]', "[Si](C)(C)C", 0.5), + Substitution(['Ts'], 'S(=O)(=O)c1[cH1][cH1][cH0]([CH3])[cH1][cH1]1', "[S](=O)(=O)c1ccc(C)cc1", 0.6), # Tos + + # Alkyl chains + Substitution(['OMe', 'MeO'], '[OH0;D2][CH3;D1]', "[O]C", 0.3), + Substitution(['SMe', 'MeS'], '[SH0;D2][CH3;D1]', "[S]C", 0.3), + Substitution(['NMe', 'MeN'], '[N;X3][CH3;D1]', "[NH]C", 0.3), + Substitution(['Me'], '[CH3;D1]', "[CH3]", 0.1), + Substitution(['OEt', 'EtO'], '[OH0;D2][CH2;D2][CH3]', "[O]CC", 0.5), + Substitution(['Et', 'C2H5'], '[CH2;D2][CH3]', "[CH2]C", 0.3), + Substitution(['Pr', 'nPr', 'n-Pr'], '[CH2;D2][CH2;D2][CH3]', "[CH2]CC", 0.3), + Substitution(['Bu', 'nBu', 'n-Bu'], '[CH2;D2][CH2;D2][CH2;D2][CH3]', "[CH2]CCC", 0.3), + + # Branched + Substitution(['iPr', 'i-Pr'], '[CH1;D3]([CH3])[CH3]', "[CH1](C)C", 0.2), + Substitution(['iBu', 'i-Bu'], '[CH2;D2][CH1;D3]([CH3])[CH3]', "[CH2]C(C)C", 0.2), + Substitution(['OiBu'], '[OH0;D2][CH2;D2][CH1;D3]([CH3])[CH3]', "[O]CC(C)C", 0.2), + Substitution(['OtBu'], '[OH0;D2][CH0]([CH3])([CH3])[CH3]', "[O]C(C)(C)C", 0.6), + Substitution(['tBu', 't-Bu'], '[CH0]([CH3])([CH3])[CH3]', "[C](C)(C)C", 0.3), + + # Other shorthands (MIGHT NOT WANT ALL OF THESE) + Substitution(['CF3', 'F3C'], '[CH0;D4](F)(F)F', "[C](F)(F)F", 0.5), + Substitution(['NCF3', 'F3CN'], '[N;X3][CH0;D4](F)(F)F', "[NH]C(F)(F)F", 0.5), + Substitution(['OCF3', 'F3CO'], '[OH0;X2][CH0;D4](F)(F)F', "[O]C(F)(F)F", 0.5), + Substitution(['CCl3'], '[CH0;D4](Cl)(Cl)Cl', "[C](Cl)(Cl)Cl", 0.5), + Substitution(['CO2H', 'HO2C', 'COOH'], 'C(=O)[OH]', "[C](=O)O", 0.5), # COOH + Substitution(['CN', 'NC'], 'C#[ND1]', "[C]#N", 0.5), + Substitution(['OCH3', 'H3CO'], '[OH0;D2][CH3]', "[O]C", 0.4), + Substitution(['SO3H'], 'S(=O)(=O)[OH]', "[S](=O)(=O)O", 0.4), + +] + +ABBREVIATIONS = {abbrv: sub for sub in SUBSTITUTIONS for abbrv in sub.abbrvs} + +VALENCES = { + "H": [1], "Li": [1], "Be": [2], "B": [3], "C": [4], "N": [3, 5], "O": [2], "F": [1], + "Na": [1], "Mg": [2], "Al": [3], "Si": [4], "P": [5, 3], "S": [6, 2, 4], "Cl": [1], "K": [1], "Ca": [2], + "Br": [1], "I": [1] +} + +ELEMENTS = [ + "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", + "Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar", "K", "Ca", + "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn", + "Ga", "Ge", "As", "Se", "Br", "Kr", "Rb", "Sr", "Y", "Zr", + "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn", + "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd", + "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", + "Lu", "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", + "Tl", "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac", "Th", + "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm", + "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", + "Rg", "Cn", "Nh", "Fl", "Mc", "Lv", "Ts", "Og" +] + +COLORS = { + u'c': '0.0,0.75,0.75', u'b': '0.0,0.0,1.0', u'g': '0.0,0.5,0.0', u'y': '0.75,0.75,0', + u'k': '0.0,0.0,0.0', u'r': '1.0,0.0,0.0', u'm': '0.75,0,0.75' +} + +# tokens of condensed formula +FORMULA_REGEX = re.compile( + '(' + '|'.join(list(ABBREVIATIONS.keys())) + '|R[0-9]*|[A-Z][a-z]+|[A-Z]|[0-9]+|\(|\))') diff --git a/molscribe/dataset.py b/molscribe/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a0b34c0dd36dc7b8561dc5f44ba37edfbe6e3fe4 --- /dev/null +++ b/molscribe/dataset.py @@ -0,0 +1,594 @@ +import os +import cv2 +import time +import random +import re +import string +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset +from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence +import albumentations as A +from albumentations.pytorch import ToTensorV2 + +from .indigo import Indigo +from .indigo.renderer import IndigoRenderer + +from .augment import SafeRotate, CropWhite, PadWhite, SaltAndPepperNoise +from .utils import FORMAT_INFO +from .tokenizer import PAD_ID +from .chemistry import get_num_atoms, normalize_nodes +from .constants import RGROUP_SYMBOLS, SUBSTITUTIONS, ELEMENTS, COLORS + +cv2.setNumThreads(1) + +INDIGO_HYGROGEN_PROB = 0.2 +INDIGO_FUNCTIONAL_GROUP_PROB = 0.8 +INDIGO_CONDENSED_PROB = 0.5 +INDIGO_RGROUP_PROB = 0.5 +INDIGO_COMMENT_PROB = 0.3 +INDIGO_DEARMOTIZE_PROB = 0.8 +INDIGO_COLOR_PROB = 0.2 + + +def get_transforms(input_size, augment=True, rotate=True, debug=False): + trans_list = [] + if augment and rotate: + trans_list.append(SafeRotate(limit=90, border_mode=cv2.BORDER_CONSTANT, value=(255, 255, 255))) + trans_list.append(CropWhite(pad=5)) + if augment: + trans_list += [ + # NormalizedGridDistortion(num_steps=10, distort_limit=0.3), + A.CropAndPad(percent=[-0.01, 0.00], keep_size=False, p=0.5), + PadWhite(pad_ratio=0.4, p=0.2), + A.Downscale(scale_min=0.2, scale_max=0.5, interpolation=3), + A.Blur(), + A.GaussNoise(), + SaltAndPepperNoise(num_dots=20, p=0.5) + ] + trans_list.append(A.Resize(input_size, input_size)) + if not debug: + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + trans_list += [ + A.ToGray(p=1), + A.Normalize(mean=mean, std=std), + ToTensorV2(), + ] + return A.Compose(trans_list, keypoint_params=A.KeypointParams(format='xy', remove_invisible=False)) + + +def add_functional_group(indigo, mol, debug=False): + if random.random() > INDIGO_FUNCTIONAL_GROUP_PROB: + return mol + # Delete functional group and add a pseudo atom with its abbrv + substitutions = [sub for sub in SUBSTITUTIONS] + random.shuffle(substitutions) + for sub in substitutions: + query = indigo.loadSmarts(sub.smarts) + matcher = indigo.substructureMatcher(mol) + matched_atoms_ids = set() + for match in matcher.iterateMatches(query): + if random.random() < sub.probability or debug: + atoms = [] + atoms_ids = set() + for item in query.iterateAtoms(): + atom = match.mapAtom(item) + atoms.append(atom) + atoms_ids.add(atom.index()) + if len(matched_atoms_ids.intersection(atoms_ids)) > 0: + continue + abbrv = random.choice(sub.abbrvs) + superatom = mol.addAtom(abbrv) + for atom in atoms: + for nei in atom.iterateNeighbors(): + if nei.index() not in atoms_ids: + if nei.symbol() == 'H': + # indigo won't match explicit hydrogen, so remove them explicitly + atoms_ids.add(nei.index()) + else: + superatom.addBond(nei, nei.bond().bondOrder()) + for id in atoms_ids: + mol.getAtom(id).remove() + matched_atoms_ids = matched_atoms_ids.union(atoms_ids) + return mol + + +def add_explicit_hydrogen(indigo, mol): + atoms = [] + for atom in mol.iterateAtoms(): + try: + hs = atom.countImplicitHydrogens() + if hs > 0: + atoms.append((atom, hs)) + except: + continue + if len(atoms) > 0 and random.random() < INDIGO_HYGROGEN_PROB: + atom, hs = random.choice(atoms) + for i in range(hs): + h = mol.addAtom('H') + h.addBond(atom, 1) + return mol + + +def add_rgroup(indigo, mol, smiles): + atoms = [] + for atom in mol.iterateAtoms(): + try: + hs = atom.countImplicitHydrogens() + if hs > 0: + atoms.append(atom) + except: + continue + if len(atoms) > 0 and '*' not in smiles: + if random.random() < INDIGO_RGROUP_PROB: + atom_idx = random.choice(range(len(atoms))) + atom = atoms[atom_idx] + atoms.pop(atom_idx) + symbol = random.choice(RGROUP_SYMBOLS) + r = mol.addAtom(symbol) + r.addBond(atom, 1) + return mol + + +def get_rand_symb(): + symb = random.choice(ELEMENTS) + if random.random() < 0.1: + symb += random.choice(string.ascii_lowercase) + if random.random() < 0.1: + symb += random.choice(string.ascii_uppercase) + if random.random() < 0.1: + symb = f'({gen_rand_condensed()})' + return symb + + +def get_rand_num(): + if random.random() < 0.9: + if random.random() < 0.8: + return '' + else: + return str(random.randint(2, 9)) + else: + return '1' + str(random.randint(2, 9)) + + +def gen_rand_condensed(): + tokens = [] + for i in range(5): + if i >= 1 and random.random() < 0.8: + break + tokens.append(get_rand_symb()) + tokens.append(get_rand_num()) + return ''.join(tokens) + + +def add_rand_condensed(indigo, mol): + atoms = [] + for atom in mol.iterateAtoms(): + try: + hs = atom.countImplicitHydrogens() + if hs > 0: + atoms.append(atom) + except: + continue + if len(atoms) > 0 and random.random() < INDIGO_CONDENSED_PROB: + atom = random.choice(atoms) + symbol = gen_rand_condensed() + r = mol.addAtom(symbol) + r.addBond(atom, 1) + return mol + + +def generate_output_smiles(indigo, mol): + # TODO: if using mol.canonicalSmiles(), explicit H will be removed + smiles = mol.smiles() + mol = indigo.loadMolecule(smiles) + if '*' in smiles: + part_a, part_b = smiles.split(' ', maxsplit=1) + part_b = re.search(r'\$.*\$', part_b).group(0)[1:-1] + symbols = [t for t in part_b.split(';') if len(t) > 0] + output = '' + cnt = 0 + for i, c in enumerate(part_a): + if c != '*': + output += c + else: + output += f'[{symbols[cnt]}]' + cnt += 1 + return mol, output + else: + if ' ' in smiles: + # special cases with extension + smiles = smiles.split(' ')[0] + return mol, smiles + + +def add_comment(indigo): + if random.random() < INDIGO_COMMENT_PROB: + indigo.setOption('render-comment', str(random.randint(1, 20)) + random.choice(string.ascii_letters)) + indigo.setOption('render-comment-font-size', random.randint(40, 60)) + indigo.setOption('render-comment-alignment', random.choice([0, 0.5, 1])) + indigo.setOption('render-comment-position', random.choice(['top', 'bottom'])) + indigo.setOption('render-comment-offset', random.randint(2, 30)) + + +def add_color(indigo, mol): + if random.random() < INDIGO_COLOR_PROB: + indigo.setOption('render-coloring', True) + if random.random() < INDIGO_COLOR_PROB: + indigo.setOption('render-base-color', random.choice(list(COLORS.values()))) + if random.random() < INDIGO_COLOR_PROB: + if random.random() < 0.5: + indigo.setOption('render-highlight-color-enabled', True) + indigo.setOption('render-highlight-color', random.choice(list(COLORS.values()))) + if random.random() < 0.5: + indigo.setOption('render-highlight-thickness-enabled', True) + for atom in mol.iterateAtoms(): + if random.random() < 0.1: + atom.highlight() + return mol + + +def get_graph(mol, image, shuffle_nodes=False, pseudo_coords=False): + mol.layout() + coords, symbols = [], [] + index_map = {} + atoms = [atom for atom in mol.iterateAtoms()] + if shuffle_nodes: + random.shuffle(atoms) + for i, atom in enumerate(atoms): + if pseudo_coords: + x, y, z = atom.xyz() + else: + x, y = atom.coords() + coords.append([x, y]) + symbols.append(atom.symbol()) + index_map[atom.index()] = i + if pseudo_coords: + coords = normalize_nodes(np.array(coords)) + h, w, _ = image.shape + coords[:, 0] = coords[:, 0] * w + coords[:, 1] = coords[:, 1] * h + n = len(symbols) + edges = np.zeros((n, n), dtype=int) + for bond in mol.iterateBonds(): + s = index_map[bond.source().index()] + t = index_map[bond.destination().index()] + # 1/2/3/4 : single/double/triple/aromatic + edges[s, t] = bond.bondOrder() + edges[t, s] = bond.bondOrder() + if bond.bondStereo() in [5, 6]: + edges[s, t] = bond.bondStereo() + edges[t, s] = 11 - bond.bondStereo() + graph = { + 'coords': coords, + 'symbols': symbols, + 'edges': edges, + 'num_atoms': len(symbols) + } + return graph + + +def generate_indigo_image(smiles, mol_augment=True, default_option=False, shuffle_nodes=False, pseudo_coords=False, + include_condensed=True, debug=False): + indigo = Indigo() + renderer = IndigoRenderer(indigo) + indigo.setOption('render-output-format', 'png') + indigo.setOption('render-background-color', '1,1,1') + indigo.setOption('render-stereo-style', 'none') + indigo.setOption('render-label-mode', 'hetero') + indigo.setOption('render-font-family', 'Arial') + if not default_option: + thickness = random.uniform(0.5, 2) # limit the sum of the following two parameters to be smaller than 4 + indigo.setOption('render-relative-thickness', thickness) + indigo.setOption('render-bond-line-width', random.uniform(1, 4 - thickness)) + if random.random() < 0.5: + indigo.setOption('render-font-family', random.choice(['Arial', 'Times', 'Courier', 'Helvetica'])) + indigo.setOption('render-label-mode', random.choice(['hetero', 'terminal-hetero'])) + indigo.setOption('render-implicit-hydrogens-visible', random.choice([True, False])) + if random.random() < 0.1: + indigo.setOption('render-stereo-style', 'old') + if random.random() < 0.2: + indigo.setOption('render-atom-ids-visible', True) + + try: + mol = indigo.loadMolecule(smiles) + if mol_augment: + if random.random() < INDIGO_DEARMOTIZE_PROB: + mol.dearomatize() + else: + mol.aromatize() + smiles = mol.canonicalSmiles() + add_comment(indigo) + mol = add_explicit_hydrogen(indigo, mol) + mol = add_rgroup(indigo, mol, smiles) + if include_condensed: + mol = add_rand_condensed(indigo, mol) + mol = add_functional_group(indigo, mol, debug) + mol = add_color(indigo, mol) + mol, smiles = generate_output_smiles(indigo, mol) + + buf = renderer.renderToBuffer(mol) + img = cv2.imdecode(np.asarray(bytearray(buf), dtype=np.uint8), 1) # decode buffer to image + # img = np.repeat(np.expand_dims(img, 2), 3, axis=2) # expand to RGB + graph = get_graph(mol, img, shuffle_nodes, pseudo_coords) + success = True + except Exception: + if debug: + raise Exception + img = np.array([[[255., 255., 255.]] * 10] * 10).astype(np.float32) + graph = {} + success = False + return img, smiles, graph, success + + +class TrainDataset(Dataset): + def __init__(self, args, df, tokenizer, split='train', dynamic_indigo=False): + super().__init__() + self.df = df + self.args = args + self.tokenizer = tokenizer + if 'file_path' in df.columns: + self.file_paths = df['file_path'].values + if not self.file_paths[0].startswith(args.data_path): + self.file_paths = [os.path.join(args.data_path, path) for path in df['file_path']] + self.smiles = df['SMILES'].values if 'SMILES' in df.columns else None + self.formats = args.formats + self.labelled = (split == 'train') + if self.labelled: + self.labels = {} + for format_ in self.formats: + if format_ in ['atomtok', 'inchi']: + field = FORMAT_INFO[format_]['name'] + if field in df.columns: + self.labels[format_] = df[field].values + self.transform = get_transforms(args.input_size, + augment=(self.labelled and args.augment)) + # self.fix_transform = A.Compose([A.Transpose(p=1), A.VerticalFlip(p=1)]) + self.dynamic_indigo = (dynamic_indigo and split == 'train') + if self.labelled and not dynamic_indigo and args.coords_file is not None: + if args.coords_file == 'aux_file': + self.coords_df = df + self.pseudo_coords = True + else: + self.coords_df = pd.read_csv(args.coords_file) + self.pseudo_coords = False + else: + self.coords_df = None + self.pseudo_coords = args.pseudo_coords + + def __len__(self): + return len(self.df) + + def image_transform(self, image, coords=[], renormalize=False): + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # .astype(np.float32) + augmented = self.transform(image=image, keypoints=coords) + image = augmented['image'] + if len(coords) > 0: + coords = np.array(augmented['keypoints']) + if renormalize: + coords = normalize_nodes(coords, flip_y=False) + else: + _, height, width = image.shape + coords[:, 0] = coords[:, 0] / width + coords[:, 1] = coords[:, 1] / height + coords = np.array(coords).clip(0, 1) + return image, coords + return image + + def __getitem__(self, idx): + try: + return self.getitem(idx) + except Exception as e: + with open(os.path.join(self.args.save_path, f'error_dataset_{int(time.time())}.log'), 'w') as f: + f.write(str(e)) + raise e + + def getitem(self, idx): + ref = {} + if self.dynamic_indigo: + begin = time.time() + image, smiles, graph, success = generate_indigo_image( + self.smiles[idx], mol_augment=self.args.mol_augment, default_option=self.args.default_option, + shuffle_nodes=self.args.shuffle_nodes, pseudo_coords=self.pseudo_coords, + include_condensed=self.args.include_condensed) + # raw_image = image + end = time.time() + if idx < 30 and self.args.save_image: + path = os.path.join(self.args.save_path, 'images') + os.makedirs(path, exist_ok=True) + cv2.imwrite(os.path.join(path, f'{idx}.png'), image) + if not success: + return idx, None, {} + image, coords = self.image_transform(image, graph['coords'], renormalize=self.pseudo_coords) + graph['coords'] = coords + ref['time'] = end - begin + if 'atomtok' in self.formats: + max_len = FORMAT_INFO['atomtok']['max_len'] + label = self.tokenizer['atomtok'].text_to_sequence(smiles, tokenized=False) + ref['atomtok'] = torch.LongTensor(label[:max_len]) + if 'edges' in self.formats and 'atomtok_coords' not in self.formats and 'chartok_coords' not in self.formats: + ref['edges'] = torch.tensor(graph['edges']) + if 'atomtok_coords' in self.formats: + self._process_atomtok_coords(idx, ref, smiles, graph['coords'], graph['edges'], + mask_ratio=self.args.mask_ratio) + if 'chartok_coords' in self.formats: + self._process_chartok_coords(idx, ref, smiles, graph['coords'], graph['edges'], + mask_ratio=self.args.mask_ratio) + return idx, image, ref + else: + file_path = self.file_paths[idx] + image = cv2.imread(file_path) + if image is None: + image = np.array([[[255., 255., 255.]] * 10] * 10).astype(np.float32) + print(file_path, 'not found!') + if self.coords_df is not None: + h, w, _ = image.shape + coords = np.array(eval(self.coords_df.loc[idx, 'node_coords'])) + if self.pseudo_coords: + coords = normalize_nodes(coords) + coords[:, 0] = coords[:, 0] * w + coords[:, 1] = coords[:, 1] * h + image, coords = self.image_transform(image, coords, renormalize=self.pseudo_coords) + else: + image = self.image_transform(image) + coords = None + if self.labelled: + smiles = self.smiles[idx] + if 'atomtok' in self.formats: + max_len = FORMAT_INFO['atomtok']['max_len'] + label = self.tokenizer['atomtok'].text_to_sequence(smiles, False) + ref['atomtok'] = torch.LongTensor(label[:max_len]) + if 'atomtok_coords' in self.formats: + if coords is not None: + self._process_atomtok_coords(idx, ref, smiles, coords, mask_ratio=0) + else: + self._process_atomtok_coords(idx, ref, smiles, mask_ratio=1) + if 'chartok_coords' in self.formats: + if coords is not None: + self._process_chartok_coords(idx, ref, smiles, coords, mask_ratio=0) + else: + self._process_chartok_coords(idx, ref, smiles, mask_ratio=1) + if self.args.predict_coords and ('atomtok_coords' in self.formats or 'chartok_coords' in self.formats): + smiles = self.smiles[idx] + if 'atomtok_coords' in self.formats: + self._process_atomtok_coords(idx, ref, smiles, mask_ratio=1) + if 'chartok_coords' in self.formats: + self._process_chartok_coords(idx, ref, smiles, mask_ratio=1) + return idx, image, ref + + def _process_atomtok_coords(self, idx, ref, smiles, coords=None, edges=None, mask_ratio=0): + max_len = FORMAT_INFO['atomtok_coords']['max_len'] + tokenizer = self.tokenizer['atomtok_coords'] + if smiles is None or type(smiles) is not str: + smiles = "" + label, indices = tokenizer.smiles_to_sequence(smiles, coords, mask_ratio=mask_ratio) + ref['atomtok_coords'] = torch.LongTensor(label[:max_len]) + indices = [i for i in indices if i < max_len] + ref['atom_indices'] = torch.LongTensor(indices) + if tokenizer.continuous_coords: + if coords is not None: + ref['coords'] = torch.tensor(coords) + else: + ref['coords'] = torch.ones(len(indices), 2) * -1. + if edges is not None: + ref['edges'] = torch.tensor(edges)[:len(indices), :len(indices)] + else: + if 'edges' in self.df.columns: + edge_list = eval(self.df.loc[idx, 'edges']) + n = len(indices) + edges = torch.zeros((n, n), dtype=torch.long) + for u, v, t in edge_list: + if u < n and v < n: + if t <= 4: + edges[u, v] = t + edges[v, u] = t + else: + edges[u, v] = t + edges[v, u] = 11 - t + ref['edges'] = edges + else: + ref['edges'] = torch.ones(len(indices), len(indices), dtype=torch.long) * (-100) + + def _process_chartok_coords(self, idx, ref, smiles, coords=None, edges=None, mask_ratio=0): + max_len = FORMAT_INFO['chartok_coords']['max_len'] + tokenizer = self.tokenizer['chartok_coords'] + if smiles is None or type(smiles) is not str: + smiles = "" + label, indices = tokenizer.smiles_to_sequence(smiles, coords, mask_ratio=mask_ratio) + ref['chartok_coords'] = torch.LongTensor(label[:max_len]) + indices = [i for i in indices if i < max_len] + ref['atom_indices'] = torch.LongTensor(indices) + if tokenizer.continuous_coords: + if coords is not None: + ref['coords'] = torch.tensor(coords) + else: + ref['coords'] = torch.ones(len(indices), 2) * -1. + if edges is not None: + ref['edges'] = torch.tensor(edges)[:len(indices), :len(indices)] + else: + if 'edges' in self.df.columns: + edge_list = eval(self.df.loc[idx, 'edges']) + n = len(indices) + edges = torch.zeros((n, n), dtype=torch.long) + for u, v, t in edge_list: + if u < n and v < n: + if t <= 4: + edges[u, v] = t + edges[v, u] = t + else: + edges[u, v] = t + edges[v, u] = 11 - t + ref['edges'] = edges + else: + ref['edges'] = torch.ones(len(indices), len(indices), dtype=torch.long) * (-100) + + +class AuxTrainDataset(Dataset): + + def __init__(self, args, train_df, aux_df, tokenizer): + super().__init__() + self.train_dataset = TrainDataset(args, train_df, tokenizer, dynamic_indigo=args.dynamic_indigo) + self.aux_dataset = TrainDataset(args, aux_df, tokenizer, dynamic_indigo=False) + + def __len__(self): + return len(self.train_dataset) + len(self.aux_dataset) + + def __getitem__(self, idx): + if idx < len(self.train_dataset): + return self.train_dataset[idx] + else: + return self.aux_dataset[idx - len(self.train_dataset)] + + +def pad_images(imgs): + # B, C, H, W + max_shape = [0, 0] + for img in imgs: + for i in range(len(max_shape)): + max_shape[i] = max(max_shape[i], img.shape[-1 - i]) + stack = [] + for img in imgs: + pad = [] + for i in range(len(max_shape)): + pad = pad + [0, max_shape[i] - img.shape[-1 - i]] + stack.append(F.pad(img, pad, value=0)) + return torch.stack(stack) + + +def bms_collate(batch): + ids = [] + imgs = [] + batch = [ex for ex in batch if ex[1] is not None] + formats = list(batch[0][2].keys()) + seq_formats = [k for k in formats if + k in ['atomtok', 'inchi', 'nodes', 'atomtok_coords', 'chartok_coords', 'atom_indices']] + refs = {key: [[], []] for key in seq_formats} + for ex in batch: + ids.append(ex[0]) + imgs.append(ex[1]) + ref = ex[2] + for key in seq_formats: + refs[key][0].append(ref[key]) + refs[key][1].append(torch.LongTensor([len(ref[key])])) + # Sequence + for key in seq_formats: + # this padding should work for atomtok_with_coords too, each of which has shape (length, 4) + refs[key][0] = pad_sequence(refs[key][0], batch_first=True, padding_value=PAD_ID) + refs[key][1] = torch.stack(refs[key][1]).reshape(-1, 1) + # Time + # if 'time' in formats: + # refs['time'] = [ex[2]['time'] for ex in batch] + # Coords + if 'coords' in formats: + refs['coords'] = pad_sequence([ex[2]['coords'] for ex in batch], batch_first=True, padding_value=-1.) + # Edges + if 'edges' in formats: + edges_list = [ex[2]['edges'] for ex in batch] + max_len = max([len(edges) for edges in edges_list]) + refs['edges'] = torch.stack( + [F.pad(edges, (0, max_len - len(edges), 0, max_len - len(edges)), value=-100) for edges in edges_list], + dim=0) + return ids, pad_images(imgs), refs diff --git a/molscribe/evaluate.py b/molscribe/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..ff28829c50df4746f2398ae30acdeecd00f49508 --- /dev/null +++ b/molscribe/evaluate.py @@ -0,0 +1,79 @@ +import numpy as np +import multiprocessing + +import rdkit +import rdkit.Chem as Chem +rdkit.RDLogger.DisableLog('rdApp.*') +from SmilesPE.pretokenizer import atomwise_tokenizer + + +def canonicalize_smiles(smiles, ignore_chiral=False, ignore_cistrans=False, replace_rgroup=True): + if type(smiles) is not str or smiles == '': + return '', False + if ignore_cistrans: + smiles = smiles.replace('/', '').replace('\\', '') + if replace_rgroup: + tokens = atomwise_tokenizer(smiles) + for j, token in enumerate(tokens): + if token[0] == '[' and token[-1] == ']': + symbol = token[1:-1] + if symbol[0] == 'R' and symbol[1:].isdigit(): + tokens[j] = f'[{symbol[1:]}*]' + elif Chem.AtomFromSmiles(token) is None: + tokens[j] = '*' + smiles = ''.join(tokens) + try: + canon_smiles = Chem.CanonSmiles(smiles, useChiral=(not ignore_chiral)) + success = True + except: + canon_smiles = smiles + success = False + return canon_smiles, success + + +def convert_smiles_to_canonsmiles( + smiles_list, ignore_chiral=False, ignore_cistrans=False, replace_rgroup=True, num_workers=16): + with multiprocessing.Pool(num_workers) as p: + results = p.starmap(canonicalize_smiles, + [(smiles, ignore_chiral, ignore_cistrans, replace_rgroup) for smiles in smiles_list], + chunksize=128) + canon_smiles, success = zip(*results) + return list(canon_smiles), np.mean(success) + + +class SmilesEvaluator(object): + + def __init__(self, gold_smiles, num_workers=16): + self.gold_smiles = gold_smiles + self.gold_canon_smiles, self.gold_valid = convert_smiles_to_canonsmiles(gold_smiles, num_workers=num_workers) + self.gold_smiles_chiral, _ = convert_smiles_to_canonsmiles(gold_smiles, + ignore_chiral=True, num_workers=num_workers) + self.gold_smiles_cistrans, _ = convert_smiles_to_canonsmiles(gold_smiles, + ignore_cistrans=True, num_workers=num_workers) + self.gold_canon_smiles = self._replace_empty(self.gold_canon_smiles) + self.gold_smiles_chiral = self._replace_empty(self.gold_smiles_chiral) + self.gold_smiles_cistrans = self._replace_empty(self.gold_smiles_cistrans) + + def _replace_empty(self, smiles_list): + """Replace empty SMILES in the gold, otherwise it will be considered correct if both pred and gold is empty.""" + return [smiles if smiles is not None and type(smiles) is str and smiles != "" else "<empty>" + for smiles in smiles_list] + + def evaluate(self, pred_smiles): + results = {} + results['gold_valid'] = self.gold_valid + # Canon SMILES + pred_canon_smiles, pred_valid = convert_smiles_to_canonsmiles(pred_smiles) + results['canon_smiles_em'] = (np.array(self.gold_canon_smiles) == np.array(pred_canon_smiles)).mean() + results['pred_valid'] = pred_valid + # Ignore chirality (Graph exact match) + pred_smiles_chiral, _ = convert_smiles_to_canonsmiles(pred_smiles, ignore_chiral=True) + results['graph'] = (np.array(self.gold_smiles_chiral) == np.array(pred_smiles_chiral)).mean() + # Ignore double bond cis/trans + pred_smiles_cistrans, _ = convert_smiles_to_canonsmiles(pred_smiles, ignore_cistrans=True) + results['canon_smiles'] = (np.array(self.gold_smiles_cistrans) == np.array(pred_smiles_cistrans)).mean() + # Evaluate on molecules with chiral centers + chiral = np.array([[g, p] for g, p in zip(self.gold_smiles_cistrans, pred_smiles_cistrans) if '@' in g]) + results['chiral_ratio'] = len(chiral) / len(self.gold_smiles) + results['chiral'] = (chiral[:, 0] == chiral[:, 1]).mean() if len(chiral) > 0 else -1 + return results diff --git a/molscribe/indigo/__init__.py b/molscribe/indigo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7fbf42753036b29b9fc5a726d72db3cecb588c20 --- /dev/null +++ b/molscribe/indigo/__init__.py @@ -0,0 +1,4164 @@ +# +# +# Copyright (C) from 2009 to Present EPAM Systems. +# +# This file is part of Indigo toolkit. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import platform +import sys +import warnings +from array import array +from ctypes import (CDLL, POINTER, RTLD_GLOBAL, c_byte, c_char_p, c_double, + c_float, c_int, c_ulonglong, pointer) + +DECODE_ENCODING = "utf-8" +ENCODE_ENCODING = "utf-8" + + +class IndigoException(Exception): + def __init__(self, value): + if sys.version_info > (3, 0) and not isinstance(value, str): + self.value = value.decode(DECODE_ENCODING) + else: + self.value = value + + def __str__(self): + return self.value + + +class IndigoObject(object): + """Docstring for class IndigoObject.""" + + def __init__(self, dispatcher, id, parent=None): + self.id = id + self.dispatcher = dispatcher + self.parent = parent + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.dispatcher._setSessionId() + self.dispatcher._lib.indigoClose(self.id) + + def __del__(self): + self.dispose() + + def dispose(self): + if self.id >= 0: + if getattr(Indigo, "_lib", None) is not None: + self.dispatcher._setSessionId() + Indigo._lib.indigoFree(self.id) + self.id = -1 + + def __iter__(self): + return self + + def _next(self): + self.dispatcher._setSessionId() + newobj = self.dispatcher._checkResult(Indigo._lib.indigoNext(self.id)) + if newobj == 0: + return None + else: + return self.dispatcher.IndigoObject(self.dispatcher, newobj, self) + + def __next__(self): + obj = self._next() + if obj == None: + raise StopIteration + return obj + + def next(self): + return self.__next__() + + def oneBitsList(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoOneBitsList(self.id) + ) + + def mdlct(self): + buf = self.dispatcher.writeBuffer() + self.dispatcher._setSessionId() + self.dispatcher._checkResult( + Indigo._lib.indigoSaveMDLCT(self.id, buf.id) + ) + return buf.toBuffer() + + def xyz(self): + self.dispatcher._setSessionId() + xyz = Indigo._lib.indigoXYZ(self.id) + if xyz is None: + raise IndigoException(Indigo._lib.indigoGetLastError()) + return [xyz[0], xyz[1], xyz[2]] + + def coords(self): + self.dispatcher._setSessionId() + xyz = Indigo._lib.indigoCoords(self.id) + if xyz is None: + raise IndigoException(Indigo._lib.indigoGetLastError()) + return [xyz[0], xyz[1]] + + def alignAtoms(self, atom_ids, desired_xyz): + if len(atom_ids) * 3 != len(desired_xyz): + raise IndigoException( + "alignAtoms(): desired_xyz[] must be exactly 3 times bigger than atom_ids[]" + ) + atoms = (c_int * len(atom_ids))() + for i in range(len(atoms)): + atoms[i] = atom_ids[i] + xyz = (c_float * len(desired_xyz))() + for i in range(len(desired_xyz)): + xyz[i] = desired_xyz[i] + self.dispatcher._setSessionId() + return self.dispatcher._checkResultFloat( + self.dispatcher._lib.indigoAlignAtoms( + self.id, len(atoms), atoms, xyz + ) + ) + + def addStereocenter(self, type, v1, v2, v3, v4=-1): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoAddStereocenter(self.id, type, v1, v2, v3, v4) + ) + + def clone(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult(Indigo._lib.indigoClone(self.id)), + ) + + def check(self, props=""): + if props is None: + props = "" + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoCheck(self.id, props.encode(ENCODE_ENCODING)) + ) + + def close(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult(Indigo._lib.indigoClose(self.id)) + + def hasNext(self): + self.dispatcher._setSessionId() + return bool( + self.dispatcher._checkResult(Indigo._lib.indigoHasNext(self.id)) + ) + + def index(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult(Indigo._lib.indigoIndex(self.id)) + + def remove(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult(Indigo._lib.indigoRemove(self.id)) + + def saveMolfile(self, filename): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSaveMolfileToFile( + self.id, filename.encode(ENCODE_ENCODING) + ) + ) + + def molfile(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoMolfile(self.id) + ) + + def saveCml(self, filename): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSaveCmlToFile( + self.id, filename.encode(ENCODE_ENCODING) + ) + ) + + def cml(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoCml(self.id) + ) + + def saveCdxml(self, filename): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSaveCdxmlToFile( + self.id, filename.encode(ENCODE_ENCODING) + ) + ) + + def cdxml(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoCdxml(self.id) + ) + + def json(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoJson(self.id) + ) + + def saveMDLCT(self, output): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSaveMDLCT(self.id, output.id) + ) + + def addReactant(self, molecule): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoAddReactant(self.id, molecule.id) + ) + + def addProduct(self, molecule): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoAddProduct(self.id, molecule.id) + ) + + def addCatalyst(self, molecule): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoAddCatalyst(self.id, molecule.id) + ) + + def countReactants(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCountReactants(self.id) + ) + + def countProducts(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCountProducts(self.id) + ) + + def countCatalysts(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCountCatalysts(self.id) + ) + + def countMolecules(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCountMolecules(self.id) + ) + + def getMolecule(self, index): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoGetMolecule(self.id, index) + ), + ) + + def iterateReactants(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateReactants(self.id) + ), + ) + + def iterateProducts(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateProducts(self.id) + ), + ) + + def iterateCatalysts(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateCatalysts(self.id) + ), + ) + + def iterateMolecules(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateMolecules(self.id) + ), + ) + + def saveRxnfile(self, filename): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSaveRxnfileToFile( + self.id, filename.encode(ENCODE_ENCODING) + ) + ) + + def rxnfile(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoRxnfile(self.id) + ) + + def optimize(self, options=""): + if options is None: + options = "" + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoOptimize( + self.id, options.encode(ENCODE_ENCODING) + ) + ) + + def normalize(self, options=""): + if options is None: + options = "" + self.dispatcher._setSessionId() + return bool( + self.dispatcher._checkResult( + Indigo._lib.indigoNormalize( + self.id, options.encode(ENCODE_ENCODING) + ) + ) + ) + + def standardize(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoStandardize(self.id) + ) + + def ionize(self, pH, pH_toll): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoIonize(self.id, pH, pH_toll) + ) + + def getAcidPkaValue(self, atom, level, min_level): + self.dispatcher._setSessionId() + result = self.dispatcher._checkResultPtr( + Indigo._lib.indigoGetAcidPkaValue( + self.id, atom.id, level, min_level + ) + ) + return result[0] + + def getBasicPkaValue(self, atom, level, min_level): + self.dispatcher._setSessionId() + result = self.dispatcher._checkResultPtr( + Indigo._lib.indigoGetBasicPkaValue( + self.id, atom.id, level, min_level + ) + ) + return result[0] + + def automap(self, mode=""): + if mode is None: + mode = "" + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoAutomap(self.id, mode.encode(ENCODE_ENCODING)) + ) + + def atomMappingNumber(self, reaction_atom): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoGetAtomMappingNumber(self.id, reaction_atom.id) + ) + + def setAtomMappingNumber(self, reaction_atom, number): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetAtomMappingNumber( + self.id, reaction_atom.id, number + ) + ) + + def reactingCenter(self, reaction_bond): + value = c_int() + self.dispatcher._setSessionId() + res = self.dispatcher._checkResult( + Indigo._lib.indigoGetReactingCenter( + self.id, reaction_bond.id, pointer(value) + ) + ) + if res == 0: + return None + return value.value + + def setReactingCenter(self, reaction_bond, rc): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetReactingCenter(self.id, reaction_bond.id, rc) + ) + + def clearAAM(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoClearAAM(self.id) + ) + + def correctReactingCenters(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCorrectReactingCenters(self.id) + ) + + def iterateAtoms(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateAtoms(self.id) + ), + ) + + def iteratePseudoatoms(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIteratePseudoatoms(self.id) + ), + ) + + def iterateRSites(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateRSites(self.id) + ), + ) + + def iterateStereocenters(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateStereocenters(self.id) + ), + ) + + def iterateAlleneCenters(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateAlleneCenters(self.id) + ), + ) + + def iterateRGroups(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateRGroups(self.id) + ), + ) + + def countRGroups(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCountRGroups(self.id) + ) + + def isPseudoatom(self): + self.dispatcher._setSessionId() + return bool( + self.dispatcher._checkResult( + Indigo._lib.indigoIsPseudoatom(self.id) + ) + ) + + def isRSite(self): + self.dispatcher._setSessionId() + return bool( + self.dispatcher._checkResult(Indigo._lib.indigoIsRSite(self.id)) + ) + + def isTemplateAtom(self): + self.dispatcher._setSessionId() + return bool( + self.dispatcher._checkResult( + Indigo._lib.indigoIsTemplateAtom(self.id) + ) + ) + + def stereocenterType(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoStereocenterType(self.id) + ) + + def stereocenterGroup(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoStereocenterGroup(self.id) + ) + + def setStereocenterGroup(self, group): + self.dispatcher._setSessionId() + self.dispatcher._checkResult( + Indigo._lib.indigoSetStereocenterGroup(self.id, group) + ) + + def changeStereocenterType(self, type): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoChangeStereocenterType(self.id, type) + ) + + def validateChirality(self): + self.dispatcher._setSessionId() + self.dispatcher._checkResult( + Indigo._lib.indigoValidateChirality(self.id) + ) + + def singleAllowedRGroup(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSingleAllowedRGroup(self.id) + ) + + def iterateRGroupFragments(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateRGroupFragments(self.id) + ), + ) + + def countAttachmentPoints(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCountAttachmentPoints(self.id) + ) + + def iterateAttachmentPoints(self, order): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateAttachmentPoints(self.id, order) + ), + ) + + def symbol(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoSymbol(self.id) + ) + + def degree(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult(Indigo._lib.indigoDegree(self.id)) + + def charge(self): + value = c_int() + self.dispatcher._setSessionId() + res = self.dispatcher._checkResult( + Indigo._lib.indigoGetCharge(self.id, pointer(value)) + ) + if res == 0: + return None + return value.value + + def getExplicitValence(self): + value = c_int() + self.dispatcher._setSessionId() + res = self.dispatcher._checkResult( + Indigo._lib.indigoGetExplicitValence(self.id, pointer(value)) + ) + if res == 0: + return None + return value.value + + def setExplicitValence(self, valence): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetExplicitValence(self.id, valence) + ) + + def radicalElectrons(self): + value = c_int() + self.dispatcher._setSessionId() + res = self.dispatcher._checkResult( + Indigo._lib.indigoGetRadicalElectrons(self.id, pointer(value)) + ) + if res == 0: + return None + return value.value + + def radical(self): + value = c_int() + self.dispatcher._setSessionId() + res = self.dispatcher._checkResult( + Indigo._lib.indigoGetRadical(self.id, pointer(value)) + ) + if res == 0: + return None + return value.value + + def setRadical(self, radical): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetRadical(self.id, radical) + ) + + def atomicNumber(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoAtomicNumber(self.id) + ) + + def isotope(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult(Indigo._lib.indigoIsotope(self.id)) + + def valence(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult(Indigo._lib.indigoValence(self.id)) + + def checkValence(self): + + """ + :: + + Since version 1.3.0 + """ + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCheckValence(self.id) + ) + + def checkQuery(self): + """ + :: + + Since version 1.3.0 + """ + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCheckQuery(self.id) + ) + + def checkRGroups(self): + """ + :: + + Since version 1.3.0 + """ + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCheckRGroups(self.id) + ) + + def checkChirality(self): + + """ + :: + + Since version 1.3.0 + """ + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCheckChirality(self.id) + ) + + def check3DStereo(self): + + """ + :: + + Since version 1.3.0 + """ + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCheck3DStereo(self.id) + ) + + def checkStereo(self): + + """ + :: + + Since version 1.3.0 + """ + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCheckStereo(self.id) + ) + + def countHydrogens(self): + value = c_int() + self.dispatcher._setSessionId() + res = self.dispatcher._checkResult( + Indigo._lib.indigoCountHydrogens(self.id, pointer(value)) + ) + if res == 0: + return None + return value.value + + def countImplicitHydrogens(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCountImplicitHydrogens(self.id) + ) + + def setXYZ(self, x, y, z): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetXYZ(self.id, x, y, z) + ) + + def countSuperatoms(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCountSuperatoms(self.id) + ) + + def countDataSGroups(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCountDataSGroups(self.id) + ) + + def countRepeatingUnits(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCountRepeatingUnits(self.id) + ) + + def countMultipleGroups(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCountMultipleGroups(self.id) + ) + + def countGenericSGroups(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCountGenericSGroups(self.id) + ) + + def iterateDataSGroups(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateDataSGroups(self.id) + ), + ) + + def iterateSuperatoms(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateSuperatoms(self.id) + ), + ) + + def iterateGenericSGroups(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateGenericSGroups(self.id) + ), + ) + + def iterateRepeatingUnits(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateRepeatingUnits(self.id) + ), + ) + + def iterateMultipleGroups(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateMultipleGroups(self.id) + ), + ) + + def iterateSGroups(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateSGroups(self.id) + ), + ) + + def iterateTGroups(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateTGroups(self.id) + ), + ) + + def getSuperatom(self, index): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoGetSuperatom(self.id, index) + ), + ) + + def getDataSGroup(self, index): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoGetDataSGroup(self.id, index) + ), + ) + + def getGenericSGroup(self, index): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoGetGenericSGroup(self.id, index) + ), + ) + + def getMultipleGroup(self, index): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoGetMultipleGroup(self.id, index) + ), + ) + + def getRepeatingUnit(self, index): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoGetRepeatingUnit(self.id, index) + ), + ) + + def description(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoDescription(self.id) + ) + + def data(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoData(self.id) + ) + + def addDataSGroup(self, atoms, bonds, description, data): + arr2 = (c_int * len(atoms))() + for i in range(len(atoms)): + arr2[i] = atoms[i] + arr4 = (c_int * len(bonds))() + for i in range(len(bonds)): + arr4[i] = bonds[i] + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoAddDataSGroup( + self.id, + len(arr2), + arr2, + len(arr4), + arr4, + description.encode(ENCODE_ENCODING), + data.encode(ENCODE_ENCODING), + ) + ), + ) + + def addSuperatom(self, atoms, name): + arr2 = (c_int * len(atoms))() + for i in range(len(atoms)): + arr2[i] = atoms[i] + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoAddSuperatom( + self.id, len(arr2), arr2, name.encode(ENCODE_ENCODING) + ) + ), + ) + + def setDataSGroupXY(self, x, y, options=""): + self.dispatcher._setSessionId() + if options is None: + options = "" + return self.dispatcher._checkResult( + Indigo._lib.indigoSetDataSGroupXY( + self.id, x, y, options.encode(ENCODE_ENCODING) + ) + ) + + def setSGroupData(self, data): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetSGroupData( + self.id, data.encode(ENCODE_ENCODING) + ) + ) + + def setSGroupCoords(self, x, y): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetSGroupCoords(self.id, x, y) + ) + + def setSGroupDescription(self, description): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetSGroupDescription( + self.id, description.encode(ENCODE_ENCODING) + ) + ) + + def setSGroupFieldName(self, name): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetSGroupFieldName( + self.id, name.encode(ENCODE_ENCODING) + ) + ) + + def setSGroupQueryCode(self, code): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetSGroupQueryCode( + self.id, code.encode(ENCODE_ENCODING) + ) + ) + + def setSGroupQueryOper(self, oper): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetSGroupQueryOper( + self.id, oper.encode(ENCODE_ENCODING) + ) + ) + + def setSGroupDisplay(self, option): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetSGroupDisplay( + self.id, option.encode(ENCODE_ENCODING) + ) + ) + + def setSGroupLocation(self, option): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetSGroupLocation( + self.id, option.encode(ENCODE_ENCODING) + ) + ) + + def setSGroupTag(self, tag): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetSGroupTag( + self.id, tag.encode(ENCODE_ENCODING) + ) + ) + + def setSGroupTagAlign(self, tag_align): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetSGroupTagAlign(self.id, tag_align) + ) + + def setSGroupDataType(self, data_type): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetSGroupDataType( + self.id, data_type.encode(ENCODE_ENCODING) + ) + ) + + def setSGroupXCoord(self, x): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetSGroupXCoord(self.id, x) + ) + + def setSGroupYCoord(self, y): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetSGroupYCoord(self.id, y) + ) + + def createSGroup(self, sgtype, mapping, name): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoCreateSGroup( + sgtype.encode(ENCODE_ENCODING), + mapping.id, + name.encode(ENCODE_ENCODING), + ) + ), + ) + + def setSGroupClass(self, sgclass): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetSGroupClass( + self.id, sgclass.encode(ENCODE_ENCODING) + ) + ) + + def setSGroupName(self, sgname): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetSGroupName( + self.id, sgname.encode(ENCODE_ENCODING) + ) + ) + + def getSGroupClass(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoGetSGroupClass(self.id) + ) + + def getSGroupName(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoGetSGroupName(self.id) + ) + + def getSGroupNumCrossBonds(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoGetSGroupNumCrossBonds(self.id) + ) + + def addSGroupAttachmentPoint(self, aidx, lvidx, apid): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoAddSGroupAttachmentPoint( + self.id, aidx, lvidx, apid.encode(ENCODE_ENCODING) + ) + ) + + def deleteSGroupAttachmentPoint(self, apidx): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoDeleteSGroupAttachmentPoint(self.id, apidx) + ) + + def getSGroupDisplayOption(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoGetSGroupDisplayOption(self.id) + ) + + def setSGroupDisplayOption(self, option): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetSGroupDisplayOption(self.id, option) + ) + + def getSGroupSeqId(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoGetSGroupSeqId(self.id) + ) + + def getSGroupCoords(self): + """ + Returns: + XY coordinates for Data sgroup + :: + Since 1.3.0 + """ + self.dispatcher._setSessionId() + xyz = Indigo._lib.indigoGetSGroupCoords(self.id) + if xyz is None: + raise IndigoException(Indigo._lib.indigoGetLastError()) + return [xyz[0], xyz[1]] + + def getRepeatingUnitSubscript(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoGetRepeatingUnitSubscript(self.id) + ) + + def getRepeatingUnitConnectivity(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoGetRepeatingUnitConnectivity(self.id) + ) + + def getSGroupMultiplier(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoGetSGroupMultiplier(self.id) + ) + + def setSGroupMultiplier(self, mult): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetSGroupMultiplier(self.id, mult) + ) + + def setSGroupBrackets(self, style, x1, y1, x2, y2, x3, y3, x4, y4): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetSGroupBrackets( + self.id, style, x1, y1, x2, y2, x3, y3, x4, y4 + ) + ) + + def findSGroups(self, prop, val): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoFindSGroups( + self.id, + prop.encode(ENCODE_ENCODING), + val.encode(ENCODE_ENCODING), + ) + ), + ) + + def getSGroupType(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoGetSGroupType(self.id) + ) + + def getSGroupIndex(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoGetSGroupIndex(self.id) + ) + + def getSGroupOriginalId(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoGetSGroupOriginalId(self.id) + ) + + def setSGroupOriginalId(self, original): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetSGroupOriginalId(self.id, original) + ) + + def getSGroupParentId(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoGetSGroupParentId(self.id) + ) + + def setSGroupParentId(self, parent): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetSGroupParentId(self.id, parent) + ) + + def addTemplate(self, templates, name): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoAddTemplate( + self.id, templates.id, name.encode(ENCODE_ENCODING) + ) + ) + + def removeTemplate(self, name): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoRemoveTemplate( + self.id, name.encode(ENCODE_ENCODING) + ) + ) + + def findTemplate(self, name): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoFindTemplate( + self.id, name.encode(ENCODE_ENCODING) + ) + ) + + def getTGroupClass(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoGetTGroupClass(self.id) + ) + + def getTGroupName(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoGetTGroupName(self.id) + ) + + def getTGroupAlias(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoGetTGroupAlias(self.id) + ) + + def transformSCSRtoCTAB(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoTransformSCSRtoCTAB(self.id) + ) + + def transformCTABtoSCSR(self, templates): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoTransformCTABtoSCSR(self.id, templates.id) + ) + + def getTemplateAtomClass(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoGetTemplateAtomClass(self.id) + ) + + def setTemplateAtomClass(self, name): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetTemplateAtomClass( + self.id, name.encode(ENCODE_ENCODING) + ) + ) + + def clean2d(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult(Indigo._lib.indigoClean2d(self.id)) + + def resetCharge(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoResetCharge(self.id) + ) + + def resetExplicitValence(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoResetExplicitValence(self.id) + ) + + def resetRadical(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoResetRadical(self.id) + ) + + def resetIsotope(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoResetIsotope(self.id) + ) + + def setAttachmentPoint(self, order): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetAttachmentPoint(self.id, order) + ) + + def clearAttachmentPoints(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoClearAttachmentPoints(self.id) + ) + + def removeConstraints(self, type): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoRemoveConstraints( + self.id, type.encode(ENCODE_ENCODING) + ) + ) + + def addConstraint(self, type, value): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoAddConstraint( + self.id, + type.encode(ENCODE_ENCODING), + value.encode(ENCODE_ENCODING), + ) + ) + + def addConstraintNot(self, type, value): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoAddConstraintNot( + self.id, + type.encode(ENCODE_ENCODING), + value.encode(ENCODE_ENCODING), + ) + ) + + def addConstraintOr(self, type, value): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoAddConstraintOr( + self.id, + type.encode(ENCODE_ENCODING), + value.encode(ENCODE_ENCODING), + ) + ) + + def resetStereo(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoResetStereo(self.id) + ) + + def invertStereo(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoInvertStereo(self.id) + ) + + def countAtoms(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCountAtoms(self.id) + ) + + def countBonds(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCountBonds(self.id) + ) + + def countPseudoatoms(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCountPseudoatoms(self.id) + ) + + def countRSites(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCountRSites(self.id) + ) + + def iterateBonds(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateBonds(self.id) + ), + ) + + def bondOrder(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoBondOrder(self.id) + ) + + def bondStereo(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoBondStereo(self.id) + ) + + def topology(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoTopology(self.id) + ) + + def iterateNeighbors(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateNeighbors(self.id) + ), + ) + + def bond(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult(Indigo._lib.indigoBond(self.id)), + ) + + def getAtom(self, idx): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoGetAtom(self.id, idx) + ), + ) + + def getBond(self, idx): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoGetBond(self.id, idx) + ), + ) + + def source(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult(Indigo._lib.indigoSource(self.id)), + ) + + def destination(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoDestination(self.id) + ), + ) + + def clearCisTrans(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoClearCisTrans(self.id) + ) + + def clearStereocenters(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoClearStereocenters(self.id) + ) + + def countStereocenters(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCountStereocenters(self.id) + ) + + def clearAlleneCenters(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoClearAlleneCenters(self.id) + ) + + def countAlleneCenters(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCountAlleneCenters(self.id) + ) + + def resetSymmetricCisTrans(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoResetSymmetricCisTrans(self.id) + ) + + def resetSymmetricStereocenters(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoResetSymmetricStereocenters(self.id) + ) + + def markEitherCisTrans(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoMarkEitherCisTrans(self.id) + ) + + def markStereobonds(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoMarkStereobonds(self.id) + ) + + def addAtom(self, symbol): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoAddAtom( + self.id, symbol.encode(ENCODE_ENCODING) + ) + ), + ) + + def resetAtom(self, symbol): + self.dispatcher._setSessionId() + self.dispatcher._checkResult( + Indigo._lib.indigoResetAtom( + self.id, symbol.encode(ENCODE_ENCODING) + ) + ) + + def addRSite(self, name): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoAddRSite( + self.id, name.encode(ENCODE_ENCODING) + ) + ), + ) + + def setRSite(self, name): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetRSite(self.id, name.encode(ENCODE_ENCODING)) + ) + + def setCharge(self, charge): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetCharge(self.id, charge) + ) + + def setIsotope(self, isotope): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetIsotope(self.id, isotope) + ) + + def setImplicitHCount(self, impl_h): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetImplicitHCount(self.id, impl_h) + ) + + def addBond(self, destination, order): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoAddBond(self.id, destination.id, order) + ), + ) + + def setBondOrder(self, order): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoSetBondOrder(self.id, order) + ), + ) + + def merge(self, what): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoMerge(self.id, what.id) + ), + ) + + def highlight(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoHighlight(self.id) + ) + + def unhighlight(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoUnhighlight(self.id) + ) + + def isHighlighted(self): + self.dispatcher._setSessionId() + return bool( + self.dispatcher._checkResult( + Indigo._lib.indigoIsHighlighted(self.id) + ) + ) + + def countComponents(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCountComponents(self.id) + ) + + def componentIndex(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoComponentIndex(self.id) + ) + + def iterateComponents(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateComponents(self.id) + ), + ) + + def component(self, index): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoComponent(self.id, index) + ), + ) + + def countSSSR(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCountSSSR(self.id) + ) + + def iterateSSSR(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateSSSR(self.id) + ), + ) + + def iterateSubtrees(self, min_atoms, max_atoms): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateSubtrees( + self.id, min_atoms, max_atoms + ) + ), + ) + + def iterateRings(self, min_atoms, max_atoms): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateRings(self.id, min_atoms, max_atoms) + ), + ) + + def iterateEdgeSubmolecules(self, min_bonds, max_bonds): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateEdgeSubmolecules( + self.id, min_bonds, max_bonds + ) + ), + ) + + def countHeavyAtoms(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCountHeavyAtoms(self.id) + ) + + def grossFormula(self): + self.dispatcher._setSessionId() + gfid = self.dispatcher._checkResult( + Indigo._lib.indigoGrossFormula(self.id) + ) + gf = self.dispatcher.IndigoObject(self.dispatcher, gfid) + return self.dispatcher._checkResultString( + Indigo._lib.indigoToString(gf.id) + ) + + def molecularWeight(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultFloat( + Indigo._lib.indigoMolecularWeight(self.id) + ) + + def mostAbundantMass(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultFloat( + Indigo._lib.indigoMostAbundantMass(self.id) + ) + + def monoisotopicMass(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultFloat( + Indigo._lib.indigoMonoisotopicMass(self.id) + ) + + def massComposition(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoMassComposition(self.id) + ) + + def canonicalSmiles(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoCanonicalSmiles(self.id) + ) + + def canonicalSmarts(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoCanonicalSmarts(self.id) + ) + + def layeredCode(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoLayeredCode(self.id) + ) + + def symmetryClasses(self): + c_size = c_int() + self.dispatcher._setSessionId() + c_buf = self.dispatcher._checkResultPtr( + Indigo._lib.indigoSymmetryClasses(self.id, pointer(c_size)) + ) + res = array("i") + for i in range(c_size.value): + res.append(c_buf[i]) + return res + + def hasCoord(self): + self.dispatcher._setSessionId() + return bool( + self.dispatcher._checkResult(Indigo._lib.indigoHasCoord(self.id)) + ) + + def hasZCoord(self): + self.dispatcher._setSessionId() + return bool( + self.dispatcher._checkResult(Indigo._lib.indigoHasZCoord(self.id)) + ) + + def isChiral(self): + self.dispatcher._setSessionId() + return bool( + self.dispatcher._checkResult(Indigo._lib.indigoIsChiral(self.id)) + ) + + def isPossibleFischerProjection(self, options): + self.dispatcher._setSessionId() + return bool( + self.dispatcher._checkResult( + Indigo._lib.indigoIsPossibleFischerProjection( + self.id, options.encode(ENCODE_ENCODING) + ) + ) + ) + + def createSubmolecule(self, vertices): + arr2 = (c_int * len(vertices))() + for i in range(len(vertices)): + arr2[i] = vertices[i] + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoCreateSubmolecule(self.id, len(arr2), arr2) + ), + ) + + def createEdgeSubmolecule(self, vertices, edges): + arr2 = (c_int * len(vertices))() + for i in range(len(vertices)): + arr2[i] = vertices[i] + arr4 = (c_int * len(edges))() + for i in range(len(edges)): + arr4[i] = edges[i] + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoCreateEdgeSubmolecule( + self.id, len(arr2), arr2, len(arr4), arr4 + ) + ), + ) + + def getSubmolecule(self, vertices): + arr2 = (c_int * len(vertices))() + for i in range(len(vertices)): + arr2[i] = vertices[i] + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoGetSubmolecule(self.id, len(arr2), arr2) + ), + self, + ) + + def removeAtoms(self, vertices): + arr2 = (c_int * len(vertices))() + for i in range(len(vertices)): + arr2[i] = vertices[i] + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoRemoveAtoms(self.id, len(arr2), arr2) + ) + + def removeBonds(self, bonds): + arr2 = (c_int * len(bonds))() + for i in range(len(bonds)): + arr2[i] = bonds[i] + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoRemoveBonds(self.id, len(arr2), arr2) + ) + + def aromatize(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoAromatize(self.id) + ) + + def dearomatize(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoDearomatize(self.id) + ) + + def foldHydrogens(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoFoldHydrogens(self.id) + ) + + def unfoldHydrogens(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoUnfoldHydrogens(self.id) + ) + + def layout(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult(Indigo._lib.indigoLayout(self.id)) + + def smiles(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoSmiles(self.id) + ) + + def smarts(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoSmarts(self.id) + ) + + def name(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoName(self.id) + ) + + def setName(self, name): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetName(self.id, name.encode(ENCODE_ENCODING)) + ) + + def serialize(self): + c_size = c_int() + c_buf = POINTER(c_byte)() + self.dispatcher._setSessionId() + self.dispatcher._checkResult( + Indigo._lib.indigoSerialize( + self.id, pointer(c_buf), pointer(c_size) + ) + ) + res = array("b") + for i in range(c_size.value): + res.append(c_buf[i]) + return res + + def hasProperty(self, prop): + self.dispatcher._setSessionId() + return bool( + self.dispatcher._checkResult( + Indigo._lib.indigoHasProperty(self.id, prop) + ) + ) + + def getProperty(self, prop): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoGetProperty( + self.id, prop.encode(ENCODE_ENCODING) + ) + ) + + def setProperty(self, prop, value): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSetProperty( + self.id, + prop.encode(ENCODE_ENCODING), + value.encode(ENCODE_ENCODING), + ) + ) + + def removeProperty(self, prop): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoRemoveProperty( + self.id, prop.encode(ENCODE_ENCODING) + ) + ) + + def iterateProperties(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateProperties(self.id) + ), + ) + + def clearProperties(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoClearProperties(self.id) + ) + + def checkBadValence(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoCheckBadValence(self.id) + ) + + def checkAmbiguousH(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoCheckAmbiguousH(self.id) + ) + + def fingerprint(self, type): + self.dispatcher._setSessionId() + newobj = self.dispatcher._checkResult( + Indigo._lib.indigoFingerprint( + self.id, type.encode(ENCODE_ENCODING) + ) + ) + if newobj == 0: + return None + return self.dispatcher.IndigoObject(self.dispatcher, newobj, self) + + def countBits(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCountBits(self.id) + ) + + def rawData(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoRawData(self.id) + ) + + def tell(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult(Indigo._lib.indigoTell(self.id)) + + def sdfAppend(self, item): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSdfAppend(self.id, item.id) + ) + + def smilesAppend(self, item): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoSmilesAppend(self.id, item.id) + ) + + def rdfHeader(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoRdfHeader(self.id) + ) + + def rdfAppend(self, item): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoRdfAppend(self.id, item.id) + ) + + def cmlHeader(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCmlHeader(self.id) + ) + + def cmlAppend(self, item): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCmlAppend(self.id, item.id) + ) + + def cmlFooter(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCmlFooter(self.id) + ) + + def append(self, object): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoAppend(self.id, object.id) + ) + + def arrayAdd(self, object): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoArrayAdd(self.id, object.id) + ) + + def at(self, index): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult(Indigo._lib.indigoAt(self.id, index)), + ) + + def count(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult(Indigo._lib.indigoCount(self.id)) + + def clear(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult(Indigo._lib.indigoClear(self.id)) + + def iterateArray(self): + self.dispatcher._setSessionId() + newobj = self.dispatcher._checkResult( + Indigo._lib.indigoIterateArray(self.id) + ) + if newobj == 0: + return None + else: + return self.dispatcher.IndigoObject(self.dispatcher, newobj, self) + + def ignoreAtom(self, atom_object): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoIgnoreAtom(self.id, atom_object.id) + ) + + def unignoreAtom(self, atom_object): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoUnignoreAtom(self.id, atom_object.id) + ) + + def unignoreAllAtoms(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoUnignoreAllAtoms(self.id) + ) + + def match(self, query): + self.dispatcher._setSessionId() + newobj = self.dispatcher._checkResult( + Indigo._lib.indigoMatch(self.id, query.id) + ) + if newobj == 0: + return None + else: + return self.dispatcher.IndigoObject(self.dispatcher, newobj, self) + + def countMatches(self, query): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCountMatches(self.id, query.id) + ) + + def countMatchesWithLimit(self, query, embeddings_limit): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoCountMatchesWithLimit( + self.id, query.id, embeddings_limit + ) + ) + + def iterateMatches(self, query): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateMatches(self.id, query.id) + ), + ) + + def highlightedTarget(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoHighlightedTarget(self.id) + ), + ) + + def mapAtom(self, atom): + self.dispatcher._setSessionId() + newobj = self.dispatcher._checkResult( + Indigo._lib.indigoMapAtom(self.id, atom.id) + ) + if newobj == 0: + return None + else: + return self.dispatcher.IndigoObject(self.dispatcher, newobj, self) + + def mapBond(self, bond): + self.dispatcher._setSessionId() + newobj = self.dispatcher._checkResult( + Indigo._lib.indigoMapBond(self.id, bond.id) + ) + if newobj == 0: + return None + else: + return self.dispatcher.IndigoObject(self.dispatcher, newobj, self) + + def mapMolecule(self, molecule): + self.dispatcher._setSessionId() + newobj = self.dispatcher._checkResult( + Indigo._lib.indigoMapMolecule(self.id, molecule.id) + ) + if newobj == 0: + return None + else: + return self.dispatcher.IndigoObject(self.dispatcher, newobj, self) + + def allScaffolds(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoAllScaffolds(self.id) + ), + ) + + def decomposedMoleculeScaffold(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoDecomposedMoleculeScaffold(self.id) + ), + ) + + def iterateDecomposedMolecules(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateDecomposedMolecules(self.id) + ), + ) + + def decomposedMoleculeHighlighted(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoDecomposedMoleculeHighlighted(self.id) + ), + ) + + def decomposedMoleculeWithRGroups(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoDecomposedMoleculeWithRGroups(self.id) + ), + ) + + def decomposeMolecule(self, mol): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoDecomposeMolecule(self.id, mol.id) + ), + ) + + def iterateDecompositions(self): + self.dispatcher._setSessionId() + return self.dispatcher.IndigoObject( + self.dispatcher, + self.dispatcher._checkResult( + Indigo._lib.indigoIterateDecompositions(self.id) + ), + ) + + def addDecomposition(self, q_match): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoAddDecomposition(self.id, q_match.id) + ) + + def toString(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoToString(self.id) + ) + + def toBuffer(self): + c_size = c_int() + c_buf = POINTER(c_byte)() + self.dispatcher._setSessionId() + self.dispatcher._checkResult( + Indigo._lib.indigoToBuffer( + self.id, pointer(c_buf), pointer(c_size) + ) + ) + res = array("b") + for i in range(c_size.value): + res.append(c_buf[i]) + return res + + def stereocenterPyramid(self): + self.dispatcher._setSessionId() + ptr = self.dispatcher._checkResultPtr( + Indigo._lib.indigoStereocenterPyramid(self.id) + ) + res = [0] * 4 + for i in range(4): + res[i] = ptr[i] + return res + + def expandAbbreviations(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResult( + Indigo._lib.indigoExpandAbbreviations(self.id) + ) + + def dbgInternalType(self): + self.dispatcher._setSessionId() + return self.dispatcher._checkResultString( + Indigo._lib.indigoDbgInternalType(self.id) + ) + + +class Indigo(object): + ABS = 1 + OR = 2 + AND = 3 + EITHER = 4 + UP = 5 + DOWN = 6 + CIS = 7 + TRANS = 8 + CHAIN = 9 + RING = 10 + ALLENE = 11 + + SINGLET = 101 + DOUBLET = 102 + TRIPLET = 103 + RC_NOT_CENTER = -1 + RC_UNMARKED = 0 + RC_CENTER = 1 + RC_UNCHANGED = 2 + RC_MADE_OR_BROKEN = 4 + RC_ORDER_CHANGED = 8 + + SG_TYPE_GEN = 0 + SG_TYPE_DAT = 1 + SG_TYPE_SUP = 2 + SG_TYPE_SRU = 3 + SG_TYPE_MUL = 4 + SG_TYPE_MON = 5 + SG_TYPE_MER = 6 + SG_TYPE_COP = 7 + SG_TYPE_CRO = 8 + SG_TYPE_MOD = 9 + SG_TYPE_GRA = 10 + SG_TYPE_COM = 11 + SG_TYPE_MIX = 12 + SG_TYPE_FOR = 13 + SG_TYPE_ANY = 14 + + _crt = None + _crtp = None + _lib = None + + # Python embeds path into .pyc code if method is marked with @staticmethod + # This causes an error when Indigo is loaded from different places by relative path + def _initStatic(self, path=None): + def cdll_if_exists(cdll_path_): + if os.path.exists(cdll_path_): + return CDLL(cdll_path_) + return None + + paths = [] + if not path: + cur_file = os.path.abspath(__file__) + paths = [ + os.path.join(os.path.dirname(cur_file), "lib"), + os.path.join( + os.path.dirname(os.path.dirname(cur_file)), "lib" + ), + ] + else: + paths.append(path) + + indigoFound = False + for path in paths: + if ( + os.name == "posix" + and not platform.mac_ver()[0] + and not platform.system().startswith("CYGWIN") + ): + arch = platform.architecture()[0] + path = os.path.join(path, "Linux") + if arch == "32bit": + path = os.path.join(path, "x86") + elif arch == "64bit": + path = os.path.join(path, "x64") + else: + raise IndigoException("unknown platform " + arch) + if os.path.exists(os.path.join(path, "libindigo.so")): + Indigo._lib = CDLL( + os.path.join(path, "libindigo.so"), mode=RTLD_GLOBAL + ) + indigoFound = True + Indigo.dllpath = path + elif os.name == "nt" or platform.system().startswith("CYGWIN"): + arch = platform.architecture()[0] + path = os.path.join(path, "Win") + if arch == "32bit": + path = os.path.join(path, "x86") + elif arch == "64bit": + path = os.path.join(path, "x64") + else: + raise IndigoException("unknown platform " + arch) + if os.path.exists(os.path.join(path, "indigo.dll")): + Indigo._crt = cdll_if_exists( + os.path.join(path, "vcruntime140.dll") + ) + Indigo._crt_1 = cdll_if_exists( + os.path.join(path, "vcruntime140_1.dll") + ) + Indigo._crtp = cdll_if_exists( + os.path.join(path, "msvcp140.dll") + ) + Indigo._crtc = cdll_if_exists( + os.path.join(path, "concrt140.dll") + ) + Indigo._lib = CDLL(os.path.join(path, "indigo.dll")) + indigoFound = True + Indigo.dllpath = path + elif platform.mac_ver()[0]: + path = os.path.join(path, "Mac") + mac_ver = ".".join(platform.mac_ver()[0].split(".")[:2]) + current_mac_ver = int(mac_ver.split(".")[1]) + using_mac_ver = None + for version in reversed(range(5, current_mac_ver + 1)): + if os.path.exists( + os.path.join(path, "10." + str(version)) + ): + using_mac_ver = str(version) + break + if using_mac_ver: + path = os.path.join(path, "10." + using_mac_ver) + Indigo._lib = CDLL( + os.path.join(path, "libindigo.dylib"), mode=RTLD_GLOBAL + ) + indigoFound = True + Indigo.dllpath = path + else: + raise IndigoException("unsupported OS: " + os.name) + if not indigoFound: + raise IndigoException( + "Could not find native libraries for target OS in search directories: {}".format( + os.pathsep.join(paths) + ) + ) + + def _setSessionId(self): + Indigo._lib.indigoSetSessionId(self._sid) + + def __init__(self, path=None): + if Indigo._lib is None: + self._initStatic(path) + self._sid = Indigo._lib.indigoAllocSessionId() + # Capture a reference to the _lib to access it in the __del__ method because + # at interpreter shutdown, the module's global variables are set to None + self._lib = Indigo._lib + self._setSessionId() + self.IndigoObject = IndigoObject + Indigo._lib.indigoVersion.restype = c_char_p + Indigo._lib.indigoVersion.argtypes = None + Indigo._lib.indigoAllocSessionId.restype = c_ulonglong + Indigo._lib.indigoAllocSessionId.argtypes = None + Indigo._lib.indigoSetSessionId.restype = None + Indigo._lib.indigoSetSessionId.argtypes = [c_ulonglong] + Indigo._lib.indigoReleaseSessionId.restype = None + Indigo._lib.indigoReleaseSessionId.argtypes = [c_ulonglong] + Indigo._lib.indigoGetLastError.restype = c_char_p + Indigo._lib.indigoGetLastError.argtypes = None + Indigo._lib.indigoFree.restype = c_int + Indigo._lib.indigoFree.argtypes = [c_int] + Indigo._lib.indigoCountReferences.restype = c_int + Indigo._lib.indigoCountReferences.argtypes = None + Indigo._lib.indigoFreeAllObjects.restype = c_int + Indigo._lib.indigoFreeAllObjects.argtypes = None + Indigo._lib.indigoSetOption.restype = c_int + Indigo._lib.indigoSetOption.argtypes = [c_char_p, c_char_p] + Indigo._lib.indigoSetOptionInt.restype = c_int + Indigo._lib.indigoSetOptionInt.argtypes = [c_char_p, c_int] + Indigo._lib.indigoSetOptionBool.restype = c_int + Indigo._lib.indigoSetOptionBool.argtypes = [c_char_p, c_int] + Indigo._lib.indigoSetOptionFloat.restype = c_int + Indigo._lib.indigoSetOptionFloat.argtypes = [c_char_p, c_float] + Indigo._lib.indigoSetOptionColor.restype = c_int + Indigo._lib.indigoSetOptionColor.argtypes = [ + c_char_p, + c_float, + c_float, + c_float, + ] + Indigo._lib.indigoSetOptionXY.restype = c_int + Indigo._lib.indigoSetOptionXY.argtypes = [c_char_p, c_int, c_int] + Indigo._lib.indigoGetOption.restype = c_char_p + Indigo._lib.indigoGetOption.argtypes = [c_char_p] + Indigo._lib.indigoGetOptionInt.restype = c_int + Indigo._lib.indigoGetOptionInt.argtypes = [c_char_p, POINTER(c_int)] + Indigo._lib.indigoGetOptionBool.argtypes = [c_char_p, POINTER(c_int)] + Indigo._lib.indigoGetOptionBool.restype = c_int + Indigo._lib.indigoGetOptionFloat.argtypes = [ + c_char_p, + POINTER(c_float), + ] + Indigo._lib.indigoGetOptionFloat.restype = c_int + Indigo._lib.indigoGetOptionColor.argtypes = [ + c_char_p, + POINTER(c_float), + POINTER(c_float), + POINTER(c_float), + ] + Indigo._lib.indigoGetOptionColor.restype = c_int + Indigo._lib.indigoGetOptionXY.argtypes = [ + c_char_p, + POINTER(c_int), + POINTER(c_int), + ] + Indigo._lib.indigoGetOptionXY.restype = c_int + Indigo._lib.indigoGetOptionType.restype = c_char_p + Indigo._lib.indigoGetOptionType.argtypes = [c_char_p] + Indigo._lib.indigoReadFile.restype = c_int + Indigo._lib.indigoReadFile.argtypes = [c_char_p] + Indigo._lib.indigoLoadString.restype = c_int + Indigo._lib.indigoLoadString.argtypes = [c_char_p] + Indigo._lib.indigoLoadBuffer.restype = c_int + Indigo._lib.indigoLoadBuffer.argtypes = [POINTER(c_byte), c_int] + Indigo._lib.indigoWriteFile.restype = c_int + Indigo._lib.indigoWriteFile.argtypes = [c_char_p] + Indigo._lib.indigoWriteBuffer.restype = c_int + Indigo._lib.indigoWriteBuffer.argtypes = None + Indigo._lib.indigoCreateMolecule.restype = c_int + Indigo._lib.indigoCreateMolecule.argtypes = None + Indigo._lib.indigoCreateQueryMolecule.restype = c_int + Indigo._lib.indigoCreateQueryMolecule.argtypes = None + Indigo._lib.indigoLoadMoleculeFromString.restype = c_int + Indigo._lib.indigoLoadMoleculeFromString.argtypes = [c_char_p] + Indigo._lib.indigoLoadMoleculeFromFile.restype = c_int + Indigo._lib.indigoLoadMoleculeFromFile.argtypes = [c_char_p] + Indigo._lib.indigoLoadMoleculeFromBuffer.restype = c_int + Indigo._lib.indigoLoadMoleculeFromBuffer.argtypes = [ + POINTER(c_byte), + c_int, + ] + Indigo._lib.indigoLoadQueryMoleculeFromString.restype = c_int + Indigo._lib.indigoLoadQueryMoleculeFromString.argtypes = [c_char_p] + Indigo._lib.indigoLoadQueryMoleculeFromFile.restype = c_int + Indigo._lib.indigoLoadQueryMoleculeFromFile.argtypes = [c_char_p] + Indigo._lib.indigoLoadSmartsFromString.restype = c_int + Indigo._lib.indigoLoadSmartsFromString.argtypes = [c_char_p] + Indigo._lib.indigoLoadSmartsFromFile.restype = c_int + Indigo._lib.indigoLoadSmartsFromFile.argtypes = [c_char_p] + Indigo._lib.indigoLoadReactionFromString.restype = c_int + Indigo._lib.indigoLoadReactionFromString.argtypes = [c_char_p] + Indigo._lib.indigoLoadReactionFromFile.restype = c_int + Indigo._lib.indigoLoadReactionFromFile.argtypes = [c_char_p] + Indigo._lib.indigoLoadQueryReactionFromString.restype = c_int + Indigo._lib.indigoLoadQueryReactionFromString.argtypes = [c_char_p] + Indigo._lib.indigoLoadQueryReactionFromFile.restype = c_int + Indigo._lib.indigoLoadQueryReactionFromFile.argtypes = [c_char_p] + Indigo._lib.indigoLoadReactionSmartsFromString.restype = c_int + Indigo._lib.indigoLoadReactionSmartsFromString.argtypes = [c_char_p] + Indigo._lib.indigoLoadReactionSmartsFromFile.restype = c_int + Indigo._lib.indigoLoadReactionSmartsFromFile.argtypes = [c_char_p] + Indigo._lib.indigoLoadStructureFromString.restype = c_int + Indigo._lib.indigoLoadStructureFromString.argtypes = [ + c_char_p, + c_char_p, + ] + Indigo._lib.indigoLoadStructureFromBuffer.restype = c_int + Indigo._lib.indigoLoadStructureFromBuffer.argtypes = [ + POINTER(c_byte), + c_int, + c_char_p, + ] + Indigo._lib.indigoLoadStructureFromFile.restype = c_int + Indigo._lib.indigoLoadStructureFromFile.argtypes = [c_char_p, c_char_p] + Indigo._lib.indigoCreateReaction.restype = c_int + Indigo._lib.indigoCreateReaction.argtypes = None + Indigo._lib.indigoCreateQueryReaction.restype = c_int + Indigo._lib.indigoCreateQueryReaction.argtypes = None + Indigo._lib.indigoExactMatch.restype = c_int + Indigo._lib.indigoExactMatch.argtypes = [c_int, c_int, c_char_p] + Indigo._lib.indigoSetTautomerRule.restype = c_int + Indigo._lib.indigoSetTautomerRule.argtypes = [ + c_int, + c_char_p, + c_char_p, + ] + Indigo._lib.indigoRemoveTautomerRule.restype = c_int + Indigo._lib.indigoRemoveTautomerRule.argtypes = [c_int] + Indigo._lib.indigoClearTautomerRules.restype = c_int + Indigo._lib.indigoClearTautomerRules.argtypes = None + Indigo._lib.indigoUnserialize.restype = c_int + Indigo._lib.indigoUnserialize.argtypes = [POINTER(c_byte), c_int] + Indigo._lib.indigoCommonBits.restype = c_int + Indigo._lib.indigoCommonBits.argtypes = [c_int, c_int] + Indigo._lib.indigoSimilarity.restype = c_float + Indigo._lib.indigoSimilarity.argtypes = [c_int, c_int, c_char_p] + Indigo._lib.indigoIterateSDF.restype = c_int + Indigo._lib.indigoIterateSDF.argtypes = [c_int] + Indigo._lib.indigoIterateRDF.restype = c_int + Indigo._lib.indigoIterateRDF.argtypes = [c_int] + Indigo._lib.indigoIterateSmiles.restype = c_int + Indigo._lib.indigoIterateSmiles.argtypes = [c_int] + Indigo._lib.indigoIterateCML.restype = c_int + Indigo._lib.indigoIterateCML.argtypes = [c_int] + Indigo._lib.indigoIterateCDX.restype = c_int + Indigo._lib.indigoIterateCDX.argtypes = [c_int] + Indigo._lib.indigoIterateSDFile.restype = c_int + Indigo._lib.indigoIterateSDFile.argtypes = [c_char_p] + Indigo._lib.indigoIterateRDFile.restype = c_int + Indigo._lib.indigoIterateRDFile.argtypes = [c_char_p] + Indigo._lib.indigoIterateSmilesFile.restype = c_int + Indigo._lib.indigoIterateSmilesFile.argtypes = [c_char_p] + Indigo._lib.indigoIterateCMLFile.restype = c_int + Indigo._lib.indigoIterateCMLFile.argtypes = [c_char_p] + Indigo._lib.indigoIterateCDXFile.restype = c_int + Indigo._lib.indigoIterateCDXFile.argtypes = [c_char_p] + Indigo._lib.indigoCreateSaver.restype = c_int + Indigo._lib.indigoCreateSaver.argtypes = [c_int, c_char_p] + Indigo._lib.indigoCreateFileSaver.restype = c_int + Indigo._lib.indigoCreateFileSaver.argtypes = [c_char_p, c_char_p] + Indigo._lib.indigoCreateArray.restype = c_int + Indigo._lib.indigoCreateArray.argtypes = None + Indigo._lib.indigoSubstructureMatcher.restype = c_int + Indigo._lib.indigoSubstructureMatcher.argtypes = [c_int, c_char_p] + Indigo._lib.indigoExtractCommonScaffold.restype = c_int + Indigo._lib.indigoExtractCommonScaffold.argtypes = [c_int, c_char_p] + Indigo._lib.indigoDecomposeMolecules.restype = c_int + Indigo._lib.indigoDecomposeMolecules.argtypes = [c_int, c_int] + Indigo._lib.indigoRGroupComposition.restype = c_int + Indigo._lib.indigoRGroupComposition.argtypes = [c_int, c_char_p] + Indigo._lib.indigoGetFragmentedMolecule.restype = c_int + Indigo._lib.indigoGetFragmentedMolecule.argtypes = [c_int, c_char_p] + Indigo._lib.indigoCreateDecomposer.restype = c_int + Indigo._lib.indigoCreateDecomposer.argtypes = [c_int] + Indigo._lib.indigoReactionProductEnumerate.restype = c_int + Indigo._lib.indigoReactionProductEnumerate.argtypes = [c_int, c_int] + Indigo._lib.indigoTransform.restype = c_int + Indigo._lib.indigoTransform.argtypes = [c_int, c_int] + Indigo._lib.indigoDbgBreakpoint.restype = None + Indigo._lib.indigoDbgBreakpoint.argtypes = None + Indigo._lib.indigoClone.restype = c_int + Indigo._lib.indigoClone.argtypes = [c_int] + Indigo._lib.indigoCheck.restype = c_char_p + Indigo._lib.indigoCheck.argtypes = [c_int, c_char_p] + Indigo._lib.indigoCheckStructure.restype = c_char_p + Indigo._lib.indigoCheckStructure.argtypes = [c_char_p, c_char_p] + Indigo._lib.indigoClose.restype = c_int + Indigo._lib.indigoClose.argtypes = [c_int] + Indigo._lib.indigoNext.restype = c_int + Indigo._lib.indigoNext.argtypes = [c_int] + Indigo._lib.indigoHasNext.restype = c_int + Indigo._lib.indigoHasNext.argtypes = [c_int] + Indigo._lib.indigoIndex.restype = c_int + Indigo._lib.indigoIndex.argtypes = [c_int] + Indigo._lib.indigoRemove.restype = c_int + Indigo._lib.indigoRemove.argtypes = [c_int] + Indigo._lib.indigoSaveMolfileToFile.restype = c_int + Indigo._lib.indigoSaveMolfileToFile.argtypes = [c_int, c_char_p] + Indigo._lib.indigoMolfile.restype = c_char_p + Indigo._lib.indigoMolfile.argtypes = [c_int] + Indigo._lib.indigoSaveCmlToFile.restype = c_int + Indigo._lib.indigoSaveCmlToFile.argtypes = [c_int, c_char_p] + Indigo._lib.indigoCml.restype = c_char_p + Indigo._lib.indigoCml.argtypes = [c_int] + Indigo._lib.indigoSaveCdxmlToFile.restype = c_int + Indigo._lib.indigoSaveCdxmlToFile.argtypes = [c_int, c_char_p] + Indigo._lib.indigoCdxml.restype = c_char_p + Indigo._lib.indigoCdxml.argtypes = [c_int] + Indigo._lib.indigoJson.restype = c_char_p + Indigo._lib.indigoJson.argtypes = [c_int] + Indigo._lib.indigoSaveMDLCT.restype = c_int + Indigo._lib.indigoSaveMDLCT.argtypes = [c_int, c_int] + Indigo._lib.indigoAddReactant.restype = c_int + Indigo._lib.indigoAddReactant.argtypes = [c_int, c_int] + Indigo._lib.indigoAddProduct.restype = c_int + Indigo._lib.indigoAddProduct.argtypes = [c_int, c_int] + Indigo._lib.indigoAddCatalyst.restype = c_int + Indigo._lib.indigoAddCatalyst.argtypes = [c_int, c_int] + Indigo._lib.indigoCountReactants.restype = c_int + Indigo._lib.indigoCountReactants.argtypes = [c_int] + Indigo._lib.indigoCountProducts.restype = c_int + Indigo._lib.indigoCountProducts.argtypes = [c_int] + Indigo._lib.indigoCountCatalysts.restype = c_int + Indigo._lib.indigoCountCatalysts.argtypes = [c_int] + Indigo._lib.indigoCountMolecules.restype = c_int + Indigo._lib.indigoCountMolecules.argtypes = [c_int] + Indigo._lib.indigoGetMolecule.restype = c_int + Indigo._lib.indigoGetMolecule.argtypes = [c_int, c_int] + Indigo._lib.indigoIterateReactants.restype = c_int + Indigo._lib.indigoIterateReactants.argtypes = [c_int] + Indigo._lib.indigoIterateProducts.restype = c_int + Indigo._lib.indigoIterateProducts.argtypes = [c_int] + Indigo._lib.indigoIterateCatalysts.restype = c_int + Indigo._lib.indigoIterateCatalysts.argtypes = [c_int] + Indigo._lib.indigoIterateMolecules.restype = c_int + Indigo._lib.indigoIterateMolecules.argtypes = [c_int] + Indigo._lib.indigoSaveRxnfileToFile.restype = c_int + Indigo._lib.indigoSaveRxnfileToFile.argtypes = [c_int, c_char_p] + Indigo._lib.indigoRxnfile.restype = c_char_p + Indigo._lib.indigoRxnfile.argtypes = [c_int] + Indigo._lib.indigoOptimize.restype = c_int + Indigo._lib.indigoOptimize.argtypes = [c_int, c_char_p] + Indigo._lib.indigoNormalize.restype = c_int + Indigo._lib.indigoNormalize.argtypes = [c_int, c_char_p] + Indigo._lib.indigoStandardize.restype = c_int + Indigo._lib.indigoStandardize.argtypes = [c_int] + Indigo._lib.indigoIonize.restype = c_int + Indigo._lib.indigoIonize.argtypes = [c_int, c_float, c_float] + Indigo._lib.indigoBuildPkaModel.restype = c_int + Indigo._lib.indigoBuildPkaModel.argtypes = [c_int, c_float, c_char_p] + Indigo._lib.indigoGetAcidPkaValue.restype = POINTER(c_float) + Indigo._lib.indigoGetAcidPkaValue.argtypes = [ + c_int, + c_int, + c_int, + c_int, + ] + Indigo._lib.indigoGetBasicPkaValue.restype = POINTER(c_float) + Indigo._lib.indigoGetBasicPkaValue.argtypes = [ + c_int, + c_int, + c_int, + c_int, + ] + Indigo._lib.indigoAutomap.restype = c_int + Indigo._lib.indigoAutomap.argtypes = [c_int, c_char_p] + Indigo._lib.indigoGetAtomMappingNumber.restype = c_int + Indigo._lib.indigoGetAtomMappingNumber.argtypes = [c_int, c_int] + Indigo._lib.indigoSetAtomMappingNumber.restype = c_int + Indigo._lib.indigoSetAtomMappingNumber.argtypes = [c_int, c_int, c_int] + Indigo._lib.indigoGetReactingCenter.restype = c_int + Indigo._lib.indigoGetReactingCenter.argtypes = [ + c_int, + c_int, + POINTER(c_int), + ] + Indigo._lib.indigoSetReactingCenter.restype = c_int + Indigo._lib.indigoSetReactingCenter.argtypes = [c_int, c_int, c_int] + Indigo._lib.indigoClearAAM.restype = c_int + Indigo._lib.indigoClearAAM.argtypes = [c_int] + Indigo._lib.indigoCorrectReactingCenters.restype = c_int + Indigo._lib.indigoCorrectReactingCenters.argtypes = [c_int] + Indigo._lib.indigoIterateAtoms.restype = c_int + Indigo._lib.indigoIterateAtoms.argtypes = [c_int] + Indigo._lib.indigoIteratePseudoatoms.restype = c_int + Indigo._lib.indigoIteratePseudoatoms.argtypes = [c_int] + Indigo._lib.indigoIterateRSites.restype = c_int + Indigo._lib.indigoIterateRSites.argtypes = [c_int] + Indigo._lib.indigoIterateStereocenters.restype = c_int + Indigo._lib.indigoIterateStereocenters.argtypes = [c_int] + Indigo._lib.indigoIterateAlleneCenters.restype = c_int + Indigo._lib.indigoIterateAlleneCenters.argtypes = [c_int] + Indigo._lib.indigoIterateRGroups.restype = c_int + Indigo._lib.indigoIterateRGroups.argtypes = [c_int] + Indigo._lib.indigoCountRGroups.restype = c_int + Indigo._lib.indigoCountRGroups.argtypes = [c_int] + Indigo._lib.indigoIsPseudoatom.restype = c_int + Indigo._lib.indigoIsPseudoatom.argtypes = [c_int] + Indigo._lib.indigoIsRSite.restype = c_int + Indigo._lib.indigoIsRSite.argtypes = [c_int] + Indigo._lib.indigoIsTemplateAtom.restype = c_int + Indigo._lib.indigoIsTemplateAtom.argtypes = [c_int] + Indigo._lib.indigoStereocenterType.restype = c_int + Indigo._lib.indigoStereocenterType.argtypes = [c_int] + Indigo._lib.indigoStereocenterGroup.restype = c_int + Indigo._lib.indigoStereocenterGroup.argtypes = [c_int] + Indigo._lib.indigoSetStereocenterGroup.restype = c_int + Indigo._lib.indigoSetStereocenterGroup.argtypes = [c_int, c_int] + Indigo._lib.indigoChangeStereocenterType.restype = c_int + Indigo._lib.indigoChangeStereocenterType.argtypes = [c_int, c_int] + Indigo._lib.indigoValidateChirality.restype = c_int + Indigo._lib.indigoValidateChirality.argtypes = [c_int] + Indigo._lib.indigoSingleAllowedRGroup.restype = c_int + Indigo._lib.indigoSingleAllowedRGroup.argtypes = [c_int] + Indigo._lib.indigoAddStereocenter.restype = c_int + Indigo._lib.indigoAddStereocenter.argtypes = [ + c_int, + c_int, + c_int, + c_int, + c_int, + c_int, + ] + Indigo._lib.indigoIterateRGroupFragments.restype = c_int + Indigo._lib.indigoIterateRGroupFragments.argtypes = [c_int] + Indigo._lib.indigoCountAttachmentPoints.restype = c_int + Indigo._lib.indigoCountAttachmentPoints.argtypes = [c_int] + Indigo._lib.indigoIterateAttachmentPoints.restype = c_int + Indigo._lib.indigoIterateAttachmentPoints.argtypes = [c_int, c_int] + Indigo._lib.indigoSymbol.restype = c_char_p + Indigo._lib.indigoSymbol.argtypes = [c_int] + Indigo._lib.indigoDegree.restype = c_int + Indigo._lib.indigoDegree.argtypes = [c_int] + Indigo._lib.indigoGetCharge.restype = c_int + Indigo._lib.indigoGetCharge.argtypes = [c_int, POINTER(c_int)] + Indigo._lib.indigoGetExplicitValence.restype = c_int + Indigo._lib.indigoGetExplicitValence.argtypes = [c_int, POINTER(c_int)] + Indigo._lib.indigoSetExplicitValence.restype = c_int + Indigo._lib.indigoSetExplicitValence.argtypes = [c_int, c_int] + Indigo._lib.indigoGetRadicalElectrons.restype = c_int + Indigo._lib.indigoGetRadicalElectrons.argtypes = [ + c_int, + POINTER(c_int), + ] + Indigo._lib.indigoGetRadical.restype = c_int + Indigo._lib.indigoGetRadical.argtypes = [c_int, POINTER(c_int)] + Indigo._lib.indigoSetRadical.restype = c_int + Indigo._lib.indigoSetRadical.argtypes = [c_int, c_int] + Indigo._lib.indigoAtomicNumber.restype = c_int + Indigo._lib.indigoAtomicNumber.argtypes = [c_int] + Indigo._lib.indigoIsotope.restype = c_int + Indigo._lib.indigoIsotope.argtypes = [c_int] + Indigo._lib.indigoValence.restype = c_int + Indigo._lib.indigoValence.argtypes = [c_int] + Indigo._lib.indigoCheckValence.restype = c_int + Indigo._lib.indigoCheckValence.argtypes = [c_int] + Indigo._lib.indigoCheckQuery.restype = c_int + Indigo._lib.indigoCheckQuery.argtypes = [c_int] + Indigo._lib.indigoCheckRGroups.restype = c_int + Indigo._lib.indigoCheckRGroups.argtypes = [c_int] + Indigo._lib.indigoCountHydrogens.restype = c_int + Indigo._lib.indigoCountHydrogens.argtypes = [c_int, POINTER(c_int)] + Indigo._lib.indigoCountImplicitHydrogens.restype = c_int + Indigo._lib.indigoCountImplicitHydrogens.argtypes = [c_int] + Indigo._lib.indigoXYZ.restype = POINTER(c_float) + Indigo._lib.indigoXYZ.argtypes = [c_int] + Indigo._lib.indigoCoords.restype = POINTER(c_float) + Indigo._lib.indigoCoords.argtypes = [c_int] + Indigo._lib.indigoSetXYZ.restype = c_int + Indigo._lib.indigoSetXYZ.argtypes = [c_int, c_float, c_float, c_float] + Indigo._lib.indigoCountSuperatoms.restype = c_int + Indigo._lib.indigoCountSuperatoms.argtypes = [c_int] + Indigo._lib.indigoCountDataSGroups.restype = c_int + Indigo._lib.indigoCountDataSGroups.argtypes = [c_int] + Indigo._lib.indigoCountRepeatingUnits.restype = c_int + Indigo._lib.indigoCountRepeatingUnits.argtypes = [c_int] + Indigo._lib.indigoCountMultipleGroups.restype = c_int + Indigo._lib.indigoCountMultipleGroups.argtypes = [c_int] + Indigo._lib.indigoCountGenericSGroups.restype = c_int + Indigo._lib.indigoCountGenericSGroups.argtypes = [c_int] + Indigo._lib.indigoIterateDataSGroups.restype = c_int + Indigo._lib.indigoIterateDataSGroups.argtypes = [c_int] + Indigo._lib.indigoIterateSuperatoms.restype = c_int + Indigo._lib.indigoIterateSuperatoms.argtypes = [c_int] + Indigo._lib.indigoIterateGenericSGroups.restype = c_int + Indigo._lib.indigoIterateGenericSGroups.argtypes = [c_int] + Indigo._lib.indigoIterateRepeatingUnits.restype = c_int + Indigo._lib.indigoIterateRepeatingUnits.argtypes = [c_int] + Indigo._lib.indigoIterateMultipleGroups.restype = c_int + Indigo._lib.indigoIterateMultipleGroups.argtypes = [c_int] + Indigo._lib.indigoIterateSGroups.restype = c_int + Indigo._lib.indigoIterateSGroups.argtypes = [c_int] + Indigo._lib.indigoIterateTGroups.restype = c_int + Indigo._lib.indigoIterateTGroups.argtypes = [c_int] + Indigo._lib.indigoGetSuperatom.restype = c_int + Indigo._lib.indigoGetSuperatom.argtypes = [c_int, c_int] + Indigo._lib.indigoGetDataSGroup.restype = c_int + Indigo._lib.indigoGetDataSGroup.argtypes = [c_int, c_int] + Indigo._lib.indigoGetGenericSGroup.restype = c_int + Indigo._lib.indigoGetGenericSGroup.argtypes = [c_int, c_int] + Indigo._lib.indigoGetMultipleGroup.restype = c_int + Indigo._lib.indigoGetMultipleGroup.argtypes = [c_int, c_int] + Indigo._lib.indigoGetRepeatingUnit.restype = c_int + Indigo._lib.indigoGetRepeatingUnit.argtypes = [c_int, c_int] + Indigo._lib.indigoDescription.restype = c_char_p + Indigo._lib.indigoDescription.argtypes = [c_int] + Indigo._lib.indigoData.restype = c_char_p + Indigo._lib.indigoData.argtypes = [c_int] + Indigo._lib.indigoAddDataSGroup.restype = c_int + Indigo._lib.indigoAddDataSGroup.argtypes = [ + c_int, + c_int, + POINTER(c_int), + c_int, + POINTER(c_int), + c_char_p, + c_char_p, + ] + Indigo._lib.indigoAddSuperatom.restype = c_int + Indigo._lib.indigoAddSuperatom.argtypes = [ + c_int, + c_int, + POINTER(c_int), + c_char_p, + ] + Indigo._lib.indigoSetDataSGroupXY.restype = c_int + Indigo._lib.indigoSetDataSGroupXY.argtypes = [ + c_int, + c_float, + c_float, + c_char_p, + ] + Indigo._lib.indigoSetSGroupData.restype = c_int + Indigo._lib.indigoSetSGroupData.argtypes = [c_int, c_char_p] + Indigo._lib.indigoSetSGroupCoords.restype = c_int + Indigo._lib.indigoSetSGroupCoords.argtypes = [c_int, c_float, c_float] + Indigo._lib.indigoSetSGroupDescription.restype = c_int + Indigo._lib.indigoSetSGroupDescription.argtypes = [c_int, c_char_p] + Indigo._lib.indigoSetSGroupFieldName.restype = c_int + Indigo._lib.indigoSetSGroupFieldName.argtypes = [c_int, c_char_p] + Indigo._lib.indigoSetSGroupQueryCode.restype = c_int + Indigo._lib.indigoSetSGroupQueryCode.argtypes = [c_int, c_char_p] + Indigo._lib.indigoSetSGroupQueryOper.restype = c_int + Indigo._lib.indigoSetSGroupQueryOper.argtypes = [c_int, c_char_p] + Indigo._lib.indigoSetSGroupDisplay.restype = c_int + Indigo._lib.indigoSetSGroupDisplay.argtypes = [c_int, c_char_p] + Indigo._lib.indigoSetSGroupLocation.restype = c_int + Indigo._lib.indigoSetSGroupLocation.argtypes = [c_int, c_char_p] + Indigo._lib.indigoSetSGroupTag.restype = c_int + Indigo._lib.indigoSetSGroupTag.argtypes = [c_int, c_char_p] + Indigo._lib.indigoSetSGroupTagAlign.restype = c_int + Indigo._lib.indigoSetSGroupTagAlign.argtypes = [c_int, c_int] + Indigo._lib.indigoSetSGroupDataType.restype = c_int + Indigo._lib.indigoSetSGroupDataType.argtypes = [c_int, c_char_p] + Indigo._lib.indigoSetSGroupXCoord.restype = c_int + Indigo._lib.indigoSetSGroupXCoord.argtypes = [c_int, c_float] + Indigo._lib.indigoSetSGroupYCoord.restype = c_int + Indigo._lib.indigoSetSGroupYCoord.argtypes = [c_int, c_float] + Indigo._lib.indigoCreateSGroup.restype = c_int + Indigo._lib.indigoCreateSGroup.argtypes = [c_char_p, c_int, c_char_p] + Indigo._lib.indigoSetSGroupClass.restype = c_int + Indigo._lib.indigoSetSGroupClass.argtypes = [c_int, c_char_p] + Indigo._lib.indigoSetSGroupName.restype = c_int + Indigo._lib.indigoSetSGroupName.argtypes = [c_int, c_char_p] + Indigo._lib.indigoGetSGroupClass.restype = c_char_p + Indigo._lib.indigoGetSGroupClass.argtypes = [c_int] + Indigo._lib.indigoGetSGroupName.restype = c_char_p + Indigo._lib.indigoGetSGroupName.argtypes = [c_int] + Indigo._lib.indigoGetSGroupNumCrossBonds.restype = c_int + Indigo._lib.indigoGetSGroupNumCrossBonds.argtypes = [c_int] + Indigo._lib.indigoAddSGroupAttachmentPoint.restype = c_int + Indigo._lib.indigoAddSGroupAttachmentPoint.argtypes = [ + c_int, + c_int, + c_int, + c_char_p, + ] + Indigo._lib.indigoDeleteSGroupAttachmentPoint.restype = c_int + Indigo._lib.indigoDeleteSGroupAttachmentPoint.argtypes = [c_int, c_int] + Indigo._lib.indigoGetSGroupDisplayOption.restype = c_int + Indigo._lib.indigoGetSGroupDisplayOption.argtypes = [c_int] + Indigo._lib.indigoSetSGroupDisplayOption.restype = c_int + Indigo._lib.indigoSetSGroupDisplayOption.argtypes = [c_int, c_int] + Indigo._lib.indigoGetSGroupSeqId.restype = c_int + Indigo._lib.indigoGetSGroupSeqId.argtypes = [c_int] + Indigo._lib.indigoGetSGroupCoords.restype = POINTER(c_float) + Indigo._lib.indigoGetSGroupCoords.argtypes = [c_int] + Indigo._lib.indigoGetRepeatingUnitSubscript.restype = c_char_p + Indigo._lib.indigoGetRepeatingUnitSubscript.argtypes = [c_int] + Indigo._lib.indigoGetRepeatingUnitConnectivity.restype = c_int + Indigo._lib.indigoGetRepeatingUnitConnectivity.argtypes = [c_int] + Indigo._lib.indigoGetSGroupMultiplier.restype = c_int + Indigo._lib.indigoGetSGroupMultiplier.argtypes = [c_int] + Indigo._lib.indigoSetSGroupMultiplier.restype = c_int + Indigo._lib.indigoSetSGroupMultiplier.argtypes = [c_int, c_int] + Indigo._lib.indigoSetSGroupBrackets.restype = c_int + Indigo._lib.indigoSetSGroupBrackets.argtypes = [ + c_int, + c_int, + c_float, + c_float, + c_float, + c_float, + c_float, + c_float, + c_float, + c_float, + ] + Indigo._lib.indigoFindSGroups.restype = c_int + Indigo._lib.indigoFindSGroups.argtypes = [c_int, c_char_p, c_char_p] + Indigo._lib.indigoGetSGroupType.restype = c_int + Indigo._lib.indigoGetSGroupType.argtypes = [c_int] + Indigo._lib.indigoGetSGroupIndex.restype = c_int + Indigo._lib.indigoGetSGroupIndex.argtypes = [c_int] + Indigo._lib.indigoGetSGroupOriginalId.restype = c_int + Indigo._lib.indigoGetSGroupOriginalId.argtypes = [c_int] + Indigo._lib.indigoSetSGroupOriginalId.restype = c_int + Indigo._lib.indigoSetSGroupOriginalId.argtypes = [c_int, c_int] + Indigo._lib.indigoGetSGroupParentId.restype = c_int + Indigo._lib.indigoGetSGroupParentId.argtypes = [c_int] + Indigo._lib.indigoSetSGroupParentId.restype = c_int + Indigo._lib.indigoSetSGroupParentId.argtypes = [c_int, c_int] + Indigo._lib.indigoAddTemplate.restype = c_int + Indigo._lib.indigoAddTemplate.argtypes = [c_int, c_int, c_char_p] + Indigo._lib.indigoRemoveTemplate.restype = c_int + Indigo._lib.indigoRemoveTemplate.argtypes = [c_int, c_char_p] + Indigo._lib.indigoFindTemplate.restype = c_int + Indigo._lib.indigoFindTemplate.argtypes = [c_int, c_char_p] + Indigo._lib.indigoGetTGroupClass.restype = c_char_p + Indigo._lib.indigoGetTGroupClass.argtypes = [c_int] + Indigo._lib.indigoGetTGroupName.restype = c_char_p + Indigo._lib.indigoGetTGroupName.argtypes = [c_int] + Indigo._lib.indigoGetTGroupAlias.restype = c_char_p + Indigo._lib.indigoGetTGroupAlias.argtypes = [c_int] + Indigo._lib.indigoTransformSCSRtoCTAB.restype = c_int + Indigo._lib.indigoTransformSCSRtoCTAB.argtypes = [c_int] + Indigo._lib.indigoTransformCTABtoSCSR.restype = c_int + Indigo._lib.indigoTransformCTABtoSCSR.argtypes = [c_int, c_int] + Indigo._lib.indigoTransformHELMtoSCSR.restype = c_int + Indigo._lib.indigoTransformHELMtoSCSR.argtypes = [c_int] + Indigo._lib.indigoGetTemplateAtomClass.restype = c_char_p + Indigo._lib.indigoGetTemplateAtomClass.argtypes = [c_int] + Indigo._lib.indigoSetTemplateAtomClass.restype = c_int + Indigo._lib.indigoSetTemplateAtomClass.argtypes = [c_int, c_char_p] + Indigo._lib.indigoResetCharge.restype = c_int + Indigo._lib.indigoResetCharge.argtypes = [c_int] + Indigo._lib.indigoResetExplicitValence.restype = c_int + Indigo._lib.indigoResetExplicitValence.argtypes = [c_int] + Indigo._lib.indigoResetRadical.restype = c_int + Indigo._lib.indigoResetRadical.argtypes = [c_int] + Indigo._lib.indigoResetIsotope.restype = c_int + Indigo._lib.indigoResetIsotope.argtypes = [c_int] + Indigo._lib.indigoSetAttachmentPoint.restype = c_int + Indigo._lib.indigoSetAttachmentPoint.argtypes = [c_int, c_int] + Indigo._lib.indigoClearAttachmentPoints.restype = c_int + Indigo._lib.indigoClearAttachmentPoints.argtypes = [c_int] + Indigo._lib.indigoRemoveConstraints.restype = c_int + Indigo._lib.indigoRemoveConstraints.argtypes = [c_int, c_char_p] + Indigo._lib.indigoAddConstraint.restype = c_int + Indigo._lib.indigoAddConstraint.argtypes = [c_int, c_char_p, c_char_p] + Indigo._lib.indigoAddConstraintNot.restype = c_int + Indigo._lib.indigoAddConstraintNot.argtypes = [ + c_int, + c_char_p, + c_char_p, + ] + Indigo._lib.indigoAddConstraintOr.restype = c_int + Indigo._lib.indigoAddConstraintOr.argtypes = [ + c_int, + c_char_p, + c_char_p, + ] + Indigo._lib.indigoResetStereo.restype = c_int + Indigo._lib.indigoResetStereo.argtypes = [c_int] + Indigo._lib.indigoInvertStereo.restype = c_int + Indigo._lib.indigoInvertStereo.argtypes = [c_int] + Indigo._lib.indigoCountAtoms.restype = c_int + Indigo._lib.indigoCountAtoms.argtypes = [c_int] + Indigo._lib.indigoCountBonds.restype = c_int + Indigo._lib.indigoCountBonds.argtypes = [c_int] + Indigo._lib.indigoCountPseudoatoms.restype = c_int + Indigo._lib.indigoCountPseudoatoms.argtypes = [c_int] + Indigo._lib.indigoCountRSites.restype = c_int + Indigo._lib.indigoCountRSites.argtypes = [c_int] + Indigo._lib.indigoIterateBonds.restype = c_int + Indigo._lib.indigoIterateBonds.argtypes = [c_int] + Indigo._lib.indigoBondOrder.restype = c_int + Indigo._lib.indigoBondOrder.argtypes = [c_int] + Indigo._lib.indigoBondStereo.restype = c_int + Indigo._lib.indigoBondStereo.argtypes = [c_int] + Indigo._lib.indigoTopology.restype = c_int + Indigo._lib.indigoTopology.argtypes = [c_int] + Indigo._lib.indigoIterateNeighbors.restype = c_int + Indigo._lib.indigoIterateNeighbors.argtypes = [c_int] + Indigo._lib.indigoBond.restype = c_int + Indigo._lib.indigoBond.argtypes = [c_int] + Indigo._lib.indigoGetAtom.restype = c_int + Indigo._lib.indigoGetAtom.argtypes = [c_int, c_int] + Indigo._lib.indigoGetBond.restype = c_int + Indigo._lib.indigoGetBond.argtypes = [c_int, c_int] + Indigo._lib.indigoSource.restype = c_int + Indigo._lib.indigoSource.argtypes = [c_int] + Indigo._lib.indigoDestination.restype = c_int + Indigo._lib.indigoDestination.argtypes = [c_int] + Indigo._lib.indigoClearCisTrans.restype = c_int + Indigo._lib.indigoClearCisTrans.argtypes = [c_int] + Indigo._lib.indigoClearStereocenters.restype = c_int + Indigo._lib.indigoClearStereocenters.argtypes = [c_int] + Indigo._lib.indigoCountStereocenters.restype = c_int + Indigo._lib.indigoCountStereocenters.argtypes = [c_int] + Indigo._lib.indigoClearAlleneCenters.restype = c_int + Indigo._lib.indigoClearAlleneCenters.argtypes = [c_int] + Indigo._lib.indigoCountAlleneCenters.restype = c_int + Indigo._lib.indigoCountAlleneCenters.argtypes = [c_int] + Indigo._lib.indigoResetSymmetricCisTrans.restype = c_int + Indigo._lib.indigoResetSymmetricCisTrans.argtypes = [c_int] + Indigo._lib.indigoResetSymmetricStereocenters.restype = c_int + Indigo._lib.indigoResetSymmetricStereocenters.argtypes = [c_int] + Indigo._lib.indigoMarkEitherCisTrans.restype = c_int + Indigo._lib.indigoMarkEitherCisTrans.argtypes = [c_int] + Indigo._lib.indigoMarkStereobonds.restype = c_int + Indigo._lib.indigoMarkStereobonds.argtypes = [c_int] + Indigo._lib.indigoAddAtom.restype = c_int + Indigo._lib.indigoAddAtom.argtypes = [c_int, c_char_p] + Indigo._lib.indigoResetAtom.restype = c_int + Indigo._lib.indigoResetAtom.argtypes = [c_int, c_char_p] + Indigo._lib.indigoAddRSite.restype = c_int + Indigo._lib.indigoAddRSite.argtypes = [c_int, c_char_p] + Indigo._lib.indigoSetRSite.restype = c_int + Indigo._lib.indigoSetRSite.argtypes = [c_int, c_char_p] + Indigo._lib.indigoSetCharge.restype = c_int + Indigo._lib.indigoSetCharge.argtypes = [c_int, c_int] + Indigo._lib.indigoSetIsotope.restype = c_int + Indigo._lib.indigoSetIsotope.argtypes = [c_int, c_int] + Indigo._lib.indigoSetImplicitHCount.restype = c_int + Indigo._lib.indigoSetImplicitHCount.argtypes = [c_int, c_int] + Indigo._lib.indigoAddBond.restype = c_int + Indigo._lib.indigoAddBond.argtypes = [c_int, c_int, c_int] + Indigo._lib.indigoSetBondOrder.restype = c_int + Indigo._lib.indigoSetBondOrder.argtypes = [c_int, c_int] + Indigo._lib.indigoMerge.restype = c_int + Indigo._lib.indigoMerge.argtypes = [c_int, c_int] + Indigo._lib.indigoHighlight.restype = c_int + Indigo._lib.indigoHighlight.argtypes = [c_int] + Indigo._lib.indigoUnhighlight.restype = c_int + Indigo._lib.indigoUnhighlight.argtypes = [c_int] + Indigo._lib.indigoIsHighlighted.restype = c_int + Indigo._lib.indigoIsHighlighted.argtypes = [c_int] + Indigo._lib.indigoCountComponents.restype = c_int + Indigo._lib.indigoCountComponents.argtypes = [c_int] + Indigo._lib.indigoComponentIndex.restype = c_int + Indigo._lib.indigoComponentIndex.argtypes = [c_int] + Indigo._lib.indigoIterateComponents.restype = c_int + Indigo._lib.indigoIterateComponents.argtypes = [c_int] + Indigo._lib.indigoComponent.restype = c_int + Indigo._lib.indigoComponent.argtypes = [c_int, c_int] + Indigo._lib.indigoCountSSSR.restype = c_int + Indigo._lib.indigoCountSSSR.argtypes = [c_int] + Indigo._lib.indigoIterateSSSR.restype = c_int + Indigo._lib.indigoIterateSSSR.argtypes = [c_int] + Indigo._lib.indigoIterateSubtrees.restype = c_int + Indigo._lib.indigoIterateSubtrees.argtypes = [c_int, c_int, c_int] + Indigo._lib.indigoIterateRings.restype = c_int + Indigo._lib.indigoIterateRings.argtypes = [c_int, c_int, c_int] + Indigo._lib.indigoIterateEdgeSubmolecules.restype = c_int + Indigo._lib.indigoIterateEdgeSubmolecules.argtypes = [ + c_int, + c_int, + c_int, + ] + Indigo._lib.indigoCountHeavyAtoms.restype = c_int + Indigo._lib.indigoCountHeavyAtoms.argtypes = [c_int] + Indigo._lib.indigoGrossFormula.restype = c_int + Indigo._lib.indigoGrossFormula.argtypes = [c_int] + Indigo._lib.indigoMolecularWeight.restype = c_double + Indigo._lib.indigoMolecularWeight.argtypes = [c_int] + Indigo._lib.indigoMostAbundantMass.restype = c_double + Indigo._lib.indigoMostAbundantMass.argtypes = [c_int] + Indigo._lib.indigoMonoisotopicMass.restype = c_double + Indigo._lib.indigoMonoisotopicMass.argtypes = [c_int] + Indigo._lib.indigoMassComposition.restype = c_char_p + Indigo._lib.indigoMassComposition.argtypes = [c_int] + Indigo._lib.indigoCanonicalSmiles.restype = c_char_p + Indigo._lib.indigoCanonicalSmiles.argtypes = [c_int] + Indigo._lib.indigoCanonicalSmarts.restype = c_char_p + Indigo._lib.indigoCanonicalSmarts.argtypes = [c_int] + Indigo._lib.indigoLayeredCode.restype = c_char_p + Indigo._lib.indigoLayeredCode.argtypes = [c_int] + Indigo._lib.indigoSymmetryClasses.restype = POINTER(c_int) + Indigo._lib.indigoSymmetryClasses.argtypes = [c_int, POINTER(c_int)] + Indigo._lib.indigoHasCoord.restype = c_int + Indigo._lib.indigoHasCoord.argtypes = [c_int] + Indigo._lib.indigoHasZCoord.restype = c_int + Indigo._lib.indigoHasZCoord.argtypes = [c_int] + Indigo._lib.indigoIsChiral.restype = c_int + Indigo._lib.indigoIsChiral.argtypes = [c_int] + Indigo._lib.indigoIsPossibleFischerProjection.restype = c_int + Indigo._lib.indigoIsPossibleFischerProjection.argtypes = [ + c_int, + c_char_p, + ] + Indigo._lib.indigoCreateSubmolecule.restype = c_int + Indigo._lib.indigoCreateSubmolecule.argtypes = [ + c_int, + c_int, + POINTER(c_int), + ] + Indigo._lib.indigoCreateEdgeSubmolecule.restype = c_int + Indigo._lib.indigoCreateEdgeSubmolecule.argtypes = [ + c_int, + c_int, + POINTER(c_int), + c_int, + POINTER(c_int), + ] + Indigo._lib.indigoGetSubmolecule.restype = c_int + Indigo._lib.indigoGetSubmolecule.argtypes = [ + c_int, + c_int, + POINTER(c_int), + ] + Indigo._lib.indigoRemoveAtoms.restype = c_int + Indigo._lib.indigoRemoveAtoms.argtypes = [c_int, c_int, POINTER(c_int)] + Indigo._lib.indigoRemoveBonds.restype = c_int + Indigo._lib.indigoRemoveBonds.argtypes = [c_int, c_int, POINTER(c_int)] + Indigo._lib.indigoAlignAtoms.restype = c_float + Indigo._lib.indigoAlignAtoms.argtypes = [ + c_int, + c_int, + POINTER(c_int), + POINTER(c_float), + ] + Indigo._lib.indigoAromatize.restype = c_int + Indigo._lib.indigoAromatize.argtypes = [c_int] + Indigo._lib.indigoDearomatize.restype = c_int + Indigo._lib.indigoDearomatize.argtypes = [c_int] + Indigo._lib.indigoFoldHydrogens.restype = c_int + Indigo._lib.indigoFoldHydrogens.argtypes = [c_int] + Indigo._lib.indigoUnfoldHydrogens.restype = c_int + Indigo._lib.indigoUnfoldHydrogens.argtypes = [c_int] + Indigo._lib.indigoLayout.restype = c_int + Indigo._lib.indigoLayout.argtypes = [c_int] + Indigo._lib.indigoClean2d.restype = c_int + Indigo._lib.indigoClean2d.argtypes = [c_int] + Indigo._lib.indigoSmiles.restype = c_char_p + Indigo._lib.indigoSmiles.argtypes = [c_int] + Indigo._lib.indigoSmarts.restype = c_char_p + Indigo._lib.indigoSmarts.argtypes = [c_int] + Indigo._lib.indigoName.restype = c_char_p + Indigo._lib.indigoName.argtypes = [c_int] + Indigo._lib.indigoSetName.restype = c_int + Indigo._lib.indigoSetName.argtypes = [c_int, c_char_p] + Indigo._lib.indigoSerialize.restype = c_int + Indigo._lib.indigoSerialize.argtypes = [ + c_int, + POINTER(POINTER(c_byte)), + POINTER(c_int), + ] + Indigo._lib.indigoHasProperty.restype = c_int + Indigo._lib.indigoHasProperty.argtypes = [c_int, c_char_p] + Indigo._lib.indigoGetProperty.restype = c_char_p + Indigo._lib.indigoGetProperty.argtypes = [c_int, c_char_p] + Indigo._lib.indigoSetProperty.restype = c_int + Indigo._lib.indigoSetProperty.argtypes = [c_int, c_char_p, c_char_p] + Indigo._lib.indigoRemoveProperty.restype = c_int + Indigo._lib.indigoRemoveProperty.argtypes = [c_int, c_char_p] + Indigo._lib.indigoIterateProperties.restype = c_int + Indigo._lib.indigoIterateProperties.argtypes = [c_int] + Indigo._lib.indigoClearProperties.restype = c_int + Indigo._lib.indigoClearProperties.argtypes = [c_int] + Indigo._lib.indigoCheckBadValence.restype = c_char_p + Indigo._lib.indigoCheckBadValence.argtypes = [c_int] + Indigo._lib.indigoCheckAmbiguousH.restype = c_char_p + Indigo._lib.indigoCheckAmbiguousH.argtypes = [c_int] + Indigo._lib.indigoCheckChirality.restype = c_int + Indigo._lib.indigoCheckChirality.argtypes = [c_int] + Indigo._lib.indigoCheck3DStereo.restype = c_int + Indigo._lib.indigoCheck3DStereo.argtypes = [c_int] + Indigo._lib.indigoCheckStereo.restype = c_int + Indigo._lib.indigoCheckStereo.argtypes = [c_int] + Indigo._lib.indigoFingerprint.restype = c_int + Indigo._lib.indigoFingerprint.argtypes = [c_int, c_char_p] + Indigo._lib.indigoLoadFingerprintFromBuffer.restype = c_int + Indigo._lib.indigoLoadFingerprintFromBuffer.argtypes = [ + POINTER(c_byte), + c_int, + ] + Indigo._lib.indigoLoadFingerprintFromDescriptors.restype = c_int + Indigo._lib.indigoLoadFingerprintFromDescriptors.argtypes = [ + POINTER(c_double), + c_int, + c_int, + c_double, + ] + Indigo._lib.indigoCountBits.restype = c_int + Indigo._lib.indigoCountBits.argtypes = [c_int] + Indigo._lib.indigoRawData.restype = c_char_p + Indigo._lib.indigoRawData.argtypes = [c_int] + Indigo._lib.indigoTell.restype = c_int + Indigo._lib.indigoTell.argtypes = [c_int] + Indigo._lib.indigoSdfAppend.restype = c_int + Indigo._lib.indigoSdfAppend.argtypes = [c_int, c_int] + Indigo._lib.indigoSmilesAppend.restype = c_int + Indigo._lib.indigoSmilesAppend.argtypes = [c_int, c_int] + Indigo._lib.indigoRdfHeader.restype = c_int + Indigo._lib.indigoRdfHeader.argtypes = [c_int] + Indigo._lib.indigoRdfAppend.restype = c_int + Indigo._lib.indigoRdfAppend.argtypes = [c_int, c_int] + Indigo._lib.indigoCmlHeader.restype = c_int + Indigo._lib.indigoCmlHeader.argtypes = [c_int] + Indigo._lib.indigoCmlAppend.restype = c_int + Indigo._lib.indigoCmlAppend.argtypes = [c_int, c_int] + Indigo._lib.indigoCmlFooter.restype = c_int + Indigo._lib.indigoCmlFooter.argtypes = [c_int] + Indigo._lib.indigoAppend.restype = c_int + Indigo._lib.indigoAppend.argtypes = [c_int, c_int] + Indigo._lib.indigoArrayAdd.restype = c_int + Indigo._lib.indigoArrayAdd.argtypes = [c_int, c_int] + Indigo._lib.indigoAt.restype = c_int + Indigo._lib.indigoAt.argtypes = [c_int, c_int] + Indigo._lib.indigoCount.restype = c_int + Indigo._lib.indigoCount.argtypes = [c_int] + Indigo._lib.indigoClear.restype = c_int + Indigo._lib.indigoClear.argtypes = [c_int] + Indigo._lib.indigoIterateArray.restype = c_int + Indigo._lib.indigoIterateArray.argtypes = [c_int] + Indigo._lib.indigoIgnoreAtom.restype = c_int + Indigo._lib.indigoIgnoreAtom.argtypes = [c_int, c_int] + Indigo._lib.indigoUnignoreAtom.restype = c_int + Indigo._lib.indigoUnignoreAtom.argtypes = [c_int, c_int] + Indigo._lib.indigoUnignoreAllAtoms.restype = c_int + Indigo._lib.indigoUnignoreAllAtoms.argtypes = [c_int] + Indigo._lib.indigoMatch.restype = c_int + Indigo._lib.indigoMatch.argtypes = [c_int, c_int] + Indigo._lib.indigoCountMatches.restype = c_int + Indigo._lib.indigoCountMatches.argtypes = [c_int, c_int] + Indigo._lib.indigoCountMatchesWithLimit.restype = c_int + Indigo._lib.indigoCountMatchesWithLimit.argtypes = [ + c_int, + c_int, + c_int, + ] + Indigo._lib.indigoIterateMatches.restype = c_int + Indigo._lib.indigoIterateMatches.argtypes = [c_int, c_int] + Indigo._lib.indigoHighlightedTarget.restype = c_int + Indigo._lib.indigoHighlightedTarget.argtypes = [c_int] + Indigo._lib.indigoMapAtom.restype = c_int + Indigo._lib.indigoMapAtom.argtypes = [c_int, c_int] + Indigo._lib.indigoMapBond.restype = c_int + Indigo._lib.indigoMapBond.argtypes = [c_int, c_int] + Indigo._lib.indigoMapMolecule.restype = c_int + Indigo._lib.indigoMapMolecule.argtypes = [c_int, c_int] + Indigo._lib.indigoIterateTautomers.restype = c_int + Indigo._lib.indigoIterateTautomers.argtypes = [c_int, c_char_p] + Indigo._lib.indigoAllScaffolds.restype = c_int + Indigo._lib.indigoAllScaffolds.argtypes = [c_int] + Indigo._lib.indigoDecomposedMoleculeScaffold.restype = c_int + Indigo._lib.indigoDecomposedMoleculeScaffold.argtypes = [c_int] + Indigo._lib.indigoIterateDecomposedMolecules.restype = c_int + Indigo._lib.indigoIterateDecomposedMolecules.argtypes = [c_int] + Indigo._lib.indigoDecomposedMoleculeHighlighted.restype = c_int + Indigo._lib.indigoDecomposedMoleculeHighlighted.argtypes = [c_int] + Indigo._lib.indigoDecomposedMoleculeWithRGroups.restype = c_int + Indigo._lib.indigoDecomposedMoleculeWithRGroups.argtypes = [c_int] + Indigo._lib.indigoDecomposeMolecule.restype = c_int + Indigo._lib.indigoDecomposeMolecule.argtypes = [c_int, c_int] + Indigo._lib.indigoIterateDecompositions.restype = c_int + Indigo._lib.indigoIterateDecompositions.argtypes = [c_int] + Indigo._lib.indigoAddDecomposition.restype = c_int + Indigo._lib.indigoAddDecomposition.argtypes = [c_int, c_int] + Indigo._lib.indigoToString.restype = c_char_p + Indigo._lib.indigoToString.argtypes = [c_int] + Indigo._lib.indigoOneBitsList.restype = c_char_p + Indigo._lib.indigoOneBitsList.argtypes = [c_int] + Indigo._lib.indigoToBuffer.restype = c_int + Indigo._lib.indigoToBuffer.argtypes = [ + c_int, + POINTER(POINTER(c_byte)), + POINTER(c_int), + ] + Indigo._lib.indigoStereocenterPyramid.restype = POINTER(c_int) + Indigo._lib.indigoStereocenterPyramid.argtypes = [c_int] + Indigo._lib.indigoExpandAbbreviations.restype = c_int + Indigo._lib.indigoExpandAbbreviations.argtypes = [c_int] + Indigo._lib.indigoDbgInternalType.restype = c_char_p + Indigo._lib.indigoDbgInternalType.argtypes = [c_int] + Indigo._lib.indigoNameToStructure.restype = c_int + Indigo._lib.indigoNameToStructure.argtypes = [c_char_p, c_char_p] + Indigo._lib.indigoResetOptions.restype = c_int + Indigo._lib.indigoResetOptions.argtypes = None + + def __del__(self): + if hasattr(self, "_lib"): + self._lib.indigoReleaseSessionId(self._sid) + + def deserialize(self, arr): + values = (c_byte * len(arr))() + for i in range(len(arr)): + values[i] = arr[i] + self._setSessionId() + res = Indigo._lib.indigoUnserialize(values, len(arr)) + return self.IndigoObject(self, self._checkResult(res)) + + def unserialize(self, arr): + warnings.warn( + "unserialize() is deprecated, use deserialize() instead", + DeprecationWarning, + ) + return self.deserialize(arr) + + def setOption(self, option, value1, value2=None, value3=None): + self._setSessionId() + if ( + ( + type(value1).__name__ == "str" + or type(value1).__name__ == "unicode" + ) + and value2 is None + and value3 is None + ): + self._checkResult( + Indigo._lib.indigoSetOption( + option.encode(ENCODE_ENCODING), + value1.encode(ENCODE_ENCODING), + ) + ) + elif ( + type(value1).__name__ == "int" + and value2 is None + and value3 is None + ): + self._checkResult( + Indigo._lib.indigoSetOptionInt( + option.encode(ENCODE_ENCODING), value1 + ) + ) + elif ( + type(value1).__name__ == "float" + and value2 is None + and value3 is None + ): + self._checkResult( + Indigo._lib.indigoSetOptionFloat( + option.encode(ENCODE_ENCODING), value1 + ) + ) + elif ( + type(value1).__name__ == "bool" + and value2 is None + and value3 is None + ): + value1_b = 0 + if value1: + value1_b = 1 + self._checkResult( + Indigo._lib.indigoSetOptionBool( + option.encode(ENCODE_ENCODING), value1_b + ) + ) + elif ( + type(value1).__name__ == "int" + and value2 + and type(value2).__name__ == "int" + and value3 is None + ): + self._checkResult( + Indigo._lib.indigoSetOptionXY( + option.encode(ENCODE_ENCODING), value1, value2 + ) + ) + elif ( + type(value1).__name__ == "float" + and value2 + and type(value2).__name__ == "float" + and value3 + and type(value3).__name__ == "float" + ): + self._checkResult( + Indigo._lib.indigoSetOptionColor( + option.encode(ENCODE_ENCODING), value1, value2, value3 + ) + ) + else: + raise IndigoException("bad option") + + def getOption(self, option): + self._setSessionId() + return self._checkResultString( + Indigo._lib.indigoGetOption(option.encode(ENCODE_ENCODING)) + ) + + def getOptionInt(self, option): + self._setSessionId() + value = c_int() + self._checkResult( + Indigo._lib.indigoGetOptionInt( + option.encode(ENCODE_ENCODING), pointer(value) + ) + ) + return value.value + + def getOptionBool(self, option): + self._setSessionId() + value = c_int() + self._checkResult( + Indigo._lib.indigoGetOptionBool( + option.encode(ENCODE_ENCODING), pointer(value) + ) + ) + if value.value == 1: + return True + return False + + def getOptionFloat(self, option): + self._setSessionId() + value = c_float() + self._checkResult( + Indigo._lib.indigoGetOptionFloat( + option.encode(ENCODE_ENCODING), pointer(value) + ) + ) + return value.value + + def getOptionType(self, option): + self._setSessionId() + return self._checkResultString( + Indigo._lib.indigoGetOptionType(option.encode(ENCODE_ENCODING)) + ) + + def resetOptions(self): + self._setSessionId() + self._checkResult(Indigo._lib.indigoResetOptions()) + + def _checkResult(self, result): + if result < 0: + raise IndigoException(Indigo._lib.indigoGetLastError()) + return result + + def _checkResultFloat(self, result): + if result < -0.5: + raise IndigoException(Indigo._lib.indigoGetLastError()) + return result + + def _checkResultPtr(self, result): + if result is None: + raise IndigoException(Indigo._lib.indigoGetLastError()) + return result + + def _checkResultString(self, result): + return self._checkResultPtr(result).decode(DECODE_ENCODING) + + def convertToArray(self, iteratable): + if isinstance(iteratable, IndigoObject): + return iteratable + try: + some_object_iterator = iter(iteratable) + res = self.createArray() + for obj in some_object_iterator: + res.arrayAdd(self.convertToArray(obj)) + return res + except TypeError: + raise IndigoException( + "Cannot convert object %s to an array" % (iteratable) + ) + + def dbgBreakpoint(self): + self._setSessionId() + return Indigo._lib.indigoDbgBreakpoint() + + def version(self): + self._setSessionId() + return self._checkResultString(Indigo._lib.indigoVersion()) + + def countReferences(self): + self._setSessionId() + return self._checkResult(Indigo._lib.indigoCountReferences()) + + def writeFile(self, filename): + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoWriteFile(filename.encode(ENCODE_ENCODING)) + ), + ) + + def writeBuffer(self): + self._setSessionId() + return self.IndigoObject( + self, self._checkResult(Indigo._lib.indigoWriteBuffer()) + ) + + def createMolecule(self): + self._setSessionId() + return self.IndigoObject( + self, self._checkResult(Indigo._lib.indigoCreateMolecule()) + ) + + def createQueryMolecule(self): + self._setSessionId() + return self.IndigoObject( + self, self._checkResult(Indigo._lib.indigoCreateQueryMolecule()) + ) + + def loadMolecule(self, string): + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoLoadMoleculeFromString( + string.encode(ENCODE_ENCODING) + ) + ), + ) + + def loadMoleculeFromFile(self, filename): + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoLoadMoleculeFromFile( + filename.encode(ENCODE_ENCODING) + ) + ), + ) + + def loadMoleculeFromBuffer(self, data): + """ + Loads molecule from given buffer. Automatically detects input format + + Args: + * buf - byte array + + Usage: + ``` + with open (..), 'rb') as f: + m = indigo.loadMoleculeFromBuffer(f.read()) + ``` + Raises: + Exception if structure format is incorrect + + :: + + Since version 1.3.0 + """ + if sys.version_info[0] < 3: + buf = map(ord, data) + else: + buf = data + values = (c_byte * len(buf))() + for i in range(len(buf)): + values[i] = buf[i] + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoLoadMoleculeFromBuffer(values, len(buf)) + ), + ) + + def loadQueryMolecule(self, string): + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoLoadQueryMoleculeFromString( + string.encode(ENCODE_ENCODING) + ) + ), + ) + + def loadQueryMoleculeFromFile(self, filename): + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoLoadQueryMoleculeFromFile( + filename.encode(ENCODE_ENCODING) + ) + ), + ) + + def loadSmarts(self, string): + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoLoadSmartsFromString( + string.encode(ENCODE_ENCODING) + ) + ), + ) + + def loadSmartsFromFile(self, filename): + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoLoadSmartsFromFile( + filename.encode(ENCODE_ENCODING) + ) + ), + ) + + def loadReaction(self, string): + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoLoadReactionFromString( + string.encode(ENCODE_ENCODING) + ) + ), + ) + + def loadReactionFromFile(self, filename): + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoLoadReactionFromFile( + filename.encode(ENCODE_ENCODING) + ) + ), + ) + + def loadQueryReaction(self, string): + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoLoadQueryReactionFromString( + string.encode(ENCODE_ENCODING) + ) + ), + ) + + def loadQueryReactionFromFile(self, filename): + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoLoadQueryReactionFromFile( + filename.encode(ENCODE_ENCODING) + ) + ), + ) + + def loadReactionSmarts(self, string): + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoLoadReactionSmartsFromString( + string.encode(ENCODE_ENCODING) + ) + ), + ) + + def loadReactionSmartsFromFile(self, filename): + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoLoadReactionSmartsFromFile( + filename.encode(ENCODE_ENCODING) + ) + ), + ) + + def loadStructure(self, structureStr, parameter=None): + self._setSessionId() + parameter = "" if parameter is None else parameter + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoLoadStructureFromString( + structureStr.encode(ENCODE_ENCODING), + parameter.encode(ENCODE_ENCODING), + ) + ), + ) + + def loadStructureFromBuffer(self, structureData, parameter=None): + if sys.version_info[0] < 3: + buf = map(ord, structureData) + else: + buf = structureData + values = (c_byte * len(buf))() + for i in range(len(buf)): + values[i] = buf[i] + self._setSessionId() + parameter = "" if parameter is None else parameter + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoLoadStructureFromBuffer( + values, len(buf), parameter.encode(ENCODE_ENCODING) + ) + ), + ) + + def loadStructureFromFile(self, filename, parameter=None): + self._setSessionId() + parameter = "" if parameter is None else parameter + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoLoadStructureFromFile( + filename.encode(ENCODE_ENCODING), + parameter.encode(ENCODE_ENCODING), + ) + ), + ) + + def checkStructure(self, structure, props=""): + if props is None: + props = "" + self._setSessionId() + return self._checkResultString( + Indigo._lib.indigoCheckStructure( + structure.encode(ENCODE_ENCODING), + props.encode(ENCODE_ENCODING), + ) + ) + + def loadFingerprintFromBuffer(self, buffer): + """Creates a fingerprint from the supplied binary data + + :param buffer: a list of bytes + :return: a fingerprint object + + Since version 1.3.0 + """ + self._setSessionId() + length = len(buffer) + + values = (c_byte * length)() + for i in range(length): + values[i] = buffer[i] + + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoLoadFingerprintFromBuffer(values, length) + ), + ) + + def loadFingerprintFromDescriptors(self, descriptors, size, density): + """Packs a list of molecule descriptors into a fingerprint object + + :param descriptors: list of normalized numbers (roughly) between 0.0 and 1.0 + :param size: size of the fingerprint in bytes + :param density: approximate density of '1's vs `0`s in the fingerprint + :return: a fingerprint object + + Since version 1.3.0 + """ + self._setSessionId() + length = len(descriptors) + + descr_arr = (c_double * length)() + for i in range(length): + descr_arr[i] = descriptors[i] + + result = Indigo._lib.indigoLoadFingerprintFromDescriptors( + descr_arr, length, size, density + ) + return self.IndigoObject(self, self._checkResult(result)) + + def createReaction(self): + self._setSessionId() + return self.IndigoObject( + self, self._checkResult(Indigo._lib.indigoCreateReaction()) + ) + + def createQueryReaction(self): + self._setSessionId() + return self.IndigoObject( + self, self._checkResult(Indigo._lib.indigoCreateQueryReaction()) + ) + + def exactMatch(self, item1, item2, flags=""): + if flags is None: + flags = "" + self._setSessionId() + newobj = self._checkResult( + Indigo._lib.indigoExactMatch( + item1.id, item2.id, flags.encode(ENCODE_ENCODING) + ) + ) + if newobj == 0: + return None + else: + return self.IndigoObject(self, newobj, [item1, item2, self]) + + def setTautomerRule(self, id, beg, end): + self._setSessionId() + return self._checkResult( + Indigo._lib.indigoSetTautomerRule( + id, beg.encode(ENCODE_ENCODING), end.encode(ENCODE_ENCODING) + ) + ) + + def removeTautomerRule(self, id): + self._setSessionId() + return self._checkResult(Indigo._lib.indigoRemoveTautomerRule(id)) + + def clearTautomerRules(self): + self._setSessionId() + return self._checkResult(Indigo._lib.indigoClearTautomerRules()) + + def commonBits(self, fingerprint1, fingerprint2): + self._setSessionId() + return self._checkResult( + Indigo._lib.indigoCommonBits(fingerprint1.id, fingerprint2.id) + ) + + def similarity(self, item1, item2, metrics=""): + if metrics is None: + metrics = "" + self._setSessionId() + return self._checkResultFloat( + Indigo._lib.indigoSimilarity( + item1.id, item2.id, metrics.encode(ENCODE_ENCODING) + ) + ) + + def iterateSDFile(self, filename): + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoIterateSDFile( + filename.encode(ENCODE_ENCODING) + ) + ), + ) + + def iterateRDFile(self, filename): + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoIterateRDFile( + filename.encode(ENCODE_ENCODING) + ) + ), + ) + + def iterateSmilesFile(self, filename): + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoIterateSmilesFile( + filename.encode(ENCODE_ENCODING) + ) + ), + ) + + def iterateCMLFile(self, filename): + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoIterateCMLFile( + filename.encode(ENCODE_ENCODING) + ) + ), + ) + + def iterateCDXFile(self, filename): + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoIterateCDXFile( + filename.encode(ENCODE_ENCODING) + ) + ), + ) + + def createFileSaver(self, filename, format): + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoCreateFileSaver( + filename.encode(ENCODE_ENCODING), + format.encode(ENCODE_ENCODING), + ) + ), + ) + + def createSaver(self, obj, format): + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoCreateSaver( + obj.id, format.encode(ENCODE_ENCODING) + ) + ), + ) + + def createArray(self): + self._setSessionId() + return self.IndigoObject( + self, self._checkResult(Indigo._lib.indigoCreateArray()) + ) + + def substructureMatcher(self, target, mode=""): + if mode is None: + mode = "" + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoSubstructureMatcher( + target.id, mode.encode(ENCODE_ENCODING) + ) + ), + target, + ) + + def extractCommonScaffold(self, structures, options=""): + structures = self.convertToArray(structures) + if options is None: + options = "" + self._setSessionId() + newobj = self._checkResult( + Indigo._lib.indigoExtractCommonScaffold( + structures.id, options.encode(ENCODE_ENCODING) + ) + ) + if newobj == 0: + return None + else: + return self.IndigoObject(self, newobj, self) + + def decomposeMolecules(self, scaffold, structures): + structures = self.convertToArray(structures) + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoDecomposeMolecules( + scaffold.id, structures.id + ) + ), + scaffold, + ) + + def rgroupComposition(self, molecule, options=""): + if options is None: + options = "" + self._setSessionId() + newobj = self._checkResult( + Indigo._lib.indigoRGroupComposition( + molecule.id, options.encode(ENCODE_ENCODING) + ) + ) + if newobj == 0: + return None + else: + return self.IndigoObject(self, newobj, self) + + def getFragmentedMolecule(self, elem, options=""): + if options is None: + options = "" + self._setSessionId() + newobj = self._checkResult( + Indigo._lib.indigoGetFragmentedMolecule( + elem.id, options.encode(ENCODE_ENCODING) + ) + ) + if newobj == 0: + return None + else: + return self.IndigoObject(self, newobj, self) + + def createDecomposer(self, scaffold): + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult(Indigo._lib.indigoCreateDecomposer(scaffold.id)), + scaffold, + ) + + def reactionProductEnumerate(self, replacedaction, monomers): + self._setSessionId() + monomers = self.convertToArray(monomers) + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoReactionProductEnumerate( + replacedaction.id, monomers.id + ) + ), + replacedaction, + ) + + def transform(self, reaction, monomers): + self._setSessionId() + newobj = self._checkResult( + Indigo._lib.indigoTransform(reaction.id, monomers.id) + ) + if newobj == 0: + return None + else: + return self.IndigoObject(self, newobj, self) + + def loadBuffer(self, buf): + buf = list(buf) + values = (c_byte * len(buf))() + for i in range(len(buf)): + values[i] = buf[i] + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult(Indigo._lib.indigoLoadBuffer(values, len(buf))), + ) + + def loadString(self, string): + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoLoadString(string.encode(ENCODE_ENCODING)) + ), + ) + + def iterateSDF(self, reader): + self._setSessionId() + result = self._checkResult(Indigo._lib.indigoIterateSDF(reader.id)) + if not result: + return None + return self.IndigoObject(self, result, reader) + + def iterateSmiles(self, reader): + self._setSessionId() + result = self._checkResult(Indigo._lib.indigoIterateSmiles(reader.id)) + if not result: + return None + return self.IndigoObject(self, result, reader) + + def iterateCML(self, reader): + self._setSessionId() + result = self._checkResult(Indigo._lib.indigoIterateCML(reader.id)) + if not result: + return None + return self.IndigoObject(self, result, reader) + + def iterateCDX(self, reader): + self._setSessionId() + result = self._checkResult(Indigo._lib.indigoIterateCDX(reader.id)) + if not result: + return None + return self.IndigoObject(self, result, reader) + + def iterateRDF(self, reader): + self._setSessionId() + result = self._checkResult(Indigo._lib.indigoIterateRDF(reader.id)) + if not result: + return None + return self.IndigoObject(self, result, reader) + + def iterateTautomers(self, molecule, params): + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoIterateTautomers( + molecule.id, params.encode(ENCODE_ENCODING) + ) + ), + molecule, + ) + + def nameToStructure(self, name, params=None): + """ + Converts a chemical name into a corresponding structure + + Args: + * name - a name to parse + * params - a string containing parsing options or nullptr if no options are changed + + Raises: + Exception if parsing fails or no structure is found + + :: + + Since version 1.3.0 + """ + if params is None: + params = "" + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult( + Indigo._lib.indigoNameToStructure( + name.encode(ENCODE_ENCODING), + params.encode(ENCODE_ENCODING), + ) + ), + ) + + def buildPkaModel(self, level, threshold, filename): + self._setSessionId() + return self._checkResult( + Indigo._lib.indigoBuildPkaModel( + level, threshold, filename.encode(ENCODE_ENCODING) + ) + ) + + def transformHELMtoSCSR(self, item): + """ + :: + + Since version 1.3.0 + """ + self._setSessionId() + return self.IndigoObject( + self, + self._checkResult(Indigo._lib.indigoTransformHELMtoSCSR(item.id)), + ) diff --git a/molscribe/indigo/__pycache__/__init__.cpython-310.pyc b/molscribe/indigo/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9abbf5bc0daf5fa2a06c4a247b857e51a48e41ff Binary files /dev/null and b/molscribe/indigo/__pycache__/__init__.cpython-310.pyc differ diff --git a/molscribe/indigo/__pycache__/renderer.cpython-310.pyc b/molscribe/indigo/__pycache__/renderer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ab54b2677fd6247a01507359b98b27252a1461e Binary files /dev/null and b/molscribe/indigo/__pycache__/renderer.cpython-310.pyc differ diff --git a/molscribe/indigo/bingo.py b/molscribe/indigo/bingo.py new file mode 100644 index 0000000000000000000000000000000000000000..07c1f2701f31387bd32b2b189fda34f17baadabd --- /dev/null +++ b/molscribe/indigo/bingo.py @@ -0,0 +1,334 @@ +# +# Copyright (C) from 2009 to Present EPAM Systems. +# +# This file is part of Indigo toolkit. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from . import * + + +class BingoException(Exception): + + def __init__(self, value): + self.value = value + + def __str__(self): + if sys.version_info > (3, 0): + return repr(self.value.decode('ascii')) + else: + return repr(self.value) + + +class Bingo(object): + def __init__(self, bingoId, indigo, lib): + self._id = bingoId + self._indigo = indigo + self._lib = lib + self._lib.bingoVersion.restype = c_char_p + self._lib.bingoVersion.argtypes = None + self._lib.bingoCreateDatabaseFile.restype = c_int + self._lib.bingoCreateDatabaseFile.argtypes = [c_char_p, c_char_p, c_char_p] + self._lib.bingoLoadDatabaseFile.restype = c_int + self._lib.bingoLoadDatabaseFile.argtypes = [c_char_p, c_char_p] + self._lib.bingoCloseDatabase.restype = c_int + self._lib.bingoCloseDatabase.argtypes = [c_int] + self._lib.bingoInsertRecordObj.restype = c_int + self._lib.bingoInsertRecordObj.argtypes = [c_int, c_int] + self._lib.bingoInsertRecordObjWithExtFP.restype = c_int + self._lib.bingoInsertRecordObjWithExtFP.argtypes = [c_int, c_int, c_int] + self._lib.bingoGetRecordObj.restype = c_int + self._lib.bingoGetRecordObj.argtypes = [c_int, c_int] + self._lib.bingoInsertRecordObjWithId.restype = c_int + self._lib.bingoInsertRecordObjWithId.argtypes = [c_int, c_int, c_int] + self._lib.bingoInsertRecordObjWithIdAndExtFP.restype = c_int + self._lib.bingoInsertRecordObjWithIdAndExtFP.argtypes = [c_int, c_int, c_int, c_int] + self._lib.bingoDeleteRecord.restype = c_int + self._lib.bingoDeleteRecord.argtypes = [c_int, c_int] + self._lib.bingoSearchSub.restype = c_int + self._lib.bingoSearchSub.argtypes = [c_int, c_int, c_char_p] + self._lib.bingoSearchExact.restype = c_int + self._lib.bingoSearchExact.argtypes = [c_int, c_int, c_char_p] + self._lib.bingoSearchMolFormula.restype = c_int + self._lib.bingoSearchMolFormula.argtypes = [c_int, c_char_p, c_char_p] + self._lib.bingoSearchSim.restype = c_int + self._lib.bingoSearchSim.argtypes = [c_int, c_int, c_float, c_float, c_char_p] + self._lib.bingoSearchSimWithExtFP.restype = c_int + self._lib.bingoSearchSimWithExtFP.argtypes = [c_int, c_int, c_float, c_float, c_int, c_char_p] + self._lib.bingoSearchSimTopN.restype = c_int + self._lib.bingoSearchSimTopN.argtypes = [c_int, c_int, c_int, c_float, c_char_p] + self._lib.bingoSearchSimTopNWithExtFP.restype = c_int + self._lib.bingoSearchSimTopNWithExtFP.argtypes = [c_int, c_int, c_int, c_float, c_int, c_char_p] + self._lib.bingoEnumerateId.restype = c_int + self._lib.bingoEnumerateId.argtypes = [c_int] + self._lib.bingoNext.restype = c_int + self._lib.bingoNext.argtypes = [c_int] + self._lib.bingoGetCurrentId.restype = c_int + self._lib.bingoGetCurrentId.argtypes = [c_int] + self._lib.bingoGetObject.restype = c_int + self._lib.bingoGetObject.argtypes = [c_int] + self._lib.bingoEndSearch.restype = c_int + self._lib.bingoEndSearch.argtypes = [c_int] + self._lib.bingoGetCurrentSimilarityValue.restype = c_float + self._lib.bingoGetCurrentSimilarityValue.argtypes = [c_int] + self._lib.bingoOptimize.restype = c_int + self._lib.bingoOptimize.argtypes = [c_int] + self._lib.bingoEstimateRemainingResultsCount.restype = c_int + self._lib.bingoEstimateRemainingResultsCount.argtypes = [c_int] + self._lib.bingoEstimateRemainingResultsCountError.restype = c_int + self._lib.bingoEstimateRemainingResultsCountError.argtypes = [c_int] + self._lib.bingoEstimateRemainingTime.restype = c_int + self._lib.bingoEstimateRemainingTime.argtypes = [c_int, POINTER(c_float)] + self._lib.bingoContainersCount.restype = c_int + self._lib.bingoContainersCount.argtypes = [c_int] + self._lib.bingoCellsCount.restype = c_int + self._lib.bingoCellsCount.argtypes = [c_int] + self._lib.bingoCurrentCell.restype = c_int + self._lib.bingoCurrentCell.argtypes = [c_int] + self._lib.bingoMinCell.restype = c_int + self._lib.bingoMinCell.argtypes = [c_int] + self._lib.bingoMaxCell.restype = c_int + self._lib.bingoMaxCell.argtypes = [c_int] + + def __del__(self): + self.close() + + def close(self): + self._indigo._setSessionId() + if self._id >= 0: + Bingo._checkResult(self._indigo, self._lib.bingoCloseDatabase(self._id)) + self._id = -1 + + @staticmethod + def _checkResult(indigo, result): + if result < 0: + raise BingoException(indigo._lib.indigoGetLastError()) + return result + + @staticmethod + def _checkResultPtr (indigo, result): + if result is None: + raise BingoException(indigo._lib.indigoGetLastError()) + return result + + @staticmethod + def _checkResultString (indigo, result): + res = Bingo._checkResultPtr(indigo, result) + if sys.version_info >= (3, 0): + return res.decode('ascii') + else: + return res.encode('ascii') + + @staticmethod + def _getLib(indigo): + if os.name == 'posix' and not platform.mac_ver()[0] and not platform.system().startswith("CYGWIN"): + _lib = CDLL(indigo.dllpath + "/libbingo.so") + elif os.name == 'nt' or platform.system().startswith("CYGWIN"): + _lib = CDLL(indigo.dllpath + "/bingo.dll") + elif platform.mac_ver()[0]: + _lib = CDLL(indigo.dllpath + "/libbingo.dylib") + else: + raise BingoException("unsupported OS: " + os.name) + return _lib + + @staticmethod + def createDatabaseFile(indigo, path, databaseType, options=''): + indigo._setSessionId() + if not options: + options = '' + lib = Bingo._getLib(indigo) + lib.bingoCreateDatabaseFile.restype = c_int + lib.bingoCreateDatabaseFile.argtypes = [c_char_p, c_char_p, c_char_p] + return Bingo(Bingo._checkResult(indigo, lib.bingoCreateDatabaseFile(path.encode('ascii'), databaseType.encode('ascii'), options.encode('ascii'))), indigo, lib) + + @staticmethod + def loadDatabaseFile(indigo, path, options=''): + indigo._setSessionId() + if not options: + options = '' + lib = Bingo._getLib(indigo) + lib.bingoLoadDatabaseFile.restype = c_int + lib.bingoLoadDatabaseFile.argtypes = [c_char_p, c_char_p] + return Bingo(Bingo._checkResult(indigo, lib.bingoLoadDatabaseFile(path.encode('ascii'), options.encode('ascii'))), indigo, lib) + + def version(self): + self._indigo._setSessionId() + return Bingo._checkResultString(self._indigo, self._lib.bingoVersion()) + + def insert(self, indigoObject, index=None): + self._indigo._setSessionId() + if not index: + return Bingo._checkResult(self._indigo, self._lib.bingoInsertRecordObj(self._id, indigoObject.id)) + else: + return Bingo._checkResult(self._indigo, + self._lib.bingoInsertRecordObjWithId(self._id, indigoObject.id, index)) + + def insertWithExtFP(self, indigoObject, ext_fp, index=None): + self._indigo._setSessionId() + if not index: + return Bingo._checkResult(self._indigo, self._lib.bingoInsertRecordObjWithExtFP(self._id, indigoObject.id, ext_fp.id)) + else: + return Bingo._checkResult(self._indigo, + self._lib.bingoInsertRecordObjWithIdAndExtFP(self._id, indigoObject.id, index, ext_fp.id)) + + def delete(self, index): + self._indigo._setSessionId() + Bingo._checkResult(self._indigo, self._lib.bingoDeleteRecord(self._id, index)) + + def searchSub(self, query, options=''): + self._indigo._setSessionId() + if not options: + options = '' + return BingoObject(Bingo._checkResult(self._indigo, self._lib.bingoSearchSub(self._id, query.id, options.encode('ascii'))), + self._indigo, self) + + def searchExact(self, query, options=''): + self._indigo._setSessionId() + if not options: + options = '' + return BingoObject(Bingo._checkResult(self._indigo, self._lib.bingoSearchExact(self._id, query.id, options.encode('ascii'))), + self._indigo, self) + + def searchSim(self, query, minSim, maxSim, metric='tanimoto'): + self._indigo._setSessionId() + if not metric: + metric = 'tanimoto' + return BingoObject( + Bingo._checkResult(self._indigo, self._lib.bingoSearchSim(self._id, query.id, minSim, maxSim, metric.encode('ascii'))), + self._indigo, self) + + def searchSimWithExtFP(self, query, minSim, maxSim, ext_fp, metric='tanimoto'): + self._indigo._setSessionId() + if not metric: + metric = 'tanimoto' + return BingoObject( + Bingo._checkResult(self._indigo, self._lib.bingoSearchSimWithExtFP(self._id, query.id, minSim, maxSim, ext_fp.id, metric.encode('ascii'))), + self._indigo, self) + + def searchSimTopN(self, query, limit, minSim, metric='tanimoto'): + self._indigo._setSessionId() + if not metric: + metric = 'tanimoto' + return BingoObject( + Bingo._checkResult(self._indigo, self._lib.bingoSearchSimTopN(self._id, query.id, limit, minSim, metric.encode('ascii'))), + self._indigo, self) + + def searchSimTopNWithExtFP(self, query, limit, minSim, ext_fp, metric='tanimoto'): + self._indigo._setSessionId() + if not metric: + metric = 'tanimoto' + return BingoObject( + Bingo._checkResult(self._indigo, self._lib.bingoSearchSimTopNWithExtFP(self._id, query.id, limit, minSim, ext_fp.id, metric.encode('ascii'))), + self._indigo, self) + + def enumerateId(self): + self._indigo._setSessionId() + e = self._lib.bingoEnumerateId(self._id) + result = Bingo._checkResult(self._indigo, e) + return BingoObject(result, self._indigo, self) + + def searchMolFormula(self, query, options=''): + self._indigo._setSessionId() + if not options: + options = '' + return BingoObject(Bingo._checkResult(self._indigo, self._lib.bingoSearchMolFormula(self._id, query.encode('ascii'), options.encode('ascii'))), + self._indigo, self) + + def optimize(self): + self._indigo._setSessionId() + Bingo._checkResult(self._indigo, self._lib.bingoOptimize(self._id)) + + def getRecordById (self, id): + self._indigo._setSessionId() + return IndigoObject(self._indigo, Bingo._checkResult(self._indigo, self._lib.bingoGetRecordObj(self._id, id))) + +class BingoObject(object): + def __init__(self, objId, indigo, bingo): + self._id = objId + self._indigo = indigo + self._bingo = bingo + + def __del__(self): + self.close() + + def close(self): + self._indigo._setSessionId() + if self._id >= 0: + Bingo._checkResult(self._indigo, self._bingo._lib.bingoEndSearch(self._id)) + self._id = -1 + + def next(self): + self._indigo._setSessionId() + return (Bingo._checkResult(self._indigo, self._bingo._lib.bingoNext(self._id)) == 1) + + def getCurrentId(self): + self._indigo._setSessionId() + return Bingo._checkResult(self._indigo, self._bingo._lib.bingoGetCurrentId(self._id)) + + def getIndigoObject(self): + self._indigo._setSessionId() + return IndigoObject(self._indigo, Bingo._checkResult(self._indigo, self._bingo._lib.bingoGetObject(self._id))) + + def getCurrentSimilarityValue(self): + self._indigo._setSessionId() + return Bingo._checkResult(self._indigo, self._bingo._lib.bingoGetCurrentSimilarityValue(self._id)) + + def estimateRemainingResultsCount(self): + self._indigo._setSessionId() + return Bingo._checkResult(self._indigo, self._bingo._lib.bingoEstimateRemainingResultsCount(self._id)) + + def estimateRemainingResultsCountError(self): + self._indigo._setSessionId() + return Bingo._checkResult(self._indigo, self._bingo._lib.bingoEstimateRemainingResultsCountError(self._id)) + + def estimateRemainingTime(self): + self._indigo._setSessionId() + value = c_float() + Bingo._checkResult(self._indigo, self._bingo._lib.bingoEstimateRemainingTime(self._id, pointer(value))) + return value.value + + def containersCount(self): + self._indigo._setSessionId() + return Bingo._checkResult(self._indigo, self._bingo._lib.bingoContainersCount(self._id)) + + def cellsCount(self): + self._indigo._setSessionId() + return Bingo._checkResult(self._indigo, self._bingo._lib.bingoCellsCount(self._id)) + + def currentCell(self): + self._indigo._setSessionId() + return Bingo._checkResult(self._indigo, self._bingo._lib.bingoCurrentCell(self._id)) + + def minCell(self): + self._indigo._setSessionId() + return Bingo._checkResult(self._indigo, self._bingo._lib.bingoMinCell(self._id)) + + def maxCell(self): + self._indigo._setSessionId() + return Bingo._checkResult(self._indigo, self._bingo._lib.bingoMaxCell(self._id)) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def __iter__(self): + return self + + def __next__(self): + next_item = self.next() + if next_item: + return self + raise StopIteration diff --git a/molscribe/indigo/inchi.py b/molscribe/indigo/inchi.py new file mode 100644 index 0000000000000000000000000000000000000000..3a0cfd528ae05ed12af1963675217e1840468a0a --- /dev/null +++ b/molscribe/indigo/inchi.py @@ -0,0 +1,84 @@ +# +# Copyright (C) from 2009 to Present EPAM Systems. +# +# This file is part of Indigo toolkit. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import * + + +class IndigoInchi(object): + def __init__(self, indigo): + self.indigo = indigo + + if os.name == 'posix' and not platform.mac_ver()[0] and not platform.system().startswith("CYGWIN"): + self._lib = CDLL(indigo.dllpath + "/libindigo-inchi.so") + elif os.name == 'nt' or platform.system().startswith("CYGWIN"): + self._lib = CDLL(indigo.dllpath + "\indigo-inchi.dll") + elif platform.mac_ver()[0]: + self._lib = CDLL(indigo.dllpath + "/libindigo-inchi.dylib") + else: + raise IndigoException("unsupported OS: " + os.name) + + self._lib.indigoInchiVersion.restype = c_char_p + self._lib.indigoInchiVersion.argtypes = [] + self._lib.indigoInchiResetOptions.restype = c_int + self._lib.indigoInchiResetOptions.argtypes = [] + self._lib.indigoInchiLoadMolecule.restype = c_int + self._lib.indigoInchiLoadMolecule.argtypes = [c_char_p] + self._lib.indigoInchiGetInchi.restype = c_char_p + self._lib.indigoInchiGetInchi.argtypes = [c_int] + self._lib.indigoInchiGetInchiKey.restype = c_char_p + self._lib.indigoInchiGetInchiKey.argtypes = [c_char_p] + self._lib.indigoInchiGetWarning.restype = c_char_p + self._lib.indigoInchiGetWarning.argtypes = [] + self._lib.indigoInchiGetLog.restype = c_char_p + self._lib.indigoInchiGetLog.argtypes = [] + self._lib.indigoInchiGetAuxInfo.restype = c_char_p + self._lib.indigoInchiGetAuxInfo.argtypes = [] + + def resetOptions(self): + self.indigo._setSessionId() + self.indigo._checkResult(self._lib.indigoInchiResetOptions()) + + def loadMolecule(self, inchi): + self.indigo._setSessionId() + res = self.indigo._checkResult(self._lib.indigoInchiLoadMolecule(inchi.encode('ascii'))) + if res == 0: + return None + return self.indigo.IndigoObject(self.indigo, res) + + def version(self): + self.indigo._setSessionId() + return self.indigo._checkResultString(self._lib.indigoInchiVersion()) + + def getInchi(self, molecule): + self.indigo._setSessionId() + return self.indigo._checkResultString(self._lib.indigoInchiGetInchi(molecule.id)) + + def getInchiKey(self, inchi): + self.indigo._setSessionId() + return self.indigo._checkResultString(self._lib.indigoInchiGetInchiKey(inchi.encode('ascii'))) + + def getWarning(self): + self.indigo._setSessionId() + return self.indigo._checkResultString(self._lib.indigoInchiGetWarning()) + + def getLog(self): + self.indigo._setSessionId() + return self.indigo._checkResultString(self._lib.indigoInchiGetLog()) + + def getAuxInfo(self): + self.indigo._setSessionId() + return self.indigo._checkResultString(self._lib.indigoInchiGetAuxInfo()) diff --git a/molscribe/indigo/lib/Linux/x64/libbingo.so b/molscribe/indigo/lib/Linux/x64/libbingo.so new file mode 100644 index 0000000000000000000000000000000000000000..f4d94061696e1667ed4d83275452647f299a95cf --- /dev/null +++ b/molscribe/indigo/lib/Linux/x64/libbingo.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b7b7c6f577a12f13101ad6de1c617c4563eedfeb5892c49eab1c0b177739cbd8 +size 467096 diff --git a/molscribe/indigo/lib/Linux/x64/libindigo-inchi.so b/molscribe/indigo/lib/Linux/x64/libindigo-inchi.so new file mode 100644 index 0000000000000000000000000000000000000000..3d4dfe713e2a35ad618524a48d3788a46457b9ef Binary files /dev/null and b/molscribe/indigo/lib/Linux/x64/libindigo-inchi.so differ diff --git a/molscribe/indigo/lib/Linux/x64/libindigo-renderer.so b/molscribe/indigo/lib/Linux/x64/libindigo-renderer.so new file mode 100644 index 0000000000000000000000000000000000000000..294dba39f14098b56c17d84ab424557171191586 --- /dev/null +++ b/molscribe/indigo/lib/Linux/x64/libindigo-renderer.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9d7fcc1f549b32a8003dee0e146fb834de16909d5225dc0a7ff175a7350f27a1 +size 2869440 diff --git a/molscribe/indigo/lib/Linux/x64/libindigo.so b/molscribe/indigo/lib/Linux/x64/libindigo.so new file mode 100644 index 0000000000000000000000000000000000000000..adc3d4180118f920ff6801b5846351a2d3352e0d --- /dev/null +++ b/molscribe/indigo/lib/Linux/x64/libindigo.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c1520c77081f1a1c3f7497c000ea37dae10db00512bfe30c392bb061027f61a4 +size 9389128 diff --git a/molscribe/indigo/renderer.py b/molscribe/indigo/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..699c80e905d3c307239cef3c8f6d8c8f9b7dc075 --- /dev/null +++ b/molscribe/indigo/renderer.py @@ -0,0 +1,113 @@ +# +# Copyright (C) from 2009 to Present EPAM Systems. +# +# This file is part of Indigo toolkit. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import platform +from ctypes import CDLL, POINTER, c_char_p, c_int + +from . import IndigoException + + +class IndigoRenderer(object): + def __init__(self, indigo): + self.indigo = indigo + + if ( + os.name == "posix" + and not platform.mac_ver()[0] + and not platform.system().startswith("CYGWIN") + ): + self._lib = CDLL(indigo.dllpath + "/libindigo-renderer.so") + elif os.name == "nt" or platform.system().startswith("CYGWIN"): + self._lib = CDLL(indigo.dllpath + "\indigo-renderer.dll") + elif platform.mac_ver()[0]: + self._lib = CDLL(indigo.dllpath + "/libindigo-renderer.dylib") + else: + raise IndigoException("unsupported OS: " + os.name) + + self._lib.indigoRender.restype = c_int + self._lib.indigoRender.argtypes = [c_int, c_int] + self._lib.indigoRenderToFile.restype = c_int + self._lib.indigoRenderToFile.argtypes = [c_int, c_char_p] + self._lib.indigoRenderGrid.restype = c_int + self._lib.indigoRenderGrid.argtypes = [ + c_int, + POINTER(c_int), + c_int, + c_int, + ] + self._lib.indigoRenderGridToFile.restype = c_int + self._lib.indigoRenderGridToFile.argtypes = [ + c_int, + POINTER(c_int), + c_int, + c_char_p, + ] + self._lib.indigoRenderReset.restype = c_int + self._lib.indigoRenderReset.argtypes = [c_int] + + def renderToBuffer(self, obj): + self.indigo._setSessionId() + wb = self.indigo.writeBuffer() + try: + self.indigo._checkResult(self._lib.indigoRender(obj.id, wb.id)) + return wb.toBuffer() + finally: + wb.dispose() + + def renderToFile(self, obj, filename): + self.indigo._setSessionId() + self.indigo._checkResult( + self._lib.indigoRenderToFile(obj.id, filename.encode("ascii")) + ) + + def renderGridToFile(self, objects, refatoms, ncolumns, filename): + self.indigo._setSessionId() + arr = None + if refatoms: + if len(refatoms) != objects.count(): + raise IndigoException( + "renderGridToFile(): refatoms[] size must be equal to the number of objects" + ) + arr = (c_int * len(refatoms))() + for i in range(len(refatoms)): + arr[i] = refatoms[i] + self.indigo._checkResult( + self._lib.indigoRenderGridToFile( + objects.id, arr, ncolumns, filename.encode("ascii") + ) + ) + + def renderGridToBuffer(self, objects, refatoms, ncolumns): + self.indigo._setSessionId() + arr = None + if refatoms: + if len(refatoms) != objects.count(): + raise IndigoException( + "renderGridToBuffer(): refatoms[] size must be equal to the number of objects" + ) + arr = (c_int * len(refatoms))() + for i in range(len(refatoms)): + arr[i] = refatoms[i] + wb = self.indigo.writeBuffer() + try: + self.indigo._checkResult( + self._lib.indigoRenderGrid(objects.id, arr, ncolumns, wb.id) + ) + return wb.toBuffer() + finally: + wb.dispose() diff --git a/molscribe/inference/__init__.py b/molscribe/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c61feef15b9a94dbb97126be593f0c445f1870c0 --- /dev/null +++ b/molscribe/inference/__init__.py @@ -0,0 +1,4 @@ +from .greedy_search import GreedySearch +from .beam_search import BeamSearch + +__all__ = ["GreedySearch", "BeamSearch"] diff --git a/molscribe/inference/__pycache__/__init__.cpython-310.pyc b/molscribe/inference/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c92294d56612388c41f9b5b76420f803dc009d24 Binary files /dev/null and b/molscribe/inference/__pycache__/__init__.cpython-310.pyc differ diff --git a/molscribe/inference/__pycache__/beam_search.cpython-310.pyc b/molscribe/inference/__pycache__/beam_search.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b0ea6af141eda69248f3ac82bd447a0e4b59019 Binary files /dev/null and b/molscribe/inference/__pycache__/beam_search.cpython-310.pyc differ diff --git a/molscribe/inference/__pycache__/decode_strategy.cpython-310.pyc b/molscribe/inference/__pycache__/decode_strategy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1927e5d16b3f89ea52bbecf14b2dfedf64bd4405 Binary files /dev/null and b/molscribe/inference/__pycache__/decode_strategy.cpython-310.pyc differ diff --git a/molscribe/inference/__pycache__/greedy_search.cpython-310.pyc b/molscribe/inference/__pycache__/greedy_search.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2917f052652959690237909ac09ab67350d0fa3e Binary files /dev/null and b/molscribe/inference/__pycache__/greedy_search.cpython-310.pyc differ diff --git a/molscribe/inference/beam_search.py b/molscribe/inference/beam_search.py new file mode 100644 index 0000000000000000000000000000000000000000..60e7df337f92d51ddceceb7a51070aae92e3db08 --- /dev/null +++ b/molscribe/inference/beam_search.py @@ -0,0 +1,190 @@ +import torch +from .decode_strategy import DecodeStrategy + + +class BeamSearch(DecodeStrategy): + """Generation with beam search. + """ + + def __init__(self, pad, bos, eos, batch_size, beam_size, n_best, min_length, + return_attention, max_length): + super(BeamSearch, self).__init__( + pad, bos, eos, batch_size, beam_size, min_length, return_attention, max_length) + self.beam_size = beam_size + self.n_best = n_best + + # result caching + self.hypotheses = [[] for _ in range(batch_size)] + + # beam state + self.top_beam_finished = torch.zeros([batch_size], dtype=torch.bool) + + self._batch_offset = torch.arange(batch_size, dtype=torch.long) + + self.select_indices = None + self.done = False + + def initialize(self, memory_bank, device=None): + """Repeat src objects `beam_size` times. + """ + + def fn_map_state(state, dim): + return torch.repeat_interleave(state, self.beam_size, dim=dim) + + memory_bank = torch.repeat_interleave(memory_bank, self.beam_size, dim=0) + if device is None: + device = memory_bank.device + + self.memory_length = memory_bank.size(1) + super().initialize(memory_bank, device) + + self.best_scores = torch.full([self.batch_size], -1e10, dtype=torch.float, device=device) + self._beam_offset = torch.arange( + 0, self.batch_size * self.beam_size, step=self.beam_size, dtype=torch.long, device=device) + self.topk_log_probs = torch.tensor( + [0.0] + [float("-inf")] * (self.beam_size - 1), device=device + ).repeat(self.batch_size) + # buffers for the topk scores and 'backpointer' + self.topk_scores = torch.empty((self.batch_size, self.beam_size), dtype=torch.float, device=device) + self.topk_ids = torch.empty((self.batch_size, self.beam_size), dtype=torch.long, device=device) + self._batch_index = torch.empty([self.batch_size, self.beam_size], dtype=torch.long, device=device) + + return fn_map_state, memory_bank + + @property + def current_predictions(self): + return self.alive_seq[:, -1] + + @property + def current_backptr(self): + # for testing + return self.select_indices.view(self.batch_size, self.beam_size) + + @property + def batch_offset(self): + return self._batch_offset + + def _pick(self, log_probs): + """Return token decision for a step. + + Args: + log_probs (FloatTensor): (B, vocab_size) + + Returns: + topk_scores (FloatTensor): (B, beam_size) + topk_ids (LongTensor): (B, beam_size) + """ + vocab_size = log_probs.size(-1) + + # Flatten probs into a list of probabilities. + curr_scores = log_probs.reshape(-1, self.beam_size * vocab_size) + topk_scores, topk_ids = torch.topk(curr_scores, self.beam_size, dim=-1) + return topk_scores, topk_ids + + def advance(self, log_probs, attn): + """ + Args: + log_probs: (B * beam_size, vocab_size) + """ + vocab_size = log_probs.size(-1) + + # (non-finished) batch_size + _B = log_probs.shape[0] // self.beam_size + + step = len(self) # alive_seq + self.ensure_min_length(log_probs) + + # Multiply probs by the beam probability + log_probs += self.topk_log_probs.view(_B * self.beam_size, 1) + + curr_length = step + 1 + curr_scores = log_probs / curr_length # avg log_prob + self.topk_scores, self.topk_ids = self._pick(curr_scores) + # topk_scores/topk_ids: (batch_size, beam_size) + + # Recover log probs + torch.mul(self.topk_scores, curr_length, out=self.topk_log_probs) + + # Resolve beam origin and map to batch index flat representation. + self._batch_index = self.topk_ids // vocab_size + self._batch_index += self._beam_offset[:_B].unsqueeze(1) + self.select_indices = self._batch_index.view(_B * self.beam_size) + self.topk_ids.fmod_(vocab_size) # resolve true word ids + + # Append last prediction. + self.alive_seq = torch.cat( + [self.alive_seq.index_select(0, self.select_indices), + self.topk_ids.view(_B * self.beam_size, 1)], -1) + + if self.return_attention: + current_attn = attn.index_select(1, self.select_indices) + if step == 1: + self.alive_attn = current_attn + else: + self.alive_attn = self.alive_attn.index_select( + 1, self.select_indices) + self.alive_attn = torch.cat([self.alive_attn, current_attn], 0) + + self.is_finished = self.topk_ids.eq(self.eos) + self.ensure_max_length() + + def update_finished(self): + _B_old = self.topk_log_probs.shape[0] + step = self.alive_seq.shape[-1] # len(self) + self.topk_log_probs.masked_fill_(self.is_finished, -1e10) + + self.is_finished = self.is_finished.to('cpu') + self.top_beam_finished |= self.is_finished[:, 0].eq(1) + predictions = self.alive_seq.view(_B_old, self.beam_size, step) + attention = ( + self.alive_attn.view( + step - 1, _B_old, self.beam_size, self.alive_attn.size(-1)) + if self.alive_attn is not None else None) + non_finished_batch = [] + for i in range(self.is_finished.size(0)): + b = self._batch_offset[i] + finished_hyp = self.is_finished[i].nonzero(as_tuple=False).view(-1) + # Store finished hypothesis for this batch. + for j in finished_hyp: # Beam level: finished beam j in batch i + self.hypotheses[b].append(( + self.topk_scores[i, j], + predictions[i, j, 1:], # Ignore start token + attention[:, i, j, :self.memory_length] + if attention is not None else None)) + # End condition is the top beam finished and we can return + # n_best hypotheses. + finish_flag = self.top_beam_finished[i] != 0 + if finish_flag and len(self.hypotheses[b]) >= self.n_best: + best_hyp = sorted( + self.hypotheses[b], key=lambda x: x[0], reverse=True) + for n, (score, pred, attn) in enumerate(best_hyp): + if n >= self.n_best: + break + self.scores[b].append(score.item()) + self.predictions[b].append(pred) + self.attention[b].append( + attn if attn is not None else []) + else: + non_finished_batch.append(i) + non_finished = torch.tensor(non_finished_batch) + + if len(non_finished) == 0: + self.done = True + return + + _B_new = non_finished.shape[0] + # Remove finished batches for the next step + self.top_beam_finished = self.top_beam_finished.index_select(0, non_finished) + self._batch_offset = self._batch_offset.index_select(0, non_finished) + non_finished = non_finished.to(self.topk_ids.device) + self.topk_log_probs = self.topk_log_probs.index_select(0, non_finished) + self._batch_index = self._batch_index.index_select(0, non_finished) + self.select_indices = self._batch_index.view(_B_new * self.beam_size) + self.alive_seq = predictions.index_select(0, non_finished).view(-1, self.alive_seq.size(-1)) + self.topk_scores = self.topk_scores.index_select(0, non_finished) + self.topk_ids = self.topk_ids.index_select(0, non_finished) + + if self.alive_attn is not None: + inp_seq_len = self.alive_attn.size(-1) + self.alive_attn = attention.index_select(1, non_finished) \ + .view(step - 1, _B_new * self.beam_size, inp_seq_len) diff --git a/molscribe/inference/decode_strategy.py b/molscribe/inference/decode_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..3a92a9d4f69f3f5cf7f5d15c74ccf178351eb48e --- /dev/null +++ b/molscribe/inference/decode_strategy.py @@ -0,0 +1,63 @@ +import torch + + +class DecodeStrategy(object): + def __init__(self, pad, bos, eos, batch_size, parallel_paths, min_length, max_length, + return_attention=False, return_hidden=False): + self.pad = pad + self.bos = bos + self.eos = eos + + self.batch_size = batch_size + self.parallel_paths = parallel_paths + # result catching + self.predictions = [[] for _ in range(batch_size)] + self.scores = [[] for _ in range(batch_size)] + self.token_scores = [[] for _ in range(batch_size)] + self.attention = [[] for _ in range(batch_size)] + self.hidden = [[] for _ in range(batch_size)] + + self.alive_attn = None + self.alive_hidden = None + + self.min_length = min_length + self.max_length = max_length + + n_paths = batch_size * parallel_paths + self.return_attention = return_attention + self.return_hidden = return_hidden + + self.done = False + + def initialize(self, memory_bank, device=None): + if device is None: + device = torch.device('cpu') + self.alive_seq = torch.full( + [self.batch_size * self.parallel_paths, 1], self.bos, + dtype=torch.long, device=device) + self.is_finished = torch.zeros( + [self.batch_size, self.parallel_paths], + dtype=torch.uint8, device=device) + self.alive_log_token_scores = torch.zeros( + [self.batch_size * self.parallel_paths, 0], + dtype=torch.float, device=device) + + return None, memory_bank + + def __len__(self): + return self.alive_seq.shape[1] + + def ensure_min_length(self, log_probs): + if len(self) <= self.min_length: + log_probs[:, self.eos] = -1e20 # forced non-end + + def ensure_max_length(self): + if len(self) == self.max_length + 1: + self.is_finished.fill_(1) + + def advance(self, log_probs, attn): + raise NotImplementedError() + + def update_finished(self): + raise NotImplementedError + diff --git a/molscribe/inference/greedy_search.py b/molscribe/inference/greedy_search.py new file mode 100644 index 0000000000000000000000000000000000000000..55593ea99643a1163429e3c8263fa8105d653166 --- /dev/null +++ b/molscribe/inference/greedy_search.py @@ -0,0 +1,128 @@ +import torch +from .decode_strategy import DecodeStrategy + + +def sample_with_temperature(logits, sampling_temp, keep_topk): + """Select next tokens randomly from the top k possible next tokens. + + Samples from a categorical distribution over the ``keep_topk`` words using + the category probabilities ``logits / sampling_temp``. + """ + + if sampling_temp == 0.0 or keep_topk == 1: + # argmax + topk_scores, topk_ids = logits.topk(1, dim=-1) + if sampling_temp > 0: + topk_scores /= sampling_temp + else: + logits = torch.div(logits, sampling_temp) + if keep_topk > 0: + top_values, top_indices = torch.topk(logits, keep_topk, dim=1) + kth_best = top_values[:, -1].view([-1, 1]) + kth_best = kth_best.repeat([1, logits.shape[1]]).float() + ignore = torch.lt(logits, kth_best) + logits = logits.masked_fill(ignore, -10000) + + dist = torch.distributions.Multinomial(logits=logits, total_count=1) + topk_ids = torch.argmax(dist.sample(), dim=1, keepdim=True) + topk_scores = logits.gather(dim=1, index=topk_ids) + + return topk_ids, topk_scores + + +class GreedySearch(DecodeStrategy): + """Select next tokens randomly from the top k possible next tokens. + """ + + def __init__(self, pad, bos, eos, batch_size, min_length, max_length, + return_attention=False, return_hidden=False, sampling_temp=1, keep_topk=1): + super().__init__( + pad, bos, eos, batch_size, 1, min_length, max_length, return_attention, return_hidden) + self.sampling_temp = sampling_temp + self.keep_topk = keep_topk + self.topk_scores = None + + def initialize(self, memory_bank, device=None): + fn_map_state = None + + if device is None: + device = memory_bank.device + + self.memory_length = memory_bank.size(1) + super().initialize(memory_bank, device) + + self.select_indices = torch.arange( + self.batch_size, dtype=torch.long, device=device) + self.original_batch_idx = torch.arange( + self.batch_size, dtype=torch.long, device=device) + + return fn_map_state, memory_bank + + @property + def current_predictions(self): + return self.alive_seq[:, -1] + + @property + def batch_offset(self): + return self.select_indices + + def _pick(self, log_probs): + """Function used to pick next tokens. + """ + topk_ids, topk_scores = sample_with_temperature( + log_probs, self.sampling_temp, self.keep_topk) + return topk_ids, topk_scores + + def advance(self, log_probs, attn=None, hidden=None, label=None): + """Select next tokens randomly from the top k possible next tokens. + """ + self.ensure_min_length(log_probs) + topk_ids, self.topk_scores = self._pick(log_probs) # log_probs: b x v; topk_ids & self.topk_scores: b x (t=1) + self.is_finished = topk_ids.eq(self.eos) + if label is not None: + label = label.view_as(self.is_finished) + self.is_finished = label.eq(self.eos) + self.alive_seq = torch.cat([self.alive_seq, topk_ids], -1) # b x (l+1) (first element is <bos>; note l = len(self)-1) + self.alive_log_token_scores = torch.cat([self.alive_log_token_scores, self.topk_scores], -1) + + if self.return_attention: + if self.alive_attn is None: + self.alive_attn = attn + else: + self.alive_attn = torch.cat([self.alive_attn, attn], 1) + if self.return_hidden: + if self.alive_hidden is None: + self.alive_hidden = hidden + else: + self.alive_hidden = torch.cat([self.alive_hidden, hidden], 1) # b x l x h + self.ensure_max_length() + + def update_finished(self): + """Finalize scores and predictions.""" + # is_finished indicates the decoder finished generating the sequence. Remove it from the batch and update + # the results. + finished_batches = self.is_finished.view(-1).nonzero() + for b in finished_batches.view(-1): + b_orig = self.original_batch_idx[b] + # scores/predictions/attention are lists, + # (to be compatible with beam-search) + self.scores[b_orig].append(torch.exp(torch.mean(self.alive_log_token_scores[b])).item()) + self.token_scores[b_orig].append(torch.exp(self.alive_log_token_scores[b]).tolist()) + self.predictions[b_orig].append(self.alive_seq[b, 1:]) # skip <bos> + self.attention[b_orig].append( + self.alive_attn[b, :, :self.memory_length] if self.alive_attn is not None else []) + self.hidden[b_orig].append( + self.alive_hidden[b, :] if self.alive_hidden is not None else []) + self.done = self.is_finished.all() + if self.done: + return + is_alive = ~self.is_finished.view(-1) + self.alive_seq = self.alive_seq[is_alive] + self.alive_log_token_scores = self.alive_log_token_scores[is_alive] + if self.alive_attn is not None: + self.alive_attn = self.alive_attn[is_alive] + if self.alive_hidden is not None: + self.alive_hidden = self.alive_hidden[is_alive] + self.select_indices = is_alive.nonzero().view(-1) + self.original_batch_idx = self.original_batch_idx[is_alive] + # select_indices is equal to original_batch_idx for greedy search? diff --git a/molscribe/interface.py b/molscribe/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..9d9bccfa2fd9029db5970e8fdb54d7e4cf6cd704 --- /dev/null +++ b/molscribe/interface.py @@ -0,0 +1,222 @@ +import argparse +from typing import List + +import cv2 +import torch +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.backends.backend_agg import FigureCanvasAgg + +from .dataset import get_transforms +from .model import Encoder, Decoder +from .chemistry import convert_graph_to_smiles +from .tokenizer import get_tokenizer + + +BOND_TYPES = ["", "single", "double", "triple", "aromatic", "solid wedge", "dashed wedge"] + + +def safe_load(module, module_states): + def remove_prefix(state_dict): + return {k.replace('module.', ''): v for k, v in state_dict.items()} + missing_keys, unexpected_keys = module.load_state_dict(remove_prefix(module_states), strict=False) + return + + +class MolScribe: + + def __init__(self, model_path, device=None, num_workers=1): + """ + MolScribe Interface + :param model_path: path of the model checkpoint. + :param device: torch device, defaults to be CPU. + :param multiprocessing_enabled: uses multiprocessing to parallelize parts of the inference when enabled, defaults to False. + """ + model_states = torch.load(model_path, map_location=torch.device('cpu')) + args = self._get_args(model_states['args']) + if device is None: + device = torch.device('cpu') + self.device = device + self.tokenizer = get_tokenizer(args) + self.encoder, self.decoder = self._get_model(args, self.tokenizer, self.device, model_states) + self.transform = get_transforms(args.input_size, augment=False) + self.num_workers = num_workers + + def _get_args(self, args_states=None): + parser = argparse.ArgumentParser() + # Model + parser.add_argument('--encoder', type=str, default='swin_base') + parser.add_argument('--decoder', type=str, default='transformer') + parser.add_argument('--trunc_encoder', action='store_true') # use the hidden states before downsample + parser.add_argument('--no_pretrained', action='store_true') + parser.add_argument('--use_checkpoint', action='store_true', default=True) + parser.add_argument('--dropout', type=float, default=0.5) + parser.add_argument('--embed_dim', type=int, default=256) + parser.add_argument('--enc_pos_emb', action='store_true') + group = parser.add_argument_group("transformer_options") + group.add_argument("--dec_num_layers", help="No. of layers in transformer decoder", type=int, default=6) + group.add_argument("--dec_hidden_size", help="Decoder hidden size", type=int, default=256) + group.add_argument("--dec_attn_heads", help="Decoder no. of attention heads", type=int, default=8) + group.add_argument("--dec_num_queries", type=int, default=128) + group.add_argument("--hidden_dropout", help="Hidden dropout", type=float, default=0.1) + group.add_argument("--attn_dropout", help="Attention dropout", type=float, default=0.1) + group.add_argument("--max_relative_positions", help="Max relative positions", type=int, default=0) + parser.add_argument('--continuous_coords', action='store_true') + parser.add_argument('--compute_confidence', action='store_true') + # Data + parser.add_argument('--input_size', type=int, default=384) + parser.add_argument('--vocab_file', type=str, default=None) + parser.add_argument('--coord_bins', type=int, default=64) + parser.add_argument('--sep_xy', action='store_true', default=True) + + args = parser.parse_args([]) + if args_states: + for key, value in args_states.items(): + args.__dict__[key] = value + return args + + def _get_model(self, args, tokenizer, device, states): + encoder = Encoder(args, pretrained=False) + args.encoder_dim = encoder.n_features + decoder = Decoder(args, tokenizer) + + safe_load(encoder, states['encoder']) + safe_load(decoder, states['decoder']) + # print(f"Model loaded from {load_path}") + + encoder.to(device) + decoder.to(device) + encoder.eval() + decoder.eval() + return encoder, decoder + + def predict_images(self, input_images: List, return_atoms_bonds=False, return_confidence=False, batch_size=16): + device = self.device + predictions = [] + self.decoder.compute_confidence = return_confidence + + for idx in range(0, len(input_images), batch_size): + batch_images = input_images[idx:idx+batch_size] + images = [self.transform(image=image, keypoints=[])['image'] for image in batch_images] + images = torch.stack(images, dim=0).to(device) + with torch.no_grad(): + features, hiddens = self.encoder(images) + batch_predictions = self.decoder.decode(features, hiddens) + predictions += batch_predictions + + smiles = [pred['chartok_coords']['smiles'] for pred in predictions] + node_coords = [pred['chartok_coords']['coords'] for pred in predictions] + node_symbols = [pred['chartok_coords']['symbols'] for pred in predictions] + edges = [pred['edges'] for pred in predictions] + + smiles_list, molblock_list, r_success = convert_graph_to_smiles( + node_coords, node_symbols, edges, images=input_images, num_workers=self.num_workers) + + outputs = [] + for smiles, molblock, pred in zip(smiles_list, molblock_list, predictions): + #pred_dict = {"smiles": smiles, "molfile": molblock} + #pred_dict = {"smiles": smiles,"original_symbols": pred['chartok_coords']['symbols'], "molfile": molblock} + pred_dict = {"smiles": smiles, "symbols": pred['chartok_coords']['symbols'], "coords": pred['chartok_coords']['coords'], "edges": pred['edges'], "molfile": molblock} + if return_confidence: + pred_dict["confidence"] = pred["overall_score"] + if return_atoms_bonds: + coords = pred['chartok_coords']['coords'] + symbols = pred['chartok_coords']['symbols'] + # get atoms info + atom_list = [] + for i, (symbol, coord) in enumerate(zip(symbols, coords)): + atom_dict = {"atom_symbol": symbol, "x": round(coord[0],3), "y": round(coord[1],3)} + if return_confidence: + atom_dict["confidence"] = pred['chartok_coords']['atom_scores'][i] + atom_list.append(atom_dict) + pred_dict["atoms"] = atom_list + # get bonds info + bond_list = [] + num_atoms = len(symbols) + for i in range(num_atoms-1): + for j in range(i+1, num_atoms): + bond_type_int = pred['edges'][i][j] + if bond_type_int != 0: + bond_type_str = BOND_TYPES[bond_type_int] + bond_dict = {"bond_type": bond_type_str, "endpoint_atoms": (i, j)} + if return_confidence: + bond_dict["confidence"] = pred["edge_scores"][i][j] + bond_list.append(bond_dict) + pred_dict["bonds"] = bond_list + outputs.append(pred_dict) + return outputs + + def predict_image(self, image, return_atoms_bonds=False, return_confidence=False): + return self.predict_images([ + image], return_atoms_bonds=return_atoms_bonds, return_confidence=return_confidence)[0] + + def predict_image_files(self, image_files: List, return_atoms_bonds=False, return_confidence=False): + input_images = [] + for path in image_files: + image = cv2.imread(path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + input_images.append(image) + return self.predict_images( + input_images, return_atoms_bonds=return_atoms_bonds, return_confidence=return_confidence) + + def predict_image_file(self, image_file: str, return_atoms_bonds=False, return_confidence=False): + return self.predict_image_files( + [image_file], return_atoms_bonds=return_atoms_bonds, return_confidence=return_confidence)[0] + + def draw_prediction(self, prediction, image, notebook=False): + if "atoms" not in prediction or "bonds" not in prediction: + raise ValueError("atoms and bonds information are not provided.") + h, w, _ = image.shape + h, w = np.array([h, w]) * 400 / max(h, w) + image = cv2.resize(image, (int(w), int(h))) + fig, ax = plt.subplots(1, 1) + ax.axis('off') + ax.set_xlim(-0.05 * w, w * 1.05) + ax.set_ylim(1.05 * h, -0.05 * h) + plt.imshow(image, alpha=0.) + x = [a['x'] * w for a in prediction['atoms']] + y = [a['y'] * h for a in prediction['atoms']] + markersize = min(w, h) / 3 + plt.scatter(x, y, marker='o', s=markersize, color='lightskyblue', zorder=10) + for i, atom in enumerate(prediction['atoms']): + symbol = atom['atom_symbol'].lstrip('[').rstrip(']') + plt.annotate(symbol, xy=(x[i], y[i]), ha='center', va='center', color='black', zorder=100) + for bond in prediction['bonds']: + u, v = bond['endpoint_atoms'] + x1, y1, x2, y2 = x[u], y[u], x[v], y[v] + bond_type = bond['bond_type'] + if bond_type == 'single': + color = 'tab:green' + ax.plot([x1, x2], [y1, y2], color, linewidth=4) + elif bond_type == 'aromatic': + color = 'tab:purple' + ax.plot([x1, x2], [y1, y2], color, linewidth=4) + elif bond_type == 'double': + color = 'tab:green' + ax.plot([x1, x2], [y1, y2], color=color, linewidth=7) + ax.plot([x1, x2], [y1, y2], color='w', linewidth=1.5, zorder=2.1) + elif bond_type == 'triple': + color = 'tab:green' + x1s, x2s = 0.8 * x1 + 0.2 * x2, 0.2 * x1 + 0.8 * x2 + y1s, y2s = 0.8 * y1 + 0.2 * y2, 0.2 * y1 + 0.8 * y2 + ax.plot([x1s, x2s], [y1s, y2s], color=color, linewidth=9) + ax.plot([x1, x2], [y1, y2], color='w', linewidth=5, zorder=2.05) + ax.plot([x1, x2], [y1, y2], color=color, linewidth=2, zorder=2.1) + else: + length = 10 + width = 10 + color = 'tab:green' + if bond_type == 'solid wedge': + ax.annotate('', xy=(x1, y1), xytext=(x2, y2), + arrowprops=dict(color=color, width=3, headwidth=width, headlength=length), zorder=2) + else: + ax.annotate('', xy=(x2, y2), xytext=(x1, y1), + arrowprops=dict(color=color, width=3, headwidth=width, headlength=length), zorder=2) + fig.tight_layout() + if not notebook: + canvas = FigureCanvasAgg(fig) + canvas.draw() + buf = canvas.buffer_rgba() + result_image = np.asarray(buf) + plt.close(fig) + return result_image diff --git a/molscribe/loss.py b/molscribe/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..24d7dc33c24e5be8d36c9a5b11f504cd9e64b389 --- /dev/null +++ b/molscribe/loss.py @@ -0,0 +1,125 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from scipy.optimize import linear_sum_assignment +from .tokenizer import PAD_ID, MASK, MASK_ID + + +class LabelSmoothingLoss(nn.Module): + """ + With label smoothing, + KL-divergence between q_{smoothed ground truth prob.}(w) + and p_{prob. computed by model}(w) is minimized. + """ + def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100): + assert 0.0 < label_smoothing <= 1.0 + self.ignore_index = ignore_index + super(LabelSmoothingLoss, self).__init__() + + smoothing_value = label_smoothing / (tgt_vocab_size - 2) + one_hot = torch.full((tgt_vocab_size,), smoothing_value) + one_hot[self.ignore_index] = 0 + self.register_buffer('one_hot', one_hot.unsqueeze(0)) + + self.confidence = 1.0 - label_smoothing + + def forward(self, output, target): + """ + output (FloatTensor): batch_size x n_classes + target (LongTensor): batch_size + """ + # assuming output is raw logits + # convert to log_probs + log_probs = F.log_softmax(output, dim=-1) + + model_prob = self.one_hot.repeat(target.size(0), 1) + model_prob.scatter_(1, target.unsqueeze(1), self.confidence) + model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0) + + # reduction mean or sum? + return F.kl_div(log_probs, model_prob, reduction='batchmean') + + +class SequenceLoss(nn.Module): + + def __init__(self, label_smoothing, vocab_size, ignore_index=-100, ignore_indices=[]): + super(SequenceLoss, self).__init__() + if ignore_indices: + ignore_index = ignore_indices[0] + self.ignore_index = ignore_index + self.ignore_indices = ignore_indices + if label_smoothing == 0: + self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='mean') + else: + self.criterion = LabelSmoothingLoss(label_smoothing, vocab_size, ignore_index) + + def forward(self, output, target): + """ + :param output: [batch, len, vocab] + :param target: [batch, len] + :return: + """ + batch_size, max_len, vocab_size = output.size() + output = output.reshape(-1, vocab_size) + target = target.reshape(-1) + for idx in self.ignore_indices: + if idx != self.ignore_index: + target.masked_fill_((target == idx), self.ignore_index) + loss = self.criterion(output, target) + return loss + + +class GraphLoss(nn.Module): + + def __init__(self): + super(GraphLoss, self).__init__() + weight = torch.ones(7) * 10 + weight[0] = 1 + self.criterion = nn.CrossEntropyLoss(weight, ignore_index=-100) + + def forward(self, outputs, targets): + results = {} + if 'coords' in outputs: + pred = outputs['coords'] + max_len = pred.size(1) + target = targets['coords'][:, :max_len] + mask = target.ge(0) + loss = F.l1_loss(pred, target, reduction='none') + results['coords'] = (loss * mask).sum() / mask.sum() + if 'edges' in outputs: + pred = outputs['edges'] + max_len = pred.size(-1) + target = targets['edges'][:, :max_len, :max_len] + results['edges'] = self.criterion(pred, target) + return results + + +class Criterion(nn.Module): + + def __init__(self, args, tokenizer): + super(Criterion, self).__init__() + criterion = {} + for format_ in args.formats: + if format_ == 'edges': + criterion['edges'] = GraphLoss() + else: + if MASK in tokenizer[format_].stoi: + ignore_indices = [PAD_ID, MASK_ID] + else: + ignore_indices = [] + criterion[format_] = SequenceLoss(args.label_smoothing, len(tokenizer[format_]), + ignore_index=PAD_ID, ignore_indices=ignore_indices) + self.criterion = nn.ModuleDict(criterion) + + def forward(self, results, refs): + losses = {} + for format_ in results: + predictions, targets, *_ = results[format_] + loss_ = self.criterion[format_](predictions, targets) + if type(loss_) is dict: + losses.update(loss_) + else: + if loss_.numel() > 1: + loss_ = loss_.mean() + losses[format_] = loss_ + return losses diff --git a/molscribe/model.py b/molscribe/model.py new file mode 100644 index 0000000000000000000000000000000000000000..bff85e087d679daf2b25df787950fb72a0c5fcf8 --- /dev/null +++ b/molscribe/model.py @@ -0,0 +1,397 @@ +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import timm + +from .utils import FORMAT_INFO, to_device +from .tokenizer import SOS_ID, EOS_ID, PAD_ID, MASK_ID +from .inference import GreedySearch, BeamSearch +from .transformer import TransformerDecoder, Embeddings + + +class Encoder(nn.Module): + def __init__(self, args, pretrained=False): + super().__init__() + model_name = args.encoder + self.model_name = model_name + if model_name.startswith('resnet'): + self.model_type = 'resnet' + self.cnn = timm.create_model(model_name, pretrained=pretrained) + self.n_features = self.cnn.num_features # encoder_dim + self.cnn.global_pool = nn.Identity() + self.cnn.fc = nn.Identity() + elif model_name.startswith('swin'): + self.model_type = 'swin' + self.transformer = timm.create_model(model_name, pretrained=pretrained, pretrained_strict=False, + use_checkpoint=args.use_checkpoint) + self.n_features = self.transformer.num_features + self.transformer.head = nn.Identity() + elif 'efficientnet' in model_name: + self.model_type = 'efficientnet' + self.cnn = timm.create_model(model_name, pretrained=pretrained) + self.n_features = self.cnn.num_features + self.cnn.global_pool = nn.Identity() + self.cnn.classifier = nn.Identity() + else: + raise NotImplemented + + def swin_forward(self, transformer, x): + x = transformer.patch_embed(x) + if transformer.absolute_pos_embed is not None: + x = x + transformer.absolute_pos_embed + x = transformer.pos_drop(x) + + def layer_forward(layer, x, hiddens): + for blk in layer.blocks: + if not torch.jit.is_scripting() and layer.use_checkpoint: + x = torch.utils.checkpoint.checkpoint(blk, x) + else: + x = blk(x) + H, W = layer.input_resolution + B, L, C = x.shape + hiddens.append(x.view(B, H, W, C)) + if layer.downsample is not None: + x = layer.downsample(x) + return x, hiddens + + hiddens = [] + for layer in transformer.layers: + x, hiddens = layer_forward(layer, x, hiddens) + x = transformer.norm(x) # B L C + hiddens[-1] = x.view_as(hiddens[-1]) + return x, hiddens + + def forward(self, x, refs=None): + if self.model_type in ['resnet', 'efficientnet']: + features = self.cnn(x) + features = features.permute(0, 2, 3, 1) + hiddens = [] + elif self.model_type == 'swin': + if 'patch' in self.model_name: + features, hiddens = self.swin_forward(self.transformer, x) + else: + features, hiddens = self.transformer(x) + else: + raise NotImplemented + return features, hiddens + + +class TransformerDecoderBase(nn.Module): + + def __init__(self, args): + super().__init__() + self.args = args + + self.enc_trans_layer = nn.Sequential( + nn.Linear(args.encoder_dim, args.dec_hidden_size) + # nn.LayerNorm(args.dec_hidden_size, eps=1e-6) + ) + self.enc_pos_emb = nn.Embedding(144, args.encoder_dim) if args.enc_pos_emb else None + + self.decoder = TransformerDecoder( + num_layers=args.dec_num_layers, + d_model=args.dec_hidden_size, + heads=args.dec_attn_heads, + d_ff=args.dec_hidden_size * 4, + copy_attn=False, + self_attn_type="scaled-dot", + dropout=args.hidden_dropout, + attention_dropout=args.attn_dropout, + max_relative_positions=args.max_relative_positions, + aan_useffn=False, + full_context_alignment=False, + alignment_layer=0, + alignment_heads=0, + pos_ffn_activation_fn='gelu' + ) + + def enc_transform(self, encoder_out): + batch_size = encoder_out.size(0) + encoder_dim = encoder_out.size(-1) + encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim) + max_len = encoder_out.size(1) + device = encoder_out.device + if self.enc_pos_emb: + pos_emb = self.enc_pos_emb(torch.arange(max_len, device=device)).unsqueeze(0) + encoder_out = encoder_out + pos_emb + encoder_out = self.enc_trans_layer(encoder_out) + return encoder_out + + +class TransformerDecoderAR(TransformerDecoderBase): + """Autoregressive Transformer Decoder""" + + def __init__(self, args, tokenizer): + super().__init__(args) + self.tokenizer = tokenizer + self.vocab_size = len(self.tokenizer) + self.output_layer = nn.Linear(args.dec_hidden_size, self.vocab_size, bias=True) + self.embeddings = Embeddings( + word_vec_size=args.dec_hidden_size, + word_vocab_size=self.vocab_size, + word_padding_idx=PAD_ID, + position_encoding=True, + dropout=args.hidden_dropout) + + def dec_embedding(self, tgt, step=None): + pad_idx = self.embeddings.word_padding_idx + tgt_pad_mask = tgt.data.eq(pad_idx).transpose(1, 2) # [B, 1, T_tgt] + emb = self.embeddings(tgt, step=step) + assert emb.dim() == 3 # batch x len x embedding_dim + return emb, tgt_pad_mask + + def forward(self, encoder_out, labels, label_lengths): + """Training mode""" + batch_size, max_len, _ = encoder_out.size() + memory_bank = self.enc_transform(encoder_out) + + tgt = labels.unsqueeze(-1) # (b, t, 1) + tgt_emb, tgt_pad_mask = self.dec_embedding(tgt) + dec_out, *_ = self.decoder(tgt_emb=tgt_emb, memory_bank=memory_bank, tgt_pad_mask=tgt_pad_mask) + + logits = self.output_layer(dec_out) # (b, t, h) -> (b, t, v) + return logits[:, :-1], labels[:, 1:], dec_out + + def decode(self, encoder_out, beam_size: int, n_best: int, min_length: int = 1, max_length: int = 256, + labels=None): + """Inference mode. Autoregressively decode the sequence. Only greedy search is supported now. Beam search is + out-dated. The labels is used for partial prediction, i.e. part of the sequence is given. In standard decoding, + labels=None.""" + batch_size, max_len, _ = encoder_out.size() + memory_bank = self.enc_transform(encoder_out) + orig_labels = labels + + if beam_size == 1: + decode_strategy = GreedySearch( + sampling_temp=0.0, keep_topk=1, batch_size=batch_size, min_length=min_length, max_length=max_length, + pad=PAD_ID, bos=SOS_ID, eos=EOS_ID, + return_attention=False, return_hidden=True) + else: + decode_strategy = BeamSearch( + beam_size=beam_size, n_best=n_best, batch_size=batch_size, min_length=min_length, max_length=max_length, + pad=PAD_ID, bos=SOS_ID, eos=EOS_ID, + return_attention=False) + + # adapted from onmt.translate.translator + results = { + "predictions": None, + "scores": None, + "attention": None + } + + # (2) prep decode_strategy. Possibly repeat src objects. + _, memory_bank = decode_strategy.initialize(memory_bank=memory_bank) + + # (3) Begin decoding step by step: + for step in range(decode_strategy.max_length): + tgt = decode_strategy.current_predictions.view(-1, 1, 1) + if labels is not None: + label = labels[:, step].view(-1, 1, 1) + mask = label.eq(MASK_ID).long() + tgt = tgt * mask + label * (1 - mask) + tgt_emb, tgt_pad_mask = self.dec_embedding(tgt) + dec_out, dec_attn, *_ = self.decoder(tgt_emb=tgt_emb, memory_bank=memory_bank, + tgt_pad_mask=tgt_pad_mask, step=step) + + attn = dec_attn.get("std", None) + + dec_logits = self.output_layer(dec_out) # [b, t, h] => [b, t, v] + dec_logits = dec_logits.squeeze(1) + log_probs = F.log_softmax(dec_logits, dim=-1) + + if self.tokenizer.output_constraint: + output_mask = [self.tokenizer.get_output_mask(id) for id in tgt.view(-1).tolist()] + output_mask = torch.tensor(output_mask, device=log_probs.device) + log_probs.masked_fill_(output_mask, -10000) + + label = labels[:, step + 1] if labels is not None and step + 1 < labels.size(1) else None + decode_strategy.advance(log_probs, attn, dec_out, label) + any_finished = decode_strategy.is_finished.any() + if any_finished: + decode_strategy.update_finished() + if decode_strategy.done: + break + + select_indices = decode_strategy.select_indices + if any_finished: + # Reorder states. + memory_bank = memory_bank.index_select(0, select_indices) + if labels is not None: + labels = labels.index_select(0, select_indices) + self.map_state(lambda state, dim: state.index_select(dim, select_indices)) + + results["scores"] = decode_strategy.scores # fixed to be average of token scores + results["token_scores"] = decode_strategy.token_scores + results["predictions"] = decode_strategy.predictions + results["attention"] = decode_strategy.attention + results["hidden"] = decode_strategy.hidden + if orig_labels is not None: + for i in range(batch_size): + pred = results["predictions"][i][0] + label = orig_labels[i][1:len(pred) + 1] + mask = label.eq(MASK_ID).long() + pred = pred[:len(label)] + results["predictions"][i][0] = pred * mask + label * (1 - mask) + + return results["predictions"], results['scores'], results["token_scores"], results["hidden"] + + # adapted from onmt.decoders.transformer + def map_state(self, fn): + def _recursive_map(struct, batch_dim=0): + for k, v in struct.items(): + if v is not None: + if isinstance(v, dict): + _recursive_map(v) + else: + struct[k] = fn(v, batch_dim) + + if self.decoder.state["cache"] is not None: + _recursive_map(self.decoder.state["cache"]) + + +class GraphPredictor(nn.Module): + + def __init__(self, decoder_dim, coords=False): + super(GraphPredictor, self).__init__() + self.coords = coords + self.mlp = nn.Sequential( + nn.Linear(decoder_dim * 2, decoder_dim), nn.GELU(), + nn.Linear(decoder_dim, 7) + ) + if coords: + self.coords_mlp = nn.Sequential( + nn.Linear(decoder_dim, decoder_dim), nn.GELU(), + nn.Linear(decoder_dim, 2) + ) + + def forward(self, hidden, indices=None): + b, l, dim = hidden.size() + if indices is None: + index = [i for i in range(3, l, 3)] + hidden = hidden[:, index] + else: + batch_id = torch.arange(b).unsqueeze(1).expand_as(indices).reshape(-1) + indices = indices.view(-1) + hidden = hidden[batch_id, indices].view(b, -1, dim) + b, l, dim = hidden.size() + results = {} + hh = torch.cat([hidden.unsqueeze(2).expand(b, l, l, dim), hidden.unsqueeze(1).expand(b, l, l, dim)], dim=3) + results['edges'] = self.mlp(hh).permute(0, 3, 1, 2) + if self.coords: + results['coords'] = self.coords_mlp(hidden) + return results + + +def get_edge_prediction(edge_prob): + if not edge_prob: + return [], [] + n = len(edge_prob) + if n == 0: + return [], [] + for i in range(n): + for j in range(i + 1, n): + for k in range(5): + edge_prob[i][j][k] = (edge_prob[i][j][k] + edge_prob[j][i][k]) / 2 + edge_prob[j][i][k] = edge_prob[i][j][k] + edge_prob[i][j][5] = (edge_prob[i][j][5] + edge_prob[j][i][6]) / 2 + edge_prob[i][j][6] = (edge_prob[i][j][6] + edge_prob[j][i][5]) / 2 + edge_prob[j][i][5] = edge_prob[i][j][6] + edge_prob[j][i][6] = edge_prob[i][j][5] + prediction = np.argmax(edge_prob, axis=2).tolist() + score = np.max(edge_prob, axis=2).tolist() + return prediction, score + + +class Decoder(nn.Module): + """This class is a wrapper for different decoder architectures, and support multiple decoders.""" + + def __init__(self, args, tokenizer): + super(Decoder, self).__init__() + self.args = args + self.formats = args.formats + self.tokenizer = tokenizer + decoder = {} + for format_ in args.formats: + if format_ == 'edges': + decoder['edges'] = GraphPredictor(args.dec_hidden_size, coords=args.continuous_coords) + else: + decoder[format_] = TransformerDecoderAR(args, tokenizer[format_]) + self.decoder = nn.ModuleDict(decoder) + self.compute_confidence = args.compute_confidence + + def forward(self, encoder_out, hiddens, refs): + """Training mode. Compute the logits with teacher forcing.""" + results = {} + refs = to_device(refs, encoder_out.device) + for format_ in self.formats: + if format_ == 'edges': + if 'atomtok_coords' in results: + dec_out = results['atomtok_coords'][2] + predictions = self.decoder['edges'](dec_out, indices=refs['atom_indices'][0]) + elif 'chartok_coords' in results: + dec_out = results['chartok_coords'][2] + predictions = self.decoder['edges'](dec_out, indices=refs['atom_indices'][0]) + else: + raise NotImplemented + targets = {'edges': refs['edges']} + if 'coords' in predictions: + targets['coords'] = refs['coords'] + results['edges'] = (predictions, targets) + else: + labels, label_lengths = refs[format_] + results[format_] = self.decoder[format_](encoder_out, labels, label_lengths) + return results + + def decode(self, encoder_out, hiddens=None, refs=None, beam_size=1, n_best=1): + """Inference mode. Call each decoder's decode method (if required), convert the output format (e.g. token to + sequence). Beam search is not supported yet.""" + results = {} + predictions = [] + for format_ in self.formats: + if format_ in ['atomtok', 'atomtok_coords', 'chartok_coords']: + max_len = FORMAT_INFO[format_]['max_len'] + results[format_] = self.decoder[format_].decode(encoder_out, beam_size, n_best, max_length=max_len) + outputs, scores, token_scores, *_ = results[format_] + beam_preds = [[self.tokenizer[format_].sequence_to_smiles(x.tolist()) for x in pred] + for pred in outputs] + predictions = [{format_: pred[0]} for pred in beam_preds] + if self.compute_confidence: + for i in range(len(predictions)): + # -1: y score, -2: x score, -3: symbol score + indices = np.array(predictions[i][format_]['indices']) - 3 + if format_ == 'chartok_coords': + atom_scores = [] + for symbol, index in zip(predictions[i][format_]['symbols'], indices): + atom_score = (np.prod(token_scores[i][0][index - len(symbol) + 1:index + 1]) + ** (1 / len(symbol))).item() + atom_scores.append(atom_score) + else: + atom_scores = np.array(token_scores[i][0])[indices].tolist() + predictions[i][format_]['atom_scores'] = atom_scores + predictions[i][format_]['average_token_score'] = scores[i][0] + if format_ == 'edges': + if 'atomtok_coords' in results: + atom_format = 'atomtok_coords' + elif 'chartok_coords' in results: + atom_format = 'chartok_coords' + else: + raise NotImplemented + dec_out = results[atom_format][3] # batch x n_best x len x dim + for i in range(len(dec_out)): + hidden = dec_out[i][0].unsqueeze(0) # 1 * len * dim + indices = torch.LongTensor(predictions[i][atom_format]['indices']).unsqueeze(0) # 1 * k + pred = self.decoder['edges'](hidden, indices) # k * k + prob = F.softmax(pred['edges'].squeeze(0).permute(1, 2, 0), dim=2).tolist() # k * k * 7 + edge_pred, edge_score = get_edge_prediction(prob) + predictions[i]['edges'] = edge_pred + if self.compute_confidence: + predictions[i]['edge_scores'] = edge_score + predictions[i]['edge_score_product'] = np.sqrt(np.prod(edge_score)).item() + predictions[i]['overall_score'] = predictions[i][atom_format]['average_token_score'] * \ + predictions[i]['edge_score_product'] + predictions[i][atom_format].pop('average_token_score') + predictions[i].pop('edge_score_product') + return predictions diff --git a/molscribe/tokenizer.py b/molscribe/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..6789ab09433ad13e6f61512cfe6a6f0f307ff75d --- /dev/null +++ b/molscribe/tokenizer.py @@ -0,0 +1,524 @@ +import os +import json +import random +import numpy as np +from SmilesPE.pretokenizer import atomwise_tokenizer + +PAD = '<pad>' +SOS = '<sos>' +EOS = '<eos>' +UNK = '<unk>' +MASK = '<mask>' +PAD_ID = 0 +SOS_ID = 1 +EOS_ID = 2 +UNK_ID = 3 +MASK_ID = 4 + + +class Tokenizer(object): + + def __init__(self, path=None): + self.stoi = {} + self.itos = {} + if path: + self.load(path) + + def __len__(self): + return len(self.stoi) + + @property + def output_constraint(self): + return False + + def save(self, path): + with open(path, 'w') as f: + json.dump(self.stoi, f) + + def load(self, path): + with open(path) as f: + self.stoi = json.load(f) + self.itos = {item[1]: item[0] for item in self.stoi.items()} + + def fit_on_texts(self, texts): + vocab = set() + for text in texts: + vocab.update(text.split(' ')) + vocab = [PAD, SOS, EOS, UNK] + list(vocab) + for i, s in enumerate(vocab): + self.stoi[s] = i + self.itos = {item[1]: item[0] for item in self.stoi.items()} + assert self.stoi[PAD] == PAD_ID + assert self.stoi[SOS] == SOS_ID + assert self.stoi[EOS] == EOS_ID + assert self.stoi[UNK] == UNK_ID + + def text_to_sequence(self, text, tokenized=True): + sequence = [] + sequence.append(self.stoi['<sos>']) + if tokenized: + tokens = text.split(' ') + else: + tokens = atomwise_tokenizer(text) + for s in tokens: + if s not in self.stoi: + s = '<unk>' + sequence.append(self.stoi[s]) + sequence.append(self.stoi['<eos>']) + return sequence + + def texts_to_sequences(self, texts): + sequences = [] + for text in texts: + sequence = self.text_to_sequence(text) + sequences.append(sequence) + return sequences + + def sequence_to_text(self, sequence): + return ''.join(list(map(lambda i: self.itos[i], sequence))) + + def sequences_to_texts(self, sequences): + texts = [] + for sequence in sequences: + text = self.sequence_to_text(sequence) + texts.append(text) + return texts + + def predict_caption(self, sequence): + caption = '' + for i in sequence: + if i == self.stoi['<eos>'] or i == self.stoi['<pad>']: + break + caption += self.itos[i] + return caption + + def predict_captions(self, sequences): + captions = [] + for sequence in sequences: + caption = self.predict_caption(sequence) + captions.append(caption) + return captions + + def sequence_to_smiles(self, sequence): + return {'smiles': self.predict_caption(sequence)} + + +class NodeTokenizer(Tokenizer): + + def __init__(self, input_size=100, path=None, sep_xy=False, continuous_coords=False, debug=False): + super().__init__(path) + self.maxx = input_size # height + self.maxy = input_size # width + self.sep_xy = sep_xy + self.special_tokens = [PAD, SOS, EOS, UNK, MASK] + self.continuous_coords = continuous_coords + self.debug = debug + + def __len__(self): + if self.sep_xy: + return self.offset + self.maxx + self.maxy + else: + return self.offset + max(self.maxx, self.maxy) + + @property + def offset(self): + return len(self.stoi) + + @property + def output_constraint(self): + return not self.continuous_coords + + def len_symbols(self): + return len(self.stoi) + + def fit_atom_symbols(self, atoms): + vocab = self.special_tokens + list(set(atoms)) + for i, s in enumerate(vocab): + self.stoi[s] = i + assert self.stoi[PAD] == PAD_ID + assert self.stoi[SOS] == SOS_ID + assert self.stoi[EOS] == EOS_ID + assert self.stoi[UNK] == UNK_ID + assert self.stoi[MASK] == MASK_ID + self.itos = {item[1]: item[0] for item in self.stoi.items()} + + def is_x(self, x): + return self.offset <= x < self.offset + self.maxx + + def is_y(self, y): + if self.sep_xy: + return self.offset + self.maxx <= y + return self.offset <= y + + def is_symbol(self, s): + return len(self.special_tokens) <= s < self.offset or s == UNK_ID + + def is_atom(self, id): + if self.is_symbol(id): + return self.is_atom_token(self.itos[id]) + return False + + def is_atom_token(self, token): + return token.isalpha() or token.startswith("[") or token == '*' or token == UNK + + def x_to_id(self, x): + return self.offset + round(x * (self.maxx - 1)) + + def y_to_id(self, y): + if self.sep_xy: + return self.offset + self.maxx + round(y * (self.maxy - 1)) + return self.offset + round(y * (self.maxy - 1)) + + def id_to_x(self, id): + return (id - self.offset) / (self.maxx - 1) + + def id_to_y(self, id): + if self.sep_xy: + return (id - self.offset - self.maxx) / (self.maxy - 1) + return (id - self.offset) / (self.maxy - 1) + + def get_output_mask(self, id): + mask = [False] * len(self) + if self.continuous_coords: + return mask + if self.is_atom(id): + return [True] * self.offset + [False] * self.maxx + [True] * self.maxy + if self.is_x(id): + return [True] * (self.offset + self.maxx) + [False] * self.maxy + if self.is_y(id): + return [False] * self.offset + [True] * (self.maxx + self.maxy) + return mask + + def symbol_to_id(self, symbol): + if symbol not in self.stoi: + return UNK_ID + return self.stoi[symbol] + + def symbols_to_labels(self, symbols): + labels = [] + for symbol in symbols: + labels.append(self.symbol_to_id(symbol)) + return labels + + def labels_to_symbols(self, labels): + symbols = [] + for label in labels: + symbols.append(self.itos[label]) + return symbols + + def nodes_to_grid(self, nodes): + coords, symbols = nodes['coords'], nodes['symbols'] + grid = np.zeros((self.maxx, self.maxy), dtype=int) + for [x, y], symbol in zip(coords, symbols): + x = round(x * (self.maxx - 1)) + y = round(y * (self.maxy - 1)) + grid[x][y] = self.symbol_to_id(symbol) + return grid + + def grid_to_nodes(self, grid): + coords, symbols, indices = [], [], [] + for i in range(self.maxx): + for j in range(self.maxy): + if grid[i][j] != 0: + x = i / (self.maxx - 1) + y = j / (self.maxy - 1) + coords.append([x, y]) + symbols.append(self.itos[grid[i][j]]) + indices.append([i, j]) + return {'coords': coords, 'symbols': symbols, 'indices': indices} + + def nodes_to_sequence(self, nodes): + coords, symbols = nodes['coords'], nodes['symbols'] + labels = [SOS_ID] + for (x, y), symbol in zip(coords, symbols): + assert 0 <= x <= 1 + assert 0 <= y <= 1 + labels.append(self.x_to_id(x)) + labels.append(self.y_to_id(y)) + labels.append(self.symbol_to_id(symbol)) + labels.append(EOS_ID) + return labels + + def sequence_to_nodes(self, sequence): + coords, symbols = [], [] + i = 0 + if sequence[0] == SOS_ID: + i += 1 + while i + 2 < len(sequence): + if sequence[i] == EOS_ID: + break + if self.is_x(sequence[i]) and self.is_y(sequence[i+1]) and self.is_symbol(sequence[i+2]): + x = self.id_to_x(sequence[i]) + y = self.id_to_y(sequence[i+1]) + symbol = self.itos[sequence[i+2]] + coords.append([x, y]) + symbols.append(symbol) + i += 3 + return {'coords': coords, 'symbols': symbols} + + def smiles_to_sequence(self, smiles, coords=None, mask_ratio=0, atom_only=False): + tokens = atomwise_tokenizer(smiles) + labels = [SOS_ID] + indices = [] + atom_idx = -1 + for token in tokens: + if atom_only and not self.is_atom_token(token): + continue + if token in self.stoi: + labels.append(self.stoi[token]) + else: + if self.debug: + print(f'{token} not in vocab') + labels.append(UNK_ID) + if self.is_atom_token(token): + atom_idx += 1 + if not self.continuous_coords: + if mask_ratio > 0 and random.random() < mask_ratio: + labels.append(MASK_ID) + labels.append(MASK_ID) + elif coords is not None: + if atom_idx < len(coords): + x, y = coords[atom_idx] + assert 0 <= x <= 1 + assert 0 <= y <= 1 + else: + x = random.random() + y = random.random() + labels.append(self.x_to_id(x)) + labels.append(self.y_to_id(y)) + indices.append(len(labels) - 1) + labels.append(EOS_ID) + return labels, indices + + def sequence_to_smiles(self, sequence): + has_coords = not self.continuous_coords + smiles = '' + coords, symbols, indices = [], [], [] + for i, label in enumerate(sequence): + if label == EOS_ID or label == PAD_ID: + break + if self.is_x(label) or self.is_y(label): + continue + token = self.itos[label] + smiles += token + if self.is_atom_token(token): + if has_coords: + if i+3 < len(sequence) and self.is_x(sequence[i+1]) and self.is_y(sequence[i+2]): + x = self.id_to_x(sequence[i+1]) + y = self.id_to_y(sequence[i+2]) + coords.append([x, y]) + symbols.append(token) + indices.append(i+3) + else: + if i+1 < len(sequence): + symbols.append(token) + indices.append(i+1) + results = {'smiles': smiles, 'symbols': symbols, 'indices': indices} + if has_coords: + results['coords'] = coords + return results + + +class CharTokenizer(NodeTokenizer): + + def __init__(self, input_size=100, path=None, sep_xy=False, continuous_coords=False, debug=False): + super().__init__(input_size, path, sep_xy, continuous_coords, debug) + + def fit_on_texts(self, texts): + vocab = set() + for text in texts: + vocab.update(list(text)) + if ' ' in vocab: + vocab.remove(' ') + vocab = [PAD, SOS, EOS, UNK] + list(vocab) + for i, s in enumerate(vocab): + self.stoi[s] = i + self.itos = {item[1]: item[0] for item in self.stoi.items()} + assert self.stoi[PAD] == PAD_ID + assert self.stoi[SOS] == SOS_ID + assert self.stoi[EOS] == EOS_ID + assert self.stoi[UNK] == UNK_ID + + def text_to_sequence(self, text, tokenized=True): + sequence = [] + sequence.append(self.stoi['<sos>']) + if tokenized: + tokens = text.split(' ') + assert all(len(s) == 1 for s in tokens) + else: + tokens = list(text) + for s in tokens: + if s not in self.stoi: + s = '<unk>' + sequence.append(self.stoi[s]) + sequence.append(self.stoi['<eos>']) + return sequence + + def fit_atom_symbols(self, atoms): + atoms = list(set(atoms)) + chars = [] + for atom in atoms: + chars.extend(list(atom)) + vocab = self.special_tokens + chars + for i, s in enumerate(vocab): + self.stoi[s] = i + assert self.stoi[PAD] == PAD_ID + assert self.stoi[SOS] == SOS_ID + assert self.stoi[EOS] == EOS_ID + assert self.stoi[UNK] == UNK_ID + assert self.stoi[MASK] == MASK_ID + self.itos = {item[1]: item[0] for item in self.stoi.items()} + + def get_output_mask(self, id): + ''' TO FIX ''' + mask = [False] * len(self) + if self.continuous_coords: + return mask + if self.is_x(id): + return [True] * (self.offset + self.maxx) + [False] * self.maxy + if self.is_y(id): + return [False] * self.offset + [True] * (self.maxx + self.maxy) + return mask + + def nodes_to_sequence(self, nodes): + coords, symbols = nodes['coords'], nodes['symbols'] + labels = [SOS_ID] + for (x, y), symbol in zip(coords, symbols): + assert 0 <= x <= 1 + assert 0 <= y <= 1 + labels.append(self.x_to_id(x)) + labels.append(self.y_to_id(y)) + for char in symbol: + labels.append(self.symbol_to_id(char)) + labels.append(EOS_ID) + return labels + + def sequence_to_nodes(self, sequence): + coords, symbols = [], [] + i = 0 + if sequence[0] == SOS_ID: + i += 1 + while i < len(sequence): + if sequence[i] == EOS_ID: + break + if i+2 < len(sequence) and self.is_x(sequence[i]) and self.is_y(sequence[i+1]) and self.is_symbol(sequence[i+2]): + x = self.id_to_x(sequence[i]) + y = self.id_to_y(sequence[i+1]) + for j in range(i+2, len(sequence)): + if not self.is_symbol(sequence[j]): + break + symbol = ''.join(self.itos(sequence[k]) for k in range(i+2, j)) + coords.append([x, y]) + symbols.append(symbol) + i = j + else: + i += 1 + return {'coords': coords, 'symbols': symbols} + + def smiles_to_sequence(self, smiles, coords=None, mask_ratio=0, atom_only=False): + tokens = atomwise_tokenizer(smiles) + labels = [SOS_ID] + indices = [] + atom_idx = -1 + for token in tokens: + if atom_only and not self.is_atom_token(token): + continue + for c in token: + if c in self.stoi: + labels.append(self.stoi[c]) + else: + if self.debug: + print(f'{c} not in vocab') + labels.append(UNK_ID) + if self.is_atom_token(token): + atom_idx += 1 + if not self.continuous_coords: + if mask_ratio > 0 and random.random() < mask_ratio: + labels.append(MASK_ID) + labels.append(MASK_ID) + elif coords is not None: + if atom_idx < len(coords): + x, y = coords[atom_idx] + assert 0 <= x <= 1 + assert 0 <= y <= 1 + else: + x = random.random() + y = random.random() + labels.append(self.x_to_id(x)) + labels.append(self.y_to_id(y)) + indices.append(len(labels) - 1) + labels.append(EOS_ID) + return labels, indices + + def sequence_to_smiles(self, sequence): + has_coords = not self.continuous_coords + smiles = '' + coords, symbols, indices = [], [], [] + i = 0 + while i < len(sequence): + label = sequence[i] + if label == EOS_ID or label == PAD_ID: + break + if self.is_x(label) or self.is_y(label): + i += 1 + continue + if not self.is_atom(label): + smiles += self.itos[label] + i += 1 + continue + if self.itos[label] == '[': + j = i + 1 + while j < len(sequence): + if not self.is_symbol(sequence[j]): + break + if self.itos[sequence[j]] == ']': + j += 1 + break + j += 1 + else: + if i+1 < len(sequence) and (self.itos[label] == 'C' and self.is_symbol(sequence[i+1]) and self.itos[sequence[i+1]] == 'l' \ + or self.itos[label] == 'B' and self.is_symbol(sequence[i+1]) and self.itos[sequence[i+1]] == 'r'): + j = i+2 + else: + j = i+1 + token = ''.join(self.itos[sequence[k]] for k in range(i, j)) + smiles += token + if has_coords: + if j+2 < len(sequence) and self.is_x(sequence[j]) and self.is_y(sequence[j+1]): + x = self.id_to_x(sequence[j]) + y = self.id_to_y(sequence[j+1]) + coords.append([x, y]) + symbols.append(token) + indices.append(j+2) + i = j+2 + else: + i = j + else: + if j < len(sequence): + symbols.append(token) + indices.append(j) + i = j + results = {'smiles': smiles, 'symbols': symbols, 'indices': indices} + if has_coords: + results['coords'] = coords + return results + + +def get_tokenizer(args): + tokenizer = {} + for format_ in args.formats: + if format_ == 'atomtok': + if args.vocab_file is None: + args.vocab_file = os.path.join(os.path.dirname(__file__), 'vocab/vocab_uspto.json') + tokenizer['atomtok'] = Tokenizer(args.vocab_file) + elif format_ == "atomtok_coords": + if args.vocab_file is None: + args.vocab_file = os.path.join(os.path.dirname(__file__), 'vocab/vocab_uspto.json') + tokenizer["atomtok_coords"] = NodeTokenizer(args.coord_bins, args.vocab_file, args.sep_xy, + continuous_coords=args.continuous_coords) + elif format_ == "chartok_coords": + if args.vocab_file is None: + args.vocab_file = os.path.join(os.path.dirname(__file__), 'vocab/vocab_chars.json') + tokenizer["chartok_coords"] = CharTokenizer(args.coord_bins, args.vocab_file, args.sep_xy, + continuous_coords=args.continuous_coords) + return tokenizer diff --git a/molscribe/transformer/__init__.py b/molscribe/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e9c953b157c20f4c8bd13345f0a846fd70e0815e --- /dev/null +++ b/molscribe/transformer/__init__.py @@ -0,0 +1,3 @@ +from .decoder import TransformerDecoder +from .embedding import Embeddings +from .swin_transformer import swin_base, swin_large diff --git a/molscribe/transformer/__pycache__/__init__.cpython-310.pyc b/molscribe/transformer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be9d5f68f31ab7d5ca635c0dd5673c6bcc972fb7 Binary files /dev/null and b/molscribe/transformer/__pycache__/__init__.cpython-310.pyc differ diff --git a/molscribe/transformer/__pycache__/decoder.cpython-310.pyc b/molscribe/transformer/__pycache__/decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5337fa6bddadfa70e2bc8044957db3d1e4d743d1 Binary files /dev/null and b/molscribe/transformer/__pycache__/decoder.cpython-310.pyc differ diff --git a/molscribe/transformer/__pycache__/embedding.cpython-310.pyc b/molscribe/transformer/__pycache__/embedding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45ddffec0f204271e500302663bcd573481ce1ad Binary files /dev/null and b/molscribe/transformer/__pycache__/embedding.cpython-310.pyc differ diff --git a/molscribe/transformer/__pycache__/swin_transformer.cpython-310.pyc b/molscribe/transformer/__pycache__/swin_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9651e7961967be7476e853abd4cb27868dd4eb96 Binary files /dev/null and b/molscribe/transformer/__pycache__/swin_transformer.cpython-310.pyc differ diff --git a/molscribe/transformer/decoder.py b/molscribe/transformer/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a04a96aa4c472fd00d3e8a9d470bd62d380b3202 --- /dev/null +++ b/molscribe/transformer/decoder.py @@ -0,0 +1,487 @@ +""" +Implementation of "Attention is All You Need" and of +subsequent transformer based architectures +""" + +import torch +import torch.nn as nn + +from onmt.decoders.decoder import DecoderBase +from onmt.modules import MultiHeadedAttention, AverageAttention +from onmt.modules.position_ffn import PositionwiseFeedForward +from onmt.modules.position_ffn import ActivationFunction +from onmt.utils.misc import sequence_mask + + +class TransformerDecoderLayerBase(nn.Module): + def __init__( + self, + d_model, + heads, + d_ff, + dropout, + attention_dropout, + self_attn_type="scaled-dot", + max_relative_positions=0, + aan_useffn=False, + full_context_alignment=False, + alignment_heads=0, + pos_ffn_activation_fn=ActivationFunction.relu, + ): + """ + Args: + d_model (int): the dimension of keys/values/queries in + :class:`MultiHeadedAttention`, also the input size of + the first-layer of the :class:`PositionwiseFeedForward`. + heads (int): the number of heads for MultiHeadedAttention. + d_ff (int): the second-layer of the + :class:`PositionwiseFeedForward`. + dropout (float): dropout in residual, self-attn(dot) and + feed-forward + attention_dropout (float): dropout in context_attn (and + self-attn(avg)) + self_attn_type (string): type of self-attention scaled-dot, + average + max_relative_positions (int): + Max distance between inputs in relative positions + representations + aan_useffn (bool): Turn on the FFN layer in the AAN decoder + full_context_alignment (bool): + whether enable an extra full context decoder forward for + alignment + alignment_heads (int): + N. of cross attention heads to use for alignment guiding + pos_ffn_activation_fn (ActivationFunction): + activation function choice for PositionwiseFeedForward layer + + """ + super(TransformerDecoderLayerBase, self).__init__() + + if self_attn_type == "scaled-dot": + self.self_attn = MultiHeadedAttention( + heads, + d_model, + dropout=attention_dropout, + max_relative_positions=max_relative_positions, + ) + elif self_attn_type == "average": + self.self_attn = AverageAttention( + d_model, dropout=attention_dropout, aan_useffn=aan_useffn + ) + + self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout, + pos_ffn_activation_fn + ) + self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6) + self.drop = nn.Dropout(dropout) + self.full_context_alignment = full_context_alignment + self.alignment_heads = alignment_heads + + def forward(self, *args, **kwargs): + """Extend `_forward` for (possibly) multiple decoder pass: + Always a default (future masked) decoder forward pass, + Possibly a second future aware decoder pass for joint learn + full context alignement, :cite:`garg2019jointly`. + + Args: + * All arguments of _forward. + with_align (bool): whether return alignment attention. + + Returns: + (FloatTensor, FloatTensor, FloatTensor or None): + + * output ``(batch_size, T, model_dim)`` + * top_attn ``(batch_size, T, src_len)`` + * attn_align ``(batch_size, T, src_len)`` or None + """ + with_align = kwargs.pop("with_align", False) + output, attns = self._forward(*args, **kwargs) + top_attn = attns[:, 0, :, :].contiguous() + attn_align = None + if with_align: + if self.full_context_alignment: + # return _, (B, Q_len, K_len) + _, attns = self._forward(*args, **kwargs, future=True) + + if self.alignment_heads > 0: + attns = attns[:, : self.alignment_heads, :, :].contiguous() + # layer average attention across heads, get ``(B, Q, K)`` + # Case 1: no full_context, no align heads -> layer avg baseline + # Case 2: no full_context, 1 align heads -> guided align + # Case 3: full_context, 1 align heads -> full cte guided align + attn_align = attns.mean(dim=1) + return output, top_attn, attn_align + + def update_dropout(self, dropout, attention_dropout): + self.self_attn.update_dropout(attention_dropout) + self.feed_forward.update_dropout(dropout) + self.drop.p = dropout + + def _forward(self, *args, **kwargs): + raise NotImplementedError + + def _compute_dec_mask(self, tgt_pad_mask, future): + tgt_len = tgt_pad_mask.size(-1) + if not future: # apply future_mask, result mask in (B, T, T) + future_mask = torch.ones( + [tgt_len, tgt_len], + device=tgt_pad_mask.device, + dtype=torch.uint8, + ) + future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len) + # BoolTensor was introduced in pytorch 1.2 + try: + future_mask = future_mask.bool() + except AttributeError: + pass + dec_mask = torch.gt(tgt_pad_mask + future_mask, 0) + else: # only mask padding, result mask in (B, 1, T) + dec_mask = tgt_pad_mask + return dec_mask + + def _forward_self_attn(self, inputs_norm, dec_mask, layer_cache, step): + if isinstance(self.self_attn, MultiHeadedAttention): + return self.self_attn( + inputs_norm, + inputs_norm, + inputs_norm, + mask=dec_mask, + layer_cache=layer_cache, + attn_type="self", + ) + elif isinstance(self.self_attn, AverageAttention): + return self.self_attn( + inputs_norm, mask=dec_mask, layer_cache=layer_cache, step=step + ) + else: + raise ValueError( + f"self attention {type(self.self_attn)} not supported" + ) + + +class TransformerDecoderLayer(TransformerDecoderLayerBase): + """Transformer Decoder layer block in Pre-Norm style. + Pre-Norm style is an improvement w.r.t. Original paper's Post-Norm style, + providing better converge speed and performance. This is also the actual + implementation in tensor2tensor and also avalable in fairseq. + See https://tunz.kr/post/4 and :cite:`DeeperTransformer`. + + .. mermaid:: + + graph LR + %% "*SubLayer" can be self-attn, src-attn or feed forward block + A(input) --> B[Norm] + B --> C["*SubLayer"] + C --> D[Drop] + D --> E((+)) + A --> E + E --> F(out) + + """ + + def __init__( + self, + d_model, + heads, + d_ff, + dropout, + attention_dropout, + self_attn_type="scaled-dot", + max_relative_positions=0, + aan_useffn=False, + full_context_alignment=False, + alignment_heads=0, + pos_ffn_activation_fn=ActivationFunction.relu, + ): + """ + Args: + See TransformerDecoderLayerBase + """ + super(TransformerDecoderLayer, self).__init__( + d_model, + heads, + d_ff, + dropout, + attention_dropout, + self_attn_type, + max_relative_positions, + aan_useffn, + full_context_alignment, + alignment_heads, + pos_ffn_activation_fn=pos_ffn_activation_fn, + ) + self.context_attn = MultiHeadedAttention( + heads, d_model, dropout=attention_dropout + ) + self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6) + + def update_dropout(self, dropout, attention_dropout): + super(TransformerDecoderLayer, self).update_dropout( + dropout, attention_dropout + ) + self.context_attn.update_dropout(attention_dropout) + + def _forward( + self, + inputs, + memory_bank, + src_pad_mask, + tgt_pad_mask, + layer_cache=None, + step=None, + future=False, + ): + """A naive forward pass for transformer decoder. + + # T: could be 1 in the case of stepwise decoding or tgt_len + + Args: + inputs (FloatTensor): ``(batch_size, T, model_dim)`` + memory_bank (FloatTensor): ``(batch_size, src_len, model_dim)`` + src_pad_mask (bool): ``(batch_size, 1, src_len)`` + tgt_pad_mask (bool): ``(batch_size, 1, T)`` + layer_cache (dict or None): cached layer info when stepwise decode + step (int or None): stepwise decoding counter + future (bool): If set True, do not apply future_mask. + + Returns: + (FloatTensor, FloatTensor): + + * output ``(batch_size, T, model_dim)`` + * attns ``(batch_size, head, T, src_len)`` + + """ + dec_mask = None + + if inputs.size(1) > 1: + # masking is necessary when sequence length is greater than one + dec_mask = self._compute_dec_mask(tgt_pad_mask, future) + + inputs_norm = self.layer_norm_1(inputs) + + query, _ = self._forward_self_attn( + inputs_norm, dec_mask, layer_cache, step + ) + + query = self.drop(query) + inputs + + query_norm = self.layer_norm_2(query) + mid, attns = self.context_attn( + memory_bank, + memory_bank, + query_norm, + mask=src_pad_mask, + layer_cache=layer_cache, + attn_type="context", + ) + output = self.feed_forward(self.drop(mid) + query) + + return output, attns + + +class TransformerDecoderBase(DecoderBase): + def __init__(self, d_model, copy_attn, alignment_layer): + super(TransformerDecoderBase, self).__init__() + + # Decoder State + self.state = {} + + # previously, there was a GlobalAttention module here for copy + # attention. But it was never actually used -- the "copy" attention + # just reuses the context attention. + self._copy = copy_attn + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + + self.alignment_layer = alignment_layer + + @classmethod + def from_opt(cls, opt, embeddings): + """Alternate constructor.""" + return cls( + opt.dec_layers, + opt.dec_rnn_size, + opt.heads, + opt.transformer_ff, + opt.copy_attn, + opt.self_attn_type, + opt.dropout[0] if type(opt.dropout) is list else opt.dropout, + opt.attention_dropout[0] if type(opt.attention_dropout) is list else opt.attention_dropout, + embeddings, + opt.max_relative_positions, + opt.aan_useffn, + opt.full_context_alignment, + opt.alignment_layer, + alignment_heads=opt.alignment_heads, + pos_ffn_activation_fn=opt.pos_ffn_activation_fn, + ) + + def init_state(self, src, memory_bank, enc_hidden): + """Initialize decoder state.""" + self.state["src"] = src + self.state["cache"] = None + + def map_state(self, fn): + def _recursive_map(struct, batch_dim=0): + for k, v in struct.items(): + if v is not None: + if isinstance(v, dict): + _recursive_map(v) + else: + struct[k] = fn(v, batch_dim) + + if self.state["src"] is not None: + self.state["src"] = fn(self.state["src"], 1) + if self.state["cache"] is not None: + _recursive_map(self.state["cache"]) + + def detach_state(self): + raise NotImplementedError + + def forward(self, *args, **kwargs): + raise NotImplementedError + + def update_dropout(self, dropout, attention_dropout): + self.embeddings.update_dropout(dropout) + for layer in self.transformer_layers: + layer.update_dropout(dropout, attention_dropout) + + +class TransformerDecoder(TransformerDecoderBase): + """The Transformer decoder from "Attention is All You Need". + :cite:`DBLP:journals/corr/VaswaniSPUJGKP17` + + .. mermaid:: + + graph BT + A[input] + B[multi-head self-attn] + BB[multi-head src-attn] + C[feed forward] + O[output] + A --> B + B --> BB + BB --> C + C --> O + + + Args: + num_layers (int): number of decoder layers. + d_model (int): size of the model + heads (int): number of heads + d_ff (int): size of the inner FF layer + copy_attn (bool): if using a separate copy attention + self_attn_type (str): type of self-attention scaled-dot, average + dropout (float): dropout in residual, self-attn(dot) and feed-forward + attention_dropout (float): dropout in context_attn (and self-attn(avg)) + embeddings (onmt.modules.Embeddings): + embeddings to use, should have positional encodings + max_relative_positions (int): + Max distance between inputs in relative positions representations + aan_useffn (bool): Turn on the FFN layer in the AAN decoder + full_context_alignment (bool): + whether enable an extra full context decoder forward for alignment + alignment_layer (int): N° Layer to supervise with for alignment guiding + alignment_heads (int): + N. of cross attention heads to use for alignment guiding + """ + + def __init__( + self, + num_layers, + d_model, + heads, + d_ff, + copy_attn, + self_attn_type, + dropout, + attention_dropout, + max_relative_positions, + aan_useffn, + full_context_alignment, + alignment_layer, + alignment_heads, + pos_ffn_activation_fn=ActivationFunction.relu, + ): + super(TransformerDecoder, self).__init__( + d_model, copy_attn, alignment_layer + ) + + self.transformer_layers = nn.ModuleList( + [ + TransformerDecoderLayer( + d_model, + heads, + d_ff, + dropout, + attention_dropout, + self_attn_type=self_attn_type, + max_relative_positions=max_relative_positions, + aan_useffn=aan_useffn, + full_context_alignment=full_context_alignment, + alignment_heads=alignment_heads, + pos_ffn_activation_fn=pos_ffn_activation_fn, + ) + for i in range(num_layers) + ] + ) + + def detach_state(self): + self.state["src"] = self.state["src"].detach() + + def forward(self, tgt_emb, memory_bank, src_pad_mask=None, tgt_pad_mask=None, step=None, **kwargs): + """Decode, possibly stepwise.""" + if step == 0: + self._init_cache(memory_bank) + + batch_size, src_len, src_dim = memory_bank.size() + device = memory_bank.device + if src_pad_mask is None: + src_pad_mask = torch.zeros((batch_size, 1, src_len), dtype=torch.bool, device=device) + output = tgt_emb + batch_size, tgt_len, tgt_dim = tgt_emb.size() + if tgt_pad_mask is None: + tgt_pad_mask = torch.zeros((batch_size, 1, tgt_len), dtype=torch.bool, device=device) + + future = kwargs.pop("future", False) + with_align = kwargs.pop("with_align", False) + attn_aligns = [] + hiddens = [] + + for i, layer in enumerate(self.transformer_layers): + layer_cache = ( + self.state["cache"]["layer_{}".format(i)] + if step is not None + else None + ) + output, attn, attn_align = layer( + output, + memory_bank, + src_pad_mask, + tgt_pad_mask, + layer_cache=layer_cache, + step=step, + with_align=with_align, + future=future + ) + hiddens.append(output) + if attn_align is not None: + attn_aligns.append(attn_align) + + output = self.layer_norm(output) # (B, L, D) + + attns = {"std": attn} + if self._copy: + attns["copy"] = attn + if with_align: + attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)` + # attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg + + # TODO change the way attns is returned dict => list or tuple (onnx) + return output, attns, hiddens + + def _init_cache(self, memory_bank): + self.state["cache"] = {} + for i, layer in enumerate(self.transformer_layers): + layer_cache = {"memory_keys": None, "memory_values": None, "self_keys": None, "self_values": None} + self.state["cache"]["layer_{}".format(i)] = layer_cache + diff --git a/molscribe/transformer/embedding.py b/molscribe/transformer/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..39647774d7f183690a3443bdf608479088fb0f5d --- /dev/null +++ b/molscribe/transformer/embedding.py @@ -0,0 +1,260 @@ +""" Embeddings module """ +import math +import warnings + +import torch +import torch.nn as nn + +from onmt.modules.util_class import Elementwise + + +class SequenceTooLongError(Exception): + pass + + +class PositionalEncoding(nn.Module): + """Sinusoidal positional encoding for non-recurrent neural networks. + + Implementation based on "Attention Is All You Need" + :cite:`DBLP:journals/corr/VaswaniSPUJGKP17` + + Args: + dropout (float): dropout parameter + dim (int): embedding size + """ + + def __init__(self, dropout, dim, max_len=5000): + if dim % 2 != 0: + raise ValueError("Cannot use sin/cos positional encoding with " + "odd dim (got dim={:d})".format(dim)) + pe = torch.zeros(max_len, dim) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * + -(math.log(10000.0) / dim))) + pe[:, 0::2] = torch.sin(position.float() * div_term) + pe[:, 1::2] = torch.cos(position.float() * div_term) + pe = pe.unsqueeze(1) + super(PositionalEncoding, self).__init__() + self.register_buffer('pe', pe) + self.dropout = nn.Dropout(p=dropout) + self.dim = dim + + def forward(self, emb, step=None): + """Embed inputs. + + Args: + emb (FloatTensor): Sequence of word vectors + ``(seq_len, batch_size, self.dim)`` + step (int or NoneType): If stepwise (``seq_len = 1``), use + the encoding for this position. + """ + + emb = emb * math.sqrt(self.dim) + step = step or 0 + if self.pe.size(0) < step + emb.size(0): + raise SequenceTooLongError( + f"Sequence is {emb.size(0) + step} but PositionalEncoding is" + f" limited to {self.pe.size(0)}. See max_len argument." + ) + emb = emb + self.pe[step:emb.size(0)+step] + emb = self.dropout(emb) + return emb + + +class Embeddings(nn.Module): + """Words embeddings for encoder/decoder. + + Additionally includes ability to add sparse input features + based on "Linguistic Input Features Improve Neural Machine Translation" + :cite:`sennrich2016linguistic`. + + + .. mermaid:: + + graph LR + A[Input] + C[Feature 1 Lookup] + A-->B[Word Lookup] + A-->C + A-->D[Feature N Lookup] + B-->E[MLP/Concat] + C-->E + D-->E + E-->F[Output] + + Args: + word_vec_size (int): size of the dictionary of embeddings. + word_padding_idx (int): padding index for words in the embeddings. + feat_padding_idx (List[int]): padding index for a list of features + in the embeddings. + word_vocab_size (int): size of dictionary of embeddings for words. + feat_vocab_sizes (List[int], optional): list of size of dictionary + of embeddings for each feature. + position_encoding (bool): see :class:`~onmt.modules.PositionalEncoding` + feat_merge (string): merge action for the features embeddings: + concat, sum or mlp. + feat_vec_exponent (float): when using `-feat_merge concat`, feature + embedding size is N^feat_dim_exponent, where N is the + number of values the feature takes. + feat_vec_size (int): embedding dimension for features when using + `-feat_merge mlp` + dropout (float): dropout probability. + freeze_word_vecs (bool): freeze weights of word vectors. + """ + + def __init__(self, word_vec_size, + word_vocab_size, + word_padding_idx, + position_encoding=False, + feat_merge="concat", + feat_vec_exponent=0.7, + feat_vec_size=-1, + feat_padding_idx=[], + feat_vocab_sizes=[], + dropout=0, + sparse=False, + freeze_word_vecs=False): + self._validate_args(feat_merge, feat_vocab_sizes, feat_vec_exponent, + feat_vec_size, feat_padding_idx) + + if feat_padding_idx is None: + feat_padding_idx = [] + self.word_padding_idx = word_padding_idx + + self.word_vec_size = word_vec_size + + # Dimensions and padding for constructing the word embedding matrix + vocab_sizes = [word_vocab_size] + emb_dims = [word_vec_size] + pad_indices = [word_padding_idx] + + # Dimensions and padding for feature embedding matrices + # (these have no effect if feat_vocab_sizes is empty) + if feat_merge == 'sum': + feat_dims = [word_vec_size] * len(feat_vocab_sizes) + elif feat_vec_size > 0: + feat_dims = [feat_vec_size] * len(feat_vocab_sizes) + else: + feat_dims = [int(vocab ** feat_vec_exponent) + for vocab in feat_vocab_sizes] + vocab_sizes.extend(feat_vocab_sizes) + emb_dims.extend(feat_dims) + pad_indices.extend(feat_padding_idx) + + # The embedding matrix look-up tables. The first look-up table + # is for words. Subsequent ones are for features, if any exist. + emb_params = zip(vocab_sizes, emb_dims, pad_indices) + embeddings = [nn.Embedding(vocab, dim, padding_idx=pad, sparse=sparse) + for vocab, dim, pad in emb_params] + emb_luts = Elementwise(feat_merge, embeddings) + + # The final output size of word + feature vectors. This can vary + # from the word vector size if and only if features are defined. + # This is the attribute you should access if you need to know + # how big your embeddings are going to be. + self.embedding_size = (sum(emb_dims) if feat_merge == 'concat' + else word_vec_size) + + # The sequence of operations that converts the input sequence + # into a sequence of embeddings. At minimum this consists of + # looking up the embeddings for each word and feature in the + # input. Model parameters may require the sequence to contain + # additional operations as well. + super(Embeddings, self).__init__() + self.make_embedding = nn.Sequential() + self.make_embedding.add_module('emb_luts', emb_luts) + + if feat_merge == 'mlp' and len(feat_vocab_sizes) > 0: + in_dim = sum(emb_dims) + mlp = nn.Sequential(nn.Linear(in_dim, word_vec_size), nn.ReLU()) + self.make_embedding.add_module('mlp', mlp) + + self.position_encoding = position_encoding + + if self.position_encoding: + pe = PositionalEncoding(dropout, self.embedding_size) + self.make_embedding.add_module('pe', pe) + + if freeze_word_vecs: + self.word_lut.weight.requires_grad = False + + def _validate_args(self, feat_merge, feat_vocab_sizes, feat_vec_exponent, + feat_vec_size, feat_padding_idx): + if feat_merge == "sum": + # features must use word_vec_size + if feat_vec_exponent != 0.7: + warnings.warn("Merging with sum, but got non-default " + "feat_vec_exponent. It will be unused.") + if feat_vec_size != -1: + warnings.warn("Merging with sum, but got non-default " + "feat_vec_size. It will be unused.") + elif feat_vec_size > 0: + # features will use feat_vec_size + if feat_vec_exponent != -1: + warnings.warn("Not merging with sum and positive " + "feat_vec_size, but got non-default " + "feat_vec_exponent. It will be unused.") + else: + if feat_vec_exponent <= 0: + raise ValueError("Using feat_vec_exponent to determine " + "feature vec size, but got feat_vec_exponent " + "less than or equal to 0.") + n_feats = len(feat_vocab_sizes) + if n_feats != len(feat_padding_idx): + raise ValueError("Got unequal number of feat_vocab_sizes and " + "feat_padding_idx ({:d} != {:d})".format( + n_feats, len(feat_padding_idx))) + + @property + def word_lut(self): + """Word look-up table.""" + return self.make_embedding[0][0] + + @property + def emb_luts(self): + """Embedding look-up table.""" + return self.make_embedding[0] + + def load_pretrained_vectors(self, emb_file): + """Load in pretrained embeddings. + + Args: + emb_file (str) : path to torch serialized embeddings + """ + + if emb_file: + pretrained = torch.load(emb_file) + pretrained_vec_size = pretrained.size(1) + if self.word_vec_size > pretrained_vec_size: + self.word_lut.weight.data[:, :pretrained_vec_size] = pretrained + elif self.word_vec_size < pretrained_vec_size: + self.word_lut.weight.data \ + .copy_(pretrained[:, :self.word_vec_size]) + else: + self.word_lut.weight.data.copy_(pretrained) + + def forward(self, source, step=None): + """Computes the embeddings for words and features. + + Args: + source (LongTensor): index tensor ``(len, batch, nfeat)`` + + Returns: + FloatTensor: Word embeddings ``(len, batch, embedding_size)`` + """ + + if self.position_encoding: + for i, module in enumerate(self.make_embedding._modules.values()): + if i == len(self.make_embedding._modules.values()) - 1: + source = module(source, step=step) + else: + source = module(source) + else: + source = self.make_embedding(source) + + return source + + def update_dropout(self, dropout): + if self.position_encoding: + self._modules['make_embedding'][1].dropout.p = dropout + diff --git a/molscribe/transformer/swin_transformer.py b/molscribe/transformer/swin_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..ea2580fb5f730e4f66e2d07fbe071c09c976fd80 --- /dev/null +++ b/molscribe/transformer/swin_transformer.py @@ -0,0 +1,677 @@ +""" Swin Transformer +A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` + - https://arxiv.org/pdf/2103.14030 + +Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below +""" +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu +# -------------------------------------------------------- +import logging +import math +from copy import deepcopy +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.helpers import build_model_with_cfg, overlay_external_default_cfg +from timm.models.layers import Mlp, DropPath, to_2tuple, trunc_normal_ +from timm.models.registry import register_model +from timm.models.vision_transformer import checkpoint_filter_fn, _init_vit_weights + +_logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # patch models (my experiments) + 'swin_base_patch4_window12_384': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + + 'swin_base_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth', + ), + + 'swin_large_patch4_window12_384': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + + 'swin_large_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth', + ), + + 'swin_small_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth', + ), + + 'swin_tiny_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth', + ), + + 'swin_base_patch4_window12_384_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth', + input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841), + + 'swin_base_patch4_window7_224_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth', + num_classes=21841), + + 'swin_large_patch4_window12_384_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth', + input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841), + + 'swin_large_patch4_window7_224_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth', + num_classes=21841), + +} + + +def window_partition(x, window_size: int): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size: int, H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask: Optional[torch.Tensor] = None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, + attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def get_attn_mask(self, H, W, device): + if self.shift_size > 0: + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, H, W, 1), device=device) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + return attn_mask + + def forward(self, x, H, W): + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_mask = self.get_attn_mask(Hp, Wp, x.device) + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """ + x: B, H*W, C + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + # assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + H, W = x.shape[1:3] + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x, H, W + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W, hiddens): + for blk in self.blocks: + if not torch.jit.is_scripting() and self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, H, W) + else: + x = blk(x, H, W) + hiddens.append(x) + if self.downsample is not None: + x, H, W = self.downsample(x, H, W) + return x, H, W + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + B, C, H, W = x.shape + # _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") + # _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) + H, W = x.shape[2:] + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x, H, W + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), + window_size=7, mlp_ratio=4., qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, weight_init='', **kwargs): + super().__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + self.patch_grid = self.patch_embed.grid_size + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + else: + self.absolute_pos_embed = None + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + layers = [] + for i_layer in range(self.num_layers): + layers += [BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + ] + self.layers = nn.Sequential(*layers) + + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0. + if weight_init.startswith('jax'): + for n, m in self.named_modules(): + _init_vit_weights(m, n, head_bias=head_bias, jax_impl=True) + else: + self.apply(_init_vit_weights) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward(self, x): + x, H, W = self.patch_embed(x) + if self.absolute_pos_embed is not None: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + hiddens = [] + for layer in self.layers: + x, H, W = layer(x, H, W, hiddens) + x = self.norm(x) # B L C + # x = self.avgpool(x.transpose(1, 2)) # B C 1 + # x = torch.flatten(x, 1) + return x, hiddens + + # def forward(self, x): + # x = self.forward_features(x) + # x = self.head(x) + # return x + + +def _create_swin_transformer(variant, pretrained=False, default_cfg=None, **kwargs): + if default_cfg is None: + default_cfg = deepcopy(default_cfgs[variant]) + overlay_external_default_cfg(default_cfg, kwargs) + default_num_classes = default_cfg['num_classes'] + default_img_size = default_cfg['input_size'][-2:] + + num_classes = kwargs.pop('num_classes', default_num_classes) + img_size = kwargs.pop('img_size', default_img_size) + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + SwinTransformer, variant, pretrained, + default_cfg=default_cfg, + img_size=img_size, + num_classes=num_classes, + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + + return model + + + +@register_model +def swin_base(pretrained=False, **kwargs): + """ Swin-B @ 384x384, pretrained ImageNet-22k, fine tune 1k + """ + model_kwargs = dict( + patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer('swin_base_patch4_window12_384', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_large(pretrained=False, **kwargs): + """ Swin-L @ 384x384, pretrained ImageNet-22k, fine tune 1k + """ + model_kwargs = dict( + patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) + return _create_swin_transformer('swin_large_patch4_window12_384', pretrained=pretrained, **model_kwargs) + + +# @register_model +# def swin_small_patch4_window7_224(pretrained=False, **kwargs): +# """ Swin-S @ 224x224, trained ImageNet-1k +# """ +# model_kwargs = dict( +# patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs) +# return _create_swin_transformer('swin_small_patch4_window7_224', pretrained=pretrained, **model_kwargs) +# +# +# @register_model +# def swin_tiny_patch4_window7_224(pretrained=False, **kwargs): +# """ Swin-T @ 224x224, trained ImageNet-1k +# """ +# model_kwargs = dict( +# patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), **kwargs) +# return _create_swin_transformer('swin_tiny_patch4_window7_224', pretrained=pretrained, **model_kwargs) +# +# +# @register_model +# def swin_base_patch4_window12_384_in22k(pretrained=False, **kwargs): +# """ Swin-B @ 384x384, trained ImageNet-22k +# """ +# model_kwargs = dict( +# patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) +# return _create_swin_transformer('swin_base_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs) +# +# +# @register_model +# def swin_base_patch4_window7_224_in22k(pretrained=False, **kwargs): +# """ Swin-B @ 224x224, trained ImageNet-22k +# """ +# model_kwargs = dict( +# patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) +# return _create_swin_transformer('swin_base_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs) +# +# +# @register_model +# def swin_large_patch4_window12_384_in22k(pretrained=False, **kwargs): +# """ Swin-L @ 384x384, trained ImageNet-22k +# """ +# model_kwargs = dict( +# patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) +# return _create_swin_transformer('swin_large_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs) +# +# +# @register_model +# def swin_large_patch4_window7_224_in22k(pretrained=False, **kwargs): +# """ Swin-L @ 224x224, trained ImageNet-22k +# """ +# model_kwargs = dict( +# patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) +# return _create_swin_transformer('swin_large_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs) diff --git a/molscribe/utils.py b/molscribe/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d1dd8331835b2c7075fa0c0229b40cc5d9f7266b --- /dev/null +++ b/molscribe/utils.py @@ -0,0 +1,163 @@ +import os +import random +import numpy as np +import torch +import math +import time +import datetime +import json +from json import encoder + + +FORMAT_INFO = { + "inchi": { + "name": "InChI_text", + "tokenizer": "tokenizer_inchi.json", + "max_len": 300 + }, + "atomtok": { + "name": "SMILES_atomtok", + "tokenizer": "tokenizer_smiles_atomtok.json", + "max_len": 256 + }, + "nodes": {"max_len": 384}, + "atomtok_coords": {"max_len": 480}, + "chartok_coords": {"max_len": 480} +} + + +def init_logger(log_file='train.log'): + from logging import getLogger, INFO, FileHandler, Formatter, StreamHandler + logger = getLogger(__name__) + logger.setLevel(INFO) + handler1 = StreamHandler() + handler1.setFormatter(Formatter("%(message)s")) + handler2 = FileHandler(filename=log_file) + handler2.setFormatter(Formatter("%(message)s")) + logger.addHandler(handler1) + logger.addHandler(handler2) + return logger + + +def init_summary_writer(save_path): + from tensorboardX import SummaryWriter + summary = SummaryWriter(save_path) + return summary + + +def save_args(args): + dt = datetime.datetime.strftime(datetime.datetime.now(), "%y%m%d-%H%M") + path = os.path.join(args.save_path, f'train_{dt}.log') + with open(path, 'w') as f: + for k, v in vars(args).items(): + f.write(f"**** {k} = *{v}*\n") + return + + +def seed_torch(seed=42): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +class EpochMeter(AverageMeter): + def __init__(self): + super().__init__() + self.epoch = AverageMeter() + + def update(self, val, n=1): + super().update(val, n) + self.epoch.update(val, n) + + +class LossMeter(EpochMeter): + def __init__(self): + self.subs = {} + super().__init__() + + def reset(self): + super().reset() + for k in self.subs: + self.subs[k].reset() + + def update(self, loss, losses, n=1): + loss = loss.item() + super().update(loss, n) + losses = {k: v.item() for k, v in losses.items()} + for k, v in losses.items(): + if k not in self.subs: + self.subs[k] = EpochMeter() + self.subs[k].update(v, n) + + +def asMinutes(s): + m = math.floor(s / 60) + s -= m * 60 + return '%dm %ds' % (m, s) + + +def timeSince(since, percent): + now = time.time() + s = now - since + es = s / (percent) + rs = es - s + return '%s (remain %s)' % (asMinutes(s), asMinutes(rs)) + + +def print_rank_0(message): + if torch.distributed.is_initialized(): + if torch.distributed.get_rank() == 0: + print(message, flush=True) + else: + print(message, flush=True) + + +def to_device(data, device): + if torch.is_tensor(data): + return data.to(device) + if type(data) is list: + return [to_device(v, device) for v in data] + if type(data) is dict: + return {k: to_device(v, device) for k, v in data.items()} + + +def round_floats(o): + if isinstance(o, float): + return round(o, 3) + if isinstance(o, dict): + return {k: round_floats(v) for k, v in o.items()} + if isinstance(o, (list, tuple)): + return [round_floats(x) for x in o] + return o + + +def format_df(df): + def _dumps(obj): + if obj is None: + return obj + return json.dumps(round_floats(obj)).replace(" ", "") + for field in ['node_coords', 'node_symbols', 'edges']: + if field in df.columns: + df[field] = [_dumps(obj) for obj in df[field]] + return df diff --git a/molscribe/vocab/vocab_chars.json b/molscribe/vocab/vocab_chars.json new file mode 100644 index 0000000000000000000000000000000000000000..daf380c9efda9b89dcdd07398e84609971f0c2bc --- /dev/null +++ b/molscribe/vocab/vocab_chars.json @@ -0,0 +1 @@ +{"<pad>": 0, "<sos>": 1, "<eos>": 2, "<unk>": 3, "<mask>": 4, ".": 5, "-": 6, "=": 7, "#": 8, ":": 9, "/": 10, "\\": 11, "(": 12, ")": 13, "[": 14, "]": 15, "@": 16, "+": 17, "%": 18, "0": 19, "1": 20, "2": 21, "3": 22, "4": 23, "5": 24, "6": 25, "7": 26, "8": 27, "9": 28, "a": 29, "b": 30, "c": 31, "d": 32, "e": 33, "f": 34, "g": 35, "h": 36, "i": 37, "j": 38, "k": 39, "l": 40, "m": 41, "n": 42, "o": 43, "p": 44, "q": 45, "r": 46, "s": 47, "t": 48, "u": 49, "v": 50, "w": 51, "x": 52, "y": 53, "z": 54, "A": 55, "B": 56, "C": 57, "D": 58, "E": 59, "F": 60, "G": 61, "H": 62, "I": 63, "J": 64, "K": 65, "L": 66, "M": 67, "N": 68, "O": 69, "P": 70, "Q": 71, "R": 72, "S": 73, "T": 74, "U": 75, "V": 76, "W": 77, "X": 78, "Y": 79, "Z": 80, "*": 81, "~": 82, "\u000f": 83, "!": 84, "\"": 85, "$": 86, "&": 87, "'": 88, ",": 89, ";": 90, "<": 91, ">": 92, "?": 93, "^": 94, "_": 95, "`": 96, "{": 97, "|": 98, "}": 99, "\u0155": 100} diff --git a/molscribe/vocab/vocab_uspto.json b/molscribe/vocab/vocab_uspto.json new file mode 100644 index 0000000000000000000000000000000000000000..1916eaf0291909eee431678cfcfc390609c8479a --- /dev/null +++ b/molscribe/vocab/vocab_uspto.json @@ -0,0 +1 @@ +{"<pad>": 0, "<sos>": 1, "<eos>": 2, "<unk>": 3, "<mask>": 4, "[OR12]": 5, "[Z;]": 6, "[LG]": 7, "[10*:0]": 8, "[R35]": 9, "[U1]": 10, "[CH2)]": 11, "[(XV)]": 12, "[fmoc]": 13, "[(Z)n]": 14, "[(L)m]": 15, "[24*]": 16, "[CN;]": 17, "[E,]": 18, "[OC4H9(n)]": 19, "[62*]": 20, "[NH2;]": 21, "[OR]": 22, "[Rb3]": 23, "[(Ra)p]": 24, "[Z13]": 25, "[Den1]": 26, "[NOMe]": 27, "[R1]": 28, "[(NH]": 29, "4": 30, "[R70]": 31, "[O*]": 32, "8": 33, "[Cl(O]": 34, "%15": 35, "[R4a]": 36, "[RK]": 37, "[Br-]": 38, "[)n]": 39, "[KH]": 40, "[12]": 41, "[K]": 42, "[ROOC]": 43, "[(CH2CH2O)x]": 44, "[Re1]": 45, "[Cl]": 46, "[i]": 47, "[X.]": 48, "[SO3-M+]": 49, "[37*]": 50, "[(CR3R4)n]": 51, "[Ry]": 52, "[(R2)m]": 53, "[(OH)n]": 54, "[Pg2]": 55, "[OH.]": 56, "[(A)p]": 57, "[b2]": 58, "[R1c]": 59, "[NO2.]": 60, "[22*]": 61, "[OR1A]": 62, "[R50]": 63, "[(L)]": 64, "[R52]": 65, "[Ce]": 66, "[Ra3]": 67, "[39*]": 68, "[(R2)p]": 69, "[(R10)t]": 70, "[H]": 71, "[NHAc]": 72, "[(n)C4H9]": 73, "[CONR1R2]": 74, "[AlH3]": 75, "[Gal]": 76, "[XVII]": 77, "[20]": 78, "[X;]": 79, "[HZ]": 80, "[C*]": 81, "[RA1]": 82, "[(CH2)n]": 83, "[R(6)]": 84, "[S@@]": 85, "[(R4)q]": 86, "[(R5)c]": 87, "[EtOOC]": 88, "[(R12)m]": 89, "[NR7R8]": 90, "[P4]": 91, "[ALOC]": 92, "[Ra]": 93, "[45*]": 94, "[Dye]": 95, "[se]": 96, "[W(1)(2)]": 97, "[LIGAND]": 98, "[4*]": 99, "[O,]": 100, "[FG]": 101, "[(SiO)n]": 102, "[*HN]": 103, "[(R8)r]": 104, "[CH2)3]": 105, "[)a]": 106, "[(R)m]": 107, "[NBoc]": 108, "[CH2OH.]": 109, "[AA3]": 110, "[3*+]": 111, "[OR15]": 112, "[(O)x]": 113, "[CO2R6]": 114, "[NH,]": 115, "[S@H]": 116, "[Bu-t]": 117, "[Bn]": 118, "[Rc4]": 119, "[IH]": 120, "[Z21]": 121, "[65*]": 122, "[IH+]": 123, "[42*]": 124, "[CO2But]": 125, "[R7]": 126, "[30*:0]": 127, "[N2]": 128, "[Thy]": 129, "[(CF2)n]": 130, "[(Rx)q]": 131, "[(R4b)]": 132, "[(t)C5H11]": 133, "[Rx1]": 134, "[Bz]": 135, "[X-Ph3P+CH2]": 136, "[N3]": 137, "[(Y)a]": 138, "[29]": 139, "[NH.]": 140, "[A0]": 141, "[G4]": 142, "[NR10R11]": 143, "[SH-]": 144, "[(X)]": 145, "[Z4]": 146, "[85%]": 147, "[COR2]": 148, "[nH+]": 149, "[t-BuO]": 150, "[Ac]": 151, "[(CR7R8)n]": 152, "[R9a]": 153, "[G]": 154, "[OR2]": 155, "[(CH3)m]": 156, "%20": 157, "[O)x]": 158, "[G7]": 159, "[(OCH2CH2)n]": 160, "2": 161, "[S;]": 162, "[R2b]": 163, "/": 164, "[Mc]": 165, "[CH+]": 166, "[(CH2)z]": 167, "[BocN]": 168, "[OTs]": 169, "[L]": 170, "[(a)]": 171, "[Aryl]": 172, "[O(CH2)z]": 173, "[(R9)p]": 174, "[RL]": 175, "[An]": 176, "[CH]": 177, "[Pb]": 178, "[X-]": 179, "[CO2R1]": 180, "[)C]": 181, "[Rii]": 182, "[N(R1)2]": 183, "[(CR2R3)n]": 184, "[NR5]": 185, "[O)CH3]": 186, "[55*]": 187, "[(Br)n]": 188, "[*N]": 189, "[Ib]": 190, "[OR9]": 191, "[Halo]": 192, "[hal]": 193, "[Qs]": 194, "[X14]": 195, "[C2F5]": 196, "[CH)m]": 197, "[(R4)a]": 198, "[halo]": 199, "[(MeO)3Si(CH2)3]": 200, "[OR20]": 201, "[(R13)b]": 202, "[X5]": 203, "[OR6,]": 204, "[Hf]": 205, "[21*]": 206, "S": 207, "[CR12]": 208, "[65535*:0]": 209, "[(R1]": 210, "[[A]": 211, "[L11]": 212, "[CO2X]": 213, "[Sx]": 214, "[OBut]": 215, "[[Z]": 216, "[(OCH2CH2)x]": 217, "[Z,]": 218, "[CR6]": 219, "[F,]": 220, "[(CR10R11)m]": 221, "[Ab]": 222, "[PG]": 223, "[d(H2C)]": 224, "F": 225, "[(OH)m]": 226, "[OPG1]": 227, "[Sb]": 228, "[43*]": 229, "[HO2C]": 230, "%24": 231, "[60*]": 232, "B": 233, "[OH)]": 234, "[SG]": 235, "[(t)H9C4]": 236, "[W5]": 237, "[NR3R4]": 238, "[Rsa2]": 239, "[(R2)n]": 240, "[SOmR4]": 241, "[CH2;]": 242, "[RB2]": 243, "[28*]": 244, "[)2]": 245, "[Rq]": 246, "[NH3+]": 247, "[CO2R7]": 248, "[base]": 249, "[SO2R]": 250, "[Rn]": 251, "[SMe]": 252, "[(Q2)n]": 253, "[NHZ]": 254, "[Ri]": 255, "[(XVII)]": 256, "[P@@]": 257, "[NH2]": 258, "[C2H4]": 259, "[W-]": 260, "[(O)n]": 261, "[(R4)p]": 262, "[R54]": 263, "[51*]": 264, "[M2]": 265, "[tBuO]": 266, "[NHMe]": 267, "[X15]": 268, "[16*]": 269, "[LG2]": 270, "[C@@]": 271, "[15]": 272, "[i)]": 273, "[*C]": 274, "[DMT]": 275, "[PG1]": 276, "[Rb2]": 277, "[(L1)a]": 278, "[S]": 279, "[-]": 280, "[b)]": 281, "[90*]": 282, "[RA3]": 283, "[RI]": 284, "[C12H25-n]": 285, "[R22]": 286, "[X21]": 287, "[(CH2)m]": 288, "[R14]": 289, "[(SO3H)n]": 290, "[IVa]": 291, "[NO2;]": 292, "[H3CO]": 293, "[C5]": 294, "[Rv]": 295, "[Ra5]": 296, "[(Z)]": 297, "[0]": 298, "[O]": 299, "[RHN]": 300, "[S@+]": 301, "[Se]": 302, "[C4H9-n]": 303, "[64]": 304, "[RO]": 305, "[Y2,]": 306, "[(O)p]": 307, "[A8]": 308, "[*:0]": 309, "[O)n]": 310, "[Ar7]": 311, "[R51]": 312, "[Os]": 313, "[NR3]": 314, "[J1]": 315, "[CN.]": 316, "[ORc]": 317, "[29*]": 318, "[L32]": 319, "[(R8)p]": 320, "[L31]": 321, "[X22]": 322, "[Example]": 323, "[30]": 324, "[CHO.]": 325, "[R31]": 326, "[CF2]": 327, "[13C]": 328, "[Ag]": 329, "[RN1]": 330, "[(CH2)a]": 331, "[(CR1R2)n]": 332, "[L2]": 333, "[K1]": 334, "[I-2]": 335, "[(C1]": 336, "[(R8)s]": 337, "[wherein]": 338, "[NH]": 339, "[S(O)m]": 340, "[r1]": 341, "[(2)]": 342, "[NMe]": 343, "[3.]": 344, "[H2,]": 345, "[M+]": 346, "[B(OR)2]": 347, "[70*]": 348, "[YR]": 349, "[OP2]": 350, "[(CH2)q]": 351, "[COOM]": 352, "[ButO2C]": 353, "[A5]": 354, "[(CH2)j]": 355, "[CR9]": 356, "[Cl;]": 357, "[ZH]": 358, "[[CH2CH2O]": 359, "[CR3R4]": 360, "[O-]": 361, "[CnH2n+1]": 362, "[2.]": 363, "[S.]": 364, "[b4]": 365, "[11]": 366, "[(C(R6)2)k]": 367, "[125I]": 368, "[1]": 369, "[(CH2)c]": 370, "[CH2)m]": 371, "[W4]": 372, "[SO2]": 373, "[G3]": 374, "[R12]": 375, "[(X)q]": 376, "[Formula]": 377, "[NHR5]": 378, "[HfLn]": 379, "[t-Bu]": 380, "[(CH2CH(CH3)O)y]": 381, "[CH2CO2Me]": 382, "[X2]": 383, "[(IX)]": 384, "[O(CH2)n]": 385, "[F1]": 386, "[44*]": 387, "[(L)r]": 388, "[NR2]": 389, "[14]": 390, "[CO2]": 391, "[SiR1R2R3]": 392, "[(x)]": 393, "[(N]": 394, "[Xn]": 395, "[OM]": 396, "[P(O)(OR1)2]": 397, "[(R5)i]": 398, "[C(]": 399, "[CH3.]": 400, "[Z1)n]": 401, "[MgH]": 402, "[PHV]": 403, "[Mn+]": 404, "[H;]": 405, "[(R2)b]": 406, "[n-Bu]": 407, "[B]": 408, "[HH]": 409, "[[R1]": 410, "[(R5)n]": 411, "[OMe.]": 412, "[SeH]": 413, "[(OH)y]": 414, "[I-]": 415, "[RA2]": 416, ")": 417, "[SR]": 418, "[OE]": 419, "[R;]": 420, "[R2a]": 421, "[}]": 422, "[(R3)c]": 423, "[t-Boc]": 424, "[Rh]": 425, "[S+]": 426, "9": 427, "[z]": 428, "[*-]": 429, "[Polymer]": 430, "[(CH2)1]": 431, "[Base]": 432, "[(X1)d]": 433, "[Me3Si]": 434, "[CO2R10]": 435, "[CH3SO2]": 436, "[RC1]": 437, "[Lv]": 438, "[RN]": 439, "[MO3S]": 440, "[Rb1]": 441, "[(CH2)l]": 442, "[(C)q]": 443, "[23]": 444, "[w]": 445, "[E5]": 446, "[R24]": 447, "[CmH2m]": 448, "[SiH2]": 449, "[OSiMe2tBu]": 450, "[(CF2]": 451, "[(AcG)GLYACHMGPIT(1-nal)VCQPLR(MeG)]": 452, "[P-]": 453, "[Rx]": 454, "[Ra6]": 455, "[R(4)]": 456, "[CH3;]": 457, "[*]": 458, "[R15]": 459, "[Xc]": 460, "[52*]": 461, "[C-]": 462, "[Bu]": 463, "[R21]": 464, "[Den]": 465, "[L10]": 466, "[mPEG]": 467, "[PGO]": 468, "[Cy2]": 469, "[NHFmoc]": 470, "[NA]": 471, "[C6H4]": 472, "[c]": 473, "[(L)n]": 474, "[(SO3M)n]": 475, "[Ra11]": 476, "[2*:0]": 477, "[PH2+]": 478, "[I+3]": 479, "[F-]": 480, "[N.]": 481, "[82*]": 482, "[YR1]": 483, "[SO3H.]": 484, "[OX1]": 485, "[Me.]": 486, "[R(2)]": 487, "[O+]": 488, "[Y]": 489, "[OH3+]": 490, "[CH2CH3.]": 491, "[X31]": 492, "[(CH2)4]": 493, "[Ax]": 494, "[Y2]": 495, "[COOR3]": 496, "[CH2)n]": 497, "[Z22]": 498, "[y(RYZ)]": 499, "[Rm]": 500, "[OCF3]": 501, "[50]": 502, "[2*]": 503, "[NR1R2]": 504, "[Q2]": 505, "3": 506, "[;]": 507, "[NHR]": 508, "[CnH2n]": 509, "[X13]": 510, "[Tr]": 511, "[SO2R1]": 512, "[D3]": 513, "[g]": 514, "[OTES]": 515, "[(R10)r]": 516, "[C4H8SO3]": 517, "[Pr]": 518, "[AR]": 519, "[NR4R5]": 520, "[(C)m]": 521, "[Y7]": 522, "[M]": 523, "[N;]": 524, "[OQ]": 525, "[C@@H]": 526, "[Rb]": 527, "[Rj]": 528, "[where]": 529, "[(R4)x]": 530, "[Z14]": 531, "[R41]": 532, "[)x]": 533, "[Ar12]": 534, "[(CH2)r]": 535, "[C@H]": 536, "[(R5)x]": 537, "[(R14)n]": 538, "o": 539, "[O-Cat+]": 540, "[C(R7)]": 541, "[M(X-Y)n]": 542, "[(T)n]": 543, "[Q3]": 544, "[CsH]": 545, "[(V)t]": 546, "[L,]": 547, "[(R4)s]": 548, "[CO2R5b]": 549, "[63*]": 550, "[E4]": 551, "[EO]": 552, "[Cr]": 553, "[(R1)s]": 554, "[(R11)p]": 555, "[OPg]": 556, "[R']": 557, "[18]": 558, "[(CH2)y]": 559, "[(R4)t]": 560, "[(OH)b]": 561, "[RD]": 562, "[Si]": 563, "[(R8)q]": 564, "[56*]": 565, "[Bc]": 566, "[R]": 567, "[)d]": 568, "[C4H9(t)]": 569, "[R39]": 570, "[H2NSO2]": 571, "[*,]": 572, "[XX]": 573, "[(CH2)f]": 574, "[(R2)q]": 575, "[R37]": 576, "[(H2C)n]": 577, "[48*]": 578, "[QL]": 579, "[(CH2)w]": 580, "[Rk]": 581, "[D5]": 582, "[(R3)r]": 583, "[Cl.]": 584, "[(Q)n]": 585, "[OMe]": 586, "[DE]": 587, "[C4H8]": 588, "[BH3-]": 589, "[V3]": 590, "[R34]": 591, "[68*]": 592, "[Linker]": 593, "[PCy3]": 594, "[(XVI)]": 595, "[OtBu]": 596, "[R16]": 597, "[K+]": 598, "[X8]": 599, "[RJ]": 600, "[RAn]": 601, "[u]": 602, "[NHEt]": 603, "%21": 604, "[(Rb2)n2]": 605, "[OP3]": 606, "[OiPr]": 607, "[P5]": 608, "[tBu]": 609, "[COORD1b]": 610, "[Xa]": 611, "[(R)q]": 612, "[27]": 613, "[Trt]": 614, "[RS]": 615, "[XV]": 616, "[OMe;]": 617, "[Nuc]": 618, "[Zn+2]": 619, "[ORx]": 620, "[R(7)]": 621, "[link]": 622, "[OZ]": 623, "[(XII)]": 624, "[78*]": 625, "C": 626, "[C5H11(t)]": 627, "[RP2]": 628, "[CuPc]": 629, "[Silica]": 630, "[Den2]": 631, "[HO3S]": 632, "\\": 633, "[RaH2]": 634, "6": 635, "[SiH4]": 636, "[nBu]": 637, "[(R2)]": 638, "[COOR1]": 639, "[X0]": 640, "[L13]": 641, "[1*:0]": 642, "[Rc1]": 643, "[(R5)m]": 644, "[RY]": 645, "[A11]": 646, "[OR7]": 647, "[*CH*]": 648, "[alk]": 649, "[Na]": 650, "[P2O]": 651, "[L9]": 652, "[(O)a]": 653, "[(X)p]": 654, "[I+]": 655, "[CR1]": 656, "[Y13]": 657, "[Me2N]": 658, "[C8H17(n)]": 659, "[CO2t-Bu]": 660, "[C2]": 661, "[BaH2]": 662, "1": 663, "[But]": 664, "I": 665, "[CO2Et.]": 666, "[Ti]": 667, "[Et]": 668, "[XH]": 669, "[Y8]": 670, "[C4H9-t]": 671, "[S(O)n]": 672, "[A22]": 673, "[34]": 674, "[B4]": 675, "[CO2R5]": 676, "[(R3)a]": 677, "[Ha]": 678, "[Ta]": 679, "[NHR8]": 680, "[phenyl]": 681, "[(R7)n]": 682, "[HET]": 683, "[Ar1SO2]": 684, "[(R4)b]": 685, "[PROTECTING]": 686, "[(CH2)i]": 687, "[[R]": 688, "[53*]": 689, "[40]": 690, "[3CF3CO2H]": 691, "[12c]": 692, "[R56]": 693, "[A9]": 694, "[COOR10]": 695, "[X18]": 696, "[COOEt]": 697, "[R5a]": 698, "[S3]": 699, "[NHBoc]": 700, "[61*]": 701, "[8]": 702, "[F2C]": 703, "[RSUB2]": 704, "[SO3-Et3NH+]": 705, "[PG2]": 706, "[X1,]": 707, "[R9]": 708, "[(R3)m]": 709, "[(XIV)]": 710, "[CHR3]": 711, "[OR24]": 712, "[B5]": 713, "[1/2]": 714, "[13CH]": 715, "[2Mana1]": 716, "[(CH)m]": 717, "%14": 718, "[66*]": 719, "[COOMe]": 720, "[COOtBu]": 721, "[(X)n]": 722, "[L1*]": 723, "[RB1]": 724, "[(CH2CH2O)m]": 725, "[#]": 726, "[NH4+]": 727, "[Cm]": 728, "[L7]": 729, "[10*]": 730, "[(R2)a]": 731, "[O)m]": 732, "[11*]": 733, "[Cn]": 734, "[6*:0]": 735, "%12": 736, "[[SEQ]": 737, "[(CH2)s]": 738, "[H-]": 739, "[20*]": 740, "Br": 741, "[(R13)m]": 742, "[(Ic)]": 743, "[X1b]": 744, "[W]": 745, "[PH-]": 746, "[IH-4]": 747, "[Al]": 748, "[p]": 749, "[RB4]": 750, "[f]": 751, "[74*]": 752, "[CH2.]": 753, "[(R3)q]": 754, "[FmocHN]": 755, "[OR4]": 756, "[R29]": 757, "[COOC4H9(n)]": 758, "[R32]": 759, "[ClH+]": 760, "[P1]": 761, "[(O]": 762, "[Sm]": 763, "[OL]": 764, "[R8,]": 765, "[MeS]": 766, "[MeLeu]": 767, "[C6H13(n)]": 768, "[S(]": 769, "[COOR7]": 770, "[(X1)m]": 771, "[b3]": 772, "[e]": 773, "[RC]": 774, "[OR13]": 775, "[NR13]": 776, "[E]": 777, "[R-link]": 778, "[Fmoc]": 779, "[C(CH3)]": 780, "[H.]": 781, "[F3CO]": 782, "[CO2R8]": 783, "[NHR2]": 784, "[NHR11]": 785, "[ZO]": 786, "[Alkyl]": 787, "[OR14]": 788, "[X1]": 789, "[SEM]": 790, "[J4]": 791, "[NO2]": 792, "[(Z]": 793, "[Gd+3]": 794, "[(R4)o]": 795, "[SO3M1]": 796, "[(R6)n]": 797, "[A2]": 798, "[CO2PG]": 799, "[Het1]": 800, "[OR8]": 801, "[O)2]": 802, "[SO2R3]": 803, "[23*]": 804, "[(R)k]": 805, "[T6]": 806, "[Q6]": 807, "[OX2]": 808, "[t-C4H9]": 809, "[RSUB1]": 810, "[75*]": 811, "[CH)n]": 812, "[OEt.]": 813, "[Qn]": 814, "[5*]": 815, "[R23]": 816, "[t-BuO2C]": 817, "[C@]": 818, "[R38]": 819, "[Pn]": 820, "[34*]": 821, "[M*]": 822, "[X6]": 823, "[Ar1]": 824, "[Pt]": 825, "[R*]": 826, "[3H]": 827, "[OG]": 828, "[NR7]": 829, "[R8]": 830, "[(CH2]": 831, "[(OH)p]": 832, "[(TR7)t]": 833, "%27": 834, "[OR6]": 835, "[(CH2)e]": 836, "[R25]": 837, "[J]": 838, "[MgH2]": 839, "[CF3.]": 840, "[(SO3H)m]": 841, "[ORd]": 842, "[51*:0]": 843, "[Cd]": 844, "[Ru]": 845, "[21]": 846, "[L14]": 847, "[W5a]": 848, "[Z3]": 849, "[Ir]": 850, "[Sc]": 851, "n": 852, "[Ge]": 853, "[V1]": 854, "[R3b]": 855, "[AlkO]": 856, "[Hg]": 857, "[(R10)0-2]": 858, "[Y10]": 859, "[N@]": 860, "[R10]": 861, "[V]": 862, "[NHSO2Me]": 863, "[Te]": 864, "[CR1R2]": 865, "[R33]": 866, "[2R]": 867, "[35*]": 868, "[alkyl]": 869, "[SO3M]": 870, "[OR3,]": 871, "[XII]": 872, "[CH3]": 873, "[(R4]": 874, "[NCO2t-Bu]": 875, "O": 876, "[s]": 877, "[(CH2)k]": 878, "[R28]": 879, "[Y)n]": 880, "[(R41P)0-2]": 881, "[67*]": 882, "[(ZRY)y]": 883, "%16": 884, "[58*]": 885, "[M3]": 886, "[2H]": 887, "[N-]": 888, "[R11]": 889, "[Ba]": 890, "[C3]": 891, "[(R)p]": 892, "[R:]": 893, "(": 894, "[Cu+]": 895, "[Q4]": 896, "[Nu1]": 897, "[(WRW)m]": 898, "[G1]": 899, "[NHR3]": 900, "[MSG1]": 901, "[Ch]": 902, "[(R4)d]": 903, "[N*]": 904, "[Ph3P]": 905, "[RaH]": 906, "[R.]": 907, "[R72]": 908, "[T4]": 909, "[R45]": 910, "[A]": 911, "[CR3]": 912, "b": 913, "[NHOH.]": 914, "[(A1]": 915, "[BPin]": 916, "[(R5)a]": 917, "[Ni+2]": 918, "[S@]": 919, "[A.]": 920, "[(R9)n]": 921, "[(Z)q]": 922, "[Y1,]": 923, "[n-]": 924, "[OBn]": 925, "[R0]": 926, "[9*:0]": 927, "[peptide]": 928, "[Ra12]": 929, "[32*]": 930, "[(R5)d]": 931, "[25]": 932, "[n-C3H7]": 933, "[COOR6]": 934, "[PGN]": 935, "[SiO3/2]": 936, "[O-M+]": 937, "[Cy]": 938, "[Mj]": 939, "[PgO]": 940, "[Z15]": 941, "[C5H11-t]": 942, "[(R)n]": 943, "[TG]": 944, "[R43]": 945, "[H2n+1Cn]": 946, "[)]": 947, "[(C)p]": 948, "[T3]": 949, "[(V)n-1]": 950, "[La]": 951, "[alkenyl]": 952, "[D4]": 953, "[Val]": 954, "[)b]": 955, "[R18]": 956, "[(R1)p]": 957, "[Rz]": 958, "[(C)r]": 959, "[D1]": 960, "[6*]": 961, "[COR]": 962, "[Ua]": 963, "[M(X]": 964, "%19": 965, "[R,]": 966, "[OR5]": 967, "[N+]": 968, "[Ar]": 969, "[(R1)3]": 970, "[Hf-]": 971, "[CH2OR]": 972, "[SnH]": 973, "[RA]": 974, "[boc]": 975, "[CH2+]": 976, "[(A]": 977, "[[X]": 978, "[n-Pr]": 979, "[Py]": 980, "[(R5)q]": 981, "[OPv]": 982, "[Z0]": 983, "[k]": 984, "[i-Bu]": 985, "[PhO]": 986, "[R100]": 987, "[Z12]": 988, "[15N]": 989, "[Drug]": 990, "[MeSO2]": 991, "[X3]": 992, "[Ar3]": 993, "[Y4]": 994, "[n-C6H13]": 995, "[NHCORB]": 996, "[(R6)m]": 997, "[S*]": 998, "[CF]": 999, "[(Rz)v]": 1000, "[(W)n]": 1001, "[B-]": 1002, "[26*]": 1003, "[Core]": 1004, "[(R1)g]": 1005, "[17*]": 1006, "[H,]": 1007, "[ROH]": 1008, "[Tc]": 1009, "[CF3]": 1010, "[Ar11]": 1011, "[tC4H9]": 1012, "[basic]": 1013, "[Mg]": 1014, "[G2]": 1015, "[7*:0]": 1016, "[CO2H;]": 1017, "[3H-]": 1018, "[[R5]": 1019, "[Cys]": 1020, "[P3]": 1021, "[(R9)q]": 1022, "[b]": 1023, "[O2]": 1024, "[DNA]": 1025, "[X2R3]": 1026, "[CH2*]": 1027, "[NEt2]": 1028, "[TsO]": 1029, "[(R1)x]": 1030, "[CH2R]": 1031, "[NR12]": 1032, "[OPiv]": 1033, "[(CH]": 1034, "[NH-]": 1035, "[L4]": 1036, "[NH2-]": 1037, "[Zb]": 1038, "[1)]": 1039, "[31*]": 1040, "[T1]": 1041, "[(R3)p]": 1042, "[NMe2]": 1043, "[(R2)m1]": 1044, "[2CF3CO2H]": 1045, "[F.]": 1046, "Cl": 1047, "[OEt]": 1048, "[NR48]": 1049, "[OR17]": 1050, "[Ar1p]": 1051, "[Rt]": 1052, "[Pd/C]": 1053, "[TES]": 1054, "[Qa]": 1055, "[(R2)s]": 1056, "[PH2]": 1057, "[BH4-]": 1058, "[(R6)b]": 1059, "[Z9]": 1060, "[C1]": 1061, "[CPG]": 1062, "[OA]": 1063, "[tBuOOC]": 1064, "[H+]": 1065, "[Cbm]": 1066, "[(HO)n]": 1067, "[Compound]": 1068, "[q]": 1069, "[NHPG]": 1070, "[PAG]": 1071, "[Rd]": 1072, "[NR11R12]": 1073, "[OHC]": 1074, "[CONHtBu]": 1075, "[A3]": 1076, "[R2]": 1077, "[19*]": 1078, "[CO2H.]": 1079, "[31]": 1080, "[(R1)2]": 1081, "[(1)]": 1082, "[CO2tBu]": 1083, "[72*]": 1084, "[(R1)]": 1085, "[C(R3)]": 1086, "[SiH]": 1087, "[ORN-a]": 1088, "[S1]": 1089, "[36]": 1090, "[O;]": 1091, "[SO3X]": 1092, "[[R6O2C]": 1093, "[(R2)r]": 1094, "[80*]": 1095, "[Cl,]": 1096, "[+]": 1097, "[S2]": 1098, "[Z.]": 1099, "[a1]": 1100, "[Lb]": 1101, "[*CH2]": 1102, "[PR1]": 1103, "[5*:0]": 1104, "[59*]": 1105, "[77*]": 1106, "[(Z)o]": 1107, "[Rb8]": 1108, "[NR14]": 1109, "[Li+]": 1110, "[L6]": 1111, "[PMP]": 1112, "[(R1)t]": 1113, "[Co-2]": 1114, "[o+]": 1115, "[2)]": 1116, "[98*]": 1117, "[81*]": 1118, "[C4H9(n)]": 1119, "[X16]": 1120, "[l]": 1121, "[85*]": 1122, "[R20]": 1123, "[RpO]": 1124, "[SO2Ph]": 1125, "[Mg+2]": 1126, "[COOX]": 1127, "[OR22]": 1128, "[(Y]": 1129, "[A31]": 1130, "[Rsa1]": 1131, "[OR11]": 1132, "[aryl]": 1133, "[O2N]": 1134, "[Ar4]": 1135, "[30*]": 1136, "[NaBH3CN]": 1137, "[Au]": 1138, "[14*]": 1139, "[CO2R]": 1140, "[Reaction]": 1141, "[(Si]": 1142, "[93*]": 1143, "[Ia]": 1144, "[COOR4]": 1145, "[n+]": 1146, "[Z1]": 1147, "[R17]": 1148, "[(XI)]": 1149, "[2H-]": 1150, "[CR2]": 1151, "[(Y)q]": 1152, "[PAG1]": 1153, "[C34H56+a]": 1154, "[22]": 1155, "[SO3]": 1156, "[S@@+]": 1157, "[CH2]": 1158, "[Pd(PPh3)4/NMM/HOAc/CHCl3]": 1159, "[R61]": 1160, "[Z100]": 1161, "[OPG2]": 1162, "[(XX)]": 1163, "[NR8]": 1164, "[W6]": 1165, "[O)b]": 1166, "[Q]": 1167, "[Pa]": 1168, "[16]": 1169, "[Ar5]": 1170, "[L16]": 1171, "[POLY]": 1172, "[CF3;]": 1173, "[NR6]": 1174, "[(IIIa)]": 1175, "[[Si]": 1176, "[NHR4]": 1177, "[Q,]": 1178, "[CHR1]": 1179, "[C6H4)]": 1180, "[Rc2]": 1181, "[CH3,]": 1182, "[Y12]": 1183, "[Y3]": 1184, "[MeO2C]": 1185, "[(R5)o]": 1186, "[Cb]": 1187, "[,]": 1188, "[Y,]": 1189, ".": 1190, "[(C]": 1191, "[Zc]": 1192, "c": 1193, "[Yp]": 1194, "%26": 1195, "[SH2+]": 1196, "[AM]": 1197, "[NaH]": 1198, "+": 1199, "[5]": 1200, "[R44]": 1201, "[COOR8]": 1202, "[[CH2]": 1203, "[V-2]": 1204, "[4GlcNAc]": 1205, "[91*]": 1206, "[(4)]": 1207, "[COOR2]": 1208, "[R13]": 1209, "[(CH2)x]": 1210, "[NCO]": 1211, "[X+]": 1212, "[Z7]": 1213, "[OP1]": 1214, "[*+]": 1215, "[CR]": 1216, "[(A2]": 1217, "[A21]": 1218, "[nC4H9]": 1219, "[(Bn)2]": 1220, "[(CH2)t]": 1221, "[(H2C]": 1222, "[S-]": 1223, "[W1]": 1224, "[Cbz]": 1225, "[(R12)n]": 1226, "[(XIII)]": 1227, "[Ca]": 1228, "[Ym]": 1229, "[Rg]": 1230, "[3*]": 1231, "%25": 1232, "[O2S]": 1233, "[SO2Me]": 1234, "[SiR]": 1235, "[P@]": 1236, "[(R)b]": 1237, "[(R1)a]": 1238, "[A+]": 1239, "[NH2.]": 1240, "[Zm]": 1241, "[(CH2)u]": 1242, "[NHR6]": 1243, "[Gd]": 1244, "[CH2OR1]": 1245, "[Ra1]": 1246, "[NHR1]": 1247, "[COR6a]": 1248, "[Antibody]": 1249, "[Si(OR)3]": 1250, "[XVI]": 1251, "[(CH2)nCH3]": 1252, "[(A)m]": 1253, "[IH-2]": 1254, "[28]": 1255, "[Y22]": 1256, "[(R2)x]": 1257, "[7]": 1258, "[(A)]": 1259, "[33*]": 1260, "p": 1261, "[J5]": 1262, "[Z8]": 1263, "[RG]": 1264, "[Rb6]": 1265, "[41*]": 1266, "[m-PEG]": 1267, "[OMs]": 1268, "[LQ]": 1269, "[(X)m]": 1270, "[SiH3]": 1271, "[79*]": 1272, "[CO2Bn]": 1273, "[A4]": 1274, "[12*:0]": 1275, "[P3O]": 1276, "[(Z)p]": 1277, "[(R8a)m]": 1278, "[L8]": 1279, "[Rc]": 1280, "[R7,]": 1281, "[mPEG(30K)]": 1282, "[(CH2)v]": 1283, "[OCH3;]": 1284, "[CH)x]": 1285, "[AR2]": 1286, "[69*]": 1287, "[IH-]": 1288, "[COR1]": 1289, "[(Z1)n]": 1290, "[Man]": 1291, "[linker]": 1292, "[R3a]": 1293, "[(OH)q]": 1294, "[M,]": 1295, "[(AO)n]": 1296, "[Y21]": 1297, "[{]": 1298, "[(R3]": 1299, "[n]": 1300, "[Prt]": 1301, "[OR3]": 1302, "[(NH)p]": 1303, "[COOH;]": 1304, "[Z6]": 1305, "[OH-]": 1306, "[Nu]": 1307, "[13CH2]": 1308, "[F3]": 1309, "[(R5)b]": 1310, "[EtO]": 1311, "[CO2Z1]": 1312, "[n-C12H25]": 1313, "[Q8]": 1314, "[(Y)t]": 1315, "[SO3-]": 1316, "[9*]": 1317, "[W3]": 1318, "[2*+]": 1319, "[C)n]": 1320, "[Ca+2]": 1321, "[A-]": 1322, "[Fe]": 1323, "[Q+]": 1324, "[SH]": 1325, "[ODMT]": 1326, "[(R0)q]": 1327, "[XIV]": 1328, "[NHCbz]": 1329, "[(CH)i]": 1330, "[NR6R7]": 1331, "[(Y)m]": 1332, "[MsO]": 1333, "[COOBn]": 1334, "[X7]": 1335, "[y]": 1336, "[Lg]": 1337, "[P2]": 1338, "[DMTO]": 1339, "[X,]": 1340, "[XXII]": 1341, "[Me;]": 1342, "[(Y)p]": 1343, "[COOR]": 1344, "[NR2R3]": 1345, "[P1O]": 1346, "[Sp]": 1347, "[O-2]": 1348, "[B0]": 1349, "[T2]": 1350, "[Hs]": 1351, "[20*:0]": 1352, "[NCS]": 1353, "[NR5R6]": 1354, "[d]": 1355, "[HG]": 1356, "[R40]": 1357, "[Sg]": 1358, "[(CH2)m-1]": 1359, "[NPht]": 1360, "[Si(R2)(R3)]": 1361, "%18": 1362, "[83*]": 1363, "[Aa]": 1364, "[N(iPr)2]": 1365, "[CH2,]": 1366, "[D]": 1367, "[(alkyl)]": 1368, "[F3C]": 1369, "[Pt+2]": 1370, "[3HH]": 1371, "[pH]": 1372, "[OR102]": 1373, "[h]": 1374, "[Z-]": 1375, "[Ep]": 1376, "[ORE]": 1377, "[[R3]": 1378, "[Rr]": 1379, "[OCF2CFHCF3]": 1380, "[)y]": 1381, "[[Y]": 1382, "[SH.]": 1383, "[Cs]": 1384, "[U]": 1385, "[Q7]": 1386, "[Ln]": 1387, "[R81]": 1388, "[(R12)a]": 1389, "[Riii]": 1390, "%17": 1391, "[(R8)n]": 1392, "[Hb]": 1393, "[E1]": 1394, "[(Y)n]": 1395, "[(RA)n]": 1396, "[Z1,]": 1397, "[nH]": 1398, "[Li]": 1399, "[PH+]": 1400, "[E3]": 1401, "[Heteroaryl]": 1402, "[(Ra)w]": 1403, "[(CH2)g]": 1404, "[E2]": 1405, "[ClSO2]": 1406, "[COOBut]": 1407, "N": 1408, "[RF]": 1409, "[No]": 1410, "[Co]": 1411, "[[C(O)]": 1412, "[C)]": 1413, "[[N]": 1414, "[R5]": 1415, "[C]": 1416, "[s+]": 1417, "[HA]": 1418, "[ii)]": 1419, "[R3,]": 1420, "[CmH2m+1]": 1421, "[SA]": 1422, "[a2]": 1423, "[S@@H]": 1424, "[38*]": 1425, "[SR2]": 1426, "[(CH2)n2]": 1427, "[n-C4H9]": 1428, "[(Ra)n]": 1429, "[BASE]": 1430, "[Z5]": 1431, "[RZ]": 1432, "[Ns]": 1433, "[RP]": 1434, "[Ara]": 1435, "[i-C3F7]": 1436, "[Cy1]": 1437, "[SR1]": 1438, "[a4]": 1439, "[Prot]": 1440, "[8*:0]": 1441, "[Zn]": 1442, "[Chiral]": 1443, "[[]": 1444, "[19]": 1445, "[(R4)m]": 1446, "[NHSO2]": 1447, "[A6]": 1448, "[or]": 1449, "[*;]": 1450, "[RM]": 1451, "[OR10]": 1452, "[LiH]": 1453, "[a)]": 1454, "[Abu]": 1455, "[NR9]": 1456, "[COO-t-Bu]": 1457, "[QO]": 1458, "[Rl]": 1459, "[(AcG)GLYACHMGPIT(1-nal)VCQPLR]": 1460, "[1*+]": 1461, "[(CH2CH2]": 1462, "[Lys]": 1463, "[R4,]": 1464, "[OH+]": 1465, "[3)]": 1466, "[MeLeu-MeVal-N]": 1467, "[a3]": 1468, "[OR16]": 1469, "[Cu]": 1470, "[X11]": 1471, "[(R3)n]": 1472, "[(R10)n]": 1473, "[(CH2)p]": 1474, "[P+]": 1475, "[OR7a]": 1476, ":": 1477, "%11": 1478, "[BnO]": 1479, "[76*]": 1480, "[O)y]": 1481, "[X4]": 1482, "[1*]": 1483, "[O1/2]": 1484, "[NR1]": 1485, "%13": 1486, "[Rb4]": 1487, "[Mo]": 1488, "[(R3aR3bC)n]": 1489, "[Cyc2]": 1490, "[R**]": 1491, "[**]": 1492, "[R30]": 1493, "[Hal]": 1494, "[Ala]": 1495, "[C4F9]": 1496, "[RN2]": 1497, "[pol]": 1498, "[36*]": 1499, "[47*]": 1500, "[is]": 1501, "*": 1502, "[R(5)]": 1503, "[27*]": 1504, "[SOmR4,]": 1505, "[Roxa]": 1506, "[86*]": 1507, "[CbzHN]": 1508, "[PH]": 1509, "[(CH2)n,]": 1510, "[8*]": 1511, "[(Z)m]": 1512, "[Rb7]": 1513, "[10]": 1514, "[R27]": 1515, "[SO2R6]": 1516, "[Fe+3]": 1517, "[(R5)p]": 1518, "[(R15)q]": 1519, "[NH2+]": 1520, "[(R12)r]": 1521, "[33]": 1522, "[OPG]": 1523, "[(R5)r]": 1524, "[OY1]": 1525, "[H5C2]": 1526, "[Pg0O]": 1527, "[OR21]": 1528, "[TESO]": 1529, "[Lc]": 1530, "[RP1]": 1531, "[[C]": 1532, "[54*]": 1533, "[(R10)m]": 1534, "[Y5]": 1535, "[B3]": 1536, "[Pd]": 1537, "[Arom]": 1538, "[(CH2)nO]": 1539, "[(CRn)a]": 1540, "[84*]": 1541, "[AA]": 1542, "[CH2X]": 1543, "[(E)]": 1544, "[(t)C4H9]": 1545, "[(R7)q]": 1546, "[(QRX)x]": 1547, "[(R4)n]": 1548, "[Y-]": 1549, "[CH2)2]": 1550, "[GlcNAc]": 1551, "[CH)]": 1552, "[(C(R6)2)p]": 1553, "[Mm+]": 1554, "[M4]": 1555, "[CH3(CH2)n]": 1556, "[and/or]": 1557, "[46*]": 1558, "[(CRVRVI)x]": 1559, "[(CH2)o]": 1560, "[DMTrO]": 1561, "[K2]": 1562, "[(CH2)h]": 1563, "[Y11]": 1564, "[(R1)q]": 1565, "[CO2R11]": 1566, "[RB]": 1567, "[(H,]": 1568, "[Ya]": 1569, "[OH2+]": 1570, "[NR10]": 1571, "[X17]": 1572, "[(R9)m]": 1573, "[X9]": 1574, "[TfO]": 1575, "[(R2)y]": 1576, "[Z32]": 1577, "[Y3,]": 1578, "[)3]": 1579, "[SiMe3]": 1580, "[(R4)r]": 1581, "[OCOR]": 1582, "[(R6)s]": 1583, "[Z1)m]": 1584, "[CO2R3]": 1585, "[(R4)c]": 1586, "[CO2R2]": 1587, "[SO2R2]": 1588, "[Sn]": 1589, "[DMF]": 1590, "[AlH]": 1591, "[R6]": 1592, "[CH2)x]": 1593, "[15*]": 1594, "[Ar6]": 1595, "[b1]": 1596, "[(R10)q]": 1597, "[(R6)q]": 1598, "[11*:0]": 1599, "[X10]": 1600, "[P]": 1601, "[NH+]": 1602, "[N1]": 1603, "[2]": 1604, "[Si@]": 1605, "[)m]": 1606, "[32]": 1607, "[~]": 1608, "[O)]": 1609, "[A1]": 1610, "[ZN]": 1611, "[MeLeu-D-Ala-Ala-MeLeu-Val-MeLeu]": 1612, "[(CR1b2)p]": 1613, "[OH;]": 1614, "[NR]": 1615, "[OH,]": 1616, "[Boc]": 1617, "[CO2Me]": 1618, "[CR5]": 1619, "[Rs]": 1620, "[Cl-]": 1621, "[PPh2]": 1622, "=": 1623, "[Bx]": 1624, "[Bzl]": 1625, "[Z2,]": 1626, "[TBZ]": 1627, "[iPr]": 1628, "[Y.]": 1629, "[49*]": 1630, "[(C(R13)H)r]": 1631, "[.]": 1632, "[a]": 1633, "[18F]": 1634, "[Synthesis]": 1635, "[OH]": 1636, "[COOR5]": 1637, "[)c]": 1638, "[R53]": 1639, "[Yb]": 1640, "[(R3)b]": 1641, "[Cx]": 1642, "[(]": 1643, "[4*:0]": 1644, "[71*]": 1645, "[(R)]": 1646, "[XIX]": 1647, "[L1]": 1648, "[12*]": 1649, "[EWG]": 1650, "[MeOH]": 1651, "[A32]": 1652, "[Dq]": 1653, "[m]": 1654, "%28": 1655, "#": 1656, "[(R3)s]": 1657, "[NH;]": 1658, "[v]": 1659, "[o]": 1660, "[92*]": 1661, "[HX]": 1662, "[BocNH]": 1663, "[(CH2)n-1]": 1664, "[R42]": 1665, "[(L)k]": 1666, "[[R2]": 1667, "[J3]": 1668, "[Alk1]": 1669, "[L22]": 1670, "[(CHR2)n]": 1671, "[C18H37]": 1672, "[x]": 1673, "[SO2R4]": 1674, "[64*]": 1675, "[(T)t]": 1676, "[MeOOC]": 1677, "[(CH)n]": 1678, "[C4]": 1679, "[Y1a]": 1680, "[OCH2Ph]": 1681, "[CO2R14]": 1682, "[(R1)n]": 1683, "[CO]": 1684, "[R46]": 1685, "[(R)2]": 1686, "[Me]": 1687, "[N,]": 1688, "[n-C5H11]": 1689, "[Zr]": 1690, "[Et2N]": 1691, "[SO2NMe2,]": 1692, "[R-link-P]": 1693, "[Y1]": 1694, "[RB3]": 1695, "[Het]": 1696, "[CH*]": 1697, "[Re]": 1698, "[3]": 1699, "[R62]": 1700, "[3R]": 1701, "[ZHN]": 1702, "[Sf]": 1703, "[E6]": 1704, "[A,]": 1705, "[CBz]": 1706, "[Xb]": 1707, "[CH2C6H5]": 1708, "[XVIII]": 1709, "[[CH]": 1710, "[STol]": 1711, "[(R8)m]": 1712, "[C8F17]": 1713, "[(CHR5)m]": 1714, "[73*]": 1715, "[EQ]": 1716, "[R66]": 1717, "[Abu-Sar]": 1718, "[(O)k]": 1719, "[(X]": 1720, "[O.]": 1721, "[Y6]": 1722, "[R26]": 1723, "[Z]": 1724, "[COOt-Bu]": 1725, "[(CR23]": 1726, "[(R7)m]": 1727, "s": 1728, "[2+]": 1729, "[SnBu3]": 1730, "[Rb5]": 1731, "[(R19)p]": 1732, "[Na+]": 1733, "[CCH2)y]": 1734, "[Br]": 1735, "[n-C8H17]": 1736, "[Protecting]": 1737, "[V-]": 1738, "[CHR2]": 1739, "[(CH2)n1]": 1740, "[Fc]": 1741, "[(R1)m]": 1742, "[NCF3]": 1743, "[Z11]": 1744, "[(C(R14)R20)n]": 1745, "[AA1]": 1746, "[C2F4]": 1747, "[R4]": 1748, "[Cycl2]": 1749, "[R1b]": 1750, "[F8]": 1751, "[(R1)r]": 1752, "[XR1]": 1753, "[J2]": 1754, "[OCH3.]": 1755, "[Y2b]": 1756, "[Peptide]": 1757, "[F3CN]": 1758, "[L15]": 1759, "[H2]": 1760, "[R10,]": 1761, "[(R)x]": 1762, "[R71]": 1763, "[A;]": 1764, "[Alk]": 1765, "[C+]": 1766, "[Z10]": 1767, "P": 1768, "[L3]": 1769, "%23": 1770, "[OTf]": 1771, "[H2/Pd]": 1772, "[GP]": 1773, "[tBuO2C]": 1774, "[t]": 1775, "[(Rv)r,]": 1776, "[(Q)d]": 1777, "[L5]": 1778, "[4GlcNAcb1]": 1779, "[polymer]": 1780, "[A23]": 1781, "[XI]": 1782, "[3*:0]": 1783, "[F(Cl,]": 1784, "[C6]": 1785, "[2HCl]": 1786, "[1.]": 1787, "[CO2Et]": 1788, "[G11]": 1789, "[G6]": 1790, "[(B)n]": 1791, "[Yn]": 1792, "[B.]": 1793, "[U2]": 1794, "[PEG]": 1795, "[OX]": 1796, "[CO2CH3.]": 1797, "[CH3-]": 1798, "[Qb]": 1799, "[Rc5]": 1800, "[(C)n]": 1801, "[BocHN]": 1802, "[XO]": 1803, "[(R5)t]": 1804, "[(L1)m]": 1805, "[B2]": 1806, "5": 1807, "[G5]": 1808, "[Ey]": 1809, "[C(H)p]": 1810, "[IH-3]": 1811, "[*CH]": 1812, "[1H]": 1813, "[MO]": 1814, "[X12]": 1815, "[TBS]": 1816, "[Rc3]": 1817, "[7*]": 1818, "[B1]": 1819, "[Y+]": 1820, "[Resin]": 1821, "[COX]": 1822, "[NHR7]": 1823, "[CBZ]": 1824, "[Rf1]": 1825, "[XIII]": 1826, "[(CH2)nRf]": 1827, "[Q5]": 1828, "[Pg]": 1829, "[Rw]": 1830, "[RX]": 1831, "[IX]": 1832, "[AR1]": 1833, "[25*]": 1834, "[SCHEME]": 1835, "[I-3]": 1836, "[F2]": 1837, "[L2*]": 1838, "[COOZ2]": 1839, "[CH2N2]": 1840, "[Mn]": 1841, "[Za]": 1842, "[CX3]": 1843, "[Cl+3]": 1844, "[iBu]": 1845, "[NR8R9]": 1846, "[Fm]": 1847, "[Cu+2]": 1848, "[NHR;]": 1849, "[C15H31-n]": 1850, "[R(1)]": 1851, "[OTr]": 1852, "[PG3]": 1853, "[C6H5]": 1854, "[CH2-]": 1855, "[alkyl,]": 1856, "[(R5]": 1857, "-": 1858, "[A10]": 1859, "[Step]": 1860, "[Q1]": 1861, "[A13]": 1862, "[(R2]": 1863, "[57*]": 1864, "[EE]": 1865, "[PhSO2]": 1866, "[AlH2]": 1867, "[A7]": 1868, "[(O)m]": 1869, "[COOH]": 1870, "[APG]": 1871, "[OR1]": 1872, "[A15]": 1873, "[Ni]": 1874, "[Z31]": 1875, "[(Ib)]": 1876, "[Yq]": 1877, "[Al+3]": 1878, "[As]": 1879, "[(Ia)]": 1880, "[OPh]": 1881, "[Rf2]": 1882, "[Br;]": 1883, "[W2]": 1884, "[Rf]": 1885, "[t-butyl]": 1886, "[R1,]": 1887, "[CaH2]": 1888, "[V2]": 1889, "[C8H17(t)]": 1890, "[CR5R6]": 1891, "[D2]": 1892, "[(O)q]": 1893, "[SH+]": 1894, "[(R0)n]": 1895, "[Cp]": 1896, "[R2,]": 1897, "[(A)n]": 1898, "[Tb]": 1899, "[A12]": 1900, "[nPr]": 1901, "[FULL]": 1902, "[R1a]": 1903, "[RbH]": 1904, "[X1a]": 1905, "[SO2NHtBu]": 1906, "~": 1907, "[Br.]": 1908, "[R(3)]": 1909, "[L0]": 1910, "[(IIa)]": 1911, "[6]": 1912, "[(CRR)m]": 1913, "[Mana1]": 1914, "[ORA]": 1915, "[Ra4]": 1916, "[24]": 1917, "[BH2-]": 1918, "[and]": 1919, "[(CH2)b]": 1920, "[Ra2]": 1921, "[I]": 1922, "[halogen]": 1923, "[R36]": 1924, "[R19]": 1925, "[r]": 1926, "[[O]": 1927, "[OR2(2-a)]": 1928, "[R3]": 1929, "[13*]": 1930, "[M1]": 1931, "[Ry1]": 1932, "[ED]": 1933, "[40*]": 1934, "[RO2C]": 1935, "[OAlk]": 1936, "[MeO]": 1937, "[18*]": 1938, "%22": 1939, "[Fe+2]": 1940, "[Nb]": 1941, "[Protein]": 1942, "[LINKER]": 1943, "[CHR4]": 1944, "[Z2]": 1945, "[Ph]": 1946, "[n(H2C)]": 1947, "[PivO]": 1948, "[N]": 1949, "[GO]": 1950, "[(3)]": 1951, "[(CH2CH]": 1952, "[i-Pr]": 1953, "[L12]": 1954, "[2HH]": 1955, "[1R]": 1956, "[(CH2CH)z]": 1957, "7": 1958, "[F;]": 1959, "[9]": 1960, "[BzO]": 1961, "[Scheme]": 1962, "[L21]": 1963, "[13]": 1964, "[(CH2)d]": 1965, "[Rp]": 1966, "[Y;]": 1967, "[(R1)k]": 1968, "%10": 1969, "[CONH-n-C12H25]": 1970, "[X]": 1971, "[TrO]": 1972, "[(CR4R5)n]": 1973, "[4]": 1974, "[(R6)p]": 1975, "[NH2,]": 1976, "[COOH.]": 1977, "[c+]": 1978, "[OEE]": 1979, "[OR,]": 1980, "[Cu-]": 1981, "[(M)n]": 1982, "[17]": 1983, "[Si@@]": 1984, "[Pg1]": 1985, "[RE]": 1986, "[pg]": 1987, "[0.5]": 1988, "[R6a]": 1989, "[CO2R4]": 1990, "[OBz]": 1991, "[Ts]": 1992, "[50*]": 1993, "[T5]": 1994, "[(O)b]": 1995, "[Xm]": 1996, "[Ar2]": 1997, "[Ro]": 1998, "[S-2]": 1999, "[MSG2]": 2000, "[(CH2CH2O)n]": 2001, "[AA2]": 2002, "[A14]": 2003, "[(Al]": 2004, "[ETA]": 2005, "[(R]": 2006, "[(C(R12)H)q]": 2007, "[=]": 2008, "[NR4]": 2009, "[EtO2C]": 2010, "[CR4]": 2011, "[CH-]": 2012} \ No newline at end of file diff --git a/output.json b/output.json new file mode 100644 index 0000000000000000000000000000000000000000..2c9ce8a2ee1cebe0600f50b86b94136a05a57bad --- /dev/null +++ b/output.json @@ -0,0 +1,326 @@ +{ + "image_title": "Selected Examples", + "reactions": [ + { + "reaction_id": "0_1", + "reactants": [ + { + "smiles": "*C([H])=O", + "label": "1" + }, + { + "smiles": "*C([H])=O", + "label": "2" + } + ], + "condition": [ + { + "role": "reagent", + "text": "5 mol% G31", + "smiles": "*N1=CN2CCCCC2=N1", + "label": "G31" + }, + { + "role": "reagent", + "text": "iPr2NEt (1 equiv)", + "smiles": "CC(C)N(C)C" + }, + { + "role": "solvent", + "text": "CH2Cl2", + "smiles": "ClCCl" + }, + { + "role": "temperature", + "text": "70 \u00b0C" + }, + { + "role": "yield", + "text": "61 - 99%" + } + ], + "products": [ + { + "smiles": "*C(=O)C(*)O", + "label": "3" + } + ] + }, + { + "reaction_id": "1_1", + "reactants": [ + { + "smiles": "O=CCCc1ccccc1", + "label": "1a" + }, + { + "smiles": "COc1ccc(C=O)cc1", + "label": "2a" + } + ], + "condition": [ + { + "role": "reagent", + "text": "5 mol% G31", + "smiles": "*N1=CN2CCCCC2=N1", + "label": "G31" + }, + { + "role": "reagent", + "text": "iPr2NEt (1 equiv)", + "smiles": "CC(C)N(C)C" + }, + { + "role": "solvent", + "text": "CH2Cl2", + "smiles": "ClCCl" + }, + { + "role": "temperature", + "text": "70 \u00b0C" + }, + { + "role": "yield", + "text": "99%" + } + ], + "products": [ + { + "smiles": "COc1ccc(C(O)C(=O)CCc2ccccc2)cc1", + "label": "3a" + } + ], + "additional_info": [] + }, + { + "reaction_id": "2_1", + "reactants": [ + { + "smiles": "O=CCCc1ccccc1", + "label": "1b" + }, + { + "smiles": "O=Cc1ccccc1Br", + "label": "2b" + } + ], + "condition": [ + { + "role": "reagent", + "text": "5 mol% G31", + "smiles": "*N1=CN2CCCCC2=N1", + "label": "G31" + }, + { + "role": "reagent", + "text": "iPr2NEt (1 equiv)", + "smiles": "CC(C)N(C)C" + }, + { + "role": "solvent", + "text": "CH2Cl2", + "smiles": "ClCCl" + }, + { + "role": "temperature", + "text": "70 \u00b0C" + }, + { + "role": "yield", + "text": "89%" + } + ], + "products": [ + { + "smiles": "O=C(CCc1ccccc1)C(O)c1ccccc1Br", + "label": "3b" + } + ], + "additional_info": [] + }, + { + "reaction_id": "3_1", + "reactants": [ + { + "smiles": "O=CCCc1ccccc1", + "label": "1c" + }, + { + "smiles": "O=Cc1ccccc1", + "label": "2c" + } + ], + "condition": [ + { + "role": "reagent", + "text": "5 mol% G31", + "smiles": "*N1=CN2CCCCC2=N1", + "label": "G31" + }, + { + "role": "reagent", + "text": "iPr2NEt (1 equiv)", + "smiles": "CC(C)N(C)C" + }, + { + "role": "solvent", + "text": "CH2Cl2", + "smiles": "ClCCl" + }, + { + "role": "temperature", + "text": "70 \u00b0C" + }, + { + "role": "yield", + "text": "84%" + } + ], + "products": [ + { + "smiles": "O=C(CCc1ccccc1)C(O)c1ccccc1", + "label": "3c" + } + ], + "additional_info": [] + }, + { + "reaction_id": "4_1", + "reactants": [ + { + "smiles": "O=CCCc1ccccc1", + "label": "1d" + }, + { + "smiles": "COC(=O)c1ccc(C=O)cc1", + "label": "2d" + } + ], + "condition": [ + { + "role": "reagent", + "text": "5 mol% G31", + "smiles": "*N1=CN2CCCCC2=N1", + "label": "G31" + }, + { + "role": "reagent", + "text": "iPr2NEt (1 equiv)", + "smiles": "CC(C)N(C)C" + }, + { + "role": "solvent", + "text": "CH2Cl2", + "smiles": "ClCCl" + }, + { + "role": "temperature", + "text": "70 \u00b0C" + }, + { + "role": "yield", + "text": "61%" + } + ], + "products": [ + { + "smiles": "COC(=O)c1ccc(C(O)C(=O)CCc2ccccc2)cc1", + "label": "3d" + } + ], + "additional_info": [] + }, + { + "reaction_id": "5_1", + "reactants": [ + { + "smiles": "CC(C)C=O", + "label": "1e" + }, + { + "smiles": "O=Cc1ccccc1", + "label": "2e" + } + ], + "condition": [ + { + "role": "reagent", + "text": "5 mol% G31", + "smiles": "*N1=CN2CCCCC2=N1", + "label": "G31" + }, + { + "role": "reagent", + "text": "iPr2NEt (1 equiv)", + "smiles": "CC(C)N(C)C" + }, + { + "role": "solvent", + "text": "CH2Cl2", + "smiles": "ClCCl" + }, + { + "role": "temperature", + "text": "70 \u00b0C" + }, + { + "role": "yield", + "text": "73%" + } + ], + "products": [ + { + "smiles": "CC(C)C(=O)C(O)c1ccccc1", + "label": "3e" + } + ], + "additional_info": [] + }, + { + "reaction_id": "6_1", + "reactants": [ + { + "smiles": "CC=O", + "label": "1f" + }, + { + "smiles": "O=Cc1ccccc1", + "label": "2f" + } + ], + "condition": [ + { + "role": "reagent", + "text": "5 mol% G31", + "smiles": "*N1=CN2CCCCC2=N1", + "label": "G31" + }, + { + "role": "reagent", + "text": "iPr2NEt (1 equiv)", + "smiles": "CC(C)N(C)C" + }, + { + "role": "solvent", + "text": "CH2Cl2", + "smiles": "ClCCl" + }, + { + "role": "temperature", + "text": "70 \u00b0C" + }, + { + "role": "yield", + "text": "68%" + } + ], + "products": [ + { + "smiles": "CC(=O)C(O)c1ccccc1", + "label": "3f" + } + ], + "additional_info": [] + } + ] +} \ No newline at end of file diff --git a/pix2seq_reaction_full.ckpt b/pix2seq_reaction_full.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..059bb7122c96e5dc0209614bfdd7a412da59285b --- /dev/null +++ b/pix2seq_reaction_full.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b0020634f13fb3e1f588bddca97f68fd6483f0cdecd83e2ca31c2434ea4340fe +size 432324497 diff --git a/prompt.txt b/prompt.txt new file mode 100644 index 0000000000000000000000000000000000000000..4af06b8e25f776818b0972e33e8aaaf7a053083c --- /dev/null +++ b/prompt.txt @@ -0,0 +1,30 @@ +You are a helpful assistant in identifying chemistry data in an image. In this reaction image, there are chemistry reaction diagrams with multiple product molecular diagrams with the detailed R-group information with and their corresponding coref and text that represents different reaction products. Your task is to use "get_multi_molecular_text_to_correct" and "get_reaction" function get the tools outputs first. for the get_multi_molecular_text_to_correct output, please Double-check the text in the image and fix the OCR error of the label and text information in the key 'text'(Sometimes the correct number plus letter label will be incorrectly recognized as multiple digits and miss the letter. For example, 3a is incorrectly recognized as 33, 3o is incorrectly recognized as 30,. Please check the image and fix the OCR error according to the image and the associated text), and find the missing text for some molecular smiles. +If there is no any label and coref information provided in the image and output, please then look into the "get_reaction" output to find the condition smiles, the reactant and product template(ususally there are special sumbol such as "*", "1*" to represent R-group in the template). Then first find them in the get_multi_molecular_text_to_correct output, then find other reaction products. And create labels for all the molecules. +You should do: + Rearrange the output according to the image and assign a label to each molecule (use "1", "2", "3", "3a","3b", assign a single number to the reactant/product templete such as "1"/"2".... Then find all the different products depend on the 'smiles' and 'text'(if smiles and text come in one after another such as: {"smiles": "CCC1(c2ccccc2)C(=O)OC(c2ccccc2)=NN1c1ccccc1", "text": ["79% yield", "96% ee"]}, then there's a good chance it's product), use number + a,b,c.. to represent different products such as "2a","2b" (please use from a to z), make sure the single number is assigned before the same number + alphabet, and should start from "a", and the product template uses the same number as different products. Suppose that there is only one such set of corresponding products), please find all the products, don't miss the molecule that miss corresponding text. And also find the text that is missed in the tool output. + And make sure that the 'product template' and 'product' use the same number as the label (For example ['3' and '3a','3b','3da']... or ['4' and '4a','4b','4fa']...And so on.). + Find if there is any smiles in the 'condition' in "get_reaction" output. If not, please recheck the image and find from the 'get_multi_molecular_text_to_correct' output which is neither smiles of reactant, product template, nor smiles of product, They are also condition smiles. Please also find their label in the image if there are (such as B17, A18, B27). Then output it like : 'CC(CC(=O)c1ccccc1)OCCC#N':['B17', 'conditions']. + Please do not change any tool outputs of the 'smiles'. + Please make sure that all SMILES in the output of 'the get_multi_molecular_text_to_correct' tool are in the final output in the four categories: 'reactant template', 'product template', 'condition smiles', 'product'. + + !!!If the molecule in the image already has a label, such as 2a, 3b, 3fa, or 3da, use the label provided in the image. + !!!! important NOTE: due to the characteristics of the tool, if YOU find the labels of the molecule appear likse 3ab,3ac,3ad,3ae,3af (the first English letter (here is a)remains unchanged), our tool cannot output normally, so you should copy the second letter to the first, make sure that the first English letters are different, such as (3bab,3cac,3dad,3eae,3faf). + !!!! change label like 3ab,3ac,3ad,3ae,3af to 3bab,3cac,3dad,3eae,3faf. +Your output should be like the json format. An example output is: + + {'*C(=O)NN=CC(F)(F)F': ['1','reactant template'], + 'N#CN':['2','reactant template'],###smiles and text and assigned label for reactants template in the "get_reaction".(Sometimes there are same reactants) + '*C(=O)N1NC(C(F)(F)F)N=C1N': ['3','product template'], ###smiles and text and assigned label for the product template in the "get_reaction". And make sure that the product template and product use the same number. + 'Cc1cc(C)cc(C(=O)N2NC(C(F)(F)F)N=C2N)c1': ['3a', '6 h, 88%','product'],###smiles and text and assigned label for different products. Note that please also identify the missing text of the tool. + 'Cc1cc(C)cc(C(=O)N2NC(C(F)(F)F)N=C2N)c1': ['3b, X:CF3, 6 h, 63%, 2q, X:OMe, 4 h, 60%','product'],###smiles and text and assigned label for different products. Note that please also identify the missing text of the tool. + 'Cc1cc(C)cc(C(=O)N2NC(C(F)(F)F)N=C2N)c1': ['3ca', "8 h, 91%', 'product']###smiles and text and assigned label for different products. Note that please also identify the missing text of the tool. + 'Cc1cc(C)cc(C(=O)N2NC(C(F)(F)F)N=C2N)c1': ['3cac', '3ac, 8 h, 91%', 'product'],###change label like 3ab,3ac,3ad,3ae,3af to 3bab,3cac,3dad,3eae,3faf + 'Cc1cc(C)cc(C(=O)N2NC(C(F)(F)F)N=C2N)c1': ['3dad', '3ad, 8 h, 90%', 'product'],###change label like 3ab,3ac,3ad,3ae,3af to 3bab,3cac,3dad,3eae,3faf + 'CC(CC(=O)c1ccccc1)OCCC#N':['B17', 'conditions'] ## if there is any smiles in the 'condition' in "get_reaction" output. If not, please find from the 'get_multi_molecular_text_to_correct output' which is neither smiles of reactant, product template, nor smiles of product, They are also conditions smiles. + } + +!!!! important NOTE: due to the characteristics of the tool, if YOU find the labels of the molecule appear likse 3ab,3ac,3ad,3ae,3af (the first English letter (here is a)remains unchanged), our tool cannot output normally, so you should copy the second letter to the first, make sure that the first English letters are different, such as (3bab,3cac,3dad,3eae,3faf). +!!!! important NOTE: change label like 3ab,3ac,3ad,3ae,3af to 3bab,3cac,3dad,3eae,3faf. +!!!! important NOTE: Please make sure again that all SMILES in the output of 'the get_multi_molecular_text_to_correct' tool are in the final output, Which means the number of the SMLIES in the output of 'the get_multi_molecular_text_to_correct' tool is the same as in the final output. + +!!!!!!!important: Please check your results again to make sure the 'product template' and 'product' use the same number, avoid like ['2','product template'] and ['3a','product'] appear together. And Make sure that the number of SMILES in you final output is equal to the number of SMILES in "get_multi_molecular_text_to_correct" \ No newline at end of file diff --git a/prompt/prompt.txt b/prompt/prompt.txt new file mode 100644 index 0000000000000000000000000000000000000000..4af06b8e25f776818b0972e33e8aaaf7a053083c --- /dev/null +++ b/prompt/prompt.txt @@ -0,0 +1,30 @@ +You are a helpful assistant in identifying chemistry data in an image. In this reaction image, there are chemistry reaction diagrams with multiple product molecular diagrams with the detailed R-group information with and their corresponding coref and text that represents different reaction products. Your task is to use "get_multi_molecular_text_to_correct" and "get_reaction" function get the tools outputs first. for the get_multi_molecular_text_to_correct output, please Double-check the text in the image and fix the OCR error of the label and text information in the key 'text'(Sometimes the correct number plus letter label will be incorrectly recognized as multiple digits and miss the letter. For example, 3a is incorrectly recognized as 33, 3o is incorrectly recognized as 30,. Please check the image and fix the OCR error according to the image and the associated text), and find the missing text for some molecular smiles. +If there is no any label and coref information provided in the image and output, please then look into the "get_reaction" output to find the condition smiles, the reactant and product template(ususally there are special sumbol such as "*", "1*" to represent R-group in the template). Then first find them in the get_multi_molecular_text_to_correct output, then find other reaction products. And create labels for all the molecules. +You should do: + Rearrange the output according to the image and assign a label to each molecule (use "1", "2", "3", "3a","3b", assign a single number to the reactant/product templete such as "1"/"2".... Then find all the different products depend on the 'smiles' and 'text'(if smiles and text come in one after another such as: {"smiles": "CCC1(c2ccccc2)C(=O)OC(c2ccccc2)=NN1c1ccccc1", "text": ["79% yield", "96% ee"]}, then there's a good chance it's product), use number + a,b,c.. to represent different products such as "2a","2b" (please use from a to z), make sure the single number is assigned before the same number + alphabet, and should start from "a", and the product template uses the same number as different products. Suppose that there is only one such set of corresponding products), please find all the products, don't miss the molecule that miss corresponding text. And also find the text that is missed in the tool output. + And make sure that the 'product template' and 'product' use the same number as the label (For example ['3' and '3a','3b','3da']... or ['4' and '4a','4b','4fa']...And so on.). + Find if there is any smiles in the 'condition' in "get_reaction" output. If not, please recheck the image and find from the 'get_multi_molecular_text_to_correct' output which is neither smiles of reactant, product template, nor smiles of product, They are also condition smiles. Please also find their label in the image if there are (such as B17, A18, B27). Then output it like : 'CC(CC(=O)c1ccccc1)OCCC#N':['B17', 'conditions']. + Please do not change any tool outputs of the 'smiles'. + Please make sure that all SMILES in the output of 'the get_multi_molecular_text_to_correct' tool are in the final output in the four categories: 'reactant template', 'product template', 'condition smiles', 'product'. + + !!!If the molecule in the image already has a label, such as 2a, 3b, 3fa, or 3da, use the label provided in the image. + !!!! important NOTE: due to the characteristics of the tool, if YOU find the labels of the molecule appear likse 3ab,3ac,3ad,3ae,3af (the first English letter (here is a)remains unchanged), our tool cannot output normally, so you should copy the second letter to the first, make sure that the first English letters are different, such as (3bab,3cac,3dad,3eae,3faf). + !!!! change label like 3ab,3ac,3ad,3ae,3af to 3bab,3cac,3dad,3eae,3faf. +Your output should be like the json format. An example output is: + + {'*C(=O)NN=CC(F)(F)F': ['1','reactant template'], + 'N#CN':['2','reactant template'],###smiles and text and assigned label for reactants template in the "get_reaction".(Sometimes there are same reactants) + '*C(=O)N1NC(C(F)(F)F)N=C1N': ['3','product template'], ###smiles and text and assigned label for the product template in the "get_reaction". And make sure that the product template and product use the same number. + 'Cc1cc(C)cc(C(=O)N2NC(C(F)(F)F)N=C2N)c1': ['3a', '6 h, 88%','product'],###smiles and text and assigned label for different products. Note that please also identify the missing text of the tool. + 'Cc1cc(C)cc(C(=O)N2NC(C(F)(F)F)N=C2N)c1': ['3b, X:CF3, 6 h, 63%, 2q, X:OMe, 4 h, 60%','product'],###smiles and text and assigned label for different products. Note that please also identify the missing text of the tool. + 'Cc1cc(C)cc(C(=O)N2NC(C(F)(F)F)N=C2N)c1': ['3ca', "8 h, 91%', 'product']###smiles and text and assigned label for different products. Note that please also identify the missing text of the tool. + 'Cc1cc(C)cc(C(=O)N2NC(C(F)(F)F)N=C2N)c1': ['3cac', '3ac, 8 h, 91%', 'product'],###change label like 3ab,3ac,3ad,3ae,3af to 3bab,3cac,3dad,3eae,3faf + 'Cc1cc(C)cc(C(=O)N2NC(C(F)(F)F)N=C2N)c1': ['3dad', '3ad, 8 h, 90%', 'product'],###change label like 3ab,3ac,3ad,3ae,3af to 3bab,3cac,3dad,3eae,3faf + 'CC(CC(=O)c1ccccc1)OCCC#N':['B17', 'conditions'] ## if there is any smiles in the 'condition' in "get_reaction" output. If not, please find from the 'get_multi_molecular_text_to_correct output' which is neither smiles of reactant, product template, nor smiles of product, They are also conditions smiles. + } + +!!!! important NOTE: due to the characteristics of the tool, if YOU find the labels of the molecule appear likse 3ab,3ac,3ad,3ae,3af (the first English letter (here is a)remains unchanged), our tool cannot output normally, so you should copy the second letter to the first, make sure that the first English letters are different, such as (3bab,3cac,3dad,3eae,3faf). +!!!! important NOTE: change label like 3ab,3ac,3ad,3ae,3af to 3bab,3cac,3dad,3eae,3faf. +!!!! important NOTE: Please make sure again that all SMILES in the output of 'the get_multi_molecular_text_to_correct' tool are in the final output, Which means the number of the SMLIES in the output of 'the get_multi_molecular_text_to_correct' tool is the same as in the final output. + +!!!!!!!important: Please check your results again to make sure the 'product template' and 'product' use the same number, avoid like ['2','product template'] and ['3a','product'] appear together. And Make sure that the number of SMILES in you final output is equal to the number of SMILES in "get_multi_molecular_text_to_correct" \ No newline at end of file diff --git a/prompt/prompt_final.txt b/prompt/prompt_final.txt new file mode 100644 index 0000000000000000000000000000000000000000..ae2924c10de5e5ae9fa2bca593c7ee5ec60981a1 --- /dev/null +++ b/prompt/prompt_final.txt @@ -0,0 +1,261 @@ +You are a helpful chemical assistant in identifying chemistry data in an image. In this reaction image, there is a chemistry reaction diagram with one step reaction tempelete and a image-based table consisting of product molecular images with detailed R-group and different conditions. Use the "get_reaction" function provided to get the reaction data of the reaction diagram and use the "process_reaction_image_with_multiple_products" function provided to get SMILES strings of every detailed reaction in reaction diagram and the table. Then based on the tool results, your task is to recheck them and the image with the table, match the detailed product and condition in the table with the corresponding detailed reaction, and re-label each reaction and product according to the order in the picture. Also please identifying the condition role in"reagents","solvents","yield","time(such as "1 h", "24 h")","temperature (Note "rt" is temperature too)",if there is no then use "None, and show additional information displayed in the table in "additional_info" section. Additionally, assign a reaction number to each modified reaction and output the updated results. + +Requirements: + Use the tools to get the SMILES of reaction template, and the SMILES of the detailed reactions that the the detailed products are given in the table. + First identify the condition roles in the original conditions in the reaction template. Then add the conditions above the products in table, identifying their condition roles too, and add additional imformation. And match these conditions that in the table to different reactions based on the reaction SMILES. And for simple chemistry texts in reagent and solvent, please conver them into SMILES as well based on your knowledge. + if there is molecule and it's label in the condition. please find it's SMILES in the "get_reaction" tool output and combine the SMILES and it's label with the corresponding text. If the tool does not output this condition SMILES, please recheck the image yourself and try to convert the molecule to SMILES based on your chemical knowledge. + Please re-label each reaction and product according to the order in the image.(use the single number for reactant and product tamplate(1,2,3), and single number + English alphabet(1a,1b,2a,2b,3a,3b) for diffirent reactant and product) + Generate a complete reaction list +An example is: +First out put the original reaction with (with coref (label) when the label is provided such as "1a","2a","3b", or else use "label":"None"). Then for each row of the table, generate the corresponding reaction by replacing the molecular SMILES and the conditions. +The result should look like this json format: +{ +"reactions":[ +{ + # reaction tempelete + "reaction_id": 0_1 + "reactants": [ + { + "smiles": " reactant template smiles", + "label": "1" + }, + { + "smiles": "reactant template smiles", + "label": "2" + }, + ] + .... + “condition":[ + + {###The Molecule and it's label in the condition. + "role": "reagent", + "text": "5 mol% G30" + "smiles": "reagent smiles of G30", + "label":"G30" + }, + {###If there is a "or" in the condition. + "role": "reagent", + "text": "5 mol% G31 or G33" + "smiles": "reagent smiles of G31", + "label":"G31" + }, + {###If there is a "or" in the condition. + "role": "reagent", + "text": "5 mol% G31 or G33" + "smiles": "reagent smiles of G33", + "label":"G33" + }, + ###### + { + "role": "reagent", + "text": "B-chlorocatecholborane (1.4 equiv)", + "smiles": "reagent smiles", + }, + { + "role": "solvent", + "text": "toluene", + "smiles": "solvent smiles", + }, + { + "role": "time", + "text": "24 h" + }, + { + "role": "temperature", + "text": "100 °C" + }, + { ### original yield if provided + "role": "yield", + "text": "38%" + } + ] + "products": [ + { + "smiles": "product template smiles", + "label": "3" + }, + ... + ] +}, +{ + "reaction_id": 1_1 + "reactants": { + "smiles": "reactant smiles", + "label": "1a" + }, + { + "smiles": "reactant smiles", + "label": "2a" + }, + .... + “condition":[#Note: identify the condition roles in the original conditions and table both + { + "role": "reagent", + "text": "B-chlorocatecholborane (1.4 equiv)" + }, + { + "role": "solvent", + "text": "toluene" + }, + { + "role": "time", + "text": "24 h" + }, + { + "role": "temperature", + "text": "100 °C" + }, + { ### new yield + "role": "yield", + "text": "89%" + } + ,] + "products": [ + { + "smiles": "product smiles", + "label": "3a" + }, + ] + “additional_info" :[{...},{...}] +}, +{ + "reaction_id": 2_1 + "reactants": [...], + .... + “condition":[{new condition based on the table}, {...}, ...] + "products": [...] + “additional_info" :[{...},{...}] +}, +{ + "reaction_id": 3_1, ... +}, ... +] +}. + + +After you check the image, if this is a reaction image that contains only a text-based table and does not involve any R-group replacement, simply call the "get_full_reaction" agent tool and organize the condition on each row of the table. Please carefully confirm how many rows there are in this table, and then output the corresponding number of reactions according to this. The output format remains unchanged. !!!important: Make sure that no matter how many rows there are in the table, you should complete the output reactions for every rows. Sometimes tables are 20 rows, then output 20 reactions. + + +After you check the image, if this is a reaction image does not contain any tables or sets of product variants, then just call the "get_full_reaction" agent tool and the output format should be like: +{ + "reactions": [ + { + "reaction_id": "0_1"(The first step of the first reaction), + "reactants": [ + { + "smiles": "*C(=O)NN=CC(F)(F)F", + "label": "None" (or a number if there is) + }, + { + "smiles": "N#CN", + "label": "None" (or a number if there is) + } + ], + "conditions": [ + { + "role": "reagent", + "text": "K2CO3", + "smiles": ...... + }, + { + "role": "solvent", + "text": "THF", + "smiles": ...... + } + ], + "products": [ + { + "smiles": "*C(=O)N1NC(C(F)(F)F)N=C1N", + "label": "None" (or a number if there is) + } + ] + }, + { + "reaction_id": "0_2"(The second step of the first reaction), + "reactants": [ + { + "smiles": "*C(=O)N1NC(C(F)(F)F)N=C1N", + "label": "None" (or a number if there is) + } + ], + "conditions": [ + { + "role": "reagent", + "text": "NBS" + "smiles": ...... + }, + { + "role": "solvent", + "text": "DCM" + "smiles": ...... + }, + { + "role": "temperature", + "text": "reflux" + } + ], + "products": [ + { + "smiles": "*C(=O)n1nc(C(F)(F)F)nc1N", + "label": "None" (or a number if there is) + } + ] + }, + { + "reaction_id": "1_1"(The first step of the second reaction), + "reactants": [ + { + "smiles": "*CC12=C3C1=CC=C(C(=O)NN=CC(F)(F)F)C=23*", + "label": "None" (or a number if there is) + }, + { + "smiles": "N=C=N", + "label": "None" (or a number if there is) + } + ], + "conditions": [ + { + "role": "reagent", + "text": "CuI" + "smiles": ...... + }, + { + "role": "solvent", + "text": "DMF" + "smiles": ...... + }, + { + "role": "temperature", + "text": "150 \u00b0C" + }, + { + "role": "additional_info", + "text": "X = Cl, Br, I" + } + ], + "products": [ + { + "smiles": "*C.O=C1C2=CCC=CC=2NC2=NC(C(F)(F)F)=NN12", + "label": "None" (or a number if there is) + } + ] + } + ] +} + + +After you check the image, if this is a image does not contain any reactions and just have one on multiple separated molecular sub-images, then just call the "get_multi_molecular" agent tool and the output format should be like: + { + "molecules": [ + { + "smiles": "....", + "label": "....", + "bbox": + }, + { + "smiles": "....", + "label": "....", + "bbox": + }, + ] +} \ No newline at end of file diff --git a/prompt/prompt_final_simple_version.txt b/prompt/prompt_final_simple_version.txt new file mode 100644 index 0000000000000000000000000000000000000000..e898ef88b6c92eb01ebde2b04e2ca0717b8858b1 --- /dev/null +++ b/prompt/prompt_final_simple_version.txt @@ -0,0 +1,263 @@ +You are a helpful chemical assistant in identifying chemistry data in an image. In this reaction image, there is a chemistry reaction diagram with one step reaction tempelete and a image-based table consisting of product molecular images with detailed R-group and different conditions. Use the "process_reaction_image" agent function provided to get the reaction data of the reaction diagram and get SMILES strings of every detailed reaction in reaction diagram and the table, and the original molecular list. Then based on the tool results, your task is to recheck them and the image with the table, match the detailed product and condition in the table with the corresponding detailed reaction, and re-label each reaction and product according to the order in the picture. Also please identifying the condition role in"reagents","solvents","yield","time(such as "1 h", "24 h")","temperature (Note "rt" is temperature too)",if there is no then use "None, and show additional information displayed in the table in "additional_info" section. Additionally, assign a reaction number to each modified reaction and output the updated results. + +Requirements: + Use the tools to get the SMILES of reaction template, and the SMILES of the detailed reactions that the the detailed products are given in the table. + If any molecule or label lacks a SMILES in the tool output, please consult the “Library of known molecules” before proceeding. + First identify the condition roles in the original conditions in the reaction template. Then add the conditions above the products in table, identifying their condition roles too, and add additional imformation. And match these conditions that in the table to different reactions based on the reaction SMILES. And for simple chemistry texts in reagent and solvent, please conver them into SMILES as well based on your knowledge. + if there is molecule and it's label in the condition. please find it's SMILES in the "process_reaction_image" tool output and combine the SMILES and it's label with the corresponding text. If still cannot find, recheck the image yourself and try to convert the molecule to SMILES based on your chemical knowledge. If there is only a label in the condition, please first check the 'Library of known molecules' to find the corresponding SMILES, if can't find ,then no need to output SMILES. + Please re-label each reaction and product according to the order in the image.(use the single number for reactant and product tamplate(1,2,3), and single number + English alphabet(1a,1b,2a,2b,3a,3b) for diffirent reactant and product) + Generate a complete reaction list +An example is: +First out put the original reaction with (with coref (label) when the label is provided such as "1a","2a","3b", or else use "label":"None"). Then for each row of the table, generate the corresponding reaction by replacing the molecular SMILES and the conditions. +The result should look like this json format: +{ +"image_title":"Table 1 .../Figure 2 .... OR None", +"reactions":[ +{ + # reaction tempelete + "reaction_id": 0_1 + "reactants": [ + { + "smiles": " reactant template smiles", + "label": "1" + }, + { + "smiles": "reactant template smiles", + "label": "2" + }, + ] + .... + “condition":[ + + {###The Molecule and it's label in the condition. + "role": "reagent", + "text": "5 mol% G30" + "smiles": "reagent smiles of G30", + "label":"G30" + }, + {###If there is a "or" in the condition. + "role": "reagent", + "text": "5 mol% G31 or G33" + "smiles": "reagent smiles of G31", + "label":"G31" + }, + {###If there is a "or" in the condition. + "role": "reagent", + "text": "5 mol% G31 or G33" + "smiles": "reagent smiles of G33", + "label":"G33" + }, + ###### + { + "role": "reagent", + "text": "B-chlorocatecholborane (1.4 equiv)", + "smiles": "reagent smiles", + }, + { + "role": "solvent", + "text": "toluene", + "smiles": "solvent smiles", + }, + { + "role": "time", + "text": "24 h" + }, + { + "role": "temperature", + "text": "100 °C" + }, + { ### original yield if provided + "role": "yield", + "text": "38%" + } + ] + "products": [ + { + "smiles": "product template smiles", + "label": "3" + }, + ... + ] +}, +{ + "reaction_id": 1_1 + "reactants": { + "smiles": "reactant smiles", + "label": "1a" + }, + { + "smiles": "reactant smiles", + "label": "2a" + }, + .... + “condition":[#Note: identify the condition roles in the original conditions and table both + { + "role": "reagent", + "text": "B-chlorocatecholborane (1.4 equiv)" + }, + { + "role": "solvent", + "text": "toluene" + }, + { + "role": "time", + "text": "24 h" + }, + { + "role": "temperature", + "text": "100 °C" + }, + { ### new yield + "role": "yield", + "text": "89%" + } + ,] + "products": [ + { + "smiles": "product smiles", + "label": "3a" + }, + ] + “additional_info" :[{...},{...}] +}, +{ + "reaction_id": 2_1 + "reactants": [...], + .... + “condition":[{new condition based on the table}, {...}, ...] + "products": [...] + “additional_info" :[{...},{...}] +}, +{ + "reaction_id": 3_1, ... +}, ... +] +}. + + +After you check the image, if this is a reaction image that contains only a text-based table and does not involve any R-group replacement, simply call the "get_full_reaction" agent tool and organize the condition on each row of the table. Please carefully confirm how many rows there are in this table, and then output the corresponding number of reactions according to this. The output format remains unchanged. !!!important: Make sure that no matter how many rows there are in the table, you should complete the output reactions for every rows. Sometimes tables are 20 rows, then output 20 reactions. + + +After you check the image, if this is a reaction image does not contain any tables or sets of product variants, then just call the "get_full_reaction" agent tool and the output format should be like: +{ + "reactions": [ + { + "reaction_id": "0_1"(The first step of the first reaction), + "reactants": [ + { + "smiles": "*C(=O)NN=CC(F)(F)F", + "label": "None" (or a number if there is) + }, + { + "smiles": "N#CN", + "label": "None" (or a number if there is) + } + ], + "conditions": [ + { + "role": "reagent", + "text": "K2CO3", + "smiles": ...... + }, + { + "role": "solvent", + "text": "THF", + "smiles": ...... + } + ], + "products": [ + { + "smiles": "*C(=O)N1NC(C(F)(F)F)N=C1N", + "label": "None" (or a number if there is) + } + ] + }, + { + "reaction_id": "0_2"(The second step of the first reaction), + "reactants": [ + { + "smiles": "*C(=O)N1NC(C(F)(F)F)N=C1N", + "label": "None" (or a number if there is) + } + ], + "conditions": [ + { + "role": "reagent", + "text": "NBS" + "smiles": ...... + }, + { + "role": "solvent", + "text": "DCM" + "smiles": ...... + }, + { + "role": "temperature", + "text": "reflux" + } + ], + "products": [ + { + "smiles": "*C(=O)n1nc(C(F)(F)F)nc1N", + "label": "None" (or a number if there is) + } + ] + }, + { + "reaction_id": "1_1"(The first step of the second reaction), + "reactants": [ + { + "smiles": "*CC12=C3C1=CC=C(C(=O)NN=CC(F)(F)F)C=23*", + "label": "None" (or a number if there is) + }, + { + "smiles": "N=C=N", + "label": "None" (or a number if there is) + } + ], + "conditions": [ + { + "role": "reagent", + "text": "CuI" + "smiles": ...... + }, + { + "role": "solvent", + "text": "DMF" + "smiles": ...... + }, + { + "role": "temperature", + "text": "150 \u00b0C" + }, + { + "role": "additional_info", + "text": "X = Cl, Br, I" + } + ], + "products": [ + { + "smiles": "*C.O=C1C2=CCC=CC=2NC2=NC(C(F)(F)F)=NN12", + "label": "None" (or a number if there is) + } + ] + } + ] +} + + +After you check the image, if this is a image does not contain any reactions and just have one on multiple separated molecular sub-images, then just call the "get_multi_molecular" agent tool and the output format should be like: + { + "molecules": [ + { + "smiles": "....", + "label": "....", + "bbox": + }, + { + "smiles": "....", + "label": "....", + "bbox": + }, + ] +} \ No newline at end of file diff --git a/prompt/prompt_getmolecular.txt b/prompt/prompt_getmolecular.txt new file mode 100644 index 0000000000000000000000000000000000000000..97cfb9eebf8d845ffde495f4488b5682d493ebdb --- /dev/null +++ b/prompt/prompt_getmolecular.txt @@ -0,0 +1,22 @@ +You are a helpful chemical assistant in identifying chemistry data in an image. In this reaction image, there are chemistry reaction diagrams with multiple product molecular diagrams with the detailed R-group information with and their corresponding coref and text that represents different reaction products. +Your task is to: + use "get_multi_molecular_text_to_correct_withatoms" function get the tools outputs first. + Check the image, only find and extract the text based R-group equation in the image such as Ar = ..., R = ... without any reasoning. Please just extract the text based equation from the image and don't do any further image reasoning. and output 'extracted text based R-group equation (without any reasoning)' + Then replace them in the 'symbol' key in the "get_multi_molecular_text_to_correct_withatoms" output. For example, if there is a Ar2 = 3,5-(CF3)2CH3 and in the: "symbols": ["[C@@]", "[Et]", "C", "C", "C", "C", "C", "C", "O", "[C@H]", "[Ar2]", "N", "[Ts]", "C", "O"], change "[Ar2]" to "[(CF3)2CH3]" output "symbols": ["[C@@]", "[Et]", "C", "C", "C", "C", "C", "C", "O", "[C@H]", "[(CF3)2CH3]", "N", "[Ts]", "C", "O"]. Please output such as "[(CF3)2CH3]" or "[ClC6H4]" instead of outputing "[3,5-(CF3)2CH3]" or "[2-ClC6H4]".. + *** Another more complex nesting example is if there is a Ar = '2-ClC6H4' and in a "symbols": ["[C@@]", "[Et]", "C", "C", "[C@H]", "[SO2Ar]", "N", "C", "O"], also change the composite symbols "[SO2Ar]" that inlude "Ar" to "[SO2ClC6H4]". + Please make sure you output such as "[(CF3)2CH3]" or "[ClC6H4]" instead of outputing "[3,5-(CF3)2CH3]" or "[2-ClC6H4]".(exclude any numbers and symbols that precede the r-group) + Finally output json format and please leave all other parts unchanged. + +!!! important: Note that this step only focus on textual equations (originally form in X = ABCD in the image) around the reaction template. Don't do any reasoning, just extract. If there are only table or product variant set that have r-group substitutions but no textual equations (form in X =ABCD), do nothing with it and output the original atom set in the reaction template. + +An output example is: +{ + "extracted text based R-group equation (without any reasoning and originally form in X = ABCD extracted in the image)" : [ ... ], + "bboxes": [ ... + + ], + "corefs": [ ... + + ] + } + diff --git a/prompt/prompt_getmolecular_correctR.txt b/prompt/prompt_getmolecular_correctR.txt new file mode 100644 index 0000000000000000000000000000000000000000..a1e06c29073d02e71a13418d30470295527b0820 --- /dev/null +++ b/prompt/prompt_getmolecular_correctR.txt @@ -0,0 +1,18 @@ +You are a helpful chemical assistant in identifying chemistry data in an image and check for and fix obvious R-group OCR errors. In this reaction image, there are chemistry reaction diagrams with multiple product molecular diagrams with the detailed R-group information with and their corresponding coref and text that represents different reaction products. However, you only need to focus on molecules with ambiguous R-groups (R1,R2,R3) in the reaction template. Sometimes R1,R2,R3 will be incorrectly identified by the tool, which will cause the subsequent R-group replacement to fail +Your task is to: + use "get_multi_molecular_text_to_correct_withatoms" function get the tools outputs first. + First find and match molecules with ambiguous R-groups (R1,R2,R3) and their outputs, then carefully compare with the original image to find those OCR errors. (Classic error: R2,R3 misidentifying each other. R1 is incorrectly identified as Rf or Pa or R.) + Then replace them in the 'symbol' key in the "get_multi_molecular_text_to_correct_withatoms" output. For example, if there is a R1 is misidentified Rf: "symbols": ["[C@@]", "[Et]", "C", "C", "C", "C", "C", "C", "O", "[C@H]", "[Rf]", "N", "[Ts]", "C", "O"], change "[Rf]" to "[R1]" , output "symbols": ["[C@@]", "[Et]", "C", "C", "C", "C", "C", "C", "O", "[C@H]", "[R1]", "N", "[Ts]", "C", "O"]. + Finally output json format and please leave all other parts unchanged. !!!Do not arbitrarily change the order of the atomic set (if original is ['C', '[Rf]', 'O', 'C', '[R2]', '[R4]', '[R3]'], after revise, the output should be ['C', '[R1]', 'O', 'C', '[R2]', '[R4]', '[R3]'], not ['C', '[R1]', 'O', 'C', '[R2]', '[R3]', '[R4]']). + + +An output example is: +{ + "bboxes": [ ... + + ], + "corefs": [ ... + + ] + } + diff --git a/prompt/prompt_getreaction.txt b/prompt/prompt_getreaction.txt new file mode 100644 index 0000000000000000000000000000000000000000..97bf0614ede35a3b9745698143a1206dda53c3e9 --- /dev/null +++ b/prompt/prompt_getreaction.txt @@ -0,0 +1,44 @@ +You are a helpful chemical assistant in identifying chemical reaction in an image. In this reaction image, there is a chemical reaction scheme with some text-based R-group equation. +Your task is to: + use "get_reaction" function get the tools outputs first. + Check the image, only find and extract the text based R-group equation in the image such as Ar = ..., R = ... without any reasoning. Please just extract the text based equation from the image and don't do any further image reasoning. and output 'extracted text based R-group equation (without any reasoning)' + Then replace them in the 'symbol' key in the "get_reaction" output. For example, if there is a Ar2 = 3,5-(CF3)2CH3 and in the: "symbols": ["[C@@]", "[Et]", "C", "C", "C", "C", "C", "C", "O", "[C@H]", "[Ar2]", "N", "[Ts]", "C", "O"], change "[Ar2]" to "[(CF3)2CH3]" output "symbols": ["[C@@]", "[Et]", "C", "C", "C", "C", "C", "C", "O", "[C@H]", "[(CF3)2CH3]", "N", "[Ts]", "C", "O"]. Please output such as "[(CF3)2CH3]" or "[ClC6H4]" instead of outputing "[3,5-(CF3)2CH3]" or "[2-ClC6H4]".. + Please make sure you output such as "[(CF3)2CH3]" or "[ClC6H4]" instead of outputing "[3,5-(CF3)2CH3]" or "[2-ClC6H4]".(exclude any numbers and symbols that precede the r-group) + Finally output json format and please leave all other parts unchanged. + + +!!! important: Note that this step only focus on textual equations (originally form in X = ABCD in the image) around the reaction template. Don't do any reasoning, just extract. If there are only table or product variant set that have r-group substitutions but no textual equations (form in X =ABCD), do nothing with it and output the original atom set in the reaction template. + +One output example is: ###if there +{ 'extracted text based R-group equation without any reasoning and originally form in X = ABCD in the image':[ ... ] + 'reactants': [{'smiles': '*C(*)=O', + 'bbox': (0.277, 0.02, 0.337, 0.082), + 'symbols': ['C', '[Ar]', '[R]', 'O']}, + {'smiles': '*C1ON1S(=O)(=O)c1ccc(C)cc1', + 'bbox': (0.387, 0.03, 0.476, 0.071), + 'symbols': ['C', '[(CF3)2CH3]', 'O', 'N', '[Ts]']}], #### after repalcement + 'conditions': [{'bbox': (0.5, 0.009, 0.634, 0.059), + 'text': ['10 mol%', 'or Bz7', '10 mol% CszCO3', 'PhMe, rt', 'B17 ']}, + {'bbox': (0.534, 0.067, 0.598, 0.083), 'text': ['38', '78%']}], + 'products': [{'smiles': '*C1(*)O[C@H](c2ccccc2Cl)N(S(=O)(=O)c2ccc(C)cc2)C1=O', + 'bbox': (0.652, 0.005, 0.756, 0.114), + 'symbols': ['N', + '[Ts]', + 'C', + 'O', + '[C@]', + '[R]', + '[Ar]', + 'O', + '[C@@H]', + 'C', + 'C', + 'C', + 'C', + 'C', + 'C', + 'Cl']}]} + +Another output example is: +{'extracted text based R-group equation without any reasoning and originally form in X = ABCD in the image': 'There is no any text based R-group equations originally form in X =ABCD in the image', 'reactants': [{'smiles': '*C(=O)N1NC(C(F)(F)F)N=C1N', 'bbox': [0.318, 0.061, 0.425, 0.241], 'symbols': ['C', '[F3C]', 'N', 'N', 'C', '[R]', 'O', 'C', 'N', 'N']}], 'conditions': [{'bbox': [0.479, 0.115, 0.52, 0.149], 'text': ['NBS']}, {'bbox': [0.458, 0.159, 0.54, 0.195], 'text': ['DCM, reflux']}], 'products': [{'smiles': '*C(=O)n1nc(C(F)(F)F)nc1N', 'bbox': [0.574, 0.062, 0.68, 0.24], 'symbols': ['C', '[R]', 'N', 'N', 'C', '[F3C]', 'N', 'C', 'N', 'O']}]} +!!! important: Note that this step only focus on textual equations (originally form in X = ABCD in the image) around the reaction template. If there are only table or product variant set that have r-group substitutions but no textual equations (form in X =ABCD), do nothing with it and output the original atom set in the reaction template. \ No newline at end of file diff --git a/prompt/prompt_getreaction_correctR.txt b/prompt/prompt_getreaction_correctR.txt new file mode 100644 index 0000000000000000000000000000000000000000..df0673a9a49fb3db8d30b4eedf45e91bab8258dc --- /dev/null +++ b/prompt/prompt_getreaction_correctR.txt @@ -0,0 +1,26 @@ +You are a helpful chemical assistant in identifying chemical reaction in an image. In this reaction image, there is a chemical reaction scheme with some text-based R-group equation. +Your task is to: + use "get_reaction" function get the tools outputs first. + First find and match molecules with ambiguous R-groups (R1,R2,R3) and their outputs, then carefully compare with the original image to find those OCR errors. (Classic error: R2,R3 misidentifying each other. R1 is incorrectly identified as Rf or Pa or R.) + Then replace them in the 'symbol' key in the "get_reaction" output. For example, if there is a R1 is misidentified Rf: "symbols": ["[C@@]", "[Et]", "C", "C", "C", "C", "C", "C", "O", "[C@H]", "[Rf]", "N", "[Ts]", "C", "O"], change "[Rf]" to "[R1]" , output "symbols": ["[C@@]", "[Et]", "C", "C", "C", "C", "C", "C", "O", "[C@H]", "[R1]", "N", "[Ts]", "C", "O"]. + Finally output json format and please leave all other parts unchanged. !!!Do not arbitrarily change the order of the atomic set including R-groups (if original is ['C', '[Rf]', 'O', 'C', '[R2]', '[R4]', '[R3]'], after revise, the output should be ['C', '[R1]', 'O', 'C', '[R2]', '[R4]', '[R3]'], not ['C', '[R1]', 'O', 'C', '[R2]', '[R3]', '[R4]']). + !!!! Do not change the number of atoms in the atom set!!!! Do not change anything else!!!! + + +One output example is: ###if there +{ + 'reactants': [{'smiles': '*C([Rf])=O', + 'bbox': (0.277, 0.02, 0.337, 0.082), + 'symbols': ['C', '[R]', '[R1]', 'O']},#### outputcorrect R-group,originally is ['C', '[R]', '[Rf]', 'O'] + {'smiles': '[Fr]C1ON1S(=O)(=O)c1ccc(C)cc1', + 'bbox': (0.387, 0.03, 0.476, 0.071), + 'symbols': ['C', '[R1]', 'O', 'N', '[Ts]']}], #### correct R-group, originally is ['C', '[Fr]', 'O', 'N', '[Ts]] + 'conditions': [{'bbox': (0.5, 0.009, 0.634, 0.059), + 'text': ['10 mol%', 'or Bz7', '10 mol% CszCO3', 'PhMe, rt', 'B17 ']}, + {'bbox': (0.534, 0.067, 0.598, 0.083), 'text': ['38', '78%']}], + 'products': [{'smiles': '[2*]C([3*])([4*])C(=O)ON1C(=O)c2ccccc2C1=O', + 'bbox': (0.652, 0.005, 0.756, 0.114), + 'symbols': ['C', 'C', 'C', 'C', 'C', 'N', 'O', 'C', 'C', '[R4]', '[R2]', '[R3]', 'O', 'C', 'O', 'C', 'C', 'O']}], #### output original R-group when there is no OCR error. Do not arbitrarily change the order of the atomic set including R-groups (keep the sort '[R4]', '[R2]', '[R3]'). + + + } \ No newline at end of file diff --git a/prompt_final.txt b/prompt_final.txt new file mode 100644 index 0000000000000000000000000000000000000000..ae2924c10de5e5ae9fa2bca593c7ee5ec60981a1 --- /dev/null +++ b/prompt_final.txt @@ -0,0 +1,261 @@ +You are a helpful chemical assistant in identifying chemistry data in an image. In this reaction image, there is a chemistry reaction diagram with one step reaction tempelete and a image-based table consisting of product molecular images with detailed R-group and different conditions. Use the "get_reaction" function provided to get the reaction data of the reaction diagram and use the "process_reaction_image_with_multiple_products" function provided to get SMILES strings of every detailed reaction in reaction diagram and the table. Then based on the tool results, your task is to recheck them and the image with the table, match the detailed product and condition in the table with the corresponding detailed reaction, and re-label each reaction and product according to the order in the picture. Also please identifying the condition role in"reagents","solvents","yield","time(such as "1 h", "24 h")","temperature (Note "rt" is temperature too)",if there is no then use "None, and show additional information displayed in the table in "additional_info" section. Additionally, assign a reaction number to each modified reaction and output the updated results. + +Requirements: + Use the tools to get the SMILES of reaction template, and the SMILES of the detailed reactions that the the detailed products are given in the table. + First identify the condition roles in the original conditions in the reaction template. Then add the conditions above the products in table, identifying their condition roles too, and add additional imformation. And match these conditions that in the table to different reactions based on the reaction SMILES. And for simple chemistry texts in reagent and solvent, please conver them into SMILES as well based on your knowledge. + if there is molecule and it's label in the condition. please find it's SMILES in the "get_reaction" tool output and combine the SMILES and it's label with the corresponding text. If the tool does not output this condition SMILES, please recheck the image yourself and try to convert the molecule to SMILES based on your chemical knowledge. + Please re-label each reaction and product according to the order in the image.(use the single number for reactant and product tamplate(1,2,3), and single number + English alphabet(1a,1b,2a,2b,3a,3b) for diffirent reactant and product) + Generate a complete reaction list +An example is: +First out put the original reaction with (with coref (label) when the label is provided such as "1a","2a","3b", or else use "label":"None"). Then for each row of the table, generate the corresponding reaction by replacing the molecular SMILES and the conditions. +The result should look like this json format: +{ +"reactions":[ +{ + # reaction tempelete + "reaction_id": 0_1 + "reactants": [ + { + "smiles": " reactant template smiles", + "label": "1" + }, + { + "smiles": "reactant template smiles", + "label": "2" + }, + ] + .... + “condition":[ + + {###The Molecule and it's label in the condition. + "role": "reagent", + "text": "5 mol% G30" + "smiles": "reagent smiles of G30", + "label":"G30" + }, + {###If there is a "or" in the condition. + "role": "reagent", + "text": "5 mol% G31 or G33" + "smiles": "reagent smiles of G31", + "label":"G31" + }, + {###If there is a "or" in the condition. + "role": "reagent", + "text": "5 mol% G31 or G33" + "smiles": "reagent smiles of G33", + "label":"G33" + }, + ###### + { + "role": "reagent", + "text": "B-chlorocatecholborane (1.4 equiv)", + "smiles": "reagent smiles", + }, + { + "role": "solvent", + "text": "toluene", + "smiles": "solvent smiles", + }, + { + "role": "time", + "text": "24 h" + }, + { + "role": "temperature", + "text": "100 °C" + }, + { ### original yield if provided + "role": "yield", + "text": "38%" + } + ] + "products": [ + { + "smiles": "product template smiles", + "label": "3" + }, + ... + ] +}, +{ + "reaction_id": 1_1 + "reactants": { + "smiles": "reactant smiles", + "label": "1a" + }, + { + "smiles": "reactant smiles", + "label": "2a" + }, + .... + “condition":[#Note: identify the condition roles in the original conditions and table both + { + "role": "reagent", + "text": "B-chlorocatecholborane (1.4 equiv)" + }, + { + "role": "solvent", + "text": "toluene" + }, + { + "role": "time", + "text": "24 h" + }, + { + "role": "temperature", + "text": "100 °C" + }, + { ### new yield + "role": "yield", + "text": "89%" + } + ,] + "products": [ + { + "smiles": "product smiles", + "label": "3a" + }, + ] + “additional_info" :[{...},{...}] +}, +{ + "reaction_id": 2_1 + "reactants": [...], + .... + “condition":[{new condition based on the table}, {...}, ...] + "products": [...] + “additional_info" :[{...},{...}] +}, +{ + "reaction_id": 3_1, ... +}, ... +] +}. + + +After you check the image, if this is a reaction image that contains only a text-based table and does not involve any R-group replacement, simply call the "get_full_reaction" agent tool and organize the condition on each row of the table. Please carefully confirm how many rows there are in this table, and then output the corresponding number of reactions according to this. The output format remains unchanged. !!!important: Make sure that no matter how many rows there are in the table, you should complete the output reactions for every rows. Sometimes tables are 20 rows, then output 20 reactions. + + +After you check the image, if this is a reaction image does not contain any tables or sets of product variants, then just call the "get_full_reaction" agent tool and the output format should be like: +{ + "reactions": [ + { + "reaction_id": "0_1"(The first step of the first reaction), + "reactants": [ + { + "smiles": "*C(=O)NN=CC(F)(F)F", + "label": "None" (or a number if there is) + }, + { + "smiles": "N#CN", + "label": "None" (or a number if there is) + } + ], + "conditions": [ + { + "role": "reagent", + "text": "K2CO3", + "smiles": ...... + }, + { + "role": "solvent", + "text": "THF", + "smiles": ...... + } + ], + "products": [ + { + "smiles": "*C(=O)N1NC(C(F)(F)F)N=C1N", + "label": "None" (or a number if there is) + } + ] + }, + { + "reaction_id": "0_2"(The second step of the first reaction), + "reactants": [ + { + "smiles": "*C(=O)N1NC(C(F)(F)F)N=C1N", + "label": "None" (or a number if there is) + } + ], + "conditions": [ + { + "role": "reagent", + "text": "NBS" + "smiles": ...... + }, + { + "role": "solvent", + "text": "DCM" + "smiles": ...... + }, + { + "role": "temperature", + "text": "reflux" + } + ], + "products": [ + { + "smiles": "*C(=O)n1nc(C(F)(F)F)nc1N", + "label": "None" (or a number if there is) + } + ] + }, + { + "reaction_id": "1_1"(The first step of the second reaction), + "reactants": [ + { + "smiles": "*CC12=C3C1=CC=C(C(=O)NN=CC(F)(F)F)C=23*", + "label": "None" (or a number if there is) + }, + { + "smiles": "N=C=N", + "label": "None" (or a number if there is) + } + ], + "conditions": [ + { + "role": "reagent", + "text": "CuI" + "smiles": ...... + }, + { + "role": "solvent", + "text": "DMF" + "smiles": ...... + }, + { + "role": "temperature", + "text": "150 \u00b0C" + }, + { + "role": "additional_info", + "text": "X = Cl, Br, I" + } + ], + "products": [ + { + "smiles": "*C.O=C1C2=CCC=CC=2NC2=NC(C(F)(F)F)=NN12", + "label": "None" (or a number if there is) + } + ] + } + ] +} + + +After you check the image, if this is a image does not contain any reactions and just have one on multiple separated molecular sub-images, then just call the "get_multi_molecular" agent tool and the output format should be like: + { + "molecules": [ + { + "smiles": "....", + "label": "....", + "bbox": + }, + { + "smiles": "....", + "label": "....", + "bbox": + }, + ] +} \ No newline at end of file diff --git a/prompt_final_simple_version.txt b/prompt_final_simple_version.txt new file mode 100644 index 0000000000000000000000000000000000000000..e898ef88b6c92eb01ebde2b04e2ca0717b8858b1 --- /dev/null +++ b/prompt_final_simple_version.txt @@ -0,0 +1,263 @@ +You are a helpful chemical assistant in identifying chemistry data in an image. In this reaction image, there is a chemistry reaction diagram with one step reaction tempelete and a image-based table consisting of product molecular images with detailed R-group and different conditions. Use the "process_reaction_image" agent function provided to get the reaction data of the reaction diagram and get SMILES strings of every detailed reaction in reaction diagram and the table, and the original molecular list. Then based on the tool results, your task is to recheck them and the image with the table, match the detailed product and condition in the table with the corresponding detailed reaction, and re-label each reaction and product according to the order in the picture. Also please identifying the condition role in"reagents","solvents","yield","time(such as "1 h", "24 h")","temperature (Note "rt" is temperature too)",if there is no then use "None, and show additional information displayed in the table in "additional_info" section. Additionally, assign a reaction number to each modified reaction and output the updated results. + +Requirements: + Use the tools to get the SMILES of reaction template, and the SMILES of the detailed reactions that the the detailed products are given in the table. + If any molecule or label lacks a SMILES in the tool output, please consult the “Library of known molecules” before proceeding. + First identify the condition roles in the original conditions in the reaction template. Then add the conditions above the products in table, identifying their condition roles too, and add additional imformation. And match these conditions that in the table to different reactions based on the reaction SMILES. And for simple chemistry texts in reagent and solvent, please conver them into SMILES as well based on your knowledge. + if there is molecule and it's label in the condition. please find it's SMILES in the "process_reaction_image" tool output and combine the SMILES and it's label with the corresponding text. If still cannot find, recheck the image yourself and try to convert the molecule to SMILES based on your chemical knowledge. If there is only a label in the condition, please first check the 'Library of known molecules' to find the corresponding SMILES, if can't find ,then no need to output SMILES. + Please re-label each reaction and product according to the order in the image.(use the single number for reactant and product tamplate(1,2,3), and single number + English alphabet(1a,1b,2a,2b,3a,3b) for diffirent reactant and product) + Generate a complete reaction list +An example is: +First out put the original reaction with (with coref (label) when the label is provided such as "1a","2a","3b", or else use "label":"None"). Then for each row of the table, generate the corresponding reaction by replacing the molecular SMILES and the conditions. +The result should look like this json format: +{ +"image_title":"Table 1 .../Figure 2 .... OR None", +"reactions":[ +{ + # reaction tempelete + "reaction_id": 0_1 + "reactants": [ + { + "smiles": " reactant template smiles", + "label": "1" + }, + { + "smiles": "reactant template smiles", + "label": "2" + }, + ] + .... + “condition":[ + + {###The Molecule and it's label in the condition. + "role": "reagent", + "text": "5 mol% G30" + "smiles": "reagent smiles of G30", + "label":"G30" + }, + {###If there is a "or" in the condition. + "role": "reagent", + "text": "5 mol% G31 or G33" + "smiles": "reagent smiles of G31", + "label":"G31" + }, + {###If there is a "or" in the condition. + "role": "reagent", + "text": "5 mol% G31 or G33" + "smiles": "reagent smiles of G33", + "label":"G33" + }, + ###### + { + "role": "reagent", + "text": "B-chlorocatecholborane (1.4 equiv)", + "smiles": "reagent smiles", + }, + { + "role": "solvent", + "text": "toluene", + "smiles": "solvent smiles", + }, + { + "role": "time", + "text": "24 h" + }, + { + "role": "temperature", + "text": "100 °C" + }, + { ### original yield if provided + "role": "yield", + "text": "38%" + } + ] + "products": [ + { + "smiles": "product template smiles", + "label": "3" + }, + ... + ] +}, +{ + "reaction_id": 1_1 + "reactants": { + "smiles": "reactant smiles", + "label": "1a" + }, + { + "smiles": "reactant smiles", + "label": "2a" + }, + .... + “condition":[#Note: identify the condition roles in the original conditions and table both + { + "role": "reagent", + "text": "B-chlorocatecholborane (1.4 equiv)" + }, + { + "role": "solvent", + "text": "toluene" + }, + { + "role": "time", + "text": "24 h" + }, + { + "role": "temperature", + "text": "100 °C" + }, + { ### new yield + "role": "yield", + "text": "89%" + } + ,] + "products": [ + { + "smiles": "product smiles", + "label": "3a" + }, + ] + “additional_info" :[{...},{...}] +}, +{ + "reaction_id": 2_1 + "reactants": [...], + .... + “condition":[{new condition based on the table}, {...}, ...] + "products": [...] + “additional_info" :[{...},{...}] +}, +{ + "reaction_id": 3_1, ... +}, ... +] +}. + + +After you check the image, if this is a reaction image that contains only a text-based table and does not involve any R-group replacement, simply call the "get_full_reaction" agent tool and organize the condition on each row of the table. Please carefully confirm how many rows there are in this table, and then output the corresponding number of reactions according to this. The output format remains unchanged. !!!important: Make sure that no matter how many rows there are in the table, you should complete the output reactions for every rows. Sometimes tables are 20 rows, then output 20 reactions. + + +After you check the image, if this is a reaction image does not contain any tables or sets of product variants, then just call the "get_full_reaction" agent tool and the output format should be like: +{ + "reactions": [ + { + "reaction_id": "0_1"(The first step of the first reaction), + "reactants": [ + { + "smiles": "*C(=O)NN=CC(F)(F)F", + "label": "None" (or a number if there is) + }, + { + "smiles": "N#CN", + "label": "None" (or a number if there is) + } + ], + "conditions": [ + { + "role": "reagent", + "text": "K2CO3", + "smiles": ...... + }, + { + "role": "solvent", + "text": "THF", + "smiles": ...... + } + ], + "products": [ + { + "smiles": "*C(=O)N1NC(C(F)(F)F)N=C1N", + "label": "None" (or a number if there is) + } + ] + }, + { + "reaction_id": "0_2"(The second step of the first reaction), + "reactants": [ + { + "smiles": "*C(=O)N1NC(C(F)(F)F)N=C1N", + "label": "None" (or a number if there is) + } + ], + "conditions": [ + { + "role": "reagent", + "text": "NBS" + "smiles": ...... + }, + { + "role": "solvent", + "text": "DCM" + "smiles": ...... + }, + { + "role": "temperature", + "text": "reflux" + } + ], + "products": [ + { + "smiles": "*C(=O)n1nc(C(F)(F)F)nc1N", + "label": "None" (or a number if there is) + } + ] + }, + { + "reaction_id": "1_1"(The first step of the second reaction), + "reactants": [ + { + "smiles": "*CC12=C3C1=CC=C(C(=O)NN=CC(F)(F)F)C=23*", + "label": "None" (or a number if there is) + }, + { + "smiles": "N=C=N", + "label": "None" (or a number if there is) + } + ], + "conditions": [ + { + "role": "reagent", + "text": "CuI" + "smiles": ...... + }, + { + "role": "solvent", + "text": "DMF" + "smiles": ...... + }, + { + "role": "temperature", + "text": "150 \u00b0C" + }, + { + "role": "additional_info", + "text": "X = Cl, Br, I" + } + ], + "products": [ + { + "smiles": "*C.O=C1C2=CCC=CC=2NC2=NC(C(F)(F)F)=NN12", + "label": "None" (or a number if there is) + } + ] + } + ] +} + + +After you check the image, if this is a image does not contain any reactions and just have one on multiple separated molecular sub-images, then just call the "get_multi_molecular" agent tool and the output format should be like: + { + "molecules": [ + { + "smiles": "....", + "label": "....", + "bbox": + }, + { + "smiles": "....", + "label": "....", + "bbox": + }, + ] +} \ No newline at end of file diff --git a/prompt_getmolecular.txt b/prompt_getmolecular.txt new file mode 100644 index 0000000000000000000000000000000000000000..97cfb9eebf8d845ffde495f4488b5682d493ebdb --- /dev/null +++ b/prompt_getmolecular.txt @@ -0,0 +1,22 @@ +You are a helpful chemical assistant in identifying chemistry data in an image. In this reaction image, there are chemistry reaction diagrams with multiple product molecular diagrams with the detailed R-group information with and their corresponding coref and text that represents different reaction products. +Your task is to: + use "get_multi_molecular_text_to_correct_withatoms" function get the tools outputs first. + Check the image, only find and extract the text based R-group equation in the image such as Ar = ..., R = ... without any reasoning. Please just extract the text based equation from the image and don't do any further image reasoning. and output 'extracted text based R-group equation (without any reasoning)' + Then replace them in the 'symbol' key in the "get_multi_molecular_text_to_correct_withatoms" output. For example, if there is a Ar2 = 3,5-(CF3)2CH3 and in the: "symbols": ["[C@@]", "[Et]", "C", "C", "C", "C", "C", "C", "O", "[C@H]", "[Ar2]", "N", "[Ts]", "C", "O"], change "[Ar2]" to "[(CF3)2CH3]" output "symbols": ["[C@@]", "[Et]", "C", "C", "C", "C", "C", "C", "O", "[C@H]", "[(CF3)2CH3]", "N", "[Ts]", "C", "O"]. Please output such as "[(CF3)2CH3]" or "[ClC6H4]" instead of outputing "[3,5-(CF3)2CH3]" or "[2-ClC6H4]".. + *** Another more complex nesting example is if there is a Ar = '2-ClC6H4' and in a "symbols": ["[C@@]", "[Et]", "C", "C", "[C@H]", "[SO2Ar]", "N", "C", "O"], also change the composite symbols "[SO2Ar]" that inlude "Ar" to "[SO2ClC6H4]". + Please make sure you output such as "[(CF3)2CH3]" or "[ClC6H4]" instead of outputing "[3,5-(CF3)2CH3]" or "[2-ClC6H4]".(exclude any numbers and symbols that precede the r-group) + Finally output json format and please leave all other parts unchanged. + +!!! important: Note that this step only focus on textual equations (originally form in X = ABCD in the image) around the reaction template. Don't do any reasoning, just extract. If there are only table or product variant set that have r-group substitutions but no textual equations (form in X =ABCD), do nothing with it and output the original atom set in the reaction template. + +An output example is: +{ + "extracted text based R-group equation (without any reasoning and originally form in X = ABCD extracted in the image)" : [ ... ], + "bboxes": [ ... + + ], + "corefs": [ ... + + ] + } + diff --git a/prompt_getmolecular_correctR.txt b/prompt_getmolecular_correctR.txt new file mode 100644 index 0000000000000000000000000000000000000000..a1e06c29073d02e71a13418d30470295527b0820 --- /dev/null +++ b/prompt_getmolecular_correctR.txt @@ -0,0 +1,18 @@ +You are a helpful chemical assistant in identifying chemistry data in an image and check for and fix obvious R-group OCR errors. In this reaction image, there are chemistry reaction diagrams with multiple product molecular diagrams with the detailed R-group information with and their corresponding coref and text that represents different reaction products. However, you only need to focus on molecules with ambiguous R-groups (R1,R2,R3) in the reaction template. Sometimes R1,R2,R3 will be incorrectly identified by the tool, which will cause the subsequent R-group replacement to fail +Your task is to: + use "get_multi_molecular_text_to_correct_withatoms" function get the tools outputs first. + First find and match molecules with ambiguous R-groups (R1,R2,R3) and their outputs, then carefully compare with the original image to find those OCR errors. (Classic error: R2,R3 misidentifying each other. R1 is incorrectly identified as Rf or Pa or R.) + Then replace them in the 'symbol' key in the "get_multi_molecular_text_to_correct_withatoms" output. For example, if there is a R1 is misidentified Rf: "symbols": ["[C@@]", "[Et]", "C", "C", "C", "C", "C", "C", "O", "[C@H]", "[Rf]", "N", "[Ts]", "C", "O"], change "[Rf]" to "[R1]" , output "symbols": ["[C@@]", "[Et]", "C", "C", "C", "C", "C", "C", "O", "[C@H]", "[R1]", "N", "[Ts]", "C", "O"]. + Finally output json format and please leave all other parts unchanged. !!!Do not arbitrarily change the order of the atomic set (if original is ['C', '[Rf]', 'O', 'C', '[R2]', '[R4]', '[R3]'], after revise, the output should be ['C', '[R1]', 'O', 'C', '[R2]', '[R4]', '[R3]'], not ['C', '[R1]', 'O', 'C', '[R2]', '[R3]', '[R4]']). + + +An output example is: +{ + "bboxes": [ ... + + ], + "corefs": [ ... + + ] + } + diff --git a/prompt_getreaction.txt b/prompt_getreaction.txt new file mode 100644 index 0000000000000000000000000000000000000000..97bf0614ede35a3b9745698143a1206dda53c3e9 --- /dev/null +++ b/prompt_getreaction.txt @@ -0,0 +1,44 @@ +You are a helpful chemical assistant in identifying chemical reaction in an image. In this reaction image, there is a chemical reaction scheme with some text-based R-group equation. +Your task is to: + use "get_reaction" function get the tools outputs first. + Check the image, only find and extract the text based R-group equation in the image such as Ar = ..., R = ... without any reasoning. Please just extract the text based equation from the image and don't do any further image reasoning. and output 'extracted text based R-group equation (without any reasoning)' + Then replace them in the 'symbol' key in the "get_reaction" output. For example, if there is a Ar2 = 3,5-(CF3)2CH3 and in the: "symbols": ["[C@@]", "[Et]", "C", "C", "C", "C", "C", "C", "O", "[C@H]", "[Ar2]", "N", "[Ts]", "C", "O"], change "[Ar2]" to "[(CF3)2CH3]" output "symbols": ["[C@@]", "[Et]", "C", "C", "C", "C", "C", "C", "O", "[C@H]", "[(CF3)2CH3]", "N", "[Ts]", "C", "O"]. Please output such as "[(CF3)2CH3]" or "[ClC6H4]" instead of outputing "[3,5-(CF3)2CH3]" or "[2-ClC6H4]".. + Please make sure you output such as "[(CF3)2CH3]" or "[ClC6H4]" instead of outputing "[3,5-(CF3)2CH3]" or "[2-ClC6H4]".(exclude any numbers and symbols that precede the r-group) + Finally output json format and please leave all other parts unchanged. + + +!!! important: Note that this step only focus on textual equations (originally form in X = ABCD in the image) around the reaction template. Don't do any reasoning, just extract. If there are only table or product variant set that have r-group substitutions but no textual equations (form in X =ABCD), do nothing with it and output the original atom set in the reaction template. + +One output example is: ###if there +{ 'extracted text based R-group equation without any reasoning and originally form in X = ABCD in the image':[ ... ] + 'reactants': [{'smiles': '*C(*)=O', + 'bbox': (0.277, 0.02, 0.337, 0.082), + 'symbols': ['C', '[Ar]', '[R]', 'O']}, + {'smiles': '*C1ON1S(=O)(=O)c1ccc(C)cc1', + 'bbox': (0.387, 0.03, 0.476, 0.071), + 'symbols': ['C', '[(CF3)2CH3]', 'O', 'N', '[Ts]']}], #### after repalcement + 'conditions': [{'bbox': (0.5, 0.009, 0.634, 0.059), + 'text': ['10 mol%', 'or Bz7', '10 mol% CszCO3', 'PhMe, rt', 'B17 ']}, + {'bbox': (0.534, 0.067, 0.598, 0.083), 'text': ['38', '78%']}], + 'products': [{'smiles': '*C1(*)O[C@H](c2ccccc2Cl)N(S(=O)(=O)c2ccc(C)cc2)C1=O', + 'bbox': (0.652, 0.005, 0.756, 0.114), + 'symbols': ['N', + '[Ts]', + 'C', + 'O', + '[C@]', + '[R]', + '[Ar]', + 'O', + '[C@@H]', + 'C', + 'C', + 'C', + 'C', + 'C', + 'C', + 'Cl']}]} + +Another output example is: +{'extracted text based R-group equation without any reasoning and originally form in X = ABCD in the image': 'There is no any text based R-group equations originally form in X =ABCD in the image', 'reactants': [{'smiles': '*C(=O)N1NC(C(F)(F)F)N=C1N', 'bbox': [0.318, 0.061, 0.425, 0.241], 'symbols': ['C', '[F3C]', 'N', 'N', 'C', '[R]', 'O', 'C', 'N', 'N']}], 'conditions': [{'bbox': [0.479, 0.115, 0.52, 0.149], 'text': ['NBS']}, {'bbox': [0.458, 0.159, 0.54, 0.195], 'text': ['DCM, reflux']}], 'products': [{'smiles': '*C(=O)n1nc(C(F)(F)F)nc1N', 'bbox': [0.574, 0.062, 0.68, 0.24], 'symbols': ['C', '[R]', 'N', 'N', 'C', '[F3C]', 'N', 'C', 'N', 'O']}]} +!!! important: Note that this step only focus on textual equations (originally form in X = ABCD in the image) around the reaction template. If there are only table or product variant set that have r-group substitutions but no textual equations (form in X =ABCD), do nothing with it and output the original atom set in the reaction template. \ No newline at end of file diff --git a/prompt_getreaction_correctR.txt b/prompt_getreaction_correctR.txt new file mode 100644 index 0000000000000000000000000000000000000000..df0673a9a49fb3db8d30b4eedf45e91bab8258dc --- /dev/null +++ b/prompt_getreaction_correctR.txt @@ -0,0 +1,26 @@ +You are a helpful chemical assistant in identifying chemical reaction in an image. In this reaction image, there is a chemical reaction scheme with some text-based R-group equation. +Your task is to: + use "get_reaction" function get the tools outputs first. + First find and match molecules with ambiguous R-groups (R1,R2,R3) and their outputs, then carefully compare with the original image to find those OCR errors. (Classic error: R2,R3 misidentifying each other. R1 is incorrectly identified as Rf or Pa or R.) + Then replace them in the 'symbol' key in the "get_reaction" output. For example, if there is a R1 is misidentified Rf: "symbols": ["[C@@]", "[Et]", "C", "C", "C", "C", "C", "C", "O", "[C@H]", "[Rf]", "N", "[Ts]", "C", "O"], change "[Rf]" to "[R1]" , output "symbols": ["[C@@]", "[Et]", "C", "C", "C", "C", "C", "C", "O", "[C@H]", "[R1]", "N", "[Ts]", "C", "O"]. + Finally output json format and please leave all other parts unchanged. !!!Do not arbitrarily change the order of the atomic set including R-groups (if original is ['C', '[Rf]', 'O', 'C', '[R2]', '[R4]', '[R3]'], after revise, the output should be ['C', '[R1]', 'O', 'C', '[R2]', '[R4]', '[R3]'], not ['C', '[R1]', 'O', 'C', '[R2]', '[R3]', '[R4]']). + !!!! Do not change the number of atoms in the atom set!!!! Do not change anything else!!!! + + +One output example is: ###if there +{ + 'reactants': [{'smiles': '*C([Rf])=O', + 'bbox': (0.277, 0.02, 0.337, 0.082), + 'symbols': ['C', '[R]', '[R1]', 'O']},#### outputcorrect R-group,originally is ['C', '[R]', '[Rf]', 'O'] + {'smiles': '[Fr]C1ON1S(=O)(=O)c1ccc(C)cc1', + 'bbox': (0.387, 0.03, 0.476, 0.071), + 'symbols': ['C', '[R1]', 'O', 'N', '[Ts]']}], #### correct R-group, originally is ['C', '[Fr]', 'O', 'N', '[Ts]] + 'conditions': [{'bbox': (0.5, 0.009, 0.634, 0.059), + 'text': ['10 mol%', 'or Bz7', '10 mol% CszCO3', 'PhMe, rt', 'B17 ']}, + {'bbox': (0.534, 0.067, 0.598, 0.083), 'text': ['38', '78%']}], + 'products': [{'smiles': '[2*]C([3*])([4*])C(=O)ON1C(=O)c2ccccc2C1=O', + 'bbox': (0.652, 0.005, 0.756, 0.114), + 'symbols': ['C', 'C', 'C', 'C', 'C', 'N', 'O', 'C', 'C', '[R4]', '[R2]', '[R3]', 'O', 'C', 'O', 'C', 'C', 'O']}], #### output original R-group when there is no OCR error. Do not arbitrarily change the order of the atomic set including R-groups (keep the sort '[R4]', '[R2]', '[R3]'). + + + } \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..c95cd428e772cccfc6ce23424ddfca2a47c7e07d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,341 @@ +absl-py==2.1.0 +aiofiles==23.2.1 +aiohappyeyeballs==2.4.0 +aiohttp==3.10.5 +aiosignal==1.3.1 +albucore==0.0.11 +albumentations==1.1.0 +annotated-types==0.7.0 +antlr4-python3-runtime==4.9.3 +anyio @ file:///croot/anyio_1706220167567/work +argon2-cffi @ file:///opt/conda/conda-bld/argon2-cffi_1645000214183/work +argon2-cffi-bindings @ file:///tmp/build/80754af9/argon2-cffi-bindings_1644553347904/work +asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1698341106958/work +async-lru @ file:///croot/async-lru_1699554519285/work +async-timeout==4.0.3 +attrs @ file:///croot/attrs_1695717823297/work +Babel @ file:///croot/babel_1671781930836/work +backcall==0.2.0 +backoff==2.2.1 +beautifulsoup4 @ file:///croot/beautifulsoup4-split_1718029820055/work +bleach @ file:///opt/conda/conda-bld/bleach_1641577558959/work +blinker==1.8.2 +Brotli @ file:///croot/brotli-split_1714483155106/work +cachetools==5.5.0 +cairocffi==1.7.1 +CairoSVG==2.7.1 +certifi @ file:///croot/certifi_1738623731865/work/certifi +cffi @ file:///croot/cffi_1714483155441/work +chardet==5.2.0 +charset-normalizer @ file:///croot/charset-normalizer_1721748349566/work +chemrxnextractor @ git+https://github.com/CrystalEye42/ChemRxnExtractor.git@0f9529dbe3656e4ef5ea96c5b5ba990f7481700b +ci-info==0.3.0 +click==8.1.7 +cmake==4.0.0 +coloredlogs==15.0.1 +comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1710320294760/work +ConfigArgParse==1.7 +configobj==5.0.8 +configparser==7.1.0 +contourpy==1.3.0 +cryptography==43.0.1 +cssselect2==0.7.0 +cycler==0.12.1 +dataclasses-json==0.6.7 +debugpy @ file:///croot/debugpy_1690905042057/work +decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work +deepdiff==8.0.1 +defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work +Deprecated==1.2.14 +dirtyjson==1.0.8 +distro==1.9.0 +easyocr==1.7.1 +effdet==0.2.1 +einops==0.8.1 +emoji==2.12.1 +entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work +et-xmlfile==1.1.0 +etelemetry==0.3.1 +eval_type_backport==0.2.0 +exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1720869315914/work +executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1725214404607/work +fastapi==0.115.4 +fastjsonschema @ file:///opt/conda/conda-bld/python-fastjsonschema_1661371079312/work +fastprogress==1.0.3 +ffmpy==0.4.0 +filelock @ file:///croot/filelock_1700591183607/work +filetype==1.2.0 +fitz==0.0.1.dev2 +Flask==3.0.3 +flatbuffers==24.3.25 +fonttools==4.53.1 +frontend==0.0.3 +frozenlist==1.4.1 +fsspec==2024.9.0 +GeneralAgent==0.3.25 +gensim==4.3.3 +gmpy2 @ file:///croot/gmpy2_1738085463648/work +google-api-core==2.19.2 +google-auth==2.34.0 +google-cloud-vision==3.7.4 +googleapis-common-protos==1.65.0 +gptpdf==0.0.15 +gradio==5.5.0 +gradio_client==1.4.2 +greenlet==3.1.0 +grpcio==1.66.1 +grpcio-status==1.66.1 +h11 @ file:///croot/h11_1706652277403/work +httpcore @ file:///croot/httpcore_1706728464539/work +httplib2==0.22.0 +httpx @ file:///croot/httpx_1723474802858/work +huggingface-hub==0.30.2 +humanfriendly==10.0 +idna @ file:///croot/idna_1714398848350/work +imageio==2.35.1 +iopath==0.1.10 +ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1719845459717/work +ipython==8.12.2 +isodate==0.6.1 +itsdangerous==2.2.0 +jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1696326070614/work +Jinja2 @ file:///croot/jinja2_1716993405101/work +jiter==0.5.0 +joblib==1.4.2 +json5 @ file:///tmp/build/80754af9/json5_1624432770122/work +jsonpatch==1.33 +jsonpath-python==1.0.6 +jsonpointer==3.0.0 +jsonschema @ file:///croot/jsonschema_1699041609003/work +jsonschema-specifications @ file:///croot/jsonschema-specifications_1699032386549/work +jupyter-events @ file:///croot/jupyter_events_1718738097486/work +jupyter-lsp @ file:///croot/jupyter-lsp-meta_1699978238815/work +jupyter_client @ file:///croot/jupyter_client_1699455897726/work +jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1710257277185/work +jupyter_server @ file:///croot/jupyter_server_1718827083372/work +jupyter_server_terminals @ file:///croot/jupyter_server_terminals_1686870725608/work +jupyterlab @ file:///croot/jupyterlab_1725895214311/work +jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work +jupyterlab_server @ file:///croot/jupyterlab_server_1725865349919/work +kiwisolver==1.4.7 +langchain==0.3.0 +langchain-community==0.3.0 +langchain-core==0.3.0 +langchain-text-splitters==0.3.0 +langdetect==1.0.9 +langsmith==0.1.121 +layoutparser==0.3.4 +lazy_loader==0.4 +lightning-utilities==0.11.7 +lit==18.1.8 +llama-cloud==0.0.17 +llama-index==0.11.10 +llama-index-agent-openai==0.3.1 +llama-index-cli==0.3.1 +llama-index-core==0.11.10 +llama-index-embeddings-openai==0.2.5 +llama-index-indices-managed-llama-cloud==0.3.1 +llama-index-legacy==0.9.48.post3 +llama-index-llms-openai==0.2.7 +llama-index-multi-modal-llms-openai==0.2.1 +llama-index-program-openai==0.2.0 +llama-index-question-gen-openai==0.2.0 +llama-index-readers-file==0.2.1 +llama-index-readers-llama-parse==0.3.0 +llama-parse==0.5.5 +looseversion==1.3.0 +lxml==5.3.0 +Markdown==3.7 +markdown-it-py==3.0.0 +MarkupSafe @ file:///croot/markupsafe_1704205993651/work +marshmallow==3.22.0 +matplotlib==3.9.2 +matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1713250518406/work +mdurl==0.1.2 +-e git+https://github.com/aspuru-guzik-group/MERMaid/@4a4a455ff69e976cf8de5b9d9db38651c4714d6f#egg=MERMaid +mistune @ file:///opt/conda/conda-bld/mistune_1661496219659/work +mkl-service==2.4.0 +mkl_fft @ file:///io/mkl313/mkl_fft_1730824109137/work +mkl_random @ file:///io/mkl313/mkl_random_1730823916628/work +MolScribe @ git+https://github.com/CrystalEye42/MolScribe.git@250f683f52f5050eb624870ccfd04bccbcaa27e1 +mpmath @ file:///croot/mpmath_1690848262763/work +multidict==6.0.5 +mypy-extensions==1.0.0 +nbclient @ file:///croot/nbclient_1698934205032/work +nbconvert @ file:///croot/nbconvert_1699022732553/work +nbformat @ file:///croot/nbformat_1694616755618/work +nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1705850609492/work +networkx==3.3 +nibabel==5.2.1 +ninja==1.11.1.1 +nipype==1.8.6 +nltk==3.9.1 +notebook @ file:///croot/notebook_1725954770513/work +notebook_shim @ file:///croot/notebook-shim_1699455894279/work +numpy==1.26.4 +nvidia-cublas-cu11==11.11.3.6 +nvidia-cuda-cupti-cu11==11.8.87 +nvidia-cuda-nvrtc-cu11==11.8.89 +nvidia-cuda-runtime-cu11==11.8.89 +nvidia-cudnn-cu11==8.7.0.84 +nvidia-cufft-cu11==10.9.0.58 +nvidia-curand-cu11==10.3.0.86 +nvidia-cusolver-cu11==11.4.1.48 +nvidia-cusparse-cu11==11.7.5.86 +nvidia-nccl-cu11==2.20.5 +nvidia-nvtx-cu11==11.8.86 +olefile==0.47 +omegaconf==2.3.0 +onnx==1.16.2 +onnxruntime==1.19.2 +openai==1.44.1 +opencv-python==4.5.5.64 +opencv-python-headless==4.10.0.84 +OpenNMT-py==2.2.0 +openpyxl==3.1.5 +orderly-set==5.2.2 +orjson==3.10.7 +overrides @ file:///croot/overrides_1699371140756/work +packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1718189413536/work +pandas==2.2.2 +pandocfilters @ file:///opt/conda/conda-bld/pandocfilters_1643405455980/work +parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1712320355065/work +pathlib==1.0.1 +pdf2image==1.17.0 +pdfminer.six==20231228 +pdfplumber==0.11.4 +pdftotext==2.2.2 +pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1706113125309/work +pi_heif==0.18.0 +pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work +pikepdf==9.2.1 +pillow==10.4.0 +platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1715777629804/work +portalocker==2.10.1 +prometheus-client @ file:///tmp/abs_d3zeliano1/croots/recipe/prometheus_client_1659455100375/work +prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1718047967974/work +proto-plus==1.24.0 +protobuf==5.28.0 +prov==2.0.1 +psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work +ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl +PubChemPy==1.0.4 +pure_eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1721585709575/work +pyasn1==0.6.1 +pyasn1_modules==0.4.1 +pyclipper==1.3.0.post5 +pycocotools==2.0.8 +pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work +pydantic==2.9.0 +pydantic-settings==2.5.2 +pydantic_core==2.23.2 +pydot==3.0.1 +pydub==0.25.1 +Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1714846767233/work +PyMuPDF==1.25.5 +pymupdf4llm==0.0.21 +PyMuPDFb==1.24.10 +pyonmttok==1.37.1 +pypandoc==1.13 +pyparsing==3.1.4 +pypdf==4.3.1 +PyPDF2==3.0.1 +pypdfium2==4.30.0 +PySocks @ file:///home/builder/ci_310/pysocks_1640793678128/work +python-bidi==0.6.0 +python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1709299778482/work +python-docx==1.1.2 +python-dotenv==1.0.1 +python-iso639==2024.4.27 +python-json-logger @ file:///croot/python-json-logger_1683823803357/work +python-magic==0.4.27 +python-multipart==0.0.12 +python-oxmsg==0.0.1 +python-pptx==1.0.2 +pytz @ file:///croot/pytz_1713974312559/work +pyxnat==1.6.2 +PyYAML @ file:///croot/pyyaml_1698096049011/work +pyzmq @ file:///croot/pyzmq_1705605076900/work +qudida==0.0.4 +rapidfuzz==3.9.7 +rdflib==6.3.2 +rdkit==2024.3.5 +rdkit-pypi==2022.9.5 +referencing @ file:///croot/referencing_1699012038513/work +regex==2024.7.24 +requests @ file:///croot/requests_1721410876868/work +requests-toolbelt==1.0.0 +rfc3339-validator @ file:///croot/rfc3339-validator_1683077044675/work +rfc3986-validator @ file:///croot/rfc3986-validator_1683058983515/work +rich==13.9.4 +rpds-py @ file:///croot/rpds-py_1698945930462/work +rsa==4.9 +ruff==0.7.3 +safehttpx==0.1.1 +safetensors==0.4.5 +scikit-image==0.24.0 +scikit-learn==1.5.1 +scipy==1.13.1 +semantic-version==2.10.0 +Send2Trash @ file:///croot/send2trash_1699371139552/work +sentencepiece==0.2.0 +seqeval==1.2.2 +shapely==2.0.6 +shellingham==1.5.4 +simplejson==3.19.3 +six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work +smart-open==7.0.4 +SmilesPE==0.0.3 +sniffio @ file:///croot/sniffio_1705431295498/work +soupsieve @ file:///croot/soupsieve_1696347547217/work +SQLAlchemy==2.0.34 +stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work +starlette==0.41.2 +striprtf==0.0.26 +sympy==1.13.2 +tabulate==0.9.0 +tenacity==8.5.0 +tensorboard==2.17.1 +tensorboard-data-server==0.7.2 +terminado @ file:///croot/terminado_1671751832461/work +tesseract==0.1.3 +threadpoolctl==3.5.0 +tifffile==2024.8.30 +tiktoken==0.7.0 +timm==0.4.12 +tinycss2 @ file:///croot/tinycss2_1668168815555/work +tinydb==4.8.0 +tokenizers==0.21.1 +tomli @ file:///opt/conda/conda-bld/tomli_1657175507142/work +tomlkit==0.12.0 +torch==2.3.0+cu118 +torchaudio==2.3.0+cu118 +torchtext==0.5.0 +torchvision==0.18.0+cu118 +tornado @ file:///croot/tornado_1718740109488/work +tqdm==4.66.5 +traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1713535121073/work +traits==6.3.2 +transformers==4.47.0 +triton==2.3.0 +typer==0.13.0 +typing-inspect==0.9.0 +typing_extensions @ file:///croot/typing_extensions_1715268824938/work +tzdata==2024.1 +unstructured==0.15.12 +unstructured-client==0.25.8 +unstructured-inference==0.7.36 +unstructured.pytesseract==0.3.13 +urllib3 @ file:///croot/urllib3_1718912636303/work +uvicorn==0.30.6 +valgrind==0.0.0 +waitress==3.0.0 +wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1704731205417/work +webencodings==0.5.1 +websocket-client @ file:///croot/websocket-client_1715878298792/work +websockets==12.0 +Werkzeug==3.0.4 +wrapt==1.16.0 +xlrd==2.0.1 +XlsxWriter==3.2.0 +yarl==1.10.0 diff --git a/rxnscribe/__init__.py b/rxnscribe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c93c6d2426e412f0febde0a9ea1507786ab39a2b --- /dev/null +++ b/rxnscribe/__init__.py @@ -0,0 +1,2 @@ +from .interface import RxnScribe +from .interface import MolDetect diff --git a/rxnscribe/__pycache__/__init__.cpython-310.pyc b/rxnscribe/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c30ec53bde17fee585e089f89f5e4cbe997276f Binary files /dev/null and b/rxnscribe/__pycache__/__init__.cpython-310.pyc differ diff --git a/rxnscribe/__pycache__/__init__.cpython-38.pyc b/rxnscribe/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c469d08bb41149e050a0d29bfe2caf39ab0387fa Binary files /dev/null and b/rxnscribe/__pycache__/__init__.cpython-38.pyc differ diff --git a/rxnscribe/__pycache__/data.cpython-310.pyc b/rxnscribe/__pycache__/data.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6d6b29147a49b126b6558f3437d8fcb582d359b Binary files /dev/null and b/rxnscribe/__pycache__/data.cpython-310.pyc differ diff --git a/rxnscribe/__pycache__/data.cpython-38.pyc b/rxnscribe/__pycache__/data.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..491614c8e45de72547c57b51638a9a7aac298ac8 Binary files /dev/null and b/rxnscribe/__pycache__/data.cpython-38.pyc differ diff --git a/rxnscribe/__pycache__/dataset.cpython-310.pyc b/rxnscribe/__pycache__/dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93ef49fc73bea54c7e6f20b13b2624e6032f0a8d Binary files /dev/null and b/rxnscribe/__pycache__/dataset.cpython-310.pyc differ diff --git a/rxnscribe/__pycache__/dataset.cpython-38.pyc b/rxnscribe/__pycache__/dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8966c84735d72a1b5475cfec17a63227ce163e0 Binary files /dev/null and b/rxnscribe/__pycache__/dataset.cpython-38.pyc differ diff --git a/rxnscribe/__pycache__/interface.cpython-310.pyc b/rxnscribe/__pycache__/interface.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e8c571cd99e8f4f5817a02fb7104d41e905509e Binary files /dev/null and b/rxnscribe/__pycache__/interface.cpython-310.pyc differ diff --git a/rxnscribe/__pycache__/interface.cpython-38.pyc b/rxnscribe/__pycache__/interface.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f9631a6693a3aae404ae4dfdaf06a8fa52fef1e Binary files /dev/null and b/rxnscribe/__pycache__/interface.cpython-38.pyc differ diff --git a/rxnscribe/__pycache__/tokenizer.cpython-310.pyc b/rxnscribe/__pycache__/tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c07916d4453c9332ac76d1656a5e6b22a313a80c Binary files /dev/null and b/rxnscribe/__pycache__/tokenizer.cpython-310.pyc differ diff --git a/rxnscribe/__pycache__/tokenizer.cpython-38.pyc b/rxnscribe/__pycache__/tokenizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9a908df83c4216fdaa9873a6d2350fed35f4ed4 Binary files /dev/null and b/rxnscribe/__pycache__/tokenizer.cpython-38.pyc differ diff --git a/rxnscribe/__pycache__/transforms.cpython-310.pyc b/rxnscribe/__pycache__/transforms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e956f43b04b11b7b218c6a817e22053981fa8d8 Binary files /dev/null and b/rxnscribe/__pycache__/transforms.cpython-310.pyc differ diff --git a/rxnscribe/__pycache__/transforms.cpython-38.pyc b/rxnscribe/__pycache__/transforms.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31703e3e2677bdf3d55fe3d6a4726e1c0d1465a3 Binary files /dev/null and b/rxnscribe/__pycache__/transforms.cpython-38.pyc differ diff --git a/rxnscribe/data.py b/rxnscribe/data.py new file mode 100644 index 0000000000000000000000000000000000000000..4dcf76e882e55a4ba1c8e16a6e62cb6f719e1793 --- /dev/null +++ b/rxnscribe/data.py @@ -0,0 +1,532 @@ +import os +import cv2 +import numpy as np +import matplotlib.colors as colors +import matplotlib.patches as patches +from PIL import Image + + +class BBox(object): + + def __init__(self, bbox, image_data=None, xyxy=False, normalized=False): + """ + :param bbox: {'catrgory_id', 'bbox'} + :param input_image: ImageData object + :param xyxy: + :param normalized: + """ + self.data = bbox + self.image_data = image_data + if image_data is not None: + self.width = image_data.width + self.height = image_data.height + self.category_id = bbox['category_id'] + if xyxy: + x1, y1, x2, y2 = bbox['bbox'] + else: + x1, y1, w, h = bbox['bbox'] + x2, y2 = x1 + w, y1 + h + if not normalized: + x1, y1, x2, y2 = x1 / self.width, y1 / self.height, x2 / self.width, y2 / self.height + self.x1, self.y1, self.x2, self.y2 = x1, y1, x2, y2 + + @property + def is_mol(self): + return self.category_id == 1 + + @property + def is_idt(self): + return self.category_id == 3 + + @property + def is_empty(self): + return abs(self.x2 - self.x1) <= 0.01 or abs(self.y2 - self.y1) <= 0.01 + + def unnormalize(self): + return self.x1 * self.width, self.y1 * self.height, self.x2 * self.width, self.y2 * self.height + + def image(self): + x1, y1, x2, y2 = self.unnormalize() + x1, y1, x2, y2 = max(int(x1), 0), max(int(y1), 0), min(int(x2), self.width), min(int(y2), self.height) + return self.image_data.image[y1:y2, x1:x2] + + COLOR = {1: 'r', 2: 'g', 3: 'b', 4: 'y'} + CATEGORY = {1: 'Mol', 2: 'Txt', 3: 'Idt', 4: 'Sup'} + + def draw(self, ax, color='r', text = None): + x1, y1, x2, y2 = self.unnormalize() + if color is None: + color = self.COLOR[self.category_id] + rect = patches.Rectangle( + (x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor=color, facecolor=colors.to_rgba(color, 0.2)) + text = f'{self.CATEGORY[self.category_id]}' + if text == 'Mol': + ax.text(x1 , y1-15, text, fontsize=10, bbox=dict(linewidth=0, facecolor='yellow', alpha=0.5)) + else: + ax.text(x1-45 , y1+10, text, fontsize=10, bbox=dict(linewidth=0, facecolor='yellow', alpha=0.5)) + ax.add_patch(rect) + return + + def set_smiles(self, smiles, symbols,coords,edges, molfile=None, atoms=None, bonds=None): + self.data['smiles'] = smiles + self.data['symbols'] = symbols + self.data['coords'] = coords + self.data['edges'] = edges + if molfile: + self.data['molfile'] = molfile + + if atoms: + self.data['atoms'] = atoms + if bonds: + self.data['bonds'] = bonds + + def set_text(self, text): + self.data['text'] = text + + def to_json(self): + return self.data + + +class Reaction(object): + + def __init__(self, reaction=None, bboxes=None, image_data=None): + ''' + if image_data is None, create from prediction + if image_data is not None, create from groundtruth + ''' + self.reactants = [] + self.conditions = [] + self.products = [] + self.bboxes = [] + if reaction is not None: + for x in reaction['reactants']: + bbox = bboxes[x] if type(x) is int else BBox(x, image_data, xyxy=True, normalized=True) + self.bboxes.append(bbox) + self.reactants.append(len(self.bboxes) - 1) + for x in reaction['conditions']: + bbox = bboxes[x] if type(x) is int else BBox(x, image_data, xyxy=True, normalized=True) + self.bboxes.append(bbox) + self.conditions.append(len(self.bboxes) - 1) + for x in reaction['products']: + bbox = bboxes[x] if type(x) is int else BBox(x, image_data, xyxy=True, normalized=True) + self.bboxes.append(bbox) + self.products.append(len(self.bboxes) - 1) + + def to_json(self): + return { + 'reactants': [self.bboxes[i].to_json() for i in self.reactants], + 'conditions': [self.bboxes[i].to_json() for i in self.conditions], + 'products': [self.bboxes[i].to_json() for i in self.products] + } + + def _deduplicate_bboxes(self, indices): + results = [] + for i, idx_i in enumerate(indices): + duplicate = False + for j, idx_j in enumerate(indices[:i]): + if get_iou(self.bboxes[idx_i], self.bboxes[idx_j]) > 0.6: + duplicate = True + break + if not duplicate: + results.append(idx_i) + return results + + def deduplicate(self): + flags = [False] * len(self.bboxes) + bbox_list = self.reactants + self.products + self.conditions + for i, idx_i in enumerate(bbox_list): + if self.bboxes[idx_i].is_empty: + flags[idx_i] = True + continue + for idx_j in bbox_list[:i]: + if flags[idx_j] is False and get_iou(self.bboxes[idx_i], self.bboxes[idx_j]) > 0.6: + flags[idx_i] = True + break + self.reactants = [i for i in self.reactants if not flags[i]] + self.conditions = [i for i in self.conditions if not flags[i]] + self.products = [i for i in self.products if not flags[i]] + + def schema(self, mol_only=False): + # Return reactants, conditions, and products. If mol_only is True, only include bboxes that are mol structures. + if mol_only: + reactants, conditions, products = [[idx for idx in indices if self.bboxes[idx].is_mol] + for indices in [self.reactants, self.conditions, self.products]] + # It would be unfair to compare two reactions if their reactants or products are empty after filtering. + # Setting them to the original ones in this case. + if len(reactants) == 0: + reactants = self.reactants + if len(products) == 0: + products = self.products + return reactants, conditions, products + else: + return self.reactants, self.conditions, self.products + + def compare(self, other, mol_only=False, merge_condition=False, debug=False): + reactants1, conditions1, products1 = self.schema(mol_only) + reactants2, conditions2, products2 = other.schema(mol_only) + if debug: + print(reactants1, conditions1, products1, ';', reactants2, conditions2, products2) + if len(reactants1) + len(conditions1) + len(products1) == 0: + # schema is empty, always return False + return False + if len(reactants1) + len(conditions1) + len(products1) != len(reactants2) + len(conditions2) + len(products2): + return False + # Match use original index + match1, match2, scores = get_bboxes_match(self.bboxes, other.bboxes, iou_thres=0.5) + m_reactants, m_conditions, m_products = [[match1[i] for i in x] for x in [reactants1, conditions1, products1]] + if any([m == -1 for m in m_reactants + m_conditions + m_products]): + return False + if debug: + print(m_reactants, m_conditions, m_products, ';', reactants2, conditions2, products2) + if merge_condition: + return sorted(m_reactants + m_conditions) == sorted(reactants2 + conditions2) \ + and sorted(m_products) == sorted(products2) + else: + return sorted(m_reactants) == sorted(reactants2) and sorted(m_conditions) == sorted(conditions2) \ + and sorted(m_products) == sorted(products2) + + def __eq__(self, other): + # Exact matching of two reactions + return self.compare(other) + + def draw(self, ax): + for i in self.reactants: + self.bboxes[i].draw(ax, color='r') + for i in self.conditions: + self.bboxes[i].draw(ax, color='g') + for i in self.products: + self.bboxes[i].draw(ax, color='b') + return + + +class ReactionSet(object): + + def __init__(self, reactions, bboxes=None, image_data=None): + self.reactions = [Reaction(reaction, bboxes, image_data) for reaction in reactions] + + def __len__(self): + return len(self.reactions) + + def __iter__(self): + return iter(self.reactions) + + def __getitem__(self, item): + return self.reactions[item] + + def deduplicate(self): + results = [] + for reaction in self.reactions: + if any(r == reaction for r in results): + continue + if len(reaction.reactants) < 1 or len(reaction.products) < 1: + continue + results.append(reaction) + self.reactions = results + + def to_json(self): + return [r.to_json() for r in self.reactions] + + +class ImageData(object): + + def __init__(self, data=None, predictions=None, image_file=None, image=None): + self.width, self.height = None, None + if data: + self.file_name = data['file_name'] + self.width = data['width'] + self.height = data['height'] + if image_file: + self.image = cv2.imread(image_file) + self.height, self.width, _ = self.image.shape + if image is not None: + if not isinstance(image, np.ndarray): + image = np.asarray(image) + self.image = image + self.height, self.width, _ = self.image.shape + if data and 'bboxes' in data: + self.gold_bboxes = [BBox(bbox, self, xyxy=False, normalized=False) for bbox in data['bboxes']] + if predictions is not None: + self.pred_bboxes = [BBox(bbox, self, xyxy=True, normalized=True) for bbox in predictions] + + def draw_gold(self, ax, image=None): + if image is not None: + ax.imshow(image) + for i, b in enumerate(self.gold_bboxes): + b.draw(ax, color = None) + + def draw_prediction(self, ax, image=None): + if image is not None: + ax.imshow(image) + for i, b in enumerate(self.pred_bboxes): + b.draw(ax, color = None) + + +class ReactionImageData(ImageData): + + def __init__(self, data=None, predictions=None, image_file=None, image=None): + super().__init__(data=data, image_file=image_file, image=image) + if data and 'reactions' in data: + self.gold_reactions = ReactionSet(data['reactions'], self.gold_bboxes, image_data=self) + if predictions is not None: + self.pred_reactions = ReactionSet(predictions, image_data=self) + self.pred_reactions.deduplicate() + + def evaluate(self, mol_only=False, merge_condition=False, debug=False): + gold_total = len(self.gold_reactions) + gold_hit = [False] * gold_total + pred_total = len(self.pred_reactions) + pred_hit = [False] * pred_total + for i, ri in enumerate(self.gold_reactions): + for j, rj in enumerate(self.pred_reactions): + if gold_hit[i] and pred_hit[j]: + continue + if ri.compare(rj, mol_only, merge_condition, debug): + gold_hit[i] = True + pred_hit[j] = True + return gold_hit, pred_hit + +class CorefImageData(ImageData): + + def __init__(self, data=None, predictions=None, image_file=None, image=None): + super().__init__(data=data, predictions = predictions, image_file=image_file, image=image) + if data and 'corefs' in data: + self.gold_corefs = data['corefs'] + + def evaluate(self): + #for every bbox in self.gold_bboxes, match with highest iou in self.pred_bboxes + #a true hit is defined as follows: suppose a pair (i, j) is a coref. then if highest_iou(j) follows + #highest_iou(i) in pred_bboxes, it is a hit. + #total number of predictions is number of bboxes in pred/2. + #precision = TP/number of predictions + #recall = TP/number of gt pairs + + if hasattr(self, "pred_bboxes"): + hits = 0 + num_preds = 0 + for pred in self.pred_bboxes: + if pred.category_id == 3: + num_preds+=1 + matches = {} + for gold in self.gold_bboxes: + highest_iou = 0 + highest_index = -1 + for i, pred in enumerate(self.pred_bboxes): + iou = get_iou(gold, pred) + if iou> highest_iou: + highest_iou = iou + highest_index = i + if highest_iou > 0.3 and gold.category_id == 1: + matches[gold] = highest_index + else: + matches[gold]=highest_index + for coref_pair in self.gold_corefs: + mol = self.gold_bboxes[coref_pair[0]] + idx = self.gold_bboxes[coref_pair[1]] + + if mol in matches and idx in matches: + all_ids = True + if matches[mol] < matches[idx]: + for counter in range(matches[mol]+1, matches[idx], 1): + if self.pred_bboxes[counter].category_id != 3: + all_ids = False + if all_ids: + hits+=1 + return hits, len(self.gold_corefs), num_preds + + return 0, 0, 0 + + def draw_gold(self, ax, image=None): + if image is not None: + ax.imshow(image) + counter_dict = {} + counter = 0 + + for pair in self.gold_corefs: + mol, idt = pair + if mol in counter_dict: + xmin, ymin, xmax, ymax = self.gold_bboxes[idt].unnormalize() + ax.text(xmin - 50, ymin+ 60, str(counter_dict[mol]), fontsize=20, bbox=dict(facecolor='purple', alpha=0.5)) + + else: + counter+=1 + counter_dict[mol] = counter + xmin, ymin, xmax, ymax = self.gold_bboxes[mol].unnormalize() + ax.text(xmin - 50, ymin+ 60, str(counter), fontsize=20, bbox=dict(facecolor='purple', alpha=0.5)) + xmin, ymin, xmax, ymax = self.gold_bboxes[idt].unnormalize() + ax.text(xmin - 50, ymin+ 60, str(counter), fontsize=20, bbox=dict(facecolor='purple', alpha=0.5)) + for b in self.gold_bboxes: + b.draw(ax) + + def draw_prediction(self, ax, image=None): + if image is not None: + ax.imshow(image) + counter = 0 + colours = ['#648fff', '#785ef0','#dc267f', '#fe6100','#ffb000','r', 'b', 'g', 'k', 'c', 'm', 'y', 'r', 'b', 'g', 'k', 'c', 'm', 'y', 'r', 'b', 'g', 'k', 'c', 'm', 'y'] + colorcounter = -1 + for i, b in enumerate(self.pred_bboxes): + if (b.category_id == 1 or b.category_id == 2): + counter += 1 + colorcounter += 1 + b.draw(ax, color = colours[colorcounter%len(colours)]) + elif b.category_id == 3: + b.draw(ax, color = colours[colorcounter%len(colours)]) + + + +def deduplicate_bboxes(bboxes): + results = [] + for i in range(len(bboxes)): + duplicate = False + for j in range(i): + if get_iou(bboxes[i], bboxes[j]) > 0.9: + duplicate = True + break + if not duplicate: + results.append(bboxes[i]) + return results + +def get_iou(bb1, bb2): + """Calculate the Intersection over Union (IoU) of two bounding boxes.""" + bb1 = {'x1': bb1.x1, 'y1': bb1.y1, 'x2': bb1.x2, 'y2': bb1.y2} + bb2 = {'x1': bb2.x1, 'y1': bb2.y1, 'x2': bb2.x2, 'y2': bb2.y2} + + assert bb1['x1'] < bb1['x2'] + assert bb1['y1'] < bb1['y2'] + assert bb2['x1'] < bb2['x2'] + assert bb2['y1'] < bb2['y2'] + + # determine the coordinates of the intersection rectangle + x_left = max(bb1['x1'], bb2['x1']) + y_top = max(bb1['y1'], bb2['y1']) + x_right = min(bb1['x2'], bb2['x2']) + y_bottom = min(bb1['y2'], bb2['y2']) + + if x_right < x_left or y_bottom < y_top: + return 0.0 + + # The intersection of two axis-aligned bounding boxes is always an + # axis-aligned bounding box + intersection_area = (x_right - x_left) * (y_bottom - y_top) + + # compute the area of both AABBs + bb1_area = (bb1['x2'] - bb1['x1']) * (bb1['y2'] - bb1['y1']) + bb2_area = (bb2['x2'] - bb2['x1']) * (bb2['y2'] - bb2['y1']) + + # compute the intersection over union by taking the intersection + # area and dividing it by the sum of prediction + ground-truth + # areas - the interesection area + iou = intersection_area / float(bb1_area + bb2_area - intersection_area) + assert iou >= 0.0 + assert iou <= 1.0 + return iou + + +def get_bboxes_match(bboxes1, bboxes2, iou_thres=0.5, match_category=False): + """Find the match between two sets of bboxes. Each bbox is matched with a bbox with maximum overlap + (at least above iou_thres). -1 if a bbox does not have a match.""" + scores = np.zeros((len(bboxes1), len(bboxes2))) + for i, bbox1 in enumerate(bboxes1): + for j, bbox2 in enumerate(bboxes2): + if match_category and bbox1.category_id != bbox2.category_id: + scores[i, j] = 0 + else: + scores[i, j] = get_iou(bbox1, bbox2) + match1 = scores.argmax(axis=1) + for i in range(len(match1)): + if scores[i, match1[i]] < iou_thres: + match1[i] = -1 + match2 = scores.argmax(axis=0) + for j in range(len(match2)): + if scores[match2[j], j] < iou_thres: + match2[j] = -1 + return match1, match2, scores + + +def deduplicate_reactions(reactions): + pred_reactions = ReactionSet(reactions) + for r in pred_reactions: + r.deduplicate() + pred_reactions.deduplicate() + return pred_reactions.to_json() + + +def postprocess_reactions(reactions, image_file=None, image=None, molscribe=None, ocr=None, batch_size=32): + image_data = ReactionImageData(predictions=reactions, image_file=image_file, image=image) + pred_reactions = image_data.pred_reactions + for r in pred_reactions: + r.deduplicate() + pred_reactions.deduplicate() + if molscribe: + bbox_images, bbox_indices = [], [] + for i, reaction in enumerate(pred_reactions): + for j, bbox in enumerate(reaction.bboxes): + if bbox.is_mol: + bbox_images.append(bbox.image()) + bbox_indices.append((i, j)) + if len(bbox_images) > 0: + predictions = molscribe.predict_images(bbox_images, return_atoms_bonds=True, batch_size=batch_size) + + for (i, j), pred in zip(bbox_indices, predictions): + pred_reactions[i].bboxes[j].set_smiles(pred['smiles'],pred["symbols"], pred["coords"],pred["edges"],pred['molfile'],pred['atoms'], pred['bonds']) + #deduplicated[i].set_smiles(pred['smiles'],pred['oringinal_coords'],pred['original_symbols'],pred['orignal_edges']) + if ocr: + for reaction in pred_reactions: + for bbox in reaction.bboxes: + if not bbox.is_mol: + text = ocr.readtext(bbox.image(), detail=0) + bbox.set_text(text) + return pred_reactions.to_json() + +def postprocess_bboxes(bboxes, image = None, molscribe = None, batch_size = 32): + image_d = ImageData(image = image) + bbox_objects = [BBox(bbox = bbox, image_data = image_d, xyxy = True, normalized = True) for bbox in bboxes] + bbox_objects_no_empty = [bbox for bbox in bbox_objects if not bbox.is_empty] + #deduplicate + deduplicated = deduplicate_bboxes(bbox_objects_no_empty) + + if molscribe: + bbox_images, bbox_indices = [], [] + + for i, bbox in enumerate(deduplicated): + if bbox.is_mol: + bbox_images.append(bbox.image()) + bbox_indices.append(i) + + if len(bbox_images) > 0: + predictions = molscribe.predict_images(bbox_images, return_atoms_bonds=True, batch_size = batch_size) + + for i, pred in zip(bbox_indices, predictions): + #deduplicated[i].set_smiles(pred['smiles'], pred["original_symbols"],pred['molfile'],pred['atoms'], pred['bonds']) + deduplicated[i].set_smiles(pred['smiles'],pred["symbols"], pred["coords"],pred["edges"],pred['molfile'],pred['atoms'], pred['bonds']) + return [bbox.to_json() for bbox in deduplicated] + +def postprocess_coref_results(bboxes, image, molscribe = None, ocr = None, batch_size = 32): + image_d = ImageData(image = cv2.resize(np.asarray(image), None, fx=3, fy=3)) + bbox_objects = [BBox(bbox = bbox, image_data = image_d, xyxy = True, normalized = True) for bbox in bboxes['bboxes']] + if molscribe: + + bbox_images, bbox_indices = [], [] + + for i, bbox in enumerate(bbox_objects): + if bbox.is_mol: + bbox_images.append(bbox.image()) + bbox_indices.append(i) + + if len(bbox_images) > 0: + predictions = molscribe.predict_images(bbox_images, return_atoms_bonds=True, batch_size = batch_size) + + for i, pred in zip(bbox_indices, predictions): + bbox_objects[i].set_smiles(pred['smiles'],pred["symbols"], pred["coords"],pred["edges"],pred['molfile'],pred['atoms'], pred['bonds']) + if ocr: + for bbox in bbox_objects: + if bbox.is_idt: + text = ocr.readtext(bbox.image(), detail = 0) + bbox.set_text(text) + + return {'bboxes': [bbox.to_json() for bbox in bbox_objects], 'corefs': bboxes['corefs']} + + + + + + diff --git a/rxnscribe/dataset.py b/rxnscribe/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..51b6f19e93ac9d6990e59c3f42d6b7c2041534a3 --- /dev/null +++ b/rxnscribe/dataset.py @@ -0,0 +1,263 @@ +import os +import cv2 +import copy +import random +import json +import contextlib +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset +from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence + +from . import transforms as T + +from pycocotools.coco import COCO +from PIL import Image + + +class ReactionDataset(Dataset): + def __init__(self, args, tokenizer, data_file=None, image_files=None, split='train', debug=False): + super().__init__() + self.args = args + self.tokenizer = tokenizer + if data_file: + data_file = os.path.join(args.data_path, data_file) + with open(data_file) as f: + self.data = json.load(f)['images'] + if split == 'train' and args.num_train_example is not None: + self.data = self.data[:args.num_train_example] + if split != 'train': + with open(os.devnull, 'w') as devnull: + with contextlib.redirect_stdout(devnull): + self.coco = COCO(data_file) + self.name = os.path.basename(data_file).split('.')[0] + if image_files: + self.data = [{'file_name': file} for file in image_files] + self.image_path = args.image_path + self.split = split + self.format = args.format + self.is_train = (split == 'train') + self.transform = make_transforms(split, args.augment, debug) + # self.reaction_transform = T.RandomReactionCrop() + + def __len__(self): + return len(self.data) + + @property + def pad_id(self): + return self.tokenizer[self.format].PAD_ID + + def generate_sample(self, image, target): + ref = {} + # coordinates are normalized after transform + image, target = self.transform(image, target) + ref['scale'] = target['scale'] + if self.is_train or True: + args = self.args + if self.format == 'reaction': + max_len = self.tokenizer['reaction'].max_len + label, label_out = self.tokenizer['reaction'].data_to_sequence( + target, rand_order=args.rand_order, shuffle_bbox=args.shuffle_bbox, add_noise=args.add_noise, + mix_noise=args.mix_noise) + ref['reaction'] = torch.LongTensor(label[:max_len]) + ref['reaction_out'] = torch.LongTensor(label_out[:max_len]) + if self.format == 'bbox': + max_len = self.tokenizer['bbox'].max_len + label, label_out = self.tokenizer['bbox'].data_to_sequence( + target, rand_order=args.rand_order, split_heuristic = args.split_heuristic, add_noise=args.add_noise) + ref['bbox'] = torch.LongTensor(label[:max_len]) + ref['bbox_out'] = torch.LongTensor(label_out[:max_len]) + if self.format == 'coref': + max_len = self.tokenizer['coref'].max_len + label, label_out = self.tokenizer['coref'].data_to_sequence( + target, rand_order = False, add_noise = False, split_heuristic = args.split_heuristic + ) + + ref['coref'] = torch.LongTensor(label[:max_len]) + ref['coref_out'] = torch.LongTensor(label_out[:max_len]) + return image, ref + + def __getitem__(self, idx): + image, target = self.load_and_prepare(idx) + if self.is_train and self.args.composite_augment: + cnt = 0 + while idx % 2 == random.randrange(2) and cnt < 5: + # Augment with probability 0.5 + n = len(self) + idx2 = (idx + random.randrange(n)) % n + image2, target2 = self.load_and_prepare(idx2) + # if 'reaction' in self.formats: + # image, target = self.reaction_transform(image, target) + # image2, target2 = self.reaction_transform(image2, target2) + image, target = self.concat(image, target, image2, target2) + cnt += 1 + if self.is_train and self.args.augment: + image1, ref1 = self.generate_sample(image, target) + image2, ref2 = self.generate_sample(image, target) + return [[idx, image1, ref1], [idx, image2, ref2]] + else: + image, ref = self.generate_sample(image, target) + ref['file_name'] = self.data[idx]['file_name'] + return [[idx, image, ref]] + + def load_and_prepare(self, idx): + target = self.data[idx] + if self.args.is_coco: + if self.is_train: + path = os.path.join(self.image_path, 'train2017', target['file_name']) + else: + path = os.path.join(self.image_path, 'val2017', target['file_name']) + else: + path = os.path.join(self.image_path, target['file_name']) + if not os.path.exists(path): + print(path, "doesn't exists.", flush=True) + image = Image.open(path).convert("RGB") + if self.is_train or True: + image, target = self.prepare(image, target) + return image, target + + def prepare(self, image, target): + w, h = target['width'], target['height'] + + image_id = target["id"] + image_id = torch.tensor([image_id]) + + anno = target["bboxes"] + + boxes = [obj['bbox'] for obj in anno] + # guard against no boxes via resizing + boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2].clamp_(min=0, max=w) + boxes[:, 1::2].clamp_(min=0, max=h) + + classes = [obj["category_id"] for obj in anno] + classes = torch.tensor(classes, dtype=torch.int64) + target = copy.deepcopy(target) + target["boxes"] = boxes + target["labels"] = classes + target["image_id"] = image_id + + # for conversion to coco api + area = torch.tensor([obj["bbox"][2] * obj['bbox'][3] for obj in anno]) + target["area"] = area + target["orig_size"] = torch.as_tensor([int(w), int(h)]) + target["size"] = torch.as_tensor([int(w), int(h)]) + + return image, target + + def concat(self, image1, target1, image2, target2): + color = (255, 255, 255) + if random.random() < 1: + # Vertically concat two images + w = max(image1.width, image2.width) + h = image1.height + image2.height + if image1.width > image2.width: + x1, y1 = 0, 0 + x2, y2 = random.randint(0, image1.width - image2.width), image1.height + else: + x1, y1 = random.randint(0, image2.width - image1.width), 0 + x2, y2 = 0, image1.height + else: + # Horizontally concat two images + w = image1.width + image2.width + h = max(image1.height, image2.height) + if image1.height > image2.height: + x1, y1 = 0, 0 + x2, y2 = image1.width, random.randint(0, image1.height - image2.height) + else: + x1, y1 = 0, random.randint(0, image2.height - image1.height) + x2, y2 = image1.width, 0 + image = Image.new('RGB', (w, h), color) + image.paste(image1, (x1, y1)) + image.paste(image2, (x2, y2)) + target = { + "image_id": target1["image_id"], + "orig_size": torch.as_tensor([int(w), int(h)]), + "size": torch.as_tensor([int(w), int(h)]) + } + target1["boxes"][:, 0::2] += x1 + target1["boxes"][:, 1::2] += y1 + target2["boxes"][:, 0::2] += x2 + target2["boxes"][:, 1::2] += y2 + for key in ["boxes", "labels", "area"]: + target[key] = torch.cat([target1[key], target2[key]], dim=0) + if "reactions" in target1 and self.format == 'reactions': + target["reactions"] = [r for r in target1["reactions"]] + nbox = len(target1["boxes"]) + for r in target2["reactions"]: + newr = {} + for key, seq in r.items(): + newr[key] = [x + nbox for x in seq] + + target["reactions"].append(newr) + if "corefs" in target1 and self.format == 'coref': + target["corefs"] = [pair for pair in target1["corefs"]] + nBoxes1 = len(target1["boxes"]) + for pair in target2["corefs"]: + target["corefs"].append([pair[0]+nBoxes1, pair[1]+nBoxes1]) + return image, target + + +def make_transforms(image_set, augment=False, debug=False): + normalize = T.Compose([ + # T.Resize((1333, 1333)), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], debug) + ]) + + if image_set == 'train' and augment: + return T.Compose([ + T.RandomRotate(), + T.RandomHorizontalFlip(), + T.LargeScaleJitter(output_size=1333, aug_scale_min=0.3, aug_scale_max=2.0), + T.RandomDistortion(0.5, 0.5, 0.5, 0.5), + normalize]) + else: + return T.Compose([ + T.LargeScaleJitter(output_size=1333, aug_scale_min=1.0, aug_scale_max=1.0), + normalize]) + + +def pad_images(imgs): + # B, C, H, W + max_shape = [0, 0] + for img in imgs: + for i in range(len(max_shape)): + max_shape[i] = max(max_shape[i], img.shape[-1-i]) + stack = [] + for img in imgs: + pad = [] + for i in range(len(max_shape)): + pad = pad + [0, max_shape[i] - img.shape[-1-i]] + stack.append(F.pad(img, pad, value=0)) + return torch.stack(stack) + + +def get_collate_fn(pad_id): + def rxn_collate(batch): + ids = [] + imgs = [] + batch = [ex for seq in batch for ex in seq] + keys = list(batch[0][2].keys()) + seq_formats = [key for key in keys if key in ['bbox', 'bbox_out', 'reaction', 'reaction_out', 'coref', 'coref_out']] + refs = {key: [[], []] for key in seq_formats} + for ex in batch: + ids.append(ex[0]) + imgs.append(ex[1]) + ref = ex[2] + for key in seq_formats: + refs[key][0].append(ref[key]) + refs[key][1].append(torch.LongTensor([len(ref[key])])) + # Sequence + for key in keys: + if key in seq_formats: + refs[key][0] = pad_sequence(refs[key][0], batch_first=True, padding_value=pad_id) + refs[key][1] = torch.stack(refs[key][1]).reshape(-1, 1) + else: + refs[key] = [ex[2][key] for ex in batch] + return ids, pad_images(imgs), refs + + return rxn_collate diff --git a/rxnscribe/evaluate.py b/rxnscribe/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..12c54548c4bd46725df046791cfcce5a7555d0d0 --- /dev/null +++ b/rxnscribe/evaluate.py @@ -0,0 +1,175 @@ +import os +import contextlib +import copy +import numpy as np + +from pycocotools.cocoeval import COCOeval +from pycocotools.coco import COCO + +from .data import ImageData, ReactionImageData, CorefImageData + + +class CocoEvaluator(object): + + def __init__(self, coco_gt): + coco_gt = copy.deepcopy(coco_gt) + self.coco_gt = coco_gt + + def evaluate(self, predictions): + img_ids, results = self.prepare(predictions, 'bbox') + if len(results) == 0: + return np.zeros((12,)) + coco_dt = self.coco_gt.loadRes(results) + cocoEval = COCOeval(self.coco_gt, coco_dt, 'bbox') + cocoEval.params.imgIds = img_ids + cocoEval.params.catIds = [1] + cocoEval.evaluate() + cocoEval.accumulate() + cocoEval.summarize() + self.cocoEval = cocoEval + return cocoEval.stats + + def prepare(self, predictions, iou_type): + if iou_type == "bbox": + return self.prepare_for_coco_detection(predictions) + else: + raise ValueError("Unknown iou type {}".format(iou_type)) + + def prepare_for_coco_detection(self, predictions): + img_ids = [] + coco_results = [] + for idx, prediction in enumerate(predictions): + if len(prediction) == 0: + continue + + image = self.coco_gt.dataset['images'][idx] + img_ids.append(image['id']) + width = image['width'] + height = image['height'] + coco_results.extend( + [ + { + "image_id": image['id'], + "category_id": pred['category_id'], + "bbox": convert_to_xywh(pred['bbox'], width, height), + "score": pred['score'], + } + for pred in prediction + ] + ) + return img_ids, coco_results + + +def convert_to_xywh(box, width, height): + xmin, ymin, xmax, ymax = box + return [xmin * width, ymin * height, (xmax - xmin) * width, (ymax - ymin) * height] + + +EMPTY_STATS = {'gold_hits': 0, 'gold_total': 0, 'pred_hits': 0, 'pred_total': 0, 'image': 0} + + +class ReactionEvaluator(object): + + def evaluate_image(self, gold_image, pred_image, **kwargs): + data = ReactionImageData(gold_image, pred_image) + + + return data.evaluate(**kwargs) + + def compute_metrics(self, gold_hits, gold_total, pred_hits, pred_total): + precision = pred_hits / max(pred_total, 1) + recall = gold_hits / max(gold_total, 1) + f1 = precision * recall * 2 / max(precision + recall, 1e-6) + return {'precision': precision, 'recall': recall, 'f1': f1} + + def evaluate(self, groundtruths, predictions, **kwargs): + gold_hits, gold_total, pred_hits, pred_total = 0, 0, 0, 0 + for gold_image, pred_image in zip(groundtruths, predictions): + gh, ph = self.evaluate_image(gold_image, pred_image, **kwargs) + gold_hits += sum(gh) + gold_total += len(gh) + pred_hits += sum(ph) + pred_total += len(ph) + return self.compute_metrics(gold_hits, gold_total, pred_hits, pred_total) + + def evaluate_by_size(self, groundtruths, predictions, **kwargs): + group_stats = {} + for gold_image, pred_image in zip(groundtruths, predictions): + gh, ph = self.evaluate_image(gold_image, pred_image, **kwargs) + gtotal = len(gh) + if gtotal not in group_stats: + group_stats[gtotal] = copy.deepcopy(EMPTY_STATS) + group_stats[gtotal]['gold_hits'] += sum(gh) + group_stats[gtotal]['gold_total'] += len(gh) + group_stats[gtotal]['pred_hits'] += sum(ph) + group_stats[gtotal]['pred_total'] += len(ph) + group_stats[gtotal]['image'] += 1 + group_scores = {} + for gtotal, stats in group_stats.items(): + group_scores[gtotal] = self.compute_metrics( + stats['gold_hits'], stats['gold_total'], stats['pred_hits'], stats['pred_total']) + return group_scores, group_stats + + def evaluate_by_group(self, groundtruths, predictions, **kwargs): + group_stats = {} + for gold_image, pred_image in zip(groundtruths, predictions): + gh, ph = self.evaluate_image(gold_image, pred_image, **kwargs) + diagram_type = gold_image['diagram_type'] + if diagram_type not in group_stats: + group_stats[diagram_type] = copy.deepcopy(EMPTY_STATS) + group_stats[diagram_type]['gold_hits'] += sum(gh) + group_stats[diagram_type]['gold_total'] += len(gh) + group_stats[diagram_type]['pred_hits'] += sum(ph) + group_stats[diagram_type]['pred_total'] += len(ph) + group_stats[diagram_type]['image'] += 1 + group_scores = {} + for group, stats in group_stats.items(): + group_scores[group] = self.compute_metrics( + stats['gold_hits'], stats['gold_total'], stats['pred_hits'], stats['pred_total']) + return group_scores, group_stats + + def evaluate_summarize(self, groundtruths, predictions, **kwargs): + size_scores, size_stats = self.evaluate_by_size(groundtruths, predictions, **kwargs) + summarize = { + 'overall': copy.deepcopy(EMPTY_STATS), + # 'single': copy.deepcopy(EMPTY_STATS), + # 'multiple': copy.deepcopy(EMPTY_STATS) + } + for size, stats in size_stats.items(): + if type(size) is int: + # output = summarize['single'] if size <= 1 else summarize['multiple'] + for key in stats: + # output[key] += stats[key] + summarize['overall'][key] += stats[key] + scores = {} + for key, val in summarize.items(): + scores[key] = self.compute_metrics(val['gold_hits'], val['gold_total'], val['pred_hits'], val['pred_total']) + return scores, summarize, size_stats + +class CorefEvaluator(object): + + def evaluate_image(self, gold_image, pred_image, **kwargs): + data = CorefImageData(gold_image, predictions = pred_image) + return data.evaluate() + + def evaluate(self, groundtruths, predictions): + hits, gold_total, pred_total = 0, 0, 0 + counter = 0 + print(len(predictions)) + for gold_image, pred_image in zip(groundtruths, predictions): + + try: hit, gold_pairs, pred_pairs = self.evaluate_image(gold_image, pred_image) + except: print(counter) + hits += hit + gold_total += gold_pairs + pred_total += pred_pairs + counter += 1 + return hits, gold_total, pred_total + + def evaluate_summarize(self, groundtruths, predictions): + hits, gold_total, pred_total = self.evaluate(groundtruths, predictions) + precision = hits/max(pred_total, 1) + recall = hits/max(gold_total, 1) + f1 = precision * recall * 2 / max(precision + recall, 1e-6) + return (precision, recall, f1) + \ No newline at end of file diff --git a/rxnscribe/inference/__init__.py b/rxnscribe/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c61feef15b9a94dbb97126be593f0c445f1870c0 --- /dev/null +++ b/rxnscribe/inference/__init__.py @@ -0,0 +1,4 @@ +from .greedy_search import GreedySearch +from .beam_search import BeamSearch + +__all__ = ["GreedySearch", "BeamSearch"] diff --git a/rxnscribe/inference/beam_search.py b/rxnscribe/inference/beam_search.py new file mode 100644 index 0000000000000000000000000000000000000000..5b2fe9361f564b796913c00fee294f0367af66d1 --- /dev/null +++ b/rxnscribe/inference/beam_search.py @@ -0,0 +1,200 @@ +import torch +from .decode_strategy import DecodeStrategy + +import warnings + + +class BeamSearch(DecodeStrategy): + """Generation with beam search. + """ + + def __init__(self, pad, bos, eos, batch_size, beam_size, n_best, min_length, + return_attention, max_length): + super(BeamSearch, self).__init__( + pad, bos, eos, batch_size, beam_size, min_length, return_attention, max_length) + self.beam_size = beam_size + self.n_best = n_best + + # result caching + self.hypotheses = [[] for _ in range(batch_size)] + + # beam state + self.top_beam_finished = torch.zeros([batch_size], dtype=torch.bool) + + self._batch_offset = torch.arange(batch_size, dtype=torch.long) + + self.select_indices = None + self.done = False + + def initialize(self, memory_bank, device=None): + """Repeat src objects `beam_size` times. + """ + def fn_map_state(state, dim): + return torch.repeat_interleave(state, self.beam_size, dim=dim) + + memory_bank = torch.repeat_interleave(memory_bank, self.beam_size, dim=0) + if device is None: + device = memory_bank.device + + self.memory_length = memory_bank.size(1) + super().initialize(memory_bank, device) + + self.best_scores = torch.full( + [self.batch_size], -1e10, dtype=torch.float, device=device) + self._beam_offset = torch.arange( + 0, self.batch_size * self.beam_size, step=self.beam_size, + dtype=torch.long, device=device) + self.topk_log_probs = torch.tensor( + [0.0] + [float("-inf")] * (self.beam_size - 1), device=device + ).repeat(self.batch_size) + # buffers for the topk scores and 'backpointer' + self.topk_scores = torch.empty((self.batch_size, self.beam_size), + dtype=torch.float, device=device) + self.topk_ids = torch.empty((self.batch_size, self.beam_size), + dtype=torch.long, device=device) + self._batch_index = torch.empty([self.batch_size, self.beam_size], + dtype=torch.long, device=device) + + return fn_map_state, memory_bank + + @property + def current_predictions(self): + return self.alive_seq[:, -1] + + @property + def current_backptr(self): + # for testing + return self.select_indices.view(self.batch_size, self.beam_size) + + @property + def batch_offset(self): + return self._batch_offset + + def _pick(self, log_probs): + """Return token decision for a step. + + Args: + log_probs (FloatTensor): (B, vocab_size) + + Returns: + topk_scores (FloatTensor): (B, beam_size) + topk_ids (LongTensor): (B, beam_size) + """ + vocab_size = log_probs.size(-1) + + # Flatten probs into a list of probabilities. + curr_scores = log_probs.reshape(-1, self.beam_size * vocab_size) + topk_scores, topk_ids = torch.topk(curr_scores, self.beam_size, dim=-1) + return topk_scores, topk_ids + + def advance(self, log_probs, attn): + """ + Args: + log_probs: (B * beam_size, vocab_size) + """ + vocab_size = log_probs.size(-1) + + # (non-finished) batch_size + _B = log_probs.shape[0] // self.beam_size + + step = len(self) # alive_seq + self.ensure_min_length(log_probs) + + # Multiply probs by the beam probability + log_probs += self.topk_log_probs.view(_B * self.beam_size, 1) + + curr_length = step + 1 + curr_scores = log_probs / curr_length # avg log_prob + self.topk_scores, self.topk_ids = self._pick(curr_scores) + # topk_scores/topk_ids: (batch_size, beam_size) + + # Recover log probs + torch.mul(self.topk_scores, curr_length, out=self.topk_log_probs) + + # Resolve beam origin and map to batch index flat representation. + self._batch_index = self.topk_ids // vocab_size + self._batch_index += self._beam_offset[:_B].unsqueeze(1) + self.select_indices = self._batch_index.view(_B * self.beam_size) + self.topk_ids.fmod_(vocab_size) # resolve true word ids + + # Append last prediction. + self.alive_seq = torch.cat( + [self.alive_seq.index_select(0, self.select_indices), + self.topk_ids.view(_B * self.beam_size, 1)], -1) + + if self.return_attention: + current_attn = attn.index_select(1, self.select_indices) + if step == 1: + self.alive_attn = current_attn + else: + self.alive_attn = self.alive_attn.index_select( + 1, self.select_indices) + self.alive_attn = torch.cat([self.alive_attn, current_attn], 0) + + self.is_finished = self.topk_ids.eq(self.eos) + self.ensure_max_length() + + def update_finished(self): + _B_old = self.topk_log_probs.shape[0] + step = self.alive_seq.shape[-1] # len(self) + self.topk_log_probs.masked_fill_(self.is_finished, -1e10) + + self.is_finished = self.is_finished.to('cpu') + self.top_beam_finished |= self.is_finished[:, 0].eq(1) + predictions = self.alive_seq.view(_B_old, self.beam_size, step) + attention = ( + self.alive_attn.view( + step - 1, _B_old, self.beam_size, self.alive_attn.size(-1)) + if self.alive_attn is not None else None) + non_finished_batch = [] + for i in range(self.is_finished.size(0)): + b = self._batch_offset[i] + finished_hyp = self.is_finished[i].nonzero(as_tuple=False).view(-1) + # Store finished hypothesis for this batch. + for j in finished_hyp: # Beam level: finished beam j in batch i + self.hypotheses[b].append(( + self.topk_scores[i, j], + predictions[i, j, 1:], # Ignore start token + attention[:, i, j, :self.memory_length] + if attention is not None else None)) + # End condition is the top beam finished and we can return + # n_best hypotheses. + finish_flag = self.top_beam_finished[i] != 0 + if finish_flag and len(self.hypotheses[b]) >= self.n_best: + best_hyp = sorted( + self.hypotheses[b], key=lambda x: x[0], reverse=True) + for n, (score, pred, attn) in enumerate(best_hyp): + if n >= self.n_best: + break + self.scores[b].append(score.item()) + self.predictions[b].append(pred) + self.attention[b].append( + attn if attn is not None else []) + else: + non_finished_batch.append(i) + non_finished = torch.tensor(non_finished_batch) + + if len(non_finished) == 0: + self.done = True + return + + _B_new = non_finished.shape[0] + # Remove finished batches for the next step + self.top_beam_finished = self.top_beam_finished.index_select( + 0, non_finished) + self._batch_offset = self._batch_offset.index_select(0, non_finished) + non_finished = non_finished.to(self.topk_ids.device) + self.topk_log_probs = self.topk_log_probs.index_select( + 0, non_finished) + self._batch_index = self._batch_index.index_select(0, non_finished) + self.select_indices = self._batch_index.view(_B_new * self.beam_size) + self.alive_seq = predictions.index_select(0, non_finished) \ + .view(-1, self.alive_seq.size(-1)) + self.topk_scores = self.topk_scores.index_select(0, non_finished) + self.topk_ids = self.topk_ids.index_select(0, non_finished) + + if self.alive_attn is not None: + inp_seq_len = self.alive_attn.size(-1) + self.alive_attn = attention.index_select(1, non_finished) \ + .view(step - 1, _B_new * self.beam_size, inp_seq_len) + diff --git a/rxnscribe/inference/decode_strategy.py b/rxnscribe/inference/decode_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..394c83e951c29ed9b19f41d16a22eef3e802ab35 --- /dev/null +++ b/rxnscribe/inference/decode_strategy.py @@ -0,0 +1,60 @@ +import torch +from copy import deepcopy + + +class DecodeStrategy(object): + def __init__(self, pad, bos, eos, batch_size, parallel_paths, min_length, max_length, + return_attention=False, return_hidden=False): + self.pad = pad + self.bos = bos + self.eos = eos + + self.batch_size = batch_size + self.parallel_paths = parallel_paths + # result catching + self.predictions = [[] for _ in range(batch_size)] + self.scores = [[] for _ in range(batch_size)] + self.attention = [[] for _ in range(batch_size)] + self.hidden = [[] for _ in range(batch_size)] + + self.alive_attn = None + self.alive_hidden = None + + self.min_length = min_length + self.max_length = max_length + + n_paths = batch_size * parallel_paths + self.return_attention = return_attention + self.return_hidden = return_hidden + + self.done = False + + def initialize(self, memory_bank, device=None): + if device is None: + device = torch.device('cpu') + self.alive_seq = torch.full( + [self.batch_size * self.parallel_paths, 1], self.bos, + dtype=torch.long, device=device) + self.is_finished = torch.zeros( + [self.batch_size, self.parallel_paths], + dtype=torch.uint8, device=device) + + return None, memory_bank + + def __len__(self): + return self.alive_seq.shape[1] + + def ensure_min_length(self, log_probs): + if len(self) <= self.min_length: + log_probs[:, self.eos] = -1e20 # forced non-end + + def ensure_max_length(self): + if len(self) == self.max_length + 1: + self.is_finished.fill_(1) + + def advance(self, log_probs, attn): + raise NotImplementedError() + + def update_finished(self): + raise NotImplementedError + diff --git a/rxnscribe/inference/greedy_search.py b/rxnscribe/inference/greedy_search.py new file mode 100644 index 0000000000000000000000000000000000000000..3878dcc620caa71a298133fc35358abe4c34ad95 --- /dev/null +++ b/rxnscribe/inference/greedy_search.py @@ -0,0 +1,123 @@ +import torch +from .decode_strategy import DecodeStrategy + + +def sample_with_temperature(logits, sampling_temp, keep_topk): + """Select next tokens randomly from the top k possible next tokens. + + Samples from a categorical distribution over the ``keep_topk`` words using + the category probabilities ``logits / sampling_temp``. + """ + + if sampling_temp == 0.0 or keep_topk == 1: + # argmax + topk_scores, topk_ids = logits.topk(1, dim=-1) + if sampling_temp > 0: + topk_scores /= sampling_temp + else: + logits = torch.div(logits, sampling_temp) + if keep_topk > 0: + top_values, top_indices = torch.topk(logits, keep_topk, dim=1) + kth_best = top_values[:, -1].view([-1, 1]) + kth_best = kth_best.repeat([1, logits.shape[1]]).float() + ignore = torch.lt(logits, kth_best) + logits = logits.masked_fill(ignore, -10000) + + dist = torch.distributions.Multinomial(logits=logits, total_count=1) + topk_ids = torch.argmax(dist.sample(), dim=1, keepdim=True) + topk_scores = logits.gather(dim=1, index=topk_ids) + + return topk_ids, topk_scores + + +class GreedySearch(DecodeStrategy): + """Select next tokens randomly from the top k possible next tokens. + """ + + def __init__(self, pad, bos, eos, batch_size, min_length, max_length, + return_attention=False, return_hidden=False, sampling_temp=1, keep_topk=1): + super().__init__( + pad, bos, eos, batch_size, 1, min_length, max_length, return_attention, return_hidden) + self.sampling_temp = sampling_temp + self.keep_topk = keep_topk + self.topk_scores = None + + def initialize(self, memory_bank, device=None): + fn_map_state = None + + if device is None: + device = memory_bank.device + + self.memory_length = memory_bank.size(1) + super().initialize(memory_bank, device) + + self.select_indices = torch.arange( + self.batch_size, dtype=torch.long, device=device) + self.original_batch_idx = torch.arange( + self.batch_size, dtype=torch.long, device=device) + + return fn_map_state, memory_bank + + @property + def current_predictions(self): + return self.alive_seq[:, -1] + + @property + def batch_offset(self): + return self.select_indices + + def _pick(self, log_probs): + """Function used to pick next tokens. + """ + topk_ids, topk_scores = sample_with_temperature( + log_probs, self.sampling_temp, self.keep_topk) + return topk_ids, topk_scores + + def advance(self, log_probs, attn=None, hidden=None, label=None): + """Select next tokens randomly from the top k possible next tokens. + """ + self.ensure_min_length(log_probs) + topk_ids, self.topk_scores = self._pick(log_probs) + self.is_finished = topk_ids.eq(self.eos) + if label is not None: + label = label.view_as(self.is_finished) + self.is_finished = label.eq(self.eos) + self.alive_seq = torch.cat([self.alive_seq, topk_ids], -1) + + if self.return_attention: + if self.alive_attn is None: + self.alive_attn = attn + else: + self.alive_attn = torch.cat([self.alive_attn, attn], 1) + if self.return_hidden: + if self.alive_hidden is None: + self.alive_hidden = hidden + else: + self.alive_hidden = torch.cat([self.alive_hidden, hidden], 1) + self.ensure_max_length() + + def update_finished(self): + """Finalize scores and predictions.""" + finished_batches = self.is_finished.view(-1).nonzero() + for b in finished_batches.view(-1): + b_orig = self.original_batch_idx[b] + # scores/predictions/attention are lists, + # (to be compatible with beam-search) + self.scores[b_orig].append(self.topk_scores[b, 0].item()) + self.predictions[b_orig].append(self.alive_seq[b, 1:]) + self.attention[b_orig].append( + self.alive_attn[b, :, :self.memory_length] if self.alive_attn is not None else []) + self.hidden[b_orig].append( + self.alive_hidden[b, :] if self.alive_hidden is not None else []) + self.done = self.is_finished.all() + if self.done: + return + is_alive = ~self.is_finished.view(-1) + self.alive_seq = self.alive_seq[is_alive] + if self.alive_attn is not None: + self.alive_attn = self.alive_attn[is_alive] + if self.alive_hidden is not None: + self.alive_hidden = self.alive_hidden[is_alive] + self.select_indices = is_alive.nonzero().view(-1) + self.original_batch_idx = self.original_batch_idx[is_alive] + # select_indices is equal to original_batch_idx for greedy search? diff --git a/rxnscribe/interface.py b/rxnscribe/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..1b520c588321a86284e759646711765583c2e7cc --- /dev/null +++ b/rxnscribe/interface.py @@ -0,0 +1,299 @@ +import os +import argparse +from typing import List +import PIL +import torch +from torch.profiler import profile, record_function, ProfilerActivity +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.backends.backend_agg import FigureCanvasAgg + +from .pix2seq import build_pix2seq_model +from .tokenizer import get_tokenizer +from .dataset import make_transforms +from .data import postprocess_reactions, postprocess_bboxes, postprocess_coref_results, ReactionImageData, ImageData, CorefImageData + +from molscribe import MolScribe +from huggingface_hub import hf_hub_download +import easyocr + + +class RxnScribe: + + def __init__(self, model_path, device=None): + """ + RxnScribe Interface + :param model_path: path of the model checkpoint. + :param device: torch device, defaults to be CPU. + """ + args = self._get_args() + args.format = 'reaction' + states = torch.load(model_path, map_location=torch.device('cpu')) + if device is None: + device = torch.device('cpu') + self.device = device + self.tokenizer = get_tokenizer(args) + self.model = self.get_model(args, self.tokenizer, self.device, states['state_dict']) + self.transform = make_transforms('test', augment=False, debug=False) + self.molscribe = self.get_molscribe() + self.ocr_model = self.get_ocr_model() + + def _get_args(self): + parser = argparse.ArgumentParser() + # * Backbone + parser.add_argument('--backbone', default='resnet50', type=str, + help="Name of the convolutional backbone to use") + parser.add_argument('--dilation', action='store_true', + help="If true, we replace stride with dilation in the last convolutional block (DC5)") + parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), + help="Type of positional embedding to use on top of the image features") + # * Transformer + parser.add_argument('--enc_layers', default=6, type=int, help="Number of encoding layers in the transformer") + parser.add_argument('--dec_layers', default=6, type=int, help="Number of decoding layers in the transformer") + parser.add_argument('--dim_feedforward', default=1024, type=int, + help="Intermediate size of the feedforward layers in the transformer blocks") + parser.add_argument('--hidden_dim', default=256, type=int, + help="Size of the embeddings (dimension of the transformer)") + parser.add_argument('--dropout', default=0.1, type=float, help="Dropout applied in the transformer") + parser.add_argument('--nheads', default=8, type=int, + help="Number of attention heads inside the transformer's attentions") + parser.add_argument('--pre_norm', action='store_true') + # Data + parser.add_argument('--format', type=str, default='reaction') + parser.add_argument('--input_size', type=int, default=1333) + + args = parser.parse_args([]) + args.pix2seq = True + args.pix2seq_ckpt = None + args.pred_eos = True + args.is_coco = False + args.use_hf_transformer = False + return args + + def get_model(self, args, tokenizer, device, model_states): + def remove_prefix(state_dict): + return {k.replace('model.', ''): v for k, v in state_dict.items()} + + model = build_pix2seq_model(args, tokenizer[args.format]) + model.load_state_dict(remove_prefix(model_states), strict=False) + model.to(device) + model.eval() + return model + + def get_molscribe(self): + ckpt_path = hf_hub_download("yujieq/MolScribe", "swin_base_char_aux_1m.pth") + molscribe = MolScribe(ckpt_path, device=self.device) + return molscribe + + def get_ocr_model(self): + reader = easyocr.Reader(['en'], gpu=(self.device.type == 'cuda')) + return reader + + def predict_images(self, input_images: List, batch_size=16, molscribe=False, ocr=False): + # images: a list of PIL images + device = self.device + tokenizer = self.tokenizer['reaction'] + predictions = [] + for idx in range(0, len(input_images), batch_size): + batch_images = input_images[idx:idx+batch_size] + images, refs = zip(*[self.transform(image) for image in batch_images]) + images = torch.stack(images, dim=0).to(device) + with torch.no_grad(): + pred_seqs, pred_scores = self.model(images, max_len=tokenizer.max_len) + for i, (seqs, scores) in enumerate(zip(pred_seqs, pred_scores)): + reactions = tokenizer.sequence_to_data(seqs.tolist(), scores.tolist(), scale=refs[i]['scale']) + reactions = postprocess_reactions( + reactions, + image=input_images[i], + molscribe=self.molscribe if molscribe else None, + ocr=self.ocr_model if ocr else None + ) + predictions.append(reactions) + return predictions + + def predict_image(self, image, **kwargs): + predictions = self.predict_images([image], **kwargs) + return predictions[0] + + def predict_image_files(self, image_files: List, **kwargs): + input_images = [] + for path in image_files: + image = PIL.Image.open(path).convert("RGB") + input_images.append(image) + return self.predict_images(input_images, **kwargs) + + def predict_image_file(self, image_file: str, **kwargs): + predictions = self.predict_image_files([image_file], **kwargs) + return predictions[0] + + def draw_predictions(self, predictions, image=None, image_file=None): + results = [] + assert image or image_file + data = ReactionImageData(predictions=predictions, image=image, image_file=image_file) + h, w = np.array([data.height, data.width]) * 10 / max(data.height, data.width) + for r in data.pred_reactions: + fig, ax = plt.subplots(figsize=(w, h)) + fig.tight_layout() + canvas = FigureCanvasAgg(fig) + ax.imshow(data.image) + ax.axis('off') + r.draw(ax) + canvas.draw() + buf = canvas.buffer_rgba() + results.append(np.asarray(buf)) + plt.close(fig) + return results + + def draw_predictions_combined(self, predictions, image=None, image_file=None): + assert image or image_file + data = ReactionImageData(predictions=predictions, image=image, image_file=image_file) + h, w = np.array([data.height, data.width]) * 10 / max(data.height, data.width) + n = len(data.pred_reactions) + fig, axes = plt.subplots(n, 1, figsize=(w, h * n)) + if n == 1: + axes = [axes] + fig.tight_layout(rect=(0.02, 0.02, 0.99, 0.99)) + canvas = FigureCanvasAgg(fig) + for i, r in enumerate(data.pred_reactions): + ax = axes[i] + ax.imshow(data.image) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_title(f'reaction # {i}', fontdict={'fontweight': 'bold', 'fontsize': 14}) + r.draw(ax) + canvas.draw() + buf = canvas.buffer_rgba() + result_image = np.asarray(buf) + plt.close(fig) + return result_image + +class MolDetect: + + def __init__(self, model_path, device = None, coref = False): + """ + MolDetect Interface + :param model_path: path of the model checkpoint. + :param device: torch device, defaults to be CPU. + """ + args = self._get_args() + if not coref: args.format = 'bbox' + else: args.format = 'coref' + states = torch.load(model_path, map_location = torch.device('cpu')) + if device is None: + device = torch.device('cpu') + self.device = device + self.tokenizer = get_tokenizer(args) + self.model = self.get_model(args, self.tokenizer, self.device, states['state_dict']) + self.transform = make_transforms('test', augment=False, debug=False) + self.ocr_model = self.get_ocr_model() + self.molscribe = self.get_molscribe() + + def _get_args(self): + parser = argparse.ArgumentParser() + # * Backbone + parser.add_argument('--backbone', default='resnet50', type=str, + help="Name of the convolutional backbone to use") + parser.add_argument('--dilation', action='store_true', + help="If true, we replace stride with dilation in the last convolutional block (DC5)") + parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), + help="Type of positional embedding to use on top of the image features") + # * Transformer + parser.add_argument('--enc_layers', default=6, type=int, help="Number of encoding layers in the transformer") + parser.add_argument('--dec_layers', default=6, type=int, help="Number of decoding layers in the transformer") + parser.add_argument('--dim_feedforward', default=1024, type=int, + help="Intermediate size of the feedforward layers in the transformer blocks") + parser.add_argument('--hidden_dim', default=256, type=int, + help="Size of the embeddings (dimension of the transformer)") + parser.add_argument('--dropout', default=0.1, type=float, help="Dropout applied in the transformer") + parser.add_argument('--nheads', default=8, type=int, + help="Number of attention heads inside the transformer's attentions") + parser.add_argument('--pre_norm', action='store_true') + # Data + parser.add_argument('--format', type=str, default='reaction') + parser.add_argument('--input_size', type=int, default=1333) + + args = parser.parse_args([]) + args.pix2seq = True + args.pix2seq_ckpt = None + args.pred_eos = True + args.is_coco = False + args.use_hf_transformer = True + return args + + + def get_model(self, args, tokenizer, device, model_states): + def remove_prefix(state_dict): + return {k.replace('model.', ''): v for k, v in state_dict.items()} + + model = build_pix2seq_model(args, tokenizer[args.format]) + model.load_state_dict(remove_prefix(model_states), strict=False) + model.to(device) + model.eval() + return model + + def get_molscribe(self): + ckpt_path = hf_hub_download("yujieq/MolScribe", "swin_base_char_aux_1m.pth") + molscribe = MolScribe(ckpt_path, device=self.device) + return molscribe + + def get_ocr_model(self): + reader = easyocr.Reader(['en'], gpu = (self.device.type == 'cuda')) + return reader + + def predict_images(self, input_images: List, batch_size = 16, molscribe = False, coref = False, ocr = False): + device = self.device + if not coref: + tokenizer = self.tokenizer['bbox'] + else: + tokenizer = self.tokenizer['coref'] + predictions = [] + for idx in range(0, len(input_images), batch_size): + batch_images = input_images[idx:idx+batch_size] + images, refs = zip(*[self.transform(image) for image in batch_images]) + images = torch.stack(images, dim=0).to(device) + with torch.no_grad(): + pred_seqs, pred_scores = self.model(images, max_len=tokenizer.max_len) + for i, (seqs, scores) in enumerate(zip(pred_seqs, pred_scores)): + bboxes = tokenizer.sequence_to_data(seqs.tolist(), scores.tolist(), scale=refs[i]['scale']) + if coref: + bboxes = postprocess_coref_results(bboxes, image = input_images[i], molscribe = self.molscribe if molscribe else None, ocr = self.ocr_model if ocr else None) + if not coref: + bboxes = postprocess_bboxes(bboxes, image = input_images[i], molscribe = self.molscribe if molscribe else None) + predictions.append(bboxes) + return predictions + + def predict_image(self, image, molscribe = False, coref = False, ocr = False): + predictions = self.predict_images([image], molscribe = molscribe, coref = coref, ocr = ocr) + return predictions[0] + + def predict_image_files(self, image_files: List, batch_size = 16, molscribe = False, coref = False, ocr = False): + input_images = [] + for path in image_files: + image = PIL.Image.open(path).convert("RGB") + input_images.append(image) + return self.predict_images(input_images, batch_size = batch_size, molscribe = molscribe, coref = coref, ocr = ocr) + + def predict_image_file(self, image_file: str, molscribe = False, coref = False, ocr = False, **kwargs): + predictions = self.predict_image_files([image_file], molscribe = molscribe, coref = coref, ocr = ocr) + return predictions[0] + + def draw_bboxes(self, predictions, image=None, image_file=None, coref = False): + results = [] + assert image or image_file + if not coref: data = ImageData(predictions = predictions, image = image, image_file = image_file) + else: data = CorefImageData(predictions = predictions['bboxes'], image = image, image_file = image_file) + h, w = np.array([data.height, data.width]) * 10 / max(data.height, data.width) + fig, ax = plt.subplots(figsize = (w, h)) + fig.tight_layout() + canvas = FigureCanvasAgg(fig) + ax.imshow(data.image) + ax.axis('off') + data.draw_prediction(ax, data.image) + canvas.draw() + buf = canvas.buffer_rgba() + results.append(np.asarray(buf)) + plt.close(fig) + return results + + + diff --git a/rxnscribe/loss.py b/rxnscribe/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..97005bf9ffb9d95bb5745ce1d4e79f0440edeff9 --- /dev/null +++ b/rxnscribe/loss.py @@ -0,0 +1,96 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LabelSmoothingLoss(nn.Module): + """ + With label smoothing, + KL-divergence between q_{smoothed ground truth prob.}(w) + and p_{prob. computed by model}(w) is minimized. + """ + def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100): + assert 0.0 < label_smoothing <= 1.0 + self.ignore_index = ignore_index + super(LabelSmoothingLoss, self).__init__() + + smoothing_value = label_smoothing / (tgt_vocab_size - 2) + one_hot = torch.full((tgt_vocab_size,), smoothing_value) + one_hot[self.ignore_index] = 0 + self.register_buffer('one_hot', one_hot.unsqueeze(0)) + + self.confidence = 1.0 - label_smoothing + + def forward(self, output, target): + """ + output (FloatTensor): batch_size x n_classes + target (LongTensor): batch_size + """ + # assuming output is raw logits + # convert to log_probs + log_probs = F.log_softmax(output, dim=-1) + + model_prob = self.one_hot.repeat(target.size(0), 1) + model_prob.scatter_(1, target.unsqueeze(1), self.confidence) + model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0) + + # reduction mean or sum? + return F.kl_div(log_probs, model_prob, reduction='batchmean') + + +class SequenceLoss(nn.Module): + + def __init__(self, label_smoothing, vocab_size, ignore_index=-100, ignore_indices=[], punish_first = False): + super(SequenceLoss, self).__init__() + if ignore_indices: + ignore_index = ignore_indices[0] + self.ignore_index = ignore_index + self.ignore_indices = ignore_indices + self.punish_first = punish_first + if label_smoothing == 0: + self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='mean') + # Cross entropy = KL divergence + constant + else: + self.criterion = LabelSmoothingLoss(label_smoothing, vocab_size, ignore_index) + + def forward(self, output, target): + """ + :param output: [batch, len, vocab] + :param target: [batch, len] + :return: + """ + batch_size, max_len, vocab_size = output.size() + output = output.reshape(-1, vocab_size) + target = target.reshape(-1) + for idx in self.ignore_indices: + if idx != self.ignore_index: + target.masked_fill_((target == idx), self.ignore_index) + if self.punish_first: + loss = 10* self.criterion(output[:1, :], target[:1]) +self.criterion(output, target) + else: + loss = self.criterion(output, target) + return loss + + +class Criterion(nn.Module): + + def __init__(self, args, tokenizer): + super(Criterion, self).__init__() + criterion = {} + format = args.format + tn = tokenizer[format] + criterion[format] = SequenceLoss(args.label_smoothing, len(tn), ignore_index=tn.PAD_ID, punish_first = args.punish_first) + self.criterion = nn.ModuleDict(criterion) + + def forward(self, results, refs): + losses = {} + for format_ in results: + predictions, targets, *_ = results[format_] + loss_ = self.criterion[format_](predictions, targets) + if type(loss_) is dict: + losses.update(loss_) + else: + if loss_.numel() > 1: + loss_ = loss_.mean() + losses[format_] = loss_ + return losses diff --git a/rxnscribe/model.py b/rxnscribe/model.py new file mode 100644 index 0000000000000000000000000000000000000000..db197725949cd19ee82521399dc81d5c65e3c019 --- /dev/null +++ b/rxnscribe/model.py @@ -0,0 +1,260 @@ +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import timm + +from .inference import GreedySearch, BeamSearch +from .transformer import TransformerDecoder, Embeddings + + +class Encoder(nn.Module): + def __init__(self, args, pretrained=False): + super().__init__() + model_name = args.encoder + self.model_name = model_name + if model_name.startswith('resnet'): + self.model_type = 'resnet' + self.cnn = timm.create_model(model_name, pretrained=pretrained) + self.n_features = self.cnn.num_features # encoder_dim + self.cnn.global_pool = nn.Identity() + self.cnn.fc = nn.Identity() + elif model_name.startswith('swin'): + self.model_type = 'swin' + self.transformer = timm.create_model(model_name, pretrained=pretrained, pretrained_strict=False, + use_checkpoint=args.use_checkpoint) + self.n_features = self.transformer.num_features + self.transformer.head = nn.Identity() + elif 'efficientnet' in model_name: + self.model_type = 'efficientnet' + self.cnn = timm.create_model(model_name, pretrained=pretrained) + self.n_features = self.cnn.num_features + self.cnn.global_pool = nn.Identity() + self.cnn.classifier = nn.Identity() + else: + raise NotImplemented + + def swin_forward(self, transformer, x): + x = transformer.patch_embed(x) + if transformer.absolute_pos_embed is not None: + x = x + transformer.absolute_pos_embed + x = transformer.pos_drop(x) + + def layer_forward(layer, x, hiddens): + for blk in layer.blocks: + if not torch.jit.is_scripting() and layer.use_checkpoint: + x = torch.utils.checkpoint.checkpoint(blk, x) + else: + x = blk(x) + H, W = layer.input_resolution + B, L, C = x.shape + hiddens.append(x.view(B, H, W, C)) + if layer.downsample is not None: + x = layer.downsample(x) + return x, hiddens + + hiddens = [] + for layer in transformer.layers: + x, hiddens = layer_forward(layer, x, hiddens) + x = transformer.norm(x) # B L C + hiddens[-1] = x.view_as(hiddens[-1]) + return x, hiddens + + def forward(self, x, refs=None): + if self.model_type in ['resnet', 'efficientnet']: + features = self.cnn(x) + features = features.permute(0, 2, 3, 1) + hiddens = [] + elif self.model_type == 'swin': + if 'patch' in self.model_name: + features, hiddens = self.swin_forward(self.transformer, x) + else: + features, hiddens = self.transformer(x) + else: + raise NotImplemented + return features, hiddens + + +class TransformerDecoderBase(nn.Module): + + def __init__(self, args): + super().__init__() + self.args = args + + self.enc_trans_layer = nn.Sequential( + nn.Linear(args.encoder_dim, args.dec_hidden_size) + # nn.LayerNorm(args.dec_hidden_size, eps=1e-6) + ) + self.enc_pos_emb = nn.Embedding(144, args.encoder_dim) if args.enc_pos_emb else None + + self.decoder = TransformerDecoder( + num_layers=args.dec_num_layers, + d_model=args.dec_hidden_size, + heads=args.dec_attn_heads, + d_ff=args.dec_hidden_size * 4, + copy_attn=False, + self_attn_type="scaled-dot", + dropout=args.hidden_dropout, + attention_dropout=args.attn_dropout, + max_relative_positions=args.max_relative_positions, + aan_useffn=False, + full_context_alignment=False, + alignment_layer=0, + alignment_heads=0, + pos_ffn_activation_fn='gelu' + ) + + def enc_transform(self, encoder_out): + batch_size = encoder_out.size(0) + encoder_dim = encoder_out.size(-1) + encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim) + max_len = encoder_out.size(1) + device = encoder_out.device + if self.enc_pos_emb: + pos_emb = self.enc_pos_emb(torch.arange(max_len, device=device)).unsqueeze(0) + encoder_out = encoder_out + pos_emb + encoder_out = self.enc_trans_layer(encoder_out) + return encoder_out + + +class TransformerDecoderAR(TransformerDecoderBase): + + def __init__(self, args, tokenizer): + super().__init__(args) + self.tokenizer = tokenizer + self.vocab_size = len(self.tokenizer) + self.output_layer = nn.Linear(args.dec_hidden_size, self.vocab_size, bias=True) + self.embeddings = Embeddings( + word_vec_size=args.dec_hidden_size, + word_vocab_size=self.vocab_size, + word_padding_idx=tokenizer.PAD_ID, + position_encoding=True, + dropout=args.hidden_dropout) + + def dec_embedding(self, tgt, step=None): + pad_idx = self.embeddings.word_padding_idx + tgt_pad_mask = tgt.data.eq(pad_idx).transpose(1, 2) # [B, 1, T_tgt] + emb = self.embeddings(tgt, step=step) + assert emb.dim() == 3 # batch x len x embedding_dim + return emb, tgt_pad_mask + + def forward(self, encoder_out, labels, label_lengths): + batch_size, max_len, _ = encoder_out.size() + memory_bank = self.enc_transform(encoder_out) + + tgt = labels.unsqueeze(-1) # (b, t, 1) + tgt_emb, tgt_pad_mask = self.dec_embedding(tgt) + dec_out, *_ = self.decoder(tgt_emb=tgt_emb, memory_bank=memory_bank, tgt_pad_mask=tgt_pad_mask) + + logits = self.output_layer(dec_out) # (b, t, h) -> (b, t, v) + return logits[:, :-1], labels[:, 1:], dec_out + + def decode(self, encoder_out, beam_size: int, n_best: int, min_length: int = 1, max_length: int = 256): + batch_size, max_len, _ = encoder_out.size() + memory_bank = self.enc_transform(encoder_out) + + if beam_size == 1: + decode_strategy = GreedySearch( + sampling_temp=0.0, keep_topk=1, batch_size=batch_size, min_length=min_length, max_length=max_length, + pad=self.tokenizer.PAD_ID, bos=self.tokenizer.SOS_ID, eos=self.tokenizer.EOS_ID, + return_attention=False, return_hidden=True) + else: + decode_strategy = BeamSearch( + beam_size=beam_size, n_best=n_best, batch_size=batch_size, min_length=min_length, max_length=max_length, + pad=self.tokenizer.PAD_ID, bos=self.tokenizer.SOS_ID, eos=self.tokenizer.EOS_ID, + return_attention=False) + + # adapted from onmt.translate.translator + results = { + "predictions": None, + "scores": None, + "attention": None + } + + # (2) prep decode_strategy. Possibly repeat src objects. + _, memory_bank = decode_strategy.initialize(memory_bank=memory_bank) + + # (3) Begin decoding step by step: + for step in range(decode_strategy.max_length): + tgt = decode_strategy.current_predictions.view(-1, 1, 1) + tgt_emb, tgt_pad_mask = self.dec_embedding(tgt) + dec_out, dec_attn, *_ = self.decoder(tgt_emb=tgt_emb, memory_bank=memory_bank, + tgt_pad_mask=tgt_pad_mask, step=step) + + attn = dec_attn.get("std", None) + + dec_logits = self.output_layer(dec_out) # [b, t, h] => [b, t, v] + dec_logits = dec_logits.squeeze(1) + log_probs = F.log_softmax(dec_logits, dim=-1) + + if self.tokenizer.output_constraint: + output_mask = [self.tokenizer.get_output_mask(id) for id in tgt.view(-1).tolist()] + output_mask = torch.tensor(output_mask, device=log_probs.device) + log_probs.masked_fill_(output_mask, -10000) + + decode_strategy.advance(log_probs, attn, dec_out) + any_finished = decode_strategy.is_finished.any() + if any_finished: + decode_strategy.update_finished() + if decode_strategy.done: + break + + select_indices = decode_strategy.select_indices + if any_finished: + # Reorder states. + memory_bank = memory_bank.index_select(0, select_indices) + self.map_state(lambda state, dim: state.index_select(dim, select_indices)) + + results["scores"] = decode_strategy.scores + results["predictions"] = decode_strategy.predictions + results["attention"] = decode_strategy.attention + results["hidden"] = decode_strategy.hidden + + return results["predictions"], results['scores'], results["hidden"] + + # adapted from onmt.decoders.transformer + def map_state(self, fn): + def _recursive_map(struct, batch_dim=0): + for k, v in struct.items(): + if v is not None: + if isinstance(v, dict): + _recursive_map(v) + else: + struct[k] = fn(v, batch_dim) + if self.decoder.state["cache"] is not None: + _recursive_map(self.decoder.state["cache"]) + + +class Decoder(nn.Module): + + def __init__(self, args, tokenizer): + super(Decoder, self).__init__() + self.args = args + self.formats = args.formats + self.tokenizer = tokenizer + decoder = {} + for format_ in args.formats: + decoder[format_] = TransformerDecoderAR(args, tokenizer[format_]) + self.decoder = nn.ModuleDict(decoder) + + def forward(self, encoder_out, hiddens, refs): + results = {} + for format_ in self.formats: + labels, label_lengths = refs[format_] + results[format_] = self.decoder[format_](encoder_out, labels, label_lengths) + return results + + def decode(self, encoder_out, hiddens, refs=None, beam_size=1, n_best=1): + results = {} + predictions = {} + beam_predictions = {} + for format_ in self.formats: + max_len = self.tokenizer[format_].max_len + results[format_] = self.decoder[format_].decode(encoder_out, beam_size, n_best, max_length=max_len) + outputs, scores, *_ = results[format_] + beam_preds = [[self.tokenizer[format_].sequence_to_data(x.tolist()) for x in pred] for pred in outputs] + beam_predictions[format_] = (beam_preds, scores) + predictions[format_] = [preds[0] for preds in beam_preds] + return predictions, beam_predictions diff --git a/rxnscribe/pix2seq/__init__.py b/rxnscribe/pix2seq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6fcc52bb49f98bdcc0fd2e0088c338bbbe65c2a8 --- /dev/null +++ b/rxnscribe/pix2seq/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from .pix2seq import build_pix2seq_model diff --git a/rxnscribe/pix2seq/__pycache__/__init__.cpython-310.pyc b/rxnscribe/pix2seq/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9aede85a7ed31432df74cee9bacb2625c2fc19ea Binary files /dev/null and b/rxnscribe/pix2seq/__pycache__/__init__.cpython-310.pyc differ diff --git a/rxnscribe/pix2seq/__pycache__/__init__.cpython-38.pyc b/rxnscribe/pix2seq/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fd339c3ca8624dc33c5bc1709a64995ac0be9e0 Binary files /dev/null and b/rxnscribe/pix2seq/__pycache__/__init__.cpython-38.pyc differ diff --git a/rxnscribe/pix2seq/__pycache__/attention_layer.cpython-310.pyc b/rxnscribe/pix2seq/__pycache__/attention_layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c0bfcb720c33098ddf285e8b5c95e5982abcc61 Binary files /dev/null and b/rxnscribe/pix2seq/__pycache__/attention_layer.cpython-310.pyc differ diff --git a/rxnscribe/pix2seq/__pycache__/attention_layer.cpython-38.pyc b/rxnscribe/pix2seq/__pycache__/attention_layer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbf94d6f4c486a66ad255b77f37695a5eb1f71e8 Binary files /dev/null and b/rxnscribe/pix2seq/__pycache__/attention_layer.cpython-38.pyc differ diff --git a/rxnscribe/pix2seq/__pycache__/backbone.cpython-310.pyc b/rxnscribe/pix2seq/__pycache__/backbone.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b85ca666048f769bc96224256d17be219183a5a Binary files /dev/null and b/rxnscribe/pix2seq/__pycache__/backbone.cpython-310.pyc differ diff --git a/rxnscribe/pix2seq/__pycache__/backbone.cpython-38.pyc b/rxnscribe/pix2seq/__pycache__/backbone.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1501ab1967acf3981a5e06288d54b34b436215ef Binary files /dev/null and b/rxnscribe/pix2seq/__pycache__/backbone.cpython-38.pyc differ diff --git a/rxnscribe/pix2seq/__pycache__/misc.cpython-310.pyc b/rxnscribe/pix2seq/__pycache__/misc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..245a6f99e6b584d075c69bdc58aa8335c5fd1b5f Binary files /dev/null and b/rxnscribe/pix2seq/__pycache__/misc.cpython-310.pyc differ diff --git a/rxnscribe/pix2seq/__pycache__/misc.cpython-38.pyc b/rxnscribe/pix2seq/__pycache__/misc.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1aa0e3aae224966da1648d4170e09e5bff34c15e Binary files /dev/null and b/rxnscribe/pix2seq/__pycache__/misc.cpython-38.pyc differ diff --git a/rxnscribe/pix2seq/__pycache__/pix2seq.cpython-310.pyc b/rxnscribe/pix2seq/__pycache__/pix2seq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ee25e5c1e464a6c1146442d921c0fd83cd9c243 Binary files /dev/null and b/rxnscribe/pix2seq/__pycache__/pix2seq.cpython-310.pyc differ diff --git a/rxnscribe/pix2seq/__pycache__/pix2seq.cpython-38.pyc b/rxnscribe/pix2seq/__pycache__/pix2seq.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccd596f6924919f7facbb8082ffcd33b6aa1411e Binary files /dev/null and b/rxnscribe/pix2seq/__pycache__/pix2seq.cpython-38.pyc differ diff --git a/rxnscribe/pix2seq/__pycache__/position_encoding.cpython-310.pyc b/rxnscribe/pix2seq/__pycache__/position_encoding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c628a6c4c3bf92d4078fd53f82fca08c6fc211d0 Binary files /dev/null and b/rxnscribe/pix2seq/__pycache__/position_encoding.cpython-310.pyc differ diff --git a/rxnscribe/pix2seq/__pycache__/position_encoding.cpython-38.pyc b/rxnscribe/pix2seq/__pycache__/position_encoding.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2461d4e1181df170b730f5459284187ef60dbcc3 Binary files /dev/null and b/rxnscribe/pix2seq/__pycache__/position_encoding.cpython-38.pyc differ diff --git a/rxnscribe/pix2seq/__pycache__/transformer.cpython-310.pyc b/rxnscribe/pix2seq/__pycache__/transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9d294915efc72af674f0bfd6375e5a44c55f039 Binary files /dev/null and b/rxnscribe/pix2seq/__pycache__/transformer.cpython-310.pyc differ diff --git a/rxnscribe/pix2seq/__pycache__/transformer.cpython-38.pyc b/rxnscribe/pix2seq/__pycache__/transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b2c04a67d82b1c8d3f7a1e84de8b4d7ab8fafdd Binary files /dev/null and b/rxnscribe/pix2seq/__pycache__/transformer.cpython-38.pyc differ diff --git a/rxnscribe/pix2seq/attention_layer.py b/rxnscribe/pix2seq/attention_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..7ddf98672787801d999645ced04887d0c13de31c --- /dev/null +++ b/rxnscribe/pix2seq/attention_layer.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, dropout=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3) + self.attn_drop = nn.Dropout(dropout) + self.proj = nn.Linear(dim, dim) + + def forward(self, x, pre_kv=None, attn_mask=None): + N, B, C = x.shape + qkv = self.qkv(x).reshape(N, B, 3, self.num_heads, C // self.num_heads).permute(2, 1, 3, 0, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + if not self.training: + k = torch.cat([pre_kv[0], k], dim=2) + v = torch.cat([pre_kv[1], v], dim=2) + pre_kv = torch.stack([k, v], dim=0) + + attn = (q @ k.transpose(-2, -1)) * self.scale + + if attn_mask is not None: + attn.masked_fill_(attn_mask, float('-inf')) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).permute(2, 0, 1, 3).reshape(N, B, C) + x = self.proj(x) + return x, pre_kv diff --git a/rxnscribe/pix2seq/backbone.py b/rxnscribe/pix2seq/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..f489ff23abeac88a624b95506c18ed2ebd1221d4 --- /dev/null +++ b/rxnscribe/pix2seq/backbone.py @@ -0,0 +1,119 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Backbone modules. +""" +from collections import OrderedDict + +import torch +import torch.nn.functional as F +import torchvision +from torch import nn +from torchvision.models._utils import IntermediateLayerGetter +from typing import Dict, List + +from .misc import NestedTensor, is_main_process +from .position_encoding import build_position_encoding + + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other models than torchvision.models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + num_batches_tracked_key = prefix + 'num_batches_tracked' + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = 1e-5 + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class BackboneBase(nn.Module): + + def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): + super().__init__() + for name, parameter in backbone.named_parameters(): + if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: + parameter.requires_grad_(False) + if return_interm_layers: + return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} + else: + return_layers = {'layer4': "0"} + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) + self.num_channels = num_channels + + def forward(self, tensor_list: NestedTensor): + xs = self.body(tensor_list.tensors) + out: Dict[str, NestedTensor] = {} + for name, x in xs.items(): + m = tensor_list.mask + assert m is not None + mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] + out[name] = NestedTensor(x, mask) + return out + + +class Backbone(BackboneBase): + """ResNet backbone with frozen BatchNorm.""" + def __init__(self, name: str, + train_backbone: bool, + return_interm_layers: bool, + dilation: bool): + backbone = getattr(torchvision.models, name)( + replace_stride_with_dilation=[False, False, dilation], + pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) + # weights="IMAGENET1K_V1" + num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 + super().__init__(backbone, train_backbone, num_channels, return_interm_layers) + + +class Joiner(nn.Sequential): + def __init__(self, backbone, position_embedding): + super().__init__(backbone, position_embedding) + + def forward(self, tensor_list: NestedTensor): + xs = self[0](tensor_list) + out: List[NestedTensor] = [] + pos = [] + for name, x in xs.items(): + out.append(x) + # position encoding + pos.append(self[1](x).to(x.tensors.dtype)) + + return out, pos + + +def build_backbone(args): + position_embedding = build_position_encoding(args) + train_backbone = True + return_interm_layers = False + backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) + model = Joiner(backbone, position_embedding) + model.num_channels = backbone.num_channels + return model diff --git a/rxnscribe/pix2seq/misc.py b/rxnscribe/pix2seq/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..bc430d31d4c91c29d4caa10452015364f79d1548 --- /dev/null +++ b/rxnscribe/pix2seq/misc.py @@ -0,0 +1,604 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import os +import subprocess +import time +from collections import defaultdict, deque +import datetime +import pickle +from typing import Optional, List + +import torch +import torch.distributed as dist +from torch import Tensor +from bisect import bisect_right +from torch.optim.lr_scheduler import _LRScheduler + +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + if torch.cuda.is_available(): + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}' + ]) + else: + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ]) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() + sha = 'N/A' + diff = "clean" + branch = 'N/A' + try: + sha = _run(['git', 'rev-parse', 'HEAD']) + subprocess.check_output(['git', 'diff'], cwd=cwd) + diff = _run(['git', 'diff-index', 'HEAD']) + diff = "has uncommited changes" if diff else "clean" + branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +def collate_fn(batch): + batch = list(zip(*batch)) + if len(batch) > 2: + batch[0] = nested_tensor_from_tensor_list(batch[0] + batch[1], batch[2] + batch[3]) + return tuple([batch[0], batch[2] + batch[3]]) + else: + batch[0] = nested_tensor_from_tensor_list(batch[0], batch[1]) + return tuple(batch) + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor], target_list=None): + # TODO make this more general + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + if target_list is not None: + for img, pad_img, m, target in zip(tensor_list, tensor, mask, target_list): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + size = target["size"] + m[:size[0], :size[1]] = False + else: + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], :img.shape[2]] = False + else: + raise ValueError('not supported') + return NestedTensor(tensor, mask) + + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_local_size(): + if not is_dist_avail_and_initialized(): + return 1 + return int(os.environ['LOCAL_SIZE']) + + +def get_local_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return int(os.environ['LOCAL_RANK']) + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + args.dist_url = 'env://' + os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + addr = subprocess.getoutput( + 'scontrol show hostname {} | head -n1'.format(node_list)) + os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500') + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['RANK'] = str(proc_id) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['LOCAL_SIZE'] = str(num_gpus) + args.dist_url = 'env://' + args.world_size = ntasks + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + if target.numel() == 0: + return [torch.zeros([], device=output.device)] + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + if float(torchvision.__version__.split('+')[0][2:]) < 7.0: + if input.numel() > 0: + return torch.nn.functional.interpolate( + input, size, scale_factor, mode, align_corners + ) + + output_shape = _output_size(2, input, size, scale_factor) + output_shape = list(input.shape[:-2]) + list(output_shape) + return _new_empty_tensor(input, output_shape) + else: + return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) + + +class NoScaler: + state_dict_key = "no_scaler" + + def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False): + loss.backward() + if clip_grad is not None and clip_grad > 0: + assert parameters is not None + torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + optimizer.step() + + +class WarmupLinearDecayLR(_LRScheduler): + def __init__( + self, + optimizer: torch.optim.Optimizer, + warmup_factor: float = 0.001, + warmup_iters: int = 10, + warmup_method: str = "linear", + end_epoch: int = 300, + final_lr_factor: float = 0.003, + last_epoch: int = -1, + ): + """ + Multi Step LR with warmup + + Args: + optimizer (torch.optim.Optimizer): optimizer used. + warmup_factor (float): lr = warmup_factor * base_lr + warmup_iters (int): iters to warmup + warmup_method (str): warmup method in ["constant", "linear", "burnin"] + last_epoch(int): The index of last epoch. Default: -1. + """ + self.warmup_factor = warmup_factor + self.warmup_iters = warmup_iters + self.warmup_method = warmup_method + self.end_epoch = end_epoch + assert 0 < final_lr_factor < 1 + self.final_lr_factor = final_lr_factor + super().__init__(optimizer, last_epoch) + + def get_lr(self) -> List[float]: + warmup_factor = _get_warmup_factor_at_iter( + self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor) + linear_decay_factor = _get_lr_linear_decay_factor_at_iter( + self.last_epoch, self.warmup_iters, self.end_epoch, self.final_lr_factor) + return [ + base_lr * warmup_factor * linear_decay_factor for base_lr in self.base_lrs + ] + + def _get_closed_form_lr(self): + warmup_factor = _get_warmup_factor_at_iter( + self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor) + linear_decay_factor = _get_lr_linear_decay_factor_at_iter( + self.last_epoch, self.warmup_iters, self.end_epoch, self.final_lr_factor) + return [ + base_lr * warmup_factor * linear_decay_factor for base_lr in self.base_lrs + ] + + +def _get_lr_linear_decay_factor_at_iter(iter: int, start_epoch: int, end_epoch: int, + final_lr_factor: float): + assert iter <= end_epoch + if iter <= start_epoch: + return 1.0 + alpha = (iter - start_epoch) / (end_epoch - start_epoch) + lr_step = final_lr_factor * alpha + 1 - alpha + + return lr_step + + +def _get_warmup_factor_at_iter(method: str, iter: int, warmup_iters: int, + warmup_factor: float) -> float: + """ + Return the learning rate warmup factor at a specific iteration. + See https://arxiv.org/abs/1706.02677 for more details. + + Args: + method (str): warmup method; either "constant" or "linear". + iter (int): iteration at which to calculate the warmup factor. + warmup_iters (int): the number of warmup iterations. + warmup_factor (float): the base warmup factor (the meaning changes according + to the method used). + + Returns: + float: the effective warmup factor at the given iteration. + """ + if iter >= warmup_iters: + return 1.0 + + if method == "constant": + return warmup_factor + elif method == "linear": + alpha = iter / warmup_iters + return warmup_factor * (1 - alpha) + alpha + elif method == "burnin": + return (iter / warmup_iters)**4 + else: + raise ValueError("Unknown warmup method: {}".format(method)) diff --git a/rxnscribe/pix2seq/pix2seq.py b/rxnscribe/pix2seq/pix2seq.py new file mode 100644 index 0000000000000000000000000000000000000000..a047b3e19a37f6386cb50bdad4f08aeea310226e --- /dev/null +++ b/rxnscribe/pix2seq/pix2seq.py @@ -0,0 +1,217 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Pix2Seq model and criterion classes. +""" +import torch +from torch.profiler import profile, record_function, ProfilerActivity +import torch.nn.functional as F +from torch import nn + +from .misc import nested_tensor_from_tensor_list +from .backbone import build_backbone +from .transformer import build_transformer +from transformers import GenerationConfig + +import numpy as np + + +class Pix2Seq(nn.Module): + """ This is the Pix2Seq module that performs object detection """ + def __init__(self, backbone, transformer, use_hf = False): + """ Initializes the model. + Parameters: + backbone: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + num_classes: number of object classes + num_bins: number of bins for each side of the input image + """ + super().__init__() + self.transformer = transformer + hidden_dim = 256 if use_hf else transformer.d_model + self.input_proj = nn.Sequential( + nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=(1, 1)), + nn.GroupNorm(32, hidden_dim)) + self.backbone = backbone + + self.use_hf = use_hf + + + + def forward(self, image_tensor, targets=None, max_len=500, cheat = None): + """  + image_tensor: + The forward expects a NestedTensor, which consists of: + - samples.tensor: batched images, of shape [batch_size x 3 x H x W] + - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels + It returns a dict with the following elements: + - "pred_logits": the classification logits (including no-object) for all vocabulary. + Shape= [batch_size, num_sequence, num_vocal] + """ + + if isinstance(image_tensor, (list, torch.Tensor)): + image_tensor = nested_tensor_from_tensor_list(image_tensor) + features, pos = self.backbone(image_tensor) + #print(len(features)) + #print(pos.size()) + ''' + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], with_stack = True, record_shapes=True) as prof: + with record_function("model_inference"): + features, pos = self.backbone(image_tensor) + print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) + prof.export_stacks("/tmp/profiler_stacks_cuda_A6000_16_backbone.txt", "self_cuda_time_total") + ''' + src, mask = features[-1].decompose() + assert mask is not None + mask = torch.zeros_like(mask).bool() + + src = self.input_proj(src) + + + + + if self.use_hf: + if targets is not None: + ''' + logits = self.transformer(src) + + + + + + + + input_seq, input_len = targets + + logits = logits.reshape(-1, 2094) + + loss = self.loss_fn(logits, input_seq.view(-1)) + + return loss, loss + ''' + + + ''' + output_logits = self.transformer(src, input_seq[:, 1:], mask, pos[-1]) + return output_logits[:, :-1] + ''' + #print(input_seq) + input_seq, input_len = targets + input_seq = input_seq[:, 1:] + bs = src.shape[0] + src = src.flatten(2).permute(0, 2, 1) + #b x c x h x w to b x hw x c + pos_embed = pos[-1].flatten(2).permute(0, 2, 1) + max_len = input_seq.size(1) + indices = torch.arange(max_len).unsqueeze(0).expand_as(input_seq).to(src.device) + mask = indices >= input_len - torch.ones(input_len.shape).to(src.device) + masked_input_seq = input_seq.masked_fill(mask, -100) + #print("input_seq "+str(input_seq)) + #print("masked_input "+str(masked_input_seq)) + #src = src + pos_embed #unclear if this line is needed... + ''' + decoder_input = torch.cat( + [ + nn.Embedding(1, 256).to(src.device).weight.unsqueeze(0).repeat(bs, 1, 1), + nn.Embedding(2092, 256).to(src.device)(input_seq) + ], dim = 1 + ) + ''' + #decoder_mask = torch.full(decoder_input.shape[:2], False, dtype = torch.bool).to(src.device) + #decoder_mask[:, 0] = True + output = self.transformer(inputs_embeds = src,labels = masked_input_seq) + #print("output logits " + str(torch.argmax(output["logits"], dim = 2)) + "target labels "+ str(masked_input_seq)) + + #print(output["logits"].shape) + + return output["logits"], output["loss"] + else: + ''' + logits = self.transformer(src) + + print(logits.shape) + + return self.transformer(src).argmax(dim = 1), self.transformer(src).argmax(dim = 1) + ''' + + #with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], with_stack = True, record_shapes=True) as prof: + # with record_function("model_inference"): + #print(pos[-1]) + #output_seqs, output_scores = self.transformer(src, None, mask, pos[-1], max_len=max_len) + ''' + flatten src from B x C x H x W into B x HW x C and pass in as input_embeds + potentially flatten pos[-1] as well and add to input embeds + ''' + bs = src.shape[0] + src = src.flatten(2).permute(0, 2, 1) + generation_config = GenerationConfig(max_new_tokens = max_len, bos_token_id = 2002, eos_token_id = 2092, pad_token_id = 2001, output_hidden_states = True) + #output = self.transformer.generate(inputs_embeds = src, generation_config = generation_config, return_dict_in_generate=True, output_scores=True) + #transition_scores = self.transformer.compute_transition_scores(output.sequences, output.scores, normalize_logits=True) + #for tok, score in zip(output.sequences[0], transition_scores[0]): + # print(f"| {tok:5d} | {score.to('cpu').numpy():.3f} | {np.exp(score.to('cpu').numpy()):.2%}") + #print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) + #prof.export_stacks("/tmp/profiler_stacks_cpu_A6000_16_decoder.txt", "self_cpu_time_total") + #print("loss "+str(output.loss)) + + #encoder_outputs = self.transformer.encoder(inputs_embeds = src) + ''' + print(cheat) + print("own predictions") + print(cheat['coref'][0][:, :3]) + print(self.transformer.decoder(input_ids = cheat['coref'][0][:, :3].to(src.device), encoder_hidden_states = torch.rand_like(encoder_outputs[0]).to(src.device)).logits.argmax(dim = 2)) + print(self.transformer.decoder(input_ids = cheat['coref'][0][:, :4].to(src.device), encoder_hidden_states = torch.rand_like(encoder_outputs[0]).to(src.device)).logits.argmax(dim = 2)) + print(self.transformer.decoder(input_ids = cheat['coref'][0][:, :5].to(src.device), encoder_hidden_states = torch.rand_like(encoder_outputs[0]).to(src.device)).logits.argmax(dim = 2)) + ''' + + #input_seq, input_len = cheat['bbox'] + #input_seq = input_seq[:, 1:] + #b x c x h x w to b x hw x c + #max_len = input_seq.size(1) + #indices = torch.arange(max_len).unsqueeze(0).expand_as(input_seq).to(src.device) + #mask = indices >= input_len - torch.ones(input_len.shape).to(src.device) + #masked_input_seq = input_seq.masked_fill(mask, -100) + #output = self.transformer(inputs_embeds = src,labels = masked_input_seq) + #print("output logits " + str(torch.argmax(output["logits"], dim = 2)) + "target labels "+ str(masked_input_seq)) + outputs = self.transformer.generate(inputs_embeds = src, generation_config = generation_config) + + return outputs, outputs + else: + if targets is not None: + input_seq, input_len = targets + output_logits = self.transformer(src, input_seq[:, 1:], mask, pos[-1]) + return output_logits[:, :-1] + else: + output_seqs, output_scores = self.transformer(src, None, mask, pos[-1], max_len=max_len) + return output_seqs, output_scores + + + +def build_pix2seq_model(args, tokenizer): + # the `num_classes` naming here is somewhat misleading. + # it indeed corresponds to `max_obj_id + 1`, where max_obj_id + # is the maximum id for a class in your dataset. For example, + # COCO has a max_obj_id of 90, so we pass `num_classes` to be 91. + # As another example, for a dataset that has a single class with id 1, + # you should pass `num_classes` to be 2 (max_obj_id + 1). + # For more details on this, check the following discussion + # https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223 + + + + + backbone = build_backbone(args) + transformer = build_transformer(args, tokenizer) + + model = Pix2Seq(backbone, transformer, use_hf = args.use_hf_transformer) + + if args.pix2seq_ckpt is not None: + checkpoint = torch.load(args.pix2seq_ckpt, map_location='cpu') + if args.use_hf_transformer: + new_dict = {} + #print(checkpoint['state_dict'].keys()) + for key in checkpoint['state_dict']: + new_dict[key[6:]] = checkpoint['state_dict'][key] + model.load_state_dict(new_dict, strict = False) + else: + model.load_state_dict(checkpoint['model']) + + return model diff --git a/rxnscribe/pix2seq/position_encoding.py b/rxnscribe/pix2seq/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..189461a0f158e45bc2733a0c73b4ed1608ccd6f5 --- /dev/null +++ b/rxnscribe/pix2seq/position_encoding.py @@ -0,0 +1,89 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Various positional encodings for the transformer. +""" +import math +import torch +from torch import nn + +from .misc import NestedTensor + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + mask = tensor_list.mask + assert mask is not None + not_mask = torch.ones_like(mask, dtype=torch.bool) + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='trunc') / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) + return pos + + +def build_position_encoding(args): + N_steps = args.hidden_dim // 2 + if args.position_embedding in ('v2', 'sine'): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(N_steps) + elif args.position_embedding in ('v3', 'learned'): + position_embedding = PositionEmbeddingLearned(N_steps) + else: + raise ValueError(f"not supported {args.position_embedding}") + + return position_embedding diff --git a/rxnscribe/pix2seq/transformer.py b/rxnscribe/pix2seq/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..a4ffcf8e3f3102f56c0689c0e137915690fd7637 --- /dev/null +++ b/rxnscribe/pix2seq/transformer.py @@ -0,0 +1,390 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Pix2Seq Transformer class. + +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" +import copy +from typing import Optional, List + +import torch +import torch.nn.functional as F +from torch import nn, Tensor +from .attention_layer import Attention + +from transformers import EncoderDecoderConfig, EncoderDecoderModel, AutoConfig, BertConfig + + +class Transformer(nn.Module): + + def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, + num_decoder_layers=6, dim_feedforward=1024, dropout=0.1, + activation="relu", normalize_before=False, num_vocal=2094, + pred_eos=False, tokenizer=None): + super().__init__() + + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + decoder_layer = TransformerDecoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm) + self._reset_parameters() + + self.num_vocal = num_vocal + self.vocal_classifier = nn.Linear(d_model, num_vocal) + self.det_embed = nn.Embedding(1, d_model) + self.vocal_embed = nn.Embedding(self.num_vocal - 2, d_model) + self.pred_eos = pred_eos + + self.d_model = d_model + self.nhead = nhead + self.num_decoder_layers = num_decoder_layers + self.tokenizer = tokenizer + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, input_seq, mask, pos_embed, max_len=500): + """ + Args: + src: shape[B, C, H, W] + input_seq: shape[B, 501, C] for training and shape[B, 1, C] for inference + mask: shape[B, H, W] + pos_embed: shape[B, C, H, W] + """ + # flatten NxCxHxW to HWxNxC + bs = src.shape[0] + src = src.flatten(2).permute(2, 0, 1) + mask = mask.flatten(1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + pre_kv = [torch.as_tensor([[], []], device=memory.device) + for _ in range(self.num_decoder_layers)] + + if self.training: + input_seq = input_seq.clamp(max=self.num_vocal - 3) + input_embed = torch.cat( + [self.det_embed.weight.unsqueeze(0).repeat(bs, 1, 1), + self.vocal_embed(input_seq)], dim=1) + input_embed = input_embed.transpose(0, 1) + num_seq = input_embed.shape[0] + self_attn_mask = torch.triu(torch.ones((num_seq, num_seq)), diagonal=1).bool().to(input_embed.device) + hs, pre_kv = self.decoder( + input_embed, + memory, + memory_key_padding_mask=mask, + pos=pos_embed, + pre_kv_list=pre_kv, + self_attn_mask=self_attn_mask) + # hs: N x B x D + pred_seq_logits = self.vocal_classifier(hs.transpose(0, 1)) + return pred_seq_logits + else: + end = torch.zeros(bs).bool().to(memory.device) + end_lens = torch.zeros(bs).long().to(memory.device) + input_embed = self.det_embed.weight.unsqueeze(0).repeat(bs, 1, 1).transpose(0, 1) + states, pred_token = [None] * bs, [None] * bs + pred_seq, pred_scores = [], [] + for seq_i in range(max_len): + hs, pre_kv = self.decoder( + input_embed, + memory, + memory_key_padding_mask=mask, + pos=pos_embed, + pre_kv_list=pre_kv) + # hs: N x B x D + logits = self.vocal_classifier(hs.transpose(0, 1)) + log_probs = F.log_softmax(logits, dim=-1) + if self.tokenizer.output_constraint: + states, output_masks = self.tokenizer.update_states_and_masks(states, pred_token) + output_masks = torch.tensor(output_masks, device=logits.device).unsqueeze(1) + log_probs.masked_fill_(output_masks, -10000) + if not self.pred_eos: + log_probs[:, :, self.tokenizer.EOS_ID] = -10000 + + score, pred_token = log_probs.max(dim=-1) + pred_seq.append(pred_token) + + pred_scores.append(score) + + if self.pred_eos: + stop_state = pred_token.squeeze(1).eq(self.tokenizer.EOS_ID) + end_lens += seq_i * (~end * stop_state) + end = (stop_state + end).bool() + if end.all() and seq_i > 4: + break + + token = log_probs[:, :, :self.num_vocal - 2].argmax(dim=-1) + input_embed = self.vocal_embed(token.transpose(0, 1)) + + if not self.pred_eos: + end_lens = end_lens.fill_(max_len) + pred_seq = torch.cat(pred_seq, dim=1) + pred_seq = [seq[:end_idx] for end_idx, seq in zip(end_lens, pred_seq)] + pred_scores = torch.cat(pred_scores, dim=1) + pred_scores = [scores[:end_idx] for end_idx, scores in zip(end_lens, pred_scores)] + return pred_seq, pred_scores + + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, src, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + output = src + + for layer in self.layers: + output = layer(output, src_key_padding_mask=src_key_padding_mask, pos=pos) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(nn.Module): + + def __init__(self, decoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, tgt, memory, memory_key_padding_mask, pos, pre_kv_list=None, self_attn_mask=None): + output = tgt + cur_kv_list = [] + for layer, pre_kv in zip(self.layers, pre_kv_list): + output, cur_kv = layer( + output, + memory, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, + self_attn_mask=self_attn_mask, + pre_kv=pre_kv) + cur_kv_list.append(cur_kv) + + if self.norm is not None: + output = self.norm(output) + + return output, cur_kv_list + + +class TransformerEncoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + src, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn(q, k, value=src, key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre(self, src, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn(q, k, value=src2, key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward(self, src, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(src, src_key_padding_mask, pos) + return self.forward_post(src, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = Attention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + tgt, + memory, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + self_attn_mask: Optional[Tensor] = None, + pre_kv=None, + ): + tgt2, pre_kv = self.self_attn(tgt, pre_kv=pre_kv, attn_mask=self_attn_mask) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn( + query=tgt, + key=self.with_pos_embed(memory, pos), + value=memory, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt, pre_kv + + def forward_pre( + self, + tgt, + memory, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + self_attn_mask: Optional[Tensor] = None, + pre_kv=None, + ): + tgt2 = self.norm1(tgt) + tgt2, pre_kv = self.self_attn(tgt2, pre_kv=pre_kv, attn_mask=self_attn_mask) + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn( + query=tgt2, + key=self.with_pos_embed(memory, pos), + value=memory, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt, pre_kv + + def forward( + self, + tgt, + memory, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + self_attn_mask: Optional[Tensor] = None, + pre_kv=None, + ): + if self.normalize_before: + return self.forward_pre(tgt, memory, memory_key_padding_mask, pos, self_attn_mask, pre_kv) + return self.forward_post(tgt, memory, memory_key_padding_mask, pos, self_attn_mask, pre_kv) + + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def build_transformer(args, tokenizer): + if args.use_hf_transformer: + num_vocal = len(tokenizer) + encoder_config = BertConfig(max_position_embeddings = 1764, hidden_size = 256, num_attention_heads = 4, vocab_size = num_vocal, num_hidden_layers = 4, intermediate_size = 1024) + decoder_config = BertConfig(max_position_embeddings = 1764, hidden_size = 256, num_attention_heads = 4, vocab_size = num_vocal, is_decoder = True, num_hidden_layers = 4, intermediate_size = 1024) + config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config, add_pooling_layer = False, decoder_add_pooling_layer = False) + + model = EncoderDecoderModel(config=config) + model.config.vocab_size = num_vocal + model.config.decoder_start_token_id = tokenizer.SOS_ID + model.config.pad_token_id = tokenizer.PAD_ID + model.config.eos_token_id = tokenizer.EOS_ID + model.encoder.embeddings.word_embeddings = None + model.encoder.pooler = None + return model + else: + num_vocal = len(tokenizer) + return Transformer( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + num_vocal=num_vocal, + pred_eos=args.pred_eos, + tokenizer=tokenizer + ) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") diff --git a/rxnscribe/tokenizer.py b/rxnscribe/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..67ee915b849edd12f38f998e0ea69224501fbca2 --- /dev/null +++ b/rxnscribe/tokenizer.py @@ -0,0 +1,691 @@ +import json +import copy +import random +import numpy as np + + +PAD = '<pad>' +SOS = '<sos>' +EOS = '<eos>' +UNK = '<unk>' +MASK = '<mask>' + +Rxn = '[Rxn]' # Reaction +Rct = '[Rct]' # Reactant +Prd = '[Prd]' # Product +Cnd = '[Cnd]' # Condition +Idt = '[Idt]' # Identifier +Mol = '[Mol]' # Molecule +Txt = '[Txt]' # Text +Sup = '[Sup]' # Supplement +Noise = '[Nos]' + + +class ReactionTokenizer(object): + + def __init__(self, input_size=100, sep_xy=True, pix2seq=False): + self.stoi = {} + self.itos = {} + self.pix2seq = pix2seq + self.maxx = input_size # height + self.maxy = input_size # width + self.sep_xy = sep_xy + self.special_tokens = [PAD, SOS, EOS, UNK, MASK] + self.tokens = [Rxn, Rct, Prd, Cnd, Idt, Mol, Txt, Sup, Noise] + self.fit_tokens(self.tokens) + + def __len__(self): + if self.pix2seq: + return 2094 + if self.sep_xy: + return self.offset + self.maxx + self.maxy + else: + return self.offset + max(self.maxx, self.maxy) + + @property + def max_len(self): + return 256 + + @property + def PAD_ID(self): + return self.stoi[PAD] + + @property + def SOS_ID(self): + return self.stoi[SOS] + + @property + def EOS_ID(self): + return self.stoi[EOS] + + @property + def UNK_ID(self): + return self.stoi[UNK] + + @property + def NOISE_ID(self): + return self.stoi[Noise] + + @property + def offset(self): + return 0 if self.pix2seq else len(self.stoi) + + @property + def output_constraint(self): + return True + + def fit_tokens(self, tokens): + vocab = self.special_tokens + tokens + if self.pix2seq: + for i, s in enumerate(vocab): + self.stoi[s] = 2001 + i + self.stoi[EOS] = len(self) - 2 + # self.stoi[Noise] = len(self) - 1 + else: + for i, s in enumerate(vocab): + self.stoi[s] = i + self.itos = {item[1]: item[0] for item in self.stoi.items()} + self.bbox_category_to_token = {1: Mol, 2: Txt, 3: Idt, 4: Sup} + self.token_to_bbox_category = {item[1]: item[0] for item in self.bbox_category_to_token.items()} + + def is_x(self, x): + return 0 <= x - self.offset < self.maxx + + def is_y(self, y): + if self.sep_xy: + return self.maxx <= y - self.offset < self.maxx + self.maxy + return 0 <= y - self.offset < self.maxy + + def x_to_id(self, x): + if x < -0.001 or x > 1.001: + print(x) + else: + x = min(max(x, 0), 1) + assert 0 <= x <= 1 + return self.offset + round(x * (self.maxx - 1)) + + def y_to_id(self, y): + if y < -0.001 or y > 1.001: + print(y) + else: + y = min(max(y, 0), 1) + assert 0 <= y <= 1 + if self.sep_xy: + return self.offset + self.maxx + round(y * (self.maxy - 1)) + return self.offset + round(y * (self.maxy - 1)) + + def id_to_x(self, id, scale=1): + if not self.is_x(id): + return -1 + return (id - self.offset) / (self.maxx - 1) / scale + + def id_to_y(self, id, scale=1): + if not self.is_y(id): + return -1 + if self.sep_xy: + return (id - self.offset - self.maxx) / (self.maxy - 1) * scale + return (id - self.offset) / (self.maxy - 1) / scale + + def update_state(self, state, idx): + if state is None: + new_state = (Rxn, 'e') + else: + if state[1] == 'x1': + new_state = (state[0], 'y1') + elif state[1] == 'y1': + new_state = (state[0], 'x2') + elif state[1] == 'x2': + new_state = (state[0], 'y2') + elif state[1] == 'y2': + new_state = (state[0], 'c') + elif state[1] == 'c': + if self.is_x(idx): + new_state = (state[0], 'x1') + else: + new_state = (state[0], 'e') + else: + if state[0] == Rct: + if self.is_x(idx): + new_state = (Cnd, 'x1') + else: + new_state = (Cnd, 'e') + elif state[0] == Cnd: + new_state = (Prd, 'x1') + elif state[0] == Prd: + new_state = (Rxn, 'e') + elif state[0] == Rxn: + if self.is_x(idx): + new_state = (Rct, 'x1') + else: + new_state = (EOS, 'e') + else: + new_state = (EOS, 'e') + return new_state + + def output_mask(self, state): + # mask: True means forbidden + mask = np.array([True] * len(self)) + if state[1] in ['y1', 'c']: + mask[self.offset:self.offset+self.maxx] = False + if state[1] in ['x1', 'x2']: + if self.sep_xy: + mask[self.offset+self.maxx:self.offset+self.maxx+self.maxy] = False + else: + mask[self.offset:self.offset+self.maxy] = False + if state[1] == 'y2': + for token in [Idt, Mol, Txt, Sup]: + mask[self.stoi[token]] = False + if state[1] == 'c': + mask[self.stoi[state[0]]] = False + if state[1] == 'e': + if state[0] in [Rct, Cnd, Rxn]: + mask[self.offset:self.offset + self.maxx] = False + if state[0] == Rct: + mask[self.stoi[Cnd]] = False + if state[0] == Prd: + mask[self.stoi[Rxn]] = False + mask[self.stoi[Noise]] = False + if state[0] in [Rxn, EOS]: + mask[self.EOS_ID] = False + return mask + + def update_states_and_masks(self, states, ids): + new_states = [self.update_state(state, idx) for state, idx in zip(states, ids)] + masks = np.array([self.output_mask(state) for state in new_states]) + return new_states, masks + + def bbox_to_sequence(self, bbox, category): + sequence = [] + x1, y1, x2, y2 = bbox + if x1 >= x2 or y1 >= y2: + return [] + sequence.append(self.x_to_id(x1)) + sequence.append(self.y_to_id(y1)) + sequence.append(self.x_to_id(x2)) + sequence.append(self.y_to_id(y2)) + if category in self.bbox_category_to_token: + sequence.append(self.stoi[self.bbox_category_to_token[category]]) + else: + sequence.append(self.stoi[Noise]) + return sequence + + def sequence_to_bbox(self, sequence, scale=[1, 1]): + if len(sequence) < 5: + return None + x1, y1 = self.id_to_x(sequence[0], scale[0]), self.id_to_y(sequence[1], scale[1]) + x2, y2 = self.id_to_x(sequence[2], scale[0]), self.id_to_y(sequence[3], scale[1]) + if x1 == -1 or y1 == -1 or x2 == -1 or y2 == -1 or x1 >= x2 or y1 >= y2 or sequence[4] not in self.itos: + return None + category = self.itos[sequence[4]] + if category not in [Mol, Txt, Idt, Sup]: + return None + return {'category': category, 'bbox': (round(x1,3), round(y1,3), round(x2,3), round(y2,3)), 'category_id': self.token_to_bbox_category[category]} + + def perturb_reaction(self, reaction, boxes): + reaction = copy.deepcopy(reaction) + options = [] + options.append(0) # Option 0: add + if not(len(reaction['reactants']) == 1 and len(reaction['conditions']) == 0 and len(reaction['products']) == 1): + options.append(1) # Option 1: delete + options.append(2) # Option 2: move + choice = random.choice(options) + if choice == 0: + key = random.choice(['reactants', 'conditions', 'products']) + # TODO: insert to a random position + # We simply add a random box, which may be a duplicate box in this reaction + reaction[key].append(random.randrange(len(boxes))) + if choice == 1 or choice == 2: + options = [] + for key, val in [('reactants', 1), ('conditions', 0), ('products', 1)]: + if len(reaction[key]) > val: + options.append(key) + key = random.choice(options) + idx = random.randrange(len(reaction[key])) + del_box = reaction[key][idx] + reaction[key] = reaction[key][:idx] + reaction[key][idx+1:] + if choice == 2: + options = ['reactants', 'conditions', 'products'] + options.remove(key) + newkey = random.choice(options) + reaction[newkey].append(del_box) + return reaction + + def augment_reaction(self, reactions, data): + area, boxes, labels = data['area'], data['boxes'], data['labels'] + nonempty_boxes = [i for i in range(len(area)) if area[i] > 0] + if len(nonempty_boxes) == 0: + return None + if len(reactions) == 0 or random.randrange(100) < 20: + num_reactants = random.randint(1, 3) + num_conditions = random.randint(0, 3) + num_products = random.randint(1, 3) + reaction = { + 'reactants': random.choices(nonempty_boxes, k=num_reactants), + 'conditions': random.choices(nonempty_boxes, k=num_conditions), + 'products': random.choices(nonempty_boxes, k=num_products) + } + else: + assert len(reactions) > 0 + reaction = self.perturb_reaction(random.choice(reactions), boxes) + return reaction + + def reaction_to_sequence(self, reaction, data, shuffle_bbox=False): + reaction = copy.deepcopy(reaction) + area, boxes, labels = data['area'], data['boxes'], data['labels'] + # If reactants or products are empty (because of image cropping), skip the reaction + if all([area[i] == 0 for i in reaction['reactants']]) or all([area[i] == 0 for i in reaction['products']]): + return [] + if shuffle_bbox: + random.shuffle(reaction['reactants']) + random.shuffle(reaction['conditions']) + random.shuffle(reaction['products']) + sequence = [] + for idx in reaction['reactants']: + if area[idx] == 0: + continue + sequence += self.bbox_to_sequence(boxes[idx].tolist(), labels[idx].item()) + sequence.append(self.stoi[Rct]) + for idx in reaction['conditions']: + if area[idx] == 0: + continue + sequence += self.bbox_to_sequence(boxes[idx].tolist(), labels[idx].item()) + sequence.append(self.stoi[Cnd]) + for idx in reaction['products']: + if area[idx] == 0: + continue + sequence += self.bbox_to_sequence(boxes[idx].tolist(), labels[idx].item()) + sequence.append(self.stoi[Prd]) + sequence.append(self.stoi[Rxn]) + return sequence + + def data_to_sequence(self, data, rand_order=False, shuffle_bbox=False, add_noise=False, mix_noise=False): + sequence = [self.SOS_ID] + sequence_out = [self.SOS_ID] + reactions = copy.deepcopy(data['reactions']) + reactions_seqs = [] + for reaction in reactions: + seq = self.reaction_to_sequence(reaction, data, shuffle_bbox=shuffle_bbox) + reactions_seqs.append([seq, seq]) + noise_seqs = [] + if add_noise: + total_len = sum(len(seq) for seq, seq_out in reactions_seqs) + while total_len < self.max_len: + reaction = self.augment_reaction(reactions, data) + if reaction is None: + break + seq = self.reaction_to_sequence(reaction, data) + if len(seq) == 0: + continue + if mix_noise: + seq[-1] = self.NOISE_ID + seq_out = [self.PAD_ID] * (len(seq) - 1) + [self.NOISE_ID] + else: + seq_out = [self.PAD_ID] * (len(seq) - 1) + [self.NOISE_ID] + noise_seqs.append([seq, seq_out]) + total_len += len(seq) + if rand_order: + random.shuffle(reactions_seqs) + reactions_seqs += noise_seqs + if mix_noise: + random.shuffle(reactions_seqs) + for seq, seq_out in reactions_seqs: + sequence += seq + sequence_out += seq_out + sequence.append(self.EOS_ID) + sequence_out.append(self.EOS_ID) + return sequence, sequence_out + + def sequence_to_data(self, sequence, scores=None, scale=None): + reactions = [] + i = 0 + cur_reaction = {'reactants': [], 'conditions': [], 'products': []} + flag = 'reactants' + if len(sequence) > 0 and sequence[0] == self.SOS_ID: + i += 1 + while i < len(sequence): + if sequence[i] == self.EOS_ID: + break + if sequence[i] in self.itos: + if self.itos[sequence[i]] in [Rxn, Noise]: + cur_reaction['label'] = self.itos[sequence[i]] + if len(cur_reaction['reactants']) > 0 and len(cur_reaction['products']) > 0: + reactions.append(cur_reaction) + cur_reaction = {'reactants': [], 'conditions': [], 'products': []} + flag = 'reactants' + elif self.itos[sequence[i]] == Rct: + flag = 'conditions' + elif self.itos[sequence[i]] == Cnd: + flag = 'products' + elif self.itos[sequence[i]] == Prd: + flag = None + elif i+5 <= len(sequence) and flag is not None: + bbox = self.sequence_to_bbox(sequence[i:i+5], scale) + if bbox is not None: + cur_reaction[flag].append(bbox) + i += 4 + i += 1 + return reactions + + def sequence_to_tokens(self, sequence): + return [self.itos[x] if x in self.itos else x for x in sequence] + + +class BboxTokenizer(ReactionTokenizer): + + def __init__(self, input_size=100, sep_xy=True, pix2seq=False): + super(BboxTokenizer, self).__init__(input_size, sep_xy, pix2seq) + + @property + def max_len(self): + return 500 + + @property + def output_constraint(self): + return False + + def random_category(self): + return random.choice(list(self.bbox_category_to_token.keys())) + # return random.choice([random.choice(list(self.bbox_category_to_token.keys())), self.NOISE_ID]) + + def random_bbox(self): + _x1, _y1, _x2, _y2 = random.random(), random.random(), random.random(), random.random() + x1, y1, x2, y2 = min(_x1, _x2), min(_y1, _y2), max(_x1, _x2), max(_y1, _y2) + category = self.random_category() + return [x1, y1, x2, y2], category + + def jitter_bbox(self, bbox, ratio=0.2): + x1, y1, x2, y2 = bbox + w, h = x2 - x1, y2 - y1 + _x1 = x1 + random.uniform(-w*ratio, w*ratio) + _y1 = y1 + random.uniform(-h*ratio, h*ratio) + _x2 = x2 + random.uniform(-w * ratio, w * ratio) + _y2 = y2 + random.uniform(-h * ratio, h * ratio) + x1, y1, x2, y2 = min(_x1, _x2), min(_y1, _y2), max(_x1, _x2), max(_y1, _y2) + category = self.random_category() + return np.clip([x1, y1, x2, y2], 0, 1), category + + def augment_box(self, bboxes): + if len(bboxes) == 0: + return self.random_bbox() + if random.random() < 0.5: + return self.random_bbox() + else: + return self.jitter_bbox(random.choice(bboxes)) + + def split_heuristic_helper(self, toprocess): + maxy = 0 + for pair in toprocess: + if pair[0][1]>maxy: + maxy = pair[0][1] + numbuckets = int(maxy//500 + 1) + + buckets = {} + for i in range(numbuckets): + buckets[i] = [] + + for pair in toprocess: + buckets[int(pair[0][1]//500)].append(pair) + + for bucket in buckets: + buckets[bucket] = sorted(buckets[bucket], key = lambda x: x[0][0]) + toreturn = [] + + for bucket in buckets: + toreturn+=buckets[bucket] + + return toreturn + + def data_to_sequence(self, data, add_noise=False, rand_order=False, split_heuristic=False): + sequence = [self.SOS_ID] + sequence_out = [self.SOS_ID] + if rand_order: + perm = np.random.permutation(len(data['boxes'])) + boxes = data['boxes'][perm].tolist() + labels = data['labels'][perm].tolist() + elif split_heuristic: + to_process = list(zip(data['boxes'].tolist(), data['labels'].tolist())) + processed = self.split_heuristic_helper(to_process) + boxes = [item[0] for item in processed] + labels = [item[1] for item in processed] + else: + boxes = data['boxes'].tolist() + labels = data['labels'].tolist() + for bbox, category in zip(boxes, labels): + seq = self.bbox_to_sequence(bbox, category) + sequence += seq + # sequence[-1] = self.random_category() + sequence_out += seq + if add_noise: + while len(sequence) < self.max_len: + bbox, category = self.augment_box(boxes) + sequence += self.bbox_to_sequence(bbox, category) + sequence_out += [self.PAD_ID] * 4 + [self.NOISE_ID] + sequence.append(self.EOS_ID) + sequence_out.append(self.EOS_ID) + return sequence, sequence_out + + def sequence_to_data(self, sequence, scores=None, scale=None): + bboxes = [] + i = 0 + #print(sequence) + if len(sequence) > 0 and sequence[0] == self.SOS_ID: + i += 1 + while i < len(sequence): + if sequence[i] == self.EOS_ID: + break + if i+4 < len(sequence): + bbox = self.sequence_to_bbox(sequence[i:i+5], scale) + if bbox is not None: + if scores is not None: + bbox['score'] = scores[i + 4] + bboxes.append(bbox) + i += 4 + i += 1 + return bboxes + +class CorefTokenizer(ReactionTokenizer): + + def __init__(self, input_size=100, sep_xy=True, pix2seq=False): + super(CorefTokenizer, self).__init__(input_size, sep_xy, pix2seq) + + @property + def max_len(self): + return 500 + + @property + def output_constraint(self): + return False + + def split_heuristic_helper(self, toprocess): + maxy = 0 + compress = [] + for pair in toprocess: + if pair[1] == 1 or pair[1] == 2: + compress.append([pair]) + else: + compress[-1].append(pair) + + for pair in toprocess: + if pair[0][1] > maxy and (pair[1] == 1 or pair[1] ==2): + maxy = pair[0][1] + numbuckets = int(maxy//500 + 1) + + buckets = {} + for i in range(numbuckets): + buckets[i] = [] + + for bbox_group in compress: + buckets[int(bbox_group[0][0][1]//500)].append(bbox_group) + + for bucket in buckets: + buckets[bucket] = sorted(buckets[bucket], key = lambda x: x[0][0][0]) + toreturn = [] + + for bucket in buckets: + for bbox_group in buckets[bucket]: + toreturn+=bbox_group + + return toreturn + + def coref_tokenize(self, boxes, labels, corefs, split_heuristic = False): + coref_dict = {} + for pair in corefs: + if pair[0] in coref_dict: + coref_dict[pair[0]].append(pair[1]) + else: + coref_dict[pair[0]] = [pair[1]] + #coref_dict = {pair[0]: pair[1] for pair in corefs} + toreturn_boxes = [] + toreturn_labels = [] + + for i, label in enumerate(labels): + if i in coref_dict: + toreturn_boxes.append(boxes[i]) + toreturn_labels.append(labels[i]) + for index in coref_dict[i]: + + toreturn_boxes.append(boxes[index]) + toreturn_labels.append(labels[index]) + elif label == 1: + toreturn_boxes.append(boxes[i]) + toreturn_labels.append(labels[i]) + ''' + for pair in corefs: + for entry in pair: + toreturn_boxes.append(boxes[entry]) + toreturn_labels.append(labels[entry]) + ''' + if split_heuristic: + returned = self.split_heuristic_helper(list(zip(toreturn_boxes, toreturn_labels))) + toreturn_boxes = [r[0] for r in returned] + toreturn_labels = [r[1] for r in returned] + ''' + if True: + for i, label in enumerate(labels): + if label == 2: + toreturn_boxes.append(boxes[i]) + toreturn_labels.append(labels[i]) + ''' + return toreturn_boxes, toreturn_labels + + def data_to_sequence(self, data, add_noise = False, rand_order = False, split_heuristic = False): + sequence = [self.SOS_ID] + sequence_out = [self.SOS_ID] + if rand_order: + #TODO + pass + else: + boxes, labels = self.coref_tokenize(data['boxes'].tolist(), data['labels'].tolist(), data['corefs'], split_heuristic) + for bbox, category in zip(boxes, labels): + + seq = self.bbox_to_sequence(bbox, category) + sequence += seq + # sequence[-1] = self.random_category() + sequence_out += seq + if add_noise: + pass + #TODO + ''' + while len(sequence) < self.max_len: + bbox, category = self.augment_box(boxes) + sequence += self.bbox_to_sequence(bbox, category) + sequence_out += [self.PAD_ID] * 4 + [self.NOISE_ID] + ''' + + #sequence = sequence[:6] + #sequence_out = sequence_out[:6] + sequence.append(self.EOS_ID) + sequence_out.append(self.EOS_ID) + return sequence, sequence_out + + def sequence_to_data(self, sequence, scores=None, scale=None): + bboxes = [] + i = 0 + if len(sequence) > 0 and sequence[0] == self.SOS_ID: + i += 1 + while i < len(sequence): + if sequence[i] == self.EOS_ID: + break + if i+4 < len(sequence): + bbox = self.sequence_to_bbox(sequence[i:i+5], scale) + if bbox is not None: + if scores is not None: + bbox['score'] = scores[i + 4] + bboxes.append(bbox) + i += 4 + i += 1 + return {'bboxes': bboxes, 'corefs': self.bbox_to_coref(bboxes)} + + def bbox_to_coref(self, bboxes): + corefs = [] + + for i in range(len(bboxes) - 1): + if bboxes[i]['category_id'] == 1 or bboxes[i]['category_id'] == 2: + j = i + 1 + while j < len(bboxes) and bboxes[j]['category_id'] == 3: + corefs.append([i, j]) + j += 1 + + return corefs + +class CocoTokenizer(BboxTokenizer): + + def __init__(self, input_size=100, sep_xy=True, pix2seq=False): + super(CocoTokenizer, self).__init__(input_size, sep_xy, pix2seq) + self.index_to_class = {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7, 7: 8, 8: 9, 9: 10, 10: 11, 11: 13, 12: 14, 13: 15, 14: 16, 15: 17, 16: 18, 17: 19, 18: 20, 19: 21, 20: 22, 21: 23, 22: 24, 23: 25, 24: 27, 25: 28, 26: 31, 27: 32, 28: 33, 29: 34, 30: 35, 31: 36, 32: 37, 33: 38, 34: 39, 35: 40, 36: 41, 37: 42, 38: 43, 39: 44, 40: 46, 41: 47, 42: 48, 43: 49, 44: 50, 45: 51, 46: 52, 47: 53, 48: 54, 49: 55, 50: 56, 51: 57, 52: 58, 53: 59, 54: 60, 55: 61, 56: 62, 57: 63, 58: 64, 59: 65, 60: 67, 61: 70, 62: 72, 63: 73, 64: 74, 65: 75, 66: 76, 67: 77, 68: 78, 69: 79, 70: 80, 71: 81, 72: 82, 73: 84, 74: 85, 75: 86, 76: 87, 77: 88, 78: 89, 79: 90} + self.class_to_index = {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7, 9: 8, 10: 9, 11: 10, 13: 11, 14: 12, 15: 13, 16: 14, 17: 15, 18: 16, 19: 17, 20: 18, 21: 19, 22: 20, 23: 21, 24: 22, 25: 23, 27: 24, 28: 25, 31: 26, 32: 27, 33: 28, 34: 29, 35: 30, 36: 31, 37: 32, 38: 33, 39: 34, 40: 35, 41: 36, 42: 37, 43: 38, 44: 39, 46: 40, 47: 41, 48: 42, 49: 43, 50: 44, 51: 45, 52: 46, 53: 47, 54: 48, 55: 49, 56: 50, 57: 51, 58: 52, 59: 53, 60: 54, 61: 55, 62: 56, 63: 57, 64: 58, 65: 59, 67: 60, 70: 61, 72: 62, 73: 63, 74: 64, 75: 65, 76: 66, 77: 67, 78: 68, 79: 69, 80: 70, 81: 71, 82: 72, 84: 73, 85: 74, 86: 75, 87: 76, 88: 77, 89: 78, 90: 79} + + @property + def max_len(self): + return 700 + + def random_category(self): + return random.choice(list(self.class_to_index.keys())) + + + def bbox_to_sequence(self, bbox, category): + sequence = [] + x1, y1, x2, y2 = bbox + if x1 >= x2 or y1 >= y2: + return [] + sequence.append(self.x_to_id(x1)) + sequence.append(self.y_to_id(y1)) + sequence.append(self.x_to_id(x2)) + sequence.append(self.y_to_id(y2)) + + sequence.append(2006+self.class_to_index[category]) + + + return sequence + + def sequence_to_bbox(self, sequence, scale=[1, 1]): + if len(sequence) < 5: + return None + x1, y1 = self.id_to_x(sequence[0], scale[0]), self.id_to_y(sequence[1], scale[1]) + x2, y2 = self.id_to_x(sequence[2], scale[0]), self.id_to_y(sequence[3], scale[1]) + if x1 == -1 or y1 == -1 or x2 == -1 or y2 == -1 or x1 >= x2 or y1 >= y2: + return None + if sequence[4] - 2006 in self.index_to_class: + category = self.index_to_class[sequence[4] - 2006] + else: + category = -1 + return { 'bbox': (x1, y1, x2, y2), 'category_id': category} + + + +def get_tokenizer(args): + tokenizer = {} + if args.pix2seq: + args.coord_bins = 2000 + args.sep_xy = False + format = args.format + if format == 'reaction': + tokenizer[format] = ReactionTokenizer(args.coord_bins, args.sep_xy, args.pix2seq) + if format == 'bbox': + if args.is_coco: + tokenizer[format] = CocoTokenizer(args.coord_bins, args.sep_xy, args.pix2seq) + else: + tokenizer[format] = BboxTokenizer(args.coord_bins, args.sep_xy, args.pix2seq) + if format == 'coref': + tokenizer[format] = CorefTokenizer(args.coord_bins, args.sep_xy, args.pix2seq) + return tokenizer \ No newline at end of file diff --git a/rxnscribe/transformer/__init__.py b/rxnscribe/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e9c953b157c20f4c8bd13345f0a846fd70e0815e --- /dev/null +++ b/rxnscribe/transformer/__init__.py @@ -0,0 +1,3 @@ +from .decoder import TransformerDecoder +from .embedding import Embeddings +from .swin_transformer import swin_base, swin_large diff --git a/rxnscribe/transformer/decoder.py b/rxnscribe/transformer/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a04a96aa4c472fd00d3e8a9d470bd62d380b3202 --- /dev/null +++ b/rxnscribe/transformer/decoder.py @@ -0,0 +1,487 @@ +""" +Implementation of "Attention is All You Need" and of +subsequent transformer based architectures +""" + +import torch +import torch.nn as nn + +from onmt.decoders.decoder import DecoderBase +from onmt.modules import MultiHeadedAttention, AverageAttention +from onmt.modules.position_ffn import PositionwiseFeedForward +from onmt.modules.position_ffn import ActivationFunction +from onmt.utils.misc import sequence_mask + + +class TransformerDecoderLayerBase(nn.Module): + def __init__( + self, + d_model, + heads, + d_ff, + dropout, + attention_dropout, + self_attn_type="scaled-dot", + max_relative_positions=0, + aan_useffn=False, + full_context_alignment=False, + alignment_heads=0, + pos_ffn_activation_fn=ActivationFunction.relu, + ): + """ + Args: + d_model (int): the dimension of keys/values/queries in + :class:`MultiHeadedAttention`, also the input size of + the first-layer of the :class:`PositionwiseFeedForward`. + heads (int): the number of heads for MultiHeadedAttention. + d_ff (int): the second-layer of the + :class:`PositionwiseFeedForward`. + dropout (float): dropout in residual, self-attn(dot) and + feed-forward + attention_dropout (float): dropout in context_attn (and + self-attn(avg)) + self_attn_type (string): type of self-attention scaled-dot, + average + max_relative_positions (int): + Max distance between inputs in relative positions + representations + aan_useffn (bool): Turn on the FFN layer in the AAN decoder + full_context_alignment (bool): + whether enable an extra full context decoder forward for + alignment + alignment_heads (int): + N. of cross attention heads to use for alignment guiding + pos_ffn_activation_fn (ActivationFunction): + activation function choice for PositionwiseFeedForward layer + + """ + super(TransformerDecoderLayerBase, self).__init__() + + if self_attn_type == "scaled-dot": + self.self_attn = MultiHeadedAttention( + heads, + d_model, + dropout=attention_dropout, + max_relative_positions=max_relative_positions, + ) + elif self_attn_type == "average": + self.self_attn = AverageAttention( + d_model, dropout=attention_dropout, aan_useffn=aan_useffn + ) + + self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout, + pos_ffn_activation_fn + ) + self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6) + self.drop = nn.Dropout(dropout) + self.full_context_alignment = full_context_alignment + self.alignment_heads = alignment_heads + + def forward(self, *args, **kwargs): + """Extend `_forward` for (possibly) multiple decoder pass: + Always a default (future masked) decoder forward pass, + Possibly a second future aware decoder pass for joint learn + full context alignement, :cite:`garg2019jointly`. + + Args: + * All arguments of _forward. + with_align (bool): whether return alignment attention. + + Returns: + (FloatTensor, FloatTensor, FloatTensor or None): + + * output ``(batch_size, T, model_dim)`` + * top_attn ``(batch_size, T, src_len)`` + * attn_align ``(batch_size, T, src_len)`` or None + """ + with_align = kwargs.pop("with_align", False) + output, attns = self._forward(*args, **kwargs) + top_attn = attns[:, 0, :, :].contiguous() + attn_align = None + if with_align: + if self.full_context_alignment: + # return _, (B, Q_len, K_len) + _, attns = self._forward(*args, **kwargs, future=True) + + if self.alignment_heads > 0: + attns = attns[:, : self.alignment_heads, :, :].contiguous() + # layer average attention across heads, get ``(B, Q, K)`` + # Case 1: no full_context, no align heads -> layer avg baseline + # Case 2: no full_context, 1 align heads -> guided align + # Case 3: full_context, 1 align heads -> full cte guided align + attn_align = attns.mean(dim=1) + return output, top_attn, attn_align + + def update_dropout(self, dropout, attention_dropout): + self.self_attn.update_dropout(attention_dropout) + self.feed_forward.update_dropout(dropout) + self.drop.p = dropout + + def _forward(self, *args, **kwargs): + raise NotImplementedError + + def _compute_dec_mask(self, tgt_pad_mask, future): + tgt_len = tgt_pad_mask.size(-1) + if not future: # apply future_mask, result mask in (B, T, T) + future_mask = torch.ones( + [tgt_len, tgt_len], + device=tgt_pad_mask.device, + dtype=torch.uint8, + ) + future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len) + # BoolTensor was introduced in pytorch 1.2 + try: + future_mask = future_mask.bool() + except AttributeError: + pass + dec_mask = torch.gt(tgt_pad_mask + future_mask, 0) + else: # only mask padding, result mask in (B, 1, T) + dec_mask = tgt_pad_mask + return dec_mask + + def _forward_self_attn(self, inputs_norm, dec_mask, layer_cache, step): + if isinstance(self.self_attn, MultiHeadedAttention): + return self.self_attn( + inputs_norm, + inputs_norm, + inputs_norm, + mask=dec_mask, + layer_cache=layer_cache, + attn_type="self", + ) + elif isinstance(self.self_attn, AverageAttention): + return self.self_attn( + inputs_norm, mask=dec_mask, layer_cache=layer_cache, step=step + ) + else: + raise ValueError( + f"self attention {type(self.self_attn)} not supported" + ) + + +class TransformerDecoderLayer(TransformerDecoderLayerBase): + """Transformer Decoder layer block in Pre-Norm style. + Pre-Norm style is an improvement w.r.t. Original paper's Post-Norm style, + providing better converge speed and performance. This is also the actual + implementation in tensor2tensor and also avalable in fairseq. + See https://tunz.kr/post/4 and :cite:`DeeperTransformer`. + + .. mermaid:: + + graph LR + %% "*SubLayer" can be self-attn, src-attn or feed forward block + A(input) --> B[Norm] + B --> C["*SubLayer"] + C --> D[Drop] + D --> E((+)) + A --> E + E --> F(out) + + """ + + def __init__( + self, + d_model, + heads, + d_ff, + dropout, + attention_dropout, + self_attn_type="scaled-dot", + max_relative_positions=0, + aan_useffn=False, + full_context_alignment=False, + alignment_heads=0, + pos_ffn_activation_fn=ActivationFunction.relu, + ): + """ + Args: + See TransformerDecoderLayerBase + """ + super(TransformerDecoderLayer, self).__init__( + d_model, + heads, + d_ff, + dropout, + attention_dropout, + self_attn_type, + max_relative_positions, + aan_useffn, + full_context_alignment, + alignment_heads, + pos_ffn_activation_fn=pos_ffn_activation_fn, + ) + self.context_attn = MultiHeadedAttention( + heads, d_model, dropout=attention_dropout + ) + self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6) + + def update_dropout(self, dropout, attention_dropout): + super(TransformerDecoderLayer, self).update_dropout( + dropout, attention_dropout + ) + self.context_attn.update_dropout(attention_dropout) + + def _forward( + self, + inputs, + memory_bank, + src_pad_mask, + tgt_pad_mask, + layer_cache=None, + step=None, + future=False, + ): + """A naive forward pass for transformer decoder. + + # T: could be 1 in the case of stepwise decoding or tgt_len + + Args: + inputs (FloatTensor): ``(batch_size, T, model_dim)`` + memory_bank (FloatTensor): ``(batch_size, src_len, model_dim)`` + src_pad_mask (bool): ``(batch_size, 1, src_len)`` + tgt_pad_mask (bool): ``(batch_size, 1, T)`` + layer_cache (dict or None): cached layer info when stepwise decode + step (int or None): stepwise decoding counter + future (bool): If set True, do not apply future_mask. + + Returns: + (FloatTensor, FloatTensor): + + * output ``(batch_size, T, model_dim)`` + * attns ``(batch_size, head, T, src_len)`` + + """ + dec_mask = None + + if inputs.size(1) > 1: + # masking is necessary when sequence length is greater than one + dec_mask = self._compute_dec_mask(tgt_pad_mask, future) + + inputs_norm = self.layer_norm_1(inputs) + + query, _ = self._forward_self_attn( + inputs_norm, dec_mask, layer_cache, step + ) + + query = self.drop(query) + inputs + + query_norm = self.layer_norm_2(query) + mid, attns = self.context_attn( + memory_bank, + memory_bank, + query_norm, + mask=src_pad_mask, + layer_cache=layer_cache, + attn_type="context", + ) + output = self.feed_forward(self.drop(mid) + query) + + return output, attns + + +class TransformerDecoderBase(DecoderBase): + def __init__(self, d_model, copy_attn, alignment_layer): + super(TransformerDecoderBase, self).__init__() + + # Decoder State + self.state = {} + + # previously, there was a GlobalAttention module here for copy + # attention. But it was never actually used -- the "copy" attention + # just reuses the context attention. + self._copy = copy_attn + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + + self.alignment_layer = alignment_layer + + @classmethod + def from_opt(cls, opt, embeddings): + """Alternate constructor.""" + return cls( + opt.dec_layers, + opt.dec_rnn_size, + opt.heads, + opt.transformer_ff, + opt.copy_attn, + opt.self_attn_type, + opt.dropout[0] if type(opt.dropout) is list else opt.dropout, + opt.attention_dropout[0] if type(opt.attention_dropout) is list else opt.attention_dropout, + embeddings, + opt.max_relative_positions, + opt.aan_useffn, + opt.full_context_alignment, + opt.alignment_layer, + alignment_heads=opt.alignment_heads, + pos_ffn_activation_fn=opt.pos_ffn_activation_fn, + ) + + def init_state(self, src, memory_bank, enc_hidden): + """Initialize decoder state.""" + self.state["src"] = src + self.state["cache"] = None + + def map_state(self, fn): + def _recursive_map(struct, batch_dim=0): + for k, v in struct.items(): + if v is not None: + if isinstance(v, dict): + _recursive_map(v) + else: + struct[k] = fn(v, batch_dim) + + if self.state["src"] is not None: + self.state["src"] = fn(self.state["src"], 1) + if self.state["cache"] is not None: + _recursive_map(self.state["cache"]) + + def detach_state(self): + raise NotImplementedError + + def forward(self, *args, **kwargs): + raise NotImplementedError + + def update_dropout(self, dropout, attention_dropout): + self.embeddings.update_dropout(dropout) + for layer in self.transformer_layers: + layer.update_dropout(dropout, attention_dropout) + + +class TransformerDecoder(TransformerDecoderBase): + """The Transformer decoder from "Attention is All You Need". + :cite:`DBLP:journals/corr/VaswaniSPUJGKP17` + + .. mermaid:: + + graph BT + A[input] + B[multi-head self-attn] + BB[multi-head src-attn] + C[feed forward] + O[output] + A --> B + B --> BB + BB --> C + C --> O + + + Args: + num_layers (int): number of decoder layers. + d_model (int): size of the model + heads (int): number of heads + d_ff (int): size of the inner FF layer + copy_attn (bool): if using a separate copy attention + self_attn_type (str): type of self-attention scaled-dot, average + dropout (float): dropout in residual, self-attn(dot) and feed-forward + attention_dropout (float): dropout in context_attn (and self-attn(avg)) + embeddings (onmt.modules.Embeddings): + embeddings to use, should have positional encodings + max_relative_positions (int): + Max distance between inputs in relative positions representations + aan_useffn (bool): Turn on the FFN layer in the AAN decoder + full_context_alignment (bool): + whether enable an extra full context decoder forward for alignment + alignment_layer (int): N° Layer to supervise with for alignment guiding + alignment_heads (int): + N. of cross attention heads to use for alignment guiding + """ + + def __init__( + self, + num_layers, + d_model, + heads, + d_ff, + copy_attn, + self_attn_type, + dropout, + attention_dropout, + max_relative_positions, + aan_useffn, + full_context_alignment, + alignment_layer, + alignment_heads, + pos_ffn_activation_fn=ActivationFunction.relu, + ): + super(TransformerDecoder, self).__init__( + d_model, copy_attn, alignment_layer + ) + + self.transformer_layers = nn.ModuleList( + [ + TransformerDecoderLayer( + d_model, + heads, + d_ff, + dropout, + attention_dropout, + self_attn_type=self_attn_type, + max_relative_positions=max_relative_positions, + aan_useffn=aan_useffn, + full_context_alignment=full_context_alignment, + alignment_heads=alignment_heads, + pos_ffn_activation_fn=pos_ffn_activation_fn, + ) + for i in range(num_layers) + ] + ) + + def detach_state(self): + self.state["src"] = self.state["src"].detach() + + def forward(self, tgt_emb, memory_bank, src_pad_mask=None, tgt_pad_mask=None, step=None, **kwargs): + """Decode, possibly stepwise.""" + if step == 0: + self._init_cache(memory_bank) + + batch_size, src_len, src_dim = memory_bank.size() + device = memory_bank.device + if src_pad_mask is None: + src_pad_mask = torch.zeros((batch_size, 1, src_len), dtype=torch.bool, device=device) + output = tgt_emb + batch_size, tgt_len, tgt_dim = tgt_emb.size() + if tgt_pad_mask is None: + tgt_pad_mask = torch.zeros((batch_size, 1, tgt_len), dtype=torch.bool, device=device) + + future = kwargs.pop("future", False) + with_align = kwargs.pop("with_align", False) + attn_aligns = [] + hiddens = [] + + for i, layer in enumerate(self.transformer_layers): + layer_cache = ( + self.state["cache"]["layer_{}".format(i)] + if step is not None + else None + ) + output, attn, attn_align = layer( + output, + memory_bank, + src_pad_mask, + tgt_pad_mask, + layer_cache=layer_cache, + step=step, + with_align=with_align, + future=future + ) + hiddens.append(output) + if attn_align is not None: + attn_aligns.append(attn_align) + + output = self.layer_norm(output) # (B, L, D) + + attns = {"std": attn} + if self._copy: + attns["copy"] = attn + if with_align: + attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)` + # attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg + + # TODO change the way attns is returned dict => list or tuple (onnx) + return output, attns, hiddens + + def _init_cache(self, memory_bank): + self.state["cache"] = {} + for i, layer in enumerate(self.transformer_layers): + layer_cache = {"memory_keys": None, "memory_values": None, "self_keys": None, "self_values": None} + self.state["cache"]["layer_{}".format(i)] = layer_cache + diff --git a/rxnscribe/transformer/embedding.py b/rxnscribe/transformer/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..39647774d7f183690a3443bdf608479088fb0f5d --- /dev/null +++ b/rxnscribe/transformer/embedding.py @@ -0,0 +1,260 @@ +""" Embeddings module """ +import math +import warnings + +import torch +import torch.nn as nn + +from onmt.modules.util_class import Elementwise + + +class SequenceTooLongError(Exception): + pass + + +class PositionalEncoding(nn.Module): + """Sinusoidal positional encoding for non-recurrent neural networks. + + Implementation based on "Attention Is All You Need" + :cite:`DBLP:journals/corr/VaswaniSPUJGKP17` + + Args: + dropout (float): dropout parameter + dim (int): embedding size + """ + + def __init__(self, dropout, dim, max_len=5000): + if dim % 2 != 0: + raise ValueError("Cannot use sin/cos positional encoding with " + "odd dim (got dim={:d})".format(dim)) + pe = torch.zeros(max_len, dim) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * + -(math.log(10000.0) / dim))) + pe[:, 0::2] = torch.sin(position.float() * div_term) + pe[:, 1::2] = torch.cos(position.float() * div_term) + pe = pe.unsqueeze(1) + super(PositionalEncoding, self).__init__() + self.register_buffer('pe', pe) + self.dropout = nn.Dropout(p=dropout) + self.dim = dim + + def forward(self, emb, step=None): + """Embed inputs. + + Args: + emb (FloatTensor): Sequence of word vectors + ``(seq_len, batch_size, self.dim)`` + step (int or NoneType): If stepwise (``seq_len = 1``), use + the encoding for this position. + """ + + emb = emb * math.sqrt(self.dim) + step = step or 0 + if self.pe.size(0) < step + emb.size(0): + raise SequenceTooLongError( + f"Sequence is {emb.size(0) + step} but PositionalEncoding is" + f" limited to {self.pe.size(0)}. See max_len argument." + ) + emb = emb + self.pe[step:emb.size(0)+step] + emb = self.dropout(emb) + return emb + + +class Embeddings(nn.Module): + """Words embeddings for encoder/decoder. + + Additionally includes ability to add sparse input features + based on "Linguistic Input Features Improve Neural Machine Translation" + :cite:`sennrich2016linguistic`. + + + .. mermaid:: + + graph LR + A[Input] + C[Feature 1 Lookup] + A-->B[Word Lookup] + A-->C + A-->D[Feature N Lookup] + B-->E[MLP/Concat] + C-->E + D-->E + E-->F[Output] + + Args: + word_vec_size (int): size of the dictionary of embeddings. + word_padding_idx (int): padding index for words in the embeddings. + feat_padding_idx (List[int]): padding index for a list of features + in the embeddings. + word_vocab_size (int): size of dictionary of embeddings for words. + feat_vocab_sizes (List[int], optional): list of size of dictionary + of embeddings for each feature. + position_encoding (bool): see :class:`~onmt.modules.PositionalEncoding` + feat_merge (string): merge action for the features embeddings: + concat, sum or mlp. + feat_vec_exponent (float): when using `-feat_merge concat`, feature + embedding size is N^feat_dim_exponent, where N is the + number of values the feature takes. + feat_vec_size (int): embedding dimension for features when using + `-feat_merge mlp` + dropout (float): dropout probability. + freeze_word_vecs (bool): freeze weights of word vectors. + """ + + def __init__(self, word_vec_size, + word_vocab_size, + word_padding_idx, + position_encoding=False, + feat_merge="concat", + feat_vec_exponent=0.7, + feat_vec_size=-1, + feat_padding_idx=[], + feat_vocab_sizes=[], + dropout=0, + sparse=False, + freeze_word_vecs=False): + self._validate_args(feat_merge, feat_vocab_sizes, feat_vec_exponent, + feat_vec_size, feat_padding_idx) + + if feat_padding_idx is None: + feat_padding_idx = [] + self.word_padding_idx = word_padding_idx + + self.word_vec_size = word_vec_size + + # Dimensions and padding for constructing the word embedding matrix + vocab_sizes = [word_vocab_size] + emb_dims = [word_vec_size] + pad_indices = [word_padding_idx] + + # Dimensions and padding for feature embedding matrices + # (these have no effect if feat_vocab_sizes is empty) + if feat_merge == 'sum': + feat_dims = [word_vec_size] * len(feat_vocab_sizes) + elif feat_vec_size > 0: + feat_dims = [feat_vec_size] * len(feat_vocab_sizes) + else: + feat_dims = [int(vocab ** feat_vec_exponent) + for vocab in feat_vocab_sizes] + vocab_sizes.extend(feat_vocab_sizes) + emb_dims.extend(feat_dims) + pad_indices.extend(feat_padding_idx) + + # The embedding matrix look-up tables. The first look-up table + # is for words. Subsequent ones are for features, if any exist. + emb_params = zip(vocab_sizes, emb_dims, pad_indices) + embeddings = [nn.Embedding(vocab, dim, padding_idx=pad, sparse=sparse) + for vocab, dim, pad in emb_params] + emb_luts = Elementwise(feat_merge, embeddings) + + # The final output size of word + feature vectors. This can vary + # from the word vector size if and only if features are defined. + # This is the attribute you should access if you need to know + # how big your embeddings are going to be. + self.embedding_size = (sum(emb_dims) if feat_merge == 'concat' + else word_vec_size) + + # The sequence of operations that converts the input sequence + # into a sequence of embeddings. At minimum this consists of + # looking up the embeddings for each word and feature in the + # input. Model parameters may require the sequence to contain + # additional operations as well. + super(Embeddings, self).__init__() + self.make_embedding = nn.Sequential() + self.make_embedding.add_module('emb_luts', emb_luts) + + if feat_merge == 'mlp' and len(feat_vocab_sizes) > 0: + in_dim = sum(emb_dims) + mlp = nn.Sequential(nn.Linear(in_dim, word_vec_size), nn.ReLU()) + self.make_embedding.add_module('mlp', mlp) + + self.position_encoding = position_encoding + + if self.position_encoding: + pe = PositionalEncoding(dropout, self.embedding_size) + self.make_embedding.add_module('pe', pe) + + if freeze_word_vecs: + self.word_lut.weight.requires_grad = False + + def _validate_args(self, feat_merge, feat_vocab_sizes, feat_vec_exponent, + feat_vec_size, feat_padding_idx): + if feat_merge == "sum": + # features must use word_vec_size + if feat_vec_exponent != 0.7: + warnings.warn("Merging with sum, but got non-default " + "feat_vec_exponent. It will be unused.") + if feat_vec_size != -1: + warnings.warn("Merging with sum, but got non-default " + "feat_vec_size. It will be unused.") + elif feat_vec_size > 0: + # features will use feat_vec_size + if feat_vec_exponent != -1: + warnings.warn("Not merging with sum and positive " + "feat_vec_size, but got non-default " + "feat_vec_exponent. It will be unused.") + else: + if feat_vec_exponent <= 0: + raise ValueError("Using feat_vec_exponent to determine " + "feature vec size, but got feat_vec_exponent " + "less than or equal to 0.") + n_feats = len(feat_vocab_sizes) + if n_feats != len(feat_padding_idx): + raise ValueError("Got unequal number of feat_vocab_sizes and " + "feat_padding_idx ({:d} != {:d})".format( + n_feats, len(feat_padding_idx))) + + @property + def word_lut(self): + """Word look-up table.""" + return self.make_embedding[0][0] + + @property + def emb_luts(self): + """Embedding look-up table.""" + return self.make_embedding[0] + + def load_pretrained_vectors(self, emb_file): + """Load in pretrained embeddings. + + Args: + emb_file (str) : path to torch serialized embeddings + """ + + if emb_file: + pretrained = torch.load(emb_file) + pretrained_vec_size = pretrained.size(1) + if self.word_vec_size > pretrained_vec_size: + self.word_lut.weight.data[:, :pretrained_vec_size] = pretrained + elif self.word_vec_size < pretrained_vec_size: + self.word_lut.weight.data \ + .copy_(pretrained[:, :self.word_vec_size]) + else: + self.word_lut.weight.data.copy_(pretrained) + + def forward(self, source, step=None): + """Computes the embeddings for words and features. + + Args: + source (LongTensor): index tensor ``(len, batch, nfeat)`` + + Returns: + FloatTensor: Word embeddings ``(len, batch, embedding_size)`` + """ + + if self.position_encoding: + for i, module in enumerate(self.make_embedding._modules.values()): + if i == len(self.make_embedding._modules.values()) - 1: + source = module(source, step=step) + else: + source = module(source) + else: + source = self.make_embedding(source) + + return source + + def update_dropout(self, dropout): + if self.position_encoding: + self._modules['make_embedding'][1].dropout.p = dropout + diff --git a/rxnscribe/transformer/swin_transformer.py b/rxnscribe/transformer/swin_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..7f144d1042799562c14833f90657697338a1ec33 --- /dev/null +++ b/rxnscribe/transformer/swin_transformer.py @@ -0,0 +1,677 @@ +""" Swin Transformer +A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` + - https://arxiv.org/pdf/2103.14030 + +Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below +""" +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu +# -------------------------------------------------------- +import logging +import math +from copy import deepcopy +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.helpers import build_model_with_cfg, overlay_external_default_cfg +from timm.models.layers import Mlp, DropPath, to_2tuple, trunc_normal_ +from timm.models.registry import register_model +from timm.models.vision_transformer import checkpoint_filter_fn, _init_vit_weights + +_logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # patch models (my experiments) + 'swin_base_patch4_window12_384': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + + 'swin_base_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth', + ), + + 'swin_large_patch4_window12_384': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + + 'swin_large_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth', + ), + + 'swin_small_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth', + ), + + 'swin_tiny_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth', + ), + + 'swin_base_patch4_window12_384_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth', + input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841), + + 'swin_base_patch4_window7_224_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth', + num_classes=21841), + + 'swin_large_patch4_window12_384_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth', + input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841), + + 'swin_large_patch4_window7_224_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth', + num_classes=21841), + +} + + +def window_partition(x, window_size: int): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size: int, H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask: Optional[torch.Tensor] = None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, + attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def get_attn_mask(self, H, W, device): + if self.shift_size > 0: + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, H, W, 1), device=device) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + return attn_mask + + def forward(self, x, H, W): + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_mask = self.get_attn_mask(Hp, Wp, x.device) + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """ + x: B, H*W, C + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + # assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + H, W = x.shape[1:3] + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x, H, W + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W, hiddens): + for blk in self.blocks: + if not torch.jit.is_scripting() and self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, H, W) + else: + x = blk(x, H, W) + hiddens.append(x) + if self.downsample is not None: + x, H, W = self.downsample(x, H, W) + return x, H, W + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + B, C, H, W = x.shape + # _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") + # _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) + H, W = x.shape[2:] + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x, H, W + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), + window_size=7, mlp_ratio=4., qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, weight_init='', **kwargs): + super().__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + self.patch_grid = self.patch_embed.grid_size + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + else: + self.absolute_pos_embed = None + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + layers = [] + for i_layer in range(self.num_layers): + layers += [BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + ] + self.layers = nn.Sequential(*layers) + + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0. + if weight_init.startswith('jax'): + for n, m in self.named_modules(): + _init_vit_weights(m, n, head_bias=head_bias, jax_impl=True) + else: + self.apply(_init_vit_weights) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward(self, x): + x, H, W = self.patch_embed(x) + if self.absolute_pos_embed is not None: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + hiddens = [] + for layer in self.layers: + x, H, W = layer(x, H, W, hiddens) + x = self.norm(x) # B L C + # x = self.avgpool(x.transpose(1, 2)) # B C 1 + # x = torch.flatten(x, 1) + return x, hiddens + + # def forward(self, x): + # x = self.forward_features(x) + # x = self.head(x) + # return x + + +def _create_swin_transformer(variant, pretrained=False, default_cfg=None, **kwargs): + if default_cfg is None: + default_cfg = deepcopy(default_cfgs[variant]) + overlay_external_default_cfg(default_cfg, kwargs) + default_num_classes = default_cfg['num_classes'] + default_img_size = default_cfg['input_size'][-2:] + + num_classes = kwargs.pop('num_classes', default_num_classes) + img_size = kwargs.pop('img_size', default_img_size) + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + SwinTransformer, variant, pretrained, + default_cfg=default_cfg, + img_size=img_size, + num_classes=num_classes, + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + + return model + + + +@register_model +def swin_base(pretrained=False, **kwargs): + """ Swin-B @ 384x384, pretrained ImageNet-22k, fine tune 1k + """ + model_kwargs = dict( + patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer('swin_base_patch4_window12_384', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_large(pretrained=False, **kwargs): + """ Swin-L @ 384x384, pretrained ImageNet-22k, fine tune 1k + """ + model_kwargs = dict( + patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) + return _create_swin_transformer('swin_large_patch4_window12_384', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_small(pretrained=False, **kwargs): + """ Swin-S @ 224x224, trained ImageNet-1k + """ + model_kwargs = dict( + patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer('swin_small_patch4_window7_224', pretrained=pretrained, **model_kwargs) + + +# @register_model +# def swin_tiny_patch4_window7_224(pretrained=False, **kwargs): +# """ Swin-T @ 224x224, trained ImageNet-1k +# """ +# model_kwargs = dict( +# patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), **kwargs) +# return _create_swin_transformer('swin_tiny_patch4_window7_224', pretrained=pretrained, **model_kwargs) +# +# +# @register_model +# def swin_base_patch4_window12_384_in22k(pretrained=False, **kwargs): +# """ Swin-B @ 384x384, trained ImageNet-22k +# """ +# model_kwargs = dict( +# patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) +# return _create_swin_transformer('swin_base_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs) +# +# +# @register_model +# def swin_base_patch4_window7_224_in22k(pretrained=False, **kwargs): +# """ Swin-B @ 224x224, trained ImageNet-22k +# """ +# model_kwargs = dict( +# patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) +# return _create_swin_transformer('swin_base_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs) +# +# +# @register_model +# def swin_large_patch4_window12_384_in22k(pretrained=False, **kwargs): +# """ Swin-L @ 384x384, trained ImageNet-22k +# """ +# model_kwargs = dict( +# patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) +# return _create_swin_transformer('swin_large_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs) +# +# +# @register_model +# def swin_large_patch4_window7_224_in22k(pretrained=False, **kwargs): +# """ Swin-L @ 224x224, trained ImageNet-22k +# """ +# model_kwargs = dict( +# patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) +# return _create_swin_transformer('swin_large_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs) diff --git a/rxnscribe/transforms.py b/rxnscribe/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6f359c5501a20208a433cc73e3b0f87e44b434dc --- /dev/null +++ b/rxnscribe/transforms.py @@ -0,0 +1,498 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Transforms and data augmentation for both image + bbox. +""" +import random +import math + +import PIL +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as F + +import numpy as np + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), + (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, + (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + +def crop(image, target, region): + cropped_image = F.crop(image, *region) + + target = target.copy() + i, j, h, w = region + + # should we do something wrt the original size? + # target["size"] = torch.tensor([h, w]) + + fields = ["labels", "area"] + + if "boxes" in target: + boxes = target["boxes"] + max_size = torch.as_tensor([w, h], dtype=torch.float32) + cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) + cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) + target["boxes"] = cropped_boxes.reshape(-1, 4) + target["area"] = area + fields.append("boxes") + + # remove elements for which the boxes or masks that have zero area + # if "boxes" in target or "masks" in target: + # # favor boxes selection when defining which elements to keep + # # this is compatible with previous implementation + # if "boxes" in target: + # cropped_boxes = target['boxes'].reshape(-1, 2, 2) + # keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) + # else: + # keep = target['masks'].flatten(1).any(1) + # + # for field in fields: + # target[field] = target[field][keep] + + return cropped_image, target + + +def hflip(image, target): + flipped_image = F.hflip(image) + + w, h = image.size + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) + target["boxes"] = boxes + + return flipped_image, target + + +def rotate90(image, target): + rotated_image = image.rotate(90, expand=1) + + w, h = rotated_image.size + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + boxes = boxes[:, [1, 2, 3, 0]] * torch.as_tensor([1, -1, 1, -1]) + torch.as_tensor([0, h, 0, h]) + target["boxes"] = boxes + + return rotated_image, target + + +def resize(image, target, size, max_size=None): + # size can be min_size (scalar) or (w, h) tuple + + def get_size_with_aspect_ratio(image_size, size, max_size=None): + w, h = image_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = int(round(max_size * min_original_size / max_original_size)) + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + + return (oh, ow) + + def get_size(image_size, size, max_size=None): + if isinstance(size, (list, tuple)): + return size[::-1] + else: + return get_size_with_aspect_ratio(image_size, size, max_size) + + size = get_size(image.size, size, max_size) + rescaled_image = F.resize(image, size) + + if target is None: + return rescaled_image, None + + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) + ratio_width, ratio_height = ratios + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) + target["boxes"] = scaled_boxes + + if "area" in target: + area = target["area"] + scaled_area = area * (ratio_width * ratio_height) + target["area"] = scaled_area + + return rescaled_image, target + + +def pad(image, target, padding): + # assumes that we only pad on the bottom right corners + padded_image = F.pad(image, (0, 0, padding[0], padding[1])) + if target is None: + return padded_image, None + target = target.copy() + # should we do something wrt the original size? + target["size"] = torch.tensor(padded_image.size[::-1]) + if "masks" in target: + target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1])) + return padded_image, target + + +class RandomCrop(object): + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + region = T.RandomCrop.get_params(img, self.size) + return crop(img, target, region) + + +class RandomSizeCrop(object): + def __init__(self, min_size: int, max_size: int): + self.min_size = min_size + self.max_size = max_size + + def __call__(self, img: PIL.Image.Image, target: dict): + w = random.randint(self.min_size, min(img.width, self.max_size)) + h = random.randint(self.min_size, min(img.height, self.max_size)) + region = T.RandomCrop.get_params(img, [h, w]) + return crop(img, target, region) + + +class CenterCrop(object): + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + image_width, image_height = img.size + crop_height, crop_width = self.size + crop_top = int(round((image_height - crop_height) / 2.)) + crop_left = int(round((image_width - crop_width) / 2.)) + return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) + + +class RandomReactionCrop(object): + def __init__(self): + pass + + def __call__(self, img, target): + w, h = img.size + boxes = target["boxes"] + x_avail = [1] * w + y_avail = [1] * h + for reaction in target['reactions']: + ids = reaction['reactants'] + reaction['conditions'] + reaction['products'] + rboxes = boxes[ids].round().int() + rmin, _ = rboxes.min(dim=0) + rmax, _ = rboxes.max(dim=0) + x1, x2 = (rmin[0].item(), rmax[2].item()) + for i in range(x1, x2): + x_avail[i] = 0 + y1, y2 = (rmin[1].item(), rmax[3].item()) + for i in range(y1, y2): + y_avail[i] = 0 + + def sample_from_avail(w): + spans = [] + left, right = 0, 0 + while right < len(w): + while right < len(w) and w[left] == w[right]: + right += 1 + if w[left] == 1: + spans.append((left, right)) + left, right = right + 1, right + 1 + if w[0] == 0: + spans = [(0, 0)] + spans + if w[-1] == 0: + spans = spans + [(len(w), len(w))] + if len(spans) < 2: + w1 = random.randint(0, len(w)) + w2 = random.randint(0, len(w)) + else: + spans = random.sample(spans, 2) + w1 = random.randint(*spans[0]) + w2 = random.randint(*spans[1]) + return min(w1, w2), max(w1, w2) + + x1, x2 = sample_from_avail(x_avail) + y1, y2 = sample_from_avail(y_avail) + region = (y1, x1, y2-y1, x2-x1) + if x2-x1 < 30 or y2-y1 < 30: + # Cropped region too small + return img, target + else: + return crop(img, target, region) + + +class RandomHorizontalFlip(object): + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return hflip(img, target) + return img, target + + +class RandomRotate(object): + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return rotate90(img, target) + return img, target + + +class RandomResize(object): + def __init__(self, sizes, max_size=None): + assert isinstance(sizes, (list, tuple)) + self.sizes = sizes + self.max_size = max_size + + def __call__(self, img, target=None): + size = random.choice(self.sizes) + return resize(img, target, size, self.max_size) + + +class RandomPad(object): + def __init__(self, max_pad): + self.max_pad = max_pad + + def __call__(self, img, target): + pad_x = random.randint(0, self.max_pad) + pad_y = random.randint(0, self.max_pad) + return pad(img, target, (pad_x, pad_y)) + + +class RandomSelect(object): + """ + Randomly selects between transforms1 and transforms2, + with probability p for transforms1 and (1 - p) for transforms2 + """ + def __init__(self, transforms1, transforms2, p=0.5): + self.transforms1 = transforms1 + self.transforms2 = transforms2 + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return self.transforms1(img, target) + return self.transforms2(img, target) + + +class Resize(object): + def __init__(self, size): + assert isinstance(size, (list, tuple)) + self.size = size + + def __call__(self, img, target=None): + return resize(img, target, self.size) + + +class ToTensor(object): + def __call__(self, img, target): + return F.to_tensor(img), target + + +class RandomErasing(object): + + def __init__(self, *args, **kwargs): + self.eraser = T.RandomErasing(*args, **kwargs) + + def __call__(self, img, target): + return self.eraser(img), target + + +class Normalize(object): + def __init__(self, mean, std, debug=False): + self.mean = mean + self.std = std + self.debug = debug + + def __call__(self, image, target=None): + if not self.debug: + image = F.normalize(image, mean=self.mean, std=self.std) + if target is None: + return image, None + target = target.copy() + h, w = image.shape[-2:] + if "boxes" in target: + boxes = target["boxes"] + boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) + target["boxes"] = boxes.clamp(min=0, max=1) + return image, target + + +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target=None): + for t in self.transforms: + image, target = t(image, target) + return image, target + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string + + +class LargeScaleJitter(object): + """ + implementation of large scale jitter from copy_paste + """ + + def __init__(self, output_size=1333, aug_scale_min=0.3, aug_scale_max=2.0): + self.desired_size = output_size + self.aug_scale_min = aug_scale_min + self.aug_scale_max = aug_scale_max + self.random = (aug_scale_min != 1) or (aug_scale_max != 1) + + def rescale_target(self, scaled_size, image_size, target): + # compute rescaled targets + image_scale = scaled_size / image_size + ratio_height, ratio_width = image_scale + + target = target.copy() + + if "boxes" in target: + boxes = target["boxes"] + scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) + target["boxes"] = scaled_boxes + + if "area" in target: + area = target["area"] + scaled_area = area * (ratio_width * ratio_height) + target["area"] = scaled_area + + return target + + def crop_target(self, region, target): + i, j, h, w = region + fields = ["labels", "area"] + + target = target.copy() + + if "boxes" in target: + boxes = target["boxes"] + max_size = torch.as_tensor([w, h], dtype=torch.float32) + cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) + cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) + target["boxes"] = cropped_boxes.reshape(-1, 4) + target["area"] = area + fields.append("boxes") + + # Do not remove the boxes with zero area. Tokenizer does it instead. + # if "boxes" in target: + # # favor boxes selection when defining which elements to keep + # # this is compatible with previous implementation + # cropped_boxes = target['boxes'].reshape(-1, 2, 2) + # keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) + # for field in fields: + # target[field] = target[field][keep] + return target + + def pad_target(self, padding, target): + # padding: left, top, right, bottom + target = target.copy() + if "boxes" in target: + left, top, right, bottom = padding + target["boxes"][:, 0::2] += left + target["boxes"][:, 1::2] += top + return target + + def __call__(self, image, target=None): + image_size = image.size + image_size = torch.tensor(image_size[::-1]) + if target is None: + target = {} + + # out_desired_size = (self.desired_size * image_size / max(image_size)).round().int() + out_desired_size = torch.tensor([self.desired_size, self.desired_size]) + + random_scale = torch.rand(1) * (self.aug_scale_max - self.aug_scale_min) + self.aug_scale_min + scaled_size = (random_scale * self.desired_size).round() + + scale = torch.minimum(scaled_size / image_size[0], scaled_size / image_size[1]) + scaled_size = (image_size * scale).round().int().clamp(min=1) + + scaled_image = F.resize(image, scaled_size.tolist()) + + if target is not None: + target = self.rescale_target(scaled_size, image_size, target) + + # randomly crop or pad images + delta = scaled_size - out_desired_size + output_image = scaled_image + + w, h = scaled_image.size + target["scale"] = [w / self.desired_size, h / self.desired_size] + + if delta.lt(0).any(): + padding = torch.clamp(-delta, min=0) + if self.random: + padding1 = (torch.rand(1) * padding).round().int() + padding2 = padding - padding1 + padding = padding1.tolist()[::-1] + padding2.tolist()[::-1] + else: + padding = [0, 0] + padding.tolist()[::-1] + output_image = F.pad(output_image, padding, 255) + # output_image = F.pad(scaled_image, [0, 0, padding[1].item(), padding[0].item()]) + if target is not None: + target = self.pad_target(padding, target) + + if delta.gt(0).any(): + # Selects non-zero random offset (x, y) if scaled image is larger than desired_size. + max_offset = torch.clamp(delta, min=0) + if self.random: + offset = (max_offset * torch.rand(2)).floor().int() + else: + offset = torch.zeros(2) + region = (offset[0].item(), offset[1].item(), out_desired_size[0].item(), out_desired_size[1].item()) + output_image = F.crop(output_image, *region) + if target is not None: + target = self.crop_target(region, target) + + return output_image, target + + +class RandomDistortion(object): + """ + Distort image w.r.t hue, saturation and exposure. + """ + + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, prob=0.5): + self.prob = prob + self.tfm = T.ColorJitter(brightness, contrast, saturation, hue) + + def __call__(self, img, target=None): + if np.random.random() < self.prob: + return self.tfm(img), target + else: + return img, target diff --git a/rxnscribe/utils.py b/rxnscribe/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e91b42d249e88a50458f4ee2109a1ff9b171ce16 --- /dev/null +++ b/rxnscribe/utils.py @@ -0,0 +1,14 @@ +import json + + +def merge_predictions(results): + if len(results) == 0: + return {} + formats = results[0][1].keys() + predictions = {format_: {} for format_ in formats} + for format_ in formats: + for indices, batch_preds in results: + for idx, preds in zip(indices, batch_preds[format_]): + predictions[format_][idx] = preds + predictions[format_] = [predictions[format_][i] for i in range(len(predictions[format_]))] + return predictions