hh1199 commited on
Commit
0803d70
·
verified ·
1 Parent(s): 2536443

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -33
app.py CHANGED
@@ -1,41 +1,39 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
- import re
 
4
 
5
  MODELS = {
6
- "ruRoberta-large": "sberbank-ai/ruRoberta-large",
7
  "rubert-tiny2": "cointegrated/rubert-tiny2",
8
- "multilingual-e5": "intfloat/multilingual-e5-base"
 
 
9
  }
10
 
 
 
 
 
 
11
  def classify(model_name: str, item: str, categories: str) -> str:
12
- # Нормализация текста
13
- item = re.sub(r"[^а-яА-ЯёЁ]", " ", item).lower().strip()
 
 
 
 
14
 
15
- classifier = pipeline(
16
- "zero-shot-classification",
17
- model=MODELS[model_name],
18
- device=-1
19
- )
20
 
21
- hypothesis_template = (
22
- "Примеры категорий:\n"
23
- "- молоток → инструменты\n"
24
- "- картофель → овощи\n"
25
- "Категория для '{}' → "
26
- )
27
 
28
- result = classifier(
29
- item,
30
- candidate_labels=[c.strip().lower() for c in categories.split(",")],
31
- hypothesis_template=hypothesis_template,
32
- multi_label=False
33
- )
34
 
35
- if result['scores'][0] < 0.3:
36
- return "Категория не определена"
37
-
38
- return f"{result['labels'][0].capitalize()} ({result['scores'][0]:.2f})"
39
 
40
  iface = gr.Interface(
41
  fn=classify,
@@ -44,11 +42,7 @@ iface = gr.Interface(
44
  gr.Textbox(label="Товар"),
45
  gr.Textbox(label="Категории", value="Инструменты, Овощи, Техника")
46
  ],
47
- outputs=gr.Textbox(label="Результат"),
48
- examples=[
49
- ["ruRoberta-large", "Аккумуляторная дрель", "Инструменты, Техника"],
50
- ["rubert-tiny2", "Свёкла кормовая", "Овощи, Фураж"]
51
- ]
52
  )
53
-
54
  iface.launch()
 
1
  import gradio as gr
2
+ import numpy as np
3
+ from transformers import AutoTokenizer, AutoModel
4
+ from sklearn.metrics.pairwise import cosine_similarity
5
 
6
  MODELS = {
 
7
  "rubert-tiny2": "cointegrated/rubert-tiny2",
8
+ "sbert": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
9
+ "LaBSE": "sentence-transformers/LaBSE",
10
+ "ruRoberta": "sberbank-ai/ruRoberta-large"
11
  }
12
 
13
+ def get_embeddings(model, tokenizer, texts):
14
+ inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
15
+ outputs = model(**inputs)
16
+ return outputs.last_hidden_state[:, 0].detach().numpy()
17
+
18
  def classify(model_name: str, item: str, categories: str) -> str:
19
+ # Загрузка модели и токенизатора
20
+ tokenizer = AutoTokenizer.from_pretrained(MODELS[model_name])
21
+ model = AutoModel.from_pretrained(MODELS[model_name])
22
+
23
+ # Подготовка текстов
24
+ texts = [item] + [c.strip() for c in categories.split(",")]
25
 
26
+ # Получение эмбеддингов
27
+ embeddings = get_embeddings(model, tokenizer, texts)
 
 
 
28
 
29
+ # Сравнение с категориями
30
+ item_embedding = embeddings[0].reshape(1, -1)
31
+ category_embeddings = embeddings[1:]
 
 
 
32
 
33
+ similarities = cosine_similarity(item_embedding, category_embeddings)[0]
34
+ best_idx = np.argmax(similarities)
 
 
 
 
35
 
36
+ return f"{texts[1:][best_idx]} ({similarities[best_idx]:.2f})"
 
 
 
37
 
38
  iface = gr.Interface(
39
  fn=classify,
 
42
  gr.Textbox(label="Товар"),
43
  gr.Textbox(label="Категории", value="Инструменты, Овощи, Техника")
44
  ],
45
+ outputs=gr.Textbox(label="Результат")
 
 
 
 
46
  )
47
+
48
  iface.launch()