Spaces:
Sleeping
Sleeping
import streamlit as st | |
from sentence_transformers import SentenceTransformer | |
import pandas as pd | |
import numpy as np | |
import faiss | |
# import openai | |
import spacy | |
from googletrans import Translator | |
from sklearn.metrics.pairwise import cosine_similarity | |
def load_model(): | |
return SentenceTransformer("sentence-transformers/paraphrase-xlm-r-multilingual-v1") | |
def load_data(): | |
df = pd.read_csv('data/final_with_emb.csv') | |
return df | |
def load_embeddings(): | |
return np.load('for_models/embeddings.npy') | |
def load_faiss_index(): | |
index_l2 = faiss.read_index('for_models/faiss_index_l2.bin') | |
index_ip = faiss.read_index('for_models/faiss_index_ip.bin') | |
index_hnsw = faiss.read_index('for_models/faiss_index_hnsw.bin') | |
return {'L2': index_l2, 'IP': index_ip, 'HNSW': index_hnsw} | |
st.title('Рекомендация сериалов') | |
st.markdown( | |
""" | |
<style> | |
.header { | |
font-size: 32px; | |
font-weight: bold; | |
color: #7147e6; | |
margin-bottom: 20px; | |
} | |
.subheader { | |
font-size: 24px; | |
font-weight: 600; | |
color: #7147e6; | |
margin-bottom: 15px; | |
} | |
.paragraph { | |
font-size: 18px; | |
line-height: 1.6; | |
color: #4799e6; | |
margin-bottom: 20px; | |
} | |
.list { | |
font-size: 18px; | |
color: #4799e6; | |
line-height: 1.8; | |
padding-left: 20px; | |
} | |
.service { | |
background-color: #ECF0F1; | |
border-radius: 10px; | |
padding: 20px; | |
margin-bottom: 30px; | |
} | |
.highlight { | |
color: #E74C3C; | |
font-weight: bold; | |
} | |
</style> | |
""", unsafe_allow_html=True | |
) | |
st.markdown('<div class="header">Добро пожаловать на мою страницу!</div>', unsafe_allow_html=True) | |
st.markdown( | |
""" | |
<div class="paragraph"> | |
Этот сервис использует передовые технологии машинного обучения и обработки естественного языка для того, чтобы порекомендовать вам сериалы, которые могут вам понравиться. Мы применяем XLM-RoBERTa для поиска и обработки данных, чтобы вывести наиболее релевантные результаты по вашему запросу. | |
</div> | |
""", unsafe_allow_html=True) | |
st.markdown( | |
""" | |
<div class="subheader">Что умеет сервис?</div> | |
<div class="paragraph"> | |
Cервис предоставляет следующие возможности: | |
</div> | |
<ul class="list"> | |
<li>Поиск сериалов по вашему запросу с использованием различных методов поиска.</li> | |
<li>Перевод информации о сериале в режиме реального времени (если язык - не русский).</li> | |
<li>Вывод информации о сериале, включая название, описание и изображение.</li> | |
<li>Интерактивный поиск с возможностью выбора метода поиска: L2, IP, HNSW.</li> | |
<li>Отображение списка сериалов в удобном формате.</li> | |
</ul> | |
""", unsafe_allow_html=True) | |
def calculate_cosine_similarity(query_emb, embeddings): | |
similarity = cosine_similarity(query_emb, embeddings) | |
return similarity.flatten() | |
def calculate_l2_similarity(query_emb, embeddings): | |
l2_distances = np.linalg.norm(embeddings - query_emb, axis=1) | |
return l2_distances | |
top_k = st.slider('Сколько выдаем рекомендаций?', min_value=1, max_value=20, value=5) | |
def search_similar(query, index_type, top_k=5): | |
query_emb = model.encode([query]).astype(np.float32) | |
if index_type == 'IP': | |
faiss.normalize_L2(query_emb) | |
distances, indices = indexes[index_type].search(query_emb, top_k) | |
# st.write(f"Используемый индекс: {index_type}") | |
# st.write(f"Размер индекса: {indexes[index_type].ntotal}") | |
results = df.iloc[indices[0]] | |
return results, distances[0] | |
translator = Translator() | |
def detect_and_translate(text): | |
detected_lang = translator.detect(text).lang | |
if detected_lang != 'ru': | |
translated_text = translator.translate(text, src=detected_lang, dest='ru').text | |
return translated_text | |
return text | |
nlp = spacy.load('en_core_web_sm') | |
def show_desc(desc, title, max_lines=4): | |
translated_title = detect_and_translate(title) | |
translated_desc = detect_and_translate(desc) | |
doc = nlp(translated_desc) | |
sentence = [sent.text for sent in doc.sents] | |
short_desc = ' '.join(sentence[:max_lines]) | |
st.markdown(f'### {translated_title}') | |
st.write(short_desc) | |
with st.expander('Показать полное описание'): | |
st.write(desc) | |
# client = openai.OpenAI(api_key='сюда свой APIKEY от ChatGPT') | |
def generate_summary(query, title, desc): | |
prompt = f"""Ты – эксперт по кино. Пользователь ищет сериал по запросу: "{query}". | |
Опиши сериал "{title}" коротко и понятно. Объясни, почему он подходит. | |
Описание из базы: {desc} | |
Ответь в формате: | |
- Краткое описание: | |
- Почему стоит посмотреть: | |
""" | |
response = client.chat.completions.create( | |
model="gpt-4", | |
messages=[{"role": "user", "content": prompt}] | |
) | |
return response.choices[0].message.content | |
model = load_model() | |
df = load_data() | |
embeddings = load_embeddings() | |
indexes = load_faiss_index() | |
query = st.text_input('Введите описание сериала', 'Найди мне что-нибудь про автомобили') | |
index_type = st.selectbox('Выберите метод поиска:', ['IP', 'L2', 'HNSW']) | |
if st.button('Начать поиск'): | |
if query: | |
results, scores = search_similar(query, index_type, top_k) | |
st.subheader(f'Результаты c использованием {index_type}:') | |
for _, row in results.iterrows(): | |
title = row['title'] | |
desc = row['description'] | |
image_url = row['image_url'] | |
# summary = generate_summary(query, title, desc) раскоммитить при работе с ChatGPT | |
with st.container(): | |
col1, col2 = st.columns([1, 3]) | |
with col1: | |
st.image(image_url, width=500) | |
with col2: | |
# st.write(summary) если работает ChatGPT | |
show_desc(desc, title) | |
st.markdown('---') | |
query_emb = model.encode([query]).astype(np.float32) | |
cosine_scores = calculate_cosine_similarity(query_emb, embeddings) | |
l2_scores = calculate_l2_similarity(query_emb, embeddings) | |
faiss.normalize_L2(query_emb) | |
distances_hnsw, _ = indexes['HNSW'].search(query_emb, len(df)) | |
hnsw_scores = distances_hnsw[0] | |
df['cosine_similarity'] = cosine_scores | |
df['l2_similarity'] = l2_scores | |
df['hnsw_similarity'] = hnsw_scores | |
df_sorted = df[['title', 'cosine_similarity', 'l2_similarity', 'hnsw_similarity']].sort_values(by='cosine_similarity', ascending=False) | |
st.subheader('Таблица с метриками') | |
st.markdown( | |
""" | |
<style> | |
.stDataFrame { | |
height: 400px; | |
overflow-y: auto; | |
width: 100%; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
st.dataframe(df_sorted) |