Update app.py
Browse files
app.py
CHANGED
@@ -12,7 +12,6 @@ from urllib.parse import urlparse
|
|
12 |
import logging
|
13 |
from sklearn.preprocessing import normalize
|
14 |
from concurrent.futures import ThreadPoolExecutor
|
15 |
-
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
16 |
|
17 |
# Настройка логирования
|
18 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
@@ -32,20 +31,12 @@ db_params = {
|
|
32 |
"sslmode": "require"
|
33 |
}
|
34 |
|
35 |
-
# Загружаем
|
36 |
model_name = "BAAI/bge-m3"
|
37 |
logging.info(f"Загрузка модели {model_name}...")
|
38 |
model = SentenceTransformer(model_name)
|
39 |
logging.info("Модель загружена успешно.")
|
40 |
|
41 |
-
# Загружаем модель-реранкер
|
42 |
-
reranker_name = 'BAAI/bge-reranker-v2-m3'
|
43 |
-
logging.info(f"Загрузка модели реранкера {reranker_name}...")
|
44 |
-
reranker_tokenizer = AutoTokenizer.from_pretrained(reranker_name)
|
45 |
-
reranker_model = AutoModelForSequenceClassification.from_pretrained(reranker_name)
|
46 |
-
reranker_model.eval()
|
47 |
-
logging.info("Модель реранкера загружена успешно.")
|
48 |
-
|
49 |
# Имена таблиц
|
50 |
embeddings_table = "movie_embeddings"
|
51 |
query_cache_table = "query_cache"
|
@@ -298,27 +289,6 @@ def get_movie_embeddings(conn):
|
|
298 |
logging.error(f"Ошибка при загрузке эмбеддингов фильмов: {e}")
|
299 |
return movie_embeddings
|
300 |
|
301 |
-
def rerank_results(query, results, conn):
|
302 |
-
"""Ранжирует результаты поиска с помощью реранкера."""
|
303 |
-
if not results:
|
304 |
-
return []
|
305 |
-
|
306 |
-
pairs = []
|
307 |
-
movie_ids = []
|
308 |
-
for movie_id, _ in results:
|
309 |
-
movie = next((m for m in movies_data if m['id'] == movie_id), None)
|
310 |
-
if movie:
|
311 |
-
movie_info = f"Название: {movie['name']}\nГод: {movie['year']}\nЖанры: {movie['genreslist']}\nОписание: {movie['description']}"
|
312 |
-
pairs.append([query, movie_info])
|
313 |
-
movie_ids.append(movie_id)
|
314 |
-
|
315 |
-
with torch.no_grad():
|
316 |
-
inputs = reranker_tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
|
317 |
-
scores = reranker_model(**inputs, return_dict=True).logits.view(-1, ).float().cpu().numpy()
|
318 |
-
|
319 |
-
reranked_results = sorted(zip(movie_ids, scores), key=lambda x: x[1], reverse=True)
|
320 |
-
return reranked_results
|
321 |
-
|
322 |
def search_movies(query, top_k=20):
|
323 |
"""Выполняет поиск фильмов по запросу."""
|
324 |
global search_in_progress
|
@@ -362,29 +332,23 @@ def search_movies(query, top_k=20):
|
|
362 |
FROM {embeddings_table} m, query_embedding
|
363 |
ORDER BY similarity DESC
|
364 |
LIMIT %s
|
365 |
-
""", (query_crc32, top_k
|
366 |
|
367 |
results = cur.fetchall()
|
368 |
-
logging.info(f"Найдено {len(results)}
|
369 |
except Exception as e:
|
370 |
logging.error(f"Ошибка при выполнении поискового запроса: {e}")
|
371 |
results = []
|
372 |
-
|
373 |
-
# Применяем реранкер
|
374 |
-
reranked_results = rerank_results(query, results, conn)
|
375 |
-
|
376 |
-
# Ограничиваем количество результатов после реранкинга
|
377 |
-
reranked_results = reranked_results[:top_k]
|
378 |
|
379 |
output = ""
|
380 |
-
for movie_id,
|
381 |
# Находим фильм по ID
|
382 |
movie = next((m for m in movies_data if m['id'] == movie_id), None)
|
383 |
if movie:
|
384 |
output += f"<h3>{movie['name']} ({movie['year']})</h3>\n"
|
385 |
output += f"<p><strong>Жанры:</strong> {movie['genreslist']}</p>\n"
|
386 |
output += f"<p><strong>Описание:</strong> {movie['description']}</p>\n"
|
387 |
-
output += f"<p><strong>Релевантность:</strong> {
|
388 |
output += "<hr>\n"
|
389 |
|
390 |
search_time = time.time() - start_time
|
|
|
12 |
import logging
|
13 |
from sklearn.preprocessing import normalize
|
14 |
from concurrent.futures import ThreadPoolExecutor
|
|
|
15 |
|
16 |
# Настройка логирования
|
17 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
31 |
"sslmode": "require"
|
32 |
}
|
33 |
|
34 |
+
# Загружаем модель
|
35 |
model_name = "BAAI/bge-m3"
|
36 |
logging.info(f"Загрузка модели {model_name}...")
|
37 |
model = SentenceTransformer(model_name)
|
38 |
logging.info("Модель загружена успешно.")
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
# Имена таблиц
|
41 |
embeddings_table = "movie_embeddings"
|
42 |
query_cache_table = "query_cache"
|
|
|
289 |
logging.error(f"Ошибка при загрузке эмбеддингов фильмов: {e}")
|
290 |
return movie_embeddings
|
291 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
def search_movies(query, top_k=20):
|
293 |
"""Выполняет поиск фильмов по запросу."""
|
294 |
global search_in_progress
|
|
|
332 |
FROM {embeddings_table} m, query_embedding
|
333 |
ORDER BY similarity DESC
|
334 |
LIMIT %s
|
335 |
+
""", (query_crc32, top_k))
|
336 |
|
337 |
results = cur.fetchall()
|
338 |
+
logging.info(f"Найдено {len(results)} результатов поиска.")
|
339 |
except Exception as e:
|
340 |
logging.error(f"Ошибка при выполнении поискового запроса: {e}")
|
341 |
results = []
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
|
343 |
output = ""
|
344 |
+
for movie_id, similarity in results:
|
345 |
# Находим фильм по ID
|
346 |
movie = next((m for m in movies_data if m['id'] == movie_id), None)
|
347 |
if movie:
|
348 |
output += f"<h3>{movie['name']} ({movie['year']})</h3>\n"
|
349 |
output += f"<p><strong>Жанры:</strong> {movie['genreslist']}</p>\n"
|
350 |
output += f"<p><strong>Описание:</strong> {movie['description']}</p>\n"
|
351 |
+
output += f"<p><strong>Релевантность:</strong> {similarity:.4f}</p>\n"
|
352 |
output += "<hr>\n"
|
353 |
|
354 |
search_time = time.time() - start_time
|