test / app.py
Ali2206's picture
Update app.py
90e4214 verified
raw
history blame
5.43 kB
import gradio as gr
import os
import logging
from txagent import TxAgent
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TxAgentApp:
def __init__(self):
self.agent = self._initialize_agent()
def _initialize_agent(self):
"""Initialize the TxAgent with proper parameters"""
try:
logger.info("Initializing TxAgent...")
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
enable_finish=True,
enable_rag=True,
enable_summary=False,
init_rag_num=0,
step_rag_num=10,
summary_mode='step',
summary_skip_last_k=0,
summary_context_length=None,
force_finish=True,
avoid_repeat=True,
seed=42,
enable_checker=True,
enable_chat=False,
additional_default_tools=["DirectResponse", "RequireClarification"]
)
logger.info("Model loading complete")
return agent
except Exception as e:
logger.error(f"Initialization failed: {str(e)}")
raise
def respond(self, message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round):
"""Handle streaming responses with Gradio"""
try:
if not isinstance(message, str) or len(message.strip()) <= 10:
yield "Please provide a valid message longer than 10 characters."
return
response_generator = self.agent.run_gradio_chat(
message=message.strip(),
history=chat_history,
temperature=temperature,
max_new_tokens=max_new_tokens,
max_token=max_tokens,
call_agent=multi_agent,
conversation=conversation_state,
max_round=max_round,
seed=42
)
full_response = ""
for chunk in response_generator:
if isinstance(chunk, dict) and "content" in chunk:
content = chunk["content"]
elif isinstance(chunk, str):
content = chunk
else:
content = str(chunk)
full_response += content
yield full_response
except Exception as e:
logger.error(f"Error in respond function: {str(e)}")
yield f"⚠️ Error: {str(e)}"
def create_demo():
"""Create and configure the Gradio interface"""
app = TxAgentApp()
with gr.Blocks(
title="TxAgent Medical AI",
theme=gr.themes.Soft(spacing_size="sm", radius_size="none")
) as demo:
gr.Markdown("""<h1 style='text-align: center'>TxAgent Biomedical Assistant</h1>""")
with gr.Row(equal_height=False):
with gr.Column(scale=2):
chatbot = gr.Chatbot(
height=650,
bubble_full_width=False,
render_markdown=True
)
msg = gr.Textbox(
label="Your medical query",
placeholder="Enter your biomedical question...",
lines=5,
max_lines=10
)
with gr.Column(scale=1):
with gr.Accordion("⚙️ Parameters", open=False):
temperature = gr.Slider(0, 1, value=0.3, label="Creativity")
max_new_tokens = gr.Slider(128, 4096, value=1024, step=128, label="Max Response Length")
max_tokens = gr.Slider(128, 81920, value=81920, step=1024, label="Max Total Tokens")
max_rounds = gr.Slider(1, 30, value=10, step=1, label="Max Rounds")
multi_agent = gr.Checkbox(value=False, label="Multi-Agent Mode")
submit_btn = gr.Button("Submit", variant="primary")
clear_btn = gr.Button("Clear History")
conversation_state = gr.State([])
# Chat interface
msg.submit(
app.respond,
[msg, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
chatbot
).then(lambda: "", None, msg)
submit_btn.click(
app.respond,
[msg, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
chatbot
).then(lambda: "", None, msg)
clear_btn.click(
lambda: ([], []),
None,
[chatbot, conversation_state]
)
return demo
def main():
"""Main entry point for the application"""
try:
logger.info("Starting TxAgent application...")
demo = create_demo()
logger.info("Launching Gradio interface...")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
except Exception as e:
logger.error(f"Application failed to start: {str(e)}")
raise
if __name__ == "__main__":
main()