""" Summarization Model Handler Manages the fine-tuned BART model for text summarization. """ from transformers import BartTokenizer, BartForConditionalGeneration import torch import streamlit as st class Summarizer: def __init__(self): """Initialize the summarization model.""" self.model = None self.tokenizer = None def load_model(self): """Load the fine-tuned BART summarization model.""" try: # Load the tokenizer self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") # Load the fine-tuned model self.model = BartForConditionalGeneration.from_pretrained("bart_ami_finetuned.pkl") # Move model to appropriate device (GPU if available) self.model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) return self.model except Exception as e: st.error(f"Error loading fine-tuned summarization model: {str(e)}") return None def process(self, text: str, max_length: int = 130, min_length: int = 30): """Process text for summarization. Args: text (str): Text to summarize max_length (int): Maximum length of summary min_length (int): Minimum length of summary Returns: str: Summarized text """ try: # Tokenize input text inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=1024, padding="max_length") # Move inputs to the same device as the model inputs = {key: value.to(self.model.device) for key, value in inputs.items()} # Generate summary summary_ids = self.model.generate( inputs["input_ids"], max_length=max_length, min_length=min_length, num_beams=4, # Beam search for better quality early_stopping=True ) # Decode summary tokens to text summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True) return summary except Exception as e: st.error(f"Error in summarization: {str(e)}") return None