# import gradio as gr | |
# from langchain_mistralai.chat_models import ChatMistralAI | |
# from langchain.prompts import ChatPromptTemplate | |
# import os | |
# from pathlib import Path | |
# from typing import List, Dict, Optional | |
# import json | |
# import faiss | |
# import numpy as np | |
# from langchain.schema import Document | |
# from sentence_transformers import SentenceTransformer | |
# import pickle | |
# import re | |
# os.environ.get('HUGGINGFACE_TOKEN') | |
# class RAGLoader: | |
# def __init__(self, | |
# docs_folder: str = "./docs", | |
# splits_folder: str = "./splits", | |
# index_folder: str = "./index",): | |
# # model_name: str = "intfloat/multilingual-e5-large") | |
# """ | |
# Initialise le RAG Loader | |
# Args: | |
# docs_folder: Dossier contenant les documents sources | |
# splits_folder: Dossier où seront stockés les morceaux de texte | |
# index_folder: Dossier où sera stocké l'index FAISS | |
# model_name: Nom du modèle SentenceTransformer à utiliser | |
# """ | |
# self.docs_folder = Path(docs_folder) | |
# self.splits_folder = Path(splits_folder) | |
# self.index_folder = Path(index_folder) | |
# # self.model_name = model_name | |
# # Créer les dossiers s'ils n'existent pas | |
# self.splits_folder.mkdir(parents=True, exist_ok=True) | |
# self.index_folder.mkdir(parents=True, exist_ok=True) | |
# # Chemins des fichiers | |
# self.splits_path = self.splits_folder / "splits.json" | |
# self.index_path = self.index_folder / "faiss.index" | |
# self.documents_path = self.index_folder / "documents.pkl" | |
# # Initialiser le modèle | |
# # self.model = None | |
# self.index = None | |
# self.indexed_documents = None | |
# def encode(self,payload): | |
# token = os.environ.get('HUGGINGFACE_TOKEN') | |
# API_URL = "https://api-inference.huggingface.co/models/intfloat/multilingual-e5-large" | |
# headers = {"Authorization": "Bearer {token}"} | |
# response = requests.post(API_URL, headers=headers, json=payload) | |
# return response.json() | |
# def load_and_split_texts(self) -> List[Document]: | |
# """ | |
# Charge les textes du dossier docs, les découpe en morceaux et les sauvegarde | |
# dans un fichier JSON unique. | |
# Returns: | |
# Liste de Documents contenant les morceaux de texte et leurs métadonnées | |
# """ | |
# documents = [] | |
# # Vérifier d'abord si les splits existent déjà | |
# if self._splits_exist(): | |
# print("Chargement des splits existants...") | |
# return self._load_existing_splits() | |
# print("Création de nouveaux splits...") | |
# # Parcourir tous les fichiers du dossier docs | |
# for file_path in self.docs_folder.glob("*.txt"): | |
# with open(file_path, 'r', encoding='utf-8') as file: | |
# text = file.read() | |
# # Découper le texte en phrases | |
# # chunks = [chunk.strip() for chunk in text.split('.') if chunk.strip()] | |
# chunks = [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()] | |
# # Créer un Document pour chaque morceau | |
# for i, chunk in enumerate(chunks): | |
# doc = Document( | |
# page_content=chunk, | |
# metadata={ | |
# 'source': file_path.name, | |
# 'chunk_id': i, | |
# 'total_chunks': len(chunks) | |
# } | |
# ) | |
# documents.append(doc) | |
# # Sauvegarder tous les splits dans un seul fichier JSON | |
# self._save_splits(documents) | |
# print(f"Nombre total de morceaux créés: {len(documents)}") | |
# return documents | |
# def _splits_exist(self) -> bool: | |
# """Vérifie si le fichier de splits existe""" | |
# return self.splits_path.exists() | |
# def _save_splits(self, documents: List[Document]): | |
# """Sauvegarde tous les documents découpés dans un seul fichier JSON""" | |
# splits_data = { | |
# 'splits': [ | |
# { | |
# 'text': doc.page_content, | |
# 'metadata': doc.metadata | |
# } | |
# for doc in documents | |
# ] | |
# } | |
# with open(self.splits_path, 'w', encoding='utf-8') as f: | |
# json.dump(splits_data, f, ensure_ascii=False, indent=2) | |
# def _load_existing_splits(self) -> List[Document]: | |
# """Charge les splits depuis le fichier JSON unique""" | |
# with open(self.splits_path, 'r', encoding='utf-8') as f: | |
# splits_data = json.load(f) | |
# documents = [ | |
# Document( | |
# page_content=split['text'], | |
# metadata=split['metadata'] | |
# ) | |
# for split in splits_data['splits'] | |
# ] | |
# print(f"Nombre de splits chargés: {len(documents)}") | |
# return documents | |
# def load_index(self) -> bool: | |
# """ | |
# Charge l'index FAISS et les documents associés s'ils existent | |
# Returns: | |
# bool: True si l'index a été chargé, False sinon | |
# """ | |
# if not self._index_exists(): | |
# print("Aucun index trouvé.") | |
# return False | |
# print("Chargement de l'index existant...") | |
# try: | |
# # Charger l'index FAISS | |
# self.index = faiss.read_index(str(self.index_path)) | |
# # Charger les documents associés | |
# with open(self.documents_path, 'rb') as f: | |
# self.indexed_documents = pickle.load(f) | |
# print(f"Index chargé avec {self.index.ntotal} vecteurs") | |
# return True | |
# except Exception as e: | |
# print(f"Erreur lors du chargement de l'index: {e}") | |
# return False | |
# def create_index(self, documents: Optional[List[Document]] = None) -> bool: | |
# """ | |
# Crée un nouvel index FAISS à partir des documents. | |
# Si aucun document n'est fourni, charge les documents depuis le fichier JSON. | |
# Args: | |
# documents: Liste optionnelle de Documents à indexer | |
# Returns: | |
# bool: True si l'index a été créé avec succès, False sinon | |
# """ | |
# try: | |
# # # Initialiser le modèle si nécessaire | |
# # if self.model is None: | |
# # print("Chargement du modèle...") | |
# # self.model = SentenceTransformer(self.model_name) | |
# # Charger les documents si non fournis | |
# if documents is None: | |
# documents = self.load_and_split_texts() | |
# if not documents: | |
# print("Aucun document à indexer.") | |
# return False | |
# print("Création des embeddings...") | |
# texts = [doc.page_content for doc in documents] | |
# embeddings = self.encode(texts) | |
# # Initialiser l'index FAISS | |
# dimension = embeddings.shape[1] | |
# self.index = faiss.IndexFlatL2(dimension) | |
# # Ajouter les vecteurs à l'index | |
# self.index.add(np.array(embeddings).astype('float32')) | |
# # Sauvegarder l'index | |
# print("Sauvegarde de l'index...") | |
# faiss.write_index(self.index, str(self.index_path)) | |
# # Sauvegarder les documents associés | |
# self.indexed_documents = documents | |
# with open(self.documents_path, 'wb') as f: | |
# pickle.dump(documents, f) | |
# print(f"Index créé avec succès : {self.index.ntotal} vecteurs") | |
# return True | |
# except Exception as e: | |
# print(f"Erreur lors de la création de l'index: {e}") | |
# return False | |
# def _index_exists(self) -> bool: | |
# """Vérifie si l'index et les documents associés existent""" | |
# return self.index_path.exists() and self.documents_path.exists() | |
# def get_retriever(self, k: int = 10): | |
# """ | |
# Crée un retriever pour l'utilisation avec LangChain | |
# Args: | |
# k: Nombre de documents similaires à retourner | |
# Returns: | |
# Callable: Fonction de recherche compatible avec LangChain | |
# """ | |
# if self.index is None: | |
# if not self.load_index(): | |
# if not self.create_index(): | |
# raise ValueError("Impossible de charger ou créer l'index") | |
# # if self.model is None: | |
# # self.model = SentenceTransformer(self.model_name) | |
# def retriever_function(query: str) -> List[Document]: | |
# # Créer l'embedding de la requête | |
# query_embedding = self.encode([query])[0] | |
# # Rechercher les documents similaires | |
# distances, indices = self.index.search( | |
# np.array([query_embedding]).astype('float32'), | |
# k | |
# ) | |
# # Retourner les documents trouvés | |
# results = [] | |
# for idx in indices[0]: | |
# if idx != -1: # FAISS retourne -1 pour les résultats invalides | |
# results.append(self.indexed_documents[idx]) | |
# return results | |
# return retriever_function | |
# # Initialize the RAG system | |
# llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key="QK0ZZpSxQbCEVgOLtI6FARQVmBYc6WGP") | |
# rag_loader = RAGLoader() | |
# retriever = rag_loader.get_retriever(k=10) | |
# prompt_template = ChatPromptTemplate.from_messages([ | |
# ("system", """أنت مساعد مفيد يجيب على الأسئلة باللغة العربية باستخدام المعلومات المقدمة. | |
# استخدم المعلومات التالية للإجابة على السؤال: | |
# {context} | |
# إذا لم تكن المعلومات كافية للإجابة على السؤال بشكل كامل، قم بتوضيح ذلك. | |
# أجب بشكل موجز ودقيق."""), | |
# ("human", "{question}") | |
# ]) | |
# def process_question(question: str) -> tuple[str, str]: | |
# """ | |
# Process a question and return both the answer and the relevant context | |
# """ | |
# relevant_docs = retriever(question) | |
# context = "\n".join([doc.page_content for doc in relevant_docs]) | |
# prompt = prompt_template.format_messages( | |
# context=context, | |
# question=question | |
# ) | |
# response = llm(prompt) | |
# return response.content, context | |
# def gradio_interface(question: str) -> tuple[str, str]: | |
# """ | |
# Gradio interface function that returns both answer and context as a tuple | |
# """ | |
# return process_question(question) | |
# # Custom CSS for right-aligned text in textboxes | |
# custom_css = """ | |
# .rtl-text { | |
# text-align: right !important; | |
# direction: rtl !important; | |
# } | |
# .rtl-text textarea { | |
# text-align: right !important; | |
# direction: rtl !important; | |
# } | |
# """ | |
# # Define the Gradio interface | |
# with gr.Blocks(css=custom_css) as iface: | |
# with gr.Column(): | |
# input_text = gr.Textbox( | |
# label="السؤال", | |
# placeholder="اكتب سؤالك هنا...", | |
# lines=2, | |
# elem_classes="rtl-text" | |
# ) | |
# answer_box = gr.Textbox( | |
# label="الإجابة", | |
# lines=4, | |
# elem_classes="rtl-text" | |
# ) | |
# context_box = gr.Textbox( | |
# label="السياق المستخدم", | |
# lines=8, | |
# elem_classes="rtl-text" | |
# ) | |
# submit_btn = gr.Button("إرسال") | |
# submit_btn.click( | |
# fn=gradio_interface, | |
# inputs=input_text, | |
# outputs=[answer_box, context_box] | |
# ) | |
# # Launch the interface | |
# if __name__ == "__main__": | |
# iface.launch(share=True) | |
import gradio as gr | |
from langchain_mistralai.chat_models import ChatMistralAI | |
from langchain.prompts import ChatPromptTemplate | |
import os | |
from pathlib import Path | |
import json | |
import faiss | |
import numpy as np | |
from langchain.schema import Document | |
import pickle | |
import re | |
import requests | |
from functools import lru_cache | |
import torch | |
from sentence_transformers import SentenceTransformer | |
import threading | |
from queue import Queue | |
import concurrent.futures | |
class OptimizedRAGLoader: | |
def __init__(self, | |
docs_folder: str = "./docs", | |
splits_folder: str = "./splits", | |
index_folder: str = "./index"): | |
self.docs_folder = Path(docs_folder) | |
self.splits_folder = Path(splits_folder) | |
self.index_folder = Path(index_folder) | |
# Create folders if they don't exist | |
for folder in [self.splits_folder, self.index_folder]: | |
folder.mkdir(parents=True, exist_ok=True) | |
# File paths | |
self.splits_path = self.splits_folder / "splits.json" | |
self.index_path = self.index_folder / "faiss.index" | |
self.documents_path = self.index_folder / "documents.pkl" | |
# Initialize components | |
self.index = None | |
self.indexed_documents = None | |
# Initialize encoder model | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.encoder = SentenceTransformer("intfloat/multilingual-e5-large") | |
self.encoder.to(self.device) | |
# Initialize thread pool | |
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4) | |
# Initialize response cache | |
self.response_cache = {} | |
def encode(self, text: str): | |
"""Cached encoding function""" | |
with torch.no_grad(): | |
embeddings = self.encoder.encode( | |
text, | |
convert_to_numpy=True, | |
normalize_embeddings=True | |
) | |
return embeddings | |
def batch_encode(self, texts: list): | |
"""Batch encoding for multiple texts""" | |
with torch.no_grad(): | |
embeddings = self.encoder.encode( | |
texts, | |
batch_size=32, | |
convert_to_numpy=True, | |
normalize_embeddings=True, | |
show_progress_bar=False | |
) | |
return embeddings | |
def load_and_split_texts(self): | |
if self._splits_exist(): | |
return self._load_existing_splits() | |
documents = [] | |
futures = [] | |
for file_path in self.docs_folder.glob("*.txt"): | |
future = self.executor.submit(self._process_file, file_path) | |
futures.append(future) | |
for future in concurrent.futures.as_completed(futures): | |
documents.extend(future.result()) | |
self._save_splits(documents) | |
return documents | |
def _process_file(self, file_path): | |
with open(file_path, 'r', encoding='utf-8') as file: | |
text = file.read() | |
chunks = [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()] | |
return [ | |
Document( | |
page_content=chunk, | |
metadata={ | |
'source': file_path.name, | |
'chunk_id': i, | |
'total_chunks': len(chunks) | |
} | |
) | |
for i, chunk in enumerate(chunks) | |
] | |
def load_index(self) -> bool: | |
""" | |
Charge l'index FAISS et les documents associés s'ils existent | |
Returns: | |
bool: True si l'index a été chargé, False sinon | |
""" | |
if not self._index_exists(): | |
print("Aucun index trouvé.") | |
return False | |
print("Chargement de l'index existant...") | |
try: | |
# Charger l'index FAISS | |
self.index = faiss.read_index(str(self.index_path)) | |
# Charger les documents associés | |
with open(self.documents_path, 'rb') as f: | |
self.indexed_documents = pickle.load(f) | |
print(f"Index chargé avec {self.index.ntotal} vecteurs") | |
return True | |
except Exception as e: | |
print(f"Erreur lors du chargement de l'index: {e}") | |
return False | |
def create_index(self, documents=None): | |
if documents is None: | |
documents = self.load_and_split_texts() | |
if not documents: | |
return False | |
texts = [doc.page_content for doc in documents] | |
embeddings = self.batch_encode(texts) | |
dimension = embeddings.shape[1] | |
self.index = faiss.IndexFlatL2(dimension) | |
if torch.cuda.is_available(): | |
# Use GPU for FAISS if available | |
res = faiss.StandardGpuResources() | |
self.index = faiss.index_cpu_to_gpu(res, 0, self.index) | |
self.index.add(np.array(embeddings).astype('float32')) | |
self.indexed_documents = documents | |
# Save index and documents | |
cpu_index = faiss.index_gpu_to_cpu(self.index) if torch.cuda.is_available() else self.index | |
faiss.write_index(cpu_index, str(self.index_path)) | |
with open(self.documents_path, 'wb') as f: | |
pickle.dump(documents, f) | |
return True | |
def _index_exists(self) -> bool: | |
"""Vérifie si l'index et les documents associés existent""" | |
return self.index_path.exists() and self.documents_path.exists() | |
def get_retriever(self, k: int = 10): | |
if self.index is None: | |
if not self.load_index(): | |
if not self.create_index(): | |
raise ValueError("Unable to load or create index") | |
def retriever_function(query: str) -> list: | |
# Check cache first | |
cache_key = f"{query}_{k}" | |
if cache_key in self.response_cache: | |
return self.response_cache[cache_key] | |
query_embedding = self.encode(query) | |
distances, indices = self.index.search( | |
np.array([query_embedding]).astype('float32'), | |
k | |
) | |
results = [ | |
self.indexed_documents[idx] | |
for idx in indices[0] | |
if idx != -1 | |
] | |
# Cache the results | |
self.response_cache[cache_key] = results | |
return results | |
return retriever_function | |
# Initialize components | |
llm = ChatMistralAI( | |
model="mistral-large-latest", | |
mistral_api_key="QK0ZZpSxQbCEVgOLtI6FARQVmBYc6WGP", | |
temperature=0.1 # Lower temperature for faster responses | |
) | |
rag_loader = OptimizedRAGLoader() | |
retriever = rag_loader.get_retriever(k=5) # Reduced k for faster retrieval | |
# Cache for processed questions | |
question_cache = {} | |
prompt_template = ChatPromptTemplate.from_messages([ | |
("system", """أنت مساعد مفيد يجيب على الأسئلة باللغة العربية باستخدام المعلومات المقدمة. | |
استخدم المعلومات التالية للإجابة على السؤال: | |
{context} | |
إذا لم تكن المعلومات كافية للإجابة على السؤال بشكل كامل، قم بتوضيح ذلك. | |
أجب بشكل موجز ودقيق."""), | |
("human", "{question}") | |
]) | |
def process_question(question: str) -> tuple[str, str]: | |
# Check cache first | |
if question in question_cache: | |
return question_cache[question] | |
relevant_docs = retriever(question) | |
context = "\n".join([doc.page_content for doc in relevant_docs]) | |
prompt = prompt_template.format_messages( | |
context=context, | |
question=question | |
) | |
response = llm(prompt) | |
result = (response.content, context) | |
# Cache the result | |
question_cache[question] = result | |
return result | |
# Custom CSS for right-aligned text in textboxes | |
custom_css = """ | |
.rtl-text { | |
text-align: right !important; | |
direction: rtl !important; | |
} | |
.rtl-text textarea { | |
text-align: right !important; | |
direction: rtl !important; | |
} | |
""" | |
# Gradio interface with queue | |
with gr.Blocks(css=custom_css) as iface: | |
with gr.Column(): | |
input_text = gr.Textbox( | |
label="السؤال", | |
placeholder="اكتب سؤالك هنا...", | |
lines=2, | |
elem_classes="rtl-text" | |
) | |
with gr.Row(): | |
answer_box = gr.Textbox( | |
label="الإجابة", | |
lines=4, | |
elem_classes="rtl-text" | |
) | |
context_box = gr.Textbox( | |
label="السياق المستخدم", | |
lines=8, | |
elem_classes="rtl-text" | |
) | |
submit_btn = gr.Button("إرسال") | |
submit_btn.click( | |
fn=process_question, | |
inputs=input_text, | |
outputs=[answer_box, context_box], | |
api_name="predict" | |
) | |
if __name__ == "__main__": | |
iface.launch( | |
share=True, | |
server_name="0.0.0.0", | |
server_port=7860, | |
max_threads=3, # Controls concurrency | |
show_error=True | |
) |