File size: 5,052 Bytes
0d925f9 190e77d 0d925f9 190e77d 0d925f9 190e77d 0d925f9 190e77d 0d925f9 190e77d 0d925f9 190e77d 0d925f9 190e77d 0d925f9 190e77d 0d925f9 190e77d 0d925f9 190e77d 0d925f9 190e77d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import streamlit as st
import newspaper
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
from urllib.parse import urlparse
# Initialize session state for model and tokenizer
if 'model' not in st.session_state:
st.session_state.model = None
if 'tokenizer' not in st.session_state:
st.session_state.tokenizer = None
@st.cache_resource
def load_model():
try:
# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the model
model = T5ForConditionalGeneration.from_pretrained('t5-base')
# Load the saved weights with appropriate map_location
checkpoint = torch.load('abstractive-model-sihanas.pth', map_location=device)
model.load_state_dict(checkpoint)
model.to(device)
# Load tokenizer
tokenizer = T5Tokenizer.from_pretrained('t5-base')
return model, tokenizer, device
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return None, None, None
def clean_text(text):
"""Clean and preprocess the input text"""
# Remove extra whitespace
text = ' '.join(text.split())
# Remove very long words (likely garbage)
text = ' '.join(word for word in text.split() if len(word) < 100)
return text
def summarize_text(text, model, tokenizer, device):
try:
# Clean the text
cleaned_text = clean_text(text)
# Tokenize and generate summary
inputs = tokenizer.encode("summarize: " + cleaned_text,
return_tensors='pt',
max_length=512,
truncation=True).to(device)
summary_ids = model.generate(
inputs,
max_length=150,
min_length=40,
num_beams=4,
length_penalty=2.0,
early_stopping=True
)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
except Exception as e:
st.error(f"Error in summarization: {str(e)}")
return None
def fetch_article(url):
"""Fetch article content and metadata from URL using newspaper3k"""
try:
# Download and parse the article
article = newspaper.Article(url)
# Enable extraction of all possible metadata
article.download()
article.parse()
# Extract metadata
title = article.title or 'No title found'
authors = ', '.join(article.authors) if article.authors else 'No author information'
publish_date = article.publish_date or 'No publish date found'
# Extract publisher from URL domain
publisher = urlparse(url).netloc.replace('www.', '').capitalize() or 'No publisher information'
# Get the main text content
text = article.text or ''
return title, authors, str(publish_date), publisher, text
except Exception as e:
st.error(f"Error fetching the article: {str(e)}")
return None, None, None, None, None
def main():
st.title("News Article Summarizer")
st.write("Enter a news article URL to get a summary.")
# Load model and tokenizer
model, tokenizer, device = load_model()
if model is None or tokenizer is None:
st.error("Failed to load the model. Please check your model file and dependencies.")
return
# URL input
url = st.text_input("News Article URL")
if st.button("Summarize"):
if not url:
st.warning("Please enter a URL")
return
with st.spinner("Fetching article and generating summary..."):
# Fetch article
title, authors, publish_date, publisher, article_text = fetch_article(url)
if article_text:
# Display metadata
st.write(f"**Title**: {title}")
st.write(f"**Authors**: {authors}")
st.write(f"**Publish Date**: {publish_date}")
st.write(f"**Publisher**: {publisher}")
# Generate summary
summary = summarize_text(article_text, model, tokenizer, device)
if summary:
st.success("Summary generated successfully!")
st.write("### Summary")
st.write(summary)
# Display original text (collapsed)
with st.expander("Show original article"):
st.write(article_text)
else:
st.error("Failed to fetch the article. Please check the URL and try again.")
if __name__ == "__main__":
main() |