File size: 4,374 Bytes
4862b9f 9defb57 5f87533 7f51a5a 184b57b 5949a92 4862b9f 5949a92 4862b9f 5949a92 cbe0781 40b5018 5949a92 4d64061 035f699 5949a92 4862b9f 184b57b 4862b9f ed54e4f 4862b9f 234ded3 4862b9f dd45b3a 4862b9f 5949a92 44131c9 4862b9f 44131c9 9861a8a 4862b9f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import os
import time
import streamlit as st
from htmlTemplates import css, bot_template, user_template
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.memory import ConversationBufferMemory
from langchain.chains import RetrievalQA
from pdfminer.high_level import extract_text
from langchain.text_splitter import RecursiveCharacterTextSplitter
from transformers import AutoTokenizer, AutoModelForCausalLM
# Updated Prompt Template
persist_directory = 'db'
embeddings_model_name = 'sentence-transformers/all-MiniLM-L6-v2'
def get_pdf_text(pdf_path):
return extract_text(pdf_path)
def get_pdf_text_chunks(pdf_text):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
return text_splitter.split_text(text=pdf_text)
def create_vector_store(target_source_chunks):
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
db = Chroma.from_texts(texts=target_source_chunks, persist_directory=persist_directory, embedding=embeddings)
db.persist()
return db
def get_vector_store(target_source_chunks):
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
retriver = db.as_retriever(search_kwargs={"k": target_source_chunks})
return retriver
def get_conversation_chain(retriever):
tokenizer = AutoTokenizer.from_pretrained("red1xe/Llama-2-7B-codeGPT")
model = AutoModelForCausalLM.from_pretrained("red1xe/Llama-2-7B-codeGPT")
memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True,)
chain = RetrievalQA.from_llm(
llm=model,
memory=memory,
retriever=retriever,
)
return chain
def handle_userinput(user_question):
if st.session_state.conversation is None:
st.warning("Please load the Vectorstore first!")
return
else:
with st.spinner('Thinking...', ):
start_time = time.time()
response = st.session_state.conversation({'query': user_question})
end_time = time.time()
st.session_state.chat_history = response['chat_history']
for i, message in enumerate(st.session_state.chat_history):
if i % 2 == 0:
st.write(user_template.replace("{{MSG}}", message.content), unsafe_allow_html=True)
else:
st.write(bot_template.replace("{{MSG}}", message.content), unsafe_allow_html=True)
st.write('Elapsed time: {:.2f} seconds'.format(end_time - start_time))
st.balloons()
def main():
st.set_page_config(page_title='Java Copilot :coffee:', page_icon=':rocket:', layout='wide', )
with st.sidebar.title(':gear: Parameters'):
model_n_ctx = st.sidebar.slider('Model N_CTX', min_value=128, max_value=2048, value=1024, step=2)
model_n_batch = st.sidebar.slider('Model N_BATCH', min_value=1, max_value=model_n_ctx, value=512, step=2)
target_source_chunks = st.sidebar.slider('Target Source Chunks', min_value=1, max_value=10, value=4, step=1)
st.write(css, unsafe_allow_html=True)
if "conversation" not in st.session_state:
st.session_state.conversation = None
if "chat_history" not in st.session_state:
st.session_state.chat_history = None
st.header('Java Copilot :coffee:')
st.subheader('Upload your PDF file and start chatting with it!')
user_question = st.text_input('Enter your message here:')
pdf_file = st.file_uploader("Upload PDF", type=['pdf'])
if st.button('Start Chain'):
with st.spinner('Working in progress ...'):
if pdf_file is not None:
pdf_text = get_pdf_text(pdf_file)
pdf_text_chunks = get_pdf_text_chunks(pdf_text)
st.session_state.vector_store = create_vector_store(pdf_text_chunks)
st.session_state.conversation = get_conversation_chain(
retriever=st.session_state.vector_store,
)
st.success('Vectorstore created successfully! You can start chatting now!')
else:
st.warning('Please upload a PDF file first!')
if user_question:
handle_userinput(user_question)
if __name__ == '__main__':
main()
|