Abdelrhman Ashraf commited on
Commit
d80a8bb
·
1 Parent(s): 88fa37a

Add translation application with greedy and beam search methods

Browse files
Files changed (1) hide show
  1. app.py +83 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from Models.ModelArgs import ModelArgs
3
+ from Models.AutoModel import get_model
4
+ from gradio_utils import Callable_tokenizer, greedy_decode
5
+ import gradio as gr
6
+
7
+ def en_translate_ar_beam(text, model, tokenizer, max_tries=50):
8
+ return "future work"
9
+
10
+
11
+ def en_translate_ar_greedy(text, model, tokenizer, max_tries=50):
12
+ source_tensor = torch.tensor(tokenizer(text)).unsqueeze(0)
13
+ target_tokens = greedy_decode(model, source_tensor,
14
+ tokenizer.get_tokenId('<s>'),
15
+ tokenizer.get_tokenId('</s>'),
16
+ tokenizer.get_tokenId('<pad>'), max_tries)
17
+
18
+ return tokenizer.decode(target_tokens)
19
+
20
+
21
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
+ tokenizer = Callable_tokenizer('./assets/tokenizers/en-ar_tokenizer.model')
23
+
24
+ model_state_dict = torch.load("./assets/models/en-ar_s2sAttention.pth", map_location=device, weights_only=True)['model_state_dict']
25
+ model_args = ModelArgs('s2sattention', "./Configurations/s2sattention_model_config.json")
26
+ s2sattention = get_model(model_args, len(tokenizer))
27
+ s2sattention.load_state_dict(model_state_dict)
28
+ s2sattention.to(device)
29
+ s2sattention.eval()
30
+
31
+ model_state_dict = torch.load("./assets/models/en-ar_s2s.pth", map_location=device, weights_only=True)['model_state_dict']
32
+ model_args = ModelArgs('s2s', "./Configurations/s2s_model_config.json")
33
+ s2s = get_model(model_args, len(tokenizer))
34
+ s2s.load_state_dict(model_state_dict)
35
+ s2s.to(device)
36
+ s2s.eval()
37
+
38
+ model_state_dict = torch.load("./assets/models/en-ar_transformer.pth", map_location=device, weights_only=True)['model_state_dict']
39
+ model_args = ModelArgs('transformer', "./Configurations/transformer_model_config.json")
40
+ transformer = get_model(model_args, len(tokenizer))
41
+ transformer.load_state_dict(model_state_dict)
42
+ transformer.to(device)
43
+ transformer.eval()
44
+
45
+
46
+ def launch_translation_greedy(raw_input, maxtries=50):
47
+ transformer_out = en_translate_ar_greedy(raw_input, transformer, tokenizer, maxtries)
48
+ s2sattention_out = en_translate_ar_greedy(raw_input, s2sattention, tokenizer, maxtries)
49
+ s2s_out = en_translate_ar_greedy(raw_input, s2s, tokenizer, maxtries)
50
+ return transformer_out, s2sattention_out, s2s_out,
51
+
52
+
53
+ def launch_translation_beam(raw_input, maxtries=50):
54
+ transformer_out = en_translate_ar_beam(raw_input, transformer, tokenizer, maxtries)
55
+ s2sattention_out = en_translate_ar_beam(raw_input, s2sattention, tokenizer, maxtries)
56
+ s2s_out = en_translate_ar_beam(raw_input, s2s, tokenizer, maxtries)
57
+ return transformer_out, s2sattention_out, s2s_out
58
+
59
+
60
+ custom_css ='.gr-button {background-color: #bf4b04; color: white;}'
61
+ with gr.Blocks(css=custom_css) as demo:
62
+ with gr.Row():
63
+ with gr.Column():
64
+ input_text = gr.Textbox(label='English Sentence')
65
+ gr.Examples(['How are you?',
66
+ 'She is a good girl.',
67
+ 'Who is better than me?!',
68
+ 'is tom looking at me?',
69
+ 'when was the last time we met?'],
70
+ inputs=input_text, label="Examples: ")
71
+ with gr.Column():
72
+ output1 = gr.Textbox(label="Arabic Transformer Translation")
73
+ output2 = gr.Textbox(label="Arabic seq2seq with Attention Translation")
74
+ output3 = gr.Textbox(label="Arabic seq2seq No Attention Translation")
75
+
76
+ start_greedy_btn = gr.Button(value='Arabic Translation (Greedy search)', elem_classes=["gr-button"])
77
+ start_beam_btn = gr.Button(value='Arabic Translation (Beam search)', elem_classes=["gr-button"])
78
+
79
+ start_greedy_btn.click(fn=launch_translation_greedy, inputs=input_text, outputs=[output1, output2, output3])
80
+ start_beam_btn.click(fn=launch_translation_beam, inputs=input_text, outputs=[output1, output2, output3])
81
+
82
+
83
+ demo.launch()