Spaces:
Runtime error
Runtime error
File size: 1,830 Bytes
41508f8 3c03f61 41508f8 5925e5f 41508f8 3c03f61 41508f8 5925e5f 4e410f4 9cd8995 41508f8 4874293 41508f8 5925e5f 41508f8 4e410f4 1aab2b0 41508f8 1aab2b0 5925e5f 4874293 6158825 41508f8 3c03f61 41508f8 3c03f61 41508f8 3c03f61 41508f8 3c03f61 41508f8 3c03f61 5925e5f 41508f8 5925e5f 4874293 1aab2b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
"""
Allows to predict the summary for a given entry text
"""
import re
import string
import contractions
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
def clean_text(texts: str) -> str:
texts = texts.lower()
texts = texts.translate(str.maketrans("", "", string.punctuation))
texts = re.sub(r"\n", " ", texts)
return texts
def inference_t5(text: str) -> str:
"""
Predict the summary for an input text
--------
Parameter
text: str
the text to sumarize
Return
str
The summary for the input text
"""
# On défini les paramètres d'entrée pour le modèle
text = clean_text(text)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = (AutoTokenizer.from_pretrained("Linggg/t5_summary",use_auth_token=True))
# load local model
model = (AutoModelForSeq2SeqLM
.from_pretrained("Linggg/t5_summary",use_auth_token=True)
.to(device))
text_encoding = tokenizer(
text,
max_length=1024,
padding="max_length",
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
generated_ids = model.generate(
input_ids=text_encoding["input_ids"],
attention_mask=text_encoding["attention_mask"],
max_length=128,
num_beams=8,
length_penalty=0.8,
early_stopping=True,
)
preds = [
tokenizer.decode(
gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
for gen_id in generated_ids
]
return "".join(preds)
# if __name__ == "__main__":
# text = input('Entrez votre phrase à résumer : ')
# print('summary:', inferenceAPI_T5(text))
|