ChemEagle_API / app.py
CYF200127's picture
Update app.py
08d0a04 verified
import os
import gradio as gr
import json
from main import ChemEagle # 支持 API key 通过环境变量
from rdkit import Chem
from rdkit.Chem import rdChemReactions
from rdkit.Chem import Draw
from rdkit.Chem import AllChem
from rdkit.Chem.Draw import rdMolDraw2D
import cairosvg
import re
import torch
example_diagram = "examples/exp.png"
rdkit_image = "examples/rdkit.png"
# 解析 ChemEagle 返回的结构化数据
def parse_reactions(output_json):
"""
解析 JSON 格式的反应数据并格式化输出,包含颜色定制。
"""
if isinstance(output_json, str):
reactions_data = json.loads(output_json)
elif isinstance(output_json, dict):
reactions_data = output_json # 转换 JSON 字符串为字典
reactions_list = reactions_data.get("reactions", [])
detailed_output = []
smiles_output = []
for reaction in reactions_list:
reaction_id = reaction.get("reaction_id", "Unknown ID")
reactants = [r.get("smiles", "Unknown") for r in reaction.get("reactants", [])]
conditions = [
f"<span style='color:red'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>"
for c in reaction.get("condition", [])
]
conditions_1 = [
f"<span style='color:black'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>"
for c in reaction.get("condition", [])
]
products = [f"<span style='color:orange'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])]
products_1 = [f"<span style='color:black'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])]
products_2 = [r.get("smiles", "Unknown") for r in reaction.get("products", [])]
additional = reaction.get("additional_info", [])
additional_str = [str(x) for x in additional if x is not None]
tail = conditions_1 + additional_str
tail_str = ", ".join(tail)
# 构造反应的完整字符串,定制字体颜色
full_reaction = f"{'.'.join(reactants)}>>{'.'.join(products_1)} | {tail_str}"
full_reaction = f"<span style='color:black'>{full_reaction}</span>"
# 详细反应格式化输出
reaction_output = f"<b>Reaction: </b> {reaction_id}<br>"
reaction_output += f" Reactants: <span style='color:blue'>{', '.join(reactants)}</span><br>"
reaction_output += f" Conditions: {', '.join(conditions)}<br>"
reaction_output += f" Products: {', '.join(products)}<br>"
reaction_output += f" additional_info: {', '.join(additional_str)}<br>"
reaction_output += f" <b>Full Reaction:</b> {full_reaction}<br>"
reaction_output += "<br>"
detailed_output.append(reaction_output)
reaction_smiles = f"{'.'.join(reactants)}>>{'.'.join(products_2)}"
smiles_output.append(reaction_smiles)
return detailed_output, smiles_output
# 核心处理函数,仅使用 API Key 和图像
def process_chem_image(api_key, image):
# 设置 API Key 环境变量,供 ChemEagle 使用
os.environ["CHEMEAGLE_API_KEY"] = api_key
# 保存上传图片
image_path = "temp_image.png"
image.save(image_path)
# 调用 ChemEagle(实现内部读取 os.getenv)
chemeagle_result = ChemEagle(image_path)
# 解析输出
detailed, smiles = parse_reactions(chemeagle_result)
# 写出 JSON
json_path = "output.json"
with open(json_path, 'w') as jf:
json.dump(chemeagle_result, jf, indent=2)
# 返回 HTML、SMILES 合并文本、示意图、JSON 下载
return "\n\n".join(detailed), smiles, example_diagram, json_path
# 构建 Gradio 界面
with gr.Blocks() as demo:
gr.Markdown(
"""
<center><h1>ChemEagle: A Multi-Agent System for Multimodal Chemical Information Extraction</h1></center>
Upload a multimodal reaction image and type your OpenAI API key to extract multimodal chemical information.
"""
)
with gr.Row():
# ———— 左侧:上传 + API Key + 按钮 ————
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload a multimodal reaction image")
api_key_input = gr.Textbox(
label="Your API-Key",
placeholder="Type your OpenAI_API_KEY",
type="password"
)
with gr.Row():
clear_btn = gr.Button("Clear")
run_btn = gr.Button("Run", elem_id="submit-btn")
# ———— 中间:解析结果 + 示意图 ————
with gr.Column(scale=1):
gr.Markdown("### Parsed Reactions")
reaction_output = gr.HTML(label="Detailed Reaction Output")
gr.Markdown("### Schematic Diagram")
schematic_diagram = gr.Image(value=example_diagram, label="Schematic Diagram")
# ———— 右侧:SMILES 拆分 & RDKit 渲染 + JSON 下载 ————
with gr.Column(scale=1):
gr.Markdown("### Machine-readable Output")
smiles_output = gr.Textbox(
label="Reaction SMILES",
show_copy_button=True,
interactive=False,
visible=False
)
@gr.render(inputs = smiles_output) # 使用gr.render修饰器绑定输入和渲染逻辑
def show_split(inputs): # 定义处理和展示分割文本的函数
if not inputs or isinstance(inputs, str) and inputs.strip() == "": # 检查输入文本是否为空
return gr.Textbox(label= "SMILES of Reaction i"), gr.Image(value=rdkit_image, label= "RDKit Image of Reaction i",height=100)
else:
# 假设输入是逗号分隔的 SMILES 字符串
smiles_list = inputs.split(",")
smiles_list = [re.sub(r"^\s*\[?'?|'\]?\s*$", "", item) for item in smiles_list]
components = [] # 初始化一个组件列表,用于存放每个 SMILES 对应的 Textbox 组件
for i, smiles in enumerate(smiles_list):
smiles.replace('"', '').replace("'", "").replace("[", "").replace("]", "")
rxn = rdChemReactions.ReactionFromSmarts(smiles, useSmiles=True)
if rxn:
new_rxn = AllChem.ChemicalReaction()
for mol in rxn.GetReactants():
mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol))
new_rxn.AddReactantTemplate(mol)
for mol in rxn.GetProducts():
mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol))
new_rxn.AddProductTemplate(mol)
rxn = new_rxn
def atom_mapping_remover(rxn):
for reactant in rxn.GetReactants():
for atom in reactant.GetAtoms():
atom.SetAtomMapNum(0)
for product in rxn.GetProducts():
for atom in product.GetAtoms():
atom.SetAtomMapNum(0)
return rxn
atom_mapping_remover(rxn)
reactant1 = rxn.GetReactantTemplate(0)
print(reactant1.GetNumBonds)
reactant2 = rxn.GetReactantTemplate(1) if rxn.GetNumReactantTemplates() > 1 else None
if reactant1.GetNumBonds() > 0:
bond_length_reference = Draw.MeanBondLength(reactant1)
elif reactant2 and reactant2.GetNumBonds() > 0:
bond_length_reference = Draw.MeanBondLength(reactant2)
else:
bond_length_reference = 1.0
drawer = rdMolDraw2D.MolDraw2DSVG(-1, -1)
dopts = drawer.drawOptions()
dopts.padding = 0.1
dopts.includeRadicals = True
Draw.SetACS1996Mode(dopts, bond_length_reference*0.55)
dopts.bondLineWidth = 1.5
drawer.DrawReaction(rxn)
drawer.FinishDrawing()
svg_content = drawer.GetDrawingText()
svg_file = f"reaction{i+1}.svg"
with open(svg_file, "w") as f:
f.write(svg_content)
png_file = f"reaction_{i+1}.png"
cairosvg.svg2png(url=svg_file, write_to=png_file)
components.append(gr.Textbox(value=smiles,label= f"SMILES of Reaction {i}", show_copy_button=True, interactive=False))
components.append(gr.Image(value=png_file,label= f"RDKit Image of Reaction {i}"))
return components # 返回包含所有 SMILES Textbox 组件的列表
download_json = gr.File(label="Download JSON File")
gr.Examples(
examples=[
["examples/reaction1.jpg", ""],
["examples/reaction2.png", ""],
["examples/reaction3.png", ""],
["examples/reaction4.png", ""],
],
inputs=[image_input, api_key_input],
outputs=[reaction_output, smiles_output, schematic_diagram, download_json],
cache_examples=False,
examples_per_page=4,
)
# ———— 清空与运行 绑定 ————
clear_btn.click(
lambda: (None, None, None, None, None),
inputs=[],
outputs=[image_input, api_key_input, reaction_output, smiles_output, download_json]
)
run_btn.click(
process_chem_image,
inputs=[api_key_input, image_input],
outputs=[reaction_output, smiles_output, schematic_diagram, download_json]
)
# 自定义按钮样式
demo.css = """
#submit-btn {
background-color: #FF914D;
color: white;
font-weight: bold;
}
"""
demo.launch()