Ali2206 commited on
Commit
57027dc
·
verified ·
1 Parent(s): 398d7f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -21
app.py CHANGED
@@ -30,7 +30,6 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
30
  CONFIG = {
31
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
32
  "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
33
- "embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding.pt",
34
  "tool_files": {
35
  "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
36
  "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
@@ -40,16 +39,6 @@ CONFIG = {
40
  }
41
  }
42
 
43
- def generate_tool_embeddings(agent):
44
- tu = ToolUniverse(tool_files=CONFIG["tool_files"])
45
- tu.load_tools()
46
- embedding_tensor = agent.rag_model.load_tool_desc_embedding(tu)
47
- if embedding_tensor is not None:
48
- torch.save(embedding_tensor, CONFIG["embedding_filename"])
49
- logger.info(f"Saved new embedding tensor to {CONFIG['embedding_filename']}")
50
- else:
51
- logger.warning("Embedding generation returned None")
52
-
53
  def prepare_tool_files():
54
  os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
55
  if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
@@ -74,8 +63,6 @@ def create_agent():
74
  seed=42,
75
  additional_default_tools=["DirectResponse", "RequireClarification"]
76
  )
77
- if not os.path.exists(CONFIG["embedding_filename"]):
78
- generate_tool_embeddings(agent)
79
  agent.init_model()
80
  return agent
81
  except Exception as e:
@@ -88,7 +75,7 @@ def respond(msg, chat_history, temperature, max_new_tokens, max_tokens, multi_ag
88
 
89
  message = msg.strip()
90
  chat_history.append({"role": "user", "content": message})
91
- formatted_history = chat_history # format as list of dicts for run_gradio_chat
92
 
93
  try:
94
  response_generator = agent.run_gradio_chat(
@@ -101,20 +88,18 @@ def respond(msg, chat_history, temperature, max_new_tokens, max_tokens, multi_ag
101
  conversation=conversation,
102
  max_round=max_round,
103
  seed=42,
104
- call_agent_level=0,
105
  sub_agent_task=None
106
  )
107
 
108
  collected = ""
109
  for chunk in response_generator:
110
- if isinstance(chunk, list):
111
- for msg in chunk:
112
- if isinstance(msg, dict) and "content" in msg:
113
- collected += msg["content"]
114
- elif isinstance(chunk, dict) and "content" in chunk:
115
  collected += chunk["content"]
116
  elif isinstance(chunk, str):
117
  collected += chunk
 
 
118
 
119
  chat_history.append({"role": "assistant", "content": collected or "⚠️ No content returned."})
120
 
@@ -145,7 +130,8 @@ def main():
145
  global agent
146
  agent = create_agent()
147
  demo = create_demo(agent)
148
- demo.launch(share=True)
 
149
 
150
  if __name__ == "__main__":
151
  main()
 
30
  CONFIG = {
31
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
32
  "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
 
33
  "tool_files": {
34
  "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
35
  "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
 
39
  }
40
  }
41
 
 
 
 
 
 
 
 
 
 
 
42
  def prepare_tool_files():
43
  os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
44
  if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
 
63
  seed=42,
64
  additional_default_tools=["DirectResponse", "RequireClarification"]
65
  )
 
 
66
  agent.init_model()
67
  return agent
68
  except Exception as e:
 
75
 
76
  message = msg.strip()
77
  chat_history.append({"role": "user", "content": message})
78
+ formatted_history = [(m["role"], m["content"]) for m in chat_history if "role" in m and "content" in m]
79
 
80
  try:
81
  response_generator = agent.run_gradio_chat(
 
88
  conversation=conversation,
89
  max_round=max_round,
90
  seed=42,
91
+ call_agent_level=None,
92
  sub_agent_task=None
93
  )
94
 
95
  collected = ""
96
  for chunk in response_generator:
97
+ if isinstance(chunk, dict) and "content" in chunk:
 
 
 
 
98
  collected += chunk["content"]
99
  elif isinstance(chunk, str):
100
  collected += chunk
101
+ elif chunk is not None:
102
+ collected += str(chunk)
103
 
104
  chat_history.append({"role": "assistant", "content": collected or "⚠️ No content returned."})
105
 
 
130
  global agent
131
  agent = create_agent()
132
  demo = create_demo(agent)
133
+ print("Exiting after embedding generation. Please restart the Space manually.")
134
+ exit()
135
 
136
  if __name__ == "__main__":
137
  main()