import gradio as gr | |
from langchain_mistralai.chat_models import ChatMistralAI | |
from langchain.prompts import ChatPromptTemplate | |
from langchain_deepseek import ChatDeepSeek | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
import os | |
from pathlib import Path | |
import json | |
import faiss | |
import numpy as np | |
from langchain.schema import Document | |
import pickle | |
import re | |
import requests | |
from functools import lru_cache | |
import torch | |
from sentence_transformers import SentenceTransformer | |
from sentence_transformers.cross_encoder import CrossEncoder | |
import threading | |
from queue import Queue | |
import concurrent.futures | |
from typing import Generator, Tuple, Iterator | |
import time | |
class OptimizedRAGLoader: | |
def __init__(self, | |
docs_folder: str = "./docs", | |
splits_folder: str = "./splits", | |
index_folder: str = "./index"): | |
self.docs_folder = Path(docs_folder) | |
self.splits_folder = Path(splits_folder) | |
self.index_folder = Path(index_folder) | |
# Create folders if they don't exist | |
for folder in [self.splits_folder, self.index_folder]: | |
folder.mkdir(parents=True, exist_ok=True) | |
# File paths | |
self.splits_path = self.splits_folder / "splits.json" | |
self.index_path = self.index_folder / "faiss.index" | |
self.documents_path = self.index_folder / "documents.pkl" | |
# Initialize components | |
self.index = None | |
self.indexed_documents = None | |
# Initialize encoder model | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.encoder = SentenceTransformer("intfloat/multilingual-e5-large") | |
self.encoder.to(self.device) | |
self.reranker = model = CrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1",trust_remote_code=True) | |
# Initialize thread pool | |
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4) | |
# Initialize response cache | |
self.response_cache = {} | |
def encode(self, text: str): | |
"""Cached encoding function""" | |
with torch.no_grad(): | |
embeddings = self.encoder.encode( | |
text, | |
convert_to_numpy=True, | |
normalize_embeddings=True | |
) | |
return embeddings | |
def batch_encode(self, texts: list): | |
"""Batch encoding for multiple texts""" | |
with torch.no_grad(): | |
embeddings = self.encoder.encode( | |
texts, | |
batch_size=32, | |
convert_to_numpy=True, | |
normalize_embeddings=True, | |
show_progress_bar=False | |
) | |
return embeddings | |
def load_and_split_texts(self): | |
if self._splits_exist(): | |
return self._load_existing_splits() | |
documents = [] | |
futures = [] | |
for file_path in self.docs_folder.glob("*.txt"): | |
future = self.executor.submit(self._process_file, file_path) | |
futures.append(future) | |
for future in concurrent.futures.as_completed(futures): | |
documents.extend(future.result()) | |
self._save_splits(documents) | |
return documents | |
def _process_file(self, file_path): | |
with open(file_path, 'r', encoding='utf-8') as file: | |
text = file.read() | |
chunks = [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()] | |
return [ | |
Document( | |
page_content=chunk, | |
metadata={ | |
'source': file_path.name, | |
'chunk_id': i, | |
'total_chunks': len(chunks) | |
} | |
) | |
for i, chunk in enumerate(chunks) | |
] | |
def load_index(self) -> bool: | |
""" | |
Charge l'index FAISS et les documents associés s'ils existent | |
Returns: | |
bool: True si l'index a été chargé, False sinon | |
""" | |
if not self._index_exists(): | |
print("Aucun index trouvé.") | |
return False | |
print("Chargement de l'index existant...") | |
try: | |
# Charger l'index FAISS | |
self.index = faiss.read_index(str(self.index_path)) | |
# Charger les documents associés | |
with open(self.documents_path, 'rb') as f: | |
self.indexed_documents = pickle.load(f) | |
print(f"Index chargé avec {self.index.ntotal} vecteurs") | |
return True | |
except Exception as e: | |
print(f"Erreur lors du chargement de l'index: {e}") | |
return False | |
def create_index(self, documents=None): | |
if documents is None: | |
documents = self.load_and_split_texts() | |
if not documents: | |
return False | |
texts = [doc.page_content for doc in documents] | |
embeddings = self.batch_encode(texts) | |
dimension = embeddings.shape[1] | |
self.index = faiss.IndexFlatL2(dimension) | |
if torch.cuda.is_available(): | |
# Use GPU for FAISS if available | |
res = faiss.StandardGpuResources() | |
self.index = faiss.index_cpu_to_gpu(res, 0, self.index) | |
self.index.add(np.array(embeddings).astype('float32')) | |
self.indexed_documents = documents | |
# Save index and documents | |
cpu_index = faiss.index_gpu_to_cpu(self.index) if torch.cuda.is_available() else self.index | |
faiss.write_index(cpu_index, str(self.index_path)) | |
with open(self.documents_path, 'wb') as f: | |
pickle.dump(documents, f) | |
return True | |
def _index_exists(self) -> bool: | |
"""Vérifie si l'index et les documents associés existent""" | |
return self.index_path.exists() and self.documents_path.exists() | |
def get_retriever(self, k: int = 10): | |
if self.index is None: | |
if not self.load_index(): | |
if not self.create_index(): | |
raise ValueError("Unable to load or create index") | |
def retriever_function(query: str) -> list: | |
# Check cache first | |
cache_key = f"{query}_{k}" | |
if cache_key in self.response_cache: | |
return self.response_cache[cache_key] | |
query_embedding = self.encode(query) | |
distances, indices = self.index.search( | |
np.array([query_embedding]).astype('float32'), | |
k | |
) | |
results = [ | |
self.indexed_documents[idx] | |
for idx in indices[0] | |
if idx != -1 | |
] | |
# Cache the results | |
self.response_cache[cache_key] = results | |
return results | |
return retriever_function | |
# # # Initialize components | |
# # mistral_api_key = os.getenv("mistral_api_key") | |
# # llm = ChatMistralAI( | |
# # model="mistral-large-latest", | |
# # mistral_api_key=mistral_api_key, | |
# # temperature=0.01, | |
# # streaming=True, | |
# # ) | |
# # deepseek_api_key = os.getenv("DEEPSEEK_KEY") | |
# # llm = ChatDeepSeek( | |
# # model="deepseek-chat", | |
# # temperature=0, | |
# # api_key=deepseek_api_key, | |
# # streaming=True, | |
# # ) | |
# gemini_api_key = os.getenv("GEMINI_KEY") | |
# llm = ChatGoogleGenerativeAI( | |
# model="gemini-1.5-pro", | |
# temperature=0, | |
# google_api_key=gemini_api_key, | |
# disable_streaming=True, | |
# ) | |
# rag_loader = OptimizedRAGLoader() | |
# retriever = rag_loader.get_retriever(k=5) # Reduced k for faster retrieval | |
# # Cache for processed questions | |
# question_cache = {} | |
# prompt_template = ChatPromptTemplate.from_messages([ | |
# ("system", """Vous êtes un assistant juridique expert qualifié. Analysez et répondez aux questions juridiques avec précision. | |
# PROCESSUS D'ANALYSE : | |
# 1. Analysez le contexte fourni : {context} | |
# 2. Utilisez la recherche web si la reponse n'existe pas dans le contexte | |
# 3. Privilégiez les sources officielles et la jurisprudence récente | |
# Question à traiter : {question} | |
# """), | |
# ("human", "{question}") | |
# ]) | |
# import gradio as gr | |
# # Ajouter du CSS pour personnaliser l'apparence | |
# css = """ | |
# /* Reset RTL global */ | |
# *, *::before, *::after { | |
# direction: rtl !important; | |
# text-align: right !important; | |
# } | |
# body { | |
# font-family: 'Amiri', sans-serif; /* Utilisation de la police Arabe andalouse */ | |
# background-color: black; /* Fond blanc */ | |
# color: black !important; /* Texte noir */ | |
# direction: rtl !important; /* Texte en arabe aligné à droite */ | |
# } | |
# .gradio-container { | |
# direction: rtl !important; /* Alignement RTL pour toute l'interface */ | |
# } | |
# /* Éléments de formulaire */ | |
# input[type="text"], | |
# .gradio-textbox input, | |
# textarea { | |
# border-radius: 20px; | |
# padding: 10px 15px; | |
# border: 2px solid #000; | |
# font-size: 16px; | |
# width: 80%; | |
# margin: 0 auto; | |
# text-align: right !important; | |
# } | |
# /* Surcharge des styles de placeholder */ | |
# input::placeholder, | |
# textarea::placeholder { | |
# text-align: right !important; | |
# direction: rtl !important; | |
# } | |
# /* Boutons */ | |
# .gradio-button { | |
# border-radius: 20px; | |
# font-size: 16px; | |
# background-color: #007BFF; | |
# color: white; | |
# padding: 10px 20px; | |
# margin: 10px auto; | |
# border: none; | |
# width: 80%; | |
# display: block; | |
# } | |
# .gradio-button:hover { | |
# background-color: #0056b3; | |
# } | |
# .gradio-chatbot .message { | |
# border-radius: 20px; | |
# padding: 10px; | |
# margin: 10px 0; | |
# background-color: #f1f1f1; | |
# border: 1px solid #ddd; | |
# width: 80%; | |
# text-align: right !important; | |
# direction: rtl !important; | |
# } | |
# /* Messages utilisateur alignés à gauche */ | |
# .gradio-chatbot .user-message { | |
# margin-right: auto; | |
# background-color: #e3f2fd; | |
# text-align: right !important; | |
# direction: rtl !important; | |
# } | |
# /* Messages assistant alignés à droite */ | |
# .gradio-chatbot .assistant-message { | |
# margin-right: auto; | |
# background-color: #f1f1f1; | |
# text-align: right | |
# } | |
# /* Corrections RTL pour les éléments spécifiques */ | |
# .gradio-textbox textarea { | |
# text-align: right !important; | |
# } | |
# .gradio-dropdown div { | |
# text-align: right !important; | |
# } | |
# """ | |
# # Modified process_question function to better work with tuples | |
# def process_question(question: str) -> Iterator[str]: | |
# if question in question_cache: | |
# response, docs = question_cache[question] | |
# sources = [doc.metadata.get("source") for doc in docs] | |
# sources = list(set([os.path.splitext(source)[0] for source in sources])) | |
# yield response + "\n\n\nالمصادر المحتملة :\n" + "\n".join(sources) | |
# return | |
# relevant_docs = retriever(question) | |
# # Reranking with cross-encoder | |
# context = [doc.page_content for doc in relevant_docs] | |
# text_pairs = [[question, text] for text in context] | |
# scores = rag_loader.reranker.predict(text_pairs) | |
# scored_docs = list(zip(scores, context, relevant_docs)) | |
# scored_docs.sort(key=lambda x: x[0], reverse=True) | |
# reranked_docs = [d[2].page_content for d in scored_docs][:10] | |
# prompt = prompt_template.format_messages( | |
# context=reranked_docs, | |
# question=question | |
# ) | |
# full_response = "" | |
# try: | |
# for chunk in llm.stream(prompt): | |
# if isinstance(chunk, str): | |
# current_chunk = chunk | |
# else: | |
# current_chunk = chunk.content | |
# full_response += current_chunk | |
# sources = [d[2].metadata['source'] for d in scored_docs][:10] | |
# sources = list(set([os.path.splitext(source)[0] for source in sources])) | |
# yield full_response + "\n\n\nالمصادر المحتملة :\n" + "\n".join(sources) | |
# question_cache[question] = (full_response, relevant_docs) | |
# except Exception as e: | |
# yield f"Erreur lors du traitement : {str(e)}" | |
# # Updated gradio_stream function for 'messages' format | |
# def gradio_stream(question: str, chat_history: list) -> Iterator[list]: | |
# # chat_history now contains the user message added by user_input | |
# # Add a placeholder for the assistant's response | |
# chat_history.append({"role": "assistant", "content": ""}) | |
# try: | |
# # Stream the response using the existing process_question generator | |
# for partial_response in process_question(question): | |
# # Update the content of the last message (the assistant's placeholder) | |
# chat_history[-1]["content"] = partial_response | |
# yield chat_history # Yield the entire updated history list | |
# except Exception as e: | |
# # Update the assistant's message with the error | |
# chat_history[-1]["content"] = f"Erreur : {str(e)}" | |
# yield chat_history # Yield the history with the error message | |
# # Gradio interface | |
# with gr.Blocks(css=css) as demo: | |
# gr.Markdown("<h2 style='text-align: center !important;'>هذا تطبيق للاجابة على الأسئلة المتعلقة بالقوانين المغربية</h2>") | |
# with gr.Row(): | |
# message = gr.Textbox(label="أدخل سؤالك", placeholder="اكتب سؤالك هنا", elem_id="question_input") | |
# with gr.Row(): | |
# send = gr.Button("بحث", elem_id="search_button") | |
# with gr.Row(): | |
# # No type parameter - use Gradio's default | |
# chatbot = gr.Chatbot(label="", type="messages") # Ajout de type="messages" | |
# # Updated user_input function for 'messages' format | |
# def user_input(user_message, chat_history): | |
# # chat_history is already a list of message dicts | |
# # Append the new user message | |
# return "", chat_history + [{"role": "user", "content": user_message}] | |
# send.click(user_input, [message, chatbot], [message, chatbot], queue=False) | |
# send.click(gradio_stream, [message, chatbot], chatbot) | |
# demo.launch(share=True) | |
from fastapi import FastAPI, Request, HTTPException | |
from fastapi.responses import StreamingResponse, HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
import uvicorn | |
import asyncio | |
import os # Assurez-vous que 'os' est importé si vous l'utilisez pour les clés API, etc. | |
# --- Vos imports (Document, LLM, PromptTemplate, etc.) --- | |
# from langchain_google_genai import ChatGoogleGenerativeAI | |
# from langchain.prompts import ChatPromptTemplate | |
# ... autres imports nécessaires ... | |
# from your_rag_module import OptimizedRAGLoader # Assurez-vous que la classe est importable | |
# --- Variables globales (initialisées à None) --- | |
rag_loader = None | |
llm = None | |
retriever = None | |
prompt_template = None | |
initialization_error = None # Pour stocker une erreur d'initialisation | |
# --- Bloc d'initialisation robuste --- | |
print("--- Starting Application Initialization ---") | |
try: | |
# Initialisation du LLM | |
print("Initializing LLM...") | |
gemini_api_key = os.getenv("GEMINI_KEY") | |
if not gemini_api_key: | |
raise ValueError("GEMINI_KEY environment variable not set.") | |
llm = ChatGoogleGenerativeAI( | |
model="gemini-1.5-pro", | |
temperature=0.1, | |
google_api_key=gemini_api_key, | |
disable_streaming=True, | |
) | |
print("LLM Initialized.") | |
# Initialisation RAG Loader et Retriever | |
print("Initializing RAG Loader...") | |
# Assurez-vous que OptimizedRAGLoader est défini ou importé correctement | |
rag_loader = OptimizedRAGLoader() # Cette ligne peut échouer (chargement modèles/index) | |
print("RAG Loader Initialized. Getting Retriever...") | |
retriever = rag_loader.get_retriever(k=15) # Cette ligne dépend de rag_loader | |
print("Retriever Initialized.") | |
# Initialisation du Prompt Template | |
print("Initializing Prompt Template...") | |
prompt_template = ChatPromptTemplate.from_messages([ | |
("system", """أنت مساعد قانوني خبير... (votre prompt système complet ici) ...السؤال المطلوب الإجابة عليه: {question}"""), | |
("human", "{question}") | |
]) | |
print("Prompt Template Initialized.") | |
print("--- Application Initialization Successful ---") | |
except Exception as e: | |
print(f"!!!!!!!!!! FATAL INITIALIZATION ERROR !!!!!!!!!!") | |
print(f"Error during startup: {e}") | |
import traceback | |
traceback.print_exc() # Affiche la trace complète de l'erreur dans les logs | |
initialization_error = str(e) # Stocke l'erreur pour l'API | |
# On laisse les variables globales à None si l'initialisation échoue | |
# --- FastAPI App --- | |
app = FastAPI() | |
# --- Fonction backend modifiée --- | |
# (get_llm_response_stream - Gardez la version précédente qui gère le streaming SSE) | |
# Assurez-vous qu'elle utilise les variables globales llm, retriever, prompt_template | |
async def get_llm_response_stream(question: str): | |
# *** Vérification cruciale au début de la fonction *** | |
if initialization_error: | |
yield f"data: Erreur critique lors de l'initialisation du serveur: {initialization_error}\n\n" | |
return | |
if not retriever or not llm or not prompt_template: | |
yield f"data: Erreur: Un ou plusieurs composants serveur (LLM, Retriever, Prompt) ne sont pas initialisés.\n\n" | |
return | |
# *** Fin de la vérification *** | |
print(f"API processing question: {question}") | |
try: | |
# Utilisation de la variable globale 'retriever' | |
relevant_docs = retriever(question) | |
# ... (le reste de votre logique pour context, sources, llm.stream) ... | |
context_str = "\n\n".join([f"المصدر: {doc.metadata.get('source', 'غير معروف')}\nالمحتوى: {doc.page_content}" for doc in relevant_docs]) if relevant_docs else "لا يوجد سياق" | |
sources = sorted(list(set([os.path.splitext(doc.metadata.get("source", "غير معروف"))[0] for doc in relevant_docs]))) if relevant_docs else [] | |
sources_str = "\n\n\nالمصادر المحتملة التي تم الرجوع إليها:\n- " + "\n- ".join(sources) if sources else "" | |
if not relevant_docs: | |
# Gérer le cas où il n'y a pas de documents | |
yield f"data: لم أتمكن من العثور على معلومات ذات صلة في المستندات.\n\n" | |
# Optionnel: appeler le LLM sans contexte ou s'arrêter ici | |
return | |
# Utilisation de la variable globale 'prompt_template' | |
prompt = prompt_template.format_messages(context=context_str, question=question) | |
full_response = "" | |
# Utilisation de la variable globale 'llm' | |
stream = llm.stream(prompt) | |
for chunk in stream: | |
content = chunk.content if hasattr(chunk, 'content') else str(chunk) | |
if content: | |
formatted_chunk = content.replace('\n', '\ndata: ') | |
yield f"data: {formatted_chunk}\n\n" # Format SSE | |
full_response += content | |
# Envoyer les sources à la fin | |
if sources_str: | |
# Assurez-vous que sources_str est bien formaté pour SSE s'il contient des sauts de ligne | |
sources_sse = sources_str.replace('\n', '\ndata: ') | |
yield f"data: {sources_sse}\n\n" | |
# Signal de fin (optionnel mais utile pour le client JS) | |
yield "event: end\ndata: Stream finished\n\n" | |
except Exception as e: | |
print(f"Error during API LLM generation: {e}") | |
import traceback | |
traceback.print_exc() # Affiche l'erreur dans les logs serveur | |
yield f"data: حدث خطأ أثناء معالجة طلبك: {str(e)}\n\n" | |
yield "event: error\ndata: Stream error\n\n" # Signale une erreur au client | |
# --- Endpoint API --- | |
async def handle_ask(request: Request): | |
# Vérifie si l'initialisation globale a échoué dès le début | |
if initialization_error: | |
raise HTTPException(status_code=500, detail=f"Erreur d'initialisation serveur: {initialization_error}") | |
try: | |
data = await request.json() | |
question = data.get("question") | |
if not question: | |
raise HTTPException(status_code=400, detail="Question manquante dans la requête JSON") | |
# Retourne la réponse streamée | |
return StreamingResponse(get_llm_response_stream(question), media_type="text/event-stream") | |
except Exception as e: | |
print(f"Error in /ask endpoint: {e}") | |
raise HTTPException(status_code=500, detail=f"Erreur interne du serveur: {str(e)}") | |
# --- Servir les fichiers statiques (HTML/JS/CSS) --- | |
# Assurez-vous que le dossier 'static' existe et contient index.html, script.js, style.css | |
app.mount("/", StaticFiles(directory="static", html=True), name="static") | |
# --- Démarrage du serveur --- | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |