Manyue-DataScientist's picture
Update src/models/summarization.py
e168ebd verified
raw
history blame
1.29 kB
"""
Summarization Model Handler
Manages the BART model for text summarization.
"""
from transformers import pipeline
import torch
import streamlit as st
class Summarizer:
def __init__(self, model_path='bart_ami_finetuned.pkl'):
self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
with open(model_path, 'rb') as f:
self.model = pickle.load(f)
def process(self, text):
inputs = self.tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
summary_ids = self.model.generate(inputs["input_ids"], max_length=150, min_length=40)
return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
def process_audio(audio_file):
"""Process text for summarization.
Args:
text (str): Text to summarize
max_length (int): Maximum length of summary
min_length (int): Minimum length of summary
Returns:
str: Summarized text
"""
try:
text = transcriber.process(audio_file)
summary = summarizer.process(text)
return {
"transcription": text,
"summary": summary
}
except Exception as e:
st.error(f"Error: {str(e)}")
return None