File size: 7,440 Bytes
da29c5c
 
 
 
 
 
 
 
 
7a3fac5
da29c5c
7a3fac5
da29c5c
 
 
 
7a3fac5
 
da29c5c
 
 
7a3fac5
 
 
 
 
da29c5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a3fac5
da29c5c
33f1c38
 
 
 
7a3fac5
33f1c38
 
 
 
 
7a3fac5
 
 
 
da29c5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a3fac5
 
 
da29c5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a3fac5
da29c5c
 
7a3fac5
da29c5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a3fac5
da29c5c
 
 
 
 
 
 
7a3fac5
da29c5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a3fac5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain.document_loaders import PDFMinerLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
import os

# Initialize session state for storing the vector database and tenant
if 'vectordb' not in st.session_state:
    st.session_state.vectordb = {}
if 'model' not in st.session_state:
    st.session_state.model = None
if 'tokenizer' not in st.session_state:
    st.session_state.tokenizer = None
if 'tenant' not in st.session_state:
    st.session_state.tenant = "default_tenant"  # Default tenant

st.title("PDF Question Answering System")

# Tenant selection
st.sidebar.title("Settings")
tenant = st.sidebar.text_input("Enter your tenant:", value=st.session_state.tenant)
st.session_state.tenant = tenant  # Update the tenant in session state

# File uploader for PDFs
def load_pdfs():
    uploaded_files = st.file_uploader("Upload your PDF files", type=['pdf'], accept_multiple_files=True)
    if uploaded_files and st.button("Process PDFs"):
        with st.spinner("Processing PDFs..."):
            # Save uploaded files temporarily
            temp_paths = []
            for file in uploaded_files:
                temp_path = f"temp_{file.name}"
                with open(temp_path, "wb") as f:
                    f.write(file.getbuffer())
                temp_paths.append(temp_path)

            # Load PDFs
            documents = []
            for pdf_path in temp_paths:
                loader = PDFMinerLoader(pdf_path)
                doc = loader.load()
                for d in doc:
                    d.metadata["source"] = pdf_path
                documents.extend(doc)

            # Clean up temporary files
            for path in temp_paths:
                os.remove(path)

            # Split documents
            text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
            splits = text_splitter.split_documents(documents)

            # Create embeddings and vector store for the current tenant
            embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

            # Directory to store the vector database
            db_directory = "path/to/store/chroma_db"  # Update with your desired path

            if st.session_state.tenant not in st.session_state.vectordb:
                st.session_state.vectordb[st.session_state.tenant] = Chroma.from_documents(
                    documents=splits,
                    embedding=embeddings,
                    persist_directory=db_directory
                )
            else:
                # Update the existing vector store for the tenant
                st.session_state.vectordb[st.session_state.tenant].add_documents(splits)

            st.success("PDFs processed successfully!")
            return True
    return False

# Load model and tokenizer
@st.cache_resource
def load_model(model_path):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
    )
    model.eval()
    return model, tokenizer

def generate_response(prompt, model, tokenizer, max_new_tokens=256):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=0.1,
            top_p=0.95,
            repetition_penalty=1.15
        )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response[len(prompt):].strip()

def combine_documents_and_answer(retrieved_docs, question, model, tokenizer):
    context = "\n".join(doc.page_content for doc in retrieved_docs)
    prompt = f"""You are an assistant tasked with answering questions based SOLELY on the provided context.
Do not use any external knowledge or information not present in the given context.
If the question is of any other field and irrelevant to the context provided, respond just with "I can't tell you this, ask something from the provided context." 
DO NOT INCLUDE YOUR OWN OPINION. IMPORTANT: Your answer should be well structured and meaningful. 
Your answer should elaborate every tiny detail mentioned in the context. So, answer the following question within the context in detail:
Question: {question}
Context:
{context}
Answer:"""
    return generate_response(prompt, model, tokenizer)

# Main app logic
def main():
    if torch.cuda.is_available():
        st.sidebar.success("GPU is available!")
    else:
        st.sidebar.warning("GPU is not available. This app may run slowly on CPU.")

    # Model path input
    model_path = st.sidebar.text_input("Enter the path to your model:", 
                                        placeholder="waqasali1707/llama_3.2_3B_4_bit_Quan")

    # Load PDFs first
    if st.session_state.tenant not in st.session_state.vectordb:
        pdfs_processed = load_pdfs()
        if not pdfs_processed:
            st.info("Please upload PDF files and click 'Process PDFs' to continue.")
            return
    
    # Load model if path is provided and model isn't loaded
    if model_path and st.session_state.model is None:
        with st.spinner("Loading model..."):
            try:
                st.session_state.model, st.session_state.tokenizer = load_model(model_path)
                st.success("Model loaded successfully!")
            except Exception as e:
                st.error(f"Error loading model: {str(e)}")
                return

    # Question answering interface
    if st.session_state.tenant in st.session_state.vectordb and st.session_state.model is not None:
        question = st.text_area("Enter your question:", height=100)
        
        if st.button("Get Answer"):
            if question:
                with st.spinner("Generating answer..."):
                    try:
                        # Get relevant documents
                        retriever = st.session_state.vectordb[st.session_state.tenant].as_retriever(search_kwargs={"k": 4})
                        retrieved_docs = retriever.get_relevant_documents(question)
                        
                        # Generate answer
                        answer = combine_documents_and_answer(
                            retrieved_docs, 
                            question,
                            st.session_state.model,
                            st.session_state.tokenizer
                        )
                        
                        # Display answer
                        st.subheader("Answer:")
                        st.write(answer)
                        
                        # Display sources
                        st.subheader("Sources:")
                        sources = set(doc.metadata["source"] for doc in retrieved_docs)
                        for source in sources:
                            st.write(f"- {os.path.basename(source)}")
                            
                    except Exception as e:
                        st.error(f"Error generating answer: {str(e)}")
            else:
                st.warning("Please enter a question.")

if __name__ == "__main__":
    main()