wsntxxn commited on
Commit
80f816b
·
verified ·
1 Parent(s): 69d9fbc

Update hf_wrapper.py

Browse files
Files changed (1) hide show
  1. hf_wrapper.py +8 -0
hf_wrapper.py CHANGED
@@ -1159,6 +1159,14 @@ class Effb2TrmCaptioningModel(PreTrainedModel):
1159
  model = TransformerModel(encoder, decoder)
1160
  self.model = ContraEncoderKdWrapper(model, config.shared_dim, config.tchr_dim)
1161
 
 
 
 
 
 
 
 
 
1162
  def forward(self,
1163
  audio: torch.Tensor,
1164
  audio_length: Union[List, np.ndarray, torch.Tensor],
 
1159
  model = TransformerModel(encoder, decoder)
1160
  self.model = ContraEncoderKdWrapper(model, config.shared_dim, config.tchr_dim)
1161
 
1162
+ @classmethod
1163
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
1164
+ model = super().from_pretrained(
1165
+ pretrained_model_name_or_path, *args, **kwargs
1166
+ )
1167
+ model.model.model.decoder.word_embedding.weight = model.model.model.decoder.classifier.weight
1168
+ return model
1169
+
1170
  def forward(self,
1171
  audio: torch.Tensor,
1172
  audio_length: Union[List, np.ndarray, torch.Tensor],