File size: 5,430 Bytes
9aeb1dd dec1312 0b9e159 410d25f 66e2fa0 81ad366 79fb3cd dc06321 81ad366 dc06321 81ad366 dec1312 dc06321 dec1312 81ad366 0b9e159 dec1312 0b9e159 dec1312 0b9e159 dec1312 dc06321 81ad366 dc06321 81ad366 dc06321 90e4214 0b9e159 dc06321 0b9e159 dc06321 0b9e159 90e4214 0b9e159 dc06321 0b9e159 dc06321 0b9e159 dc06321 0b9e159 dc06321 0b9e159 81ad366 0b9e159 81ad366 0b9e159 81ad366 0b9e159 90e4214 0b9e159 81ad366 70839bb 0b9e159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import gradio as gr
import os
import logging
from txagent import TxAgent
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TxAgentApp:
def __init__(self):
self.agent = self._initialize_agent()
def _initialize_agent(self):
"""Initialize the TxAgent with proper parameters"""
try:
logger.info("Initializing TxAgent...")
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
enable_finish=True,
enable_rag=True,
enable_summary=False,
init_rag_num=0,
step_rag_num=10,
summary_mode='step',
summary_skip_last_k=0,
summary_context_length=None,
force_finish=True,
avoid_repeat=True,
seed=42,
enable_checker=True,
enable_chat=False,
additional_default_tools=["DirectResponse", "RequireClarification"]
)
logger.info("Model loading complete")
return agent
except Exception as e:
logger.error(f"Initialization failed: {str(e)}")
raise
def respond(self, message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round):
"""Handle streaming responses with Gradio"""
try:
if not isinstance(message, str) or len(message.strip()) <= 10:
yield "Please provide a valid message longer than 10 characters."
return
response_generator = self.agent.run_gradio_chat(
message=message.strip(),
history=chat_history,
temperature=temperature,
max_new_tokens=max_new_tokens,
max_token=max_tokens,
call_agent=multi_agent,
conversation=conversation_state,
max_round=max_round,
seed=42
)
full_response = ""
for chunk in response_generator:
if isinstance(chunk, dict) and "content" in chunk:
content = chunk["content"]
elif isinstance(chunk, str):
content = chunk
else:
content = str(chunk)
full_response += content
yield full_response
except Exception as e:
logger.error(f"Error in respond function: {str(e)}")
yield f"⚠️ Error: {str(e)}"
def create_demo():
"""Create and configure the Gradio interface"""
app = TxAgentApp()
with gr.Blocks(
title="TxAgent Medical AI",
theme=gr.themes.Soft(spacing_size="sm", radius_size="none")
) as demo:
gr.Markdown("""<h1 style='text-align: center'>TxAgent Biomedical Assistant</h1>""")
with gr.Row(equal_height=False):
with gr.Column(scale=2):
chatbot = gr.Chatbot(
height=650,
bubble_full_width=False,
render_markdown=True
)
msg = gr.Textbox(
label="Your medical query",
placeholder="Enter your biomedical question...",
lines=5,
max_lines=10
)
with gr.Column(scale=1):
with gr.Accordion("⚙️ Parameters", open=False):
temperature = gr.Slider(0, 1, value=0.3, label="Creativity")
max_new_tokens = gr.Slider(128, 4096, value=1024, step=128, label="Max Response Length")
max_tokens = gr.Slider(128, 81920, value=81920, step=1024, label="Max Total Tokens")
max_rounds = gr.Slider(1, 30, value=10, step=1, label="Max Rounds")
multi_agent = gr.Checkbox(value=False, label="Multi-Agent Mode")
submit_btn = gr.Button("Submit", variant="primary")
clear_btn = gr.Button("Clear History")
conversation_state = gr.State([])
# Chat interface
msg.submit(
app.respond,
[msg, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
chatbot
).then(lambda: "", None, msg)
submit_btn.click(
app.respond,
[msg, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
chatbot
).then(lambda: "", None, msg)
clear_btn.click(
lambda: ([], []),
None,
[chatbot, conversation_state]
)
return demo
def main():
"""Main entry point for the application"""
try:
logger.info("Starting TxAgent application...")
demo = create_demo()
logger.info("Launching Gradio interface...")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
except Exception as e:
logger.error(f"Application failed to start: {str(e)}")
raise
if __name__ == "__main__":
main() |