|
import streamlit as st
|
|
import newspaper
|
|
import torch
|
|
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
|
from urllib.parse import urlparse
|
|
|
|
|
|
if 'model' not in st.session_state:
|
|
st.session_state.model = None
|
|
if 'tokenizer' not in st.session_state:
|
|
st.session_state.tokenizer = None
|
|
|
|
@st.cache_resource
|
|
def load_model():
|
|
try:
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
model = T5ForConditionalGeneration.from_pretrained('t5-base')
|
|
|
|
|
|
checkpoint = torch.load('abstractive-model-sihanas.pth', map_location=device)
|
|
|
|
model.load_state_dict(checkpoint)
|
|
model.to(device)
|
|
|
|
|
|
tokenizer = T5Tokenizer.from_pretrained('t5-base')
|
|
|
|
return model, tokenizer, device
|
|
|
|
except Exception as e:
|
|
st.error(f"Error loading model: {str(e)}")
|
|
return None, None, None
|
|
|
|
def clean_text(text):
|
|
"""Clean and preprocess the input text"""
|
|
|
|
text = ' '.join(text.split())
|
|
|
|
text = ' '.join(word for word in text.split() if len(word) < 100)
|
|
return text
|
|
|
|
def summarize_text(text, model, tokenizer, device):
|
|
try:
|
|
|
|
cleaned_text = clean_text(text)
|
|
|
|
|
|
inputs = tokenizer.encode("summarize: " + cleaned_text,
|
|
return_tensors='pt',
|
|
max_length=512,
|
|
truncation=True).to(device)
|
|
|
|
summary_ids = model.generate(
|
|
inputs,
|
|
max_length=150,
|
|
min_length=40,
|
|
num_beams=4,
|
|
length_penalty=2.0,
|
|
early_stopping=True
|
|
)
|
|
|
|
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
|
return summary
|
|
|
|
except Exception as e:
|
|
st.error(f"Error in summarization: {str(e)}")
|
|
return None
|
|
|
|
def fetch_article(url):
|
|
"""Fetch article content and metadata from URL using newspaper3k"""
|
|
try:
|
|
|
|
article = newspaper.Article(url)
|
|
|
|
|
|
article.download()
|
|
article.parse()
|
|
|
|
|
|
title = article.title or 'No title found'
|
|
authors = ', '.join(article.authors) if article.authors else 'No author information'
|
|
publish_date = article.publish_date or 'No publish date found'
|
|
|
|
|
|
publisher = urlparse(url).netloc.replace('www.', '').capitalize() or 'No publisher information'
|
|
|
|
|
|
text = article.text or ''
|
|
|
|
return title, authors, str(publish_date), publisher, text
|
|
|
|
except Exception as e:
|
|
st.error(f"Error fetching the article: {str(e)}")
|
|
return None, None, None, None, None
|
|
|
|
def main():
|
|
st.title("News Article Summarizer")
|
|
st.write("Enter a news article URL to get a summary.")
|
|
|
|
|
|
model, tokenizer, device = load_model()
|
|
|
|
if model is None or tokenizer is None:
|
|
st.error("Failed to load the model. Please check your model file and dependencies.")
|
|
return
|
|
|
|
|
|
url = st.text_input("News Article URL")
|
|
|
|
if st.button("Summarize"):
|
|
if not url:
|
|
st.warning("Please enter a URL")
|
|
return
|
|
|
|
with st.spinner("Fetching article and generating summary..."):
|
|
|
|
title, authors, publish_date, publisher, article_text = fetch_article(url)
|
|
|
|
if article_text:
|
|
|
|
st.write(f"**Title**: {title}")
|
|
st.write(f"**Authors**: {authors}")
|
|
st.write(f"**Publish Date**: {publish_date}")
|
|
st.write(f"**Publisher**: {publisher}")
|
|
|
|
|
|
summary = summarize_text(article_text, model, tokenizer, device)
|
|
|
|
if summary:
|
|
st.success("Summary generated successfully!")
|
|
st.write("### Summary")
|
|
st.write(summary)
|
|
|
|
|
|
with st.expander("Show original article"):
|
|
st.write(article_text)
|
|
else:
|
|
st.error("Failed to fetch the article. Please check the URL and try again.")
|
|
|
|
if __name__ == "__main__":
|
|
main() |