Upload 3 files
Browse files- configuration_hat.py +1 -1
- modelling_hat.py +0 -1
- tokenization_hat.py +22 -0
configuration_hat.py
CHANGED
|
@@ -147,4 +147,4 @@ class HATOnnxConfig(OnnxConfig):
|
|
| 147 |
("input_ids", {0: "batch", 1: "sequence"}),
|
| 148 |
("attention_mask", {0: "batch", 1: "sequence"}),
|
| 149 |
]
|
| 150 |
-
)
|
|
|
|
| 147 |
("input_ids", {0: "batch", 1: "sequence"}),
|
| 148 |
("attention_mask", {0: "batch", 1: "sequence"}),
|
| 149 |
]
|
| 150 |
+
)
|
modelling_hat.py
CHANGED
|
@@ -2357,4 +2357,3 @@ def off_diagonal(x):
|
|
| 2357 |
assert n == m
|
| 2358 |
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
|
| 2359 |
|
| 2360 |
-
|
|
|
|
| 2357 |
assert n == m
|
| 2358 |
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
|
| 2359 |
|
|
|
tokenization_hat.py
CHANGED
|
@@ -246,4 +246,26 @@ class HATTokenizer:
|
|
| 246 |
flat_input[:chunk_size-1],
|
| 247 |
torch.tensor([self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1), dtype=torch.int)
|
| 248 |
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
|
|
|
|
| 246 |
flat_input[:chunk_size-1],
|
| 247 |
torch.tensor([self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1), dtype=torch.int)
|
| 248 |
))
|
| 249 |
+
|
| 250 |
+
@classmethod
|
| 251 |
+
def register_for_auto_class(cls, auto_class="AutoModel"):
|
| 252 |
+
"""
|
| 253 |
+
Register this class with a given auto class. This should only be used for custom models as the ones in the
|
| 254 |
+
library are already mapped with an auto class.
|
| 255 |
+
<Tip warning={true}>
|
| 256 |
+
This API is experimental and may have some slight breaking changes in the next releases.
|
| 257 |
+
</Tip>
|
| 258 |
+
Args:
|
| 259 |
+
auto_class (`str` or `type`, *optional*, defaults to `"TFAutoModel"`):
|
| 260 |
+
The auto class to register this new model with.
|
| 261 |
+
"""
|
| 262 |
+
if not isinstance(auto_class, str):
|
| 263 |
+
auto_class = auto_class.__name__
|
| 264 |
+
|
| 265 |
+
import transformers.models.auto as auto_module
|
| 266 |
+
|
| 267 |
+
if not hasattr(auto_module, auto_class):
|
| 268 |
+
raise ValueError(f"{auto_class} is not a valid auto class.")
|
| 269 |
+
|
| 270 |
+
cls._auto_class = auto_class
|
| 271 |
|