Spaces:
Runtime error
Runtime error
File size: 1,840 Bytes
41508f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
"""
Allows to predict the summary for a given entry text
"""
import torch
import nltk
import contractions
import re
import string
nltk.download('stopwords')
nltk.download('punkt')
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
def clean_data(texts):
texts = texts.lower()
texts = contractions.fix(texts)
texts = texts.translate(str.maketrans("", "", string.punctuation))
texts = re.sub(r'\n',' ',texts)
return texts
def inferenceAPI(text: str) -> str:
"""
Predict the summary for an input text
--------
Parameter
text: str
the text to sumarize
Return
str
The summary for the input text
"""
# On défini les paramètres d'entrée pour le modèle
text = clean_data(text)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer= (AutoTokenizer.from_pretrained("./summarization_t5"))
# load local model
model = (AutoModelForSeq2SeqLM
.from_pretrained("./summarization_t5")
.to(device))
text_encoding = tokenizer(
text,
max_length=1024,
padding='max_length',
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors='pt'
)
generated_ids = model.generate(
input_ids=text_encoding['input_ids'],
attention_mask=text_encoding['attention_mask'],
max_length=128,
num_beams=8,
length_penalty=0.8,
early_stopping=True
)
preds = [
tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
for gen_id in generated_ids
]
return "".join(preds)
if __name__ == "__main__":
text = input('Entrez votre phrase à résumer : ')
print('summary:',inferenceAPI(text))
|