import torch   
import gradio as gr 
import librosa 
import tempfile
from typing import Optional
from TTS.config import load_config
from transformers import AutoFeatureExtractor, AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer


first_generation = True
device = 'cuda' if torch.cuda.is_available() else 'cpu'


def load_and_fix_data(input_file, model_sampling_rate):
    speech, sample_rate = librosa.load(input_file)
    if len(speech.shape) > 1:
        speech = speech[:, 0] + speech[:, 1]
    if sample_rate != model_sampling_rate:
        speech = librosa.resample(speech, sample_rate, model_sampling_rate)
    return speech


feature_extractor = AutoFeatureExtractor.from_pretrained("jonatasgrosman/wav2vec2-xls-r-1b-spanish")
sampling_rate = feature_extractor.sampling_rate

asr = pipeline("automatic-speech-recognition", model="jonatasgrosman/wav2vec2-xls-r-1b-spanish")

prefix = ''
model_checkpoint = "hackathon-pln-es/es_text_neutralizer"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)


manager = ModelManager()
MODEL_NAMES = manager.list_tts_models()


def postproc(input_sentence, preds):
    try:
        preds = preds.replace('De el', 'Del').replace('de el', 'del').replace('  ', ' ')
        if preds[0].islower():
            preds = preds.capitalize()
        preds = preds.replace(' . ', '. ').replace(' , ', ', ')

        # Nombres en mayusculas
        prev_letter = ''
        for word in input_sentence.split(' '):
            if word:
                if word[0].isupper():
                    if word.lower() in preds and word != input_sentence.split(' ')[0]:
                        if prev_letter == '.':
                            preds = preds.replace('. ' + word.lower() + ' ', '. ' + word + ' ')
                        else:
                            if word[-1] == '.':
                                preds = preds.replace(word.lower(), word)
                            else:
                                preds = preds.replace(word.lower() + ' ', word + ' ')
                prev_letter = word[-1]
        preds = preds.strip()  # quitar ultimo espacio
    except:
        pass
    return preds
    
model_name = "es/mai/tacotron2-DDC"    
MAX_TXT_LEN = 100

def predict_and_ctc_lm_decode(input_file, speaker_idx: str=None):
    speech = load_and_fix_data(input_file, sampling_rate)
    transcribed_text = asr(speech, chunk_length_s=10, stride_length_s=1)
    transcribed_text = transcribed_text["text"]
    inputs = tokenizer([prefix + transcribed_text], return_tensors="pt", padding=True)
    with torch.no_grad():
        if first_generation:
            output_sequence = model.generate(
                input_ids=inputs["input_ids"].to(device),
                attention_mask=inputs["attention_mask"].to(device),
                do_sample=False,  # disable sampling to test if batching affects output
            )
        else:

            output_sequence = model.generate(
                input_ids=inputs["input_ids"].to(device),
                attention_mask=inputs["attention_mask"].to(device),
                do_sample=False,  
                num_beams=2,
                repetition_penalty=2.5, 
                # length_penalty=1.0, 
                early_stopping=True# disable sampling to test if batching affects output
            )
    text = postproc(transcribed_text,
                     preds=tokenizer.decode(output_sequence[0], skip_special_tokens=True, clean_up_tokenization_spaces=True))
    if len(text) > MAX_TXT_LEN:
        text = text[:MAX_TXT_LEN]
        print(f"Input text was cutoff since it went over the {MAX_TXT_LEN} character limit.")
    print(text, model_name)
    # download model
    model_path, config_path, model_item = manager.download_model(f"tts_models/{model_name}")
    vocoder_name: Optional[str] = model_item["default_vocoder"]
    # download vocoder
    vocoder_path = None
    vocoder_config_path = None
    if vocoder_name is not None:
        vocoder_path, vocoder_config_path, _ = manager.download_model(vocoder_name)
    # init synthesizer
    synthesizer = Synthesizer(
        model_path, config_path, None, None, vocoder_path, vocoder_config_path,
    )
    # synthesize
    if synthesizer is None:
        raise NameError("model not found")
    wavs = synthesizer.tts(text, speaker_idx)
    # return output
    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
        synthesizer.save_wav(wavs, fp)
        return fp.name

description = """This is a Gradio demo for generating gender-neutralized audios. To use it, simply provide an audio input (via microphone or audio recording), which will then be transcribed and gender-neutralized using pre-trained models. Finally, with the help of Coqui's TTS model, gender neutralized audio is generated.

Pre-trained model used for Spanish ASR: [jonatasgrosman/wav2vec2-xls-r-1b-spanish](https://huggingface.co/jonatasgrosman/wav2vec2-xls-r-1b-spanish)

Pre-trained model used for Gender Neutralization: [hackathon-pln-es/es_text_neutralizer](https://huggingface.co/hackathon-pln-es/es_text_neutralizer)

Pre-trained model used for TTS: πŸΈπŸ’¬ CoquiTTS => model_name = "es/mai/tacotron2-DDC"  

"""


article = """ **ACKNOWLEDGEMENT:**

**This project is based on the following Spaces:**

[CoquiTTS](https://huggingface.co/spaces/coqui/CoquiTTS)

[es_nlp_gender_neutralizer](https://huggingface.co/spaces/hackathon-pln-es/es_nlp_gender_neutralizer)

[Hindi_ASR](https://huggingface.co/spaces/anuragshas/Hindi_ASR)

"""


gr.Interface(
    predict_and_ctc_lm_decode,
    inputs=[
        gr.inputs.Audio(source="microphone", type="filepath", label="Record your audio")
    ],
    outputs=gr.outputs.Audio(label="Output"),
    examples=[["Example1.wav"],["Example2.wav"],["Example3.wav"]],
    title="Generate-Gender-Neutralized-Audios",
    description = description,
    article=article,
    layout="horizontal",
    theme="huggingface",
).launch(enable_queue=True, cache_examples=True)