"""
 Allows to predict the summary for a given entry text
"""
import torch
import nltk
import contractions
import re
import string
nltk.download('stopwords')
nltk.download('punkt')
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

def clean_data(texts):
    texts = texts.lower()
    texts = contractions.fix(texts)
    texts = texts.translate(str.maketrans("", "", string.punctuation))
    texts = re.sub(r'\n',' ',texts)
    return texts

def inferenceAPI(text: str) -> str:
    """
    Predict the summary for an input text
    --------
    Parameter
        text: str
            the text to sumarize
    Return
        str
            The summary for the input text
    """
    # On défini les paramètres d'entrée pour le modèle
    text = clean_data(text)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer= (AutoTokenizer.from_pretrained("./summarization_t5"))
    # load local model
    model = (AutoModelForSeq2SeqLM
            .from_pretrained("./summarization_t5")
            .to(device))
    text_encoding = tokenizer(
        text,
        max_length=1024,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        add_special_tokens=True,
        return_tensors='pt'
    )
    generated_ids = model.generate(
        input_ids=text_encoding['input_ids'],
        attention_mask=text_encoding['attention_mask'],
        max_length=128,
        num_beams=8,
        length_penalty=0.8,
        early_stopping=True
    )

    preds = [
            tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
            for gen_id in generated_ids
    ]
    return "".join(preds)

if __name__ == "__main__":
     text = input('Entrez votre phrase à résumer : ')
     print('summary:',inferenceAPI(text))