AlexHung29629 commited on
Commit
bf1a94f
·
verified ·
1 Parent(s): b7af19e

Update ultravox_processing.py

Browse files
Files changed (1) hide show
  1. ultravox_processing.py +26 -23
ultravox_processing.py CHANGED
@@ -20,6 +20,7 @@ class UltravoxProcessor(transformers.ProcessorMixin):
20
  "Wav2Vec2Processor",
21
  "SeamlessM4TFeatureExtractor",
22
  "WhisperProcessor",
 
23
  )
24
  tokenizer_class = (
25
  "PreTrainedTokenizer",
@@ -128,12 +129,27 @@ class UltravoxProcessor(transformers.ProcessorMixin):
128
  """
129
  # TODO: Add support for multiple audio and text inputs.
130
  data = {}
 
131
  if audio is not None and len(audio) > 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  # Main audio processing. The processor is model-specific.
133
  x = self.audio_processor(
134
  audio,
135
  sampling_rate=sampling_rate,
136
  padding="longest",
 
137
  return_attention_mask=True,
138
  **kwargs,
139
  )
@@ -142,21 +158,13 @@ class UltravoxProcessor(transformers.ProcessorMixin):
142
  else:
143
  data["audio_values"] = x.input_values
144
  data["audio_len"] = x.attention_mask.sum(-1) - 1
145
- def cnn_out_len(in_len, kernel, stride=1, padding=1, dilation=1):
146
- return np.floor((in_len + (2*padding) - (dilation * (kernel - 1)) - 1)/stride + 1)
147
- def stack_frame_len(T):
148
- T_pad = ((T + self.stack_factor - 1) // self.stack_factor) * self.stack_factor
149
- return ((T_pad + self.stack_factor) // self.stack_factor).astype(int)
150
- nb_encoder_frames = cnn_out_len(cnn_out_len(data["audio_len"], kernel=3), kernel=3, stride=2)
151
- data["audio_token_len"] = stack_frame_len(nb_encoder_frames)
152
 
153
  if text is not None:
154
- assert isinstance(
155
- text, list
156
- ), "Text must be a list."
157
- processed_text = []
158
  data["audio_token_start_idx"] = []
159
- for i, t in enumerate(text):
160
  assert self.audio_placeholder in t
161
  if "audio_token_len" not in data:
162
  raise ValueError(
@@ -165,24 +173,19 @@ class UltravoxProcessor(transformers.ProcessorMixin):
165
 
166
  start_idx = len(
167
  self.tokenizer.encode(
168
- t.split(self.audio_placeholder)[0],
169
  add_special_tokens=False,
170
  )
171
  )
172
  data["audio_token_start_idx"].append(start_idx)
173
 
174
- # Replace the audio placeholder with the audio token.
175
- # e.g. "Transcribe <|audio|>" -> "Transcribe </s></s></s></s></s></s></s></s>"
176
- # where the number of </s> is the number of audio frames.
177
- t = t.replace(
178
- self.audio_placeholder,
179
- self.audio_token_replacement * data["audio_token_len"][i],
180
- )
181
- processed_text.append(t)
182
-
183
 
184
  # Special tokens like BOS should already have been added by the caller.
185
- data.update(self.tokenizer(processed_text, add_special_tokens=False, padding='longest', **kwargs))
186
 
187
  return transformers.BatchFeature(data=data, tensor_type=return_tensors)
188
 
 
20
  "Wav2Vec2Processor",
21
  "SeamlessM4TFeatureExtractor",
22
  "WhisperProcessor",
23
+ "Wav2Vec2BertProcessor",
24
  )
25
  tokenizer_class = (
26
  "PreTrainedTokenizer",
 
129
  """
130
  # TODO: Add support for multiple audio and text inputs.
131
  data = {}
132
+ audio_embed_frames = 0
133
  if audio is not None and len(audio) > 0:
134
+ if self.audio_padding == "max_length":
135
+ # 30 seconds is the expected length for Whisper
136
+ assert sampling_rate is not None, "Sampling rate must be provided."
137
+ audio_len = [30 * sampling_rate] * len(audio)
138
+ else:
139
+ audio_len = [a.shape[-1] for a in audio]
140
+ # It's guaranteed that the number of frames is less than or equal to this amount.
141
+ # For Whisper this is exact AFAICT, but for Wav2Vec2 it's an upper bound.
142
+ # Currently, StackAudioFrames makes sure an over-estimation won't cause issues by padding the audio embeddings.
143
+ nb_encoder_frames = [int(round(a / self.encoder_ds_factor + 1e-4)) for a in audio_len]
144
+ audio_embed_frames = [int(np.ceil(n / self.stack_factor)) for n in nb_encoder_frames]
145
+ data["audio_token_len"] = audio_embed_frames
146
+
147
  # Main audio processing. The processor is model-specific.
148
  x = self.audio_processor(
149
  audio,
150
  sampling_rate=sampling_rate,
151
  padding="longest",
152
+ max_length=max(audio_len),
153
  return_attention_mask=True,
154
  **kwargs,
155
  )
 
158
  else:
159
  data["audio_values"] = x.input_values
160
  data["audio_len"] = x.attention_mask.sum(-1) - 1
 
 
 
 
 
 
 
161
 
162
  if text is not None:
163
+ #assert isinstance(
164
+ # text, str
165
+ #), "Text must be a string. Batch mode not supported yet."
 
166
  data["audio_token_start_idx"] = []
167
+ for t in text:
168
  assert self.audio_placeholder in t
169
  if "audio_token_len" not in data:
170
  raise ValueError(
 
173
 
174
  start_idx = len(
175
  self.tokenizer.encode(
176
+ t[: t.index(self.audio_placeholder)],
177
  add_special_tokens=False,
178
  )
179
  )
180
  data["audio_token_start_idx"].append(start_idx)
181
 
182
+ # Replace the audio placeholder with the audio token.
183
+ # e.g. "Transcribe\n<|audio|>" -> "Transcribe </s></s></s></s></s></s></s></s>"
184
+ # where the number of </s> is the number of audio frames.
185
+ text = [t.replace(self.audio_placeholder, self.audio_token_replacement * data["audio_token_len"][i]) for i, t in enumerate(text)]
 
 
 
 
 
186
 
187
  # Special tokens like BOS should already have been added by the caller.
188
+ data.update(self.tokenizer(text, add_special_tokens=False, padding=True, **kwargs))
189
 
190
  return transformers.BatchFeature(data=data, tensor_type=return_tensors)
191