Manyue-DataScientist commited on
Commit
8998fb8
·
verified ·
1 Parent(s): 15b7502

Update src/models/summarization.py

Browse files
Files changed (1) hide show
  1. src/models/summarization.py +10 -30
src/models/summarization.py CHANGED
@@ -9,53 +9,33 @@ import streamlit as st
9
 
10
  class Summarizer:
11
  def __init__(self):
12
- """Initialize the summarization model."""
13
  self.model = None
14
  self.tokenizer = None
15
 
16
  def load_model(self):
17
- """Load the fine-tuned BART summarization model."""
18
  try:
19
- with open('bart_ami_finetuned.pkl','rb') as f:
20
- self.model = pickle.load(f)
21
- # Load the tokenizer
22
- self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
23
-
24
  return self.model
25
  except Exception as e:
26
- st.error(f"Error loading fine-tuned summarization model: {str(e)}")
27
  return None
28
 
29
- def process(self, text: str, max_length: int = 130, min_length: int = 30):
30
- """Process text for summarization.
31
-
32
- Args:
33
- text (str): Text to summarize
34
- max_length (int): Maximum length of summary
35
- min_length (int): Minimum length of summary
36
-
37
- Returns:
38
- str: Summarized text
39
- """
40
  try:
41
- # Tokenize input text
42
- inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=1024, padding="max_length")
43
-
44
- # Move inputs to the same device as the model
45
  inputs = {key: value.to(self.model.device) for key, value in inputs.items()}
46
-
47
- # Generate summary
48
  summary_ids = self.model.generate(
49
  inputs["input_ids"],
50
  max_length=max_length,
51
  min_length=min_length,
52
- num_beams=4, # Beam search for better quality
53
- early_stopping=True
54
  )
55
-
56
- # Decode summary tokens to text
57
  summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
58
- return summary
 
59
  except Exception as e:
60
  st.error(f"Error in summarization: {str(e)}")
61
  return None
 
9
 
10
  class Summarizer:
11
  def __init__(self):
 
12
  self.model = None
13
  self.tokenizer = None
14
 
15
  def load_model(self):
 
16
  try:
17
+ self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
18
+ self.model = torch.load('bart_ami_finetuned.pkl')
19
+ self.model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
 
 
20
  return self.model
21
  except Exception as e:
22
+ st.error(f"Error loading summarization model: {str(e)}")
23
  return None
24
 
25
+ def process(self, text: str, max_length: int = 150, min_length: int = 40):
 
 
 
 
 
 
 
 
 
 
26
  try:
27
+ inputs = self.tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
 
 
 
28
  inputs = {key: value.to(self.model.device) for key, value in inputs.items()}
 
 
29
  summary_ids = self.model.generate(
30
  inputs["input_ids"],
31
  max_length=max_length,
32
  min_length=min_length,
33
+ num_beams=4,
34
+ length_penalty=2.0
35
  )
 
 
36
  summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
37
+ # Return in the expected format
38
+ return [{"summary_text": summary}]
39
  except Exception as e:
40
  st.error(f"Error in summarization: {str(e)}")
41
  return None