CamiloVega commited on
Commit
0a80de4
·
verified ·
1 Parent(s): 985ad05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -82
app.py CHANGED
@@ -11,8 +11,9 @@ from langchain.chains import RetrievalQA
11
  from langchain.prompts import PromptTemplate
12
  from langchain_community.llms import HuggingFacePipeline
13
  from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader
14
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
15
  from huggingface_hub import login
 
16
 
17
  # Configure logging
18
  logging.basicConfig(
@@ -30,11 +31,17 @@ class RAGSystem:
30
  """Main RAG system class."""
31
 
32
  def __init__(self):
 
 
 
 
 
33
  self.upload_folder = UPLOAD_FOLDER
34
  if os.path.exists(self.upload_folder):
35
  shutil.rmtree(self.upload_folder)
36
  os.makedirs(self.upload_folder, exist_ok=True)
37
 
 
38
  self.max_files = 5
39
  self.max_file_size = 10 * 1024 * 1024 # 10 MB
40
  self.supported_formats = ['.pdf', '.txt', '.docx']
@@ -45,7 +52,7 @@ class RAGSystem:
45
  self.qa_chain = None
46
  self.documents = []
47
 
48
- # Initialize embeddings once
49
  self.initialize_embeddings()
50
 
51
  def initialize_embeddings(self):
@@ -53,12 +60,107 @@ class RAGSystem:
53
  try:
54
  self.embeddings = HuggingFaceEmbeddings(
55
  model_name=EMBEDDING_MODEL,
56
- model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'}
 
 
 
57
  )
 
58
  except Exception as e:
59
  logger.error(f"Error initializing embeddings: {str(e)}")
60
  raise
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def validate_file(self, file_path: str, file_size: int) -> bool:
63
  """Validate uploaded file."""
64
  if file_size > self.max_file_size:
@@ -105,7 +207,6 @@ class RAGSystem:
105
  def update_vector_store(self, new_documents: List):
106
  """Update vector store with new documents."""
107
  try:
108
- # Process documents
109
  text_splitter = RecursiveCharacterTextSplitter(
110
  chunk_size=500,
111
  chunk_overlap=50,
@@ -113,74 +214,17 @@ class RAGSystem:
113
  )
114
  chunks = text_splitter.split_documents(new_documents)
115
 
116
- # Create or update vector store
117
  if self.vector_store is None:
118
  self.vector_store = FAISS.from_documents(chunks, self.embeddings)
119
  else:
120
  self.vector_store.add_documents(chunks)
121
 
 
 
122
  except Exception as e:
123
  logger.error(f"Error updating vector store: {str(e)}")
124
  raise
125
 
126
- def initialize_llm(self):
127
- """Initialize the language model and QA chain."""
128
- try:
129
- # Get Hugging Face token
130
- hf_token = os.environ.get('HUGGINGFACE_TOKEN')
131
- if not hf_token:
132
- raise ValueError("Please set HUGGINGFACE_TOKEN environment variable")
133
-
134
- # Login to Hugging Face
135
- login(token=hf_token)
136
-
137
- # Initialize model and tokenizer
138
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
139
- model = AutoModelForCausalLM.from_pretrained(
140
- MODEL_NAME,
141
- torch_dtype=torch.float16,
142
- device_map="auto"
143
- )
144
-
145
- # Create pipeline
146
- pipe = pipeline(
147
- "text-generation",
148
- model=model,
149
- tokenizer=tokenizer,
150
- max_new_tokens=512,
151
- temperature=0.1,
152
- device_map="auto"
153
- )
154
-
155
- llm = HuggingFacePipeline(pipeline=pipe)
156
-
157
- # Create QA chain
158
- prompt_template = """
159
- Context: {context}
160
-
161
- Based on the context above, please provide a clear and concise answer to the following question.
162
- If the information is not in the context, explicitly state so.
163
-
164
- Question: {question}
165
- """
166
-
167
- PROMPT = PromptTemplate(
168
- template=prompt_template,
169
- input_variables=["context", "question"]
170
- )
171
-
172
- self.qa_chain = RetrievalQA.from_chain_type(
173
- llm=llm,
174
- chain_type="stuff",
175
- retriever=self.vector_store.as_retriever(search_kwargs={"k": 4}),
176
- return_source_documents=True,
177
- chain_type_kwargs={"prompt": PROMPT}
178
- )
179
-
180
- except Exception as e:
181
- logger.error(f"Error initializing LLM: {str(e)}")
182
- raise
183
-
184
  def process_upload(self, files: List[gr.File]) -> str:
185
  """Process uploaded files and initialize/update the system."""
186
  if not files:
@@ -191,7 +235,6 @@ class RAGSystem:
191
  if current_files + len(files) > self.max_files:
192
  return f"Maximum number of documents ({self.max_files}) exceeded"
193
 
194
- # Process each file
195
  processed_files = []
196
  new_documents = []
197
  for file in files:
@@ -199,15 +242,13 @@ class RAGSystem:
199
  new_documents.extend(documents)
200
  processed_files.append(os.path.basename(file.name))
201
 
202
- # Update vector store with new documents
203
  self.update_vector_store(new_documents)
204
  self.documents.extend(new_documents)
205
 
206
- # Initialize LLM if not already initialized
207
  if self.qa_chain is None:
208
  self.initialize_llm()
209
 
210
- return f"Successfully processed and initialized: {', '.join(processed_files)}"
211
 
212
  except Exception as e:
213
  return f"Error: {str(e)}"
@@ -270,7 +311,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
270
  """)
271
 
272
  with gr.Row():
273
- # Sidebar for document upload
274
  with gr.Column(scale=1):
275
  with gr.Group():
276
  gr.HTML("""
@@ -295,7 +335,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
295
  )
296
  gr.HTML("</div>")
297
 
298
- # Main chat area
299
  with gr.Column(scale=3):
300
  chatbot = gr.Chatbot(
301
  show_label=False,
@@ -335,17 +374,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
335
  </div>
336
  </div>
337
  """)
338
-
339
- # Add custom CSS
340
- demo.css = """
341
- .container {
342
- border-radius: 0.5rem;
343
- margin: 0.5rem;
344
- }
345
- #file-upload {
346
- margin-bottom: 1rem;
347
- }
348
- """
349
 
350
  # Set up event handlers
351
  file_output.upload(
@@ -362,5 +390,38 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
362
 
363
  clear.click(lambda: None, None, chatbot)
364
 
365
- if __name__ == "__main__":
366
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  from langchain.prompts import PromptTemplate
12
  from langchain_community.llms import HuggingFacePipeline
13
  from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader
14
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
15
  from huggingface_hub import login
16
+ import bitsandbytes as bnb
17
 
18
  # Configure logging
19
  logging.basicConfig(
 
31
  """Main RAG system class."""
32
 
33
  def __init__(self):
34
+ # Initialize device
35
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
36
+ logger.info(f"Using device: {self.device}")
37
+
38
+ # Initialize folders
39
  self.upload_folder = UPLOAD_FOLDER
40
  if os.path.exists(self.upload_folder):
41
  shutil.rmtree(self.upload_folder)
42
  os.makedirs(self.upload_folder, exist_ok=True)
43
 
44
+ # Set limits
45
  self.max_files = 5
46
  self.max_file_size = 10 * 1024 * 1024 # 10 MB
47
  self.supported_formats = ['.pdf', '.txt', '.docx']
 
52
  self.qa_chain = None
53
  self.documents = []
54
 
55
+ # Initialize embeddings
56
  self.initialize_embeddings()
57
 
58
  def initialize_embeddings(self):
 
60
  try:
61
  self.embeddings = HuggingFaceEmbeddings(
62
  model_name=EMBEDDING_MODEL,
63
+ model_kwargs={
64
+ 'device': self.device,
65
+ 'torch_dtype': torch.float32,
66
+ }
67
  )
68
+ logger.info("Embeddings initialized successfully")
69
  except Exception as e:
70
  logger.error(f"Error initializing embeddings: {str(e)}")
71
  raise
72
 
73
+ def initialize_llm(self):
74
+ """Initialize the language model and QA chain."""
75
+ try:
76
+ # Get Hugging Face token
77
+ hf_token = os.environ.get('HUGGINGFACE_TOKEN')
78
+ if not hf_token:
79
+ raise ValueError("Please set HUGGINGFACE_TOKEN environment variable")
80
+
81
+ # Login to Hugging Face
82
+ login(token=hf_token)
83
+
84
+ # Configure model loading based on device
85
+ if self.device == "cuda":
86
+ model_config = {
87
+ 'torch_dtype': torch.float16,
88
+ 'device_map': "auto",
89
+ }
90
+ else:
91
+ quantization_config = BitsAndBytesConfig(
92
+ load_in_4bit=True,
93
+ bnb_4bit_compute_dtype=torch.float32,
94
+ bnb_4bit_quant_type="nf4",
95
+ bnb_4bit_use_double_quant=True,
96
+ )
97
+ model_config = {
98
+ 'quantization_config': quantization_config,
99
+ 'device_map': "auto",
100
+ 'torch_dtype': torch.float32,
101
+ 'low_cpu_mem_usage': True,
102
+ }
103
+
104
+ # Initialize tokenizer and model
105
+ tokenizer = AutoTokenizer.from_pretrained(
106
+ MODEL_NAME,
107
+ token=hf_token,
108
+ trust_remote_code=True
109
+ )
110
+
111
+ model = AutoModelForCausalLM.from_pretrained(
112
+ MODEL_NAME,
113
+ token=hf_token,
114
+ trust_remote_code=True,
115
+ **model_config
116
+ )
117
+
118
+ # Create pipeline
119
+ pipe_config = {
120
+ "model": model,
121
+ "tokenizer": tokenizer,
122
+ "max_new_tokens": 512,
123
+ "temperature": 0.1,
124
+ "device_map": "auto",
125
+ "torch_dtype": torch.float32 if self.device == "cpu" else torch.float16,
126
+ }
127
+
128
+ if self.device == "cpu":
129
+ pipe_config["model"] = pipe_config["model"].to('cpu')
130
+
131
+ pipe = pipeline("text-generation", **pipe_config)
132
+
133
+ # Create QA chain
134
+ llm = HuggingFacePipeline(pipeline=pipe)
135
+
136
+ prompt_template = """
137
+ Context: {context}
138
+
139
+ Based on the context above, please provide a clear and concise answer to the following question.
140
+ If the information is not in the context, explicitly state so.
141
+
142
+ Question: {question}
143
+ """
144
+
145
+ PROMPT = PromptTemplate(
146
+ template=prompt_template,
147
+ input_variables=["context", "question"]
148
+ )
149
+
150
+ self.qa_chain = RetrievalQA.from_chain_type(
151
+ llm=llm,
152
+ chain_type="stuff",
153
+ retriever=self.vector_store.as_retriever(search_kwargs={"k": 4}),
154
+ return_source_documents=True,
155
+ chain_type_kwargs={"prompt": PROMPT}
156
+ )
157
+
158
+ logger.info("LLM initialized successfully")
159
+
160
+ except Exception as e:
161
+ logger.error(f"Error initializing LLM: {str(e)}")
162
+ raise
163
+
164
  def validate_file(self, file_path: str, file_size: int) -> bool:
165
  """Validate uploaded file."""
166
  if file_size > self.max_file_size:
 
207
  def update_vector_store(self, new_documents: List):
208
  """Update vector store with new documents."""
209
  try:
 
210
  text_splitter = RecursiveCharacterTextSplitter(
211
  chunk_size=500,
212
  chunk_overlap=50,
 
214
  )
215
  chunks = text_splitter.split_documents(new_documents)
216
 
 
217
  if self.vector_store is None:
218
  self.vector_store = FAISS.from_documents(chunks, self.embeddings)
219
  else:
220
  self.vector_store.add_documents(chunks)
221
 
222
+ logger.info(f"Vector store updated with {len(chunks)} chunks")
223
+
224
  except Exception as e:
225
  logger.error(f"Error updating vector store: {str(e)}")
226
  raise
227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  def process_upload(self, files: List[gr.File]) -> str:
229
  """Process uploaded files and initialize/update the system."""
230
  if not files:
 
235
  if current_files + len(files) > self.max_files:
236
  return f"Maximum number of documents ({self.max_files}) exceeded"
237
 
 
238
  processed_files = []
239
  new_documents = []
240
  for file in files:
 
242
  new_documents.extend(documents)
243
  processed_files.append(os.path.basename(file.name))
244
 
 
245
  self.update_vector_store(new_documents)
246
  self.documents.extend(new_documents)
247
 
 
248
  if self.qa_chain is None:
249
  self.initialize_llm()
250
 
251
+ return f"Successfully processed: {', '.join(processed_files)}"
252
 
253
  except Exception as e:
254
  return f"Error: {str(e)}"
 
311
  """)
312
 
313
  with gr.Row():
 
314
  with gr.Column(scale=1):
315
  with gr.Group():
316
  gr.HTML("""
 
335
  )
336
  gr.HTML("</div>")
337
 
 
338
  with gr.Column(scale=3):
339
  chatbot = gr.Chatbot(
340
  show_label=False,
 
374
  </div>
375
  </div>
376
  """)
 
 
 
 
 
 
 
 
 
 
 
377
 
378
  # Set up event handlers
379
  file_output.upload(
 
390
 
391
  clear.click(lambda: None, None, chatbot)
392
 
393
+ if __name__ == "__main__":
394
+ # Log system information
395
+ logger.info("Starting Easy RAG system...")
396
+ logger.info(f"PyTorch version: {torch.__version__}")
397
+ logger.info(f"CUDA available: {torch.cuda.is_available()}")
398
+ if torch.cuda.is_available():
399
+ logger.info(f"CUDA device: {torch.cuda.get_device_name(0)}")
400
+ else:
401
+ logger.info("Running on CPU mode with optimizations")
402
+
403
+ # Check for HUGGINGFACE_TOKEN
404
+ if not os.environ.get('HUGGINGFACE_TOKEN'):
405
+ logger.warning("HUGGINGFACE_TOKEN not found in environment variables")
406
+ logger.warning("Please set it before running the application")
407
+ print("Please set your HUGGINGFACE_TOKEN environment variable")
408
+ print("Example: export HUGGINGFACE_TOKEN=your_token_here")
409
+ exit(1)
410
+
411
+ # Create upload directory if it doesn't exist
412
+ if not os.path.exists(UPLOAD_FOLDER):
413
+ os.makedirs(UPLOAD_FOLDER)
414
+ logger.info(f"Created upload directory: {UPLOAD_FOLDER}")
415
+
416
+ try:
417
+ # Launch the Gradio interface
418
+ demo.launch(
419
+ share=False, # Set to True if you want to create a public link
420
+ server_name="0.0.0.0", # Listen on all network interfaces
421
+ server_port=7860, # Default Gradio port
422
+ show_error=True,
423
+ enable_queue=True
424
+ )
425
+ except Exception as e:
426
+ logger.error(f"Error launching Gradio interface: {str(e)}")
427
+ raise