File size: 4,390 Bytes
4862b9f
 
 
 
 
 
 
 
 
 
9defb57
5f87533
7f51a5a
184b57b
5949a92
 
4862b9f
5949a92
 
4862b9f
5949a92
cbe0781
40b5018
5949a92
 
 
4d64061
035f699
5949a92
4862b9f
 
 
 
 
 
 
 
c1adec2
 
4862b9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed54e4f
4862b9f
 
 
 
 
 
 
 
 
 
 
234ded3
4862b9f
 
dd45b3a
4862b9f
 
5949a92
 
 
 
 
 
 
 
 
 
 
44131c9
4862b9f
 
44131c9
9861a8a
4862b9f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import time
import streamlit as st
from htmlTemplates import css, bot_template, user_template
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.memory import ConversationBufferMemory
from langchain.chains import RetrievalQA
from pdfminer.high_level import extract_text
from langchain.text_splitter import RecursiveCharacterTextSplitter
from transformers import AutoTokenizer, AutoModelForCausalLM


# Updated Prompt Template
persist_directory = 'db'
embeddings_model_name = 'sentence-transformers/all-MiniLM-L6-v2'

def get_pdf_text(pdf_path):
    return extract_text(pdf_path)

def get_pdf_text_chunks(pdf_text):
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
    return text_splitter.split_text(text=pdf_text)

def create_vector_store(target_source_chunks):
    embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
    db = Chroma.from_texts(texts=target_source_chunks, persist_directory=persist_directory, embedding=embeddings)
    db.persist()
    return db

def get_vector_store(target_source_chunks):
    embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
    db = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
    retriver = db.as_retriever(search_kwargs={"k": target_source_chunks})
    return retriver

def get_conversation_chain(retriever):
    tokenizer = AutoTokenizer.from_pretrained("TinyPixel/Llama-2-7B-bf16-sharded")
    model = AutoModelForCausalLM.from_pretrained("TinyPixel/Llama-2-7B-bf16-sharded")
    memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True,)
    chain = RetrievalQA.from_llm(
        llm=model,
        memory=memory,
        retriever=retriever,
    )
    return chain


def handle_userinput(user_question):
    if st.session_state.conversation is None:
        st.warning("Please load the Vectorstore first!")
        return
    else:
        with st.spinner('Thinking...', ):
            start_time = time.time()
            response = st.session_state.conversation({'query': user_question})
            end_time = time.time()

            st.session_state.chat_history = response['chat_history']

            for i, message in enumerate(st.session_state.chat_history):
                if i % 2 == 0:
                    st.write(user_template.replace("{{MSG}}", message.content), unsafe_allow_html=True)
                else:
                    st.write(bot_template.replace("{{MSG}}", message.content), unsafe_allow_html=True)

            st.write('Elapsed time: {:.2f} seconds'.format(end_time - start_time))
            st.balloons()




def main():

    st.set_page_config(page_title='Java Copilot :coffee:', page_icon=':rocket:', layout='wide', )
    with st.sidebar.title(':gear: Parameters'):
        model_n_ctx = st.sidebar.slider('Model N_CTX', min_value=128, max_value=2048, value=1024, step=2)
        model_n_batch = st.sidebar.slider('Model N_BATCH', min_value=1, max_value=model_n_ctx, value=512, step=2)
        target_source_chunks = st.sidebar.slider('Target Source Chunks', min_value=1, max_value=10, value=4, step=1)
    st.write(css, unsafe_allow_html=True)

    if "conversation" not in st.session_state:
        st.session_state.conversation = None
    if "chat_history" not in st.session_state:
        st.session_state.chat_history = None

    st.header('Java Copilot :coffee:')
    st.subheader('Upload your PDF file and start chatting with it!')
    user_question = st.text_input('Enter your message here:')
    pdf_file = st.file_uploader("Upload PDF", type=['pdf'])
    if st.button('Start Chain'):
        with st.spinner('Working in progress ...'):
            if pdf_file is not None:
                pdf_text = get_pdf_text(pdf_file)
                pdf_text_chunks = get_pdf_text_chunks(pdf_text)
                st.session_state.vector_store = create_vector_store(pdf_text_chunks)
                st.session_state.conversation = get_conversation_chain(
                    retriever=st.session_state.vector_store,
                )
                st.success('Vectorstore created successfully! You can start chatting now!')
            else:
                st.warning('Please upload a PDF file first!')


    if user_question:
        handle_userinput(user_question)


if __name__ == '__main__':
    main()