Spaces:
Running
Running
vickeee465
commited on
Commit
·
0f6522c
1
Parent(s):
2fe6a76
giving a try to crazy imshow thingie
Browse files
app.py
CHANGED
@@ -9,6 +9,7 @@ from transformers import AutoTokenizer
|
|
9 |
import gradio as gr
|
10 |
import matplotlib.pyplot as plt
|
11 |
from matplotlib.colors import LinearSegmentedColormap
|
|
|
12 |
import plotly.express as px
|
13 |
import seaborn as sns
|
14 |
|
@@ -99,44 +100,31 @@ def prepare_heatmap_data(data):
|
|
99 |
heatmap_data.columns = [item["sentence"][:18]+"..." for item in data]
|
100 |
return heatmap_data
|
101 |
|
102 |
-
|
103 |
def plot_emotion_heatmap(heatmap_data):
|
104 |
-
#
|
105 |
-
|
106 |
-
# Normalize all values to [0, 1] for each emotion
|
107 |
normalized_data = heatmap_data.copy()
|
108 |
for row in heatmap_data.index:
|
109 |
max_val = heatmap_data.loc[row].max()
|
110 |
-
if max_val > 0
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
vmin=0,
|
131 |
-
vmax=1,
|
132 |
-
ax=ax if i == 0 else ax, # reuse same axis
|
133 |
-
)
|
134 |
-
|
135 |
-
# Format axis
|
136 |
-
ax.set_xticks(np.arange(len(heatmap_data.columns)) + 0.5)
|
137 |
-
ax.set_xticklabels(heatmap_data.columns, rotation=0, ha='center')
|
138 |
-
ax.set_yticks(np.arange(len(heatmap_data.index)) + 0.5)
|
139 |
-
ax.set_yticklabels(heatmap_data.index, rotation=0, ha='right')
|
140 |
|
141 |
ax.set_xlabel("Sentences")
|
142 |
ax.set_ylabel("Emotions")
|
|
|
9 |
import gradio as gr
|
10 |
import matplotlib.pyplot as plt
|
11 |
from matplotlib.colors import LinearSegmentedColormap
|
12 |
+
import matplotlib.colors as mcolors
|
13 |
import plotly.express as px
|
14 |
import seaborn as sns
|
15 |
|
|
|
100 |
heatmap_data.columns = [item["sentence"][:18]+"..." for item in data]
|
101 |
return heatmap_data
|
102 |
|
|
|
103 |
def plot_emotion_heatmap(heatmap_data):
|
104 |
+
# Normalize values to [0, 1] per row (emotion)
|
|
|
|
|
105 |
normalized_data = heatmap_data.copy()
|
106 |
for row in heatmap_data.index:
|
107 |
max_val = heatmap_data.loc[row].max()
|
108 |
+
normalized_data.loc[row] = heatmap_data.loc[row] / max_val if max_val > 0 else 0
|
109 |
+
|
110 |
+
# Build custom RGB color matrix
|
111 |
+
color_matrix = np.empty((len(normalized_data.index), len(normalized_data.columns), 3))
|
112 |
+
for i, emotion in enumerate(normalized_data.index):
|
113 |
+
base_rgb = mcolors.to_rgb(emotion_colors[emotion])
|
114 |
+
for j, val in enumerate(normalized_data.loc[emotion]):
|
115 |
+
# Linear interpolation from white to base color
|
116 |
+
color = tuple(1 - val * (1 - c) for c in base_rgb)
|
117 |
+
color_matrix[i, j] = color
|
118 |
+
|
119 |
+
fig, ax = plt.subplots(figsize=(len(normalized_data.columns) * 0.5 + 4, len(normalized_data.index) * 0.5 + 2))
|
120 |
+
|
121 |
+
ax.imshow(color_matrix, aspect='auto')
|
122 |
+
|
123 |
+
# Ticks and labels
|
124 |
+
ax.set_xticks(np.arange(len(normalized_data.columns)))
|
125 |
+
ax.set_xticklabels(normalized_data.columns, rotation=0, ha='center')
|
126 |
+
ax.set_yticks(np.arange(len(normalized_data.index)))
|
127 |
+
ax.set_yticklabels(normalized_data.index, rotation=0, ha='right')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
ax.set_xlabel("Sentences")
|
130 |
ax.set_ylabel("Emotions")
|