mskov commited on
Commit
8e87013
·
1 Parent(s): e29cd5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -9
app.py CHANGED
@@ -23,13 +23,6 @@ from transformers import AutoModelForSequenceClassification, pipeline, WhisperTo
23
 
24
  model_cache = {}
25
 
26
- # Building prediction function for gradio
27
- emo_dict = {
28
- 'sad': 'Sad',
29
- 'hap': 'Happy',
30
- 'ang': 'Anger',
31
- 'neu': 'Neutral'
32
- }
33
 
34
  # static classes for now, but it would be best ot have the user select from multiple, and to enter their own
35
  class_options = {
@@ -57,12 +50,14 @@ def slider_logic(slider):
57
  return threshold
58
 
59
  # Create a Gradio interface with audio file and text inputs
60
- def classify_toxicity(audio_file, slider):
61
  # Transcribe the audio file using Whisper ASR
62
  if audio_file != None:
63
  transcribed_text = pipe(audio_file)["text"]
64
  else:
65
  transcribed_text = text_input
 
 
66
 
67
  threshold = slider_logic(slider)
68
  model = whisper.load_model("large")
@@ -95,6 +90,20 @@ def classify_toxicity(audio_file, slider):
95
  )
96
  average_logprobs -= internal_lm_average_logprobs
97
  scores = average_logprobs.softmax(-1).tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  holder1 = {class_name: score for class_name, score in zip(class_names, scores)}
99
  # miso_label_dict = {label: score for label, score in classify_anxiety[0].items()}
100
  holder2 = ""
@@ -114,6 +123,7 @@ def positive_affirmations():
114
  with gr.Blocks() as iface:
115
  show_state = gr.State([])
116
  with gr.Column():
 
117
  sense_slider = gr.Slider(minimum=1, maximum=5, step=1.0, label="How readily do you want the tool to intervene? 1 = in extreme cases and 5 = at every opportunity")
118
  with gr.Column():
119
  aud_input = gr.Audio(source="upload", type="filepath", label="Upload Audio File")
@@ -121,6 +131,6 @@ with gr.Blocks() as iface:
121
  with gr.Column():
122
  # out_val = gr.Textbox()
123
  out_class = gr.Label()
124
- submit_btn.click(fn=classify_toxicity, inputs=[aud_input, sense_slider], outputs=out_class)
125
 
126
  iface.launch()
 
23
 
24
  model_cache = {}
25
 
 
 
 
 
 
 
 
26
 
27
  # static classes for now, but it would be best ot have the user select from multiple, and to enter their own
28
  class_options = {
 
50
  return threshold
51
 
52
  # Create a Gradio interface with audio file and text inputs
53
+ def classify_toxicity(audio_file, selected_sounds, slider):
54
  # Transcribe the audio file using Whisper ASR
55
  if audio_file != None:
56
  transcribed_text = pipe(audio_file)["text"]
57
  else:
58
  transcribed_text = text_input
59
+
60
+ selected_class_names = selected_sounds.split(",")
61
 
62
  threshold = slider_logic(slider)
63
  model = whisper.load_model("large")
 
90
  )
91
  average_logprobs -= internal_lm_average_logprobs
92
  scores = average_logprobs.softmax(-1).tolist()
93
+
94
+ class_score_dict = {class_name: score for class_name, score in zip(class_names, scores)}
95
+ for selected_class_name in selected_class_names:
96
+ if selected_class_name in class_score_dict:
97
+ score = class_score_dict[selected_class_name]
98
+ if score > threshold:
99
+ print(f"Threshold exceeded for class '{selected_class_name}': Score = {score:.4f}")
100
+
101
+
102
+ '''
103
+ for class_name, score in class_score_dict.items():
104
+ if score > threshold:
105
+ print(f"Threshold exceeded for class '{class_name}': Score = {score:.4f}")
106
+ '''
107
  holder1 = {class_name: score for class_name, score in zip(class_names, scores)}
108
  # miso_label_dict = {label: score for label, score in classify_anxiety[0].items()}
109
  holder2 = ""
 
123
  with gr.Blocks() as iface:
124
  show_state = gr.State([])
125
  with gr.Column():
126
+ miso_sounds = gr.CheckboxGroup(["chewing", "breathing", "mouthsounds", "popping", "sneezing", "yawning", "smacking", "sniffling", "panting"])
127
  sense_slider = gr.Slider(minimum=1, maximum=5, step=1.0, label="How readily do you want the tool to intervene? 1 = in extreme cases and 5 = at every opportunity")
128
  with gr.Column():
129
  aud_input = gr.Audio(source="upload", type="filepath", label="Upload Audio File")
 
131
  with gr.Column():
132
  # out_val = gr.Textbox()
133
  out_class = gr.Label()
134
+ submit_btn.click(fn=classify_toxicity, inputs=[aud_input, miso_sounds, sense_slider], outputs=out_class)
135
 
136
  iface.launch()