File size: 2,348 Bytes
1b2a9b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import random
import sys
import os.path
from PIL import Image
from swapae.data.base_dataset import BaseDataset, get_transform
import cv2
import numpy as np
if sys.version_info[0] == 2:
    import cPickle as pickle
else:
    import pickle
import torchvision.transforms as transforms


class LMDBDataset(BaseDataset):
    def __init__(self, opt):
        import lmdb
        self.opt = opt
        write_cache = True
        root = opt.dataroot
        self.root = os.path.expanduser(root)
        self.env = lmdb.open(root, readonly=True, lock=False)
        with self.env.begin(write=False) as txn:
            self.length = txn.stat()['entries']
        print('lmdb file at %s opened.' % root)
        cache_file = os.path.join(root, '_cache_')
        if os.path.isfile(cache_file):
            self.keys = pickle.load(open(cache_file, "rb"))
        elif write_cache:
            print('generating keys')
            with self.env.begin(write=False) as txn:
                self.keys = [key for key, _ in txn.cursor()]
            pickle.dump(self.keys, open(cache_file, "wb"))
            print('cache file generated at %s' % cache_file)
        else:
            self.keys = []

        random.Random(0).shuffle(self.keys)

        self.transform = get_transform(self.opt, grayscale=False)
        if "lsun" in self.opt.dataroot.lower():
            print("Seems like a LSUN dataset, so we will apply BGR->RGB conversion")


    def __getitem__(self, index):
        path = self.keys[index]
        return self.getitem_by_path(path)

    def getitem_by_path(self, path):
        env = self.env
        with env.begin(write=False) as txn:
            imgbuf = txn.get(path)
        try:
            img = cv2.imdecode(
                np.fromstring(imgbuf, dtype=np.uint8), 1)
        except cv2.error as e:
            print(path, e)
            return self.__getitem__(random.randint(0, self.length - 1))
        if "lsun" in self.opt.dataroot.lower():
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(img)

        return {"real_A": self.transform(img), "path_A": path.decode("utf-8")}

    def set_phase(self, phase):
        super().set_phase(phase)
        pass

    def __len__(self):
        return self.length

    def __repr__(self):
        return self.__class__.__name__ + ' (' + self.root + ')'