test / app.py
Ali2206's picture
Update app.py
81ad366 verified
raw
history blame
3.72 kB
import gradio as gr
from txagent import TxAgent
from tooluniverse import ToolUniverse
import os
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TxAgentApp:
def __init__(self):
self.agent = self._initialize_agent()
def _initialize_agent(self):
"""Initialize the TxAgent with A100 optimizations"""
try:
logger.info("Initializing TxAgent with A100 optimizations...")
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
device_map="auto",
torch_dtype="auto",
enable_xformers=True,
max_model_len=8192 # Optimized for A100 80GB
)
logger.info("Model loading complete")
return agent
except Exception as e:
logger.error(f"Initialization failed: {str(e)}")
raise
def respond(self, message, history):
"""Handle streaming responses with Gradio 5.23"""
try:
response_generator = self.agent.run_gradio_chat(
message=message,
history=history,
temperature=0.3,
max_new_tokens=2048,
stream=True
)
for chunk in response_generator:
if isinstance(chunk, dict):
yield chunk.get("content", "")
elif isinstance(chunk, str):
yield chunk
except Exception as e:
logger.error(f"Generation error: {str(e)}")
yield f"⚠️ Error: {str(e)}"
# Initialize the app
app = TxAgentApp()
# Gradio 5.23 interface
with gr.Blocks(
title="TxAgent Medical AI",
theme=gr.themes.Soft(spacing_size="sm", radius_size="none")
) as demo:
gr.Markdown("""<h1 style='text-align: center'>TxAgent Biomedical Assistant</h1>""")
with gr.Row(equal_height=False):
with gr.Column(scale=2):
chatbot = gr.Chatbot(
height=650,
bubble_full_width=False,
avatar_images=(
"https://example.com/user.png", # Replace with actual avatars
"https://example.com/bot.png"
)
)
with gr.Column(scale=1):
with gr.Accordion("⚙️ Parameters", open=False):
temperature = gr.Slider(0, 1, value=0.3, label="Creativity")
max_tokens = gr.Slider(128, 4096, value=1024, step=128, label="Max Response Length")
rag_toggle = gr.Checkbox(value=True, label="Enable RAG")
msg = gr.Textbox(
label="Your medical query",
placeholder="Enter your biomedical question...",
lines=5,
max_lines=10
)
submit_btn = gr.Button("Submit", variant="primary")
clear_btn = gr.Button("Clear History")
# Chat interface
msg.submit(
app.respond,
[msg, chatbot],
chatbot,
api_name="chat"
).then(
lambda: "", None, msg
)
submit_btn.click(
app.respond,
[msg, chatbot],
chatbot,
api_name="chat"
).then(
lambda: "", None, msg
)
clear_btn.click(
lambda: [], None, chatbot
)
# Launch configuration
if __name__ == "__main__":
demo.queue(
concurrency_count=5,
max_size=20
).launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
favicon_path="icon.png" # Add favicon
)