asahi417 commited on
Commit
986454a
·
verified ·
1 Parent(s): 8069744

Upload KotobaWhisperPipeline

Browse files
Files changed (1) hide show
  1. kotoba_whisper.py +22 -6
kotoba_whisper.py CHANGED
@@ -49,6 +49,7 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
49
  device: Union[int, "torch.device"] = None,
50
  device_diarizarization: Union[int, "torch.device"] = None,
51
  torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
 
52
  **kwargs):
53
  self.type = "seq2seq_whisper"
54
  if device is None:
@@ -58,6 +59,7 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
58
  if type(device_diarizarization) is str:
59
  device_diarizarization = torch.device(device_diarizarization)
60
  self.model_speaker_diarization = SpeakerDiarization(model_diarizarization, device_diarizarization)
 
61
  super().__init__(
62
  model=model,
63
  feature_extractor=feature_extractor,
@@ -192,6 +194,7 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
192
  )
193
 
194
  # custom processing for Whisper timestamps and word-level timestamps
 
195
  if inputs.shape[-1] > self.feature_extractor.nb_max_frames:
196
  generate_kwargs["input_features"] = inputs
197
  else:
@@ -215,7 +218,7 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
215
  *args,
216
  **kwargs):
217
  assert len(model_outputs) > 0
218
- audio_array = list(model_outputs)[0].pop("audio_array")
219
  sd = self.model_speaker_diarization(audio_array, sampling_rate=self.feature_extractor.sampling_rate)
220
  timelines = sd.get_timeline()
221
  outputs = super().postprocess(
@@ -229,35 +232,48 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
229
  new_chunks = []
230
  while True:
231
  if pointer_ts == len(timelines):
232
- new_chunks += outputs["chunks"][pointer_chunk:]
 
 
 
233
  break
234
  if pointer_chunk == len(outputs["chunks"]):
235
  break
236
  ts = timelines[pointer_ts]
 
237
  chunk = outputs["chunks"][pointer_chunk]
238
  if "speaker" not in chunk:
239
- chunk["speaker"] = set()
 
240
  start, end = chunk["timestamp"]
241
  if ts.end <= start:
242
- chunk["speaker"].update(sd.get_labels(ts))
243
  pointer_ts += 1
244
  elif end <= ts.start:
 
 
245
  new_chunks.append(chunk)
246
  pointer_chunk += 1
247
  else:
 
248
  if ts.end >= end:
249
  new_chunks.append(chunk)
250
  pointer_chunk += 1
251
  else:
252
- chunk["speaker"].update(sd.get_labels(ts))
253
  pointer_ts += 1
254
  for i in new_chunks:
255
  if "speaker" in i:
256
- i["speaker"] = list(i["speaker"])
 
 
 
257
  else:
258
  i["speaker"] = []
259
  outputs["chunks"] = new_chunks
260
  outputs["text"] = "".join([c["text"] for c in outputs["chunks"]])
261
  outputs["speakers"] = sd.labels()
 
 
 
 
262
  return outputs
263
 
 
49
  device: Union[int, "torch.device"] = None,
50
  device_diarizarization: Union[int, "torch.device"] = None,
51
  torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
52
+ return_unique_speaker: bool = False,
53
  **kwargs):
54
  self.type = "seq2seq_whisper"
55
  if device is None:
 
59
  if type(device_diarizarization) is str:
60
  device_diarizarization = torch.device(device_diarizarization)
61
  self.model_speaker_diarization = SpeakerDiarization(model_diarizarization, device_diarizarization)
62
+ self.return_unique_speaker = return_unique_speaker
63
  super().__init__(
64
  model=model,
65
  feature_extractor=feature_extractor,
 
194
  )
195
 
196
  # custom processing for Whisper timestamps and word-level timestamps
197
+ generate_kwargs["return_timestamps"] = True
198
  if inputs.shape[-1] > self.feature_extractor.nb_max_frames:
199
  generate_kwargs["input_features"] = inputs
200
  else:
 
218
  *args,
219
  **kwargs):
220
  assert len(model_outputs) > 0
221
+ audio_array = list(model_outputs)[0]["audio_array"]
222
  sd = self.model_speaker_diarization(audio_array, sampling_rate=self.feature_extractor.sampling_rate)
223
  timelines = sd.get_timeline()
224
  outputs = super().postprocess(
 
232
  new_chunks = []
233
  while True:
234
  if pointer_ts == len(timelines):
235
+ ts = timelines[-1]
236
+ for chunk in outputs["chunks"][pointer_chunk:]:
237
+ chunk["speaker"] = sd.get_labels(ts)
238
+ new_chunks.append(chunk)
239
  break
240
  if pointer_chunk == len(outputs["chunks"]):
241
  break
242
  ts = timelines[pointer_ts]
243
+
244
  chunk = outputs["chunks"][pointer_chunk]
245
  if "speaker" not in chunk:
246
+ chunk["speaker"] = []
247
+
248
  start, end = chunk["timestamp"]
249
  if ts.end <= start:
 
250
  pointer_ts += 1
251
  elif end <= ts.start:
252
+ if len(chunk["speaker"]) == 0:
253
+ chunk["speaker"] += list(sd.get_labels(ts))
254
  new_chunks.append(chunk)
255
  pointer_chunk += 1
256
  else:
257
+ chunk["speaker"] += list(sd.get_labels(ts))
258
  if ts.end >= end:
259
  new_chunks.append(chunk)
260
  pointer_chunk += 1
261
  else:
 
262
  pointer_ts += 1
263
  for i in new_chunks:
264
  if "speaker" in i:
265
+ if self.return_unique_speaker:
266
+ i["speaker"] = [i["speaker"][0]]
267
+ else:
268
+ i["speaker"] = list(set(i["speaker"]))
269
  else:
270
  i["speaker"] = []
271
  outputs["chunks"] = new_chunks
272
  outputs["text"] = "".join([c["text"] for c in outputs["chunks"]])
273
  outputs["speakers"] = sd.labels()
274
+ outputs.pop("audio_array")
275
+ for s in outputs["speakers"]:
276
+ outputs[f"text/{s}"] = "".join([c["text"] for c in outputs["chunks"] if s in c["speaker"]])
277
+ outputs[f"chunks/{s}"] = [c for c in outputs["chunks"] if s in c["speaker"]]
278
  return outputs
279