vickeee465 commited on
Commit
3ad91fd
·
1 Parent(s): 7e66e8d

color coding logic for heatmap

Browse files
Files changed (1) hide show
  1. app.py +37 -23
app.py CHANGED
@@ -100,33 +100,47 @@ def prepare_heatmap_data(data):
100
  return heatmap_data
101
 
102
 
103
- def plot_emotion_heatmap(heatmap_data):
104
- heatmap_data = heatmap_data.T
105
-
106
- fig = plt.figure(figsize=(len(heatmap_data.columns) * 0.5 + 4, len(heatmap_data.index) * 0.5 + 2))
107
-
108
- cmap = LinearSegmentedColormap.from_list("white_to_grey", ["#ffffff", "#aaaaaa"])
109
-
110
- sns.heatmap(
111
- heatmap_data,
112
- annot=False,
113
- cmap=cmap,
114
- cbar=True,
115
- linewidths=0.5,
116
- linecolor='gray',
117
- vmin=0,
118
- vmax=1
119
- )
120
-
121
- plt.xlabel("Emotions")
122
- plt.ylabel("Sentences")
123
- plt.xticks(rotation=0, ha='center')
124
- plt.yticks(rotation=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  plt.tight_layout()
126
 
127
  return fig
128
 
129
-
130
  def plot_average_emotion_pie(heatmap_data):
131
  all_emotion_scores = np.array([item['emotions'] for item in heatmap_data])
132
  mean_scores = all_emotion_scores.mean(axis=0)
 
100
  return heatmap_data
101
 
102
 
103
+ def plot_emotion_heatmap_colored(heatmap_data, emotion_colors):
104
+ heatmap_data = heatmap_data.T # Transpose to: rows = emotions, cols = sentences
105
+
106
+ # Normalize all values to [0, 1] for each emotion
107
+ normalized_data = heatmap_data.copy()
108
+ for row in heatmap_data.index:
109
+ max_val = heatmap_data.loc[row].max()
110
+ if max_val > 0:
111
+ normalized_data.loc[row] = heatmap_data.loc[row] / max_val
112
+ else:
113
+ normalized_data.loc[row] = 0
114
+
115
+ fig, ax = plt.subplots(figsize=(len(heatmap_data.columns) * 0.5 + 4, len(heatmap_data.index) * 0.5 + 2))
116
+
117
+ for i, emotion in enumerate(heatmap_data.index):
118
+ # Create custom colormap for each row
119
+ base_color = emotion_colors[emotion]
120
+ cmap = LinearSegmentedColormap.from_list(f"{emotion}_map", ["#ffffff", base_color])
121
+
122
+ # Create heatmap for one row at a time
123
+ sns.heatmap(
124
+ pd.DataFrame([normalized_data.loc[emotion].values], index=[emotion], columns=normalized_data.columns),
125
+ cmap=cmap,
126
+ cbar=False,
127
+ linewidths=0.5,
128
+ linecolor='gray',
129
+ vmin=0,
130
+ vmax=1,
131
+ ax=ax if i == 0 else ax, # reuse same axis
132
+ )
133
+
134
+ # Format axis
135
+ ax.set_xticklabels(heatmap_data.columns, rotation=0, ha='center')
136
+ ax.set_yticklabels(heatmap_data.index, rotation=0, ha='right')
137
+
138
+ ax.set_xlabel("Sentences")
139
+ ax.set_ylabel("Emotions")
140
  plt.tight_layout()
141
 
142
  return fig
143
 
 
144
  def plot_average_emotion_pie(heatmap_data):
145
  all_emotion_scores = np.array([item['emotions'] for item in heatmap_data])
146
  mean_scores = all_emotion_scores.mean(axis=0)