import torch as t
import torch.distributed as dist
from tqdm import tqdm
from datetime import date
import os
import sys

import sys
sys.path.append(os.path.dirname(os.path.dirname(sys.path[0])))
from utils.torch_utils import parse_args


args = parse_args()
mydevice = t.device('cuda:' + args.gpu)

def def_tqdm(x):
    return tqdm(x, leave=True, file=sys.stdout, bar_format="{n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]")

def get_range(x):
    if dist.get_rank() == 0:
        return def_tqdm(x)
    else:
        return x

def init_logging(hps, local_rank, rank):
    logdir = f"{hps.local_logdir}/{hps.name}"
    if local_rank == 0:
        if not os.path.exists(logdir):
            os.makedirs(logdir)
        with open(logdir + 'argv.txt', 'w') as f:
            f.write(hps.argv + '\n')
        print("Logging to", logdir)
    logger = Logger(logdir, rank)
    metrics = Metrics()
    logger.add_text('hps', str(hps))
    return logger, metrics

def get_name(hps):
    name = ""
    for key, value in hps.items():
        name += f"{key}_{value}_"
    return name

def average_metrics(_metrics):
    metrics = {}
    for _metric in _metrics:
        for key, val in _metric.items():
            if key not in metrics:
                metrics[key] = []
            metrics[key].append(val)
    return {key: sum(vals)//len(vals) for key, vals in metrics.items()}

class Metrics:
    def __init__(self):
        self.sum = {}
        self.n = {}

    def update(self, tag, val, batch):
        # v is average value over batch
        # store total value and total batch, returns dist average
        sum = t.tensor(val * batch).float().to(mydevice)
        n = t.tensor(batch).float().to(mydevice)
        dist.all_reduce(sum)
        dist.all_reduce(n)
        sum = sum.item()
        n = n.item()
        self.sum[tag] = self.sum.get(tag, 0.0) + sum
        self.n[tag] = self.n.get(tag, 0.0) + n
        return sum / n

    def avg(self, tag):
        if tag in self.sum:
            return self.sum[tag] / self.n[tag]
        else:
            return 0.0

    def reset(self):
        self.sum = {}
        self.n = {}

class Logger:
    def __init__(self, logdir, rank):
        if rank == 0:
            from tensorboardX import SummaryWriter
            self.sw = SummaryWriter(f"{logdir}/logs")
        self.iters = 0
        self.rank = rank
        self.works = []
        self.logdir = logdir

    def step(self):
        self.iters += 1

    def flush(self):
        if self.rank == 0:
            self.sw.flush()

    def add_text(self, tag, text):
        if self.rank == 0:
            self.sw.add_text(tag, text, self.iters)

    def add_audios(self, tag, auds, sample_rate=22050, max_len=None, max_log=8):
        if self.rank == 0:
            for i in range(min(len(auds), max_log)):
                if max_len:
                    self.sw.add_audio(f"{i}/{tag}", auds[i][:max_len * sample_rate], self.iters, sample_rate)
                else:
                    self.sw.add_audio(f"{i}/{tag}", auds[i], self.iters, sample_rate)

    def add_audio(self, tag, aud, sample_rate=22050):
        if self.rank == 0:
            self.sw.add_audio(tag, aud, self.iters, sample_rate)

    def add_images(self, tag, img, dataformats="NHWC"):
        if self.rank == 0:
            self.sw.add_images(tag, img, self.iters, dataformats=dataformats)

    def add_image(self, tag, img):
        if self.rank == 0:
            self.sw.add_image(tag, img, self.iters)

    def add_scalar(self, tag, val):
        if self.rank == 0:
            self.sw.add_scalar(tag, val, self.iters)

    def get_range(self, loader):
        if self.rank == 0:
            self.trange = def_tqdm(loader)
        else:
            self.trange = loader
        return enumerate(self.trange)

    def close_range(self):
        if self.rank == 0:
            self.trange.close()

    def set_postfix(self, *args, **kwargs):
        if self.rank == 0:
            self.trange.set_postfix(*args, **kwargs)

    # For logging summaries of varies graph ops
    def add_reduce_scalar(self, tag, layer, val):
        if self.iters % 100 == 0:
            with t.no_grad():
                val = val.float().norm()/float(val.numel())
            work = dist.reduce(val, 0, async_op=True)
            self.works.append((tag, layer, val, work))

    def finish_reduce(self):
        for tag, layer, val, work in self.works:
            work.wait()
            if self.rank == 0:
                val = val.item()/dist.get_world_size()
                self.lw[layer].add_scalar(tag, val, self.iters)
        self.works = []