|
import gradio as gr |
|
from transformers import pipeline |
|
import re |
|
|
|
MODELS = { |
|
"ruRoberta-large": "sberbank-ai/ruRoberta-large", |
|
"rubert-tiny2": "cointegrated/rubert-tiny2", |
|
"multilingual-e5": "intfloat/multilingual-e5-base" |
|
} |
|
|
|
def classify(model_name: str, item: str, categories: str) -> str: |
|
|
|
item = re.sub(r"[^а-яА-ЯёЁ]", " ", item).lower().strip() |
|
|
|
classifier = pipeline( |
|
"zero-shot-classification", |
|
model=MODELS[model_name], |
|
device=-1 |
|
) |
|
|
|
hypothesis_template = ( |
|
"Примеры категорий:\n" |
|
"- молоток → инструменты\n" |
|
"- картофель → овощи\n" |
|
"Категория для '{}' → " |
|
) |
|
|
|
result = classifier( |
|
item, |
|
candidate_labels=[c.strip().lower() for c in categories.split(",")], |
|
hypothesis_template=hypothesis_template, |
|
multi_label=False |
|
) |
|
|
|
if result['scores'][0] < 0.3: |
|
return "Категория не определена" |
|
|
|
return f"{result['labels'][0].capitalize()} ({result['scores'][0]:.2f})" |
|
|
|
iface = gr.Interface( |
|
fn=classify, |
|
inputs=[ |
|
gr.Dropdown(list(MODELS.keys()), label="Модель"), |
|
gr.Textbox(label="Товар"), |
|
gr.Textbox(label="Категории", value="Инструменты, Овощи, Техника") |
|
], |
|
outputs=gr.Textbox(label="Результат"), |
|
examples=[ |
|
["ruRoberta-large", "Аккумуляторная дрель", "Инструменты, Техника"], |
|
["rubert-tiny2", "Свёкла кормовая", "Овощи, Фураж"] |
|
] |
|
) |
|
|
|
iface.launch() |