Spaces:
Running
Running
vickeee465
commited on
Commit
·
7780172
1
Parent(s):
ec0a067
sunburst chart
Browse files
app.py
CHANGED
@@ -8,6 +8,7 @@ from transformers import AutoModelForSequenceClassification
|
|
8 |
from transformers import AutoTokenizer
|
9 |
import gradio as gr
|
10 |
import matplotlib.pyplot as plt
|
|
|
11 |
import seaborn as sns
|
12 |
|
13 |
PATH = '/data/' # at least 150GB storage needs to be attached
|
@@ -99,6 +100,31 @@ def plot_emotion_heatmap(heatmap_data):
|
|
99 |
plt.tight_layout()
|
100 |
return fig
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
def plot_emotion_barplot(heatmap_data):
|
104 |
most_probable_emotions = heatmap_data.idxmax(axis=0)
|
@@ -133,8 +159,9 @@ def predict_wrapper(text, language):
|
|
133 |
print(results_heatmap)
|
134 |
|
135 |
figure = plot_emotion_barplot(prepare_heatmap_data(results_heatmap))
|
|
|
136 |
output_info = f'Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.'
|
137 |
-
return results, figure, output_info
|
138 |
|
139 |
|
140 |
with gr.Blocks() as demo:
|
@@ -163,7 +190,7 @@ with gr.Blocks() as demo:
|
|
163 |
predict_button.click(
|
164 |
fn=predict_wrapper,
|
165 |
inputs=[input_text, language_choice],
|
166 |
-
outputs=[result_table, plot, model_info]
|
167 |
)
|
168 |
|
169 |
if __name__ == "__main__":
|
|
|
8 |
from transformers import AutoTokenizer
|
9 |
import gradio as gr
|
10 |
import matplotlib.pyplot as plt
|
11 |
+
import plotly.express as px
|
12 |
import seaborn as sns
|
13 |
|
14 |
PATH = '/data/' # at least 150GB storage needs to be attached
|
|
|
100 |
plt.tight_layout()
|
101 |
return fig
|
102 |
|
103 |
+
def plot_sunburst_chart(heatmap_data):
|
104 |
+
data = []
|
105 |
+
for item in heatmap_data:
|
106 |
+
sentence = item['sentence']
|
107 |
+
emotions = item['emotions']
|
108 |
+
for i, score in enumerate(emotions):
|
109 |
+
data.append({
|
110 |
+
'root': 'All Sentences',
|
111 |
+
'sentence': sentence,
|
112 |
+
'emotion': id2label[i],
|
113 |
+
'score': float(score)
|
114 |
+
})
|
115 |
+
|
116 |
+
df = pd.DataFrame(data)
|
117 |
+
|
118 |
+
# Plot sunburst
|
119 |
+
fig = px.sunburst(
|
120 |
+
df,
|
121 |
+
path=['root', 'sentence', 'emotion'],
|
122 |
+
values='score',
|
123 |
+
color='emotion',
|
124 |
+
title='Sentence-level Emotion Confidences'
|
125 |
+
)
|
126 |
+
|
127 |
+
return fig
|
128 |
|
129 |
def plot_emotion_barplot(heatmap_data):
|
130 |
most_probable_emotions = heatmap_data.idxmax(axis=0)
|
|
|
159 |
print(results_heatmap)
|
160 |
|
161 |
figure = plot_emotion_barplot(prepare_heatmap_data(results_heatmap))
|
162 |
+
sunburst_chart = plot_sunburst_chart(results_heatmap)
|
163 |
output_info = f'Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.'
|
164 |
+
return results, figure, sunburst_chart, output_info
|
165 |
|
166 |
|
167 |
with gr.Blocks() as demo:
|
|
|
190 |
predict_button.click(
|
191 |
fn=predict_wrapper,
|
192 |
inputs=[input_text, language_choice],
|
193 |
+
outputs=[result_table, plot, sunburst_chart, model_info]
|
194 |
)
|
195 |
|
196 |
if __name__ == "__main__":
|