import os
import logging

from modules.api.api_setup import (
    process_api_args,
    process_model_args,
    setup_api_args,
    setup_model_args,
)
from modules.ffmpeg_env import setup_ffmpeg_path
from modules.utils.env import get_and_update_env
from modules.api.app_config import app_description, app_title, app_version
from modules.utils.torch_opt import configure_torch_optimizations

setup_ffmpeg_path()
logging.basicConfig(
    level=os.getenv("LOG_LEVEL", "INFO"),
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)

from modules.webui import webui_config
from modules import config
from modules.webui.app import webui_init, create_interface
import argparse
from modules.gradio_dcls_fix import dcls_patch

dcls_patch()


def setup_webui_args(parser: argparse.ArgumentParser):
    parser.add_argument("--server_name", type=str, help="server name")
    parser.add_argument("--server_port", type=int, help="server port")
    parser.add_argument(
        "--share", action="store_true", help="share the gradio interface"
    )
    parser.add_argument("--debug", action="store_true", help="enable debug mode")
    parser.add_argument("--auth", type=str, help="username:password for authentication")
    parser.add_argument(
        "--tts_max_len",
        type=int,
        help="Max length of text for TTS",
    )
    parser.add_argument(
        "--ssml_max_len",
        type=int,
        help="Max length of text for SSML",
    )
    parser.add_argument(
        "--max_batch_size",
        type=int,
        help="Max batch size for TTS",
    )
    # webui_Experimental
    parser.add_argument(
        "--webui_experimental",
        action="store_true",
        help="Enable webui_experimental features",
    )
    parser.add_argument(
        "--language",
        type=str,
        help="Set the default language for the webui",
    )
    parser.add_argument(
        "--api",
        action="store_true",
        help="use api=True to launch the API together with the webui (run launch.py for only API server)",
    )


def process_webui_args(args):
    server_name = get_and_update_env(args, "server_name", "0.0.0.0", str)
    server_port = get_and_update_env(args, "server_port", 7860, int)
    share = get_and_update_env(args, "share", False, bool)
    debug = get_and_update_env(args, "debug", False, bool)
    auth = get_and_update_env(args, "auth", None, str)
    language = get_and_update_env(args, "language", "zh-CN", str)
    api = get_and_update_env(args, "api", False, bool)

    webui_config.experimental = get_and_update_env(
        args, "webui_experimental", False, bool
    )
    webui_config.tts_max = get_and_update_env(args, "tts_max_len", 1000, int)
    webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int)
    webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int)

    webui_config.experimental = get_and_update_env(
        args, "webui_experimental", False, bool
    )
    webui_config.tts_max = get_and_update_env(args, "tts_max_len", 1000, int)
    webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int)
    webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int)

    configure_torch_optimizations()
    webui_init()
    demo = create_interface()

    if auth:
        auth = tuple(auth.split(":"))

    app, local_url, share_url = demo.queue().launch(
        server_name=server_name,
        server_port=server_port,
        share=share,
        debug=debug,
        auth=auth,
        show_api=False,
        prevent_thread_lock=True,
        app_kwargs={
            "title": app_title,
            "description": app_description,
            "version": app_version,
            "redoc_url": (
                None
                if api is False
                else None if config.runtime_env_vars.no_docs else "/redoc"
            ),
            "docs_url": (
                None
                if api is False
                else None if config.runtime_env_vars.no_docs else "/docs"
            ),
        },
    )
    # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
    # an attacker to trick the user into opening a malicious HTML page, which makes a request to the
    # running web ui and do whatever the attacker wants, including installing an extension and
    # running its code. We disable this here. Suggested by RyotaK.
    app.user_middleware = [
        x for x in app.user_middleware if x.cls.__name__ != "CustomCORSMiddleware"
    ]

    if api:
        process_api_args(args, app)

    demo.block_thread()


if __name__ == "__main__":
    import dotenv

    dotenv.load_dotenv(
        dotenv_path=os.getenv("ENV_FILE", ".env.webui"),
    )

    parser = argparse.ArgumentParser(description="Gradio App")

    setup_webui_args(parser)
    setup_model_args(parser)
    setup_api_args(parser)

    args = parser.parse_args()

    process_model_args(args)
    process_webui_args(args)