Update app.py
Browse files
app.py
CHANGED
|
@@ -2,10 +2,10 @@ import os
|
|
| 2 |
import json
|
| 3 |
import logging
|
| 4 |
import torch
|
| 5 |
-
from txagent import TxAgent
|
| 6 |
import gradio as gr
|
| 7 |
from tooluniverse import ToolUniverse
|
| 8 |
import warnings
|
|
|
|
| 9 |
|
| 10 |
# Suppress specific warnings
|
| 11 |
warnings.filterwarnings("ignore", category=UserWarning)
|
|
@@ -28,6 +28,7 @@ logging.basicConfig(
|
|
| 28 |
logger = logging.getLogger(__name__)
|
| 29 |
|
| 30 |
def prepare_tool_files():
|
|
|
|
| 31 |
os.makedirs("./data", exist_ok=True)
|
| 32 |
if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
|
| 33 |
logger.info("Generating tool list using ToolUniverse...")
|
|
@@ -37,7 +38,7 @@ def prepare_tool_files():
|
|
| 37 |
json.dump(tools, f, indent=2)
|
| 38 |
logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
|
| 39 |
|
| 40 |
-
def safe_load_embeddings(filepath):
|
| 41 |
"""Safely load embeddings with proper weights_only handling"""
|
| 42 |
try:
|
| 43 |
# First try with weights_only=True (secure mode)
|
|
@@ -54,7 +55,7 @@ def patch_embedding_loading():
|
|
| 54 |
|
| 55 |
original_load = ToolRAGModel.load_tool_desc_embedding
|
| 56 |
|
| 57 |
-
def patched_load(self, tooluniverse):
|
| 58 |
try:
|
| 59 |
if not os.path.exists(CONFIG["embedding_filename"]):
|
| 60 |
logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}")
|
|
@@ -64,7 +65,7 @@ def patch_embedding_loading():
|
|
| 64 |
self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"])
|
| 65 |
|
| 66 |
# Handle tool count mismatch
|
| 67 |
-
tools = tooluniverse.
|
| 68 |
current_count = len(tools)
|
| 69 |
embedding_count = len(self.tool_desc_embedding)
|
| 70 |
|
|
@@ -101,7 +102,8 @@ class TxAgentApp:
|
|
| 101 |
self.agent = None
|
| 102 |
self.is_initialized = False
|
| 103 |
|
| 104 |
-
def initialize(self):
|
|
|
|
| 105 |
if self.is_initialized:
|
| 106 |
return "✅ Already initialized"
|
| 107 |
|
|
@@ -111,8 +113,8 @@ class TxAgentApp:
|
|
| 111 |
|
| 112 |
logger.info("Initializing TxAgent...")
|
| 113 |
self.agent = TxAgent(
|
| 114 |
-
CONFIG["model_name"],
|
| 115 |
-
CONFIG["rag_model_name"],
|
| 116 |
tool_files_dict=CONFIG["tool_files"],
|
| 117 |
force_finish=True,
|
| 118 |
enable_checker=True,
|
|
@@ -131,36 +133,51 @@ class TxAgentApp:
|
|
| 131 |
logger.error(f"Initialization failed: {str(e)}")
|
| 132 |
return f"❌ Initialization failed: {str(e)}"
|
| 133 |
|
| 134 |
-
def chat(self, message, history):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
if not self.is_initialized:
|
| 136 |
-
return
|
| 137 |
|
| 138 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
response = ""
|
| 140 |
-
# Modified to use the correct parameter name (max_length instead of max_tokens)
|
| 141 |
for chunk in self.agent.run_gradio_chat(
|
| 142 |
message=message,
|
| 143 |
-
history=
|
| 144 |
temperature=0.3,
|
| 145 |
max_new_tokens=1024,
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
conversation=
|
| 149 |
max_round=30
|
| 150 |
):
|
| 151 |
-
response
|
| 152 |
|
| 153 |
-
# Format response
|
| 154 |
-
return [
|
| 155 |
-
{"role": "user", "content": message},
|
| 156 |
-
{"role": "assistant", "content": response}
|
| 157 |
-
]
|
| 158 |
|
| 159 |
except Exception as e:
|
| 160 |
logger.error(f"Chat error: {str(e)}")
|
| 161 |
-
return
|
| 162 |
|
| 163 |
-
def create_interface():
|
|
|
|
| 164 |
app = TxAgentApp()
|
| 165 |
|
| 166 |
with gr.Blocks(
|
|
@@ -195,25 +212,23 @@ def create_interface():
|
|
| 195 |
inputs=msg
|
| 196 |
)
|
| 197 |
|
| 198 |
-
def wrapper_initialize():
|
|
|
|
| 199 |
status = app.initialize()
|
| 200 |
return status, gr.update(interactive=False)
|
| 201 |
|
| 202 |
-
def wrapper_chat(message, chat_history):
|
| 203 |
-
response = app.chat(message, chat_history)
|
| 204 |
-
if isinstance(response, dict): # Error case
|
| 205 |
-
return chat_history + [response]
|
| 206 |
-
return response # Normal case
|
| 207 |
-
|
| 208 |
init_btn.click(
|
| 209 |
fn=wrapper_initialize,
|
| 210 |
outputs=[init_status, init_btn]
|
| 211 |
)
|
| 212 |
|
| 213 |
msg.submit(
|
| 214 |
-
fn=
|
| 215 |
inputs=[msg, chatbot],
|
| 216 |
outputs=chatbot
|
|
|
|
|
|
|
|
|
|
| 217 |
)
|
| 218 |
|
| 219 |
clear_btn.click(
|
|
|
|
| 2 |
import json
|
| 3 |
import logging
|
| 4 |
import torch
|
|
|
|
| 5 |
import gradio as gr
|
| 6 |
from tooluniverse import ToolUniverse
|
| 7 |
import warnings
|
| 8 |
+
from typing import List, Dict, Any
|
| 9 |
|
| 10 |
# Suppress specific warnings
|
| 11 |
warnings.filterwarnings("ignore", category=UserWarning)
|
|
|
|
| 28 |
logger = logging.getLogger(__name__)
|
| 29 |
|
| 30 |
def prepare_tool_files():
|
| 31 |
+
"""Ensure tool files exist and are populated"""
|
| 32 |
os.makedirs("./data", exist_ok=True)
|
| 33 |
if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
|
| 34 |
logger.info("Generating tool list using ToolUniverse...")
|
|
|
|
| 38 |
json.dump(tools, f, indent=2)
|
| 39 |
logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
|
| 40 |
|
| 41 |
+
def safe_load_embeddings(filepath: str) -> Any:
|
| 42 |
"""Safely load embeddings with proper weights_only handling"""
|
| 43 |
try:
|
| 44 |
# First try with weights_only=True (secure mode)
|
|
|
|
| 55 |
|
| 56 |
original_load = ToolRAGModel.load_tool_desc_embedding
|
| 57 |
|
| 58 |
+
def patched_load(self, tooluniverse: ToolUniverse) -> bool:
|
| 59 |
try:
|
| 60 |
if not os.path.exists(CONFIG["embedding_filename"]):
|
| 61 |
logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}")
|
|
|
|
| 65 |
self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"])
|
| 66 |
|
| 67 |
# Handle tool count mismatch
|
| 68 |
+
tools = tooluniverse.get_all_tools() # Use get_all_tools() instead of direct access
|
| 69 |
current_count = len(tools)
|
| 70 |
embedding_count = len(self.tool_desc_embedding)
|
| 71 |
|
|
|
|
| 102 |
self.agent = None
|
| 103 |
self.is_initialized = False
|
| 104 |
|
| 105 |
+
def initialize(self) -> str:
|
| 106 |
+
"""Initialize the TxAgent with all required components"""
|
| 107 |
if self.is_initialized:
|
| 108 |
return "✅ Already initialized"
|
| 109 |
|
|
|
|
| 113 |
|
| 114 |
logger.info("Initializing TxAgent...")
|
| 115 |
self.agent = TxAgent(
|
| 116 |
+
model_name=CONFIG["model_name"],
|
| 117 |
+
rag_model_name=CONFIG["rag_model_name"],
|
| 118 |
tool_files_dict=CONFIG["tool_files"],
|
| 119 |
force_finish=True,
|
| 120 |
enable_checker=True,
|
|
|
|
| 133 |
logger.error(f"Initialization failed: {str(e)}")
|
| 134 |
return f"❌ Initialization failed: {str(e)}"
|
| 135 |
|
| 136 |
+
def chat(self, message: str, history: List[List[str]]) -> List[List[str]]:
|
| 137 |
+
"""
|
| 138 |
+
Handle chat interactions with the TxAgent
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
message: User input message
|
| 142 |
+
history: Chat history in format [[user_msg, bot_msg], ...]
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
Updated chat history
|
| 146 |
+
"""
|
| 147 |
if not self.is_initialized:
|
| 148 |
+
return history + [["", "⚠️ Please initialize the model first"]]
|
| 149 |
|
| 150 |
try:
|
| 151 |
+
# Convert history to the format TxAgent expects
|
| 152 |
+
tx_history = []
|
| 153 |
+
for user_msg, bot_msg in history:
|
| 154 |
+
tx_history.append({"role": "user", "content": user_msg})
|
| 155 |
+
if bot_msg: # Only add bot response if it exists
|
| 156 |
+
tx_history.append({"role": "assistant", "content": bot_msg})
|
| 157 |
+
|
| 158 |
+
# Generate response
|
| 159 |
response = ""
|
|
|
|
| 160 |
for chunk in self.agent.run_gradio_chat(
|
| 161 |
message=message,
|
| 162 |
+
history=tx_history,
|
| 163 |
temperature=0.3,
|
| 164 |
max_new_tokens=1024,
|
| 165 |
+
max_token=8192, # Note: Using max_token instead of max_length
|
| 166 |
+
call_agent=False,
|
| 167 |
+
conversation=None,
|
| 168 |
max_round=30
|
| 169 |
):
|
| 170 |
+
response = chunk # Get the final response
|
| 171 |
|
| 172 |
+
# Format response for Gradio Chatbot
|
| 173 |
+
return history + [[message, response]]
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
except Exception as e:
|
| 176 |
logger.error(f"Chat error: {str(e)}")
|
| 177 |
+
return history + [["", f"Error: {str(e)}"]]
|
| 178 |
|
| 179 |
+
def create_interface() -> gr.Blocks:
|
| 180 |
+
"""Create the Gradio interface"""
|
| 181 |
app = TxAgentApp()
|
| 182 |
|
| 183 |
with gr.Blocks(
|
|
|
|
| 212 |
inputs=msg
|
| 213 |
)
|
| 214 |
|
| 215 |
+
def wrapper_initialize() -> tuple:
|
| 216 |
+
"""Wrapper for initialization with UI updates"""
|
| 217 |
status = app.initialize()
|
| 218 |
return status, gr.update(interactive=False)
|
| 219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
init_btn.click(
|
| 221 |
fn=wrapper_initialize,
|
| 222 |
outputs=[init_status, init_btn]
|
| 223 |
)
|
| 224 |
|
| 225 |
msg.submit(
|
| 226 |
+
fn=app.chat,
|
| 227 |
inputs=[msg, chatbot],
|
| 228 |
outputs=chatbot
|
| 229 |
+
).then(
|
| 230 |
+
lambda: "", # Clear message box
|
| 231 |
+
outputs=msg
|
| 232 |
)
|
| 233 |
|
| 234 |
clear_btn.click(
|