|
import os |
|
from operator import itemgetter |
|
|
|
from langchain_chroma import Chroma |
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
from langchain_core.runnables import RunnablePassthrough, RunnableParallel |
|
from langchain_core.output_parsers import JsonOutputParser |
|
from langchain.prompts import PromptTemplate |
|
|
|
from lib.models import MODELS_MAP |
|
from lib.utils import format_docs, retrieve_answer, load_embeddings |
|
from lib.entities import LLMEvalResult |
|
|
|
def create_retriever(llm_name, db_path, docs, collection_name="local-rag"): |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=60) |
|
|
|
splits = text_splitter.split_documents(docs) |
|
|
|
embeddings = load_embeddings(llm_name) |
|
|
|
if not os.path.exists(db_path): |
|
vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings, persist_directory=db_path, collection_name=collection_name) |
|
else: |
|
vectorstore = Chroma(persist_directory=db_path, embedding_function=embeddings, collection_name=collection_name) |
|
|
|
retriever = vectorstore.as_retriever() |
|
return retriever |
|
|
|
def create_qa_chain(llm, retriever, prompts_text): |
|
initial_prompt_text = prompts_text["initial_prompt"] |
|
qa_eval_prompt_text = prompts_text["evaluation_prompt"] |
|
|
|
initial_prompt = PromptTemplate( |
|
template=initial_prompt_text, |
|
input_variables=["question", "context"] |
|
) |
|
|
|
json_parser = JsonOutputParser(pydantic_object=LLMEvalResult) |
|
qa_eval_prompt = PromptTemplate( |
|
template=qa_eval_prompt_text, |
|
input_variables=["question","answer"], |
|
partial_variables={"format_instructions": json_parser.get_format_instructions()}, |
|
) |
|
|
|
qa_eval_prompt_with_context = PromptTemplate( |
|
template=qa_eval_prompt_text, |
|
input_variables=["question","answer","context"], |
|
partial_variables={"format_instructions": json_parser.get_format_instructions()}, |
|
) |
|
|
|
chain = ( |
|
RunnableParallel(context = retriever | format_docs, question = RunnablePassthrough()) | |
|
RunnableParallel(answer = initial_prompt | llm | retrieve_answer, question = itemgetter("question"), context = itemgetter("context") ) | |
|
RunnableParallel(input = qa_eval_prompt, context = itemgetter("context"), answer = itemgetter("answer")) | |
|
RunnableParallel(evaluation = itemgetter("input") | llm , context = itemgetter("context"), answer = itemgetter("answer") ) | |
|
RunnableParallel(output = itemgetter("answer"), evaluation = itemgetter("evaluation") | json_parser, context = itemgetter("context")) |
|
) |
|
return chain |