Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from dataclasses import dataclass, field | |
from fairseq import file_utils | |
from fairseq.data.encoders import register_bpe | |
from fairseq.dataclass import FairseqDataclass | |
from .gpt2_bpe_utils import get_encoder | |
DEFAULT_ENCODER_JSON = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json" | |
DEFAULT_VOCAB_BPE = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe" | |
class GPT2BPEConfig(FairseqDataclass): | |
gpt2_encoder_json: str = field( | |
default=DEFAULT_ENCODER_JSON, metadata={"help": "path to encoder.json"} | |
) | |
gpt2_vocab_bpe: str = field( | |
default=DEFAULT_VOCAB_BPE, metadata={"help": "path to vocab.bpe"} | |
) | |
class GPT2BPE(object): | |
def __init__(self, cfg): | |
encoder_json = file_utils.cached_path(cfg.gpt2_encoder_json) | |
vocab_bpe = file_utils.cached_path(cfg.gpt2_vocab_bpe) | |
self.bpe = get_encoder(encoder_json, vocab_bpe) | |
def encode(self, x: str) -> str: | |
return " ".join(map(str, self.bpe.encode(x))) | |
def decode(self, x: str) -> str: | |
return self.bpe.decode( | |
[int(tok) if tok not in {"<unk>", "<mask>"} else tok for tok in x.split()] | |
) | |
def is_beginning_of_word(self, x: str) -> bool: | |
return self.decode(x).startswith(" ") | |