Ali2206 commited on
Commit
a893249
·
verified ·
1 Parent(s): 57027dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -2
app.py CHANGED
@@ -30,6 +30,7 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
30
  CONFIG = {
31
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
32
  "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
 
33
  "tool_files": {
34
  "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
35
  "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
@@ -39,6 +40,16 @@ CONFIG = {
39
  }
40
  }
41
 
 
 
 
 
 
 
 
 
 
 
42
  def prepare_tool_files():
43
  os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
44
  if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
@@ -63,6 +74,10 @@ def create_agent():
63
  seed=42,
64
  additional_default_tools=["DirectResponse", "RequireClarification"]
65
  )
 
 
 
 
66
  agent.init_model()
67
  return agent
68
  except Exception as e:
@@ -130,8 +145,7 @@ def main():
130
  global agent
131
  agent = create_agent()
132
  demo = create_demo(agent)
133
- print("Exiting after embedding generation. Please restart the Space manually.")
134
- exit()
135
 
136
  if __name__ == "__main__":
137
  main()
 
30
  CONFIG = {
31
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
32
  "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
33
+ "embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding.pt",
34
  "tool_files": {
35
  "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
36
  "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
 
40
  }
41
  }
42
 
43
+ def generate_tool_embeddings(agent):
44
+ tu = ToolUniverse(tool_files=CONFIG["tool_files"])
45
+ tu.load_tools()
46
+ embedding_tensor = agent.rag_model.generate_tool_desc_embedding(tu)
47
+ if embedding_tensor is not None:
48
+ torch.save(embedding_tensor, CONFIG["embedding_filename"])
49
+ logger.info(f"Saved new embedding tensor to {CONFIG['embedding_filename']}")
50
+ else:
51
+ logger.warning("Embedding generation returned None")
52
+
53
  def prepare_tool_files():
54
  os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
55
  if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
 
74
  seed=42,
75
  additional_default_tools=["DirectResponse", "RequireClarification"]
76
  )
77
+ if not os.path.exists(CONFIG["embedding_filename"]):
78
+ generate_tool_embeddings(agent)
79
+ else:
80
+ logger.info("Embedding file found. Skipping embedding generation.")
81
  agent.init_model()
82
  return agent
83
  except Exception as e:
 
145
  global agent
146
  agent = create_agent()
147
  demo = create_demo(agent)
148
+ demo.launch(share=True)
 
149
 
150
  if __name__ == "__main__":
151
  main()