|
import os |
|
import sys |
|
import torch |
|
import json |
|
import logging |
|
import gradio as gr |
|
from importlib.resources import files |
|
from txagent import TxAgent |
|
from tooluniverse import ToolUniverse |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
os.environ["MKL_THREADING_LAYER"] = "GNU" |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
CONFIG = { |
|
"model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B", |
|
"rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", |
|
"embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt", |
|
"tool_files": { |
|
"opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')), |
|
"fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')), |
|
"special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')), |
|
"monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json')), |
|
"new_tool": os.path.join(current_dir, 'data', 'new_tool.json') |
|
} |
|
} |
|
|
|
chat_css = """ |
|
.gr-button { font-size: 20px !important; } |
|
.gr-button svg { width: 32px !important; height: 32px !important; } |
|
""" |
|
|
|
def safe_load_embeddings(filepath: str) -> any: |
|
try: |
|
return torch.load(filepath, weights_only=True) |
|
except Exception as e: |
|
logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}") |
|
try: |
|
return torch.load(filepath, weights_only=False) |
|
except Exception as e: |
|
logger.error(f"Failed to load embeddings: {str(e)}") |
|
return None |
|
|
|
def patch_embedding_loading(): |
|
try: |
|
from txagent.toolrag import ToolRAGModel |
|
|
|
def patched_load(self, tooluniverse): |
|
try: |
|
if not os.path.exists(CONFIG["embedding_filename"]): |
|
logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}") |
|
return False |
|
|
|
self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"]) |
|
|
|
if hasattr(tooluniverse, 'get_all_tools'): |
|
tools = tooluniverse.get_all_tools() |
|
elif hasattr(tooluniverse, 'tools'): |
|
tools = tooluniverse.tools |
|
else: |
|
logger.error("No method found to access tools from ToolUniverse") |
|
return False |
|
|
|
current_count = len(tools) |
|
embedding_count = len(self.tool_desc_embedding) |
|
|
|
if current_count != embedding_count: |
|
logger.warning(f"Tool count mismatch (tools: {current_count}, embeddings: {embedding_count})") |
|
if current_count < embedding_count: |
|
self.tool_desc_embedding = self.tool_desc_embedding[:current_count] |
|
logger.info(f"Truncated embeddings to match {current_count} tools") |
|
else: |
|
last_embedding = self.tool_desc_embedding[-1] |
|
padding = [last_embedding] * (current_count - embedding_count) |
|
self.tool_desc_embedding = torch.cat([self.tool_desc_embedding] + padding) |
|
logger.info(f"Padded embeddings to match {current_count} tools") |
|
|
|
return True |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to load embeddings: {str(e)}") |
|
return False |
|
|
|
ToolRAGModel.load_tool_desc_embedding = patched_load |
|
logger.info("Successfully patched embedding loading") |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to patch embedding loading: {str(e)}") |
|
raise |
|
|
|
def prepare_tool_files(): |
|
os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True) |
|
if not os.path.exists(CONFIG["tool_files"]["new_tool"]): |
|
logger.info("Generating tool list using ToolUniverse...") |
|
try: |
|
tu = ToolUniverse() |
|
if hasattr(tu, 'get_all_tools'): |
|
tools = tu.get_all_tools() |
|
elif hasattr(tu, 'tools'): |
|
tools = tu.tools |
|
else: |
|
tools = [] |
|
logger.error("Could not access tools from ToolUniverse") |
|
|
|
with open(CONFIG["tool_files"]["new_tool"], "w") as f: |
|
json.dump(tools, f, indent=2) |
|
logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}") |
|
except Exception as e: |
|
logger.error(f"Failed to prepare tool files: {str(e)}") |
|
|
|
def create_agent(): |
|
patch_embedding_loading() |
|
prepare_tool_files() |
|
|
|
try: |
|
agent = TxAgent( |
|
CONFIG["model_name"], |
|
CONFIG["rag_model_name"], |
|
tool_files_dict=CONFIG["tool_files"], |
|
force_finish=True, |
|
enable_checker=True, |
|
step_rag_num=10, |
|
seed=100, |
|
additional_default_tools=['DirectResponse', 'RequireClarification'] |
|
) |
|
agent.init_model() |
|
return agent |
|
except Exception as e: |
|
logger.error(f"Failed to create agent: {str(e)}") |
|
raise |
|
|
|
def respond(msg, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round): |
|
if not msg or len(msg.strip()) <= 10: |
|
return chat_history + [{"role": "assistant", "content": "Please provide a valid message with a string longer than 10 characters."}] |
|
|
|
|
|
chat_history = chat_history + [{"role": "user", "content": msg}] |
|
print("\n==== DEBUG ====") |
|
print("User Message:", msg) |
|
print("Chat History:", chat_history) |
|
print("================\n") |
|
|
|
try: |
|
|
|
formatted_history = [(m["role"], m["content"]) for m in chat_history] |
|
response_generator = agent.run_gradio_chat( |
|
formatted_history, |
|
temperature, |
|
max_new_tokens, |
|
max_tokens, |
|
multi_agent, |
|
conversation, |
|
max_round |
|
) |
|
collected = "" |
|
for chunk in response_generator: |
|
if isinstance(chunk, dict): |
|
collected += chunk.get("content", "") |
|
else: |
|
collected += str(chunk) |
|
chat_history.append({"role": "assistant", "content": collected}) |
|
except Exception as e: |
|
chat_history.append({"role": "assistant", "content": f"Error: {str(e)}"}) |
|
|
|
return chat_history |
|
|
|
def create_demo(agent): |
|
with gr.Blocks(css=chat_css) as demo: |
|
chatbot = gr.Chatbot(label="TxAgent", render_markdown=True) |
|
msg = gr.Textbox(label="Your question", placeholder="Type your biomedical query...", scale=6) |
|
with gr.Row(): |
|
temp = gr.Slider(0, 1, value=0.3, label="Temperature") |
|
max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens") |
|
max_tokens = gr.Slider(128, 81920, value=81920, label="Max Total Tokens") |
|
max_rounds = gr.Slider(1, 30, value=30, label="Max Rounds") |
|
multi_agent = gr.Checkbox(label="Multi-Agent Mode") |
|
with gr.Row(): |
|
submit = gr.Button("Ask TxAgent") |
|
|
|
submit.click( |
|
respond, |
|
inputs=[msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, gr.State([]), max_rounds], |
|
outputs=[chatbot] |
|
) |
|
|
|
return demo |
|
|
|
def main(): |
|
try: |
|
global agent |
|
agent = create_agent() |
|
demo = create_demo(agent) |
|
demo.launch() |
|
except Exception as e: |
|
logger.error(f"Application failed to start: {str(e)}") |
|
raise |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|