TomatoCocotree
上传
6a62ffb
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq import file_utils
from fairseq.data.encoders import register_bpe
from fairseq.dataclass import FairseqDataclass
from .gpt2_bpe_utils import get_encoder
DEFAULT_ENCODER_JSON = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json"
DEFAULT_VOCAB_BPE = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe"
@dataclass
class GPT2BPEConfig(FairseqDataclass):
gpt2_encoder_json: str = field(
default=DEFAULT_ENCODER_JSON, metadata={"help": "path to encoder.json"}
)
gpt2_vocab_bpe: str = field(
default=DEFAULT_VOCAB_BPE, metadata={"help": "path to vocab.bpe"}
)
@register_bpe("gpt2", dataclass=GPT2BPEConfig)
class GPT2BPE(object):
def __init__(self, cfg):
encoder_json = file_utils.cached_path(cfg.gpt2_encoder_json)
vocab_bpe = file_utils.cached_path(cfg.gpt2_vocab_bpe)
self.bpe = get_encoder(encoder_json, vocab_bpe)
def encode(self, x: str) -> str:
return " ".join(map(str, self.bpe.encode(x)))
def decode(self, x: str) -> str:
return self.bpe.decode(
[int(tok) if tok not in {"<unk>", "<mask>"} else tok for tok in x.split()]
)
def is_beginning_of_word(self, x: str) -> bool:
return self.decode(x).startswith(" ")