File size: 3,719 Bytes
9aeb1dd 511fb62 81ad366 410d25f 66e2fa0 81ad366 79fb3cd dc06321 81ad366 dc06321 81ad366 dc06321 81ad366 dc06321 81ad366 dc06321 81ad366 dc06321 81ad366 dc06321 81ad366 dc06321 81ad366 dc06321 81ad366 dc06321 81ad366 dc06321 81ad366 dc06321 81ad366 dc06321 81ad366 70839bb 81ad366 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import gradio as gr
from txagent import TxAgent
from tooluniverse import ToolUniverse
import os
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TxAgentApp:
def __init__(self):
self.agent = self._initialize_agent()
def _initialize_agent(self):
"""Initialize the TxAgent with A100 optimizations"""
try:
logger.info("Initializing TxAgent with A100 optimizations...")
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
device_map="auto",
torch_dtype="auto",
enable_xformers=True,
max_model_len=8192 # Optimized for A100 80GB
)
logger.info("Model loading complete")
return agent
except Exception as e:
logger.error(f"Initialization failed: {str(e)}")
raise
def respond(self, message, history):
"""Handle streaming responses with Gradio 5.23"""
try:
response_generator = self.agent.run_gradio_chat(
message=message,
history=history,
temperature=0.3,
max_new_tokens=2048,
stream=True
)
for chunk in response_generator:
if isinstance(chunk, dict):
yield chunk.get("content", "")
elif isinstance(chunk, str):
yield chunk
except Exception as e:
logger.error(f"Generation error: {str(e)}")
yield f"⚠️ Error: {str(e)}"
# Initialize the app
app = TxAgentApp()
# Gradio 5.23 interface
with gr.Blocks(
title="TxAgent Medical AI",
theme=gr.themes.Soft(spacing_size="sm", radius_size="none")
) as demo:
gr.Markdown("""<h1 style='text-align: center'>TxAgent Biomedical Assistant</h1>""")
with gr.Row(equal_height=False):
with gr.Column(scale=2):
chatbot = gr.Chatbot(
height=650,
bubble_full_width=False,
avatar_images=(
"https://example.com/user.png", # Replace with actual avatars
"https://example.com/bot.png"
)
)
with gr.Column(scale=1):
with gr.Accordion("⚙️ Parameters", open=False):
temperature = gr.Slider(0, 1, value=0.3, label="Creativity")
max_tokens = gr.Slider(128, 4096, value=1024, step=128, label="Max Response Length")
rag_toggle = gr.Checkbox(value=True, label="Enable RAG")
msg = gr.Textbox(
label="Your medical query",
placeholder="Enter your biomedical question...",
lines=5,
max_lines=10
)
submit_btn = gr.Button("Submit", variant="primary")
clear_btn = gr.Button("Clear History")
# Chat interface
msg.submit(
app.respond,
[msg, chatbot],
chatbot,
api_name="chat"
).then(
lambda: "", None, msg
)
submit_btn.click(
app.respond,
[msg, chatbot],
chatbot,
api_name="chat"
).then(
lambda: "", None, msg
)
clear_btn.click(
lambda: [], None, chatbot
)
# Launch configuration
if __name__ == "__main__":
demo.queue(
concurrency_count=5,
max_size=20
).launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
favicon_path="icon.png" # Add favicon
) |