Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -10,7 +10,8 @@ dialects = {"Palestinian/Jordanian": "P", "Syrian": "S", "Lebanese": "L", "Egypt | |
| 10 | 
             
            # translator_en2ar = pipeline(task="translation", model="guymorlan/English2Dialect")
         | 
| 11 | 
             
            translator_en2ar = MarianMTModel.from_pretrained("guymorlan/English2Dialect", output_attentions=True)
         | 
| 12 | 
             
            tokenizer_en2ar = AutoTokenizer.from_pretrained("guymorlan/English2Dialect")
         | 
| 13 | 
            -
            translator_ar2en =  | 
|  | |
| 14 | 
             
            transliterator = pipeline(task="translation", model="guymorlan/DialectTransliterator")
         | 
| 15 |  | 
| 16 | 
             
            speech_config = speechsdk.SpeechConfig(subscription=os.environ.get('SPEECH_KEY'), region=os.environ.get('SPEECH_REGION'))
         | 
| @@ -28,7 +29,7 @@ def generate_diverging_colors(num_colors, palette='Set3'): # courtesy of ChatGPT | |
| 28 | 
             
                return colors_hex
         | 
| 29 |  | 
| 30 |  | 
| 31 | 
            -
            def align_words(outputs, tokenizer, encoder_input_ids, decoder_input_ids, threshold=0.4):
         | 
| 32 | 
             
                alignment = []
         | 
| 33 | 
             
                for i, tok in enumerate(outputs.cross_attentions[2][0][7]):
         | 
| 34 | 
             
                    alignment.append([[i], (tok > threshold).nonzero().squeeze(-1).tolist()])
         | 
| @@ -93,7 +94,7 @@ def align_words(outputs, tokenizer, encoder_input_ids, decoder_input_ids, thresh | |
| 93 |  | 
| 94 | 
             
                srchtml = []
         | 
| 95 | 
             
                for i, token in enumerate(encoder_input_ids[0]):
         | 
| 96 | 
            -
                    if i == 0:
         | 
| 97 | 
             
                        continue
         | 
| 98 | 
             
                    if f"trg_{i}" in colordict:
         | 
| 99 | 
             
                        label = f"trg_{i}"
         | 
| @@ -158,13 +159,42 @@ def translate_english(input_text, include): | |
| 158 |  | 
| 159 | 
             
                return palhtml, pal_out, sy_out, lb_out, eg_out
         | 
| 160 |  | 
| 161 | 
            -
            def translate_arabic(input_text):
         | 
| 162 | 
             
                if not input_text:
         | 
| 163 | 
             
                    return ""
         | 
| 164 |  | 
| 165 | 
            -
                 | 
| 166 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 167 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 168 |  | 
| 169 | 
             
            def get_audio(input_text):
         | 
| 170 | 
             
                audio_config = speechsdk.audio.AudioOutputConfig(filename=f"{input_text}.wav")
         | 
| @@ -244,6 +274,7 @@ with gr.Blocks(title = "Levantine Arabic Translator", css=css, theme="default") | |
| 244 | 
             
                    input_text.submit(translate_english, inputs=[input_text, include], outputs=[pal_html, pal, sy, lb, eg],scroll_to_output=True)
         | 
| 245 | 
             
                    pal.change(get_transliteration, inputs=[pal, include], outputs=[pal_translit]);
         | 
| 246 | 
             
                    include.change(toggle_visibility, inputs=[include], outputs=[pal_translit, sy, lb, eg])
         | 
|  | |
| 247 | 
             
                with gr.Tab('Ar > En'):
         | 
| 248 | 
             
                    with gr.Row():
         | 
| 249 | 
             
                        with gr.Column():
         | 
| @@ -252,8 +283,12 @@ with gr.Blocks(title = "Levantine Arabic Translator", css=css, theme="default") | |
| 252 | 
             
                            btn = gr.Button("Translate", label="Translate")
         | 
| 253 | 
             
                            gr.Markdown("Built by [Guy Mor-Lan](mailto:[email protected]).")
         | 
| 254 | 
             
                        with gr.Column():
         | 
| 255 | 
            -
                             | 
|  | |
|  | |
|  | |
| 256 | 
             
                    btn.click(translate_arabic,inputs=input_text, outputs=[eng])
         | 
|  | |
| 257 | 
             
                with gr.Tab("Transliterate"):
         | 
| 258 | 
             
                    with gr.Row():
         | 
| 259 | 
             
                        with gr.Column():
         | 
|  | |
| 10 | 
             
            # translator_en2ar = pipeline(task="translation", model="guymorlan/English2Dialect")
         | 
| 11 | 
             
            translator_en2ar = MarianMTModel.from_pretrained("guymorlan/English2Dialect", output_attentions=True)
         | 
| 12 | 
             
            tokenizer_en2ar = AutoTokenizer.from_pretrained("guymorlan/English2Dialect")
         | 
| 13 | 
            +
            translator_ar2en = MarianMTModel.from_pretrained("guymorlan/Shami2English", output_attentions=True)
         | 
| 14 | 
            +
            tokenizer_ar2en = AutoTokenizer.from_pretrained("guymorlan/Shami2English")
         | 
| 15 | 
             
            transliterator = pipeline(task="translation", model="guymorlan/DialectTransliterator")
         | 
| 16 |  | 
| 17 | 
             
            speech_config = speechsdk.SpeechConfig(subscription=os.environ.get('SPEECH_KEY'), region=os.environ.get('SPEECH_REGION'))
         | 
|  | |
| 29 | 
             
                return colors_hex
         | 
| 30 |  | 
| 31 |  | 
| 32 | 
            +
            def align_words(outputs, tokenizer, encoder_input_ids, decoder_input_ids, threshold=0.4, skip_first_src=True):
         | 
| 33 | 
             
                alignment = []
         | 
| 34 | 
             
                for i, tok in enumerate(outputs.cross_attentions[2][0][7]):
         | 
| 35 | 
             
                    alignment.append([[i], (tok > threshold).nonzero().squeeze(-1).tolist()])
         | 
|  | |
| 94 |  | 
| 95 | 
             
                srchtml = []
         | 
| 96 | 
             
                for i, token in enumerate(encoder_input_ids[0]):
         | 
| 97 | 
            +
                    if skip_first_src and i == 0:
         | 
| 98 | 
             
                        continue
         | 
| 99 | 
             
                    if f"trg_{i}" in colordict:
         | 
| 100 | 
             
                        label = f"trg_{i}"
         | 
|  | |
| 159 |  | 
| 160 | 
             
                return palhtml, pal_out, sy_out, lb_out, eg_out
         | 
| 161 |  | 
| 162 | 
            +
            def translate_arabic(input_text, include=["Colorize"]):
         | 
| 163 | 
             
                if not input_text:
         | 
| 164 | 
             
                    return ""
         | 
| 165 |  | 
| 166 | 
            +
                input_tokens = tokenizer_ar2en(input_text, return_tensors="pt").input_ids
         | 
| 167 | 
            +
                # print(input_tokens)
         | 
| 168 | 
            +
                outputs =  translator_ar2en.generate(input_tokens)
         | 
| 169 | 
            +
                # print(outputs)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                encoder_input_ids = input_tokens[0].unsqueeze(0)
         | 
| 172 | 
            +
                decoder_input_ids = outputs[0].unsqueeze(0)
         | 
| 173 |  | 
| 174 | 
            +
                decoded = tokenizer_en2ar.batch_decode(outputs, skip_special_tokens=True)
         | 
| 175 | 
            +
                # print(decoded)
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                print(include)
         | 
| 178 | 
            +
                if "Colorize" in include:
         | 
| 179 | 
            +
                    html_outputs = translator_ar2en(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids)
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    # set dynamic threshold
         | 
| 182 | 
            +
                    # print(input_tokens, input_tokens.shape)
         | 
| 183 | 
            +
                    if input_tokens.shape[1] < 20:
         | 
| 184 | 
            +
                        threshold = 0.1
         | 
| 185 | 
            +
                    elif input_tokens.shape[1] < 30:
         | 
| 186 | 
            +
                        threshold = 0.01
         | 
| 187 | 
            +
                    else:
         | 
| 188 | 
            +
                        threshold = 0.05
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                    print("threshold", threshold)
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    srchtml, tgthtml = align_words(html_outputs, tokenizer_ar2en, encoder_input_ids, decoder_input_ids, threshold, skip_first_src=False)
         | 
| 193 | 
            +
                    enhtml = f"<div style='direction: rtl'>{srchtml}</div><br><br><div>{tgthtml}</div>"
         | 
| 194 | 
            +
                else:
         | 
| 195 | 
            +
                    enhtml = f"<div style='font-size: 30px;'>{decoded[0]}</div>"
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                return enhtml
         | 
| 198 |  | 
| 199 | 
             
            def get_audio(input_text):
         | 
| 200 | 
             
                audio_config = speechsdk.audio.AudioOutputConfig(filename=f"{input_text}.wav")
         | 
|  | |
| 274 | 
             
                    input_text.submit(translate_english, inputs=[input_text, include], outputs=[pal_html, pal, sy, lb, eg],scroll_to_output=True)
         | 
| 275 | 
             
                    pal.change(get_transliteration, inputs=[pal, include], outputs=[pal_translit]);
         | 
| 276 | 
             
                    include.change(toggle_visibility, inputs=[include], outputs=[pal_translit, sy, lb, eg])
         | 
| 277 | 
            +
             | 
| 278 | 
             
                with gr.Tab('Ar > En'):
         | 
| 279 | 
             
                    with gr.Row():
         | 
| 280 | 
             
                        with gr.Column():
         | 
|  | |
| 283 | 
             
                            btn = gr.Button("Translate", label="Translate")
         | 
| 284 | 
             
                            gr.Markdown("Built by [Guy Mor-Lan](mailto:[email protected]).")
         | 
| 285 | 
             
                        with gr.Column():
         | 
| 286 | 
            +
                            with gr.Box(label = "English"):
         | 
| 287 | 
            +
                                gr.Markdown("English")
         | 
| 288 | 
            +
                                with gr.Box():
         | 
| 289 | 
            +
                                    eng = gr.HTML("<br>", label="English", elem_id="main")
         | 
| 290 | 
             
                    btn.click(translate_arabic,inputs=input_text, outputs=[eng])
         | 
| 291 | 
            +
             | 
| 292 | 
             
                with gr.Tab("Transliterate"):
         | 
| 293 | 
             
                    with gr.Row():
         | 
| 294 | 
             
                        with gr.Column():
         |