Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
from langchain_mistralai.chat_models import ChatMistralAI
|
3 |
from langchain.prompts import ChatPromptTemplate
|
@@ -19,44 +464,63 @@ from sentence_transformers.cross_encoder import CrossEncoder
|
|
19 |
import threading
|
20 |
from queue import Queue
|
21 |
import concurrent.futures
|
22 |
-
from typing import Generator, Tuple, Iterator
|
23 |
import time
|
24 |
|
|
|
25 |
class OptimizedRAGLoader:
|
26 |
def __init__(self,
|
27 |
docs_folder: str = "./docs",
|
28 |
splits_folder: str = "./splits",
|
29 |
index_folder: str = "./index"):
|
30 |
-
|
31 |
self.docs_folder = Path(docs_folder)
|
32 |
self.splits_folder = Path(splits_folder)
|
33 |
self.index_folder = Path(index_folder)
|
34 |
-
|
35 |
# Create folders if they don't exist
|
36 |
for folder in [self.splits_folder, self.index_folder]:
|
37 |
folder.mkdir(parents=True, exist_ok=True)
|
38 |
-
|
39 |
# File paths
|
40 |
self.splits_path = self.splits_folder / "splits.json"
|
41 |
self.index_path = self.index_folder / "faiss.index"
|
42 |
self.documents_path = self.index_folder / "documents.pkl"
|
43 |
-
|
44 |
# Initialize components
|
45 |
self.index = None
|
46 |
self.indexed_documents = None
|
47 |
-
|
48 |
# Initialize encoder model
|
|
|
49 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
50 |
self.encoder = SentenceTransformer("intfloat/multilingual-e5-large")
|
51 |
self.encoder.to(self.device)
|
52 |
-
|
53 |
-
|
|
|
|
|
54 |
# Initialize thread pool
|
55 |
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
|
56 |
-
|
57 |
# Initialize response cache
|
58 |
self.response_cache = {}
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
@lru_cache(maxsize=1000)
|
61 |
def encode(self, text: str):
|
62 |
"""Cached encoding function"""
|
@@ -64,362 +528,720 @@ class OptimizedRAGLoader:
|
|
64 |
embeddings = self.encoder.encode(
|
65 |
text,
|
66 |
convert_to_numpy=True,
|
67 |
-
normalize_embeddings=True
|
|
|
68 |
)
|
69 |
return embeddings
|
70 |
-
|
71 |
def batch_encode(self, texts: list):
|
72 |
"""Batch encoding for multiple texts"""
|
73 |
with torch.no_grad():
|
74 |
embeddings = self.encoder.encode(
|
75 |
texts,
|
76 |
-
batch_size=32,
|
77 |
convert_to_numpy=True,
|
78 |
normalize_embeddings=True,
|
79 |
-
show_progress_bar=
|
|
|
80 |
)
|
81 |
return embeddings
|
82 |
|
83 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
if self._splits_exist():
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
87 |
documents = []
|
88 |
futures = []
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
future = self.executor.submit(self._process_file, file_path)
|
92 |
futures.append(future)
|
93 |
-
|
|
|
94 |
for future in concurrent.futures.as_completed(futures):
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
return documents
|
99 |
-
|
100 |
-
def _process_file(self, file_path):
|
101 |
-
with open(file_path, 'r', encoding='utf-8') as file:
|
102 |
-
text = file.read()
|
103 |
-
chunks = [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()]
|
104 |
-
|
105 |
-
return [
|
106 |
-
Document(
|
107 |
-
page_content=chunk,
|
108 |
-
metadata={
|
109 |
-
'source': file_path.name,
|
110 |
-
'chunk_id': i,
|
111 |
-
'total_chunks': len(chunks)
|
112 |
-
}
|
113 |
-
)
|
114 |
-
for i, chunk in enumerate(chunks)
|
115 |
-
]
|
116 |
|
117 |
-
def
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
-
|
122 |
-
|
123 |
-
"""
|
124 |
if not self._index_exists():
|
125 |
-
print("
|
126 |
return False
|
127 |
|
128 |
-
print("
|
129 |
try:
|
130 |
-
#
|
131 |
self.index = faiss.read_index(str(self.index_path))
|
132 |
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
with open(self.documents_path, 'rb') as f:
|
135 |
self.indexed_documents = pickle.load(f)
|
136 |
|
137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
return True
|
139 |
|
|
|
|
|
|
|
|
|
|
|
140 |
except Exception as e:
|
141 |
-
print(f"
|
|
|
|
|
|
|
142 |
return False
|
143 |
|
144 |
-
def create_index(self, documents=None):
|
|
|
145 |
if documents is None:
|
146 |
documents = self.load_and_split_texts()
|
147 |
-
|
148 |
if not documents:
|
|
|
149 |
return False
|
150 |
-
|
|
|
151 |
texts = [doc.page_content for doc in documents]
|
|
|
152 |
embeddings = self.batch_encode(texts)
|
153 |
-
|
|
|
|
|
|
|
|
|
154 |
dimension = embeddings.shape[1]
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
157 |
if torch.cuda.is_available():
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
self.indexed_documents = documents
|
164 |
-
|
165 |
-
# Save
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
def _index_exists(self) -> bool:
|
175 |
-
"""
|
176 |
return self.index_path.exists() and self.documents_path.exists()
|
177 |
|
178 |
-
def get_retriever(self, k: int = 10):
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
|
184 |
-
def retriever_function(query: str) -> list:
|
185 |
-
# Check cache first
|
186 |
-
cache_key = f"{query}_{
|
187 |
-
if cache_key in self.response_cache:
|
188 |
-
|
189 |
|
|
|
190 |
query_embedding = self.encode(query)
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
self.indexed_documents[idx]
|
199 |
for idx in indices[0]
|
200 |
-
if idx != -1
|
201 |
]
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
return retriever_function
|
208 |
|
209 |
-
#
|
210 |
-
#
|
211 |
-
#
|
212 |
-
#
|
213 |
-
#
|
214 |
-
#
|
215 |
-
#
|
216 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
|
218 |
# deepseek_api_key = os.getenv("DEEPSEEK_KEY")
|
219 |
-
#
|
220 |
-
#
|
221 |
-
#
|
222 |
-
#
|
223 |
-
#
|
224 |
-
#
|
|
|
|
|
|
|
|
|
225 |
|
226 |
|
227 |
gemini_api_key = os.getenv("GEMINI_KEY")
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
|
236 |
-
rag_loader = OptimizedRAGLoader()
|
237 |
-
retriever = rag_loader.get_retriever(k=5) # Reduced k for faster retrieval
|
238 |
|
239 |
-
# Cache for processed questions
|
240 |
question_cache = {}
|
241 |
|
|
|
|
|
242 |
prompt_template = ChatPromptTemplate.from_messages([
|
243 |
-
("system", """
|
244 |
-
|
245 |
-
PROCESSUS D'ANALYSE :
|
246 |
-
1. Analysez le contexte fourni : {context}
|
247 |
-
2. Utilisez la recherche web si la reponse n'existe pas dans le contexte
|
248 |
-
3. Privilégiez les sources officielles et la jurisprudence récente
|
249 |
-
|
250 |
-
Question à traiter : {question}
|
251 |
-
"""),
|
252 |
-
("human", "{question}")
|
253 |
-
])
|
254 |
-
|
255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
|
257 |
-
|
258 |
-
|
|
|
259 |
|
260 |
-
#
|
|
|
261 |
css = """
|
262 |
/* Reset RTL global */
|
|
|
|
|
|
|
|
|
|
|
263 |
*, *::before, *::after {
|
264 |
direction: rtl !important;
|
265 |
text-align: right !important;
|
266 |
}
|
267 |
-
|
268 |
body {
|
269 |
-
font-family: '
|
270 |
-
background-color:
|
271 |
-
color:
|
272 |
-
direction: rtl !important;
|
273 |
}
|
274 |
-
|
275 |
.gradio-container {
|
276 |
-
direction: rtl !important;
|
|
|
277 |
}
|
278 |
-
|
279 |
-
|
280 |
-
input
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
width: 80%;
|
288 |
-
margin: 0 auto;
|
289 |
text-align: right !important;
|
|
|
|
|
|
|
290 |
}
|
291 |
|
292 |
-
/*
|
293 |
-
|
294 |
-
textarea::placeholder {
|
295 |
text-align: right !important;
|
296 |
direction: rtl !important;
|
|
|
297 |
}
|
298 |
|
299 |
-
/*
|
300 |
.gradio-button {
|
301 |
-
border-radius:
|
302 |
-
font-size: 16px;
|
303 |
-
|
304 |
-
color:
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
|
|
|
|
|
|
|
|
310 |
}
|
311 |
-
|
312 |
.gradio-button:hover {
|
313 |
-
background-color: #0056b3;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
}
|
315 |
|
316 |
.gradio-chatbot .message {
|
317 |
-
border-radius:
|
318 |
-
padding: 10px;
|
319 |
-
margin:
|
320 |
-
|
321 |
-
border:
|
322 |
-
|
323 |
text-align: right !important;
|
324 |
direction: rtl !important;
|
|
|
|
|
325 |
}
|
326 |
|
327 |
-
/* Messages
|
328 |
-
.gradio-chatbot .user-message {
|
329 |
-
margin-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
}
|
334 |
|
335 |
-
/* Messages
|
336 |
-
.gradio-chatbot .assistant-message {
|
337 |
-
margin-right: auto;
|
338 |
-
|
339 |
-
|
|
|
340 |
}
|
341 |
|
342 |
-
/*
|
343 |
-
|
344 |
-
|
|
|
|
|
|
|
|
|
345 |
}
|
346 |
|
347 |
-
|
|
|
348 |
text-align: right !important;
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
}
|
350 |
"""
|
351 |
|
352 |
-
#
|
353 |
def process_question(question: str) -> Iterator[str]:
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
|
|
359 |
return
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
prompt = prompt_template.format_messages(
|
373 |
-
context=reranked_docs,
|
374 |
-
question=question
|
375 |
-
)
|
376 |
-
|
377 |
-
full_response = ""
|
378 |
try:
|
379 |
-
|
380 |
-
|
381 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
else:
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
392 |
except Exception as e:
|
393 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
394 |
|
395 |
-
# Updated gradio_stream function to work with tuples
|
396 |
-
def gradio_stream(question: str, chat_history: list) -> Iterator[list]:
|
397 |
try:
|
|
|
398 |
for partial_response in process_question(question):
|
399 |
-
#
|
400 |
-
|
|
|
401 |
except Exception as e:
|
402 |
-
|
|
|
|
|
|
|
403 |
|
404 |
-
|
405 |
-
|
406 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
407 |
|
408 |
with gr.Row():
|
409 |
-
message = gr.Textbox(
|
410 |
-
|
411 |
-
|
412 |
-
|
|
|
|
|
413 |
|
414 |
with gr.Row():
|
415 |
-
|
416 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
417 |
|
418 |
-
# Simplified user_input function
|
419 |
-
def user_input(user_message, chat_history):
|
420 |
-
return "", chat_history + [(user_message, None)]
|
421 |
|
422 |
-
|
423 |
-
send.click(gradio_stream, [message, chatbot], chatbot)
|
424 |
|
425 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import gradio as gr
|
2 |
+
# from langchain_mistralai.chat_models import ChatMistralAI
|
3 |
+
# from langchain.prompts import ChatPromptTemplate
|
4 |
+
# from langchain_deepseek import ChatDeepSeek
|
5 |
+
# from langchain_google_genai import ChatGoogleGenerativeAI
|
6 |
+
# import os
|
7 |
+
# from pathlib import Path
|
8 |
+
# import json
|
9 |
+
# import faiss
|
10 |
+
# import numpy as np
|
11 |
+
# from langchain.schema import Document
|
12 |
+
# import pickle
|
13 |
+
# import re
|
14 |
+
# import requests
|
15 |
+
# from functools import lru_cache
|
16 |
+
# import torch
|
17 |
+
# from sentence_transformers import SentenceTransformer
|
18 |
+
# from sentence_transformers.cross_encoder import CrossEncoder
|
19 |
+
# import threading
|
20 |
+
# from queue import Queue
|
21 |
+
# import concurrent.futures
|
22 |
+
# from typing import Generator, Tuple, Iterator
|
23 |
+
# import time
|
24 |
+
|
25 |
+
# class OptimizedRAGLoader:
|
26 |
+
# def __init__(self,
|
27 |
+
# docs_folder: str = "./docs",
|
28 |
+
# splits_folder: str = "./splits",
|
29 |
+
# index_folder: str = "./index"):
|
30 |
+
|
31 |
+
# self.docs_folder = Path(docs_folder)
|
32 |
+
# self.splits_folder = Path(splits_folder)
|
33 |
+
# self.index_folder = Path(index_folder)
|
34 |
+
|
35 |
+
# # Create folders if they don't exist
|
36 |
+
# for folder in [self.splits_folder, self.index_folder]:
|
37 |
+
# folder.mkdir(parents=True, exist_ok=True)
|
38 |
+
|
39 |
+
# # File paths
|
40 |
+
# self.splits_path = self.splits_folder / "splits.json"
|
41 |
+
# self.index_path = self.index_folder / "faiss.index"
|
42 |
+
# self.documents_path = self.index_folder / "documents.pkl"
|
43 |
+
|
44 |
+
# # Initialize components
|
45 |
+
# self.index = None
|
46 |
+
# self.indexed_documents = None
|
47 |
+
|
48 |
+
# # Initialize encoder model
|
49 |
+
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
50 |
+
# self.encoder = SentenceTransformer("intfloat/multilingual-e5-large")
|
51 |
+
# self.encoder.to(self.device)
|
52 |
+
# self.reranker = model = CrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1",trust_remote_code=True)
|
53 |
+
|
54 |
+
# # Initialize thread pool
|
55 |
+
# self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
|
56 |
+
|
57 |
+
# # Initialize response cache
|
58 |
+
# self.response_cache = {}
|
59 |
+
|
60 |
+
# @lru_cache(maxsize=1000)
|
61 |
+
# def encode(self, text: str):
|
62 |
+
# """Cached encoding function"""
|
63 |
+
# with torch.no_grad():
|
64 |
+
# embeddings = self.encoder.encode(
|
65 |
+
# text,
|
66 |
+
# convert_to_numpy=True,
|
67 |
+
# normalize_embeddings=True
|
68 |
+
# )
|
69 |
+
# return embeddings
|
70 |
+
|
71 |
+
# def batch_encode(self, texts: list):
|
72 |
+
# """Batch encoding for multiple texts"""
|
73 |
+
# with torch.no_grad():
|
74 |
+
# embeddings = self.encoder.encode(
|
75 |
+
# texts,
|
76 |
+
# batch_size=32,
|
77 |
+
# convert_to_numpy=True,
|
78 |
+
# normalize_embeddings=True,
|
79 |
+
# show_progress_bar=False
|
80 |
+
# )
|
81 |
+
# return embeddings
|
82 |
+
|
83 |
+
# def load_and_split_texts(self):
|
84 |
+
# if self._splits_exist():
|
85 |
+
# return self._load_existing_splits()
|
86 |
+
|
87 |
+
# documents = []
|
88 |
+
# futures = []
|
89 |
+
|
90 |
+
# for file_path in self.docs_folder.glob("*.txt"):
|
91 |
+
# future = self.executor.submit(self._process_file, file_path)
|
92 |
+
# futures.append(future)
|
93 |
+
|
94 |
+
# for future in concurrent.futures.as_completed(futures):
|
95 |
+
# documents.extend(future.result())
|
96 |
+
|
97 |
+
# self._save_splits(documents)
|
98 |
+
# return documents
|
99 |
+
|
100 |
+
# def _process_file(self, file_path):
|
101 |
+
# with open(file_path, 'r', encoding='utf-8') as file:
|
102 |
+
# text = file.read()
|
103 |
+
# chunks = [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()]
|
104 |
+
|
105 |
+
# return [
|
106 |
+
# Document(
|
107 |
+
# page_content=chunk,
|
108 |
+
# metadata={
|
109 |
+
# 'source': file_path.name,
|
110 |
+
# 'chunk_id': i,
|
111 |
+
# 'total_chunks': len(chunks)
|
112 |
+
# }
|
113 |
+
# )
|
114 |
+
# for i, chunk in enumerate(chunks)
|
115 |
+
# ]
|
116 |
+
|
117 |
+
# def load_index(self) -> bool:
|
118 |
+
# """
|
119 |
+
# Charge l'index FAISS et les documents associés s'ils existent
|
120 |
+
|
121 |
+
# Returns:
|
122 |
+
# bool: True si l'index a été chargé, False sinon
|
123 |
+
# """
|
124 |
+
# if not self._index_exists():
|
125 |
+
# print("Aucun index trouvé.")
|
126 |
+
# return False
|
127 |
+
|
128 |
+
# print("Chargement de l'index existant...")
|
129 |
+
# try:
|
130 |
+
# # Charger l'index FAISS
|
131 |
+
# self.index = faiss.read_index(str(self.index_path))
|
132 |
+
|
133 |
+
# # Charger les documents associés
|
134 |
+
# with open(self.documents_path, 'rb') as f:
|
135 |
+
# self.indexed_documents = pickle.load(f)
|
136 |
+
|
137 |
+
# print(f"Index chargé avec {self.index.ntotal} vecteurs")
|
138 |
+
# return True
|
139 |
+
|
140 |
+
# except Exception as e:
|
141 |
+
# print(f"Erreur lors du chargement de l'index: {e}")
|
142 |
+
# return False
|
143 |
+
|
144 |
+
# def create_index(self, documents=None):
|
145 |
+
# if documents is None:
|
146 |
+
# documents = self.load_and_split_texts()
|
147 |
+
|
148 |
+
# if not documents:
|
149 |
+
# return False
|
150 |
+
|
151 |
+
# texts = [doc.page_content for doc in documents]
|
152 |
+
# embeddings = self.batch_encode(texts)
|
153 |
+
|
154 |
+
# dimension = embeddings.shape[1]
|
155 |
+
# self.index = faiss.IndexFlatL2(dimension)
|
156 |
+
|
157 |
+
# if torch.cuda.is_available():
|
158 |
+
# # Use GPU for FAISS if available
|
159 |
+
# res = faiss.StandardGpuResources()
|
160 |
+
# self.index = faiss.index_cpu_to_gpu(res, 0, self.index)
|
161 |
+
|
162 |
+
# self.index.add(np.array(embeddings).astype('float32'))
|
163 |
+
# self.indexed_documents = documents
|
164 |
+
|
165 |
+
# # Save index and documents
|
166 |
+
# cpu_index = faiss.index_gpu_to_cpu(self.index) if torch.cuda.is_available() else self.index
|
167 |
+
# faiss.write_index(cpu_index, str(self.index_path))
|
168 |
+
|
169 |
+
# with open(self.documents_path, 'wb') as f:
|
170 |
+
# pickle.dump(documents, f)
|
171 |
+
|
172 |
+
# return True
|
173 |
+
|
174 |
+
# def _index_exists(self) -> bool:
|
175 |
+
# """Vérifie si l'index et les documents associés existent"""
|
176 |
+
# return self.index_path.exists() and self.documents_path.exists()
|
177 |
+
|
178 |
+
# def get_retriever(self, k: int = 10):
|
179 |
+
# if self.index is None:
|
180 |
+
# if not self.load_index():
|
181 |
+
# if not self.create_index():
|
182 |
+
# raise ValueError("Unable to load or create index")
|
183 |
+
|
184 |
+
# def retriever_function(query: str) -> list:
|
185 |
+
# # Check cache first
|
186 |
+
# cache_key = f"{query}_{k}"
|
187 |
+
# if cache_key in self.response_cache:
|
188 |
+
# return self.response_cache[cache_key]
|
189 |
+
|
190 |
+
# query_embedding = self.encode(query)
|
191 |
+
|
192 |
+
# distances, indices = self.index.search(
|
193 |
+
# np.array([query_embedding]).astype('float32'),
|
194 |
+
# k
|
195 |
+
# )
|
196 |
+
|
197 |
+
# results = [
|
198 |
+
# self.indexed_documents[idx]
|
199 |
+
# for idx in indices[0]
|
200 |
+
# if idx != -1
|
201 |
+
# ]
|
202 |
+
|
203 |
+
# # Cache the results
|
204 |
+
# self.response_cache[cache_key] = results
|
205 |
+
# return results
|
206 |
+
|
207 |
+
# return retriever_function
|
208 |
+
|
209 |
+
# # # Initialize components
|
210 |
+
# # mistral_api_key = os.getenv("mistral_api_key")
|
211 |
+
# # llm = ChatMistralAI(
|
212 |
+
# # model="mistral-large-latest",
|
213 |
+
# # mistral_api_key=mistral_api_key,
|
214 |
+
# # temperature=0.01,
|
215 |
+
# # streaming=True,
|
216 |
+
# # )
|
217 |
+
|
218 |
+
# # deepseek_api_key = os.getenv("DEEPSEEK_KEY")
|
219 |
+
# # llm = ChatDeepSeek(
|
220 |
+
# # model="deepseek-chat",
|
221 |
+
# # temperature=0,
|
222 |
+
# # api_key=deepseek_api_key,
|
223 |
+
# # streaming=True,
|
224 |
+
# # )
|
225 |
+
|
226 |
+
|
227 |
+
# gemini_api_key = os.getenv("GEMINI_KEY")
|
228 |
+
# llm = ChatGoogleGenerativeAI(
|
229 |
+
# model="gemini-1.5-pro",
|
230 |
+
# temperature=0,
|
231 |
+
# google_api_key=gemini_api_key,
|
232 |
+
# disable_streaming=True,
|
233 |
+
# )
|
234 |
+
|
235 |
+
|
236 |
+
# rag_loader = OptimizedRAGLoader()
|
237 |
+
# retriever = rag_loader.get_retriever(k=5) # Reduced k for faster retrieval
|
238 |
+
|
239 |
+
# # Cache for processed questions
|
240 |
+
# question_cache = {}
|
241 |
+
|
242 |
+
# prompt_template = ChatPromptTemplate.from_messages([
|
243 |
+
# ("system", """Vous êtes un assistant juridique expert qualifié. Analysez et répondez aux questions juridiques avec précision.
|
244 |
+
|
245 |
+
# PROCESSUS D'ANALYSE :
|
246 |
+
# 1. Analysez le contexte fourni : {context}
|
247 |
+
# 2. Utilisez la recherche web si la reponse n'existe pas dans le contexte
|
248 |
+
# 3. Privilégiez les sources officielles et la jurisprudence récente
|
249 |
+
|
250 |
+
# Question à traiter : {question}
|
251 |
+
# """),
|
252 |
+
# ("human", "{question}")
|
253 |
+
# ])
|
254 |
+
|
255 |
+
|
256 |
+
|
257 |
+
# import gradio as gr
|
258 |
+
|
259 |
+
|
260 |
+
# # Ajouter du CSS pour personnaliser l'apparence
|
261 |
+
# css = """
|
262 |
+
# /* Reset RTL global */
|
263 |
+
# *, *::before, *::after {
|
264 |
+
# direction: rtl !important;
|
265 |
+
# text-align: right !important;
|
266 |
+
# }
|
267 |
+
|
268 |
+
# body {
|
269 |
+
# font-family: 'Amiri', sans-serif; /* Utilisation de la police Arabe andalouse */
|
270 |
+
# background-color: black; /* Fond blanc */
|
271 |
+
# color: black !important; /* Texte noir */
|
272 |
+
# direction: rtl !important; /* Texte en arabe aligné à droite */
|
273 |
+
# }
|
274 |
+
|
275 |
+
# .gradio-container {
|
276 |
+
# direction: rtl !important; /* Alignement RTL pour toute l'interface */
|
277 |
+
# }
|
278 |
+
|
279 |
+
# /* Éléments de formulaire */
|
280 |
+
# input[type="text"],
|
281 |
+
# .gradio-textbox input,
|
282 |
+
# textarea {
|
283 |
+
# border-radius: 20px;
|
284 |
+
# padding: 10px 15px;
|
285 |
+
# border: 2px solid #000;
|
286 |
+
# font-size: 16px;
|
287 |
+
# width: 80%;
|
288 |
+
# margin: 0 auto;
|
289 |
+
# text-align: right !important;
|
290 |
+
# }
|
291 |
+
|
292 |
+
# /* Surcharge des styles de placeholder */
|
293 |
+
# input::placeholder,
|
294 |
+
# textarea::placeholder {
|
295 |
+
# text-align: right !important;
|
296 |
+
# direction: rtl !important;
|
297 |
+
# }
|
298 |
+
|
299 |
+
# /* Boutons */
|
300 |
+
# .gradio-button {
|
301 |
+
# border-radius: 20px;
|
302 |
+
# font-size: 16px;
|
303 |
+
# background-color: #007BFF;
|
304 |
+
# color: white;
|
305 |
+
# padding: 10px 20px;
|
306 |
+
# margin: 10px auto;
|
307 |
+
# border: none;
|
308 |
+
# width: 80%;
|
309 |
+
# display: block;
|
310 |
+
# }
|
311 |
+
|
312 |
+
# .gradio-button:hover {
|
313 |
+
# background-color: #0056b3;
|
314 |
+
# }
|
315 |
+
|
316 |
+
# .gradio-chatbot .message {
|
317 |
+
# border-radius: 20px;
|
318 |
+
# padding: 10px;
|
319 |
+
# margin: 10px 0;
|
320 |
+
# background-color: #f1f1f1;
|
321 |
+
# border: 1px solid #ddd;
|
322 |
+
# width: 80%;
|
323 |
+
# text-align: right !important;
|
324 |
+
# direction: rtl !important;
|
325 |
+
# }
|
326 |
+
|
327 |
+
# /* Messages utilisateur alignés à gauche */
|
328 |
+
# .gradio-chatbot .user-message {
|
329 |
+
# margin-right: auto;
|
330 |
+
# background-color: #e3f2fd;
|
331 |
+
# text-align: right !important;
|
332 |
+
# direction: rtl !important;
|
333 |
+
# }
|
334 |
+
|
335 |
+
# /* Messages assistant alignés à droite */
|
336 |
+
# .gradio-chatbot .assistant-message {
|
337 |
+
# margin-right: auto;
|
338 |
+
# background-color: #f1f1f1;
|
339 |
+
# text-align: right
|
340 |
+
# }
|
341 |
+
|
342 |
+
# /* Corrections RTL pour les éléments spécifiques */
|
343 |
+
# .gradio-textbox textarea {
|
344 |
+
# text-align: right !important;
|
345 |
+
# }
|
346 |
+
|
347 |
+
# .gradio-dropdown div {
|
348 |
+
# text-align: right !important;
|
349 |
+
# }
|
350 |
+
# """
|
351 |
+
|
352 |
+
# # Modified process_question function to better work with tuples
|
353 |
+
# def process_question(question: str) -> Iterator[str]:
|
354 |
+
# if question in question_cache:
|
355 |
+
# response, docs = question_cache[question]
|
356 |
+
# sources = [doc.metadata.get("source") for doc in docs]
|
357 |
+
# sources = list(set([os.path.splitext(source)[0] for source in sources]))
|
358 |
+
# yield response + "\n\n\nالمصادر المحتملة :\n" + "\n".join(sources)
|
359 |
+
# return
|
360 |
+
|
361 |
+
# relevant_docs = retriever(question)
|
362 |
+
|
363 |
+
# # Reranking with cross-encoder
|
364 |
+
# context = [doc.page_content for doc in relevant_docs]
|
365 |
+
# text_pairs = [[question, text] for text in context]
|
366 |
+
# scores = rag_loader.reranker.predict(text_pairs)
|
367 |
+
|
368 |
+
# scored_docs = list(zip(scores, context, relevant_docs))
|
369 |
+
# scored_docs.sort(key=lambda x: x[0], reverse=True)
|
370 |
+
# reranked_docs = [d[2].page_content for d in scored_docs][:10]
|
371 |
+
|
372 |
+
# prompt = prompt_template.format_messages(
|
373 |
+
# context=reranked_docs,
|
374 |
+
# question=question
|
375 |
+
# )
|
376 |
+
|
377 |
+
# full_response = ""
|
378 |
+
# try:
|
379 |
+
# for chunk in llm.stream(prompt):
|
380 |
+
# if isinstance(chunk, str):
|
381 |
+
# current_chunk = chunk
|
382 |
+
# else:
|
383 |
+
# current_chunk = chunk.content
|
384 |
+
# full_response += current_chunk
|
385 |
+
|
386 |
+
# sources = [d[2].metadata['source'] for d in scored_docs][:10]
|
387 |
+
# sources = list(set([os.path.splitext(source)[0] for source in sources]))
|
388 |
+
|
389 |
+
# yield full_response + "\n\n\nالمصادر المحتملة :\n" + "\n".join(sources)
|
390 |
+
|
391 |
+
# question_cache[question] = (full_response, relevant_docs)
|
392 |
+
# except Exception as e:
|
393 |
+
# yield f"Erreur lors du traitement : {str(e)}"
|
394 |
+
|
395 |
+
# # Updated gradio_stream function for 'messages' format
|
396 |
+
# def gradio_stream(question: str, chat_history: list) -> Iterator[list]:
|
397 |
+
# # chat_history now contains the user message added by user_input
|
398 |
+
# # Add a placeholder for the assistant's response
|
399 |
+
# chat_history.append({"role": "assistant", "content": ""})
|
400 |
+
|
401 |
+
# try:
|
402 |
+
# # Stream the response using the existing process_question generator
|
403 |
+
# for partial_response in process_question(question):
|
404 |
+
# # Update the content of the last message (the assistant's placeholder)
|
405 |
+
# chat_history[-1]["content"] = partial_response
|
406 |
+
# yield chat_history # Yield the entire updated history list
|
407 |
+
# except Exception as e:
|
408 |
+
# # Update the assistant's message with the error
|
409 |
+
# chat_history[-1]["content"] = f"Erreur : {str(e)}"
|
410 |
+
# yield chat_history # Yield the history with the error message
|
411 |
+
|
412 |
+
# # Gradio interface
|
413 |
+
# with gr.Blocks(css=css) as demo:
|
414 |
+
# gr.Markdown("<h2 style='text-align: center !important;'>هذا تطبيق للاجابة على الأسئلة المتعلقة بالقوانين المغربية</h2>")
|
415 |
+
|
416 |
+
# with gr.Row():
|
417 |
+
# message = gr.Textbox(label="أدخل سؤالك", placeholder="اكتب سؤالك هنا", elem_id="question_input")
|
418 |
+
|
419 |
+
# with gr.Row():
|
420 |
+
# send = gr.Button("بحث", elem_id="search_button")
|
421 |
+
|
422 |
+
# with gr.Row():
|
423 |
+
# # No type parameter - use Gradio's default
|
424 |
+
# chatbot = gr.Chatbot(label="", type="messages") # Ajout de type="messages"
|
425 |
+
|
426 |
+
# # Updated user_input function for 'messages' format
|
427 |
+
# def user_input(user_message, chat_history):
|
428 |
+
# # chat_history is already a list of message dicts
|
429 |
+
# # Append the new user message
|
430 |
+
# return "", chat_history + [{"role": "user", "content": user_message}]
|
431 |
+
|
432 |
+
# send.click(user_input, [message, chatbot], [message, chatbot], queue=False)
|
433 |
+
# send.click(gradio_stream, [message, chatbot], chatbot)
|
434 |
+
|
435 |
+
# demo.launch(share=True)
|
436 |
+
|
437 |
+
|
438 |
+
|
439 |
+
|
440 |
+
|
441 |
+
|
442 |
+
|
443 |
+
|
444 |
+
|
445 |
+
|
446 |
import gradio as gr
|
447 |
from langchain_mistralai.chat_models import ChatMistralAI
|
448 |
from langchain.prompts import ChatPromptTemplate
|
|
|
464 |
import threading
|
465 |
from queue import Queue
|
466 |
import concurrent.futures
|
467 |
+
from typing import Generator, Tuple, Iterator, List, Dict
|
468 |
import time
|
469 |
|
470 |
+
# --- (Votre classe OptimizedRAGLoader reste la même) ---
|
471 |
class OptimizedRAGLoader:
|
472 |
def __init__(self,
|
473 |
docs_folder: str = "./docs",
|
474 |
splits_folder: str = "./splits",
|
475 |
index_folder: str = "./index"):
|
476 |
+
|
477 |
self.docs_folder = Path(docs_folder)
|
478 |
self.splits_folder = Path(splits_folder)
|
479 |
self.index_folder = Path(index_folder)
|
480 |
+
|
481 |
# Create folders if they don't exist
|
482 |
for folder in [self.splits_folder, self.index_folder]:
|
483 |
folder.mkdir(parents=True, exist_ok=True)
|
484 |
+
|
485 |
# File paths
|
486 |
self.splits_path = self.splits_folder / "splits.json"
|
487 |
self.index_path = self.index_folder / "faiss.index"
|
488 |
self.documents_path = self.index_folder / "documents.pkl"
|
489 |
+
|
490 |
# Initialize components
|
491 |
self.index = None
|
492 |
self.indexed_documents = None
|
493 |
+
|
494 |
# Initialize encoder model
|
495 |
+
print("Loading Sentence Transformer...")
|
496 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
497 |
self.encoder = SentenceTransformer("intfloat/multilingual-e5-large")
|
498 |
self.encoder.to(self.device)
|
499 |
+
print("Loading Cross Encoder...")
|
500 |
+
self.reranker = CrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1", trust_remote_code=True)
|
501 |
+
print("Models loaded.")
|
502 |
+
|
503 |
# Initialize thread pool
|
504 |
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
|
505 |
+
|
506 |
# Initialize response cache
|
507 |
self.response_cache = {}
|
508 |
+
|
509 |
+
# Try loading index on init
|
510 |
+
self.load_or_create_index()
|
511 |
+
|
512 |
+
def load_or_create_index(self):
|
513 |
+
"""Loads index if exists, otherwise creates it."""
|
514 |
+
if not self.load_index():
|
515 |
+
print("Index not found, creating new index...")
|
516 |
+
if not self.create_index():
|
517 |
+
raise RuntimeError("Failed to create index.")
|
518 |
+
else:
|
519 |
+
print("Index created successfully.")
|
520 |
+
else:
|
521 |
+
print("Index loaded successfully.")
|
522 |
+
|
523 |
+
|
524 |
@lru_cache(maxsize=1000)
|
525 |
def encode(self, text: str):
|
526 |
"""Cached encoding function"""
|
|
|
528 |
embeddings = self.encoder.encode(
|
529 |
text,
|
530 |
convert_to_numpy=True,
|
531 |
+
normalize_embeddings=True,
|
532 |
+
device=self.device # Ensure encoding runs on the correct device
|
533 |
)
|
534 |
return embeddings
|
535 |
+
|
536 |
def batch_encode(self, texts: list):
|
537 |
"""Batch encoding for multiple texts"""
|
538 |
with torch.no_grad():
|
539 |
embeddings = self.encoder.encode(
|
540 |
texts,
|
541 |
+
batch_size=32, # Adjust based on GPU memory
|
542 |
convert_to_numpy=True,
|
543 |
normalize_embeddings=True,
|
544 |
+
show_progress_bar=True, # Show progress for potentially long operations
|
545 |
+
device=self.device # Ensure encoding runs on the correct device
|
546 |
)
|
547 |
return embeddings
|
548 |
|
549 |
+
def _splits_exist(self) -> bool:
|
550 |
+
"""Check if split files exist."""
|
551 |
+
return self.splits_path.exists()
|
552 |
+
|
553 |
+
def _load_existing_splits(self) -> List[Document]:
|
554 |
+
"""Load splits from JSON file."""
|
555 |
+
print(f"Loading existing splits from {self.splits_path}...")
|
556 |
+
try:
|
557 |
+
with open(self.splits_path, 'r', encoding='utf-8') as f:
|
558 |
+
splits_data = json.load(f)
|
559 |
+
documents = [
|
560 |
+
Document(page_content=item['page_content'], metadata=item['metadata'])
|
561 |
+
for item in splits_data
|
562 |
+
]
|
563 |
+
print(f"Loaded {len(documents)} splits.")
|
564 |
+
return documents
|
565 |
+
except Exception as e:
|
566 |
+
print(f"Error loading splits: {e}. Recreating...")
|
567 |
+
return [] # Return empty list to trigger recreation
|
568 |
+
|
569 |
+
def _save_splits(self, documents: List[Document]):
|
570 |
+
"""Save splits to JSON file."""
|
571 |
+
print(f"Saving {len(documents)} splits to {self.splits_path}...")
|
572 |
+
splits_data = [
|
573 |
+
{'page_content': doc.page_content, 'metadata': doc.metadata}
|
574 |
+
for doc in documents
|
575 |
+
]
|
576 |
+
try:
|
577 |
+
with open(self.splits_path, 'w', encoding='utf-8') as f:
|
578 |
+
json.dump(splits_data, f, ensure_ascii=False, indent=2)
|
579 |
+
print("Splits saved successfully.")
|
580 |
+
except Exception as e:
|
581 |
+
print(f"Error saving splits: {e}")
|
582 |
+
|
583 |
+
|
584 |
+
def load_and_split_texts(self) -> List[Document]:
|
585 |
if self._splits_exist():
|
586 |
+
loaded_splits = self._load_existing_splits()
|
587 |
+
if loaded_splits: # Check if loading was successful
|
588 |
+
return loaded_splits
|
589 |
+
|
590 |
+
print("Processing documents and creating splits...")
|
591 |
documents = []
|
592 |
futures = []
|
593 |
+
|
594 |
+
# Ensure docs folder exists
|
595 |
+
if not self.docs_folder.is_dir():
|
596 |
+
print(f"Error: Docs folder not found at {self.docs_folder}")
|
597 |
+
# Create dummy docs folder for Spaces if it doesn't exist
|
598 |
+
self.docs_folder.mkdir(parents=True, exist_ok=True)
|
599 |
+
print(f"Created empty docs folder at {self.docs_folder}. Please upload text files.")
|
600 |
+
# You might want to add a default dummy file here for testing
|
601 |
+
# with open(self.docs_folder / "dummy.txt", "w") as f:
|
602 |
+
# f.write("This is a dummy file. Please replace with real legal documents.")
|
603 |
+
# return [] # Return empty if no real docs
|
604 |
+
# Or let it continue to process the dummy file if created
|
605 |
+
|
606 |
+
doc_files = list(self.docs_folder.glob("*.txt"))
|
607 |
+
if not doc_files:
|
608 |
+
print(f"No .txt files found in {self.docs_folder}. Cannot create index.")
|
609 |
+
# Add a dummy document if none exist to prevent errors downstream?
|
610 |
+
# Or handle this case in create_index more gracefully.
|
611 |
+
return []
|
612 |
+
|
613 |
+
print(f"Found {len(doc_files)} files to process.")
|
614 |
+
for file_path in doc_files:
|
615 |
future = self.executor.submit(self._process_file, file_path)
|
616 |
futures.append(future)
|
617 |
+
|
618 |
+
processed_count = 0
|
619 |
for future in concurrent.futures.as_completed(futures):
|
620 |
+
try:
|
621 |
+
documents.extend(future.result())
|
622 |
+
processed_count += 1
|
623 |
+
print(f"Processed file {processed_count}/{len(doc_files)}")
|
624 |
+
except Exception as e:
|
625 |
+
print(f"Error processing file in future: {e}")
|
626 |
+
|
627 |
+
if documents:
|
628 |
+
self._save_splits(documents)
|
629 |
+
else:
|
630 |
+
print("No documents were successfully processed or split.")
|
631 |
return documents
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
632 |
|
633 |
+
def _process_file(self, file_path: Path) -> List[Document]:
|
634 |
+
try:
|
635 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
636 |
+
text = file.read()
|
637 |
+
# Improved splitting: handle more sentence endings and ensure non-empty chunks
|
638 |
+
chunks = [s.strip() for s in re.split(r'(?<=[.!?؟؛])\s+', text) if s and s.strip()]
|
639 |
+
if not chunks: # Handle empty files or files with no standard sentence endings
|
640 |
+
print(f"Warning: No chunks generated for file {file_path.name}. Treating whole file as one chunk.")
|
641 |
+
if text.strip(): # If there's content, use it as one chunk
|
642 |
+
chunks = [text.strip()]
|
643 |
+
else:
|
644 |
+
return [] # Skip empty files
|
645 |
+
|
646 |
+
return [
|
647 |
+
Document(
|
648 |
+
page_content=chunk,
|
649 |
+
metadata={
|
650 |
+
'source': file_path.name,
|
651 |
+
'chunk_id': i,
|
652 |
+
'total_chunks': len(chunks)
|
653 |
+
}
|
654 |
+
)
|
655 |
+
for i, chunk in enumerate(chunks)
|
656 |
+
]
|
657 |
+
except Exception as e:
|
658 |
+
print(f"Error processing file {file_path.name}: {e}")
|
659 |
+
return [] # Return empty list on error for this file
|
660 |
+
|
661 |
|
662 |
+
def load_index(self) -> bool:
|
663 |
+
"""Loads FAISS index and associated documents if they exist."""
|
|
|
664 |
if not self._index_exists():
|
665 |
+
# print("Index files not found.") # Reduced verbosity
|
666 |
return False
|
667 |
|
668 |
+
print(f"Loading existing index from {self.index_path} and documents from {self.documents_path}...")
|
669 |
try:
|
670 |
+
# Load FAISS index
|
671 |
self.index = faiss.read_index(str(self.index_path))
|
672 |
|
673 |
+
# If the loaded index was originally GPU, move it back if possible
|
674 |
+
if torch.cuda.is_available():
|
675 |
+
try:
|
676 |
+
print("Moving loaded index to GPU...")
|
677 |
+
res = faiss.StandardGpuResources()
|
678 |
+
self.index = faiss.index_cpu_to_gpu(res, 0, self.index)
|
679 |
+
print("Index successfully moved to GPU.")
|
680 |
+
except Exception as gpu_e:
|
681 |
+
print(f"Could not move index to GPU, using CPU. Error: {gpu_e}")
|
682 |
+
|
683 |
+
|
684 |
+
# Load associated documents
|
685 |
with open(self.documents_path, 'rb') as f:
|
686 |
self.indexed_documents = pickle.load(f)
|
687 |
|
688 |
+
if not self.indexed_documents:
|
689 |
+
print("Warning: Index loaded, but associated documents file is empty.")
|
690 |
+
# Consider this a failure case maybe?
|
691 |
+
# return False
|
692 |
+
elif self.index.ntotal != len(self.indexed_documents):
|
693 |
+
print(f"Warning: Index size ({self.index.ntotal}) does not match document count ({len(self.indexed_documents)}). Index might be corrupted or outdated.")
|
694 |
+
# Decide how to handle mismatch: rebuild? error? proceed with caution?
|
695 |
+
# For now, let's treat it as loaded but potentially problematic.
|
696 |
+
|
697 |
+
print(f"Index loaded with {self.index.ntotal} vectors.")
|
698 |
return True
|
699 |
|
700 |
+
except FileNotFoundError:
|
701 |
+
print("Index files not found during load attempt.")
|
702 |
+
self.index = None
|
703 |
+
self.indexed_documents = None
|
704 |
+
return False
|
705 |
except Exception as e:
|
706 |
+
print(f"Error loading index: {e}")
|
707 |
+
# Clean up potentially partially loaded state
|
708 |
+
self.index = None
|
709 |
+
self.indexed_documents = None
|
710 |
return False
|
711 |
|
712 |
+
def create_index(self, documents: List[Document] = None) -> bool:
|
713 |
+
"""Creates or recreates the FAISS index."""
|
714 |
if documents is None:
|
715 |
documents = self.load_and_split_texts()
|
716 |
+
|
717 |
if not documents:
|
718 |
+
print("No documents provided or loaded, cannot create index.")
|
719 |
return False
|
720 |
+
|
721 |
+
print(f"Creating index for {len(documents)} document splits...")
|
722 |
texts = [doc.page_content for doc in documents]
|
723 |
+
print("Encoding documents...")
|
724 |
embeddings = self.batch_encode(texts)
|
725 |
+
|
726 |
+
if embeddings is None or len(embeddings) == 0:
|
727 |
+
print("Encoding failed or produced no embeddings.")
|
728 |
+
return False
|
729 |
+
|
730 |
dimension = embeddings.shape[1]
|
731 |
+
print(f"Embeddings created with dimension {dimension}.")
|
732 |
+
|
733 |
+
# Create CPU index first
|
734 |
+
cpu_index = faiss.IndexFlatL2(dimension)
|
735 |
+
print("FAISS CPU index created.")
|
736 |
+
|
737 |
+
# Use GPU for FAISS if available
|
738 |
if torch.cuda.is_available():
|
739 |
+
try:
|
740 |
+
print("Attempting to use GPU for FAISS indexing...")
|
741 |
+
res = faiss.StandardGpuResources()
|
742 |
+
self.index = faiss.index_cpu_to_gpu(res, 0, cpu_index)
|
743 |
+
print("Adding embeddings to GPU index...")
|
744 |
+
self.index.add(np.array(embeddings).astype('float32'))
|
745 |
+
print(f"Embeddings added to GPU index. Index size: {self.index.ntotal}")
|
746 |
+
# Save the CPU version of the index
|
747 |
+
print(f"Saving CPU version of index to {self.index_path}...")
|
748 |
+
faiss.write_index(faiss.index_gpu_to_cpu(self.index), str(self.index_path))
|
749 |
+
except Exception as gpu_e:
|
750 |
+
print(f"GPU indexing failed: {gpu_e}. Falling back to CPU.")
|
751 |
+
self.index = cpu_index # Fallback to CPU index
|
752 |
+
print("Adding embeddings to CPU index...")
|
753 |
+
self.index.add(np.array(embeddings).astype('float32'))
|
754 |
+
print(f"Embeddings added to CPU index. Index size: {self.index.ntotal}")
|
755 |
+
print(f"Saving CPU index to {self.index_path}...")
|
756 |
+
faiss.write_index(self.index, str(self.index_path))
|
757 |
+
else:
|
758 |
+
print("GPU not available. Using CPU for FAISS indexing.")
|
759 |
+
self.index = cpu_index
|
760 |
+
print("Adding embeddings to CPU index...")
|
761 |
+
self.index.add(np.array(embeddings).astype('float32'))
|
762 |
+
print(f"Embeddings added to CPU index. Index size: {self.index.ntotal}")
|
763 |
+
print(f"Saving CPU index to {self.index_path}...")
|
764 |
+
faiss.write_index(self.index, str(self.index_path))
|
765 |
+
|
766 |
+
|
767 |
self.indexed_documents = documents
|
768 |
+
|
769 |
+
# Save documents
|
770 |
+
print(f"Saving associated documents to {self.documents_path}...")
|
771 |
+
try:
|
772 |
+
with open(self.documents_path, 'wb') as f:
|
773 |
+
pickle.dump(documents, f)
|
774 |
+
print("Index and documents saved successfully.")
|
775 |
+
return True
|
776 |
+
except Exception as e:
|
777 |
+
print(f"Error saving associated documents: {e}")
|
778 |
+
# Should we delete the index file if doc saving fails?
|
779 |
+
# self.index_path.unlink(missing_ok=True)
|
780 |
+
return False
|
781 |
+
|
782 |
|
783 |
def _index_exists(self) -> bool:
|
784 |
+
"""Checks if the index and associated document files exist."""
|
785 |
return self.index_path.exists() and self.documents_path.exists()
|
786 |
|
787 |
+
def get_retriever(self, k: int = 10, rerank_k: int = 5):
|
788 |
+
"""Gets a retriever function that performs FAISS search and cross-encoder reranking."""
|
789 |
+
if self.index is None or self.indexed_documents is None:
|
790 |
+
print("Index not initialized. Ensure load_or_create_index() was successful.")
|
791 |
+
# Attempt to load/create again, or raise error
|
792 |
+
self.load_or_create_index()
|
793 |
+
if self.index is None or self.indexed_documents is None:
|
794 |
+
raise ValueError("Unable to load or create index for retriever.")
|
795 |
+
|
796 |
+
# Make sure k for FAISS search is >= rerank_k
|
797 |
+
faiss_k = max(k, rerank_k)
|
798 |
+
|
799 |
|
800 |
+
def retriever_function(query: str) -> list[Document]:
|
801 |
+
# Check cache first (optional, consider if caching reranked results is desired)
|
802 |
+
# cache_key = f"{query}_{rerank_k}"
|
803 |
+
# if cache_key in self.response_cache:
|
804 |
+
# return self.response_cache[cache_key]
|
805 |
|
806 |
+
print(f"\nRetriever: Searching for query: '{query[:50]}...'")
|
807 |
query_embedding = self.encode(query)
|
808 |
+
|
809 |
+
print(f"Searching top {faiss_k} in FAISS index...")
|
810 |
+
try:
|
811 |
+
distances, indices = self.index.search(
|
812 |
+
np.array([query_embedding]).astype('float32'),
|
813 |
+
faiss_k
|
814 |
+
)
|
815 |
+
except Exception as search_e:
|
816 |
+
print(f"Error during FAISS search: {search_e}")
|
817 |
+
return []
|
818 |
+
|
819 |
+
# Filter out invalid indices (-1) and get initial documents
|
820 |
+
initial_results = [
|
821 |
self.indexed_documents[idx]
|
822 |
for idx in indices[0]
|
823 |
+
if idx != -1 and idx < len(self.indexed_documents) # Added bounds check
|
824 |
]
|
825 |
+
|
826 |
+
if not initial_results:
|
827 |
+
print("No relevant documents found in FAISS search.")
|
828 |
+
return []
|
829 |
+
|
830 |
+
print(f"Found {len(initial_results)} initial candidates. Reranking top {len(initial_results)}...")
|
831 |
+
|
832 |
+
# Prepare for reranking
|
833 |
+
context = [doc.page_content for doc in initial_results]
|
834 |
+
text_pairs = [[query, text] for text in context]
|
835 |
+
|
836 |
+
# Rerank using the cross-encoder
|
837 |
+
try:
|
838 |
+
scores = self.reranker.predict(text_pairs, show_progress_bar=False) # Don't show progress bar here
|
839 |
+
except Exception as rerank_e:
|
840 |
+
print(f"Error during reranking: {rerank_e}. Returning initial FAISS results.")
|
841 |
+
return initial_results[:rerank_k] # Return top k from initial results
|
842 |
+
|
843 |
+
|
844 |
+
# Combine scores with documents and sort
|
845 |
+
scored_docs = list(zip(scores, initial_results))
|
846 |
+
scored_docs.sort(key=lambda x: x[0], reverse=True)
|
847 |
+
|
848 |
+
# Select the top rerank_k documents
|
849 |
+
reranked_docs = [doc for score, doc in scored_docs[:rerank_k]]
|
850 |
+
|
851 |
+
print(f"Reranked results (top {len(reranked_docs)}):")
|
852 |
+
# for i, (score, doc) in enumerate(scored_docs[:rerank_k]):
|
853 |
+
# print(f" {i+1}. Score: {score:.4f}, Source: {doc.metadata.get('source', 'N/A')}, Chunk: {doc.metadata.get('chunk_id', 'N/A')}")
|
854 |
+
|
855 |
+
|
856 |
+
# Cache results (optional)
|
857 |
+
# self.response_cache[cache_key] = reranked_docs
|
858 |
+
return reranked_docs
|
859 |
+
|
860 |
return retriever_function
|
861 |
|
862 |
+
# --- LLM Initialization ---
|
863 |
+
# Choose *one* LLM to uncomment based on your available API key
|
864 |
+
# print("Initializing LLM...")
|
865 |
+
# mistral_api_key = os.getenv("MISTRAL_API_KEY") # Ensure env var name matches your Space secret
|
866 |
+
# if mistral_api_key:
|
867 |
+
# llm = ChatMistralAI(
|
868 |
+
# model="mistral-large-latest",
|
869 |
+
# mistral_api_key=mistral_api_key,
|
870 |
+
# temperature=0.01,
|
871 |
+
# # streaming=True, # Streaming is handled differently with Gradio 'messages'
|
872 |
+
# )
|
873 |
+
# print("Using Mistral LLM.")
|
874 |
+
# else:
|
875 |
+
# print("Mistral API key not found.")
|
876 |
|
877 |
# deepseek_api_key = os.getenv("DEEPSEEK_KEY")
|
878 |
+
# if deepseek_api_key:
|
879 |
+
# llm = ChatDeepSeek(
|
880 |
+
# model="deepseek-chat",
|
881 |
+
# temperature=0.01, # Slightly non-zero for potentially better phrasing
|
882 |
+
# api_key=deepseek_api_key,
|
883 |
+
# # streaming=True, # Streaming is handled differently with Gradio 'messages'
|
884 |
+
# )
|
885 |
+
# print("Using DeepSeek LLM.")
|
886 |
+
# else:
|
887 |
+
# print("Deepseek API key not found.")
|
888 |
|
889 |
|
890 |
gemini_api_key = os.getenv("GEMINI_KEY")
|
891 |
+
if gemini_api_key:
|
892 |
+
try:
|
893 |
+
llm = ChatGoogleGenerativeAI(
|
894 |
+
model="gemini-1.5-flash", # Using flash for potentially faster responses
|
895 |
+
temperature=0.1, # Slightly increased temperature
|
896 |
+
google_api_key=gemini_api_key,
|
897 |
+
# convert_system_message_to_human=True # Sometimes needed for Gemini
|
898 |
+
# streaming=False # Gemini streaming requires specific handling; simpler without for now
|
899 |
+
)
|
900 |
+
print("Using Google Gemini LLM.")
|
901 |
+
except Exception as gemini_e:
|
902 |
+
print(f"Error initializing Gemini LLM: {gemini_e}")
|
903 |
+
llm = None # Set llm to None if initialization fails
|
904 |
+
else:
|
905 |
+
print("Gemini API key not found.")
|
906 |
+
llm = None
|
907 |
|
908 |
+
# --- RAG Loader and Retriever Initialization ---
|
909 |
+
print("Initializing RAG Loader...")
|
910 |
+
try:
|
911 |
+
rag_loader = OptimizedRAGLoader()
|
912 |
+
retriever = rag_loader.get_retriever(k=15, rerank_k=5) # Retrieve more initially, rerank top 5
|
913 |
+
print("RAG Loader and Retriever initialized.")
|
914 |
+
except Exception as rag_e:
|
915 |
+
print(f"FATAL: Could not initialize RAG system: {rag_e}")
|
916 |
+
# Optionally exit or provide a dummy retriever/LLM
|
917 |
+
retriever = lambda query: [] # Dummy retriever
|
918 |
+
llm = None # Ensure LLM is None if RAG fails
|
919 |
|
|
|
|
|
920 |
|
921 |
+
# Cache for processed questions (Consider persistence or size limits if needed)
|
922 |
question_cache = {}
|
923 |
|
924 |
+
# --- Prompt Template ---
|
925 |
+
# Adjusted prompt for clarity and conciseness
|
926 |
prompt_template = ChatPromptTemplate.from_messages([
|
927 |
+
("system", """أنت مساعد قانوني خبير ومؤهل في القانون المغربي. مهمتك هي تحليل الأسئلة القانونية والإجابة عليها بدقة بناءً على السياق المقدم.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
928 |
|
929 |
+
إرشادات:
|
930 |
+
1. حلل السياق التالي بعناية:
|
931 |
+
{context}
|
932 |
+
2. استخدم المعلومات من السياق فقط لصياغة إجابتك.
|
933 |
+
3. إذا كانت المعلومات غير كافية أو غير موجودة في السياق للإجابة على السؤال، أشر بوضوح إلى أن السياق المقدم لا يحتوي على الإجابة المطلوبة. لا تختلق معلومات.
|
934 |
+
4. اذكر المصادر (أسماء الملفات) التي استخدمتها من السياق في نهاية إجابتك.
|
935 |
+
5. أجب باللغة العربية وبأسلوب واضح وموجز.
|
936 |
|
937 |
+
السؤال المطلوب الإجابة عليه: {question}"""),
|
938 |
+
("human", "{question}") # Human message might be redundant if question is in system prompt, but often helps guide model role.
|
939 |
+
])
|
940 |
|
941 |
+
# --- CSS Styling ---
|
942 |
+
# (CSS remains the same, ensure RTL works as intended)
|
943 |
css = """
|
944 |
/* Reset RTL global */
|
945 |
+
:root {
|
946 |
+
--input-border-radius: 15px !important; /* Example variable */
|
947 |
+
--button-border-radius: 15px !important;
|
948 |
+
}
|
949 |
+
|
950 |
*, *::before, *::after {
|
951 |
direction: rtl !important;
|
952 |
text-align: right !important;
|
953 |
}
|
|
|
954 |
body {
|
955 |
+
font-family: 'Arial', 'sans-serif'; /* Using a more standard font */
|
956 |
+
background-color: #f8f9fa; /* Light gray background */
|
957 |
+
color: #343a40; /* Darker text */
|
958 |
+
direction: rtl !important;
|
959 |
}
|
|
|
960 |
.gradio-container {
|
961 |
+
direction: rtl !important;
|
962 |
+
background-color: #f8f9fa;
|
963 |
}
|
964 |
+
/* Input Textbox */
|
965 |
+
.gradio-textbox textarea {
|
966 |
+
border-radius: var(--input-border-radius) !important;
|
967 |
+
padding: 12px 18px !important;
|
968 |
+
border: 1px solid #ced4da !important;
|
969 |
+
font-size: 16px !important;
|
970 |
+
width: 95% !important; /* Adjust width */
|
971 |
+
margin: 10px auto !important;
|
972 |
+
display: block !important;
|
|
|
|
|
973 |
text-align: right !important;
|
974 |
+
background-color: #ffffff !important; /* White background for input */
|
975 |
+
color: #495057 !important; /* Input text color */
|
976 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.05) !important; /* Subtle shadow */
|
977 |
}
|
978 |
|
979 |
+
/* Placeholder styling */
|
980 |
+
.gradio-textbox textarea::placeholder {
|
|
|
981 |
text-align: right !important;
|
982 |
direction: rtl !important;
|
983 |
+
color: #6c757d !important; /* Lighter placeholder text */
|
984 |
}
|
985 |
|
986 |
+
/* Send Button */
|
987 |
.gradio-button {
|
988 |
+
border-radius: var(--button-border-radius) !important;
|
989 |
+
font-size: 16px !important;
|
990 |
+
font-weight: bold !important;
|
991 |
+
background-color: #007bff !important; /* Primary blue */
|
992 |
+
color: white !important;
|
993 |
+
padding: 12px 24px !important;
|
994 |
+
margin: 5px auto 15px auto !important; /* Adjust margins */
|
995 |
+
border: none !important;
|
996 |
+
width: 95% !important; /* Match textbox width */
|
997 |
+
display: block !important;
|
998 |
+
cursor: pointer !important;
|
999 |
+
transition: background-color 0.2s ease-in-out !important;
|
1000 |
+
box-shadow: 0 2px 5px rgba(0, 123, 255, 0.3) !important; /* Button shadow */
|
1001 |
}
|
|
|
1002 |
.gradio-button:hover {
|
1003 |
+
background-color: #0056b3 !important; /* Darker blue on hover */
|
1004 |
+
}
|
1005 |
+
|
1006 |
+
/* Chatbot Messages */
|
1007 |
+
.gradio-chatbot {
|
1008 |
+
box-shadow: 0 2px 10px rgba(0,0,0,0.1) !important; /* Chatbot container shadow */
|
1009 |
+
border-radius: 10px !important;
|
1010 |
+
background-color: #ffffff !important; /* White background for chat area */
|
1011 |
+
padding: 10px !important;
|
1012 |
+
}
|
1013 |
+
|
1014 |
+
.gradio-chatbot .message-wrap { /* Targeting the wrapper for better styling */
|
1015 |
+
padding: 5px 0 !important; /* Space between messages */
|
1016 |
}
|
1017 |
|
1018 |
.gradio-chatbot .message {
|
1019 |
+
border-radius: 18px !important; /* Rounded corners for messages */
|
1020 |
+
padding: 10px 15px !important;
|
1021 |
+
margin: 5px 0 !important; /* Vertical margin */
|
1022 |
+
max-width: 85% !important; /* Max width of message bubble */
|
1023 |
+
border: none !important; /* Remove default border */
|
1024 |
+
box-shadow: 0 1px 3px rgba(0,0,0,0.08) !important; /* Subtle message shadow */
|
1025 |
text-align: right !important;
|
1026 |
direction: rtl !important;
|
1027 |
+
word-wrap: break-word; /* Ensure long words break */
|
1028 |
+
line-height: 1.5; /* Improve readability */
|
1029 |
}
|
1030 |
|
1031 |
+
/* User Messages (align left conceptually, but RTL makes them appear right-aligned within container) */
|
1032 |
+
.gradio-chatbot .user-message .message {
|
1033 |
+
margin-left: auto !important; /* Push to the 'end' side in LTR, start in RTL */
|
1034 |
+
margin-right: 0 !important;
|
1035 |
+
background-color: #e7f5ff !important; /* Light blue for user */
|
1036 |
+
color: #0056b3 !important;
|
1037 |
}
|
1038 |
|
1039 |
+
/* Assistant Messages (align right conceptually, appear left-aligned in RTL) */
|
1040 |
+
.gradio-chatbot .assistant-message .message {
|
1041 |
+
margin-right: auto !important; /* Push to the 'start' side in LTR, end in RTL */
|
1042 |
+
margin-left: 0 !important;
|
1043 |
+
background-color: #f1f3f5 !important; /* Light gray for assistant */
|
1044 |
+
color: #343a40 !important;
|
1045 |
}
|
1046 |
|
1047 |
+
/* Markdown Header */
|
1048 |
+
h2 {
|
1049 |
+
color: #0056b3 !important; /* Match button hover color */
|
1050 |
+
font-weight: bold !important;
|
1051 |
+
margin-bottom: 20px !important;
|
1052 |
+
text-align: center !important;
|
1053 |
+
direction: rtl !important; /* Ensure header is also RTL */
|
1054 |
}
|
1055 |
|
1056 |
+
/* Ensure specific Gradio elements inherit RTL */
|
1057 |
+
.gradio-dropdown div, .gradio-checkboxgroup div, .gradio-radio div {
|
1058 |
text-align: right !important;
|
1059 |
+
direction: rtl !important;
|
1060 |
+
}
|
1061 |
+
|
1062 |
+
/* Center alignment for elements within rows if needed */
|
1063 |
+
.gradio-row {
|
1064 |
+
justify-content: center !important; /* Helps center content like buttons/textboxes */
|
1065 |
}
|
1066 |
"""
|
1067 |
|
1068 |
+
# --- Backend Processing Function ---
|
1069 |
def process_question(question: str) -> Iterator[str]:
|
1070 |
+
"""
|
1071 |
+
Processes a question using RAG and LLM, yielding the response stream.
|
1072 |
+
Includes source attribution.
|
1073 |
+
"""
|
1074 |
+
if not llm:
|
1075 |
+
yield "عذراً، النموذج اللغوي غير متاح حالياً. يرجى المحاولة لاحقاً."
|
1076 |
return
|
1077 |
+
if not retriever:
|
1078 |
+
yield "عذراً، نظام استرجاع المعلومات غير متاح حالياً."
|
1079 |
+
return
|
1080 |
+
|
1081 |
+
# Simple caching check (consider more robust caching)
|
1082 |
+
# if question in question_cache:
|
1083 |
+
# response, sources_str = question_cache[question]
|
1084 |
+
# yield response + sources_str
|
1085 |
+
# return
|
1086 |
+
|
1087 |
+
print(f"Processing question: {question}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1088 |
try:
|
1089 |
+
relevant_docs = retriever(question)
|
1090 |
+
except Exception as ret_e:
|
1091 |
+
print(f"Error during retrieval: {ret_e}")
|
1092 |
+
yield f"حدث خطأ أثناء البحث عن المستندات المتعلقة: {str(ret_e)}"
|
1093 |
+
return
|
1094 |
+
|
1095 |
+
|
1096 |
+
if not relevant_docs:
|
1097 |
+
print("No relevant documents found by retriever.")
|
1098 |
+
yield "لم أتمكن من العثور على معلومات ذات صلة في المستندات المتاحة للإجابة على سؤالك."
|
1099 |
+
return
|
1100 |
+
|
1101 |
+
context_str = "\n\n".join([f"المصدر: {doc.metadata.get('source', 'غير معروف')}\nالمحتوى: {doc.page_content}" for doc in relevant_docs])
|
1102 |
+
sources = list(set([doc.metadata.get("source", "غير معروف") for doc in relevant_docs]))
|
1103 |
+
sources_str = "\n\n\nالمصادر المحتملة التي تم الرجوع إليها:\n- " + "\n- ".join(sources)
|
1104 |
+
|
1105 |
+
print(f"Context created from {len(relevant_docs)} documents. Generating response...")
|
1106 |
+
# print(f"Context sample: {context_str[:200]}...") # Debugging
|
1107 |
+
|
1108 |
+
try:
|
1109 |
+
# Format the prompt using the template
|
1110 |
+
prompt = prompt_template.format_messages(
|
1111 |
+
context=context_str,
|
1112 |
+
question=question
|
1113 |
+
)
|
1114 |
+
|
1115 |
+
# --- Non-Streaming Call (Simpler for Gemini without specific streaming setup) ---
|
1116 |
+
# full_response = llm.invoke(prompt) # Use invoke for non-streaming
|
1117 |
+
# if isinstance(full_response, str): # Handle potential different return types
|
1118 |
+
# response_content = full_response
|
1119 |
+
# else:
|
1120 |
+
# response_content = full_response.content
|
1121 |
+
|
1122 |
+
# yield response_content + sources_str # Yield the complete response at once
|
1123 |
+
|
1124 |
+
# --- Streaming Call (If LLM supports it and is configured correctly) ---
|
1125 |
+
full_response = ""
|
1126 |
+
stream = llm.stream(prompt)
|
1127 |
+
start_time = time.time()
|
1128 |
+
first_chunk_received = False
|
1129 |
+
|
1130 |
+
for chunk in stream:
|
1131 |
+
if not first_chunk_received:
|
1132 |
+
end_time = time.time()
|
1133 |
+
print(f"Time to first chunk: {end_time - start_time:.2f} seconds")
|
1134 |
+
first_chunk_received = True
|
1135 |
+
|
1136 |
+
# Adapt based on Langchain version and LLM provider's chunk structure
|
1137 |
+
if hasattr(chunk, 'content'):
|
1138 |
+
current_chunk = chunk.content
|
1139 |
+
elif isinstance(chunk, str):
|
1140 |
+
current_chunk = chunk
|
1141 |
else:
|
1142 |
+
print(f"Unexpected chunk type: {type(chunk)}")
|
1143 |
+
current_chunk = str(chunk) # Fallback
|
1144 |
+
|
1145 |
+
if current_chunk: # Avoid adding empty chunks
|
1146 |
+
full_response += current_chunk
|
1147 |
+
# Yield intermediate response with sources appended
|
1148 |
+
yield full_response + sources_str # Appending sources at each step
|
1149 |
+
|
1150 |
+
if not first_chunk_received: # Handle cases where stream might be empty or fail silently
|
1151 |
+
print("No chunks received from LLM stream.")
|
1152 |
+
yield "حدث خطأ أو لم يتمكن النموذج من إنشاء رد." + sources_str
|
1153 |
+
|
1154 |
+
print("LLM response generation complete.")
|
1155 |
+
# Cache the final result (optional)
|
1156 |
+
# question_cache[question] = (full_response, sources_str)
|
1157 |
+
|
1158 |
except Exception as e:
|
1159 |
+
print(f"Error during LLM generation: {e}")
|
1160 |
+
yield f"حدث خطأ أثناء توليد الإجابة: {str(e)}" + sources_str # Include sources even on error if possible
|
1161 |
+
|
1162 |
+
|
1163 |
+
# --- Gradio Interface Functions ---
|
1164 |
+
|
1165 |
+
# Function to add user message to history (using 'messages' format)
|
1166 |
+
def user_input(user_message: str, chat_history: List[Dict[str, str]]) -> Tuple[str, List[Dict[str, str]]]:
|
1167 |
+
if not user_message.strip(): # Prevent empty messages
|
1168 |
+
return "", chat_history
|
1169 |
+
# Append the user's message to the history in the correct format
|
1170 |
+
return "", chat_history + [{"role": "user", "content": user_message}]
|
1171 |
+
|
1172 |
+
# Function to handle the streaming response (using 'messages' format)
|
1173 |
+
def gradio_stream(question: str, chat_history: List[Dict[str, str]]) -> Iterator[List[Dict[str, str]]]:
|
1174 |
+
if not question.strip(): # Prevent processing empty questions passed from user_input
|
1175 |
+
yield chat_history
|
1176 |
+
return
|
1177 |
+
|
1178 |
+
if not llm or not retriever:
|
1179 |
+
chat_history.append({"role": "assistant", "content": "عذراً، النظام غير جاهز حالياً. يرجى التأكد من تهيئة المفاتيح والنماذج."})
|
1180 |
+
yield chat_history
|
1181 |
+
return
|
1182 |
+
|
1183 |
+
# Add a placeholder for the assistant's response
|
1184 |
+
chat_history.append({"role": "assistant", "content": ""})
|
1185 |
+
# Use a thinking indicator initially
|
1186 |
+
chat_history[-1]["content"] = "جارٍ التفكير والبحث..."
|
1187 |
+
yield chat_history # Show "Thinking..." message immediately
|
1188 |
+
|
1189 |
|
|
|
|
|
1190 |
try:
|
1191 |
+
# Stream the response using the process_question generator
|
1192 |
for partial_response in process_question(question):
|
1193 |
+
# Update the content of the last message (the assistant's placeholder)
|
1194 |
+
chat_history[-1]["content"] = partial_response
|
1195 |
+
yield chat_history # Yield the entire updated history list
|
1196 |
except Exception as e:
|
1197 |
+
print(f"Error in gradio_stream calling process_question: {e}")
|
1198 |
+
# Update the assistant's message with the error
|
1199 |
+
chat_history[-1]["content"] = f"حدث خطأ غير متوقع: {str(e)}"
|
1200 |
+
yield chat_history # Yield the history with the error message
|
1201 |
|
1202 |
+
|
1203 |
+
# --- Gradio Interface Definition ---
|
1204 |
+
print("Building Gradio interface...")
|
1205 |
+
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: # Added a theme
|
1206 |
+
gr.Markdown("<h2 style='text-align: center;'>مساعد قانوني مغربي - أسئلة وأجوبة</h2>")
|
1207 |
+
gr.Markdown("<p style='text-align: center; color: #6c757d;'>اطرح سؤالك حول القوانين المغربية وسأحاول الإجابة بناءً على المستندات المتوفرة.</p>")
|
1208 |
+
|
1209 |
+
# Use type="messages" for the chatbot
|
1210 |
+
chatbot = gr.Chatbot(label="المحادثة", type="messages", height=500)
|
1211 |
|
1212 |
with gr.Row():
|
1213 |
+
message = gr.Textbox(
|
1214 |
+
label="أدخل سؤالك هنا:",
|
1215 |
+
placeholder="مثال: ما هي شروط الحصول على تعويض عن حادثة شغل؟",
|
1216 |
+
lines=3, # Allow more lines for input
|
1217 |
+
elem_id="question_input" # Keep elem_id if needed elsewhere
|
1218 |
+
)
|
1219 |
|
1220 |
with gr.Row():
|
1221 |
+
send = gr.Button("إرسال السؤال", variant="primary") # Use variant for emphasis
|
1222 |
+
|
1223 |
+
|
1224 |
+
# Chain the actions:
|
1225 |
+
# 1. When send is clicked, call user_input:
|
1226 |
+
# - Takes user message (from 'message' textbox) and current history (from 'chatbot')
|
1227 |
+
# - Outputs: Clears the 'message' textbox, updates 'chatbot' history with user message
|
1228 |
+
# 2. After user_input completes, call gradio_stream:
|
1229 |
+
# - Takes the *original* user message (important!) and the *updated* history from step 1.
|
1230 |
+
# - Outputs: Streams updates back to the 'chatbot' component.
|
1231 |
+
send.click(user_input, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\
|
1232 |
+
then(gradio_stream, inputs=[message, chatbot], outputs=chatbot)
|
1233 |
+
|
1234 |
+
# Optional: Allow submitting with Enter key
|
1235 |
+
message.submit(user_input, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\
|
1236 |
+
then(gradio_stream, inputs=[message, chatbot], outputs=chatbot)
|
1237 |
|
|
|
|
|
|
|
1238 |
|
1239 |
+
print("Gradio Blocks defined.")
|
|
|
1240 |
|
1241 |
+
# --- Launch the Application ---
|
1242 |
+
if __name__ == "__main__":
|
1243 |
+
print("Launching Gradio app...")
|
1244 |
+
# Set share=False when running locally or in environments like Spaces where it's handled differently
|
1245 |
+
# Set debug=True for more detailed logs during development
|
1246 |
+
demo.queue() # Enable queue for handling multiple users/requests
|
1247 |
+
demo.launch(share=False, debug=True) # share=True can cause issues in some environments
|