Manyue-DataScientist commited on
Commit
be5ea7d
·
verified ·
1 Parent(s): 2d5dff2

Update src/models/summarization.py

Browse files
Files changed (1) hide show
  1. src/models/summarization.py +23 -20
src/models/summarization.py CHANGED
@@ -1,41 +1,44 @@
1
  """
2
  Summarization Model Handler
3
- Manages the fine-tuned BART model for text summarization.
4
  """
5
 
6
- from transformers import BartTokenizer, BartForConditionalGeneration
7
  import torch
8
  import streamlit as st
9
 
10
  class Summarizer:
11
  def __init__(self):
 
12
  self.model = None
13
- self.tokenizer = None
14
 
15
  def load_model(self):
 
16
  try:
17
- self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
18
- self.model = torch.load('bart_ami_finetuned.pkl')
19
- self.model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
 
 
20
  return self.model
21
  except Exception as e:
22
  st.error(f"Error loading summarization model: {str(e)}")
23
  return None
24
 
25
- def process(self, text: str, max_length: int = 150, min_length: int = 40):
 
 
 
 
 
 
 
 
 
 
26
  try:
27
- inputs = self.tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
28
- inputs = {key: value.to(self.model.device) for key, value in inputs.items()}
29
- summary_ids = self.model.generate(
30
- inputs["input_ids"],
31
- max_length=max_length,
32
- min_length=min_length,
33
- num_beams=4,
34
- length_penalty=2.0
35
- )
36
- summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
37
- # Return in the expected format
38
- return [{"summary_text": summary}]
39
  except Exception as e:
40
  st.error(f"Error in summarization: {str(e)}")
41
- return None
 
1
  """
2
  Summarization Model Handler
3
+ Manages the BART model for text summarization.
4
  """
5
 
6
+ from transformers import pipeline
7
  import torch
8
  import streamlit as st
9
 
10
  class Summarizer:
11
  def __init__(self):
12
+ """Initialize the summarization model."""
13
  self.model = None
 
14
 
15
  def load_model(self):
16
+ """Load the BART summarization model."""
17
  try:
18
+ self.model = pipeline(
19
+ "summarization",
20
+ model="facebook/bart-large-cnn",
21
+ device=0 if torch.cuda.is_available() else -1
22
+ )
23
  return self.model
24
  except Exception as e:
25
  st.error(f"Error loading summarization model: {str(e)}")
26
  return None
27
 
28
+ def process(self, text: str, max_length: int = 130, min_length: int = 30):
29
+ """Process text for summarization.
30
+
31
+ Args:
32
+ text (str): Text to summarize
33
+ max_length (int): Maximum length of summary
34
+ min_length (int): Minimum length of summary
35
+
36
+ Returns:
37
+ str: Summarized text
38
+ """
39
  try:
40
+ summary = self.model(text, max_length=max_length, min_length=min_length)
41
+ return summary
 
 
 
 
 
 
 
 
 
 
42
  except Exception as e:
43
  st.error(f"Error in summarization: {str(e)}")
44
+ return None