Upload 4 files
Browse files- __init__.py +0 -0
- gradio_utils.py +54 -0
- requirements.txt +2 -0
- runtime.txt +1 -0
__init__.py
ADDED
File without changes
|
gradio_utils.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sentencepiece as spm
|
2 |
+
import torch
|
3 |
+
|
4 |
+
## Tokenizer
|
5 |
+
class Callable_tokenizer():
|
6 |
+
def __init__(self, tokenizer_path):
|
7 |
+
self.path = tokenizer_path
|
8 |
+
self.tokenizer = spm.SentencePieceProcessor()
|
9 |
+
self.tokenizer.load(tokenizer_path)
|
10 |
+
def __call__(self, text):
|
11 |
+
return self.tokenizer.Encode(text)
|
12 |
+
|
13 |
+
def get_tokenId(self, token_name):
|
14 |
+
return self.tokenizer.piece_to_id(token_name)
|
15 |
+
|
16 |
+
def get_tokenName(self, id):
|
17 |
+
return self.tokenizer.id_to_piece(id)
|
18 |
+
|
19 |
+
def decode(self, tokens_list):
|
20 |
+
return self.tokenizer.Decode(tokens_list)
|
21 |
+
|
22 |
+
def __len__(self):
|
23 |
+
return len(self.tokenizer)
|
24 |
+
|
25 |
+
def user_tokenization(self, text):
|
26 |
+
return self(text) + [self.get_tokenId('</s>')]
|
27 |
+
|
28 |
+
|
29 |
+
@torch.no_grad
|
30 |
+
def greedy_decode(model:torch.nn.Module, source_tensor:torch.Tensor, sos_tokenId: int, eos_tokenId:int, pad_tokenId, max_tries=50):
|
31 |
+
model.eval()
|
32 |
+
device = source_tensor.device
|
33 |
+
target_tensor = torch.tensor([sos_tokenId]).unsqueeze(0).to(device)
|
34 |
+
|
35 |
+
for i in range(max_tries):
|
36 |
+
logits, _ = model(source_tensor, target_tensor, pad_tokenId)
|
37 |
+
# Greedy decoding
|
38 |
+
top1 = logits[:,-1,:].argmax(dim=-1, keepdim=True)
|
39 |
+
# Append predicted token
|
40 |
+
target_tensor = torch.cat([target_tensor, top1], dim=1)
|
41 |
+
# Stop if predict <EOS>
|
42 |
+
if top1.item() == eos_tokenId:
|
43 |
+
break
|
44 |
+
return target_tensor.squeeze(0).tolist()
|
45 |
+
|
46 |
+
|
47 |
+
def en_translate_ar(text, model, tokenizer):
|
48 |
+
source_tensor = torch.tensor(tokenizer(text)).unsqueeze(0)
|
49 |
+
target_tokens = greedy_decode(model, source_tensor,
|
50 |
+
tokenizer.get_tokenId('<s>'),
|
51 |
+
tokenizer.get_tokenId('</s>'),
|
52 |
+
tokenizer.get_tokenId('<pad>'), 30)
|
53 |
+
|
54 |
+
return tokenizer.decode(target_tokens)
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
sentencepiece
|
runtime.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python-3.11.9
|