Update ultravox_processing.py
Browse files- ultravox_processing.py +26 -23
ultravox_processing.py
CHANGED
@@ -20,6 +20,7 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
20 |
"Wav2Vec2Processor",
|
21 |
"SeamlessM4TFeatureExtractor",
|
22 |
"WhisperProcessor",
|
|
|
23 |
)
|
24 |
tokenizer_class = (
|
25 |
"PreTrainedTokenizer",
|
@@ -128,12 +129,27 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
128 |
"""
|
129 |
# TODO: Add support for multiple audio and text inputs.
|
130 |
data = {}
|
|
|
131 |
if audio is not None and len(audio) > 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
# Main audio processing. The processor is model-specific.
|
133 |
x = self.audio_processor(
|
134 |
audio,
|
135 |
sampling_rate=sampling_rate,
|
136 |
padding="longest",
|
|
|
137 |
return_attention_mask=True,
|
138 |
**kwargs,
|
139 |
)
|
@@ -142,21 +158,13 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
142 |
else:
|
143 |
data["audio_values"] = x.input_values
|
144 |
data["audio_len"] = x.attention_mask.sum(-1) - 1
|
145 |
-
def cnn_out_len(in_len, kernel, stride=1, padding=1, dilation=1):
|
146 |
-
return np.floor((in_len + (2*padding) - (dilation * (kernel - 1)) - 1)/stride + 1)
|
147 |
-
def stack_frame_len(T):
|
148 |
-
T_pad = ((T + self.stack_factor - 1) // self.stack_factor) * self.stack_factor
|
149 |
-
return ((T_pad + self.stack_factor) // self.stack_factor).astype(int)
|
150 |
-
nb_encoder_frames = cnn_out_len(cnn_out_len(data["audio_len"], kernel=3), kernel=3, stride=2)
|
151 |
-
data["audio_token_len"] = stack_frame_len(nb_encoder_frames)
|
152 |
|
153 |
if text is not None:
|
154 |
-
assert isinstance(
|
155 |
-
|
156 |
-
), "Text must be a
|
157 |
-
processed_text = []
|
158 |
data["audio_token_start_idx"] = []
|
159 |
-
for
|
160 |
assert self.audio_placeholder in t
|
161 |
if "audio_token_len" not in data:
|
162 |
raise ValueError(
|
@@ -165,24 +173,19 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
165 |
|
166 |
start_idx = len(
|
167 |
self.tokenizer.encode(
|
168 |
-
t.
|
169 |
add_special_tokens=False,
|
170 |
)
|
171 |
)
|
172 |
data["audio_token_start_idx"].append(start_idx)
|
173 |
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
self.audio_placeholder,
|
179 |
-
self.audio_token_replacement * data["audio_token_len"][i],
|
180 |
-
)
|
181 |
-
processed_text.append(t)
|
182 |
-
|
183 |
|
184 |
# Special tokens like BOS should already have been added by the caller.
|
185 |
-
data.update(self.tokenizer(
|
186 |
|
187 |
return transformers.BatchFeature(data=data, tensor_type=return_tensors)
|
188 |
|
|
|
20 |
"Wav2Vec2Processor",
|
21 |
"SeamlessM4TFeatureExtractor",
|
22 |
"WhisperProcessor",
|
23 |
+
"Wav2Vec2BertProcessor",
|
24 |
)
|
25 |
tokenizer_class = (
|
26 |
"PreTrainedTokenizer",
|
|
|
129 |
"""
|
130 |
# TODO: Add support for multiple audio and text inputs.
|
131 |
data = {}
|
132 |
+
audio_embed_frames = 0
|
133 |
if audio is not None and len(audio) > 0:
|
134 |
+
if self.audio_padding == "max_length":
|
135 |
+
# 30 seconds is the expected length for Whisper
|
136 |
+
assert sampling_rate is not None, "Sampling rate must be provided."
|
137 |
+
audio_len = [30 * sampling_rate] * len(audio)
|
138 |
+
else:
|
139 |
+
audio_len = [a.shape[-1] for a in audio]
|
140 |
+
# It's guaranteed that the number of frames is less than or equal to this amount.
|
141 |
+
# For Whisper this is exact AFAICT, but for Wav2Vec2 it's an upper bound.
|
142 |
+
# Currently, StackAudioFrames makes sure an over-estimation won't cause issues by padding the audio embeddings.
|
143 |
+
nb_encoder_frames = [int(round(a / self.encoder_ds_factor + 1e-4)) for a in audio_len]
|
144 |
+
audio_embed_frames = [int(np.ceil(n / self.stack_factor)) for n in nb_encoder_frames]
|
145 |
+
data["audio_token_len"] = audio_embed_frames
|
146 |
+
|
147 |
# Main audio processing. The processor is model-specific.
|
148 |
x = self.audio_processor(
|
149 |
audio,
|
150 |
sampling_rate=sampling_rate,
|
151 |
padding="longest",
|
152 |
+
max_length=max(audio_len),
|
153 |
return_attention_mask=True,
|
154 |
**kwargs,
|
155 |
)
|
|
|
158 |
else:
|
159 |
data["audio_values"] = x.input_values
|
160 |
data["audio_len"] = x.attention_mask.sum(-1) - 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
|
162 |
if text is not None:
|
163 |
+
#assert isinstance(
|
164 |
+
# text, str
|
165 |
+
#), "Text must be a string. Batch mode not supported yet."
|
|
|
166 |
data["audio_token_start_idx"] = []
|
167 |
+
for t in text:
|
168 |
assert self.audio_placeholder in t
|
169 |
if "audio_token_len" not in data:
|
170 |
raise ValueError(
|
|
|
173 |
|
174 |
start_idx = len(
|
175 |
self.tokenizer.encode(
|
176 |
+
t[: t.index(self.audio_placeholder)],
|
177 |
add_special_tokens=False,
|
178 |
)
|
179 |
)
|
180 |
data["audio_token_start_idx"].append(start_idx)
|
181 |
|
182 |
+
# Replace the audio placeholder with the audio token.
|
183 |
+
# e.g. "Transcribe\n<|audio|>" -> "Transcribe </s></s></s></s></s></s></s></s>"
|
184 |
+
# where the number of </s> is the number of audio frames.
|
185 |
+
text = [t.replace(self.audio_placeholder, self.audio_token_replacement * data["audio_token_len"][i]) for i, t in enumerate(text)]
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
# Special tokens like BOS should already have been added by the caller.
|
188 |
+
data.update(self.tokenizer(text, add_special_tokens=False, padding=True, **kwargs))
|
189 |
|
190 |
return transformers.BatchFeature(data=data, tensor_type=return_tensors)
|
191 |
|