File size: 4,373 Bytes
9aeb1dd 0b9e159 dcb29df 229805b 410d25f 81ad366 79fb3cd 696fd36 37d892a 08baaf7 f15352f 696fd36 f15352f 696fd36 37d892a f15352f 37d892a f15352f 696fd36 83c8341 08baaf7 696fd36 83c8341 f15352f 83c8341 f15352f 83c8341 696fd36 08baaf7 696fd36 f15352f 08baaf7 696fd36 08baaf7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import gradio as gr
import logging
from txagent import TxAgent
from tooluniverse import ToolUniverse
from importlib.resources import files
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
tx_app = None # Global TxAgent instance
def init_txagent():
logger.info("🔥 Initializing TxAgent...")
tool_files = {
"opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
"fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
"special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')),
"monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json'))
}
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
tool_files_dict=tool_files,
enable_finish=True,
enable_rag=True,
enable_summary=False,
init_rag_num=0,
step_rag_num=10,
summary_mode='step',
summary_skip_last_k=0,
summary_context_length=None,
force_finish=True,
avoid_repeat=True,
seed=42,
enable_checker=True,
enable_chat=False,
additional_default_tools=["DirectResponse", "RequireClarification"]
)
agent.init_model()
logger.info("✅ TxAgent fully initialized")
return agent
def respond(message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round):
global tx_app
if tx_app is None:
return chat_history + [("", "⚠️ Model not ready yet. Please wait a few seconds and try again.")]
try:
if not isinstance(message, str) or len(message.strip()) <= 10:
return chat_history + [("", "Please provide a valid message longer than 10 characters.")]
if chat_history and isinstance(chat_history[0], dict):
chat_history = [(h["role"], h["content"]) for h in chat_history if "role" in h and "content" in h]
response = ""
for chunk in tx_app.run_gradio_chat(
message=message.strip(),
history=chat_history,
temperature=temperature,
max_new_tokens=max_new_tokens,
max_token=max_tokens,
call_agent=multi_agent,
conversation=conversation_state,
max_round=max_round,
seed=42
):
if isinstance(chunk, dict):
response += chunk.get("content", "")
elif isinstance(chunk, str):
response += chunk
else:
response += str(chunk)
yield chat_history + [("user", message), ("assistant", response)]
except Exception as e:
logger.error(f"Error in respond function: {str(e)}")
yield chat_history + [("", f"⚠️ Error: {str(e)}")]
# ✅ Top-level app object that HF Spaces can detect
with gr.Blocks(title="TxAgent Biomedical Assistant") as app:
gr.Markdown("# 🧠 TxAgent Biomedical Assistant")
chatbot = gr.Chatbot(label="Conversation", height=600, type="messages")
msg = gr.Textbox(label="Your medical query", placeholder="Enter your biomedical question...", lines=3)
with gr.Row():
temp = gr.Slider(0, 1, value=0.3, label="Temperature")
max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
max_tokens = gr.Slider(128, 81920, value=81920, label="Max Total Tokens")
max_rounds = gr.Slider(1, 30, value=10, label="Max Rounds")
multi_agent = gr.Checkbox(label="Multi-Agent Mode")
submit = gr.Button("Submit")
clear = gr.Button("Clear")
conversation_state = gr.State([])
submit.click(
respond,
[msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
chatbot
)
clear.click(lambda: [], None, chatbot)
msg.submit(
respond,
[msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
chatbot
)
# ✅ hidden init trigger on page load
hidden_button = gr.Button(visible=False)
def initialize_agent():
global tx_app
tx_app = init_txagent()
return gr.update(visible=False)
app.load(hidden_button.click(fn=initialize_agent))
|