Spaces:
Running
Running
File size: 5,742 Bytes
aa975e0 cec858f aa975e0 2d748e6 aa975e0 7c543b8 ac4450b aa975e0 575d2bc cec858f aa975e0 e3bfea3 aa975e0 9b8b66e aa975e0 796fe47 cec858f 796fe47 575d2bc 796fe47 b742f2b aa975e0 5b27629 ac4450b ed412e3 5b27629 f9b69e4 5b27629 3838a79 ed412e3 5b27629 ed412e3 ac4450b 8d02664 18c170e ed412e3 27ca4aa 98bd063 ac4450b e252a02 aa975e0 cec858f ac4450b cec858f 796fe47 2e2f372 ac4450b cec858f 0d8846e e252a02 e5fe883 ac4450b aa975e0 cec858f 3bdb0a7 9f40288 3bdb0a7 552f630 d71e26f 552f630 66daa12 95a86c7 552f630 169b92f 6f3f9f8 552f630 cec858f aa975e0 cec858f 169b92f cec858f aa975e0 cec858f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import os
import torch
import spacy
import numpy as np
import pandas as pd
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
import gradio as gr
import matplotlib.pyplot as plt
import seaborn as sns
PATH = '/data/' # at least 150GB storage needs to be attached
os.environ['TRANSFORMERS_CACHE'] = PATH
os.environ['HF_HOME'] = PATH
os.environ['HF_DATASETS_CACHE'] = PATH
os.environ['TORCH_HOME'] = PATH
HF_TOKEN = os.environ["hf_read"]
SENTIMENT_LABEL_NAMES = {0: "Negative", 1: "No sentiment or Neutral sentiment", 2: "Positive"}
LANGUAGES = ["Czech", "English", "French", "German", "Hungarian", "Polish", "Slovakian"]
id2label = {
0: "Anger",
1: "Fear",
2: "Disgust",
3: "Sadness",
4: "Joy",
5: "None of Them"
}
def load_spacy_model(model_name="xx_sent_ud_sm"):
try:
model = spacy.load(model_name)
except OSError:
spacy.cli.download(model_name)
model = spacy.load(model_name)
return model
def split_sentences(text, model):
# disable pipeline components not necessary for splitting
model.disable_pipes(model.pipe_names) # first disable all the pipes
model.enable_pipe("senter") # then enable the sentence splitter only
doc = model(text)
sentences = [sent.text for sent in doc.sents]
return sentences
def build_huggingface_path(language: str):
if language == "Czech" or language == "Slovakian":
return "visegradmedia-emotion/Emotion_RoBERTa_pooled_V4"
return "poltextlab/xlm-roberta-large-pooled-MORES"
def predict(text, model_id, tokenizer_id):
model = AutoModelForSequenceClassification.from_pretrained(model_id, low_cpu_mem_usage=True, device_map="auto", offload_folder="offload", token=HF_TOKEN)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
inputs = tokenizer(text,
max_length=64,
truncation=True,
padding="do_not_pad",
return_tensors="pt")
model.eval()
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
return probs
def get_most_probable_label(probs):
label = id2label[probs.argmax()]
probability = f"{round(100 * probs.max(), 2)}%"
return label, probability
def prepare_heatmap_data(data):
heatmap_data = pd.DataFrame(0.0, index=id2label.values(), columns=range(len(data)))
for idx, row in enumerate(data):
confidences = row["emotions"].tolist()
for idy, confidence in enumerate(confidences):
emotion = id2label[idy]
heatmap_data.at[emotion, idx] = round(confidence, 4)
heatmap_data.columns = [item["sentence"][:18]+"..." for item in data]
return heatmap_data
def plot_emotion_heatmap(heatmap_data):
fig = plt.figure(figsize=(len(heatmap_data.columns) * 0.5 + 4, 5))
sns.heatmap(heatmap_data, annot=False, cmap="coolwarm", cbar=True, linewidths=0.5, linecolor='gray')
plt.xlabel("Sentences")
plt.ylabel("Emotions")
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.subplots_adjust(left=0.2, right=0.95, top=0.9, bottom=0.2)
plt.tight_layout()
return fig
def plot_emotion_barplot(heatmap_data):
most_probable_emotions = heatmap_data.idxmax(axis=1)
emotion_counts = most_probable_emotions.value_counts()
# Normalize to get relative frequencies
emotion_frequencies = (emotion_counts / emotion_counts.sum()).sort_values(ascending=False)
fig, ax = plt.subplots(figsize=(8, 6))
sns.barplot(x=emotion_frequencies.values, y=emotion_frequencies.index, palette="coolwarm", ax=ax)
ax.set_title("Relative Frequencies of Most Probable Emotions")
ax.set_xlabel("Relative Frequency")
ax.set_ylabel("Emotions")
for i, value in enumerate(emotion_frequencies.values):
ax.text(value + 0.01, i, f"{value:.2f}", va='center')
plt.tight_layout()
return fig
def predict_wrapper(text, language):
model_id = build_huggingface_path(language)
tokenizer_id = "xlm-roberta-large"
spacy_model = load_spacy_model()
sentences = split_sentences(text, spacy_model)
results = []
results_heatmap = []
for sentence in sentences:
probs = predict(sentence, model_id, tokenizer_id)
label, probability = get_most_probable_label(probs)
results.append([sentence, label, probability])
results_heatmap.append({"sentence":sentence, "emotions":probs})
# let's see...
print(results)
print(results_heatmap)
figure = plot_emotion_barplot(prepare_heatmap_data(results_heatmap))
output_info = f'Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.'
return results, figure, output_info
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_text = gr.Textbox(lines=6, label="Input", placeholder="Enter your text here...")
with gr.Column():
with gr.Row():
language_choice = gr.Dropdown(choices=LANGUAGES, label="Language", value="English")
with gr.Row():
predict_button = gr.Button("Submit")
with gr.Row():
result_table = gr.Dataframe(
headers=["Sentence", "Prediction", "Confidence"],
column_widths=["65%", "25%", "10%"],
wrap=True # important
)
with gr.Row():
plot = gr.Plot()
with gr.Row():
model_info = gr.Markdown()
predict_button.click(
fn=predict_wrapper,
inputs=[input_text, language_choice],
outputs=[result_table, plot, model_info]
)
if __name__ == "__main__":
demo.launch()
|