hh1199 commited on
Commit
19a9f33
·
verified ·
1 Parent(s): 0847fdb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -19
app.py CHANGED
@@ -1,30 +1,18 @@
1
  import gradio as gr
2
  from transformers import pipeline
3
 
4
- # Загружаем модель
5
- classifier = pipeline("text-classification", model="cointegrated/rubert-tiny2")
6
 
7
- # Категории в порядке, соответствующем меткам модели
8
- CATEGORIES = ["Техника", "Овощи", "Инструменты", "Коробки", "Прочее"] # Измените на ваши категории!
 
 
9
 
10
- def classify(item: str, categories: list) -> str:
11
- prompt = f"""
12
- Товар: {item}
13
- Категории: {", ".join(categories)}.
14
- К какой категории относится товар? Ответь только названием категории.
15
- """
16
- result = classifier(prompt, truncation=True)
17
- print(result)
18
- label_index = int(result[0]['label'].split("_")[1])
19
- #return CATEGORIES[label_index]
20
- return result
21
-
22
- # Интерфейс Gradio
23
  iface = gr.Interface(
24
  fn=classify,
25
  inputs=[
26
- gr.Textbox(label="Название товара cointegrated/rubert-tiny2 ++"),
27
- gr.Textbox(label="Категории (через запятую)", value="Техника, Овощи, Инструменты, Коробки, Прочее")
28
  ],
29
  outputs=gr.Textbox(label="Категория")
30
  )
 
1
  import gradio as gr
2
  from transformers import pipeline
3
 
4
+ classifier = pipeline("zero-shot-classification", model="cointegrated/rubert-tiny2")
 
5
 
6
+ def classify(item: str, categories: str) -> str:
7
+ categories_list = [c.strip() for c in categories.split(",")]
8
+ result = classifier(item, categories_list, multi_label=False)
9
+ return result['labels'][0]
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  iface = gr.Interface(
12
  fn=classify,
13
  inputs=[
14
+ gr.Textbox(label="Название товара"),
15
+ gr.Textbox(label="Категории (через запятую)", value="Техника, Овощи, Инструменты")
16
  ],
17
  outputs=gr.Textbox(label="Категория")
18
  )