import copy
import io
import json
import logging
import re
from typing import List, Union

import numpy as np
from box import Box
from pydub import AudioSegment
from scipy.io import wavfile

from modules import generate_audio
from modules.api.utils import calc_spk_style
from modules.normalization import text_normalize
from modules.SentenceSplitter import SentenceSplitter
from modules.speaker import Speaker
from modules.ssml_parser.SSMLParser import SSMLBreak, SSMLContext, SSMLSegment
from modules.utils import rng
from modules.utils.audio import apply_prosody_to_audio_segment

logger = logging.getLogger(__name__)


def audio_data_to_segment_slow(audio_data, sr):
    byte_io = io.BytesIO()
    wavfile.write(byte_io, rate=sr, data=audio_data)
    byte_io.seek(0)

    return AudioSegment.from_file(byte_io, format="wav")


def clip_audio(audio_data: np.ndarray, threshold: float = 0.99):
    audio_data = np.clip(audio_data, -threshold, threshold)
    return audio_data


def normalize_audio(audio_data: np.ndarray, norm_factor: float = 0.8):
    max_amplitude = np.max(np.abs(audio_data))
    if max_amplitude > 0:
        audio_data = audio_data / max_amplitude * norm_factor
    return audio_data


def audio_data_to_segment(audio_data: np.ndarray, sr: int):
    """
    optimize: https://github.com/lenML/ChatTTS-Forge/issues/57
    """

    audio_data = normalize_audio(audio_data)
    audio_data = clip_audio(audio_data)

    audio_data = (audio_data * 32767).astype(np.int16)
    audio_segment = AudioSegment(
        audio_data.tobytes(),
        frame_rate=sr,
        sample_width=audio_data.dtype.itemsize,
        channels=1,
    )
    return audio_segment


def combine_audio_segments(audio_segments: list[AudioSegment]) -> AudioSegment:
    combined_audio = AudioSegment.empty()
    for segment in audio_segments:
        combined_audio += segment
    return combined_audio


def to_number(value, t, default=0):
    try:
        number = t(value)
        return number
    except (ValueError, TypeError) as e:
        return default


class TTSAudioSegment(Box):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._type = kwargs.get("_type", "voice")
        self.text = kwargs.get("text", "")
        self.temperature = kwargs.get("temperature", 0.3)
        self.top_P = kwargs.get("top_P", 0.5)
        self.top_K = kwargs.get("top_K", 20)
        self.spk = kwargs.get("spk", -1)
        self.infer_seed = kwargs.get("infer_seed", -1)
        self.prompt1 = kwargs.get("prompt1", "")
        self.prompt2 = kwargs.get("prompt2", "")
        self.prefix = kwargs.get("prefix", "")


class SynthesizeSegments:
    def __init__(self, batch_size: int = 8, eos="", spliter_thr=100):
        self.batch_size = batch_size
        self.batch_default_spk_seed = rng.np_rng()
        self.batch_default_infer_seed = rng.np_rng()
        self.eos = eos
        self.spliter_thr = spliter_thr

    def segment_to_generate_params(
        self, segment: Union[SSMLSegment, SSMLBreak]
    ) -> TTSAudioSegment:
        if isinstance(segment, SSMLBreak):
            return TTSAudioSegment(_type="break")

        if segment.get("params", None) is not None:
            params = segment.get("params")
            text = segment.get("text", None) or segment.text or ""
            return TTSAudioSegment(**params, text=text)

        text = segment.get("text", None) or segment.text or ""
        is_end = segment.get("is_end", False)

        text = str(text).strip()

        attrs = segment.attrs
        spk = attrs.spk
        style = attrs.style

        ss_params = calc_spk_style(spk, style)

        if "spk" in ss_params:
            spk = ss_params["spk"]

        seed = to_number(attrs.seed, int, ss_params.get("seed") or -1)
        top_k = to_number(attrs.top_k, int, None)
        top_p = to_number(attrs.top_p, float, None)
        temp = to_number(attrs.temp, float, None)

        prompt1 = attrs.prompt1 or ss_params.get("prompt1")
        prompt2 = attrs.prompt2 or ss_params.get("prompt2")
        prefix = attrs.prefix or ss_params.get("prefix")
        disable_normalize = attrs.get("normalize", "") == "False"

        seg = TTSAudioSegment(
            _type="voice",
            text=text,
            temperature=temp if temp is not None else 0.3,
            top_P=top_p if top_p is not None else 0.5,
            top_K=top_k if top_k is not None else 20,
            spk=spk if spk else -1,
            infer_seed=seed if seed else -1,
            prompt1=prompt1 if prompt1 else "",
            prompt2=prompt2 if prompt2 else "",
            prefix=prefix if prefix else "",
        )

        if not disable_normalize:
            seg.text = text_normalize(text, is_end=is_end)

        # NOTE 每个batch的默认seed保证前后一致即使是没设置spk的情况
        if seg.spk == -1:
            seg.spk = self.batch_default_spk_seed
        if seg.infer_seed == -1:
            seg.infer_seed = self.batch_default_infer_seed

        return seg

    def process_break_segments(
        self,
        src_segments: List[SSMLBreak],
        bucket_segments: List[SSMLBreak],
        audio_segments: List[AudioSegment],
    ):
        for segment in bucket_segments:
            index = src_segments.index(segment)
            audio_segments[index] = AudioSegment.silent(
                duration=int(segment.attrs.duration)
            )

    def process_voice_segments(
        self,
        src_segments: List[SSMLSegment],
        bucket: List[SSMLSegment],
        audio_segments: List[AudioSegment],
    ):
        for i in range(0, len(bucket), self.batch_size):
            batch = bucket[i : i + self.batch_size]
            param_arr = [self.segment_to_generate_params(segment) for segment in batch]

            def append_eos(text: str):
                text = text.strip()
                eos_arr = ["[uv_break]", "[v_break]", "[lbreak]", "[llbreak]"]
                has_eos = False
                for eos in eos_arr:
                    if eos in text:
                        has_eos = True
                        break
                if not has_eos:
                    text += self.eos
                return text

            # 这里会添加 end_of_text 到 text 之后
            texts = [append_eos(params.text) for params in param_arr]

            params = param_arr[0]
            audio_datas = generate_audio.generate_audio_batch(
                texts=texts,
                temperature=params.temperature,
                top_P=params.top_P,
                top_K=params.top_K,
                spk=params.spk,
                infer_seed=params.infer_seed,
                prompt1=params.prompt1,
                prompt2=params.prompt2,
                prefix=params.prefix,
            )
            for idx, segment in enumerate(batch):
                sr, audio_data = audio_datas[idx]
                rate = float(segment.get("rate", "1.0"))
                volume = float(segment.get("volume", "0"))
                pitch = float(segment.get("pitch", "0"))

                audio_segment = audio_data_to_segment(audio_data, sr)
                audio_segment = apply_prosody_to_audio_segment(
                    audio_segment, rate=rate, volume=volume, pitch=pitch
                )
                # compare by Box object
                original_index = src_segments.index(segment)
                audio_segments[original_index] = audio_segment

    def bucket_segments(
        self, segments: List[Union[SSMLSegment, SSMLBreak]]
    ) -> List[List[Union[SSMLSegment, SSMLBreak]]]:
        buckets = {"<break>": []}
        for segment in segments:
            if isinstance(segment, SSMLBreak):
                buckets["<break>"].append(segment)
                continue

            params = self.segment_to_generate_params(segment)

            if isinstance(params.spk, Speaker):
                params.spk = str(params.spk.id)

            key = json.dumps(
                {k: v for k, v in params.items() if k != "text"}, sort_keys=True
            )
            if key not in buckets:
                buckets[key] = []
            buckets[key].append(segment)

        return buckets

    def split_segments(self, segments: List[Union[SSMLSegment, SSMLBreak]]):
        """
        将 segments 中的 text 经过 spliter 处理成多个 segments
        """
        spliter = SentenceSplitter(threshold=self.spliter_thr)
        ret_segments: List[Union[SSMLSegment, SSMLBreak]] = []

        for segment in segments:
            if isinstance(segment, SSMLBreak):
                ret_segments.append(segment)
                continue

            text = segment.text
            if not text:
                continue

            sentences = spliter.parse(text)
            for sentence in sentences:
                seg = SSMLSegment(
                    text=sentence,
                    attrs=segment.attrs.copy(),
                    params=copy.copy(segment.params),
                )
                ret_segments.append(seg)
                setattr(seg, "_idx", len(ret_segments) - 1)

        def is_none_speak_segment(segment: SSMLSegment):
            text = segment.text.strip()
            regexp = r"\[[^\]]+?\]"
            text = re.sub(regexp, "", text)
            text = text.strip()
            if not text:
                return True
            return False

        # 将 none_speak 合并到前一个 speak segment
        for i in range(1, len(ret_segments)):
            if is_none_speak_segment(ret_segments[i]):
                ret_segments[i - 1].text += ret_segments[i].text
                ret_segments[i].text = ""
        # 移除空的 segment
        ret_segments = [seg for seg in ret_segments if seg.text.strip()]

        return ret_segments

    def synthesize_segments(
        self, segments: List[Union[SSMLSegment, SSMLBreak]]
    ) -> List[AudioSegment]:
        segments = self.split_segments(segments)
        audio_segments = [None] * len(segments)
        buckets = self.bucket_segments(segments)

        break_segments = buckets.pop("<break>")
        self.process_break_segments(segments, break_segments, audio_segments)

        buckets = list(buckets.values())

        for bucket in buckets:
            self.process_voice_segments(segments, bucket, audio_segments)

        return audio_segments


# 示例使用
if __name__ == "__main__":
    ctx1 = SSMLContext()
    ctx1.spk = 1
    ctx1.seed = 42
    ctx1.temp = 0.1
    ctx2 = SSMLContext()
    ctx2.spk = 2
    ctx2.seed = 42
    ctx2.temp = 0.1
    ssml_segments = [
        SSMLSegment(text="大🍌,一条大🍌,嘿,你的感觉真的很奇妙", attrs=ctx1.copy()),
        SSMLBreak(duration_ms=1000),
        SSMLSegment(text="大🍉,一个大🍉,嘿,你的感觉真的很奇妙", attrs=ctx1.copy()),
        SSMLSegment(text="大🍊,一个大🍊,嘿,你的感觉真的很奇妙", attrs=ctx2.copy()),
    ]

    synthesizer = SynthesizeSegments(batch_size=2)
    audio_segments = synthesizer.synthesize_segments(ssml_segments)
    print(audio_segments)
    combined_audio = combine_audio_segments(audio_segments)
    combined_audio.export("output.wav", format="wav")