MohamedRashad commited on
Commit
bcc0c7f
·
1 Parent(s): 7efa162

chore: Add requirements for shakkala and kaldialign

Browse files
Files changed (13) hide show
  1. app.py +45 -22
  2. bw2ar.py +126 -0
  3. ed.py +63 -0
  4. ed_pl.py +164 -0
  5. eo.py +50 -0
  6. eo_pl.py +151 -0
  7. models/best_ed_mlm_ns_epoch_178.pt +3 -0
  8. models/best_eo_mlm_ns_epoch_193.pt +3 -0
  9. requirements.txt +2 -1
  10. tashkeel_tokenizer.py +234 -0
  11. transformer.py +559 -0
  12. utils.py +139 -0
  13. xer.py +76 -0
app.py CHANGED
@@ -1,23 +1,18 @@
1
- import requests
2
  import gradio as gr
3
  from shakkala import Shakkala
 
 
 
 
 
 
 
4
 
5
- tashkel_url = "http://www.7koko.com/api/tashkil/index.php"
6
-
7
-
8
- def add_tashkeel1(text):
9
- data = {"textArabic": text}
10
- response = requests.post(tashkel_url, data=data)
11
- response.encoding = "utf-8"
12
- text = response.text.strip()
13
- return text
14
-
15
-
16
  sh = Shakkala(version=3)
17
  model, graph = sh.get_model()
18
 
19
-
20
- def add_tashkeel2(input_text):
21
  input_int = sh.prepare_input(input_text)
22
  logits = model.predict(input_int)[0]
23
  predicted_harakat = sh.logits_to_text(logits)
@@ -25,28 +20,56 @@ def add_tashkeel2(input_text):
25
  print(final_output)
26
  return final_output
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  with gr.Blocks(title="Arabic Tashkeel") as demo:
30
  gr.HTML("<center><h1>Arabic Tashkeel</h1></center>")
31
  gr.HTML(
32
  "<center><p>Compare different methods for adding tashkeel to Arabic text.</p></center>"
33
  )
34
- with gr.Tab(label="Tashkil"):
 
35
  with gr.Row():
36
  with gr.Column():
37
- text_input1 = gr.Textbox(
38
- lines=1, label="Input Text", rtl=True, text_align="right"
 
 
 
39
  )
40
  with gr.Row():
41
  clear_button1 = gr.Button(value="Clear", variant="secondary")
42
  submit_button1 = gr.Button(value="Add Tashkeel", variant="primary")
43
 
44
  with gr.Column():
45
- text_output1 = gr.Textbox(
46
- lines=1, label="Output Text", rtl=True, text_align="right"
47
- )
48
 
49
- submit_button1.click(add_tashkeel1, inputs=text_input1, outputs=text_output1)
50
  clear_button1.click(lambda: text_input1.update(""))
51
 
52
  with gr.Tab(label="Shakkala"):
@@ -66,7 +89,7 @@ with gr.Blocks(title="Arabic Tashkeel") as demo:
66
  lines=1, label="Output Text", rtl=True, text_align="right"
67
  )
68
 
69
- submit_button2.click(add_tashkeel2, inputs=text_input2, outputs=text_output2)
70
  clear_button2.click(lambda: text_input2.update(""))
71
 
72
  demo.queue().launch()
 
 
1
  import gradio as gr
2
  from shakkala import Shakkala
3
+ from pathlib import Path
4
+ import torch
5
+ from eo_pl import TashkeelModel as TashkeelModelEO
6
+ from ed_pl import TashkeelModel as TashkeelModelED
7
+ from tashkeel_tokenizer import TashkeelTokenizer
8
+ from utils import remove_non_arabic
9
+ import spaces
10
 
11
+ # Initialize the Shakkala model
 
 
 
 
 
 
 
 
 
 
12
  sh = Shakkala(version=3)
13
  model, graph = sh.get_model()
14
 
15
+ def infer_shakkala(input_text):
 
16
  input_int = sh.prepare_input(input_text)
17
  logits = model.predict(input_int)[0]
18
  predicted_harakat = sh.logits_to_text(logits)
 
20
  print(final_output)
21
  return final_output
22
 
23
+ # Initialize the CaTT model and tokenizer
24
+ tokenizer = TashkeelTokenizer()
25
+ eo_ckpt_path = Path(__file__).parent / 'models/best_eo_mlm_ns_epoch_193.pt'
26
+
27
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
28
+ print('device:', device)
29
+
30
+ max_seq_len = 1024
31
+ print('Creating Model...')
32
+ eo_model = TashkeelModelEO(tokenizer, max_seq_len=max_seq_len, n_layers=6, learnable_pos_emb=False)
33
+ ed_model = TashkeelModelED(tokenizer, max_seq_len=max_seq_len, n_layers=3, learnable_pos_emb=False)
34
+
35
+ eo_model.load_state_dict(torch.load(eo_ckpt_path, map_location=device)).eval().to(device)
36
+ ed_model.load_state_dict(torch.load(eo_ckpt_path, map_location=device)).eval().to(device)
37
+
38
+ @spaces.GPU()
39
+ def infer_catt(input_text, choose_model):
40
+ input_text = remove_non_arabic(input_text)
41
+ batch_size = 16
42
+ verbose = True
43
+ if choose_model == 'Encoder-Only':
44
+ output_text = eo_model.do_tashkeel_batch([input_text], batch_size, verbose)
45
+ else:
46
+ output_text = ed_model.do_tashkeel_batch([input_text], batch_size, verbose)
47
+
48
+ return output_text[0]
49
 
50
  with gr.Blocks(title="Arabic Tashkeel") as demo:
51
  gr.HTML("<center><h1>Arabic Tashkeel</h1></center>")
52
  gr.HTML(
53
  "<center><p>Compare different methods for adding tashkeel to Arabic text.</p></center>"
54
  )
55
+ with gr.Tab(label="CATT"):
56
+ gr.Markdown("[CATT](https://github.com/abjadai/catt) is a new deep learning model for Arabic diacritization.")
57
  with gr.Row():
58
  with gr.Column():
59
+ text_input1 = gr.Textbox(label="Input Text", rtl=True, text_align="right")
60
+ choose_model = gr.Radio(
61
+ label="Choose Model",
62
+ choices=["Encoder-Only", "Encoder-Decoder"],
63
+ default="Encoder-Decoder",
64
  )
65
  with gr.Row():
66
  clear_button1 = gr.Button(value="Clear", variant="secondary")
67
  submit_button1 = gr.Button(value="Add Tashkeel", variant="primary")
68
 
69
  with gr.Column():
70
+ text_output1 = gr.Textbox(label="Output Text", rtl=True, text_align="right")
 
 
71
 
72
+ submit_button1.click(infer_catt, inputs=[text_input1, choose_model], outputs=text_output1)
73
  clear_button1.click(lambda: text_input1.update(""))
74
 
75
  with gr.Tab(label="Shakkala"):
 
89
  lines=1, label="Output Text", rtl=True, text_align="right"
90
  )
91
 
92
+ submit_button2.click(infer_shakkala, inputs=text_input2, outputs=text_output2)
93
  clear_button2.click(lambda: text_input2.update(""))
94
 
95
  demo.queue().launch()
bw2ar.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ functions to convert Arabic words/text into buckwalter encoding and vice versa
3
+ """
4
+
5
+ import sys
6
+ import re
7
+ import utils
8
+
9
+ buck2uni = {
10
+ "'": u"\u0621", # hamza-on-the-line
11
+ "|": u"\u0622", # madda
12
+ ">": u"\u0623", # hamza-on-'alif
13
+ "&": u"\u0624", # hamza-on-waaw
14
+ "<": u"\u0625", # hamza-under-'alif
15
+ "}": u"\u0626", # hamza-on-yaa'
16
+ "A": u"\u0627", # bare 'alif
17
+ "b": u"\u0628", # baa'
18
+ "p": u"\u0629", # taa' marbuuTa
19
+ "t": u"\u062A", # taa'
20
+ "v": u"\u062B", # thaa'
21
+ "j": u"\u062C", # jiim
22
+ "H": u"\u062D", # Haa'
23
+ "x": u"\u062E", # khaa'
24
+ "d": u"\u062F", # daal
25
+ "*": u"\u0630", # dhaal
26
+ "r": u"\u0631", # raa'
27
+ "z": u"\u0632", # zaay
28
+ "s": u"\u0633", # siin
29
+ "$": u"\u0634", # shiin
30
+ "S": u"\u0635", # Saad
31
+ "D": u"\u0636", # Daad
32
+ "T": u"\u0637", # Taa'
33
+ "Z": u"\u0638", # Zaa' (DHaa')
34
+ "E": u"\u0639", # cayn
35
+ "g": u"\u063A", # ghayn
36
+ "_": u"\u0640", # taTwiil
37
+ "f": u"\u0641", # faa'
38
+ "q": u"\u0642", # qaaf
39
+ "k": u"\u0643", # kaaf
40
+ "l": u"\u0644", # laam
41
+ "m": u"\u0645", # miim
42
+ "n": u"\u0646", # nuun
43
+ "h": u"\u0647", # haa'
44
+ "w": u"\u0648", # waaw
45
+ "Y": u"\u0649", # 'alif maqSuura
46
+ "y": u"\u064A", # yaa'
47
+ "F": u"\u064B", # fatHatayn
48
+ "N": u"\u064C", # Dammatayn
49
+ "K": u"\u064D", # kasratayn
50
+ "a": u"\u064E", # fatHa
51
+ "u": u"\u064F", # Damma
52
+ "i": u"\u0650", # kasra
53
+ "~": u"\u0651", # shaddah
54
+ "o": u"\u0652", # sukuun
55
+ "`": u"\u0670", # dagger 'alif
56
+ "{": u"\u0671", # waSla
57
+ }
58
+
59
+ # For a reverse transliteration (Unicode -> Buckwalter), a dictionary
60
+ # which is the reverse of the above buck2uni is essential.
61
+ uni2buck = {}
62
+
63
+ # Iterate through all the items in the buck2uni dict.
64
+ for (key, value) in buck2uni.items():
65
+ # The value from buck2uni becomes a key in uni2buck, and vice
66
+ # versa for the keys.
67
+ uni2buck[value] = key
68
+
69
+ # add special characters
70
+ uni2buck[u"\ufefb"] = "lA"
71
+ uni2buck[u"\ufef7"] = "l>"
72
+ uni2buck[u"\ufef5"] = "l|"
73
+ uni2buck[u"\ufef9"] = "l<"
74
+
75
+ # clean the arabic text from unwanted characters that may cause problem while building the language model
76
+ def clean_text(text):
77
+ text = re.sub(u"[\ufeff]", "", text, flags=re.UNICODE) # strip Unicode Character 'ZERO WIDTH NO-BREAK SPACE' (U+FEFF). For more info, check http://www.fileformat.info/info/unicode/char/feff/index.htm
78
+ text = utils.remove_non_arabic(text)
79
+ text = utils.strip_tashkeel(text)
80
+ text = utils.strip_tatweel(text)
81
+ return text
82
+
83
+ # convert a single word into buckwalter and vice versa
84
+ def transliterate_word(input_word, direction='bw2ar'):
85
+ output_word = ''
86
+ # Loop over each character in the string, bw_word.
87
+ for char in input_word:
88
+ # Look up current char in the dictionary to get its
89
+ # respective value. If there is no match, e.g., chars like
90
+ # spaces, then just stick with the current char without any
91
+ # conversion.
92
+ # if type(char) == bytes:
93
+ # char = char.decode('ascii')
94
+ if direction == 'bw2ar':
95
+ #print('in bw2ar')
96
+ output_word += buck2uni.get(char, char)
97
+ elif direction == 'ar2bw':
98
+ #print('in ar2bw')
99
+ output_word += uni2buck.get(char, char)
100
+ else:
101
+ sys.stderr.write('Error: invalid direction!')
102
+ sys.exit()
103
+ return output_word
104
+
105
+
106
+ # convert a text into buckwalter and vice versa
107
+ def transliterate_text(input_text, direction='bw2ar'):
108
+ output_text = ''
109
+ for input_word in input_text.split(' '):
110
+ output_text += transliterate_word(input_word, direction) + ' '
111
+
112
+ return output_text[:-1] # remove the last space ONLY
113
+
114
+
115
+ if __name__ == '__main__':
116
+ if len(sys.argv) < 2:
117
+ sys.stderr.write('Usage: INPUT TEXT | python {} DIRECTION(bw2ar|ar2bw)'.format(sys.argv[1]))
118
+ exit(1)
119
+ for line in sys.stdin:
120
+ line = line if sys.argv[1] == 'bw2ar' else clean_text(line)
121
+ output_text = transliterate_text(line, direction=str(sys.argv[1]))
122
+ if output_text.strip() != '':
123
+ sys.stdout.write('{}\n'.format(output_text.strip()))
124
+
125
+
126
+
ed.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch.nn as nn
3
+ from transformer import *
4
+
5
+ class Transformer(nn.Module):
6
+
7
+ def __init__(self, src_pad_idx, trg_pad_idx, enc_voc_size, dec_voc_size, d_model, n_head, max_len,
8
+ ffn_hidden, n_layers, drop_prob, learnable_pos_emb=True):
9
+ super().__init__()
10
+ self.src_pad_idx = src_pad_idx
11
+ self.trg_pad_idx = trg_pad_idx
12
+ self.encoder = Encoder(d_model=d_model,
13
+ n_head=n_head,
14
+ max_len=max_len,
15
+ ffn_hidden=ffn_hidden,
16
+ enc_voc_size=enc_voc_size,
17
+ drop_prob=drop_prob,
18
+ n_layers=n_layers,
19
+ padding_idx=src_pad_idx,
20
+ learnable_pos_emb=learnable_pos_emb)
21
+
22
+ self.decoder = Decoder(d_model=d_model,
23
+ n_head=n_head,
24
+ max_len=max_len,
25
+ ffn_hidden=ffn_hidden,
26
+ dec_voc_size=dec_voc_size,
27
+ drop_prob=drop_prob,
28
+ n_layers=n_layers,
29
+ padding_idx=trg_pad_idx,
30
+ learnable_pos_emb=learnable_pos_emb)
31
+
32
+ def get_device(self):
33
+ return next(self.parameters()).device
34
+
35
+ def forward(self, src, trg):
36
+ device = self.get_device()
37
+ src_mask = self.make_pad_mask(src, src, self.src_pad_idx, self.src_pad_idx).to(device)
38
+ src_trg_mask = self.make_pad_mask(trg, src, self.trg_pad_idx, self.src_pad_idx).to(device)
39
+ trg_mask = self.make_pad_mask(trg, trg, self.trg_pad_idx, self.trg_pad_idx).to(device) * \
40
+ self.make_no_peak_mask(trg, trg).to(device)
41
+ enc_src = self.encoder(src, src_mask)
42
+ output = self.decoder(trg, enc_src, trg_mask, src_trg_mask)
43
+ return output
44
+
45
+ def make_pad_mask(self, q, k, q_pad_idx, k_pad_idx):
46
+ len_q, len_k = q.size(1), k.size(1)
47
+ # batch_size x 1 x 1 x len_k
48
+ k = k.ne(k_pad_idx).unsqueeze(1).unsqueeze(2)
49
+ # batch_size x 1 x len_q x len_k
50
+ k = k.repeat(1, 1, len_q, 1)
51
+ # batch_size x 1 x len_q x 1
52
+ q = q.ne(q_pad_idx).unsqueeze(1).unsqueeze(3)
53
+ # batch_size x 1 x len_q x len_k
54
+ q = q.repeat(1, 1, 1, len_k)
55
+ mask = k & q
56
+ return mask
57
+
58
+ def make_no_peak_mask(self, q, k):
59
+ len_q, len_k = q.size(1), k.size(1)
60
+ # len_q x len_k
61
+ mask = torch.tril(torch.ones(len_q, len_k)).type(torch.BoolTensor)
62
+ return mask
63
+
ed_pl.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import pytorch_lightning as pl
5
+ from torch.nn import functional as F
6
+ from torch.utils.data import DataLoader
7
+ from ed import Transformer
8
+ from tqdm import tqdm
9
+ import math
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from torch.nn.utils.rnn import pad_sequence
14
+ # sequences is a list of tensors of shape TxH where T is the seqlen and H is the feats dim
15
+ def pad_seq(sequences, batch_first=True, padding_value=0.0, prepadding=True):
16
+ lens = [i.shape[0]for i in sequences]
17
+ padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=padding_value) # NxTxH
18
+ if prepadding:
19
+ for i in range(len(lens)):
20
+ padded_sequences[i] = padded_sequences[i].roll(-lens[i])
21
+ if not batch_first:
22
+ padded_sequences = padded_sequences.transpose(0, 1) # TxNxH
23
+ return padded_sequences
24
+
25
+
26
+
27
+ def get_batches(X, batch_size=16):
28
+ num_batches = math.ceil(len(X) / batch_size)
29
+ for i in range(num_batches):
30
+ x = X[i*batch_size : (i+1)*batch_size]
31
+ yield x
32
+
33
+
34
+ class TashkeelModel(pl.LightningModule):
35
+ def __init__(self, tokenizer, max_seq_len, d_model=512, n_layers=3, n_heads=16, drop_prob=0.1, learnable_pos_emb=True):
36
+
37
+ super(TashkeelModel, self).__init__()
38
+
39
+ ffn_hidden = 4 * d_model
40
+ src_pad_idx = tokenizer.letters_map['<PAD>']
41
+ trg_pad_idx = tokenizer.tashkeel_map['<PAD>']
42
+ enc_voc_size = len(tokenizer.letters_map) # 37 + 3
43
+ dec_voc_size = len(tokenizer.tashkeel_map) # 15 + 3
44
+ self.transformer = Transformer(src_pad_idx=src_pad_idx,
45
+ trg_pad_idx=trg_pad_idx,
46
+ d_model=d_model,
47
+ enc_voc_size=enc_voc_size,
48
+ dec_voc_size=dec_voc_size,
49
+ max_len=max_seq_len,
50
+ ffn_hidden=ffn_hidden,
51
+ n_head=n_heads,
52
+ n_layers=n_layers,
53
+ drop_prob=drop_prob,
54
+ learnable_pos_emb=learnable_pos_emb
55
+ )
56
+
57
+ self.criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.tashkeel_map['<PAD>'])
58
+ self.tokenizer = tokenizer
59
+
60
+
61
+ def forward(self, x, y=None):
62
+ y_pred = self.transformer(x, y)
63
+ return y_pred
64
+
65
+
66
+ def training_step(self, batch, batch_idx):
67
+ input_ids, target_ids = batch
68
+ input_ids = input_ids[:, :-1]
69
+ y_in = target_ids[:, :-1]
70
+ y_out = target_ids[:, 1:]
71
+ y_pred = self(input_ids, y_in)
72
+ loss = self.criterion(y_pred.transpose(1, 2), y_out)
73
+
74
+ self.log('train_loss', loss, prog_bar=True)
75
+ sch = self.lr_schedulers()
76
+ sch.step()
77
+ self.log('lr', sch.get_last_lr()[0], prog_bar=True)
78
+ return loss
79
+
80
+
81
+ def validation_step(self, batch, batch_idx):
82
+ input_ids, target_ids = batch
83
+ input_ids = input_ids[:, :-1]
84
+ y_in = target_ids[:, :-1]
85
+ y_out = target_ids[:, 1:]
86
+ y_pred = self(input_ids, y_in)
87
+ loss = self.criterion(y_pred.transpose(1, 2), y_out)
88
+
89
+ pred_text_with_tashkeels = self.tokenizer.decode(input_ids, y_pred.argmax(2).squeeze())
90
+ true_text_with_tashkeels = self.tokenizer.decode(input_ids, y_out)
91
+ total_val_der_distance = 0
92
+ total_val_der_ref_length = 0
93
+ for i in range(len(true_text_with_tashkeels)):
94
+ pred_text_with_tashkeel = pred_text_with_tashkeels[i]
95
+ true_text_with_tashkeel = true_text_with_tashkeels[i]
96
+ val_der = self.tokenizer.compute_der(true_text_with_tashkeel, pred_text_with_tashkeel)
97
+ total_val_der_distance += val_der['distance']
98
+ total_val_der_ref_length += val_der['ref_length']
99
+
100
+ total_der_error = total_val_der_distance / total_val_der_ref_length
101
+ self.log('val_loss', loss)
102
+ self.log('val_der', torch.FloatTensor([total_der_error]))
103
+ self.log('val_der_distance', torch.FloatTensor([total_val_der_distance]))
104
+ self.log('val_der_ref_length', torch.FloatTensor([total_val_der_ref_length]))
105
+
106
+
107
+ def test_step(self, batch, batch_idx):
108
+ input_ids, target_ids = batch
109
+ y_pred = self(input_ids, None)
110
+ loss = self.criterion(y_pred.transpose(1, 2), target_ids)
111
+ self.log('test_loss', loss)
112
+
113
+
114
+ def configure_optimizers(self):
115
+ optimizer = torch.optim.AdamW(self.parameters(), lr=3e-4)
116
+ #max_iters = 10000
117
+ #lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_iters, eta_min=3e-6)
118
+ gamma = 1 / 1.000001
119
+ #gamma = 1 / 1.0001
120
+ lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)
121
+ opts = {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
122
+ return opts
123
+
124
+
125
+ @torch.no_grad()
126
+ def do_tashkeel_batch(self, texts, batch_size=16, verbose=True):
127
+ self.eval()
128
+ device = next(self.parameters()).device
129
+ text_with_tashkeel = []
130
+ data_iter = get_batches(texts, batch_size)
131
+ if verbose:
132
+ num_batches = math.ceil(len(texts) / batch_size)
133
+ data_iter = tqdm(data_iter, total=num_batches)
134
+ for texts_mini in data_iter:
135
+ input_ids_list = []
136
+ for text in texts_mini:
137
+ input_ids, _ = self.tokenizer.encode(text, test_match=False)
138
+ input_ids_list.append(input_ids)
139
+ batch_input_ids = pad_seq(input_ids_list, batch_first=True, padding_value=self.tokenizer.letters_map['<PAD>'], prepadding=False)
140
+ target_ids = torch.LongTensor([[self.tokenizer.tashkeel_map['<BOS>']]] * len(texts_mini)).to(device)
141
+ src = batch_input_ids.to(device)
142
+
143
+ src_mask = self.transformer.make_pad_mask(src, src, self.transformer.src_pad_idx, self.transformer.src_pad_idx).to(device)
144
+ enc_src = self.transformer.encoder(src, src_mask)
145
+
146
+ for i in range(src.shape[1] - 1):
147
+ trg = target_ids
148
+ src_trg_mask = self.transformer.make_pad_mask(trg, src, self.transformer.trg_pad_idx, self.transformer.src_pad_idx).to(device)
149
+ trg_mask = self.transformer.make_pad_mask(trg, trg, self.transformer.trg_pad_idx, self.transformer.trg_pad_idx).to(device) * \
150
+ self.transformer.make_no_peak_mask(trg, trg).to(device)
151
+
152
+ preds = self.transformer.decoder(trg, enc_src, trg_mask, src_trg_mask)
153
+ # IMPORTANT NOTE: the following code snippet is to FORCE the prediction of the input space char to output no_tashkeel tag '<NT>'
154
+ target_ids = torch.cat([target_ids, preds[:, -1].argmax(1).unsqueeze(1)], axis=1)
155
+ target_ids[self.tokenizer.letters_map[' '] == src[:, :target_ids.shape[1]]] = self.tokenizer.tashkeel_map[self.tokenizer.no_tashkeel_tag]
156
+ # target_ids = torch.cat([target_ids, preds[:, -1].argmax(1).unsqueeze(1)], axis=1)
157
+ text_with_tashkeel_mini = self.tokenizer.decode(src, target_ids)
158
+ text_with_tashkeel += text_with_tashkeel_mini
159
+ return text_with_tashkeel
160
+
161
+
162
+ @torch.no_grad()
163
+ def do_tashkeel(self, text):
164
+ return self.do_tashkeel_batch([text])[0]
eo.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from transformer import *
3
+
4
+ class Transformer(nn.Module):
5
+ def __init__(self, src_pad_idx, trg_pad_idx, enc_voc_size, dec_voc_size, d_model, n_head, max_len,
6
+ ffn_hidden, n_layers, drop_prob, learnable_pos_emb=True):
7
+ super().__init__()
8
+ self.src_pad_idx = src_pad_idx
9
+ self.trg_pad_idx = trg_pad_idx
10
+ self.encoder = Encoder(d_model=d_model,
11
+ n_head=n_head,
12
+ max_len=max_len,
13
+ ffn_hidden=ffn_hidden,
14
+ enc_voc_size=enc_voc_size,
15
+ drop_prob=drop_prob,
16
+ n_layers=n_layers,
17
+ padding_idx=src_pad_idx,
18
+ learnable_pos_emb=learnable_pos_emb)
19
+
20
+ self.decoder = nn.Linear(d_model, dec_voc_size)
21
+
22
+ def get_device(self):
23
+ return next(self.parameters()).device
24
+
25
+ def forward(self, src):
26
+ device = self.get_device()
27
+ src_mask = self.make_pad_mask(src, src, self.src_pad_idx, self.src_pad_idx).to(device)
28
+ enc_src = self.encoder(src, src_mask)
29
+ output = self.decoder(enc_src)
30
+ return output
31
+
32
+ def make_pad_mask(self, q, k, q_pad_idx, k_pad_idx):
33
+ len_q, len_k = q.size(1), k.size(1)
34
+ # batch_size x 1 x 1 x len_k
35
+ k = k.ne(k_pad_idx).unsqueeze(1).unsqueeze(2)
36
+ # batch_size x 1 x len_q x len_k
37
+ k = k.repeat(1, 1, len_q, 1)
38
+ # batch_size x 1 x len_q x 1
39
+ q = q.ne(q_pad_idx).unsqueeze(1).unsqueeze(3)
40
+ # batch_size x 1 x len_q x len_k
41
+ q = q.repeat(1, 1, 1, len_k)
42
+ mask = k & q
43
+ return mask
44
+
45
+ def make_no_peak_mask(self, q, k):
46
+ len_q, len_k = q.size(1), k.size(1)
47
+ # len_q x len_k
48
+ mask = torch.tril(torch.ones(len_q, len_k)).type(torch.BoolTensor)
49
+ return mask
50
+
eo_pl.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import pytorch_lightning as pl
5
+ from torch.nn import functional as F
6
+ from torch.utils.data import DataLoader
7
+ from eo import Transformer
8
+ from tqdm import tqdm
9
+ import math
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from torch.nn.utils.rnn import pad_sequence
14
+ # sequences is a list of tensors of shape TxH where T is the seqlen and H is the feats dim
15
+ def pad_seq(sequences, batch_first=True, padding_value=0.0, prepadding=True):
16
+ lens = [i.shape[0]for i in sequences]
17
+ padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=padding_value) # NxTxH
18
+ if prepadding:
19
+ for i in range(len(lens)):
20
+ padded_sequences[i] = padded_sequences[i].roll(-lens[i])
21
+ if not batch_first:
22
+ padded_sequences = padded_sequences.transpose(0, 1) # TxNxH
23
+ return padded_sequences
24
+
25
+
26
+
27
+ def get_batches(X, batch_size=16):
28
+ num_batches = math.ceil(len(X) / batch_size)
29
+ for i in range(num_batches):
30
+ x = X[i*batch_size : (i+1)*batch_size]
31
+ yield x
32
+
33
+
34
+ class TashkeelModel(pl.LightningModule):
35
+ def __init__(self, tokenizer, max_seq_len, d_model=512, n_layers=3, n_heads=16, drop_prob=0.1, learnable_pos_emb=True):
36
+
37
+ super(TashkeelModel, self).__init__()
38
+
39
+ ffn_hidden = 4 * d_model
40
+ src_pad_idx = tokenizer.letters_map['<PAD>']
41
+ trg_pad_idx = tokenizer.tashkeel_map['<PAD>']
42
+ enc_voc_size = len(tokenizer.letters_map) # 37 + 3
43
+ dec_voc_size = len(tokenizer.tashkeel_map) # 15 + 3
44
+ self.transformer = Transformer(src_pad_idx=src_pad_idx,
45
+ trg_pad_idx=trg_pad_idx,
46
+ d_model=d_model,
47
+ enc_voc_size=enc_voc_size,
48
+ dec_voc_size=dec_voc_size,
49
+ max_len=max_seq_len,
50
+ ffn_hidden=ffn_hidden,
51
+ n_head=n_heads,
52
+ n_layers=n_layers,
53
+ drop_prob=drop_prob,
54
+ learnable_pos_emb=learnable_pos_emb
55
+ )
56
+
57
+ self.criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.tashkeel_map['<PAD>'])
58
+ self.tokenizer = tokenizer
59
+
60
+
61
+ def forward(self, x):
62
+ y_pred = self.transformer(x)
63
+ return y_pred
64
+
65
+
66
+ def training_step(self, batch, batch_idx):
67
+ input_ids, target_ids = batch
68
+ input_ids = input_ids[:, 1:-1]
69
+ y_out = target_ids[:, 1:-1]
70
+ y_pred = self(input_ids)
71
+ loss = self.criterion(y_pred.transpose(1, 2), y_out)
72
+
73
+ self.log('train_loss', loss, prog_bar=True)
74
+ # sch = self.lr_schedulers()
75
+ # sch.step()
76
+ # self.log('lr', sch.get_last_lr()[0], prog_bar=True)
77
+ return loss
78
+
79
+
80
+ def validation_step(self, batch, batch_idx):
81
+ input_ids, target_ids = batch
82
+ input_ids = input_ids[:, 1:-1]
83
+ y_out = target_ids[:, 1:-1]
84
+ y_pred = self(input_ids)
85
+ loss = self.criterion(y_pred.transpose(1, 2), y_out)
86
+
87
+ pred_text_with_tashkeels = self.tokenizer.decode(input_ids, y_pred.argmax(2).squeeze())
88
+ true_text_with_tashkeels = self.tokenizer.decode(input_ids, y_out)
89
+ total_val_der_distance = 0
90
+ total_val_der_ref_length = 0
91
+ for i in range(len(true_text_with_tashkeels)):
92
+ pred_text_with_tashkeel = pred_text_with_tashkeels[i]
93
+ true_text_with_tashkeel = true_text_with_tashkeels[i]
94
+ val_der = self.tokenizer.compute_der(true_text_with_tashkeel, pred_text_with_tashkeel)
95
+ total_val_der_distance += val_der['distance']
96
+ total_val_der_ref_length += val_der['ref_length']
97
+
98
+ total_der_error = total_val_der_distance / total_val_der_ref_length
99
+ self.log('val_loss', loss)
100
+ self.log('val_der', torch.FloatTensor([total_der_error]))
101
+ self.log('val_der_distance', torch.FloatTensor([total_val_der_distance]))
102
+ self.log('val_der_ref_length', torch.FloatTensor([total_val_der_ref_length]))
103
+
104
+
105
+ def test_step(self, batch, batch_idx):
106
+ input_ids, target_ids = batch
107
+ y_pred = self(input_ids, None)
108
+ loss = self.criterion(y_pred.transpose(1, 2), target_ids)
109
+ self.log('test_loss', loss)
110
+
111
+
112
+ def configure_optimizers(self):
113
+ optimizer = torch.optim.AdamW(self.parameters(), lr=3e-5)
114
+ #max_iters = 10000
115
+ #lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_iters, eta_min=3e-6)
116
+ gamma = 1 / 1.000001
117
+ #gamma = 1 / 1.0001
118
+ #lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)
119
+ opts = {"optimizer": optimizer} #, "lr_scheduler": lr_scheduler}
120
+ return opts
121
+
122
+
123
+ @torch.no_grad()
124
+ def do_tashkeel_batch(self, texts, batch_size=16, verbose=True):
125
+ self.eval()
126
+ device = next(self.parameters()).device
127
+ text_with_tashkeel = []
128
+ data_iter = get_batches(texts, batch_size)
129
+ if verbose:
130
+ num_batches = math.ceil(len(texts) / batch_size)
131
+ data_iter = tqdm(data_iter, total=num_batches)
132
+ for texts_mini in data_iter:
133
+ input_ids_list = []
134
+ for text in texts_mini:
135
+ input_ids, _ = self.tokenizer.encode(text, test_match=False)
136
+ input_ids_list.append(input_ids)
137
+ batch_input_ids = pad_seq(input_ids_list, batch_first=True, padding_value=self.tokenizer.letters_map['<PAD>'], prepadding=False)
138
+ batch_input_ids = batch_input_ids[:, 1:-1].to(device)
139
+ y_pred = self(batch_input_ids)
140
+ y_pred = y_pred.argmax(-1)
141
+ # IMPORTANT NOTE: the following code snippet is to FORCE the prediction of the input space char to output no_tashkeel tag '<NT>'
142
+ y_pred[self.tokenizer.letters_map[' '] == batch_input_ids] = self.tokenizer.tashkeel_map[self.tokenizer.no_tashkeel_tag]
143
+ text_with_tashkeel_mini = self.tokenizer.decode(batch_input_ids, y_pred)
144
+ text_with_tashkeel += text_with_tashkeel_mini
145
+
146
+ return text_with_tashkeel
147
+
148
+
149
+ @torch.no_grad()
150
+ def do_tashkeel(self, text):
151
+ return self.do_tashkeel_batch([text])[0]
models/best_ed_mlm_ns_epoch_178.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6297ec54f9af518a59d4b9e5831d3ad4da304435b8c7fb969f6e3a16816ac23
3
+ size 88403510
models/best_eo_mlm_ns_epoch_193.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ad3e811b1bf7ecda252dc9284d2bf3087c1795add47be5ff34cb13688413322
3
+ size 75762232
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  requests
2
  gradio
3
- shakkala
 
 
1
  requests
2
  gradio
3
+ shakkala
4
+ kaldialign
tashkeel_tokenizer.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import re
4
+ import bw2ar
5
+ import torch
6
+ import xer
7
+
8
+ # Diacritics
9
+ FATHATAN = u'\u064b'
10
+ DAMMATAN = u'\u064c'
11
+ KASRATAN = u'\u064d'
12
+ FATHA = u'\u064e'
13
+ DAMMA = u'\u064f'
14
+ KASRA = u'\u0650'
15
+ SHADDA = u'\u0651'
16
+ SUKUN = u'\u0652'
17
+ TATWEEL = u'\u0640'
18
+
19
+ HARAKAT_PAT = re.compile(u"["+u"".join([FATHATAN, DAMMATAN, KASRATAN,
20
+ FATHA, DAMMA, KASRA, SUKUN,
21
+ SHADDA])+u"]")
22
+
23
+
24
+ class TashkeelTokenizer:
25
+
26
+ def __init__(self):
27
+ self.letters = [' ', '$', '&', "'", '*', '<', '>', 'A', 'D', 'E', 'H', 'S', 'T', 'Y', 'Z',
28
+ 'b', 'd', 'f', 'g', 'h', 'j', 'k', 'l', 'm', 'n', 'p', 'q', 'r', 's', 't',
29
+ 'v', 'w', 'x', 'y', 'z', '|', '}'
30
+ ]
31
+ self.letters = ['<PAD>', '<BOS>', '<EOS>'] + self.letters + ['<MASK>']
32
+
33
+ self.no_tashkeel_tag = '<NT>'
34
+ self.tashkeel_list = ['<NT>', '<SD>', '<SDD>', '<SF>', '<SFF>', '<SK>',
35
+ '<SKK>', 'F', 'K', 'N', 'a', 'i', 'o', 'u', '~']
36
+
37
+ self.tashkeel_list = ['<PAD>', '<BOS>', '<EOS>'] + self.tashkeel_list
38
+
39
+ self.tashkeel_map = {c:i for i,c in enumerate(self.tashkeel_list)}
40
+ self.letters_map = {c:i for i,c in enumerate(self.letters)}
41
+ self.inverse_tags = {
42
+ '~a': '<SF>', # shaddah and fatHa
43
+ '~u': '<SD>', # shaddah and Damma
44
+ '~i': '<SK>', # shaddah and kasra
45
+ '~F': '<SFF>', # shaddah and fatHatayn
46
+ '~N': '<SDD>', # shaddah and Dammatayn
47
+ '~K': '<SKK>' # shaddah and kasratayn
48
+ }
49
+ self.tags = {v:k for k,v in self.inverse_tags.items()}
50
+ self.shaddah_last = ['a~', 'u~', 'i~', 'F~', 'N~', 'K~']
51
+ self.shaddah_first = ['~a', '~u', '~i', '~F', '~N', '~K']
52
+ self.tahkeel_chars = ['F','N','K','a', 'u', 'i', '~', 'o']
53
+
54
+
55
+ def clean_text(self, text):
56
+ text = re.sub(u'[%s]' % u'\u0640', '', text) # strip tatweel
57
+ text = text.replace('ٱ', 'ا')
58
+ return ' '.join(re.sub(u"[^\u0621-\u063A\u0640-\u0652\u0670\u0671\ufefb\ufef7\ufef5\ufef9 ]", " ", text, flags=re.UNICODE).split())
59
+
60
+
61
+ def check_match(self, text_with_tashkeel, letter_n_tashkeel_pairs):
62
+ text_with_tashkeel = text_with_tashkeel.strip()
63
+ # test if the reconstructed text with tashkeel is the same as the original one
64
+ syn_text = self.combine_tashkeel_with_text(letter_n_tashkeel_pairs)
65
+ return syn_text == text_with_tashkeel or syn_text == self.unify_shaddah_position(text_with_tashkeel)
66
+
67
+
68
+ def unify_shaddah_position(self, text_with_tashkeel):
69
+ # unify the order of shaddah and the harakah to make shaddah always at the beginning
70
+ for i in range(len(self.shaddah_first)):
71
+ text_with_tashkeel = text_with_tashkeel.replace(self.shaddah_last[i], self.shaddah_first[i])
72
+ return text_with_tashkeel
73
+
74
+
75
+ def split_tashkeel_from_text(self, text_with_tashkeel, test_match=True):
76
+ text_with_tashkeel = self.clean_text(text_with_tashkeel)
77
+ text_with_tashkeel = bw2ar.transliterate_text(text_with_tashkeel, 'ar2bw')
78
+ text_with_tashkeel = text_with_tashkeel.replace('`', '') # remove dagger 'alif
79
+
80
+ # unify the order of shaddah and the harakah to make shaddah always at the beginning
81
+ text_with_tashkeel = self.unify_shaddah_position(text_with_tashkeel)
82
+
83
+ # remove duplicated harakat
84
+ for i in range(len(self.tahkeel_chars)):
85
+ text_with_tashkeel = text_with_tashkeel.replace(self.tahkeel_chars[i]*2, self.tahkeel_chars[i])
86
+
87
+ letter_n_tashkeel_pairs = []
88
+ for i in range(len(text_with_tashkeel)): # go over the whole text
89
+ # check if the first character is a normal letter and the second character is a tashkeel
90
+ if i < (len(text_with_tashkeel) - 1) and not text_with_tashkeel[i] in self.tashkeel_list and text_with_tashkeel[i+1] in self.tashkeel_list:
91
+ # IMPORTANT: check if tashkeel is Shaddah, then there might be another Tashkeel char associated with it. If so,
92
+ # replace both Shaddah and the Tashkeel chars with the appropriate tag
93
+ if text_with_tashkeel[i+1] == '~':
94
+ # IMPORTANT: the following if statement depends on the concept of short circuit!!
95
+ # The first condition checks if there are still more chars before it access position i+2
96
+ # "text_with_tashkeel[i+2]" since it causes "index out of range" exception. Notice that
97
+ # Shaddah here is put in the first position before the Harakah.
98
+ if i+2 < len(text_with_tashkeel) and f'~{text_with_tashkeel[i+2]}' in self.inverse_tags:
99
+ letter_n_tashkeel_pairs.append((text_with_tashkeel[i], self.inverse_tags[f'~{text_with_tashkeel[i+2]}']))
100
+ else:
101
+ # if it is only Shaddah, just add it to the list
102
+ letter_n_tashkeel_pairs.append((text_with_tashkeel[i], '~'))
103
+ else:
104
+ letter_n_tashkeel_pairs.append((text_with_tashkeel[i], text_with_tashkeel[i+1]))
105
+ # if the character at position i is a normal letter and has no Tashkeel, then add
106
+ # it with the tag "self.no_tashkeel_tag"
107
+ # IMPORTANT: this elif block ensures also that there is no two or more consecutive tashkeel other than shaddah
108
+ elif not text_with_tashkeel[i] in self.tashkeel_list:
109
+ letter_n_tashkeel_pairs.append((text_with_tashkeel[i], self.no_tashkeel_tag))
110
+
111
+ if test_match:
112
+ # test if the split is done correctly by ensuring that we can retrieve back the original text
113
+ assert self.check_match(text_with_tashkeel, letter_n_tashkeel_pairs)
114
+ return [('<BOS>', '<BOS>')] + letter_n_tashkeel_pairs + [('<EOS>', '<EOS>')]
115
+
116
+
117
+ def combine_tashkeel_with_text(self, letter_n_tashkeel_pairs):
118
+ combined_with_tashkeel = []
119
+ for letter, tashkeel in letter_n_tashkeel_pairs:
120
+ combined_with_tashkeel.append(letter)
121
+ if tashkeel in self.tags:
122
+ combined_with_tashkeel.append(self.tags[tashkeel])
123
+ elif tashkeel != self.no_tashkeel_tag:
124
+ combined_with_tashkeel.append(tashkeel)
125
+ text = ''.join(combined_with_tashkeel)
126
+ return text
127
+
128
+
129
+ def encode(self, text_with_tashkeel, test_match=True):
130
+ letter_n_tashkeel_pairs = self.split_tashkeel_from_text(text_with_tashkeel, test_match)
131
+ text, tashkeel = zip(*letter_n_tashkeel_pairs)
132
+ input_ids = [self.letters_map[c] for c in text]
133
+ target_ids = [self.tashkeel_map[c] for c in tashkeel]
134
+ return torch.LongTensor(input_ids), torch.LongTensor(target_ids)
135
+
136
+
137
+ def filter_tashkeel(self, tashkeel):
138
+ tmp = []
139
+ for i, t in enumerate(tashkeel):
140
+ if i != 0 and t == '<BOS>':
141
+ t = self.no_tashkeel_tag
142
+ elif i != (len(tashkeel) - 1) and t == '<EOS>':
143
+ t = self.no_tashkeel_tag
144
+ tmp.append(t)
145
+ tashkeel = tmp
146
+ return tashkeel
147
+
148
+
149
+ def decode(self, input_ids, target_ids):
150
+ # print('input_ids.shape:', input_ids.shape)
151
+ # print('target_ids.shape:', target_ids.shape)
152
+ input_ids = input_ids.cpu().tolist()
153
+ target_ids = target_ids.cpu().tolist()
154
+ ar_texts = []
155
+ for j in range(len(input_ids)):
156
+ letters = [self.letters[i] for i in input_ids[j]]
157
+ tashkeel = [self.tashkeel_list[i] for i in target_ids[j]]
158
+
159
+ letters = list(filter(lambda x: x != '<BOS>' and x != '<EOS>' and x != '<PAD>', letters))
160
+ tashkeel = self.filter_tashkeel(tashkeel)
161
+ tashkeel = list(filter(lambda x: x != '<BOS>' and x != '<EOS>' and x != '<PAD>', tashkeel))
162
+
163
+ # VERY IMPORTANT NOTE: zip takes min(len(letters), len(tashkeel)) and discard the reset of letters / tashkeels
164
+ letter_n_tashkeel_pairs = list(zip(letters, tashkeel))
165
+ bw_text = self.combine_tashkeel_with_text(letter_n_tashkeel_pairs)
166
+ ar_text = bw2ar.transliterate_text(bw_text, 'bw2ar')
167
+ ar_texts.append(ar_text)
168
+ return ar_texts
169
+
170
+ def get_tashkeel_with_case_ending(self, text, case_ending=True):
171
+ text_split = self.split_tashkeel_from_text(text, test_match=False)
172
+ text_spaces_indecies = [i for i, el in enumerate(text_split) if el == (' ', '<NT>')]
173
+ new_text_split = []
174
+ for i, el in enumerate(text_split):
175
+ if not case_ending and (i+1) in text_spaces_indecies:
176
+ el = (el[0], '<NT>') # no case ending
177
+ new_text_split.append(el)
178
+ letters, tashkeel = zip(*new_text_split)
179
+ return letters, tashkeel
180
+
181
+
182
+ def compute_der(self, ref, hyp, case_ending=True):
183
+ _, ref_tashkeel = self.get_tashkeel_with_case_ending(ref, case_ending=case_ending)
184
+ _, hyp_tashkeel = self.get_tashkeel_with_case_ending(hyp, case_ending=case_ending)
185
+ ref_tashkeel = ' '.join(ref_tashkeel)
186
+ hyp_tashkeel = ' '.join(hyp_tashkeel)
187
+ return xer.wer(ref_tashkeel, hyp_tashkeel)
188
+
189
+ def compute_wer(self, ref, hyp, case_ending=True):
190
+ ref_letters, ref_tashkeel = self.get_tashkeel_with_case_ending(ref, case_ending=case_ending)
191
+ hyp_letters, hyp_tashkeel = self.get_tashkeel_with_case_ending(hyp, case_ending=case_ending)
192
+ ref_text_combined = self.combine_tashkeel_with_text(zip(ref_letters, ref_tashkeel))
193
+ hyp_text_combined = self.combine_tashkeel_with_text(zip(hyp_letters, hyp_tashkeel))
194
+ return xer.wer(ref_text_combined, hyp_text_combined)
195
+
196
+ def remove_tashkeel(self, text):
197
+ text = HARAKAT_PAT.sub('', text)
198
+ text = re.sub(u"[\u064E]", "", text, flags=re.UNICODE) # fattha
199
+ text = re.sub(u"[\u0671]", "", text, flags=re.UNICODE) # waSla
200
+ return text
201
+
202
+
203
+
204
+ if __name__ == '__main__':
205
+ import utils
206
+ from tqdm import tqdm
207
+ tokenizer = TashkeelTokenizer()
208
+
209
+ txt_folder_path = 'dataset/train'
210
+ prepared_lines = []
211
+ for filepath in utils.get_files(txt_folder_path, '*.txt'):
212
+ print(f'Reading file: {filepath}')
213
+ with open(filepath) as f1:
214
+ for line in f1:
215
+ clean_line = tokenizer.clean_text(line)
216
+ if clean_line != '':
217
+ prepared_lines.append(clean_line)
218
+ print(f'completed file: {filepath}')
219
+
220
+ good_sentences = []
221
+ bad_sentences = []
222
+ tokenized_sentences = []
223
+ for line in tqdm(prepared_lines):
224
+ try:
225
+ letter_n_tashkeel_pairs = tokenizer.split_tashkeel_from_text(line, test_match=True)
226
+ tokenized_sentences.append(letter_n_tashkeel_pairs)
227
+ good_sentences.append(line)
228
+ except AssertionError as e:
229
+ bad_sentences.append(line)
230
+
231
+ print('len(good_sentences), len(bad_sentences):', len(good_sentences), len(bad_sentences))
232
+
233
+
234
+
transformer.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @author : Hyunwoong
3
+ @when : 2019-12-18
4
+ @homepage : https://github.com/gusdnd852
5
+ """
6
+
7
+ import math
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+
12
+ class EncoderLayer(nn.Module):
13
+
14
+ def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
15
+ super(EncoderLayer, self).__init__()
16
+ self.attention = MultiHeadAttention(d_model=d_model, n_head=n_head)
17
+ self.norm1 = LayerNorm(d_model=d_model)
18
+ self.dropout1 = nn.Dropout(p=drop_prob)
19
+
20
+ self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
21
+ self.norm2 = LayerNorm(d_model=d_model)
22
+ self.dropout2 = nn.Dropout(p=drop_prob)
23
+
24
+ def forward(self, x, s_mask):
25
+ # 1. compute self attention
26
+ _x = x
27
+ x = self.attention(q=x, k=x, v=x, mask=s_mask)
28
+
29
+ # 2. add and norm
30
+ x = self.dropout1(x)
31
+ x = self.norm1(x + _x)
32
+
33
+ # 3. positionwise feed forward network
34
+ _x = x
35
+ x = self.ffn(x)
36
+
37
+ # 4. add and norm
38
+ x = self.dropout2(x)
39
+ x = self.norm2(x + _x)
40
+ return x
41
+
42
+
43
+ class DecoderLayer(nn.Module):
44
+
45
+ def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
46
+ super(DecoderLayer, self).__init__()
47
+ self.self_attention = MultiHeadAttention(d_model=d_model, n_head=n_head)
48
+ self.norm1 = LayerNorm(d_model=d_model)
49
+ self.dropout1 = nn.Dropout(p=drop_prob)
50
+
51
+ self.enc_dec_attention = MultiHeadAttention(d_model=d_model, n_head=n_head)
52
+ self.norm2 = LayerNorm(d_model=d_model)
53
+ self.dropout2 = nn.Dropout(p=drop_prob)
54
+
55
+ self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
56
+ self.norm3 = LayerNorm(d_model=d_model)
57
+ self.dropout3 = nn.Dropout(p=drop_prob)
58
+
59
+ def forward(self, dec, enc, t_mask, s_mask):
60
+ # 1. compute self attention
61
+ _x = dec
62
+ x = self.self_attention(q=dec, k=dec, v=dec, mask=t_mask)
63
+
64
+ # 2. add and norm
65
+ x = self.dropout1(x)
66
+ x = self.norm1(x + _x)
67
+
68
+ if enc is not None:
69
+ # 3. compute encoder - decoder attention
70
+ _x = x
71
+ x = self.enc_dec_attention(q=x, k=enc, v=enc, mask=s_mask)
72
+
73
+ # 4. add and norm
74
+ x = self.dropout2(x)
75
+ x = self.norm2(x + _x)
76
+
77
+ # 5. positionwise feed forward network
78
+ _x = x
79
+ x = self.ffn(x)
80
+
81
+ # 6. add and norm
82
+ x = self.dropout3(x)
83
+ x = self.norm3(x + _x)
84
+ return x
85
+
86
+
87
+ class ScaleDotProductAttention(nn.Module):
88
+ """
89
+ compute scale dot product attention
90
+
91
+ Query : given sentence that we focused on (decoder)
92
+ Key : every sentence to check relationship with Qeury(encoder)
93
+ Value : every sentence same with Key (encoder)
94
+ """
95
+
96
+ def __init__(self):
97
+ super(ScaleDotProductAttention, self).__init__()
98
+ self.softmax = nn.Softmax(dim=-1)
99
+
100
+ def forward(self, q, k, v, mask=None, e=1e-12):
101
+ # input is 4 dimension tensor
102
+ # [batch_size, head, length, d_tensor]
103
+ batch_size, head, length, d_tensor = k.size()
104
+
105
+ # 1. dot product Query with Key^T to compute similarity
106
+ k_t = k.transpose(2, 3) # transpose
107
+ score = (q @ k_t) / math.sqrt(d_tensor) # scaled dot product
108
+
109
+ # 2. apply masking (opt)
110
+ if mask is not None:
111
+ score = score.masked_fill(mask == 0, -10000)
112
+
113
+ # 3. pass them softmax to make [0, 1] range
114
+ score = self.softmax(score)
115
+
116
+ # 4. multiply with Value
117
+ v = score @ v
118
+
119
+ return v, score
120
+
121
+
122
+ class PositionwiseFeedForward(nn.Module):
123
+
124
+ def __init__(self, d_model, hidden, drop_prob=0.1):
125
+ super(PositionwiseFeedForward, self).__init__()
126
+ self.linear1 = nn.Linear(d_model, hidden)
127
+ self.linear2 = nn.Linear(hidden, d_model)
128
+ self.relu = nn.ReLU()
129
+ self.dropout = nn.Dropout(p=drop_prob)
130
+
131
+ def forward(self, x):
132
+ x = self.linear1(x)
133
+ x = self.relu(x)
134
+ x = self.dropout(x)
135
+ x = self.linear2(x)
136
+ return x
137
+
138
+
139
+ class MultiHeadAttention(nn.Module):
140
+
141
+ def __init__(self, d_model, n_head):
142
+ super(MultiHeadAttention, self).__init__()
143
+ self.n_head = n_head
144
+ self.attention = ScaleDotProductAttention()
145
+ self.w_q = nn.Linear(d_model, d_model, bias=False)
146
+ self.w_k = nn.Linear(d_model, d_model, bias=False)
147
+ self.w_v = nn.Linear(d_model, d_model, bias=False)
148
+ self.w_concat = nn.Linear(d_model, d_model, bias=False)
149
+
150
+ def forward(self, q, k, v, mask=None):
151
+ # 1. dot product with weight matrices
152
+ q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
153
+
154
+ # 2. split tensor by number of heads
155
+ q, k, v = self.split(q), self.split(k), self.split(v)
156
+
157
+ # 3. do scale dot product to compute similarity
158
+ out, attention = self.attention(q, k, v, mask=mask)
159
+
160
+ # 4. concat and pass to linear layer
161
+ out = self.concat(out)
162
+ out = self.w_concat(out)
163
+
164
+ # 5. visualize attention map
165
+ # TODO : we should implement visualization
166
+
167
+ return out
168
+
169
+ def split(self, tensor):
170
+ """
171
+ split tensor by number of head
172
+
173
+ :param tensor: [batch_size, length, d_model]
174
+ :return: [batch_size, head, length, d_tensor]
175
+ """
176
+ batch_size, length, d_model = tensor.size()
177
+
178
+ d_tensor = d_model // self.n_head
179
+ tensor = tensor.view(batch_size, length, self.n_head, d_tensor).transpose(1, 2)
180
+ # it is similar with group convolution (split by number of heads)
181
+
182
+ return tensor
183
+
184
+ def concat(self, tensor):
185
+ """
186
+ inverse function of self.split(tensor : torch.Tensor)
187
+
188
+ :param tensor: [batch_size, head, length, d_tensor]
189
+ :return: [batch_size, length, d_model]
190
+ """
191
+ batch_size, head, length, d_tensor = tensor.size()
192
+ d_model = head * d_tensor
193
+
194
+ tensor = tensor.transpose(1, 2).contiguous().view(batch_size, length, d_model)
195
+ return tensor
196
+
197
+
198
+ class LayerNorm(nn.Module):
199
+ def __init__(self, d_model, eps=1e-12):
200
+ super(LayerNorm, self).__init__()
201
+ self.gamma = nn.Parameter(torch.ones(d_model))
202
+ self.beta = nn.Parameter(torch.zeros(d_model))
203
+ self.eps = eps
204
+
205
+ def forward(self, x):
206
+ mean = x.mean(-1, keepdim=True)
207
+ var = x.var(-1, unbiased=False, keepdim=True)
208
+ # '-1' means last dimension.
209
+
210
+ out = (x - mean) / torch.sqrt(var + self.eps)
211
+ out = self.gamma * out + self.beta
212
+ return out
213
+
214
+
215
+ class TransformerEmbedding(nn.Module):
216
+ """
217
+ token embedding + positional encoding (sinusoid)
218
+ positional encoding can give positional information to network
219
+ """
220
+
221
+ def __init__(self, vocab_size, d_model, max_len, drop_prob, padding_idx, learnable_pos_emb=True):
222
+ """
223
+ class for word embedding that included positional information
224
+
225
+ :param vocab_size: size of vocabulary
226
+ :param d_model: dimensions of model
227
+ """
228
+ super(TransformerEmbedding, self).__init__()
229
+ self.tok_emb = TokenEmbedding(vocab_size, d_model, padding_idx)
230
+ if learnable_pos_emb:
231
+ self.pos_emb = LearnablePositionalEncoding(d_model, max_len)
232
+ else:
233
+ self.pos_emb = SinusoidalPositionalEncoding(d_model, max_len)
234
+ self.drop_out = nn.Dropout(p=drop_prob)
235
+
236
+ def forward(self, x):
237
+ tok_emb = self.tok_emb(x)
238
+ pos_emb = self.pos_emb(x).to(tok_emb.device)
239
+ return self.drop_out(tok_emb + pos_emb)
240
+
241
+
242
+ class TokenEmbedding(nn.Embedding):
243
+ """
244
+ Token Embedding using torch.nn
245
+ they will dense representation of word using weighted matrix
246
+ """
247
+
248
+ def __init__(self, vocab_size, d_model, padding_idx):
249
+ """
250
+ class for token embedding that included positional information
251
+
252
+ :param vocab_size: size of vocabulary
253
+ :param d_model: dimensions of model
254
+ """
255
+ super(TokenEmbedding, self).__init__(vocab_size, d_model, padding_idx=padding_idx)
256
+
257
+
258
+ class SinusoidalPositionalEncoding(nn.Module):
259
+ """
260
+ compute sinusoid encoding.
261
+ """
262
+
263
+ def __init__(self, d_model, max_len):
264
+ """
265
+ constructor of sinusoid encoding class
266
+
267
+ :param d_model: dimension of model
268
+ :param max_len: max sequence length
269
+
270
+ """
271
+ super(SinusoidalPositionalEncoding, self).__init__()
272
+
273
+ # same size with input matrix (for adding with input matrix)
274
+ self.encoding = torch.zeros(max_len, d_model)
275
+ self.encoding.requires_grad = False # we don't need to compute gradient
276
+
277
+ pos = torch.arange(0, max_len)
278
+ pos = pos.float().unsqueeze(dim=1)
279
+ # 1D => 2D unsqueeze to represent word's position
280
+
281
+ _2i = torch.arange(0, d_model, step=2).float()
282
+ # 'i' means index of d_model (e.g. embedding size = 50, 'i' = [0,50])
283
+ # "step=2" means 'i' multiplied with two (same with 2 * i)
284
+
285
+ self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
286
+ self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
287
+ # compute positional encoding to consider positional information of words
288
+
289
+ def forward(self, x):
290
+ # self.encoding
291
+ # [max_len = 512, d_model = 512]
292
+
293
+ batch_size, seq_len = x.size()
294
+ # [batch_size = 128, seq_len = 30]
295
+
296
+ return self.encoding[:seq_len, :]
297
+ # [seq_len = 30, d_model = 512]
298
+ # it will add with tok_emb : [128, 30, 512]
299
+
300
+
301
+ class LearnablePositionalEncoding(nn.Module):
302
+ """
303
+ compute sinusoid encoding.
304
+ """
305
+
306
+ def __init__(self, d_model, max_seq_len):
307
+ """
308
+ constructor of learnable positonal encoding class
309
+
310
+ :param d_model: dimension of model
311
+ :param max_seq_len: max sequence length
312
+
313
+ """
314
+ super(LearnablePositionalEncoding, self).__init__()
315
+ self.max_seq_len = max_seq_len
316
+ self.wpe = nn.Embedding(max_seq_len, d_model)
317
+
318
+ def forward(self, x):
319
+ # self.encoding
320
+ # [max_len = 512, d_model = 512]
321
+ device = x.device
322
+ batch_size, seq_len = x.size()
323
+ assert seq_len <= self.max_seq_len, f"Cannot forward sequence of length {seq_len}, max_seq_len is {self.max_seq_len}"
324
+ pos = torch.arange(0, seq_len, dtype=torch.long, device=device) # shape (seq_len)
325
+ pos_emb = self.wpe(pos) # position embeddings of shape (seq_len, d_model)
326
+
327
+ return pos_emb
328
+ # [seq_len = 30, d_model = 512]
329
+ # it will add with tok_emb : [128, 30, 512]
330
+
331
+
332
+ class Encoder(nn.Module):
333
+
334
+ def __init__(self, enc_voc_size, max_len, d_model, ffn_hidden, n_head, n_layers, drop_prob, padding_idx, learnable_pos_emb=True):
335
+ super().__init__()
336
+ self.emb = TransformerEmbedding(d_model=d_model,
337
+ max_len=max_len,
338
+ vocab_size=enc_voc_size,
339
+ drop_prob=drop_prob,
340
+ padding_idx=padding_idx,
341
+ learnable_pos_emb=learnable_pos_emb
342
+ )
343
+
344
+ self.layers = nn.ModuleList([EncoderLayer(d_model=d_model,
345
+ ffn_hidden=ffn_hidden,
346
+ n_head=n_head,
347
+ drop_prob=drop_prob)
348
+ for _ in range(n_layers)])
349
+
350
+ def forward(self, x, s_mask):
351
+ x = self.emb(x)
352
+
353
+ for layer in self.layers:
354
+ x = layer(x, s_mask)
355
+
356
+ return x
357
+
358
+ class Decoder(nn.Module):
359
+ def __init__(self, dec_voc_size, max_len, d_model, ffn_hidden, n_head, n_layers, drop_prob, padding_idx, learnable_pos_emb=True):
360
+ super().__init__()
361
+ self.emb = TransformerEmbedding(d_model=d_model,
362
+ drop_prob=drop_prob,
363
+ max_len=max_len,
364
+ vocab_size=dec_voc_size,
365
+ padding_idx=padding_idx,
366
+ learnable_pos_emb=learnable_pos_emb
367
+ )
368
+
369
+ self.layers = nn.ModuleList([DecoderLayer(d_model=d_model,
370
+ ffn_hidden=ffn_hidden,
371
+ n_head=n_head,
372
+ drop_prob=drop_prob)
373
+ for _ in range(n_layers)])
374
+
375
+ self.linear = nn.Linear(d_model, dec_voc_size)
376
+
377
+ def forward(self, trg, enc_src, trg_mask, src_mask):
378
+ trg = self.emb(trg)
379
+
380
+ for layer in self.layers:
381
+ trg = layer(trg, enc_src, trg_mask, src_mask)
382
+
383
+ # pass to LM head
384
+ output = self.linear(trg)
385
+ return output
386
+
387
+ class Transformer(nn.Module):
388
+
389
+ def __init__(self, src_pad_idx, trg_pad_idx, enc_voc_size, dec_voc_size, d_model, n_head, max_len,
390
+ ffn_hidden, n_layers, drop_prob, learnable_pos_emb=True):
391
+ super().__init__()
392
+ self.src_pad_idx = src_pad_idx
393
+ self.trg_pad_idx = trg_pad_idx
394
+ self.encoder = Encoder(d_model=d_model,
395
+ n_head=n_head,
396
+ max_len=max_len,
397
+ ffn_hidden=ffn_hidden,
398
+ enc_voc_size=enc_voc_size,
399
+ drop_prob=drop_prob,
400
+ n_layers=n_layers,
401
+ padding_idx=src_pad_idx,
402
+ learnable_pos_emb=learnable_pos_emb)
403
+
404
+ self.decoder = Decoder(d_model=d_model,
405
+ n_head=n_head,
406
+ max_len=max_len,
407
+ ffn_hidden=ffn_hidden,
408
+ dec_voc_size=dec_voc_size,
409
+ drop_prob=drop_prob,
410
+ n_layers=n_layers,
411
+ padding_idx=trg_pad_idx,
412
+ learnable_pos_emb=learnable_pos_emb)
413
+
414
+ def get_device(self):
415
+ return next(self.parameters()).device
416
+
417
+ def forward(self, src, trg):
418
+ device = self.get_device()
419
+ src_mask = self.make_pad_mask(src, src, self.src_pad_idx, self.src_pad_idx).to(device)
420
+ src_trg_mask = self.make_pad_mask(trg, src, self.trg_pad_idx, self.src_pad_idx).to(device)
421
+ trg_mask = self.make_pad_mask(trg, trg, self.trg_pad_idx, self.trg_pad_idx).to(device) * \
422
+ self.make_no_peak_mask(trg, trg).to(device)
423
+
424
+ #print(src_mask)
425
+ #print('-'*100)
426
+ #print(trg_mask)
427
+ enc_src = self.encoder(src, src_mask)
428
+ output = self.decoder(trg, enc_src, trg_mask, src_trg_mask)
429
+ return output
430
+
431
+ def make_pad_mask(self, q, k, q_pad_idx, k_pad_idx):
432
+ len_q, len_k = q.size(1), k.size(1)
433
+
434
+ # batch_size x 1 x 1 x len_k
435
+ k = k.ne(k_pad_idx).unsqueeze(1).unsqueeze(2)
436
+ # batch_size x 1 x len_q x len_k
437
+ k = k.repeat(1, 1, len_q, 1)
438
+
439
+ # batch_size x 1 x len_q x 1
440
+ q = q.ne(q_pad_idx).unsqueeze(1).unsqueeze(3)
441
+ # batch_size x 1 x len_q x len_k
442
+ q = q.repeat(1, 1, 1, len_k)
443
+
444
+ mask = k & q
445
+ return mask
446
+
447
+ def make_no_peak_mask(self, q, k):
448
+ len_q, len_k = q.size(1), k.size(1)
449
+
450
+ # len_q x len_k
451
+ mask = torch.tril(torch.ones(len_q, len_k)).type(torch.BoolTensor)
452
+
453
+ return mask
454
+
455
+
456
+ def make_pad_mask(x, pad_idx):
457
+ q = k = x
458
+ q_pad_idx = k_pad_idx = pad_idx
459
+ len_q, len_k = q.size(1), k.size(1)
460
+
461
+ # batch_size x 1 x 1 x len_k
462
+ k = k.ne(k_pad_idx).unsqueeze(1).unsqueeze(2)
463
+ # batch_size x 1 x len_q x len_k
464
+ k = k.repeat(1, 1, len_q, 1)
465
+
466
+ # batch_size x 1 x len_q x 1
467
+ q = q.ne(q_pad_idx).unsqueeze(1).unsqueeze(3)
468
+ # batch_size x 1 x len_q x len_k
469
+ q = q.repeat(1, 1, 1, len_k)
470
+
471
+ mask = k & q
472
+ return mask
473
+
474
+
475
+ from torch.nn.utils.rnn import pad_sequence
476
+ # x_list is a list of tensors of shape TxH where T is the seqlen and H is the feats dim
477
+ def pad_seq_v2(sequences, batch_first=True, padding_value=0.0, prepadding=True):
478
+ lens = [i.shape[0]for i in sequences]
479
+ padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=padding_value) # NxTxH
480
+ if prepadding:
481
+ for i in range(len(lens)):
482
+ padded_sequences[i] = padded_sequences[i].roll(-lens[i])
483
+ if not batch_first:
484
+ padded_sequences = padded_sequences.transpose(0, 1) # TxNxH
485
+ return padded_sequences
486
+
487
+
488
+
489
+ if __name__ == '__main__':
490
+ import torch
491
+ import random
492
+ import numpy as np
493
+
494
+ rand_seed = 10
495
+
496
+ device = 'cpu'
497
+
498
+ # model parameter setting
499
+ batch_size = 128
500
+ max_len = 256
501
+ d_model = 512
502
+ n_layers = 3
503
+ n_heads = 16
504
+ ffn_hidden = 2048
505
+ drop_prob = 0.1
506
+
507
+ # optimizer parameter setting
508
+ init_lr = 1e-5
509
+ factor = 0.9
510
+ adam_eps = 5e-9
511
+ patience = 10
512
+ warmup = 100
513
+ epoch = 1000
514
+ clip = 1.0
515
+ weight_decay = 5e-4
516
+ inf = float('inf')
517
+
518
+ src_pad_idx = 2
519
+ trg_pad_idx = 3
520
+
521
+ enc_voc_size = 37
522
+ dec_voc_size = 15
523
+ model = Transformer(src_pad_idx=src_pad_idx,
524
+ trg_pad_idx=trg_pad_idx,
525
+ d_model=d_model,
526
+ enc_voc_size=enc_voc_size,
527
+ dec_voc_size=dec_voc_size,
528
+ max_len=max_len,
529
+ ffn_hidden=ffn_hidden,
530
+ n_head=n_heads,
531
+ n_layers=n_layers,
532
+ drop_prob=drop_prob
533
+ ).to(device)
534
+
535
+ random.seed(rand_seed)
536
+ # Set the seed to 0 for reproducible results
537
+ np.random.seed(rand_seed)
538
+ torch.manual_seed(rand_seed)
539
+
540
+ x_list = [
541
+ torch.tensor([[1, 1]]).transpose(0, 1), # 2
542
+ torch.tensor([[1, 1, 1, 1, 1, 1, 1]]).transpose(0, 1), # 7
543
+ torch.tensor([[1, 1, 1]]).transpose(0, 1) # 3
544
+ ]
545
+
546
+
547
+ src_pad_idx = model.src_pad_idx
548
+ trg_pad_idx = model.trg_pad_idx
549
+
550
+ src = pad_seq_v2(x_list, padding_value=src_pad_idx, prepadding=False).squeeze(2)
551
+ trg = pad_seq_v2(x_list, padding_value=trg_pad_idx, prepadding=False).squeeze(2)
552
+ out = model(src, trg)
553
+
554
+
555
+
556
+
557
+
558
+
559
+
utils.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ note: this code is used in bw2ar.py file
3
+ """
4
+
5
+ #!/usr/bin/python
6
+ # -*- coding=utf-8 -*-
7
+ #---
8
+ # $Id: arabic.py,v 1.6 2003/04/22 17:18:22 elzubeir Exp $
9
+ #
10
+ # ------------
11
+ # Description:
12
+ # ------------
13
+ #
14
+ # Arabic codes
15
+ #
16
+ # (C) Copyright 2003, Arabeyes, Mohammed Elzubeir
17
+ # (C) Copyright 2019, Faris Abdullah Alasmary
18
+ # -----------------
19
+ # Revision Details: (Updated by Revision Control System)
20
+ # -----------------
21
+ # $Date: 2003/04/22 17:18:22 $
22
+ # $Author: elzubeir $
23
+ # $Revision: 1.6 $
24
+ # $Source: /home/arabeyes/cvs/projects/duali/pyduali/pyduali/arabic.py,v $
25
+ #
26
+ # This program is written under the BSD License.
27
+ #---
28
+ """ Constants for arabic """
29
+ import re
30
+ COMMA = u'\u060C'
31
+ SEMICOLON = u'\u061B'
32
+ QUESTION = u'\u061F'
33
+ HAMZA = u'\u0621'
34
+ ALEF_MADDA = u'\u0622'
35
+ ALEF_HAMZA_ABOVE = u'\u0623'
36
+ WAW_HAMZA = u'\u0624'
37
+ ALEF_HAMZA_BELOW = u'\u0625'
38
+ YEH_HAMZA = u'\u0626'
39
+ ALEF = u'\u0627'
40
+ BEH = u'\u0628'
41
+ TEH_MARBUTA = u'\u0629'
42
+ TEH = u'\u062a'
43
+ THEH = u'\u062b'
44
+ JEEM = u'\u062c'
45
+ HAH = u'\u062d'
46
+ KHAH = u'\u062e'
47
+ DAL = u'\u062f'
48
+ THAL = u'\u0630'
49
+ REH = u'\u0631'
50
+ ZAIN = u'\u0632'
51
+ SEEN = u'\u0633'
52
+ SHEEN = u'\u0634'
53
+ SAD = u'\u0635'
54
+ DAD = u'\u0636'
55
+ TAH = u'\u0637'
56
+ ZAH = u'\u0638'
57
+ AIN = u'\u0639'
58
+ GHAIN = u'\u063a'
59
+ TATWEEL = u'\u0640'
60
+ FEH = u'\u0641'
61
+ QAF = u'\u0642'
62
+ KAF = u'\u0643'
63
+ LAM = u'\u0644'
64
+ MEEM = u'\u0645'
65
+ NOON = u'\u0646'
66
+ HEH = u'\u0647'
67
+ WAW = u'\u0648'
68
+ ALEF_MAKSURA = u'\u0649'
69
+ YEH = u'\u064a'
70
+ MADDA_ABOVE = u'\u0653'
71
+ HAMZA_ABOVE = u'\u0654'
72
+ HAMZA_BELOW = u'\u0655'
73
+ ZERO = u'\u0660'
74
+ ONE = u'\u0661'
75
+ TWO = u'\u0662'
76
+ THREE = u'\u0663'
77
+ FOUR = u'\u0664'
78
+ FIVE = u'\u0665'
79
+ SIX = u'\u0666'
80
+ SEVEN = u'\u0667'
81
+ EIGHT = u'\u0668'
82
+ NINE = u'\u0669'
83
+ PERCENT = u'\u066a'
84
+ DECIMAL = u'\u066b'
85
+ THOUSANDS = u'\u066c'
86
+ STAR = u'\u066d'
87
+ MINI_ALEF = u'\u0670'
88
+ ALEF_WASLA = u'\u0671'
89
+ FULL_STOP = u'\u06d4'
90
+ BYTE_ORDER_MARK = u'\ufeff'
91
+
92
+ # Diacritics
93
+ FATHATAN = u'\u064b'
94
+ DAMMATAN = u'\u064c'
95
+ KASRATAN = u'\u064d'
96
+ FATHA = u'\u064e'
97
+ DAMMA = u'\u064f'
98
+ KASRA = u'\u0650'
99
+ SHADDA = u'\u0651'
100
+ SUKUN = u'\u0652'
101
+
102
+ #Ligatures
103
+ LAM_ALEF = u'\ufefb'
104
+ LAM_ALEF_HAMZA_ABOVE = u'\ufef7'
105
+ LAM_ALEF_HAMZA_BELOW = u'\ufef9'
106
+ LAM_ALEF_MADDA_ABOVE = u'\ufef5'
107
+ SIMPLE_LAM_ALEF = u'\u0644\u0627'
108
+ SIMPLE_LAM_ALEF_HAMZA_ABOVE = u'\u0644\u0623'
109
+ SIMPLE_LAM_ALEF_HAMZA_BELOW = u'\u0644\u0625'
110
+ SIMPLE_LAM_ALEF_MADDA_ABOVE = u'\u0644\u0622'
111
+
112
+
113
+ HARAKAT_PAT = re.compile(u"["+u"".join([FATHATAN, DAMMATAN, KASRATAN,
114
+ FATHA, DAMMA, KASRA, SUKUN,
115
+ SHADDA])+u"]")
116
+ HAMZAT_PAT = re.compile(u"["+u"".join([WAW_HAMZA, YEH_HAMZA])+u"]")
117
+ ALEFAT_PAT = re.compile(u"["+u"".join([ALEF_MADDA, ALEF_HAMZA_ABOVE,
118
+ ALEF_HAMZA_BELOW, HAMZA_ABOVE,
119
+ HAMZA_BELOW])+u"]")
120
+ LAMALEFAT_PAT = re.compile(u"["+u"".join([LAM_ALEF,
121
+ LAM_ALEF_HAMZA_ABOVE,
122
+ LAM_ALEF_HAMZA_BELOW,
123
+ LAM_ALEF_MADDA_ABOVE])+u"]")
124
+
125
+ def strip_tashkeel(text):
126
+ text = HARAKAT_PAT.sub('', text)
127
+ text = re.sub(u"[\u064E]", "", text, flags=re.UNICODE) # fattha
128
+ text = re.sub(u"[\u0671]", "", text, flags=re.UNICODE) # waSla
129
+ return text
130
+
131
+ def strip_tatweel(text):
132
+ return re.sub(u'[%s]' % TATWEEL, '', text)
133
+
134
+ # remove removing Tashkeel + removing Tatweel + non Arabic chars
135
+ def remove_non_arabic(text):
136
+ text = strip_tashkeel(text)
137
+ text = strip_tatweel(text)
138
+ return ' '.join(re.sub(u"[^\u0621-\u063A\u0641-\u064A ]", " ", text, flags=re.UNICODE).split())
139
+
xer.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @author
3
+ ______ _ _
4
+ | ____| (_) /\ | |
5
+ | |__ __ _ _ __ _ ___ / \ | | __ _ ___ _ __ ___ __ _ _ __ _ _
6
+ | __/ _` | '__| / __| / /\ \ | |/ _` / __| '_ ` _ \ / _` | '__| | | |
7
+ | | | (_| | | | \__ \ / ____ \| | (_| \__ \ | | | | | (_| | | | |_| |
8
+ |_| \__,_|_| |_|___/ /_/ \_\_|\__,_|___/_| |_| |_|\__,_|_| \__, |
9
+ __/ |
10
+ |___/
11
12
+ Date: Mar 15, 2022
13
+ """
14
+
15
+ # pip install git+https://github.com/pzelasko/kaldialign.git
16
+
17
+ from kaldialign import edit_distance
18
+
19
+
20
+ def cer(ref, hyp):
21
+ """
22
+ Computes the Character Error Rate, defined as the edit distance.
23
+
24
+ Arguments:
25
+ ref (string): a space-separated ground truth string
26
+ hyp (string): a space-separated hypothesis
27
+ """
28
+ ref, hyp, = ref.replace(' ', '').strip(), hyp.replace(' ', '').strip()
29
+ info = edit_distance(ref, hyp)
30
+ distance = info['total']
31
+ ref_length = float(len(ref))
32
+
33
+ data = {
34
+ 'insertions': info['ins'],
35
+ 'deletions': info['del'],
36
+ 'substitutions': info['sub'],
37
+ 'distance': distance,
38
+ 'ref_length': ref_length,
39
+ 'Error Rate': (distance / ref_length) * 100
40
+ }
41
+
42
+ return data
43
+
44
+
45
+ def wer(ref, hyp):
46
+ """
47
+ Computes the Word Error Rate, defined as the edit distance between the
48
+ two provided sentences after tokenizing to words.
49
+ Arguments:
50
+ ref (string): a space-separated ground truth string
51
+ hyp (string): a space-separated hypothesis
52
+ """
53
+
54
+ # build mapping of words to integers
55
+ b = set(ref.split() + hyp.split())
56
+ word2char = dict(zip(b, range(len(b))))
57
+
58
+ # map the words to a char array (Levenshtein packages only accepts strings)
59
+ w1 = [chr(word2char[w]) for w in ref.split()]
60
+ w2 = [chr(word2char[w]) for w in hyp.split()]
61
+
62
+ info = edit_distance(''.join(w1), ''.join(w2))
63
+ distance = info['total']
64
+ ref_length = float(len(w1))
65
+
66
+ data = {
67
+ 'insertions': info['ins'],
68
+ 'deletions': info['del'],
69
+ 'substitutions': info['sub'],
70
+ 'distance': distance,
71
+ 'ref_length': ref_length,
72
+ 'Error Rate': (distance / ref_length) * 100
73
+ }
74
+
75
+ return data
76
+