vickeee465 commited on
Commit
0f6522c
·
1 Parent(s): 2fe6a76

giving a try to crazy imshow thingie

Browse files
Files changed (1) hide show
  1. app.py +22 -34
app.py CHANGED
@@ -9,6 +9,7 @@ from transformers import AutoTokenizer
9
  import gradio as gr
10
  import matplotlib.pyplot as plt
11
  from matplotlib.colors import LinearSegmentedColormap
 
12
  import plotly.express as px
13
  import seaborn as sns
14
 
@@ -99,44 +100,31 @@ def prepare_heatmap_data(data):
99
  heatmap_data.columns = [item["sentence"][:18]+"..." for item in data]
100
  return heatmap_data
101
 
102
-
103
  def plot_emotion_heatmap(heatmap_data):
104
- # heatmap_data = heatmap_data.T # Transpose to: rows = emotions, cols = sentences
105
-
106
- # Normalize all values to [0, 1] for each emotion
107
  normalized_data = heatmap_data.copy()
108
  for row in heatmap_data.index:
109
  max_val = heatmap_data.loc[row].max()
110
- if max_val > 0:
111
- normalized_data.loc[row] = heatmap_data.loc[row] / max_val
112
- else:
113
- normalized_data.loc[row] = 0
114
-
115
- fig, ax = plt.subplots(figsize=(len(heatmap_data.columns) * 0.5 + 4, len(heatmap_data.index) * 0.5 + 2))
116
-
117
- print(heatmap_data) # debug
118
- for i, emotion in enumerate(heatmap_data.index):
119
- # Create custom colormap for each row
120
- base_color = emotion_colors[emotion]
121
- cmap = LinearSegmentedColormap.from_list(f"{emotion}_map", ["#ffffff", base_color])
122
-
123
- # Create heatmap for one row at a time
124
- sns.heatmap(
125
- pd.DataFrame([normalized_data.loc[emotion].values], index=[emotion], columns=normalized_data.columns),
126
- cmap=cmap,
127
- cbar=False,
128
- linewidths=0.5,
129
- linecolor='gray',
130
- vmin=0,
131
- vmax=1,
132
- ax=ax if i == 0 else ax, # reuse same axis
133
- )
134
-
135
- # Format axis
136
- ax.set_xticks(np.arange(len(heatmap_data.columns)) + 0.5)
137
- ax.set_xticklabels(heatmap_data.columns, rotation=0, ha='center')
138
- ax.set_yticks(np.arange(len(heatmap_data.index)) + 0.5)
139
- ax.set_yticklabels(heatmap_data.index, rotation=0, ha='right')
140
 
141
  ax.set_xlabel("Sentences")
142
  ax.set_ylabel("Emotions")
 
9
  import gradio as gr
10
  import matplotlib.pyplot as plt
11
  from matplotlib.colors import LinearSegmentedColormap
12
+ import matplotlib.colors as mcolors
13
  import plotly.express as px
14
  import seaborn as sns
15
 
 
100
  heatmap_data.columns = [item["sentence"][:18]+"..." for item in data]
101
  return heatmap_data
102
 
 
103
  def plot_emotion_heatmap(heatmap_data):
104
+ # Normalize values to [0, 1] per row (emotion)
 
 
105
  normalized_data = heatmap_data.copy()
106
  for row in heatmap_data.index:
107
  max_val = heatmap_data.loc[row].max()
108
+ normalized_data.loc[row] = heatmap_data.loc[row] / max_val if max_val > 0 else 0
109
+
110
+ # Build custom RGB color matrix
111
+ color_matrix = np.empty((len(normalized_data.index), len(normalized_data.columns), 3))
112
+ for i, emotion in enumerate(normalized_data.index):
113
+ base_rgb = mcolors.to_rgb(emotion_colors[emotion])
114
+ for j, val in enumerate(normalized_data.loc[emotion]):
115
+ # Linear interpolation from white to base color
116
+ color = tuple(1 - val * (1 - c) for c in base_rgb)
117
+ color_matrix[i, j] = color
118
+
119
+ fig, ax = plt.subplots(figsize=(len(normalized_data.columns) * 0.5 + 4, len(normalized_data.index) * 0.5 + 2))
120
+
121
+ ax.imshow(color_matrix, aspect='auto')
122
+
123
+ # Ticks and labels
124
+ ax.set_xticks(np.arange(len(normalized_data.columns)))
125
+ ax.set_xticklabels(normalized_data.columns, rotation=0, ha='center')
126
+ ax.set_yticks(np.arange(len(normalized_data.index)))
127
+ ax.set_yticklabels(normalized_data.index, rotation=0, ha='right')
 
 
 
 
 
 
 
 
 
 
128
 
129
  ax.set_xlabel("Sentences")
130
  ax.set_ylabel("Emotions")