File size: 3,875 Bytes
d80a8bb
 
 
 
 
 
 
 
 
 
6b6e8d3
d80a8bb
6b6e8d3
d80a8bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
from Models.ModelArgs import ModelArgs
from Models.AutoModel import get_model
from gradio_utils import Callable_tokenizer, greedy_decode
import gradio as gr

def en_translate_ar_beam(text, model, tokenizer, max_tries=50):
    return "future work"


device = 'cuda' if torch.cuda.is_available() else 'cpu'
def en_translate_ar_greedy(text, model, tokenizer, max_tries=50):
    source_tensor = torch.tensor(tokenizer(text)).unsqueeze(0).to(device)
    target_tokens = greedy_decode(model, source_tensor,
                                  tokenizer.get_tokenId('<s>'),
                                  tokenizer.get_tokenId('</s>'),
                                  tokenizer.get_tokenId('<pad>'), max_tries)
    
    return tokenizer.decode(target_tokens)


tokenizer = Callable_tokenizer('./assets/tokenizers/en-ar_tokenizer.model')

model_state_dict = torch.load("./assets/models/en-ar_s2sAttention.pth", map_location=device, weights_only=True)['model_state_dict']
model_args = ModelArgs('s2sattention', "./Configurations/s2sattention_model_config.json")
s2sattention = get_model(model_args, len(tokenizer))
s2sattention.load_state_dict(model_state_dict)
s2sattention.to(device)
s2sattention.eval()

model_state_dict = torch.load("./assets/models/en-ar_s2s.pth", map_location=device, weights_only=True)['model_state_dict']
model_args = ModelArgs('s2s', "./Configurations/s2s_model_config.json")
s2s = get_model(model_args, len(tokenizer))
s2s.load_state_dict(model_state_dict)
s2s.to(device)
s2s.eval()

model_state_dict = torch.load("./assets/models/en-ar_transformer.pth", map_location=device, weights_only=True)['model_state_dict']
model_args = ModelArgs('transformer', "./Configurations/transformer_model_config.json")
transformer = get_model(model_args, len(tokenizer))
transformer.load_state_dict(model_state_dict)
transformer.to(device)
transformer.eval()


def launch_translation_greedy(raw_input, maxtries=50):
    transformer_out = en_translate_ar_greedy(raw_input, transformer, tokenizer, maxtries)
    s2sattention_out = en_translate_ar_greedy(raw_input, s2sattention, tokenizer, maxtries)
    s2s_out = en_translate_ar_greedy(raw_input, s2s, tokenizer, maxtries)
    return transformer_out, s2sattention_out, s2s_out, 


def launch_translation_beam(raw_input, maxtries=50):
    transformer_out = en_translate_ar_beam(raw_input, transformer, tokenizer, maxtries)
    s2sattention_out = en_translate_ar_beam(raw_input, s2sattention, tokenizer, maxtries)
    s2s_out = en_translate_ar_beam(raw_input, s2s, tokenizer, maxtries)
    return transformer_out, s2sattention_out, s2s_out


custom_css ='.gr-button {background-color: #bf4b04; color: white;}'
with gr.Blocks(css=custom_css) as demo:
    with gr.Row():
        with gr.Column():
            input_text = gr.Textbox(label='English Sentence')
            gr.Examples(['How are you?',
                         'She is a good girl.',
                         'Who is better than me?!',
                         'is tom looking at me?',
                         'when was the last time we met?'],
                        inputs=input_text, label="Examples: ")
        with gr.Column():
            output1 = gr.Textbox(label="Arabic Transformer Translation")
            output2 = gr.Textbox(label="Arabic seq2seq with Attention Translation")
            output3 = gr.Textbox(label="Arabic seq2seq No Attention Translation")
            
            start_greedy_btn = gr.Button(value='Arabic Translation (Greedy search)', elem_classes=["gr-button"])
            start_beam_btn = gr.Button(value='Arabic Translation (Beam search)', elem_classes=["gr-button"])

    start_greedy_btn.click(fn=launch_translation_greedy, inputs=input_text, outputs=[output1, output2, output3])
    start_beam_btn.click(fn=launch_translation_beam, inputs=input_text, outputs=[output1, output2, output3])


demo.launch()