ChemEagle_API / get_molecular_agent.py
CYF200127's picture
Upload 162 files
1f516b6 verified
import sys
import torch
import json
from chemietoolkit import ChemIEToolkit
import cv2
from PIL import Image
import json
import sys
#sys.path.append('./RxnScribe-main/')
import torch
from rxnscribe import RxnScribe
import json
import sys
import torch
import json
model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
from molscribe.chemistry import _convert_graph_to_smiles
import base64
import torch
import json
from PIL import Image
import numpy as np
from chemietoolkit import ChemIEToolkit, utils
from openai import AzureOpenAI
import os
ckpt_path = "./pix2seq_reaction_full.ckpt"
model1 = RxnScribe(ckpt_path, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
device = torch.device(('cuda' if torch.cuda.is_available() else 'cpu'))
model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
def get_multi_molecular(image_path: str) -> list:
'''Returns a list of reactions extracted from the image.'''
# 打开图像文件
image = Image.open(image_path).convert('RGB')
# 将图像作为输入传递给模型
coref_results = model.extract_molecule_corefs_from_figures([image])
for item in coref_results:
for bbox in item.get("bboxes", []):
for key in ["category", "molfile", "symbols", 'atoms', "bonds", 'category_id', 'score', 'corefs']: #'atoms'
bbox.pop(key, None) # 安全地移除键
print(json.dumps(coref_results))
# 返回反应列表,使用 json.dumps 进行格式化
return json.dumps(coref_results)
def get_multi_molecular_text_to_correct(image_path: str) -> list:
'''Returns a list of reactions extracted from the image.'''
# 打开图像文件
image = Image.open(image_path).convert('RGB')
# 将图像作为输入传递给模型
coref_results = model.extract_molecule_corefs_from_figures([image])
for item in coref_results:
for bbox in item.get("bboxes", []):
for key in ["category", "bbox", "molfile", "symbols", 'atoms', "bonds", 'category_id', 'score', 'corefs']: #'atoms'
bbox.pop(key, None) # 安全地移除键
print(json.dumps(coref_results))
# 返回反应列表,使用 json.dumps 进行格式化
return json.dumps(coref_results)
def get_multi_molecular_text_to_correct_withatoms(image_path: str) -> list:
'''Returns a list of reactions extracted from the image.'''
# 打开图像文件
image = Image.open(image_path).convert('RGB')
# 将图像作为输入传递给模型
coref_results = model.extract_molecule_corefs_from_figures([image])
for item in coref_results:
for bbox in item.get("bboxes", []):
for key in ["coords","edges","molfile", 'atoms', "bonds", 'category_id', 'score', 'corefs']: #'atoms'
bbox.pop(key, None) # 安全地移除键
print(json.dumps(coref_results))
# 返回反应列表,使用 json.dumps 进行格式化
return json.dumps(coref_results)
def process_reaction_image_with_multiple_products_and_text(image_path: str) -> dict:
"""
Args:
image_path (str): 图像文件路径。
Returns:
dict: 整理后的反应数据,包括反应物、产物和反应模板。
"""
# 配置 API Key 和 Azure Endpoint
api_key = os.getenv("CHEMEAGLE_API_KEY")
if not api_key:
raise RuntimeError("Missing CHEMEAGLE_API_KEY environment variable")
# 替换为实际的 API Key
azure_endpoint = "https://hkust.azure-api.net" # 替换为实际的 Azure Endpoint
model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
client = AzureOpenAI(
api_key=api_key,
api_version='2024-06-01',
azure_endpoint=azure_endpoint
)
# 加载图像并编码为 Base64
def encode_image(image_path: str):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
base64_image = encode_image(image_path)
# GPT 工具调用配置
tools = [
{
'type': 'function',
'function': {
'name': 'get_multi_molecular_text_to_correct_withatoms',
'description': 'Extracts the SMILES string, the symbols set, and the text coref of all molecular images in a table-reaction image and ready to be correct.',
'parameters': {
'type': 'object',
'properties': {
'image_path': {
'type': 'string',
'description': 'The path to the reaction image.',
},
},
'required': ['image_path'],
'additionalProperties': False,
},
},
},
]
# 提供给 GPT 的消息内容
with open('./prompt_getmolecular.txt', 'r') as prompt_file:
prompt = prompt_file.read()
messages = [
{'role': 'system', 'content': 'You are a helpful assistant.'},
{
'role': 'user',
'content': [
{'type': 'text', 'text': prompt},
{'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}}
]
}
]
# 调用 GPT 接口
response = client.chat.completions.create(
model = 'gpt-4o',
temperature = 0,
response_format={ 'type': 'json_object' },
messages = [
{'role': 'system', 'content': 'You are a helpful assistant.'},
{
'role': 'user',
'content': [
{
'type': 'text',
'text': prompt
},
{
'type': 'image_url',
'image_url': {
'url': f'data:image/png;base64,{base64_image}'
}
}
]},
],
tools = tools)
# Step 1: 工具映射表
TOOL_MAP = {
'get_multi_molecular_text_to_correct_withatoms': get_multi_molecular_text_to_correct_withatoms,
}
# Step 2: 处理多个工具调用
tool_calls = response.choices[0].message.tool_calls
results = []
# 遍历每个工具调用
for tool_call in tool_calls:
tool_name = tool_call.function.name
tool_arguments = tool_call.function.arguments
tool_call_id = tool_call.id
tool_args = json.loads(tool_arguments)
if tool_name in TOOL_MAP:
# 调用工具并获取结果
tool_result = TOOL_MAP[tool_name](image_path)
else:
raise ValueError(f"Unknown tool called: {tool_name}")
# 保存每个工具调用结果
results.append({
'role': 'tool',
'content': json.dumps({
'image_path': image_path,
f'{tool_name}':(tool_result),
}),
'tool_call_id': tool_call_id,
})
# Prepare the chat completion payload
completion_payload = {
'model': 'gpt-4o',
'messages': [
{'role': 'system', 'content': 'You are a helpful assistant.'},
{
'role': 'user',
'content': [
{
'type': 'text',
'text': prompt
},
{
'type': 'image_url',
'image_url': {
'url': f'data:image/png;base64,{base64_image}'
}
}
]
},
response.choices[0].message,
*results
],
}
# Generate new response
response = client.chat.completions.create(
model=completion_payload["model"],
messages=completion_payload["messages"],
response_format={ 'type': 'json_object' },
temperature=0
)
# 获取 GPT 生成的结果
gpt_output = [json.loads(response.choices[0].message.content)]
def get_multi_molecular(image_path: str) -> list:
'''Returns a list of reactions extracted from the image.'''
# 打开图像文件
image = Image.open(image_path).convert('RGB')
# 将图像作为输入传递给模型
coref_results = model.extract_molecule_corefs_from_figures([image])
return coref_results
coref_results = get_multi_molecular(image_path)
def update_symbols_in_atoms(input1, input2):
"""
用 input1 中更新后的 'symbols' 替换 input2 中对应 bboxes 的 'symbols',并同步更新 'atoms' 的 'atom_symbol'。
假设 input1 和 input2 的结构一致。
"""
for item1, item2 in zip(input1, input2):
bboxes1 = item1.get('bboxes', [])
bboxes2 = item2.get('bboxes', [])
if len(bboxes1) != len(bboxes2):
print("Warning: Mismatched number of bboxes!")
continue
for bbox1, bbox2 in zip(bboxes1, bboxes2):
# 更新 symbols
if 'symbols' in bbox1:
bbox2['symbols'] = bbox1['symbols'] # 更新 symbols
# 更新 atoms 的 atom_symbol
if 'symbols' in bbox1 and 'atoms' in bbox2:
symbols = bbox1['symbols']
atoms = bbox2.get('atoms', [])
# 确保 symbols 和 atoms 的长度一致
if len(symbols) != len(atoms):
print(f"Warning: Mismatched symbols and atoms in bbox {bbox1.get('bbox')}!")
continue
for atom, symbol in zip(atoms, symbols):
atom['atom_symbol'] = symbol # 更新 atom_symbol
return input2
input2_updated = update_symbols_in_atoms(gpt_output, coref_results)
def update_smiles_and_molfile(input_data, conversion_function):
"""
使用更新后的 'symbols'、'coords' 和 'edges' 调用 `conversion_function` 生成新的 'smiles' 和 'molfile',
并替换到原数据结构中。
参数:
- input_data: 包含 bboxes 的嵌套数据结构
- conversion_function: 函数,接受 'coords', 'symbols', 'edges' 并返回 (new_smiles, new_molfile, _)
返回:
- 更新后的数据结构
"""
for item in input_data:
for bbox in item.get('bboxes', []):
# 检查必需的键是否存在
if all(key in bbox for key in ['coords', 'symbols', 'edges']):
coords = bbox['coords']
symbols = bbox['symbols']
edges = bbox['edges']
# 调用转换函数生成新的 'smiles' 和 'molfile'
new_smiles, new_molfile, _ = conversion_function(coords, symbols, edges)
print(f" Generated 'smiles': {new_smiles}")
# 替换旧的 'smiles' 和 'molfile'
bbox['smiles'] = new_smiles
bbox['molfile'] = new_molfile
return input_data
updated_data = update_smiles_and_molfile(input2_updated, _convert_graph_to_smiles)
return updated_data
def process_reaction_image_with_multiple_products_and_text_correctR(image_path: str) -> dict:
"""
Args:
image_path (str): 图像文件路径。
Returns:
dict: 整理后的反应数据,包括反应物、产物和反应模板。
"""
# 配置 API Key 和 Azure Endpoint
api_key = os.getenv("CHEMEAGLE_API_KEY")
if not api_key:
raise RuntimeError("Missing CHEMEAGLE_API_KEY environment variable")
# 替换为实际的 API Key
azure_endpoint = "https://hkust.azure-api.net" # 替换为实际的 Azure Endpoint
model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
client = AzureOpenAI(
api_key=api_key,
api_version='2024-06-01',
azure_endpoint=azure_endpoint
)
# 加载图像并编码为 Base64
def encode_image(image_path: str):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
base64_image = encode_image(image_path)
# GPT 工具调用配置
tools = [
{
'type': 'function',
'function': {
'name': 'get_multi_molecular_text_to_correct_withatoms',
'description': 'Extracts the SMILES string, the symbols set, and the text coref of all molecular images in a table-reaction image and ready to be correct.',
'parameters': {
'type': 'object',
'properties': {
'image_path': {
'type': 'string',
'description': 'The path to the reaction image.',
},
},
'required': ['image_path'],
'additionalProperties': False,
},
},
},
]
# 提供给 GPT 的消息内容
with open('./prompt_getmolecular_correctR.txt', 'r') as prompt_file:
prompt = prompt_file.read()
messages = [
{'role': 'system', 'content': 'You are a helpful assistant.'},
{
'role': 'user',
'content': [
{'type': 'text', 'text': prompt},
{'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}}
]
}
]
# 调用 GPT 接口
response = client.chat.completions.create(
model = 'gpt-4o',
temperature = 0,
response_format={ 'type': 'json_object' },
messages = [
{'role': 'system', 'content': 'You are a helpful assistant.'},
{
'role': 'user',
'content': [
{
'type': 'text',
'text': prompt
},
{
'type': 'image_url',
'image_url': {
'url': f'data:image/png;base64,{base64_image}'
}
}
]},
],
tools = tools)
# Step 1: 工具映射表
TOOL_MAP = {
'get_multi_molecular_text_to_correct_withatoms': get_multi_molecular_text_to_correct_withatoms,
}
# Step 2: 处理多个工具调用
tool_calls = response.choices[0].message.tool_calls
results = []
# 遍历每个工具调用
for tool_call in tool_calls:
tool_name = tool_call.function.name
tool_arguments = tool_call.function.arguments
tool_call_id = tool_call.id
tool_args = json.loads(tool_arguments)
if tool_name in TOOL_MAP:
# 调用工具并获取结果
tool_result = TOOL_MAP[tool_name](image_path)
else:
raise ValueError(f"Unknown tool called: {tool_name}")
# 保存每个工具调用结果
results.append({
'role': 'tool',
'content': json.dumps({
'image_path': image_path,
f'{tool_name}':(tool_result),
}),
'tool_call_id': tool_call_id,
})
# Prepare the chat completion payload
completion_payload = {
'model': 'gpt-4o',
'messages': [
{'role': 'system', 'content': 'You are a helpful assistant.'},
{
'role': 'user',
'content': [
{
'type': 'text',
'text': prompt
},
{
'type': 'image_url',
'image_url': {
'url': f'data:image/png;base64,{base64_image}'
}
}
]
},
response.choices[0].message,
*results
],
}
# Generate new response
response = client.chat.completions.create(
model=completion_payload["model"],
messages=completion_payload["messages"],
response_format={ 'type': 'json_object' },
temperature=0
)
# 获取 GPT 生成的结果
gpt_output = [json.loads(response.choices[0].message.content)]
def get_multi_molecular(image_path: str) -> list:
'''Returns a list of reactions extracted from the image.'''
# 打开图像文件
image = Image.open(image_path).convert('RGB')
# 将图像作为输入传递给模型
coref_results = model.extract_molecule_corefs_from_figures([image])
return coref_results
coref_results = get_multi_molecular(image_path)
def update_symbols_in_atoms(input1, input2):
"""
用 input1 中更新后的 'symbols' 替换 input2 中对应 bboxes 的 'symbols',并同步更新 'atoms' 的 'atom_symbol'。
假设 input1 和 input2 的结构一致。
"""
for item1, item2 in zip(input1, input2):
bboxes1 = item1.get('bboxes', [])
bboxes2 = item2.get('bboxes', [])
if len(bboxes1) != len(bboxes2):
print("Warning: Mismatched number of bboxes!")
continue
for bbox1, bbox2 in zip(bboxes1, bboxes2):
# 更新 symbols
if 'symbols' in bbox1:
bbox2['symbols'] = bbox1['symbols'] # 更新 symbols
# 更新 atoms 的 atom_symbol
if 'symbols' in bbox1 and 'atoms' in bbox2:
symbols = bbox1['symbols']
atoms = bbox2.get('atoms', [])
# 确保 symbols 和 atoms 的长度一致
if len(symbols) != len(atoms):
print(f"Warning: Mismatched symbols and atoms in bbox {bbox1.get('bbox')}!")
continue
for atom, symbol in zip(atoms, symbols):
atom['atom_symbol'] = symbol # 更新 atom_symbol
return input2
input2_updated = update_symbols_in_atoms(gpt_output, coref_results)
def update_smiles_and_molfile(input_data, conversion_function):
"""
使用更新后的 'symbols'、'coords' 和 'edges' 调用 `conversion_function` 生成新的 'smiles' 和 'molfile',
并替换到原数据结构中。
参数:
- input_data: 包含 bboxes 的嵌套数据结构
- conversion_function: 函数,接受 'coords', 'symbols', 'edges' 并返回 (new_smiles, new_molfile, _)
返回:
- 更新后的数据结构
"""
for item in input_data:
for bbox in item.get('bboxes', []):
# 检查必需的键是否存在
if all(key in bbox for key in ['coords', 'symbols', 'edges']):
coords = bbox['coords']
symbols = bbox['symbols']
edges = bbox['edges']
# 调用转换函数生成新的 'smiles' 和 'molfile'
new_smiles, new_molfile, _ = conversion_function(coords, symbols, edges)
print(f" Generated 'smiles': {new_smiles}")
# 替换旧的 'smiles' 和 'molfile'
bbox['smiles'] = new_smiles
bbox['molfile'] = new_molfile
return input_data
updated_data = update_smiles_and_molfile(input2_updated, _convert_graph_to_smiles)
print(f"updated_mol_data:{updated_data}")
return updated_data