Manyue-DataScientist commited on
Commit
f6e66c5
·
verified ·
1 Parent(s): 0d5109a

Update src/models/summarization.py

Browse files
Files changed (1) hide show
  1. src/models/summarization.py +37 -13
src/models/summarization.py CHANGED
@@ -1,40 +1,64 @@
1
  """
2
  Summarization Model Handler
3
- Manages the BART model for text summarization.
4
  """
5
 
6
- from transformers import BartTokenizer
7
  import torch
8
  import streamlit as st
9
- import pickle
10
 
11
  class Summarizer:
12
  def __init__(self):
 
13
  self.model = None
14
  self.tokenizer = None
15
 
16
  def load_model(self):
 
17
  try:
18
- with open('bart_ami_finetuned.pkl', 'rb') as f:
19
- self.model = pickle.load(f)
20
- self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
 
 
 
 
 
21
  return self.model
22
  except Exception as e:
23
- st.error(f"Error loading summarization model: {str(e)}")
24
  return None
25
 
26
- def process(self, text: str, max_length: int = 150, min_length: int = 40):
 
 
 
 
 
 
 
 
 
 
27
  try:
28
- inputs = self.tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
 
 
 
 
 
 
29
  summary_ids = self.model.generate(
30
  inputs["input_ids"],
31
  max_length=max_length,
32
  min_length=min_length,
33
- num_beams=4,
34
- length_penalty=2.0
35
  )
 
 
36
  summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
37
- return [{"summary_text": summary}]
38
  except Exception as e:
39
  st.error(f"Error in summarization: {str(e)}")
40
- return None
 
1
  """
2
  Summarization Model Handler
3
+ Manages the fine-tuned BART model for text summarization.
4
  """
5
 
6
+ from transformers import BartTokenizer, BartForConditionalGeneration
7
  import torch
8
  import streamlit as st
 
9
 
10
  class Summarizer:
11
  def __init__(self):
12
+ """Initialize the summarization model."""
13
  self.model = None
14
  self.tokenizer = None
15
 
16
  def load_model(self):
17
+ """Load the fine-tuned BART summarization model."""
18
  try:
19
+ # Load the tokenizer
20
+ self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
21
+
22
+ # Load the fine-tuned model
23
+ self.model = BartForConditionalGeneration.from_pretrained("bart_ami_finetuned.pkl")
24
+
25
+ # Move model to appropriate device (GPU if available)
26
+ self.model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
27
  return self.model
28
  except Exception as e:
29
+ st.error(f"Error loading fine-tuned summarization model: {str(e)}")
30
  return None
31
 
32
+ def process(self, text: str, max_length: int = 130, min_length: int = 30):
33
+ """Process text for summarization.
34
+
35
+ Args:
36
+ text (str): Text to summarize
37
+ max_length (int): Maximum length of summary
38
+ min_length (int): Minimum length of summary
39
+
40
+ Returns:
41
+ str: Summarized text
42
+ """
43
  try:
44
+ # Tokenize input text
45
+ inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=1024, padding="max_length")
46
+
47
+ # Move inputs to the same device as the model
48
+ inputs = {key: value.to(self.model.device) for key, value in inputs.items()}
49
+
50
+ # Generate summary
51
  summary_ids = self.model.generate(
52
  inputs["input_ids"],
53
  max_length=max_length,
54
  min_length=min_length,
55
+ num_beams=4, # Beam search for better quality
56
+ early_stopping=True
57
  )
58
+
59
+ # Decode summary tokens to text
60
  summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
61
+ return summary
62
  except Exception as e:
63
  st.error(f"Error in summarization: {str(e)}")
64
+ return None