hh1199 commited on
Commit
6d95508
·
verified ·
1 Parent(s): 13bdacf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -21
app.py CHANGED
@@ -10,39 +10,42 @@ MODELS = {
10
  "ruRoberta": "sberbank-ai/ruRoberta-large"
11
  }
12
 
13
- def get_embeddings(model, tokenizer, texts):
14
- inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
 
 
 
 
 
 
15
  outputs = model(**inputs)
16
  return outputs.last_hidden_state[:, 0].detach().numpy()
17
 
18
  def classify(model_name: str, item: str, categories: str) -> str:
19
- # Загрузка модели и токенизатора
20
  tokenizer = AutoTokenizer.from_pretrained(MODELS[model_name])
21
  model = AutoModel.from_pretrained(MODELS[model_name])
22
 
23
- # Подготовка текстов
24
- texts = [item] + [c.strip() for c in categories.split(",")]
25
 
26
- # Получение эмбеддингов
27
- embeddings = get_embeddings(model, tokenizer, texts)
 
 
 
28
 
29
- # Сравнение с категориями
30
- item_embedding = embeddings[0].reshape(1, -1)
31
- category_embeddings = embeddings[1:]
32
-
33
- similarities = cosine_similarity(item_embedding, category_embeddings)[0]
34
  best_idx = np.argmax(similarities)
35
 
36
- return f"{texts[1:][best_idx]} ({similarities[best_idx]:.2f})"
37
 
38
- iface = gr.Interface(
39
  fn=classify,
40
  inputs=[
41
- gr.Dropdown(list(MODELS.keys()), label="Модель"),
42
- gr.Textbox(label="Товар"),
43
- gr.Textbox(label="Категории", value="Инструменты, Овощи, Техника")
44
  ],
45
- outputs=gr.Textbox(label="Результат")
46
- )
47
-
48
- iface.launch()
 
10
  "ruRoberta": "sberbank-ai/ruRoberta-large"
11
  }
12
 
13
+ def get_embeddings(model, tokenizer, text):
14
+ # Добавляем промпт
15
+ prompted_text = f"Товар: {text}. Категория:"
16
+ inputs = tokenizer(prompted_text,
17
+ padding=True,
18
+ truncation=True,
19
+ return_tensors="pt",
20
+ max_length=512)
21
  outputs = model(**inputs)
22
  return outputs.last_hidden_state[:, 0].detach().numpy()
23
 
24
  def classify(model_name: str, item: str, categories: str) -> str:
 
25
  tokenizer = AutoTokenizer.from_pretrained(MODELS[model_name])
26
  model = AutoModel.from_pretrained(MODELS[model_name])
27
 
28
+ # Эмбеддинги для товара с промптом
29
+ item_embedding = get_embeddings(model, tokenizer, item)
30
 
31
+ # Эмбеддинги для категорий
32
+ category_embeddings = []
33
+ for category in categories.split(","):
34
+ emb = get_embeddings(model, tokenizer, category.strip())
35
+ category_embeddings.append(emb)
36
 
37
+ # Сравнение
38
+ similarities = cosine_similarity(item_embedding, np.vstack(category_embeddings))[0]
 
 
 
39
  best_idx = np.argmax(similarities)
40
 
41
+ return f"{categories.split(',')[best_idx].strip()} ({similarities[best_idx]:.2f})"
42
 
43
+ gr.Interface(
44
  fn=classify,
45
  inputs=[
46
+ gr.Dropdown(list(MODELS.keys())),
47
+ gr.Textbox(),
48
+ gr.Textbox(value="Инструменты, Овощи, Техника")
49
  ],
50
+ outputs=gr.Textbox()
51
+ ).launch()