gwIAS / rag.py
JayWadekar
Improving RAG prompt
d995ec7
raw
history blame
3.31 kB
# Utilities to build a RAG system to query information from the
# gwIAS search pipeline using Langchain
# Thanks to Pablo Villanueva Domingo for sharing his CAMELS template
# https://huggingface.co/spaces/PabloVD/CAMELSDocBot
from langchain import hub
from langchain_chroma import Chroma
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
# Load documentation from urls
def load_docs():
# Get urls
urlsfile = open("urls.txt")
urls = urlsfile.readlines()
urls = [url.replace("\n","") for url in urls]
urlsfile.close()
# Load, chunk and index the contents of the blog.
loader = WebBaseLoader(urls)
docs = loader.load()
# Add source URLs as document names for reference
for i, doc in enumerate(docs):
if 'source' in doc.metadata:
doc.metadata['name'] = doc.metadata['source']
else:
doc.metadata['name'] = f"Document {i+1}"
print(f"Loaded {len(docs)} documents:")
for doc in docs:
print(f" - {doc.metadata.get('name')}")
return docs
def extract_reference(url):
"""Extract a reference keyword from the GitHub URL"""
if "blob/main" in url:
return url.split("blob/main/")[-1]
elif "tree/main" in url:
return url.split("tree/main/")[-1] or "root"
return url
# Join content pages for processing
def format_docs(docs):
formatted_docs = []
for doc in docs:
source = doc.metadata.get('source', 'Unknown source')
reference = f"[{extract_reference(source)}]"
content = doc.page_content
formatted_docs.append(f"{content}\n\nReference: {reference}")
return "\n\n---\n\n".join(formatted_docs)
# Create a RAG chain
def RAG(llm, docs, embeddings):
# Split text
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(docs)
# Create vector store
vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
# Retrieve and generate using the relevant snippets of the documents
retriever = vectorstore.as_retriever()
# Prompt basis example for RAG systems
prompt = hub.pull("rlm/rag-prompt")
# Adding custom instructions to the prompt
template = prompt.messages[0].prompt.template
template_parts = template.split("\nQuestion: {question}")
combined_template = "You are an assistant for question-answering tasks. "\
+ "Use the following pieces of retrieved context to answer the question. "\
+ "If you don't know the answer, just say that you don't know. "\
+ "Use six sentences maximum and keep the answer concise. "\
+ "Write the names of the relevant functions from the retrived code. "\
+ "Include the reference IDs in square brackets at the end of your answer."\
+ template_parts[1]
prompt.messages[0].prompt.template = combined_template
# Create the chain
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
return rag_chain