import torch.nn as nn

from .net_utils import (
    PosEnSine,
    dotproduct_attention,
    long_range_attention,
    patch_attention,
    short_range_attention,
    softmax_attention,
)


class OurMultiheadAttention(nn.Module):
    def __init__(self, q_feat_dim, k_feat_dim, out_feat_dim, n_head, d_k=None, d_v=None):
        super(OurMultiheadAttention, self).__init__()
        if d_k is None:
            d_k = out_feat_dim // n_head
        if d_v is None:
            d_v = out_feat_dim // n_head

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        # pre-attention projection
        self.w_qs = nn.Conv2d(q_feat_dim, n_head * d_k, 1, bias=False)
        self.w_ks = nn.Conv2d(k_feat_dim, n_head * d_k, 1, bias=False)
        self.w_vs = nn.Conv2d(out_feat_dim, n_head * d_v, 1, bias=False)

        # after-attention combine heads
        self.fc = nn.Conv2d(n_head * d_v, out_feat_dim, 1, bias=False)

    def forward(self, q, k, v, attn_type='softmax', **kwargs):
        # input: b x d x h x w
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head

        # Pass through the pre-attention projection: b x (nhead*dk) x h x w
        # Separate different heads: b x nhead x dk x h x w
        q = self.w_qs(q).view(q.shape[0], n_head, d_k, q.shape[2], q.shape[3])
        k = self.w_ks(k).view(k.shape[0], n_head, d_k, k.shape[2], k.shape[3])
        v = self.w_vs(v).view(v.shape[0], n_head, d_v, v.shape[2], v.shape[3])

        # -------------- Attention -----------------
        if attn_type == 'softmax':
            q, attn = softmax_attention(q, k, v)    # b x n x dk x h x w --> b x n x dv x h x w
        elif attn_type == 'dotproduct':
            q, attn = dotproduct_attention(q, k, v)
        elif attn_type == 'patch':
            q, attn = patch_attention(q, k, v, P=kwargs['P'])
        elif attn_type == 'sparse_long':
            q, attn = long_range_attention(q, k, v, P_h=kwargs['ah'], P_w=kwargs['aw'])
        elif attn_type == 'sparse_short':
            q, attn = short_range_attention(q, k, v, Q_h=kwargs['ah'], Q_w=kwargs['aw'])
        else:
            raise NotImplementedError(f'Unknown attention type {attn_type}')
        # ------------ end Attention ---------------

        # Concatenate all the heads together: b x (n*dv) x h x w
        q = q.reshape(q.shape[0], -1, q.shape[3], q.shape[4])
        q = self.fc(q)    # b x d x h x w

        return q, attn


class TransformerEncoderUnit(nn.Module):
    def __init__(self, feat_dim, n_head=8, pos_en_flag=True, attn_type='softmax', P=None):
        super(TransformerEncoderUnit, self).__init__()
        self.feat_dim = feat_dim
        self.attn_type = attn_type
        self.pos_en_flag = pos_en_flag
        self.P = P

        self.pos_en = PosEnSine(self.feat_dim // 2)
        self.attn = OurMultiheadAttention(feat_dim, n_head)

        self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1)
        self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1)
        self.activation = nn.ReLU(inplace=True)

        self.norm1 = nn.BatchNorm2d(self.feat_dim)
        self.norm2 = nn.BatchNorm2d(self.feat_dim)

    def forward(self, src):
        if self.pos_en_flag:
            pos_embed = self.pos_en(src)
        else:
            pos_embed = 0

        # multi-head attention
        src2 = self.attn(
            q=src + pos_embed, k=src + pos_embed, v=src, attn_type=self.attn_type, P=self.P
        )[0]
        src = src + src2
        src = self.norm1(src)

        # feed forward
        src2 = self.linear2(self.activation(self.linear1(src)))
        src = src + src2
        src = self.norm2(src)

        return src


class TransformerEncoderUnitSparse(nn.Module):
    def __init__(self, feat_dim, n_head=8, pos_en_flag=True, ahw=None):
        super(TransformerEncoderUnitSparse, self).__init__()
        self.feat_dim = feat_dim
        self.pos_en_flag = pos_en_flag
        self.ahw = ahw    # [Ph, Pw, Qh, Qw]

        self.pos_en = PosEnSine(self.feat_dim // 2)
        self.attn1 = OurMultiheadAttention(feat_dim, n_head)    # long range
        self.attn2 = OurMultiheadAttention(feat_dim, n_head)    # short range

        self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1)
        self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1)
        self.activation = nn.ReLU(inplace=True)

        self.norm1 = nn.BatchNorm2d(self.feat_dim)
        self.norm2 = nn.BatchNorm2d(self.feat_dim)

    def forward(self, src):
        if self.pos_en_flag:
            pos_embed = self.pos_en(src)
        else:
            pos_embed = 0

        # multi-head long-range attention
        src2 = self.attn1(
            q=src + pos_embed,
            k=src + pos_embed,
            v=src,
            attn_type='sparse_long',
            ah=self.ahw[0],
            aw=self.ahw[1]
        )[0]
        src = src + src2    # ? this might be ok to remove

        # multi-head short-range attention
        src2 = self.attn2(
            q=src + pos_embed,
            k=src + pos_embed,
            v=src,
            attn_type='sparse_short',
            ah=self.ahw[2],
            aw=self.ahw[3]
        )[0]
        src = src + src2
        src = self.norm1(src)

        # feed forward
        src2 = self.linear2(self.activation(self.linear1(src)))
        src = src + src2
        src = self.norm2(src)

        return src


class TransformerDecoderUnit(nn.Module):
    def __init__(self, feat_dim, n_head=8, pos_en_flag=True, attn_type='softmax', P=None):
        super(TransformerDecoderUnit, self).__init__()
        self.feat_dim = feat_dim
        self.attn_type = attn_type
        self.pos_en_flag = pos_en_flag
        self.P = P

        self.pos_en = PosEnSine(self.feat_dim // 2)
        self.attn1 = OurMultiheadAttention(feat_dim, n_head)    # self-attention
        self.attn2 = OurMultiheadAttention(feat_dim, n_head)    # cross-attention

        self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1)
        self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1)
        self.activation = nn.ReLU(inplace=True)

        self.norm1 = nn.BatchNorm2d(self.feat_dim)
        self.norm2 = nn.BatchNorm2d(self.feat_dim)
        self.norm3 = nn.BatchNorm2d(self.feat_dim)

    def forward(self, tgt, src):
        if self.pos_en_flag:
            src_pos_embed = self.pos_en(src)
            tgt_pos_embed = self.pos_en(tgt)
        else:
            src_pos_embed = 0
            tgt_pos_embed = 0

        # self-multi-head attention
        tgt2 = self.attn1(
            q=tgt + tgt_pos_embed, k=tgt + tgt_pos_embed, v=tgt, attn_type=self.attn_type, P=self.P
        )[0]
        tgt = tgt + tgt2
        tgt = self.norm1(tgt)

        # cross-multi-head attention
        tgt2 = self.attn2(
            q=tgt + tgt_pos_embed, k=src + src_pos_embed, v=src, attn_type=self.attn_type, P=self.P
        )[0]
        tgt = tgt + tgt2
        tgt = self.norm2(tgt)

        # feed forward
        tgt2 = self.linear2(self.activation(self.linear1(tgt)))
        tgt = tgt + tgt2
        tgt = self.norm3(tgt)

        return tgt


class TransformerDecoderUnitSparse(nn.Module):
    def __init__(self, feat_dim, n_head=8, pos_en_flag=True, ahw=None):
        super(TransformerDecoderUnitSparse, self).__init__()
        self.feat_dim = feat_dim
        self.ahw = ahw    # [Ph_tgt, Pw_tgt, Qh_tgt, Qw_tgt, Ph_src, Pw_src, Qh_tgt, Qw_tgt]
        self.pos_en_flag = pos_en_flag

        self.pos_en = PosEnSine(self.feat_dim // 2)
        self.attn1_1 = OurMultiheadAttention(feat_dim, n_head)    # self-attention: long
        self.attn1_2 = OurMultiheadAttention(feat_dim, n_head)    # self-attention: short

        self.attn2_1 = OurMultiheadAttention(
            feat_dim, n_head
        )    # cross-attention: self-attention-long + cross-attention-short
        self.attn2_2 = OurMultiheadAttention(feat_dim, n_head)

        self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1)
        self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1)
        self.activation = nn.ReLU(inplace=True)

        self.norm1 = nn.BatchNorm2d(self.feat_dim)
        self.norm2 = nn.BatchNorm2d(self.feat_dim)
        self.norm3 = nn.BatchNorm2d(self.feat_dim)

    def forward(self, tgt, src):
        if self.pos_en_flag:
            src_pos_embed = self.pos_en(src)
            tgt_pos_embed = self.pos_en(tgt)
        else:
            src_pos_embed = 0
            tgt_pos_embed = 0

        # self-multi-head attention: sparse long
        tgt2 = self.attn1_1(
            q=tgt + tgt_pos_embed,
            k=tgt + tgt_pos_embed,
            v=tgt,
            attn_type='sparse_long',
            ah=self.ahw[0],
            aw=self.ahw[1]
        )[0]
        tgt = tgt + tgt2
        # self-multi-head attention: sparse short
        tgt2 = self.attn1_2(
            q=tgt + tgt_pos_embed,
            k=tgt + tgt_pos_embed,
            v=tgt,
            attn_type='sparse_short',
            ah=self.ahw[2],
            aw=self.ahw[3]
        )[0]
        tgt = tgt + tgt2
        tgt = self.norm1(tgt)

        # self-multi-head attention: sparse long
        src2 = self.attn2_1(
            q=src + src_pos_embed,
            k=src + src_pos_embed,
            v=src,
            attn_type='sparse_long',
            ah=self.ahw[4],
            aw=self.ahw[5]
        )[0]
        src = src + src2
        # cross-multi-head attention: sparse short
        tgt2 = self.attn2_2(
            q=tgt + tgt_pos_embed,
            k=src + src_pos_embed,
            v=src,
            attn_type='sparse_short',
            ah=self.ahw[6],
            aw=self.ahw[7]
        )[0]
        tgt = tgt + tgt2
        tgt = self.norm2(tgt)

        # feed forward
        tgt2 = self.linear2(self.activation(self.linear1(tgt)))
        tgt = tgt + tgt2
        tgt = self.norm3(tgt)

        return tgt