Ali2206 commited on
Commit
929325a
·
verified ·
1 Parent(s): e10d63b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -30
app.py CHANGED
@@ -2,10 +2,10 @@ import os
2
  import json
3
  import logging
4
  import torch
5
- from txagent import TxAgent
6
  import gradio as gr
7
  from tooluniverse import ToolUniverse
8
  import warnings
 
9
 
10
  # Suppress specific warnings
11
  warnings.filterwarnings("ignore", category=UserWarning)
@@ -28,6 +28,7 @@ logging.basicConfig(
28
  logger = logging.getLogger(__name__)
29
 
30
  def prepare_tool_files():
 
31
  os.makedirs("./data", exist_ok=True)
32
  if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
33
  logger.info("Generating tool list using ToolUniverse...")
@@ -37,7 +38,7 @@ def prepare_tool_files():
37
  json.dump(tools, f, indent=2)
38
  logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
39
 
40
- def safe_load_embeddings(filepath):
41
  """Safely load embeddings with proper weights_only handling"""
42
  try:
43
  # First try with weights_only=True (secure mode)
@@ -54,7 +55,7 @@ def patch_embedding_loading():
54
 
55
  original_load = ToolRAGModel.load_tool_desc_embedding
56
 
57
- def patched_load(self, tooluniverse):
58
  try:
59
  if not os.path.exists(CONFIG["embedding_filename"]):
60
  logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}")
@@ -64,7 +65,7 @@ def patch_embedding_loading():
64
  self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"])
65
 
66
  # Handle tool count mismatch
67
- tools = tooluniverse.tools # Changed from get_all_tools()
68
  current_count = len(tools)
69
  embedding_count = len(self.tool_desc_embedding)
70
 
@@ -101,7 +102,8 @@ class TxAgentApp:
101
  self.agent = None
102
  self.is_initialized = False
103
 
104
- def initialize(self):
 
105
  if self.is_initialized:
106
  return "✅ Already initialized"
107
 
@@ -111,8 +113,8 @@ class TxAgentApp:
111
 
112
  logger.info("Initializing TxAgent...")
113
  self.agent = TxAgent(
114
- CONFIG["model_name"],
115
- CONFIG["rag_model_name"],
116
  tool_files_dict=CONFIG["tool_files"],
117
  force_finish=True,
118
  enable_checker=True,
@@ -131,36 +133,51 @@ class TxAgentApp:
131
  logger.error(f"Initialization failed: {str(e)}")
132
  return f"❌ Initialization failed: {str(e)}"
133
 
134
- def chat(self, message, history):
 
 
 
 
 
 
 
 
 
 
135
  if not self.is_initialized:
136
- return {"role": "assistant", "content": "⚠️ Please initialize the model first"}
137
 
138
  try:
 
 
 
 
 
 
 
 
139
  response = ""
140
- # Modified to use the correct parameter name (max_length instead of max_tokens)
141
  for chunk in self.agent.run_gradio_chat(
142
  message=message,
143
- history=history,
144
  temperature=0.3,
145
  max_new_tokens=1024,
146
- max_length=8192, # Changed from max_tokens
147
- multi_agent=False,
148
- conversation=[],
149
  max_round=30
150
  ):
151
- response += chunk
152
 
153
- # Format response in the expected messages format
154
- return [
155
- {"role": "user", "content": message},
156
- {"role": "assistant", "content": response}
157
- ]
158
 
159
  except Exception as e:
160
  logger.error(f"Chat error: {str(e)}")
161
- return {"role": "assistant", "content": f"Error: {str(e)}"}
162
 
163
- def create_interface():
 
164
  app = TxAgentApp()
165
 
166
  with gr.Blocks(
@@ -195,25 +212,23 @@ def create_interface():
195
  inputs=msg
196
  )
197
 
198
- def wrapper_initialize():
 
199
  status = app.initialize()
200
  return status, gr.update(interactive=False)
201
 
202
- def wrapper_chat(message, chat_history):
203
- response = app.chat(message, chat_history)
204
- if isinstance(response, dict): # Error case
205
- return chat_history + [response]
206
- return response # Normal case
207
-
208
  init_btn.click(
209
  fn=wrapper_initialize,
210
  outputs=[init_status, init_btn]
211
  )
212
 
213
  msg.submit(
214
- fn=wrapper_chat,
215
  inputs=[msg, chatbot],
216
  outputs=chatbot
 
 
 
217
  )
218
 
219
  clear_btn.click(
 
2
  import json
3
  import logging
4
  import torch
 
5
  import gradio as gr
6
  from tooluniverse import ToolUniverse
7
  import warnings
8
+ from typing import List, Dict, Any
9
 
10
  # Suppress specific warnings
11
  warnings.filterwarnings("ignore", category=UserWarning)
 
28
  logger = logging.getLogger(__name__)
29
 
30
  def prepare_tool_files():
31
+ """Ensure tool files exist and are populated"""
32
  os.makedirs("./data", exist_ok=True)
33
  if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
34
  logger.info("Generating tool list using ToolUniverse...")
 
38
  json.dump(tools, f, indent=2)
39
  logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
40
 
41
+ def safe_load_embeddings(filepath: str) -> Any:
42
  """Safely load embeddings with proper weights_only handling"""
43
  try:
44
  # First try with weights_only=True (secure mode)
 
55
 
56
  original_load = ToolRAGModel.load_tool_desc_embedding
57
 
58
+ def patched_load(self, tooluniverse: ToolUniverse) -> bool:
59
  try:
60
  if not os.path.exists(CONFIG["embedding_filename"]):
61
  logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}")
 
65
  self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"])
66
 
67
  # Handle tool count mismatch
68
+ tools = tooluniverse.get_all_tools() # Use get_all_tools() instead of direct access
69
  current_count = len(tools)
70
  embedding_count = len(self.tool_desc_embedding)
71
 
 
102
  self.agent = None
103
  self.is_initialized = False
104
 
105
+ def initialize(self) -> str:
106
+ """Initialize the TxAgent with all required components"""
107
  if self.is_initialized:
108
  return "✅ Already initialized"
109
 
 
113
 
114
  logger.info("Initializing TxAgent...")
115
  self.agent = TxAgent(
116
+ model_name=CONFIG["model_name"],
117
+ rag_model_name=CONFIG["rag_model_name"],
118
  tool_files_dict=CONFIG["tool_files"],
119
  force_finish=True,
120
  enable_checker=True,
 
133
  logger.error(f"Initialization failed: {str(e)}")
134
  return f"❌ Initialization failed: {str(e)}"
135
 
136
+ def chat(self, message: str, history: List[List[str]]) -> List[List[str]]:
137
+ """
138
+ Handle chat interactions with the TxAgent
139
+
140
+ Args:
141
+ message: User input message
142
+ history: Chat history in format [[user_msg, bot_msg], ...]
143
+
144
+ Returns:
145
+ Updated chat history
146
+ """
147
  if not self.is_initialized:
148
+ return history + [["", "⚠️ Please initialize the model first"]]
149
 
150
  try:
151
+ # Convert history to the format TxAgent expects
152
+ tx_history = []
153
+ for user_msg, bot_msg in history:
154
+ tx_history.append({"role": "user", "content": user_msg})
155
+ if bot_msg: # Only add bot response if it exists
156
+ tx_history.append({"role": "assistant", "content": bot_msg})
157
+
158
+ # Generate response
159
  response = ""
 
160
  for chunk in self.agent.run_gradio_chat(
161
  message=message,
162
+ history=tx_history,
163
  temperature=0.3,
164
  max_new_tokens=1024,
165
+ max_token=8192, # Note: Using max_token instead of max_length
166
+ call_agent=False,
167
+ conversation=None,
168
  max_round=30
169
  ):
170
+ response = chunk # Get the final response
171
 
172
+ # Format response for Gradio Chatbot
173
+ return history + [[message, response]]
 
 
 
174
 
175
  except Exception as e:
176
  logger.error(f"Chat error: {str(e)}")
177
+ return history + [["", f"Error: {str(e)}"]]
178
 
179
+ def create_interface() -> gr.Blocks:
180
+ """Create the Gradio interface"""
181
  app = TxAgentApp()
182
 
183
  with gr.Blocks(
 
212
  inputs=msg
213
  )
214
 
215
+ def wrapper_initialize() -> tuple:
216
+ """Wrapper for initialization with UI updates"""
217
  status = app.initialize()
218
  return status, gr.update(interactive=False)
219
 
 
 
 
 
 
 
220
  init_btn.click(
221
  fn=wrapper_initialize,
222
  outputs=[init_status, init_btn]
223
  )
224
 
225
  msg.submit(
226
+ fn=app.chat,
227
  inputs=[msg, chatbot],
228
  outputs=chatbot
229
+ ).then(
230
+ lambda: "", # Clear message box
231
+ outputs=msg
232
  )
233
 
234
  clear_btn.click(