Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
bcc0c7f
1
Parent(s):
7efa162
chore: Add requirements for shakkala and kaldialign
Browse files- app.py +45 -22
- bw2ar.py +126 -0
- ed.py +63 -0
- ed_pl.py +164 -0
- eo.py +50 -0
- eo_pl.py +151 -0
- models/best_ed_mlm_ns_epoch_178.pt +3 -0
- models/best_eo_mlm_ns_epoch_193.pt +3 -0
- requirements.txt +2 -1
- tashkeel_tokenizer.py +234 -0
- transformer.py +559 -0
- utils.py +139 -0
- xer.py +76 -0
app.py
CHANGED
@@ -1,23 +1,18 @@
|
|
1 |
-
import requests
|
2 |
import gradio as gr
|
3 |
from shakkala import Shakkala
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
def add_tashkeel1(text):
|
9 |
-
data = {"textArabic": text}
|
10 |
-
response = requests.post(tashkel_url, data=data)
|
11 |
-
response.encoding = "utf-8"
|
12 |
-
text = response.text.strip()
|
13 |
-
return text
|
14 |
-
|
15 |
-
|
16 |
sh = Shakkala(version=3)
|
17 |
model, graph = sh.get_model()
|
18 |
|
19 |
-
|
20 |
-
def add_tashkeel2(input_text):
|
21 |
input_int = sh.prepare_input(input_text)
|
22 |
logits = model.predict(input_int)[0]
|
23 |
predicted_harakat = sh.logits_to_text(logits)
|
@@ -25,28 +20,56 @@ def add_tashkeel2(input_text):
|
|
25 |
print(final_output)
|
26 |
return final_output
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
with gr.Blocks(title="Arabic Tashkeel") as demo:
|
30 |
gr.HTML("<center><h1>Arabic Tashkeel</h1></center>")
|
31 |
gr.HTML(
|
32 |
"<center><p>Compare different methods for adding tashkeel to Arabic text.</p></center>"
|
33 |
)
|
34 |
-
with gr.Tab(label="
|
|
|
35 |
with gr.Row():
|
36 |
with gr.Column():
|
37 |
-
text_input1 = gr.Textbox(
|
38 |
-
|
|
|
|
|
|
|
39 |
)
|
40 |
with gr.Row():
|
41 |
clear_button1 = gr.Button(value="Clear", variant="secondary")
|
42 |
submit_button1 = gr.Button(value="Add Tashkeel", variant="primary")
|
43 |
|
44 |
with gr.Column():
|
45 |
-
text_output1 = gr.Textbox(
|
46 |
-
lines=1, label="Output Text", rtl=True, text_align="right"
|
47 |
-
)
|
48 |
|
49 |
-
submit_button1.click(
|
50 |
clear_button1.click(lambda: text_input1.update(""))
|
51 |
|
52 |
with gr.Tab(label="Shakkala"):
|
@@ -66,7 +89,7 @@ with gr.Blocks(title="Arabic Tashkeel") as demo:
|
|
66 |
lines=1, label="Output Text", rtl=True, text_align="right"
|
67 |
)
|
68 |
|
69 |
-
submit_button2.click(
|
70 |
clear_button2.click(lambda: text_input2.update(""))
|
71 |
|
72 |
demo.queue().launch()
|
|
|
|
|
1 |
import gradio as gr
|
2 |
from shakkala import Shakkala
|
3 |
+
from pathlib import Path
|
4 |
+
import torch
|
5 |
+
from eo_pl import TashkeelModel as TashkeelModelEO
|
6 |
+
from ed_pl import TashkeelModel as TashkeelModelED
|
7 |
+
from tashkeel_tokenizer import TashkeelTokenizer
|
8 |
+
from utils import remove_non_arabic
|
9 |
+
import spaces
|
10 |
|
11 |
+
# Initialize the Shakkala model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
sh = Shakkala(version=3)
|
13 |
model, graph = sh.get_model()
|
14 |
|
15 |
+
def infer_shakkala(input_text):
|
|
|
16 |
input_int = sh.prepare_input(input_text)
|
17 |
logits = model.predict(input_int)[0]
|
18 |
predicted_harakat = sh.logits_to_text(logits)
|
|
|
20 |
print(final_output)
|
21 |
return final_output
|
22 |
|
23 |
+
# Initialize the CaTT model and tokenizer
|
24 |
+
tokenizer = TashkeelTokenizer()
|
25 |
+
eo_ckpt_path = Path(__file__).parent / 'models/best_eo_mlm_ns_epoch_193.pt'
|
26 |
+
|
27 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
28 |
+
print('device:', device)
|
29 |
+
|
30 |
+
max_seq_len = 1024
|
31 |
+
print('Creating Model...')
|
32 |
+
eo_model = TashkeelModelEO(tokenizer, max_seq_len=max_seq_len, n_layers=6, learnable_pos_emb=False)
|
33 |
+
ed_model = TashkeelModelED(tokenizer, max_seq_len=max_seq_len, n_layers=3, learnable_pos_emb=False)
|
34 |
+
|
35 |
+
eo_model.load_state_dict(torch.load(eo_ckpt_path, map_location=device)).eval().to(device)
|
36 |
+
ed_model.load_state_dict(torch.load(eo_ckpt_path, map_location=device)).eval().to(device)
|
37 |
+
|
38 |
+
@spaces.GPU()
|
39 |
+
def infer_catt(input_text, choose_model):
|
40 |
+
input_text = remove_non_arabic(input_text)
|
41 |
+
batch_size = 16
|
42 |
+
verbose = True
|
43 |
+
if choose_model == 'Encoder-Only':
|
44 |
+
output_text = eo_model.do_tashkeel_batch([input_text], batch_size, verbose)
|
45 |
+
else:
|
46 |
+
output_text = ed_model.do_tashkeel_batch([input_text], batch_size, verbose)
|
47 |
+
|
48 |
+
return output_text[0]
|
49 |
|
50 |
with gr.Blocks(title="Arabic Tashkeel") as demo:
|
51 |
gr.HTML("<center><h1>Arabic Tashkeel</h1></center>")
|
52 |
gr.HTML(
|
53 |
"<center><p>Compare different methods for adding tashkeel to Arabic text.</p></center>"
|
54 |
)
|
55 |
+
with gr.Tab(label="CATT"):
|
56 |
+
gr.Markdown("[CATT](https://github.com/abjadai/catt) is a new deep learning model for Arabic diacritization.")
|
57 |
with gr.Row():
|
58 |
with gr.Column():
|
59 |
+
text_input1 = gr.Textbox(label="Input Text", rtl=True, text_align="right")
|
60 |
+
choose_model = gr.Radio(
|
61 |
+
label="Choose Model",
|
62 |
+
choices=["Encoder-Only", "Encoder-Decoder"],
|
63 |
+
default="Encoder-Decoder",
|
64 |
)
|
65 |
with gr.Row():
|
66 |
clear_button1 = gr.Button(value="Clear", variant="secondary")
|
67 |
submit_button1 = gr.Button(value="Add Tashkeel", variant="primary")
|
68 |
|
69 |
with gr.Column():
|
70 |
+
text_output1 = gr.Textbox(label="Output Text", rtl=True, text_align="right")
|
|
|
|
|
71 |
|
72 |
+
submit_button1.click(infer_catt, inputs=[text_input1, choose_model], outputs=text_output1)
|
73 |
clear_button1.click(lambda: text_input1.update(""))
|
74 |
|
75 |
with gr.Tab(label="Shakkala"):
|
|
|
89 |
lines=1, label="Output Text", rtl=True, text_align="right"
|
90 |
)
|
91 |
|
92 |
+
submit_button2.click(infer_shakkala, inputs=text_input2, outputs=text_output2)
|
93 |
clear_button2.click(lambda: text_input2.update(""))
|
94 |
|
95 |
demo.queue().launch()
|
bw2ar.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
functions to convert Arabic words/text into buckwalter encoding and vice versa
|
3 |
+
"""
|
4 |
+
|
5 |
+
import sys
|
6 |
+
import re
|
7 |
+
import utils
|
8 |
+
|
9 |
+
buck2uni = {
|
10 |
+
"'": u"\u0621", # hamza-on-the-line
|
11 |
+
"|": u"\u0622", # madda
|
12 |
+
">": u"\u0623", # hamza-on-'alif
|
13 |
+
"&": u"\u0624", # hamza-on-waaw
|
14 |
+
"<": u"\u0625", # hamza-under-'alif
|
15 |
+
"}": u"\u0626", # hamza-on-yaa'
|
16 |
+
"A": u"\u0627", # bare 'alif
|
17 |
+
"b": u"\u0628", # baa'
|
18 |
+
"p": u"\u0629", # taa' marbuuTa
|
19 |
+
"t": u"\u062A", # taa'
|
20 |
+
"v": u"\u062B", # thaa'
|
21 |
+
"j": u"\u062C", # jiim
|
22 |
+
"H": u"\u062D", # Haa'
|
23 |
+
"x": u"\u062E", # khaa'
|
24 |
+
"d": u"\u062F", # daal
|
25 |
+
"*": u"\u0630", # dhaal
|
26 |
+
"r": u"\u0631", # raa'
|
27 |
+
"z": u"\u0632", # zaay
|
28 |
+
"s": u"\u0633", # siin
|
29 |
+
"$": u"\u0634", # shiin
|
30 |
+
"S": u"\u0635", # Saad
|
31 |
+
"D": u"\u0636", # Daad
|
32 |
+
"T": u"\u0637", # Taa'
|
33 |
+
"Z": u"\u0638", # Zaa' (DHaa')
|
34 |
+
"E": u"\u0639", # cayn
|
35 |
+
"g": u"\u063A", # ghayn
|
36 |
+
"_": u"\u0640", # taTwiil
|
37 |
+
"f": u"\u0641", # faa'
|
38 |
+
"q": u"\u0642", # qaaf
|
39 |
+
"k": u"\u0643", # kaaf
|
40 |
+
"l": u"\u0644", # laam
|
41 |
+
"m": u"\u0645", # miim
|
42 |
+
"n": u"\u0646", # nuun
|
43 |
+
"h": u"\u0647", # haa'
|
44 |
+
"w": u"\u0648", # waaw
|
45 |
+
"Y": u"\u0649", # 'alif maqSuura
|
46 |
+
"y": u"\u064A", # yaa'
|
47 |
+
"F": u"\u064B", # fatHatayn
|
48 |
+
"N": u"\u064C", # Dammatayn
|
49 |
+
"K": u"\u064D", # kasratayn
|
50 |
+
"a": u"\u064E", # fatHa
|
51 |
+
"u": u"\u064F", # Damma
|
52 |
+
"i": u"\u0650", # kasra
|
53 |
+
"~": u"\u0651", # shaddah
|
54 |
+
"o": u"\u0652", # sukuun
|
55 |
+
"`": u"\u0670", # dagger 'alif
|
56 |
+
"{": u"\u0671", # waSla
|
57 |
+
}
|
58 |
+
|
59 |
+
# For a reverse transliteration (Unicode -> Buckwalter), a dictionary
|
60 |
+
# which is the reverse of the above buck2uni is essential.
|
61 |
+
uni2buck = {}
|
62 |
+
|
63 |
+
# Iterate through all the items in the buck2uni dict.
|
64 |
+
for (key, value) in buck2uni.items():
|
65 |
+
# The value from buck2uni becomes a key in uni2buck, and vice
|
66 |
+
# versa for the keys.
|
67 |
+
uni2buck[value] = key
|
68 |
+
|
69 |
+
# add special characters
|
70 |
+
uni2buck[u"\ufefb"] = "lA"
|
71 |
+
uni2buck[u"\ufef7"] = "l>"
|
72 |
+
uni2buck[u"\ufef5"] = "l|"
|
73 |
+
uni2buck[u"\ufef9"] = "l<"
|
74 |
+
|
75 |
+
# clean the arabic text from unwanted characters that may cause problem while building the language model
|
76 |
+
def clean_text(text):
|
77 |
+
text = re.sub(u"[\ufeff]", "", text, flags=re.UNICODE) # strip Unicode Character 'ZERO WIDTH NO-BREAK SPACE' (U+FEFF). For more info, check http://www.fileformat.info/info/unicode/char/feff/index.htm
|
78 |
+
text = utils.remove_non_arabic(text)
|
79 |
+
text = utils.strip_tashkeel(text)
|
80 |
+
text = utils.strip_tatweel(text)
|
81 |
+
return text
|
82 |
+
|
83 |
+
# convert a single word into buckwalter and vice versa
|
84 |
+
def transliterate_word(input_word, direction='bw2ar'):
|
85 |
+
output_word = ''
|
86 |
+
# Loop over each character in the string, bw_word.
|
87 |
+
for char in input_word:
|
88 |
+
# Look up current char in the dictionary to get its
|
89 |
+
# respective value. If there is no match, e.g., chars like
|
90 |
+
# spaces, then just stick with the current char without any
|
91 |
+
# conversion.
|
92 |
+
# if type(char) == bytes:
|
93 |
+
# char = char.decode('ascii')
|
94 |
+
if direction == 'bw2ar':
|
95 |
+
#print('in bw2ar')
|
96 |
+
output_word += buck2uni.get(char, char)
|
97 |
+
elif direction == 'ar2bw':
|
98 |
+
#print('in ar2bw')
|
99 |
+
output_word += uni2buck.get(char, char)
|
100 |
+
else:
|
101 |
+
sys.stderr.write('Error: invalid direction!')
|
102 |
+
sys.exit()
|
103 |
+
return output_word
|
104 |
+
|
105 |
+
|
106 |
+
# convert a text into buckwalter and vice versa
|
107 |
+
def transliterate_text(input_text, direction='bw2ar'):
|
108 |
+
output_text = ''
|
109 |
+
for input_word in input_text.split(' '):
|
110 |
+
output_text += transliterate_word(input_word, direction) + ' '
|
111 |
+
|
112 |
+
return output_text[:-1] # remove the last space ONLY
|
113 |
+
|
114 |
+
|
115 |
+
if __name__ == '__main__':
|
116 |
+
if len(sys.argv) < 2:
|
117 |
+
sys.stderr.write('Usage: INPUT TEXT | python {} DIRECTION(bw2ar|ar2bw)'.format(sys.argv[1]))
|
118 |
+
exit(1)
|
119 |
+
for line in sys.stdin:
|
120 |
+
line = line if sys.argv[1] == 'bw2ar' else clean_text(line)
|
121 |
+
output_text = transliterate_text(line, direction=str(sys.argv[1]))
|
122 |
+
if output_text.strip() != '':
|
123 |
+
sys.stdout.write('{}\n'.format(output_text.strip()))
|
124 |
+
|
125 |
+
|
126 |
+
|
ed.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformer import *
|
4 |
+
|
5 |
+
class Transformer(nn.Module):
|
6 |
+
|
7 |
+
def __init__(self, src_pad_idx, trg_pad_idx, enc_voc_size, dec_voc_size, d_model, n_head, max_len,
|
8 |
+
ffn_hidden, n_layers, drop_prob, learnable_pos_emb=True):
|
9 |
+
super().__init__()
|
10 |
+
self.src_pad_idx = src_pad_idx
|
11 |
+
self.trg_pad_idx = trg_pad_idx
|
12 |
+
self.encoder = Encoder(d_model=d_model,
|
13 |
+
n_head=n_head,
|
14 |
+
max_len=max_len,
|
15 |
+
ffn_hidden=ffn_hidden,
|
16 |
+
enc_voc_size=enc_voc_size,
|
17 |
+
drop_prob=drop_prob,
|
18 |
+
n_layers=n_layers,
|
19 |
+
padding_idx=src_pad_idx,
|
20 |
+
learnable_pos_emb=learnable_pos_emb)
|
21 |
+
|
22 |
+
self.decoder = Decoder(d_model=d_model,
|
23 |
+
n_head=n_head,
|
24 |
+
max_len=max_len,
|
25 |
+
ffn_hidden=ffn_hidden,
|
26 |
+
dec_voc_size=dec_voc_size,
|
27 |
+
drop_prob=drop_prob,
|
28 |
+
n_layers=n_layers,
|
29 |
+
padding_idx=trg_pad_idx,
|
30 |
+
learnable_pos_emb=learnable_pos_emb)
|
31 |
+
|
32 |
+
def get_device(self):
|
33 |
+
return next(self.parameters()).device
|
34 |
+
|
35 |
+
def forward(self, src, trg):
|
36 |
+
device = self.get_device()
|
37 |
+
src_mask = self.make_pad_mask(src, src, self.src_pad_idx, self.src_pad_idx).to(device)
|
38 |
+
src_trg_mask = self.make_pad_mask(trg, src, self.trg_pad_idx, self.src_pad_idx).to(device)
|
39 |
+
trg_mask = self.make_pad_mask(trg, trg, self.trg_pad_idx, self.trg_pad_idx).to(device) * \
|
40 |
+
self.make_no_peak_mask(trg, trg).to(device)
|
41 |
+
enc_src = self.encoder(src, src_mask)
|
42 |
+
output = self.decoder(trg, enc_src, trg_mask, src_trg_mask)
|
43 |
+
return output
|
44 |
+
|
45 |
+
def make_pad_mask(self, q, k, q_pad_idx, k_pad_idx):
|
46 |
+
len_q, len_k = q.size(1), k.size(1)
|
47 |
+
# batch_size x 1 x 1 x len_k
|
48 |
+
k = k.ne(k_pad_idx).unsqueeze(1).unsqueeze(2)
|
49 |
+
# batch_size x 1 x len_q x len_k
|
50 |
+
k = k.repeat(1, 1, len_q, 1)
|
51 |
+
# batch_size x 1 x len_q x 1
|
52 |
+
q = q.ne(q_pad_idx).unsqueeze(1).unsqueeze(3)
|
53 |
+
# batch_size x 1 x len_q x len_k
|
54 |
+
q = q.repeat(1, 1, 1, len_k)
|
55 |
+
mask = k & q
|
56 |
+
return mask
|
57 |
+
|
58 |
+
def make_no_peak_mask(self, q, k):
|
59 |
+
len_q, len_k = q.size(1), k.size(1)
|
60 |
+
# len_q x len_k
|
61 |
+
mask = torch.tril(torch.ones(len_q, len_k)).type(torch.BoolTensor)
|
62 |
+
return mask
|
63 |
+
|
ed_pl.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from ed import Transformer
|
8 |
+
from tqdm import tqdm
|
9 |
+
import math
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
from torch.nn.utils.rnn import pad_sequence
|
14 |
+
# sequences is a list of tensors of shape TxH where T is the seqlen and H is the feats dim
|
15 |
+
def pad_seq(sequences, batch_first=True, padding_value=0.0, prepadding=True):
|
16 |
+
lens = [i.shape[0]for i in sequences]
|
17 |
+
padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=padding_value) # NxTxH
|
18 |
+
if prepadding:
|
19 |
+
for i in range(len(lens)):
|
20 |
+
padded_sequences[i] = padded_sequences[i].roll(-lens[i])
|
21 |
+
if not batch_first:
|
22 |
+
padded_sequences = padded_sequences.transpose(0, 1) # TxNxH
|
23 |
+
return padded_sequences
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
def get_batches(X, batch_size=16):
|
28 |
+
num_batches = math.ceil(len(X) / batch_size)
|
29 |
+
for i in range(num_batches):
|
30 |
+
x = X[i*batch_size : (i+1)*batch_size]
|
31 |
+
yield x
|
32 |
+
|
33 |
+
|
34 |
+
class TashkeelModel(pl.LightningModule):
|
35 |
+
def __init__(self, tokenizer, max_seq_len, d_model=512, n_layers=3, n_heads=16, drop_prob=0.1, learnable_pos_emb=True):
|
36 |
+
|
37 |
+
super(TashkeelModel, self).__init__()
|
38 |
+
|
39 |
+
ffn_hidden = 4 * d_model
|
40 |
+
src_pad_idx = tokenizer.letters_map['<PAD>']
|
41 |
+
trg_pad_idx = tokenizer.tashkeel_map['<PAD>']
|
42 |
+
enc_voc_size = len(tokenizer.letters_map) # 37 + 3
|
43 |
+
dec_voc_size = len(tokenizer.tashkeel_map) # 15 + 3
|
44 |
+
self.transformer = Transformer(src_pad_idx=src_pad_idx,
|
45 |
+
trg_pad_idx=trg_pad_idx,
|
46 |
+
d_model=d_model,
|
47 |
+
enc_voc_size=enc_voc_size,
|
48 |
+
dec_voc_size=dec_voc_size,
|
49 |
+
max_len=max_seq_len,
|
50 |
+
ffn_hidden=ffn_hidden,
|
51 |
+
n_head=n_heads,
|
52 |
+
n_layers=n_layers,
|
53 |
+
drop_prob=drop_prob,
|
54 |
+
learnable_pos_emb=learnable_pos_emb
|
55 |
+
)
|
56 |
+
|
57 |
+
self.criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.tashkeel_map['<PAD>'])
|
58 |
+
self.tokenizer = tokenizer
|
59 |
+
|
60 |
+
|
61 |
+
def forward(self, x, y=None):
|
62 |
+
y_pred = self.transformer(x, y)
|
63 |
+
return y_pred
|
64 |
+
|
65 |
+
|
66 |
+
def training_step(self, batch, batch_idx):
|
67 |
+
input_ids, target_ids = batch
|
68 |
+
input_ids = input_ids[:, :-1]
|
69 |
+
y_in = target_ids[:, :-1]
|
70 |
+
y_out = target_ids[:, 1:]
|
71 |
+
y_pred = self(input_ids, y_in)
|
72 |
+
loss = self.criterion(y_pred.transpose(1, 2), y_out)
|
73 |
+
|
74 |
+
self.log('train_loss', loss, prog_bar=True)
|
75 |
+
sch = self.lr_schedulers()
|
76 |
+
sch.step()
|
77 |
+
self.log('lr', sch.get_last_lr()[0], prog_bar=True)
|
78 |
+
return loss
|
79 |
+
|
80 |
+
|
81 |
+
def validation_step(self, batch, batch_idx):
|
82 |
+
input_ids, target_ids = batch
|
83 |
+
input_ids = input_ids[:, :-1]
|
84 |
+
y_in = target_ids[:, :-1]
|
85 |
+
y_out = target_ids[:, 1:]
|
86 |
+
y_pred = self(input_ids, y_in)
|
87 |
+
loss = self.criterion(y_pred.transpose(1, 2), y_out)
|
88 |
+
|
89 |
+
pred_text_with_tashkeels = self.tokenizer.decode(input_ids, y_pred.argmax(2).squeeze())
|
90 |
+
true_text_with_tashkeels = self.tokenizer.decode(input_ids, y_out)
|
91 |
+
total_val_der_distance = 0
|
92 |
+
total_val_der_ref_length = 0
|
93 |
+
for i in range(len(true_text_with_tashkeels)):
|
94 |
+
pred_text_with_tashkeel = pred_text_with_tashkeels[i]
|
95 |
+
true_text_with_tashkeel = true_text_with_tashkeels[i]
|
96 |
+
val_der = self.tokenizer.compute_der(true_text_with_tashkeel, pred_text_with_tashkeel)
|
97 |
+
total_val_der_distance += val_der['distance']
|
98 |
+
total_val_der_ref_length += val_der['ref_length']
|
99 |
+
|
100 |
+
total_der_error = total_val_der_distance / total_val_der_ref_length
|
101 |
+
self.log('val_loss', loss)
|
102 |
+
self.log('val_der', torch.FloatTensor([total_der_error]))
|
103 |
+
self.log('val_der_distance', torch.FloatTensor([total_val_der_distance]))
|
104 |
+
self.log('val_der_ref_length', torch.FloatTensor([total_val_der_ref_length]))
|
105 |
+
|
106 |
+
|
107 |
+
def test_step(self, batch, batch_idx):
|
108 |
+
input_ids, target_ids = batch
|
109 |
+
y_pred = self(input_ids, None)
|
110 |
+
loss = self.criterion(y_pred.transpose(1, 2), target_ids)
|
111 |
+
self.log('test_loss', loss)
|
112 |
+
|
113 |
+
|
114 |
+
def configure_optimizers(self):
|
115 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=3e-4)
|
116 |
+
#max_iters = 10000
|
117 |
+
#lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_iters, eta_min=3e-6)
|
118 |
+
gamma = 1 / 1.000001
|
119 |
+
#gamma = 1 / 1.0001
|
120 |
+
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)
|
121 |
+
opts = {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
|
122 |
+
return opts
|
123 |
+
|
124 |
+
|
125 |
+
@torch.no_grad()
|
126 |
+
def do_tashkeel_batch(self, texts, batch_size=16, verbose=True):
|
127 |
+
self.eval()
|
128 |
+
device = next(self.parameters()).device
|
129 |
+
text_with_tashkeel = []
|
130 |
+
data_iter = get_batches(texts, batch_size)
|
131 |
+
if verbose:
|
132 |
+
num_batches = math.ceil(len(texts) / batch_size)
|
133 |
+
data_iter = tqdm(data_iter, total=num_batches)
|
134 |
+
for texts_mini in data_iter:
|
135 |
+
input_ids_list = []
|
136 |
+
for text in texts_mini:
|
137 |
+
input_ids, _ = self.tokenizer.encode(text, test_match=False)
|
138 |
+
input_ids_list.append(input_ids)
|
139 |
+
batch_input_ids = pad_seq(input_ids_list, batch_first=True, padding_value=self.tokenizer.letters_map['<PAD>'], prepadding=False)
|
140 |
+
target_ids = torch.LongTensor([[self.tokenizer.tashkeel_map['<BOS>']]] * len(texts_mini)).to(device)
|
141 |
+
src = batch_input_ids.to(device)
|
142 |
+
|
143 |
+
src_mask = self.transformer.make_pad_mask(src, src, self.transformer.src_pad_idx, self.transformer.src_pad_idx).to(device)
|
144 |
+
enc_src = self.transformer.encoder(src, src_mask)
|
145 |
+
|
146 |
+
for i in range(src.shape[1] - 1):
|
147 |
+
trg = target_ids
|
148 |
+
src_trg_mask = self.transformer.make_pad_mask(trg, src, self.transformer.trg_pad_idx, self.transformer.src_pad_idx).to(device)
|
149 |
+
trg_mask = self.transformer.make_pad_mask(trg, trg, self.transformer.trg_pad_idx, self.transformer.trg_pad_idx).to(device) * \
|
150 |
+
self.transformer.make_no_peak_mask(trg, trg).to(device)
|
151 |
+
|
152 |
+
preds = self.transformer.decoder(trg, enc_src, trg_mask, src_trg_mask)
|
153 |
+
# IMPORTANT NOTE: the following code snippet is to FORCE the prediction of the input space char to output no_tashkeel tag '<NT>'
|
154 |
+
target_ids = torch.cat([target_ids, preds[:, -1].argmax(1).unsqueeze(1)], axis=1)
|
155 |
+
target_ids[self.tokenizer.letters_map[' '] == src[:, :target_ids.shape[1]]] = self.tokenizer.tashkeel_map[self.tokenizer.no_tashkeel_tag]
|
156 |
+
# target_ids = torch.cat([target_ids, preds[:, -1].argmax(1).unsqueeze(1)], axis=1)
|
157 |
+
text_with_tashkeel_mini = self.tokenizer.decode(src, target_ids)
|
158 |
+
text_with_tashkeel += text_with_tashkeel_mini
|
159 |
+
return text_with_tashkeel
|
160 |
+
|
161 |
+
|
162 |
+
@torch.no_grad()
|
163 |
+
def do_tashkeel(self, text):
|
164 |
+
return self.do_tashkeel_batch([text])[0]
|
eo.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from transformer import *
|
3 |
+
|
4 |
+
class Transformer(nn.Module):
|
5 |
+
def __init__(self, src_pad_idx, trg_pad_idx, enc_voc_size, dec_voc_size, d_model, n_head, max_len,
|
6 |
+
ffn_hidden, n_layers, drop_prob, learnable_pos_emb=True):
|
7 |
+
super().__init__()
|
8 |
+
self.src_pad_idx = src_pad_idx
|
9 |
+
self.trg_pad_idx = trg_pad_idx
|
10 |
+
self.encoder = Encoder(d_model=d_model,
|
11 |
+
n_head=n_head,
|
12 |
+
max_len=max_len,
|
13 |
+
ffn_hidden=ffn_hidden,
|
14 |
+
enc_voc_size=enc_voc_size,
|
15 |
+
drop_prob=drop_prob,
|
16 |
+
n_layers=n_layers,
|
17 |
+
padding_idx=src_pad_idx,
|
18 |
+
learnable_pos_emb=learnable_pos_emb)
|
19 |
+
|
20 |
+
self.decoder = nn.Linear(d_model, dec_voc_size)
|
21 |
+
|
22 |
+
def get_device(self):
|
23 |
+
return next(self.parameters()).device
|
24 |
+
|
25 |
+
def forward(self, src):
|
26 |
+
device = self.get_device()
|
27 |
+
src_mask = self.make_pad_mask(src, src, self.src_pad_idx, self.src_pad_idx).to(device)
|
28 |
+
enc_src = self.encoder(src, src_mask)
|
29 |
+
output = self.decoder(enc_src)
|
30 |
+
return output
|
31 |
+
|
32 |
+
def make_pad_mask(self, q, k, q_pad_idx, k_pad_idx):
|
33 |
+
len_q, len_k = q.size(1), k.size(1)
|
34 |
+
# batch_size x 1 x 1 x len_k
|
35 |
+
k = k.ne(k_pad_idx).unsqueeze(1).unsqueeze(2)
|
36 |
+
# batch_size x 1 x len_q x len_k
|
37 |
+
k = k.repeat(1, 1, len_q, 1)
|
38 |
+
# batch_size x 1 x len_q x 1
|
39 |
+
q = q.ne(q_pad_idx).unsqueeze(1).unsqueeze(3)
|
40 |
+
# batch_size x 1 x len_q x len_k
|
41 |
+
q = q.repeat(1, 1, 1, len_k)
|
42 |
+
mask = k & q
|
43 |
+
return mask
|
44 |
+
|
45 |
+
def make_no_peak_mask(self, q, k):
|
46 |
+
len_q, len_k = q.size(1), k.size(1)
|
47 |
+
# len_q x len_k
|
48 |
+
mask = torch.tril(torch.ones(len_q, len_k)).type(torch.BoolTensor)
|
49 |
+
return mask
|
50 |
+
|
eo_pl.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from eo import Transformer
|
8 |
+
from tqdm import tqdm
|
9 |
+
import math
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
from torch.nn.utils.rnn import pad_sequence
|
14 |
+
# sequences is a list of tensors of shape TxH where T is the seqlen and H is the feats dim
|
15 |
+
def pad_seq(sequences, batch_first=True, padding_value=0.0, prepadding=True):
|
16 |
+
lens = [i.shape[0]for i in sequences]
|
17 |
+
padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=padding_value) # NxTxH
|
18 |
+
if prepadding:
|
19 |
+
for i in range(len(lens)):
|
20 |
+
padded_sequences[i] = padded_sequences[i].roll(-lens[i])
|
21 |
+
if not batch_first:
|
22 |
+
padded_sequences = padded_sequences.transpose(0, 1) # TxNxH
|
23 |
+
return padded_sequences
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
def get_batches(X, batch_size=16):
|
28 |
+
num_batches = math.ceil(len(X) / batch_size)
|
29 |
+
for i in range(num_batches):
|
30 |
+
x = X[i*batch_size : (i+1)*batch_size]
|
31 |
+
yield x
|
32 |
+
|
33 |
+
|
34 |
+
class TashkeelModel(pl.LightningModule):
|
35 |
+
def __init__(self, tokenizer, max_seq_len, d_model=512, n_layers=3, n_heads=16, drop_prob=0.1, learnable_pos_emb=True):
|
36 |
+
|
37 |
+
super(TashkeelModel, self).__init__()
|
38 |
+
|
39 |
+
ffn_hidden = 4 * d_model
|
40 |
+
src_pad_idx = tokenizer.letters_map['<PAD>']
|
41 |
+
trg_pad_idx = tokenizer.tashkeel_map['<PAD>']
|
42 |
+
enc_voc_size = len(tokenizer.letters_map) # 37 + 3
|
43 |
+
dec_voc_size = len(tokenizer.tashkeel_map) # 15 + 3
|
44 |
+
self.transformer = Transformer(src_pad_idx=src_pad_idx,
|
45 |
+
trg_pad_idx=trg_pad_idx,
|
46 |
+
d_model=d_model,
|
47 |
+
enc_voc_size=enc_voc_size,
|
48 |
+
dec_voc_size=dec_voc_size,
|
49 |
+
max_len=max_seq_len,
|
50 |
+
ffn_hidden=ffn_hidden,
|
51 |
+
n_head=n_heads,
|
52 |
+
n_layers=n_layers,
|
53 |
+
drop_prob=drop_prob,
|
54 |
+
learnable_pos_emb=learnable_pos_emb
|
55 |
+
)
|
56 |
+
|
57 |
+
self.criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.tashkeel_map['<PAD>'])
|
58 |
+
self.tokenizer = tokenizer
|
59 |
+
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
y_pred = self.transformer(x)
|
63 |
+
return y_pred
|
64 |
+
|
65 |
+
|
66 |
+
def training_step(self, batch, batch_idx):
|
67 |
+
input_ids, target_ids = batch
|
68 |
+
input_ids = input_ids[:, 1:-1]
|
69 |
+
y_out = target_ids[:, 1:-1]
|
70 |
+
y_pred = self(input_ids)
|
71 |
+
loss = self.criterion(y_pred.transpose(1, 2), y_out)
|
72 |
+
|
73 |
+
self.log('train_loss', loss, prog_bar=True)
|
74 |
+
# sch = self.lr_schedulers()
|
75 |
+
# sch.step()
|
76 |
+
# self.log('lr', sch.get_last_lr()[0], prog_bar=True)
|
77 |
+
return loss
|
78 |
+
|
79 |
+
|
80 |
+
def validation_step(self, batch, batch_idx):
|
81 |
+
input_ids, target_ids = batch
|
82 |
+
input_ids = input_ids[:, 1:-1]
|
83 |
+
y_out = target_ids[:, 1:-1]
|
84 |
+
y_pred = self(input_ids)
|
85 |
+
loss = self.criterion(y_pred.transpose(1, 2), y_out)
|
86 |
+
|
87 |
+
pred_text_with_tashkeels = self.tokenizer.decode(input_ids, y_pred.argmax(2).squeeze())
|
88 |
+
true_text_with_tashkeels = self.tokenizer.decode(input_ids, y_out)
|
89 |
+
total_val_der_distance = 0
|
90 |
+
total_val_der_ref_length = 0
|
91 |
+
for i in range(len(true_text_with_tashkeels)):
|
92 |
+
pred_text_with_tashkeel = pred_text_with_tashkeels[i]
|
93 |
+
true_text_with_tashkeel = true_text_with_tashkeels[i]
|
94 |
+
val_der = self.tokenizer.compute_der(true_text_with_tashkeel, pred_text_with_tashkeel)
|
95 |
+
total_val_der_distance += val_der['distance']
|
96 |
+
total_val_der_ref_length += val_der['ref_length']
|
97 |
+
|
98 |
+
total_der_error = total_val_der_distance / total_val_der_ref_length
|
99 |
+
self.log('val_loss', loss)
|
100 |
+
self.log('val_der', torch.FloatTensor([total_der_error]))
|
101 |
+
self.log('val_der_distance', torch.FloatTensor([total_val_der_distance]))
|
102 |
+
self.log('val_der_ref_length', torch.FloatTensor([total_val_der_ref_length]))
|
103 |
+
|
104 |
+
|
105 |
+
def test_step(self, batch, batch_idx):
|
106 |
+
input_ids, target_ids = batch
|
107 |
+
y_pred = self(input_ids, None)
|
108 |
+
loss = self.criterion(y_pred.transpose(1, 2), target_ids)
|
109 |
+
self.log('test_loss', loss)
|
110 |
+
|
111 |
+
|
112 |
+
def configure_optimizers(self):
|
113 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=3e-5)
|
114 |
+
#max_iters = 10000
|
115 |
+
#lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_iters, eta_min=3e-6)
|
116 |
+
gamma = 1 / 1.000001
|
117 |
+
#gamma = 1 / 1.0001
|
118 |
+
#lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)
|
119 |
+
opts = {"optimizer": optimizer} #, "lr_scheduler": lr_scheduler}
|
120 |
+
return opts
|
121 |
+
|
122 |
+
|
123 |
+
@torch.no_grad()
|
124 |
+
def do_tashkeel_batch(self, texts, batch_size=16, verbose=True):
|
125 |
+
self.eval()
|
126 |
+
device = next(self.parameters()).device
|
127 |
+
text_with_tashkeel = []
|
128 |
+
data_iter = get_batches(texts, batch_size)
|
129 |
+
if verbose:
|
130 |
+
num_batches = math.ceil(len(texts) / batch_size)
|
131 |
+
data_iter = tqdm(data_iter, total=num_batches)
|
132 |
+
for texts_mini in data_iter:
|
133 |
+
input_ids_list = []
|
134 |
+
for text in texts_mini:
|
135 |
+
input_ids, _ = self.tokenizer.encode(text, test_match=False)
|
136 |
+
input_ids_list.append(input_ids)
|
137 |
+
batch_input_ids = pad_seq(input_ids_list, batch_first=True, padding_value=self.tokenizer.letters_map['<PAD>'], prepadding=False)
|
138 |
+
batch_input_ids = batch_input_ids[:, 1:-1].to(device)
|
139 |
+
y_pred = self(batch_input_ids)
|
140 |
+
y_pred = y_pred.argmax(-1)
|
141 |
+
# IMPORTANT NOTE: the following code snippet is to FORCE the prediction of the input space char to output no_tashkeel tag '<NT>'
|
142 |
+
y_pred[self.tokenizer.letters_map[' '] == batch_input_ids] = self.tokenizer.tashkeel_map[self.tokenizer.no_tashkeel_tag]
|
143 |
+
text_with_tashkeel_mini = self.tokenizer.decode(batch_input_ids, y_pred)
|
144 |
+
text_with_tashkeel += text_with_tashkeel_mini
|
145 |
+
|
146 |
+
return text_with_tashkeel
|
147 |
+
|
148 |
+
|
149 |
+
@torch.no_grad()
|
150 |
+
def do_tashkeel(self, text):
|
151 |
+
return self.do_tashkeel_batch([text])[0]
|
models/best_ed_mlm_ns_epoch_178.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c6297ec54f9af518a59d4b9e5831d3ad4da304435b8c7fb969f6e3a16816ac23
|
3 |
+
size 88403510
|
models/best_eo_mlm_ns_epoch_193.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4ad3e811b1bf7ecda252dc9284d2bf3087c1795add47be5ff34cb13688413322
|
3 |
+
size 75762232
|
requirements.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
requests
|
2 |
gradio
|
3 |
-
shakkala
|
|
|
|
1 |
requests
|
2 |
gradio
|
3 |
+
shakkala
|
4 |
+
kaldialign
|
tashkeel_tokenizer.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
import re
|
4 |
+
import bw2ar
|
5 |
+
import torch
|
6 |
+
import xer
|
7 |
+
|
8 |
+
# Diacritics
|
9 |
+
FATHATAN = u'\u064b'
|
10 |
+
DAMMATAN = u'\u064c'
|
11 |
+
KASRATAN = u'\u064d'
|
12 |
+
FATHA = u'\u064e'
|
13 |
+
DAMMA = u'\u064f'
|
14 |
+
KASRA = u'\u0650'
|
15 |
+
SHADDA = u'\u0651'
|
16 |
+
SUKUN = u'\u0652'
|
17 |
+
TATWEEL = u'\u0640'
|
18 |
+
|
19 |
+
HARAKAT_PAT = re.compile(u"["+u"".join([FATHATAN, DAMMATAN, KASRATAN,
|
20 |
+
FATHA, DAMMA, KASRA, SUKUN,
|
21 |
+
SHADDA])+u"]")
|
22 |
+
|
23 |
+
|
24 |
+
class TashkeelTokenizer:
|
25 |
+
|
26 |
+
def __init__(self):
|
27 |
+
self.letters = [' ', '$', '&', "'", '*', '<', '>', 'A', 'D', 'E', 'H', 'S', 'T', 'Y', 'Z',
|
28 |
+
'b', 'd', 'f', 'g', 'h', 'j', 'k', 'l', 'm', 'n', 'p', 'q', 'r', 's', 't',
|
29 |
+
'v', 'w', 'x', 'y', 'z', '|', '}'
|
30 |
+
]
|
31 |
+
self.letters = ['<PAD>', '<BOS>', '<EOS>'] + self.letters + ['<MASK>']
|
32 |
+
|
33 |
+
self.no_tashkeel_tag = '<NT>'
|
34 |
+
self.tashkeel_list = ['<NT>', '<SD>', '<SDD>', '<SF>', '<SFF>', '<SK>',
|
35 |
+
'<SKK>', 'F', 'K', 'N', 'a', 'i', 'o', 'u', '~']
|
36 |
+
|
37 |
+
self.tashkeel_list = ['<PAD>', '<BOS>', '<EOS>'] + self.tashkeel_list
|
38 |
+
|
39 |
+
self.tashkeel_map = {c:i for i,c in enumerate(self.tashkeel_list)}
|
40 |
+
self.letters_map = {c:i for i,c in enumerate(self.letters)}
|
41 |
+
self.inverse_tags = {
|
42 |
+
'~a': '<SF>', # shaddah and fatHa
|
43 |
+
'~u': '<SD>', # shaddah and Damma
|
44 |
+
'~i': '<SK>', # shaddah and kasra
|
45 |
+
'~F': '<SFF>', # shaddah and fatHatayn
|
46 |
+
'~N': '<SDD>', # shaddah and Dammatayn
|
47 |
+
'~K': '<SKK>' # shaddah and kasratayn
|
48 |
+
}
|
49 |
+
self.tags = {v:k for k,v in self.inverse_tags.items()}
|
50 |
+
self.shaddah_last = ['a~', 'u~', 'i~', 'F~', 'N~', 'K~']
|
51 |
+
self.shaddah_first = ['~a', '~u', '~i', '~F', '~N', '~K']
|
52 |
+
self.tahkeel_chars = ['F','N','K','a', 'u', 'i', '~', 'o']
|
53 |
+
|
54 |
+
|
55 |
+
def clean_text(self, text):
|
56 |
+
text = re.sub(u'[%s]' % u'\u0640', '', text) # strip tatweel
|
57 |
+
text = text.replace('ٱ', 'ا')
|
58 |
+
return ' '.join(re.sub(u"[^\u0621-\u063A\u0640-\u0652\u0670\u0671\ufefb\ufef7\ufef5\ufef9 ]", " ", text, flags=re.UNICODE).split())
|
59 |
+
|
60 |
+
|
61 |
+
def check_match(self, text_with_tashkeel, letter_n_tashkeel_pairs):
|
62 |
+
text_with_tashkeel = text_with_tashkeel.strip()
|
63 |
+
# test if the reconstructed text with tashkeel is the same as the original one
|
64 |
+
syn_text = self.combine_tashkeel_with_text(letter_n_tashkeel_pairs)
|
65 |
+
return syn_text == text_with_tashkeel or syn_text == self.unify_shaddah_position(text_with_tashkeel)
|
66 |
+
|
67 |
+
|
68 |
+
def unify_shaddah_position(self, text_with_tashkeel):
|
69 |
+
# unify the order of shaddah and the harakah to make shaddah always at the beginning
|
70 |
+
for i in range(len(self.shaddah_first)):
|
71 |
+
text_with_tashkeel = text_with_tashkeel.replace(self.shaddah_last[i], self.shaddah_first[i])
|
72 |
+
return text_with_tashkeel
|
73 |
+
|
74 |
+
|
75 |
+
def split_tashkeel_from_text(self, text_with_tashkeel, test_match=True):
|
76 |
+
text_with_tashkeel = self.clean_text(text_with_tashkeel)
|
77 |
+
text_with_tashkeel = bw2ar.transliterate_text(text_with_tashkeel, 'ar2bw')
|
78 |
+
text_with_tashkeel = text_with_tashkeel.replace('`', '') # remove dagger 'alif
|
79 |
+
|
80 |
+
# unify the order of shaddah and the harakah to make shaddah always at the beginning
|
81 |
+
text_with_tashkeel = self.unify_shaddah_position(text_with_tashkeel)
|
82 |
+
|
83 |
+
# remove duplicated harakat
|
84 |
+
for i in range(len(self.tahkeel_chars)):
|
85 |
+
text_with_tashkeel = text_with_tashkeel.replace(self.tahkeel_chars[i]*2, self.tahkeel_chars[i])
|
86 |
+
|
87 |
+
letter_n_tashkeel_pairs = []
|
88 |
+
for i in range(len(text_with_tashkeel)): # go over the whole text
|
89 |
+
# check if the first character is a normal letter and the second character is a tashkeel
|
90 |
+
if i < (len(text_with_tashkeel) - 1) and not text_with_tashkeel[i] in self.tashkeel_list and text_with_tashkeel[i+1] in self.tashkeel_list:
|
91 |
+
# IMPORTANT: check if tashkeel is Shaddah, then there might be another Tashkeel char associated with it. If so,
|
92 |
+
# replace both Shaddah and the Tashkeel chars with the appropriate tag
|
93 |
+
if text_with_tashkeel[i+1] == '~':
|
94 |
+
# IMPORTANT: the following if statement depends on the concept of short circuit!!
|
95 |
+
# The first condition checks if there are still more chars before it access position i+2
|
96 |
+
# "text_with_tashkeel[i+2]" since it causes "index out of range" exception. Notice that
|
97 |
+
# Shaddah here is put in the first position before the Harakah.
|
98 |
+
if i+2 < len(text_with_tashkeel) and f'~{text_with_tashkeel[i+2]}' in self.inverse_tags:
|
99 |
+
letter_n_tashkeel_pairs.append((text_with_tashkeel[i], self.inverse_tags[f'~{text_with_tashkeel[i+2]}']))
|
100 |
+
else:
|
101 |
+
# if it is only Shaddah, just add it to the list
|
102 |
+
letter_n_tashkeel_pairs.append((text_with_tashkeel[i], '~'))
|
103 |
+
else:
|
104 |
+
letter_n_tashkeel_pairs.append((text_with_tashkeel[i], text_with_tashkeel[i+1]))
|
105 |
+
# if the character at position i is a normal letter and has no Tashkeel, then add
|
106 |
+
# it with the tag "self.no_tashkeel_tag"
|
107 |
+
# IMPORTANT: this elif block ensures also that there is no two or more consecutive tashkeel other than shaddah
|
108 |
+
elif not text_with_tashkeel[i] in self.tashkeel_list:
|
109 |
+
letter_n_tashkeel_pairs.append((text_with_tashkeel[i], self.no_tashkeel_tag))
|
110 |
+
|
111 |
+
if test_match:
|
112 |
+
# test if the split is done correctly by ensuring that we can retrieve back the original text
|
113 |
+
assert self.check_match(text_with_tashkeel, letter_n_tashkeel_pairs)
|
114 |
+
return [('<BOS>', '<BOS>')] + letter_n_tashkeel_pairs + [('<EOS>', '<EOS>')]
|
115 |
+
|
116 |
+
|
117 |
+
def combine_tashkeel_with_text(self, letter_n_tashkeel_pairs):
|
118 |
+
combined_with_tashkeel = []
|
119 |
+
for letter, tashkeel in letter_n_tashkeel_pairs:
|
120 |
+
combined_with_tashkeel.append(letter)
|
121 |
+
if tashkeel in self.tags:
|
122 |
+
combined_with_tashkeel.append(self.tags[tashkeel])
|
123 |
+
elif tashkeel != self.no_tashkeel_tag:
|
124 |
+
combined_with_tashkeel.append(tashkeel)
|
125 |
+
text = ''.join(combined_with_tashkeel)
|
126 |
+
return text
|
127 |
+
|
128 |
+
|
129 |
+
def encode(self, text_with_tashkeel, test_match=True):
|
130 |
+
letter_n_tashkeel_pairs = self.split_tashkeel_from_text(text_with_tashkeel, test_match)
|
131 |
+
text, tashkeel = zip(*letter_n_tashkeel_pairs)
|
132 |
+
input_ids = [self.letters_map[c] for c in text]
|
133 |
+
target_ids = [self.tashkeel_map[c] for c in tashkeel]
|
134 |
+
return torch.LongTensor(input_ids), torch.LongTensor(target_ids)
|
135 |
+
|
136 |
+
|
137 |
+
def filter_tashkeel(self, tashkeel):
|
138 |
+
tmp = []
|
139 |
+
for i, t in enumerate(tashkeel):
|
140 |
+
if i != 0 and t == '<BOS>':
|
141 |
+
t = self.no_tashkeel_tag
|
142 |
+
elif i != (len(tashkeel) - 1) and t == '<EOS>':
|
143 |
+
t = self.no_tashkeel_tag
|
144 |
+
tmp.append(t)
|
145 |
+
tashkeel = tmp
|
146 |
+
return tashkeel
|
147 |
+
|
148 |
+
|
149 |
+
def decode(self, input_ids, target_ids):
|
150 |
+
# print('input_ids.shape:', input_ids.shape)
|
151 |
+
# print('target_ids.shape:', target_ids.shape)
|
152 |
+
input_ids = input_ids.cpu().tolist()
|
153 |
+
target_ids = target_ids.cpu().tolist()
|
154 |
+
ar_texts = []
|
155 |
+
for j in range(len(input_ids)):
|
156 |
+
letters = [self.letters[i] for i in input_ids[j]]
|
157 |
+
tashkeel = [self.tashkeel_list[i] for i in target_ids[j]]
|
158 |
+
|
159 |
+
letters = list(filter(lambda x: x != '<BOS>' and x != '<EOS>' and x != '<PAD>', letters))
|
160 |
+
tashkeel = self.filter_tashkeel(tashkeel)
|
161 |
+
tashkeel = list(filter(lambda x: x != '<BOS>' and x != '<EOS>' and x != '<PAD>', tashkeel))
|
162 |
+
|
163 |
+
# VERY IMPORTANT NOTE: zip takes min(len(letters), len(tashkeel)) and discard the reset of letters / tashkeels
|
164 |
+
letter_n_tashkeel_pairs = list(zip(letters, tashkeel))
|
165 |
+
bw_text = self.combine_tashkeel_with_text(letter_n_tashkeel_pairs)
|
166 |
+
ar_text = bw2ar.transliterate_text(bw_text, 'bw2ar')
|
167 |
+
ar_texts.append(ar_text)
|
168 |
+
return ar_texts
|
169 |
+
|
170 |
+
def get_tashkeel_with_case_ending(self, text, case_ending=True):
|
171 |
+
text_split = self.split_tashkeel_from_text(text, test_match=False)
|
172 |
+
text_spaces_indecies = [i for i, el in enumerate(text_split) if el == (' ', '<NT>')]
|
173 |
+
new_text_split = []
|
174 |
+
for i, el in enumerate(text_split):
|
175 |
+
if not case_ending and (i+1) in text_spaces_indecies:
|
176 |
+
el = (el[0], '<NT>') # no case ending
|
177 |
+
new_text_split.append(el)
|
178 |
+
letters, tashkeel = zip(*new_text_split)
|
179 |
+
return letters, tashkeel
|
180 |
+
|
181 |
+
|
182 |
+
def compute_der(self, ref, hyp, case_ending=True):
|
183 |
+
_, ref_tashkeel = self.get_tashkeel_with_case_ending(ref, case_ending=case_ending)
|
184 |
+
_, hyp_tashkeel = self.get_tashkeel_with_case_ending(hyp, case_ending=case_ending)
|
185 |
+
ref_tashkeel = ' '.join(ref_tashkeel)
|
186 |
+
hyp_tashkeel = ' '.join(hyp_tashkeel)
|
187 |
+
return xer.wer(ref_tashkeel, hyp_tashkeel)
|
188 |
+
|
189 |
+
def compute_wer(self, ref, hyp, case_ending=True):
|
190 |
+
ref_letters, ref_tashkeel = self.get_tashkeel_with_case_ending(ref, case_ending=case_ending)
|
191 |
+
hyp_letters, hyp_tashkeel = self.get_tashkeel_with_case_ending(hyp, case_ending=case_ending)
|
192 |
+
ref_text_combined = self.combine_tashkeel_with_text(zip(ref_letters, ref_tashkeel))
|
193 |
+
hyp_text_combined = self.combine_tashkeel_with_text(zip(hyp_letters, hyp_tashkeel))
|
194 |
+
return xer.wer(ref_text_combined, hyp_text_combined)
|
195 |
+
|
196 |
+
def remove_tashkeel(self, text):
|
197 |
+
text = HARAKAT_PAT.sub('', text)
|
198 |
+
text = re.sub(u"[\u064E]", "", text, flags=re.UNICODE) # fattha
|
199 |
+
text = re.sub(u"[\u0671]", "", text, flags=re.UNICODE) # waSla
|
200 |
+
return text
|
201 |
+
|
202 |
+
|
203 |
+
|
204 |
+
if __name__ == '__main__':
|
205 |
+
import utils
|
206 |
+
from tqdm import tqdm
|
207 |
+
tokenizer = TashkeelTokenizer()
|
208 |
+
|
209 |
+
txt_folder_path = 'dataset/train'
|
210 |
+
prepared_lines = []
|
211 |
+
for filepath in utils.get_files(txt_folder_path, '*.txt'):
|
212 |
+
print(f'Reading file: {filepath}')
|
213 |
+
with open(filepath) as f1:
|
214 |
+
for line in f1:
|
215 |
+
clean_line = tokenizer.clean_text(line)
|
216 |
+
if clean_line != '':
|
217 |
+
prepared_lines.append(clean_line)
|
218 |
+
print(f'completed file: {filepath}')
|
219 |
+
|
220 |
+
good_sentences = []
|
221 |
+
bad_sentences = []
|
222 |
+
tokenized_sentences = []
|
223 |
+
for line in tqdm(prepared_lines):
|
224 |
+
try:
|
225 |
+
letter_n_tashkeel_pairs = tokenizer.split_tashkeel_from_text(line, test_match=True)
|
226 |
+
tokenized_sentences.append(letter_n_tashkeel_pairs)
|
227 |
+
good_sentences.append(line)
|
228 |
+
except AssertionError as e:
|
229 |
+
bad_sentences.append(line)
|
230 |
+
|
231 |
+
print('len(good_sentences), len(bad_sentences):', len(good_sentences), len(bad_sentences))
|
232 |
+
|
233 |
+
|
234 |
+
|
transformer.py
ADDED
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@author : Hyunwoong
|
3 |
+
@when : 2019-12-18
|
4 |
+
@homepage : https://github.com/gusdnd852
|
5 |
+
"""
|
6 |
+
|
7 |
+
import math
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
|
12 |
+
class EncoderLayer(nn.Module):
|
13 |
+
|
14 |
+
def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
|
15 |
+
super(EncoderLayer, self).__init__()
|
16 |
+
self.attention = MultiHeadAttention(d_model=d_model, n_head=n_head)
|
17 |
+
self.norm1 = LayerNorm(d_model=d_model)
|
18 |
+
self.dropout1 = nn.Dropout(p=drop_prob)
|
19 |
+
|
20 |
+
self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
|
21 |
+
self.norm2 = LayerNorm(d_model=d_model)
|
22 |
+
self.dropout2 = nn.Dropout(p=drop_prob)
|
23 |
+
|
24 |
+
def forward(self, x, s_mask):
|
25 |
+
# 1. compute self attention
|
26 |
+
_x = x
|
27 |
+
x = self.attention(q=x, k=x, v=x, mask=s_mask)
|
28 |
+
|
29 |
+
# 2. add and norm
|
30 |
+
x = self.dropout1(x)
|
31 |
+
x = self.norm1(x + _x)
|
32 |
+
|
33 |
+
# 3. positionwise feed forward network
|
34 |
+
_x = x
|
35 |
+
x = self.ffn(x)
|
36 |
+
|
37 |
+
# 4. add and norm
|
38 |
+
x = self.dropout2(x)
|
39 |
+
x = self.norm2(x + _x)
|
40 |
+
return x
|
41 |
+
|
42 |
+
|
43 |
+
class DecoderLayer(nn.Module):
|
44 |
+
|
45 |
+
def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
|
46 |
+
super(DecoderLayer, self).__init__()
|
47 |
+
self.self_attention = MultiHeadAttention(d_model=d_model, n_head=n_head)
|
48 |
+
self.norm1 = LayerNorm(d_model=d_model)
|
49 |
+
self.dropout1 = nn.Dropout(p=drop_prob)
|
50 |
+
|
51 |
+
self.enc_dec_attention = MultiHeadAttention(d_model=d_model, n_head=n_head)
|
52 |
+
self.norm2 = LayerNorm(d_model=d_model)
|
53 |
+
self.dropout2 = nn.Dropout(p=drop_prob)
|
54 |
+
|
55 |
+
self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
|
56 |
+
self.norm3 = LayerNorm(d_model=d_model)
|
57 |
+
self.dropout3 = nn.Dropout(p=drop_prob)
|
58 |
+
|
59 |
+
def forward(self, dec, enc, t_mask, s_mask):
|
60 |
+
# 1. compute self attention
|
61 |
+
_x = dec
|
62 |
+
x = self.self_attention(q=dec, k=dec, v=dec, mask=t_mask)
|
63 |
+
|
64 |
+
# 2. add and norm
|
65 |
+
x = self.dropout1(x)
|
66 |
+
x = self.norm1(x + _x)
|
67 |
+
|
68 |
+
if enc is not None:
|
69 |
+
# 3. compute encoder - decoder attention
|
70 |
+
_x = x
|
71 |
+
x = self.enc_dec_attention(q=x, k=enc, v=enc, mask=s_mask)
|
72 |
+
|
73 |
+
# 4. add and norm
|
74 |
+
x = self.dropout2(x)
|
75 |
+
x = self.norm2(x + _x)
|
76 |
+
|
77 |
+
# 5. positionwise feed forward network
|
78 |
+
_x = x
|
79 |
+
x = self.ffn(x)
|
80 |
+
|
81 |
+
# 6. add and norm
|
82 |
+
x = self.dropout3(x)
|
83 |
+
x = self.norm3(x + _x)
|
84 |
+
return x
|
85 |
+
|
86 |
+
|
87 |
+
class ScaleDotProductAttention(nn.Module):
|
88 |
+
"""
|
89 |
+
compute scale dot product attention
|
90 |
+
|
91 |
+
Query : given sentence that we focused on (decoder)
|
92 |
+
Key : every sentence to check relationship with Qeury(encoder)
|
93 |
+
Value : every sentence same with Key (encoder)
|
94 |
+
"""
|
95 |
+
|
96 |
+
def __init__(self):
|
97 |
+
super(ScaleDotProductAttention, self).__init__()
|
98 |
+
self.softmax = nn.Softmax(dim=-1)
|
99 |
+
|
100 |
+
def forward(self, q, k, v, mask=None, e=1e-12):
|
101 |
+
# input is 4 dimension tensor
|
102 |
+
# [batch_size, head, length, d_tensor]
|
103 |
+
batch_size, head, length, d_tensor = k.size()
|
104 |
+
|
105 |
+
# 1. dot product Query with Key^T to compute similarity
|
106 |
+
k_t = k.transpose(2, 3) # transpose
|
107 |
+
score = (q @ k_t) / math.sqrt(d_tensor) # scaled dot product
|
108 |
+
|
109 |
+
# 2. apply masking (opt)
|
110 |
+
if mask is not None:
|
111 |
+
score = score.masked_fill(mask == 0, -10000)
|
112 |
+
|
113 |
+
# 3. pass them softmax to make [0, 1] range
|
114 |
+
score = self.softmax(score)
|
115 |
+
|
116 |
+
# 4. multiply with Value
|
117 |
+
v = score @ v
|
118 |
+
|
119 |
+
return v, score
|
120 |
+
|
121 |
+
|
122 |
+
class PositionwiseFeedForward(nn.Module):
|
123 |
+
|
124 |
+
def __init__(self, d_model, hidden, drop_prob=0.1):
|
125 |
+
super(PositionwiseFeedForward, self).__init__()
|
126 |
+
self.linear1 = nn.Linear(d_model, hidden)
|
127 |
+
self.linear2 = nn.Linear(hidden, d_model)
|
128 |
+
self.relu = nn.ReLU()
|
129 |
+
self.dropout = nn.Dropout(p=drop_prob)
|
130 |
+
|
131 |
+
def forward(self, x):
|
132 |
+
x = self.linear1(x)
|
133 |
+
x = self.relu(x)
|
134 |
+
x = self.dropout(x)
|
135 |
+
x = self.linear2(x)
|
136 |
+
return x
|
137 |
+
|
138 |
+
|
139 |
+
class MultiHeadAttention(nn.Module):
|
140 |
+
|
141 |
+
def __init__(self, d_model, n_head):
|
142 |
+
super(MultiHeadAttention, self).__init__()
|
143 |
+
self.n_head = n_head
|
144 |
+
self.attention = ScaleDotProductAttention()
|
145 |
+
self.w_q = nn.Linear(d_model, d_model, bias=False)
|
146 |
+
self.w_k = nn.Linear(d_model, d_model, bias=False)
|
147 |
+
self.w_v = nn.Linear(d_model, d_model, bias=False)
|
148 |
+
self.w_concat = nn.Linear(d_model, d_model, bias=False)
|
149 |
+
|
150 |
+
def forward(self, q, k, v, mask=None):
|
151 |
+
# 1. dot product with weight matrices
|
152 |
+
q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
|
153 |
+
|
154 |
+
# 2. split tensor by number of heads
|
155 |
+
q, k, v = self.split(q), self.split(k), self.split(v)
|
156 |
+
|
157 |
+
# 3. do scale dot product to compute similarity
|
158 |
+
out, attention = self.attention(q, k, v, mask=mask)
|
159 |
+
|
160 |
+
# 4. concat and pass to linear layer
|
161 |
+
out = self.concat(out)
|
162 |
+
out = self.w_concat(out)
|
163 |
+
|
164 |
+
# 5. visualize attention map
|
165 |
+
# TODO : we should implement visualization
|
166 |
+
|
167 |
+
return out
|
168 |
+
|
169 |
+
def split(self, tensor):
|
170 |
+
"""
|
171 |
+
split tensor by number of head
|
172 |
+
|
173 |
+
:param tensor: [batch_size, length, d_model]
|
174 |
+
:return: [batch_size, head, length, d_tensor]
|
175 |
+
"""
|
176 |
+
batch_size, length, d_model = tensor.size()
|
177 |
+
|
178 |
+
d_tensor = d_model // self.n_head
|
179 |
+
tensor = tensor.view(batch_size, length, self.n_head, d_tensor).transpose(1, 2)
|
180 |
+
# it is similar with group convolution (split by number of heads)
|
181 |
+
|
182 |
+
return tensor
|
183 |
+
|
184 |
+
def concat(self, tensor):
|
185 |
+
"""
|
186 |
+
inverse function of self.split(tensor : torch.Tensor)
|
187 |
+
|
188 |
+
:param tensor: [batch_size, head, length, d_tensor]
|
189 |
+
:return: [batch_size, length, d_model]
|
190 |
+
"""
|
191 |
+
batch_size, head, length, d_tensor = tensor.size()
|
192 |
+
d_model = head * d_tensor
|
193 |
+
|
194 |
+
tensor = tensor.transpose(1, 2).contiguous().view(batch_size, length, d_model)
|
195 |
+
return tensor
|
196 |
+
|
197 |
+
|
198 |
+
class LayerNorm(nn.Module):
|
199 |
+
def __init__(self, d_model, eps=1e-12):
|
200 |
+
super(LayerNorm, self).__init__()
|
201 |
+
self.gamma = nn.Parameter(torch.ones(d_model))
|
202 |
+
self.beta = nn.Parameter(torch.zeros(d_model))
|
203 |
+
self.eps = eps
|
204 |
+
|
205 |
+
def forward(self, x):
|
206 |
+
mean = x.mean(-1, keepdim=True)
|
207 |
+
var = x.var(-1, unbiased=False, keepdim=True)
|
208 |
+
# '-1' means last dimension.
|
209 |
+
|
210 |
+
out = (x - mean) / torch.sqrt(var + self.eps)
|
211 |
+
out = self.gamma * out + self.beta
|
212 |
+
return out
|
213 |
+
|
214 |
+
|
215 |
+
class TransformerEmbedding(nn.Module):
|
216 |
+
"""
|
217 |
+
token embedding + positional encoding (sinusoid)
|
218 |
+
positional encoding can give positional information to network
|
219 |
+
"""
|
220 |
+
|
221 |
+
def __init__(self, vocab_size, d_model, max_len, drop_prob, padding_idx, learnable_pos_emb=True):
|
222 |
+
"""
|
223 |
+
class for word embedding that included positional information
|
224 |
+
|
225 |
+
:param vocab_size: size of vocabulary
|
226 |
+
:param d_model: dimensions of model
|
227 |
+
"""
|
228 |
+
super(TransformerEmbedding, self).__init__()
|
229 |
+
self.tok_emb = TokenEmbedding(vocab_size, d_model, padding_idx)
|
230 |
+
if learnable_pos_emb:
|
231 |
+
self.pos_emb = LearnablePositionalEncoding(d_model, max_len)
|
232 |
+
else:
|
233 |
+
self.pos_emb = SinusoidalPositionalEncoding(d_model, max_len)
|
234 |
+
self.drop_out = nn.Dropout(p=drop_prob)
|
235 |
+
|
236 |
+
def forward(self, x):
|
237 |
+
tok_emb = self.tok_emb(x)
|
238 |
+
pos_emb = self.pos_emb(x).to(tok_emb.device)
|
239 |
+
return self.drop_out(tok_emb + pos_emb)
|
240 |
+
|
241 |
+
|
242 |
+
class TokenEmbedding(nn.Embedding):
|
243 |
+
"""
|
244 |
+
Token Embedding using torch.nn
|
245 |
+
they will dense representation of word using weighted matrix
|
246 |
+
"""
|
247 |
+
|
248 |
+
def __init__(self, vocab_size, d_model, padding_idx):
|
249 |
+
"""
|
250 |
+
class for token embedding that included positional information
|
251 |
+
|
252 |
+
:param vocab_size: size of vocabulary
|
253 |
+
:param d_model: dimensions of model
|
254 |
+
"""
|
255 |
+
super(TokenEmbedding, self).__init__(vocab_size, d_model, padding_idx=padding_idx)
|
256 |
+
|
257 |
+
|
258 |
+
class SinusoidalPositionalEncoding(nn.Module):
|
259 |
+
"""
|
260 |
+
compute sinusoid encoding.
|
261 |
+
"""
|
262 |
+
|
263 |
+
def __init__(self, d_model, max_len):
|
264 |
+
"""
|
265 |
+
constructor of sinusoid encoding class
|
266 |
+
|
267 |
+
:param d_model: dimension of model
|
268 |
+
:param max_len: max sequence length
|
269 |
+
|
270 |
+
"""
|
271 |
+
super(SinusoidalPositionalEncoding, self).__init__()
|
272 |
+
|
273 |
+
# same size with input matrix (for adding with input matrix)
|
274 |
+
self.encoding = torch.zeros(max_len, d_model)
|
275 |
+
self.encoding.requires_grad = False # we don't need to compute gradient
|
276 |
+
|
277 |
+
pos = torch.arange(0, max_len)
|
278 |
+
pos = pos.float().unsqueeze(dim=1)
|
279 |
+
# 1D => 2D unsqueeze to represent word's position
|
280 |
+
|
281 |
+
_2i = torch.arange(0, d_model, step=2).float()
|
282 |
+
# 'i' means index of d_model (e.g. embedding size = 50, 'i' = [0,50])
|
283 |
+
# "step=2" means 'i' multiplied with two (same with 2 * i)
|
284 |
+
|
285 |
+
self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
|
286 |
+
self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
|
287 |
+
# compute positional encoding to consider positional information of words
|
288 |
+
|
289 |
+
def forward(self, x):
|
290 |
+
# self.encoding
|
291 |
+
# [max_len = 512, d_model = 512]
|
292 |
+
|
293 |
+
batch_size, seq_len = x.size()
|
294 |
+
# [batch_size = 128, seq_len = 30]
|
295 |
+
|
296 |
+
return self.encoding[:seq_len, :]
|
297 |
+
# [seq_len = 30, d_model = 512]
|
298 |
+
# it will add with tok_emb : [128, 30, 512]
|
299 |
+
|
300 |
+
|
301 |
+
class LearnablePositionalEncoding(nn.Module):
|
302 |
+
"""
|
303 |
+
compute sinusoid encoding.
|
304 |
+
"""
|
305 |
+
|
306 |
+
def __init__(self, d_model, max_seq_len):
|
307 |
+
"""
|
308 |
+
constructor of learnable positonal encoding class
|
309 |
+
|
310 |
+
:param d_model: dimension of model
|
311 |
+
:param max_seq_len: max sequence length
|
312 |
+
|
313 |
+
"""
|
314 |
+
super(LearnablePositionalEncoding, self).__init__()
|
315 |
+
self.max_seq_len = max_seq_len
|
316 |
+
self.wpe = nn.Embedding(max_seq_len, d_model)
|
317 |
+
|
318 |
+
def forward(self, x):
|
319 |
+
# self.encoding
|
320 |
+
# [max_len = 512, d_model = 512]
|
321 |
+
device = x.device
|
322 |
+
batch_size, seq_len = x.size()
|
323 |
+
assert seq_len <= self.max_seq_len, f"Cannot forward sequence of length {seq_len}, max_seq_len is {self.max_seq_len}"
|
324 |
+
pos = torch.arange(0, seq_len, dtype=torch.long, device=device) # shape (seq_len)
|
325 |
+
pos_emb = self.wpe(pos) # position embeddings of shape (seq_len, d_model)
|
326 |
+
|
327 |
+
return pos_emb
|
328 |
+
# [seq_len = 30, d_model = 512]
|
329 |
+
# it will add with tok_emb : [128, 30, 512]
|
330 |
+
|
331 |
+
|
332 |
+
class Encoder(nn.Module):
|
333 |
+
|
334 |
+
def __init__(self, enc_voc_size, max_len, d_model, ffn_hidden, n_head, n_layers, drop_prob, padding_idx, learnable_pos_emb=True):
|
335 |
+
super().__init__()
|
336 |
+
self.emb = TransformerEmbedding(d_model=d_model,
|
337 |
+
max_len=max_len,
|
338 |
+
vocab_size=enc_voc_size,
|
339 |
+
drop_prob=drop_prob,
|
340 |
+
padding_idx=padding_idx,
|
341 |
+
learnable_pos_emb=learnable_pos_emb
|
342 |
+
)
|
343 |
+
|
344 |
+
self.layers = nn.ModuleList([EncoderLayer(d_model=d_model,
|
345 |
+
ffn_hidden=ffn_hidden,
|
346 |
+
n_head=n_head,
|
347 |
+
drop_prob=drop_prob)
|
348 |
+
for _ in range(n_layers)])
|
349 |
+
|
350 |
+
def forward(self, x, s_mask):
|
351 |
+
x = self.emb(x)
|
352 |
+
|
353 |
+
for layer in self.layers:
|
354 |
+
x = layer(x, s_mask)
|
355 |
+
|
356 |
+
return x
|
357 |
+
|
358 |
+
class Decoder(nn.Module):
|
359 |
+
def __init__(self, dec_voc_size, max_len, d_model, ffn_hidden, n_head, n_layers, drop_prob, padding_idx, learnable_pos_emb=True):
|
360 |
+
super().__init__()
|
361 |
+
self.emb = TransformerEmbedding(d_model=d_model,
|
362 |
+
drop_prob=drop_prob,
|
363 |
+
max_len=max_len,
|
364 |
+
vocab_size=dec_voc_size,
|
365 |
+
padding_idx=padding_idx,
|
366 |
+
learnable_pos_emb=learnable_pos_emb
|
367 |
+
)
|
368 |
+
|
369 |
+
self.layers = nn.ModuleList([DecoderLayer(d_model=d_model,
|
370 |
+
ffn_hidden=ffn_hidden,
|
371 |
+
n_head=n_head,
|
372 |
+
drop_prob=drop_prob)
|
373 |
+
for _ in range(n_layers)])
|
374 |
+
|
375 |
+
self.linear = nn.Linear(d_model, dec_voc_size)
|
376 |
+
|
377 |
+
def forward(self, trg, enc_src, trg_mask, src_mask):
|
378 |
+
trg = self.emb(trg)
|
379 |
+
|
380 |
+
for layer in self.layers:
|
381 |
+
trg = layer(trg, enc_src, trg_mask, src_mask)
|
382 |
+
|
383 |
+
# pass to LM head
|
384 |
+
output = self.linear(trg)
|
385 |
+
return output
|
386 |
+
|
387 |
+
class Transformer(nn.Module):
|
388 |
+
|
389 |
+
def __init__(self, src_pad_idx, trg_pad_idx, enc_voc_size, dec_voc_size, d_model, n_head, max_len,
|
390 |
+
ffn_hidden, n_layers, drop_prob, learnable_pos_emb=True):
|
391 |
+
super().__init__()
|
392 |
+
self.src_pad_idx = src_pad_idx
|
393 |
+
self.trg_pad_idx = trg_pad_idx
|
394 |
+
self.encoder = Encoder(d_model=d_model,
|
395 |
+
n_head=n_head,
|
396 |
+
max_len=max_len,
|
397 |
+
ffn_hidden=ffn_hidden,
|
398 |
+
enc_voc_size=enc_voc_size,
|
399 |
+
drop_prob=drop_prob,
|
400 |
+
n_layers=n_layers,
|
401 |
+
padding_idx=src_pad_idx,
|
402 |
+
learnable_pos_emb=learnable_pos_emb)
|
403 |
+
|
404 |
+
self.decoder = Decoder(d_model=d_model,
|
405 |
+
n_head=n_head,
|
406 |
+
max_len=max_len,
|
407 |
+
ffn_hidden=ffn_hidden,
|
408 |
+
dec_voc_size=dec_voc_size,
|
409 |
+
drop_prob=drop_prob,
|
410 |
+
n_layers=n_layers,
|
411 |
+
padding_idx=trg_pad_idx,
|
412 |
+
learnable_pos_emb=learnable_pos_emb)
|
413 |
+
|
414 |
+
def get_device(self):
|
415 |
+
return next(self.parameters()).device
|
416 |
+
|
417 |
+
def forward(self, src, trg):
|
418 |
+
device = self.get_device()
|
419 |
+
src_mask = self.make_pad_mask(src, src, self.src_pad_idx, self.src_pad_idx).to(device)
|
420 |
+
src_trg_mask = self.make_pad_mask(trg, src, self.trg_pad_idx, self.src_pad_idx).to(device)
|
421 |
+
trg_mask = self.make_pad_mask(trg, trg, self.trg_pad_idx, self.trg_pad_idx).to(device) * \
|
422 |
+
self.make_no_peak_mask(trg, trg).to(device)
|
423 |
+
|
424 |
+
#print(src_mask)
|
425 |
+
#print('-'*100)
|
426 |
+
#print(trg_mask)
|
427 |
+
enc_src = self.encoder(src, src_mask)
|
428 |
+
output = self.decoder(trg, enc_src, trg_mask, src_trg_mask)
|
429 |
+
return output
|
430 |
+
|
431 |
+
def make_pad_mask(self, q, k, q_pad_idx, k_pad_idx):
|
432 |
+
len_q, len_k = q.size(1), k.size(1)
|
433 |
+
|
434 |
+
# batch_size x 1 x 1 x len_k
|
435 |
+
k = k.ne(k_pad_idx).unsqueeze(1).unsqueeze(2)
|
436 |
+
# batch_size x 1 x len_q x len_k
|
437 |
+
k = k.repeat(1, 1, len_q, 1)
|
438 |
+
|
439 |
+
# batch_size x 1 x len_q x 1
|
440 |
+
q = q.ne(q_pad_idx).unsqueeze(1).unsqueeze(3)
|
441 |
+
# batch_size x 1 x len_q x len_k
|
442 |
+
q = q.repeat(1, 1, 1, len_k)
|
443 |
+
|
444 |
+
mask = k & q
|
445 |
+
return mask
|
446 |
+
|
447 |
+
def make_no_peak_mask(self, q, k):
|
448 |
+
len_q, len_k = q.size(1), k.size(1)
|
449 |
+
|
450 |
+
# len_q x len_k
|
451 |
+
mask = torch.tril(torch.ones(len_q, len_k)).type(torch.BoolTensor)
|
452 |
+
|
453 |
+
return mask
|
454 |
+
|
455 |
+
|
456 |
+
def make_pad_mask(x, pad_idx):
|
457 |
+
q = k = x
|
458 |
+
q_pad_idx = k_pad_idx = pad_idx
|
459 |
+
len_q, len_k = q.size(1), k.size(1)
|
460 |
+
|
461 |
+
# batch_size x 1 x 1 x len_k
|
462 |
+
k = k.ne(k_pad_idx).unsqueeze(1).unsqueeze(2)
|
463 |
+
# batch_size x 1 x len_q x len_k
|
464 |
+
k = k.repeat(1, 1, len_q, 1)
|
465 |
+
|
466 |
+
# batch_size x 1 x len_q x 1
|
467 |
+
q = q.ne(q_pad_idx).unsqueeze(1).unsqueeze(3)
|
468 |
+
# batch_size x 1 x len_q x len_k
|
469 |
+
q = q.repeat(1, 1, 1, len_k)
|
470 |
+
|
471 |
+
mask = k & q
|
472 |
+
return mask
|
473 |
+
|
474 |
+
|
475 |
+
from torch.nn.utils.rnn import pad_sequence
|
476 |
+
# x_list is a list of tensors of shape TxH where T is the seqlen and H is the feats dim
|
477 |
+
def pad_seq_v2(sequences, batch_first=True, padding_value=0.0, prepadding=True):
|
478 |
+
lens = [i.shape[0]for i in sequences]
|
479 |
+
padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=padding_value) # NxTxH
|
480 |
+
if prepadding:
|
481 |
+
for i in range(len(lens)):
|
482 |
+
padded_sequences[i] = padded_sequences[i].roll(-lens[i])
|
483 |
+
if not batch_first:
|
484 |
+
padded_sequences = padded_sequences.transpose(0, 1) # TxNxH
|
485 |
+
return padded_sequences
|
486 |
+
|
487 |
+
|
488 |
+
|
489 |
+
if __name__ == '__main__':
|
490 |
+
import torch
|
491 |
+
import random
|
492 |
+
import numpy as np
|
493 |
+
|
494 |
+
rand_seed = 10
|
495 |
+
|
496 |
+
device = 'cpu'
|
497 |
+
|
498 |
+
# model parameter setting
|
499 |
+
batch_size = 128
|
500 |
+
max_len = 256
|
501 |
+
d_model = 512
|
502 |
+
n_layers = 3
|
503 |
+
n_heads = 16
|
504 |
+
ffn_hidden = 2048
|
505 |
+
drop_prob = 0.1
|
506 |
+
|
507 |
+
# optimizer parameter setting
|
508 |
+
init_lr = 1e-5
|
509 |
+
factor = 0.9
|
510 |
+
adam_eps = 5e-9
|
511 |
+
patience = 10
|
512 |
+
warmup = 100
|
513 |
+
epoch = 1000
|
514 |
+
clip = 1.0
|
515 |
+
weight_decay = 5e-4
|
516 |
+
inf = float('inf')
|
517 |
+
|
518 |
+
src_pad_idx = 2
|
519 |
+
trg_pad_idx = 3
|
520 |
+
|
521 |
+
enc_voc_size = 37
|
522 |
+
dec_voc_size = 15
|
523 |
+
model = Transformer(src_pad_idx=src_pad_idx,
|
524 |
+
trg_pad_idx=trg_pad_idx,
|
525 |
+
d_model=d_model,
|
526 |
+
enc_voc_size=enc_voc_size,
|
527 |
+
dec_voc_size=dec_voc_size,
|
528 |
+
max_len=max_len,
|
529 |
+
ffn_hidden=ffn_hidden,
|
530 |
+
n_head=n_heads,
|
531 |
+
n_layers=n_layers,
|
532 |
+
drop_prob=drop_prob
|
533 |
+
).to(device)
|
534 |
+
|
535 |
+
random.seed(rand_seed)
|
536 |
+
# Set the seed to 0 for reproducible results
|
537 |
+
np.random.seed(rand_seed)
|
538 |
+
torch.manual_seed(rand_seed)
|
539 |
+
|
540 |
+
x_list = [
|
541 |
+
torch.tensor([[1, 1]]).transpose(0, 1), # 2
|
542 |
+
torch.tensor([[1, 1, 1, 1, 1, 1, 1]]).transpose(0, 1), # 7
|
543 |
+
torch.tensor([[1, 1, 1]]).transpose(0, 1) # 3
|
544 |
+
]
|
545 |
+
|
546 |
+
|
547 |
+
src_pad_idx = model.src_pad_idx
|
548 |
+
trg_pad_idx = model.trg_pad_idx
|
549 |
+
|
550 |
+
src = pad_seq_v2(x_list, padding_value=src_pad_idx, prepadding=False).squeeze(2)
|
551 |
+
trg = pad_seq_v2(x_list, padding_value=trg_pad_idx, prepadding=False).squeeze(2)
|
552 |
+
out = model(src, trg)
|
553 |
+
|
554 |
+
|
555 |
+
|
556 |
+
|
557 |
+
|
558 |
+
|
559 |
+
|
utils.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
note: this code is used in bw2ar.py file
|
3 |
+
"""
|
4 |
+
|
5 |
+
#!/usr/bin/python
|
6 |
+
# -*- coding=utf-8 -*-
|
7 |
+
#---
|
8 |
+
# $Id: arabic.py,v 1.6 2003/04/22 17:18:22 elzubeir Exp $
|
9 |
+
#
|
10 |
+
# ------------
|
11 |
+
# Description:
|
12 |
+
# ------------
|
13 |
+
#
|
14 |
+
# Arabic codes
|
15 |
+
#
|
16 |
+
# (C) Copyright 2003, Arabeyes, Mohammed Elzubeir
|
17 |
+
# (C) Copyright 2019, Faris Abdullah Alasmary
|
18 |
+
# -----------------
|
19 |
+
# Revision Details: (Updated by Revision Control System)
|
20 |
+
# -----------------
|
21 |
+
# $Date: 2003/04/22 17:18:22 $
|
22 |
+
# $Author: elzubeir $
|
23 |
+
# $Revision: 1.6 $
|
24 |
+
# $Source: /home/arabeyes/cvs/projects/duali/pyduali/pyduali/arabic.py,v $
|
25 |
+
#
|
26 |
+
# This program is written under the BSD License.
|
27 |
+
#---
|
28 |
+
""" Constants for arabic """
|
29 |
+
import re
|
30 |
+
COMMA = u'\u060C'
|
31 |
+
SEMICOLON = u'\u061B'
|
32 |
+
QUESTION = u'\u061F'
|
33 |
+
HAMZA = u'\u0621'
|
34 |
+
ALEF_MADDA = u'\u0622'
|
35 |
+
ALEF_HAMZA_ABOVE = u'\u0623'
|
36 |
+
WAW_HAMZA = u'\u0624'
|
37 |
+
ALEF_HAMZA_BELOW = u'\u0625'
|
38 |
+
YEH_HAMZA = u'\u0626'
|
39 |
+
ALEF = u'\u0627'
|
40 |
+
BEH = u'\u0628'
|
41 |
+
TEH_MARBUTA = u'\u0629'
|
42 |
+
TEH = u'\u062a'
|
43 |
+
THEH = u'\u062b'
|
44 |
+
JEEM = u'\u062c'
|
45 |
+
HAH = u'\u062d'
|
46 |
+
KHAH = u'\u062e'
|
47 |
+
DAL = u'\u062f'
|
48 |
+
THAL = u'\u0630'
|
49 |
+
REH = u'\u0631'
|
50 |
+
ZAIN = u'\u0632'
|
51 |
+
SEEN = u'\u0633'
|
52 |
+
SHEEN = u'\u0634'
|
53 |
+
SAD = u'\u0635'
|
54 |
+
DAD = u'\u0636'
|
55 |
+
TAH = u'\u0637'
|
56 |
+
ZAH = u'\u0638'
|
57 |
+
AIN = u'\u0639'
|
58 |
+
GHAIN = u'\u063a'
|
59 |
+
TATWEEL = u'\u0640'
|
60 |
+
FEH = u'\u0641'
|
61 |
+
QAF = u'\u0642'
|
62 |
+
KAF = u'\u0643'
|
63 |
+
LAM = u'\u0644'
|
64 |
+
MEEM = u'\u0645'
|
65 |
+
NOON = u'\u0646'
|
66 |
+
HEH = u'\u0647'
|
67 |
+
WAW = u'\u0648'
|
68 |
+
ALEF_MAKSURA = u'\u0649'
|
69 |
+
YEH = u'\u064a'
|
70 |
+
MADDA_ABOVE = u'\u0653'
|
71 |
+
HAMZA_ABOVE = u'\u0654'
|
72 |
+
HAMZA_BELOW = u'\u0655'
|
73 |
+
ZERO = u'\u0660'
|
74 |
+
ONE = u'\u0661'
|
75 |
+
TWO = u'\u0662'
|
76 |
+
THREE = u'\u0663'
|
77 |
+
FOUR = u'\u0664'
|
78 |
+
FIVE = u'\u0665'
|
79 |
+
SIX = u'\u0666'
|
80 |
+
SEVEN = u'\u0667'
|
81 |
+
EIGHT = u'\u0668'
|
82 |
+
NINE = u'\u0669'
|
83 |
+
PERCENT = u'\u066a'
|
84 |
+
DECIMAL = u'\u066b'
|
85 |
+
THOUSANDS = u'\u066c'
|
86 |
+
STAR = u'\u066d'
|
87 |
+
MINI_ALEF = u'\u0670'
|
88 |
+
ALEF_WASLA = u'\u0671'
|
89 |
+
FULL_STOP = u'\u06d4'
|
90 |
+
BYTE_ORDER_MARK = u'\ufeff'
|
91 |
+
|
92 |
+
# Diacritics
|
93 |
+
FATHATAN = u'\u064b'
|
94 |
+
DAMMATAN = u'\u064c'
|
95 |
+
KASRATAN = u'\u064d'
|
96 |
+
FATHA = u'\u064e'
|
97 |
+
DAMMA = u'\u064f'
|
98 |
+
KASRA = u'\u0650'
|
99 |
+
SHADDA = u'\u0651'
|
100 |
+
SUKUN = u'\u0652'
|
101 |
+
|
102 |
+
#Ligatures
|
103 |
+
LAM_ALEF = u'\ufefb'
|
104 |
+
LAM_ALEF_HAMZA_ABOVE = u'\ufef7'
|
105 |
+
LAM_ALEF_HAMZA_BELOW = u'\ufef9'
|
106 |
+
LAM_ALEF_MADDA_ABOVE = u'\ufef5'
|
107 |
+
SIMPLE_LAM_ALEF = u'\u0644\u0627'
|
108 |
+
SIMPLE_LAM_ALEF_HAMZA_ABOVE = u'\u0644\u0623'
|
109 |
+
SIMPLE_LAM_ALEF_HAMZA_BELOW = u'\u0644\u0625'
|
110 |
+
SIMPLE_LAM_ALEF_MADDA_ABOVE = u'\u0644\u0622'
|
111 |
+
|
112 |
+
|
113 |
+
HARAKAT_PAT = re.compile(u"["+u"".join([FATHATAN, DAMMATAN, KASRATAN,
|
114 |
+
FATHA, DAMMA, KASRA, SUKUN,
|
115 |
+
SHADDA])+u"]")
|
116 |
+
HAMZAT_PAT = re.compile(u"["+u"".join([WAW_HAMZA, YEH_HAMZA])+u"]")
|
117 |
+
ALEFAT_PAT = re.compile(u"["+u"".join([ALEF_MADDA, ALEF_HAMZA_ABOVE,
|
118 |
+
ALEF_HAMZA_BELOW, HAMZA_ABOVE,
|
119 |
+
HAMZA_BELOW])+u"]")
|
120 |
+
LAMALEFAT_PAT = re.compile(u"["+u"".join([LAM_ALEF,
|
121 |
+
LAM_ALEF_HAMZA_ABOVE,
|
122 |
+
LAM_ALEF_HAMZA_BELOW,
|
123 |
+
LAM_ALEF_MADDA_ABOVE])+u"]")
|
124 |
+
|
125 |
+
def strip_tashkeel(text):
|
126 |
+
text = HARAKAT_PAT.sub('', text)
|
127 |
+
text = re.sub(u"[\u064E]", "", text, flags=re.UNICODE) # fattha
|
128 |
+
text = re.sub(u"[\u0671]", "", text, flags=re.UNICODE) # waSla
|
129 |
+
return text
|
130 |
+
|
131 |
+
def strip_tatweel(text):
|
132 |
+
return re.sub(u'[%s]' % TATWEEL, '', text)
|
133 |
+
|
134 |
+
# remove removing Tashkeel + removing Tatweel + non Arabic chars
|
135 |
+
def remove_non_arabic(text):
|
136 |
+
text = strip_tashkeel(text)
|
137 |
+
text = strip_tatweel(text)
|
138 |
+
return ' '.join(re.sub(u"[^\u0621-\u063A\u0641-\u064A ]", " ", text, flags=re.UNICODE).split())
|
139 |
+
|
xer.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@author
|
3 |
+
______ _ _
|
4 |
+
| ____| (_) /\ | |
|
5 |
+
| |__ __ _ _ __ _ ___ / \ | | __ _ ___ _ __ ___ __ _ _ __ _ _
|
6 |
+
| __/ _` | '__| / __| / /\ \ | |/ _` / __| '_ ` _ \ / _` | '__| | | |
|
7 |
+
| | | (_| | | | \__ \ / ____ \| | (_| \__ \ | | | | | (_| | | | |_| |
|
8 |
+
|_| \__,_|_| |_|___/ /_/ \_\_|\__,_|___/_| |_| |_|\__,_|_| \__, |
|
9 |
+
__/ |
|
10 |
+
|___/
|
11 |
+
Email: [email protected]
|
12 |
+
Date: Mar 15, 2022
|
13 |
+
"""
|
14 |
+
|
15 |
+
# pip install git+https://github.com/pzelasko/kaldialign.git
|
16 |
+
|
17 |
+
from kaldialign import edit_distance
|
18 |
+
|
19 |
+
|
20 |
+
def cer(ref, hyp):
|
21 |
+
"""
|
22 |
+
Computes the Character Error Rate, defined as the edit distance.
|
23 |
+
|
24 |
+
Arguments:
|
25 |
+
ref (string): a space-separated ground truth string
|
26 |
+
hyp (string): a space-separated hypothesis
|
27 |
+
"""
|
28 |
+
ref, hyp, = ref.replace(' ', '').strip(), hyp.replace(' ', '').strip()
|
29 |
+
info = edit_distance(ref, hyp)
|
30 |
+
distance = info['total']
|
31 |
+
ref_length = float(len(ref))
|
32 |
+
|
33 |
+
data = {
|
34 |
+
'insertions': info['ins'],
|
35 |
+
'deletions': info['del'],
|
36 |
+
'substitutions': info['sub'],
|
37 |
+
'distance': distance,
|
38 |
+
'ref_length': ref_length,
|
39 |
+
'Error Rate': (distance / ref_length) * 100
|
40 |
+
}
|
41 |
+
|
42 |
+
return data
|
43 |
+
|
44 |
+
|
45 |
+
def wer(ref, hyp):
|
46 |
+
"""
|
47 |
+
Computes the Word Error Rate, defined as the edit distance between the
|
48 |
+
two provided sentences after tokenizing to words.
|
49 |
+
Arguments:
|
50 |
+
ref (string): a space-separated ground truth string
|
51 |
+
hyp (string): a space-separated hypothesis
|
52 |
+
"""
|
53 |
+
|
54 |
+
# build mapping of words to integers
|
55 |
+
b = set(ref.split() + hyp.split())
|
56 |
+
word2char = dict(zip(b, range(len(b))))
|
57 |
+
|
58 |
+
# map the words to a char array (Levenshtein packages only accepts strings)
|
59 |
+
w1 = [chr(word2char[w]) for w in ref.split()]
|
60 |
+
w2 = [chr(word2char[w]) for w in hyp.split()]
|
61 |
+
|
62 |
+
info = edit_distance(''.join(w1), ''.join(w2))
|
63 |
+
distance = info['total']
|
64 |
+
ref_length = float(len(w1))
|
65 |
+
|
66 |
+
data = {
|
67 |
+
'insertions': info['ins'],
|
68 |
+
'deletions': info['del'],
|
69 |
+
'substitutions': info['sub'],
|
70 |
+
'distance': distance,
|
71 |
+
'ref_length': ref_length,
|
72 |
+
'Error Rate': (distance / ref_length) * 100
|
73 |
+
}
|
74 |
+
|
75 |
+
return data
|
76 |
+
|