import os
import torch
import numpy as np
import librosa
import soundfile as sf
import streamlit as st
from tqdm import tqdm
from speechbrain.pretrained import Tacotron2, HIFIGAN

# Paths
output_path = "./processed_data/"
os.makedirs(output_path, exist_ok=True)

# Preprocessing Function
def preprocess_audio(audio_path, max_length=1000):
    """
    Preprocess the audio file to generate mel spectrogram with uniform length.
    """
    wav, sr = librosa.load(audio_path, sr=24000)
    mel_spectrogram = librosa.feature.melspectrogram(
        y=wav, sr=sr, n_fft=2048, hop_length=256, n_mels=120
    )
    mel_spectrogram = np.log(np.maximum(1e-5, mel_spectrogram))  # Log normalization

    # Ensure all mel spectrograms have the same time dimension
    if mel_spectrogram.shape[1] > max_length:  # Truncate
        mel_spectrogram = mel_spectrogram[:, :max_length]
    else:  # Pad
        padding = max_length - mel_spectrogram.shape[1]
        mel_spectrogram = np.pad(mel_spectrogram, ((0, 0), (0, padding)), mode="constant")

    return mel_spectrogram

# Function to Split Long Text into Chunks
def split_text_into_chunks(text, max_chunk_length=200):
    """
    Splits the input text into smaller chunks, each of up to `max_chunk_length` characters.
    """
    words = text.split()
    chunks = []
    current_chunk = []
    current_length = 0

    for word in words:
        if current_length + len(word) + 1 > max_chunk_length:
            chunks.append(" ".join(current_chunk))
            current_chunk = []
            current_length = 0
        current_chunk.append(word)
        current_length += len(word) + 1  # Account for space

    if current_chunk:
        chunks.append(" ".join(current_chunk))

    return chunks

# Generate Speech for Long Text
def generate_speech(text, tacotron2, hifi_gan, output_file="long_speech.wav", sample_rate=24000):
    """
    Generates a long speech by splitting the text into chunks, generating audio for each,
    and concatenating the waveforms.
    """
    chunks = split_text_into_chunks(text)
    waveforms = []

    for chunk in tqdm(chunks, desc="Generating speech"):
        text_input = [str(chunk)]
        mel_output, mel_length, alignment = tacotron2.encode_batch(text_input)
        waveform = hifi_gan.decode_batch(mel_output)
        waveforms.append(waveform.squeeze().cpu().numpy())

    # Concatenate waveforms
    long_waveform = np.concatenate(waveforms, axis=0)

    # Save the concatenated audio
    sf.write(output_file, long_waveform, sample_rate)
    print(f"Audio has been synthesized and saved as '{output_file}'.")

# Load Pretrained Tacotron2 and HiFi-GAN
tacotron2 = Tacotron2.from_hparams(source="speechbrain/tts-tacotron2-ljspeech", savedir="tmpdir_tacotron2")
hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-ljspeech", savedir="tmpdir_hifigan")

# Fine-tuned model (if available)
if os.path.exists("indic_accent_tacotron2.pth"):
    tacotron2.load_state_dict(torch.load("indic_accent_tacotron2.pth"))
    print("Fine-tuned Tacotron2 model loaded successfully.")

# Streamlit UI
st.title("Text to Speech Generator")

# Text input for the user
text_input = st.text_area("Enter the text you want to convert to speech:", 
                          "Good morning, lovely listeners! This is your favorite RJ, Sapna...")

# Button to generate speech
if st.button("Generate Speech"):
    if text_input:
        output_file = "output_long_speech.wav"
        
        # Generate speech for the provided text
        with st.spinner("Generating speech..."):
            generate_speech(text_input, tacotron2, hifi_gan, output_file)

        # Provide download link
        st.success("Speech generation complete!")
        st.audio(output_file, format="audio/wav")
        st.download_button(label="Download Speech", data=open(output_file, "rb").read(), file_name=output_file, mime="audio/wav")
    else:
        st.warning("Please enter some text to generate speech.")