Manyue-DataScientist commited on
Commit
1ae5349
·
verified ·
1 Parent(s): 45b780c

Update src/models/summarization.py

Browse files
Files changed (1) hide show
  1. src/models/summarization.py +9 -24
src/models/summarization.py CHANGED
@@ -1,44 +1,29 @@
1
- """
2
- Summarization Model Handler
3
- Manages the BART model for text summarization.
4
- """
5
 
6
- from transformers import pipeline
7
  import torch
8
  import streamlit as st
 
9
 
10
  class Summarizer:
11
  def __init__(self):
12
- """Initialize the summarization model."""
13
  self.model = None
 
14
 
15
  def load_model(self):
16
- """Load the BART summarization model."""
17
  try:
18
- self.model = pipeline(
19
- "summarization",
20
- model="facebook/bart-large-cnn",
21
- device=0 if torch.cuda.is_available() else -1
22
- )
23
  return self.model
24
  except Exception as e:
25
  st.error(f"Error loading summarization model: {str(e)}")
26
  return None
27
 
28
  def process(self, text: str, max_length: int = 130, min_length: int = 30):
29
- """Process text for summarization.
30
-
31
- Args:
32
- text (str): Text to summarize
33
- max_length (int): Maximum length of summary
34
- min_length (int): Minimum length of summary
35
-
36
- Returns:
37
- str: Summarized text
38
- """
39
  try:
40
- summary = self.model(text, max_length=max_length, min_length=min_length)
41
- return summary
 
42
  except Exception as e:
43
  st.error(f"Error in summarization: {str(e)}")
44
  return None
 
 
 
 
 
1
 
2
+ from transformers import BartTokenizer
3
  import torch
4
  import streamlit as st
5
+ import pickle
6
 
7
  class Summarizer:
8
  def __init__(self):
 
9
  self.model = None
10
+ self.tokenizer = None
11
 
12
  def load_model(self):
 
13
  try:
14
+ self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
15
+ with open('bart_ami_finetuned.pkl', 'rb') as f:
16
+ self.model = pickle.load(f)
 
 
17
  return self.model
18
  except Exception as e:
19
  st.error(f"Error loading summarization model: {str(e)}")
20
  return None
21
 
22
  def process(self, text: str, max_length: int = 130, min_length: int = 30):
 
 
 
 
 
 
 
 
 
 
23
  try:
24
+ inputs = self.tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
25
+ summary_ids = self.model.generate(inputs["input_ids"], max_length=max_length, min_length=min_length)
26
+ return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
27
  except Exception as e:
28
  st.error(f"Error in summarization: {str(e)}")
29
  return None