File size: 1,663 Bytes
c186b27
2b6660e
c186b27
 
 
2b6660e
c186b27
 
 
 
2b6660e
 
 
 
822e1b3
 
 
c186b27
2b6660e
 
 
c186b27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from typing import Dict, List

class NextSentencePredictionTokenizer:

    def __init__(self, _tokenizer, **_tokenizer_args):
        self.tokenizer = _tokenizer
        self.tokenizer_args = _tokenizer_args
        self.max_length_ctx = self.tokenizer_args.get("max_length_ctx")
        self.max_length_res = self.tokenizer_args.get("max_length_res")
        self.special_token = self.tokenizer_args.get("special_token")
        self.tokenizer_args["max_length"] = self.max_length_ctx + self.max_length_res

        # cleaning
        for key_to_delete in ["special_token", "naive_approach", "max_length_ctx", "max_length_res", "approach"]:
            if key_to_delete in self.tokenizer_args:
                del self.tokenizer_args[key_to_delete]

    def get_item(self, context: List[str], actual_sentence: str):
        context_str = f" {self.special_token} ".join(context) if self.special_token != " " else " ".join(context)
        actual_item = {"ctx": context_str, "res": actual_sentence}
        tokenized = self._tokenize_row(actual_item)

        for key in tokenized.data.keys():
            tokenized.data[key] = torch.reshape(torch.from_numpy(tokenized.data[key]), (1, -1))
        return tokenized

    def _tokenize_row(self, row: Dict):
        ctx_tokens = row["ctx"].split(" ")
        res_tokens = row["res"].split(" ")
        # -5 for additional information like [SEP], [CLS]
        ctx_tokens = ctx_tokens[-self.max_length_ctx:]
        res_tokens = res_tokens[-self.max_length_res:]
        _args = (ctx_tokens, res_tokens)
        tokenized_row = self.tokenizer(*_args, **self.tokenizer_args)
        return tokenized_row