Update app.py
Browse files
app.py
CHANGED
@@ -13,6 +13,7 @@ import requests
|
|
13 |
from functools import lru_cache
|
14 |
import torch
|
15 |
from sentence_transformers import SentenceTransformer
|
|
|
16 |
import threading
|
17 |
from queue import Queue
|
18 |
import concurrent.futures
|
@@ -46,6 +47,7 @@ class OptimizedRAGLoader:
|
|
46 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
47 |
self.encoder = SentenceTransformer("intfloat/multilingual-e5-large")
|
48 |
self.encoder.to(self.device)
|
|
|
49 |
|
50 |
# Initialize thread pool
|
51 |
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
|
@@ -212,7 +214,7 @@ llm = ChatMistralAI(
|
|
212 |
)
|
213 |
|
214 |
rag_loader = OptimizedRAGLoader()
|
215 |
-
retriever = rag_loader.get_retriever(k=
|
216 |
|
217 |
# Cache for processed questions
|
218 |
question_cache = {}
|
@@ -335,9 +337,21 @@ def process_question(question: str) -> Iterator[str]:
|
|
335 |
return
|
336 |
|
337 |
relevant_docs = retriever(question)
|
338 |
-
context = "\n".join([doc.page_content for doc in relevant_docs])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
339 |
prompt = prompt_template.format_messages(
|
340 |
-
context=
|
341 |
question=question
|
342 |
)
|
343 |
full_response = ""
|
@@ -353,6 +367,12 @@ def process_question(question: str) -> Iterator[str]:
|
|
353 |
# yield full_response + "\n\n\nالمصادر المحتملة :\n" + "".join(sources)
|
354 |
sources = [doc.metadata.get("source") for doc in relevant_docs]
|
355 |
sources = list(set([os.path.splitext(source)[0] for source in sources]))
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
yield full_response + "\n\n\nالمصادر المحتملة :\n" + "\n".join(sources)
|
357 |
# yield full_response + "\n\n\nالمصادر المحتملة:\n" + "\n".join([doc.metadata.get("source") for doc in relevant_docs])
|
358 |
question_cache[question] = (full_response, relevant_docs)
|
|
|
13 |
from functools import lru_cache
|
14 |
import torch
|
15 |
from sentence_transformers import SentenceTransformer
|
16 |
+
from sentence_transformers.cross_encoder import CrossEncoder
|
17 |
import threading
|
18 |
from queue import Queue
|
19 |
import concurrent.futures
|
|
|
47 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
48 |
self.encoder = SentenceTransformer("intfloat/multilingual-e5-large")
|
49 |
self.encoder.to(self.device)
|
50 |
+
self.reranker = model = CrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1",trust_remote_code=True)
|
51 |
|
52 |
# Initialize thread pool
|
53 |
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
|
|
|
214 |
)
|
215 |
|
216 |
rag_loader = OptimizedRAGLoader()
|
217 |
+
retriever = rag_loader.get_retriever(k=30) # Reduced k for faster retrieval
|
218 |
|
219 |
# Cache for processed questions
|
220 |
question_cache = {}
|
|
|
337 |
return
|
338 |
|
339 |
relevant_docs = retriever(question)
|
340 |
+
# context = "\n".join([doc.page_content for doc in relevant_docs])
|
341 |
+
|
342 |
+
|
343 |
+
context = [doc.page_content for doc in relevant_docs]
|
344 |
+
text_pairs = [[question, text] for text in context]
|
345 |
+
scores = rag_loader.reranker.predict(text_pairs)
|
346 |
+
|
347 |
+
scored_docs = list(zip(scores, context, relevant_docs))
|
348 |
+
# scored_docs.sort(reverse=True)
|
349 |
+
scored_docs.sort(key=lambda x: x[0], reverse=True)
|
350 |
+
reranked_docs = [d[2].page_content for d in scored_docs][:6]
|
351 |
+
|
352 |
+
|
353 |
prompt = prompt_template.format_messages(
|
354 |
+
context=reranked_docs,
|
355 |
question=question
|
356 |
)
|
357 |
full_response = ""
|
|
|
367 |
# yield full_response + "\n\n\nالمصادر المحتملة :\n" + "".join(sources)
|
368 |
sources = [doc.metadata.get("source") for doc in relevant_docs]
|
369 |
sources = list(set([os.path.splitext(source)[0] for source in sources]))
|
370 |
+
|
371 |
+
|
372 |
+
sources = [d[2].metadata['source'] for d in scored_docs][:6]
|
373 |
+
sources = list(set([os.path.splitext(source)[0] for source in sources]))
|
374 |
+
|
375 |
+
|
376 |
yield full_response + "\n\n\nالمصادر المحتملة :\n" + "\n".join(sources)
|
377 |
# yield full_response + "\n\n\nالمصادر المحتملة:\n" + "\n".join([doc.metadata.get("source") for doc in relevant_docs])
|
378 |
question_cache[question] = (full_response, relevant_docs)
|