|
import gradio as gr |
|
from sentence_transformers import SentenceTransformer, util |
|
import os |
|
import time |
|
import threading |
|
import queue |
|
import torch |
|
import psycopg2 |
|
import zlib |
|
import numpy as np |
|
from urllib.parse import urlparse |
|
|
|
|
|
DATABASE_URL = os.environ.get("DATABASE_URL") |
|
if DATABASE_URL is None: |
|
raise ValueError("DATABASE_URL environment variable not set.") |
|
|
|
parsed_url = urlparse(DATABASE_URL) |
|
db_params = { |
|
"host": parsed_url.hostname, |
|
"port": parsed_url.port, |
|
"database": parsed_url.path.lstrip("/"), |
|
"user": parsed_url.username, |
|
"password": parsed_url.password, |
|
"sslmode": "require" |
|
} |
|
|
|
|
|
model_name = "BAAI/bge-m3" |
|
model = SentenceTransformer(model_name) |
|
|
|
|
|
embeddings_table = "movie_embeddings" |
|
query_cache_table = "query_cache" |
|
|
|
|
|
MAX_CACHE_SIZE = 50 * 1024 * 1024 |
|
|
|
|
|
try: |
|
import json |
|
with open("movies.json", "r", encoding="utf-8") as f: |
|
movies_data = json.load(f) |
|
except FileNotFoundError: |
|
print("Ошибка: Файл movies.json не найден.") |
|
movies_data = [] |
|
|
|
|
|
movies_queue = queue.Queue() |
|
|
|
|
|
processing_complete = False |
|
|
|
|
|
search_in_progress = False |
|
|
|
|
|
db_lock = threading.Lock() |
|
|
|
|
|
batch_size = 32 |
|
|
|
def get_db_connection(): |
|
"""Устанавливает соединение с базой данных.""" |
|
try: |
|
conn = psycopg2.connect(**db_params) |
|
return conn |
|
except Exception as e: |
|
print(f"Ошибка подключения к базе данных: {e}") |
|
return None |
|
|
|
def setup_database(): |
|
"""Настраивает базу данных: создает расширение, таблицы и индексы.""" |
|
conn = get_db_connection() |
|
if conn is None: |
|
return |
|
|
|
with conn.cursor() as cur: |
|
|
|
cur.execute("CREATE EXTENSION IF NOT EXISTS vector;") |
|
|
|
|
|
cur.execute(f""" |
|
CREATE TABLE IF NOT EXISTS {embeddings_table} ( |
|
movie_id INTEGER PRIMARY KEY, |
|
embedding_crc32 BIGINT, |
|
string_crc32 BIGINT, |
|
model_name TEXT, |
|
embedding float8[] |
|
); |
|
CREATE INDEX IF NOT EXISTS idx_movie_embeddings_crc32 ON {embeddings_table} (string_crc32); |
|
""") |
|
|
|
|
|
cur.execute(f""" |
|
CREATE TABLE IF NOT EXISTS {query_cache_table} ( |
|
query_crc32 BIGINT PRIMARY KEY, |
|
query TEXT, |
|
model_name TEXT, |
|
embedding float8[], |
|
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP |
|
); |
|
CREATE INDEX IF NOT EXISTS idx_query_cache_crc32 ON {query_cache_table} (query_crc32); |
|
CREATE INDEX IF NOT EXISTS idx_query_cache_created ON {query_cache_table} (created_at); |
|
""") |
|
|
|
conn.commit() |
|
conn.close() |
|
|
|
|
|
setup_database() |
|
|
|
def calculate_crc32(text): |
|
"""Вычисляет CRC32 для строки.""" |
|
return zlib.crc32(text.encode('utf-8')) & 0xFFFFFFFF |
|
|
|
def encode_string(text): |
|
"""Кодирует строку в эмбеддинг.""" |
|
return model.encode(text, convert_to_tensor=True, normalize_embeddings=True) |
|
|
|
def get_movies_without_embeddings(): |
|
"""Получает список фильмов, для которых нужно создать эмбеддинги.""" |
|
conn = get_db_connection() |
|
if conn is None: |
|
return [] |
|
|
|
movies_to_process = [] |
|
with conn.cursor() as cur: |
|
|
|
cur.execute(f"SELECT movie_id FROM {embeddings_table}") |
|
existing_ids = {row[0] for row in cur.fetchall()} |
|
|
|
|
|
for movie in movies_data: |
|
if movie['id'] not in existing_ids: |
|
movies_to_process.append(movie) |
|
|
|
conn.close() |
|
return movies_to_process |
|
|
|
def vector_to_list(vector): |
|
"""Преобразует вектор PyTorch в список float.""" |
|
return vector.detach().cpu().numpy().tolist() |
|
|
|
def list_to_vector(lst): |
|
"""Преобразует список float в вектор PyTorch.""" |
|
return torch.tensor(lst) |
|
|
|
def get_embedding_from_db(conn, table_name, crc32_column, crc32_value, model_name): |
|
"""Получает эмбеддинг из базы данных.""" |
|
with conn.cursor() as cur: |
|
cur.execute(f"SELECT embedding FROM {table_name} WHERE {crc32_column} = %s AND model_name = %s", |
|
(crc32_value, model_name)) |
|
result = cur.fetchone() |
|
if result and result[0]: |
|
return list_to_vector(result[0]) |
|
return None |
|
|
|
def insert_embedding(conn, table_name, movie_id, embedding_crc32, string_crc32, embedding): |
|
"""Вставляет эмбеддинг в базу данных.""" |
|
embedding_list = vector_to_list(embedding) |
|
with conn.cursor() as cur: |
|
try: |
|
cur.execute(f""" |
|
INSERT INTO {table_name} |
|
(movie_id, embedding_crc32, string_crc32, model_name, embedding) |
|
VALUES (%s, %s, %s, %s, %s) |
|
ON CONFLICT (movie_id) DO NOTHING |
|
""", (movie_id, embedding_crc32, string_crc32, model_name, embedding_list)) |
|
conn.commit() |
|
return True |
|
except Exception as e: |
|
print(f"Ошибка при вставке эмбеддинга: {e}") |
|
conn.rollback() |
|
return False |
|
|
|
def process_movies(): |
|
"""Обрабатывает фильмы, создавая для них эмбеддинги.""" |
|
global processing_complete |
|
|
|
|
|
movies_to_process = get_movies_without_embeddings() |
|
|
|
if not movies_to_process: |
|
print("Все фильмы уже обработаны.") |
|
processing_complete = True |
|
return |
|
|
|
|
|
for movie in movies_to_process: |
|
movies_queue.put(movie) |
|
|
|
conn = get_db_connection() |
|
if conn is None: |
|
processing_complete = True |
|
return |
|
|
|
while True: |
|
if search_in_progress: |
|
time.sleep(1) |
|
continue |
|
|
|
batch = [] |
|
while not movies_queue.empty() and len(batch) < batch_size: |
|
try: |
|
movie = movies_queue.get_nowait() |
|
batch.append(movie) |
|
except queue.Empty: |
|
break |
|
|
|
if not batch: |
|
break |
|
|
|
print(f"Обработка пакета из {len(batch)} фильмов...") |
|
|
|
for movie in batch: |
|
embedding_string = f"Название: {movie['name']}\nГод: {movie['year']}\nЖанры: {movie['genresList']}\nОписание: {movie['description']}" |
|
string_crc32 = calculate_crc32(embedding_string) |
|
|
|
|
|
existing_embedding = get_embedding_from_db(conn, embeddings_table, "string_crc32", string_crc32, model_name) |
|
|
|
if existing_embedding is None: |
|
embedding = encode_string(embedding_string) |
|
embedding_crc32 = calculate_crc32(str(vector_to_list(embedding))) |
|
|
|
if insert_embedding(conn, embeddings_table, movie['id'], embedding_crc32, string_crc32, embedding): |
|
print(f"Сохранен эмбеддинг для '{movie['name']}'") |
|
else: |
|
print(f"Ошибка сохранения эмбеддинга для '{movie['name']}'") |
|
else: |
|
print(f"Эмбеддинг для '{movie['name']}' уже существует") |
|
|
|
conn.close() |
|
processing_complete = True |
|
print("Обработка фильмов завершена") |
|
|
|
def get_movie_embeddings(conn): |
|
"""Загружает все эмбеддинги фильмов из базы данных.""" |
|
movie_embeddings = {} |
|
with conn.cursor() as cur: |
|
cur.execute(f""" |
|
SELECT e.movie_id, e.embedding |
|
FROM {embeddings_table} e |
|
""") |
|
for movie_id, embedding in cur.fetchall(): |
|
|
|
for movie in movies_data: |
|
if movie['id'] == movie_id: |
|
movie_embeddings[movie['name']] = list_to_vector(embedding) |
|
break |
|
return movie_embeddings |
|
|
|
def search_movies(query, top_k=10): |
|
"""Выполняет поиск фильмов по запросу.""" |
|
global search_in_progress |
|
search_in_progress = True |
|
start_time = time.time() |
|
|
|
try: |
|
conn = get_db_connection() |
|
if conn is None: |
|
return "<p>Ошибка подключения к базе данных</p>" |
|
|
|
query_crc32 = calculate_crc32(query) |
|
query_embedding = get_embedding_from_db(conn, query_cache_table, "query_crc32", query_crc32, model_name) |
|
|
|
if query_embedding is None: |
|
query_embedding = encode_string(query) |
|
embedding_list = vector_to_list(query_embedding) |
|
|
|
with conn.cursor() as cur: |
|
cur.execute(f""" |
|
INSERT INTO {query_cache_table} (query_crc32, query, model_name, embedding) |
|
VALUES (%s, %s, %s, %s) |
|
ON CONFLICT (query_crc32) DO NOTHING |
|
""", (query_crc32, query, model_name, embedding_list)) |
|
conn.commit() |
|
|
|
movie_embeddings = get_movie_embeddings(conn) |
|
|
|
similarities = [] |
|
for title, movie_embedding in movie_embeddings.items(): |
|
similarity = util.pytorch_cos_sim(query_embedding, movie_embedding).item() |
|
similarities.append((title, similarity)) |
|
|
|
similarities.sort(key=lambda x: x[1], reverse=True) |
|
top_results = similarities[:top_k] |
|
|
|
results_html = "<ol>" |
|
for title, score in top_results: |
|
results_html += f"<li><strong>{title}</strong> (Сходство: {score:.4f})</li>" |
|
results_html += "</ol>" |
|
|
|
search_time = time.time() - start_time |
|
conn.close() |
|
|
|
return f"<p>Время поиска: {search_time:.2f} сек</p>{results_html}" |
|
|
|
finally: |
|
search_in_progress = False |
|
|
|
|
|
processing_thread = threading.Thread(target=process_movies) |
|
processing_thread.start() |
|
|
|
|
|
iface = gr.Interface( |
|
fn=search_movies, |
|
inputs=gr.Textbox(lines=2, placeholder="Введите запрос для поиска фильмов..."), |
|
outputs=gr.HTML(label="Результаты поиска"), |
|
title="Семантический поиск фильмов", |
|
description="Введите описание фильма, который вы ищете, и система найдет наиболее похожие фильмы." |
|
) |
|
|
|
|
|
iface.launch() |
|
|