import gradio as gr import logging import os import multiprocessing # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) tx_app = None TOOL_CACHE_PATH = "/home/user/.cache/tool_embeddings_done" # Chatbot response function def respond(message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round): global tx_app if tx_app is None: return chat_history + [("", "⚠️ Model is still loading. Please wait a few seconds and try again.")] try: if not isinstance(message, str) or len(message.strip()) < 10: return chat_history + [("", "Please enter a longer message.")] if chat_history and isinstance(chat_history[0], dict): chat_history = [(h["role"], h["content"]) for h in chat_history if "role" in h and "content" in h] response = "" for chunk in tx_app.run_gradio_chat( message=message.strip(), history=chat_history, temperature=temperature, max_new_tokens=max_new_tokens, max_token=max_tokens, call_agent=multi_agent, conversation=conversation_state, max_round=max_round, seed=42, ): if isinstance(chunk, dict): response += chunk.get("content", "") elif isinstance(chunk, str): response += chunk else: response += str(chunk) yield chat_history + [("user", message), ("assistant", response)] except Exception as e: logger.error(f"Respond error: {e}") yield chat_history + [("", f"⚠️ Error: {e}")] # === Gradio UI === with gr.Blocks(title="TxAgent Biomedical Assistant") as app: gr.Markdown("# 🧠 TxAgent Biomedical Assistant") chatbot = gr.Chatbot(label="Conversation", height=600, type="messages") msg = gr.Textbox(label="Your medical query", placeholder="Type your biomedical question...", lines=3) with gr.Row(): temp = gr.Slider(0, 1, value=0.3, label="Temperature") max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens") max_tokens = gr.Slider(128, 81920, value=81920, label="Max Total Tokens") max_rounds = gr.Slider(1, 30, value=10, label="Max Rounds") multi_agent = gr.Checkbox(label="Multi-Agent Mode") conversation_state = gr.State([]) submit = gr.Button("Submit") clear = gr.Button("Clear") submit.click( respond, [msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds], chatbot ) clear.click(lambda: [], None, chatbot) msg.submit( respond, [msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds], chatbot ) # === Safe model initialization === if __name__ == "__main__": multiprocessing.set_start_method("spawn", force=True) import tooluniverse from txagent import TxAgent from importlib.resources import files # ✅ Patch ToolUniverse to prevent exit() after embedding original_infer = tooluniverse.ToolUniverse.infer_tool_embeddings def patched_infer(self, *args, **kwargs): original_infer(self, *args, **kwargs) print("✅ Patched: Skipping forced exit() after embedding.") tooluniverse.ToolUniverse.infer_tool_embeddings = patched_infer logger.info("🔥 Initializing TxAgent...") tool_files = { "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')), "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')), "special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')), "monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json')) } tx_app = TxAgent( model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", tool_files_dict=tool_files, enable_finish=True, enable_rag=True, enable_summary=False, init_rag_num=0, step_rag_num=10, summary_mode='step', summary_skip_last_k=0, summary_context_length=None, force_finish=True, avoid_repeat=True, seed=42, enable_checker=True, enable_chat=False, additional_default_tools=["DirectResponse", "RequireClarification"] ) # 🚀 Run full embedding once, then cache if not os.path.exists(TOOL_CACHE_PATH): tx_app.init_model() os.makedirs(os.path.dirname(TOOL_CACHE_PATH), exist_ok=True) with open(TOOL_CACHE_PATH, "w") as f: f.write("done") else: tx_app.init_model(skip_tool_embedding=True) logger.info("✅ TxAgent ready.")