TheDemond commited on
Commit
1e11c07
·
verified ·
1 Parent(s): c412427

Upload 4 files

Browse files
Files changed (4) hide show
  1. __init__.py +0 -0
  2. gradio_utils.py +54 -0
  3. requirements.txt +2 -0
  4. runtime.txt +1 -0
__init__.py ADDED
File without changes
gradio_utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sentencepiece as spm
2
+ import torch
3
+
4
+ ## Tokenizer
5
+ class Callable_tokenizer():
6
+ def __init__(self, tokenizer_path):
7
+ self.path = tokenizer_path
8
+ self.tokenizer = spm.SentencePieceProcessor()
9
+ self.tokenizer.load(tokenizer_path)
10
+ def __call__(self, text):
11
+ return self.tokenizer.Encode(text)
12
+
13
+ def get_tokenId(self, token_name):
14
+ return self.tokenizer.piece_to_id(token_name)
15
+
16
+ def get_tokenName(self, id):
17
+ return self.tokenizer.id_to_piece(id)
18
+
19
+ def decode(self, tokens_list):
20
+ return self.tokenizer.Decode(tokens_list)
21
+
22
+ def __len__(self):
23
+ return len(self.tokenizer)
24
+
25
+ def user_tokenization(self, text):
26
+ return self(text) + [self.get_tokenId('</s>')]
27
+
28
+
29
+ @torch.no_grad
30
+ def greedy_decode(model:torch.nn.Module, source_tensor:torch.Tensor, sos_tokenId: int, eos_tokenId:int, pad_tokenId, max_tries=50):
31
+ model.eval()
32
+ device = source_tensor.device
33
+ target_tensor = torch.tensor([sos_tokenId]).unsqueeze(0).to(device)
34
+
35
+ for i in range(max_tries):
36
+ logits, _ = model(source_tensor, target_tensor, pad_tokenId)
37
+ # Greedy decoding
38
+ top1 = logits[:,-1,:].argmax(dim=-1, keepdim=True)
39
+ # Append predicted token
40
+ target_tensor = torch.cat([target_tensor, top1], dim=1)
41
+ # Stop if predict <EOS>
42
+ if top1.item() == eos_tokenId:
43
+ break
44
+ return target_tensor.squeeze(0).tolist()
45
+
46
+
47
+ def en_translate_ar(text, model, tokenizer):
48
+ source_tensor = torch.tensor(tokenizer(text)).unsqueeze(0)
49
+ target_tokens = greedy_decode(model, source_tensor,
50
+ tokenizer.get_tokenId('<s>'),
51
+ tokenizer.get_tokenId('</s>'),
52
+ tokenizer.get_tokenId('<pad>'), 30)
53
+
54
+ return tokenizer.decode(target_tokens)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch
2
+ sentencepiece
runtime.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python-3.11.9