#!/usr/bin/env python3 from transformers import ( WhisperForConditionalGeneration, WhisperProcessor, ) import torch import re import numpy as np from datasets import load_dataset device = "cpu" dtype = torch.float32 processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") model = WhisperForConditionalGeneration.from_pretrained( "openai/whisper-tiny", low_cpu_mem_usage=True, torch_dtype=dtype ) model.to(device) STREAMING_INTERVAL = 0.33 # in seconds SAMPLING_RATE = 16_000 INTERVAL_LENGTH = int(STREAMING_INTERVAL * SAMPLING_RATE) ds = load_dataset( "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation" ) audio_array = np.concatenate([x["array"] for x in ds["audio"]]) # fake streaming by decoding every STREAMING_INTERVAL seconds start_idx = 0 fully_decoded = "" for end_idx in range(INTERVAL_LENGTH, audio_array.shape[-1], INTERVAL_LENGTH): input_audio = audio_array[start_idx:end_idx] processor_kwargs = ( {"padding": "longest", "truncation": False, "return_attention_mask": True} if input_audio.shape[0] / SAMPLING_RATE > 30.0 else {} ) inputs = processor( input_audio, sampling_rate=SAMPLING_RATE, return_tensors="pt", **processor_kwargs, ) inputs = inputs.to(dtype=dtype, device=device) tokens = model.generate( **inputs, return_timestamps=True, ) sequences = processor.batch_decode(tokens, decode_with_timestamps=True)[0] sequences_no_special = processor.batch_decode(tokens, skip_special_tokens=True)[0] regex_search = re.findall(r"<\|[\d\.]+\|><\|[\d\.]+\|>", sequences) regex_split = re.split(r"<\|[\d\.]+\|><\|[\d\.]+\|>", sequences) # at least two timestamps seperations and 5 new words have to have been detected to cut input audio if len(regex_search) > 1 and len("".join(regex_split[1:]).split()) > 5: cut_idx = int(SAMPLING_RATE * float(regex_search[0].split("|><|")[0][2:])) start_idx += cut_idx fully_decoded += sequences_no_special sequences_no_special = "" print(fully_decoded + sequences_no_special) print(f"Passed time: {end_idx / 16_000}")