File size: 4,436 Bytes
c7c2507 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
def plot_emotion_confusion_matrix(results_df, emotion_columns):
correct_count = {emotion: 0 for emotion in emotion_columns}
incorrect_count = {emotion: 0 for emotion in emotion_columns}
undefined_count = {emotion: 0 for emotion in emotion_columns}
for idx, row in results_df.iterrows():
true_emotions = set(row['true emotions'].split()) if isinstance(row['true emotions'], str) else set()
predicted_emotions = set(row['predict emotions'].split()) if isinstance(row['predict emotions'], str) else set()
for emotion in emotion_columns:
if emotion in true_emotions and emotion in predicted_emotions:
correct_count[emotion] += 1
elif emotion in predicted_emotions and emotion not in true_emotions:
incorrect_count[emotion] += 1
elif emotion in true_emotions and emotion not in predicted_emotions:
undefined_count[emotion] += 1
data = []
for emotion in emotion_columns:
data.append([
correct_count[emotion],
incorrect_count[emotion],
undefined_count[emotion]
])
heatmap_df = pd.DataFrame(data, columns=["Correctly Identified", "Incorrectly Identified", "Undefined"], index=emotion_columns)
num_examples = len(results_df)
plt.figure(figsize=(10, 12))
sns.heatmap(heatmap_df, annot=True, cmap="Blues", fmt="d", cbar=False)
plt.title(f"Emotion Prediction Confusion Matrix (Examples: {num_examples})")
plt.xlabel("Prediction Status")
plt.ylabel("Emotion")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()
def plot_true_emotion_frequency(results_df, emotion_columns):
true_emotion_count = {emotion: 0 for emotion in emotion_columns}
for idx, row in results_df.iterrows():
true_emotions = set(row['true emotions'].split()) if isinstance(row['true emotions'], str) else set()
for emotion in emotion_columns:
if emotion in true_emotions:
true_emotion_count[emotion] += 1
data = []
for emotion in emotion_columns:
data.append([true_emotion_count[emotion]])
heatmap_df = pd.DataFrame(data, columns=["True Emotion Count"], index=emotion_columns)
plt.figure(figsize=(10, 12))
sns.heatmap(heatmap_df, annot=True, cmap="YlGnBu", fmt="d", cbar=False)
plt.title(f"True Emotion Frequency (Examples: {len(results_df)})")
plt.xlabel("True Emotion Count")
plt.ylabel("Emotion")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()
def plot_predicted_emotion_frequency(results_df, emotion_columns):
predicted_emotion_count = {emotion: 0 for emotion in emotion_columns}
for idx, row in results_df.iterrows():
predicted_emotions = set(row['predict emotions'].split()) if isinstance(row['predict emotions'], str) else set()
for emotion in emotion_columns:
if emotion in predicted_emotions:
predicted_emotion_count[emotion] += 1
data = []
for emotion in emotion_columns:
data.append([predicted_emotion_count[emotion]])
heatmap_df = pd.DataFrame(data, columns=["Predicted Emotion Count"], index=emotion_columns)
plt.figure(figsize=(10, 12))
sns.heatmap(heatmap_df, annot=True, cmap="YlOrRd", fmt="d", cbar=False)
plt.title(f"Predicted Emotion Frequency (Examples: {len(results_df)})")
plt.xlabel("Predicted Emotion Count")
plt.ylabel("Emotion")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()
csv_path = "RuBert-tiny2-EmotionsDetected/Dstasets/Emotions_detected.csv"
results_df = pd.read_csv(csv_path)
emotion_columns = [
"admiration", "amusement", "anger", "annoyance", "approval", "caring", "confusion", "curiosity", "desire",
"disappointment", "disapproval", "disgust", "embarrassment", "excitement", "fear", "gratitude", "grief", "joy",
"love", "nervousness", "optimism", "pride", "realization", "relief", "remorse", "sadness", "surprise", "neutral"
]
plot_true_emotion_frequency(results_df, emotion_columns)
plot_predicted_emotion_frequency(results_df, emotion_columns)
plot_emotion_confusion_matrix(results_df, emotion_columns)
|