Spaces:
mucusz
/
Runtime error

File size: 5,257 Bytes
da48dbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from __future__ import division
import os
import torch
import datetime
import logging
logger = logging.getLogger(__name__)

class CheckpointSaver():
    """Class that handles saving and loading checkpoints during training."""
    def __init__(self, save_dir, save_steps=1000, overwrite=False):
        self.save_dir = os.path.abspath(save_dir)
        self.save_steps = save_steps
        self.overwrite = overwrite
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)
        self.get_latest_checkpoint()
        return

    def exists_checkpoint(self, checkpoint_file=None):
        """Check if a checkpoint exists in the current directory."""
        if checkpoint_file is None:
            return False if self.latest_checkpoint is None else True
        else:
            return os.path.isfile(checkpoint_file)
    
    def save_checkpoint(self, models, optimizers, epoch, batch_idx, batch_size,
                        total_step_count, is_best=False, save_by_step=False, interval=5, with_optimizer=True):
        """Save checkpoint."""
        timestamp = datetime.datetime.now()
        if self.overwrite:
            checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, 'model_latest.pt'))
        elif save_by_step:
            checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, '{:08d}.pt'.format(total_step_count)))
        else:
            if epoch % interval == 0:
                checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, f'model_epoch_{epoch:02d}.pt'))
            else:
                checkpoint_filename = None
        
        checkpoint = {}
        for model in models:
            model_dict = models[model].state_dict()
            for k in list(model_dict.keys()):
                if '.smpl.' in k:                    
                    del model_dict[k]
            checkpoint[model] = model_dict
        if with_optimizer:
            for optimizer in optimizers:
                checkpoint[optimizer] = optimizers[optimizer].state_dict()
        checkpoint['epoch'] = epoch
        checkpoint['batch_idx'] = batch_idx
        checkpoint['batch_size'] = batch_size
        checkpoint['total_step_count'] = total_step_count
        print(timestamp, 'Epoch:', epoch, 'Iteration:', batch_idx)

        if checkpoint_filename is not None:
            torch.save(checkpoint, checkpoint_filename)
            print('Saving checkpoint file [' + checkpoint_filename + ']')
        if is_best: # save the best
            checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, 'model_best.pt'))
            torch.save(checkpoint, checkpoint_filename)
            print(timestamp, 'Epoch:', epoch, 'Iteration:', batch_idx)
            print('Saving checkpoint file [' + checkpoint_filename + ']')
            torch.save(checkpoint, checkpoint_filename)
            print('Saved checkpoint file [' + checkpoint_filename + ']')


    def load_checkpoint(self, models, optimizers, checkpoint_file=None):
        """Load a checkpoint."""
        if checkpoint_file is None:
            logger.info('Loading latest checkpoint [' + self.latest_checkpoint + ']')
            checkpoint_file = self.latest_checkpoint
        checkpoint = torch.load(checkpoint_file)
        for model in models:
            if model in checkpoint:
                model_dict = models[model].state_dict()
                pretrained_dict = {k: v for k, v in checkpoint[model].items()
                                   if k in model_dict.keys()}
                model_dict.update(pretrained_dict)
                models[model].load_state_dict(model_dict)

                # models[model].load_state_dict(checkpoint[model])
        for optimizer in optimizers:
            if optimizer in checkpoint:
                optimizers[optimizer].load_state_dict(checkpoint[optimizer])
        return {'epoch': checkpoint['epoch'],
                'batch_idx': checkpoint['batch_idx'],
                'batch_size': checkpoint['batch_size'],
                'total_step_count': checkpoint['total_step_count']}

    def get_latest_checkpoint(self):
        """Get filename of latest checkpoint if it exists."""
        checkpoint_list = [] 
        for dirpath, dirnames, filenames in os.walk(self.save_dir):
            for filename in filenames:
                if filename.endswith('.pt'):
                    checkpoint_list.append(os.path.abspath(os.path.join(dirpath, filename)))
        # sort
        import re
        def atof(text):
            try:
                retval = float(text)
            except ValueError:
                retval = text
            return retval

        def natural_keys(text):
            '''
            alist.sort(key=natural_keys) sorts in human order
            http://nedbatchelder.com/blog/200712/human_sorting.html
            (See Toothy's implementation in the comments)
            float regex comes from https://stackoverflow.com/a/12643073/190597
            '''
            return [ atof(c) for c in re.split(r'[+-]?([0-9]+(?:[.][0-9]*)?|[.][0-9]+)', text) ]
        
        checkpoint_list.sort(key=natural_keys)
        self.latest_checkpoint =  None if (len(checkpoint_list) == 0) else checkpoint_list[-1]
        return