CamiloVega commited on
Commit
6cc39b5
·
verified ·
1 Parent(s): b765cd5

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -408
app.py DELETED
@@ -1,408 +0,0 @@
1
- import os
2
- import shutil
3
- import logging
4
- from typing import List, Dict
5
- import torch
6
- import gradio as gr
7
- from langchain_text_splitters import RecursiveCharacterTextSplitter
8
- from langchain_community.embeddings import HuggingFaceEmbeddings
9
- from langchain_community.vectorstores import FAISS
10
- from langchain.chains import RetrievalQA
11
- from langchain.prompts import PromptTemplate
12
- from langchain_community.llms import HuggingFacePipeline
13
- from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader
14
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
15
- from huggingface_hub import login
16
-
17
- # Configure logging
18
- logging.basicConfig(
19
- level=logging.INFO,
20
- format='%(asctime)s - %(levelname)s - %(message)s'
21
- )
22
- logger = logging.getLogger(__name__)
23
-
24
- # Constants
25
- MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
26
- UPLOAD_FOLDER = "uploaded_docs"
27
- EMBEDDING_MODEL = "intfloat/multilingual-e5-large"
28
-
29
- class RAGSystem:
30
- """Main RAG system class."""
31
-
32
- def __init__(self):
33
- # Initialize device
34
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
35
- logger.info(f"Using device: {self.device}")
36
-
37
- # Initialize folders
38
- self.upload_folder = UPLOAD_FOLDER
39
- if os.path.exists(self.upload_folder):
40
- shutil.rmtree(self.upload_folder)
41
- os.makedirs(self.upload_folder, exist_ok=True)
42
-
43
- # Set limits
44
- self.max_files = 5
45
- self.max_file_size = 10 * 1024 * 1024 # 10 MB
46
- self.supported_formats = ['.pdf', '.txt', '.docx']
47
-
48
- # Initialize components
49
- self.embeddings = None
50
- self.vector_store = None
51
- self.qa_chain = None
52
- self.documents = []
53
-
54
- # Initialize embeddings
55
- self.initialize_embeddings()
56
-
57
- def initialize_embeddings(self):
58
- """Initialize embedding model."""
59
- try:
60
- self.embeddings = HuggingFaceEmbeddings(
61
- model_name=EMBEDDING_MODEL,
62
- model_kwargs={'device': self.device},
63
- encode_kwargs={'normalize_embeddings': True}
64
- )
65
- logger.info(f"Embeddings initialized successfully on {self.device}")
66
- except Exception as e:
67
- logger.error(f"Error initializing embeddings: {str(e)}")
68
- raise
69
-
70
- def validate_file(self, file_path: str, file_size: int) -> bool:
71
- """Validate uploaded file."""
72
- if file_size > self.max_file_size:
73
- raise ValueError(f"File size exceeds {self.max_file_size // 1024 // 1024}MB limit")
74
-
75
- ext = os.path.splitext(file_path)[1].lower()
76
- if ext not in self.supported_formats:
77
- raise ValueError(f"Unsupported format. Supported: {', '.join(self.supported_formats)}")
78
- return True
79
-
80
- def process_file(self, file: gr.File) -> List:
81
- """Process a single file and return documents."""
82
- try:
83
- file_path = file.name
84
- file_size = os.path.getsize(file_path)
85
- self.validate_file(file_path, file_size)
86
-
87
- # Copy file to upload directory
88
- filename = os.path.basename(file_path)
89
- save_path = os.path.join(self.upload_folder, filename)
90
- shutil.copy2(file_path, save_path)
91
-
92
- # Load documents based on file type
93
- ext = os.path.splitext(file_path)[1].lower()
94
- if ext == '.pdf':
95
- loader = PyPDFLoader(save_path)
96
- elif ext == '.txt':
97
- loader = TextLoader(save_path)
98
- else: # .docx
99
- loader = Docx2txtLoader(save_path)
100
-
101
- documents = loader.load()
102
- for doc in documents:
103
- doc.metadata.update({
104
- 'source': filename,
105
- 'type': 'uploaded'
106
- })
107
- return documents
108
-
109
- except Exception as e:
110
- logger.error(f"Error processing {file_path}: {str(e)}")
111
- raise
112
-
113
- def update_vector_store(self, new_documents: List):
114
- """Update vector store with new documents."""
115
- try:
116
- text_splitter = RecursiveCharacterTextSplitter(
117
- chunk_size=500,
118
- chunk_overlap=50,
119
- separators=["\n\n", "\n", ". ", " ", ""]
120
- )
121
- chunks = text_splitter.split_documents(new_documents)
122
-
123
- if self.vector_store is None:
124
- self.vector_store = FAISS.from_documents(chunks, self.embeddings)
125
- else:
126
- self.vector_store.add_documents(chunks)
127
-
128
- logger.info(f"Vector store updated with {len(chunks)} chunks")
129
-
130
- except Exception as e:
131
- logger.error(f"Error updating vector store: {str(e)}")
132
- raise
133
-
134
- def initialize_llm(self):
135
- """Initialize the language model and QA chain."""
136
- try:
137
- # Get Hugging Face token
138
- hf_token = os.environ.get('HUGGINGFACE_TOKEN')
139
- if not hf_token:
140
- raise ValueError("Please set HUGGINGFACE_TOKEN environment variable")
141
-
142
- # Login to Hugging Face
143
- login(token=hf_token)
144
-
145
- # Initialize model and tokenizer
146
- tokenizer = AutoTokenizer.from_pretrained(
147
- MODEL_NAME,
148
- token=hf_token,
149
- trust_remote_code=True
150
- )
151
-
152
- # Configure model loading based on device
153
- model_config = {
154
- 'device_map': 'auto',
155
- 'trust_remote_code': True,
156
- 'token': hf_token
157
- }
158
-
159
- if self.device == "cuda":
160
- model_config['torch_dtype'] = torch.float16
161
- else:
162
- model_config['low_cpu_mem_usage'] = True
163
-
164
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, **model_config)
165
-
166
- # Create pipeline
167
- pipe = pipeline(
168
- "text-generation",
169
- model=model,
170
- tokenizer=tokenizer,
171
- max_new_tokens=512,
172
- temperature=0.1,
173
- device_map="auto"
174
- )
175
-
176
- llm = HuggingFacePipeline(pipeline=pipe)
177
-
178
- # Create prompt template
179
- prompt_template = """
180
- Context: {context}
181
-
182
- Based on the context above, please provide a clear and concise answer to the following question.
183
- If the information is not in the context, explicitly state so.
184
-
185
- Question: {question}
186
- """
187
-
188
- PROMPT = PromptTemplate(
189
- template=prompt_template,
190
- input_variables=["context", "question"]
191
- )
192
-
193
- self.qa_chain = RetrievalQA.from_chain_type(
194
- llm=llm,
195
- chain_type="stuff",
196
- retriever=self.vector_store.as_retriever(search_kwargs={"k": 4}),
197
- return_source_documents=True,
198
- chain_type_kwargs={"prompt": PROMPT}
199
- )
200
-
201
- logger.info("LLM initialized successfully")
202
-
203
- except Exception as e:
204
- logger.error(f"Error initializing LLM: {str(e)}")
205
- raise
206
-
207
- def process_upload(self, files: List[gr.File]) -> str:
208
- """Process uploaded files and initialize/update the system."""
209
- if not files:
210
- return "Please select files to upload."
211
-
212
- try:
213
- current_files = len(os.listdir(self.upload_folder))
214
- if current_files + len(files) > self.max_files:
215
- return f"Maximum number of documents ({self.max_files}) exceeded"
216
-
217
- processed_files = []
218
- new_documents = []
219
- for file in files:
220
- documents = self.process_file(file)
221
- new_documents.extend(documents)
222
- processed_files.append(os.path.basename(file.name))
223
-
224
- self.update_vector_store(new_documents)
225
- self.documents.extend(new_documents)
226
-
227
- if self.qa_chain is None:
228
- self.initialize_llm()
229
-
230
- return f"Successfully processed: {', '.join(processed_files)}"
231
-
232
- except Exception as e:
233
- return f"Error: {str(e)}"
234
-
235
- def generate_response(self, question: str) -> Dict:
236
- """Generate response for a given question."""
237
- if not self.qa_chain:
238
- return {"error": "System not initialized. Please upload documents first."}
239
-
240
- try:
241
- result = self.qa_chain({"query": question})
242
-
243
- response = {
244
- 'answer': result['result'],
245
- 'sources': []
246
- }
247
-
248
- for doc in result['source_documents']:
249
- source = {
250
- 'title': doc.metadata.get('source', 'Unknown'),
251
- 'content': doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content
252
- }
253
- response['sources'].append(source)
254
-
255
- return response
256
-
257
- except Exception as e:
258
- logger.error(f"Error generating response: {str(e)}")
259
- return {"error": str(e)}
260
-
261
- # Initialize system
262
- rag_system = RAGSystem()
263
-
264
- def process_query(message: str, history: List) -> List:
265
- """Process user query and return updated history."""
266
- try:
267
- if not rag_system.qa_chain:
268
- return history + [(message, "Please upload documents first.")]
269
-
270
- response = rag_system.generate_response(message)
271
- if "error" in response:
272
- return history + [(message, f"Error: {response['error']}")]
273
-
274
- answer = response['answer']
275
- sources = set([source['title'] for source in response['sources']])
276
- if sources:
277
- answer += "\n\n📚 Sources:\n" + "\n".join([f"• {source}" for source in sources])
278
-
279
- return history + [(message, answer)]
280
- except Exception as e:
281
- return history + [(message, f"Error: {str(e)}")]
282
-
283
- # Create Gradio interface
284
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
285
- gr.HTML("""
286
- <div style="text-align: center; margin-bottom: 1rem;">
287
- <h1 style="color: #2d333a;">🤖 Easy RAG</h1>
288
- <p style="color: #4a5568;">A simple and powerful RAG system for your documents</p>
289
- </div>
290
- """)
291
-
292
- with gr.Row():
293
- with gr.Column(scale=1):
294
- with gr.Group():
295
- gr.HTML("""
296
- <div style="padding: 1rem; border: 1px solid #e5e7eb; border-radius: 0.5rem; background-color: white;">
297
- <h3 style="margin-top: 0;">📁 Upload Documents</h3>
298
- """)
299
- file_output = gr.File(
300
- file_count="multiple",
301
- label="Select Files",
302
- elem_id="file-upload"
303
- )
304
- gr.HTML("""
305
- <div style="font-size: 0.8em; color: #666;">
306
- <p>• Maximum 5 files</p>
307
- <p>• 10MB per file</p>
308
- <p>• Supported: PDF, TXT, DOCX</p>
309
- </div>
310
- """)
311
- system_output = gr.Textbox(
312
- label="Status",
313
- interactive=False
314
- )
315
- gr.HTML("</div>")
316
-
317
- with gr.Column(scale=3):
318
- chatbot = gr.Chatbot(
319
- value=[],
320
- label="Chat",
321
- height=600,
322
- show_copy_button=True
323
- )
324
-
325
- with gr.Row():
326
- message = gr.Textbox(
327
- placeholder="Ask a question about your documents...",
328
- show_label=False,
329
- container=False,
330
- scale=8
331
- )
332
- clear = gr.Button("🗑️", size="sm", scale=1)
333
-
334
- gr.HTML("""
335
- <div style="text-align: center; max-width: 800px; margin: 20px auto; padding: 1rem;
336
- background-color: #f8f9fa; border-radius: 10px;">
337
- <div style="margin-bottom: 1rem;">
338
- <h3 style="color: #2d333a;">🔍 About Easy RAG</h3>
339
- <p style="color: #666; font-size: 0.9em;">
340
- Powered by state-of-the-art AI technology:
341
- </p>
342
- <ul style="list-style: none; color: #666; font-size: 0.9em;">
343
- <li>🔹 LLM: Llama-2-7b-chat-hf</li>
344
- <li>🔹 Embeddings: multilingual-e5-large</li>
345
- <li>🔹 Vector Store: FAISS</li>
346
- </ul>
347
- </div>
348
- <div style="border-top: 1px solid #ddd; padding-top: 1rem;">
349
- <p style="color: #666; font-size: 0.8em;">
350
- Based on original work by <a href="https://www.linkedin.com/in/camilo-vega-169084b1/"
351
- target="_blank" style="color: #2196F3; text-decoration: none;">Camilo Vega</a>
352
- </p>
353
- </div>
354
- </div>
355
- """)
356
-
357
- # Set up event handlers
358
- file_output.upload(
359
- rag_system.process_upload,
360
- inputs=[file_output],
361
- outputs=[system_output]
362
- )
363
-
364
- message.submit(
365
- process_query,
366
- inputs=[message, chatbot],
367
- outputs=[chatbot]
368
- )
369
-
370
- clear.click(lambda: None, None, chatbot)
371
-
372
- if __name__ == "__main__":
373
- # Log system information
374
- logger.info("Starting Easy RAG system...")
375
- logger.info(f"PyTorch version: {torch.__version__}")
376
- logger.info(f"CUDA available: {torch.cuda.is_available()}")
377
- if torch.cuda.is_available():
378
- logger.info(f"CUDA device: {torch.cuda.get_device_name(0)}")
379
- else:
380
- logger.info("Running on CPU mode with optimizations")
381
-
382
- # Check for HUGGINGFACE_TOKEN
383
- if not os.environ.get('HUGGINGFACE_TOKEN'):
384
- logger.warning("HUGGINGFACE_TOKEN not found in environment variables")
385
- logger.warning("Please set it before running the application")
386
- print("Please set your HUGGINGFACE_TOKEN environment variable")
387
- print("Example: export HUGGINGFACE_TOKEN=your_token_here")
388
- exit(1)
389
-
390
- # Get sharing preference from environment
391
- share_enabled = os.environ.get('SHARE_APP', 'false').lower() == 'true'
392
- if share_enabled:
393
- logger.info("Public sharing is enabled - a public URL will be generated")
394
-
395
- try:
396
- # Launch the application
397
- demo.launch(
398
- server_name="0.0.0.0", # Listen on all network interfaces
399
- server_port=7860, # Default Gradio port
400
- share=share_enabled, # Generate public URL if enabled
401
- show_error=True, # Show detailed error messages
402
- quiet=True # Reduce console output noise
403
- )
404
- except KeyboardInterrupt:
405
- logger.info("Shutting down server...")
406
- except Exception as e:
407
- logger.error(f"Error launching server: {str(e)}")
408
- raise