opex792 commited on
Commit
512ff06
·
verified ·
1 Parent(s): 905f70e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -42
app.py CHANGED
@@ -81,9 +81,6 @@ batch_size = 32
81
  # Количество потоков для параллельной обработки
82
  num_threads = 5
83
 
84
- # Количество потоков для параллельного реранкинга
85
- rerank_threads = 5 # Подберите оптимальное значение
86
-
87
  def get_db_connection():
88
  """Устанавливает соединение с базой данных."""
89
  try:
@@ -301,57 +298,24 @@ def get_movie_embeddings(conn):
301
  logging.error(f"Ошибка при загрузке эмбеддингов фильмов: {e}")
302
  return movie_embeddings
303
 
304
- def rerank_batch(query, batch):
305
- """Переранжирует пакет результатов с помощью реранкера."""
 
306
  pairs = []
307
  movie_ids = []
308
- for movie_id, _ in batch:
309
  movie = next((m for m in movies_data if m['id'] == movie_id), None)
310
  if movie:
311
  movie_info = f"Название: {movie['name']}\nГод: {movie['year']}\nЖанры: {movie['genreslist']}\nОписание: {movie['description']}"
312
  pairs.append([query, movie_info])
313
  movie_ids.append(movie_id)
 
314
 
315
  with torch.no_grad():
316
  inputs = reranker_tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
317
  scores = reranker_model(**inputs, return_dict=True).logits.view(-1, ).float()
318
 
319
- return list(zip(movie_ids, scores.tolist()))
320
-
321
- def rerank_results(query, results):
322
- """Переранжирует результаты поиска с помощью реранкера."""
323
- logging.info(f"Начало переранжирования для запроса: '{query}'")
324
- reranked_results = []
325
-
326
- with ThreadPoolExecutor(max_workers=rerank_threads) as executor:
327
- futures = []
328
- batch = []
329
- batch_num = 0
330
- for i, result in enumerate(results):
331
- batch.append(result)
332
- if len(batch) >= batch_size: # Отправляем на реранк батчами
333
- logging.info(f"Отправка на переранжирование батча {batch_num+1} ({len(batch)} фильмов)")
334
- future = executor.submit(rerank_batch, query, batch)
335
- futures.append(future)
336
- batch = []
337
- batch_num += 1
338
-
339
- # Обработка остатка
340
- if batch:
341
- logging.info(f"Отправка на переранжирование батча {batch_num+1} ({len(batch)} фильмов)")
342
- future = executor.submit(rerank_batch, query, batch)
343
- futures.append(future)
344
-
345
- # Сбор результатов
346
- for i, future in enumerate(futures):
347
- try:
348
- batch_result = future.result()
349
- reranked_results.extend(batch_result)
350
- logging.info(f"Завершен реранк батча {i+1}")
351
- except Exception as e:
352
- logging.error(f"Ошибка при переранжировании батча {i+1}: {e}")
353
-
354
- reranked_results = sorted(reranked_results, key=lambda x: x[1], reverse=True)
355
  logging.info("Переранжирование завершено.")
356
  return reranked_results
357
 
 
81
  # Количество потоков для параллельной обработки
82
  num_threads = 5
83
 
 
 
 
84
  def get_db_connection():
85
  """Устанавливает соединение с базой данных."""
86
  try:
 
298
  logging.error(f"Ошибка при загрузке эмбеддингов фильмов: {e}")
299
  return movie_embeddings
300
 
301
+ def rerank_results(query, results):
302
+ """Переранжирует результаты поиска с помощью реранкера."""
303
+ logging.info(f"Начало переранжирования для запроса: '{query}'")
304
  pairs = []
305
  movie_ids = []
306
+ for i, (movie_id, _) in enumerate(results):
307
  movie = next((m for m in movies_data if m['id'] == movie_id), None)
308
  if movie:
309
  movie_info = f"Название: {movie['name']}\nГод: {movie['year']}\nЖанры: {movie['genreslist']}\nОписание: {movie['description']}"
310
  pairs.append([query, movie_info])
311
  movie_ids.append(movie_id)
312
+ logging.info(f"Обработка фильма для реранка {i+1}/{len(results)}: {movie['name']}")
313
 
314
  with torch.no_grad():
315
  inputs = reranker_tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
316
  scores = reranker_model(**inputs, return_dict=True).logits.view(-1, ).float()
317
 
318
+ reranked_results = sorted(zip(movie_ids, scores.tolist()), key=lambda x: x[1], reverse=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  logging.info("Переранжирование завершено.")
320
  return reranked_results
321