Spaces:
Sleeping
Sleeping
File size: 4,371 Bytes
7fdb8e9 9b37798 7fdb8e9 1869a15 7fdb8e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import concurrent.futures
import os
from loguru import logger
from qdrant_client.models import FieldCondition, Filter, MatchValue
from huggingface_hub import InferenceClient
from rag_demo.preprocessing.base import (
EmbeddedChunk,
)
from rag_demo.rag.base.query import EmbeddedQuery, Query
from .query_expansion import QueryExpansion
from .reranker import Reranker
from .prompt_templates import AnswerGenerationTemplate
from dotenv import load_dotenv
load_dotenv()
def flatten(nested_list: list) -> list:
"""Flatten a list of lists into a single list."""
return [item for sublist in nested_list for item in sublist]
class RAGPipeline:
def __init__(self, mock: bool = False) -> None:
self._query_expander = QueryExpansion(mock=mock)
self._reranker = Reranker(mock=mock)
def search(
self,
query: str,
k: int = 3,
expand_to_n_queries: int = 3,
) -> list:
query_model = Query.from_str(query)
n_generated_queries = self._query_expander.generate(
query_model, expand_to_n=expand_to_n_queries
)
logger.info(
f"Successfully generated {len(n_generated_queries)} search queries.",
)
with concurrent.futures.ThreadPoolExecutor() as executor:
search_tasks = [
executor.submit(self._search, _query_model, k)
for _query_model in n_generated_queries
]
n_k_documents = [
task.result() for task in concurrent.futures.as_completed(search_tasks)
]
n_k_documents = flatten(n_k_documents)
n_k_documents = list(set(n_k_documents))
logger.info(f"{len(n_k_documents)} documents retrieved successfully")
if len(n_k_documents) > 0:
k_documents = self.rerank(query, chunks=n_k_documents, keep_top_k=k)
else:
k_documents = []
return k_documents
def _search(self, query: Query, k: int = 3) -> list[EmbeddedChunk]:
assert k >= 3, "k should be >= 3"
def _search_data(
data_category_odm: type[EmbeddedChunk], embedded_query: EmbeddedQuery
) -> list[EmbeddedChunk]:
return data_category_odm.search(
query_vector=embedded_query.embedding,
limit=k,
)
api = InferenceClient(
model="intfloat/multilingual-e5-large-instruct",
token=os.getenv("HF_API_TOKEN"),
)
embedded_query: EmbeddedQuery = EmbeddedQuery(
embedding=api.feature_extraction(query.content),
id=query.id,
content=query.content,
)
retrieved_chunks = _search_data(EmbeddedChunk, embedded_query)
logger.info(f"{len(retrieved_chunks)} documents retrieved successfully")
return retrieved_chunks
def rerank(
self, query: str | Query, chunks: list[EmbeddedChunk], keep_top_k: int
) -> list[EmbeddedChunk]:
if isinstance(query, str):
query = Query.from_str(query)
reranked_documents = self._reranker.generate(
query=query, chunks=chunks, keep_top_k=keep_top_k
)
logger.info(f"{len(reranked_documents)} documents reranked successfully.")
return reranked_documents
def generate_answer(self, query: str, reranked_chunks: list[EmbeddedChunk]) -> str:
context = ""
for chunk in reranked_chunks:
context += "\n Document: "
context += chunk.content
api = InferenceClient(
model="meta-llama/Llama-3.3-70B-Instruct",
token=os.getenv("HF_API_TOKEN"),
)
answer_generation_template = AnswerGenerationTemplate()
prompt = answer_generation_template.create_template(context, query)
logger.info(prompt)
response = api.chat_completion(
[{"role": "user", "content": prompt}],
max_tokens=8192,
)
return response.choices[0].message.content
def rag(self, query: str) -> tuple[str, list[str]]:
docs = self.search(query, k=10)
reranked_docs = self.rerank(query, docs, keep_top_k=10)
return (
self.generate_answer(query, reranked_docs),
list(set([doc.metadata["filename"].split(".pdf")[0] for doc in reranked_docs])),
)
|