Update src/models/summarization.py
Browse files
src/models/summarization.py
CHANGED
@@ -16,12 +16,11 @@ class Summarizer:
|
|
16 |
def load_model(self):
|
17 |
"""Load the fine-tuned BART summarization model."""
|
18 |
try:
|
|
|
|
|
19 |
# Load the tokenizer
|
20 |
self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
|
21 |
|
22 |
-
# Load the fine-tuned model
|
23 |
-
self.model = BartForConditionalGeneration.from_pretrained("bart_ami_finetuned.pkl")
|
24 |
-
|
25 |
# Move model to appropriate device (GPU if available)
|
26 |
self.model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
27 |
return self.model
|
|
|
16 |
def load_model(self):
|
17 |
"""Load the fine-tuned BART summarization model."""
|
18 |
try:
|
19 |
+
with open('bart_ami_finetuned.pkl','rb') as f:
|
20 |
+
self.model = pickle.load(f)
|
21 |
# Load the tokenizer
|
22 |
self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
|
23 |
|
|
|
|
|
|
|
24 |
# Move model to appropriate device (GPU if available)
|
25 |
self.model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
26 |
return self.model
|