import logging
import os
import sys

from modules.ffmpeg_env import setup_ffmpeg_path

try:
    setup_ffmpeg_path()
    # NOTE: 因为 logger 都是在模块中初始化,所以这个 config 必须在最前面
    logging.basicConfig(
        level=os.getenv("LOG_LEVEL", "INFO"),
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    )
except BaseException:
    pass

import argparse

from modules import config
from modules.api.api_setup import process_api_args, setup_api_args
from modules.api.app_config import app_description, app_title, app_version
from modules.gradio_dcls_fix import dcls_patch
from modules.models_setup import process_model_args, setup_model_args
from modules.utils.env import get_and_update_env
from modules.utils.ignore_warn import ignore_useless_warnings
from modules.utils.torch_opt import configure_torch_optimizations
from modules.webui import webui_config
from modules.webui.app import create_interface, webui_init

dcls_patch()
ignore_useless_warnings()

import subprocess

subprocess.run(
    "pip install flash-attn --no-build-isolation",
    env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
    shell=True,
)


def setup_webui_args(parser: argparse.ArgumentParser):
    parser.add_argument("--server_name", type=str, help="server name")
    parser.add_argument("--server_port", type=int, help="server port")
    parser.add_argument(
        "--share", action="store_true", help="share the gradio interface"
    )
    parser.add_argument("--debug", action="store_true", help="enable debug mode")
    parser.add_argument("--auth", type=str, help="username:password for authentication")
    parser.add_argument(
        "--tts_max_len",
        type=int,
        help="Max length of text for TTS",
    )
    parser.add_argument(
        "--ssml_max_len",
        type=int,
        help="Max length of text for SSML",
    )
    parser.add_argument(
        "--max_batch_size",
        type=int,
        help="Max batch size for TTS",
    )
    # webui_Experimental
    parser.add_argument(
        "--webui_experimental",
        action="store_true",
        help="Enable webui_experimental features",
    )
    parser.add_argument(
        "--language",
        type=str,
        help="Set the default language for the webui",
    )
    parser.add_argument(
        "--api",
        action="store_true",
        help="use api=True to launch the API together with the webui (run launch.py for only API server)",
    )


def process_webui_args(args):
    server_name = get_and_update_env(args, "server_name", "0.0.0.0", str)
    server_port = get_and_update_env(args, "server_port", 7860, int)
    share = get_and_update_env(args, "share", False, bool)
    debug = get_and_update_env(args, "debug", False, bool)
    auth = get_and_update_env(args, "auth", None, str)
    language = get_and_update_env(args, "language", "zh-CN", str)
    api = get_and_update_env(args, "api", False, bool)

    webui_config.experimental = get_and_update_env(
        args, "webui_experimental", False, bool
    )
    webui_config.tts_max = get_and_update_env(args, "tts_max_len", 1000, int)
    webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int)
    webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int)

    webui_config.experimental = get_and_update_env(
        args, "webui_experimental", False, bool
    )
    webui_config.tts_max = get_and_update_env(args, "tts_max_len", 1000, int)
    webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int)
    webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int)

    configure_torch_optimizations()
    webui_init()
    demo = create_interface()

    if auth:
        auth = tuple(auth.split(":"))

    app, local_url, share_url = demo.queue().launch(
        server_name=server_name,
        server_port=server_port,
        share=share,
        debug=debug,
        auth=auth,
        show_api=False,
        prevent_thread_lock=True,
        inbrowser=sys.platform == "win32",
        app_kwargs={
            "title": app_title,
            "description": app_description,
            "version": app_version,
            "redoc_url": (
                None
                if api is False
                else None if config.runtime_env_vars.no_docs else "/redoc"
            ),
            "docs_url": (
                None
                if api is False
                else None if config.runtime_env_vars.no_docs else "/docs"
            ),
        },
    )
    # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
    # an attacker to trick the user into opening a malicious HTML page, which makes a request to the
    # running web ui and do whatever the attacker wants, including installing an extension and
    # running its code. We disable this here. Suggested by RyotaK.
    app.user_middleware = [
        x for x in app.user_middleware if x.cls.__name__ != "CustomCORSMiddleware"
    ]

    if api:
        process_api_args(args, app)

    demo.block_thread()


if __name__ == "__main__":
    import dotenv

    dotenv.load_dotenv(
        dotenv_path=os.getenv("ENV_FILE", ".env.webui"),
    )

    parser = argparse.ArgumentParser(description="Gradio App")

    setup_webui_args(parser)
    setup_model_args(parser)
    setup_api_args(parser)

    args = parser.parse_args()

    process_model_args(args)
    process_webui_args(args)