"""
 Allows to predict the summary for a given entry text
"""
import torch
import contractions
import re
import string
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer


def clean_text(texts: str) -> str:
    texts = texts.lower()
    texts = contractions.fix(texts)
    texts = texts.translate(str.maketrans("", "", string.punctuation))
    texts = re.sub(r'\n', ' ', texts)
    return texts


def inferenceAPI(text: str) -> str:
    """
    Predict the summary for an input text
    --------
    Parameter
        text: str
            the text to sumarize
    Return
        str
            The summary for the input text
    """
    # On défini les paramètres d'entrée pour le modèle
    text = clean_text(text)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = (AutoTokenizer.from_pretrained("Linggg/t5_summary"))
    # load local model
    model = (AutoModelForSeq2SeqLM
             .from_pretrained("Linggg/t5_summary")
             .to(device))
    text_encoding = tokenizer(
        text,
        max_length=1024,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        add_special_tokens=True,
        return_tensors='pt'
    )
    generated_ids = model.generate(
        input_ids=text_encoding['input_ids'],
        attention_mask=text_encoding['attention_mask'],
        max_length=128,
        num_beams=8,
        length_penalty=0.8,
        early_stopping=True
    )

    preds = [
        tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        for gen_id in generated_ids
    ]
    return "".join(preds)


if __name__ == "__main__":
    '''
    '''
    text = input('Entrez votre phrase à résumer : ')
    print('summary:', inferenceAPI(text))