Doux Thibault
rag to streamlit + new pdf
9a30a8c
raw
history blame
3.54 kB
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
from dotenv import load_dotenv
load_dotenv() # load .env api keys
mistral_api_key = os.getenv("MISTRAL_API_KEY")
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma, FAISS
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_mistralai import MistralAIEmbeddings
from langchain import hub
from langchain.chains import (
create_history_aware_retriever,
create_retrieval_chain,
)
from typing import Literal
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_mistralai import ChatMistralAI
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain_community.tools import DuckDuckGoSearchRun
from pathlib import Path
def load_chunk_persist_pdf() -> Chroma:
pdf_folder_path = os.path.join(os.getcwd(),Path("data/pdf/"))
documents = []
for file in os.listdir(pdf_folder_path):
if file.endswith('.pdf'):
pdf_path = os.path.join(pdf_folder_path, file)
loader = PyPDFLoader(pdf_path)
documents.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=10)
chunked_documents = text_splitter.split_documents(documents)
os.makedirs("data/chroma_store/", exist_ok=True)
vectorstore = Chroma.from_documents(
documents=chunked_documents,
embedding=MistralAIEmbeddings(),
persist_directory= os.path.join(os.getcwd(),Path("data/chroma_store/"))
)
vectorstore.persist()
return vectorstore
vectorstore = load_chunk_persist_pdf()
retriever = vectorstore.as_retriever()
prompt = hub.pull("rlm/rag-prompt")
# Data model
class RouteQuery(BaseModel):
"""Route a user query to the most relevant datasource."""
datasource: Literal["vectorstore", "websearch"] = Field(
...,
description="Given a user question choose to route it to web search or a vectorstore.",
)
# LLM with function call
llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key=mistral_api_key, temperature=0)
prompt = ChatPromptTemplate.from_template(
"""
You are a professional AI coach specialized in fitness, bodybuilding and nutrition.
You must adapt to the user : if he is a beginner, use simple words. You are gentle and motivative.
Use the following pieces of retrieved context to answer the question.
If you don't know the answer, just say that you don't know, and to refer to a nutritionist or a doctor.
Use three sentences maximum and keep the answer concise.
Question: {question}
Context: {context}
Answer:
""",
)
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
# print(rag_chain.invoke("Build a fitness program for me. Be precise in terms of exercises"))
# print(rag_chain.invoke("I am a 45 years old woman and I have to loose weight for the summer. Provide me with a fitness program"))