import os import json import torch import logging import numpy import gradio as gr import torch.serialization from importlib.resources import files from txagent import TxAgent from tooluniverse import ToolUniverse # Allow loading old numpy types with torch.load torch.serialization.add_safe_globals([ numpy.core.multiarray._reconstruct, numpy.ndarray, numpy.dtype, numpy.dtypes.Float32DType ]) logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) os.environ["MKL_THREADING_LAYER"] = "GNU" os.environ["TOKENIZERS_PARALLELISM"] = "false" current_dir = os.path.dirname(os.path.abspath(__file__)) CONFIG = { "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B", "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", "embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding.pt", "tool_files": { "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')), "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')), "special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')), "monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json')), "new_tool": os.path.join(current_dir, 'data', 'new_tool.json') } } def generate_tool_embeddings(agent): tu = ToolUniverse(tool_files=CONFIG["tool_files"]) tu.load_tools() embedding_tensor = agent.rag_model.load_tool_desc_embedding(tu) if embedding_tensor is not None: torch.save(embedding_tensor, CONFIG["embedding_filename"]) logger.info(f"Saved new embedding tensor to {CONFIG['embedding_filename']}") else: logger.warning("Embedding generation returned None") def prepare_tool_files(): os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True) if not os.path.exists(CONFIG["tool_files"]["new_tool"]): try: tu = ToolUniverse() tools = tu.get_all_tools() if hasattr(tu, "get_all_tools") else getattr(tu, "tools", []) with open(CONFIG["tool_files"]["new_tool"], "w") as f: json.dump(tools, f, indent=2) except Exception as e: logger.error(f"Tool generation failed: {e}") def create_agent(): prepare_tool_files() try: agent = TxAgent( CONFIG["model_name"], CONFIG["rag_model_name"], tool_files_dict=CONFIG["tool_files"], force_finish=True, enable_checker=True, step_rag_num=10, seed=42, additional_default_tools=["DirectResponse", "RequireClarification"] ) if not os.path.exists(CONFIG["embedding_filename"]): generate_tool_embeddings(agent) agent.init_model() return agent except Exception as e: logger.error(f"Agent initialization failed: {e}") raise def respond(msg, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round): if not isinstance(msg, str) or len(msg.strip()) <= 10: return chat_history + [{"role": "assistant", "content": "Hi, I am TxAgent. Please provide a valid message longer than 10 characters."}] message = msg.strip() chat_history.append({"role": "user", "content": message}) formatted_history = chat_history # format as list of dicts for run_gradio_chat try: response_generator = agent.run_gradio_chat( message=message, history=formatted_history, temperature=temperature, max_new_tokens=max_new_tokens, max_token=max_tokens, call_agent=multi_agent, conversation=conversation, max_round=max_round, seed=42, call_agent_level=0, sub_agent_task=None ) collected = "" for chunk in response_generator: if isinstance(chunk, list): for msg in chunk: if isinstance(msg, dict) and "content" in msg: collected += msg["content"] elif isinstance(chunk, dict) and "content" in chunk: collected += chunk["content"] elif isinstance(chunk, str): collected += chunk chat_history.append({"role": "assistant", "content": collected or "⚠️ No content returned."}) except Exception as e: chat_history.append({"role": "assistant", "content": f"❌ Error: {str(e)}"}) return chat_history def create_demo(agent): with gr.Blocks(css=".gr-button { font-size: 18px !important; }") as demo: chatbot = gr.Chatbot(label="TxAgent", type="messages", render_markdown=True) msg = gr.Textbox(label="Your question", placeholder="Ask a biomedical question...", scale=6) with gr.Row(): temp = gr.Slider(0, 1, value=0.3, label="Temperature") max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens") max_tokens = gr.Slider(128, 81920, value=81920, label="Max Total Tokens") max_rounds = gr.Slider(1, 30, value=30, label="Max Rounds") multi_agent = gr.Checkbox(label="Multi-Agent Mode") submit = gr.Button("Ask TxAgent") submit.click( respond, inputs=[msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, gr.State([]), max_rounds], outputs=[chatbot] ) return demo def main(): global agent agent = create_agent() demo = create_demo(agent) demo.launch(share=True) if __name__ == "__main__": main()