CamiloVega commited on
Commit
f498b40
ยท
verified ยท
1 Parent(s): 11822c7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +309 -0
app.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from typing import List, Dict
4
+ import torch
5
+ import gradio as gr
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain.embeddings import HuggingFaceEmbeddings
8
+ from langchain.vectorstores import FAISS
9
+ from langchain.chains import RetrievalQA
10
+ from langchain.prompts import PromptTemplate
11
+ from langchain.llms import HuggingFacePipeline
12
+ from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader
13
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
14
+
15
+ # Configure logging
16
+ logging.basicConfig(
17
+ level=logging.INFO,
18
+ format='%(asctime)s - %(levelname)s - %(message)s'
19
+ )
20
+ logger = logging.getLogger(__name__)
21
+
22
+ MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
23
+ UPLOAD_FOLDER = "uploaded_docs"
24
+
25
+ class DocumentManager:
26
+ """Class to manage document uploads and processing."""
27
+
28
+ def __init__(self):
29
+ self.upload_folder = UPLOAD_FOLDER
30
+ os.makedirs(self.upload_folder, exist_ok=True)
31
+ self.max_files = 5
32
+ self.max_file_size = 10 * 1024 * 1024 # 10 MB
33
+ self.supported_formats = ['.pdf', '.txt', '.docx']
34
+ self.documents = []
35
+
36
+ def validate_file(self, file):
37
+ if os.path.getsize(file.name) > self.max_file_size:
38
+ raise ValueError(f"File size exceeds {self.max_file_size // 1024 // 1024}MB limit")
39
+
40
+ ext = os.path.splitext(file.name)[1].lower()
41
+ if ext not in self.supported_formats:
42
+ raise ValueError(f"Unsupported file format. Supported formats: {', '.join(self.supported_formats)}")
43
+
44
+ def load_document(self, file_path: str) -> List:
45
+ ext = os.path.splitext(file_path)[1].lower()
46
+ try:
47
+ if ext == '.pdf':
48
+ loader = PyPDFLoader(file_path)
49
+ elif ext == '.txt':
50
+ loader = TextLoader(file_path)
51
+ elif ext == '.docx':
52
+ loader = Docx2txtLoader(file_path)
53
+ else:
54
+ raise ValueError(f"Unsupported file format: {ext}")
55
+
56
+ documents = loader.load()
57
+ for doc in documents:
58
+ doc.metadata.update({
59
+ 'source': os.path.basename(file_path),
60
+ 'type': 'uploaded'
61
+ })
62
+ return documents
63
+
64
+ except Exception as e:
65
+ logger.error(f"Error loading {file_path}: {str(e)}")
66
+ raise
67
+
68
+ def process_upload(self, files: List) -> str:
69
+ if len(os.listdir(self.upload_folder)) + len(files) > self.max_files:
70
+ raise ValueError(f"Maximum number of documents ({self.max_files}) exceeded")
71
+
72
+ processed_files = []
73
+ for file in files:
74
+ try:
75
+ self.validate_file(file)
76
+ save_path = os.path.join(self.upload_folder, file.name)
77
+ file.save(save_path)
78
+ docs = self.load_document(save_path)
79
+ self.documents.extend(docs)
80
+ processed_files.append(file.name)
81
+ except Exception as e:
82
+ logger.error(f"Error processing {file.name}: {str(e)}")
83
+ return f"Error processing {file.name}: {str(e)}"
84
+
85
+ return f"Successfully processed files: {', '.join(processed_files)}"
86
+
87
+ class RAGSystem:
88
+ """Main RAG system class."""
89
+
90
+ def __init__(self, model_name: str = MODEL_NAME):
91
+ self.model_name = model_name
92
+ self.document_manager = DocumentManager()
93
+ self.embeddings = None
94
+ self.vector_store = None
95
+ self.qa_chain = None
96
+ self.is_initialized = False
97
+
98
+ def initialize_system(self, documents: List = None):
99
+ """Initialize RAG system with provided documents."""
100
+ try:
101
+ if not documents:
102
+ raise ValueError("No documents provided for initialization")
103
+
104
+ # Initialize text splitter
105
+ text_splitter = RecursiveCharacterTextSplitter(
106
+ chunk_size=500,
107
+ chunk_overlap=50,
108
+ separators=["\n\n", "\n", ". ", " ", ""]
109
+ )
110
+
111
+ # Process documents
112
+ chunks = text_splitter.split_documents(documents)
113
+
114
+ # Initialize embeddings
115
+ self.embeddings = HuggingFaceEmbeddings(
116
+ model_name="intfloat/multilingual-e5-large",
117
+ model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'}
118
+ )
119
+
120
+ # Create vector store
121
+ self.vector_store = FAISS.from_documents(chunks, self.embeddings)
122
+
123
+ # Initialize LLM pipeline
124
+ tokenizer = AutoTokenizer.from_pretrained(self.model_name)
125
+ model = AutoModelForCausalLM.from_pretrained(
126
+ self.model_name,
127
+ torch_dtype=torch.float16,
128
+ device_map="auto"
129
+ )
130
+
131
+ pipe = pipeline(
132
+ "text-generation",
133
+ model=model,
134
+ tokenizer=tokenizer,
135
+ max_new_tokens=512,
136
+ temperature=0.1,
137
+ device_map="auto"
138
+ )
139
+
140
+ llm = HuggingFacePipeline(pipeline=pipe)
141
+
142
+ # Create prompt template
143
+ prompt_template = """
144
+ Context: {context}
145
+
146
+ Based on the context above, please provide a clear and concise answer to the following question.
147
+ If the information is not in the context, explicitly state so.
148
+
149
+ Question: {question}
150
+ """
151
+
152
+ PROMPT = PromptTemplate(
153
+ template=prompt_template,
154
+ input_variables=["context", "question"]
155
+ )
156
+
157
+ # Set up QA chain
158
+ self.qa_chain = RetrievalQA.from_chain_type(
159
+ llm=llm,
160
+ chain_type="stuff",
161
+ retriever=self.vector_store.as_retriever(search_kwargs={"k": 4}),
162
+ return_source_documents=True,
163
+ chain_type_kwargs={"prompt": PROMPT}
164
+ )
165
+
166
+ self.is_initialized = True
167
+ return "System initialized successfully"
168
+
169
+ except Exception as e:
170
+ logger.error(f"Error during system initialization: {str(e)}")
171
+ return f"Error: {str(e)}"
172
+
173
+ def generate_response(self, question: str) -> Dict:
174
+ """Generate response for a given question."""
175
+ if not self.is_initialized:
176
+ return {"error": "System not initialized. Please upload documents first."}
177
+
178
+ try:
179
+ result = self.qa_chain({"query": question})
180
+
181
+ response = {
182
+ 'answer': result['result'],
183
+ 'sources': []
184
+ }
185
+
186
+ for doc in result['source_documents']:
187
+ source = {
188
+ 'title': doc.metadata.get('source', 'Unknown'),
189
+ 'content': doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content
190
+ }
191
+ response['sources'].append(source)
192
+
193
+ return response
194
+
195
+ except Exception as e:
196
+ logger.error(f"Error generating response: {str(e)}")
197
+ return {"error": str(e)}
198
+
199
+ # Initialize RAG system
200
+ rag_system = RAGSystem()
201
+
202
+ def process_file_upload(files):
203
+ """Handle file uploads and system initialization."""
204
+ try:
205
+ upload_result = rag_system.document_manager.process_upload(files)
206
+ if "Error" in upload_result:
207
+ return upload_result
208
+
209
+ init_result = rag_system.initialize_system(rag_system.document_manager.documents)
210
+ return f"{upload_result}\n{init_result}"
211
+ except Exception as e:
212
+ return f"Error: {str(e)}"
213
+
214
+ def process_query(message, history):
215
+ """Process user query and generate response."""
216
+ try:
217
+ if not rag_system.is_initialized:
218
+ return history + [(message, "Please upload documents first.")]
219
+
220
+ response = rag_system.generate_response(message)
221
+ if "error" in response:
222
+ return history + [(message, f"Error: {response['error']}")]
223
+
224
+ answer = response['answer']
225
+ sources = set([source['title'] for source in response['sources']])
226
+ if sources:
227
+ answer += "\n\n๐Ÿ“š Sources:\n" + "\n".join([f"โ€ข {source}" for source in sources])
228
+
229
+ return history + [(message, answer)]
230
+ except Exception as e:
231
+ return history + [(message, f"Error: {str(e)}")]
232
+
233
+ # Create Gradio interface
234
+ demo = gr.Blocks(css="div.gradio-container {background-color: #f0f2f6}")
235
+
236
+ with demo:
237
+ gr.HTML("""
238
+ <div style="text-align: center; max-width: 800px; margin: 0 auto; padding: 20px;">
239
+ <h1 style="color: #2d333a;">๐Ÿค– Easy RAG</h1>
240
+ <p style="color: #4a5568;">
241
+ A simple and powerful RAG system for your documents
242
+ </p>
243
+ </div>
244
+ """)
245
+
246
+ with gr.Row():
247
+ file_output = gr.File(
248
+ file_count="multiple",
249
+ label="Upload Documents (PDF, TXT, DOCX - Max 5 files, 10MB each)"
250
+ )
251
+
252
+ upload_button = gr.Button("Upload and Initialize")
253
+ system_output = gr.Textbox(label="System Status")
254
+
255
+ chatbot = gr.Chatbot(
256
+ show_label=False,
257
+ container=True,
258
+ height=400,
259
+ show_copy_button=True
260
+ )
261
+
262
+ with gr.Row():
263
+ message = gr.Textbox(
264
+ placeholder="Ask a question about your documents...",
265
+ show_label=False,
266
+ container=False,
267
+ scale=8
268
+ )
269
+ clear = gr.Button("๐Ÿ—‘๏ธ Clear", size="sm", scale=1)
270
+
271
+ gr.HTML("""
272
+ <div style="text-align: center; max-width: 800px; margin: 20px auto; padding: 20px;
273
+ background-color: #f8f9fa; border-radius: 10px;">
274
+ <div style="margin-bottom: 15px;">
275
+ <h3 style="color: #2d333a;">๐Ÿ” About Easy RAG</h3>
276
+ <p style="color: #666; font-size: 14px;">
277
+ A powerful RAG system that lets you query your documents using:
278
+ </p>
279
+ <ul style="list-style: none; color: #666; font-size: 14px;">
280
+ <li>๐Ÿ”น LLM: Llama-2-7b-chat-hf</li>
281
+ <li>๐Ÿ”น Embeddings: multilingual-e5-large</li>
282
+ <li>๐Ÿ”น Vector Store: FAISS</li>
283
+ </ul>
284
+ </div>
285
+ <div style="border-top: 1px solid #ddd; padding-top: 15px;">
286
+ <p style="color: #666; font-size: 14px;">
287
+ Based on original work by <a href="https://www.linkedin.com/in/camilo-vega-169084b1/"
288
+ target="_blank" style="color: #2196F3; text-decoration: none;">Camilo Vega</a>
289
+ </p>
290
+ </div>
291
+ </div>
292
+ """)
293
+
294
+ # Set up event handlers
295
+ upload_button.click(
296
+ process_file_upload,
297
+ inputs=[file_output],
298
+ outputs=[system_output]
299
+ )
300
+
301
+ message.submit(
302
+ process_query,
303
+ inputs=[message, chatbot],
304
+ outputs=[chatbot]
305
+ )
306
+
307
+ clear.click(lambda: None, None, chatbot)
308
+
309
+ demo.launch()