Spaces:
Sleeping
Sleeping
import sys | |
import torch | |
import json | |
from chemietoolkit import ChemIEToolkit,utils | |
import cv2 | |
from openai import AzureOpenAI | |
import numpy as np | |
from PIL import Image | |
import json | |
from get_molecular_agent import process_reaction_image_with_multiple_products_and_text_correctR | |
from get_reaction_agent import get_reaction_withatoms_correctR | |
import sys | |
from rxnscribe import RxnScribe | |
import json | |
import base64 | |
model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) | |
ckpt_path = "./pix2seq_reaction_full.ckpt" | |
model1 = RxnScribe(ckpt_path, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) | |
device = torch.device(('cuda' if torch.cuda.is_available() else 'cpu')) | |
import os | |
# with open('api_key.txt', 'r') as api_key_file: | |
# API_KEY = api_key_file.read() | |
def parse_coref_data_with_fallback(data): | |
bboxes = data["bboxes"] | |
corefs = data["corefs"] | |
paired_indices = set() | |
# 先处理有 coref 配对的 | |
results = [] | |
for idx1, idx2 in corefs: | |
smiles_entry = bboxes[idx1] if "smiles" in bboxes[idx1] else bboxes[idx2] | |
text_entry = bboxes[idx2] if "text" in bboxes[idx2] else bboxes[idx1] | |
smiles = smiles_entry.get("smiles", "") | |
texts = text_entry.get("text", []) | |
results.append({ | |
"smiles": smiles, | |
"texts": texts | |
}) | |
# 记录下哪些 SMILES 被配对过了 | |
paired_indices.add(idx1) | |
paired_indices.add(idx2) | |
# 处理未配对的 SMILES(补充进来) | |
for idx, entry in enumerate(bboxes): | |
if "smiles" in entry and idx not in paired_indices: | |
results.append({ | |
"smiles": entry["smiles"], | |
"texts": ["There is no label or failed to detect, please recheck the image again"] | |
}) | |
return results | |
def get_multi_molecular_text_to_correct(image_path: str) -> list: | |
'''Returns a list of reactions extracted from the image.''' | |
# 打开图像文件 | |
image = Image.open(image_path).convert('RGB') | |
# 将图像作为输入传递给模型 | |
#coref_results = process_reaction_image_with_multiple_products_and_text_correctR(image_path) | |
coref_results = model.extract_molecule_corefs_from_figures([image]) | |
for item in coref_results: | |
for bbox in item.get("bboxes", []): | |
for key in ["category", "bbox", "molfile", "symbols", 'atoms', "bonds", 'category_id', 'score', 'corefs',"coords","edges"]: #'atoms' | |
bbox.pop(key, None) # 安全地移除键 | |
data = coref_results[0] | |
parsed = parse_coref_data_with_fallback(data) | |
print(f"coref_results:{json.dumps(parsed)}") | |
return json.dumps(parsed) | |
def get_reaction(image_path: str) -> dict: | |
''' | |
Returns a structured dictionary of reactions extracted from the image, | |
including only reactants, conditions, and products with their smiles, bbox, or text. | |
''' | |
image_file = image_path | |
#raw_prediction = model1.predict_image_file(image_file, molscribe=True, ocr=True) | |
raw_prediction = get_reaction_withatoms_correctR(image_path) | |
# Ensure raw_prediction is treated as a list directly | |
structured_output = {} | |
for section_key in ['reactants', 'conditions', 'products']: | |
if section_key in raw_prediction[0]: | |
structured_output[section_key] = [] | |
for item in raw_prediction[0][section_key]: | |
if section_key in ['reactants', 'products']: | |
# Extract smiles and bbox for molecules | |
structured_output[section_key].append({ | |
"smiles": item.get("smiles", ""), | |
"bbox": item.get("bbox", []) | |
}) | |
elif section_key == 'conditions': | |
# Extract text and bbox for conditions | |
structured_output[section_key].append({ | |
"text": item.get("text", []), | |
"bbox": item.get("bbox", []), | |
"smiles": item.get("smiles", []), | |
}) | |
return structured_output | |
def process_reaction_image(image_path: str) -> dict: | |
""" | |
Args: | |
image_path (str): 图像文件路径。 | |
Returns: | |
dict: 整理后的反应数据,包括反应物、产物和反应模板。 | |
""" | |
# 配置 API Key 和 Azure Endpoint | |
api_key = os.getenv("CHEMEAGLE_API_KEY") | |
if not api_key: | |
raise RuntimeError("Missing CHEMEAGLE_API_KEY environment variable") | |
azure_endpoint = "https://hkust.azure-api.net" # 替换为实际的 Azure Endpoint | |
model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) | |
client = AzureOpenAI( | |
api_key=api_key, | |
api_version='2024-06-01', | |
azure_endpoint=azure_endpoint | |
) | |
# 加载图像并编码为 Base64 | |
def encode_image(image_path: str): | |
with open(image_path, "rb") as image_file: | |
return base64.b64encode(image_file.read()).decode('utf-8') | |
base64_image = encode_image(image_path) | |
# GPT 工具调用配置 | |
tools = [ | |
{ | |
'type': 'function', | |
'function': { | |
'name': 'get_multi_molecular_text_to_correct', | |
'description': 'Extracts the SMILES string and text coref from molecular images.', | |
'parameters': { | |
'type': 'object', | |
'properties': { | |
'image_path': { | |
'type': 'string', | |
'description': 'Path to the reaction image.' | |
} | |
}, | |
'required': ['image_path'], | |
'additionalProperties': False | |
} | |
} | |
}, | |
{ | |
'type': 'function', | |
'function': { | |
'name': 'get_reaction', | |
'description': 'Get a list of reactions from a reaction image. A reaction contains data of the reactants, conditions, and products.', | |
'parameters': { | |
'type': 'object', | |
'properties': { | |
'image_path': { | |
'type': 'string', | |
'description': 'The path to the reaction image.', | |
}, | |
}, | |
'required': ['image_path'], | |
'additionalProperties': False, | |
}, | |
}, | |
}, | |
] | |
# 提供给 GPT 的消息内容 | |
with open('./prompt.txt', 'r') as prompt_file: | |
prompt = prompt_file.read() | |
messages = [ | |
{'role': 'system', 'content': 'You are a helpful assistant.'}, | |
{ | |
'role': 'user', | |
'content': [ | |
{'type': 'text', 'text': prompt}, | |
{'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}} | |
] | |
} | |
] | |
# 调用 GPT 接口 | |
response = client.chat.completions.create( | |
model = 'gpt-4o', | |
temperature = 0, | |
response_format={ 'type': 'json_object' }, | |
messages = [ | |
{'role': 'system', 'content': 'You are a helpful assistant.'}, | |
{ | |
'role': 'user', | |
'content': [ | |
{ | |
'type': 'text', | |
'text': prompt | |
}, | |
{ | |
'type': 'image_url', | |
'image_url': { | |
'url': f'data:image/png;base64,{base64_image}' | |
} | |
} | |
]}, | |
], | |
tools = tools) | |
# Step 1: 工具映射表 | |
TOOL_MAP = { | |
'get_multi_molecular_text_to_correct': get_multi_molecular_text_to_correct, | |
'get_reaction': get_reaction | |
} | |
# Step 2: 处理多个工具调用 | |
tool_calls = response.choices[0].message.tool_calls | |
results = [] | |
# 遍历每个工具调用 | |
for tool_call in tool_calls: | |
tool_name = tool_call.function.name | |
tool_arguments = tool_call.function.arguments | |
tool_call_id = tool_call.id | |
tool_args = json.loads(tool_arguments) | |
if tool_name in TOOL_MAP: | |
# 调用工具并获取结果 | |
tool_result = TOOL_MAP[tool_name](image_path) | |
else: | |
raise ValueError(f"Unknown tool called: {tool_name}") | |
# 保存每个工具调用结果 | |
results.append({ | |
'role': 'tool', | |
'content': json.dumps({ | |
'image_path': image_path, | |
f'{tool_name}':(tool_result), | |
}), | |
'tool_call_id': tool_call_id, | |
}) | |
# Prepare the chat completion payload | |
completion_payload = { | |
'model': 'gpt-4o', | |
'messages': [ | |
{'role': 'system', 'content': 'You are a helpful assistant.'}, | |
{ | |
'role': 'user', | |
'content': [ | |
{ | |
'type': 'text', | |
'text': prompt | |
}, | |
{ | |
'type': 'image_url', | |
'image_url': { | |
'url': f'data:image/png;base64,{base64_image}' | |
} | |
} | |
] | |
}, | |
response.choices[0].message, | |
*results | |
], | |
} | |
# Generate new response | |
response = client.chat.completions.create( | |
model=completion_payload["model"], | |
messages=completion_payload["messages"], | |
response_format={ 'type': 'json_object' }, | |
temperature=0 | |
) | |
# 获取 GPT 生成的结果 | |
gpt_output = json.loads(response.choices[0].message.content) | |
print(gpt_output) | |
image = Image.open(image_path).convert('RGB') | |
image_np = np.array(image) | |
# reaction_results = model.extract_reactions_from_figures([image_np]) | |
coref_results = model.extract_molecule_corefs_from_figures([image_np]) | |
reaction_results = get_reaction_withatoms_correctR(image_path)[0] | |
reaction = { | |
"reactants": reaction_results.get('reactants', []), | |
"conditions": reaction_results.get('conditions', []), | |
"products": reaction_results.get('products', []) | |
} | |
reaction_results = [{"reactions": [reaction]}] | |
print(reaction_results) | |
#coref_results = process_reaction_image_with_multiple_products_and_text_correctR(image_path) | |
# 定义更新工具输出的函数 | |
def extract_smiles_details(smiles_data, raw_details): | |
smiles_details = {} | |
for smiles in smiles_data: | |
for detail in raw_details: | |
for bbox in detail.get('bboxes', []): | |
if bbox.get('smiles') == smiles: | |
smiles_details[smiles] = { | |
'category': bbox.get('category'), | |
'bbox': bbox.get('bbox'), | |
'category_id': bbox.get('category_id'), | |
'score': bbox.get('score'), | |
'molfile': bbox.get('molfile'), | |
'atoms': bbox.get('atoms'), | |
'bonds': bbox.get('bonds'), | |
} | |
break | |
return smiles_details | |
# 获取结果 | |
smiles_details = extract_smiles_details(gpt_output, coref_results) | |
reactants_array = [] | |
products = [] | |
for reactant in reaction_results[0]['reactions'][0]['reactants']: | |
if 'smiles' in reactant: | |
print(f"SMILES:{reactant['smiles']}") | |
#print(reactant) | |
reactants_array.append(reactant['smiles']) | |
for product in reaction_results[0]['reactions'][0]['products']: | |
#print(product['smiles']) | |
#print(product) | |
products.append(product['smiles']) | |
# 输出结果 | |
#import pprint | |
#pprint.pprint(smiles_details) | |
# 整理反应数据 | |
backed_out = utils.backout_without_coref(reaction_results, coref_results, gpt_output, smiles_details, model.molscribe) | |
backed_out.sort(key=lambda x: x[2]) | |
extracted_rxns = {} | |
for reactants, products_, label in backed_out: | |
extracted_rxns[label] = {'reactants': reactants, 'products': products_} | |
toadd = { | |
"reaction_template": { | |
"reactants": reactants_array, | |
"products": products | |
}, | |
"reactions": extracted_rxns, | |
"original_molecule_list": gpt_output | |
} | |
# 按标签排序 | |
sorted_keys = sorted(toadd["reactions"].keys()) | |
toadd["reactions"] = {i: toadd["reactions"][i] for i in sorted_keys} | |
print(toadd) | |
return toadd | |
def ChemEagle(image_path: str) -> dict: | |
""" | |
输入化学反应图像路径,通过 GPT 模型和 TOOLS 提取反应信息并返回整理后的反应数据。 | |
Args: | |
image_path (str): 图像文件路径。 | |
Returns: | |
dict: 整理后的反应数据,包括反应物、产物和反应模板。 | |
""" | |
# 配置 API Key 和 Azure Endpoint | |
api_key = os.getenv("CHEMEAGLE_API_KEY") | |
if not api_key: | |
raise RuntimeError("Missing CHEMEAGLE_API_KEY environment variable") | |
azure_endpoint = "https://hkust.azure-api.net" # 替换为实际的 Azure Endpoint | |
model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) | |
client = AzureOpenAI( | |
api_key=api_key, | |
api_version='2024-06-01', | |
azure_endpoint=azure_endpoint | |
) | |
# 加载图像并编码为 Base64 | |
def encode_image(image_path: str): | |
with open(image_path, "rb") as image_file: | |
return base64.b64encode(image_file.read()).decode('utf-8') | |
base64_image = encode_image(image_path) | |
# GPT 工具调用配置 | |
tools = [ | |
{ | |
'type': 'function', | |
'function': { | |
'name': 'process_reaction_image', | |
'description': 'get the reaction data of the reaction diagram and get SMILES strings of every detailed reaction in reaction diagram and the table, and the original molecular list.', | |
'parameters': { | |
'type': 'object', | |
'properties': { | |
'image_path': { | |
'type': 'string', | |
'description': 'The path to the reaction image.', | |
}, | |
}, | |
'required': ['image_path'], | |
'additionalProperties': False, | |
}, | |
}, | |
}, | |
] | |
# 提供给 GPT 的消息内容 | |
with open('./prompt_final_simple_version.txt', 'r') as prompt_file: | |
prompt = prompt_file.read() | |
messages = [ | |
{'role': 'system', 'content': 'You are a helpful assistant.'}, | |
{ | |
'role': 'user', | |
'content': [ | |
{'type': 'text', 'text': prompt}, | |
{'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}} | |
] | |
} | |
] | |
# 调用 GPT 接口 | |
response = client.chat.completions.create( | |
model = 'gpt-4o', | |
temperature = 0, | |
response_format={ 'type': 'json_object' }, | |
messages = [ | |
{'role': 'system', 'content': 'You are a helpful assistant.'}, | |
{ | |
'role': 'user', | |
'content': [ | |
{ | |
'type': 'text', | |
'text': prompt | |
}, | |
{ | |
'type': 'image_url', | |
'image_url': { | |
'url': f'data:image/png;base64,{base64_image}' | |
} | |
} | |
]}, | |
], | |
tools = tools) | |
# Step 1: 工具映射表 | |
TOOL_MAP = { | |
'process_reaction_image': process_reaction_image | |
} | |
# Step 2: 处理多个工具调用 | |
tool_calls = response.choices[0].message.tool_calls | |
results = [] | |
# 遍历每个工具调用 | |
for tool_call in tool_calls: | |
tool_name = tool_call.function.name | |
tool_arguments = tool_call.function.arguments | |
tool_call_id = tool_call.id | |
tool_args = json.loads(tool_arguments) | |
if tool_name in TOOL_MAP: | |
# 调用工具并获取结果 | |
tool_result = TOOL_MAP[tool_name](image_path) | |
else: | |
raise ValueError(f"Unknown tool called: {tool_name}") | |
# 保存每个工具调用结果 | |
results.append({ | |
'role': 'tool', | |
'content': json.dumps({ | |
'image_path': image_path, | |
f'{tool_name}':(tool_result), | |
}), | |
'tool_call_id': tool_call_id, | |
}) | |
# Prepare the chat completion payload | |
completion_payload = { | |
'model': 'gpt-4o', | |
'messages': [ | |
{'role': 'system', 'content': 'You are a helpful assistant.'}, | |
{ | |
'role': 'user', | |
'content': [ | |
{ | |
'type': 'text', | |
'text': prompt | |
}, | |
{ | |
'type': 'image_url', | |
'image_url': { | |
'url': f'data:image/png;base64,{base64_image}' | |
} | |
} | |
] | |
}, | |
response.choices[0].message, | |
*results | |
], | |
} | |
# Generate new response | |
response = client.chat.completions.create( | |
model=completion_payload["model"], | |
messages=completion_payload["messages"], | |
response_format={ 'type': 'json_object' }, | |
temperature=0 | |
) | |
# 获取 GPT 生成的结果 | |
gpt_output = json.loads(response.choices[0].message.content) | |
print(gpt_output) | |
return gpt_output |