Spaces:
Sleeping
Sleeping
File size: 5,642 Bytes
158b5a1 fb05782 158b5a1 6b24c4d 158b5a1 cb8ef04 3fd2db3 158b5a1 50c2025 fb05782 158b5a1 c1f3033 158b5a1 f61fe0b 158b5a1 b82d609 fb05782 b82d609 50c2025 b82d609 30624e1 158b5a1 40de5cf 3fd2db3 db0e152 40de5cf 0bade3d 40de5cf 09dd2f7 db0e152 40de5cf db0e152 3fd2db3 5250894 6947b3f db0e152 eb99fb6 93eecfb 3fd2db3 acd22e7 8deb410 acd22e7 aba39cd acd22e7 b444720 acd22e7 158b5a1 fb05782 3fd2db3 fb05782 b82d609 8383fbb 3fd2db3 fb05782 6ba52eb acd22e7 20f3f13 3fd2db3 158b5a1 fb05782 23bf035 3c4b3e3 23bf035 f7c3109 cc6aa04 f7c3109 1830273 4ee3727 f7c3109 6940c7c f516e56 f7c3109 fb05782 158b5a1 fb05782 6940c7c fb05782 158b5a1 fb05782 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import os
import torch
import spacy
import numpy as np
import pandas as pd
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
import gradio as gr
import matplotlib.pyplot as plt
import seaborn as sns
PATH = '/data/' # at least 150GB storage needs to be attached
os.environ['TRANSFORMERS_CACHE'] = PATH
os.environ['HF_HOME'] = PATH
os.environ['HF_DATASETS_CACHE'] = PATH
os.environ['TORCH_HOME'] = PATH
HF_TOKEN = os.environ["hf_read"]
SENTIMENT_LABEL_NAMES = {0: "Negative", 1: "No sentiment or Neutral sentiment", 2: "Positive"}
LANGUAGES = ["Czech", "English", "French", "German", "Hungarian", "Polish", "Slovakian"]
id2label = {
0: "Anger",
1: "Fear",
2: "Disgust",
3: "Sadness",
4: "Joy",
5: "None of Them"
}
def load_spacy_model(model_name="xx_sent_ud_sm"):
try:
model = spacy.load(model_name)
except OSError:
spacy.cli.download(model_name)
model = spacy.load(model_name)
return model
def split_sentences(text, model):
# disable pipeline components not necessary for splitting
model.disable_pipes(model.pipe_names) # first disable all the pipes
model.enable_pipe("senter") # then enable the sentence splitter only
doc = model(text)
sentences = [sent.text for sent in doc.sents]
return sentences
def build_huggingface_path(language: str):
if language == "Czech" or language == "Slovakian":
return "visegradmedia-emotion/Emotion_RoBERTa_pooled_V4"
return "poltextlab/xlm-roberta-large-pooled-MORES"
def predict(text, model_id, tokenizer_id):
model = AutoModelForSequenceClassification.from_pretrained(model_id, low_cpu_mem_usage=True, device_map="auto", offload_folder="offload", token=HF_TOKEN)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
inputs = tokenizer(text,
max_length=64,
truncation=True,
padding="do_not_pad",
return_tensors="pt")
model.eval()
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
return probs
def get_most_probable_label(probs):
label = id2label[probs.argmax()]
probability = f"{round(100 * probs.max(), 2)}%"
return label, probability
def prepare_heatmap_data(data):
heatmap_data = pd.DataFrame(0.0, index=id2label.values(), columns=range(len(data)))
for idx, row in enumerate(data):
confidences = row["emotions"].tolist()
for idy, confidence in enumerate(confidences):
emotion = id2label[idy]
heatmap_data.at[emotion, idx] = round(confidence, 4)
heatmap_data.columns = [item["sentence"][:18]+"..." for item in data]
return heatmap_data
def plot_emotion_heatmap(heatmap_data):
fig = plt.figure(figsize=(len(heatmap_data.columns) * 0.5 + 4, 5))
sns.heatmap(heatmap_data, annot=False, cmap="coolwarm", cbar=True, linewidths=0.5, linecolor='gray')
plt.xlabel("Sentences")
plt.ylabel("Emotions")
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.subplots_adjust(left=0.2, right=0.95, top=0.9, bottom=0.2)
plt.tight_layout()
return fig
def plot_emotion_barplot(heatmap_data):
most_probable_emotions = heatmap_data.idxmax(axis=0)
emotion_counts = most_probable_emotions.value_counts()
all_emotions = heatmap_data.index
emotion_frequencies = (emotion_counts.reindex(all_emotions, fill_value=0) / emotion_counts.sum()).sort_values(ascending=False)
fig, ax = plt.subplots(figsize=(8, 6))
sns.barplot(x=emotion_frequencies.values, y=emotion_frequencies.index, palette="coolwarm", ax=ax)
ax.set_title("Relative Frequencies of Predicted Emotions")
ax.set_xlabel("Relative Frequency")
ax.set_ylabel("Emotions")
plt.tight_layout()
return fig
def predict_wrapper(text, language):
model_id = build_huggingface_path(language)
tokenizer_id = "xlm-roberta-large"
spacy_model = load_spacy_model()
sentences = split_sentences(text, spacy_model)
results = []
results_heatmap = []
for sentence in sentences:
probs = predict(sentence, model_id, tokenizer_id)
label, probability = get_most_probable_label(probs)
results.append([sentence, label, probability])
results_heatmap.append({"sentence":sentence, "emotions":probs})
# let's see...
print(results)
print(results_heatmap)
figure = plot_emotion_barplot(prepare_heatmap_data(results_heatmap))
output_info = f'Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.'
return results, figure, output_info
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_text = gr.Textbox(lines=6, label="Input", placeholder="Enter your text here...")
with gr.Column():
with gr.Row():
language_choice = gr.Dropdown(choices=LANGUAGES, label="Language", value="English")
with gr.Row():
predict_button = gr.Button("Submit")
with gr.Row():
result_table = gr.Dataframe(
headers=["Sentence", "Prediction", "Confidence"],
column_widths=["65%", "25%", "10%"],
wrap=True # important
)
with gr.Row():
plot = gr.Plot()
with gr.Row():
model_info = gr.Markdown()
predict_button.click(
fn=predict_wrapper,
inputs=[input_text, language_choice],
outputs=[result_table, plot, model_info]
)
if __name__ == "__main__":
demo.launch()
|