Update ultravox_processing.py
Browse files- ultravox_processing.py +18 -19
ultravox_processing.py
CHANGED
@@ -134,15 +134,15 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
134 |
if self.audio_padding == "max_length":
|
135 |
# 30 seconds is the expected length for Whisper
|
136 |
assert sampling_rate is not None, "Sampling rate must be provided."
|
137 |
-
audio_len = 30 * sampling_rate
|
138 |
else:
|
139 |
-
audio_len =
|
140 |
# It's guaranteed that the number of frames is less than or equal to this amount.
|
141 |
# For Whisper this is exact AFAICT, but for Wav2Vec2 it's an upper bound.
|
142 |
# Currently, StackAudioFrames makes sure an over-estimation won't cause issues by padding the audio embeddings.
|
143 |
-
nb_encoder_frames = int(round(
|
144 |
-
audio_embed_frames = int(np.ceil(
|
145 |
-
data["audio_token_len"] =
|
146 |
|
147 |
# Main audio processing. The processor is model-specific.
|
148 |
x = self.audio_processor(
|
@@ -160,10 +160,12 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
160 |
data["audio_len"] = x.attention_mask.sum(-1) - 1
|
161 |
|
162 |
if text is not None:
|
163 |
-
assert isinstance(
|
164 |
-
|
165 |
-
), "Text must be a string. Batch mode not supported yet."
|
166 |
-
|
|
|
|
|
167 |
if "audio_token_len" not in data:
|
168 |
raise ValueError(
|
169 |
f"audio must be provided when using audio placeholder ({self.audio_placeholder}) in text."
|
@@ -171,19 +173,16 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
171 |
|
172 |
start_idx = len(
|
173 |
self.tokenizer.encode(
|
174 |
-
|
175 |
add_special_tokens=False,
|
176 |
)
|
177 |
)
|
178 |
-
data["audio_token_start_idx"]
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
self.audio_placeholder,
|
185 |
-
self.audio_token_replacement * audio_embed_frames,
|
186 |
-
)
|
187 |
|
188 |
# Special tokens like BOS should already have been added by the caller.
|
189 |
data.update(self.tokenizer([text], add_special_tokens=False, **kwargs))
|
|
|
134 |
if self.audio_padding == "max_length":
|
135 |
# 30 seconds is the expected length for Whisper
|
136 |
assert sampling_rate is not None, "Sampling rate must be provided."
|
137 |
+
audio_len = [30 * sampling_rate] * len(audio)
|
138 |
else:
|
139 |
+
audio_len = [a.shape[-1] for a in audio]
|
140 |
# It's guaranteed that the number of frames is less than or equal to this amount.
|
141 |
# For Whisper this is exact AFAICT, but for Wav2Vec2 it's an upper bound.
|
142 |
# Currently, StackAudioFrames makes sure an over-estimation won't cause issues by padding the audio embeddings.
|
143 |
+
nb_encoder_frames = [int(round(a / self.encoder_ds_factor + 1e-4)) for a in audio_len]
|
144 |
+
audio_embed_frames = [int(np.ceil(n / self.stack_factor)) for n in nb_encoder_frames]
|
145 |
+
data["audio_token_len"] = audio_embed_frames
|
146 |
|
147 |
# Main audio processing. The processor is model-specific.
|
148 |
x = self.audio_processor(
|
|
|
160 |
data["audio_len"] = x.attention_mask.sum(-1) - 1
|
161 |
|
162 |
if text is not None:
|
163 |
+
#assert isinstance(
|
164 |
+
# text, str
|
165 |
+
#), "Text must be a string. Batch mode not supported yet."
|
166 |
+
data["audio_token_start_idx"] = []
|
167 |
+
for t in text:
|
168 |
+
assert self.audio_placeholder in t
|
169 |
if "audio_token_len" not in data:
|
170 |
raise ValueError(
|
171 |
f"audio must be provided when using audio placeholder ({self.audio_placeholder}) in text."
|
|
|
173 |
|
174 |
start_idx = len(
|
175 |
self.tokenizer.encode(
|
176 |
+
t[: t.index(self.audio_placeholder)],
|
177 |
add_special_tokens=False,
|
178 |
)
|
179 |
)
|
180 |
+
data["audio_token_start_idx"].append(start_idx)
|
181 |
+
|
182 |
+
# Replace the audio placeholder with the audio token.
|
183 |
+
# e.g. "Transcribe\n<|audio|>" -> "Transcribe </s></s></s></s></s></s></s></s>"
|
184 |
+
# where the number of </s> is the number of audio frames.
|
185 |
+
text = [t.replace(self.audio_placeholder, self.audio_token_replacement * data["audio_token_len"][i]) for i, t in enumerate(text)]
|
|
|
|
|
|
|
186 |
|
187 |
# Special tokens like BOS should already have been added by the caller.
|
188 |
data.update(self.tokenizer([text], add_special_tokens=False, **kwargs))
|