poltextlab commited on
Commit
d5088d4
·
verified ·
1 Parent(s): a3f43ae

revert emotion

Browse files
Files changed (1) hide show
  1. interfaces/emotion.py +9 -8
interfaces/emotion.py CHANGED
@@ -7,20 +7,21 @@ from transformers import AutoModelForSequenceClassification
7
  from transformers import AutoTokenizer
8
  from huggingface_hub import HfApi
9
 
10
- from label_dicts import EMOTION9_LABEL_NAMES
11
 
12
  HF_TOKEN = os.environ["hf_read"]
13
 
14
  languages = [
15
- "Czech", "English", "German", "Hungarian", "Polish", "Slovak"
16
  ]
17
  domains = {
18
  "parliamentary speech": "parlspeech",
19
  }
20
 
21
  def build_huggingface_path(language: str):
22
- language = language.lower()
23
- return f"poltextlab/xlm-roberta-large-pooled-{language}-emotions9"
 
24
 
25
  def predict(text, model_id, tokenizer_id):
26
  device = torch.device("cpu")
@@ -38,18 +39,18 @@ def predict(text, model_id, tokenizer_id):
38
  with torch.no_grad():
39
  logits = model(**inputs).logits
40
 
41
- NUMS_DICT = {i: key for i, key in enumerate(sorted(EMOTION9_LABEL_NAMES.keys()))}
42
- output_pred = {f"[{NUMS_DICT[i]}] {EMOTION9_LABEL_NAMES[NUMS_DICT[i]]}": probs[i] for i in np.argsort(probs)[::-1]}
43
  output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
44
  return output_pred, output_info
45
 
46
- def predict_e6(text, language, domain):
47
  model_id = build_huggingface_path(language)
48
  tokenizer_id = "xlm-roberta-large"
49
  return predict(text, model_id, tokenizer_id)
50
 
51
  demo = gr.Interface(
52
- fn=predict_e6,
53
  inputs=[gr.Textbox(lines=6, label="Input"),
54
  gr.Dropdown(languages, label="Language"),
55
  gr.Dropdown(domains.keys(), label="Domain")],
 
7
  from transformers import AutoTokenizer
8
  from huggingface_hub import HfApi
9
 
10
+ from label_dicts import MANIFESTO_LABEL_NAMES
11
 
12
  HF_TOKEN = os.environ["hf_read"]
13
 
14
  languages = [
15
+ "Czech", "English", "French", "German", "Hungarian", "Polish", "Slovak"
16
  ]
17
  domains = {
18
  "parliamentary speech": "parlspeech",
19
  }
20
 
21
  def build_huggingface_path(language: str):
22
+ if language == "Czech" or language == "Slovak":
23
+ return "visegradmedia-emotion/Emotion_RoBERTa_pooled_V4"
24
+ return "poltextlab/xlm-roberta-large-pooled-MORES"
25
 
26
  def predict(text, model_id, tokenizer_id):
27
  device = torch.device("cpu")
 
39
  with torch.no_grad():
40
  logits = model(**inputs).logits
41
 
42
+ probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
43
+ output_pred = {model.config.id2label[i]: probs[i] for i in np.argsort(probs)[::-1]}
44
  output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
45
  return output_pred, output_info
46
 
47
+ def predict_cap(text, language, domain):
48
  model_id = build_huggingface_path(language)
49
  tokenizer_id = "xlm-roberta-large"
50
  return predict(text, model_id, tokenizer_id)
51
 
52
  demo = gr.Interface(
53
+ fn=predict_cap,
54
  inputs=[gr.Textbox(lines=6, label="Input"),
55
  gr.Dropdown(languages, label="Language"),
56
  gr.Dropdown(domains.keys(), label="Domain")],