from fastapi import File, Form, HTTPException, Body, UploadFile

from numpy import clip
from pydantic import BaseModel, Field
from fastapi.responses import StreamingResponse


from modules.api.impl.handler.TTSHandler import TTSHandler
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
from modules.api.impl.model.enhancer_model import EnhancerConfig


from typing import List, Optional

from modules.api import utils as api_utils
from modules.api.Api import APIManager

from modules.speaker import Speaker, speaker_mgr
from modules.data import styles_mgr


class AudioSpeechRequest(BaseModel):
    input: str  # 需要合成的文本
    model: str = "chattts-4w"
    voice: str = "female2"
    response_format: AudioFormat = "mp3"
    speed: float = Field(1, ge=0.1, le=10, description="Speed of the audio")
    seed: int = 42

    temperature: float = 0.3
    top_k: int = 20
    top_p: float = 0.7

    style: str = ""
    batch_size: int = Field(1, ge=1, le=20, description="Batch size")
    spliter_threshold: float = Field(
        100, ge=10, le=1024, description="Threshold for sentence spliter"
    )
    # end of sentence
    eos: str = "[uv_break]"

    enhance: bool = False
    denoise: bool = False


async def openai_speech_api(
    request: AudioSpeechRequest = Body(
        ..., description="JSON body with model, input text, and voice"
    )
):
    model = request.model
    input_text = request.input
    voice = request.voice
    style = request.style
    eos = request.eos
    seed = request.seed

    response_format = request.response_format
    if not isinstance(response_format, AudioFormat) and isinstance(
        response_format, str
    ):
        response_format = AudioFormat(response_format)

    batch_size = request.batch_size
    spliter_threshold = request.spliter_threshold
    speed = request.speed
    speed = clip(speed, 0.1, 10)

    if not input_text:
        raise HTTPException(status_code=400, detail="Input text is required.")
    if speaker_mgr.get_speaker(voice) is None:
        raise HTTPException(status_code=400, detail="Invalid voice.")
    try:
        if style:
            styles_mgr.find_item_by_name(style)
    except:
        raise HTTPException(status_code=400, detail="Invalid style.")

    ctx_params = api_utils.calc_spk_style(spk=voice, style=style)

    speaker = ctx_params.get("spk")
    if not isinstance(speaker, Speaker):
        raise HTTPException(status_code=400, detail="Invalid voice.")

    tts_config = ChatTTSConfig(
        style=style,
        temperature=request.temperature,
        top_k=request.top_k,
        top_p=request.top_p,
    )
    infer_config = InferConfig(
        batch_size=batch_size,
        spliter_threshold=spliter_threshold,
        eos=eos,
        seed=seed,
    )
    adjust_config = AdjustConfig(speaking_rate=speed)
    enhancer_config = EnhancerConfig(
        enabled=request.enhance or request.denoise or False,
        lambd=0.9 if request.denoise else 0.1,
    )
    try:
        handler = TTSHandler(
            text_content=input_text,
            spk=speaker,
            tts_config=tts_config,
            infer_config=infer_config,
            adjust_config=adjust_config,
            enhancer_config=enhancer_config,
        )

        buffer = handler.enqueue_to_buffer(response_format)

        mime_type = f"audio/{response_format.value}"
        if response_format == AudioFormat.mp3:
            mime_type = "audio/mpeg"
        return StreamingResponse(buffer, media_type=mime_type)

    except Exception as e:
        import logging

        logging.exception(e)

        if isinstance(e, HTTPException):
            raise e
        else:
            raise HTTPException(status_code=500, detail=str(e))


class TranscribeSegment(BaseModel):
    id: int
    seek: float
    start: float
    end: float
    text: str
    tokens: list[int]
    temperature: float
    avg_logprob: float
    compression_ratio: float
    no_speech_prob: float


class TranscriptionsVerboseResponse(BaseModel):
    task: str
    language: str
    duration: float
    text: str
    segments: list[TranscribeSegment]


def setup(app: APIManager):
    app.post(
        "/v1/audio/speech",
        description="""
openai api document: 
[https://platform.openai.com/docs/guides/text-to-speech](https://platform.openai.com/docs/guides/text-to-speech)

以下属性为本系统自定义属性,不在openai文档中:
- batch_size: 是否开启batch合成,小于等于1表示不使用batch (不推荐)
- spliter_threshold: 开启batch合成时,句子分割的阈值
- style: 风格

> model 可填任意值
        """,
    )(openai_speech_api)

    @app.post(
        "/v1/audio/transcriptions",
        response_model=TranscriptionsVerboseResponse,
        description="Transcribes audio into the input language.",
    )
    async def transcribe(
        file: UploadFile = File(...),
        model: str = Form(...),
        language: Optional[str] = Form(None),
        prompt: Optional[str] = Form(None),
        response_format: str = Form("json"),
        temperature: float = Form(0),
        timestamp_granularities: List[str] = Form(["segment"]),
    ):
        # TODO: Implement transcribe
        return api_utils.success_response("not implemented yet")