File size: 4,810 Bytes
9aeb1dd 0b9e159 37d892a 0b9e159 dcb29df 229805b 410d25f 66e2fa0 81ad366 79fb3cd 37d892a dc06321 37d892a dc06321 90e4214 0b9e159 dc06321 0b9e159 dcb29df 0b9e159 dcb29df 0b9e159 90e4214 0b9e159 dcb29df dc06321 dcb29df 0b9e159 dcb29df 83c8341 dcb29df 0b9e159 dc06321 0b9e159 dcb29df 81ad366 37d892a 83c8341 37d892a 83c8341 37d892a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import gradio as gr
import logging
import multiprocessing
from txagent import TxAgent
from tooluniverse import ToolUniverse
from importlib.resources import files
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
tx_app = None # Global holder for app instance (for Gradio to use)
def init_txagent():
"""Initialize the TxAgent with proper tool file paths"""
try:
multiprocessing.set_start_method("spawn", force=True)
logger.info("Initializing TxAgent...")
tool_files = {
"opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
"fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
"special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')),
"monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json'))
}
logger.info(f"Using tool files at: {tool_files}")
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
tool_files_dict=tool_files,
enable_finish=True,
enable_rag=True,
enable_summary=False,
init_rag_num=0,
step_rag_num=10,
summary_mode='step',
summary_skip_last_k=0,
summary_context_length=None,
force_finish=True,
avoid_repeat=True,
seed=42,
enable_checker=True,
enable_chat=False,
additional_default_tools=["DirectResponse", "RequireClarification"]
)
agent.init_model()
logger.info("Model loading complete")
return agent
except Exception as e:
logger.error(f"Initialization failed: {str(e)}")
raise
class TxAgentApp:
def __init__(self):
self.agent = init_txagent()
def respond(self, message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round):
"""Handle streaming responses with Gradio"""
try:
if not isinstance(message, str) or len(message.strip()) <= 10:
return chat_history + [("", "Please provide a valid message longer than 10 characters.")]
if chat_history and isinstance(chat_history[0], dict):
chat_history = [(h["role"], h["content"]) for h in chat_history if "role" in h and "content" in h]
response = ""
for chunk in self.agent.run_gradio_chat(
message=message.strip(),
history=chat_history,
temperature=temperature,
max_new_tokens=max_new_tokens,
max_token=max_tokens,
call_agent=multi_agent,
conversation=conversation_state,
max_round=max_round,
seed=42
):
if isinstance(chunk, dict):
response += chunk.get("content", "")
elif isinstance(chunk, str):
response += chunk
else:
response += str(chunk)
yield chat_history + [("user", message), ("assistant", response)]
except Exception as e:
logger.error(f"Error in respond function: {str(e)}")
yield chat_history + [("", f"⚠️ Error: {str(e)}")]
# Initialize the agent safely
tx_app = TxAgentApp()
# Define Gradio UI interface
with gr.Blocks(title="TxAgent Biomedical Assistant") as app:
gr.Markdown("# 🧠 TxAgent Biomedical Assistant")
chatbot = gr.Chatbot(
label="Conversation",
height=600
)
msg = gr.Textbox(
label="Your medical query",
placeholder="Enter your biomedical question...",
lines=3
)
with gr.Row():
temp = gr.Slider(0, 1, value=0.3, label="Temperature")
max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
max_tokens = gr.Slider(128, 81920, value=81920, label="Max Total Tokens")
max_rounds = gr.Slider(1, 30, value=10, label="Max Rounds")
multi_agent = gr.Checkbox(label="Multi-Agent Mode")
submit = gr.Button("Submit")
clear = gr.Button("Clear")
conversation_state = gr.State([])
submit.click(
tx_app.respond,
[msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
chatbot
)
clear.click(lambda: [], None, chatbot)
msg.submit(
tx_app.respond,
[msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
chatbot
)
# This `app` will be served by Hugging Face automatically
|