import os import gradio as gr import json from main import ChemEagle # 支持 API key 通过环境变量 from rdkit import Chem from rdkit.Chem import rdChemReactions from rdkit.Chem import Draw from rdkit.Chem import AllChem from rdkit.Chem.Draw import rdMolDraw2D import cairosvg import re import torch example_diagram = "examples/exp.png" rdkit_image = "examples/rdkit.png" # 解析 ChemEagle 返回的结构化数据 def parse_reactions(output_json): """ 解析 JSON 格式的反应数据并格式化输出,包含颜色定制。 """ if isinstance(output_json, str): reactions_data = json.loads(output_json) elif isinstance(output_json, dict): reactions_data = output_json # 转换 JSON 字符串为字典 reactions_list = reactions_data.get("reactions", []) detailed_output = [] smiles_output = [] for reaction in reactions_list: reaction_id = reaction.get("reaction_id", "Unknown ID") reactants = [r.get("smiles", "Unknown") for r in reaction.get("reactants", [])] conditions = [ f"{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]" for c in reaction.get("condition", []) ] conditions_1 = [ f"{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]" for c in reaction.get("condition", []) ] products = [f"{p.get('smiles', 'Unknown')}" for p in reaction.get("products", [])] products_1 = [f"{p.get('smiles', 'Unknown')}" for p in reaction.get("products", [])] products_2 = [r.get("smiles", "Unknown") for r in reaction.get("products", [])] additional = reaction.get("additional_info", []) additional_str = [str(x) for x in additional if x is not None] tail = conditions_1 + additional_str tail_str = ", ".join(tail) # 构造反应的完整字符串,定制字体颜色 full_reaction = f"{'.'.join(reactants)}>>{'.'.join(products_1)} | {tail_str}" full_reaction = f"{full_reaction}" # 详细反应格式化输出 reaction_output = f"Reaction: {reaction_id}
" reaction_output += f" Reactants: {', '.join(reactants)}
" reaction_output += f" Conditions: {', '.join(conditions)}
" reaction_output += f" Products: {', '.join(products)}
" reaction_output += f" additional_info: {', '.join(additional_str)}
" reaction_output += f" Full Reaction: {full_reaction}
" reaction_output += "
" detailed_output.append(reaction_output) reaction_smiles = f"{'.'.join(reactants)}>>{'.'.join(products_2)}" smiles_output.append(reaction_smiles) return detailed_output, smiles_output # 核心处理函数,仅使用 API Key 和图像 def process_chem_image(api_key, image): # 设置 API Key 环境变量,供 ChemEagle 使用 os.environ["CHEMEAGLE_API_KEY"] = api_key # 保存上传图片 image_path = "temp_image.png" image.save(image_path) # 调用 ChemEagle(实现内部读取 os.getenv) chemeagle_result = ChemEagle(image_path) # 解析输出 detailed, smiles = parse_reactions(chemeagle_result) # 写出 JSON json_path = "output.json" with open(json_path, 'w') as jf: json.dump(chemeagle_result, jf, indent=2) # 返回 HTML、SMILES 合并文本、示意图、JSON 下载 return "\n\n".join(detailed), smiles, example_diagram, json_path # 构建 Gradio 界面 with gr.Blocks() as demo: gr.Markdown( """

ChemEagle: A Multi-Agent System for Multimodal Chemical Information Extraction

Upload a multimodal reaction image and type your OpenAI API key to extract multimodal chemical information. """ ) with gr.Row(): # ———— 左侧:上传 + API Key + 按钮 ———— with gr.Column(scale=1): image_input = gr.Image(type="pil", label="Upload a multimodal reaction image") api_key_input = gr.Textbox( label="Your API-Key", placeholder="Type your OpenAI_API_KEY", type="password" ) with gr.Row(): clear_btn = gr.Button("Clear") run_btn = gr.Button("Run", elem_id="submit-btn") # ———— 中间:解析结果 + 示意图 ———— with gr.Column(scale=1): gr.Markdown("### Parsed Reactions") reaction_output = gr.HTML(label="Detailed Reaction Output") gr.Markdown("### Schematic Diagram") schematic_diagram = gr.Image(value=example_diagram, label="Schematic Diagram") # ———— 右侧:SMILES 拆分 & RDKit 渲染 + JSON 下载 ———— with gr.Column(scale=1): gr.Markdown("### Machine-readable Output") smiles_output = gr.Textbox( label="Reaction SMILES", show_copy_button=True, interactive=False, visible=False ) @gr.render(inputs = smiles_output) # 使用gr.render修饰器绑定输入和渲染逻辑 def show_split(inputs): # 定义处理和展示分割文本的函数 if not inputs or isinstance(inputs, str) and inputs.strip() == "": # 检查输入文本是否为空 return gr.Textbox(label= "SMILES of Reaction i"), gr.Image(value=rdkit_image, label= "RDKit Image of Reaction i",height=100) else: # 假设输入是逗号分隔的 SMILES 字符串 smiles_list = inputs.split(",") smiles_list = [re.sub(r"^\s*\[?'?|'\]?\s*$", "", item) for item in smiles_list] components = [] # 初始化一个组件列表,用于存放每个 SMILES 对应的 Textbox 组件 for i, smiles in enumerate(smiles_list): smiles.replace('"', '').replace("'", "").replace("[", "").replace("]", "") rxn = rdChemReactions.ReactionFromSmarts(smiles, useSmiles=True) if rxn: new_rxn = AllChem.ChemicalReaction() for mol in rxn.GetReactants(): mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol)) new_rxn.AddReactantTemplate(mol) for mol in rxn.GetProducts(): mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol)) new_rxn.AddProductTemplate(mol) rxn = new_rxn def atom_mapping_remover(rxn): for reactant in rxn.GetReactants(): for atom in reactant.GetAtoms(): atom.SetAtomMapNum(0) for product in rxn.GetProducts(): for atom in product.GetAtoms(): atom.SetAtomMapNum(0) return rxn atom_mapping_remover(rxn) reactant1 = rxn.GetReactantTemplate(0) print(reactant1.GetNumBonds) reactant2 = rxn.GetReactantTemplate(1) if rxn.GetNumReactantTemplates() > 1 else None if reactant1.GetNumBonds() > 0: bond_length_reference = Draw.MeanBondLength(reactant1) elif reactant2 and reactant2.GetNumBonds() > 0: bond_length_reference = Draw.MeanBondLength(reactant2) else: bond_length_reference = 1.0 drawer = rdMolDraw2D.MolDraw2DSVG(-1, -1) dopts = drawer.drawOptions() dopts.padding = 0.1 dopts.includeRadicals = True Draw.SetACS1996Mode(dopts, bond_length_reference*0.55) dopts.bondLineWidth = 1.5 drawer.DrawReaction(rxn) drawer.FinishDrawing() svg_content = drawer.GetDrawingText() svg_file = f"reaction{i+1}.svg" with open(svg_file, "w") as f: f.write(svg_content) png_file = f"reaction_{i+1}.png" cairosvg.svg2png(url=svg_file, write_to=png_file) components.append(gr.Textbox(value=smiles,label= f"SMILES of Reaction {i}", show_copy_button=True, interactive=False)) components.append(gr.Image(value=png_file,label= f"RDKit Image of Reaction {i}")) return components # 返回包含所有 SMILES Textbox 组件的列表 download_json = gr.File(label="Download JSON File") gr.Examples( examples=[ ["examples/reaction1.jpg", ""], ["examples/reaction2.png", ""], ["examples/reaction3.png", ""], ["examples/reaction4.png", ""], ], inputs=[image_input, api_key_input], outputs=[reaction_output, smiles_output, schematic_diagram, download_json], cache_examples=False, examples_per_page=4, ) # ———— 清空与运行 绑定 ———— clear_btn.click( lambda: (None, None, None, None, None), inputs=[], outputs=[image_input, api_key_input, reaction_output, smiles_output, download_json] ) run_btn.click( process_chem_image, inputs=[api_key_input, image_input], outputs=[reaction_output, smiles_output, schematic_diagram, download_json] ) # 自定义按钮样式 demo.css = """ #submit-btn { background-color: #FF914D; color: white; font-weight: bold; } """ demo.launch()