Ali2206 commited on
Commit
f2d6e83
Β·
verified Β·
1 Parent(s): 58c988e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -23
app.py CHANGED
@@ -6,11 +6,11 @@ from txagent import TxAgent
6
  import gradio as gr
7
  from tooluniverse import ToolUniverse
8
 
9
- # Configuration - Using your existing embedding file
10
  CONFIG = {
11
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
12
  "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
13
- "embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_e27fb393f3144ec28f620f33d4d79911.pt",
14
  "tool_files": {
15
  "new_tool": "./data/new_tool.json"
16
  }
@@ -33,21 +33,45 @@ def prepare_tool_files():
33
  json.dump(tools, f, indent=2)
34
  logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
35
 
36
- def load_embeddings(agent):
37
- embedding_path = CONFIG["embedding_filename"]
38
- try:
39
- if os.path.exists(embedding_path):
40
- logger.info(f"βœ… Loading existing embeddings from {embedding_path}")
41
- embeddings = torch.load(embedding_path)
42
- agent.rag_model.tool_desc_embedding = embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  return True
44
- else:
45
- logger.error(f"❌ Embedding file not found at {embedding_path}")
46
- logger.info("Please ensure the embedding file is in the root directory")
47
  return False
48
- except Exception as e:
49
- logger.error(f"Failed to load embeddings: {str(e)}")
50
- return False
51
 
52
  class TxAgentApp:
53
  def __init__(self):
@@ -59,9 +83,10 @@ class TxAgentApp:
59
  return "βœ… Already initialized"
60
 
61
  try:
62
- logger.info("Initializing TxAgent...")
 
63
 
64
- # Initialize TxAgent
65
  self.agent = TxAgent(
66
  CONFIG["model_name"],
67
  CONFIG["rag_model_name"],
@@ -73,15 +98,9 @@ class TxAgentApp:
73
  additional_default_tools=["DirectResponse", "RequireClarification"]
74
  )
75
 
76
- # Initialize models
77
  logger.info("Loading models...")
78
  self.agent.init_model()
79
 
80
- # Load embeddings
81
- logger.info("Loading embeddings...")
82
- if not load_embeddings(self.agent):
83
- return "❌ Failed to load embeddings - check logs"
84
-
85
  self.is_initialized = True
86
  return "βœ… TxAgent initialized successfully"
87
 
@@ -123,6 +142,7 @@ def create_interface():
123
  ) as demo:
124
  gr.Markdown("""
125
  # 🧠 TxAgent: Therapeutic Reasoning AI
 
126
  """)
127
 
128
  with gr.Row():
 
6
  import gradio as gr
7
  from tooluniverse import ToolUniverse
8
 
9
+ # Configuration with hardcoded embedding file
10
  CONFIG = {
11
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
12
  "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
13
+ "embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt",
14
  "tool_files": {
15
  "new_tool": "./data/new_tool.json"
16
  }
 
33
  json.dump(tools, f, indent=2)
34
  logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
35
 
36
+ def patch_toolrag_class():
37
+ """Monkey-patch the ToolRAG class to use our embedding file and handle tool count mismatch"""
38
+ from txagent.toolrag import ToolRAG
39
+
40
+ original_load = ToolRAG.load_tool_desc_embedding
41
+
42
+ def patched_load(self, tooluniverse):
43
+ try:
44
+ # Load our specific embedding file
45
+ self.tool_desc_embedding = torch.load(CONFIG["embedding_filename"])
46
+
47
+ # Get current tools and their count
48
+ tools = tooluniverse.get_all_tools()
49
+ current_tool_count = len(tools)
50
+ embedding_count = len(self.tool_desc_embedding)
51
+
52
+ # If counts don't match, truncate or pad as needed
53
+ if current_tool_count != embedding_count:
54
+ logger.warning(f"Tool count mismatch! Tools: {current_tool_count}, Embeddings: {embedding_count}")
55
+
56
+ if current_tool_count < embedding_count:
57
+ # Truncate embeddings to match tool count
58
+ self.tool_desc_embedding = self.tool_desc_embedding[:current_tool_count]
59
+ logger.warning(f"Truncated embeddings to {current_tool_count} vectors")
60
+ else:
61
+ # Pad with zeros (last embedding) if tools > embeddings
62
+ last_embedding = self.tool_desc_embedding[-1]
63
+ padding = [last_embedding] * (current_tool_count - embedding_count)
64
+ self.tool_desc_embedding = torch.cat([self.tool_desc_embedding] + padding)
65
+ logger.warning(f"Padded embeddings with {current_tool_count - embedding_count} vectors")
66
+
67
  return True
68
+
69
+ except Exception as e:
70
+ logger.error(f"Failed to load embeddings: {str(e)}")
71
  return False
72
+
73
+ # Apply the patch
74
+ ToolRAG.load_tool_desc_embedding = patched_load
75
 
76
  class TxAgentApp:
77
  def __init__(self):
 
83
  return "βœ… Already initialized"
84
 
85
  try:
86
+ # Apply our patch before initialization
87
+ patch_toolrag_class()
88
 
89
+ logger.info("Initializing TxAgent...")
90
  self.agent = TxAgent(
91
  CONFIG["model_name"],
92
  CONFIG["rag_model_name"],
 
98
  additional_default_tools=["DirectResponse", "RequireClarification"]
99
  )
100
 
 
101
  logger.info("Loading models...")
102
  self.agent.init_model()
103
 
 
 
 
 
 
104
  self.is_initialized = True
105
  return "βœ… TxAgent initialized successfully"
106
 
 
142
  ) as demo:
143
  gr.Markdown("""
144
  # 🧠 TxAgent: Therapeutic Reasoning AI
145
+ ### (Using pre-loaded embeddings)
146
  """)
147
 
148
  with gr.Row():