Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	test colab dev
Browse files- app.py +5 -0
 - src/vanilla_summarizer.py +0 -83
 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,5 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import streamlit as st
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 5 | 
         
            +
                st.header("Streamlit 🤝 Colab")
         
     | 
    	
        src/vanilla_summarizer.py
    CHANGED
    
    | 
         @@ -1,83 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            import torch
         
     | 
| 2 | 
         
            -
            import streamlit as st
         
     | 
| 3 | 
         
            -
            from transformers import BartTokenizer, BartForConditionalGeneration
         
     | 
| 4 | 
         
            -
            from transformers import T5Tokenizer, T5ForConditionalGeneration
         
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
            st.title('Text Summarization Demo')
         
     | 
| 7 | 
         
            -
            st.markdown('Using BART and T5 transformer model')
         
     | 
| 8 | 
         
            -
             
     | 
| 9 | 
         
            -
            model = st.selectbox('Select the model', ('BART', 'T5'))
         
     | 
| 10 | 
         
            -
             
     | 
| 11 | 
         
            -
            if model == 'BART':
         
     | 
| 12 | 
         
            -
                _num_beams = 4
         
     | 
| 13 | 
         
            -
                _no_repeat_ngram_size = 3
         
     | 
| 14 | 
         
            -
                _length_penalty = 1
         
     | 
| 15 | 
         
            -
                _min_length = 12
         
     | 
| 16 | 
         
            -
                _max_length = 128
         
     | 
| 17 | 
         
            -
                _early_stopping = True
         
     | 
| 18 | 
         
            -
            else:
         
     | 
| 19 | 
         
            -
                _num_beams = 4
         
     | 
| 20 | 
         
            -
                _no_repeat_ngram_size = 3
         
     | 
| 21 | 
         
            -
                _length_penalty = 2
         
     | 
| 22 | 
         
            -
                _min_length = 30
         
     | 
| 23 | 
         
            -
                _max_length = 200
         
     | 
| 24 | 
         
            -
                _early_stopping = True
         
     | 
| 25 | 
         
            -
             
     | 
| 26 | 
         
            -
            col1, col2, col3 = st.beta_columns(3)
         
     | 
| 27 | 
         
            -
            _num_beams = col1.number_input("num_beams", value=_num_beams)
         
     | 
| 28 | 
         
            -
            _no_repeat_ngram_size = col2.number_input("no_repeat_ngram_size", value=_no_repeat_ngram_size)
         
     | 
| 29 | 
         
            -
            _length_penalty = col3.number_input("length_penalty", value=_length_penalty)
         
     | 
| 30 | 
         
            -
             
     | 
| 31 | 
         
            -
            col1, col2, col3 = st.beta_columns(3)
         
     | 
| 32 | 
         
            -
            _min_length = col1.number_input("min_length", value=_min_length)
         
     | 
| 33 | 
         
            -
            _max_length = col2.number_input("max_length", value=_max_length)
         
     | 
| 34 | 
         
            -
            _early_stopping = col3.number_input("early_stopping", value=_early_stopping)
         
     | 
| 35 | 
         
            -
             
     | 
| 36 | 
         
            -
            text = st.text_area('Text Input')
         
     | 
| 37 | 
         
            -
             
     | 
| 38 | 
         
            -
             
     | 
| 39 | 
         
            -
            def run_model(input_text):
         
     | 
| 40 | 
         
            -
                device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         
     | 
| 41 | 
         
            -
             
     | 
| 42 | 
         
            -
                if model == "BART":
         
     | 
| 43 | 
         
            -
                    bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
         
     | 
| 44 | 
         
            -
                    bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
         
     | 
| 45 | 
         
            -
                    input_text = str(input_text)
         
     | 
| 46 | 
         
            -
                    input_text = ' '.join(input_text.split())
         
     | 
| 47 | 
         
            -
                    input_tokenized = bart_tokenizer.encode(input_text, return_tensors='pt').to(device)
         
     | 
| 48 | 
         
            -
                    summary_ids = bart_model.generate(input_tokenized,
         
     | 
| 49 | 
         
            -
                                                      num_beams=_num_beams,
         
     | 
| 50 | 
         
            -
                                                      no_repeat_ngram_size=_no_repeat_ngram_size,
         
     | 
| 51 | 
         
            -
                                                      length_penalty=_length_penalty,
         
     | 
| 52 | 
         
            -
                                                      min_length=_min_length,
         
     | 
| 53 | 
         
            -
                                                      max_length=_max_length,
         
     | 
| 54 | 
         
            -
                                                      early_stopping=_early_stopping)
         
     | 
| 55 | 
         
            -
             
     | 
| 56 | 
         
            -
                    output = [bart_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in
         
     | 
| 57 | 
         
            -
                              summary_ids]
         
     | 
| 58 | 
         
            -
                    st.write('Summary')
         
     | 
| 59 | 
         
            -
                    st.success(output[0])
         
     | 
| 60 | 
         
            -
             
     | 
| 61 | 
         
            -
                else:
         
     | 
| 62 | 
         
            -
                    t5_model = T5ForConditionalGeneration.from_pretrained("t5-base")
         
     | 
| 63 | 
         
            -
                    t5_tokenizer = T5Tokenizer.from_pretrained("t5-base")
         
     | 
| 64 | 
         
            -
                    input_text = str(input_text).replace('\n', '')
         
     | 
| 65 | 
         
            -
                    input_text = ' '.join(input_text.split())
         
     | 
| 66 | 
         
            -
                    input_tokenized = t5_tokenizer.encode(input_text, return_tensors="pt").to(device)
         
     | 
| 67 | 
         
            -
                    summary_task = torch.tensor([[21603, 10]]).to(device)
         
     | 
| 68 | 
         
            -
                    input_tokenized = torch.cat([summary_task, input_tokenized], dim=-1).to(device)
         
     | 
| 69 | 
         
            -
                    summary_ids = t5_model.generate(input_tokenized,
         
     | 
| 70 | 
         
            -
                                                    num_beams=_num_beams,
         
     | 
| 71 | 
         
            -
                                                    no_repeat_ngram_size=_no_repeat_ngram_size,
         
     | 
| 72 | 
         
            -
                                                    length_penalty=_length_penalty,
         
     | 
| 73 | 
         
            -
                                                    min_length=_min_length,
         
     | 
| 74 | 
         
            -
                                                    max_length=_max_length,
         
     | 
| 75 | 
         
            -
                                                    early_stopping=_early_stopping)
         
     | 
| 76 | 
         
            -
                    output = [t5_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in
         
     | 
| 77 | 
         
            -
                              summary_ids]
         
     | 
| 78 | 
         
            -
                    st.write('Summary')
         
     | 
| 79 | 
         
            -
                    st.success(output[0])
         
     | 
| 80 | 
         
            -
             
     | 
| 81 | 
         
            -
             
     | 
| 82 | 
         
            -
            if st.button('Submit'):
         
     | 
| 83 | 
         
            -
                run_model(text)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         |