import timeit
import numpy as np
import os
import os.path as osp
import shutil
import copy
import torch
import torch.nn as nn
import torch.distributed as dist
from .cfg_holder import cfg_unique_holder as cfguh
from . import sync

def print_log(*console_info):
    grank, lrank, _ = sync.get_rank('all')
    if lrank!=0:
        return

    console_info = [str(i) for i in console_info]
    console_info = ' '.join(console_info)
    print(console_info)

    if grank!=0:
        return

    log_file = None
    try:
        log_file = cfguh().cfg.train.log_file
    except:
        try:
            log_file = cfguh().cfg.eval.log_file
        except:
            return
    if log_file is not None:
        with open(log_file, 'a') as f:
            f.write(console_info + '\n')

class distributed_log_manager(object):
    def __init__(self):
        self.sum = {}
        self.cnt = {}
        self.time_check = timeit.default_timer()

        cfgt = cfguh().cfg.train
        self.ddp = sync.is_ddp()
        self.grank, self.lrank, _ = sync.get_rank('all')
        self.gwsize = sync.get_world_size('global')

        use_tensorboard = cfgt.get('log_tensorboard', False) and (self.grank==0)

        self.tb = None
        if use_tensorboard:
            import tensorboardX
            monitoring_dir = osp.join(cfguh().cfg.train.log_dir, 'tensorboard')
            self.tb = tensorboardX.SummaryWriter(osp.join(monitoring_dir))

    def accumulate(self, n, **data):
        if n < 0:
            raise ValueError

        for itemn, di in data.items():
            if itemn in self.sum:
                self.sum[itemn] += di * n
                self.cnt[itemn] += n
            else:
                self.sum[itemn] = di * n
                self.cnt[itemn] = n

    def get_mean_value_dict(self):
        value_gather = [
            self.sum[itemn]/self.cnt[itemn] \
                for itemn in sorted(self.sum.keys()) ]

        value_gather_tensor = torch.FloatTensor(value_gather).to(self.lrank)
        if self.ddp:
            dist.all_reduce(value_gather_tensor, op=dist.ReduceOp.SUM)
            value_gather_tensor /= self.gwsize

        mean = {}
        for idx, itemn in enumerate(sorted(self.sum.keys())):
            mean[itemn] = value_gather_tensor[idx].item()
        return mean

    def tensorboard_log(self, step, data, mode='train', **extra):
        if self.tb is None:
            return
        if mode == 'train':
            self.tb.add_scalar('other/epochn', extra['epochn'], step)
            if ('lr' in extra) and (extra['lr'] is not None):
                self.tb.add_scalar('other/lr', extra['lr'], step)
            for itemn, di in data.items():
                if itemn.find('loss') == 0:
                    self.tb.add_scalar('loss/'+itemn,  di, step)
                elif itemn == 'Loss':
                    self.tb.add_scalar('Loss',  di, step)
                else:
                    self.tb.add_scalar('other/'+itemn, di, step)
        elif mode == 'eval':
            if isinstance(data, dict):
                for itemn, di in data.items():
                    self.tb.add_scalar('eval/'+itemn, di, step)
            else:
                self.tb.add_scalar('eval', data, step)
        return

    def train_summary(self, itern, epochn, samplen, lr, tbstep=None):
        console_info = [
            'Iter:{}'.format(itern),
            'Epoch:{}'.format(epochn),
            'Sample:{}'.format(samplen),]

        if lr is not None:
            console_info += ['LR:{:.4E}'.format(lr)]

        mean = self.get_mean_value_dict()

        tbstep = itern if tbstep is None else tbstep
        self.tensorboard_log(
            tbstep, mean, mode='train',
            itern=itern, epochn=epochn, lr=lr)

        loss = mean.pop('Loss')
        mean_info = ['Loss:{:.4f}'.format(loss)] + [
            '{}:{:.4f}'.format(itemn, mean[itemn]) \
                for itemn in sorted(mean.keys()) \
                    if itemn.find('loss') == 0
        ]
        console_info += mean_info
        console_info.append('Time:{:.2f}s'.format(
            timeit.default_timer() - self.time_check))
        return ' , '.join(console_info)

    def clear(self):
        self.sum = {}
        self.cnt = {}
        self.time_check = timeit.default_timer()

    def tensorboard_close(self):
        if self.tb is not None:
            self.tb.close()

# ----- also include some small utils -----

def torch_to_numpy(*argv):
    if len(argv) > 1:
        data = list(argv)
    else:
        data = argv[0]

    if isinstance(data, torch.Tensor):
        return data.to('cpu').detach().numpy()

    elif isinstance(data, (list, tuple)):
        out = []
        for di in data:
            out.append(torch_to_numpy(di))
        return out

    elif isinstance(data, dict):
        out = {}
        for ni, di in data.items():
            out[ni] = torch_to_numpy(di)
        return out
    
    else:
        return data