File size: 3,686 Bytes
aa975e0
 
cec858f
aa975e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575d2bc
 
 
 
 
 
 
 
cec858f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa975e0
 
 
 
 
 
 
 
 
 
 
e3bfea3
aa975e0
 
9b8b66e
aa975e0
 
 
 
 
 
796fe47
cec858f
796fe47
575d2bc
796fe47
 
aa975e0
 
 
 
 
cec858f
 
 
 
 
796fe47
 
cec858f
e5fe883
cec858f
aa975e0
 
cec858f
3bdb0a7
9f40288
3bdb0a7
 
 
 
 
552f630
d71e26f
552f630
 
66daa12
95a86c7
552f630
6f3f9f8
552f630
cec858f
 
aa975e0
cec858f
 
 
aa975e0
 
 
cec858f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import torch
import spacy
import numpy as np
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
import gradio as gr

PATH = '/data/' # at least 150GB storage needs to be attached
os.environ['TRANSFORMERS_CACHE'] = PATH
os.environ['HF_HOME'] = PATH
os.environ['HF_DATASETS_CACHE'] = PATH
os.environ['TORCH_HOME'] = PATH

HF_TOKEN = os.environ["hf_read"]

SENTIMENT_LABEL_NAMES = {0: "Negative", 1: "No sentiment or Neutral sentiment", 2: "Positive"}
LANGUAGES = ["Czech", "English", "French", "German", "Hungarian", "Polish", "Slovakian"]

id2label = {
    0: "Anger",
    1: "Fear",
    2: "Disgust",
    3: "Sadness",
    4: "Joy",
    5: "None of Them"
}
def load_spacy_model(model_name="xx_sent_ud_sm"):
    try:
        model = spacy.load(model_name)
    except OSError:
        spacy.cli.download(model_name)
        model = spacy.load(model_name)
    return model

def split_sentences(text, model):
    # disable pipeline components not necessary for splitting
    model.disable_pipes(model.pipe_names)  # first disable all the pipes
    model.enable_pipe("senter") # then enable the sentence splitter only

    doc = model(text)
    sentences = [sent.text for sent in doc.sents]

    return sentences

def build_huggingface_path(language: str):
    if language == "Czech" or language == "Slovakian":
        return "visegradmedia-emotion/Emotion_RoBERTa_pooled_V4"
    return "poltextlab/xlm-roberta-large-pooled-MORES"

def predict(text, model_id, tokenizer_id):
    model = AutoModelForSequenceClassification.from_pretrained(model_id, low_cpu_mem_usage=True, device_map="auto", offload_folder="offload", token=HF_TOKEN)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)

    inputs = tokenizer(text,
                       max_length=64,
                       truncation=True,
                       padding="do_not_pad",
                       return_tensors="pt")
    model.eval()

    with torch.no_grad():
        logits = model(**inputs).logits

    probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
    return probs

def get_most_probable_label(probs):
    label = id2label[probs.argmax()]
    probability = f"{round(100 * probs.max(), 2)}%"
    return [sentence, label, probability]

def predict_wrapper(text, language):
    model_id = build_huggingface_path(language)
    tokenizer_id = "xlm-roberta-large"

    spacy_model = load_spacy_model()
    sentences = split_sentences(text, spacy_model)

    results = []
    for sentence in sentences:
        probs = predict(sentence, model_id, tokenizer_id)
        results.append(get_most_probable_label(probs))

    output_info = f'Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.'
    return results, output_info

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            input_text = gr.Textbox(lines=6, label="Input", placeholder="Enter your text here...")
        with gr.Column():
            with gr.Row():
                language_choice = gr.Dropdown(choices=LANGUAGES, label="Language", value="English")
            with gr.Row():
                predict_button = gr.Button("Submit")

    with gr.Row():
        result_table = gr.Dataframe(
            headers=["Sentence", "Prediction", "Confidence"],
            column_widths=["65%", "25%", "10%"],
            wrap=True # important
        )
    with gr.Row():
        model_info = gr.Markdown()

    predict_button.click(
        fn=predict_wrapper,
        inputs=[input_text, language_choice],
        outputs=[result_table, model_info]
    )

if __name__ == "__main__":
    demo.launch()