Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import json | |
from main import ChemEagle # 支持 API key 通过环境变量 | |
from rdkit import Chem | |
from rdkit.Chem import rdChemReactions | |
from rdkit.Chem import Draw | |
from rdkit.Chem import AllChem | |
from rdkit.Chem.Draw import rdMolDraw2D | |
import cairosvg | |
import re | |
import torch | |
example_diagram = "examples/exp.png" | |
rdkit_image = "examples/rdkit.png" | |
# 解析 ChemEagle 返回的结构化数据 | |
def parse_reactions(output_json): | |
""" | |
解析 JSON 格式的反应数据并格式化输出,包含颜色定制。 | |
""" | |
if isinstance(output_json, str): | |
reactions_data = json.loads(output_json) | |
elif isinstance(output_json, dict): | |
reactions_data = output_json # 转换 JSON 字符串为字典 | |
reactions_list = reactions_data.get("reactions", []) | |
detailed_output = [] | |
smiles_output = [] | |
for reaction in reactions_list: | |
reaction_id = reaction.get("reaction_id", "Unknown ID") | |
reactants = [r.get("smiles", "Unknown") for r in reaction.get("reactants", [])] | |
conditions = [ | |
f"<span style='color:red'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>" | |
for c in reaction.get("condition", []) | |
] | |
conditions_1 = [ | |
f"<span style='color:black'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>" | |
for c in reaction.get("condition", []) | |
] | |
products = [f"<span style='color:orange'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])] | |
products_1 = [f"<span style='color:black'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])] | |
products_2 = [r.get("smiles", "Unknown") for r in reaction.get("products", [])] | |
additional = reaction.get("additional_info", []) | |
additional_str = [str(x) for x in additional if x is not None] | |
tail = conditions_1 + additional_str | |
tail_str = ", ".join(tail) | |
# 构造反应的完整字符串,定制字体颜色 | |
full_reaction = f"{'.'.join(reactants)}>>{'.'.join(products_1)} | {tail_str}" | |
full_reaction = f"<span style='color:black'>{full_reaction}</span>" | |
# 详细反应格式化输出 | |
reaction_output = f"<b>Reaction: </b> {reaction_id}<br>" | |
reaction_output += f" Reactants: <span style='color:blue'>{', '.join(reactants)}</span><br>" | |
reaction_output += f" Conditions: {', '.join(conditions)}<br>" | |
reaction_output += f" Products: {', '.join(products)}<br>" | |
reaction_output += f" additional_info: {', '.join(additional_str)}<br>" | |
reaction_output += f" <b>Full Reaction:</b> {full_reaction}<br>" | |
reaction_output += "<br>" | |
detailed_output.append(reaction_output) | |
reaction_smiles = f"{'.'.join(reactants)}>>{'.'.join(products_2)}" | |
smiles_output.append(reaction_smiles) | |
return detailed_output, smiles_output | |
# 核心处理函数,仅使用 API Key 和图像 | |
def process_chem_image(api_key, image): | |
# 设置 API Key 环境变量,供 ChemEagle 使用 | |
os.environ["CHEMEAGLE_API_KEY"] = api_key | |
# 保存上传图片 | |
image_path = "temp_image.png" | |
image.save(image_path) | |
# 调用 ChemEagle(实现内部读取 os.getenv) | |
chemeagle_result = ChemEagle(image_path) | |
# 解析输出 | |
detailed, smiles = parse_reactions(chemeagle_result) | |
# 写出 JSON | |
json_path = "output.json" | |
with open(json_path, 'w') as jf: | |
json.dump(chemeagle_result, jf, indent=2) | |
# 返回 HTML、SMILES 合并文本、示意图、JSON 下载 | |
return "\n\n".join(detailed), smiles, example_diagram, json_path | |
# 构建 Gradio 界面 | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
<center><h1>ChemEagle: A Multi-Agent System for Multimodal Chemical Information Extraction</h1></center> | |
Upload a multimodal reaction image and type your OpenAI API key to extract multimodal chemical information. | |
""" | |
) | |
with gr.Row(): | |
# ———— 左侧:上传 + API Key + 按钮 ———— | |
with gr.Column(scale=1): | |
image_input = gr.Image(type="pil", label="Upload a multimodal reaction image") | |
api_key_input = gr.Textbox( | |
label="Your API-Key", | |
placeholder="Type your OpenAI_API_KEY", | |
type="password" | |
) | |
with gr.Row(): | |
clear_btn = gr.Button("Clear") | |
run_btn = gr.Button("Run", elem_id="submit-btn") | |
# ———— 中间:解析结果 + 示意图 ———— | |
with gr.Column(scale=1): | |
gr.Markdown("### Parsed Reactions") | |
reaction_output = gr.HTML(label="Detailed Reaction Output") | |
gr.Markdown("### Schematic Diagram") | |
schematic_diagram = gr.Image(value=example_diagram, label="Schematic Diagram") | |
# ———— 右侧:SMILES 拆分 & RDKit 渲染 + JSON 下载 ———— | |
with gr.Column(scale=1): | |
gr.Markdown("### Machine-readable Output") | |
smiles_output = gr.Textbox( | |
label="Reaction SMILES", | |
show_copy_button=True, | |
interactive=False, | |
visible=False | |
) | |
# 使用gr.render修饰器绑定输入和渲染逻辑 | |
def show_split(inputs): # 定义处理和展示分割文本的函数 | |
if not inputs or isinstance(inputs, str) and inputs.strip() == "": # 检查输入文本是否为空 | |
return gr.Textbox(label= "SMILES of Reaction i"), gr.Image(value=rdkit_image, label= "RDKit Image of Reaction i",height=100) | |
else: | |
# 假设输入是逗号分隔的 SMILES 字符串 | |
smiles_list = inputs.split(",") | |
smiles_list = [re.sub(r"^\s*\[?'?|'\]?\s*$", "", item) for item in smiles_list] | |
components = [] # 初始化一个组件列表,用于存放每个 SMILES 对应的 Textbox 组件 | |
for i, smiles in enumerate(smiles_list): | |
smiles.replace('"', '').replace("'", "").replace("[", "").replace("]", "") | |
rxn = rdChemReactions.ReactionFromSmarts(smiles, useSmiles=True) | |
if rxn: | |
new_rxn = AllChem.ChemicalReaction() | |
for mol in rxn.GetReactants(): | |
mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol)) | |
new_rxn.AddReactantTemplate(mol) | |
for mol in rxn.GetProducts(): | |
mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol)) | |
new_rxn.AddProductTemplate(mol) | |
rxn = new_rxn | |
def atom_mapping_remover(rxn): | |
for reactant in rxn.GetReactants(): | |
for atom in reactant.GetAtoms(): | |
atom.SetAtomMapNum(0) | |
for product in rxn.GetProducts(): | |
for atom in product.GetAtoms(): | |
atom.SetAtomMapNum(0) | |
return rxn | |
atom_mapping_remover(rxn) | |
reactant1 = rxn.GetReactantTemplate(0) | |
print(reactant1.GetNumBonds) | |
reactant2 = rxn.GetReactantTemplate(1) if rxn.GetNumReactantTemplates() > 1 else None | |
if reactant1.GetNumBonds() > 0: | |
bond_length_reference = Draw.MeanBondLength(reactant1) | |
elif reactant2 and reactant2.GetNumBonds() > 0: | |
bond_length_reference = Draw.MeanBondLength(reactant2) | |
else: | |
bond_length_reference = 1.0 | |
drawer = rdMolDraw2D.MolDraw2DSVG(-1, -1) | |
dopts = drawer.drawOptions() | |
dopts.padding = 0.1 | |
dopts.includeRadicals = True | |
Draw.SetACS1996Mode(dopts, bond_length_reference*0.55) | |
dopts.bondLineWidth = 1.5 | |
drawer.DrawReaction(rxn) | |
drawer.FinishDrawing() | |
svg_content = drawer.GetDrawingText() | |
svg_file = f"reaction{i+1}.svg" | |
with open(svg_file, "w") as f: | |
f.write(svg_content) | |
png_file = f"reaction_{i+1}.png" | |
cairosvg.svg2png(url=svg_file, write_to=png_file) | |
components.append(gr.Textbox(value=smiles,label= f"SMILES of Reaction {i}", show_copy_button=True, interactive=False)) | |
components.append(gr.Image(value=png_file,label= f"RDKit Image of Reaction {i}")) | |
return components # 返回包含所有 SMILES Textbox 组件的列表 | |
download_json = gr.File(label="Download JSON File") | |
gr.Examples( | |
examples=[ | |
["examples/reaction1.jpg", ""], | |
["examples/reaction2.png", ""], | |
["examples/reaction3.png", ""], | |
["examples/reaction4.png", ""], | |
], | |
inputs=[image_input, api_key_input], | |
outputs=[reaction_output, smiles_output, schematic_diagram, download_json], | |
cache_examples=False, | |
examples_per_page=4, | |
) | |
# ———— 清空与运行 绑定 ———— | |
clear_btn.click( | |
lambda: (None, None, None, None, None), | |
inputs=[], | |
outputs=[image_input, api_key_input, reaction_output, smiles_output, download_json] | |
) | |
run_btn.click( | |
process_chem_image, | |
inputs=[api_key_input, image_input], | |
outputs=[reaction_output, smiles_output, schematic_diagram, download_json] | |
) | |
# 自定义按钮样式 | |
demo.css = """ | |
#submit-btn { | |
background-color: #FF914D; | |
color: white; | |
font-weight: bold; | |
} | |
""" | |
demo.launch() |