# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import einops
import torch
from jaxtyping import Float
from typeguard import typechecked


@torch.no_grad()
@typechecked
def get_contributions(
    parts: torch.Tensor,
    whole: torch.Tensor,
    distance_norm: int = 1,
) -> torch.Tensor:
    """
    Compute contributions of the `parts` vectors into the `whole` vector.

    Shapes of the tensors are as follows:
    parts:  p_1 ... p_k, v_1 ... v_n, d
    whole:               v_1 ... v_n, d
    result: p_1 ... p_k, v_1 ... v_n

    Here
    * `p_1 ... p_k`: dimensions for enumerating the parts
    * `v_1 ... v_n`: dimensions listing the independent cases (batching),
    * `d` is the dimension to compute the distances on.

    The resulting contributions will be normalized so that
    for each v_: sum(over p_ of result(p_, v_)) = 1.
    """
    EPS = 1e-5

    k = len(parts.shape) - len(whole.shape)
    assert k >= 0
    assert parts.shape[k:] == whole.shape
    bc_whole = whole.expand(parts.shape)  # new dims p_1 ... p_k are added to the front

    distance = torch.nn.functional.pairwise_distance(parts, bc_whole, p=distance_norm)

    whole_norm = torch.norm(whole, p=distance_norm, dim=-1)
    distance = (whole_norm - distance).clip(min=EPS)

    sum = distance.sum(dim=tuple(range(k)), keepdim=True)

    return distance / sum


@torch.no_grad()
@typechecked
def get_contributions_with_one_off_part(
    parts: torch.Tensor,
    one_off: torch.Tensor,
    whole: torch.Tensor,
    distance_norm: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Same as computing the contributions, but there is one additional part. That's useful
    because we always have the residual stream as one of the parts.

    See `get_contributions` documentation about `parts` and `whole` dimensions. The
    `one_off` should have the same dimensions as `whole`.

    Returns a pair consisting of
    1. contributions tensor for the `parts`
    2. contributions tensor for the `one_off` vector
    """
    assert one_off.shape == whole.shape

    k = len(parts.shape) - len(whole.shape)
    assert k >= 0

    # Flatten the p_ dimensions, get contributions for the list, unflatten.
    flat = parts.flatten(start_dim=0, end_dim=k - 1)
    flat = torch.cat([flat, one_off.unsqueeze(0)])
    contributions = get_contributions(flat, whole, distance_norm)
    parts_contributions, one_off_contributions = torch.split(
        contributions, flat.shape[0] - 1
    )
    return (
        parts_contributions.unflatten(0, parts.shape[0:k]),
        one_off_contributions[0],
    )


@torch.no_grad()
@typechecked
def get_attention_contributions(
    resid_pre: Float[torch.Tensor, "batch pos d_model"],
    resid_mid: Float[torch.Tensor, "batch pos d_model"],
    decomposed_attn: Float[torch.Tensor, "batch pos key_pos head d_model"],
    distance_norm: int = 1,
) -> Tuple[
    Float[torch.Tensor, "batch pos key_pos head"],
    Float[torch.Tensor, "batch pos"],
]:
    """
    Returns a pair of
    - a tensor of contributions of each token via each head
    - the contribution of the residual stream.
    """

    # part dimensions | batch dimensions | vector dimension
    # ----------------+------------------+-----------------
    # key_pos, head   | batch, pos       | d_model
    parts = einops.rearrange(
        decomposed_attn,
        "batch pos key_pos head d_model -> key_pos head batch pos d_model",
    )
    attn_contribution, residual_contribution = get_contributions_with_one_off_part(
        parts, resid_pre, resid_mid, distance_norm
    )
    return (
        einops.rearrange(
            attn_contribution, "key_pos head batch pos -> batch pos key_pos head"
        ),
        residual_contribution,
    )


@torch.no_grad()
@typechecked
def get_mlp_contributions(
    resid_mid: Float[torch.Tensor, "batch pos d_model"],
    resid_post: Float[torch.Tensor, "batch pos d_model"],
    mlp_out: Float[torch.Tensor, "batch pos d_model"],
    distance_norm: int = 1,
) -> Tuple[Float[torch.Tensor, "batch pos"], Float[torch.Tensor, "batch pos"]]:
    """
    Returns a pair of (mlp, residual) contributions for each sentence and token.
    """

    contributions = get_contributions(
        torch.stack((mlp_out, resid_mid)), resid_post, distance_norm
    )
    return contributions[0], contributions[1]


@torch.no_grad()
@typechecked
def get_decomposed_mlp_contributions(
    resid_mid: Float[torch.Tensor, "d_model"],
    resid_post: Float[torch.Tensor, "d_model"],
    decomposed_mlp_out: Float[torch.Tensor, "hidden d_model"],
    distance_norm: int = 1,
) -> Tuple[Float[torch.Tensor, "hidden"], float]:
    """
    Similar to `get_mlp_contributions`, but it takes the MLP output for each neuron of
    the hidden layer and thus computes a contribution per neuron.

    Doesn't contain batch and token dimensions for sake of saving memory. But we may
    consider adding them.
    """

    neuron_contributions, residual_contribution = get_contributions_with_one_off_part(
        decomposed_mlp_out, resid_mid, resid_post, distance_norm
    )
    return neuron_contributions, residual_contribution.item()


@torch.no_grad()
def apply_threshold_and_renormalize(
    threshold: float,
    c_blocks: torch.Tensor,
    c_residual: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Thresholding mechanism used in the original graphs paper. After the threshold is
    applied, the remaining contributions are renormalized on order to sum up to 1 for
    each representation.

    threshold: The threshold.
    c_residual: Contribution of the residual stream for each representation. This tensor
        should contain 1 element per representation, i.e., its dimensions are all batch
        dimensions.
    c_blocks: Contributions of the blocks. Could be 1 block per representation, like
        ffn, or heads*tokens blocks in case of attention. The shape of `c_residual`
        must be a prefix if the shape of this tensor. The remaining dimensions are for
        listing the blocks.
    """

    block_dims = len(c_blocks.shape)
    resid_dims = len(c_residual.shape)
    bound_dims = block_dims - resid_dims
    assert bound_dims >= 0
    assert c_blocks.shape[0:resid_dims] == c_residual.shape

    c_blocks = c_blocks * (c_blocks > threshold)
    c_residual = c_residual * (c_residual > threshold)

    denom = c_residual + c_blocks.sum(dim=tuple(range(resid_dims, block_dims)))
    return (
        c_blocks / denom.reshape(denom.shape + (1,) * bound_dims),
        c_residual / denom,
    )