File size: 1,897 Bytes
1e11c07 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import sentencepiece as spm
import torch
## Tokenizer
class Callable_tokenizer():
def __init__(self, tokenizer_path):
self.path = tokenizer_path
self.tokenizer = spm.SentencePieceProcessor()
self.tokenizer.load(tokenizer_path)
def __call__(self, text):
return self.tokenizer.Encode(text)
def get_tokenId(self, token_name):
return self.tokenizer.piece_to_id(token_name)
def get_tokenName(self, id):
return self.tokenizer.id_to_piece(id)
def decode(self, tokens_list):
return self.tokenizer.Decode(tokens_list)
def __len__(self):
return len(self.tokenizer)
def user_tokenization(self, text):
return self(text) + [self.get_tokenId('</s>')]
@torch.no_grad
def greedy_decode(model:torch.nn.Module, source_tensor:torch.Tensor, sos_tokenId: int, eos_tokenId:int, pad_tokenId, max_tries=50):
model.eval()
device = source_tensor.device
target_tensor = torch.tensor([sos_tokenId]).unsqueeze(0).to(device)
for i in range(max_tries):
logits, _ = model(source_tensor, target_tensor, pad_tokenId)
# Greedy decoding
top1 = logits[:,-1,:].argmax(dim=-1, keepdim=True)
# Append predicted token
target_tensor = torch.cat([target_tensor, top1], dim=1)
# Stop if predict <EOS>
if top1.item() == eos_tokenId:
break
return target_tensor.squeeze(0).tolist()
def en_translate_ar(text, model, tokenizer):
source_tensor = torch.tensor(tokenizer(text)).unsqueeze(0)
target_tokens = greedy_decode(model, source_tensor,
tokenizer.get_tokenId('<s>'),
tokenizer.get_tokenId('</s>'),
tokenizer.get_tokenId('<pad>'), 30)
return tokenizer.decode(target_tokens) |