'''
Adapted from
https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/train.py
'''


import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from typing import Callable, Iterable, Iterator

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed import ReduceOp
from SAE.dataset_iterator import ActivationsDataloader
from SAE.sae import SparseAutoencoder, unit_norm_decoder_, unit_norm_decoder_grad_adjustment_
from SAE.sae_utils import SAETrainingConfig, Config

from types import SimpleNamespace
from typing import Optional, List
import json

import tqdm

def weighted_average(points: torch.Tensor, weights: torch.Tensor):
    weights = weights / weights.sum()
    return (points * weights.view(-1, 1)).sum(dim=0)


@torch.no_grad()
def geometric_median_objective(
    median: torch.Tensor, points: torch.Tensor, weights: torch.Tensor
) -> torch.Tensor:

    norms = torch.linalg.norm(points - median.view(1, -1), dim=1)  # type: ignore

    return (norms * weights).sum()


def compute_geometric_median(
    points: torch.Tensor,
    weights: Optional[torch.Tensor] = None,
    eps: float = 1e-6,
    maxiter: int = 100,
    ftol: float = 1e-20,
    do_log: bool = False,
):
    """
    :param points: ``torch.Tensor`` of shape ``(n, d)``
    :param weights: Optional ``torch.Tensor`` of shape :math:``(n,)``.
    :param eps: Smallest allowed value of denominator, to avoid divide by zero.
        Equivalently, this is a smoothing parameter. Default 1e-6.
    :param maxiter: Maximum number of Weiszfeld iterations. Default 100
    :param ftol: If objective value does not improve by at least this `ftol` fraction, terminate the algorithm. Default 1e-20.
    :param do_log: If true will return a log of function values encountered through the course of the algorithm
    :return: SimpleNamespace object with fields
        - `median`: estimate of the geometric median, which is a ``torch.Tensor`` object of shape :math:``(d,)``
        - `termination`: string explaining how the algorithm terminated.
        - `logs`: function values encountered through the course of the algorithm in a list (None if do_log is false).
    """
    with torch.no_grad():

        if weights is None:
            weights = torch.ones((points.shape[0],), device=points.device)
        # initialize median estimate at mean
        new_weights = weights
        median = weighted_average(points, weights)
        objective_value = geometric_median_objective(median, points, weights)
        if do_log:
            logs = [objective_value]
        else:
            logs = None

        # Weiszfeld iterations
        early_termination = False
        pbar = tqdm.tqdm(range(maxiter))
        for _ in pbar:
            prev_obj_value = objective_value

            norms = torch.linalg.norm(points - median.view(1, -1), dim=1)  # type: ignore
            new_weights = weights / torch.clamp(norms, min=eps)
            median = weighted_average(points, new_weights)
            objective_value = geometric_median_objective(median, points, weights)

            if logs is not None:
                logs.append(objective_value)
            if abs(prev_obj_value - objective_value) <= ftol * objective_value:
                early_termination = True
                break

            pbar.set_description(f"Objective value: {objective_value:.4f}")

    median = weighted_average(points, new_weights)  # allow autodiff to track it
    return SimpleNamespace(
        median=median,
        new_weights=new_weights,
        termination=(
            "function value converged within tolerance"
            if early_termination
            else "maximum iterations reached"
        ),
        logs=logs,
    )

def maybe_transpose(x):
    return x.T if not x.is_contiguous() and x.T.is_contiguous() else x

import wandb

RANK = 0

class Logger:
    def __init__(self, sae_name, **kws):
        self.vals = {}
        self.enabled = (RANK == 0) and not kws.pop("dummy", False)
        self.sae_name = sae_name

    def logkv(self, k, v):
        if self.enabled:
            self.vals[f'{self.sae_name}/{k}'] = v.detach() if isinstance(v, torch.Tensor) else v
        return v

    def dumpkvs(self, step):
        if self.enabled:
            wandb.log(self.vals, step=step)
            self.vals = {}
    

class FeaturesStats:
    def __init__(self, dim, logger):
        self.dim = dim
        self.logger = logger
        self.reinit()

    def reinit(self):
        self.n_activated = torch.zeros(self.dim, dtype=torch.long, device="cuda")
        self.n = 0
    
    def update(self, inds):
        self.n += inds.shape[0]
        inds = inds.flatten().detach()
        self.n_activated.scatter_add_(0, inds, torch.ones_like(inds))

    def log(self):
        self.logger.logkv('activated', (self.n_activated / self.n + 1e-9).log10().cpu().numpy())

def training_loop_(
    aes, 
    train_acts_iter, 
    loss_fn, 
    log_interval, 
    save_interval,
    loggers,
    sae_cfgs,
):
    sae_packs = []
    for ae, cfg, logger in zip(aes, sae_cfgs, loggers):
        pbar = tqdm.tqdm(unit=" steps", desc="Training Loss: ")
        fstats = FeaturesStats(ae.n_dirs, logger)
        opt = torch.optim.Adam(ae.parameters(), lr=cfg.lr, eps=cfg.eps, fused=True)
        sae_packs.append((ae, cfg, logger, pbar, fstats, opt))
    
    for i, flat_acts_train_batch in enumerate(train_acts_iter):
        flat_acts_train_batch = flat_acts_train_batch.cuda()

        for ae, cfg, logger, pbar, fstats, opt in sae_packs:
            recons, info = ae(flat_acts_train_batch)
            loss = loss_fn(ae, cfg, flat_acts_train_batch, recons, info, logger)

            fstats.update(info['inds'])
            
            bs = flat_acts_train_batch.shape[0]
            logger.logkv('not-activated 1e4', (ae.stats_last_nonzero > 1e4 / bs).mean(dtype=float).item())
            logger.logkv('not-activated 1e6', (ae.stats_last_nonzero > 1e6 / bs).mean(dtype=float).item())
            logger.logkv('not-activated 1e7', (ae.stats_last_nonzero > 1e7 / bs).mean(dtype=float).item())

            logger.logkv('explained variance', explained_variance(recons, flat_acts_train_batch))
            logger.logkv('l2_div', (torch.linalg.norm(recons, dim=1) / torch.linalg.norm(flat_acts_train_batch, dim=1)).mean())

            if (i + 1) % log_interval == 0:
                fstats.log()
                fstats.reinit()
            
            if (i + 1) % save_interval == 0:
                ae.save_to_disk(f"{cfg.save_path}/{i + 1}")

            loss.backward()

            unit_norm_decoder_(ae)
            unit_norm_decoder_grad_adjustment_(ae)

            opt.step()
            opt.zero_grad()
            logger.dumpkvs(i)

            pbar.set_description(f"Training Loss {loss.item():.4f}")
            pbar.update(1)


    for ae, cfg, logger, pbar, fstats, opt in sae_packs:
        pbar.close()
        ae.save_to_disk(f"{cfg.save_path}/final")


def init_from_data_(ae, stats_acts_sample):
    ae.pre_bias.data = (
        compute_geometric_median(stats_acts_sample[:32768].float().cpu()).median.cuda().float()
    )


def mse(recons, x):
    # return ((recons - x) ** 2).sum(dim=-1).mean()
    return ((recons - x) ** 2).mean()

def normalized_mse(recon: torch.Tensor, xs: torch.Tensor) -> torch.Tensor:
    # only used for auxk
    xs_mu = xs.mean(dim=0)

    loss = mse(recon, xs) / mse(
        xs_mu[None, :].broadcast_to(xs.shape), xs
    )

    return loss

def explained_variance(recons, x):
    # Compute the variance of the difference
    diff = x - recons
    diff_var = torch.var(diff, dim=0, unbiased=False)

    # Compute the variance of the original tensor
    x_var = torch.var(x, dim=0, unbiased=False)

    # Avoid division by zero
    explained_var = 1 - diff_var / (x_var + 1e-8)

    return explained_var.mean()


def main():
    cfg = Config(json.load(open('SAE/config.json')))

    dataloader = ActivationsDataloader(cfg.paths_to_latents, cfg.block_name, cfg.bs)

    acts_iter = dataloader.iterate()
    stats_acts_sample = torch.cat([
        next(acts_iter).cpu() for _ in range(10)
    ], dim=0)

    aes = [
        SparseAutoencoder(
            n_dirs_local=sae.n_dirs,
            d_model=sae.d_model,
            k=sae.k,
            auxk=sae.auxk,
            dead_steps_threshold=sae.dead_toks_threshold // cfg.bs,
        ).cuda()
        for sae in cfg.saes
    ]
    
    for ae in aes:
        init_from_data_(ae, stats_acts_sample)
    
    mse_scale = (
        1 / ((stats_acts_sample.float().mean(dim=0) - stats_acts_sample.float()) ** 2).mean()
    )
    mse_scale = mse_scale.item()
    del stats_acts_sample

    wandb.init(
        project=cfg.wandb_project,
        name=cfg.wandb_name,
    )

    loggers = [Logger(
        sae_name=cfg_sae.sae_name,
        dummy=False,
    ) for cfg_sae in cfg.saes]

    training_loop_(
        aes,
        acts_iter,
        lambda ae, cfg_sae, flat_acts_train_batch, recons, info, logger: (
            # MSE
            logger.logkv("train_recons", mse_scale * mse(recons, flat_acts_train_batch))
            # AuxK
            + logger.logkv(
                "train_maxk_recons",
                cfg_sae.auxk_coef
                * normalized_mse(
                    ae.decode_sparse(
                        info["auxk_inds"],
                        info["auxk_vals"],
                    ),
                    flat_acts_train_batch - recons.detach() + ae.pre_bias.detach(),
                ).nan_to_num(0),
            )
        ),
        sae_cfgs = cfg.saes,
        loggers=loggers,
        log_interval=cfg.log_interval,
        save_interval=cfg.save_interval,
    )


if __name__ == "__main__":
    main()