File size: 2,301 Bytes
b1426fb
 
f6e66c5
b1426fb
 
f6e66c5
b1426fb
 
 
 
08d05f4
f6e66c5
08d05f4
 
b1426fb
08d05f4
f6e66c5
08d05f4
3befa67
 
f6e66c5
 
 
 
 
08d05f4
 
f6e66c5
08d05f4
b1426fb
f6e66c5
 
 
 
 
 
 
 
 
 
 
08d05f4
f6e66c5
 
 
 
 
 
 
08d05f4
 
 
 
f6e66c5
 
08d05f4
f6e66c5
 
08d05f4
f6e66c5
08d05f4
 
f6e66c5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""
Summarization Model Handler
Manages the fine-tuned BART model for text summarization.
"""

from transformers import BartTokenizer, BartForConditionalGeneration
import torch
import streamlit as st

class Summarizer:
    def __init__(self):
        """Initialize the summarization model."""
        self.model = None
        self.tokenizer = None

    def load_model(self):
        """Load the fine-tuned BART summarization model."""
        try:
            with open('bart_ami_finetuned.pkl','rb') as f:
                self.model = pickle.load(f)
            # Load the tokenizer
            self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
            
            # Move model to appropriate device (GPU if available)
            self.model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
            return self.model
        except Exception as e:
            st.error(f"Error loading fine-tuned summarization model: {str(e)}")
            return None

    def process(self, text: str, max_length: int = 130, min_length: int = 30):
        """Process text for summarization.
        
        Args:
            text (str): Text to summarize
            max_length (int): Maximum length of summary
            min_length (int): Minimum length of summary
            
        Returns:
            str: Summarized text
        """
        try:
            # Tokenize input text
            inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=1024, padding="max_length")
            
            # Move inputs to the same device as the model
            inputs = {key: value.to(self.model.device) for key, value in inputs.items()}
            
            # Generate summary
            summary_ids = self.model.generate(
                inputs["input_ids"],
                max_length=max_length,
                min_length=min_length,
                num_beams=4,  # Beam search for better quality
                early_stopping=True
            )
            
            # Decode summary tokens to text
            summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
            return summary
        except Exception as e:
            st.error(f"Error in summarization: {str(e)}")
            return None