File size: 2,283 Bytes
1b2a9b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import random
import numpy as np
import os.path
from swapae.data.base_dataset import BaseDataset, get_transform
import torchvision


class CIFAR100Dataset(BaseDataset):
    @staticmethod
    def modify_commandline_options(parser, is_train):
        parser.set_defaults(load_size=32, crop_size=32, preprocess_crop_padding=0,
                            preprocess='crop', num_classes=100, use_class_labels=True)
        opt, _ = parser.parse_known_args()
        assert opt.preprocess == 'crop' and opt.load_size == 32 and opt.crop_size == 32
        return parser

    def __init__(self, opt):
        self.opt = opt
        self.torch_dataset = torchvision.datasets.CIFAR100(
            opt.dataroot, train=opt.isTrain, download=True
        )
        self.transform = get_transform(self.opt, grayscale=False)
        self.class_list = self.create_class_list()

    def create_class_list(self):
        cache_path = os.path.join(self.opt.dataroot, "%s_classlist.npy" % self.opt.phase)
        if os.path.exists(cache_path):
            cache = np.load(cache_path)
            classlist = {i: [] for i in range(100)}
            for i, c in enumerate(cache):
                classlist[c].append(i)
            return classlist

        print("creating cache list of classes...")
        classes = np.zeros((len(self.torch_dataset)), dtype=int)
        for i in range(len(self.torch_dataset)):
            _, class_id = self.torch_dataset[i]
            classes[i] = class_id
            if i % 100 == 0:
                print("%d/%d\r" % (i, len(self.torch_dataset)), end="", flush=True)
        np.save(cache_path, classes)
        print("cache saved at %s" % cache_path)
        return self.create_class_list()

    def __getitem__(self, index):
        index = index % len(self.torch_dataset)
        image, class_id = self.torch_dataset[index]

        another_image_index = random.choice(self.class_list[class_id])
        another_image, another_class_id = self.torch_dataset[another_image_index]
        assert class_id == another_class_id
        return {"real_A": self.transform(image),
                "real_B": self.transform(another_image),
                "class_A": class_id, "class_B": class_id}

    def __len__(self):
        return len(self.torch_dataset)