Manyue-DataScientist commited on
Commit
08d05f4
·
verified ·
1 Parent(s): 7ee06b0

Update src/models/summarization.py

Browse files
Files changed (1) hide show
  1. src/models/summarization.py +29 -31
src/models/summarization.py CHANGED
@@ -3,40 +3,38 @@ Summarization Model Handler
3
  Manages the BART model for text summarization.
4
  """
5
 
6
- from transformers import pipeline
7
  import torch
8
  import streamlit as st
 
9
 
10
  class Summarizer:
11
- def __init__(self, model_path='bart_ami_finetuned.pkl'):
12
- self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
13
- with open(model_path, 'rb') as f:
14
- self.model = pickle.load(f)
15
 
16
- def process(self, text):
17
- inputs = self.tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
18
- summary_ids = self.model.generate(inputs["input_ids"], max_length=150, min_length=40)
19
- return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
 
 
 
 
 
20
 
21
-
22
- def process_audio(audio_file):
23
- """Process text for summarization.
24
-
25
- Args:
26
- text (str): Text to summarize
27
- max_length (int): Maximum length of summary
28
- min_length (int): Minimum length of summary
29
-
30
- Returns:
31
- str: Summarized text
32
- """
33
- try:
34
- text = transcriber.process(audio_file)
35
- summary = summarizer.process(text)
36
- return {
37
- "transcription": text,
38
- "summary": summary
39
- }
40
- except Exception as e:
41
- st.error(f"Error: {str(e)}")
42
- return None
 
3
  Manages the BART model for text summarization.
4
  """
5
 
6
+ from transformers import BartTokenizer
7
  import torch
8
  import streamlit as st
9
+ import pickle
10
 
11
  class Summarizer:
12
+ def __init__(self):
13
+ self.model = None
14
+ self.tokenizer = None
 
15
 
16
+ def load_model(self):
17
+ try:
18
+ with open('bart_ami_finetuned.pkl', 'rb') as f:
19
+ self.model = pickle.load(f)
20
+ self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
21
+ return self.model
22
+ except Exception as e:
23
+ st.error(f"Error loading summarization model: {str(e)}")
24
+ return None
25
 
26
+ def process(self, text: str, max_length: int = 150, min_length: int = 40):
27
+ try:
28
+ inputs = self.tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
29
+ summary_ids = self.model.generate(
30
+ inputs["input_ids"],
31
+ max_length=max_length,
32
+ min_length=min_length,
33
+ num_beams=4,
34
+ length_penalty=2.0
35
+ )
36
+ summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
37
+ return [{"summary_text": summary}]
38
+ except Exception as e:
39
+ st.error(f"Error in summarization: {str(e)}")
40
+ return None