Lawyer-ChatBot / app.py
Krishnachaitanya2004's picture
Update app.py
e4be66c
raw
history blame
4.3 kB
from langchain.vectorstores import Chroma
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import pipeline
import torch
from langchain.llms import HuggingFacePipeline
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.chains import RetrievalQA
from langchain_community.document_loaders import UnstructuredFileLoader
from langchain.text_splitter import CharacterTextSplitter
import streamlit as st
import os
def main_process(uploaded_file):
file_name = uploaded_file.name
# Create a temporary directory
temp_dir = "temp"
os.makedirs(temp_dir, exist_ok=True)
# Save the uploaded file to the temporary directory
temp_path = os.path.join(temp_dir, file_name)
with open(temp_path, "wb") as temp_file:
temp_file.write(uploaded_file.getvalue())
# Process the uploaded file
loader = UnstructuredFileLoader(temp_path)
documents = loader.load()
for document in documents:
print(document.page_content)
# We cant load the whole pdf into the program so we split the pdf into chunks
# We use RecursiveCharacterTextSplitter to split the pdf into chunks
# Each chunk is 500 characters long and the chunks overlap by 200 characters (You can change this according to your needs)
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=400)
texts = text_splitter.split_documents(documents)
# We use SentenceTransformerEmbeddings to embed the text chunks
# Embeddings are used to find the similarity between the query and the text chunks
# We use multi-qa-mpnet-base-dot-v1 model to embed the text chunks
# We need to save the embeddings to disk so we use persist_directory to save the embeddings to disk
embeddings = SentenceTransformerEmbeddings(model_name="multi-qa-mpnet-base-dot-v1")
persist_directory = "chroma/"
# Chroma is used to store the embeddings
# We use from_documents to store the embeddings
# We use the persist_directory to save the embeddings to disk
db = Chroma.from_documents(texts, embeddings, persist_directory=persist_directory)
# To save and load the saved vector db (if needed in the future)
# Persist the database to disk
# db.persist()
# db = Chroma(persist_directory="db", embedding_function=embeddings)
checkpoint = "MBZUAI/LaMini-Flan-T5-783M"
# Initialize the tokenizer and base model for text generation
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
base_model = AutoModelForSeq2SeqLM.from_pretrained(
checkpoint,
device_map="auto",
torch_dtype=torch.float32
)
pipe = pipeline(
'text2text-generation',
model = base_model,
tokenizer = tokenizer,
max_length = 512,
do_sample = True,
temperature = 0.3,
top_p= 0.95
)
# Initialize a local language model pipeline
local_llm = HuggingFacePipeline(pipeline=pipe)
# Create a RetrievalQA chain
qa_chain = RetrievalQA.from_chain_type(
llm=local_llm,
chain_type='stuff',
retriever=db.as_retriever(search_type="similarity", search_kwargs={"k": 2}),
return_source_documents=True,
)
return qa_chain
st.title("Document Chatbot")
st.write("Upload a pdf file to get started")
uploaded_file = st.file_uploader("Choose a file", type=["pdf"])
if uploaded_file is not None:
qa_chain = main_process(uploaded_file)
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Accept user input
if prompt := st.chat_input("What is up?"):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message in chat message container
with st.chat_message("user"):
st.markdown(prompt)
# Get response from chatbot
with st.chat_message("assitant"):
response = qa_chain(prompt)
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})