opex792 commited on
Commit
bcb5bfd
·
verified ·
1 Parent(s): bea8834

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -7
app.py CHANGED
@@ -12,6 +12,7 @@ from urllib.parse import urlparse
12
  import logging
13
  from sklearn.preprocessing import normalize
14
  from concurrent.futures import ThreadPoolExecutor
 
15
 
16
  # Настройка логирования
17
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -31,12 +32,20 @@ db_params = {
31
  "sslmode": "require"
32
  }
33
 
34
- # Загружаем модель
35
  model_name = "BAAI/bge-m3"
36
  logging.info(f"Загрузка модели {model_name}...")
37
  model = SentenceTransformer(model_name)
38
  logging.info("Модель загружена успешно.")
39
 
 
 
 
 
 
 
 
 
40
  # Имена таблиц
41
  embeddings_table = "movie_embeddings"
42
  query_cache_table = "query_cache"
@@ -207,7 +216,7 @@ def process_batch(batch):
207
 
208
  try:
209
  for movie in batch:
210
- embedding_string = f"Название: {movie['name']}\nГод: {movie['year']}\nЖанры: {movie['genreslist']}\nОписание: {movie['description']}"
211
  string_crc32 = calculate_crc32(embedding_string)
212
 
213
  # Проверяем существующий эмбеддинг
@@ -289,6 +298,27 @@ def get_movie_embeddings(conn):
289
  logging.error(f"Ошибка при загрузке эмбеддингов фильмов: {e}")
290
  return movie_embeddings
291
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  def search_movies(query, top_k=20):
293
  """Выполняет поиск фильмов по запросу."""
294
  global search_in_progress
@@ -332,23 +362,29 @@ def search_movies(query, top_k=20):
332
  FROM {embeddings_table} m, query_embedding
333
  ORDER BY similarity DESC
334
  LIMIT %s
335
- """, (query_crc32, top_k))
336
 
337
  results = cur.fetchall()
338
- logging.info(f"Найдено {len(results)} результатов поиска.")
339
  except Exception as e:
340
  logging.error(f"Ошибка при выполнении поискового запроса: {e}")
341
  results = []
 
 
 
 
 
 
342
 
343
  output = ""
344
- for movie_id, similarity in results:
345
  # Находим фильм по ID
346
  movie = next((m for m in movies_data if m['id'] == movie_id), None)
347
  if movie:
348
  output += f"<h3>{movie['name']} ({movie['year']})</h3>\n"
349
- output += f"<p><strong>Жанры:</strong> {movie['genreslist']}</p>\n"
350
  output += f"<p><strong>Описание:</strong> {movie['description']}</p>\n"
351
- output += f"<p><strong>Релевантность:</strong> {similarity:.4f}</p>\n"
352
  output += "<hr>\n"
353
 
354
  search_time = time.time() - start_time
 
12
  import logging
13
  from sklearn.preprocessing import normalize
14
  from concurrent.futures import ThreadPoolExecutor
15
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
16
 
17
  # Настройка логирования
18
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
32
  "sslmode": "require"
33
  }
34
 
35
+ # Загружаем модель-энкодер
36
  model_name = "BAAI/bge-m3"
37
  logging.info(f"Загрузка модели {model_name}...")
38
  model = SentenceTransformer(model_name)
39
  logging.info("Модель загружена успешно.")
40
 
41
+ # Загружаем модель-реранкер
42
+ reranker_name = 'BAAI/bge-reranker-v2-m3'
43
+ logging.info(f"Загрузка модели реранкера {reranker_name}...")
44
+ reranker_tokenizer = AutoTokenizer.from_pretrained(reranker_name)
45
+ reranker_model = AutoModelForSequenceClassification.from_pretrained(reranker_name)
46
+ reranker_model.eval()
47
+ logging.info("Модель реранкера загружена успешно.")
48
+
49
  # Имена таблиц
50
  embeddings_table = "movie_embeddings"
51
  query_cache_table = "query_cache"
 
216
 
217
  try:
218
  for movie in batch:
219
+ embedding_string = f"Название: {movie['name']}\nГод: {movie['year']}\nЖанры: {movie['genresList']}\nОписание: {movie['description']}"
220
  string_crc32 = calculate_crc32(embedding_string)
221
 
222
  # Проверяем существующий эмбеддинг
 
298
  logging.error(f"Ошибка при загрузке эмбеддингов фильмов: {e}")
299
  return movie_embeddings
300
 
301
+ def rerank_results(query, results, conn):
302
+ """Ранжирует результаты поиска с помощью реранкера."""
303
+ if not results:
304
+ return []
305
+
306
+ pairs = []
307
+ movie_ids = []
308
+ for movie_id, _ in results:
309
+ movie = next((m for m in movies_data if m['id'] == movie_id), None)
310
+ if movie:
311
+ movie_info = f"Название: {movie['name']}\nГод: {movie['year']}\nЖанры: {movie['genresList']}\nОписание: {movie['description']}"
312
+ pairs.append([query, movie_info])
313
+ movie_ids.append(movie_id)
314
+
315
+ with torch.no_grad():
316
+ inputs = reranker_tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
317
+ scores = reranker_model(**inputs, return_dict=True).logits.view(-1, ).float().cpu().numpy()
318
+
319
+ reranked_results = sorted(zip(movie_ids, scores), key=lambda x: x[1], reverse=True)
320
+ return reranked_results
321
+
322
  def search_movies(query, top_k=20):
323
  """Выполняет поиск фильмов по запросу."""
324
  global search_in_progress
 
362
  FROM {embeddings_table} m, query_embedding
363
  ORDER BY similarity DESC
364
  LIMIT %s
365
+ """, (query_crc32, top_k * 2)) # Увеличиваем лимит для последующего реранкинга
366
 
367
  results = cur.fetchall()
368
+ logging.info(f"Найдено {len(results)} предварительных результатов поиска.")
369
  except Exception as e:
370
  logging.error(f"Ошибка при выполнении поискового запроса: {e}")
371
  results = []
372
+
373
+ # Применяем реранкер
374
+ reranked_results = rerank_results(query, results, conn)
375
+
376
+ # Ограничиваем количество результатов после реранкинга
377
+ reranked_results = reranked_results[:top_k]
378
 
379
  output = ""
380
+ for movie_id, score in reranked_results:
381
  # Находим фильм по ID
382
  movie = next((m for m in movies_data if m['id'] == movie_id), None)
383
  if movie:
384
  output += f"<h3>{movie['name']} ({movie['year']})</h3>\n"
385
+ output += f"<p><strong>Жанры:</strong> {movie['genresList']}</p>\n"
386
  output += f"<p><strong>Описание:</strong> {movie['description']}</p>\n"
387
+ output += f"<p><strong>Релевантность:</strong> {score:.4f}</p>\n" # Используем score от реранкера
388
  output += "<hr>\n"
389
 
390
  search_time = time.time() - start_time