Update hf_wrapper.py
Browse files- hf_wrapper.py +8 -0
hf_wrapper.py
CHANGED
@@ -1159,6 +1159,14 @@ class Effb2TrmCaptioningModel(PreTrainedModel):
|
|
1159 |
model = TransformerModel(encoder, decoder)
|
1160 |
self.model = ContraEncoderKdWrapper(model, config.shared_dim, config.tchr_dim)
|
1161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1162 |
def forward(self,
|
1163 |
audio: torch.Tensor,
|
1164 |
audio_length: Union[List, np.ndarray, torch.Tensor],
|
|
|
1159 |
model = TransformerModel(encoder, decoder)
|
1160 |
self.model = ContraEncoderKdWrapper(model, config.shared_dim, config.tchr_dim)
|
1161 |
|
1162 |
+
@classmethod
|
1163 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
1164 |
+
model = super().from_pretrained(
|
1165 |
+
pretrained_model_name_or_path, *args, **kwargs
|
1166 |
+
)
|
1167 |
+
model.model.model.decoder.word_embedding.weight = model.model.model.decoder.classifier.weight
|
1168 |
+
return model
|
1169 |
+
|
1170 |
def forward(self,
|
1171 |
audio: torch.Tensor,
|
1172 |
audio_length: Union[List, np.ndarray, torch.Tensor],
|