|
""" |
|
Summarization Model Handler |
|
Manages the BART model for text summarization. |
|
""" |
|
|
|
from transformers import pipeline |
|
import torch |
|
import streamlit as st |
|
|
|
class Summarizer: |
|
def __init__(self, model_path='bart_ami_finetuned.pkl'): |
|
self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-base') |
|
with open(model_path, 'rb') as f: |
|
self.model = pickle.load(f) |
|
|
|
def process(self, text): |
|
inputs = self.tokenizer(text, return_tensors="pt", max_length=1024, truncation=True) |
|
summary_ids = self.model.generate(inputs["input_ids"], max_length=150, min_length=40) |
|
return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
|
|
|
|
def process_audio(audio_file): |
|
"""Process text for summarization. |
|
|
|
Args: |
|
text (str): Text to summarize |
|
max_length (int): Maximum length of summary |
|
min_length (int): Minimum length of summary |
|
|
|
Returns: |
|
str: Summarized text |
|
""" |
|
try: |
|
text = transcriber.process(audio_file) |
|
summary = summarizer.process(text) |
|
return { |
|
"transcription": text, |
|
"summary": summary |
|
} |
|
except Exception as e: |
|
st.error(f"Error: {str(e)}") |
|
return None |