|
|
|
from transformers import ( |
|
WhisperForConditionalGeneration, |
|
WhisperProcessor, |
|
) |
|
import torch |
|
import re |
|
import numpy as np |
|
from datasets import load_dataset |
|
|
|
device = "cpu" |
|
dtype = torch.float32 |
|
|
|
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") |
|
model = WhisperForConditionalGeneration.from_pretrained( |
|
"openai/whisper-tiny", low_cpu_mem_usage=True, torch_dtype=dtype |
|
) |
|
model.to(device) |
|
|
|
STREAMING_INTERVAL = 0.33 |
|
SAMPLING_RATE = 16_000 |
|
INTERVAL_LENGTH = int(STREAMING_INTERVAL * SAMPLING_RATE) |
|
|
|
|
|
ds = load_dataset( |
|
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation" |
|
) |
|
audio_array = np.concatenate([x["array"] for x in ds["audio"]]) |
|
|
|
|
|
start_idx = 0 |
|
fully_decoded = "" |
|
for end_idx in range(INTERVAL_LENGTH, audio_array.shape[-1], INTERVAL_LENGTH): |
|
input_audio = audio_array[start_idx:end_idx] |
|
|
|
processor_kwargs = ( |
|
{"padding": "longest", "truncation": False, "return_attention_mask": True} |
|
if input_audio.shape[0] / SAMPLING_RATE > 30.0 |
|
else {} |
|
) |
|
inputs = processor( |
|
input_audio, |
|
sampling_rate=SAMPLING_RATE, |
|
return_tensors="pt", |
|
**processor_kwargs, |
|
) |
|
inputs = inputs.to(dtype=dtype, device=device) |
|
tokens = model.generate( |
|
**inputs, |
|
return_timestamps=True, |
|
) |
|
|
|
sequences = processor.batch_decode(tokens, decode_with_timestamps=True)[0] |
|
sequences_no_special = processor.batch_decode(tokens, skip_special_tokens=True)[0] |
|
|
|
regex_search = re.findall(r"<\|[\d\.]+\|><\|[\d\.]+\|>", sequences) |
|
regex_split = re.split(r"<\|[\d\.]+\|><\|[\d\.]+\|>", sequences) |
|
|
|
|
|
if len(regex_search) > 1 and len("".join(regex_split[1:]).split()) > 5: |
|
cut_idx = int(SAMPLING_RATE * float(regex_search[0].split("|><|")[0][2:])) |
|
|
|
start_idx += cut_idx |
|
fully_decoded += sequences_no_special |
|
sequences_no_special = "" |
|
|
|
print(fully_decoded + sequences_no_special) |
|
print(f"Passed time: {end_idx / 16_000}") |
|
|