import hydra
import torch
import torch.nn.functional as F
from einops.layers.torch import Rearrange

from utils.pooling import HomogeneousAggregator
import torch.nn as nn


class RelationalTransformer(nn.Module):
    def __init__(
        self,
        d_node,
        d_edge,
        d_attn_hid,
        d_node_hid,
        d_edge_hid,
        d_out_hid,
        d_out,
        n_layers,
        n_heads,
        layer_layout,
        graph_constructor,
        dropout=0.0,
        node_update_type="rt",
        disable_edge_updates=False,
        use_cls_token=False,
        pooling_method="cat",
        pooling_layer_idx="last",
        rev_edge_features=False,
        modulate_v=True,
        use_ln=True,
        tfixit_init=False,
    ):
        super().__init__()
        assert use_cls_token == (pooling_method == "cls_token")
        self.pooling_method = pooling_method
        self.pooling_layer_idx = pooling_layer_idx
        self.rev_edge_features = rev_edge_features
        self.nodes_per_layer = layer_layout
        self.construct_graph = hydra.utils.instantiate(
            graph_constructor,
            d_node=d_node,
            d_edge=d_edge,
            layer_layout=layer_layout,
            rev_edge_features=rev_edge_features,
        )

        self.use_cls_token = use_cls_token
        if use_cls_token:
            self.cls_token = nn.Parameter(torch.randn(d_node))

        self.layers = nn.ModuleList(
            [
                torch.jit.script(
                    RTLayer(
                        d_node,
                        d_edge,
                        d_attn_hid,
                        d_node_hid,
                        d_edge_hid,
                        n_heads,
                        dropout,
                        node_update_type=node_update_type,
                        disable_edge_updates=(
                            (disable_edge_updates or (i == n_layers - 1))
                            and pooling_method != "mean_edge"
                            and pooling_layer_idx != "all"
                        ),
                        modulate_v=modulate_v,
                        use_ln=use_ln,
                        tfixit_init=tfixit_init,
                        n_layers=n_layers,
                    )
                )
                for i in range(n_layers)
            ]
        )

        if pooling_method != "cls_token":
            self.pool = HomogeneousAggregator(
                pooling_method,
                pooling_layer_idx,
                layer_layout,
            )

        self.num_graph_features = (
            layer_layout[-1] * d_node
            if pooling_method == "cat" and pooling_layer_idx == "last"
            else d_edge if pooling_method in ("mean_edge", "max_edge") else d_node
        )
        self.proj_out = nn.Sequential(
            nn.Linear(self.num_graph_features, d_out_hid),
            nn.ReLU(),
            # nn.Linear(d_out_hid, d_out_hid),
            # nn.ReLU(),
            nn.Linear(d_out_hid, d_out),
        )

        self.final_features = (None,None,None,None)

    def forward(self, inputs):
        attn_weights = None
        node_features, edge_features, mask = self.construct_graph(inputs)
        if self.use_cls_token:
            node_features = torch.cat(
                [
                    # repeat(self.cls_token, "d -> b 1 d", b=node_features.size(0)),
                    self.cls_token.unsqueeze(0).expand(node_features.size(0), 1, -1),
                    node_features,
                ],
                dim=1,
            )
            edge_features = F.pad(edge_features, (0, 0, 1, 0, 1, 0), value=0)
        for layer in self.layers:
            node_features, edge_features, attn_weights = layer(node_features, edge_features, mask)

        if self.pooling_method == "cls_token":
            graph_features = node_features[:, 0]
        else:
            graph_features = self.pool(node_features, edge_features)
        self.final_features = (graph_features, node_features, edge_features, attn_weights)
        return self.proj_out(graph_features)


class RTLayer(nn.Module):
    def __init__(
        self,
        d_node,
        d_edge,
        d_attn_hid,
        d_node_hid,
        d_edge_hid,
        n_heads,
        dropout,
        node_update_type="rt",
        disable_edge_updates=False,
        modulate_v=True,
        use_ln=True,
        tfixit_init=False,
        n_layers=None,
    ):
        super().__init__()
        self.node_update_type = node_update_type
        self.disable_edge_updates = disable_edge_updates
        self.use_ln = use_ln
        self.n_layers = n_layers

        self.self_attn = torch.jit.script(
            RTAttention(
                d_node,
                d_edge,
                d_attn_hid,
                n_heads,
                modulate_v=modulate_v,
                use_ln=use_ln,
            )
        )
        # self.self_attn = RTAttention(d_hid, d_hid, d_hid, n_heads)
        self.lin0 = Linear(d_node, d_node)
        self.dropout0 = nn.Dropout(dropout)
        if use_ln:
            self.node_ln0 = nn.LayerNorm(d_node)
            self.node_ln1 = nn.LayerNorm(d_node)
        else:
            self.node_ln0 = nn.Identity()
            self.node_ln1 = nn.Identity()

        act_fn = nn.GELU

        self.node_mlp = nn.Sequential(
            Linear(d_node, d_node_hid, bias=False),
            act_fn(),
            Linear(d_node_hid, d_node),
            nn.Dropout(dropout),
        )

        if not self.disable_edge_updates:
            self.edge_updates = EdgeLayer(
                d_node=d_node,
                d_edge=d_edge,
                d_edge_hid=d_edge_hid,
                dropout=dropout,
                act_fn=act_fn,
                use_ln=use_ln,
            )
        else:
            self.edge_updates = NoEdgeLayer()

        if tfixit_init:
            self.fixit_init()

    def fixit_init(self):
        temp_state_dict = self.state_dict()
        n_layers = self.n_layers
        for name, param in self.named_parameters():
            if "weight" in name:
                if name.split(".")[0] in ["node_mlp", "edge_mlp0", "edge_mlp1"]:
                    temp_state_dict[name] = (0.67 * (n_layers) ** (-1.0 / 4.0)) * param
                elif name.split(".")[0] in ["self_attn"]:
                    temp_state_dict[name] = (0.67 * (n_layers) ** (-1.0 / 4.0)) * (
                        param * (2**0.5)
                    )

        self.load_state_dict(temp_state_dict)

    def node_updates(self, node_features, edge_features, mask):
        out = self.self_attn(node_features, edge_features, mask)
        attn_out, attn_weights = out
        node_features = self.node_ln0(
            node_features
            + self.dropout0(
                self.lin0(attn_out)
            )
        )
        node_features = self.node_ln1(node_features + self.node_mlp(node_features))

        return node_features, attn_weights

    def forward(self, node_features, edge_features, mask):
        node_features, attn_weights = self.node_updates(node_features, edge_features, mask)
        edge_features = self.edge_updates(node_features, edge_features, mask)

        return node_features, edge_features, attn_weights


class EdgeLayer(nn.Module):
    def __init__(
        self,
        *,
        d_node,
        d_edge,
        d_edge_hid,
        dropout,
        act_fn,
        use_ln=True,
    ) -> None:
        super().__init__()
        self.edge_mlp0 = EdgeMLP(
            d_edge=d_edge,
            d_node=d_node,
            d_edge_hid=d_edge_hid,
            act_fn=act_fn,
            dropout=dropout,
        )
        self.edge_mlp1 = nn.Sequential(
            Linear(d_edge, d_edge_hid, bias=False),
            act_fn(),
            Linear(d_edge_hid, d_edge),
            nn.Dropout(dropout),
        )
        if use_ln:
            self.eln0 = nn.LayerNorm(d_edge)
            self.eln1 = nn.LayerNorm(d_edge)
        else:
            self.eln0 = nn.Identity()
            self.eln1 = nn.Identity()

    def forward(self, node_features, edge_features, mask):
        edge_features = self.eln0(
            edge_features + self.edge_mlp0(node_features, edge_features)
        )
        edge_features = self.eln1(edge_features + self.edge_mlp1(edge_features))
        return edge_features


class NoEdgeLayer(nn.Module):
    def forward(self, node_features, edge_features, mask):
        return edge_features


class EdgeMLP(nn.Module):
    def __init__(self, *, d_node, d_edge, d_edge_hid, act_fn, dropout):
        super().__init__()
        self.reverse_edge = Rearrange("b n m d -> b m n d")
        self.lin0_e = Linear(2 * d_edge, d_edge_hid)
        self.lin0_s = Linear(d_node, d_edge_hid)
        self.lin0_t = Linear(d_node, d_edge_hid)
        self.act = act_fn()
        self.lin1 = Linear(d_edge_hid, d_edge)
        self.drop = nn.Dropout(dropout)

    def forward(self, node_features, edge_features):
        source_nodes = (
            self.lin0_s(node_features)
            .unsqueeze(-2)
            .expand(-1, -1, node_features.size(-2), -1)
        )
        target_nodes = (
            self.lin0_t(node_features)
            .unsqueeze(-3)
            .expand(-1, node_features.size(-2), -1, -1)
        )

        # reversed_edge_features = self.reverse_edge(edge_features)
        edge_features = self.lin0_e(
            torch.cat([edge_features, self.reverse_edge(edge_features)], dim=-1)
        )
        edge_features = edge_features + source_nodes + target_nodes
        edge_features = self.act(edge_features)
        edge_features = self.lin1(edge_features)
        edge_features = self.drop(edge_features)

        return edge_features


class RTAttention(nn.Module):
    def __init__(self, d_node, d_edge, d_hid, n_heads, modulate_v=None, use_ln=True):
        super().__init__()
        self.n_heads = n_heads
        self.d_node = d_node
        self.d_edge = d_edge
        self.d_hid = d_hid
        self.use_ln = use_ln
        self.modulate_v = modulate_v
        self.scale = 1 / (d_hid**0.5)
        self.split_head_node = Rearrange("b n (h d) -> b h n d", h=n_heads)
        self.split_head_edge = Rearrange("b n m (h d) -> b h n m d", h=n_heads)
        self.cat_head_node = Rearrange("... h n d -> ... n (h d)", h=n_heads)

        self.qkv_node = Linear(d_node, 3 * d_hid, bias=False)
        self.edge_factor = 4 if modulate_v else 3
        self.qkv_edge = Linear(d_edge, self.edge_factor * d_hid, bias=False)
        self.proj_out = Linear(d_hid, d_node)

    def forward(self, node_features, edge_features, mask):
        qkv_node = self.qkv_node(node_features)
        # qkv_node = rearrange(qkv_node, "b n (h d) -> b h n d", h=self.n_heads)
        qkv_node = self.split_head_node(qkv_node)
        q_node, k_node, v_node = torch.chunk(qkv_node, 3, dim=-1)

        qkv_edge = self.qkv_edge(edge_features)
        # qkv_edge = rearrange(qkv_edge, "b n m (h d) -> b h n m d", h=self.n_heads)
        qkv_edge = self.split_head_edge(qkv_edge)
        qkv_edge = torch.chunk(qkv_edge, self.edge_factor, dim=-1)
        # q_edge, k_edge, v_edge, q_edge_b, k_edge_b, v_edge_b = torch.chunk(
        #     qkv_edge, 6, dim=-1
        # )
        # qkv_edge = [item.masked_fill(mask.unsqueeze(1) == 0, 0) for item in qkv_edge]

        q = q_node.unsqueeze(-2) + qkv_edge[0]  # + q_edge_b
        k = k_node.unsqueeze(-3) + qkv_edge[1]  # + k_edge_b
        if self.modulate_v:
            v = v_node.unsqueeze(-3) * qkv_edge[3] + qkv_edge[2]
        else:
            v = v_node.unsqueeze(-3) + qkv_edge[2]
        dots = self.scale * torch.einsum("b h i j d, b h i j d -> b h i j", q, k)
        # dots.masked_fill_(mask.unsqueeze(1).squeeze(-1) == 0, -1e-9)

        attn = F.softmax(dots, dim=-1)
        out = torch.einsum("b h i j, b h i j d -> b h i d", attn, v)
        out = self.cat_head_node(out)
        return self.proj_out(out), attn


def Linear(in_features, out_features, bias=True):
    m = nn.Linear(in_features, out_features, bias)
    nn.init.xavier_uniform_(m.weight)  # , gain=1 / math.sqrt(2))
    if bias:
        nn.init.constant_(m.bias, 0.0)
    return m