herme commited on
Commit
bdaccf4
·
1 Parent(s): 150d5bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -107
app.py CHANGED
@@ -1,109 +1,82 @@
1
- import whisper
2
- import gradio as gr
3
- import datetime
4
-
5
- import subprocess
6
-
7
- import torch
8
- import pyannote.audio
9
- from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
10
-
11
- from pyannote.audio import Audio
12
- from pyannote.core import Segment
13
-
14
- import wave
15
- import contextlib
16
-
17
- from sklearn.cluster import AgglomerativeClustering
18
- import numpy as np
19
-
20
- model = whisper.load_model("large-v2")
21
- embedding_model = PretrainedSpeakerEmbedding(
22
- "speechbrain/spkrec-ecapa-voxceleb",
23
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
- )
25
-
26
- def transcribe(audio, num_speakers):
27
- path, error = convert_to_wav(audio)
28
- if error is not None:
29
- return error
30
-
31
- duration = get_duration(path)
32
- if duration > 4 * 60 * 60:
33
- return "Audio duration too long"
34
-
35
- result = model.transcribe(path)
36
- segments = result["segments"]
37
-
38
- num_speakers = min(max(round(num_speakers), 1), len(segments))
39
- if len(segments) == 1:
40
- segments[0]['speaker'] = 'SPEAKER 1'
41
- else:
42
- embeddings = make_embeddings(path, segments, duration)
43
- add_speaker_labels(segments, embeddings, num_speakers)
44
- output = get_output(segments)
45
- return output
46
 
47
- def convert_to_wav(path):
48
- if path[-3:] != 'wav':
49
- new_path = '.'.join(path.split('.')[:-1]) + '.wav'
50
- try:
51
- subprocess.call(['ffmpeg', '-i', path, new_path, '-y'])
52
- except:
53
- return path, 'Error: Could not convert file to .wav'
54
- path = new_path
55
- return path, None
56
-
57
- def get_duration(path):
58
- with contextlib.closing(wave.open(path,'r')) as f:
59
- frames = f.getnframes()
60
- rate = f.getframerate()
61
- return frames / float(rate)
62
-
63
- def make_embeddings(path, segments, duration):
64
- embeddings = np.zeros(shape=(len(segments), 192))
65
- for i, segment in enumerate(segments):
66
- embeddings[i] = segment_embedding(path, segment, duration)
67
- return np.nan_to_num(embeddings)
68
-
69
- audio = Audio()
70
-
71
- def segment_embedding(path, segment, duration):
72
- start = segment["start"]
73
- # Whisper overshoots the end timestamp in the last segment
74
- end = min(duration, segment["end"])
75
- clip = Segment(start, end)
76
- waveform, sample_rate = audio.crop(path, clip)
77
- return embedding_model(waveform[None])
78
-
79
- def add_speaker_labels(segments, embeddings, num_speakers):
80
- clustering = AgglomerativeClustering(num_speakers).fit(embeddings)
81
- labels = clustering.labels_
82
- for i in range(len(segments)):
83
- segments[i]["speaker"] = 'SPEAKER ' + str(labels[i] + 1)
84
-
85
- def time(secs):
86
- return datetime.timedelta(seconds=round(secs))
87
-
88
- def get_output(segments):
89
- output = ''
90
- for (i, segment) in enumerate(segments):
91
- if i == 0 or segments[i - 1]["speaker"] != segment["speaker"]:
92
- if i != 0:
93
- output += '\n\n'
94
- output += segment["speaker"] + ' ' + str(time(segment["start"])) + '\n\n'
95
- output += segment["text"][1:] + ' '
96
- return output
97
-
98
- gr.Interface(
99
- title = 'Prueba Whisper Audio to Text ',
100
- fn=transcribe,
101
- inputs=[
102
- gr.inputs.Audio(source="upload", type="filepath"),
103
- gr.inputs.Number(default=2, label="Number of Speakers")
104
-
105
- ],
106
- outputs=[
107
- gr.outputs.Textbox(label='Transcript')
108
  ]
109
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ import gradio as gr
4
+ import whisper
5
+ from whisper.tokenizer import get_tokenizer
6
+
7
+ import classify
8
+
9
+ model_cache = {}
10
+
11
+
12
+ def zero_shot_classify(audio_path: str, class_names: str, model_name: str) -> Dict[str, float]:
13
+ class_names = class_names.split(",")
14
+ tokenizer = get_tokenizer(multilingual=".en" not in model_name)
15
+
16
+ if model_name not in model_cache:
17
+ model = whisper.load_model(model_name)
18
+ model_cache[model_name] = model
19
+ else:
20
+ model = model_cache[model_name]
21
+
22
+ internal_lm_average_logprobs = classify.calculate_internal_lm_average_logprobs(
23
+ model=model,
24
+ class_names=class_names,
25
+ tokenizer=tokenizer,
26
+ )
27
+ audio_features = classify.calculate_audio_features(audio_path, model)
28
+ average_logprobs = classify.calculate_average_logprobs(
29
+ model=model,
30
+ audio_features=audio_features,
31
+ class_names=class_names,
32
+ tokenizer=tokenizer,
33
+ )
34
+ average_logprobs -= internal_lm_average_logprobs
35
+ scores = average_logprobs.softmax(-1).tolist()
36
+ return {class_name: score for class_name, score in zip(class_names, scores)}
37
+
38
+
39
+ def main():
40
+ CLASS_NAMES = "[dog barking],[helicopter whirring],[laughing],[birds chirping],[clock ticking]"
41
+ AUDIO_PATHS = [
42
+ "./data/(dog)1-100032-A-0.wav",
43
+ "./data/(helicopter)1-181071-A-40.wav",
44
+ "./data/(laughing)1-1791-A-26.wav",
45
+ "./data/(chirping_birds)1-34495-A-14.wav",
46
+ "./data/(clock_tick)1-21934-A-38.wav",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  ]
48
+ EXAMPLES = []
49
+ for audio_path in AUDIO_PATHS:
50
+ EXAMPLES.append([audio_path, CLASS_NAMES, "small"])
51
+
52
+ DESCRIPTION = (
53
+ '<div style="text-align: center;">'
54
+ "<p>This demo allows you to try out zero-shot audio classification using "
55
+ "<a href=https://github.com/openai/whisper>Whisper</a>.</p>"
56
+ "<p>Github: <a href=https://github.com/jumon/zac>https://github.com/jumon/zac</a></p>"
57
+ "<p>Example audio files are from the <a href=https://github.com/karolpiczak/ESC-50>ESC-50"
58
+ "</a> dataset (CC BY-NC 3.0).</p></div>"
59
+ )
60
+
61
+ demo = gr.Interface(
62
+ fn=zero_shot_classify,
63
+ inputs=[
64
+ gr.Audio(source="upload", type="filepath", label="Audio File"),
65
+ gr.Textbox(lines=1, label="Candidate class names (comma-separated)"),
66
+ gr.Radio(
67
+ choices=["tiny", "base", "small", "medium", "large"],
68
+ value="small",
69
+ label="Model Name",
70
+ ),
71
+ ],
72
+ outputs="label",
73
+ examples=EXAMPLES,
74
+ title="Zero-shot Audio Classification using Whisper",
75
+ description=DESCRIPTION,
76
+ )
77
+
78
+ demo.launch()
79
+
80
+
81
+ if __name__ == "__main__":
82
+ main()