Ali2206 commited on
Commit
1a87180
·
verified ·
1 Parent(s): 1ac55b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -48
app.py CHANGED
@@ -9,11 +9,12 @@ from importlib.resources import files
9
  from txagent import TxAgent
10
  from tooluniverse import ToolUniverse
11
 
12
- # Patch PyTorch to allow loading old numpy pickles
13
  torch.serialization.add_safe_globals([
14
  numpy.core.multiarray._reconstruct,
15
  numpy.ndarray,
16
- numpy.dtype
 
17
  ])
18
 
19
  logging.basicConfig(
@@ -29,7 +30,7 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
29
  CONFIG = {
30
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
31
  "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
32
- "embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt",
33
  "tool_files": {
34
  "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
35
  "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
@@ -39,42 +40,15 @@ CONFIG = {
39
  }
40
  }
41
 
42
- def safe_load_embeddings(filepath):
43
- try:
44
- return torch.load(filepath, weights_only=True)
45
- except Exception as e:
46
- logger.warning(f"Retrying with weights_only=False due to: {e}")
47
- try:
48
- return torch.load(filepath, weights_only=False)
49
- except Exception as e:
50
- logger.error(f"Failed to load embeddings: {e}")
51
- return None
52
-
53
- def patch_embedding_loading():
54
- from txagent.toolrag import ToolRAGModel
55
- def patched_load(self, tooluniverse):
56
- try:
57
- if not os.path.exists(CONFIG["embedding_filename"]):
58
- return False
59
- self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"])
60
- if self.tool_desc_embedding is None:
61
- logger.error("Tool embedding file could not be loaded.")
62
- return False
63
-
64
- tools = tooluniverse.get_all_tools() if hasattr(tooluniverse, "get_all_tools") else getattr(tooluniverse, "tools", [])
65
- if len(tools) != len(self.tool_desc_embedding):
66
- logger.warning("Tool count mismatch.")
67
- if len(self.tool_desc_embedding) > len(tools):
68
- self.tool_desc_embedding = self.tool_desc_embedding[:len(tools)]
69
- else:
70
- padding = self.tool_desc_embedding[-1].unsqueeze(0).repeat(len(tools) - len(self.tool_desc_embedding), 1)
71
- self.tool_desc_embedding = torch.cat([self.tool_desc_embedding, padding], dim=0)
72
- return True
73
- except Exception as e:
74
- logger.error(f"Embedding load failed: {e}")
75
- return False
76
-
77
- ToolRAGModel.load_tool_desc_embedding = patched_load
78
 
79
  def prepare_tool_files():
80
  os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
@@ -88,14 +62,8 @@ def prepare_tool_files():
88
  logger.error(f"Tool generation failed: {e}")
89
 
90
  def create_agent():
91
- patch_embedding_loading()
92
  prepare_tool_files()
93
  try:
94
- tu = ToolUniverse()
95
- tools = tu.get_all_tools() if hasattr(tu, "get_all_tools") else getattr(tu, "tools", [])
96
- available_tool_names = [t["name"] for t in tools]
97
- additional_default_tools = [t for t in ["DirectResponse", "RequireClarification"] if t in available_tool_names]
98
-
99
  agent = TxAgent(
100
  CONFIG["model_name"],
101
  CONFIG["rag_model_name"],
@@ -104,8 +72,10 @@ def create_agent():
104
  enable_checker=True,
105
  step_rag_num=10,
106
  seed=42,
107
- additional_default_tools=additional_default_tools
108
  )
 
 
109
  agent.init_model()
110
  return agent
111
  except Exception as e:
@@ -114,8 +84,7 @@ def create_agent():
114
 
115
  def respond(msg, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
116
  if not isinstance(msg, str) or len(msg.strip()) <= 10:
117
- chat_history.append({"role": "assistant", "content": "Hi, I am TxAgent. Please provide a valid message longer than 10 characters."})
118
- return chat_history
119
 
120
  message = msg.strip()
121
  chat_history.append({"role": "user", "content": message})
 
9
  from txagent import TxAgent
10
  from tooluniverse import ToolUniverse
11
 
12
+ # Allow loading old numpy types with torch.load
13
  torch.serialization.add_safe_globals([
14
  numpy.core.multiarray._reconstruct,
15
  numpy.ndarray,
16
+ numpy.dtype,
17
+ numpy.dtypes.Float32DType
18
  ])
19
 
20
  logging.basicConfig(
 
30
  CONFIG = {
31
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
32
  "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
33
+ "embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding.pt",
34
  "tool_files": {
35
  "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
36
  "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
 
40
  }
41
  }
42
 
43
+ def generate_tool_embeddings(agent):
44
+ tu = ToolUniverse(tool_files=CONFIG["tool_files"])
45
+ tu.load_tools()
46
+ embedding_tensor = agent.rag_model.generate_tool_desc_embedding(tu)
47
+ if embedding_tensor is not None:
48
+ torch.save(embedding_tensor, CONFIG["embedding_filename"])
49
+ logger.info(f"Saved new embedding tensor to {CONFIG['embedding_filename']}")
50
+ else:
51
+ logger.warning("Embedding generation returned None")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  def prepare_tool_files():
54
  os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
 
62
  logger.error(f"Tool generation failed: {e}")
63
 
64
  def create_agent():
 
65
  prepare_tool_files()
66
  try:
 
 
 
 
 
67
  agent = TxAgent(
68
  CONFIG["model_name"],
69
  CONFIG["rag_model_name"],
 
72
  enable_checker=True,
73
  step_rag_num=10,
74
  seed=42,
75
+ additional_default_tools=["DirectResponse", "RequireClarification"]
76
  )
77
+ if not os.path.exists(CONFIG["embedding_filename"]):
78
+ generate_tool_embeddings(agent)
79
  agent.init_model()
80
  return agent
81
  except Exception as e:
 
84
 
85
  def respond(msg, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
86
  if not isinstance(msg, str) or len(msg.strip()) <= 10:
87
+ return chat_history + [{"role": "assistant", "content": "Hi, I am TxAgent. Please provide a valid message longer than 10 characters."}]
 
88
 
89
  message = msg.strip()
90
  chat_history.append({"role": "user", "content": message})