|
import random |
|
import os |
|
import torch |
|
import logging |
|
import json |
|
from importlib.resources import files |
|
from txagent import TxAgent |
|
from tooluniverse import ToolUniverse |
|
import gradio as gr |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
os.environ["MKL_THREADING_LAYER"] = "GNU" |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
CONFIG = { |
|
"model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B", |
|
"rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", |
|
"embedding_filename": "path_to_your_embeddings.pt", |
|
"tool_files": { |
|
"opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')), |
|
"fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')), |
|
"special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')), |
|
"monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json')), |
|
"new_tool": os.path.join(current_dir, 'data', 'new_tool.json') |
|
} |
|
} |
|
|
|
def safe_load_embeddings(filepath: str): |
|
"""Handle embedding loading with fallbacks""" |
|
try: |
|
|
|
return torch.load(filepath, weights_only=True) |
|
except Exception as e: |
|
logger.warning(f"Secure load failed, trying without weights_only: {str(e)}") |
|
try: |
|
return torch.load(filepath, weights_only=False) |
|
except Exception as e: |
|
logger.error(f"Failed to load embeddings: {str(e)}") |
|
return None |
|
|
|
def get_tools_from_universe(tooluniverse): |
|
"""Flexible tool extraction from ToolUniverse""" |
|
if hasattr(tooluniverse, 'get_all_tools'): |
|
return tooluniverse.get_all_tools() |
|
elif hasattr(tooluniverse, 'tools'): |
|
return tooluniverse.tools |
|
elif hasattr(tooluniverse, 'list_tools'): |
|
return tooluniverse.list_tools() |
|
else: |
|
logger.error("Could not find any tool access method in ToolUniverse") |
|
|
|
tools = [] |
|
for tool_file in CONFIG["tool_files"].values(): |
|
if os.path.exists(tool_file): |
|
with open(tool_file, 'r') as f: |
|
tools.extend(json.load(f)) |
|
return tools if tools else None |
|
|
|
def prepare_tool_files(): |
|
"""Ensure tool files exist and are populated""" |
|
os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True) |
|
if not os.path.exists(CONFIG["tool_files"]["new_tool"]): |
|
logger.info("Generating tool list...") |
|
try: |
|
tu = ToolUniverse() |
|
tools = get_tools_from_universe(tu) |
|
if tools: |
|
with open(CONFIG["tool_files"]["new_tool"], "w") as f: |
|
json.dump(tools, f, indent=2) |
|
logger.info(f"Saved {len(tools)} tools") |
|
else: |
|
logger.error("No tools could be loaded") |
|
except Exception as e: |
|
logger.error(f"Tool file preparation failed: {str(e)}") |
|
|
|
def create_agent(): |
|
"""Create and initialize the TxAgent with robust error handling""" |
|
prepare_tool_files() |
|
|
|
try: |
|
agent = TxAgent( |
|
model_name=CONFIG["model_name"], |
|
rag_model_name=CONFIG["rag_model_name"], |
|
tool_files_dict=CONFIG["tool_files"], |
|
force_finish=True, |
|
enable_checker=True, |
|
step_rag_num=10, |
|
seed=100, |
|
additional_default_tools=['DirectResponse', 'RequireClarification'] |
|
) |
|
agent.init_model() |
|
return agent |
|
except Exception as e: |
|
logger.error(f"Agent creation failed: {str(e)}") |
|
raise |
|
|
|
def format_response(history, message): |
|
"""Properly format responses for Gradio Chatbot""" |
|
if isinstance(message, (str, dict)): |
|
return history + [[None, str(message)]] |
|
elif hasattr(message, '__iter__'): |
|
full_response = "" |
|
for chunk in message: |
|
if isinstance(chunk, dict): |
|
full_response += chunk.get("content", "") |
|
else: |
|
full_response += str(chunk) |
|
return history + [[None, full_response]] |
|
return history + [[None, str(message)]] |
|
|
|
def create_demo(agent): |
|
"""Create the Gradio interface with proper message handling""" |
|
with gr.Blocks() as demo: |
|
chatbot = gr.Chatbot( |
|
height=800, |
|
label='TxAgent', |
|
show_copy_button=True, |
|
type="messages" |
|
) |
|
|
|
msg = gr.Textbox(label="Input", placeholder="Type your question...") |
|
clear = gr.ClearButton([msg, chatbot]) |
|
|
|
def respond(message, chat_history): |
|
try: |
|
|
|
agent_history = [] |
|
for user_msg, bot_msg in chat_history: |
|
if user_msg: |
|
agent_history.append({"role": "user", "content": user_msg}) |
|
if bot_msg: |
|
agent_history.append({"role": "assistant", "content": bot_msg}) |
|
|
|
|
|
response = agent.run_gradio_chat( |
|
agent_history + [{"role": "user", "content": message}], |
|
temperature=0.3, |
|
max_new_tokens=1024, |
|
max_tokens=81920, |
|
multi_agent=False, |
|
conversation=[], |
|
max_round=30 |
|
) |
|
|
|
|
|
full_response = "" |
|
for chunk in response: |
|
if isinstance(chunk, dict): |
|
full_response += chunk.get("content", "") |
|
else: |
|
full_response += str(chunk) |
|
|
|
return chat_history + [(message, full_response)] |
|
|
|
except Exception as e: |
|
logger.error(f"Error in response handling: {str(e)}") |
|
return chat_history + [(message, f"Error: {str(e)}")] |
|
|
|
msg.submit(respond, [msg, chatbot], [chatbot]) |
|
clear.click(lambda: [], None, [chatbot]) |
|
|
|
return demo |
|
|
|
def main(): |
|
"""Main application entry point""" |
|
try: |
|
agent = create_agent() |
|
demo = create_demo(agent) |
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
except Exception as e: |
|
logger.error(f"Application failed to start: {str(e)}") |
|
raise |
|
|
|
if __name__ == "__main__": |
|
main() |