import logging import os import requests from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch from openai import OpenAI from langchain_community.vectorstores import FAISS from langchain_community.embeddings import HuggingFaceEmbeddings class RAG: NO_ANSWER_MESSAGE: str = "Ho sento, no he pogut respondre la teva pregunta." #vectorstore = "index-intfloat_multilingual-e5-small-500-100-CA-ES" # mixed #vectorstore = "vectorestore" # CA only #vectorstore = "test_rag_multilingual_vectorstore" vectorstore = "renfevectorstore" def __init__(self, hf_token, embeddings_model, model_name, rerank_model, # rerank_number_contexts, source_metadata="src", source_file=True): self.model_name = model_name self.hf_token = hf_token self.rerank_model = rerank_model # self.rerank_number_contexts = rerank_number_contexts self.source_metadata = source_metadata self.source_file = source_file # load vectore store embeddings = HuggingFaceEmbeddings(model_name=embeddings_model, model_kwargs={'device': 'cpu'}) self.vectore_store = FAISS.load_local(self.vectorstore, embeddings, allow_dangerous_deserialization=True)#, allow_dangerous_deserialization=True) logging.info("RAG loaded!") def rerank_contexts(self, instruction, contexts, number_of_contexts=1): """ Rerank the contexts based on their relevance to the given instruction. """ rerank_model = self.rerank_model print(f">>> Rerank model: [{rerank_model}]") if not rerank_model: logging.warning("No rerank model specified. Returning original contexts.") return contexts[:number_of_contexts] logging.info(f"Reranking contexts using model: {rerank_model}") tokenizer = AutoTokenizer.from_pretrained(rerank_model) model = AutoModelForSequenceClassification.from_pretrained(rerank_model) logging.info(f"Model and tokenizer loaded successfully.") def get_score(query, passage): """Calculate the relevance score of a passage with respect to a query.""" inputs = tokenizer(query, passage, return_tensors='pt', truncation=True, padding=True, max_length=512) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits score = logits.view(-1, ).float() return score scores = [get_score(instruction, c[0].page_content) for c in contexts] combined = list(zip(contexts, scores)) sorted_combined = sorted(combined, key=lambda x: x[1], reverse=True) sorted_texts, _ = zip(*sorted_combined) return sorted_texts[:number_of_contexts] def get_context(self, instruction, number_of_contexts=5, number_of_contexts_rerank=2): """Retrieve the most relevant contexts for a given instruction.""" logging.info(f"Retrieving contexts for instruction: {instruction}, number_of_contexts: {number_of_contexts}") documentos = self.vectore_store.similarity_search_with_score(instruction, k=number_of_contexts) # documentos = self.vectore_store.similarity_search_with_score(instruction, k=self.rerank_number_contexts) logging.info(f"Retrieved {len(documentos)} documents from the vectorstore.") if not documentos: logging.warning("No documents found in the vectorstore for the given instruction.") return [] documentos = self.rerank_contexts(instruction, documentos, number_of_contexts=number_of_contexts_rerank) # documentos = self.rerank_contexts(instruction, documentos, number_of_contexts=number_of_contexts) logging.info(f"Reranked documents, keeping top {number_of_contexts_rerank} contexts.") return documentos def predict_dolly(self, instruction, context, model_parameters): api_key = os.getenv("HF_TOKEN") headers = { "Accept" : "application/json", "Authorization": f"Bearer {api_key}", "Content-Type": "application/json" } query = f"### Instruction\n{instruction}\n\n### Context\n{context}\n\n### Answer\n " #prompt = "You are a helpful assistant. Answer the question using only the context you are provided with. If it is not possible to do it with the context, just say 'I can't answer'. <|endoftext|>" payload = { "inputs": query, "parameters": model_parameters } response = requests.post(self.model_name, headers=headers, json=payload) return response.json()[0]["generated_text"].split("###")[-1][8:] def predict_completion(self, instruction, context, model_parameters): model = os.getenv("MODEL") if not model: logging.error("No model specified in the environment variable 'MODEL'.") return "Model endpoint not specified." client = OpenAI( base_url=os.getenv("MODEL"), api_key=os.getenv("HF_TOKEN") ) query = f"Context:\n{context}\n\nQuestion:\n{instruction}" chat_completion = client.chat.completions.create( model="tgi", messages=[ {"role": "user", "content": query} ], temperature=model_parameters["temperature"], max_tokens=model_parameters["max_new_tokens"], stream=False, stop=["<|im_end|>"], extra_body = { "presence_penalty": model_parameters["repetition_penalty"] - 2, "do_sample": False } ) response = chat_completion.choices[0].message.content return response def beautiful_context(self, docs, source_metadata="src", source_file=True): """ Create a beautiful context from the retrieved documents. """ text_context = "" full_context = "" source_context = [] for doc in docs: # print("=" *100) # print(doc) # print("#" *100) if source_file: source = doc[0].metadata[source_metadata].split("/")[-1] # If source_file is True, we assume the metadata contains a file path else: # If source_file is False, we assume the metadata contains a URL or some identifier # and we append it directly. source = doc[0].metadata[source_metadata] source_context.append((source, doc[1])) full_context += "Source: " + source + "\n\n" text_context += doc[0].page_content full_context += doc[0].page_content + "\n\n" # print("#·" * 100) # print(f"Text context (len {len(text_context)}): {text_context[:100]}") # print(f"Full context (len {len(full_context)}): {full_context[:100]}") print(f"Source context (len {len(source_context)}): {source_context}") # print("#·" * 100) return text_context, full_context, source_context def get_response(self, prompt: str, model_parameters: dict) -> str: logging.info(f"Retrive query: {prompt}") try: # Retrieve contexts based on the prompt docs = self.get_context(prompt, number_of_contexts=model_parameters["NUM_CHUNKS"], number_of_contexts_rerank=model_parameters["num_chunks_rerank"]) logging.info(f"Retrieved {len(docs)} contexts for the prompt.") text_context, full_context, source = self.beautiful_context(docs, source_metadata=self.source_metadata) # print(text_context) del model_parameters["NUM_CHUNKS"] response = self.predict_completion(prompt, text_context, model_parameters) if not response: return self.NO_ANSWER_MESSAGE return response, full_context, source except Exception as err: logging.error(f"Error while getting response: {err}") print(err)