opex792's picture
Update app.py
4725242 verified
raw
history blame
12.1 kB
import gradio as gr
from sentence_transformers import SentenceTransformer, util
import os
import time
import threading
import queue
import torch
import psycopg2
import zlib
import numpy as np
from urllib.parse import urlparse
# Настройки базы данных PostgreSQL
DATABASE_URL = os.environ.get("DATABASE_URL")
if DATABASE_URL is None:
raise ValueError("DATABASE_URL environment variable not set.")
parsed_url = urlparse(DATABASE_URL)
db_params = {
"host": parsed_url.hostname,
"port": parsed_url.port,
"database": parsed_url.path.lstrip("/"),
"user": parsed_url.username,
"password": parsed_url.password,
"sslmode": "require"
}
# Загружаем модель
model_name = "BAAI/bge-m3"
model = SentenceTransformer(model_name)
# Имена таблиц
embeddings_table = "movie_embeddings"
query_cache_table = "query_cache"
# Максимальный размер таблицы кэша запросов в байтах (50MB)
MAX_CACHE_SIZE = 50 * 1024 * 1024
# Загружаем данные из файла movies.json
try:
import json
with open("movies.json", "r", encoding="utf-8") as f:
movies_data = json.load(f)
except FileNotFoundError:
print("Ошибка: Файл movies.json не найден.")
movies_data = []
# Очередь для необработанных фильмов
movies_queue = queue.Queue()
# Флаг, указывающий, что обработка фильмов завершена
processing_complete = False
# Флаг, указывающий, что выполняется поиск
search_in_progress = False
# Блокировка для доступа к базе данных
db_lock = threading.Lock()
# Размер пакета для обработки эмбеддингов
batch_size = 32
def get_db_connection():
"""Устанавливает соединение с базой данных."""
try:
conn = psycopg2.connect(**db_params)
return conn
except Exception as e:
print(f"Ошибка подключения к базе данных: {e}")
return None
def setup_database():
"""Настраивает базу данных: создает расширение, таблицы и индексы."""
conn = get_db_connection()
if conn is None:
return
with conn.cursor() as cur:
# Создаем расширение pgvector
cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
# Создаем таблицу для хранения эмбеддингов фильмов
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {embeddings_table} (
movie_id INTEGER PRIMARY KEY,
embedding_crc32 BIGINT,
string_crc32 BIGINT,
model_name TEXT,
embedding float8[]
);
CREATE INDEX IF NOT EXISTS idx_movie_embeddings_crc32 ON {embeddings_table} (string_crc32);
""")
# Создаем таблицу для кэширования запросов
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {query_cache_table} (
query_crc32 BIGINT PRIMARY KEY,
query TEXT,
model_name TEXT,
embedding float8[],
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_query_cache_crc32 ON {query_cache_table} (query_crc32);
CREATE INDEX IF NOT EXISTS idx_query_cache_created ON {query_cache_table} (created_at);
""")
conn.commit()
conn.close()
# Настраиваем базу данных при запуске
setup_database()
def calculate_crc32(text):
"""Вычисляет CRC32 для строки."""
return zlib.crc32(text.encode('utf-8')) & 0xFFFFFFFF
def encode_string(text):
"""Кодирует строку в эмбеддинг."""
return model.encode(text, convert_to_tensor=True, normalize_embeddings=True)
def get_movies_without_embeddings():
"""Получает список фильмов, для которых нужно создать эмбеддинги."""
conn = get_db_connection()
if conn is None:
return []
movies_to_process = []
with conn.cursor() as cur:
# Получаем список ID фильмов, которые уже есть в базе
cur.execute(f"SELECT movie_id FROM {embeddings_table}")
existing_ids = {row[0] for row in cur.fetchall()}
# Фильтруем только те фильмы, которых нет в базе
for movie in movies_data:
if movie['id'] not in existing_ids:
movies_to_process.append(movie)
conn.close()
return movies_to_process
def vector_to_list(vector):
"""Преобразует вектор PyTorch в список float."""
return vector.detach().cpu().numpy().tolist()
def list_to_vector(lst):
"""Преобразует список float в вектор PyTorch."""
return torch.tensor(lst)
def get_embedding_from_db(conn, table_name, crc32_column, crc32_value, model_name):
"""Получает эмбеддинг из базы данных."""
with conn.cursor() as cur:
cur.execute(f"SELECT embedding FROM {table_name} WHERE {crc32_column} = %s AND model_name = %s",
(crc32_value, model_name))
result = cur.fetchone()
if result and result[0]:
return list_to_vector(result[0])
return None
def insert_embedding(conn, table_name, movie_id, embedding_crc32, string_crc32, embedding):
"""Вставляет эмбеддинг в базу данных."""
embedding_list = vector_to_list(embedding)
with conn.cursor() as cur:
try:
cur.execute(f"""
INSERT INTO {table_name}
(movie_id, embedding_crc32, string_crc32, model_name, embedding)
VALUES (%s, %s, %s, %s, %s)
ON CONFLICT (movie_id) DO NOTHING
""", (movie_id, embedding_crc32, string_crc32, model_name, embedding_list))
conn.commit()
return True
except Exception as e:
print(f"Ошибка при вставке эмбеддинга: {e}")
conn.rollback()
return False
def process_movies():
"""Обрабатывает фильмы, создавая для них эмбеддинги."""
global processing_complete
# Получаем список фильмов, которые нужно обработать
movies_to_process = get_movies_without_embeddings()
if not movies_to_process:
print("Все фильмы уже обработаны.")
processing_complete = True
return
# Добавляем фильмы в очередь
for movie in movies_to_process:
movies_queue.put(movie)
conn = get_db_connection()
if conn is None:
processing_complete = True
return
while True:
if search_in_progress:
time.sleep(1)
continue
batch = []
while not movies_queue.empty() and len(batch) < batch_size:
try:
movie = movies_queue.get_nowait()
batch.append(movie)
except queue.Empty:
break
if not batch:
break
print(f"Обработка пакета из {len(batch)} фильмов...")
for movie in batch:
embedding_string = f"Название: {movie['name']}\nГод: {movie['year']}\nЖанры: {movie['genresList']}\nОписание: {movie['description']}"
string_crc32 = calculate_crc32(embedding_string)
# Проверяем существующий эмбеддинг
existing_embedding = get_embedding_from_db(conn, embeddings_table, "string_crc32", string_crc32, model_name)
if existing_embedding is None:
embedding = encode_string(embedding_string)
embedding_crc32 = calculate_crc32(str(vector_to_list(embedding)))
if insert_embedding(conn, embeddings_table, movie['id'], embedding_crc32, string_crc32, embedding):
print(f"Сохранен эмбеддинг для '{movie['name']}'")
else:
print(f"Ошибка сохранения эмбеддинга для '{movie['name']}'")
else:
print(f"Эмбеддинг для '{movie['name']}' уже существует")
conn.close()
processing_complete = True
print("Обработка фильмов завершена")
def get_movie_embeddings(conn):
"""Загружает все эмбеддинги фильмов из базы данных."""
movie_embeddings = {}
with conn.cursor() as cur:
cur.execute(f"""
SELECT e.movie_id, e.embedding
FROM {embeddings_table} e
""")
for movie_id, embedding in cur.fetchall():
# Находим название фильма по ID
for movie in movies_data:
if movie['id'] == movie_id:
movie_embeddings[movie['name']] = list_to_vector(embedding)
break
return movie_embeddings
def search_movies(query, top_k=10):
"""Выполняет поиск фильмов по запросу."""
global search_in_progress
search_in_progress = True
start_time = time.time()
try:
conn = get_db_connection()
if conn is None:
return "<p>Ошибка подключения к базе данных</p>"
query_crc32 = calculate_crc32(query)
query_embedding = get_embedding_from_db(conn, query_cache_table, "query_crc32", query_crc32, model_name)
if query_embedding is None:
query_embedding = encode_string(query)
embedding_list = vector_to_list(query_embedding)
with conn.cursor() as cur:
cur.execute(f"""
INSERT INTO {query_cache_table} (query_crc32, query, model_name, embedding)
VALUES (%s, %s, %s, %s)
ON CONFLICT (query_crc32) DO NOTHING
""", (query_crc32, query, model_name, embedding_list))
conn.commit()
movie_embeddings = get_movie_embeddings(conn)
similarities = []
for title, movie_embedding in movie_embeddings.items():
similarity = util.pytorch_cos_sim(query_embedding, movie_embedding).item()
similarities.append((title, similarity))
similarities.sort(key=lambda x: x[1], reverse=True)
top_results = similarities[:top_k]
results_html = "<ol>"
for title, score in top_results:
results_html += f"<li><strong>{title}</strong> (Сходство: {score:.4f})</li>"
results_html += "</ol>"
search_time = time.time() - start_time
conn.close()
return f"<p>Время поиска: {search_time:.2f} сек</p>{results_html}"
finally:
search_in_progress = False
# Запускаем обработку фильмов в отдельном потоке
processing_thread = threading.Thread(target=process_movies)
processing_thread.start()
# Создаем интерфейс Gradio
iface = gr.Interface(
fn=search_movies,
inputs=gr.Textbox(lines=2, placeholder="Введите запрос для поиска фильмов..."),
outputs=gr.HTML(label="Результаты поиска"),
title="Семантический поиск фильмов",
description="Введите описание фильма, который вы ищете, и система найдет наиболее похожие фильмы."
)
# Запускаем интерфейс
iface.launch()