opex792 commited on
Commit
94d93d6
·
verified ·
1 Parent(s): b7a12de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -41
app.py CHANGED
@@ -12,7 +12,6 @@ from urllib.parse import urlparse
12
  import logging
13
  from sklearn.preprocessing import normalize
14
  from concurrent.futures import ThreadPoolExecutor
15
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
16
 
17
  # Настройка логирования
18
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -32,20 +31,12 @@ db_params = {
32
  "sslmode": "require"
33
  }
34
 
35
- # Загружаем модель-энкодер
36
  model_name = "BAAI/bge-m3"
37
  logging.info(f"Загрузка модели {model_name}...")
38
  model = SentenceTransformer(model_name)
39
  logging.info("Модель загружена успешно.")
40
 
41
- # Загружаем модель-реранкер
42
- reranker_name = 'BAAI/bge-reranker-v2-m3'
43
- logging.info(f"Загрузка модели реранкера {reranker_name}...")
44
- reranker_tokenizer = AutoTokenizer.from_pretrained(reranker_name)
45
- reranker_model = AutoModelForSequenceClassification.from_pretrained(reranker_name)
46
- reranker_model.eval()
47
- logging.info("Модель реранкера загружена успешно.")
48
-
49
  # Имена таблиц
50
  embeddings_table = "movie_embeddings"
51
  query_cache_table = "query_cache"
@@ -298,27 +289,6 @@ def get_movie_embeddings(conn):
298
  logging.error(f"Ошибка при загрузке эмбеддингов фильмов: {e}")
299
  return movie_embeddings
300
 
301
- def rerank_results(query, results, conn):
302
- """Ранжирует результаты поиска с помощью реранкера."""
303
- if not results:
304
- return []
305
-
306
- pairs = []
307
- movie_ids = []
308
- for movie_id, _ in results:
309
- movie = next((m for m in movies_data if m['id'] == movie_id), None)
310
- if movie:
311
- movie_info = f"Название: {movie['name']}\nГод: {movie['year']}\nЖанры: {movie['genreslist']}\nОписание: {movie['description']}"
312
- pairs.append([query, movie_info])
313
- movie_ids.append(movie_id)
314
-
315
- with torch.no_grad():
316
- inputs = reranker_tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
317
- scores = reranker_model(**inputs, return_dict=True).logits.view(-1, ).float().cpu().numpy()
318
-
319
- reranked_results = sorted(zip(movie_ids, scores), key=lambda x: x[1], reverse=True)
320
- return reranked_results
321
-
322
  def search_movies(query, top_k=20):
323
  """Выполняет поиск фильмов по запросу."""
324
  global search_in_progress
@@ -362,29 +332,23 @@ def search_movies(query, top_k=20):
362
  FROM {embeddings_table} m, query_embedding
363
  ORDER BY similarity DESC
364
  LIMIT %s
365
- """, (query_crc32, top_k * 2)) # Увеличиваем лимит для последующего реранкинга
366
 
367
  results = cur.fetchall()
368
- logging.info(f"Найдено {len(results)} предварительных результатов поиска.")
369
  except Exception as e:
370
  logging.error(f"Ошибка при выполнении поискового запроса: {e}")
371
  results = []
372
-
373
- # Применяем реранкер
374
- reranked_results = rerank_results(query, results, conn)
375
-
376
- # Ограничиваем количество результатов после реранкинга
377
- reranked_results = reranked_results[:top_k]
378
 
379
  output = ""
380
- for movie_id, score in reranked_results:
381
  # Находим фильм по ID
382
  movie = next((m for m in movies_data if m['id'] == movie_id), None)
383
  if movie:
384
  output += f"<h3>{movie['name']} ({movie['year']})</h3>\n"
385
  output += f"<p><strong>Жанры:</strong> {movie['genreslist']}</p>\n"
386
  output += f"<p><strong>Описание:</strong> {movie['description']}</p>\n"
387
- output += f"<p><strong>Релевантность:</strong> {score:.4f}</p>\n" # Используем score от реранкера
388
  output += "<hr>\n"
389
 
390
  search_time = time.time() - start_time
 
12
  import logging
13
  from sklearn.preprocessing import normalize
14
  from concurrent.futures import ThreadPoolExecutor
 
15
 
16
  # Настройка логирования
17
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
31
  "sslmode": "require"
32
  }
33
 
34
+ # Загружаем модель
35
  model_name = "BAAI/bge-m3"
36
  logging.info(f"Загрузка модели {model_name}...")
37
  model = SentenceTransformer(model_name)
38
  logging.info("Модель загружена успешно.")
39
 
 
 
 
 
 
 
 
 
40
  # Имена таблиц
41
  embeddings_table = "movie_embeddings"
42
  query_cache_table = "query_cache"
 
289
  logging.error(f"Ошибка при загрузке эмбеддингов фильмов: {e}")
290
  return movie_embeddings
291
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  def search_movies(query, top_k=20):
293
  """Выполняет поиск фильмов по запросу."""
294
  global search_in_progress
 
332
  FROM {embeddings_table} m, query_embedding
333
  ORDER BY similarity DESC
334
  LIMIT %s
335
+ """, (query_crc32, top_k))
336
 
337
  results = cur.fetchall()
338
+ logging.info(f"Найдено {len(results)} результатов поиска.")
339
  except Exception as e:
340
  logging.error(f"Ошибка при выполнении поискового запроса: {e}")
341
  results = []
 
 
 
 
 
 
342
 
343
  output = ""
344
+ for movie_id, similarity in results:
345
  # Находим фильм по ID
346
  movie = next((m for m in movies_data if m['id'] == movie_id), None)
347
  if movie:
348
  output += f"<h3>{movie['name']} ({movie['year']})</h3>\n"
349
  output += f"<p><strong>Жанры:</strong> {movie['genreslist']}</p>\n"
350
  output += f"<p><strong>Описание:</strong> {movie['description']}</p>\n"
351
+ output += f"<p><strong>Релевантность:</strong> {similarity:.4f}</p>\n"
352
  output += "<hr>\n"
353
 
354
  search_time = time.time() - start_time