import asyncio
import logging
import os
import time
import threading
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Callable

import discord
from discord import app_commands
from discord.ext import commands
from dotenv import load_dotenv
import gradio as gr
from huggingface_hub import hf_hub_download
from gradio_client import Client

# Configure logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)


class LoopMode(str, Enum):
    NONE = "none"
    SINGLE = "single"
    QUEUE = "queue"


@dataclass
class BotConfig:
    ASSETS_DIR: str = "assets"
    SONG_FILE: str = "lofi.mp3"
    HF_REPO: str = "not-lain/assets"
    HF_SPACE: str = "https://not-lain-ytdlp.hf.space/"
    MAX_RETRY_ATTEMPTS: int = 3
    MAX_CACHE_SIZE: int = 100  # Maximum number of songs to cache

    @property
    def song_path(self) -> str:
        return f"{self.ASSETS_DIR}/{self.SONG_FILE}"


@dataclass
class QueueItem:
    url: str
    title: Optional[str] = None
    file_path: Optional[str] = None


class VoiceStateError(Exception):
    """Custom exception for voice state errors"""

    pass


class MusicBot:
    def __init__(self, config: BotConfig):
        self.config = config
        self.is_playing: bool = False
        self.voice_client: Optional[discord.VoiceClient] = None
        self.last_context: Optional[commands.Context] = None
        self.loop_mode: LoopMode = LoopMode.NONE
        self.current_song: Optional[QueueItem] = None
        self.queue: list[QueueItem] = []
        self.hf_client = Client(config.HF_SPACE, hf_token=None)
        self.song_cache: dict[str, tuple[str, str]] = {}  # url -> (title, file_path)

    def _manage_cache(self) -> None:
        """Remove oldest songs if cache exceeds maximum size"""
        if len(self.song_cache) > self.config.MAX_CACHE_SIZE:
            # Convert to list to get oldest entries (assuming insertion order is maintained)
            oldest_urls = list(self.song_cache.keys())[: -self.config.MAX_CACHE_SIZE]
            for url in oldest_urls:
                del self.song_cache[url]

    async def ensure_voice_state(self, ctx: commands.Context) -> None:
        """Validate voice state and raise appropriate errors"""
        if not ctx.author.voice:
            raise VoiceStateError("You need to be in a voice channel!")

        if self.voice_client and ctx.author.voice.channel != self.voice_client.channel:
            raise VoiceStateError("You must be in the same voice channel as the bot!")

    async def download_song(self, queue_item: QueueItem) -> None:
        """Download song from URL and update queue item with file path, using cache when available"""
        try:
            # Check if song is in cache
            if queue_item.url in self.song_cache:
                title, file_path = self.song_cache[queue_item.url]
                queue_item.title = title
                queue_item.file_path = file_path
                logger.info(f"Found song in cache: {title}")
                return

            # Download if not in cache
            job = self.hf_client.submit(url=queue_item.url, api_name="/predict")
            while not job.done():
                time.sleep(0.1)
            title, file_path = job.outputs()[0]

            # Update cache
            self.song_cache[queue_item.url] = (title, file_path)
            self._manage_cache()

            queue_item.title = title
            queue_item.file_path = file_path
            logger.info(f"Downloaded and cached new song: {title}")
        except Exception as e:
            logger.error(f"Error downloading song: {e}")
            raise

    async def play_next(self, ctx: commands.Context) -> None:
        if self.is_playing:
            return

        try:
            if not self.current_song and self.queue:
                self.current_song = self.queue.pop(0)
                # Download song if it hasn't been downloaded yet
                if not self.current_song.file_path:
                    await self.download_song(self.current_song)
            elif not self.current_song:
                return

            self.is_playing = True
            audio_source = discord.FFmpegPCMAudio(self.current_song.file_path)

            def after_playing(error):
                if error:
                    logger.error(f"Error in playback: {error}")
                self.is_playing = False
                asyncio.run_coroutine_threadsafe(self.handle_song_end(ctx), bot.loop)

            self.voice_client.play(audio_source, after=after_playing)

        except Exception as e:
            logger.error(f"Error playing file: {e}")
            self.is_playing = False
            raise

    async def handle_song_end(self, ctx: commands.Context) -> None:
        if self.loop_mode == LoopMode.NONE:
            self.current_song = None
        elif self.loop_mode == LoopMode.QUEUE and self.current_song:
            self.queue.append(self.current_song)
            self.current_song = None

        if not self.is_playing:
            await self.play_next(ctx)

    async def join_voice(self, ctx: commands.Context) -> None:
        if not ctx.author.voice:
            await ctx.send("You need to be in a voice channel!")
            return

        channel = ctx.author.voice.channel
        if self.voice_client is None:
            self.voice_client = await channel.connect()
            self.last_context = ctx
        else:
            await self.voice_client.move_to(channel)

    async def stop_playing(self, ctx: commands.Context) -> bool:
        try:
            if self.voice_client:
                if self.voice_client.is_playing():
                    self.voice_client.stop()
                if self.voice_client.is_connected():
                    await self.voice_client.disconnect(force=False)
                self._reset_state()
                return True
            return False
        except Exception as e:
            logger.error(f"Error during cleanup: {e}")
            self._reset_state()
            return False

    def add_to_queue(self, url: str) -> int:
        """Add song to queue and return position"""
        queue_item = QueueItem(url=url)
        self.queue.append(queue_item)
        return len(self.queue)

    def _reset_state(self) -> None:
        self.is_playing = False
        self.voice_client = None
        self.last_context = None
        self.loop_mode = LoopMode.NONE
        self.current_song = None
        self.queue.clear()
        # Note: We don't clear the cache when resetting state


async def handle_voice_command(
    interaction: discord.Interaction, action: Callable, defer: bool = True
) -> None:
    """Generic handler for voice-related commands"""
    try:
        if defer:
            await interaction.response.defer()
        ctx = await bot.get_context(interaction)
        await music_bot.ensure_voice_state(ctx)
        await action(ctx, interaction)
    except VoiceStateError as e:
        if not interaction.response.is_done():
            await interaction.response.send_message(str(e))
        else:
            await interaction.followup.send(str(e))
    except Exception as e:
        logger.error(f"Command error: {e}")
        if not interaction.response.is_done():
            await interaction.response.send_message("An error occurred!")
        else:
            await interaction.followup.send("An error occurred!")


# Initialize bot and music bot instance
config = BotConfig()
intents = discord.Intents.default()
intents.message_content = True
bot = commands.Bot(command_prefix="!", intents=intents)
music_bot = MusicBot(config)


@bot.event
async def on_ready():
    print(f"Bot is ready! Logged in as {bot.user}")
    print("Syncing commands...")
    try:
        await bot.tree.sync(guild=None)  # Set to None for global sync
        print("Successfully synced commands globally!")
    except discord.app_commands.errors.CommandSyncFailure as e:
        print(f"Failed to sync commands: {e}")
    except Exception as e:
        print(f"An error occurred while syncing commands: {e}")


@bot.tree.command(name="lofi", description="Play lofi music")
async def lofi(interaction: discord.Interaction):
    async def play_lofi(ctx, interaction: discord.Interaction):
        await music_bot.join_voice(ctx)
        # music_bot.loop_mode = LoopMode.SINGLE
        music_bot.current_song = QueueItem(
            url=config.song_path, title="Lofi Music", file_path=config.song_path
        )
        if not music_bot.is_playing:
            await music_bot.play_next(ctx)
            await interaction.followup.send("Playing lofi music! 🎵")
        else:
            await interaction.followup.send("Already playing!")

    await handle_voice_command(interaction, play_lofi)


@bot.tree.command(name="play", description="Play a youtube song")
async def play(interaction: discord.Interaction, url: str):
    async def play_song(ctx, interaction: discord.Interaction):
        await music_bot.join_voice(ctx)

        if music_bot.is_playing or music_bot.queue:
            position = music_bot.add_to_queue(url)
            await interaction.followup.send(
                f"Added song to queue at position {position}! 🎵"
            )
        else:
            music_bot.add_to_queue(url)
            await music_bot.play_next(ctx)
            if music_bot.current_song and music_bot.current_song.title:
                await interaction.followup.send(
                    f"Playing {music_bot.current_song.title}! 🎵"
                )
            else:
                await interaction.followup.send("Playing song! 🎵")

    await handle_voice_command(interaction, play_song)


@bot.tree.command(name="skip", description="Skip the current song")
async def skip(interaction: discord.Interaction):
    # Check if user is in a voice channel
    if not interaction.user.voice:
        await interaction.response.send_message(
            "You must be in a voice channel to use this command!"
        )
        return

    # Check if bot is in a voice channel
    if not music_bot.voice_client:
        await interaction.response.send_message("No song is currently playing!")
        return

    # Check if user is in the same channel as the bot
    if interaction.user.voice.channel != music_bot.voice_client.channel:
        await interaction.response.send_message(
            "You must be in the same voice channel as the bot!"
        )
        return

    if music_bot.voice_client and music_bot.is_playing:
        music_bot.is_playing = False  # Reset playing state
        music_bot.voice_client.stop()
        await interaction.response.send_message("Skipped current song!")
    else:
        await interaction.response.send_message("No song is currently playing!")


@bot.tree.command(name="leave", description="Disconnect bot from voice channel")
async def leave(interaction: discord.Interaction):
    # Check if user is in a voice channel
    if not interaction.user.voice:
        await interaction.response.send_message(
            "You must be in a voice channel to use this command!"
        )
        return

    # Check if bot is in a voice channel
    if not music_bot.voice_client:
        await interaction.response.send_message("I'm not in any voice channel!")
        return

    # Check if user is in the same channel as the bot
    if interaction.user.voice.channel != music_bot.voice_client.channel:
        await interaction.response.send_message(
            "You must be in the same voice channel as the bot!"
        )
        return

    await interaction.response.defer()
    ctx = await bot.get_context(interaction)

    try:
        success = await music_bot.stop_playing(ctx)
        if success:
            await interaction.followup.send("Successfully disconnected! 👋")
        else:
            await interaction.followup.send(
                "Failed to disconnect properly. Please try again."
            )
    except Exception as e:
        print(f"Error during leave command: {e}")
        await interaction.followup.send("An error occurred while trying to disconnect.")


@bot.tree.command(name="loop", description="Set loop mode")
@app_commands.choices(
    mode=[app_commands.Choice(name=mode.value, value=mode.value) for mode in LoopMode]
)
async def loop(interaction: discord.Interaction, mode: str):
    try:
        music_bot.loop_mode = LoopMode(mode)
        await interaction.response.send_message(f"Loop mode set to: {mode}")
    except ValueError:
        await interaction.response.send_message("Invalid loop mode!")


@bot.tree.command(name="queue", description="Show current queue")
async def queue(interaction: discord.Interaction):
    if not music_bot.queue and not music_bot.current_song:
        await interaction.response.send_message("Queue is empty!")
        return

    queue_list = []
    if music_bot.current_song:
        status = "🎵 Now playing"
        title = music_bot.current_song.title or "Loading..."
        queue_list.append(f"{status}: {title}")

    for i, item in enumerate(music_bot.queue, 1):
        title = item.title or "Loading..."
        queue_list.append(f"{i}. {title}")

    await interaction.response.send_message("\n".join(queue_list))


def initialize_assets() -> None:
    if not os.path.exists(config.ASSETS_DIR):
        os.makedirs(config.ASSETS_DIR, exist_ok=True)
        hf_hub_download(
            config.HF_REPO,
            config.SONG_FILE,
            repo_type="dataset",
            local_dir=config.ASSETS_DIR,
        )


def run_discord_bot() -> None:
    """Run the Discord bot with the token from environment variables."""
    load_dotenv()
    bot.run(os.getenv("DISCORD_TOKEN"))


if __name__ == "__main__":
    initialize_assets()
    bot_thread = threading.Thread(target=run_discord_bot, daemon=True)
    bot_thread.start()

    with gr.Blocks() as iface:
        gr.Markdown("# Discord Music Bot Control Panel")
        gr.Markdown("Bot is running in background")

    iface.launch(debug=True)