File size: 8,899 Bytes
158b5a1
 
fb05782
315aa5c
158b5a1
6b24c4d
158b5a1
 
 
cb8ef04
467c2e7
0f6522c
7780172
3fd2db3
158b5a1
 
 
 
 
 
 
 
 
 
 
 
50c2025
 
 
 
 
 
 
 
7e66e8d
 
 
 
 
 
 
 
 
fb05782
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158b5a1
 
 
 
2fe6a76
158b5a1
315aa5c
158b5a1
 
 
 
 
c1f3033
158b5a1
 
f61fe0b
158b5a1
 
 
 
 
 
b82d609
fb05782
b82d609
50c2025
b82d609
30624e1
158b5a1
40de5cf
3fd2db3
db0e152
40de5cf
0bade3d
 
40de5cf
09dd2f7
db0e152
40de5cf
db0e152
3fd2db3
 
ed10ca1
433b160
 
 
 
3ad91fd
433b160
 
 
0f6522c
433b160
0f6522c
433b160
 
 
 
 
 
 
 
 
0f6522c
 
433b160
0f6522c
433b160
 
0f6522c
433b160
3ad91fd
433b160
 
0efc44c
433b160
3fd2db3
 
ccd8c01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e66e8d
 
ccd8c01
 
 
 
 
 
 
acd22e7
8deb410
acd22e7
aba39cd
 
7e66e8d
acd22e7
7e66e8d
b444720
acd22e7
 
 
 
 
158b5a1
 
 
 
fb05782
 
 
 
3fd2db3
fb05782
b82d609
8383fbb
 
3fd2db3
fb05782
6ba52eb
 
 
 
acd22e7
f382753
2c43ece
20f3f13
ccd8c01
3fd2db3
158b5a1
 
8106682
fb05782
23bf035
3c4b3e3
23bf035
 
 
 
 
f7c3109
cc6aa04
8106682
 
 
 
 
 
 
 
6940c7c
8106682
 
 
 
6940c7c
416f5a0
8106682
 
 
 
416f5a0
f382753
8106682
 
 
 
f382753
f516e56
f7c3109
fb05782
 
158b5a1
fb05782
ccd8c01
fb05782
158b5a1
 
 
fb05782
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import os
import torch
import spacy
import spaces
import numpy as np
import pandas as pd
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
import gradio as gr
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.colors as mcolors
import plotly.express as px
import seaborn as sns

PATH = '/data/' # at least 150GB storage needs to be attached
os.environ['TRANSFORMERS_CACHE'] = PATH
os.environ['HF_HOME'] = PATH
os.environ['HF_DATASETS_CACHE'] = PATH
os.environ['TORCH_HOME'] = PATH

HF_TOKEN = os.environ["hf_read"]

SENTIMENT_LABEL_NAMES = {0: "Negative", 1: "No sentiment or Neutral sentiment", 2: "Positive"}
LANGUAGES = ["Czech", "English", "French", "German", "Hungarian", "Polish", "Slovakian"]

id2label = {
    0: "Anger",
    1: "Fear",
    2: "Disgust",
    3: "Sadness",
    4: "Joy",
    5: "None of Them"
}

emotion_colors = {
    "Anger": "#D96459",
    "Fear": "#6A8EAE",
    "Disgust": "#A4C639",
    "Sadness": "#9DBCD4",
    "Joy": "#F3E9A8",
    "None of Them": "#C0C0C0"
}
def load_spacy_model(model_name="xx_sent_ud_sm"):
    try:
        model = spacy.load(model_name)
    except OSError:
        spacy.cli.download(model_name)
        model = spacy.load(model_name)
    return model

def split_sentences(text, model):
    # disable pipeline components not necessary for splitting
    model.disable_pipes(model.pipe_names)  # first disable all the pipes
    model.enable_pipe("senter") # then enable the sentence splitter only

    doc = model(text)
    sentences = [sent.text for sent in doc.sents]

    return sentences

def build_huggingface_path(language: str):
    if language == "Czech" or language == "Slovakian":
        return "visegradmedia-emotion/Emotion_RoBERTa_pooled_V4"
    return "poltextlab/xlm-roberta-large-pooled-emotions6"

@spaces.GPU
def predict(text, model_id, tokenizer_id):
    model = AutoModelForSequenceClassification.from_pretrained(model_id, low_cpu_mem_usage=True, device_map="auto", offload_folder="offload", token=HF_TOKEN)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)

    inputs = tokenizer(text,
                       max_length=64,
                       truncation=True,
                       padding="do_not_pad",
                       return_tensors="pt")
    model.eval()

    with torch.no_grad():
        logits = model(**inputs).logits

    probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
    return probs

def get_most_probable_label(probs):
    label = id2label[probs.argmax()]
    probability = f"{round(100 * probs.max(), 2)}%"
    return label, probability


def prepare_heatmap_data(data):
    heatmap_data = pd.DataFrame(0.0, index=id2label.values(), columns=range(len(data)))

    for idx, row in enumerate(data):
        confidences = row["emotions"].tolist()
        for idy, confidence in enumerate(confidences):
            emotion = id2label[idy]
            heatmap_data.at[emotion, idx] = round(confidence, 4)

    heatmap_data.columns = [item["sentence"][:18]+"..." for item in data]
    return heatmap_data

def plot_emotion_heatmap(heatmap_data):
    # Transpose: now rows = sentences, columns = emotions
    heatmap_data = heatmap_data.T

    # Normalize each row (sentence-wise)
    normalized_data = heatmap_data.copy()
    for row in normalized_data.index:
        max_val = normalized_data.loc[row].max()
        normalized_data.loc[row] = normalized_data.loc[row] / max_val if max_val > 0 else 0

    # Create color matrix
    color_matrix = np.empty((len(normalized_data.index), len(normalized_data.columns), 3))
    for i, sentence in enumerate(normalized_data.index):
        for j, emotion in enumerate(normalized_data.columns):
            val = normalized_data.loc[sentence, emotion]
            base_rgb = mcolors.to_rgb(emotion_colors[emotion])
            # Blend from white to base color
            blended = tuple(1 - val * (1 - c) for c in base_rgb)
            color_matrix[i, j] = blended

    fig, ax = plt.subplots(figsize=(len(normalized_data.columns) * 0.8 + 2, len(normalized_data.index) * 0.5 + 2))
    ax.imshow(color_matrix, aspect='auto')

    # Set ticks and labels
    ax.set_xticks(np.arange(len(normalized_data.columns)))
    ax.set_xticklabels(normalized_data.columns, rotation=45, ha='right', fontsize=10)

    ax.set_yticks(np.arange(len(normalized_data.index)))
    ax.set_yticklabels(normalized_data.index, rotation=0, fontsize=10)

    ax.set_xlabel("Emotions")
    ax.set_ylabel("Sentences")

    plt.tight_layout()
    return fig

def plot_average_emotion_pie(heatmap_data):
    all_emotion_scores = np.array([item['emotions'] for item in heatmap_data])
    mean_scores = all_emotion_scores.mean(axis=0)

    labels = [id2label[i] for i in range(len(mean_scores))]
    sizes = mean_scores

    # optional: remove emotions with near-zero average
    labels_filtered = []
    sizes_filtered = []
    for l, s in zip(labels, sizes):
        if s > 0.01:  # You can change this threshold
            labels_filtered.append(l)
            sizes_filtered.append(s)

    fig, ax = plt.subplots(figsize=(6, 6))
    wedges, texts, autotexts = ax.pie(
        sizes_filtered,
        labels=labels_filtered,
        autopct='%1.1f%%',
        startangle=140,
        textprops={'fontsize': 12},
        colors=[emotion_colors[l] for l in labels_filtered]
    )

    ax.axis('equal')  # Equal aspect ratio ensures a circle
    plt.title("Average Emotion Confidence Across Sentences", fontsize=14)

    return fig

def plot_emotion_barplot(heatmap_data):
    most_probable_emotions = heatmap_data.idxmax(axis=0)
    emotion_counts = most_probable_emotions.value_counts()
    all_emotions = heatmap_data.index
    emotion_frequencies = (emotion_counts.reindex(all_emotions, fill_value=0) / emotion_counts.sum()).sort_values(ascending=False)
    palette = [emotion_colors[emotion] for emotion in emotion_frequencies.index]
    fig, ax = plt.subplots(figsize=(8, 6))
    sns.barplot(x=emotion_frequencies.values, y=emotion_frequencies.index, palette=palette, ax=ax)
    ax.set_title("Relative Frequencies of Predicted Emotions")
    ax.set_xlabel("Relative Frequency")
    ax.set_ylabel("Emotions")
    plt.tight_layout()
    return fig

def predict_wrapper(text, language):
    model_id = build_huggingface_path(language)
    tokenizer_id = "xlm-roberta-large"

    spacy_model = load_spacy_model()
    sentences = split_sentences(text, spacy_model)

    results = []
    results_heatmap = []
    for sentence in sentences:
        probs = predict(sentence, model_id, tokenizer_id)
        label, probability = get_most_probable_label(probs)
        results.append([sentence, label, probability])
        results_heatmap.append({"sentence":sentence, "emotions":probs})

    # let's see...
    print(results)
    print(results_heatmap)

    figure = plot_emotion_barplot(prepare_heatmap_data(results_heatmap))
    heatmap = plot_emotion_heatmap(prepare_heatmap_data(results_heatmap))
    piechart = plot_average_emotion_pie(results_heatmap)
    output_info = f'Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.'
    return results, figure, piechart, heatmap, output_info


with gr.Blocks() as demo:
    placeholder = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua."
    with gr.Row():
        with gr.Column():
            input_text = gr.Textbox(lines=6, label="Input", placeholder="Enter your text here...")
        with gr.Column():
            with gr.Row():
                language_choice = gr.Dropdown(choices=LANGUAGES, label="Language", value="English")
            with gr.Row():
                predict_button = gr.Button("Submit")

    with gr.Row():
        with gr.Column(scale=7):
            result_table = gr.Dataframe(
                headers=["Sentence", "Prediction", "Confidence"],
                column_widths=["65%", "25%", "10%"],
                wrap=True # important
            )
        with gr.Column(scale=3):
            gr.Markdown(placeholder)
    with gr.Row():
        with gr.Column(scale=7):
            plot = gr.Plot()
        with gr.Column(scale=3):
            gr.Markdown(placeholder)

    with gr.Row():
        with gr.Column(scale=7):
            piechart = gr.Plot()
        with gr.Column(scale=3):
            gr.Markdown(placeholder)

    with gr.Row():
        with gr.Column(scale=7):
            heatmap = gr.Plot()
        with gr.Column(scale=3):
            gr.Markdown(placeholder)

    with gr.Row():
        model_info = gr.Markdown()

    predict_button.click(
        fn=predict_wrapper,
        inputs=[input_text, language_choice],
        outputs=[result_table, plot, piechart, heatmap, model_info]
    )

if __name__ == "__main__":
    demo.launch()