Ali2206 commited on
Commit
1ee16da
·
verified ·
1 Parent(s): 694bd3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -38
app.py CHANGED
@@ -33,45 +33,54 @@ def prepare_tool_files():
33
  json.dump(tools, f, indent=2)
34
  logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
35
 
36
- def patch_toolrag_class():
37
- """Monkey-patch the ToolRAG class to use our embedding file and handle tool count mismatch"""
38
- from txagent.toolrag import ToolRAG
39
-
40
- original_load = ToolRAG.load_tool_desc_embedding
41
-
42
- def patched_load(self, tooluniverse):
43
- try:
44
- # Load our specific embedding file
45
- self.tool_desc_embedding = torch.load(CONFIG["embedding_filename"])
46
-
47
- # Get current tools and their count
48
- tools = tooluniverse.get_all_tools()
49
- current_tool_count = len(tools)
50
- embedding_count = len(self.tool_desc_embedding)
51
 
52
- # If counts don't match, truncate or pad as needed
53
- if current_tool_count != embedding_count:
54
- logger.warning(f"Tool count mismatch! Tools: {current_tool_count}, Embeddings: {embedding_count}")
55
-
56
- if current_tool_count < embedding_count:
57
- # Truncate embeddings to match tool count
58
- self.tool_desc_embedding = self.tool_desc_embedding[:current_tool_count]
59
- logger.warning(f"Truncated embeddings to {current_tool_count} vectors")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  else:
61
- # Pad with zeros (last embedding) if tools > embeddings
62
- last_embedding = self.tool_desc_embedding[-1]
63
- padding = [last_embedding] * (current_tool_count - embedding_count)
64
- self.tool_desc_embedding = torch.cat([self.tool_desc_embedding] + padding)
65
- logger.warning(f"Padded embeddings with {current_tool_count - embedding_count} vectors")
66
 
67
- return True
68
-
69
- except Exception as e:
70
- logger.error(f"Failed to load embeddings: {str(e)}")
71
- return False
72
-
73
- # Apply the patch
74
- ToolRAG.load_tool_desc_embedding = patched_load
 
 
75
 
76
  class TxAgentApp:
77
  def __init__(self):
@@ -84,7 +93,7 @@ class TxAgentApp:
84
 
85
  try:
86
  # Apply our patch before initialization
87
- patch_toolrag_class()
88
 
89
  logger.info("Initializing TxAgent...")
90
  self.agent = TxAgent(
@@ -149,7 +158,11 @@ def create_interface():
149
  init_btn = gr.Button("Initialize Model", variant="primary")
150
  init_status = gr.Textbox(label="Status", interactive=False)
151
 
152
- chatbot = gr.Chatbot(height=500, label="Conversation")
 
 
 
 
153
  msg = gr.Textbox(label="Your clinical question")
154
  clear_btn = gr.Button("Clear Chat")
155
 
 
33
  json.dump(tools, f, indent=2)
34
  logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
35
 
36
+ def patch_embedding_loading():
37
+ """Monkey-patch the embedding loading functionality"""
38
+ try:
39
+ # Try to get the RAG model class dynamically
40
+ from txagent.txagent import TxAgent as TxAgentClass
41
+ original_init = TxAgentClass.__init__
42
+
43
+ def patched_init(self, *args, **kwargs):
44
+ # First let the original initialization happen
45
+ original_init(self, *args, **kwargs)
 
 
 
 
 
46
 
47
+ # Then handle the embeddings our way
48
+ try:
49
+ if os.path.exists(CONFIG["embedding_filename"]):
50
+ logger.info(f"Loading embeddings from {CONFIG['embedding_filename']}")
51
+ self.rag_model.tool_desc_embedding = torch.load(CONFIG["embedding_filename"])
52
+
53
+ # Handle tool count mismatch
54
+ tools = self.tooluniverse.get_all_tools()
55
+ current_count = len(tools)
56
+ embedding_count = len(self.rag_model.tool_desc_embedding)
57
+
58
+ if current_count != embedding_count:
59
+ logger.warning(f"Tool count mismatch (tools: {current_count}, embeddings: {embedding_count})")
60
+
61
+ if current_count < embedding_count:
62
+ self.rag_model.tool_desc_embedding = self.rag_model.tool_desc_embedding[:current_count]
63
+ logger.info(f"Truncated embeddings to match {current_count} tools")
64
+ else:
65
+ last_embedding = self.rag_model.tool_desc_embedding[-1]
66
+ padding = [last_embedding] * (current_count - embedding_count)
67
+ self.rag_model.tool_desc_embedding = torch.cat(
68
+ [self.rag_model.tool_desc_embedding] + padding
69
+ )
70
+ logger.info(f"Padded embeddings to match {current_count} tools")
71
  else:
72
+ logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}")
 
 
 
 
73
 
74
+ except Exception as e:
75
+ logger.error(f"Failed to load embeddings: {str(e)}")
76
+
77
+ # Apply the patch
78
+ TxAgentClass.__init__ = patched_init
79
+ logger.info("Successfully patched embedding loading")
80
+
81
+ except Exception as e:
82
+ logger.error(f"Failed to patch embedding loading: {str(e)}")
83
+ raise
84
 
85
  class TxAgentApp:
86
  def __init__(self):
 
93
 
94
  try:
95
  # Apply our patch before initialization
96
+ patch_embedding_loading()
97
 
98
  logger.info("Initializing TxAgent...")
99
  self.agent = TxAgent(
 
158
  init_btn = gr.Button("Initialize Model", variant="primary")
159
  init_status = gr.Textbox(label="Status", interactive=False)
160
 
161
+ chatbot = gr.Chatbot(
162
+ height=500,
163
+ label="Conversation",
164
+ type="messages" # Fixing the deprecation warning
165
+ )
166
  msg = gr.Textbox(label="Your clinical question")
167
  clear_btn = gr.Button("Clear Chat")
168