import random import datetime import sys import os import torch import logging import json from importlib.resources import files from txagent import TxAgent from tooluniverse import ToolUniverse import gradio as gr # Set up logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Determine the directory where the current file is located current_dir = os.path.dirname(os.path.abspath(__file__)) os.environ["MKL_THREADING_LAYER"] = "GNU" os.environ["TOKENIZERS_PARALLELISM"] = "false" # Configuration CONFIG = { "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B", "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", "embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt", "tool_files": { "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')), "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')), "special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')), "monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json')), "new_tool": os.path.join(current_dir, 'data', 'new_tool.json') } } chat_css = """ .gr-button { font-size: 20px !important; } .gr-button svg { width: 32px !important; height: 32px !important; } """ def safe_load_embeddings(filepath: str) -> any: try: return torch.load(filepath, weights_only=True) except Exception as e: logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}") try: return torch.load(filepath, weights_only=False) except Exception as e: logger.error(f"Failed to load embeddings: {str(e)}") return None def patch_embedding_loading(): try: from txagent.toolrag import ToolRAGModel def patched_load(self, tooluniverse): try: if not os.path.exists(CONFIG["embedding_filename"]): logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}") return False self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"]) if hasattr(tooluniverse, 'get_all_tools'): tools = tooluniverse.get_all_tools() elif hasattr(tooluniverse, 'tools'): tools = tooluniverse.tools else: logger.error("No method found to access tools from ToolUniverse") return False current_count = len(tools) embedding_count = len(self.tool_desc_embedding) if current_count != embedding_count: logger.warning(f"Tool count mismatch (tools: {current_count}, embeddings: {embedding_count})") if current_count < embedding_count: self.tool_desc_embedding = self.tool_desc_embedding[:current_count] logger.info(f"Truncated embeddings to match {current_count} tools") else: last_embedding = self.tool_desc_embedding[-1] padding = [last_embedding] * (current_count - embedding_count) self.tool_desc_embedding = torch.cat([self.tool_desc_embedding] + padding) logger.info(f"Padded embeddings to match {current_count} tools") return True except Exception as e: logger.error(f"Failed to load embeddings: {str(e)}") return False ToolRAGModel.load_tool_desc_embedding = patched_load logger.info("Successfully patched embedding loading") except Exception as e: logger.error(f"Failed to patch embedding loading: {str(e)}") raise def prepare_tool_files(): os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True) if not os.path.exists(CONFIG["tool_files"]["new_tool"]): logger.info("Generating tool list using ToolUniverse...") try: tu = ToolUniverse() if hasattr(tu, 'get_all_tools'): tools = tu.get_all_tools() elif hasattr(tu, 'tools'): tools = tu.tools else: tools = [] logger.error("Could not access tools from ToolUniverse") with open(CONFIG["tool_files"]["new_tool"], "w") as f: json.dump(tools, f, indent=2) logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}") except Exception as e: logger.error(f"Failed to prepare tool files: {str(e)}") def create_agent(): patch_embedding_loading() prepare_tool_files() try: agent = TxAgent( CONFIG["model_name"], CONFIG["rag_model_name"], tool_files_dict=CONFIG["tool_files"], force_finish=True, enable_checker=True, step_rag_num=10, seed=100, additional_default_tools=['DirectResponse', 'RequireClarification'] ) agent.init_model() return agent except Exception as e: logger.error(f"Failed to create agent: {str(e)}") raise def respond(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round): updated_history = history + [{"role": "user", "content": message}] response_generator = agent.run_gradio_chat(updated_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round) collected = "" for chunk in response_generator: if isinstance(chunk, dict): collected += chunk.get("content", "") else: collected += str(chunk) updated_history.append({"role": "assistant", "content": collected}) return updated_history def create_demo(agent): with gr.Blocks(css=chat_css) as demo: chatbot = gr.Chatbot(label="TxAgent", type="messages") with gr.Row(): msg = gr.Textbox(label="Your question") with gr.Row(): temp = gr.Slider(0, 1, value=0.3, label="Temperature") max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens") max_tokens = gr.Slider(128, 81920, value=81920, label="Max Total Tokens") max_rounds = gr.Slider(1, 30, value=30, label="Max Rounds") multi_agent = gr.Checkbox(label="Multi-Agent Mode") with gr.Row(): submit = gr.Button("Ask TxAgent") submit.click( respond, inputs=[msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, gr.State([]), max_rounds], outputs=[chatbot] ) return demo def main(): try: global agent agent = create_agent() demo = create_demo(agent) demo.launch() except Exception as e: logger.error(f"Application failed to start: {str(e)}") raise if __name__ == "__main__": main()