|
import torch |
|
from Models.ModelArgs import ModelArgs |
|
from Models.AutoModel import get_model |
|
from gradio_utils import Callable_tokenizer, greedy_decode |
|
import gradio as gr |
|
|
|
def en_translate_ar_beam(text, model, tokenizer, max_tries=50): |
|
return "future work" |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
def en_translate_ar_greedy(text, model, tokenizer, max_tries=50): |
|
source_tensor = torch.tensor(tokenizer(text)).unsqueeze(0).to(device) |
|
target_tokens = greedy_decode(model, source_tensor, |
|
tokenizer.get_tokenId('<s>'), |
|
tokenizer.get_tokenId('</s>'), |
|
tokenizer.get_tokenId('<pad>'), max_tries) |
|
|
|
return tokenizer.decode(target_tokens) |
|
|
|
|
|
tokenizer = Callable_tokenizer('./assets/tokenizers/en-ar_tokenizer.model') |
|
|
|
model_state_dict = torch.load("./assets/models/en-ar_s2sAttention.pth", map_location=device, weights_only=True)['model_state_dict'] |
|
model_args = ModelArgs('s2sattention', "./Configurations/s2sattention_model_config.json") |
|
s2sattention = get_model(model_args, len(tokenizer)) |
|
s2sattention.load_state_dict(model_state_dict) |
|
s2sattention.to(device) |
|
s2sattention.eval() |
|
|
|
model_state_dict = torch.load("./assets/models/en-ar_s2s.pth", map_location=device, weights_only=True)['model_state_dict'] |
|
model_args = ModelArgs('s2s', "./Configurations/s2s_model_config.json") |
|
s2s = get_model(model_args, len(tokenizer)) |
|
s2s.load_state_dict(model_state_dict) |
|
s2s.to(device) |
|
s2s.eval() |
|
|
|
model_state_dict = torch.load("./assets/models/en-ar_transformer.pth", map_location=device, weights_only=True)['model_state_dict'] |
|
model_args = ModelArgs('transformer', "./Configurations/transformer_model_config.json") |
|
transformer = get_model(model_args, len(tokenizer)) |
|
transformer.load_state_dict(model_state_dict) |
|
transformer.to(device) |
|
transformer.eval() |
|
|
|
|
|
def launch_translation_greedy(raw_input, maxtries=50): |
|
transformer_out = en_translate_ar_greedy(raw_input, transformer, tokenizer, maxtries) |
|
s2sattention_out = en_translate_ar_greedy(raw_input, s2sattention, tokenizer, maxtries) |
|
s2s_out = en_translate_ar_greedy(raw_input, s2s, tokenizer, maxtries) |
|
return transformer_out, s2sattention_out, s2s_out, |
|
|
|
|
|
def launch_translation_beam(raw_input, maxtries=50): |
|
transformer_out = en_translate_ar_beam(raw_input, transformer, tokenizer, maxtries) |
|
s2sattention_out = en_translate_ar_beam(raw_input, s2sattention, tokenizer, maxtries) |
|
s2s_out = en_translate_ar_beam(raw_input, s2s, tokenizer, maxtries) |
|
return transformer_out, s2sattention_out, s2s_out |
|
|
|
|
|
custom_css ='.gr-button {background-color: #bf4b04; color: white;}' |
|
with gr.Blocks(css=custom_css) as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_text = gr.Textbox(label='English Sentence') |
|
gr.Examples(['How are you?', |
|
'She is a good girl.', |
|
'Who is better than me?!', |
|
'is tom looking at me?', |
|
'when was the last time we met?'], |
|
inputs=input_text, label="Examples: ") |
|
with gr.Column(): |
|
output1 = gr.Textbox(label="Arabic Transformer Translation") |
|
output2 = gr.Textbox(label="Arabic seq2seq with Attention Translation") |
|
output3 = gr.Textbox(label="Arabic seq2seq No Attention Translation") |
|
|
|
start_greedy_btn = gr.Button(value='Arabic Translation (Greedy search)', elem_classes=["gr-button"]) |
|
start_beam_btn = gr.Button(value='Arabic Translation (Beam search)', elem_classes=["gr-button"]) |
|
|
|
start_greedy_btn.click(fn=launch_translation_greedy, inputs=input_text, outputs=[output1, output2, output3]) |
|
start_beam_btn.click(fn=launch_translation_beam, inputs=input_text, outputs=[output1, output2, output3]) |
|
|
|
|
|
demo.launch() |
|
|