Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pickle | |
import faiss | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
from transformers import AutoTokenizer, AutoModel | |
import torch | |
import pandas as pd | |
df = pd.read_csv('data/cleaned_df.csv') | |
labse_model = SentenceTransformer('sentence-transformers/LaBSE') | |
distilbert_tokenizer = AutoTokenizer.from_pretrained('distilbert-base-multilingual-cased') | |
distilbert_model = AutoModel.from_pretrained('distilbert-base-multilingual-cased') | |
tiny2_tokenizer = AutoTokenizer.from_pretrained('cointegrated/rubert-tiny2') | |
tiny2_model = AutoModel.from_pretrained('cointegrated/rubert-tiny2') | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
distilbert_model.to(device) | |
tiny2_model.to(device) | |
with open('models/dashas/labse_index.pkl', 'rb') as f: | |
labse_index = pickle.load(f) | |
with open('models/dashas/distilbert_index.pkl', 'rb') as f: | |
distilbert_index = pickle.load(f) | |
with open('models/dashas/tiny2_index.pkl', 'rb') as f: | |
tiny2_index = pickle.load(f) | |
def search_series(query, model, tokenizer=None, index=None, top_k=5): | |
if tokenizer: | |
inputs = tokenizer([query], return_tensors="pt", padding=True, truncation=True, max_length=128) | |
inputs = {key: val.to(device) for key, val in inputs.items()} | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
query_embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy() | |
else: | |
query_embedding = model.encode([query]) | |
distances, indices = index.search(query_embedding, top_k) | |
results = df.iloc[indices[0]] | |
return results | |
st.title("Умный поиск сериалов") | |
st.image("images/logo_1.jpeg", width=800) # Add your logo here | |
query = st.text_input("Введите запрос:") | |
model_choice = st.selectbox("Выберите модель:", ["LaBSE", "DistilBERT", "tiny2"]) | |
top_k = st.slider("Количество результатов:", min_value=1, max_value=20, value=5) | |
if st.button("Найти"): | |
if query: | |
if model_choice == "LaBSE": | |
results = search_series(query, labse_model, index=labse_index, top_k=top_k) | |
elif model_choice == "DistilBERT": | |
results = search_series(query, distilbert_model, distilbert_tokenizer, distilbert_index, top_k=top_k) | |
elif model_choice == "tiny2": | |
results = search_series(query, tiny2_model, tiny2_tokenizer, tiny2_index, top_k=top_k) | |
st.write("Результаты поиска:") | |
for i, row in results.iterrows(): | |
st.write(f"**{row['title']}**") | |
st.write(row['description']) | |
st.image(row['image_url'], width=600) | |
else: | |
st.write("Пожалуйста, введите запрос.") | |