|
import streamlit as st |
|
from bs4 import BeautifulSoup |
|
import io |
|
import fitz |
|
import requests |
|
from docarray import Document |
|
from pydantic import BaseModel, Field |
|
from typing import List |
|
from langchain.llms import LlamaCpp |
|
from langchain.callbacks.base import BaseCallbackHandler |
|
from langchain.vectorstores import DocArrayInMemorySearch |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
|
class StreamHandler(BaseCallbackHandler): |
|
def __init__(self, container, initial_text=""): |
|
self.container = container |
|
self.text = initial_text |
|
|
|
def on_llm_new_token(self, token: str, **kwargs) -> None: |
|
self.text += token |
|
self.container.markdown(self.text) |
|
|
|
class DocArrayDoc(BaseModel): |
|
text: str = Field(default="") |
|
embedding: List[float] |
|
metadata: dict = Field(default_factory=dict) |
|
|
|
@st.cache_data |
|
def get_page_urls(url): |
|
page = requests.get(url) |
|
soup = BeautifulSoup(page.content, 'html.parser') |
|
links = [link['href'] for link in soup.find_all('a', href=True) if link['href'].startswith(url)] |
|
links.append(url) |
|
return set(links) |
|
|
|
@st.cache(allow_output_mutation=True) |
|
def process_pdf(file): |
|
doc = fitz.open("pdf", file.read()) |
|
texts = [page.get_text() for page in doc] |
|
return '\n'.join(texts) |
|
|
|
def get_url_content(url): |
|
response = requests.get(url) |
|
if url.endswith('.pdf'): |
|
pdf = io.BytesIO(response.content) |
|
doc = fitz.open(stream=pdf, filetype="pdf") |
|
return (url, ''.join(page.get_text() for page in doc)) |
|
else: |
|
soup = BeautifulSoup(response.content, 'html.parser') |
|
content = soup.find_all('div', class_='wpb_content_element') |
|
text = ' '.join([c.get_text().strip() for c in content]) |
|
return (url, text) |
|
|
|
@st.cache_resource |
|
def get_retriever(urls): |
|
all_content = [get_url_content(url) for url in urls] |
|
documents = [Document(text=content, metadata={'url': url}) for (url, content) in all_content] |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200) |
|
docs = text_splitter.split_documents(documents) |
|
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") |
|
db = DocArrayInMemorySearch.from_documents(docs, embeddings) |
|
retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": 5, "fetch_k": 10}) |
|
return retriever |
|
|
|
@st.cache_resource |
|
def create_chain(_retriever): |
|
n_gpu_layers = 10 |
|
n_batch = 2048 |
|
llm = LlamaCpp( |
|
model_path="models/mistral-7b-instruct-v0.1.Q5_0.gguf", |
|
n_gpu_layers=n_gpu_layers, |
|
n_batch=n_batch, |
|
n_ctx=2048, |
|
temperature=0, |
|
verbose=False, |
|
streaming=True, |
|
) |
|
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) |
|
qa_chain = ConversationalRetrievalChain.from_llm( |
|
llm, retriever=_retriever, memory=memory, verbose=False |
|
) |
|
return qa_chain |
|
|
|
|
|
st.set_page_config(page_title="Your own AI-Chat!") |
|
st.header("Your own AI-Chat!") |
|
|
|
system_prompt = st.text_area( |
|
label="System Prompt", |
|
value="You are a helpful AI assistant who answers questions accurately.", |
|
key="system_prompt") |
|
|
|
input_type = st.radio("Choose an input method:", ['URL', 'Upload PDF']) |
|
if input_type == 'URL': |
|
base_url = st.text_input("Enter the site URL here:", key="base_url") |
|
if base_url: |
|
urls = get_page_urls(base_url) |
|
retriever = get_retriever(urls) |
|
llm_chain = create_chain(retriever) |
|
elif input_type == 'Upload PDF': |
|
uploaded_file = st.file_uploader("Upload your PDF here:", type="pdf") |
|
if uploaded_file: |
|
pdf_text = process_pdf(uploaded_file) |
|
urls = [pdf_text] |
|
retriever = get_retriever(urls) |
|
llm_chain = create_chain(retriever) |
|
|
|
|
|
if 'retriever' in locals() and retriever: |
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [{"role": "assistant", "content": "How may I help you today?"}] |
|
if "current_response" not in st.session_state: |
|
st.session_state.current_response = "" |
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
user_prompt = st.chat_input("Your message here", key="user_input") |
|
if user_prompt: |
|
st.session_state.messages.append({"role": "user", "content": user_prompt}) |
|
response = llm_chain.run(user_prompt) |
|
st.session_state.messages.append({"role": "assistant", "content": response}) |
|
|