AlexHung29629 commited on
Commit
3bec836
·
verified ·
1 Parent(s): 12908ea

Update ultravox_processing.py

Browse files
Files changed (1) hide show
  1. ultravox_processing.py +18 -19
ultravox_processing.py CHANGED
@@ -134,15 +134,15 @@ class UltravoxProcessor(transformers.ProcessorMixin):
134
  if self.audio_padding == "max_length":
135
  # 30 seconds is the expected length for Whisper
136
  assert sampling_rate is not None, "Sampling rate must be provided."
137
- audio_len = 30 * sampling_rate
138
  else:
139
- audio_len = max([a.shape[-1] for a in audio])
140
  # It's guaranteed that the number of frames is less than or equal to this amount.
141
  # For Whisper this is exact AFAICT, but for Wav2Vec2 it's an upper bound.
142
  # Currently, StackAudioFrames makes sure an over-estimation won't cause issues by padding the audio embeddings.
143
- nb_encoder_frames = int(round(audio_len / self.encoder_ds_factor + 1e-4))
144
- audio_embed_frames = int(np.ceil(nb_encoder_frames / self.stack_factor))
145
- data["audio_token_len"] = [audio_embed_frames]
146
 
147
  # Main audio processing. The processor is model-specific.
148
  x = self.audio_processor(
@@ -160,10 +160,12 @@ class UltravoxProcessor(transformers.ProcessorMixin):
160
  data["audio_len"] = x.attention_mask.sum(-1) - 1
161
 
162
  if text is not None:
163
- assert isinstance(
164
- text, str
165
- ), "Text must be a string. Batch mode not supported yet."
166
- if self.audio_placeholder in text:
 
 
167
  if "audio_token_len" not in data:
168
  raise ValueError(
169
  f"audio must be provided when using audio placeholder ({self.audio_placeholder}) in text."
@@ -171,19 +173,16 @@ class UltravoxProcessor(transformers.ProcessorMixin):
171
 
172
  start_idx = len(
173
  self.tokenizer.encode(
174
- text[: text.index(self.audio_placeholder)],
175
  add_special_tokens=False,
176
  )
177
  )
178
- data["audio_token_start_idx"] = [start_idx]
179
-
180
- # Replace the audio placeholder with the audio token.
181
- # e.g. "Transcribe\n<|audio|>" -> "Transcribe </s></s></s></s></s></s></s></s>"
182
- # where the number of </s> is the number of audio frames.
183
- text = text.replace(
184
- self.audio_placeholder,
185
- self.audio_token_replacement * audio_embed_frames,
186
- )
187
 
188
  # Special tokens like BOS should already have been added by the caller.
189
  data.update(self.tokenizer([text], add_special_tokens=False, **kwargs))
 
134
  if self.audio_padding == "max_length":
135
  # 30 seconds is the expected length for Whisper
136
  assert sampling_rate is not None, "Sampling rate must be provided."
137
+ audio_len = [30 * sampling_rate] * len(audio)
138
  else:
139
+ audio_len = [a.shape[-1] for a in audio]
140
  # It's guaranteed that the number of frames is less than or equal to this amount.
141
  # For Whisper this is exact AFAICT, but for Wav2Vec2 it's an upper bound.
142
  # Currently, StackAudioFrames makes sure an over-estimation won't cause issues by padding the audio embeddings.
143
+ nb_encoder_frames = [int(round(a / self.encoder_ds_factor + 1e-4)) for a in audio_len]
144
+ audio_embed_frames = [int(np.ceil(n / self.stack_factor)) for n in nb_encoder_frames]
145
+ data["audio_token_len"] = audio_embed_frames
146
 
147
  # Main audio processing. The processor is model-specific.
148
  x = self.audio_processor(
 
160
  data["audio_len"] = x.attention_mask.sum(-1) - 1
161
 
162
  if text is not None:
163
+ #assert isinstance(
164
+ # text, str
165
+ #), "Text must be a string. Batch mode not supported yet."
166
+ data["audio_token_start_idx"] = []
167
+ for t in text:
168
+ assert self.audio_placeholder in t
169
  if "audio_token_len" not in data:
170
  raise ValueError(
171
  f"audio must be provided when using audio placeholder ({self.audio_placeholder}) in text."
 
173
 
174
  start_idx = len(
175
  self.tokenizer.encode(
176
+ t[: t.index(self.audio_placeholder)],
177
  add_special_tokens=False,
178
  )
179
  )
180
+ data["audio_token_start_idx"].append(start_idx)
181
+
182
+ # Replace the audio placeholder with the audio token.
183
+ # e.g. "Transcribe\n<|audio|>" -> "Transcribe </s></s></s></s></s></s></s></s>"
184
+ # where the number of </s> is the number of audio frames.
185
+ text = [t.replace(self.audio_placeholder, self.audio_token_replacement * data["audio_token_len"][i]) for i, t in enumerate(text)]
 
 
 
186
 
187
  # Special tokens like BOS should already have been added by the caller.
188
  data.update(self.tokenizer([text], add_special_tokens=False, **kwargs))