herme commited on
Commit
b17da2c
·
1 Parent(s): c3fc696

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -2
app.py CHANGED
@@ -1,5 +1,82 @@
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- examples = [["The Moon's orbit around Earth has"], ["There once was a pineapple"]]
4
 
5
- gr.Interface.load("huggingface/EleutherAI/gpt-j-6B", examples=examples).launch();
 
 
1
+ from typing import Dict
2
+
3
  import gradio as gr
4
+ import whisper
5
+ from whisper.tokenizer import get_tokenizer
6
+
7
+ import classify
8
+
9
+ model_cache = {}
10
+
11
+
12
+ def zero_shot_classify(audio_path: str, class_names: str, model_name: str) -> Dict[str, float]:
13
+ class_names = class_names.split(",")
14
+ tokenizer = get_tokenizer(multilingual=".en" not in model_name)
15
+
16
+ if model_name not in model_cache:
17
+ model = whisper.load_model(model_name)
18
+ model_cache[model_name] = model
19
+ else:
20
+ model = model_cache[model_name]
21
+
22
+ internal_lm_average_logprobs = classify.calculate_internal_lm_average_logprobs(
23
+ model=model,
24
+ class_names=class_names,
25
+ tokenizer=tokenizer,
26
+ )
27
+ audio_features = classify.calculate_audio_features(audio_path, model)
28
+ average_logprobs = classify.calculate_average_logprobs(
29
+ model=model,
30
+ audio_features=audio_features,
31
+ class_names=class_names,
32
+ tokenizer=tokenizer,
33
+ )
34
+ average_logprobs -= internal_lm_average_logprobs
35
+ scores = average_logprobs.softmax(-1).tolist()
36
+ return {class_name: score for class_name, score in zip(class_names, scores)}
37
+
38
+
39
+ def main():
40
+ CLASS_NAMES = "[dog barking],[helicopter whirring],[laughing],[birds chirping],[clock ticking]"
41
+ AUDIO_PATHS = [
42
+ "./data/(dog)1-100032-A-0.wav",
43
+ "./data/(helicopter)1-181071-A-40.wav",
44
+ "./data/(laughing)1-1791-A-26.wav",
45
+ "./data/(chirping_birds)1-34495-A-14.wav",
46
+ "./data/(clock_tick)1-21934-A-38.wav",
47
+ ]
48
+ EXAMPLES = []
49
+ for audio_path in AUDIO_PATHS:
50
+ EXAMPLES.append([audio_path, CLASS_NAMES, "small"])
51
+
52
+ DESCRIPTION = (
53
+ '<div style="text-align: center;">'
54
+ "<p>This demo allows you to try out zero-shot audio classification using "
55
+ "<a href=https://github.com/openai/whisper>Whisper</a>.</p>"
56
+ "<p>Github: <a href=https://github.com/jumon/zac>https://github.com/jumon/zac</a></p>"
57
+ "<p>Example audio files are from the <a href=https://github.com/karolpiczak/ESC-50>ESC-50"
58
+ "</a> dataset (CC BY-NC 3.0).</p></div>"
59
+ )
60
+
61
+ demo = gr.Interface(
62
+ fn=zero_shot_classify,
63
+ inputs=[
64
+ gr.Audio(source="upload", type="filepath", label="Audio File"),
65
+ gr.Textbox(lines=1, label="Candidate class names (comma-separated)"),
66
+ gr.Radio(
67
+ choices=["tiny", "base", "small", "medium", "large"],
68
+ value="small",
69
+ label="Model Name",
70
+ ),
71
+ ],
72
+ outputs="label",
73
+ examples=EXAMPLES,
74
+ title="Zero-shot Audio Classification using Whisper",
75
+ description=DESCRIPTION,
76
+ )
77
+
78
+ demo.launch()
79
 
 
80
 
81
+ if __name__ == "__main__":
82
+ main()