Commit
·
e3a99d1
1
Parent(s):
895ac06
Add HAT implementation files
Browse files- modelling_hat.py +20 -0
modelling_hat.py
CHANGED
@@ -1093,6 +1093,26 @@ class HATForMaskedLM(HATPreTrainedModel):
|
|
1093 |
def set_output_embeddings(self, new_embeddings):
|
1094 |
self.lm_head.decoder = new_embeddings
|
1095 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1096 |
@add_start_docstrings_to_model_forward(HAT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1097 |
@add_code_sample_docstrings(
|
1098 |
processor_class=_TOKENIZER_FOR_DOC,
|
|
|
1093 |
def set_output_embeddings(self, new_embeddings):
|
1094 |
self.lm_head.decoder = new_embeddings
|
1095 |
|
1096 |
+
def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
|
1097 |
+
"""Tie or clone module weights depending of whether we are using TorchScript or not"""
|
1098 |
+
if self.config.torchscript:
|
1099 |
+
output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
|
1100 |
+
else:
|
1101 |
+
output_embeddings.weight = input_embeddings.weight
|
1102 |
+
|
1103 |
+
if getattr(output_embeddings, "bias", None) is not None:
|
1104 |
+
output_embeddings.bias.data = nn.functional.pad(
|
1105 |
+
output_embeddings.bias.data,
|
1106 |
+
(
|
1107 |
+
0,
|
1108 |
+
output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],
|
1109 |
+
),
|
1110 |
+
"constant",
|
1111 |
+
0,
|
1112 |
+
)
|
1113 |
+
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
|
1114 |
+
output_embeddings.out_features = input_embeddings.num_embeddings
|
1115 |
+
|
1116 |
@add_start_docstrings_to_model_forward(HAT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1117 |
@add_code_sample_docstrings(
|
1118 |
processor_class=_TOKENIZER_FOR_DOC,
|