Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from fairseq.data import Dictionary | |
class MaskedLMDictionary(Dictionary): | |
""" | |
Dictionary for Masked Language Modelling tasks. This extends Dictionary by | |
adding the mask symbol. | |
""" | |
def __init__( | |
self, | |
pad="<pad>", | |
eos="</s>", | |
unk="<unk>", | |
mask="<mask>", | |
): | |
super().__init__(pad=pad, eos=eos, unk=unk) | |
self.mask_word = mask | |
self.mask_index = self.add_symbol(mask) | |
self.nspecial = len(self.symbols) | |
def mask(self): | |
"""Helper to get index of mask symbol""" | |
return self.mask_index | |
class BertDictionary(MaskedLMDictionary): | |
""" | |
Dictionary for BERT task. This extends MaskedLMDictionary by adding support | |
for cls and sep symbols. | |
""" | |
def __init__( | |
self, | |
pad="<pad>", | |
eos="</s>", | |
unk="<unk>", | |
mask="<mask>", | |
cls="<cls>", | |
sep="<sep>", | |
): | |
super().__init__(pad=pad, eos=eos, unk=unk, mask=mask) | |
self.cls_word = cls | |
self.sep_word = sep | |
self.cls_index = self.add_symbol(cls) | |
self.sep_index = self.add_symbol(sep) | |
self.nspecial = len(self.symbols) | |
def cls(self): | |
"""Helper to get index of cls symbol""" | |
return self.cls_index | |
def sep(self): | |
"""Helper to get index of sep symbol""" | |
return self.sep_index | |