Abdelrhman Ashraf
		
	commited on
		
		
					Commit 
							
							·
						
						d80a8bb
	
1
								Parent(s):
							
							88fa37a
								
Add translation application with greedy and beam search methods
Browse files
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,83 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from Models.ModelArgs import ModelArgs
         | 
| 3 | 
            +
            from Models.AutoModel import get_model
         | 
| 4 | 
            +
            from gradio_utils import Callable_tokenizer, greedy_decode
         | 
| 5 | 
            +
            import gradio as gr
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            def en_translate_ar_beam(text, model, tokenizer, max_tries=50):
         | 
| 8 | 
            +
                return "future work"
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            def en_translate_ar_greedy(text, model, tokenizer, max_tries=50):
         | 
| 12 | 
            +
                source_tensor = torch.tensor(tokenizer(text)).unsqueeze(0)
         | 
| 13 | 
            +
                target_tokens = greedy_decode(model, source_tensor,
         | 
| 14 | 
            +
                                              tokenizer.get_tokenId('<s>'),
         | 
| 15 | 
            +
                                              tokenizer.get_tokenId('</s>'),
         | 
| 16 | 
            +
                                              tokenizer.get_tokenId('<pad>'), max_tries)
         | 
| 17 | 
            +
                
         | 
| 18 | 
            +
                return tokenizer.decode(target_tokens)
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
         | 
| 22 | 
            +
            tokenizer = Callable_tokenizer('./assets/tokenizers/en-ar_tokenizer.model')
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            model_state_dict = torch.load("./assets/models/en-ar_s2sAttention.pth", map_location=device, weights_only=True)['model_state_dict']
         | 
| 25 | 
            +
            model_args = ModelArgs('s2sattention', "./Configurations/s2sattention_model_config.json")
         | 
| 26 | 
            +
            s2sattention = get_model(model_args, len(tokenizer))
         | 
| 27 | 
            +
            s2sattention.load_state_dict(model_state_dict)
         | 
| 28 | 
            +
            s2sattention.to(device)
         | 
| 29 | 
            +
            s2sattention.eval()
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            model_state_dict = torch.load("./assets/models/en-ar_s2s.pth", map_location=device, weights_only=True)['model_state_dict']
         | 
| 32 | 
            +
            model_args = ModelArgs('s2s', "./Configurations/s2s_model_config.json")
         | 
| 33 | 
            +
            s2s = get_model(model_args, len(tokenizer))
         | 
| 34 | 
            +
            s2s.load_state_dict(model_state_dict)
         | 
| 35 | 
            +
            s2s.to(device)
         | 
| 36 | 
            +
            s2s.eval()
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            model_state_dict = torch.load("./assets/models/en-ar_transformer.pth", map_location=device, weights_only=True)['model_state_dict']
         | 
| 39 | 
            +
            model_args = ModelArgs('transformer', "./Configurations/transformer_model_config.json")
         | 
| 40 | 
            +
            transformer = get_model(model_args, len(tokenizer))
         | 
| 41 | 
            +
            transformer.load_state_dict(model_state_dict)
         | 
| 42 | 
            +
            transformer.to(device)
         | 
| 43 | 
            +
            transformer.eval()
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            def launch_translation_greedy(raw_input, maxtries=50):
         | 
| 47 | 
            +
                transformer_out = en_translate_ar_greedy(raw_input, transformer, tokenizer, maxtries)
         | 
| 48 | 
            +
                s2sattention_out = en_translate_ar_greedy(raw_input, s2sattention, tokenizer, maxtries)
         | 
| 49 | 
            +
                s2s_out = en_translate_ar_greedy(raw_input, s2s, tokenizer, maxtries)
         | 
| 50 | 
            +
                return transformer_out, s2sattention_out, s2s_out, 
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            def launch_translation_beam(raw_input, maxtries=50):
         | 
| 54 | 
            +
                transformer_out = en_translate_ar_beam(raw_input, transformer, tokenizer, maxtries)
         | 
| 55 | 
            +
                s2sattention_out = en_translate_ar_beam(raw_input, s2sattention, tokenizer, maxtries)
         | 
| 56 | 
            +
                s2s_out = en_translate_ar_beam(raw_input, s2s, tokenizer, maxtries)
         | 
| 57 | 
            +
                return transformer_out, s2sattention_out, s2s_out
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            custom_css ='.gr-button {background-color: #bf4b04; color: white;}'
         | 
| 61 | 
            +
            with gr.Blocks(css=custom_css) as demo:
         | 
| 62 | 
            +
                with gr.Row():
         | 
| 63 | 
            +
                    with gr.Column():
         | 
| 64 | 
            +
                        input_text = gr.Textbox(label='English Sentence')
         | 
| 65 | 
            +
                        gr.Examples(['How are you?',
         | 
| 66 | 
            +
                                     'She is a good girl.',
         | 
| 67 | 
            +
                                     'Who is better than me?!',
         | 
| 68 | 
            +
                                     'is tom looking at me?',
         | 
| 69 | 
            +
                                     'when was the last time we met?'],
         | 
| 70 | 
            +
                                    inputs=input_text, label="Examples: ")
         | 
| 71 | 
            +
                    with gr.Column():
         | 
| 72 | 
            +
                        output1 = gr.Textbox(label="Arabic Transformer Translation")
         | 
| 73 | 
            +
                        output2 = gr.Textbox(label="Arabic seq2seq with Attention Translation")
         | 
| 74 | 
            +
                        output3 = gr.Textbox(label="Arabic seq2seq No Attention Translation")
         | 
| 75 | 
            +
                        
         | 
| 76 | 
            +
                        start_greedy_btn = gr.Button(value='Arabic Translation (Greedy search)', elem_classes=["gr-button"])
         | 
| 77 | 
            +
                        start_beam_btn = gr.Button(value='Arabic Translation (Beam search)', elem_classes=["gr-button"])
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                start_greedy_btn.click(fn=launch_translation_greedy, inputs=input_text, outputs=[output1, output2, output3])
         | 
| 80 | 
            +
                start_beam_btn.click(fn=launch_translation_beam, inputs=input_text, outputs=[output1, output2, output3])
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            demo.launch()
         |