Ali2206 commited on
Commit
037141b
·
verified ·
1 Parent(s): cd50f39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -31
app.py CHANGED
@@ -11,73 +11,60 @@ from tooluniverse import ToolUniverse
11
  CONFIG = {
12
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
13
  "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
14
- "embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding.pt",
15
  "local_dir": "./models",
16
  "tool_files": {
17
  "new_tool": "./data/new_tool.json"
18
  }
19
  }
20
 
21
- # Logging setup
22
  logging.basicConfig(level=logging.INFO)
23
  logger = logging.getLogger(__name__)
24
 
25
  def prepare_tool_files():
26
  os.makedirs("./data", exist_ok=True)
27
  if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
28
- logger.info("Generating tool list using ToolUniverse...")
29
  tu = ToolUniverse()
30
  tools = tu.get_all_tools()
31
  with open(CONFIG["tool_files"]["new_tool"], "w") as f:
32
  json.dump(tools, f, indent=2)
33
  logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
34
 
35
-
36
  def download_model_files():
37
  os.makedirs(CONFIG["local_dir"], exist_ok=True)
38
- print("Downloading model files...")
39
 
40
  snapshot_download(
41
  repo_id=CONFIG["model_name"],
42
  local_dir=os.path.join(CONFIG["local_dir"], CONFIG["model_name"]),
43
  resume_download=True
44
  )
45
-
46
  snapshot_download(
47
  repo_id=CONFIG["rag_model_name"],
48
  local_dir=os.path.join(CONFIG["local_dir"], CONFIG["rag_model_name"]),
49
  resume_download=True
50
  )
51
 
52
- try:
53
- hf_hub_download(
54
- repo_id=CONFIG["rag_model_name"],
55
- filename=CONFIG["embedding_filename"],
56
- local_dir=CONFIG["local_dir"],
57
- resume_download=True
58
- )
59
- print("Embeddings file downloaded successfully")
60
- except Exception as e:
61
- print(f"Could not download embeddings file: {e}")
62
- print("Will attempt to generate it instead")
63
-
64
  def generate_embeddings(agent):
65
- embedding_path = os.path.join(CONFIG["local_dir"], CONFIG["embedding_filename"])
66
-
67
- if os.path.exists(embedding_path):
68
- print("Embeddings file already exists")
69
- return
70
-
71
- print("Generating missing tool embeddings...")
72
  try:
 
 
 
 
 
 
73
  tools = agent.tooluniverse.get_all_tools()
74
- descriptions = [tool["description"] for tool in tools]
75
- embeddings = agent.rag_model.generate_embeddings(descriptions)
76
- torch.save(embeddings, embedding_path)
77
  agent.rag_model.tool_desc_embedding = embeddings
78
- print(f"Embeddings saved to {embedding_path}")
 
79
  except Exception as e:
80
- print(f"Failed to generate embeddings: {e}")
81
  raise
82
 
83
  class TxAgentApp:
@@ -124,7 +111,6 @@ class TxAgentApp:
124
  max_round=30
125
  ):
126
  response += chunk
127
-
128
  return history + [(message, response)]
129
  except Exception as e:
130
  return history + [(message, f"Error: {str(e)}")]
 
11
  CONFIG = {
12
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
13
  "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
14
+ "embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_e27fb393f3144ec28f620f33d4d79911.pt",
15
  "local_dir": "./models",
16
  "tool_files": {
17
  "new_tool": "./data/new_tool.json"
18
  }
19
  }
20
 
21
+ # Logging
22
  logging.basicConfig(level=logging.INFO)
23
  logger = logging.getLogger(__name__)
24
 
25
  def prepare_tool_files():
26
  os.makedirs("./data", exist_ok=True)
27
  if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
 
28
  tu = ToolUniverse()
29
  tools = tu.get_all_tools()
30
  with open(CONFIG["tool_files"]["new_tool"], "w") as f:
31
  json.dump(tools, f, indent=2)
32
  logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
33
 
 
34
  def download_model_files():
35
  os.makedirs(CONFIG["local_dir"], exist_ok=True)
36
+ logger.info("Downloading model files...")
37
 
38
  snapshot_download(
39
  repo_id=CONFIG["model_name"],
40
  local_dir=os.path.join(CONFIG["local_dir"], CONFIG["model_name"]),
41
  resume_download=True
42
  )
 
43
  snapshot_download(
44
  repo_id=CONFIG["rag_model_name"],
45
  local_dir=os.path.join(CONFIG["local_dir"], CONFIG["rag_model_name"]),
46
  resume_download=True
47
  )
48
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def generate_embeddings(agent):
50
+ """Generates and assigns embeddings manually"""
51
+ path = os.path.join(CONFIG["local_dir"], CONFIG["embedding_filename"])
 
 
 
 
 
52
  try:
53
+ if os.path.exists(path):
54
+ logger.info("Embedding file already exists, loading...")
55
+ agent.rag_model.tool_desc_embedding = torch.load(path)
56
+ return
57
+
58
+ logger.info("Generating embeddings and saving...")
59
  tools = agent.tooluniverse.get_all_tools()
60
+ descs = [t["description"] for t in tools]
61
+ embeddings = agent.rag_model.generate_embeddings(descs)
62
+ torch.save(embeddings, path)
63
  agent.rag_model.tool_desc_embedding = embeddings
64
+ logger.info("Embeddings generated and saved")
65
+
66
  except Exception as e:
67
+ logger.error(f"Failed to generate embeddings: {e}")
68
  raise
69
 
70
  class TxAgentApp:
 
111
  max_round=30
112
  ):
113
  response += chunk
 
114
  return history + [(message, response)]
115
  except Exception as e:
116
  return history + [(message, f"Error: {str(e)}")]