File size: 1,149 Bytes
8d4b0c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import json

import torch


class UstaTokenizer:
  def __init__(self, vocab_file):
    with open(vocab_file, "r") as f:
      self.vocab = json.load(f)
      self.reverse_vocab = {v: k for k, v in self.vocab.items()}

  def encode(self, text):
    tokens = [] 
       
    for word in text.split():
      i = 0
      # example: states
      # state => 4
      # s => 58
      while i < len(word):
        found_match = False
        for j in range(len(word), i, -1):
          sub_word = word[i:j]
          if sub_word in self.vocab:
            tokens.append(self.vocab[sub_word])
            i = j
            found_match = True
            break
        if not found_match:
          tokens.append(self.vocab["<unk>"])
          i += 1
      tokens.append(self.vocab[" "])

    tokens.pop()
    return torch.tensor(tokens)
  
  def tokenize(self, text):
    token_ids = self.encode(text)
    # token_ids from tensor to list
    token_ids = token_ids.detach().numpy().tolist()

    return [self.reverse_vocab[id] for id in token_ids]

  def decode(self, ids):
    text = ""
    for id in ids:
      text += self.reverse_vocab[id]
    return text