Ali2206 commited on
Commit
58353ee
·
verified ·
1 Parent(s): 869805b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -19
app.py CHANGED
@@ -4,27 +4,28 @@ import logging
4
  import torch
5
  from txagent import TxAgent
6
  import gradio as gr
7
- from huggingface_hub import hf_hub_download, snapshot_download
8
  from tooluniverse import ToolUniverse
9
 
10
  # Configuration
11
  CONFIG = {
12
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
13
  "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
14
- "embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_e27fb393f3144ec28f620f33d4d79911.pt",
15
  "local_dir": "./models",
16
  "tool_files": {
17
  "new_tool": "./data/new_tool.json"
18
  }
19
  }
20
 
21
- # Logging
22
  logging.basicConfig(level=logging.INFO)
23
  logger = logging.getLogger(__name__)
24
 
25
  def prepare_tool_files():
26
  os.makedirs("./data", exist_ok=True)
27
  if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
 
28
  tu = ToolUniverse()
29
  tools = tu.get_all_tools()
30
  with open(CONFIG["tool_files"]["new_tool"], "w") as f:
@@ -33,13 +34,14 @@ def prepare_tool_files():
33
 
34
  def download_model_files():
35
  os.makedirs(CONFIG["local_dir"], exist_ok=True)
36
- logger.info("Downloading model files...")
37
 
38
  snapshot_download(
39
  repo_id=CONFIG["model_name"],
40
  local_dir=os.path.join(CONFIG["local_dir"], CONFIG["model_name"]),
41
  resume_download=True
42
  )
 
43
  snapshot_download(
44
  repo_id=CONFIG["rag_model_name"],
45
  local_dir=os.path.join(CONFIG["local_dir"], CONFIG["rag_model_name"]),
@@ -47,24 +49,23 @@ def download_model_files():
47
  )
48
 
49
  def generate_embeddings(agent):
50
- """Generates and assigns embeddings manually"""
51
- path = os.path.join(CONFIG["local_dir"], CONFIG["embedding_filename"])
52
- try:
53
- if os.path.exists(path):
54
- logger.info("Embedding file already exists, loading...")
55
- agent.rag_model.tool_desc_embedding = torch.load(path)
56
- return
57
 
58
- logger.info("Generating embeddings and saving...")
 
59
  tools = agent.tooluniverse.get_all_tools()
60
- descs = [t["description"] for t in tools]
61
- embeddings = agent.rag_model.generate_embeddings(descs)
62
- torch.save(embeddings, path)
63
  agent.rag_model.tool_desc_embedding = embeddings
64
- logger.info("Embeddings generated and saved")
65
-
66
  except Exception as e:
67
- logger.error(f"Failed to generate embeddings: {e}")
68
  raise
69
 
70
  class TxAgentApp:
@@ -111,6 +112,7 @@ class TxAgentApp:
111
  max_round=30
112
  ):
113
  response += chunk
 
114
  return history + [(message, response)]
115
  except Exception as e:
116
  return history + [(message, f"Error: {str(e)}")]
@@ -147,4 +149,4 @@ if __name__ == "__main__":
147
  prepare_tool_files()
148
  download_model_files()
149
  interface = create_interface()
150
- interface.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
4
  import torch
5
  from txagent import TxAgent
6
  import gradio as gr
7
+ from huggingface_hub import snapshot_download
8
  from tooluniverse import ToolUniverse
9
 
10
  # Configuration
11
  CONFIG = {
12
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
13
  "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
14
+ "embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt",
15
  "local_dir": "./models",
16
  "tool_files": {
17
  "new_tool": "./data/new_tool.json"
18
  }
19
  }
20
 
21
+ # Logging setup
22
  logging.basicConfig(level=logging.INFO)
23
  logger = logging.getLogger(__name__)
24
 
25
  def prepare_tool_files():
26
  os.makedirs("./data", exist_ok=True)
27
  if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
28
+ logger.info("Generating tool list using ToolUniverse...")
29
  tu = ToolUniverse()
30
  tools = tu.get_all_tools()
31
  with open(CONFIG["tool_files"]["new_tool"], "w") as f:
 
34
 
35
  def download_model_files():
36
  os.makedirs(CONFIG["local_dir"], exist_ok=True)
37
+ print("Downloading model files...")
38
 
39
  snapshot_download(
40
  repo_id=CONFIG["model_name"],
41
  local_dir=os.path.join(CONFIG["local_dir"], CONFIG["model_name"]),
42
  resume_download=True
43
  )
44
+
45
  snapshot_download(
46
  repo_id=CONFIG["rag_model_name"],
47
  local_dir=os.path.join(CONFIG["local_dir"], CONFIG["rag_model_name"]),
 
49
  )
50
 
51
  def generate_embeddings(agent):
52
+ embedding_path = CONFIG["embedding_filename"]
53
+
54
+ if os.path.exists(embedding_path):
55
+ print("Embeddings file already exists — loading...")
56
+ agent.rag_model.tool_desc_embedding = torch.load(embedding_path)
57
+ return
 
58
 
59
+ print("Generating missing tool embeddings...")
60
+ try:
61
  tools = agent.tooluniverse.get_all_tools()
62
+ descriptions = [tool["description"] for tool in tools]
63
+ embeddings = agent.rag_model.generate_embeddings(descriptions)
64
+ torch.save(embeddings, embedding_path)
65
  agent.rag_model.tool_desc_embedding = embeddings
66
+ print(f"Embeddings saved to {embedding_path}")
 
67
  except Exception as e:
68
+ print(f"Failed to generate embeddings: {e}")
69
  raise
70
 
71
  class TxAgentApp:
 
112
  max_round=30
113
  ):
114
  response += chunk
115
+
116
  return history + [(message, response)]
117
  except Exception as e:
118
  return history + [(message, f"Error: {str(e)}")]
 
149
  prepare_tool_files()
150
  download_model_files()
151
  interface = create_interface()
152
+ interface.launch(server_name="0.0.0.0", server_port=7860, share=False)