Trabis commited on
Commit
f061157
·
verified ·
1 Parent(s): e53c98f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -176
app.py CHANGED
@@ -1,210 +1,210 @@
1
- # import gradio as gr
2
- # from langchain_mistralai.chat_models import ChatMistralAI
3
- # from langchain.prompts import ChatPromptTemplate
4
- # from langchain_deepseek import ChatDeepSeek
5
- # from langchain_google_genai import ChatGoogleGenerativeAI
6
- # import os
7
- # from pathlib import Path
8
- # import json
9
- # import faiss
10
- # import numpy as np
11
- # from langchain.schema import Document
12
- # import pickle
13
- # import re
14
- # import requests
15
- # from functools import lru_cache
16
- # import torch
17
- # from sentence_transformers import SentenceTransformer
18
- # from sentence_transformers.cross_encoder import CrossEncoder
19
- # import threading
20
- # from queue import Queue
21
- # import concurrent.futures
22
- # from typing import Generator, Tuple, Iterator
23
- # import time
24
-
25
- # class OptimizedRAGLoader:
26
- # def __init__(self,
27
- # docs_folder: str = "./docs",
28
- # splits_folder: str = "./splits",
29
- # index_folder: str = "./index"):
30
 
31
- # self.docs_folder = Path(docs_folder)
32
- # self.splits_folder = Path(splits_folder)
33
- # self.index_folder = Path(index_folder)
34
 
35
- # # Create folders if they don't exist
36
- # for folder in [self.splits_folder, self.index_folder]:
37
- # folder.mkdir(parents=True, exist_ok=True)
38
 
39
- # # File paths
40
- # self.splits_path = self.splits_folder / "splits.json"
41
- # self.index_path = self.index_folder / "faiss.index"
42
- # self.documents_path = self.index_folder / "documents.pkl"
43
 
44
- # # Initialize components
45
- # self.index = None
46
- # self.indexed_documents = None
47
 
48
- # # Initialize encoder model
49
- # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
- # self.encoder = SentenceTransformer("intfloat/multilingual-e5-large")
51
- # self.encoder.to(self.device)
52
- # self.reranker = model = CrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1",trust_remote_code=True)
53
 
54
- # # Initialize thread pool
55
- # self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
56
 
57
- # # Initialize response cache
58
- # self.response_cache = {}
59
 
60
- # @lru_cache(maxsize=1000)
61
- # def encode(self, text: str):
62
- # """Cached encoding function"""
63
- # with torch.no_grad():
64
- # embeddings = self.encoder.encode(
65
- # text,
66
- # convert_to_numpy=True,
67
- # normalize_embeddings=True
68
- # )
69
- # return embeddings
70
 
71
- # def batch_encode(self, texts: list):
72
- # """Batch encoding for multiple texts"""
73
- # with torch.no_grad():
74
- # embeddings = self.encoder.encode(
75
- # texts,
76
- # batch_size=32,
77
- # convert_to_numpy=True,
78
- # normalize_embeddings=True,
79
- # show_progress_bar=False
80
- # )
81
- # return embeddings
82
-
83
- # def load_and_split_texts(self):
84
- # if self._splits_exist():
85
- # return self._load_existing_splits()
86
 
87
- # documents = []
88
- # futures = []
89
 
90
- # for file_path in self.docs_folder.glob("*.txt"):
91
- # future = self.executor.submit(self._process_file, file_path)
92
- # futures.append(future)
93
 
94
- # for future in concurrent.futures.as_completed(futures):
95
- # documents.extend(future.result())
96
 
97
- # self._save_splits(documents)
98
- # return documents
99
 
100
- # def _process_file(self, file_path):
101
- # with open(file_path, 'r', encoding='utf-8') as file:
102
- # text = file.read()
103
- # chunks = [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()]
104
 
105
- # return [
106
- # Document(
107
- # page_content=chunk,
108
- # metadata={
109
- # 'source': file_path.name,
110
- # 'chunk_id': i,
111
- # 'total_chunks': len(chunks)
112
- # }
113
- # )
114
- # for i, chunk in enumerate(chunks)
115
- # ]
116
-
117
- # def load_index(self) -> bool:
118
- # """
119
- # Charge l'index FAISS et les documents associés s'ils existent
120
-
121
- # Returns:
122
- # bool: True si l'index a été chargé, False sinon
123
- # """
124
- # if not self._index_exists():
125
- # print("Aucun index trouvé.")
126
- # return False
127
-
128
- # print("Chargement de l'index existant...")
129
- # try:
130
- # # Charger l'index FAISS
131
- # self.index = faiss.read_index(str(self.index_path))
132
-
133
- # # Charger les documents associés
134
- # with open(self.documents_path, 'rb') as f:
135
- # self.indexed_documents = pickle.load(f)
136
-
137
- # print(f"Index chargé avec {self.index.ntotal} vecteurs")
138
- # return True
139
-
140
- # except Exception as e:
141
- # print(f"Erreur lors du chargement de l'index: {e}")
142
- # return False
143
-
144
- # def create_index(self, documents=None):
145
- # if documents is None:
146
- # documents = self.load_and_split_texts()
147
 
148
- # if not documents:
149
- # return False
150
 
151
- # texts = [doc.page_content for doc in documents]
152
- # embeddings = self.batch_encode(texts)
153
 
154
- # dimension = embeddings.shape[1]
155
- # self.index = faiss.IndexFlatL2(dimension)
156
 
157
- # if torch.cuda.is_available():
158
- # # Use GPU for FAISS if available
159
- # res = faiss.StandardGpuResources()
160
- # self.index = faiss.index_cpu_to_gpu(res, 0, self.index)
161
 
162
- # self.index.add(np.array(embeddings).astype('float32'))
163
- # self.indexed_documents = documents
164
 
165
- # # Save index and documents
166
- # cpu_index = faiss.index_gpu_to_cpu(self.index) if torch.cuda.is_available() else self.index
167
- # faiss.write_index(cpu_index, str(self.index_path))
168
 
169
- # with open(self.documents_path, 'wb') as f:
170
- # pickle.dump(documents, f)
171
 
172
- # return True
173
 
174
- # def _index_exists(self) -> bool:
175
- # """Vérifie si l'index et les documents associés existent"""
176
- # return self.index_path.exists() and self.documents_path.exists()
177
 
178
- # def get_retriever(self, k: int = 10):
179
- # if self.index is None:
180
- # if not self.load_index():
181
- # if not self.create_index():
182
- # raise ValueError("Unable to load or create index")
183
 
184
- # def retriever_function(query: str) -> list:
185
- # # Check cache first
186
- # cache_key = f"{query}_{k}"
187
- # if cache_key in self.response_cache:
188
- # return self.response_cache[cache_key]
189
 
190
- # query_embedding = self.encode(query)
191
 
192
- # distances, indices = self.index.search(
193
- # np.array([query_embedding]).astype('float32'),
194
- # k
195
- # )
196
 
197
- # results = [
198
- # self.indexed_documents[idx]
199
- # for idx in indices[0]
200
- # if idx != -1
201
- # ]
202
 
203
- # # Cache the results
204
- # self.response_cache[cache_key] = results
205
- # return results
206
 
207
- # return retriever_function
208
 
209
  # # # Initialize components
210
  # # mistral_api_key = os.getenv("mistral_api_key")
 
1
+ import gradio as gr
2
+ from langchain_mistralai.chat_models import ChatMistralAI
3
+ from langchain.prompts import ChatPromptTemplate
4
+ from langchain_deepseek import ChatDeepSeek
5
+ from langchain_google_genai import ChatGoogleGenerativeAI
6
+ import os
7
+ from pathlib import Path
8
+ import json
9
+ import faiss
10
+ import numpy as np
11
+ from langchain.schema import Document
12
+ import pickle
13
+ import re
14
+ import requests
15
+ from functools import lru_cache
16
+ import torch
17
+ from sentence_transformers import SentenceTransformer
18
+ from sentence_transformers.cross_encoder import CrossEncoder
19
+ import threading
20
+ from queue import Queue
21
+ import concurrent.futures
22
+ from typing import Generator, Tuple, Iterator
23
+ import time
24
+
25
+ class OptimizedRAGLoader:
26
+ def __init__(self,
27
+ docs_folder: str = "./docs",
28
+ splits_folder: str = "./splits",
29
+ index_folder: str = "./index"):
30
 
31
+ self.docs_folder = Path(docs_folder)
32
+ self.splits_folder = Path(splits_folder)
33
+ self.index_folder = Path(index_folder)
34
 
35
+ # Create folders if they don't exist
36
+ for folder in [self.splits_folder, self.index_folder]:
37
+ folder.mkdir(parents=True, exist_ok=True)
38
 
39
+ # File paths
40
+ self.splits_path = self.splits_folder / "splits.json"
41
+ self.index_path = self.index_folder / "faiss.index"
42
+ self.documents_path = self.index_folder / "documents.pkl"
43
 
44
+ # Initialize components
45
+ self.index = None
46
+ self.indexed_documents = None
47
 
48
+ # Initialize encoder model
49
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+ self.encoder = SentenceTransformer("intfloat/multilingual-e5-large")
51
+ self.encoder.to(self.device)
52
+ self.reranker = model = CrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1",trust_remote_code=True)
53
 
54
+ # Initialize thread pool
55
+ self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
56
 
57
+ # Initialize response cache
58
+ self.response_cache = {}
59
 
60
+ @lru_cache(maxsize=1000)
61
+ def encode(self, text: str):
62
+ """Cached encoding function"""
63
+ with torch.no_grad():
64
+ embeddings = self.encoder.encode(
65
+ text,
66
+ convert_to_numpy=True,
67
+ normalize_embeddings=True
68
+ )
69
+ return embeddings
70
 
71
+ def batch_encode(self, texts: list):
72
+ """Batch encoding for multiple texts"""
73
+ with torch.no_grad():
74
+ embeddings = self.encoder.encode(
75
+ texts,
76
+ batch_size=32,
77
+ convert_to_numpy=True,
78
+ normalize_embeddings=True,
79
+ show_progress_bar=False
80
+ )
81
+ return embeddings
82
+
83
+ def load_and_split_texts(self):
84
+ if self._splits_exist():
85
+ return self._load_existing_splits()
86
 
87
+ documents = []
88
+ futures = []
89
 
90
+ for file_path in self.docs_folder.glob("*.txt"):
91
+ future = self.executor.submit(self._process_file, file_path)
92
+ futures.append(future)
93
 
94
+ for future in concurrent.futures.as_completed(futures):
95
+ documents.extend(future.result())
96
 
97
+ self._save_splits(documents)
98
+ return documents
99
 
100
+ def _process_file(self, file_path):
101
+ with open(file_path, 'r', encoding='utf-8') as file:
102
+ text = file.read()
103
+ chunks = [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()]
104
 
105
+ return [
106
+ Document(
107
+ page_content=chunk,
108
+ metadata={
109
+ 'source': file_path.name,
110
+ 'chunk_id': i,
111
+ 'total_chunks': len(chunks)
112
+ }
113
+ )
114
+ for i, chunk in enumerate(chunks)
115
+ ]
116
+
117
+ def load_index(self) -> bool:
118
+ """
119
+ Charge l'index FAISS et les documents associés s'ils existent
120
+
121
+ Returns:
122
+ bool: True si l'index a été chargé, False sinon
123
+ """
124
+ if not self._index_exists():
125
+ print("Aucun index trouvé.")
126
+ return False
127
+
128
+ print("Chargement de l'index existant...")
129
+ try:
130
+ # Charger l'index FAISS
131
+ self.index = faiss.read_index(str(self.index_path))
132
+
133
+ # Charger les documents associés
134
+ with open(self.documents_path, 'rb') as f:
135
+ self.indexed_documents = pickle.load(f)
136
+
137
+ print(f"Index chargé avec {self.index.ntotal} vecteurs")
138
+ return True
139
+
140
+ except Exception as e:
141
+ print(f"Erreur lors du chargement de l'index: {e}")
142
+ return False
143
+
144
+ def create_index(self, documents=None):
145
+ if documents is None:
146
+ documents = self.load_and_split_texts()
147
 
148
+ if not documents:
149
+ return False
150
 
151
+ texts = [doc.page_content for doc in documents]
152
+ embeddings = self.batch_encode(texts)
153
 
154
+ dimension = embeddings.shape[1]
155
+ self.index = faiss.IndexFlatL2(dimension)
156
 
157
+ if torch.cuda.is_available():
158
+ # Use GPU for FAISS if available
159
+ res = faiss.StandardGpuResources()
160
+ self.index = faiss.index_cpu_to_gpu(res, 0, self.index)
161
 
162
+ self.index.add(np.array(embeddings).astype('float32'))
163
+ self.indexed_documents = documents
164
 
165
+ # Save index and documents
166
+ cpu_index = faiss.index_gpu_to_cpu(self.index) if torch.cuda.is_available() else self.index
167
+ faiss.write_index(cpu_index, str(self.index_path))
168
 
169
+ with open(self.documents_path, 'wb') as f:
170
+ pickle.dump(documents, f)
171
 
172
+ return True
173
 
174
+ def _index_exists(self) -> bool:
175
+ """Vérifie si l'index et les documents associés existent"""
176
+ return self.index_path.exists() and self.documents_path.exists()
177
 
178
+ def get_retriever(self, k: int = 10):
179
+ if self.index is None:
180
+ if not self.load_index():
181
+ if not self.create_index():
182
+ raise ValueError("Unable to load or create index")
183
 
184
+ def retriever_function(query: str) -> list:
185
+ # Check cache first
186
+ cache_key = f"{query}_{k}"
187
+ if cache_key in self.response_cache:
188
+ return self.response_cache[cache_key]
189
 
190
+ query_embedding = self.encode(query)
191
 
192
+ distances, indices = self.index.search(
193
+ np.array([query_embedding]).astype('float32'),
194
+ k
195
+ )
196
 
197
+ results = [
198
+ self.indexed_documents[idx]
199
+ for idx in indices[0]
200
+ if idx != -1
201
+ ]
202
 
203
+ # Cache the results
204
+ self.response_cache[cache_key] = results
205
+ return results
206
 
207
+ return retriever_function
208
 
209
  # # # Initialize components
210
  # # mistral_api_key = os.getenv("mistral_api_key")