farzadab commited on
Commit
9551b2b
·
verified ·
1 Parent(s): 098bd49

Update ultravox_model.py

Browse files
Files changed (1) hide show
  1. ultravox_model.py +142 -1
ultravox_model.py CHANGED
@@ -12,7 +12,148 @@ import transformers.models
12
 
13
  # We must use relative import in this directory to allow uploading to HF Hub
14
  from . import ultravox_config
15
- from . import whisper_model_modified
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  class UltravoxModel(
 
12
 
13
  # We must use relative import in this directory to allow uploading to HF Hub
14
  from . import ultravox_config
15
+
16
+ # modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
17
+ # see this issue for the commentary: https://github.com/huggingface/transformers/issues/25744
18
+ #
19
+ # Copyright 2022 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved.
20
+ #
21
+ # Licensed under the Apache License, Version 2.0 (the "License");
22
+ # you may not use this file except in compliance with the License.
23
+ # You may obtain a copy of the License at
24
+ #
25
+ # http://www.apache.org/licenses/LICENSE-2.0
26
+ #
27
+ # Unless required by applicable law or agreed to in writing, software
28
+ # distributed under the License is distributed on an "AS IS" BASIS,
29
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30
+ # See the License for the specific language governing permissions and
31
+ # limitations under the License.
32
+ import torch
33
+ import torch.nn as nn
34
+ import transformers
35
+ import transformers.modeling_outputs
36
+ from transformers.models.whisper import modeling_whisper as whisper
37
+
38
+
39
+ class WhisperEncoder(whisper.WhisperEncoder):
40
+ """
41
+ Encoder portion of OpenAI's Whisper model.
42
+
43
+ This implementation is a slightly modified version of HF Transformers' Whisper Encoder, with only a few fixes:
44
+ 1. base_model_prefix updated to allow for doing `.from_pretrained` directly on the encoder
45
+ 2. allow less than 30 second of audio padding to be passed in:
46
+ - relaxed ValueError check for `input_features` length to be less than or equal to `expected_seq_length` instead of strictly equal
47
+ - embed_pos is now sliced to match the length of `inputs_embeds`
48
+
49
+ Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
50
+ """
51
+
52
+ base_model_prefix = "model.encoder"
53
+
54
+ def forward(
55
+ self,
56
+ input_features,
57
+ attention_mask=None,
58
+ head_mask=None,
59
+ output_attentions=None,
60
+ output_hidden_states=None,
61
+ return_dict=None,
62
+ ):
63
+ expected_seq_length = (
64
+ self.config.max_source_positions
65
+ * self.conv1.stride[0]
66
+ * self.conv2.stride[0]
67
+ )
68
+ if input_features.shape[-1] > expected_seq_length:
69
+ raise ValueError(
70
+ f"Whisper expects the mel input features to be of length {expected_seq_length} or less, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
71
+ )
72
+
73
+ output_attentions = (
74
+ output_attentions
75
+ if output_attentions is not None
76
+ else self.config.output_attentions
77
+ )
78
+ output_hidden_states = (
79
+ output_hidden_states
80
+ if output_hidden_states is not None
81
+ else self.config.output_hidden_states
82
+ )
83
+ return_dict = (
84
+ return_dict if return_dict is not None else self.config.use_return_dict
85
+ )
86
+ inputs_embeds = nn.functional.gelu(self.conv1(input_features))
87
+ inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
88
+
89
+ inputs_embeds = inputs_embeds.permute(0, 2, 1)
90
+ embed_pos = self.embed_positions.weight[: inputs_embeds.size(-2)]
91
+
92
+ hidden_states = inputs_embeds + embed_pos
93
+ hidden_states = nn.functional.dropout(
94
+ hidden_states, p=self.dropout, training=self.training
95
+ )
96
+
97
+ encoder_states = () if output_hidden_states else None
98
+ all_attentions = () if output_attentions else None
99
+
100
+ # check if head_mask has a correct number of layers specified if desired
101
+ if head_mask is not None:
102
+ assert head_mask.size()[0] == (
103
+ len(self.layers)
104
+ ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
105
+
106
+ for idx, encoder_layer in enumerate(self.layers):
107
+ if output_hidden_states:
108
+ encoder_states = encoder_states + (hidden_states,)
109
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
110
+ to_drop = False
111
+ if self.training:
112
+ dropout_probability = torch.rand([])
113
+ if dropout_probability < self.layerdrop: # skip the layer
114
+ to_drop = True
115
+
116
+ if to_drop:
117
+ layer_outputs = (None, None)
118
+ else:
119
+ if self.gradient_checkpointing and self.training:
120
+ layer_outputs = self._gradient_checkpointing_func(
121
+ encoder_layer.__call__,
122
+ hidden_states,
123
+ None,
124
+ (head_mask[idx] if head_mask is not None else None),
125
+ output_attentions,
126
+ )
127
+ else:
128
+ layer_outputs = encoder_layer(
129
+ hidden_states,
130
+ None,
131
+ layer_head_mask=(
132
+ head_mask[idx] if head_mask is not None else None
133
+ ),
134
+ output_attentions=output_attentions,
135
+ )
136
+
137
+ hidden_states = layer_outputs[0]
138
+
139
+ if output_attentions:
140
+ all_attentions = all_attentions + (layer_outputs[1],)
141
+
142
+ hidden_states = self.layer_norm(hidden_states)
143
+ if output_hidden_states:
144
+ encoder_states = encoder_states + (hidden_states,)
145
+
146
+ if not return_dict:
147
+ return tuple(
148
+ v
149
+ for v in [hidden_states, encoder_states, all_attentions]
150
+ if v is not None
151
+ )
152
+ return transformers.modeling_outputs.BaseModelOutput(
153
+ last_hidden_state=hidden_states,
154
+ hidden_states=encoder_states,
155
+ attentions=all_attentions,
156
+ )
157
 
158
 
159
  class UltravoxModel(