File size: 2,443 Bytes
346eea8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from transformers import T5ForConditionalGeneration,T5Tokenizer
from transformers import AutoModelWithLMHead, AutoTokenizer
from transformers import pipeline
import streamlit as st 

model = T5ForConditionalGeneration.from_pretrained("Michau/t5-base-en-generate-headline")
tokenizer = T5Tokenizer.from_pretrained("Michau/t5-base-en-generate-headline")

mrm_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-summarize-news")
mrm_model = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-summarize-news")


def generate_title(article):
    text =  "headline: " + article
    encoding = tokenizer.encode_plus(text, return_tensors = "pt", max_length=2048, truncation=True)
    input_ids = encoding["input_ids"]
    attention_masks = encoding["attention_mask"]

    beam_outputs = model.generate(
        input_ids = input_ids,
        attention_mask = attention_masks,
        max_length = 50,
        num_beams = 3,
        do_sample = True,
        top_k=10,
        early_stopping = False,
    )

    return tokenizer.decode(beam_outputs[0])

# def generate_summary(article):
#   input_ids = mrm_tokenizer.encode(article, return_tensors="pt", add_special_tokens=True)

#   generated_ids = mrm_model.generate(input_ids=input_ids, num_beams=3, max_length=200,  repetition_penalty=2.5, length_penalty=1.0, early_stopping=False, truncation=True)

#   preds = [mrm_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]

#   return preds[0]
def generate_summary(article):
    article = article[:1024]
    summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
    return summarizer(article, max_length=130, min_length=30, do_sample=False)
def main():
    st.title("Text Summarization")
    text = st.text_area("Enter your text here:", "")

    if st.button("Generate Summary"):
        if text.strip() == "":
            st.error("Please enter some text.")
        else:
            title = generate_title(text)
            summary = generate_summary(text)
            # summary = summary[0]['summary_text']

            st.subheader("Generated Title:")
            st.write(title.replace('<pad>', '').replace('</s>', ''))
            
            st.subheader("Generated Description:")

            # st.write(summary.replace('<pad>', '').replace('</s>', ''))
            st.write(summary[0]['summary_text'])


if __name__ == "__main__":
    main()