Trabis commited on
Commit
83bd64d
·
verified ·
1 Parent(s): 89f5e63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1067 -245
app.py CHANGED
@@ -1,3 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from langchain_mistralai.chat_models import ChatMistralAI
3
  from langchain.prompts import ChatPromptTemplate
@@ -19,44 +464,63 @@ from sentence_transformers.cross_encoder import CrossEncoder
19
  import threading
20
  from queue import Queue
21
  import concurrent.futures
22
- from typing import Generator, Tuple, Iterator
23
  import time
24
 
 
25
  class OptimizedRAGLoader:
26
  def __init__(self,
27
  docs_folder: str = "./docs",
28
  splits_folder: str = "./splits",
29
  index_folder: str = "./index"):
30
-
31
  self.docs_folder = Path(docs_folder)
32
  self.splits_folder = Path(splits_folder)
33
  self.index_folder = Path(index_folder)
34
-
35
  # Create folders if they don't exist
36
  for folder in [self.splits_folder, self.index_folder]:
37
  folder.mkdir(parents=True, exist_ok=True)
38
-
39
  # File paths
40
  self.splits_path = self.splits_folder / "splits.json"
41
  self.index_path = self.index_folder / "faiss.index"
42
  self.documents_path = self.index_folder / "documents.pkl"
43
-
44
  # Initialize components
45
  self.index = None
46
  self.indexed_documents = None
47
-
48
  # Initialize encoder model
 
49
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
  self.encoder = SentenceTransformer("intfloat/multilingual-e5-large")
51
  self.encoder.to(self.device)
52
- self.reranker = model = CrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1",trust_remote_code=True)
53
-
 
 
54
  # Initialize thread pool
55
  self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
56
-
57
  # Initialize response cache
58
  self.response_cache = {}
59
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  @lru_cache(maxsize=1000)
61
  def encode(self, text: str):
62
  """Cached encoding function"""
@@ -64,362 +528,720 @@ class OptimizedRAGLoader:
64
  embeddings = self.encoder.encode(
65
  text,
66
  convert_to_numpy=True,
67
- normalize_embeddings=True
 
68
  )
69
  return embeddings
70
-
71
  def batch_encode(self, texts: list):
72
  """Batch encoding for multiple texts"""
73
  with torch.no_grad():
74
  embeddings = self.encoder.encode(
75
  texts,
76
- batch_size=32,
77
  convert_to_numpy=True,
78
  normalize_embeddings=True,
79
- show_progress_bar=False
 
80
  )
81
  return embeddings
82
 
83
- def load_and_split_texts(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  if self._splits_exist():
85
- return self._load_existing_splits()
86
-
 
 
 
87
  documents = []
88
  futures = []
89
-
90
- for file_path in self.docs_folder.glob("*.txt"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  future = self.executor.submit(self._process_file, file_path)
92
  futures.append(future)
93
-
 
94
  for future in concurrent.futures.as_completed(futures):
95
- documents.extend(future.result())
96
-
97
- self._save_splits(documents)
 
 
 
 
 
 
 
 
98
  return documents
99
-
100
- def _process_file(self, file_path):
101
- with open(file_path, 'r', encoding='utf-8') as file:
102
- text = file.read()
103
- chunks = [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()]
104
-
105
- return [
106
- Document(
107
- page_content=chunk,
108
- metadata={
109
- 'source': file_path.name,
110
- 'chunk_id': i,
111
- 'total_chunks': len(chunks)
112
- }
113
- )
114
- for i, chunk in enumerate(chunks)
115
- ]
116
 
117
- def load_index(self) -> bool:
118
- """
119
- Charge l'index FAISS et les documents associés s'ils existent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- Returns:
122
- bool: True si l'index a été chargé, False sinon
123
- """
124
  if not self._index_exists():
125
- print("Aucun index trouvé.")
126
  return False
127
 
128
- print("Chargement de l'index existant...")
129
  try:
130
- # Charger l'index FAISS
131
  self.index = faiss.read_index(str(self.index_path))
132
 
133
- # Charger les documents associés
 
 
 
 
 
 
 
 
 
 
 
134
  with open(self.documents_path, 'rb') as f:
135
  self.indexed_documents = pickle.load(f)
136
 
137
- print(f"Index chargé avec {self.index.ntotal} vecteurs")
 
 
 
 
 
 
 
 
 
138
  return True
139
 
 
 
 
 
 
140
  except Exception as e:
141
- print(f"Erreur lors du chargement de l'index: {e}")
 
 
 
142
  return False
143
 
144
- def create_index(self, documents=None):
 
145
  if documents is None:
146
  documents = self.load_and_split_texts()
147
-
148
  if not documents:
 
149
  return False
150
-
 
151
  texts = [doc.page_content for doc in documents]
 
152
  embeddings = self.batch_encode(texts)
153
-
 
 
 
 
154
  dimension = embeddings.shape[1]
155
- self.index = faiss.IndexFlatL2(dimension)
156
-
 
 
 
 
 
157
  if torch.cuda.is_available():
158
- # Use GPU for FAISS if available
159
- res = faiss.StandardGpuResources()
160
- self.index = faiss.index_cpu_to_gpu(res, 0, self.index)
161
-
162
- self.index.add(np.array(embeddings).astype('float32'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  self.indexed_documents = documents
164
-
165
- # Save index and documents
166
- cpu_index = faiss.index_gpu_to_cpu(self.index) if torch.cuda.is_available() else self.index
167
- faiss.write_index(cpu_index, str(self.index_path))
168
-
169
- with open(self.documents_path, 'wb') as f:
170
- pickle.dump(documents, f)
171
-
172
- return True
 
 
 
 
 
173
 
174
  def _index_exists(self) -> bool:
175
- """Vérifie si l'index et les documents associés existent"""
176
  return self.index_path.exists() and self.documents_path.exists()
177
 
178
- def get_retriever(self, k: int = 10):
179
- if self.index is None:
180
- if not self.load_index():
181
- if not self.create_index():
182
- raise ValueError("Unable to load or create index")
 
 
 
 
 
 
 
183
 
184
- def retriever_function(query: str) -> list:
185
- # Check cache first
186
- cache_key = f"{query}_{k}"
187
- if cache_key in self.response_cache:
188
- return self.response_cache[cache_key]
189
 
 
190
  query_embedding = self.encode(query)
191
-
192
- distances, indices = self.index.search(
193
- np.array([query_embedding]).astype('float32'),
194
- k
195
- )
196
-
197
- results = [
 
 
 
 
 
 
198
  self.indexed_documents[idx]
199
  for idx in indices[0]
200
- if idx != -1
201
  ]
202
-
203
- # Cache the results
204
- self.response_cache[cache_key] = results
205
- return results
206
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  return retriever_function
208
 
209
- # # Initialize components
210
- # mistral_api_key = os.getenv("mistral_api_key")
211
- # llm = ChatMistralAI(
212
- # model="mistral-large-latest",
213
- # mistral_api_key=mistral_api_key,
214
- # temperature=0.01,
215
- # streaming=True,
216
- # )
 
 
 
 
 
 
217
 
218
  # deepseek_api_key = os.getenv("DEEPSEEK_KEY")
219
- # llm = ChatDeepSeek(
220
- # model="deepseek-chat",
221
- # temperature=0,
222
- # api_key=deepseek_api_key,
223
- # streaming=True,
224
- # )
 
 
 
 
225
 
226
 
227
  gemini_api_key = os.getenv("GEMINI_KEY")
228
- llm = ChatGoogleGenerativeAI(
229
- model="gemini-1.5-pro",
230
- temperature=0,
231
- google_api_key=gemini_api_key,
232
- disable_streaming=True,
233
- )
 
 
 
 
 
 
 
 
 
 
234
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
- rag_loader = OptimizedRAGLoader()
237
- retriever = rag_loader.get_retriever(k=5) # Reduced k for faster retrieval
238
 
239
- # Cache for processed questions
240
  question_cache = {}
241
 
 
 
242
  prompt_template = ChatPromptTemplate.from_messages([
243
- ("system", """Vous êtes un assistant juridique expert qualifié. Analysez et répondez aux questions juridiques avec précision.
244
-
245
- PROCESSUS D'ANALYSE :
246
- 1. Analysez le contexte fourni : {context}
247
- 2. Utilisez la recherche web si la reponse n'existe pas dans le contexte
248
- 3. Privilégiez les sources officielles et la jurisprudence récente
249
-
250
- Question à traiter : {question}
251
- """),
252
- ("human", "{question}")
253
- ])
254
-
255
 
 
 
 
 
 
 
 
256
 
257
- import gradio as gr
258
-
 
259
 
260
- # Ajouter du CSS pour personnaliser l'apparence
 
261
  css = """
262
  /* Reset RTL global */
 
 
 
 
 
263
  *, *::before, *::after {
264
  direction: rtl !important;
265
  text-align: right !important;
266
  }
267
-
268
  body {
269
- font-family: 'Amiri', sans-serif; /* Utilisation de la police Arabe andalouse */
270
- background-color: black; /* Fond blanc */
271
- color: black !important; /* Texte noir */
272
- direction: rtl !important; /* Texte en arabe aligné à droite */
273
  }
274
-
275
  .gradio-container {
276
- direction: rtl !important; /* Alignement RTL pour toute l'interface */
 
277
  }
278
-
279
- /* Éléments de formulaire */
280
- input[type="text"],
281
- .gradio-textbox input,
282
- textarea {
283
- border-radius: 20px;
284
- padding: 10px 15px;
285
- border: 2px solid #000;
286
- font-size: 16px;
287
- width: 80%;
288
- margin: 0 auto;
289
  text-align: right !important;
 
 
 
290
  }
291
 
292
- /* Surcharge des styles de placeholder */
293
- input::placeholder,
294
- textarea::placeholder {
295
  text-align: right !important;
296
  direction: rtl !important;
 
297
  }
298
 
299
- /* Boutons */
300
  .gradio-button {
301
- border-radius: 20px;
302
- font-size: 16px;
303
- background-color: #007BFF;
304
- color: white;
305
- padding: 10px 20px;
306
- margin: 10px auto;
307
- border: none;
308
- width: 80%;
309
- display: block;
 
 
 
 
310
  }
311
-
312
  .gradio-button:hover {
313
- background-color: #0056b3;
 
 
 
 
 
 
 
 
 
 
 
 
314
  }
315
 
316
  .gradio-chatbot .message {
317
- border-radius: 20px;
318
- padding: 10px;
319
- margin: 10px 0;
320
- background-color: #f1f1f1;
321
- border: 1px solid #ddd;
322
- width: 80%;
323
  text-align: right !important;
324
  direction: rtl !important;
 
 
325
  }
326
 
327
- /* Messages utilisateur alignés à gauche */
328
- .gradio-chatbot .user-message {
329
- margin-right: auto;
330
- background-color: #e3f2fd;
331
- text-align: right !important;
332
- direction: rtl !important;
333
  }
334
 
335
- /* Messages assistant alignés à droite */
336
- .gradio-chatbot .assistant-message {
337
- margin-right: auto;
338
- background-color: #f1f1f1;
339
- text-align: right
 
340
  }
341
 
342
- /* Corrections RTL pour les éléments spécifiques */
343
- .gradio-textbox textarea {
344
- text-align: right !important;
 
 
 
 
345
  }
346
 
347
- .gradio-dropdown div {
 
348
  text-align: right !important;
 
 
 
 
 
 
349
  }
350
  """
351
 
352
- # Modified process_question function to better work with tuples
353
  def process_question(question: str) -> Iterator[str]:
354
- if question in question_cache:
355
- response, docs = question_cache[question]
356
- sources = [doc.metadata.get("source") for doc in docs]
357
- sources = list(set([os.path.splitext(source)[0] for source in sources]))
358
- yield response + "\n\n\nالمصادر المحتملة :\n" + "\n".join(sources)
 
359
  return
360
-
361
- relevant_docs = retriever(question)
362
-
363
- # Reranking with cross-encoder
364
- context = [doc.page_content for doc in relevant_docs]
365
- text_pairs = [[question, text] for text in context]
366
- scores = rag_loader.reranker.predict(text_pairs)
367
-
368
- scored_docs = list(zip(scores, context, relevant_docs))
369
- scored_docs.sort(key=lambda x: x[0], reverse=True)
370
- reranked_docs = [d[2].page_content for d in scored_docs][:10]
371
-
372
- prompt = prompt_template.format_messages(
373
- context=reranked_docs,
374
- question=question
375
- )
376
-
377
- full_response = ""
378
  try:
379
- for chunk in llm.stream(prompt):
380
- if isinstance(chunk, str):
381
- current_chunk = chunk
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  else:
383
- current_chunk = chunk.content
384
- full_response += current_chunk
385
-
386
- sources = [d[2].metadata['source'] for d in scored_docs][:10]
387
- sources = list(set([os.path.splitext(source)[0] for source in sources]))
388
-
389
- yield full_response + "\n\n\nالمصادر المحتملة :\n" + "\n".join(sources)
390
-
391
- question_cache[question] = (full_response, relevant_docs)
 
 
 
 
 
 
 
392
  except Exception as e:
393
- yield f"Erreur lors du traitement : {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
 
395
- # Updated gradio_stream function to work with tuples
396
- def gradio_stream(question: str, chat_history: list) -> Iterator[list]:
397
  try:
 
398
  for partial_response in process_question(question):
399
- # Using tuples (user_message, bot_message) format
400
- yield chat_history + [(question, partial_response)]
 
401
  except Exception as e:
402
- yield chat_history + [(question, f"Erreur : {str(e)}")]
 
 
 
403
 
404
- # Gradio interface
405
- with gr.Blocks(css=css) as demo:
406
- gr.Markdown("<h2 style='text-align: center !important;'>هذا تطبيق للاجابة على الأسئلة المتعلقة بالقوانين المغربية</h2>")
 
 
 
 
 
 
407
 
408
  with gr.Row():
409
- message = gr.Textbox(label="أدخل سؤالك", placeholder="اكتب سؤالك هنا", elem_id="question_input")
410
-
411
- with gr.Row():
412
- send = gr.Button("بحث", elem_id="search_button")
 
 
413
 
414
  with gr.Row():
415
- # No type parameter - use Gradio's default
416
- chatbot = gr.Chatbot(label="")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
 
418
- # Simplified user_input function
419
- def user_input(user_message, chat_history):
420
- return "", chat_history + [(user_message, None)]
421
 
422
- send.click(user_input, [message, chatbot], [message, chatbot], queue=False)
423
- send.click(gradio_stream, [message, chatbot], chatbot)
424
 
425
- demo.launch(share=True)
 
 
 
 
 
 
 
1
+ # import gradio as gr
2
+ # from langchain_mistralai.chat_models import ChatMistralAI
3
+ # from langchain.prompts import ChatPromptTemplate
4
+ # from langchain_deepseek import ChatDeepSeek
5
+ # from langchain_google_genai import ChatGoogleGenerativeAI
6
+ # import os
7
+ # from pathlib import Path
8
+ # import json
9
+ # import faiss
10
+ # import numpy as np
11
+ # from langchain.schema import Document
12
+ # import pickle
13
+ # import re
14
+ # import requests
15
+ # from functools import lru_cache
16
+ # import torch
17
+ # from sentence_transformers import SentenceTransformer
18
+ # from sentence_transformers.cross_encoder import CrossEncoder
19
+ # import threading
20
+ # from queue import Queue
21
+ # import concurrent.futures
22
+ # from typing import Generator, Tuple, Iterator
23
+ # import time
24
+
25
+ # class OptimizedRAGLoader:
26
+ # def __init__(self,
27
+ # docs_folder: str = "./docs",
28
+ # splits_folder: str = "./splits",
29
+ # index_folder: str = "./index"):
30
+
31
+ # self.docs_folder = Path(docs_folder)
32
+ # self.splits_folder = Path(splits_folder)
33
+ # self.index_folder = Path(index_folder)
34
+
35
+ # # Create folders if they don't exist
36
+ # for folder in [self.splits_folder, self.index_folder]:
37
+ # folder.mkdir(parents=True, exist_ok=True)
38
+
39
+ # # File paths
40
+ # self.splits_path = self.splits_folder / "splits.json"
41
+ # self.index_path = self.index_folder / "faiss.index"
42
+ # self.documents_path = self.index_folder / "documents.pkl"
43
+
44
+ # # Initialize components
45
+ # self.index = None
46
+ # self.indexed_documents = None
47
+
48
+ # # Initialize encoder model
49
+ # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+ # self.encoder = SentenceTransformer("intfloat/multilingual-e5-large")
51
+ # self.encoder.to(self.device)
52
+ # self.reranker = model = CrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1",trust_remote_code=True)
53
+
54
+ # # Initialize thread pool
55
+ # self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
56
+
57
+ # # Initialize response cache
58
+ # self.response_cache = {}
59
+
60
+ # @lru_cache(maxsize=1000)
61
+ # def encode(self, text: str):
62
+ # """Cached encoding function"""
63
+ # with torch.no_grad():
64
+ # embeddings = self.encoder.encode(
65
+ # text,
66
+ # convert_to_numpy=True,
67
+ # normalize_embeddings=True
68
+ # )
69
+ # return embeddings
70
+
71
+ # def batch_encode(self, texts: list):
72
+ # """Batch encoding for multiple texts"""
73
+ # with torch.no_grad():
74
+ # embeddings = self.encoder.encode(
75
+ # texts,
76
+ # batch_size=32,
77
+ # convert_to_numpy=True,
78
+ # normalize_embeddings=True,
79
+ # show_progress_bar=False
80
+ # )
81
+ # return embeddings
82
+
83
+ # def load_and_split_texts(self):
84
+ # if self._splits_exist():
85
+ # return self._load_existing_splits()
86
+
87
+ # documents = []
88
+ # futures = []
89
+
90
+ # for file_path in self.docs_folder.glob("*.txt"):
91
+ # future = self.executor.submit(self._process_file, file_path)
92
+ # futures.append(future)
93
+
94
+ # for future in concurrent.futures.as_completed(futures):
95
+ # documents.extend(future.result())
96
+
97
+ # self._save_splits(documents)
98
+ # return documents
99
+
100
+ # def _process_file(self, file_path):
101
+ # with open(file_path, 'r', encoding='utf-8') as file:
102
+ # text = file.read()
103
+ # chunks = [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()]
104
+
105
+ # return [
106
+ # Document(
107
+ # page_content=chunk,
108
+ # metadata={
109
+ # 'source': file_path.name,
110
+ # 'chunk_id': i,
111
+ # 'total_chunks': len(chunks)
112
+ # }
113
+ # )
114
+ # for i, chunk in enumerate(chunks)
115
+ # ]
116
+
117
+ # def load_index(self) -> bool:
118
+ # """
119
+ # Charge l'index FAISS et les documents associés s'ils existent
120
+
121
+ # Returns:
122
+ # bool: True si l'index a été chargé, False sinon
123
+ # """
124
+ # if not self._index_exists():
125
+ # print("Aucun index trouvé.")
126
+ # return False
127
+
128
+ # print("Chargement de l'index existant...")
129
+ # try:
130
+ # # Charger l'index FAISS
131
+ # self.index = faiss.read_index(str(self.index_path))
132
+
133
+ # # Charger les documents associés
134
+ # with open(self.documents_path, 'rb') as f:
135
+ # self.indexed_documents = pickle.load(f)
136
+
137
+ # print(f"Index chargé avec {self.index.ntotal} vecteurs")
138
+ # return True
139
+
140
+ # except Exception as e:
141
+ # print(f"Erreur lors du chargement de l'index: {e}")
142
+ # return False
143
+
144
+ # def create_index(self, documents=None):
145
+ # if documents is None:
146
+ # documents = self.load_and_split_texts()
147
+
148
+ # if not documents:
149
+ # return False
150
+
151
+ # texts = [doc.page_content for doc in documents]
152
+ # embeddings = self.batch_encode(texts)
153
+
154
+ # dimension = embeddings.shape[1]
155
+ # self.index = faiss.IndexFlatL2(dimension)
156
+
157
+ # if torch.cuda.is_available():
158
+ # # Use GPU for FAISS if available
159
+ # res = faiss.StandardGpuResources()
160
+ # self.index = faiss.index_cpu_to_gpu(res, 0, self.index)
161
+
162
+ # self.index.add(np.array(embeddings).astype('float32'))
163
+ # self.indexed_documents = documents
164
+
165
+ # # Save index and documents
166
+ # cpu_index = faiss.index_gpu_to_cpu(self.index) if torch.cuda.is_available() else self.index
167
+ # faiss.write_index(cpu_index, str(self.index_path))
168
+
169
+ # with open(self.documents_path, 'wb') as f:
170
+ # pickle.dump(documents, f)
171
+
172
+ # return True
173
+
174
+ # def _index_exists(self) -> bool:
175
+ # """Vérifie si l'index et les documents associés existent"""
176
+ # return self.index_path.exists() and self.documents_path.exists()
177
+
178
+ # def get_retriever(self, k: int = 10):
179
+ # if self.index is None:
180
+ # if not self.load_index():
181
+ # if not self.create_index():
182
+ # raise ValueError("Unable to load or create index")
183
+
184
+ # def retriever_function(query: str) -> list:
185
+ # # Check cache first
186
+ # cache_key = f"{query}_{k}"
187
+ # if cache_key in self.response_cache:
188
+ # return self.response_cache[cache_key]
189
+
190
+ # query_embedding = self.encode(query)
191
+
192
+ # distances, indices = self.index.search(
193
+ # np.array([query_embedding]).astype('float32'),
194
+ # k
195
+ # )
196
+
197
+ # results = [
198
+ # self.indexed_documents[idx]
199
+ # for idx in indices[0]
200
+ # if idx != -1
201
+ # ]
202
+
203
+ # # Cache the results
204
+ # self.response_cache[cache_key] = results
205
+ # return results
206
+
207
+ # return retriever_function
208
+
209
+ # # # Initialize components
210
+ # # mistral_api_key = os.getenv("mistral_api_key")
211
+ # # llm = ChatMistralAI(
212
+ # # model="mistral-large-latest",
213
+ # # mistral_api_key=mistral_api_key,
214
+ # # temperature=0.01,
215
+ # # streaming=True,
216
+ # # )
217
+
218
+ # # deepseek_api_key = os.getenv("DEEPSEEK_KEY")
219
+ # # llm = ChatDeepSeek(
220
+ # # model="deepseek-chat",
221
+ # # temperature=0,
222
+ # # api_key=deepseek_api_key,
223
+ # # streaming=True,
224
+ # # )
225
+
226
+
227
+ # gemini_api_key = os.getenv("GEMINI_KEY")
228
+ # llm = ChatGoogleGenerativeAI(
229
+ # model="gemini-1.5-pro",
230
+ # temperature=0,
231
+ # google_api_key=gemini_api_key,
232
+ # disable_streaming=True,
233
+ # )
234
+
235
+
236
+ # rag_loader = OptimizedRAGLoader()
237
+ # retriever = rag_loader.get_retriever(k=5) # Reduced k for faster retrieval
238
+
239
+ # # Cache for processed questions
240
+ # question_cache = {}
241
+
242
+ # prompt_template = ChatPromptTemplate.from_messages([
243
+ # ("system", """Vous êtes un assistant juridique expert qualifié. Analysez et répondez aux questions juridiques avec précision.
244
+
245
+ # PROCESSUS D'ANALYSE :
246
+ # 1. Analysez le contexte fourni : {context}
247
+ # 2. Utilisez la recherche web si la reponse n'existe pas dans le contexte
248
+ # 3. Privilégiez les sources officielles et la jurisprudence récente
249
+
250
+ # Question à traiter : {question}
251
+ # """),
252
+ # ("human", "{question}")
253
+ # ])
254
+
255
+
256
+
257
+ # import gradio as gr
258
+
259
+
260
+ # # Ajouter du CSS pour personnaliser l'apparence
261
+ # css = """
262
+ # /* Reset RTL global */
263
+ # *, *::before, *::after {
264
+ # direction: rtl !important;
265
+ # text-align: right !important;
266
+ # }
267
+
268
+ # body {
269
+ # font-family: 'Amiri', sans-serif; /* Utilisation de la police Arabe andalouse */
270
+ # background-color: black; /* Fond blanc */
271
+ # color: black !important; /* Texte noir */
272
+ # direction: rtl !important; /* Texte en arabe aligné à droite */
273
+ # }
274
+
275
+ # .gradio-container {
276
+ # direction: rtl !important; /* Alignement RTL pour toute l'interface */
277
+ # }
278
+
279
+ # /* Éléments de formulaire */
280
+ # input[type="text"],
281
+ # .gradio-textbox input,
282
+ # textarea {
283
+ # border-radius: 20px;
284
+ # padding: 10px 15px;
285
+ # border: 2px solid #000;
286
+ # font-size: 16px;
287
+ # width: 80%;
288
+ # margin: 0 auto;
289
+ # text-align: right !important;
290
+ # }
291
+
292
+ # /* Surcharge des styles de placeholder */
293
+ # input::placeholder,
294
+ # textarea::placeholder {
295
+ # text-align: right !important;
296
+ # direction: rtl !important;
297
+ # }
298
+
299
+ # /* Boutons */
300
+ # .gradio-button {
301
+ # border-radius: 20px;
302
+ # font-size: 16px;
303
+ # background-color: #007BFF;
304
+ # color: white;
305
+ # padding: 10px 20px;
306
+ # margin: 10px auto;
307
+ # border: none;
308
+ # width: 80%;
309
+ # display: block;
310
+ # }
311
+
312
+ # .gradio-button:hover {
313
+ # background-color: #0056b3;
314
+ # }
315
+
316
+ # .gradio-chatbot .message {
317
+ # border-radius: 20px;
318
+ # padding: 10px;
319
+ # margin: 10px 0;
320
+ # background-color: #f1f1f1;
321
+ # border: 1px solid #ddd;
322
+ # width: 80%;
323
+ # text-align: right !important;
324
+ # direction: rtl !important;
325
+ # }
326
+
327
+ # /* Messages utilisateur alignés à gauche */
328
+ # .gradio-chatbot .user-message {
329
+ # margin-right: auto;
330
+ # background-color: #e3f2fd;
331
+ # text-align: right !important;
332
+ # direction: rtl !important;
333
+ # }
334
+
335
+ # /* Messages assistant alignés à droite */
336
+ # .gradio-chatbot .assistant-message {
337
+ # margin-right: auto;
338
+ # background-color: #f1f1f1;
339
+ # text-align: right
340
+ # }
341
+
342
+ # /* Corrections RTL pour les éléments spécifiques */
343
+ # .gradio-textbox textarea {
344
+ # text-align: right !important;
345
+ # }
346
+
347
+ # .gradio-dropdown div {
348
+ # text-align: right !important;
349
+ # }
350
+ # """
351
+
352
+ # # Modified process_question function to better work with tuples
353
+ # def process_question(question: str) -> Iterator[str]:
354
+ # if question in question_cache:
355
+ # response, docs = question_cache[question]
356
+ # sources = [doc.metadata.get("source") for doc in docs]
357
+ # sources = list(set([os.path.splitext(source)[0] for source in sources]))
358
+ # yield response + "\n\n\nالمصادر المحتملة :\n" + "\n".join(sources)
359
+ # return
360
+
361
+ # relevant_docs = retriever(question)
362
+
363
+ # # Reranking with cross-encoder
364
+ # context = [doc.page_content for doc in relevant_docs]
365
+ # text_pairs = [[question, text] for text in context]
366
+ # scores = rag_loader.reranker.predict(text_pairs)
367
+
368
+ # scored_docs = list(zip(scores, context, relevant_docs))
369
+ # scored_docs.sort(key=lambda x: x[0], reverse=True)
370
+ # reranked_docs = [d[2].page_content for d in scored_docs][:10]
371
+
372
+ # prompt = prompt_template.format_messages(
373
+ # context=reranked_docs,
374
+ # question=question
375
+ # )
376
+
377
+ # full_response = ""
378
+ # try:
379
+ # for chunk in llm.stream(prompt):
380
+ # if isinstance(chunk, str):
381
+ # current_chunk = chunk
382
+ # else:
383
+ # current_chunk = chunk.content
384
+ # full_response += current_chunk
385
+
386
+ # sources = [d[2].metadata['source'] for d in scored_docs][:10]
387
+ # sources = list(set([os.path.splitext(source)[0] for source in sources]))
388
+
389
+ # yield full_response + "\n\n\nالمصادر المحتملة :\n" + "\n".join(sources)
390
+
391
+ # question_cache[question] = (full_response, relevant_docs)
392
+ # except Exception as e:
393
+ # yield f"Erreur lors du traitement : {str(e)}"
394
+
395
+ # # Updated gradio_stream function for 'messages' format
396
+ # def gradio_stream(question: str, chat_history: list) -> Iterator[list]:
397
+ # # chat_history now contains the user message added by user_input
398
+ # # Add a placeholder for the assistant's response
399
+ # chat_history.append({"role": "assistant", "content": ""})
400
+
401
+ # try:
402
+ # # Stream the response using the existing process_question generator
403
+ # for partial_response in process_question(question):
404
+ # # Update the content of the last message (the assistant's placeholder)
405
+ # chat_history[-1]["content"] = partial_response
406
+ # yield chat_history # Yield the entire updated history list
407
+ # except Exception as e:
408
+ # # Update the assistant's message with the error
409
+ # chat_history[-1]["content"] = f"Erreur : {str(e)}"
410
+ # yield chat_history # Yield the history with the error message
411
+
412
+ # # Gradio interface
413
+ # with gr.Blocks(css=css) as demo:
414
+ # gr.Markdown("<h2 style='text-align: center !important;'>هذا تطبيق للاجابة على الأسئلة المتعلقة بالقوانين المغربية</h2>")
415
+
416
+ # with gr.Row():
417
+ # message = gr.Textbox(label="أدخل سؤالك", placeholder="اكتب سؤالك هنا", elem_id="question_input")
418
+
419
+ # with gr.Row():
420
+ # send = gr.Button("بحث", elem_id="search_button")
421
+
422
+ # with gr.Row():
423
+ # # No type parameter - use Gradio's default
424
+ # chatbot = gr.Chatbot(label="", type="messages") # Ajout de type="messages"
425
+
426
+ # # Updated user_input function for 'messages' format
427
+ # def user_input(user_message, chat_history):
428
+ # # chat_history is already a list of message dicts
429
+ # # Append the new user message
430
+ # return "", chat_history + [{"role": "user", "content": user_message}]
431
+
432
+ # send.click(user_input, [message, chatbot], [message, chatbot], queue=False)
433
+ # send.click(gradio_stream, [message, chatbot], chatbot)
434
+
435
+ # demo.launch(share=True)
436
+
437
+
438
+
439
+
440
+
441
+
442
+
443
+
444
+
445
+
446
  import gradio as gr
447
  from langchain_mistralai.chat_models import ChatMistralAI
448
  from langchain.prompts import ChatPromptTemplate
 
464
  import threading
465
  from queue import Queue
466
  import concurrent.futures
467
+ from typing import Generator, Tuple, Iterator, List, Dict
468
  import time
469
 
470
+ # --- (Votre classe OptimizedRAGLoader reste la même) ---
471
  class OptimizedRAGLoader:
472
  def __init__(self,
473
  docs_folder: str = "./docs",
474
  splits_folder: str = "./splits",
475
  index_folder: str = "./index"):
476
+
477
  self.docs_folder = Path(docs_folder)
478
  self.splits_folder = Path(splits_folder)
479
  self.index_folder = Path(index_folder)
480
+
481
  # Create folders if they don't exist
482
  for folder in [self.splits_folder, self.index_folder]:
483
  folder.mkdir(parents=True, exist_ok=True)
484
+
485
  # File paths
486
  self.splits_path = self.splits_folder / "splits.json"
487
  self.index_path = self.index_folder / "faiss.index"
488
  self.documents_path = self.index_folder / "documents.pkl"
489
+
490
  # Initialize components
491
  self.index = None
492
  self.indexed_documents = None
493
+
494
  # Initialize encoder model
495
+ print("Loading Sentence Transformer...")
496
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
497
  self.encoder = SentenceTransformer("intfloat/multilingual-e5-large")
498
  self.encoder.to(self.device)
499
+ print("Loading Cross Encoder...")
500
+ self.reranker = CrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1", trust_remote_code=True)
501
+ print("Models loaded.")
502
+
503
  # Initialize thread pool
504
  self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
505
+
506
  # Initialize response cache
507
  self.response_cache = {}
508
+
509
+ # Try loading index on init
510
+ self.load_or_create_index()
511
+
512
+ def load_or_create_index(self):
513
+ """Loads index if exists, otherwise creates it."""
514
+ if not self.load_index():
515
+ print("Index not found, creating new index...")
516
+ if not self.create_index():
517
+ raise RuntimeError("Failed to create index.")
518
+ else:
519
+ print("Index created successfully.")
520
+ else:
521
+ print("Index loaded successfully.")
522
+
523
+
524
  @lru_cache(maxsize=1000)
525
  def encode(self, text: str):
526
  """Cached encoding function"""
 
528
  embeddings = self.encoder.encode(
529
  text,
530
  convert_to_numpy=True,
531
+ normalize_embeddings=True,
532
+ device=self.device # Ensure encoding runs on the correct device
533
  )
534
  return embeddings
535
+
536
  def batch_encode(self, texts: list):
537
  """Batch encoding for multiple texts"""
538
  with torch.no_grad():
539
  embeddings = self.encoder.encode(
540
  texts,
541
+ batch_size=32, # Adjust based on GPU memory
542
  convert_to_numpy=True,
543
  normalize_embeddings=True,
544
+ show_progress_bar=True, # Show progress for potentially long operations
545
+ device=self.device # Ensure encoding runs on the correct device
546
  )
547
  return embeddings
548
 
549
+ def _splits_exist(self) -> bool:
550
+ """Check if split files exist."""
551
+ return self.splits_path.exists()
552
+
553
+ def _load_existing_splits(self) -> List[Document]:
554
+ """Load splits from JSON file."""
555
+ print(f"Loading existing splits from {self.splits_path}...")
556
+ try:
557
+ with open(self.splits_path, 'r', encoding='utf-8') as f:
558
+ splits_data = json.load(f)
559
+ documents = [
560
+ Document(page_content=item['page_content'], metadata=item['metadata'])
561
+ for item in splits_data
562
+ ]
563
+ print(f"Loaded {len(documents)} splits.")
564
+ return documents
565
+ except Exception as e:
566
+ print(f"Error loading splits: {e}. Recreating...")
567
+ return [] # Return empty list to trigger recreation
568
+
569
+ def _save_splits(self, documents: List[Document]):
570
+ """Save splits to JSON file."""
571
+ print(f"Saving {len(documents)} splits to {self.splits_path}...")
572
+ splits_data = [
573
+ {'page_content': doc.page_content, 'metadata': doc.metadata}
574
+ for doc in documents
575
+ ]
576
+ try:
577
+ with open(self.splits_path, 'w', encoding='utf-8') as f:
578
+ json.dump(splits_data, f, ensure_ascii=False, indent=2)
579
+ print("Splits saved successfully.")
580
+ except Exception as e:
581
+ print(f"Error saving splits: {e}")
582
+
583
+
584
+ def load_and_split_texts(self) -> List[Document]:
585
  if self._splits_exist():
586
+ loaded_splits = self._load_existing_splits()
587
+ if loaded_splits: # Check if loading was successful
588
+ return loaded_splits
589
+
590
+ print("Processing documents and creating splits...")
591
  documents = []
592
  futures = []
593
+
594
+ # Ensure docs folder exists
595
+ if not self.docs_folder.is_dir():
596
+ print(f"Error: Docs folder not found at {self.docs_folder}")
597
+ # Create dummy docs folder for Spaces if it doesn't exist
598
+ self.docs_folder.mkdir(parents=True, exist_ok=True)
599
+ print(f"Created empty docs folder at {self.docs_folder}. Please upload text files.")
600
+ # You might want to add a default dummy file here for testing
601
+ # with open(self.docs_folder / "dummy.txt", "w") as f:
602
+ # f.write("This is a dummy file. Please replace with real legal documents.")
603
+ # return [] # Return empty if no real docs
604
+ # Or let it continue to process the dummy file if created
605
+
606
+ doc_files = list(self.docs_folder.glob("*.txt"))
607
+ if not doc_files:
608
+ print(f"No .txt files found in {self.docs_folder}. Cannot create index.")
609
+ # Add a dummy document if none exist to prevent errors downstream?
610
+ # Or handle this case in create_index more gracefully.
611
+ return []
612
+
613
+ print(f"Found {len(doc_files)} files to process.")
614
+ for file_path in doc_files:
615
  future = self.executor.submit(self._process_file, file_path)
616
  futures.append(future)
617
+
618
+ processed_count = 0
619
  for future in concurrent.futures.as_completed(futures):
620
+ try:
621
+ documents.extend(future.result())
622
+ processed_count += 1
623
+ print(f"Processed file {processed_count}/{len(doc_files)}")
624
+ except Exception as e:
625
+ print(f"Error processing file in future: {e}")
626
+
627
+ if documents:
628
+ self._save_splits(documents)
629
+ else:
630
+ print("No documents were successfully processed or split.")
631
  return documents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
632
 
633
+ def _process_file(self, file_path: Path) -> List[Document]:
634
+ try:
635
+ with open(file_path, 'r', encoding='utf-8') as file:
636
+ text = file.read()
637
+ # Improved splitting: handle more sentence endings and ensure non-empty chunks
638
+ chunks = [s.strip() for s in re.split(r'(?<=[.!?؟؛])\s+', text) if s and s.strip()]
639
+ if not chunks: # Handle empty files or files with no standard sentence endings
640
+ print(f"Warning: No chunks generated for file {file_path.name}. Treating whole file as one chunk.")
641
+ if text.strip(): # If there's content, use it as one chunk
642
+ chunks = [text.strip()]
643
+ else:
644
+ return [] # Skip empty files
645
+
646
+ return [
647
+ Document(
648
+ page_content=chunk,
649
+ metadata={
650
+ 'source': file_path.name,
651
+ 'chunk_id': i,
652
+ 'total_chunks': len(chunks)
653
+ }
654
+ )
655
+ for i, chunk in enumerate(chunks)
656
+ ]
657
+ except Exception as e:
658
+ print(f"Error processing file {file_path.name}: {e}")
659
+ return [] # Return empty list on error for this file
660
+
661
 
662
+ def load_index(self) -> bool:
663
+ """Loads FAISS index and associated documents if they exist."""
 
664
  if not self._index_exists():
665
+ # print("Index files not found.") # Reduced verbosity
666
  return False
667
 
668
+ print(f"Loading existing index from {self.index_path} and documents from {self.documents_path}...")
669
  try:
670
+ # Load FAISS index
671
  self.index = faiss.read_index(str(self.index_path))
672
 
673
+ # If the loaded index was originally GPU, move it back if possible
674
+ if torch.cuda.is_available():
675
+ try:
676
+ print("Moving loaded index to GPU...")
677
+ res = faiss.StandardGpuResources()
678
+ self.index = faiss.index_cpu_to_gpu(res, 0, self.index)
679
+ print("Index successfully moved to GPU.")
680
+ except Exception as gpu_e:
681
+ print(f"Could not move index to GPU, using CPU. Error: {gpu_e}")
682
+
683
+
684
+ # Load associated documents
685
  with open(self.documents_path, 'rb') as f:
686
  self.indexed_documents = pickle.load(f)
687
 
688
+ if not self.indexed_documents:
689
+ print("Warning: Index loaded, but associated documents file is empty.")
690
+ # Consider this a failure case maybe?
691
+ # return False
692
+ elif self.index.ntotal != len(self.indexed_documents):
693
+ print(f"Warning: Index size ({self.index.ntotal}) does not match document count ({len(self.indexed_documents)}). Index might be corrupted or outdated.")
694
+ # Decide how to handle mismatch: rebuild? error? proceed with caution?
695
+ # For now, let's treat it as loaded but potentially problematic.
696
+
697
+ print(f"Index loaded with {self.index.ntotal} vectors.")
698
  return True
699
 
700
+ except FileNotFoundError:
701
+ print("Index files not found during load attempt.")
702
+ self.index = None
703
+ self.indexed_documents = None
704
+ return False
705
  except Exception as e:
706
+ print(f"Error loading index: {e}")
707
+ # Clean up potentially partially loaded state
708
+ self.index = None
709
+ self.indexed_documents = None
710
  return False
711
 
712
+ def create_index(self, documents: List[Document] = None) -> bool:
713
+ """Creates or recreates the FAISS index."""
714
  if documents is None:
715
  documents = self.load_and_split_texts()
716
+
717
  if not documents:
718
+ print("No documents provided or loaded, cannot create index.")
719
  return False
720
+
721
+ print(f"Creating index for {len(documents)} document splits...")
722
  texts = [doc.page_content for doc in documents]
723
+ print("Encoding documents...")
724
  embeddings = self.batch_encode(texts)
725
+
726
+ if embeddings is None or len(embeddings) == 0:
727
+ print("Encoding failed or produced no embeddings.")
728
+ return False
729
+
730
  dimension = embeddings.shape[1]
731
+ print(f"Embeddings created with dimension {dimension}.")
732
+
733
+ # Create CPU index first
734
+ cpu_index = faiss.IndexFlatL2(dimension)
735
+ print("FAISS CPU index created.")
736
+
737
+ # Use GPU for FAISS if available
738
  if torch.cuda.is_available():
739
+ try:
740
+ print("Attempting to use GPU for FAISS indexing...")
741
+ res = faiss.StandardGpuResources()
742
+ self.index = faiss.index_cpu_to_gpu(res, 0, cpu_index)
743
+ print("Adding embeddings to GPU index...")
744
+ self.index.add(np.array(embeddings).astype('float32'))
745
+ print(f"Embeddings added to GPU index. Index size: {self.index.ntotal}")
746
+ # Save the CPU version of the index
747
+ print(f"Saving CPU version of index to {self.index_path}...")
748
+ faiss.write_index(faiss.index_gpu_to_cpu(self.index), str(self.index_path))
749
+ except Exception as gpu_e:
750
+ print(f"GPU indexing failed: {gpu_e}. Falling back to CPU.")
751
+ self.index = cpu_index # Fallback to CPU index
752
+ print("Adding embeddings to CPU index...")
753
+ self.index.add(np.array(embeddings).astype('float32'))
754
+ print(f"Embeddings added to CPU index. Index size: {self.index.ntotal}")
755
+ print(f"Saving CPU index to {self.index_path}...")
756
+ faiss.write_index(self.index, str(self.index_path))
757
+ else:
758
+ print("GPU not available. Using CPU for FAISS indexing.")
759
+ self.index = cpu_index
760
+ print("Adding embeddings to CPU index...")
761
+ self.index.add(np.array(embeddings).astype('float32'))
762
+ print(f"Embeddings added to CPU index. Index size: {self.index.ntotal}")
763
+ print(f"Saving CPU index to {self.index_path}...")
764
+ faiss.write_index(self.index, str(self.index_path))
765
+
766
+
767
  self.indexed_documents = documents
768
+
769
+ # Save documents
770
+ print(f"Saving associated documents to {self.documents_path}...")
771
+ try:
772
+ with open(self.documents_path, 'wb') as f:
773
+ pickle.dump(documents, f)
774
+ print("Index and documents saved successfully.")
775
+ return True
776
+ except Exception as e:
777
+ print(f"Error saving associated documents: {e}")
778
+ # Should we delete the index file if doc saving fails?
779
+ # self.index_path.unlink(missing_ok=True)
780
+ return False
781
+
782
 
783
  def _index_exists(self) -> bool:
784
+ """Checks if the index and associated document files exist."""
785
  return self.index_path.exists() and self.documents_path.exists()
786
 
787
+ def get_retriever(self, k: int = 10, rerank_k: int = 5):
788
+ """Gets a retriever function that performs FAISS search and cross-encoder reranking."""
789
+ if self.index is None or self.indexed_documents is None:
790
+ print("Index not initialized. Ensure load_or_create_index() was successful.")
791
+ # Attempt to load/create again, or raise error
792
+ self.load_or_create_index()
793
+ if self.index is None or self.indexed_documents is None:
794
+ raise ValueError("Unable to load or create index for retriever.")
795
+
796
+ # Make sure k for FAISS search is >= rerank_k
797
+ faiss_k = max(k, rerank_k)
798
+
799
 
800
+ def retriever_function(query: str) -> list[Document]:
801
+ # Check cache first (optional, consider if caching reranked results is desired)
802
+ # cache_key = f"{query}_{rerank_k}"
803
+ # if cache_key in self.response_cache:
804
+ # return self.response_cache[cache_key]
805
 
806
+ print(f"\nRetriever: Searching for query: '{query[:50]}...'")
807
  query_embedding = self.encode(query)
808
+
809
+ print(f"Searching top {faiss_k} in FAISS index...")
810
+ try:
811
+ distances, indices = self.index.search(
812
+ np.array([query_embedding]).astype('float32'),
813
+ faiss_k
814
+ )
815
+ except Exception as search_e:
816
+ print(f"Error during FAISS search: {search_e}")
817
+ return []
818
+
819
+ # Filter out invalid indices (-1) and get initial documents
820
+ initial_results = [
821
  self.indexed_documents[idx]
822
  for idx in indices[0]
823
+ if idx != -1 and idx < len(self.indexed_documents) # Added bounds check
824
  ]
825
+
826
+ if not initial_results:
827
+ print("No relevant documents found in FAISS search.")
828
+ return []
829
+
830
+ print(f"Found {len(initial_results)} initial candidates. Reranking top {len(initial_results)}...")
831
+
832
+ # Prepare for reranking
833
+ context = [doc.page_content for doc in initial_results]
834
+ text_pairs = [[query, text] for text in context]
835
+
836
+ # Rerank using the cross-encoder
837
+ try:
838
+ scores = self.reranker.predict(text_pairs, show_progress_bar=False) # Don't show progress bar here
839
+ except Exception as rerank_e:
840
+ print(f"Error during reranking: {rerank_e}. Returning initial FAISS results.")
841
+ return initial_results[:rerank_k] # Return top k from initial results
842
+
843
+
844
+ # Combine scores with documents and sort
845
+ scored_docs = list(zip(scores, initial_results))
846
+ scored_docs.sort(key=lambda x: x[0], reverse=True)
847
+
848
+ # Select the top rerank_k documents
849
+ reranked_docs = [doc for score, doc in scored_docs[:rerank_k]]
850
+
851
+ print(f"Reranked results (top {len(reranked_docs)}):")
852
+ # for i, (score, doc) in enumerate(scored_docs[:rerank_k]):
853
+ # print(f" {i+1}. Score: {score:.4f}, Source: {doc.metadata.get('source', 'N/A')}, Chunk: {doc.metadata.get('chunk_id', 'N/A')}")
854
+
855
+
856
+ # Cache results (optional)
857
+ # self.response_cache[cache_key] = reranked_docs
858
+ return reranked_docs
859
+
860
  return retriever_function
861
 
862
+ # --- LLM Initialization ---
863
+ # Choose *one* LLM to uncomment based on your available API key
864
+ # print("Initializing LLM...")
865
+ # mistral_api_key = os.getenv("MISTRAL_API_KEY") # Ensure env var name matches your Space secret
866
+ # if mistral_api_key:
867
+ # llm = ChatMistralAI(
868
+ # model="mistral-large-latest",
869
+ # mistral_api_key=mistral_api_key,
870
+ # temperature=0.01,
871
+ # # streaming=True, # Streaming is handled differently with Gradio 'messages'
872
+ # )
873
+ # print("Using Mistral LLM.")
874
+ # else:
875
+ # print("Mistral API key not found.")
876
 
877
  # deepseek_api_key = os.getenv("DEEPSEEK_KEY")
878
+ # if deepseek_api_key:
879
+ # llm = ChatDeepSeek(
880
+ # model="deepseek-chat",
881
+ # temperature=0.01, # Slightly non-zero for potentially better phrasing
882
+ # api_key=deepseek_api_key,
883
+ # # streaming=True, # Streaming is handled differently with Gradio 'messages'
884
+ # )
885
+ # print("Using DeepSeek LLM.")
886
+ # else:
887
+ # print("Deepseek API key not found.")
888
 
889
 
890
  gemini_api_key = os.getenv("GEMINI_KEY")
891
+ if gemini_api_key:
892
+ try:
893
+ llm = ChatGoogleGenerativeAI(
894
+ model="gemini-1.5-flash", # Using flash for potentially faster responses
895
+ temperature=0.1, # Slightly increased temperature
896
+ google_api_key=gemini_api_key,
897
+ # convert_system_message_to_human=True # Sometimes needed for Gemini
898
+ # streaming=False # Gemini streaming requires specific handling; simpler without for now
899
+ )
900
+ print("Using Google Gemini LLM.")
901
+ except Exception as gemini_e:
902
+ print(f"Error initializing Gemini LLM: {gemini_e}")
903
+ llm = None # Set llm to None if initialization fails
904
+ else:
905
+ print("Gemini API key not found.")
906
+ llm = None
907
 
908
+ # --- RAG Loader and Retriever Initialization ---
909
+ print("Initializing RAG Loader...")
910
+ try:
911
+ rag_loader = OptimizedRAGLoader()
912
+ retriever = rag_loader.get_retriever(k=15, rerank_k=5) # Retrieve more initially, rerank top 5
913
+ print("RAG Loader and Retriever initialized.")
914
+ except Exception as rag_e:
915
+ print(f"FATAL: Could not initialize RAG system: {rag_e}")
916
+ # Optionally exit or provide a dummy retriever/LLM
917
+ retriever = lambda query: [] # Dummy retriever
918
+ llm = None # Ensure LLM is None if RAG fails
919
 
 
 
920
 
921
+ # Cache for processed questions (Consider persistence or size limits if needed)
922
  question_cache = {}
923
 
924
+ # --- Prompt Template ---
925
+ # Adjusted prompt for clarity and conciseness
926
  prompt_template = ChatPromptTemplate.from_messages([
927
+ ("system", """أنت مساعد قانوني خبير ومؤهل في القانون المغربي. مهمتك هي تحليل الأسئلة القانونية والإجابة عليها بدقة بناءً على السياق المقدم.
 
 
 
 
 
 
 
 
 
 
 
928
 
929
+ إرشادات:
930
+ 1. حلل السياق التالي بعناية:
931
+ {context}
932
+ 2. استخدم المعلومات من السياق فقط لصياغة إجابتك.
933
+ 3. إذا كانت المعلومات غير كافية أو غير موجودة في السياق للإجابة على السؤال، أشر بوضوح إلى أن السياق المقدم لا يحتوي على الإجابة المطلوبة. لا تختلق معلومات.
934
+ 4. اذكر المصادر (أسماء الملفات) التي استخدمتها من السياق في نهاية إجابتك.
935
+ 5. أجب باللغة العربية وبأسلوب واضح وموجز.
936
 
937
+ السؤال المطلوب الإجابة عليه: {question}"""),
938
+ ("human", "{question}") # Human message might be redundant if question is in system prompt, but often helps guide model role.
939
+ ])
940
 
941
+ # --- CSS Styling ---
942
+ # (CSS remains the same, ensure RTL works as intended)
943
  css = """
944
  /* Reset RTL global */
945
+ :root {
946
+ --input-border-radius: 15px !important; /* Example variable */
947
+ --button-border-radius: 15px !important;
948
+ }
949
+
950
  *, *::before, *::after {
951
  direction: rtl !important;
952
  text-align: right !important;
953
  }
 
954
  body {
955
+ font-family: 'Arial', 'sans-serif'; /* Using a more standard font */
956
+ background-color: #f8f9fa; /* Light gray background */
957
+ color: #343a40; /* Darker text */
958
+ direction: rtl !important;
959
  }
 
960
  .gradio-container {
961
+ direction: rtl !important;
962
+ background-color: #f8f9fa;
963
  }
964
+ /* Input Textbox */
965
+ .gradio-textbox textarea {
966
+ border-radius: var(--input-border-radius) !important;
967
+ padding: 12px 18px !important;
968
+ border: 1px solid #ced4da !important;
969
+ font-size: 16px !important;
970
+ width: 95% !important; /* Adjust width */
971
+ margin: 10px auto !important;
972
+ display: block !important;
 
 
973
  text-align: right !important;
974
+ background-color: #ffffff !important; /* White background for input */
975
+ color: #495057 !important; /* Input text color */
976
+ box-shadow: 0 2px 4px rgba(0,0,0,0.05) !important; /* Subtle shadow */
977
  }
978
 
979
+ /* Placeholder styling */
980
+ .gradio-textbox textarea::placeholder {
 
981
  text-align: right !important;
982
  direction: rtl !important;
983
+ color: #6c757d !important; /* Lighter placeholder text */
984
  }
985
 
986
+ /* Send Button */
987
  .gradio-button {
988
+ border-radius: var(--button-border-radius) !important;
989
+ font-size: 16px !important;
990
+ font-weight: bold !important;
991
+ background-color: #007bff !important; /* Primary blue */
992
+ color: white !important;
993
+ padding: 12px 24px !important;
994
+ margin: 5px auto 15px auto !important; /* Adjust margins */
995
+ border: none !important;
996
+ width: 95% !important; /* Match textbox width */
997
+ display: block !important;
998
+ cursor: pointer !important;
999
+ transition: background-color 0.2s ease-in-out !important;
1000
+ box-shadow: 0 2px 5px rgba(0, 123, 255, 0.3) !important; /* Button shadow */
1001
  }
 
1002
  .gradio-button:hover {
1003
+ background-color: #0056b3 !important; /* Darker blue on hover */
1004
+ }
1005
+
1006
+ /* Chatbot Messages */
1007
+ .gradio-chatbot {
1008
+ box-shadow: 0 2px 10px rgba(0,0,0,0.1) !important; /* Chatbot container shadow */
1009
+ border-radius: 10px !important;
1010
+ background-color: #ffffff !important; /* White background for chat area */
1011
+ padding: 10px !important;
1012
+ }
1013
+
1014
+ .gradio-chatbot .message-wrap { /* Targeting the wrapper for better styling */
1015
+ padding: 5px 0 !important; /* Space between messages */
1016
  }
1017
 
1018
  .gradio-chatbot .message {
1019
+ border-radius: 18px !important; /* Rounded corners for messages */
1020
+ padding: 10px 15px !important;
1021
+ margin: 5px 0 !important; /* Vertical margin */
1022
+ max-width: 85% !important; /* Max width of message bubble */
1023
+ border: none !important; /* Remove default border */
1024
+ box-shadow: 0 1px 3px rgba(0,0,0,0.08) !important; /* Subtle message shadow */
1025
  text-align: right !important;
1026
  direction: rtl !important;
1027
+ word-wrap: break-word; /* Ensure long words break */
1028
+ line-height: 1.5; /* Improve readability */
1029
  }
1030
 
1031
+ /* User Messages (align left conceptually, but RTL makes them appear right-aligned within container) */
1032
+ .gradio-chatbot .user-message .message {
1033
+ margin-left: auto !important; /* Push to the 'end' side in LTR, start in RTL */
1034
+ margin-right: 0 !important;
1035
+ background-color: #e7f5ff !important; /* Light blue for user */
1036
+ color: #0056b3 !important;
1037
  }
1038
 
1039
+ /* Assistant Messages (align right conceptually, appear left-aligned in RTL) */
1040
+ .gradio-chatbot .assistant-message .message {
1041
+ margin-right: auto !important; /* Push to the 'start' side in LTR, end in RTL */
1042
+ margin-left: 0 !important;
1043
+ background-color: #f1f3f5 !important; /* Light gray for assistant */
1044
+ color: #343a40 !important;
1045
  }
1046
 
1047
+ /* Markdown Header */
1048
+ h2 {
1049
+ color: #0056b3 !important; /* Match button hover color */
1050
+ font-weight: bold !important;
1051
+ margin-bottom: 20px !important;
1052
+ text-align: center !important;
1053
+ direction: rtl !important; /* Ensure header is also RTL */
1054
  }
1055
 
1056
+ /* Ensure specific Gradio elements inherit RTL */
1057
+ .gradio-dropdown div, .gradio-checkboxgroup div, .gradio-radio div {
1058
  text-align: right !important;
1059
+ direction: rtl !important;
1060
+ }
1061
+
1062
+ /* Center alignment for elements within rows if needed */
1063
+ .gradio-row {
1064
+ justify-content: center !important; /* Helps center content like buttons/textboxes */
1065
  }
1066
  """
1067
 
1068
+ # --- Backend Processing Function ---
1069
  def process_question(question: str) -> Iterator[str]:
1070
+ """
1071
+ Processes a question using RAG and LLM, yielding the response stream.
1072
+ Includes source attribution.
1073
+ """
1074
+ if not llm:
1075
+ yield "عذراً، النموذج اللغوي غير متاح حالياً. يرجى المحاولة لاحقاً."
1076
  return
1077
+ if not retriever:
1078
+ yield "عذراً، نظام استرجاع المعلومات غير متاح حالياً."
1079
+ return
1080
+
1081
+ # Simple caching check (consider more robust caching)
1082
+ # if question in question_cache:
1083
+ # response, sources_str = question_cache[question]
1084
+ # yield response + sources_str
1085
+ # return
1086
+
1087
+ print(f"Processing question: {question}")
 
 
 
 
 
 
 
1088
  try:
1089
+ relevant_docs = retriever(question)
1090
+ except Exception as ret_e:
1091
+ print(f"Error during retrieval: {ret_e}")
1092
+ yield f"حدث خطأ أثناء البحث عن المستندات المتعلقة: {str(ret_e)}"
1093
+ return
1094
+
1095
+
1096
+ if not relevant_docs:
1097
+ print("No relevant documents found by retriever.")
1098
+ yield "لم أتمكن من العثور على معلومات ذات صلة في المستندات المتاحة للإجابة على سؤالك."
1099
+ return
1100
+
1101
+ context_str = "\n\n".join([f"المصدر: {doc.metadata.get('source', 'غير معروف')}\nالمحتوى: {doc.page_content}" for doc in relevant_docs])
1102
+ sources = list(set([doc.metadata.get("source", "غير معروف") for doc in relevant_docs]))
1103
+ sources_str = "\n\n\nالمصادر المحتملة التي تم الرجوع إليها:\n- " + "\n- ".join(sources)
1104
+
1105
+ print(f"Context created from {len(relevant_docs)} documents. Generating response...")
1106
+ # print(f"Context sample: {context_str[:200]}...") # Debugging
1107
+
1108
+ try:
1109
+ # Format the prompt using the template
1110
+ prompt = prompt_template.format_messages(
1111
+ context=context_str,
1112
+ question=question
1113
+ )
1114
+
1115
+ # --- Non-Streaming Call (Simpler for Gemini without specific streaming setup) ---
1116
+ # full_response = llm.invoke(prompt) # Use invoke for non-streaming
1117
+ # if isinstance(full_response, str): # Handle potential different return types
1118
+ # response_content = full_response
1119
+ # else:
1120
+ # response_content = full_response.content
1121
+
1122
+ # yield response_content + sources_str # Yield the complete response at once
1123
+
1124
+ # --- Streaming Call (If LLM supports it and is configured correctly) ---
1125
+ full_response = ""
1126
+ stream = llm.stream(prompt)
1127
+ start_time = time.time()
1128
+ first_chunk_received = False
1129
+
1130
+ for chunk in stream:
1131
+ if not first_chunk_received:
1132
+ end_time = time.time()
1133
+ print(f"Time to first chunk: {end_time - start_time:.2f} seconds")
1134
+ first_chunk_received = True
1135
+
1136
+ # Adapt based on Langchain version and LLM provider's chunk structure
1137
+ if hasattr(chunk, 'content'):
1138
+ current_chunk = chunk.content
1139
+ elif isinstance(chunk, str):
1140
+ current_chunk = chunk
1141
  else:
1142
+ print(f"Unexpected chunk type: {type(chunk)}")
1143
+ current_chunk = str(chunk) # Fallback
1144
+
1145
+ if current_chunk: # Avoid adding empty chunks
1146
+ full_response += current_chunk
1147
+ # Yield intermediate response with sources appended
1148
+ yield full_response + sources_str # Appending sources at each step
1149
+
1150
+ if not first_chunk_received: # Handle cases where stream might be empty or fail silently
1151
+ print("No chunks received from LLM stream.")
1152
+ yield "حدث خطأ أو لم يتمكن النموذج من إنشاء رد." + sources_str
1153
+
1154
+ print("LLM response generation complete.")
1155
+ # Cache the final result (optional)
1156
+ # question_cache[question] = (full_response, sources_str)
1157
+
1158
  except Exception as e:
1159
+ print(f"Error during LLM generation: {e}")
1160
+ yield f"حدث خطأ أثناء توليد الإجابة: {str(e)}" + sources_str # Include sources even on error if possible
1161
+
1162
+
1163
+ # --- Gradio Interface Functions ---
1164
+
1165
+ # Function to add user message to history (using 'messages' format)
1166
+ def user_input(user_message: str, chat_history: List[Dict[str, str]]) -> Tuple[str, List[Dict[str, str]]]:
1167
+ if not user_message.strip(): # Prevent empty messages
1168
+ return "", chat_history
1169
+ # Append the user's message to the history in the correct format
1170
+ return "", chat_history + [{"role": "user", "content": user_message}]
1171
+
1172
+ # Function to handle the streaming response (using 'messages' format)
1173
+ def gradio_stream(question: str, chat_history: List[Dict[str, str]]) -> Iterator[List[Dict[str, str]]]:
1174
+ if not question.strip(): # Prevent processing empty questions passed from user_input
1175
+ yield chat_history
1176
+ return
1177
+
1178
+ if not llm or not retriever:
1179
+ chat_history.append({"role": "assistant", "content": "عذراً، النظام غير جاهز حالياً. يرجى التأكد من تهيئة المفاتيح والنماذج."})
1180
+ yield chat_history
1181
+ return
1182
+
1183
+ # Add a placeholder for the assistant's response
1184
+ chat_history.append({"role": "assistant", "content": ""})
1185
+ # Use a thinking indicator initially
1186
+ chat_history[-1]["content"] = "جارٍ التفكير والبحث..."
1187
+ yield chat_history # Show "Thinking..." message immediately
1188
+
1189
 
 
 
1190
  try:
1191
+ # Stream the response using the process_question generator
1192
  for partial_response in process_question(question):
1193
+ # Update the content of the last message (the assistant's placeholder)
1194
+ chat_history[-1]["content"] = partial_response
1195
+ yield chat_history # Yield the entire updated history list
1196
  except Exception as e:
1197
+ print(f"Error in gradio_stream calling process_question: {e}")
1198
+ # Update the assistant's message with the error
1199
+ chat_history[-1]["content"] = f"حدث خطأ غير متوقع: {str(e)}"
1200
+ yield chat_history # Yield the history with the error message
1201
 
1202
+
1203
+ # --- Gradio Interface Definition ---
1204
+ print("Building Gradio interface...")
1205
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: # Added a theme
1206
+ gr.Markdown("<h2 style='text-align: center;'>مساعد قانوني مغربي - أسئلة وأجوبة</h2>")
1207
+ gr.Markdown("<p style='text-align: center; color: #6c757d;'>اطرح سؤالك حول القوانين المغربية وسأحاول الإجابة بناءً على المستندات المتوفرة.</p>")
1208
+
1209
+ # Use type="messages" for the chatbot
1210
+ chatbot = gr.Chatbot(label="المحادثة", type="messages", height=500)
1211
 
1212
  with gr.Row():
1213
+ message = gr.Textbox(
1214
+ label="أدخل سؤالك هنا:",
1215
+ placeholder="مثال: ما هي شروط الحصول على تعويض عن حادثة شغل؟",
1216
+ lines=3, # Allow more lines for input
1217
+ elem_id="question_input" # Keep elem_id if needed elsewhere
1218
+ )
1219
 
1220
  with gr.Row():
1221
+ send = gr.Button("إرسال السؤال", variant="primary") # Use variant for emphasis
1222
+
1223
+
1224
+ # Chain the actions:
1225
+ # 1. When send is clicked, call user_input:
1226
+ # - Takes user message (from 'message' textbox) and current history (from 'chatbot')
1227
+ # - Outputs: Clears the 'message' textbox, updates 'chatbot' history with user message
1228
+ # 2. After user_input completes, call gradio_stream:
1229
+ # - Takes the *original* user message (important!) and the *updated* history from step 1.
1230
+ # - Outputs: Streams updates back to the 'chatbot' component.
1231
+ send.click(user_input, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\
1232
+ then(gradio_stream, inputs=[message, chatbot], outputs=chatbot)
1233
+
1234
+ # Optional: Allow submitting with Enter key
1235
+ message.submit(user_input, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\
1236
+ then(gradio_stream, inputs=[message, chatbot], outputs=chatbot)
1237
 
 
 
 
1238
 
1239
+ print("Gradio Blocks defined.")
 
1240
 
1241
+ # --- Launch the Application ---
1242
+ if __name__ == "__main__":
1243
+ print("Launching Gradio app...")
1244
+ # Set share=False when running locally or in environments like Spaces where it's handled differently
1245
+ # Set debug=True for more detailed logs during development
1246
+ demo.queue() # Enable queue for handling multiple users/requests
1247
+ demo.launch(share=False, debug=True) # share=True can cause issues in some environments