import gradio as gr
import torch
#import python_multipart
import os
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from datasets import load_dataset, Audio
import numpy as np
from speechbrain.inference import EncoderClassifier

# Load models and processor
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
model = SpeechT5ForTextToSpeech.from_pretrained("Solo448/Speect5-common-voice-Hindi")
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")

# Load speaker encoder
device = "cuda" if torch.cuda.is_available() else "cpu"
speaker_model = EncoderClassifier.from_hparams(
    source="speechbrain/spkrec-xvect-voxceleb",
    run_opts={"device": device},
    savedir=os.path.join("/tmp", "speechbrain/spkrec-xvect-voxceleb")
)

# Load a sample from the dataset for speaker embedding
try:
    dataset = load_dataset("mozilla-foundation/common_voice_17_0", "hi", split="validated", trust_remote_code=True)
    dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
    sample = dataset[0]
    speaker_embedding = create_speaker_embedding(sample['audio']['array'])
except Exception as e:
    print(f"Error loading dataset: {e}")
    # Use a random speaker embedding as fallback
    speaker_embedding = torch.randn(1, 512)

def create_speaker_embedding(waveform):
    with torch.no_grad():
        speaker_embeddings = speaker_model.encode_batch(torch.tensor(waveform))
        speaker_embeddings = torch.nn.functional.normalize(speaker_embeddings, dim=2)
        speaker_embeddings = speaker_embeddings.squeeze().cpu().numpy()
    return speaker_embeddings

def text_to_speech(text):
    # Clean up text
    replacements = [
    ("अ", "a"),
    ("आ", "aa"),
    ("इ", "i"),
    ("ई", "ee"),
    ("उ", "u"),
    ("ऋ", "ri"),
    ("ए", "ae"),
    ("ऐ", "ai"),
    ("ऑ", "au"),
    ("ओ", "o"),
    ("औ", "au"),
    ("क", "k"),
    ("ख", "kh"),
    ("ग", "g"),
    ("घ", "gh"),
    ("च", "ch"),
    ("छ", "chh"),
    ("ज", "j"),
    ("झ", "jh"),
    ("ञ", "gna"),
    ("ट", "t"),
    ("ठ", "th"),
    ("ड", "d"),
    ("ढ", "dh"),
    ("ण", "nr"),
    ("त", "t"),
    ("थ", "th"),
    ("द", "d"),
    ("ध", "dh"),
    ("न", "n"),
    ("प", "p"),
    ("फ", "ph"),
    ("ब", "b"),
    ("भ", "bh"),
    ("म", "m"),
    ("य", "ya"),
    ("र", "r"),
    ("ल", "l"),
    ("व", "w"),
    ("श", "sha"),
    ("ष", "sh"),
    ("स", "s"),
    ("ह", "ha"),
    ("़", "ng"),
    ("्", ""),
    ("ऽ", ""),
    ("ा", "a"),
    ("ि", "i"),
    ("ी", "ee"),
    ("ु", "u"),
    ("ॅ", "n"),
    ("े", "e"),
    ("ै", "oi"),
    ("ो", "o"),
    ("ौ", "ou"),
    ("ॅ", "n"),
    ("ॉ", "r"),
    ("ू", "uh"),
    ("ृ", "ri"),
    ("ं", "n"),
    ("क़", "q"),
    ("ज़", "z"),
    ("ड़", "r"),
    ("ढ़", "rh"),
    ("फ़", "f"),
    ("|", ".")
    ]
    for src, dst in replacements:
        text = text.replace(src, dst)

    inputs = processor(text=text, return_tensors="pt")
    speech = model.generate_speech(inputs["input_ids"], speaker_embedding, vocoder=vocoder)
    return (16000, speech.numpy())

iface = gr.Interface(
    fn=text_to_speech,
    inputs="text",
    outputs="audio",
    title="Hindi Text-to-Speech",
    description="Enter hindi text to convert to speech"
)

iface.launch(share=True)