import sys import torch import json from chemietoolkit import ChemIEToolkit import cv2 from PIL import Image import json import sys #sys.path.append('./RxnScribe-main/') import torch from rxnscribe import RxnScribe import json import sys import torch import json model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) from molscribe.chemistry import _convert_graph_to_smiles import base64 import torch import json from PIL import Image import numpy as np from chemietoolkit import ChemIEToolkit, utils from openai import AzureOpenAI import os ckpt_path = "./pix2seq_reaction_full.ckpt" model1 = RxnScribe(ckpt_path, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) device = torch.device(('cuda' if torch.cuda.is_available() else 'cpu')) model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) def get_multi_molecular(image_path: str) -> list: '''Returns a list of reactions extracted from the image.''' # 打开图像文件 image = Image.open(image_path).convert('RGB') # 将图像作为输入传递给模型 coref_results = model.extract_molecule_corefs_from_figures([image]) for item in coref_results: for bbox in item.get("bboxes", []): for key in ["category", "molfile", "symbols", 'atoms', "bonds", 'category_id', 'score', 'corefs']: #'atoms' bbox.pop(key, None) # 安全地移除键 print(json.dumps(coref_results)) # 返回反应列表,使用 json.dumps 进行格式化 return json.dumps(coref_results) def get_multi_molecular_text_to_correct(image_path: str) -> list: '''Returns a list of reactions extracted from the image.''' # 打开图像文件 image = Image.open(image_path).convert('RGB') # 将图像作为输入传递给模型 coref_results = model.extract_molecule_corefs_from_figures([image]) for item in coref_results: for bbox in item.get("bboxes", []): for key in ["category", "bbox", "molfile", "symbols", 'atoms', "bonds", 'category_id', 'score', 'corefs']: #'atoms' bbox.pop(key, None) # 安全地移除键 print(json.dumps(coref_results)) # 返回反应列表,使用 json.dumps 进行格式化 return json.dumps(coref_results) def get_multi_molecular_text_to_correct_withatoms(image_path: str) -> list: '''Returns a list of reactions extracted from the image.''' # 打开图像文件 image = Image.open(image_path).convert('RGB') # 将图像作为输入传递给模型 coref_results = model.extract_molecule_corefs_from_figures([image]) for item in coref_results: for bbox in item.get("bboxes", []): for key in ["coords","edges","molfile", 'atoms', "bonds", 'category_id', 'score', 'corefs']: #'atoms' bbox.pop(key, None) # 安全地移除键 print(json.dumps(coref_results)) # 返回反应列表,使用 json.dumps 进行格式化 return json.dumps(coref_results) def process_reaction_image_with_multiple_products_and_text(image_path: str) -> dict: """ Args: image_path (str): 图像文件路径。 Returns: dict: 整理后的反应数据,包括反应物、产物和反应模板。 """ # 配置 API Key 和 Azure Endpoint api_key = os.getenv("CHEMEAGLE_API_KEY") if not api_key: raise RuntimeError("Missing CHEMEAGLE_API_KEY environment variable") # 替换为实际的 API Key azure_endpoint = "https://hkust.azure-api.net" # 替换为实际的 Azure Endpoint model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) client = AzureOpenAI( api_key=api_key, api_version='2024-06-01', azure_endpoint=azure_endpoint ) # 加载图像并编码为 Base64 def encode_image(image_path: str): with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode('utf-8') base64_image = encode_image(image_path) # GPT 工具调用配置 tools = [ { 'type': 'function', 'function': { 'name': 'get_multi_molecular_text_to_correct_withatoms', 'description': 'Extracts the SMILES string, the symbols set, and the text coref of all molecular images in a table-reaction image and ready to be correct.', 'parameters': { 'type': 'object', 'properties': { 'image_path': { 'type': 'string', 'description': 'The path to the reaction image.', }, }, 'required': ['image_path'], 'additionalProperties': False, }, }, }, ] # 提供给 GPT 的消息内容 with open('./prompt_getmolecular.txt', 'r') as prompt_file: prompt = prompt_file.read() messages = [ {'role': 'system', 'content': 'You are a helpful assistant.'}, { 'role': 'user', 'content': [ {'type': 'text', 'text': prompt}, {'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}} ] } ] # 调用 GPT 接口 response = client.chat.completions.create( model = 'gpt-4o', temperature = 0, response_format={ 'type': 'json_object' }, messages = [ {'role': 'system', 'content': 'You are a helpful assistant.'}, { 'role': 'user', 'content': [ { 'type': 'text', 'text': prompt }, { 'type': 'image_url', 'image_url': { 'url': f'data:image/png;base64,{base64_image}' } } ]}, ], tools = tools) # Step 1: 工具映射表 TOOL_MAP = { 'get_multi_molecular_text_to_correct_withatoms': get_multi_molecular_text_to_correct_withatoms, } # Step 2: 处理多个工具调用 tool_calls = response.choices[0].message.tool_calls results = [] # 遍历每个工具调用 for tool_call in tool_calls: tool_name = tool_call.function.name tool_arguments = tool_call.function.arguments tool_call_id = tool_call.id tool_args = json.loads(tool_arguments) if tool_name in TOOL_MAP: # 调用工具并获取结果 tool_result = TOOL_MAP[tool_name](image_path) else: raise ValueError(f"Unknown tool called: {tool_name}") # 保存每个工具调用结果 results.append({ 'role': 'tool', 'content': json.dumps({ 'image_path': image_path, f'{tool_name}':(tool_result), }), 'tool_call_id': tool_call_id, }) # Prepare the chat completion payload completion_payload = { 'model': 'gpt-4o', 'messages': [ {'role': 'system', 'content': 'You are a helpful assistant.'}, { 'role': 'user', 'content': [ { 'type': 'text', 'text': prompt }, { 'type': 'image_url', 'image_url': { 'url': f'data:image/png;base64,{base64_image}' } } ] }, response.choices[0].message, *results ], } # Generate new response response = client.chat.completions.create( model=completion_payload["model"], messages=completion_payload["messages"], response_format={ 'type': 'json_object' }, temperature=0 ) # 获取 GPT 生成的结果 gpt_output = [json.loads(response.choices[0].message.content)] def get_multi_molecular(image_path: str) -> list: '''Returns a list of reactions extracted from the image.''' # 打开图像文件 image = Image.open(image_path).convert('RGB') # 将图像作为输入传递给模型 coref_results = model.extract_molecule_corefs_from_figures([image]) return coref_results coref_results = get_multi_molecular(image_path) def update_symbols_in_atoms(input1, input2): """ 用 input1 中更新后的 'symbols' 替换 input2 中对应 bboxes 的 'symbols',并同步更新 'atoms' 的 'atom_symbol'。 假设 input1 和 input2 的结构一致。 """ for item1, item2 in zip(input1, input2): bboxes1 = item1.get('bboxes', []) bboxes2 = item2.get('bboxes', []) if len(bboxes1) != len(bboxes2): print("Warning: Mismatched number of bboxes!") continue for bbox1, bbox2 in zip(bboxes1, bboxes2): # 更新 symbols if 'symbols' in bbox1: bbox2['symbols'] = bbox1['symbols'] # 更新 symbols # 更新 atoms 的 atom_symbol if 'symbols' in bbox1 and 'atoms' in bbox2: symbols = bbox1['symbols'] atoms = bbox2.get('atoms', []) # 确保 symbols 和 atoms 的长度一致 if len(symbols) != len(atoms): print(f"Warning: Mismatched symbols and atoms in bbox {bbox1.get('bbox')}!") continue for atom, symbol in zip(atoms, symbols): atom['atom_symbol'] = symbol # 更新 atom_symbol return input2 input2_updated = update_symbols_in_atoms(gpt_output, coref_results) def update_smiles_and_molfile(input_data, conversion_function): """ 使用更新后的 'symbols'、'coords' 和 'edges' 调用 `conversion_function` 生成新的 'smiles' 和 'molfile', 并替换到原数据结构中。 参数: - input_data: 包含 bboxes 的嵌套数据结构 - conversion_function: 函数,接受 'coords', 'symbols', 'edges' 并返回 (new_smiles, new_molfile, _) 返回: - 更新后的数据结构 """ for item in input_data: for bbox in item.get('bboxes', []): # 检查必需的键是否存在 if all(key in bbox for key in ['coords', 'symbols', 'edges']): coords = bbox['coords'] symbols = bbox['symbols'] edges = bbox['edges'] # 调用转换函数生成新的 'smiles' 和 'molfile' new_smiles, new_molfile, _ = conversion_function(coords, symbols, edges) print(f" Generated 'smiles': {new_smiles}") # 替换旧的 'smiles' 和 'molfile' bbox['smiles'] = new_smiles bbox['molfile'] = new_molfile return input_data updated_data = update_smiles_and_molfile(input2_updated, _convert_graph_to_smiles) return updated_data def process_reaction_image_with_multiple_products_and_text_correctR(image_path: str) -> dict: """ Args: image_path (str): 图像文件路径。 Returns: dict: 整理后的反应数据,包括反应物、产物和反应模板。 """ # 配置 API Key 和 Azure Endpoint api_key = os.getenv("CHEMEAGLE_API_KEY") if not api_key: raise RuntimeError("Missing CHEMEAGLE_API_KEY environment variable") # 替换为实际的 API Key azure_endpoint = "https://hkust.azure-api.net" # 替换为实际的 Azure Endpoint model = ChemIEToolkit(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) client = AzureOpenAI( api_key=api_key, api_version='2024-06-01', azure_endpoint=azure_endpoint ) # 加载图像并编码为 Base64 def encode_image(image_path: str): with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode('utf-8') base64_image = encode_image(image_path) # GPT 工具调用配置 tools = [ { 'type': 'function', 'function': { 'name': 'get_multi_molecular_text_to_correct_withatoms', 'description': 'Extracts the SMILES string, the symbols set, and the text coref of all molecular images in a table-reaction image and ready to be correct.', 'parameters': { 'type': 'object', 'properties': { 'image_path': { 'type': 'string', 'description': 'The path to the reaction image.', }, }, 'required': ['image_path'], 'additionalProperties': False, }, }, }, ] # 提供给 GPT 的消息内容 with open('./prompt_getmolecular_correctR.txt', 'r') as prompt_file: prompt = prompt_file.read() messages = [ {'role': 'system', 'content': 'You are a helpful assistant.'}, { 'role': 'user', 'content': [ {'type': 'text', 'text': prompt}, {'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}} ] } ] # 调用 GPT 接口 response = client.chat.completions.create( model = 'gpt-4o', temperature = 0, response_format={ 'type': 'json_object' }, messages = [ {'role': 'system', 'content': 'You are a helpful assistant.'}, { 'role': 'user', 'content': [ { 'type': 'text', 'text': prompt }, { 'type': 'image_url', 'image_url': { 'url': f'data:image/png;base64,{base64_image}' } } ]}, ], tools = tools) # Step 1: 工具映射表 TOOL_MAP = { 'get_multi_molecular_text_to_correct_withatoms': get_multi_molecular_text_to_correct_withatoms, } # Step 2: 处理多个工具调用 tool_calls = response.choices[0].message.tool_calls results = [] # 遍历每个工具调用 for tool_call in tool_calls: tool_name = tool_call.function.name tool_arguments = tool_call.function.arguments tool_call_id = tool_call.id tool_args = json.loads(tool_arguments) if tool_name in TOOL_MAP: # 调用工具并获取结果 tool_result = TOOL_MAP[tool_name](image_path) else: raise ValueError(f"Unknown tool called: {tool_name}") # 保存每个工具调用结果 results.append({ 'role': 'tool', 'content': json.dumps({ 'image_path': image_path, f'{tool_name}':(tool_result), }), 'tool_call_id': tool_call_id, }) # Prepare the chat completion payload completion_payload = { 'model': 'gpt-4o', 'messages': [ {'role': 'system', 'content': 'You are a helpful assistant.'}, { 'role': 'user', 'content': [ { 'type': 'text', 'text': prompt }, { 'type': 'image_url', 'image_url': { 'url': f'data:image/png;base64,{base64_image}' } } ] }, response.choices[0].message, *results ], } # Generate new response response = client.chat.completions.create( model=completion_payload["model"], messages=completion_payload["messages"], response_format={ 'type': 'json_object' }, temperature=0 ) # 获取 GPT 生成的结果 gpt_output = [json.loads(response.choices[0].message.content)] def get_multi_molecular(image_path: str) -> list: '''Returns a list of reactions extracted from the image.''' # 打开图像文件 image = Image.open(image_path).convert('RGB') # 将图像作为输入传递给模型 coref_results = model.extract_molecule_corefs_from_figures([image]) return coref_results coref_results = get_multi_molecular(image_path) def update_symbols_in_atoms(input1, input2): """ 用 input1 中更新后的 'symbols' 替换 input2 中对应 bboxes 的 'symbols',并同步更新 'atoms' 的 'atom_symbol'。 假设 input1 和 input2 的结构一致。 """ for item1, item2 in zip(input1, input2): bboxes1 = item1.get('bboxes', []) bboxes2 = item2.get('bboxes', []) if len(bboxes1) != len(bboxes2): print("Warning: Mismatched number of bboxes!") continue for bbox1, bbox2 in zip(bboxes1, bboxes2): # 更新 symbols if 'symbols' in bbox1: bbox2['symbols'] = bbox1['symbols'] # 更新 symbols # 更新 atoms 的 atom_symbol if 'symbols' in bbox1 and 'atoms' in bbox2: symbols = bbox1['symbols'] atoms = bbox2.get('atoms', []) # 确保 symbols 和 atoms 的长度一致 if len(symbols) != len(atoms): print(f"Warning: Mismatched symbols and atoms in bbox {bbox1.get('bbox')}!") continue for atom, symbol in zip(atoms, symbols): atom['atom_symbol'] = symbol # 更新 atom_symbol return input2 input2_updated = update_symbols_in_atoms(gpt_output, coref_results) def update_smiles_and_molfile(input_data, conversion_function): """ 使用更新后的 'symbols'、'coords' 和 'edges' 调用 `conversion_function` 生成新的 'smiles' 和 'molfile', 并替换到原数据结构中。 参数: - input_data: 包含 bboxes 的嵌套数据结构 - conversion_function: 函数,接受 'coords', 'symbols', 'edges' 并返回 (new_smiles, new_molfile, _) 返回: - 更新后的数据结构 """ for item in input_data: for bbox in item.get('bboxes', []): # 检查必需的键是否存在 if all(key in bbox for key in ['coords', 'symbols', 'edges']): coords = bbox['coords'] symbols = bbox['symbols'] edges = bbox['edges'] # 调用转换函数生成新的 'smiles' 和 'molfile' new_smiles, new_molfile, _ = conversion_function(coords, symbols, edges) print(f" Generated 'smiles': {new_smiles}") # 替换旧的 'smiles' 和 'molfile' bbox['smiles'] = new_smiles bbox['molfile'] = new_molfile return input_data updated_data = update_smiles_and_molfile(input2_updated, _convert_graph_to_smiles) print(f"updated_mol_data:{updated_data}") return updated_data