Commit
·
af99e83
1
Parent(s):
9e40d21
Add HAT implementation files
Browse files- tokenization_hat.py +7 -2
tokenization_hat.py
CHANGED
|
@@ -12,7 +12,7 @@
|
|
| 12 |
# limitations under the License.
|
| 13 |
"""Tokenization classes for HAT."""
|
| 14 |
import torch
|
| 15 |
-
from transformers import
|
| 16 |
from .configuration_hat import HATConfig
|
| 17 |
from transformers.utils import logging
|
| 18 |
try:
|
|
@@ -92,7 +92,11 @@ class HATTokenizer:
|
|
| 92 |
|
| 93 |
@classmethod
|
| 94 |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
def save_pretrained(self, *args, **kwargs):
|
| 98 |
return self._tokenizer.save_pretrained( *args, **kwargs)
|
|
@@ -242,3 +246,4 @@ class HATTokenizer:
|
|
| 242 |
flat_input[:chunk_size-1],
|
| 243 |
torch.tensor([self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1), dtype=torch.int)
|
| 244 |
))
|
|
|
|
|
|
| 12 |
# limitations under the License.
|
| 13 |
"""Tokenization classes for HAT."""
|
| 14 |
import torch
|
| 15 |
+
from transformers import RobertaTokenizer, BertTokenizer
|
| 16 |
from .configuration_hat import HATConfig
|
| 17 |
from transformers.utils import logging
|
| 18 |
try:
|
|
|
|
| 92 |
|
| 93 |
@classmethod
|
| 94 |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
| 95 |
+
try:
|
| 96 |
+
tokenizer = RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
| 97 |
+
except:
|
| 98 |
+
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
| 99 |
+
return cls(tokenizer=tokenizer)
|
| 100 |
|
| 101 |
def save_pretrained(self, *args, **kwargs):
|
| 102 |
return self._tokenizer.save_pretrained( *args, **kwargs)
|
|
|
|
| 246 |
flat_input[:chunk_size-1],
|
| 247 |
torch.tensor([self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1), dtype=torch.int)
|
| 248 |
))
|
| 249 |
+
|