Manyue-DataScientist commited on
Commit
e168ebd
·
verified ·
1 Parent(s): 6d577a0

Update src/models/summarization.py

Browse files
Files changed (1) hide show
  1. src/models/summarization.py +20 -22
src/models/summarization.py CHANGED
@@ -8,24 +8,18 @@ import torch
8
  import streamlit as st
9
 
10
  class Summarizer:
11
- def __init__(self):
12
- """Initialize the summarization model."""
13
- self.model = None
 
14
 
15
- def load_model(self):
16
- """Load the BART summarization model."""
17
- try:
18
- self.model = pipeline(
19
- "summarization",
20
- model="facebook/bart-large-cnn",
21
- device=0 if torch.cuda.is_available() else -1
22
- )
23
- return self.model
24
- except Exception as e:
25
- st.error(f"Error loading summarization model: {str(e)}")
26
- return None
27
 
28
- def process(self, text: str, max_length: int = 130, min_length: int = 30):
 
29
  """Process text for summarization.
30
 
31
  Args:
@@ -36,9 +30,13 @@ class Summarizer:
36
  Returns:
37
  str: Summarized text
38
  """
39
- try:
40
- summary = self.model(text, max_length=max_length, min_length=min_length)
41
- return summary
42
- except Exception as e:
43
- st.error(f"Error in summarization: {str(e)}")
44
- return None
 
 
 
 
 
8
  import streamlit as st
9
 
10
  class Summarizer:
11
+ def __init__(self, model_path='bart_ami_finetuned.pkl'):
12
+ self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
13
+ with open(model_path, 'rb') as f:
14
+ self.model = pickle.load(f)
15
 
16
+ def process(self, text):
17
+ inputs = self.tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
18
+ summary_ids = self.model.generate(inputs["input_ids"], max_length=150, min_length=40)
19
+ return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
20
 
21
+
22
+ def process_audio(audio_file):
23
  """Process text for summarization.
24
 
25
  Args:
 
30
  Returns:
31
  str: Summarized text
32
  """
33
+ try:
34
+ text = transcriber.process(audio_file)
35
+ summary = summarizer.process(text)
36
+ return {
37
+ "transcription": text,
38
+ "summary": summary
39
+ }
40
+ except Exception as e:
41
+ st.error(f"Error: {str(e)}")
42
+ return None