|
import streamlit as st |
|
from bs4 import BeautifulSoup |
|
import io |
|
import fitz |
|
import requests |
|
from langchain.llms import LlamaCpp |
|
from langchain.callbacks.base import BaseCallbackHandler |
|
from langchain.vectorstores import DocArrayInMemorySearch |
|
from langchain.docstore.document import Document |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from docarray import Document, DocumentArray |
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
|
|
|
|
class StreamHandler(BaseCallbackHandler): |
|
def __init__(self, container, initial_text=""): |
|
self.container = container |
|
self.text = initial_text |
|
|
|
def on_llm_new_token(self, token: str, **kwargs) -> None: |
|
self.text += token |
|
self.container.markdown(self.text) |
|
|
|
from langchain_core import BaseRetriever |
|
|
|
class SimpleEmbeddingRetriever(BaseRetriever): |
|
def __init__(self, documents): |
|
self.documents = documents |
|
|
|
def _get_relevant_documents(self, query: str, num_documents: int = 5): |
|
query_doc = Document(text=query) |
|
query_embedding = self.documents.embeddings.model.encode([query_doc.text])[0] |
|
query_doc.embedding = query_embedding |
|
scores = self.documents.match(query_doc, limit=num_documents, metric='cosine', use_scipy=True) |
|
return [(doc.text, score) for doc, score in scores] |
|
|
|
|
|
@st.cache_data |
|
def get_page_urls(url): |
|
try: |
|
page = requests.get(url) |
|
soup = BeautifulSoup(page.content, 'html.parser') |
|
links = [link['href'] for link in soup.find_all('a') if 'href' in link.attrs and link['href'].startswith(url) and link['href'] not in [url]] |
|
links.append(url) |
|
return set(links) |
|
except requests.RequestException as e: |
|
st.error(f"Failed to load page: {e}") |
|
return set() |
|
|
|
|
|
def get_url_content(url): |
|
try: |
|
response = requests.get(url) |
|
response.raise_for_status() |
|
if url.endswith('.pdf'): |
|
pdf = io.BytesIO(response.content) |
|
doc = fitz.open(stream=pdf, filetype="pdf") |
|
text = ''.join([page.get_text("text") for page in doc]) |
|
else: |
|
soup = BeautifulSoup(response.content, 'html.parser') |
|
content = soup.find_all('div', class_='wpb_content_element') |
|
text = ' '.join([c.get_text().strip() for c in content if c.get_text().strip() != '']) |
|
|
|
|
|
document = Document(text=text, tags={'url': url}) |
|
return DocumentArray([document]) |
|
except Exception as e: |
|
st.error(f"Failed to process URL content: {e}") |
|
return DocumentArray() |
|
|
|
|
|
@st.cache_resource |
|
def get_retriever(urls): |
|
documents = DocumentArray() |
|
for url in urls: |
|
content = get_url_content(url) |
|
if content: |
|
documents.extend(content) |
|
|
|
model = SentenceTransformer('all-MiniLM-L6-v2') |
|
embeddings = model.encode([doc.text for doc in documents], show_progress_bar=True) |
|
for doc, emb in zip(documents, embeddings): |
|
doc.embedding = emb |
|
|
|
return SimpleEmbeddingRetriever(documents) |
|
|
|
|
|
@st.cache_resource |
|
def create_chain(_retriever): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
n_gpu_layers = 5 |
|
n_batch = 512 |
|
|
|
llm = LlamaCpp( |
|
model_path="models /mistral-7b-instruct-v0.1.Q5_0.gguf", |
|
n_gpu_layers=n_gpu_layers, |
|
n_batch=n_batch, |
|
n_ctx=2048, |
|
|
|
temperature=0, |
|
|
|
verbose=False, |
|
streaming=True, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) |
|
|
|
|
|
qa_chain = ConversationalRetrievalChain.from_llm( |
|
llm, retriever=_retriever, memory=memory, verbose=False |
|
) |
|
|
|
return qa_chain |
|
|
|
|
|
|
|
st.set_page_config( |
|
page_title="Your own AI-Chat!" |
|
) |
|
|
|
|
|
st.header("Your own AI-Chat!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "base_url" not in st.session_state: |
|
st.session_state.base_url = "" |
|
|
|
base_url = st.text_input("Enter the site url here", key="base_url") |
|
|
|
if st.session_state.base_url != "": |
|
urls = get_page_urls(base_url) |
|
|
|
retriever = get_retriever(urls) |
|
|
|
|
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [ |
|
{"role": "assistant", "content": "How may I help you today?"} |
|
] |
|
|
|
if "current_response" not in st.session_state: |
|
st.session_state.current_response = "" |
|
|
|
|
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
|
|
|
|
|
|
llm_chain = create_chain(retriever) |
|
|
|
|
|
if user_prompt := st.chat_input("Your message here", key="user_input"): |
|
|
|
|
|
st.session_state.messages.append( |
|
{"role": "user", "content": user_prompt} |
|
) |
|
|
|
|
|
with st.chat_message("user"): |
|
st.markdown(user_prompt) |
|
|
|
|
|
|
|
|
|
|
|
response = llm_chain.run(user_prompt) |
|
|
|
|
|
st.session_state.messages.append( |
|
{"role": "assistant", "content": response} |
|
) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
st.markdown(response) |
|
|
|
|