|
import os |
|
import json |
|
import torch |
|
import logging |
|
import numpy |
|
import gradio as gr |
|
from importlib.resources import files |
|
from txagent import TxAgent |
|
from tooluniverse import ToolUniverse |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
os.environ["MKL_THREADING_LAYER"] = "GNU" |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
CONFIG = { |
|
"model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B", |
|
"rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", |
|
"tool_files": { |
|
"opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')), |
|
"fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')), |
|
"special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')), |
|
"monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json')), |
|
"new_tool": os.path.join(current_dir, 'data', 'new_tool.json') |
|
} |
|
} |
|
|
|
class TxAgentApp: |
|
def __init__(self): |
|
self.agent = None |
|
self.initialize_agent() |
|
|
|
def initialize_agent(self): |
|
"""Initialize the TxAgent with proper error handling""" |
|
try: |
|
self.prepare_tool_files() |
|
logger.info("Initializing TxAgent...") |
|
|
|
self.agent = TxAgent( |
|
model_name=CONFIG["model_name"], |
|
rag_model_name=CONFIG["rag_model_name"], |
|
tool_files_dict=CONFIG["tool_files"], |
|
force_finish=True, |
|
enable_checker=True, |
|
step_rag_num=10, |
|
seed=42, |
|
additional_default_tools=["DirectResponse", "RequireClarification"] |
|
) |
|
|
|
logger.info("Initializing model...") |
|
self.agent.init_model() |
|
logger.info("Agent initialization complete") |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to initialize agent: {e}") |
|
raise |
|
|
|
def prepare_tool_files(self): |
|
"""Prepare the tool files directory""" |
|
try: |
|
os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True) |
|
if not os.path.exists(CONFIG["tool_files"]["new_tool"]): |
|
logger.info("Creating new_tool.json...") |
|
tu = ToolUniverse() |
|
tools = tu.get_all_tools() if hasattr(tu, "get_all_tools") else getattr(tu, "tools", []) |
|
with open(CONFIG["tool_files"]["new_tool"], "w") as f: |
|
json.dump(tools, f, indent=2) |
|
except Exception as e: |
|
logger.error(f"Failed to prepare tool files: {e}") |
|
raise |
|
|
|
def respond(self, msg, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round): |
|
"""Handle user message and generate response""" |
|
try: |
|
if not isinstance(msg, str) or len(msg.strip()) <= 10: |
|
return chat_history + [{"role": "assistant", "content": "Please provide a valid message longer than 10 characters."}] |
|
|
|
message = msg.strip() |
|
chat_history.append({"role": "user", "content": message}) |
|
formatted_history = [(m["role"], m["content"]) for m in chat_history if "role" in m and "content" in m] |
|
|
|
logger.info(f"Processing message: {message[:100]}...") |
|
|
|
response_generator = self.agent.run_gradio_chat( |
|
message=message, |
|
history=formatted_history, |
|
temperature=temperature, |
|
max_new_tokens=max_new_tokens, |
|
max_token=max_tokens, |
|
call_agent=multi_agent, |
|
conversation=conversation, |
|
max_round=max_round, |
|
seed=42 |
|
) |
|
|
|
collected = "" |
|
for chunk in response_generator: |
|
if isinstance(chunk, dict) and "content" in chunk: |
|
collected += chunk["content"] |
|
elif isinstance(chunk, str): |
|
collected += chunk |
|
elif chunk is not None: |
|
collected += str(chunk) |
|
|
|
chat_history.append({"role": "assistant", "content": collected or "No response generated."}) |
|
return chat_history |
|
|
|
except Exception as e: |
|
logger.error(f"Error in respond function: {e}") |
|
chat_history.append({"role": "assistant", "content": f"Error: {str(e)}"}) |
|
return chat_history |
|
|
|
def create_demo(self): |
|
"""Create and return the Gradio interface""" |
|
with gr.Blocks(title="TxAgent", css=".gr-button { font-size: 18px !important; }") as demo: |
|
gr.Markdown("# TxAgent - Biomedical AI Assistant") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
chatbot = gr.Chatbot( |
|
label="Conversation", |
|
height=600 |
|
) |
|
msg = gr.Textbox( |
|
label="Your question", |
|
placeholder="Ask a biomedical question...", |
|
lines=3 |
|
) |
|
submit = gr.Button("Ask", variant="primary") |
|
|
|
with gr.Column(scale=1): |
|
temp = gr.Slider(0, 1, value=0.3, label="Temperature") |
|
max_new_tokens = gr.Slider(128, 4096, value=1024, step=128, label="Max New Tokens") |
|
max_tokens = gr.Slider(128, 81920, value=81920, step=1024, label="Max Total Tokens") |
|
max_rounds = gr.Slider(1, 30, value=10, step=1, label="Max Rounds") |
|
multi_agent = gr.Checkbox(label="Multi-Agent Mode", value=False) |
|
clear_btn = gr.Button("Clear Chat") |
|
|
|
submit.click( |
|
self.respond, |
|
inputs=[msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, gr.State([]), max_rounds], |
|
outputs=[chatbot] |
|
) |
|
clear_btn.click(lambda: [], None, chatbot, queue=False) |
|
|
|
|
|
demo.load(lambda: None, None, None) |
|
|
|
return demo |
|
|
|
def main(): |
|
"""Main entry point for the application""" |
|
try: |
|
logger.info("Starting TxAgent application...") |
|
app = TxAgentApp() |
|
demo = app.create_demo() |
|
|
|
logger.info("Launching Gradio interface...") |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=True, |
|
show_error=True |
|
) |
|
except Exception as e: |
|
logger.error(f"Application failed to start: {e}") |
|
raise |
|
|
|
if __name__ == "__main__": |
|
main() |