RAG-bot / app.py
Mattral's picture
Update app.py
3ef7ded verified
raw
history blame
4.84 kB
import streamlit as st
from bs4 import BeautifulSoup
import io
import fitz # PyMuPDF
import requests
from docarray import Document
from pydantic import BaseModel, Field
from typing import List
from langchain.llms import LlamaCpp
from langchain.callbacks.base import BaseCallbackHandler
from langchain.vectorstores import DocArrayInMemorySearch
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from langchain.text_splitter import RecursiveCharacterTextSplitter
class StreamHandler(BaseCallbackHandler):
def __init__(self, container, initial_text=""):
self.container = container
self.text = initial_text
def on_llm_new_token(self, token: str, **kwargs) -> None:
self.text += token
self.container.markdown(self.text)
class DocArrayDoc(BaseModel):
text: str = Field(default="")
embedding: List[float]
metadata: dict = Field(default_factory=dict)
@st.cache_data
def get_page_urls(url):
page = requests.get(url)
soup = BeautifulSoup(page.content, 'html.parser')
links = [link['href'] for link in soup.find_all('a', href=True) if link['href'].startswith(url)]
links.append(url)
return set(links)
@st.cache(allow_output_mutation=True)
def process_pdf(file):
doc = fitz.open("pdf", file.read())
texts = [page.get_text() for page in doc]
return '\n'.join(texts)
def get_url_content(url):
response = requests.get(url)
if url.endswith('.pdf'):
pdf = io.BytesIO(response.content)
doc = fitz.open(stream=pdf, filetype="pdf")
return (url, ''.join(page.get_text() for page in doc))
else:
soup = BeautifulSoup(response.content, 'html.parser')
content = soup.find_all('div', class_='wpb_content_element')
text = ' '.join([c.get_text().strip() for c in content])
return (url, text)
@st.cache_resource
def get_retriever(urls):
all_content = [get_url_content(url) for url in urls]
documents = [Document(text=content, metadata={'url': url}) for (url, content) in all_content]
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200)
docs = text_splitter.split_documents(documents)
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
db = DocArrayInMemorySearch.from_documents(docs, embeddings)
retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": 5, "fetch_k": 10})
return retriever
@st.cache_resource
def create_chain(_retriever):
n_gpu_layers = 10
n_batch = 2048
llm = LlamaCpp(
model_path="models/mistral-7b-instruct-v0.1.Q5_0.gguf",
n_gpu_layers=n_gpu_layers,
n_batch=n_batch,
n_ctx=2048,
temperature=0,
verbose=False,
streaming=True,
)
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
qa_chain = ConversationalRetrievalChain.from_llm(
llm, retriever=_retriever, memory=memory, verbose=False
)
return qa_chain
# Webpage title and header
st.set_page_config(page_title="Your own AI-Chat!")
st.header("Your own AI-Chat!")
system_prompt = st.text_area(
label="System Prompt",
value="You are a helpful AI assistant who answers questions accurately.",
key="system_prompt")
input_type = st.radio("Choose an input method:", ['URL', 'Upload PDF'])
if input_type == 'URL':
base_url = st.text_input("Enter the site URL here:", key="base_url")
if base_url:
urls = get_page_urls(base_url)
retriever = get_retriever(urls)
llm_chain = create_chain(retriever)
elif input_type == 'Upload PDF':
uploaded_file = st.file_uploader("Upload your PDF here:", type="pdf")
if uploaded_file:
pdf_text = process_pdf(uploaded_file)
urls = [pdf_text] # Assuming this needs to be wrapped into proper structure
retriever = get_retriever(urls) # Ensure retriever accepts this
llm_chain = create_chain(retriever)
# Interaction and message handling
if 'retriever' in locals() and retriever:
if "messages" not in st.session_state:
st.session_state.messages = [{"role": "assistant", "content": "How may I help you today?"}]
if "current_response" not in st.session_state:
st.session_state.current_response = ""
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
user_prompt = st.chat_input("Your message here", key="user_input")
if user_prompt:
st.session_state.messages.append({"role": "user", "content": user_prompt})
response = llm_chain.run(user_prompt)
st.session_state.messages.append({"role": "assistant", "content": response})