import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision import models
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutput
from .configuration_bytegpt import ByteGPTConfig

try:
    from flash_attn.flash_attention import FlashAttention

    FLASH_ATTENTION_AVAILABLE = (
        True and torch.cuda.is_available()
    )  # Only available on CUDA
except ImportError:
    FLASH_ATTENTION_AVAILABLE = False


class Head(nn.Module):
    """One head of self-attention.

    Args:
        head_size (int): The size of the head.
        n_embd (int): The embedding dimension.
        block_size (int): The block size.
        dropout (float): The dropout rate.
        use_flash_attention (bool): Whether to use Flash Attention.

    Attributes:
        key (nn.Linear): The linear layer for computing the keys.
        query (nn.Linear): The linear layer for computing the queries.
        value (nn.Linear): The linear layer for computing the values.
        tril (torch.Tensor): The lower triangular matrix.
        dropout (nn.Dropout): The dropout layer.
        use_flash_attention (bool): Whether to use Flash Attention.
        flash_attention (FlashAttention): The FlashAttention module.
    """

    def __init__(
        self,
        head_size: int,
        n_embd: int,
        block_size: int,
        dropout: float,
        use_flash_attention: bool = False,
    ) -> None:
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)

        # Only enable flash attention if we're on CUDA
        self.use_flash_attention = use_flash_attention and FLASH_ATTENTION_AVAILABLE
        if self.use_flash_attention:
            print("Using Flash Attention")
            self.flash_attention = FlashAttention()
        else:
            if use_flash_attention:
                print(
                    "Flash Attention requested but not available. Using standard attention."
                )
            self.tril = torch.tril(torch.ones(block_size, block_size))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Perform forward pass through the attention head.

        Args:
            x (torch.Tensor): The input tensor of shape (batch_size, sequence_length, embedding_dimension).

        Returns:
            torch.Tensor: The output tensor of shape (batch_size, sequence_length, embedding_dimension).
        """
        B, T, C = x.shape
        k = self.key(x)  # (B,T,head_size)
        q = self.query(x)  # (B,T,head_size)
        v = self.value(x)  # (B,T,head_size)

        if self.use_flash_attention:
            # Flash Attention expects shape (B, H, T, D)
            out = self.flash_attention(q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1))[
                0
            ].squeeze(1)
        else:
            # Regular attention
            self.tril = self.tril.to(x.device)
            wei = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5  # (B, T, T)
            wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))  # (B, T, T)
            wei = F.softmax(wei, dim=-1)  # (B, T, T)
            wei = self.dropout(wei)
            out = wei @ v  # (B, T, head_size)

        return out


class MultiHeadAttention(nn.Module):
    """Multiple heads of self-attention in parallel.

    Args:
        num_heads (int): The number of heads.
        head_size (int): The size of each head.
        n_embd (int): The embedding dimension.
        block_size (int): The block size.
        dropout (float): The dropout rate.
        use_flash_attention (bool): Whether to use Flash Attention.

    Attributes:
        heads (nn.Modulelist): The list of attention heads.
        proj (nn.Linear): The linear layer for projecting the concatenated heads.
        dropout (nn.Dropout): The dropout layer.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        n_embd: int,
        block_size: int,
        dropout: float,
        use_flash_attention: bool = False,
    ) -> None:
        super().__init__()
        self.heads = nn.ModuleList(
            [
                Head(
                    head_size,
                    n_embd,
                    block_size,
                    dropout,
                    use_flash_attention=use_flash_attention,
                )
                for _ in range(num_heads)
            ]
        )
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Perform forward pass through the multi-head attention layer.

        Args:
            x (torch.Tensor): The input tensor of shape (batch_size, sequence_length, embedding_dimension).

        Returns:
            torch.Tensor: The output tensor of shape (batch_size, sequence_length, embedding_dimension).
        """
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out


class FeedForward(nn.Module):
    """Simple linear layer followed by a non-linearity.

    Args:
        n_embd (int): The embedding dimension.
        dropout (float): The dropout rate.

    Attributes:
        net (nn.Sequential): The sequential network of linear layers and ReLU activation.
    """

    def __init__(self, n_embd: int, dropout: float) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Perform forward pass through the feedforward layer.

        Args:
            x (torch.Tensor): The input tensor of shape (batch_size, sequence_length, embedding_dimension).

        Returns:
            torch.Tensor: The output tensor of shape (batch_size, sequence_length, embedding_dimension).
        """
        return self.net(x)


class Block(nn.Module):
    """Transformer block: communication followed by computation.

    Args:
        n_embd (int): The embedding dimension.
        n_head (int): The number of attention heads.
        block_size (int): The block size.
        dropout (float): The dropout rate.
        use_flash_attention (bool): Whether to use Flash Attention.

    Attributes:
        sa (MultiHeadAttention): The multi-head attention layer.
        ffwd (FeedForward): The feedforward layer.
        ln1 (nn.LayerNorm): The layer normalization layer for the first sublayer.
        ln2 (nn.LayerNorm): The layer normalization layer for the second sublayer.
    """

    def __init__(
        self,
        n_embd: int,
        n_head: int,
        block_size: int,
        dropout: float,
        use_flash_attention: bool = False,
    ) -> None:
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(
            n_head,
            head_size,
            n_embd,
            block_size,
            dropout,
            use_flash_attention=use_flash_attention,
        )
        self.ffwd = FeedForward(n_embd, dropout)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

        # Remove duplicate flash attention and tril setup since it's handled in Head class
        self.use_flash_attention = use_flash_attention and FLASH_ATTENTION_AVAILABLE
        if self.use_flash_attention:
            print("Using Flash Attention")
        elif use_flash_attention:
            print(
                "Flash Attention requested but not available. Using standard attention."
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Perform forward pass through the transformer block.

        Args:
            x (torch.Tensor): The input tensor of shape (batch_size, sequence_length, embedding_dimension).

        Returns:
            torch.Tensor: The output tensor of shape (batch_size, sequence_length, embedding_dimension).
        """
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x


class ByteGPTForCausalLM(PreTrainedModel):
    config_class = ByteGPTConfig

    def __init__(
        self,
        config: ByteGPTConfig,
    ):
        super().__init__(config)
        self.block_size = config.block_size
        self.token_embedding_table = nn.Embedding(config.vocab_size, config.n_embd)
        self.position_embedding_table = nn.Embedding(config.block_size, config.n_embd)
        self.blocks = nn.Sequential(
            *[
                Block(
                    config.n_embd,
                    config.n_head,
                    config.block_size,
                    config.dropout,
                    config.use_flash_attention,
                )
                for _ in range(config.n_layer)
            ]
        )
        self.ln_f = nn.LayerNorm(config.n_embd)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        return_dict: bool = True,
        labels: torch.Tensor = None,
        **kwargs
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass of the model.

        Args:
            idx: Input tensor.
            targets: Target tensor.

        Returns:
            tuple of logits and loss.
        """
        B, T = input_ids.shape

        # Token and position embeddings
        tok_emb = self.token_embedding_table(input_ids)  # (B,T,C)
        pos_emb = self.position_embedding_table(
            torch.arange(T, device=input_ids.device)
        )  # (T,C)
        x = tok_emb + pos_emb  # (B,T,C)

        # Transformer blocks
        x = self.blocks(x)  # (B,T,C)
        x = self.ln_f(x)  # (B,T,C)

        # Language model head
        logits = self.lm_head(x)  # (B,T,vocab_size)

        if labels is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            labels = labels.view(B * T)
            loss = F.cross_entropy(logits, labels)

        if not return_dict:
            return (logits, loss)

        return CausalLMOutput(logits=logits, loss=loss)

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        # Required for .generate() to work
        return {
            "input_ids": input_ids,
            "attention_mask": torch.ones_like(input_ids),
        }

    # def generate(
    #     self, input_ids: torch.Tensor, max_new_tokens: int, temperature: float = 1.0
    # ) -> torch.Tensor:
    #     """
    #     Generate text tokens autoregressively.

    #     Args:
    #         idx: Context tokens
    #         max_new_tokens: Number of tokens to generate
    #         temperature: Sampling temperature (higher = more random)

    #     Returns:
    #         Generated token sequence
    #     """
    #     for _ in range(max_new_tokens):
    #         # Crop context if needed
    #         idx_cond = input_ids[:, -self.block_size :]
    #         # Get predictions
    #         logits, _ = self(idx_cond)
    #         # Focus only on the last time step
    #         logits = logits[:, -1, :] / temperature
    #         # Apply softmax to get probabilities
    #         probs = F.softmax(logits, dim=-1)
    #         # Sample from the distribution
    #         idx_next = torch.multinomial(probs, num_samples=1)
    #         # Append sampled index to the running sequence
    #         idx = torch.cat((idx, idx_next), dim=1)
    #     return idx