File size: 5,454 Bytes
9aeb1dd dec1312 0b9e159 dcb29df 410d25f 66e2fa0 81ad366 79fb3cd dc06321 81ad366 dc06321 81ad366 dec1312 dc06321 dec1312 dcb29df 81ad366 dcb29df 0b9e159 dec1312 0b9e159 dec1312 0b9e159 dec1312 dc06321 dcb29df 81ad366 dc06321 81ad366 dc06321 90e4214 0b9e159 dc06321 0b9e159 dcb29df 0b9e159 dcb29df 0b9e159 90e4214 0b9e159 dcb29df dc06321 dcb29df 0b9e159 dcb29df 0b9e159 dcb29df 0b9e159 dc06321 0b9e159 dcb29df 81ad366 0b9e159 81ad366 dcb29df 0b9e159 dcb29df 0b9e159 dcb29df 0b9e159 dcb29df 0b9e159 dcb29df 0b9e159 dcb29df 0b9e159 dcb29df 0b9e159 81ad366 0b9e159 90e4214 0b9e159 81ad366 70839bb 0b9e159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import gradio as gr
import os
import logging
from txagent import TxAgent
from tooluniverse import ToolUniverse
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TxAgentApp:
def __init__(self):
self.agent = self._initialize_agent()
def _initialize_agent(self):
"""Initialize the TxAgent with proper parameters"""
try:
logger.info("Initializing TxAgent...")
# Initialize default tool files
tool_files = {
"opentarget": "opentarget_tools.json",
"fda_drug_label": "fda_drug_labeling_tools.json",
"special_tools": "special_tools.json",
"monarch": "monarch_tools.json"
}
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
tool_files_dict=tool_files, # This is critical!
enable_finish=True,
enable_rag=True,
enable_summary=False,
init_rag_num=0,
step_rag_num=10,
summary_mode='step',
summary_skip_last_k=0,
summary_context_length=None,
force_finish=True,
avoid_repeat=True,
seed=42,
enable_checker=True,
enable_chat=False,
additional_default_tools=["DirectResponse", "RequireClarification"]
)
# Explicitly initialize the model
agent.init_model()
logger.info("Model loading complete")
return agent
except Exception as e:
logger.error(f"Initialization failed: {str(e)}")
raise
def respond(self, message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round):
"""Handle streaming responses with Gradio"""
try:
if not isinstance(message, str) or len(message.strip()) <= 10:
return chat_history + [("", "Please provide a valid message longer than 10 characters.")]
# Convert chat history to list of tuples if needed
if chat_history and isinstance(chat_history[0], dict):
chat_history = [(h["role"], h["content"]) for h in chat_history if "role" in h and "content" in h]
response = ""
for chunk in self.agent.run_gradio_chat(
message=message.strip(),
history=chat_history,
temperature=temperature,
max_new_tokens=max_new_tokens,
max_token=max_tokens,
call_agent=multi_agent,
conversation=conversation_state,
max_round=max_round,
seed=42
):
if isinstance(chunk, dict):
response += chunk.get("content", "")
elif isinstance(chunk, str):
response += chunk
else:
response += str(chunk)
yield chat_history + [("user", message), ("assistant", response)]
except Exception as e:
logger.error(f"Error in respond function: {str(e)}")
yield chat_history + [("", f"⚠️ Error: {str(e)}")]
def create_demo():
"""Create and configure the Gradio interface"""
app = TxAgentApp()
with gr.Blocks(title="TxAgent Medical AI") as demo:
gr.Markdown("# TxAgent Biomedical Assistant")
chatbot = gr.Chatbot(
label="Conversation",
height=600,
bubble_full_width=False
)
msg = gr.Textbox(
label="Your medical query",
placeholder="Enter your biomedical question...",
lines=3
)
with gr.Row():
temp = gr.Slider(0, 1, value=0.3, label="Temperature")
max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
max_tokens = gr.Slider(128, 81920, value=81920, label="Max Total Tokens")
max_rounds = gr.Slider(1, 30, value=10, label="Max Rounds")
multi_agent = gr.Checkbox(label="Multi-Agent Mode")
submit = gr.Button("Submit")
clear = gr.Button("Clear")
conversation_state = gr.State([])
submit.click(
app.respond,
[msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
chatbot
)
clear.click(lambda: [], None, chatbot)
msg.submit(
app.respond,
[msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
chatbot
)
return demo
def main():
"""Main entry point for the application"""
try:
logger.info("Starting TxAgent application...")
demo = create_demo()
logger.info("Launching Gradio interface...")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
except Exception as e:
logger.error(f"Application failed to start: {str(e)}")
raise
if __name__ == "__main__":
main() |