File size: 1,785 Bytes
41508f8
 
 
 
 
 
 
 
5925e5f
 
41508f8
 
5925e5f
41508f8
 
5925e5f
1aab2b0
41508f8
 
 
 
 
 
 
 
 
 
4874293
41508f8
5925e5f
41508f8
1aab2b0
41508f8
 
1aab2b0
5925e5f
4874293
41508f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5925e5f
 
41508f8
 
 
5925e5f
4874293
 
1aab2b0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""
 Allows to predict the summary for a given entry text
"""
import torch
import re
import string
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer


def clean_text(texts: str) -> str:
    texts = texts.lower()
    texts = texts.translate(str.maketrans("", "", string.punctuation))
    texts = re.sub(r'\n', ' ', texts)
    return texts


def inferenceAPI_T5(text: str) -> str:
    """
    Predict the summary for an input text
    --------
    Parameter
        text: str
            the text to sumarize
    Return
        str
            The summary for the input text
    """

    # On défini les paramètres d'entrée pour le modèle
    text = clean_text(text)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = (AutoTokenizer.from_pretrained("Linggg/t5_summary",use_auth_token=True))
    # load local model
    model = (AutoModelForSeq2SeqLM
             .from_pretrained("Linggg/t5_summary",use_auth_token=True)
             .to(device))

    text_encoding = tokenizer(
        text,
        max_length=1024,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        add_special_tokens=True,
        return_tensors='pt'
    )
    generated_ids = model.generate(
        input_ids=text_encoding['input_ids'],
        attention_mask=text_encoding['attention_mask'],
        max_length=128,
        num_beams=8,
        length_penalty=0.8,
        early_stopping=True
    )

    preds = [
        tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        for gen_id in generated_ids
    ]
    return "".join(preds)


# if __name__ == "__main__":
#     text = input('Entrez votre phrase à résumer : ')
#     print('summary:', inferenceAPI_T5(text))