File size: 1,844 Bytes
7ab3fbc 0803d70 7ab3fbc 0a1dfe8 0803d70 3ee972d 7ab3fbc 6d95508 0803d70 0a1dfe8 0803d70 6d95508 0a1dfe8 6d95508 0a1dfe8 6d95508 0803d70 0a1dfe8 6d95508 1515adb 6d95508 7ab3fbc 6d95508 86bd5a4 6d95508 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import gradio as gr
import numpy as np
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
MODELS = {
"rubert-tiny2": "cointegrated/rubert-tiny2",
"sbert": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
"LaBSE": "sentence-transformers/LaBSE",
"ruRoberta": "sberbank-ai/ruRoberta-large"
}
def get_embeddings(model, tokenizer, text):
# Добавляем промпт
prompted_text = f"Товар: {text}. Категория:"
inputs = tokenizer(prompted_text,
padding=True,
truncation=True,
return_tensors="pt",
max_length=512)
outputs = model(**inputs)
return outputs.last_hidden_state[:, 0].detach().numpy()
def classify(model_name: str, item: str, categories: str) -> str:
tokenizer = AutoTokenizer.from_pretrained(MODELS[model_name])
model = AutoModel.from_pretrained(MODELS[model_name])
# Эмбеддинги для товара с промптом
item_embedding = get_embeddings(model, tokenizer, item)
# Эмбеддинги для категорий
category_embeddings = []
for category in categories.split(","):
emb = get_embeddings(model, tokenizer, category.strip())
category_embeddings.append(emb)
# Сравнение
similarities = cosine_similarity(item_embedding, np.vstack(category_embeddings))[0]
best_idx = np.argmax(similarities)
return f"{categories.split(',')[best_idx].strip()} ({similarities[best_idx]:.2f})"
gr.Interface(
fn=classify,
inputs=[
gr.Dropdown(list(MODELS.keys())),
gr.Textbox(),
gr.Textbox(value="Инструменты, Овощи, Техника")
],
outputs=gr.Textbox()
).launch() |