Update ultravox_processing.py
Browse files- ultravox_processing.py +26 -31
ultravox_processing.py
CHANGED
@@ -10,7 +10,6 @@ from .ultravox_config import UltravoxConfig
|
|
10 |
class UltravoxProcessor(transformers.ProcessorMixin):
|
11 |
"""
|
12 |
Constructs an Ultravox processor which wraps an audio processor and a tokenizer into a single processor.
|
13 |
-
|
14 |
Args:
|
15 |
audio_processor: The audio processor for the audio encoder.
|
16 |
tokenizer: The tokenizer for the language model.
|
@@ -100,7 +99,6 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
100 |
the text. To prepare the audio(s), this method forwards the `audio`, `sampling_rate` and `kwargs` arguments to
|
101 |
audio processor's [`~Wav2Vec2Processor.__call__`] if `audio` is not `None`. Please refer to the docstring
|
102 |
of the above two methods for more information.
|
103 |
-
|
104 |
Args:
|
105 |
text (`str`, `List[str]`):
|
106 |
The sequence to be encoded. Sequence can be a string or (pretokenized string).
|
@@ -113,15 +111,12 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
113 |
you are doing.
|
114 |
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
115 |
If set, will return tensors of a particular framework. Acceptable values are:
|
116 |
-
|
117 |
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
118 |
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
119 |
- `'np'`: Return NumPy `np.ndarray` objects.
|
120 |
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
121 |
-
|
122 |
Returns:
|
123 |
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
124 |
-
|
125 |
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
126 |
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
127 |
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
@@ -133,7 +128,6 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
133 |
"""
|
134 |
# TODO: Add support for multiple audio and text inputs.
|
135 |
data = {}
|
136 |
-
audio_embed_frames = 0
|
137 |
if audio is not None and len(audio) > 0:
|
138 |
# Main audio processing. The processor is model-specific.
|
139 |
x = self.audio_processor(
|
@@ -151,10 +145,10 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
151 |
def cnn_out_len(in_len, kernel, stride=1, padding=1, dilation=1):
|
152 |
return np.floor((in_len + (2*padding) - (dilation * (kernel - 1)) - 1)/stride + 1)
|
153 |
def stack_frame_len(T):
|
154 |
-
T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor
|
155 |
-
return
|
156 |
-
nb_encoder_frames =
|
157 |
-
data["audio_token_len"] =
|
158 |
|
159 |
if text is not None:
|
160 |
assert isinstance(
|
@@ -162,29 +156,30 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
162 |
), "Text must be a list."
|
163 |
processed_text = []
|
164 |
data["audio_token_start_idx"] = []
|
165 |
-
for t in text:
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
)
|
171 |
-
|
172 |
-
start_idx = len(
|
173 |
-
self.tokenizer.encode(
|
174 |
-
t[: t.index(self.audio_placeholder)],
|
175 |
-
add_special_tokens=False,
|
176 |
-
)
|
177 |
)
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
t = t.replace(
|
184 |
-
self.audio_placeholder,
|
185 |
-
self.audio_token_replacement * audio_embed_frames,
|
186 |
)
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
# Special tokens like BOS should already have been added by the caller.
|
190 |
data.update(self.tokenizer(processed_text, add_special_tokens=False, padding='longest', **kwargs))
|
|
|
10 |
class UltravoxProcessor(transformers.ProcessorMixin):
|
11 |
"""
|
12 |
Constructs an Ultravox processor which wraps an audio processor and a tokenizer into a single processor.
|
|
|
13 |
Args:
|
14 |
audio_processor: The audio processor for the audio encoder.
|
15 |
tokenizer: The tokenizer for the language model.
|
|
|
99 |
the text. To prepare the audio(s), this method forwards the `audio`, `sampling_rate` and `kwargs` arguments to
|
100 |
audio processor's [`~Wav2Vec2Processor.__call__`] if `audio` is not `None`. Please refer to the docstring
|
101 |
of the above two methods for more information.
|
|
|
102 |
Args:
|
103 |
text (`str`, `List[str]`):
|
104 |
The sequence to be encoded. Sequence can be a string or (pretokenized string).
|
|
|
111 |
you are doing.
|
112 |
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
113 |
If set, will return tensors of a particular framework. Acceptable values are:
|
|
|
114 |
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
115 |
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
116 |
- `'np'`: Return NumPy `np.ndarray` objects.
|
117 |
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
|
|
118 |
Returns:
|
119 |
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
|
|
120 |
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
121 |
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
122 |
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
|
|
128 |
"""
|
129 |
# TODO: Add support for multiple audio and text inputs.
|
130 |
data = {}
|
|
|
131 |
if audio is not None and len(audio) > 0:
|
132 |
# Main audio processing. The processor is model-specific.
|
133 |
x = self.audio_processor(
|
|
|
145 |
def cnn_out_len(in_len, kernel, stride=1, padding=1, dilation=1):
|
146 |
return np.floor((in_len + (2*padding) - (dilation * (kernel - 1)) - 1)/stride + 1)
|
147 |
def stack_frame_len(T):
|
148 |
+
T_pad = ((T + self.stack_factor - 1) // self.stack_factor) * self.stack_factor
|
149 |
+
return ((T_pad + self.stack_factor) // self.stack_factor).astype(int)
|
150 |
+
nb_encoder_frames = cnn_out_len(cnn_out_len(data["audio_len"], kernel=3), kernel=3, stride=2)
|
151 |
+
data["audio_token_len"] = stack_frame_len(nb_encoder_frames)
|
152 |
|
153 |
if text is not None:
|
154 |
assert isinstance(
|
|
|
156 |
), "Text must be a list."
|
157 |
processed_text = []
|
158 |
data["audio_token_start_idx"] = []
|
159 |
+
for i, t in enumerate(text):
|
160 |
+
assert self.audio_placeholder in t
|
161 |
+
if "audio_token_len" not in data:
|
162 |
+
raise ValueError(
|
163 |
+
f"audio must be provided when using audio placeholder ({self.audio_placeholder}) in text."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
)
|
165 |
+
|
166 |
+
start_idx = len(
|
167 |
+
self.tokenizer.encode(
|
168 |
+
t.split(self.audio_placeholder)[0],
|
169 |
+
add_special_tokens=False,
|
|
|
|
|
|
|
170 |
)
|
171 |
+
)
|
172 |
+
data["audio_token_start_idx"].append(start_idx)
|
173 |
+
|
174 |
+
# Replace the audio placeholder with the audio token.
|
175 |
+
# e.g. "Transcribe <|audio|>" -> "Transcribe </s></s></s></s></s></s></s></s>"
|
176 |
+
# where the number of </s> is the number of audio frames.
|
177 |
+
t = t.replace(
|
178 |
+
self.audio_placeholder,
|
179 |
+
self.audio_token_replacement * data["audio_token_len"][i],
|
180 |
+
)
|
181 |
+
processed_text.append(t)
|
182 |
+
|
183 |
|
184 |
# Special tokens like BOS should already have been added by the caller.
|
185 |
data.update(self.tokenizer(processed_text, add_special_tokens=False, padding='longest', **kwargs))
|