Manyue-DataScientist commited on
Commit
9b2efc6
·
verified ·
1 Parent(s): 1ae5349

Update src/models/summarization.py

Browse files
Files changed (1) hide show
  1. src/models/summarization.py +15 -9
src/models/summarization.py CHANGED
@@ -1,8 +1,6 @@
1
-
2
- from transformers import BartTokenizer
3
  import torch
4
  import streamlit as st
5
- import pickle
6
 
7
  class Summarizer:
8
  def __init__(self):
@@ -12,18 +10,26 @@ class Summarizer:
12
  def load_model(self):
13
  try:
14
  self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
15
- with open('bart_ami_finetuned.pkl', 'rb') as f:
16
- self.model = pickle.load(f)
17
  return self.model
18
  except Exception as e:
19
  st.error(f"Error loading summarization model: {str(e)}")
20
  return None
21
 
22
- def process(self, text: str, max_length: int = 130, min_length: int = 30):
23
  try:
24
  inputs = self.tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
25
- summary_ids = self.model.generate(inputs["input_ids"], max_length=max_length, min_length=min_length)
26
- return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
27
  except Exception as e:
28
  st.error(f"Error in summarization: {str(e)}")
29
- return None
 
1
+ from transformers import BartTokenizer, BartForConditionalGeneration
 
2
  import torch
3
  import streamlit as st
 
4
 
5
  class Summarizer:
6
  def __init__(self):
 
10
  def load_model(self):
11
  try:
12
  self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
13
+ self.model = torch.load('bart_ami_finetuned.pkl')
14
+ self.model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
15
  return self.model
16
  except Exception as e:
17
  st.error(f"Error loading summarization model: {str(e)}")
18
  return None
19
 
20
+ def process(self, text: str, max_length: int = 150, min_length: int = 40):
21
  try:
22
  inputs = self.tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
23
+ inputs = {key: value.to(self.model.device) for key, value in inputs.items()}
24
+ summary_ids = self.model.generate(
25
+ inputs["input_ids"],
26
+ max_length=max_length,
27
+ min_length=min_length,
28
+ num_beams=4,
29
+ length_penalty=2.0
30
+ )
31
+ summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
32
+ return summary
33
  except Exception as e:
34
  st.error(f"Error in summarization: {str(e)}")
35
+ return None