Update app.py
Browse files
app.py
CHANGED
@@ -1,41 +1,39 @@
|
|
1 |
import gradio as gr
|
2 |
-
|
3 |
-
import
|
|
|
4 |
|
5 |
MODELS = {
|
6 |
-
"ruRoberta-large": "sberbank-ai/ruRoberta-large",
|
7 |
"rubert-tiny2": "cointegrated/rubert-tiny2",
|
8 |
-
"
|
|
|
|
|
9 |
}
|
10 |
|
|
|
|
|
|
|
|
|
|
|
11 |
def classify(model_name: str, item: str, categories: str) -> str:
|
12 |
-
#
|
13 |
-
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
model=MODELS[model_name],
|
18 |
-
device=-1
|
19 |
-
)
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
"- картофель → овощи\n"
|
25 |
-
"Категория для '{}' → "
|
26 |
-
)
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
candidate_labels=[c.strip().lower() for c in categories.split(",")],
|
31 |
-
hypothesis_template=hypothesis_template,
|
32 |
-
multi_label=False
|
33 |
-
)
|
34 |
|
35 |
-
|
36 |
-
return "Категория не определена"
|
37 |
-
|
38 |
-
return f"{result['labels'][0].capitalize()} ({result['scores'][0]:.2f})"
|
39 |
|
40 |
iface = gr.Interface(
|
41 |
fn=classify,
|
@@ -44,11 +42,7 @@ iface = gr.Interface(
|
|
44 |
gr.Textbox(label="Товар"),
|
45 |
gr.Textbox(label="Категории", value="Инструменты, Овощи, Техника")
|
46 |
],
|
47 |
-
outputs=gr.Textbox(label="Результат")
|
48 |
-
examples=[
|
49 |
-
["ruRoberta-large", "Аккумуляторная дрель", "Инструменты, Техника"],
|
50 |
-
["rubert-tiny2", "Свёкла кормовая", "Овощи, Фураж"]
|
51 |
-
]
|
52 |
)
|
53 |
-
|
54 |
iface.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
from transformers import AutoTokenizer, AutoModel
|
4 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
5 |
|
6 |
MODELS = {
|
|
|
7 |
"rubert-tiny2": "cointegrated/rubert-tiny2",
|
8 |
+
"sbert": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
|
9 |
+
"LaBSE": "sentence-transformers/LaBSE",
|
10 |
+
"ruRoberta": "sberbank-ai/ruRoberta-large"
|
11 |
}
|
12 |
|
13 |
+
def get_embeddings(model, tokenizer, texts):
|
14 |
+
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
|
15 |
+
outputs = model(**inputs)
|
16 |
+
return outputs.last_hidden_state[:, 0].detach().numpy()
|
17 |
+
|
18 |
def classify(model_name: str, item: str, categories: str) -> str:
|
19 |
+
# Загрузка модели и токенизатора
|
20 |
+
tokenizer = AutoTokenizer.from_pretrained(MODELS[model_name])
|
21 |
+
model = AutoModel.from_pretrained(MODELS[model_name])
|
22 |
+
|
23 |
+
# Подготовка текстов
|
24 |
+
texts = [item] + [c.strip() for c in categories.split(",")]
|
25 |
|
26 |
+
# Получение эмбеддингов
|
27 |
+
embeddings = get_embeddings(model, tokenizer, texts)
|
|
|
|
|
|
|
28 |
|
29 |
+
# Сравнение с категориями
|
30 |
+
item_embedding = embeddings[0].reshape(1, -1)
|
31 |
+
category_embeddings = embeddings[1:]
|
|
|
|
|
|
|
32 |
|
33 |
+
similarities = cosine_similarity(item_embedding, category_embeddings)[0]
|
34 |
+
best_idx = np.argmax(similarities)
|
|
|
|
|
|
|
|
|
35 |
|
36 |
+
return f"{texts[1:][best_idx]} ({similarities[best_idx]:.2f})"
|
|
|
|
|
|
|
37 |
|
38 |
iface = gr.Interface(
|
39 |
fn=classify,
|
|
|
42 |
gr.Textbox(label="Товар"),
|
43 |
gr.Textbox(label="Категории", value="Инструменты, Овощи, Техника")
|
44 |
],
|
45 |
+
outputs=gr.Textbox(label="Результат")
|
|
|
|
|
|
|
|
|
46 |
)
|
47 |
+
|
48 |
iface.launch()
|