Upload KotobaWhisperPipeline
Browse files- kotoba_whisper.py +22 -6
kotoba_whisper.py
CHANGED
@@ -49,6 +49,7 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
|
|
49 |
device: Union[int, "torch.device"] = None,
|
50 |
device_diarizarization: Union[int, "torch.device"] = None,
|
51 |
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
|
|
|
52 |
**kwargs):
|
53 |
self.type = "seq2seq_whisper"
|
54 |
if device is None:
|
@@ -58,6 +59,7 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
|
|
58 |
if type(device_diarizarization) is str:
|
59 |
device_diarizarization = torch.device(device_diarizarization)
|
60 |
self.model_speaker_diarization = SpeakerDiarization(model_diarizarization, device_diarizarization)
|
|
|
61 |
super().__init__(
|
62 |
model=model,
|
63 |
feature_extractor=feature_extractor,
|
@@ -192,6 +194,7 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
|
|
192 |
)
|
193 |
|
194 |
# custom processing for Whisper timestamps and word-level timestamps
|
|
|
195 |
if inputs.shape[-1] > self.feature_extractor.nb_max_frames:
|
196 |
generate_kwargs["input_features"] = inputs
|
197 |
else:
|
@@ -215,7 +218,7 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
|
|
215 |
*args,
|
216 |
**kwargs):
|
217 |
assert len(model_outputs) > 0
|
218 |
-
audio_array = list(model_outputs)[0]
|
219 |
sd = self.model_speaker_diarization(audio_array, sampling_rate=self.feature_extractor.sampling_rate)
|
220 |
timelines = sd.get_timeline()
|
221 |
outputs = super().postprocess(
|
@@ -229,35 +232,48 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
|
|
229 |
new_chunks = []
|
230 |
while True:
|
231 |
if pointer_ts == len(timelines):
|
232 |
-
|
|
|
|
|
|
|
233 |
break
|
234 |
if pointer_chunk == len(outputs["chunks"]):
|
235 |
break
|
236 |
ts = timelines[pointer_ts]
|
|
|
237 |
chunk = outputs["chunks"][pointer_chunk]
|
238 |
if "speaker" not in chunk:
|
239 |
-
chunk["speaker"] =
|
|
|
240 |
start, end = chunk["timestamp"]
|
241 |
if ts.end <= start:
|
242 |
-
chunk["speaker"].update(sd.get_labels(ts))
|
243 |
pointer_ts += 1
|
244 |
elif end <= ts.start:
|
|
|
|
|
245 |
new_chunks.append(chunk)
|
246 |
pointer_chunk += 1
|
247 |
else:
|
|
|
248 |
if ts.end >= end:
|
249 |
new_chunks.append(chunk)
|
250 |
pointer_chunk += 1
|
251 |
else:
|
252 |
-
chunk["speaker"].update(sd.get_labels(ts))
|
253 |
pointer_ts += 1
|
254 |
for i in new_chunks:
|
255 |
if "speaker" in i:
|
256 |
-
|
|
|
|
|
|
|
257 |
else:
|
258 |
i["speaker"] = []
|
259 |
outputs["chunks"] = new_chunks
|
260 |
outputs["text"] = "".join([c["text"] for c in outputs["chunks"]])
|
261 |
outputs["speakers"] = sd.labels()
|
|
|
|
|
|
|
|
|
262 |
return outputs
|
263 |
|
|
|
49 |
device: Union[int, "torch.device"] = None,
|
50 |
device_diarizarization: Union[int, "torch.device"] = None,
|
51 |
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
|
52 |
+
return_unique_speaker: bool = False,
|
53 |
**kwargs):
|
54 |
self.type = "seq2seq_whisper"
|
55 |
if device is None:
|
|
|
59 |
if type(device_diarizarization) is str:
|
60 |
device_diarizarization = torch.device(device_diarizarization)
|
61 |
self.model_speaker_diarization = SpeakerDiarization(model_diarizarization, device_diarizarization)
|
62 |
+
self.return_unique_speaker = return_unique_speaker
|
63 |
super().__init__(
|
64 |
model=model,
|
65 |
feature_extractor=feature_extractor,
|
|
|
194 |
)
|
195 |
|
196 |
# custom processing for Whisper timestamps and word-level timestamps
|
197 |
+
generate_kwargs["return_timestamps"] = True
|
198 |
if inputs.shape[-1] > self.feature_extractor.nb_max_frames:
|
199 |
generate_kwargs["input_features"] = inputs
|
200 |
else:
|
|
|
218 |
*args,
|
219 |
**kwargs):
|
220 |
assert len(model_outputs) > 0
|
221 |
+
audio_array = list(model_outputs)[0]["audio_array"]
|
222 |
sd = self.model_speaker_diarization(audio_array, sampling_rate=self.feature_extractor.sampling_rate)
|
223 |
timelines = sd.get_timeline()
|
224 |
outputs = super().postprocess(
|
|
|
232 |
new_chunks = []
|
233 |
while True:
|
234 |
if pointer_ts == len(timelines):
|
235 |
+
ts = timelines[-1]
|
236 |
+
for chunk in outputs["chunks"][pointer_chunk:]:
|
237 |
+
chunk["speaker"] = sd.get_labels(ts)
|
238 |
+
new_chunks.append(chunk)
|
239 |
break
|
240 |
if pointer_chunk == len(outputs["chunks"]):
|
241 |
break
|
242 |
ts = timelines[pointer_ts]
|
243 |
+
|
244 |
chunk = outputs["chunks"][pointer_chunk]
|
245 |
if "speaker" not in chunk:
|
246 |
+
chunk["speaker"] = []
|
247 |
+
|
248 |
start, end = chunk["timestamp"]
|
249 |
if ts.end <= start:
|
|
|
250 |
pointer_ts += 1
|
251 |
elif end <= ts.start:
|
252 |
+
if len(chunk["speaker"]) == 0:
|
253 |
+
chunk["speaker"] += list(sd.get_labels(ts))
|
254 |
new_chunks.append(chunk)
|
255 |
pointer_chunk += 1
|
256 |
else:
|
257 |
+
chunk["speaker"] += list(sd.get_labels(ts))
|
258 |
if ts.end >= end:
|
259 |
new_chunks.append(chunk)
|
260 |
pointer_chunk += 1
|
261 |
else:
|
|
|
262 |
pointer_ts += 1
|
263 |
for i in new_chunks:
|
264 |
if "speaker" in i:
|
265 |
+
if self.return_unique_speaker:
|
266 |
+
i["speaker"] = [i["speaker"][0]]
|
267 |
+
else:
|
268 |
+
i["speaker"] = list(set(i["speaker"]))
|
269 |
else:
|
270 |
i["speaker"] = []
|
271 |
outputs["chunks"] = new_chunks
|
272 |
outputs["text"] = "".join([c["text"] for c in outputs["chunks"]])
|
273 |
outputs["speakers"] = sd.labels()
|
274 |
+
outputs.pop("audio_array")
|
275 |
+
for s in outputs["speakers"]:
|
276 |
+
outputs[f"text/{s}"] = "".join([c["text"] for c in outputs["chunks"] if s in c["speaker"]])
|
277 |
+
outputs[f"chunks/{s}"] = [c for c in outputs["chunks"] if s in c["speaker"]]
|
278 |
return outputs
|
279 |
|