Update app.py
Browse files
app.py
CHANGED
@@ -4,19 +4,16 @@ import logging
|
|
4 |
import torch
|
5 |
from txagent import TxAgent
|
6 |
import gradio as gr
|
7 |
-
from huggingface_hub import hf_hub_download
|
8 |
from tooluniverse import ToolUniverse
|
9 |
-
from tqdm import tqdm
|
10 |
|
11 |
-
# Configuration -
|
12 |
CONFIG = {
|
13 |
"model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
|
14 |
-
"rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
|
15 |
"embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt",
|
16 |
"tool_files": {
|
17 |
"new_tool": "./data/new_tool.json"
|
18 |
-
}
|
19 |
-
"load_from_hub": True # Flag to load directly from Hugging Face
|
20 |
}
|
21 |
|
22 |
# Logging setup
|
@@ -47,7 +44,7 @@ def load_embeddings(agent):
|
|
47 |
except Exception as e:
|
48 |
logger.error(f"Failed to load embeddings: {e}")
|
49 |
|
50 |
-
logger.info("Generating tool embeddings
|
51 |
try:
|
52 |
tools = agent.tooluniverse.get_all_tools()
|
53 |
descriptions = [tool["description"] for tool in tools]
|
@@ -71,7 +68,7 @@ class TxAgentApp:
|
|
71 |
try:
|
72 |
logger.info("Initializing TxAgent with remote models...")
|
73 |
|
74 |
-
# Initialize
|
75 |
self.agent = TxAgent(
|
76 |
CONFIG["model_name"],
|
77 |
CONFIG["rag_model_name"],
|
@@ -80,11 +77,10 @@ class TxAgentApp:
|
|
80 |
enable_checker=True,
|
81 |
step_rag_num=10,
|
82 |
seed=100,
|
83 |
-
additional_default_tools=["DirectResponse", "RequireClarification"]
|
84 |
-
local_files_only=False # Force loading from Hugging Face Hub
|
85 |
)
|
86 |
|
87 |
-
logger.info("Loading
|
88 |
self.agent.init_model()
|
89 |
|
90 |
logger.info("Preparing embeddings...")
|
|
|
4 |
import torch
|
5 |
from txagent import TxAgent
|
6 |
import gradio as gr
|
|
|
7 |
from tooluniverse import ToolUniverse
|
|
|
8 |
|
9 |
+
# Configuration - Using remote Hugging Face models
|
10 |
CONFIG = {
|
11 |
"model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
|
12 |
+
"rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
|
13 |
"embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt",
|
14 |
"tool_files": {
|
15 |
"new_tool": "./data/new_tool.json"
|
16 |
+
}
|
|
|
17 |
}
|
18 |
|
19 |
# Logging setup
|
|
|
44 |
except Exception as e:
|
45 |
logger.error(f"Failed to load embeddings: {e}")
|
46 |
|
47 |
+
logger.info("Generating tool embeddings...")
|
48 |
try:
|
49 |
tools = agent.tooluniverse.get_all_tools()
|
50 |
descriptions = [tool["description"] for tool in tools]
|
|
|
68 |
try:
|
69 |
logger.info("Initializing TxAgent with remote models...")
|
70 |
|
71 |
+
# Initialize without local_files_only parameter
|
72 |
self.agent = TxAgent(
|
73 |
CONFIG["model_name"],
|
74 |
CONFIG["rag_model_name"],
|
|
|
77 |
enable_checker=True,
|
78 |
step_rag_num=10,
|
79 |
seed=100,
|
80 |
+
additional_default_tools=["DirectResponse", "RequireClarification"]
|
|
|
81 |
)
|
82 |
|
83 |
+
logger.info("Loading models from Hugging Face Hub...")
|
84 |
self.agent.init_model()
|
85 |
|
86 |
logger.info("Preparing embeddings...")
|