Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -23,13 +23,6 @@ from transformers import AutoModelForSequenceClassification, pipeline, WhisperTo
|
|
23 |
|
24 |
model_cache = {}
|
25 |
|
26 |
-
# Building prediction function for gradio
|
27 |
-
emo_dict = {
|
28 |
-
'sad': 'Sad',
|
29 |
-
'hap': 'Happy',
|
30 |
-
'ang': 'Anger',
|
31 |
-
'neu': 'Neutral'
|
32 |
-
}
|
33 |
|
34 |
# static classes for now, but it would be best ot have the user select from multiple, and to enter their own
|
35 |
class_options = {
|
@@ -57,12 +50,14 @@ def slider_logic(slider):
|
|
57 |
return threshold
|
58 |
|
59 |
# Create a Gradio interface with audio file and text inputs
|
60 |
-
def classify_toxicity(audio_file, slider):
|
61 |
# Transcribe the audio file using Whisper ASR
|
62 |
if audio_file != None:
|
63 |
transcribed_text = pipe(audio_file)["text"]
|
64 |
else:
|
65 |
transcribed_text = text_input
|
|
|
|
|
66 |
|
67 |
threshold = slider_logic(slider)
|
68 |
model = whisper.load_model("large")
|
@@ -95,6 +90,20 @@ def classify_toxicity(audio_file, slider):
|
|
95 |
)
|
96 |
average_logprobs -= internal_lm_average_logprobs
|
97 |
scores = average_logprobs.softmax(-1).tolist()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
holder1 = {class_name: score for class_name, score in zip(class_names, scores)}
|
99 |
# miso_label_dict = {label: score for label, score in classify_anxiety[0].items()}
|
100 |
holder2 = ""
|
@@ -114,6 +123,7 @@ def positive_affirmations():
|
|
114 |
with gr.Blocks() as iface:
|
115 |
show_state = gr.State([])
|
116 |
with gr.Column():
|
|
|
117 |
sense_slider = gr.Slider(minimum=1, maximum=5, step=1.0, label="How readily do you want the tool to intervene? 1 = in extreme cases and 5 = at every opportunity")
|
118 |
with gr.Column():
|
119 |
aud_input = gr.Audio(source="upload", type="filepath", label="Upload Audio File")
|
@@ -121,6 +131,6 @@ with gr.Blocks() as iface:
|
|
121 |
with gr.Column():
|
122 |
# out_val = gr.Textbox()
|
123 |
out_class = gr.Label()
|
124 |
-
submit_btn.click(fn=classify_toxicity, inputs=[aud_input, sense_slider], outputs=out_class)
|
125 |
|
126 |
iface.launch()
|
|
|
23 |
|
24 |
model_cache = {}
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
# static classes for now, but it would be best ot have the user select from multiple, and to enter their own
|
28 |
class_options = {
|
|
|
50 |
return threshold
|
51 |
|
52 |
# Create a Gradio interface with audio file and text inputs
|
53 |
+
def classify_toxicity(audio_file, selected_sounds, slider):
|
54 |
# Transcribe the audio file using Whisper ASR
|
55 |
if audio_file != None:
|
56 |
transcribed_text = pipe(audio_file)["text"]
|
57 |
else:
|
58 |
transcribed_text = text_input
|
59 |
+
|
60 |
+
selected_class_names = selected_sounds.split(",")
|
61 |
|
62 |
threshold = slider_logic(slider)
|
63 |
model = whisper.load_model("large")
|
|
|
90 |
)
|
91 |
average_logprobs -= internal_lm_average_logprobs
|
92 |
scores = average_logprobs.softmax(-1).tolist()
|
93 |
+
|
94 |
+
class_score_dict = {class_name: score for class_name, score in zip(class_names, scores)}
|
95 |
+
for selected_class_name in selected_class_names:
|
96 |
+
if selected_class_name in class_score_dict:
|
97 |
+
score = class_score_dict[selected_class_name]
|
98 |
+
if score > threshold:
|
99 |
+
print(f"Threshold exceeded for class '{selected_class_name}': Score = {score:.4f}")
|
100 |
+
|
101 |
+
|
102 |
+
'''
|
103 |
+
for class_name, score in class_score_dict.items():
|
104 |
+
if score > threshold:
|
105 |
+
print(f"Threshold exceeded for class '{class_name}': Score = {score:.4f}")
|
106 |
+
'''
|
107 |
holder1 = {class_name: score for class_name, score in zip(class_names, scores)}
|
108 |
# miso_label_dict = {label: score for label, score in classify_anxiety[0].items()}
|
109 |
holder2 = ""
|
|
|
123 |
with gr.Blocks() as iface:
|
124 |
show_state = gr.State([])
|
125 |
with gr.Column():
|
126 |
+
miso_sounds = gr.CheckboxGroup(["chewing", "breathing", "mouthsounds", "popping", "sneezing", "yawning", "smacking", "sniffling", "panting"])
|
127 |
sense_slider = gr.Slider(minimum=1, maximum=5, step=1.0, label="How readily do you want the tool to intervene? 1 = in extreme cases and 5 = at every opportunity")
|
128 |
with gr.Column():
|
129 |
aud_input = gr.Audio(source="upload", type="filepath", label="Upload Audio File")
|
|
|
131 |
with gr.Column():
|
132 |
# out_val = gr.Textbox()
|
133 |
out_class = gr.Label()
|
134 |
+
submit_btn.click(fn=classify_toxicity, inputs=[aud_input, miso_sounds, sense_slider], outputs=out_class)
|
135 |
|
136 |
iface.launch()
|