Ali2206 commited on
Commit
206cae1
·
verified ·
1 Parent(s): 80b0f9f

Create toolrag.py

Browse files
Files changed (1) hide show
  1. src/txagent/toolrag.py +67 -0
src/txagent/toolrag.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ from sentence_transformers import SentenceTransformer
5
+ from .utils import get_md5
6
+
7
+
8
+ class ToolRAGModel:
9
+ def __init__(self, rag_model_name):
10
+ self.rag_model_name = rag_model_name
11
+ self.rag_model = None
12
+ self.tool_desc_embedding = None
13
+ self.tool_name = None
14
+ self.tool_embedding_path = None
15
+ self.load_rag_model()
16
+
17
+ def load_rag_model(self):
18
+ self.rag_model = SentenceTransformer(self.rag_model_name)
19
+ self.rag_model.max_seq_length = 4096
20
+ self.rag_model.tokenizer.padding_side = "right"
21
+
22
+ def load_tool_desc_embedding(self, toolbox):
23
+ self.tool_name, _ = toolbox.refresh_tool_name_desc(enable_full_desc=True)
24
+ all_tools_str = [json.dumps(each) for each in toolbox.prepare_tool_prompts(toolbox.all_tools)]
25
+ md5_value = get_md5(str(all_tools_str))
26
+ print("Computed MD5 for tool embedding:", md5_value)
27
+
28
+ self.tool_embedding_path = os.path.join(
29
+ os.path.dirname(__file__),
30
+ self.rag_model_name.split("/")[-1] + f"_tool_embedding_{md5_value}.pt"
31
+ )
32
+
33
+ if os.path.exists(self.tool_embedding_path):
34
+ try:
35
+ self.tool_desc_embedding = torch.load(self.tool_embedding_path, map_location="cpu")
36
+ assert len(self.tool_desc_embedding) == len(toolbox.all_tools), \
37
+ "Tool count mismatch with loaded embeddings."
38
+ print("\033[92mLoaded cached tool_desc_embedding.\033[0m")
39
+ return
40
+ except Exception as e:
41
+ print(f"⚠️ Failed loading cached embeddings: {e}")
42
+ self.tool_desc_embedding = None
43
+
44
+ print("\033[93mGenerating new tool_desc_embedding...\033[0m")
45
+ self.tool_desc_embedding = self.rag_model.encode(
46
+ all_tools_str, prompt="", normalize_embeddings=True
47
+ )
48
+
49
+ torch.save(self.tool_desc_embedding, self.tool_embedding_path)
50
+ print(f"\033[92mSaved new tool_desc_embedding to {self.tool_embedding_path}\033[0m")
51
+
52
+ def rag_infer(self, query, top_k=5):
53
+ torch.cuda.empty_cache()
54
+ queries = [query]
55
+ query_embeddings = self.rag_model.encode(
56
+ queries, prompt="", normalize_embeddings=True
57
+ )
58
+ if self.tool_desc_embedding is None:
59
+ raise RuntimeError("❌ tool_desc_embedding is not initialized. Did you forget to call load_tool_desc_embedding()?")
60
+
61
+ scores = self.rag_model.similarity(
62
+ query_embeddings, self.tool_desc_embedding
63
+ )
64
+ top_k = min(top_k, len(self.tool_name))
65
+ top_k_indices = torch.topk(scores, top_k).indices.tolist()[0]
66
+ top_k_tool_names = [self.tool_name[i] for i in top_k_indices]
67
+ return top_k_tool_names