import os
import logging

logging.basicConfig(
    level=os.getenv("LOG_LEVEL", "INFO"),
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)

from modules.devices import devices
import argparse
import uvicorn

import torch
from modules import config
from modules.utils import env
from modules import generate_audio as generate
from modules.api.Api import APIManager

from modules.api.impl import (
    style_api,
    tts_api,
    ssml_api,
    google_api,
    openai_api,
    refiner_api,
    speaker_api,
    ping_api,
    models_api,
)

logger = logging.getLogger(__name__)

torch._dynamo.config.cache_size_limit = 64
torch._dynamo.config.suppress_errors = True
torch.set_float32_matmul_precision("high")


def create_api(app, no_docs=False, exclude=[]):
    app_mgr = APIManager(app=app, no_docs=no_docs, exclude_patterns=exclude)

    ping_api.setup(app_mgr)
    models_api.setup(app_mgr)
    style_api.setup(app_mgr)
    speaker_api.setup(app_mgr)
    tts_api.setup(app_mgr)
    ssml_api.setup(app_mgr)
    google_api.setup(app_mgr)
    openai_api.setup(app_mgr)
    refiner_api.setup(app_mgr)

    return app_mgr


def get_and_update_env(*args):
    val = env.get_env_or_arg(*args)
    key = args[1]
    config.runtime_env_vars[key] = val
    return val


def setup_model_args(parser: argparse.ArgumentParser):
    parser.add_argument("--compile", action="store_true", help="Enable model compile")
    parser.add_argument(
        "--half",
        action="store_true",
        help="Enable half precision for model inference",
    )
    parser.add_argument(
        "--off_tqdm",
        action="store_true",
        help="Disable tqdm progress bar",
    )
    parser.add_argument(
        "--device_id",
        type=str,
        help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)",
        default=None,
    )
    parser.add_argument(
        "--use_cpu",
        nargs="+",
        help="use CPU as torch device for specified modules",
        default=[],
        type=str.lower,
    )
    parser.add_argument(
        "--lru_size",
        type=int,
        default=64,
        help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
    )


def setup_api_args(parser: argparse.ArgumentParser):
    parser.add_argument("--api_host", type=str, help="Host to run the server on")
    parser.add_argument("--api_port", type=int, help="Port to run the server on")
    parser.add_argument(
        "--reload", action="store_true", help="Enable auto-reload for development"
    )
    parser.add_argument(
        "--cors_origin",
        type=str,
        help="Allowed CORS origins. Use '*' to allow all origins.",
    )
    parser.add_argument(
        "--no_playground",
        action="store_true",
        help="Disable the playground entry",
    )
    parser.add_argument(
        "--no_docs",
        action="store_true",
        help="Disable the documentation entry",
    )
    # 配置哪些api要跳过 比如 exclude="/v1/speakers/*,/v1/tts/*"
    parser.add_argument(
        "--exclude",
        type=str,
        help="Exclude the specified API from the server",
    )


def process_model_args(args):
    lru_size = get_and_update_env(args, "lru_size", 64, int)
    compile = get_and_update_env(args, "compile", False, bool)
    device_id = get_and_update_env(args, "device_id", None, str)
    use_cpu = get_and_update_env(args, "use_cpu", [], list)
    half = get_and_update_env(args, "half", False, bool)
    off_tqdm = get_and_update_env(args, "off_tqdm", False, bool)

    generate.setup_lru_cache()
    devices.reset_device()
    devices.first_time_calculation()


def process_api_args(args, app):
    cors_origin = get_and_update_env(args, "cors_origin", "*", str)
    no_playground = get_and_update_env(args, "no_playground", False, bool)
    no_docs = get_and_update_env(args, "no_docs", False, bool)
    exclude = get_and_update_env(args, "exclude", "", str)

    api = create_api(app=app, no_docs=no_docs, exclude=exclude.split(","))
    config.api = api

    if cors_origin:
        api.set_cors(allow_origins=[cors_origin])

    if not no_playground:
        api.setup_playground()

    if compile:
        logger.info("Model compile is enabled")


app_description = """
ChatTTS-Forge 是一个功能强大的文本转语音生成工具,支持通过类 SSML 语法生成丰富的音频长文本,并提供全面的 API 服务,适用于各种场景。<br/>
ChatTTS-Forge is a powerful text-to-speech generation tool that supports generating rich audio long texts through class SSML syntax

项目地址: [https://github.com/lenML/ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)

> 所有生成音频的 POST api都无法在此页面调试,调试建议使用 playground <br/>
> All audio generation POST APIs cannot be debugged on this page, it is recommended to use playground for debugging

> 如果你不熟悉本系统,建议从这个一键脚本开始,在colab中尝试一下:<br/>
> [https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb](https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb)
            """
app_title = "ChatTTS Forge API"
app_version = "0.1.0"

if __name__ == "__main__":
    import dotenv
    from fastapi import FastAPI

    dotenv.load_dotenv(
        dotenv_path=os.getenv("ENV_FILE", ".env.api"),
    )

    parser = argparse.ArgumentParser(
        description="Start the FastAPI server with command line arguments"
    )
    setup_api_args(parser)
    setup_model_args(parser)

    args = parser.parse_args()

    app = FastAPI(
        title=app_title,
        description=app_description,
        version=app_version,
        redoc_url=None if config.runtime_env_vars.no_docs else "/redoc",
        docs_url=None if config.runtime_env_vars.no_docs else "/docs",
    )

    process_model_args(args)
    process_api_args(args, app)

    host = get_and_update_env(args, "api_host", "0.0.0.0", str)
    port = get_and_update_env(args, "api_port", 7870, int)
    reload = get_and_update_env(args, "reload", False, bool)

    uvicorn.run(app, host=host, port=port, reload=reload)