import torch from Models.ModelArgs import ModelArgs from Models.AutoModel import get_model from gradio_utils import Callable_tokenizer, greedy_decode import gradio as gr def en_translate_ar_beam(text, model, tokenizer, max_tries=50): return "future work" device = 'cuda' if torch.cuda.is_available() else 'cpu' def en_translate_ar_greedy(text, model, tokenizer, max_tries=50): source_tensor = torch.tensor(tokenizer(text)).unsqueeze(0).to(device) target_tokens = greedy_decode(model, source_tensor, tokenizer.get_tokenId(''), tokenizer.get_tokenId(''), tokenizer.get_tokenId(''), max_tries) return tokenizer.decode(target_tokens) tokenizer = Callable_tokenizer('./assets/tokenizers/en-ar_tokenizer.model') model_state_dict = torch.load("./assets/models/en-ar_s2sAttention.pth", map_location=device, weights_only=True)['model_state_dict'] model_args = ModelArgs('s2sattention', "./Configurations/s2sattention_model_config.json") s2sattention = get_model(model_args, len(tokenizer)) s2sattention.load_state_dict(model_state_dict) s2sattention.to(device) s2sattention.eval() model_state_dict = torch.load("./assets/models/en-ar_s2s.pth", map_location=device, weights_only=True)['model_state_dict'] model_args = ModelArgs('s2s', "./Configurations/s2s_model_config.json") s2s = get_model(model_args, len(tokenizer)) s2s.load_state_dict(model_state_dict) s2s.to(device) s2s.eval() model_state_dict = torch.load("./assets/models/en-ar_transformer.pth", map_location=device, weights_only=True)['model_state_dict'] model_args = ModelArgs('transformer', "./Configurations/transformer_model_config.json") transformer = get_model(model_args, len(tokenizer)) transformer.load_state_dict(model_state_dict) transformer.to(device) transformer.eval() def launch_translation_greedy(raw_input, maxtries=50): transformer_out = en_translate_ar_greedy(raw_input, transformer, tokenizer, maxtries) s2sattention_out = en_translate_ar_greedy(raw_input, s2sattention, tokenizer, maxtries) s2s_out = en_translate_ar_greedy(raw_input, s2s, tokenizer, maxtries) return transformer_out, s2sattention_out, s2s_out, def launch_translation_beam(raw_input, maxtries=50): transformer_out = en_translate_ar_beam(raw_input, transformer, tokenizer, maxtries) s2sattention_out = en_translate_ar_beam(raw_input, s2sattention, tokenizer, maxtries) s2s_out = en_translate_ar_beam(raw_input, s2s, tokenizer, maxtries) return transformer_out, s2sattention_out, s2s_out custom_css ='.gr-button {background-color: #bf4b04; color: white;}' with gr.Blocks(css=custom_css) as demo: with gr.Row(): with gr.Column(): input_text = gr.Textbox(label='English Sentence') gr.Examples(['How are you?', 'She is a good girl.', 'Who is better than me?!', 'is tom looking at me?', 'when was the last time we met?'], inputs=input_text, label="Examples: ") with gr.Column(): output1 = gr.Textbox(label="Arabic Transformer Translation") output2 = gr.Textbox(label="Arabic seq2seq with Attention Translation") output3 = gr.Textbox(label="Arabic seq2seq No Attention Translation") start_greedy_btn = gr.Button(value='Arabic Translation (Greedy search)', elem_classes=["gr-button"]) start_beam_btn = gr.Button(value='Arabic Translation (Beam search)', elem_classes=["gr-button"]) start_greedy_btn.click(fn=launch_translation_greedy, inputs=input_text, outputs=[output1, output2, output3]) start_beam_btn.click(fn=launch_translation_beam, inputs=input_text, outputs=[output1, output2, output3]) demo.launch()