import gradio as gr import logging from txagent import TxAgent from tooluniverse import ToolUniverse from importlib.resources import files # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) tx_app = None # Will be initialized later in on_start def init_txagent(): """Initialize the TxAgent with proper tool file paths""" logger.info("Initializing TxAgent...") tool_files = { "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')), "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')), "special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')), "monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json')) } logger.info(f"Using tool files at: {tool_files}") agent = TxAgent( model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", tool_files_dict=tool_files, enable_finish=True, enable_rag=True, enable_summary=False, init_rag_num=0, step_rag_num=10, summary_mode='step', summary_skip_last_k=0, summary_context_length=None, force_finish=True, avoid_repeat=True, seed=42, enable_checker=True, enable_chat=False, additional_default_tools=["DirectResponse", "RequireClarification"] ) agent.init_model() logger.info("Model loading complete") return agent def respond(message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round): global tx_app try: if not isinstance(message, str) or len(message.strip()) <= 10: return chat_history + [("", "Please provide a valid message longer than 10 characters.")] if chat_history and isinstance(chat_history[0], dict): chat_history = [(h["role"], h["content"]) for h in chat_history if "role" in h and "content" in h] response = "" for chunk in tx_app.run_gradio_chat( message=message.strip(), history=chat_history, temperature=temperature, max_new_tokens=max_new_tokens, max_token=max_tokens, call_agent=multi_agent, conversation=conversation_state, max_round=max_round, seed=42 ): if isinstance(chunk, dict): response += chunk.get("content", "") elif isinstance(chunk, str): response += chunk else: response += str(chunk) yield chat_history + [("user", message), ("assistant", response)] except Exception as e: logger.error(f"Error in respond function: {str(e)}") yield chat_history + [("", f"⚠️ Error: {str(e)}")] # Define Gradio UI with gr.Blocks(title="TxAgent Biomedical Assistant") as app: gr.Markdown("# 🧠 TxAgent Biomedical Assistant") chatbot = gr.Chatbot(label="Conversation", height=600) msg = gr.Textbox( label="Your medical query", placeholder="Enter your biomedical question...", lines=3 ) with gr.Row(): temp = gr.Slider(0, 1, value=0.3, label="Temperature") max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens") max_tokens = gr.Slider(128, 81920, value=81920, label="Max Total Tokens") max_rounds = gr.Slider(1, 30, value=10, label="Max Rounds") multi_agent = gr.Checkbox(label="Multi-Agent Mode") submit = gr.Button("Submit") clear = gr.Button("Clear") conversation_state = gr.State([]) submit.click( respond, [msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds], chatbot ) clear.click(lambda: [], None, chatbot) msg.submit( respond, [msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds], chatbot ) @app.on_start def load_model(): global tx_app logger.info("🔥 Loading TxAgent model in Gradio @on_start") tx_app = init_txagent()