File size: 5,742 Bytes
aa975e0
 
cec858f
aa975e0
2d748e6
aa975e0
 
 
7c543b8
ac4450b
aa975e0
 
 
 
 
 
 
 
 
 
 
 
575d2bc
 
 
 
 
 
 
 
cec858f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa975e0
 
 
 
 
 
 
 
 
 
 
e3bfea3
aa975e0
 
9b8b66e
aa975e0
 
 
 
 
 
796fe47
cec858f
796fe47
575d2bc
796fe47
b742f2b
aa975e0
5b27629
ac4450b
ed412e3
5b27629
f9b69e4
 
5b27629
3838a79
ed412e3
5b27629
ed412e3
ac4450b
 
8d02664
18c170e
ed412e3
 
 
27ca4aa
 
98bd063
ac4450b
 
 
e252a02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa975e0
 
 
 
cec858f
 
 
 
ac4450b
cec858f
796fe47
2e2f372
 
ac4450b
cec858f
0d8846e
 
 
 
e252a02
e5fe883
ac4450b
 
aa975e0
 
cec858f
3bdb0a7
9f40288
3bdb0a7
 
 
 
 
552f630
d71e26f
552f630
 
66daa12
95a86c7
552f630
169b92f
 
 
 
6f3f9f8
552f630
cec858f
 
aa975e0
cec858f
169b92f
cec858f
aa975e0
 
 
cec858f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import os
import torch
import spacy
import numpy as np
import pandas as pd
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
import gradio as gr
import matplotlib.pyplot as plt
import seaborn as sns

PATH = '/data/' # at least 150GB storage needs to be attached
os.environ['TRANSFORMERS_CACHE'] = PATH
os.environ['HF_HOME'] = PATH
os.environ['HF_DATASETS_CACHE'] = PATH
os.environ['TORCH_HOME'] = PATH

HF_TOKEN = os.environ["hf_read"]

SENTIMENT_LABEL_NAMES = {0: "Negative", 1: "No sentiment or Neutral sentiment", 2: "Positive"}
LANGUAGES = ["Czech", "English", "French", "German", "Hungarian", "Polish", "Slovakian"]

id2label = {
    0: "Anger",
    1: "Fear",
    2: "Disgust",
    3: "Sadness",
    4: "Joy",
    5: "None of Them"
}
def load_spacy_model(model_name="xx_sent_ud_sm"):
    try:
        model = spacy.load(model_name)
    except OSError:
        spacy.cli.download(model_name)
        model = spacy.load(model_name)
    return model

def split_sentences(text, model):
    # disable pipeline components not necessary for splitting
    model.disable_pipes(model.pipe_names)  # first disable all the pipes
    model.enable_pipe("senter") # then enable the sentence splitter only

    doc = model(text)
    sentences = [sent.text for sent in doc.sents]

    return sentences

def build_huggingface_path(language: str):
    if language == "Czech" or language == "Slovakian":
        return "visegradmedia-emotion/Emotion_RoBERTa_pooled_V4"
    return "poltextlab/xlm-roberta-large-pooled-MORES"

def predict(text, model_id, tokenizer_id):
    model = AutoModelForSequenceClassification.from_pretrained(model_id, low_cpu_mem_usage=True, device_map="auto", offload_folder="offload", token=HF_TOKEN)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)

    inputs = tokenizer(text,
                       max_length=64,
                       truncation=True,
                       padding="do_not_pad",
                       return_tensors="pt")
    model.eval()

    with torch.no_grad():
        logits = model(**inputs).logits

    probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
    return probs

def get_most_probable_label(probs):
    label = id2label[probs.argmax()]
    probability = f"{round(100 * probs.max(), 2)}%"
    return label, probability


def prepare_heatmap_data(data):
    heatmap_data = pd.DataFrame(0.0, index=id2label.values(), columns=range(len(data)))

    for idx, row in enumerate(data):
        confidences = row["emotions"].tolist()
        for idy, confidence in enumerate(confidences):
            emotion = id2label[idy]
            heatmap_data.at[emotion, idx] = round(confidence, 4)

    heatmap_data.columns = [item["sentence"][:18]+"..." for item in data]
    return heatmap_data

def plot_emotion_heatmap(heatmap_data):
    fig = plt.figure(figsize=(len(heatmap_data.columns) * 0.5 + 4, 5))
    sns.heatmap(heatmap_data, annot=False, cmap="coolwarm", cbar=True, linewidths=0.5, linecolor='gray')
    plt.xlabel("Sentences")
    plt.ylabel("Emotions")
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.subplots_adjust(left=0.2, right=0.95, top=0.9, bottom=0.2)
    plt.tight_layout()
    return fig


def plot_emotion_barplot(heatmap_data):
    most_probable_emotions = heatmap_data.idxmax(axis=1)
    emotion_counts = most_probable_emotions.value_counts()

    # Normalize to get relative frequencies
    emotion_frequencies = (emotion_counts / emotion_counts.sum()).sort_values(ascending=False)

    fig, ax = plt.subplots(figsize=(8, 6))
    sns.barplot(x=emotion_frequencies.values, y=emotion_frequencies.index, palette="coolwarm", ax=ax)

    ax.set_title("Relative Frequencies of Most Probable Emotions")
    ax.set_xlabel("Relative Frequency")
    ax.set_ylabel("Emotions")

    for i, value in enumerate(emotion_frequencies.values):
        ax.text(value + 0.01, i, f"{value:.2f}", va='center')

    plt.tight_layout()
    return fig

def predict_wrapper(text, language):
    model_id = build_huggingface_path(language)
    tokenizer_id = "xlm-roberta-large"

    spacy_model = load_spacy_model()
    sentences = split_sentences(text, spacy_model)

    results = []
    results_heatmap = []
    for sentence in sentences:
        probs = predict(sentence, model_id, tokenizer_id)
        label, probability = get_most_probable_label(probs)
        results.append([sentence, label, probability])
        results_heatmap.append({"sentence":sentence, "emotions":probs})

    # let's see...
    print(results)
    print(results_heatmap)

    figure = plot_emotion_barplot(prepare_heatmap_data(results_heatmap))
    output_info = f'Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.'
    return results, figure, output_info


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            input_text = gr.Textbox(lines=6, label="Input", placeholder="Enter your text here...")
        with gr.Column():
            with gr.Row():
                language_choice = gr.Dropdown(choices=LANGUAGES, label="Language", value="English")
            with gr.Row():
                predict_button = gr.Button("Submit")

    with gr.Row():
        result_table = gr.Dataframe(
            headers=["Sentence", "Prediction", "Confidence"],
            column_widths=["65%", "25%", "10%"],
            wrap=True # important
        )

    with gr.Row():
        plot = gr.Plot()

    with gr.Row():
        model_info = gr.Markdown()

    predict_button.click(
        fn=predict_wrapper,
        inputs=[input_text, language_choice],
        outputs=[result_table, plot, model_info]
    )

if __name__ == "__main__":
    demo.launch()