File size: 4,371 Bytes
7fdb8e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b37798
7fdb8e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1869a15
7fdb8e9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import concurrent.futures
import os

from loguru import logger
from qdrant_client.models import FieldCondition, Filter, MatchValue
from huggingface_hub import InferenceClient

from rag_demo.preprocessing.base import (
    EmbeddedChunk,
)
from rag_demo.rag.base.query import EmbeddedQuery, Query

from .query_expansion import QueryExpansion
from .reranker import Reranker
from .prompt_templates import AnswerGenerationTemplate

from dotenv import load_dotenv

load_dotenv()


def flatten(nested_list: list) -> list:
    """Flatten a list of lists into a single list."""

    return [item for sublist in nested_list for item in sublist]


class RAGPipeline:
    def __init__(self, mock: bool = False) -> None:
        self._query_expander = QueryExpansion(mock=mock)
        self._reranker = Reranker(mock=mock)

    def search(
        self,
        query: str,
        k: int = 3,
        expand_to_n_queries: int = 3,
    ) -> list:
        query_model = Query.from_str(query)

        n_generated_queries = self._query_expander.generate(
            query_model, expand_to_n=expand_to_n_queries
        )
        logger.info(
            f"Successfully generated {len(n_generated_queries)} search queries.",
        )

        with concurrent.futures.ThreadPoolExecutor() as executor:
            search_tasks = [
                executor.submit(self._search, _query_model, k)
                for _query_model in n_generated_queries
            ]

            n_k_documents = [
                task.result() for task in concurrent.futures.as_completed(search_tasks)
            ]
            n_k_documents = flatten(n_k_documents)
            n_k_documents = list(set(n_k_documents))

        logger.info(f"{len(n_k_documents)} documents retrieved successfully")

        if len(n_k_documents) > 0:
            k_documents = self.rerank(query, chunks=n_k_documents, keep_top_k=k)
        else:
            k_documents = []

        return k_documents

    def _search(self, query: Query, k: int = 3) -> list[EmbeddedChunk]:
        assert k >= 3, "k should be >= 3"

        def _search_data(
            data_category_odm: type[EmbeddedChunk], embedded_query: EmbeddedQuery
        ) -> list[EmbeddedChunk]:
            return data_category_odm.search(
                query_vector=embedded_query.embedding,
                limit=k,
            )

        api = InferenceClient(
            model="intfloat/multilingual-e5-large-instruct",
            token=os.getenv("HF_API_TOKEN"),
        )
        embedded_query: EmbeddedQuery = EmbeddedQuery(
            embedding=api.feature_extraction(query.content),
            id=query.id,
            content=query.content,
        )

        retrieved_chunks = _search_data(EmbeddedChunk, embedded_query)
        logger.info(f"{len(retrieved_chunks)} documents retrieved successfully")

        return retrieved_chunks

    def rerank(
        self, query: str | Query, chunks: list[EmbeddedChunk], keep_top_k: int
    ) -> list[EmbeddedChunk]:
        if isinstance(query, str):
            query = Query.from_str(query)

        reranked_documents = self._reranker.generate(
            query=query, chunks=chunks, keep_top_k=keep_top_k
        )

        logger.info(f"{len(reranked_documents)} documents reranked successfully.")

        return reranked_documents

    def generate_answer(self, query: str, reranked_chunks: list[EmbeddedChunk]) -> str:
        context = ""
        for chunk in reranked_chunks:
            context += "\n Document: "
            context += chunk.content
        api = InferenceClient(
            model="meta-llama/Llama-3.3-70B-Instruct",
            token=os.getenv("HF_API_TOKEN"),
        )
        answer_generation_template = AnswerGenerationTemplate()
        prompt = answer_generation_template.create_template(context, query)
        logger.info(prompt)
        response = api.chat_completion(
            [{"role": "user", "content": prompt}],
            max_tokens=8192,
        )
        return response.choices[0].message.content

    def rag(self, query: str) -> tuple[str, list[str]]:
        docs = self.search(query, k=10)
        reranked_docs = self.rerank(query, docs, keep_top_k=10)
        return (
            self.generate_answer(query, reranked_docs),
            list(set([doc.metadata["filename"].split(".pdf")[0] for doc in reranked_docs])),
        )