Update app.py
Browse files
app.py
CHANGED
@@ -2,11 +2,16 @@ import os
|
|
2 |
import json
|
3 |
import torch
|
4 |
import logging
|
|
|
5 |
import gradio as gr
|
|
|
6 |
from importlib.resources import files
|
7 |
from txagent import TxAgent
|
8 |
from tooluniverse import ToolUniverse
|
9 |
|
|
|
|
|
|
|
10 |
logging.basicConfig(
|
11 |
level=logging.INFO,
|
12 |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
@@ -48,6 +53,9 @@ def patch_embedding_loading():
|
|
48 |
if not os.path.exists(CONFIG["embedding_filename"]):
|
49 |
return False
|
50 |
self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"])
|
|
|
|
|
|
|
51 |
|
52 |
tools = tooluniverse.get_all_tools() if hasattr(tooluniverse, "get_all_tools") else getattr(tooluniverse, "tools", [])
|
53 |
if len(tools) != len(self.tool_desc_embedding):
|
@@ -75,6 +83,11 @@ def create_agent():
|
|
75 |
patch_embedding_loading()
|
76 |
prepare_tool_files()
|
77 |
try:
|
|
|
|
|
|
|
|
|
|
|
78 |
agent = TxAgent(
|
79 |
CONFIG["model_name"],
|
80 |
CONFIG["rag_model_name"],
|
@@ -83,7 +96,7 @@ def create_agent():
|
|
83 |
enable_checker=True,
|
84 |
step_rag_num=10,
|
85 |
seed=42,
|
86 |
-
additional_default_tools=
|
87 |
)
|
88 |
agent.init_model()
|
89 |
return agent
|
@@ -91,7 +104,6 @@ def create_agent():
|
|
91 |
logger.error(f"Agent initialization failed: {e}")
|
92 |
raise
|
93 |
|
94 |
-
# ✅ FIXED: Proper message formatting
|
95 |
def respond(msg, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
|
96 |
if not isinstance(msg, str) or len(msg.strip()) <= 10:
|
97 |
return chat_history + [{"role": "assistant", "content": "Hi, I am TxAgent. Please provide a valid message longer than 10 characters."}]
|
@@ -147,4 +159,4 @@ def main():
|
|
147 |
demo.launch(share=False)
|
148 |
|
149 |
if __name__ == "__main__":
|
150 |
-
main()
|
|
|
2 |
import json
|
3 |
import torch
|
4 |
import logging
|
5 |
+
import numpy
|
6 |
import gradio as gr
|
7 |
+
import torch.serialization
|
8 |
from importlib.resources import files
|
9 |
from txagent import TxAgent
|
10 |
from tooluniverse import ToolUniverse
|
11 |
|
12 |
+
# Patch PyTorch to allow loading old numpy pickles
|
13 |
+
torch.serialization.add_safe_globals([numpy.core.multiarray._reconstruct])
|
14 |
+
|
15 |
logging.basicConfig(
|
16 |
level=logging.INFO,
|
17 |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
|
53 |
if not os.path.exists(CONFIG["embedding_filename"]):
|
54 |
return False
|
55 |
self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"])
|
56 |
+
if self.tool_desc_embedding is None:
|
57 |
+
logger.error("Tool embedding file could not be loaded.")
|
58 |
+
return False
|
59 |
|
60 |
tools = tooluniverse.get_all_tools() if hasattr(tooluniverse, "get_all_tools") else getattr(tooluniverse, "tools", [])
|
61 |
if len(tools) != len(self.tool_desc_embedding):
|
|
|
83 |
patch_embedding_loading()
|
84 |
prepare_tool_files()
|
85 |
try:
|
86 |
+
tu = ToolUniverse()
|
87 |
+
tools = tu.get_all_tools() if hasattr(tu, "get_all_tools") else getattr(tu, "tools", [])
|
88 |
+
available_tool_names = [t["name"] for t in tools]
|
89 |
+
additional_default_tools = [t for t in ["DirectResponse", "RequireClarification"] if t in available_tool_names]
|
90 |
+
|
91 |
agent = TxAgent(
|
92 |
CONFIG["model_name"],
|
93 |
CONFIG["rag_model_name"],
|
|
|
96 |
enable_checker=True,
|
97 |
step_rag_num=10,
|
98 |
seed=42,
|
99 |
+
additional_default_tools=additional_default_tools
|
100 |
)
|
101 |
agent.init_model()
|
102 |
return agent
|
|
|
104 |
logger.error(f"Agent initialization failed: {e}")
|
105 |
raise
|
106 |
|
|
|
107 |
def respond(msg, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
|
108 |
if not isinstance(msg, str) or len(msg.strip()) <= 10:
|
109 |
return chat_history + [{"role": "assistant", "content": "Hi, I am TxAgent. Please provide a valid message longer than 10 characters."}]
|
|
|
159 |
demo.launch(share=False)
|
160 |
|
161 |
if __name__ == "__main__":
|
162 |
+
main()
|