Spaces:
Running
Running
import os | |
import torch | |
import spacy | |
import spaces | |
import numpy as np | |
import pandas as pd | |
from transformers import AutoModelForSequenceClassification | |
from transformers import AutoTokenizer | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
from matplotlib.colors import LinearSegmentedColormap | |
import matplotlib.colors as mcolors | |
import plotly.express as px | |
import seaborn as sns | |
PATH = '/data/' # at least 150GB storage needs to be attached | |
os.environ['TRANSFORMERS_CACHE'] = PATH | |
os.environ['HF_HOME'] = PATH | |
os.environ['HF_DATASETS_CACHE'] = PATH | |
os.environ['TORCH_HOME'] = PATH | |
HF_TOKEN = os.environ["hf_read"] | |
SENTIMENT_LABEL_NAMES = {0: "Negative", 1: "No sentiment or Neutral sentiment", 2: "Positive"} | |
LANGUAGES = ["Czech", "English", "French", "German", "Hungarian", "Polish", "Slovakian"] | |
id2label = { | |
0: "Anger", | |
1: "Fear", | |
2: "Disgust", | |
3: "Sadness", | |
4: "Joy", | |
5: "None of Them" | |
} | |
emotion_colors = { | |
"Anger": "#D96459", | |
"Fear": "#6A8EAE", | |
"Disgust": "#A4C639", | |
"Sadness": "#9DBCD4", | |
"Joy": "#F3E9A8", | |
"None of Them": "#C0C0C0" | |
} | |
def load_spacy_model(model_name="xx_sent_ud_sm"): | |
try: | |
model = spacy.load(model_name) | |
except OSError: | |
spacy.cli.download(model_name) | |
model = spacy.load(model_name) | |
return model | |
def split_sentences(text, model): | |
# disable pipeline components not necessary for splitting | |
model.disable_pipes(model.pipe_names) # first disable all the pipes | |
model.enable_pipe("senter") # then enable the sentence splitter only | |
doc = model(text) | |
sentences = [sent.text for sent in doc.sents] | |
return sentences | |
def build_huggingface_path(language: str): | |
if language == "Czech" or language == "Slovakian": | |
return "visegradmedia-emotion/Emotion_RoBERTa_pooled_V4" | |
return "poltextlab/xlm-roberta-large-pooled-emotions6" | |
def predict(text, model_id, tokenizer_id): | |
model = AutoModelForSequenceClassification.from_pretrained(model_id, low_cpu_mem_usage=True, device_map="auto", offload_folder="offload", token=HF_TOKEN) | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) | |
inputs = tokenizer(text, | |
max_length=64, | |
truncation=True, | |
padding="do_not_pad", | |
return_tensors="pt") | |
model.eval() | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten() | |
return probs | |
def get_most_probable_label(probs): | |
label = id2label[probs.argmax()] | |
probability = f"{round(100 * probs.max(), 2)}%" | |
return label, probability | |
def prepare_heatmap_data(data): | |
heatmap_data = pd.DataFrame(0.0, index=id2label.values(), columns=range(len(data))) | |
for idx, row in enumerate(data): | |
confidences = row["emotions"].tolist() | |
for idy, confidence in enumerate(confidences): | |
emotion = id2label[idy] | |
heatmap_data.at[emotion, idx] = round(confidence, 4) | |
heatmap_data.columns = [item["sentence"][:18]+"..." for item in data] | |
return heatmap_data | |
def plot_emotion_heatmap(heatmap_data): | |
# Transpose: now rows = sentences, columns = emotions | |
heatmap_data = heatmap_data.T | |
# Normalize each row (sentence-wise) | |
normalized_data = heatmap_data.copy() | |
for row in normalized_data.index: | |
max_val = normalized_data.loc[row].max() | |
normalized_data.loc[row] = normalized_data.loc[row] / max_val if max_val > 0 else 0 | |
# Create color matrix | |
color_matrix = np.empty((len(normalized_data.index), len(normalized_data.columns), 3)) | |
for i, sentence in enumerate(normalized_data.index): | |
for j, emotion in enumerate(normalized_data.columns): | |
val = normalized_data.loc[sentence, emotion] | |
base_rgb = mcolors.to_rgb(emotion_colors[emotion]) | |
# Blend from white to base color | |
blended = tuple(1 - val * (1 - c) for c in base_rgb) | |
color_matrix[i, j] = blended | |
fig, ax = plt.subplots(figsize=(len(normalized_data.columns) * 0.8 + 2, len(normalized_data.index) * 0.5 + 2)) | |
ax.imshow(color_matrix, aspect='auto') | |
# Set ticks and labels | |
ax.set_xticks(np.arange(len(normalized_data.columns))) | |
ax.set_xticklabels(normalized_data.columns, rotation=45, ha='right', fontsize=10) | |
ax.set_yticks(np.arange(len(normalized_data.index))) | |
ax.set_yticklabels(normalized_data.index, rotation=0, fontsize=10) | |
ax.set_xlabel("Emotions") | |
ax.set_ylabel("Sentences") | |
plt.tight_layout() | |
return fig | |
def plot_average_emotion_pie(heatmap_data): | |
all_emotion_scores = np.array([item['emotions'] for item in heatmap_data]) | |
mean_scores = all_emotion_scores.mean(axis=0) | |
labels = [id2label[i] for i in range(len(mean_scores))] | |
sizes = mean_scores | |
# optional: remove emotions with near-zero average | |
labels_filtered = [] | |
sizes_filtered = [] | |
for l, s in zip(labels, sizes): | |
if s > 0.01: # You can change this threshold | |
labels_filtered.append(l) | |
sizes_filtered.append(s) | |
fig, ax = plt.subplots(figsize=(6, 6)) | |
wedges, texts, autotexts = ax.pie( | |
sizes_filtered, | |
labels=labels_filtered, | |
autopct='%1.1f%%', | |
startangle=140, | |
textprops={'fontsize': 12}, | |
colors=[emotion_colors[l] for l in labels_filtered] | |
) | |
ax.axis('equal') # Equal aspect ratio ensures a circle | |
plt.title("Average Emotion Confidence Across Sentences", fontsize=14) | |
return fig | |
def plot_emotion_barplot(heatmap_data): | |
most_probable_emotions = heatmap_data.idxmax(axis=0) | |
emotion_counts = most_probable_emotions.value_counts() | |
all_emotions = heatmap_data.index | |
emotion_frequencies = (emotion_counts.reindex(all_emotions, fill_value=0) / emotion_counts.sum()).sort_values(ascending=False) | |
palette = [emotion_colors[emotion] for emotion in emotion_frequencies.index] | |
fig, ax = plt.subplots(figsize=(8, 6)) | |
sns.barplot(x=emotion_frequencies.values, y=emotion_frequencies.index, palette=palette, ax=ax) | |
ax.set_title("Relative Frequencies of Predicted Emotions") | |
ax.set_xlabel("Relative Frequency") | |
ax.set_ylabel("Emotions") | |
plt.tight_layout() | |
return fig | |
def predict_wrapper(text, language): | |
model_id = build_huggingface_path(language) | |
tokenizer_id = "xlm-roberta-large" | |
spacy_model = load_spacy_model() | |
sentences = split_sentences(text, spacy_model) | |
results = [] | |
results_heatmap = [] | |
for sentence in sentences: | |
probs = predict(sentence, model_id, tokenizer_id) | |
label, probability = get_most_probable_label(probs) | |
results.append([sentence, label, probability]) | |
results_heatmap.append({"sentence":sentence, "emotions":probs}) | |
# let's see... | |
print(results) | |
print(results_heatmap) | |
figure = plot_emotion_barplot(prepare_heatmap_data(results_heatmap)) | |
heatmap = plot_emotion_heatmap(prepare_heatmap_data(results_heatmap)) | |
piechart = plot_average_emotion_pie(results_heatmap) | |
output_info = f'Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.' | |
return results, figure, piechart, heatmap, output_info | |
with gr.Blocks() as demo: | |
placeholder = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox(lines=6, label="Input", placeholder="Enter your text here...") | |
with gr.Column(): | |
with gr.Row(): | |
language_choice = gr.Dropdown(choices=LANGUAGES, label="Language", value="English") | |
with gr.Row(): | |
predict_button = gr.Button("Submit") | |
with gr.Row(): | |
with gr.Column(scale=7): | |
result_table = gr.Dataframe( | |
headers=["Sentence", "Prediction", "Confidence"], | |
column_widths=["65%", "25%", "10%"], | |
wrap=True # important | |
) | |
with gr.Column(scale=3): | |
gr.Markdown(placeholder) | |
with gr.Row(): | |
with gr.Column(scale=7): | |
plot = gr.Plot() | |
with gr.Column(scale=3): | |
gr.Markdown(placeholder) | |
with gr.Row(): | |
with gr.Column(scale=7): | |
piechart = gr.Plot() | |
with gr.Column(scale=3): | |
gr.Markdown(placeholder) | |
with gr.Row(): | |
with gr.Column(scale=7): | |
heatmap = gr.Plot() | |
with gr.Column(scale=3): | |
gr.Markdown(placeholder) | |
with gr.Row(): | |
model_info = gr.Markdown() | |
predict_button.click( | |
fn=predict_wrapper, | |
inputs=[input_text, language_choice], | |
outputs=[result_table, plot, piechart, heatmap, model_info] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |