Spaces:
Runtime error
Runtime error
File size: 1,663 Bytes
c186b27 2b6660e c186b27 2b6660e c186b27 2b6660e 822e1b3 c186b27 2b6660e c186b27 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import torch
from typing import Dict, List
class NextSentencePredictionTokenizer:
def __init__(self, _tokenizer, **_tokenizer_args):
self.tokenizer = _tokenizer
self.tokenizer_args = _tokenizer_args
self.max_length_ctx = self.tokenizer_args.get("max_length_ctx")
self.max_length_res = self.tokenizer_args.get("max_length_res")
self.special_token = self.tokenizer_args.get("special_token")
self.tokenizer_args["max_length"] = self.max_length_ctx + self.max_length_res
# cleaning
for key_to_delete in ["special_token", "naive_approach", "max_length_ctx", "max_length_res", "approach"]:
if key_to_delete in self.tokenizer_args:
del self.tokenizer_args[key_to_delete]
def get_item(self, context: List[str], actual_sentence: str):
context_str = f" {self.special_token} ".join(context) if self.special_token != " " else " ".join(context)
actual_item = {"ctx": context_str, "res": actual_sentence}
tokenized = self._tokenize_row(actual_item)
for key in tokenized.data.keys():
tokenized.data[key] = torch.reshape(torch.from_numpy(tokenized.data[key]), (1, -1))
return tokenized
def _tokenize_row(self, row: Dict):
ctx_tokens = row["ctx"].split(" ")
res_tokens = row["res"].split(" ")
# -5 for additional information like [SEP], [CLS]
ctx_tokens = ctx_tokens[-self.max_length_ctx:]
res_tokens = res_tokens[-self.max_length_res:]
_args = (ctx_tokens, res_tokens)
tokenized_row = self.tokenizer(*_args, **self.tokenizer_args)
return tokenized_row
|