poltextlab commited on
Commit
5f2958d
·
verified ·
1 Parent(s): d5088d4

Add e9 interface

Browse files
Files changed (1) hide show
  1. interfaces/emotion9.py +56 -0
interfaces/emotion9.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import os
4
+ import torch
5
+ import numpy as np
6
+ from transformers import AutoModelForSequenceClassification
7
+ from transformers import AutoTokenizer
8
+ from huggingface_hub import HfApi
9
+
10
+ from label_dicts import EMOTION9_LABEL_NAMES
11
+
12
+ HF_TOKEN = os.environ["hf_read"]
13
+
14
+ languages = [
15
+ "Czech", "English", "German", "Hungarian", "Polish", "Slovak"
16
+ ]
17
+ domains = {
18
+ "parliamentary speech": "parlspeech",
19
+ }
20
+
21
+ def build_huggingface_path(language: str):
22
+ language = language.lower()
23
+ return f"poltextlab/xlm-roberta-large-pooled-{language}-emotions9"
24
+
25
+ def predict(text, model_id, tokenizer_id):
26
+ device = torch.device("cpu")
27
+ model = AutoModelForSequenceClassification.from_pretrained(model_id, low_cpu_mem_usage=True, device_map="auto", offload_folder="offload", token=HF_TOKEN)
28
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
29
+ model.to(device)
30
+
31
+ inputs = tokenizer(text,
32
+ max_length=512,
33
+ truncation=True,
34
+ padding="do_not_pad",
35
+ return_tensors="pt").to(device)
36
+ model.eval()
37
+
38
+ with torch.no_grad():
39
+ logits = model(**inputs).logits
40
+
41
+ NUMS_DICT = {i: key for i, key in enumerate(sorted(EMOTION9_LABEL_NAMES.keys()))}
42
+ output_pred = {f"[{NUMS_DICT[i]}] {EMOTION9_LABEL_NAMES[NUMS_DICT[i]]}": probs[i] for i in np.argsort(probs)[::-1]}
43
+ output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
44
+ return output_pred, output_info
45
+
46
+ def predict_e6(text, language, domain):
47
+ model_id = build_huggingface_path(language)
48
+ tokenizer_id = "xlm-roberta-large"
49
+ return predict(text, model_id, tokenizer_id)
50
+
51
+ demo = gr.Interface(
52
+ fn=predict_e6,
53
+ inputs=[gr.Textbox(lines=6, label="Input"),
54
+ gr.Dropdown(languages, label="Language"),
55
+ gr.Dropdown(domains.keys(), label="Domain")],
56
+ outputs=[gr.Label(num_top_classes=5, label="Output"), gr.Markdown()])