Manyue-DataScientist's picture
Update src/models/summarization.py
08d05f4 verified
raw
history blame
1.33 kB
"""
Summarization Model Handler
Manages the BART model for text summarization.
"""
from transformers import BartTokenizer
import torch
import streamlit as st
import pickle
class Summarizer:
def __init__(self):
self.model = None
self.tokenizer = None
def load_model(self):
try:
with open('bart_ami_finetuned.pkl', 'rb') as f:
self.model = pickle.load(f)
self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
return self.model
except Exception as e:
st.error(f"Error loading summarization model: {str(e)}")
return None
def process(self, text: str, max_length: int = 150, min_length: int = 40):
try:
inputs = self.tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
summary_ids = self.model.generate(
inputs["input_ids"],
max_length=max_length,
min_length=min_length,
num_beams=4,
length_penalty=2.0
)
summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return [{"summary_text": summary}]
except Exception as e:
st.error(f"Error in summarization: {str(e)}")
return None