from __future__ import division
import os
import torch
import datetime
import logging

logger = logging.getLogger(__name__)


class CheckpointSaver():
    """Class that handles saving and loading checkpoints during training."""
    def __init__(self, save_dir, save_steps=1000, overwrite=False):
        self.save_dir = os.path.abspath(save_dir)
        self.save_steps = save_steps
        self.overwrite = overwrite
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)
        self.get_latest_checkpoint()
        return

    def exists_checkpoint(self, checkpoint_file=None):
        """Check if a checkpoint exists in the current directory."""
        if checkpoint_file is None:
            return False if self.latest_checkpoint is None else True
        else:
            return os.path.isfile(checkpoint_file)

    def save_checkpoint(
        self,
        models,
        optimizers,
        epoch,
        batch_idx,
        batch_size,
        total_step_count,
        is_best=False,
        save_by_step=False,
        interval=5,
        with_optimizer=True
    ):
        """Save checkpoint."""
        timestamp = datetime.datetime.now()
        if self.overwrite:
            checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, 'model_latest.pt'))
        elif save_by_step:
            checkpoint_filename = os.path.abspath(
                os.path.join(self.save_dir, '{:08d}.pt'.format(total_step_count))
            )
        else:
            if epoch % interval == 0:
                checkpoint_filename = os.path.abspath(
                    os.path.join(self.save_dir, f'model_epoch_{epoch:02d}.pt')
                )
            else:
                checkpoint_filename = None

        checkpoint = {}
        for model in models:
            model_dict = models[model].state_dict()
            for k in list(model_dict.keys()):
                if '.smpl.' in k:
                    del model_dict[k]
            checkpoint[model] = model_dict
        if with_optimizer:
            for optimizer in optimizers:
                checkpoint[optimizer] = optimizers[optimizer].state_dict()
        checkpoint['epoch'] = epoch
        checkpoint['batch_idx'] = batch_idx
        checkpoint['batch_size'] = batch_size
        checkpoint['total_step_count'] = total_step_count
        print(timestamp, 'Epoch:', epoch, 'Iteration:', batch_idx)

        if checkpoint_filename is not None:
            torch.save(checkpoint, checkpoint_filename)
            print('Saving checkpoint file [' + checkpoint_filename + ']')
        if is_best:    # save the best
            checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, 'model_best.pt'))
            torch.save(checkpoint, checkpoint_filename)
            print(timestamp, 'Epoch:', epoch, 'Iteration:', batch_idx)
            print('Saving checkpoint file [' + checkpoint_filename + ']')
            torch.save(checkpoint, checkpoint_filename)
            print('Saved checkpoint file [' + checkpoint_filename + ']')

    def load_checkpoint(self, models, optimizers, checkpoint_file=None):
        """Load a checkpoint."""
        if checkpoint_file is None:
            logger.info('Loading latest checkpoint [' + self.latest_checkpoint + ']')
            checkpoint_file = self.latest_checkpoint
        checkpoint = torch.load(checkpoint_file)
        for model in models:
            if model in checkpoint:
                model_dict = models[model].state_dict()
                pretrained_dict = {
                    k: v
                    for k, v in checkpoint[model].items() if k in model_dict.keys()
                }
                model_dict.update(pretrained_dict)
                models[model].load_state_dict(model_dict)

                # models[model].load_state_dict(checkpoint[model])
        for optimizer in optimizers:
            if optimizer in checkpoint:
                optimizers[optimizer].load_state_dict(checkpoint[optimizer])
        return {
            'epoch': checkpoint['epoch'],
            'batch_idx': checkpoint['batch_idx'],
            'batch_size': checkpoint['batch_size'],
            'total_step_count': checkpoint['total_step_count']
        }

    def get_latest_checkpoint(self):
        """Get filename of latest checkpoint if it exists."""
        checkpoint_list = []
        for dirpath, dirnames, filenames in os.walk(self.save_dir):
            for filename in filenames:
                if filename.endswith('.pt'):
                    checkpoint_list.append(os.path.abspath(os.path.join(dirpath, filename)))
        # sort
        import re

        def atof(text):
            try:
                retval = float(text)
            except ValueError:
                retval = text
            return retval

        def natural_keys(text):
            '''
            alist.sort(key=natural_keys) sorts in human order
            http://nedbatchelder.com/blog/200712/human_sorting.html
            (See Toothy's implementation in the comments)
            float regex comes from https://stackoverflow.com/a/12643073/190597
            '''
            return [atof(c) for c in re.split(r'[+-]?([0-9]+(?:[.][0-9]*)?|[.][0-9]+)', text)]

        checkpoint_list.sort(key=natural_keys)
        self.latest_checkpoint = None if (len(checkpoint_list) == 0) else checkpoint_list[-1]
        return