Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | 
         @@ -5,6 +5,7 @@ import os 
     | 
|
| 5 | 
         
             
            import httpx
         
     | 
| 6 | 
         | 
| 7 | 
         
             
            logging.set_verbosity_error()
         
     | 
| 
         | 
|
| 8 | 
         | 
| 9 | 
         
             
            def download_argos_model(from_code, to_code):
         
     | 
| 10 | 
         
             
                import argostranslate.package
         
     | 
| 
         @@ -53,7 +54,8 @@ if model_name == 'Helsinki-NLP': 
     | 
|
| 53 | 
         
             
                    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
         
     | 
| 54 | 
         
             
            if model_name.startswith('t5'):
         
     | 
| 55 | 
         
             
                tokenizer = T5Tokenizer.from_pretrained(model_name)
         
     | 
| 56 | 
         
            -
                model = T5ForConditionalGeneration.from_pretrained(model_name)
         
     | 
| 
         | 
|
| 57 | 
         | 
| 58 | 
         
             
            st.write("Selected language combination:", sselected_language, " - ", tselected_language, "Selected model:", model_name)
         
     | 
| 59 | 
         
             
            submit_button = st.button("Translate")
         
     | 
| 
         @@ -77,7 +79,7 @@ if submit_button: 
     | 
|
| 77 | 
         
             
                elif model_name.startswith('t5'):
         
     | 
| 78 | 
         
             
                    prompt = f'translate {sselected_language} to {tselected_language}: {input_text}'
         
     | 
| 79 | 
         
             
                    print(prompt)
         
     | 
| 80 | 
         
            -
                    input_ids = tokenizer.encode(prompt, return_tensors='pt')
         
     | 
| 81 | 
         
             
                    # Perform translation
         
     | 
| 82 | 
         
             
                    output_ids = model.generate(input_ids)
         
     | 
| 83 | 
         
             
                    # Decode the translated text
         
     | 
| 
         @@ -104,8 +106,6 @@ if submit_button: 
     | 
|
| 104 | 
         
             
                        translated_text = f"No Argos model for {sselected_language} to {tselected_language}. Try other model or languages combination!"
         
     | 
| 105 | 
         
             
                    except Exception as error:
         
     | 
| 106 | 
         
             
                        translated_text = error
         
     | 
| 107 | 
         
            -
                        # download_argos_model(sl, tl)
         
     | 
| 108 | 
         
            -
                        # translated_text = argostranslate.translate.translate(input_text, sl, tl)
         
     | 
| 109 | 
         
             
                # Display the translated text
         
     | 
| 110 | 
         
             
                print(translated_text)
         
     | 
| 111 | 
         
             
                st.write(f"Translated text from {sselected_language} to {tselected_language} using {model_name}:")
         
     | 
| 
         | 
|
| 5 | 
         
             
            import httpx
         
     | 
| 6 | 
         | 
| 7 | 
         
             
            logging.set_verbosity_error()
         
     | 
| 8 | 
         
            +
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         
     | 
| 9 | 
         | 
| 10 | 
         
             
            def download_argos_model(from_code, to_code):
         
     | 
| 11 | 
         
             
                import argostranslate.package
         
     | 
| 
         | 
|
| 54 | 
         
             
                    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
         
     | 
| 55 | 
         
             
            if model_name.startswith('t5'):
         
     | 
| 56 | 
         
             
                tokenizer = T5Tokenizer.from_pretrained(model_name)
         
     | 
| 57 | 
         
            +
                model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         | 
| 60 | 
         
             
            st.write("Selected language combination:", sselected_language, " - ", tselected_language, "Selected model:", model_name)
         
     | 
| 61 | 
         
             
            submit_button = st.button("Translate")
         
     | 
| 
         | 
|
| 79 | 
         
             
                elif model_name.startswith('t5'):
         
     | 
| 80 | 
         
             
                    prompt = f'translate {sselected_language} to {tselected_language}: {input_text}'
         
     | 
| 81 | 
         
             
                    print(prompt)
         
     | 
| 82 | 
         
            +
                    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
         
     | 
| 83 | 
         
             
                    # Perform translation
         
     | 
| 84 | 
         
             
                    output_ids = model.generate(input_ids)
         
     | 
| 85 | 
         
             
                    # Decode the translated text
         
     | 
| 
         | 
|
| 106 | 
         
             
                        translated_text = f"No Argos model for {sselected_language} to {tselected_language}. Try other model or languages combination!"
         
     | 
| 107 | 
         
             
                    except Exception as error:
         
     | 
| 108 | 
         
             
                        translated_text = error
         
     | 
| 
         | 
|
| 
         | 
|
| 109 | 
         
             
                # Display the translated text
         
     | 
| 110 | 
         
             
                print(translated_text)
         
     | 
| 111 | 
         
             
                st.write(f"Translated text from {sselected_language} to {tselected_language} using {model_name}:")
         
     |