Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
+
from langchain.document_loaders import PDFMinerLoader
|
5 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
6 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
7 |
+
from langchain.vectorstores import Chroma
|
8 |
+
import os
|
9 |
+
|
10 |
+
# Initialize session state for storing the vector database
|
11 |
+
if 'vectordb' not in st.session_state:
|
12 |
+
st.session_state.vectordb = None
|
13 |
+
if 'model' not in st.session_state:
|
14 |
+
st.session_state.model = None
|
15 |
+
if 'tokenizer' not in st.session_state:
|
16 |
+
st.session_state.tokenizer = None
|
17 |
+
|
18 |
+
st.title("PDF Question Answering System")
|
19 |
+
|
20 |
+
# File uploader for PDFs
|
21 |
+
def load_pdfs():
|
22 |
+
uploaded_files = st.file_uploader("Upload your PDF files", type=['pdf'], accept_multiple_files=True)
|
23 |
+
if uploaded_files and st.button("Process PDFs"):
|
24 |
+
with st.spinner("Processing PDFs..."):
|
25 |
+
# Save uploaded files temporarily
|
26 |
+
temp_paths = []
|
27 |
+
for file in uploaded_files:
|
28 |
+
temp_path = f"temp_{file.name}"
|
29 |
+
with open(temp_path, "wb") as f:
|
30 |
+
f.write(file.getbuffer())
|
31 |
+
temp_paths.append(temp_path)
|
32 |
+
|
33 |
+
# Load PDFs
|
34 |
+
documents = []
|
35 |
+
for pdf_path in temp_paths:
|
36 |
+
loader = PDFMinerLoader(pdf_path)
|
37 |
+
doc = loader.load()
|
38 |
+
for d in doc:
|
39 |
+
d.metadata["source"] = pdf_path
|
40 |
+
documents.extend(doc)
|
41 |
+
|
42 |
+
# Clean up temporary files
|
43 |
+
for path in temp_paths:
|
44 |
+
os.remove(path)
|
45 |
+
|
46 |
+
# Split documents
|
47 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
|
48 |
+
splits = text_splitter.split_documents(documents)
|
49 |
+
|
50 |
+
# Create embeddings and vector store
|
51 |
+
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
52 |
+
st.session_state.vectordb = Chroma.from_documents(documents=splits, embedding=embeddings)
|
53 |
+
|
54 |
+
st.success("PDFs processed successfully!")
|
55 |
+
return True
|
56 |
+
return False
|
57 |
+
|
58 |
+
# Load model and tokenizer
|
59 |
+
@st.cache_resource
|
60 |
+
def load_model(model_path):
|
61 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
62 |
+
model = AutoModelForCausalLM.from_pretrained(
|
63 |
+
model_path,
|
64 |
+
torch_dtype=torch.float16,
|
65 |
+
low_cpu_mem_usage=True,
|
66 |
+
)
|
67 |
+
model.eval()
|
68 |
+
return model, tokenizer
|
69 |
+
|
70 |
+
def generate_response(prompt, model, tokenizer, max_new_tokens=256):
|
71 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
72 |
+
with torch.no_grad():
|
73 |
+
outputs = model.generate(
|
74 |
+
**inputs,
|
75 |
+
max_new_tokens=max_new_tokens,
|
76 |
+
temperature=0.1,
|
77 |
+
top_p=0.95,
|
78 |
+
repetition_penalty=1.15
|
79 |
+
)
|
80 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
81 |
+
return response[len(prompt):].strip()
|
82 |
+
|
83 |
+
def combine_documents_and_answer(retrieved_docs, question, model, tokenizer):
|
84 |
+
context = "\n".join(doc.page_content for doc in retrieved_docs)
|
85 |
+
prompt = f"""You are an assistant tasked with answering questions based SOLELY on the provided context.
|
86 |
+
Do not use any external knowledge or information not present in the given context.
|
87 |
+
If the question is of any other field and irrelevant to the context provided, repond just with "I can't tell you this, ask something from the provided context." DO NOT INCLUDE YOUR OWN OPINION.
|
88 |
+
|
89 |
+
IMPORTANT: Your answer should be well structured and meaningful. It should stop generating when it is done. Do not generate or repeat absurd sentences.
|
90 |
+
Your answer should elaborate every tiny detail mentioned in the context.
|
91 |
+
So, answer the following question within the context in detail:
|
92 |
+
|
93 |
+
Question: {question}
|
94 |
+
|
95 |
+
Context:
|
96 |
+
{context}
|
97 |
+
|
98 |
+
Answer:"""
|
99 |
+
return generate_response(prompt, model, tokenizer)
|
100 |
+
|
101 |
+
# Main app logic
|
102 |
+
def main():
|
103 |
+
if torch.cuda.is_available():
|
104 |
+
st.sidebar.success("GPU is available!")
|
105 |
+
else:
|
106 |
+
st.sidebar.warning("GPU is not available. This app may run slowly on CPU.")
|
107 |
+
|
108 |
+
# Model path input
|
109 |
+
model_path = st.sidebar.text_input("Enter the path to your model:",
|
110 |
+
placeholder="waqasali1707/llama_3.2_3B_4_bit_Quan")
|
111 |
+
|
112 |
+
# Load PDFs first
|
113 |
+
if st.session_state.vectordb is None:
|
114 |
+
pdfs_processed = load_pdfs()
|
115 |
+
if not pdfs_processed:
|
116 |
+
st.info("Please upload PDF files and click 'Process PDFs' to continue.")
|
117 |
+
return
|
118 |
+
|
119 |
+
# Load model if path is provided and model isn't loaded
|
120 |
+
if model_path and st.session_state.model is None:
|
121 |
+
with st.spinner("Loading model..."):
|
122 |
+
try:
|
123 |
+
st.session_state.model, st.session_state.tokenizer = load_model(model_path)
|
124 |
+
st.success("Model loaded successfully!")
|
125 |
+
except Exception as e:
|
126 |
+
st.error(f"Error loading model: {str(e)}")
|
127 |
+
return
|
128 |
+
|
129 |
+
# Question answering interface
|
130 |
+
if st.session_state.vectordb is not None and st.session_state.model is not None:
|
131 |
+
question = st.text_area("Enter your question:", height=100)
|
132 |
+
|
133 |
+
if st.button("Get Answer"):
|
134 |
+
if question:
|
135 |
+
with st.spinner("Generating answer..."):
|
136 |
+
try:
|
137 |
+
# Get relevant documents
|
138 |
+
retriever = st.session_state.vectordb.as_retriever(search_kwargs={"k": 4})
|
139 |
+
retrieved_docs = retriever.get_relevant_documents(question)
|
140 |
+
|
141 |
+
# Generate answer
|
142 |
+
answer = combine_documents_and_answer(
|
143 |
+
retrieved_docs,
|
144 |
+
question,
|
145 |
+
st.session_state.model,
|
146 |
+
st.session_state.tokenizer
|
147 |
+
)
|
148 |
+
|
149 |
+
# Display answer
|
150 |
+
st.subheader("Answer:")
|
151 |
+
st.write(answer)
|
152 |
+
|
153 |
+
# Display sources
|
154 |
+
st.subheader("Sources:")
|
155 |
+
sources = set(doc.metadata["source"] for doc in retrieved_docs)
|
156 |
+
for source in sources:
|
157 |
+
st.write(f"- {os.path.basename(source)}")
|
158 |
+
|
159 |
+
except Exception as e:
|
160 |
+
st.error(f"Error generating answer: {str(e)}")
|
161 |
+
else:
|
162 |
+
st.warning("Please enter a question.")
|
163 |
+
|
164 |
+
if __name__ == "__main__":
|
165 |
+
main()
|