Spaces:
Sleeping
Sleeping
import os | |
os.environ['TOKENIZERS_PARALLELISM'] = 'true' | |
from dotenv import load_dotenv | |
load_dotenv() # load .env api keys | |
mistral_api_key = os.getenv("MISTRAL_API_KEY") | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders import WebBaseLoader | |
from langchain_community.vectorstores import Chroma, FAISS | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain_mistralai import MistralAIEmbeddings | |
from langchain import hub | |
from langchain.chains import ( | |
create_history_aware_retriever, | |
create_retrieval_chain, | |
) | |
from typing import Literal | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_core.pydantic_v1 import BaseModel, Field | |
from langchain_mistralai import ChatMistralAI | |
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings | |
from langchain_community.tools import DuckDuckGoSearchRun | |
from pathlib import Path | |
def load_chunk_persist_pdf() -> Chroma: | |
pdf_folder_path = os.path.join(os.getcwd(),Path("data/pdf/")) | |
documents = [] | |
for file in os.listdir(pdf_folder_path): | |
if file.endswith('.pdf'): | |
pdf_path = os.path.join(pdf_folder_path, file) | |
loader = PyPDFLoader(pdf_path) | |
documents.extend(loader.load()) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=10) | |
chunked_documents = text_splitter.split_documents(documents) | |
os.makedirs("data/chroma_store/", exist_ok=True) | |
vectorstore = Chroma.from_documents( | |
documents=chunked_documents, | |
embedding=MistralAIEmbeddings(), | |
persist_directory= os.path.join(os.getcwd(),Path("data/chroma_store/")) | |
) | |
vectorstore.persist() | |
return vectorstore | |
vectorstore = load_chunk_persist_pdf() | |
retriever = vectorstore.as_retriever() | |
prompt = hub.pull("rlm/rag-prompt") | |
# Data model | |
class RouteQuery(BaseModel): | |
"""Route a user query to the most relevant datasource.""" | |
datasource: Literal["vectorstore", "websearch"] = Field( | |
..., | |
description="Given a user question choose to route it to web search or a vectorstore.", | |
) | |
# LLM with function call | |
llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key=mistral_api_key, temperature=0) | |
prompt = ChatPromptTemplate.from_template( | |
""" | |
You are a professional AI coach specialized in fitness, bodybuilding and nutrition. | |
You must adapt to the user : if he is a beginner, use simple words. You are gentle and motivative. | |
Use the following pieces of retrieved context to answer the question. | |
If you don't know the answer, just say that you don't know, and to refer to a nutritionist or a doctor. | |
Use three sentences maximum and keep the answer concise. | |
Question: {question} | |
Context: {context} | |
Answer: | |
""", | |
) | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
def format_docs(docs): | |
return "\n\n".join(doc.page_content for doc in docs) | |
rag_chain = ( | |
{"context": retriever | format_docs, "question": RunnablePassthrough()} | |
| prompt | |
| llm | |
| StrOutputParser() | |
) | |
# print(rag_chain.invoke("Build a fitness program for me. Be precise in terms of exercises")) | |
# print(rag_chain.invoke("I am a 45 years old woman and I have to loose weight for the summer. Provide me with a fitness program")) |