Spaces:
Build error
Build error
Commit
·
df00910
1
Parent(s):
1f63fcf
Create processing_whisper.py
Browse files- processing_whisper.py +143 -0
processing_whisper.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
from transformers import WhisperProcessor
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class WhisperPrePostProcessor(WhisperProcessor):
|
| 8 |
+
def chunk_iter_with_batch(self, inputs, chunk_len, stride_left, stride_right, batch_size):
|
| 9 |
+
inputs_len = inputs.shape[0]
|
| 10 |
+
step = chunk_len - stride_left - stride_right
|
| 11 |
+
|
| 12 |
+
all_chunk_start_idx = np.arange(0, inputs_len, step)
|
| 13 |
+
num_samples = len(all_chunk_start_idx)
|
| 14 |
+
|
| 15 |
+
num_batches = math.ceil(num_samples / batch_size)
|
| 16 |
+
batch_idx = np.array_split(np.arange(num_samples), num_batches)
|
| 17 |
+
|
| 18 |
+
for i, idx in enumerate(batch_idx):
|
| 19 |
+
chunk_start_idx = all_chunk_start_idx[idx]
|
| 20 |
+
|
| 21 |
+
chunk_end_idx = chunk_start_idx + chunk_len
|
| 22 |
+
|
| 23 |
+
chunks = [inputs[chunk_start:chunk_end] for chunk_start, chunk_end in zip(chunk_start_idx, chunk_end_idx)]
|
| 24 |
+
processed = self.feature_extractor(
|
| 25 |
+
chunks, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np"
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
_stride_left = np.where(chunk_start_idx == 0, 0, stride_left)
|
| 29 |
+
is_last = np.where(stride_right > 0, chunk_end_idx > inputs_len, chunk_end_idx >= inputs_len)
|
| 30 |
+
_stride_right = np.where(is_last, 0, stride_right)
|
| 31 |
+
|
| 32 |
+
chunk_lens = [chunk.shape[0] for chunk in chunks]
|
| 33 |
+
strides = [
|
| 34 |
+
(int(chunk_l), int(_stride_l), int(_stride_r))
|
| 35 |
+
for chunk_l, _stride_l, _stride_r in zip(chunk_lens, _stride_left, _stride_right)
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
yield {"stride": strides, **processed}
|
| 39 |
+
|
| 40 |
+
def preprocess_batch(self, inputs, chunk_length_s=0, stride_length_s=None, batch_size=None):
|
| 41 |
+
stride = None
|
| 42 |
+
if isinstance(inputs, dict):
|
| 43 |
+
stride = inputs.pop("stride", None)
|
| 44 |
+
# Accepting `"array"` which is the key defined in `datasets` for
|
| 45 |
+
# better integration
|
| 46 |
+
if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
|
| 47 |
+
raise ValueError(
|
| 48 |
+
"When passing a dictionary to FlaxWhisperPipline, the dict needs to contain a "
|
| 49 |
+
'"raw" or "array" key containing the numpy array representing the audio, and a "sampling_rate" key '
|
| 50 |
+
"containing the sampling rate associated with the audio array."
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
_inputs = inputs.pop("raw", None)
|
| 54 |
+
if _inputs is None:
|
| 55 |
+
# Remove path which will not be used from `datasets`.
|
| 56 |
+
inputs.pop("path", None)
|
| 57 |
+
_inputs = inputs.pop("array", None)
|
| 58 |
+
in_sampling_rate = inputs.pop("sampling_rate")
|
| 59 |
+
inputs = _inputs
|
| 60 |
+
|
| 61 |
+
if in_sampling_rate != self.feature_extractor.sampling_rate:
|
| 62 |
+
try:
|
| 63 |
+
import librosa
|
| 64 |
+
except ImportError as err:
|
| 65 |
+
raise ImportError(
|
| 66 |
+
"To support resampling audio files, please install 'librosa' and 'soundfile'."
|
| 67 |
+
) from err
|
| 68 |
+
|
| 69 |
+
inputs = librosa.resample(
|
| 70 |
+
inputs, orig_sr=in_sampling_rate, target_sr=self.feature_extractor.sampling_rate
|
| 71 |
+
)
|
| 72 |
+
ratio = self.feature_extractor.sampling_rate / in_sampling_rate
|
| 73 |
+
else:
|
| 74 |
+
ratio = 1
|
| 75 |
+
|
| 76 |
+
if not isinstance(inputs, np.ndarray):
|
| 77 |
+
raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
|
| 78 |
+
if len(inputs.shape) != 1:
|
| 79 |
+
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
|
| 80 |
+
|
| 81 |
+
if stride is not None:
|
| 82 |
+
if stride[0] + stride[1] > inputs.shape[0]:
|
| 83 |
+
raise ValueError("Stride is too large for input")
|
| 84 |
+
|
| 85 |
+
# Stride needs to get the chunk length here, it's going to get
|
| 86 |
+
# swallowed by the `feature_extractor` later, and then batching
|
| 87 |
+
# can add extra data in the inputs, so we need to keep track
|
| 88 |
+
# of the original length in the stride so we can cut properly.
|
| 89 |
+
stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
|
| 90 |
+
|
| 91 |
+
if chunk_length_s:
|
| 92 |
+
if stride_length_s is None:
|
| 93 |
+
stride_length_s = chunk_length_s / 6
|
| 94 |
+
|
| 95 |
+
if isinstance(stride_length_s, (int, float)):
|
| 96 |
+
stride_length_s = [stride_length_s, stride_length_s]
|
| 97 |
+
|
| 98 |
+
chunk_len = round(chunk_length_s * self.feature_extractor.sampling_rate)
|
| 99 |
+
stride_left = round(stride_length_s[0] * self.feature_extractor.sampling_rate)
|
| 100 |
+
stride_right = round(stride_length_s[1] * self.feature_extractor.sampling_rate)
|
| 101 |
+
|
| 102 |
+
if chunk_len < stride_left + stride_right:
|
| 103 |
+
raise ValueError("Chunk length must be superior to stride length")
|
| 104 |
+
|
| 105 |
+
for item in self.chunk_iter_with_batch(
|
| 106 |
+
inputs,
|
| 107 |
+
chunk_len,
|
| 108 |
+
stride_left,
|
| 109 |
+
stride_right,
|
| 110 |
+
batch_size,
|
| 111 |
+
):
|
| 112 |
+
yield item
|
| 113 |
+
else:
|
| 114 |
+
processed = self.feature_extractor(
|
| 115 |
+
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np"
|
| 116 |
+
)
|
| 117 |
+
if stride is not None:
|
| 118 |
+
processed["stride"] = stride
|
| 119 |
+
yield processed
|
| 120 |
+
|
| 121 |
+
def postprocess(self, model_outputs, return_timestamps=None, return_language=None):
|
| 122 |
+
# unpack the outputs from list(dict(list)) to list(dict)
|
| 123 |
+
model_outputs = [dict(zip(output, t)) for output in model_outputs for t in zip(*output.values())]
|
| 124 |
+
|
| 125 |
+
time_precision = self.feature_extractor.chunk_length / 1500 # max source positions = 1500
|
| 126 |
+
# Send the chunking back to seconds, it's easier to handle in whisper
|
| 127 |
+
sampling_rate = self.feature_extractor.sampling_rate
|
| 128 |
+
for output in model_outputs:
|
| 129 |
+
if "stride" in output:
|
| 130 |
+
chunk_len, stride_left, stride_right = output["stride"]
|
| 131 |
+
# Go back in seconds
|
| 132 |
+
chunk_len /= sampling_rate
|
| 133 |
+
stride_left /= sampling_rate
|
| 134 |
+
stride_right /= sampling_rate
|
| 135 |
+
output["stride"] = chunk_len, stride_left, stride_right
|
| 136 |
+
|
| 137 |
+
text, optional = self.tokenizer._decode_asr(
|
| 138 |
+
model_outputs,
|
| 139 |
+
return_timestamps=return_timestamps,
|
| 140 |
+
return_language=return_language,
|
| 141 |
+
time_precision=time_precision,
|
| 142 |
+
)
|
| 143 |
+
return {"text": text, **optional}
|