acumplid commited on
Commit
9803bf8
·
1 Parent(s): 88e6c95

Implemented new rerank

Browse files
Files changed (1) hide show
  1. rag.py +46 -3
rag.py CHANGED
@@ -1,7 +1,8 @@
1
  import logging
2
  import os
3
  import requests
4
-
 
5
 
6
 
7
  from langchain_community.vectorstores import FAISS
@@ -15,11 +16,13 @@ class RAG:
15
  #vectorstore = "vectorestore" # CA only
16
  vectorstore = "index-BAAI_bge-m3-1500-200-recursive_splitter-CA_ES_UE"
17
 
18
- def __init__(self, hf_token, embeddings_model, model_name):
19
 
20
 
21
  self.model_name = model_name
22
  self.hf_token = hf_token
 
 
23
 
24
  # load vectore store
25
  embeddings = HuggingFaceEmbeddings(model_name=embeddings_model, model_kwargs={'device': 'cpu'})
@@ -27,10 +30,50 @@ class RAG:
27
 
28
  logging.info("RAG loaded!")
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def get_context(self, instruction, number_of_contexts=2):
 
 
31
 
32
- documentos = self.vectore_store.similarity_search_with_score(instruction, k=number_of_contexts)
33
 
 
34
  return documentos
35
 
36
  def predict(self, instruction, context, model_parameters):
 
1
  import logging
2
  import os
3
  import requests
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ import torch
6
 
7
 
8
  from langchain_community.vectorstores import FAISS
 
16
  #vectorstore = "vectorestore" # CA only
17
  vectorstore = "index-BAAI_bge-m3-1500-200-recursive_splitter-CA_ES_UE"
18
 
19
+ def __init__(self, hf_token, embeddings_model, model_name, rerank_model, rerank_number_contexts):
20
 
21
 
22
  self.model_name = model_name
23
  self.hf_token = hf_token
24
+ self.rerank_model = rerank_model
25
+ self.rerank_number_contexts = rerank_number_contexts
26
 
27
  # load vectore store
28
  embeddings = HuggingFaceEmbeddings(model_name=embeddings_model, model_kwargs={'device': 'cpu'})
 
30
 
31
  logging.info("RAG loaded!")
32
 
33
+ def rerank_contexts(self, instruction, contexts, number_of_contexts=1):
34
+ """
35
+ Rerank the contexts based on their relevance to the given instruction.
36
+ """
37
+
38
+ rerank_model = self.rerank_model
39
+
40
+
41
+ tokenizer = AutoTokenizer.from_pretrained(rerank_model)
42
+ model = AutoModelForSequenceClassification.from_pretrained(rerank_model)
43
+
44
+ def get_score(query, passage):
45
+ """Calculate the relevance score of a passage with respect to a query."""
46
+
47
+
48
+ inputs = tokenizer(query, passage, return_tensors='pt', truncation=True, padding=True, max_length=512)
49
+
50
+
51
+ with torch.no_grad():
52
+ outputs = model(**inputs)
53
+
54
+
55
+ logits = outputs.logits
56
+
57
+
58
+ score = logits.view(-1, ).float()
59
+
60
+
61
+ return score
62
+
63
+ scores = [get_score(instruction, c[0].page_content) for c in contexts]
64
+ combined = list(zip(contexts, scores))
65
+ sorted_combined = sorted(combined, key=lambda x: x[1], reverse=True)
66
+ sorted_texts, _ = zip(*sorted_combined)
67
+
68
+ return sorted_texts[:number_of_contexts]
69
+
70
  def get_context(self, instruction, number_of_contexts=2):
71
+ """Retrieve the most relevant contexts for a given instruction."""
72
+ documentos = self.vectore_store.similarity_search_with_score(instruction, k=self.rerank_number_contexts)
73
 
74
+ documentos = self.rerank_contexts(instruction, documentos, number_of_contexts=number_of_contexts)
75
 
76
+ print("Reranked documents")
77
  return documentos
78
 
79
  def predict(self, instruction, context, model_parameters):