Spaces:
Runtime error
Runtime error
import torch | |
from typing import Dict | |
class NextSentencePredictionTokenizer: | |
def __init__(self, _tokenizer, special_token, **_tokenizer_args): | |
self.tokenizer = _tokenizer | |
self.tokenizer_args = _tokenizer_args | |
self.max_length_ctx = self.tokenizer_args.get("max_length_ctx") | |
self.max_length_res = self.tokenizer_args.get("max_length_res") | |
del self.tokenizer_args["max_length_ctx"] | |
del self.tokenizer_args["max_length_res"] | |
self.tokenizer_args["max_length"] = self.max_length_ctx + self.max_length_res | |
self.special_token = special_token | |
def get_item(self, context: str, actual_sentence: str): | |
actual_item = {"ctx": context.replace("||", self.special_token), "res": actual_sentence} | |
tokenized = self._tokenize_row(actual_item) | |
for key in tokenized.data.keys(): | |
tokenized.data[key] = torch.reshape(torch.from_numpy(tokenized.data[key]), (1, -1)) | |
return tokenized | |
def _tokenize_row(self, row: Dict): | |
ctx_tokens = row["ctx"].split(" ") | |
res_tokens = row["res"].split(" ") | |
# -5 for additional information like [SEP], [CLS] | |
ctx_tokens = ctx_tokens[-self.max_length_ctx:] | |
res_tokens = res_tokens[-self.max_length_res:] | |
_args = (ctx_tokens, res_tokens) | |
tokenized_row = self.tokenizer(*_args, **self.tokenizer_args) | |
return tokenized_row | |