vickeee465 commited on
Commit
3fd2db3
·
1 Parent(s): 8383fbb

added figure logic

Browse files
Files changed (1) hide show
  1. app.py +27 -2
app.py CHANGED
@@ -5,6 +5,8 @@ import numpy as np
5
  from transformers import AutoModelForSequenceClassification
6
  from transformers import AutoTokenizer
7
  import gradio as gr
 
 
8
 
9
  PATH = '/data/' # at least 150GB storage needs to be attached
10
  os.environ['TRANSFORMERS_CACHE'] = PATH
@@ -70,6 +72,25 @@ def get_most_probable_label(probs):
70
  probability = f"{round(100 * probs.max(), 2)}%"
71
  return label, probability
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def predict_wrapper(text, language):
74
  model_id = build_huggingface_path(language)
75
  tokenizer_id = "xlm-roberta-large"
@@ -78,13 +99,17 @@ def predict_wrapper(text, language):
78
  sentences = split_sentences(text, spacy_model)
79
 
80
  results = []
 
81
  for sentence in sentences:
82
  probs = predict(sentence, model_id, tokenizer_id)
83
  label, probability = get_most_probable_label(probs)
84
  results.append([sentence, label, probability])
 
85
 
 
86
  output_info = f'Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.'
87
- return results, output_info
 
88
 
89
  with gr.Blocks() as demo:
90
  with gr.Row():
@@ -108,7 +133,7 @@ with gr.Blocks() as demo:
108
  predict_button.click(
109
  fn=predict_wrapper,
110
  inputs=[input_text, language_choice],
111
- outputs=[result_table, model_info]
112
  )
113
 
114
  if __name__ == "__main__":
 
5
  from transformers import AutoModelForSequenceClassification
6
  from transformers import AutoTokenizer
7
  import gradio as gr
8
+ import matplotlib.plt as pyplot
9
+ import seaborn as sns
10
 
11
  PATH = '/data/' # at least 150GB storage needs to be attached
12
  os.environ['TRANSFORMERS_CACHE'] = PATH
 
72
  probability = f"{round(100 * probs.max(), 2)}%"
73
  return label, probability
74
 
75
+ def prepare_heatmap_data(data):
76
+ heatmap_data = pd.DataFrame(0, index=range(len(data)), columns=emotion_mapping.values())
77
+ for idx, item in enumerate(data):
78
+ for idy, confidence in enumerate(item["emotions"]):
79
+ emotion = emotion_mapping[idy]
80
+ heatmap_data.at[idx, emotion] = confidence
81
+ heatmap_data.index = [item["sentence"] for item in data]
82
+ return heatmap_data
83
+
84
+ def plot_emotion_heatmap(data):
85
+ heatmap_data = prepare_heatmap_data(data)
86
+ fig = plt.figure(figsize=(10, len(data) * 0.5 + 2))
87
+ sns.heatmap(heatmap_data, annot=True, cmap="coolwarm", cbar=True, linewidths=0.5, linecolor='gray')
88
+ plt.title("Emotion Confidence Heatmap")
89
+ plt.xlabel("Emotions")
90
+ plt.ylabel("Sentences")
91
+ plt.tight_layout()
92
+ return fig
93
+
94
  def predict_wrapper(text, language):
95
  model_id = build_huggingface_path(language)
96
  tokenizer_id = "xlm-roberta-large"
 
99
  sentences = split_sentences(text, spacy_model)
100
 
101
  results = []
102
+ results_heatmap = []
103
  for sentence in sentences:
104
  probs = predict(sentence, model_id, tokenizer_id)
105
  label, probability = get_most_probable_label(probs)
106
  results.append([sentence, label, probability])
107
+ results_heatmap.append({"sentence":sentence, "emotions":probs})
108
 
109
+ figure = plot_emotion_heatmap(prepare_heatmap_data(results_heatmap))
110
  output_info = f'Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.'
111
+ return results, figure, output_info
112
+
113
 
114
  with gr.Blocks() as demo:
115
  with gr.Row():
 
133
  predict_button.click(
134
  fn=predict_wrapper,
135
  inputs=[input_text, language_choice],
136
+ outputs=[result_table, "plot", model_info]
137
  )
138
 
139
  if __name__ == "__main__":