import random import os import torch import logging import json from importlib.resources import files from txagent import TxAgent from tooluniverse import ToolUniverse import gradio as gr # Set up logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) current_dir = os.path.dirname(os.path.abspath(__file__)) os.environ["MKL_THREADING_LAYER"] = "GNU" os.environ["TOKENIZERS_PARALLELISM"] = "false" # Configuration - Update paths as needed CONFIG = { "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B", "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", "embedding_filename": "path_to_your_embeddings.pt", # Update this path "tool_files": { "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')), "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')), "special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')), "monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json')), "new_tool": os.path.join(current_dir, 'data', 'new_tool.json') } } def safe_load_embeddings(filepath: str): """Handle embedding loading with fallbacks""" try: # Try with weights_only=True first return torch.load(filepath, weights_only=True) except Exception as e: logger.warning(f"Secure load failed, trying without weights_only: {str(e)}") try: return torch.load(filepath, weights_only=False) except Exception as e: logger.error(f"Failed to load embeddings: {str(e)}") return None def get_tools_from_universe(tooluniverse): """Flexible tool extraction from ToolUniverse""" if hasattr(tooluniverse, 'get_all_tools'): return tooluniverse.get_all_tools() elif hasattr(tooluniverse, 'tools'): return tooluniverse.tools elif hasattr(tooluniverse, 'list_tools'): return tooluniverse.list_tools() else: logger.error("Could not find any tool access method in ToolUniverse") # Try to load from files directly as fallback tools = [] for tool_file in CONFIG["tool_files"].values(): if os.path.exists(tool_file): with open(tool_file, 'r') as f: tools.extend(json.load(f)) return tools if tools else None def prepare_tool_files(): """Ensure tool files exist and are populated""" os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True) if not os.path.exists(CONFIG["tool_files"]["new_tool"]): logger.info("Generating tool list...") try: tu = ToolUniverse() tools = get_tools_from_universe(tu) if tools: with open(CONFIG["tool_files"]["new_tool"], "w") as f: json.dump(tools, f, indent=2) logger.info(f"Saved {len(tools)} tools") else: logger.error("No tools could be loaded") except Exception as e: logger.error(f"Tool file preparation failed: {str(e)}") def create_agent(): """Create and initialize the TxAgent with robust error handling""" prepare_tool_files() try: agent = TxAgent( model_name=CONFIG["model_name"], rag_model_name=CONFIG["rag_model_name"], tool_files_dict=CONFIG["tool_files"], force_finish=True, enable_checker=True, step_rag_num=10, seed=100, additional_default_tools=['DirectResponse', 'RequireClarification'] ) agent.init_model() return agent except Exception as e: logger.error(f"Agent creation failed: {str(e)}") raise def format_response(history, message): """Properly format responses for Gradio Chatbot""" if isinstance(message, (str, dict)): return history + [[None, str(message)]] elif hasattr(message, '__iter__'): full_response = "" for chunk in message: if isinstance(chunk, dict): full_response += chunk.get("content", "") else: full_response += str(chunk) return history + [[None, full_response]] return history + [[None, str(message)]] def create_demo(agent): """Create the Gradio interface with proper message handling""" with gr.Blocks() as demo: chatbot = gr.Chatbot( height=800, label='TxAgent', show_copy_button=True, type="messages" # Use the modern message format ) msg = gr.Textbox(label="Input", placeholder="Type your question...") clear = gr.ClearButton([msg, chatbot]) def respond(message, chat_history): try: # Convert Gradio history to agent format agent_history = [] for user_msg, bot_msg in chat_history: if user_msg: agent_history.append({"role": "user", "content": user_msg}) if bot_msg: agent_history.append({"role": "assistant", "content": bot_msg}) # Get response from agent response = agent.run_gradio_chat( agent_history + [{"role": "user", "content": message}], temperature=0.3, max_new_tokens=1024, max_tokens=81920, multi_agent=False, conversation=[], max_round=30 ) # Format the response properly full_response = "" for chunk in response: if isinstance(chunk, dict): full_response += chunk.get("content", "") else: full_response += str(chunk) return chat_history + [(message, full_response)] except Exception as e: logger.error(f"Error in response handling: {str(e)}") return chat_history + [(message, f"Error: {str(e)}")] msg.submit(respond, [msg, chatbot], [chatbot]) clear.click(lambda: [], None, [chatbot]) return demo def main(): """Main application entry point""" try: agent = create_agent() demo = create_demo(agent) demo.launch(server_name="0.0.0.0", server_port=7860) except Exception as e: logger.error(f"Application failed to start: {str(e)}") raise if __name__ == "__main__": main()