Trabis commited on
Commit
08bb753
·
verified ·
1 Parent(s): 2a881cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -3
app.py CHANGED
@@ -13,6 +13,7 @@ import requests
13
  from functools import lru_cache
14
  import torch
15
  from sentence_transformers import SentenceTransformer
 
16
  import threading
17
  from queue import Queue
18
  import concurrent.futures
@@ -46,6 +47,7 @@ class OptimizedRAGLoader:
46
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
  self.encoder = SentenceTransformer("intfloat/multilingual-e5-large")
48
  self.encoder.to(self.device)
 
49
 
50
  # Initialize thread pool
51
  self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
@@ -212,7 +214,7 @@ llm = ChatMistralAI(
212
  )
213
 
214
  rag_loader = OptimizedRAGLoader()
215
- retriever = rag_loader.get_retriever(k=4) # Reduced k for faster retrieval
216
 
217
  # Cache for processed questions
218
  question_cache = {}
@@ -335,9 +337,21 @@ def process_question(question: str) -> Iterator[str]:
335
  return
336
 
337
  relevant_docs = retriever(question)
338
- context = "\n".join([doc.page_content for doc in relevant_docs])
 
 
 
 
 
 
 
 
 
 
 
 
339
  prompt = prompt_template.format_messages(
340
- context=context,
341
  question=question
342
  )
343
  full_response = ""
@@ -353,6 +367,12 @@ def process_question(question: str) -> Iterator[str]:
353
  # yield full_response + "\n\n\nالمصادر المحتملة :\n" + "".join(sources)
354
  sources = [doc.metadata.get("source") for doc in relevant_docs]
355
  sources = list(set([os.path.splitext(source)[0] for source in sources]))
 
 
 
 
 
 
356
  yield full_response + "\n\n\nالمصادر المحتملة :\n" + "\n".join(sources)
357
  # yield full_response + "\n\n\nالمصادر المحتملة:\n" + "\n".join([doc.metadata.get("source") for doc in relevant_docs])
358
  question_cache[question] = (full_response, relevant_docs)
 
13
  from functools import lru_cache
14
  import torch
15
  from sentence_transformers import SentenceTransformer
16
+ from sentence_transformers.cross_encoder import CrossEncoder
17
  import threading
18
  from queue import Queue
19
  import concurrent.futures
 
47
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
  self.encoder = SentenceTransformer("intfloat/multilingual-e5-large")
49
  self.encoder.to(self.device)
50
+ self.reranker = model = CrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1",trust_remote_code=True)
51
 
52
  # Initialize thread pool
53
  self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
 
214
  )
215
 
216
  rag_loader = OptimizedRAGLoader()
217
+ retriever = rag_loader.get_retriever(k=30) # Reduced k for faster retrieval
218
 
219
  # Cache for processed questions
220
  question_cache = {}
 
337
  return
338
 
339
  relevant_docs = retriever(question)
340
+ # context = "\n".join([doc.page_content for doc in relevant_docs])
341
+
342
+
343
+ context = [doc.page_content for doc in relevant_docs]
344
+ text_pairs = [[question, text] for text in context]
345
+ scores = rag_loader.reranker.predict(text_pairs)
346
+
347
+ scored_docs = list(zip(scores, context, relevant_docs))
348
+ # scored_docs.sort(reverse=True)
349
+ scored_docs.sort(key=lambda x: x[0], reverse=True)
350
+ reranked_docs = [d[2].page_content for d in scored_docs][:6]
351
+
352
+
353
  prompt = prompt_template.format_messages(
354
+ context=reranked_docs,
355
  question=question
356
  )
357
  full_response = ""
 
367
  # yield full_response + "\n\n\nالمصادر المحتملة :\n" + "".join(sources)
368
  sources = [doc.metadata.get("source") for doc in relevant_docs]
369
  sources = list(set([os.path.splitext(source)[0] for source in sources]))
370
+
371
+
372
+ sources = [d[2].metadata['source'] for d in scored_docs][:6]
373
+ sources = list(set([os.path.splitext(source)[0] for source in sources]))
374
+
375
+
376
  yield full_response + "\n\n\nالمصادر المحتملة :\n" + "\n".join(sources)
377
  # yield full_response + "\n\n\nالمصادر المحتملة:\n" + "\n".join([doc.metadata.get("source") for doc in relevant_docs])
378
  question_cache[question] = (full_response, relevant_docs)