import os import torch import spacy import numpy as np from transformers import AutoModelForSequenceClassification from transformers import AutoTokenizer import gradio as gr import matplotlib.pyplot as plt import seaborn as sns PATH = '/data/' # at least 150GB storage needs to be attached os.environ['TRANSFORMERS_CACHE'] = PATH os.environ['HF_HOME'] = PATH os.environ['HF_DATASETS_CACHE'] = PATH os.environ['TORCH_HOME'] = PATH HF_TOKEN = os.environ["hf_read"] SENTIMENT_LABEL_NAMES = {0: "Negative", 1: "No sentiment or Neutral sentiment", 2: "Positive"} LANGUAGES = ["Czech", "English", "French", "German", "Hungarian", "Polish", "Slovakian"] id2label = { 0: "Anger", 1: "Fear", 2: "Disgust", 3: "Sadness", 4: "Joy", 5: "None of Them" } def load_spacy_model(model_name="xx_sent_ud_sm"): try: model = spacy.load(model_name) except OSError: spacy.cli.download(model_name) model = spacy.load(model_name) return model def split_sentences(text, model): # disable pipeline components not necessary for splitting model.disable_pipes(model.pipe_names) # first disable all the pipes model.enable_pipe("senter") # then enable the sentence splitter only doc = model(text) sentences = [sent.text for sent in doc.sents] return sentences def build_huggingface_path(language: str): if language == "Czech" or language == "Slovakian": return "visegradmedia-emotion/Emotion_RoBERTa_pooled_V4" return "poltextlab/xlm-roberta-large-pooled-MORES" def predict(text, model_id, tokenizer_id): model = AutoModelForSequenceClassification.from_pretrained(model_id, low_cpu_mem_usage=True, device_map="auto", offload_folder="offload", token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) inputs = tokenizer(text, max_length=64, truncation=True, padding="do_not_pad", return_tensors="pt") model.eval() with torch.no_grad(): logits = model(**inputs).logits probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten() return probs def get_most_probable_label(probs): label = id2label[probs.argmax()] probability = f"{round(100 * probs.max(), 2)}%" return label, probability def prepare_heatmap_data(data): heatmap_data = pd.DataFrame(0, index=range(len(data)), columns=emotion_mapping.values()) for idx, item in enumerate(data): for idy, confidence in enumerate(item["emotions"]): emotion = emotion_mapping[idy] heatmap_data.at[idx, emotion] = confidence heatmap_data.index = [item["sentence"] for item in data] return heatmap_data def plot_emotion_heatmap(data): heatmap_data = prepare_heatmap_data(data) fig = plt.figure(figsize=(10, len(data) * 0.5 + 2)) sns.heatmap(heatmap_data, annot=True, cmap="coolwarm", cbar=True, linewidths=0.5, linecolor='gray') plt.title("Emotion Confidence Heatmap") plt.xlabel("Emotions") plt.ylabel("Sentences") plt.tight_layout() return fig def predict_wrapper(text, language): model_id = build_huggingface_path(language) tokenizer_id = "xlm-roberta-large" spacy_model = load_spacy_model() sentences = split_sentences(text, spacy_model) results = [] results_heatmap = [] for sentence in sentences: probs = predict(sentence, model_id, tokenizer_id) label, probability = get_most_probable_label(probs) results.append([sentence, label, probability]) results_heatmap.append({"sentence":sentence, "emotions":probs}) figure = plot_emotion_heatmap(prepare_heatmap_data(results_heatmap)) output_info = f'Prediction was made using the {model_id} model.' return results, figure, output_info with gr.Blocks() as demo: with gr.Row(): with gr.Column(): input_text = gr.Textbox(lines=6, label="Input", placeholder="Enter your text here...") with gr.Column(): with gr.Row(): language_choice = gr.Dropdown(choices=LANGUAGES, label="Language", value="English") with gr.Row(): predict_button = gr.Button("Submit") with gr.Row(): result_table = gr.Dataframe( headers=["Sentence", "Prediction", "Confidence"], column_widths=["65%", "25%", "10%"], wrap=True # important ) with gr.Row(): model_info = gr.Markdown() predict_button.click( fn=predict_wrapper, inputs=[input_text, language_choice], outputs=[result_table, "plot", model_info] ) if __name__ == "__main__": demo.launch()