taming_transformer
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- =1.0.8 +0 -0
- =2.0.0 +0 -0
- License.txt +19 -0
- __pycache__/main.cpython-312.pyc +0 -0
- environment.yaml +25 -0
- main.py +585 -0
- scripts/extract_depth.py +112 -0
- scripts/extract_segmentation.py +130 -0
- scripts/extract_submodel.py +17 -0
- scripts/make_samples.py +292 -0
- scripts/make_scene_samples.py +198 -0
- scripts/sample_conditional.py +355 -0
- scripts/sample_fast.py +260 -0
- scripts/taming-transformers.ipynb +0 -0
- setup.py +13 -0
- taming/__pycache__/lr_scheduler.cpython-312.pyc +0 -0
- taming/__pycache__/util.cpython-312.pyc +0 -0
- taming/data/.ipynb_checkpoints/utils-checkpoint.py +171 -0
- taming/data/__pycache__/helper_types.cpython-312.pyc +0 -0
- taming/data/__pycache__/utils.cpython-312.pyc +0 -0
- taming/data/ade20k.py +124 -0
- taming/data/annotated_objects_coco.py +139 -0
- taming/data/annotated_objects_dataset.py +218 -0
- taming/data/annotated_objects_open_images.py +137 -0
- taming/data/base.py +70 -0
- taming/data/coco.py +176 -0
- taming/data/conditional_builder/objects_bbox.py +60 -0
- taming/data/conditional_builder/objects_center_points.py +168 -0
- taming/data/conditional_builder/utils.py +105 -0
- taming/data/custom.py +38 -0
- taming/data/faceshq.py +134 -0
- taming/data/helper_types.py +49 -0
- taming/data/image_transforms.py +132 -0
- taming/data/imagenet.py +558 -0
- taming/data/open_images_helper.py +379 -0
- taming/data/sflckr.py +91 -0
- taming/data/utils.py +171 -0
- taming/lr_scheduler.py +34 -0
- taming/models/__pycache__/vqgan.cpython-312.pyc +0 -0
- taming/models/cond_transformer.py +352 -0
- taming/models/dummy_cond_stage.py +22 -0
- taming/models/vqgan.py +404 -0
- taming/modules/__pycache__/util.cpython-312.pyc +0 -0
- taming/modules/diffusionmodules/__pycache__/model.cpython-312.pyc +0 -0
- taming/modules/diffusionmodules/model.py +776 -0
- taming/modules/discriminator/__pycache__/model.cpython-312.pyc +0 -0
- taming/modules/discriminator/model.py +67 -0
- taming/modules/losses/__init__.py +2 -0
- taming/modules/losses/__pycache__/__init__.cpython-312.pyc +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
scripts/reconstruction_usage.ipynb filter=lfs diff=lfs merge=lfs -text
|
=1.0.8
ADDED
The diff for this file is too large to render.
See raw diff
|
|
=2.0.0
ADDED
File without changes
|
License.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
+
of this software and associated documentation files (the "Software"), to deal
|
5 |
+
in the Software without restriction, including without limitation the rights
|
6 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7 |
+
copies of the Software, and to permit persons to whom the Software is
|
8 |
+
furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
The above copyright notice and this permission notice shall be included in all
|
11 |
+
copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
14 |
+
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
15 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
16 |
+
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
17 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
18 |
+
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
|
19 |
+
OR OTHER DEALINGS IN THE SOFTWARE./
|
__pycache__/main.cpython-312.pyc
ADDED
Binary file (27.2 kB). View file
|
|
environment.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: taming
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- python=3.8.5
|
7 |
+
- pip=20.3
|
8 |
+
- cudatoolkit=10.2
|
9 |
+
- pytorch=1.7.0
|
10 |
+
- torchvision=0.8.1
|
11 |
+
- numpy=1.19.2
|
12 |
+
- pip:
|
13 |
+
- albumentations==0.4.3
|
14 |
+
- opencv-python==4.1.2.30
|
15 |
+
- pudb==2019.2
|
16 |
+
- imageio==2.9.0
|
17 |
+
- imageio-ffmpeg==0.4.2
|
18 |
+
- pytorch-lightning==1.0.8
|
19 |
+
- omegaconf==2.0.0
|
20 |
+
- test-tube>=0.7.5
|
21 |
+
- streamlit>=0.73.1
|
22 |
+
- einops==0.3.0
|
23 |
+
- more-itertools>=8.0.0
|
24 |
+
- transformers==4.3.1
|
25 |
+
- -e .
|
main.py
ADDED
@@ -0,0 +1,585 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse, os, sys, datetime, glob, importlib
|
2 |
+
from omegaconf import OmegaConf
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
from torch.utils.data import random_split, DataLoader, Dataset
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
from pytorch_lightning import seed_everything
|
10 |
+
from pytorch_lightning.trainer import Trainer
|
11 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
|
12 |
+
from pytorch_lightning.utilities import rank_zero_only
|
13 |
+
|
14 |
+
from taming.data.utils import custom_collate
|
15 |
+
|
16 |
+
|
17 |
+
def get_obj_from_str(string, reload=False):
|
18 |
+
module, cls = string.rsplit(".", 1)
|
19 |
+
if reload:
|
20 |
+
module_imp = importlib.import_module(module)
|
21 |
+
importlib.reload(module_imp)
|
22 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
23 |
+
|
24 |
+
|
25 |
+
def get_parser(**parser_kwargs):
|
26 |
+
def str2bool(v):
|
27 |
+
if isinstance(v, bool):
|
28 |
+
return v
|
29 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
30 |
+
return True
|
31 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
32 |
+
return False
|
33 |
+
else:
|
34 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
35 |
+
|
36 |
+
parser = argparse.ArgumentParser(**parser_kwargs)
|
37 |
+
parser.add_argument(
|
38 |
+
"-n",
|
39 |
+
"--name",
|
40 |
+
type=str,
|
41 |
+
const=True,
|
42 |
+
default="",
|
43 |
+
nargs="?",
|
44 |
+
help="postfix for logdir",
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"-r",
|
48 |
+
"--resume",
|
49 |
+
type=str,
|
50 |
+
const=True,
|
51 |
+
default="",
|
52 |
+
nargs="?",
|
53 |
+
help="resume from logdir or checkpoint in logdir",
|
54 |
+
)
|
55 |
+
parser.add_argument(
|
56 |
+
"-b",
|
57 |
+
"--base",
|
58 |
+
nargs="*",
|
59 |
+
metavar="base_config.yaml",
|
60 |
+
help="paths to base configs. Loaded from left-to-right. "
|
61 |
+
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
62 |
+
default=list(),
|
63 |
+
)
|
64 |
+
parser.add_argument(
|
65 |
+
"-t",
|
66 |
+
"--train",
|
67 |
+
type=str2bool,
|
68 |
+
const=True,
|
69 |
+
default=False,
|
70 |
+
nargs="?",
|
71 |
+
help="train",
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
"--no-test",
|
75 |
+
type=str2bool,
|
76 |
+
const=True,
|
77 |
+
default=False,
|
78 |
+
nargs="?",
|
79 |
+
help="disable test",
|
80 |
+
)
|
81 |
+
parser.add_argument("-p", "--project", help="name of new or path to existing project")
|
82 |
+
parser.add_argument(
|
83 |
+
"-d",
|
84 |
+
"--debug",
|
85 |
+
type=str2bool,
|
86 |
+
nargs="?",
|
87 |
+
const=True,
|
88 |
+
default=False,
|
89 |
+
help="enable post-mortem debugging",
|
90 |
+
)
|
91 |
+
parser.add_argument(
|
92 |
+
"-s",
|
93 |
+
"--seed",
|
94 |
+
type=int,
|
95 |
+
default=23,
|
96 |
+
help="seed for seed_everything",
|
97 |
+
)
|
98 |
+
parser.add_argument(
|
99 |
+
"-f",
|
100 |
+
"--postfix",
|
101 |
+
type=str,
|
102 |
+
default="",
|
103 |
+
help="post-postfix for default name",
|
104 |
+
)
|
105 |
+
|
106 |
+
return parser
|
107 |
+
|
108 |
+
|
109 |
+
def nondefault_trainer_args(opt):
|
110 |
+
parser = argparse.ArgumentParser()
|
111 |
+
parser = Trainer.add_argparse_args(parser)
|
112 |
+
args = parser.parse_args([])
|
113 |
+
return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
|
114 |
+
|
115 |
+
|
116 |
+
def instantiate_from_config(config):
|
117 |
+
if not "target" in config:
|
118 |
+
raise KeyError("Expected key `target` to instantiate.")
|
119 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
120 |
+
|
121 |
+
|
122 |
+
class WrappedDataset(Dataset):
|
123 |
+
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
|
124 |
+
def __init__(self, dataset):
|
125 |
+
self.data = dataset
|
126 |
+
|
127 |
+
def __len__(self):
|
128 |
+
return len(self.data)
|
129 |
+
|
130 |
+
def __getitem__(self, idx):
|
131 |
+
return self.data[idx]
|
132 |
+
|
133 |
+
|
134 |
+
class DataModuleFromConfig(pl.LightningDataModule):
|
135 |
+
def __init__(self, batch_size, train=None, validation=None, test=None,
|
136 |
+
wrap=False, num_workers=None):
|
137 |
+
super().__init__()
|
138 |
+
self.batch_size = batch_size
|
139 |
+
self.dataset_configs = dict()
|
140 |
+
self.num_workers = num_workers if num_workers is not None else batch_size*2
|
141 |
+
if train is not None:
|
142 |
+
self.dataset_configs["train"] = train
|
143 |
+
self.train_dataloader = self._train_dataloader
|
144 |
+
if validation is not None:
|
145 |
+
self.dataset_configs["validation"] = validation
|
146 |
+
self.val_dataloader = self._val_dataloader
|
147 |
+
if test is not None:
|
148 |
+
self.dataset_configs["test"] = test
|
149 |
+
self.test_dataloader = self._test_dataloader
|
150 |
+
self.wrap = wrap
|
151 |
+
|
152 |
+
def prepare_data(self):
|
153 |
+
for data_cfg in self.dataset_configs.values():
|
154 |
+
instantiate_from_config(data_cfg)
|
155 |
+
|
156 |
+
def setup(self, stage=None):
|
157 |
+
self.datasets = dict(
|
158 |
+
(k, instantiate_from_config(self.dataset_configs[k]))
|
159 |
+
for k in self.dataset_configs)
|
160 |
+
if self.wrap:
|
161 |
+
for k in self.datasets:
|
162 |
+
self.datasets[k] = WrappedDataset(self.datasets[k])
|
163 |
+
|
164 |
+
def _train_dataloader(self):
|
165 |
+
return DataLoader(self.datasets["train"], batch_size=self.batch_size,
|
166 |
+
num_workers=self.num_workers, shuffle=True, collate_fn=custom_collate)
|
167 |
+
|
168 |
+
def _val_dataloader(self):
|
169 |
+
return DataLoader(self.datasets["validation"],
|
170 |
+
batch_size=self.batch_size,
|
171 |
+
num_workers=self.num_workers, collate_fn=custom_collate)
|
172 |
+
|
173 |
+
def _test_dataloader(self):
|
174 |
+
return DataLoader(self.datasets["test"], batch_size=self.batch_size,
|
175 |
+
num_workers=self.num_workers, collate_fn=custom_collate)
|
176 |
+
|
177 |
+
|
178 |
+
class SetupCallback(Callback):
|
179 |
+
def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
|
180 |
+
super().__init__()
|
181 |
+
self.resume = resume
|
182 |
+
self.now = now
|
183 |
+
self.logdir = logdir
|
184 |
+
self.ckptdir = ckptdir
|
185 |
+
self.cfgdir = cfgdir
|
186 |
+
self.config = config
|
187 |
+
self.lightning_config = lightning_config
|
188 |
+
|
189 |
+
def on_pretrain_routine_start(self, trainer, pl_module):
|
190 |
+
if trainer.global_rank == 0:
|
191 |
+
# Create logdirs and save configs
|
192 |
+
os.makedirs(self.logdir, exist_ok=True)
|
193 |
+
os.makedirs(self.ckptdir, exist_ok=True)
|
194 |
+
os.makedirs(self.cfgdir, exist_ok=True)
|
195 |
+
|
196 |
+
print("Project config")
|
197 |
+
print(self.config.pretty())
|
198 |
+
OmegaConf.save(self.config,
|
199 |
+
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
|
200 |
+
|
201 |
+
print("Lightning config")
|
202 |
+
print(self.lightning_config.pretty())
|
203 |
+
OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
|
204 |
+
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
|
205 |
+
|
206 |
+
else:
|
207 |
+
# ModelCheckpoint callback created log directory --- remove it
|
208 |
+
if not self.resume and os.path.exists(self.logdir):
|
209 |
+
dst, name = os.path.split(self.logdir)
|
210 |
+
dst = os.path.join(dst, "child_runs", name)
|
211 |
+
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
212 |
+
try:
|
213 |
+
os.rename(self.logdir, dst)
|
214 |
+
except FileNotFoundError:
|
215 |
+
pass
|
216 |
+
|
217 |
+
|
218 |
+
class ImageLogger(Callback):
|
219 |
+
def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True):
|
220 |
+
super().__init__()
|
221 |
+
self.batch_freq = batch_frequency
|
222 |
+
self.max_images = max_images
|
223 |
+
self.logger_log_images = {
|
224 |
+
pl.loggers.WandbLogger: self._wandb,
|
225 |
+
pl.loggers.TestTubeLogger: self._testtube,
|
226 |
+
}
|
227 |
+
self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
|
228 |
+
if not increase_log_steps:
|
229 |
+
self.log_steps = [self.batch_freq]
|
230 |
+
self.clamp = clamp
|
231 |
+
|
232 |
+
@rank_zero_only
|
233 |
+
def _wandb(self, pl_module, images, batch_idx, split):
|
234 |
+
raise ValueError("No way wandb")
|
235 |
+
grids = dict()
|
236 |
+
for k in images:
|
237 |
+
grid = torchvision.utils.make_grid(images[k])
|
238 |
+
grids[f"{split}/{k}"] = wandb.Image(grid)
|
239 |
+
pl_module.logger.experiment.log(grids)
|
240 |
+
|
241 |
+
@rank_zero_only
|
242 |
+
def _testtube(self, pl_module, images, batch_idx, split):
|
243 |
+
for k in images:
|
244 |
+
grid = torchvision.utils.make_grid(images[k])
|
245 |
+
grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w
|
246 |
+
|
247 |
+
tag = f"{split}/{k}"
|
248 |
+
pl_module.logger.experiment.add_image(
|
249 |
+
tag, grid,
|
250 |
+
global_step=pl_module.global_step)
|
251 |
+
|
252 |
+
@rank_zero_only
|
253 |
+
def log_local(self, save_dir, split, images,
|
254 |
+
global_step, current_epoch, batch_idx):
|
255 |
+
root = os.path.join(save_dir, "images", split)
|
256 |
+
for k in images:
|
257 |
+
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
258 |
+
|
259 |
+
grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w
|
260 |
+
grid = grid.transpose(0,1).transpose(1,2).squeeze(-1)
|
261 |
+
grid = grid.numpy()
|
262 |
+
grid = (grid*255).astype(np.uint8)
|
263 |
+
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
|
264 |
+
k,
|
265 |
+
global_step,
|
266 |
+
current_epoch,
|
267 |
+
batch_idx)
|
268 |
+
path = os.path.join(root, filename)
|
269 |
+
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
270 |
+
Image.fromarray(grid).save(path)
|
271 |
+
|
272 |
+
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
273 |
+
if (self.check_frequency(batch_idx) and # batch_idx % self.batch_freq == 0
|
274 |
+
hasattr(pl_module, "log_images") and
|
275 |
+
callable(pl_module.log_images) and
|
276 |
+
self.max_images > 0):
|
277 |
+
logger = type(pl_module.logger)
|
278 |
+
|
279 |
+
is_train = pl_module.training
|
280 |
+
if is_train:
|
281 |
+
pl_module.eval()
|
282 |
+
|
283 |
+
with torch.no_grad():
|
284 |
+
images = pl_module.log_images(batch, split=split, pl_module=pl_module)
|
285 |
+
|
286 |
+
for k in images:
|
287 |
+
N = min(images[k].shape[0], self.max_images)
|
288 |
+
images[k] = images[k][:N]
|
289 |
+
if isinstance(images[k], torch.Tensor):
|
290 |
+
images[k] = images[k].detach().cpu()
|
291 |
+
if self.clamp:
|
292 |
+
images[k] = torch.clamp(images[k], -1., 1.)
|
293 |
+
|
294 |
+
self.log_local(pl_module.logger.save_dir, split, images,
|
295 |
+
pl_module.global_step, pl_module.current_epoch, batch_idx)
|
296 |
+
|
297 |
+
logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
|
298 |
+
logger_log_images(pl_module, images, pl_module.global_step, split)
|
299 |
+
|
300 |
+
if is_train:
|
301 |
+
pl_module.train()
|
302 |
+
|
303 |
+
def check_frequency(self, batch_idx):
|
304 |
+
if (batch_idx % self.batch_freq) == 0 or (batch_idx in self.log_steps):
|
305 |
+
try:
|
306 |
+
self.log_steps.pop(0)
|
307 |
+
except IndexError:
|
308 |
+
pass
|
309 |
+
return True
|
310 |
+
return False
|
311 |
+
|
312 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
313 |
+
self.log_img(pl_module, batch, batch_idx, split="train")
|
314 |
+
|
315 |
+
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
316 |
+
self.log_img(pl_module, batch, batch_idx, split="val")
|
317 |
+
|
318 |
+
|
319 |
+
|
320 |
+
if __name__ == "__main__":
|
321 |
+
# custom parser to specify config files, train, test and debug mode,
|
322 |
+
# postfix, resume.
|
323 |
+
# `--key value` arguments are interpreted as arguments to the trainer.
|
324 |
+
# `nested.key=value` arguments are interpreted as config parameters.
|
325 |
+
# configs are merged from left-to-right followed by command line parameters.
|
326 |
+
|
327 |
+
# model:
|
328 |
+
# base_learning_rate: float
|
329 |
+
# target: path to lightning module
|
330 |
+
# params:
|
331 |
+
# key: value
|
332 |
+
# data:
|
333 |
+
# target: main.DataModuleFromConfig
|
334 |
+
# params:
|
335 |
+
# batch_size: int
|
336 |
+
# wrap: bool
|
337 |
+
# train:
|
338 |
+
# target: path to train dataset
|
339 |
+
# params:
|
340 |
+
# key: value
|
341 |
+
# validation:
|
342 |
+
# target: path to validation dataset
|
343 |
+
# params:
|
344 |
+
# key: value
|
345 |
+
# test:
|
346 |
+
# target: path to test dataset
|
347 |
+
# params:
|
348 |
+
# key: value
|
349 |
+
# lightning: (optional, has sane defaults and can be specified on cmdline)
|
350 |
+
# trainer:
|
351 |
+
# additional arguments to trainer
|
352 |
+
# logger:
|
353 |
+
# logger to instantiate
|
354 |
+
# modelcheckpoint:
|
355 |
+
# modelcheckpoint to instantiate
|
356 |
+
# callbacks:
|
357 |
+
# callback1:
|
358 |
+
# target: importpath
|
359 |
+
# params:
|
360 |
+
# key: value
|
361 |
+
|
362 |
+
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
363 |
+
|
364 |
+
# add cwd for convenience and to make classes in this file available when
|
365 |
+
# running as `python main.py`
|
366 |
+
# (in particular `main.DataModuleFromConfig`)
|
367 |
+
sys.path.append(os.getcwd())
|
368 |
+
|
369 |
+
parser = get_parser()
|
370 |
+
parser = Trainer.add_argparse_args(parser)
|
371 |
+
|
372 |
+
opt, unknown = parser.parse_known_args()
|
373 |
+
if opt.name and opt.resume:
|
374 |
+
raise ValueError(
|
375 |
+
"-n/--name and -r/--resume cannot be specified both."
|
376 |
+
"If you want to resume training in a new log folder, "
|
377 |
+
"use -n/--name in combination with --resume_from_checkpoint"
|
378 |
+
)
|
379 |
+
if opt.resume:
|
380 |
+
if not os.path.exists(opt.resume):
|
381 |
+
raise ValueError("Cannot find {}".format(opt.resume))
|
382 |
+
if os.path.isfile(opt.resume):
|
383 |
+
paths = opt.resume.split("/")
|
384 |
+
idx = len(paths)-paths[::-1].index("logs")+1
|
385 |
+
logdir = "/".join(paths[:idx])
|
386 |
+
ckpt = opt.resume
|
387 |
+
else:
|
388 |
+
assert os.path.isdir(opt.resume), opt.resume
|
389 |
+
logdir = opt.resume.rstrip("/")
|
390 |
+
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
391 |
+
|
392 |
+
opt.resume_from_checkpoint = ckpt
|
393 |
+
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
|
394 |
+
opt.base = base_configs+opt.base
|
395 |
+
_tmp = logdir.split("/")
|
396 |
+
nowname = _tmp[_tmp.index("logs")+1]
|
397 |
+
else:
|
398 |
+
if opt.name:
|
399 |
+
name = "_"+opt.name
|
400 |
+
elif opt.base:
|
401 |
+
cfg_fname = os.path.split(opt.base[0])[-1]
|
402 |
+
cfg_name = os.path.splitext(cfg_fname)[0]
|
403 |
+
name = "_"+cfg_name
|
404 |
+
else:
|
405 |
+
name = ""
|
406 |
+
nowname = now+name+opt.postfix
|
407 |
+
logdir = os.path.join("logs", nowname)
|
408 |
+
|
409 |
+
ckptdir = os.path.join(logdir, "checkpoints")
|
410 |
+
cfgdir = os.path.join(logdir, "configs")
|
411 |
+
seed_everything(opt.seed)
|
412 |
+
|
413 |
+
try:
|
414 |
+
# init and save configs
|
415 |
+
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
416 |
+
cli = OmegaConf.from_dotlist(unknown)
|
417 |
+
config = OmegaConf.merge(*configs, cli)
|
418 |
+
lightning_config = config.pop("lightning", OmegaConf.create())
|
419 |
+
# merge trainer cli with config
|
420 |
+
trainer_config = lightning_config.get("trainer", OmegaConf.create())
|
421 |
+
# default to ddp
|
422 |
+
trainer_config["distributed_backend"] = "ddp"
|
423 |
+
for k in nondefault_trainer_args(opt):
|
424 |
+
trainer_config[k] = getattr(opt, k)
|
425 |
+
if not "gpus" in trainer_config:
|
426 |
+
del trainer_config["distributed_backend"]
|
427 |
+
cpu = True
|
428 |
+
else:
|
429 |
+
gpuinfo = trainer_config["gpus"]
|
430 |
+
print(f"Running on GPUs {gpuinfo}")
|
431 |
+
cpu = False
|
432 |
+
trainer_opt = argparse.Namespace(**trainer_config)
|
433 |
+
lightning_config.trainer = trainer_config
|
434 |
+
|
435 |
+
# model
|
436 |
+
model = instantiate_from_config(config.model)
|
437 |
+
|
438 |
+
# trainer and callbacks
|
439 |
+
trainer_kwargs = dict()
|
440 |
+
|
441 |
+
# default logger configs
|
442 |
+
# NOTE wandb < 0.10.0 interferes with shutdown
|
443 |
+
# wandb >= 0.10.0 seems to fix it but still interferes with pudb
|
444 |
+
# debugging (wrongly sized pudb ui)
|
445 |
+
# thus prefer testtube for now
|
446 |
+
default_logger_cfgs = {
|
447 |
+
"wandb": {
|
448 |
+
"target": "pytorch_lightning.loggers.WandbLogger",
|
449 |
+
"params": {
|
450 |
+
"name": nowname,
|
451 |
+
"save_dir": logdir,
|
452 |
+
"offline": opt.debug,
|
453 |
+
"id": nowname,
|
454 |
+
}
|
455 |
+
},
|
456 |
+
"testtube": {
|
457 |
+
"target": "pytorch_lightning.loggers.TestTubeLogger",
|
458 |
+
"params": {
|
459 |
+
"name": "testtube",
|
460 |
+
"save_dir": logdir,
|
461 |
+
}
|
462 |
+
},
|
463 |
+
}
|
464 |
+
default_logger_cfg = default_logger_cfgs["testtube"]
|
465 |
+
logger_cfg = lightning_config.logger or OmegaConf.create()
|
466 |
+
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
|
467 |
+
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
|
468 |
+
|
469 |
+
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
|
470 |
+
# specify which metric is used to determine best models
|
471 |
+
default_modelckpt_cfg = {
|
472 |
+
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
473 |
+
"params": {
|
474 |
+
"dirpath": ckptdir,
|
475 |
+
"filename": "{epoch:06}",
|
476 |
+
"verbose": True,
|
477 |
+
"save_last": True,
|
478 |
+
}
|
479 |
+
}
|
480 |
+
if hasattr(model, "monitor"):
|
481 |
+
print(f"Monitoring {model.monitor} as checkpoint metric.")
|
482 |
+
default_modelckpt_cfg["params"]["monitor"] = model.monitor
|
483 |
+
default_modelckpt_cfg["params"]["save_top_k"] = 3
|
484 |
+
|
485 |
+
modelckpt_cfg = lightning_config.modelcheckpoint or OmegaConf.create()
|
486 |
+
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
|
487 |
+
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
|
488 |
+
|
489 |
+
# add callback which sets up log directory
|
490 |
+
default_callbacks_cfg = {
|
491 |
+
"setup_callback": {
|
492 |
+
"target": "main.SetupCallback",
|
493 |
+
"params": {
|
494 |
+
"resume": opt.resume,
|
495 |
+
"now": now,
|
496 |
+
"logdir": logdir,
|
497 |
+
"ckptdir": ckptdir,
|
498 |
+
"cfgdir": cfgdir,
|
499 |
+
"config": config,
|
500 |
+
"lightning_config": lightning_config,
|
501 |
+
}
|
502 |
+
},
|
503 |
+
"image_logger": {
|
504 |
+
"target": "main.ImageLogger",
|
505 |
+
"params": {
|
506 |
+
"batch_frequency": 750,
|
507 |
+
"max_images": 4,
|
508 |
+
"clamp": True
|
509 |
+
}
|
510 |
+
},
|
511 |
+
"learning_rate_logger": {
|
512 |
+
"target": "main.LearningRateMonitor",
|
513 |
+
"params": {
|
514 |
+
"logging_interval": "step",
|
515 |
+
#"log_momentum": True
|
516 |
+
}
|
517 |
+
},
|
518 |
+
}
|
519 |
+
callbacks_cfg = lightning_config.callbacks or OmegaConf.create()
|
520 |
+
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
521 |
+
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
522 |
+
|
523 |
+
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
524 |
+
|
525 |
+
# data
|
526 |
+
data = instantiate_from_config(config.data)
|
527 |
+
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
528 |
+
# calling these ourselves should not be necessary but it is.
|
529 |
+
# lightning still takes care of proper multiprocessing though
|
530 |
+
data.prepare_data()
|
531 |
+
data.setup()
|
532 |
+
|
533 |
+
# configure learning rate
|
534 |
+
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
|
535 |
+
if not cpu:
|
536 |
+
ngpu = len(lightning_config.trainer.gpus.strip(",").split(','))
|
537 |
+
else:
|
538 |
+
ngpu = 1
|
539 |
+
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches or 1
|
540 |
+
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
|
541 |
+
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
|
542 |
+
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
|
543 |
+
print("Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
|
544 |
+
model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
|
545 |
+
|
546 |
+
# allow checkpointing via USR1
|
547 |
+
def melk(*args, **kwargs):
|
548 |
+
# run all checkpoint hooks
|
549 |
+
if trainer.global_rank == 0:
|
550 |
+
print("Summoning checkpoint.")
|
551 |
+
ckpt_path = os.path.join(ckptdir, "last.ckpt")
|
552 |
+
trainer.save_checkpoint(ckpt_path)
|
553 |
+
|
554 |
+
def divein(*args, **kwargs):
|
555 |
+
if trainer.global_rank == 0:
|
556 |
+
import pudb; pudb.set_trace()
|
557 |
+
|
558 |
+
import signal
|
559 |
+
signal.signal(signal.SIGUSR1, melk)
|
560 |
+
signal.signal(signal.SIGUSR2, divein)
|
561 |
+
|
562 |
+
# run
|
563 |
+
if opt.train:
|
564 |
+
try:
|
565 |
+
trainer.fit(model, data)
|
566 |
+
except Exception:
|
567 |
+
melk()
|
568 |
+
raise
|
569 |
+
if not opt.no_test and not trainer.interrupted:
|
570 |
+
trainer.test(model, data)
|
571 |
+
except Exception:
|
572 |
+
if opt.debug and trainer.global_rank==0:
|
573 |
+
try:
|
574 |
+
import pudb as debugger
|
575 |
+
except ImportError:
|
576 |
+
import pdb as debugger
|
577 |
+
debugger.post_mortem()
|
578 |
+
raise
|
579 |
+
finally:
|
580 |
+
# move newly created debug project to debug_runs
|
581 |
+
if opt.debug and not opt.resume and trainer.global_rank==0:
|
582 |
+
dst, name = os.path.split(logdir)
|
583 |
+
dst = os.path.join(dst, "debug_runs", name)
|
584 |
+
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
585 |
+
os.rename(logdir, dst)
|
scripts/extract_depth.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from tqdm import trange
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
def get_state(gpu):
|
9 |
+
import torch
|
10 |
+
midas = torch.hub.load("intel-isl/MiDaS", "MiDaS")
|
11 |
+
if gpu:
|
12 |
+
midas.cuda()
|
13 |
+
midas.eval()
|
14 |
+
|
15 |
+
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
|
16 |
+
transform = midas_transforms.default_transform
|
17 |
+
|
18 |
+
state = {"model": midas,
|
19 |
+
"transform": transform}
|
20 |
+
return state
|
21 |
+
|
22 |
+
|
23 |
+
def depth_to_rgba(x):
|
24 |
+
assert x.dtype == np.float32
|
25 |
+
assert len(x.shape) == 2
|
26 |
+
y = x.copy()
|
27 |
+
y.dtype = np.uint8
|
28 |
+
y = y.reshape(x.shape+(4,))
|
29 |
+
return np.ascontiguousarray(y)
|
30 |
+
|
31 |
+
|
32 |
+
def rgba_to_depth(x):
|
33 |
+
assert x.dtype == np.uint8
|
34 |
+
assert len(x.shape) == 3 and x.shape[2] == 4
|
35 |
+
y = x.copy()
|
36 |
+
y.dtype = np.float32
|
37 |
+
y = y.reshape(x.shape[:2])
|
38 |
+
return np.ascontiguousarray(y)
|
39 |
+
|
40 |
+
|
41 |
+
def run(x, state):
|
42 |
+
model = state["model"]
|
43 |
+
transform = state["transform"]
|
44 |
+
hw = x.shape[:2]
|
45 |
+
with torch.no_grad():
|
46 |
+
prediction = model(transform((x + 1.0) * 127.5).cuda())
|
47 |
+
prediction = torch.nn.functional.interpolate(
|
48 |
+
prediction.unsqueeze(1),
|
49 |
+
size=hw,
|
50 |
+
mode="bicubic",
|
51 |
+
align_corners=False,
|
52 |
+
).squeeze()
|
53 |
+
output = prediction.cpu().numpy()
|
54 |
+
return output
|
55 |
+
|
56 |
+
|
57 |
+
def get_filename(relpath, level=-2):
|
58 |
+
# save class folder structure and filename:
|
59 |
+
fn = relpath.split(os.sep)[level:]
|
60 |
+
folder = fn[-2]
|
61 |
+
file = fn[-1].split('.')[0]
|
62 |
+
return folder, file
|
63 |
+
|
64 |
+
|
65 |
+
def save_depth(dataset, path, debug=False):
|
66 |
+
os.makedirs(path)
|
67 |
+
N = len(dset)
|
68 |
+
if debug:
|
69 |
+
N = 10
|
70 |
+
state = get_state(gpu=True)
|
71 |
+
for idx in trange(N, desc="Data"):
|
72 |
+
ex = dataset[idx]
|
73 |
+
image, relpath = ex["image"], ex["relpath"]
|
74 |
+
folder, filename = get_filename(relpath)
|
75 |
+
# prepare
|
76 |
+
folderabspath = os.path.join(path, folder)
|
77 |
+
os.makedirs(folderabspath, exist_ok=True)
|
78 |
+
savepath = os.path.join(folderabspath, filename)
|
79 |
+
# run model
|
80 |
+
xout = run(image, state)
|
81 |
+
I = depth_to_rgba(xout)
|
82 |
+
Image.fromarray(I).save("{}.png".format(savepath))
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == "__main__":
|
86 |
+
from taming.data.imagenet import ImageNetTrain, ImageNetValidation
|
87 |
+
out = "data/imagenet_depth"
|
88 |
+
if not os.path.exists(out):
|
89 |
+
print("Please create a folder or symlink '{}' to extract depth data ".format(out) +
|
90 |
+
"(be prepared that the output size will be larger than ImageNet itself).")
|
91 |
+
exit(1)
|
92 |
+
|
93 |
+
# go
|
94 |
+
dset = ImageNetValidation()
|
95 |
+
abspath = os.path.join(out, "val")
|
96 |
+
if os.path.exists(abspath):
|
97 |
+
print("{} exists - not doing anything.".format(abspath))
|
98 |
+
else:
|
99 |
+
print("preparing {}".format(abspath))
|
100 |
+
save_depth(dset, abspath)
|
101 |
+
print("done with validation split")
|
102 |
+
|
103 |
+
dset = ImageNetTrain()
|
104 |
+
abspath = os.path.join(out, "train")
|
105 |
+
if os.path.exists(abspath):
|
106 |
+
print("{} exists - not doing anything.".format(abspath))
|
107 |
+
else:
|
108 |
+
print("preparing {}".format(abspath))
|
109 |
+
save_depth(dset, abspath)
|
110 |
+
print("done with train split")
|
111 |
+
|
112 |
+
print("done done.")
|
scripts/extract_segmentation.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, os
|
2 |
+
import numpy as np
|
3 |
+
import scipy
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from scipy import ndimage
|
7 |
+
from tqdm import tqdm, trange
|
8 |
+
from PIL import Image
|
9 |
+
import torch.hub
|
10 |
+
import torchvision
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
# download deeplabv2_resnet101_msc-cocostuff164k-100000.pth from
|
14 |
+
# https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/deeplabv2_resnet101_msc-cocostuff164k-100000.pth
|
15 |
+
# and put the path here
|
16 |
+
CKPT_PATH = "TODO"
|
17 |
+
|
18 |
+
rescale = lambda x: (x + 1.) / 2.
|
19 |
+
|
20 |
+
def rescale_bgr(x):
|
21 |
+
x = (x+1)*127.5
|
22 |
+
x = torch.flip(x, dims=[0])
|
23 |
+
return x
|
24 |
+
|
25 |
+
|
26 |
+
class COCOStuffSegmenter(nn.Module):
|
27 |
+
def __init__(self, config):
|
28 |
+
super().__init__()
|
29 |
+
self.config = config
|
30 |
+
self.n_labels = 182
|
31 |
+
model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", n_classes=self.n_labels)
|
32 |
+
ckpt_path = CKPT_PATH
|
33 |
+
model.load_state_dict(torch.load(ckpt_path))
|
34 |
+
self.model = model
|
35 |
+
|
36 |
+
normalize = torchvision.transforms.Normalize(mean=self.mean, std=self.std)
|
37 |
+
self.image_transform = torchvision.transforms.Compose([
|
38 |
+
torchvision.transforms.Lambda(lambda image: torch.stack(
|
39 |
+
[normalize(rescale_bgr(x)) for x in image]))
|
40 |
+
])
|
41 |
+
|
42 |
+
def forward(self, x, upsample=None):
|
43 |
+
x = self._pre_process(x)
|
44 |
+
x = self.model(x)
|
45 |
+
if upsample is not None:
|
46 |
+
x = torch.nn.functional.upsample_bilinear(x, size=upsample)
|
47 |
+
return x
|
48 |
+
|
49 |
+
def _pre_process(self, x):
|
50 |
+
x = self.image_transform(x)
|
51 |
+
return x
|
52 |
+
|
53 |
+
@property
|
54 |
+
def mean(self):
|
55 |
+
# bgr
|
56 |
+
return [104.008, 116.669, 122.675]
|
57 |
+
|
58 |
+
@property
|
59 |
+
def std(self):
|
60 |
+
return [1.0, 1.0, 1.0]
|
61 |
+
|
62 |
+
@property
|
63 |
+
def input_size(self):
|
64 |
+
return [3, 224, 224]
|
65 |
+
|
66 |
+
|
67 |
+
def run_model(img, model):
|
68 |
+
model = model.eval()
|
69 |
+
with torch.no_grad():
|
70 |
+
segmentation = model(img, upsample=(img.shape[2], img.shape[3]))
|
71 |
+
segmentation = torch.argmax(segmentation, dim=1, keepdim=True)
|
72 |
+
return segmentation.detach().cpu()
|
73 |
+
|
74 |
+
|
75 |
+
def get_input(batch, k):
|
76 |
+
x = batch[k]
|
77 |
+
if len(x.shape) == 3:
|
78 |
+
x = x[..., None]
|
79 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
80 |
+
return x.float()
|
81 |
+
|
82 |
+
|
83 |
+
def save_segmentation(segmentation, path):
|
84 |
+
# --> class label to uint8, save as png
|
85 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
86 |
+
assert len(segmentation.shape)==4
|
87 |
+
assert segmentation.shape[0]==1
|
88 |
+
for seg in segmentation:
|
89 |
+
seg = seg.permute(1,2,0).numpy().squeeze().astype(np.uint8)
|
90 |
+
seg = Image.fromarray(seg)
|
91 |
+
seg.save(path)
|
92 |
+
|
93 |
+
|
94 |
+
def iterate_dataset(dataloader, destpath, model):
|
95 |
+
os.makedirs(destpath, exist_ok=True)
|
96 |
+
num_processed = 0
|
97 |
+
for i, batch in tqdm(enumerate(dataloader), desc="Data"):
|
98 |
+
try:
|
99 |
+
img = get_input(batch, "image")
|
100 |
+
img = img.cuda()
|
101 |
+
seg = run_model(img, model)
|
102 |
+
|
103 |
+
path = batch["relative_file_path_"][0]
|
104 |
+
path = os.path.splitext(path)[0]
|
105 |
+
|
106 |
+
path = os.path.join(destpath, path + ".png")
|
107 |
+
save_segmentation(seg, path)
|
108 |
+
num_processed += 1
|
109 |
+
except Exception as e:
|
110 |
+
print(e)
|
111 |
+
print("but anyhow..")
|
112 |
+
|
113 |
+
print("Processed {} files. Bye.".format(num_processed))
|
114 |
+
|
115 |
+
|
116 |
+
from taming.data.sflckr import Examples
|
117 |
+
from torch.utils.data import DataLoader
|
118 |
+
|
119 |
+
if __name__ == "__main__":
|
120 |
+
dest = sys.argv[1]
|
121 |
+
batchsize = 1
|
122 |
+
print("Running with batch-size {}, saving to {}...".format(batchsize, dest))
|
123 |
+
|
124 |
+
model = COCOStuffSegmenter({}).cuda()
|
125 |
+
print("Instantiated model.")
|
126 |
+
|
127 |
+
dataset = Examples()
|
128 |
+
dloader = DataLoader(dataset, batch_size=batchsize)
|
129 |
+
iterate_dataset(dataloader=dloader, destpath=dest, model=model)
|
130 |
+
print("done.")
|
scripts/extract_submodel.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import sys
|
3 |
+
|
4 |
+
if __name__ == "__main__":
|
5 |
+
inpath = sys.argv[1]
|
6 |
+
outpath = sys.argv[2]
|
7 |
+
submodel = "cond_stage_model"
|
8 |
+
if len(sys.argv) > 3:
|
9 |
+
submodel = sys.argv[3]
|
10 |
+
|
11 |
+
print("Extracting {} from {} to {}.".format(submodel, inpath, outpath))
|
12 |
+
|
13 |
+
sd = torch.load(inpath, map_location="cpu")
|
14 |
+
new_sd = {"state_dict": dict((k.split(".", 1)[-1],v)
|
15 |
+
for k,v in sd["state_dict"].items()
|
16 |
+
if k.startswith("cond_stage_model"))}
|
17 |
+
torch.save(new_sd, outpath)
|
scripts/make_samples.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse, os, sys, glob, math, time
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
from PIL import Image
|
6 |
+
from main import instantiate_from_config, DataModuleFromConfig
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from torch.utils.data.dataloader import default_collate
|
9 |
+
from tqdm import trange
|
10 |
+
|
11 |
+
|
12 |
+
def save_image(x, path):
|
13 |
+
c,h,w = x.shape
|
14 |
+
assert c==3
|
15 |
+
x = ((x.detach().cpu().numpy().transpose(1,2,0)+1.0)*127.5).clip(0,255).astype(np.uint8)
|
16 |
+
Image.fromarray(x).save(path)
|
17 |
+
|
18 |
+
|
19 |
+
@torch.no_grad()
|
20 |
+
def run_conditional(model, dsets, outdir, top_k, temperature, batch_size=1):
|
21 |
+
if len(dsets.datasets) > 1:
|
22 |
+
split = sorted(dsets.datasets.keys())[0]
|
23 |
+
dset = dsets.datasets[split]
|
24 |
+
else:
|
25 |
+
dset = next(iter(dsets.datasets.values()))
|
26 |
+
print("Dataset: ", dset.__class__.__name__)
|
27 |
+
for start_idx in trange(0,len(dset)-batch_size+1,batch_size):
|
28 |
+
indices = list(range(start_idx, start_idx+batch_size))
|
29 |
+
example = default_collate([dset[i] for i in indices])
|
30 |
+
|
31 |
+
x = model.get_input("image", example).to(model.device)
|
32 |
+
for i in range(x.shape[0]):
|
33 |
+
save_image(x[i], os.path.join(outdir, "originals",
|
34 |
+
"{:06}.png".format(indices[i])))
|
35 |
+
|
36 |
+
cond_key = model.cond_stage_key
|
37 |
+
c = model.get_input(cond_key, example).to(model.device)
|
38 |
+
|
39 |
+
scale_factor = 1.0
|
40 |
+
quant_z, z_indices = model.encode_to_z(x)
|
41 |
+
quant_c, c_indices = model.encode_to_c(c)
|
42 |
+
|
43 |
+
cshape = quant_z.shape
|
44 |
+
|
45 |
+
xrec = model.first_stage_model.decode(quant_z)
|
46 |
+
for i in range(xrec.shape[0]):
|
47 |
+
save_image(xrec[i], os.path.join(outdir, "reconstructions",
|
48 |
+
"{:06}.png".format(indices[i])))
|
49 |
+
|
50 |
+
if cond_key == "segmentation":
|
51 |
+
# get image from segmentation mask
|
52 |
+
num_classes = c.shape[1]
|
53 |
+
c = torch.argmax(c, dim=1, keepdim=True)
|
54 |
+
c = torch.nn.functional.one_hot(c, num_classes=num_classes)
|
55 |
+
c = c.squeeze(1).permute(0, 3, 1, 2).float()
|
56 |
+
c = model.cond_stage_model.to_rgb(c)
|
57 |
+
|
58 |
+
idx = z_indices
|
59 |
+
|
60 |
+
half_sample = False
|
61 |
+
if half_sample:
|
62 |
+
start = idx.shape[1]//2
|
63 |
+
else:
|
64 |
+
start = 0
|
65 |
+
|
66 |
+
idx[:,start:] = 0
|
67 |
+
idx = idx.reshape(cshape[0],cshape[2],cshape[3])
|
68 |
+
start_i = start//cshape[3]
|
69 |
+
start_j = start %cshape[3]
|
70 |
+
|
71 |
+
cidx = c_indices
|
72 |
+
cidx = cidx.reshape(quant_c.shape[0],quant_c.shape[2],quant_c.shape[3])
|
73 |
+
|
74 |
+
sample = True
|
75 |
+
|
76 |
+
for i in range(start_i,cshape[2]-0):
|
77 |
+
if i <= 8:
|
78 |
+
local_i = i
|
79 |
+
elif cshape[2]-i < 8:
|
80 |
+
local_i = 16-(cshape[2]-i)
|
81 |
+
else:
|
82 |
+
local_i = 8
|
83 |
+
for j in range(start_j,cshape[3]-0):
|
84 |
+
if j <= 8:
|
85 |
+
local_j = j
|
86 |
+
elif cshape[3]-j < 8:
|
87 |
+
local_j = 16-(cshape[3]-j)
|
88 |
+
else:
|
89 |
+
local_j = 8
|
90 |
+
|
91 |
+
i_start = i-local_i
|
92 |
+
i_end = i_start+16
|
93 |
+
j_start = j-local_j
|
94 |
+
j_end = j_start+16
|
95 |
+
patch = idx[:,i_start:i_end,j_start:j_end]
|
96 |
+
patch = patch.reshape(patch.shape[0],-1)
|
97 |
+
cpatch = cidx[:, i_start:i_end, j_start:j_end]
|
98 |
+
cpatch = cpatch.reshape(cpatch.shape[0], -1)
|
99 |
+
patch = torch.cat((cpatch, patch), dim=1)
|
100 |
+
logits,_ = model.transformer(patch[:,:-1])
|
101 |
+
logits = logits[:, -256:, :]
|
102 |
+
logits = logits.reshape(cshape[0],16,16,-1)
|
103 |
+
logits = logits[:,local_i,local_j,:]
|
104 |
+
|
105 |
+
logits = logits/temperature
|
106 |
+
|
107 |
+
if top_k is not None:
|
108 |
+
logits = model.top_k_logits(logits, top_k)
|
109 |
+
# apply softmax to convert to probabilities
|
110 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
111 |
+
# sample from the distribution or take the most likely
|
112 |
+
if sample:
|
113 |
+
ix = torch.multinomial(probs, num_samples=1)
|
114 |
+
else:
|
115 |
+
_, ix = torch.topk(probs, k=1, dim=-1)
|
116 |
+
idx[:,i,j] = ix
|
117 |
+
|
118 |
+
xsample = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
|
119 |
+
for i in range(xsample.shape[0]):
|
120 |
+
save_image(xsample[i], os.path.join(outdir, "samples",
|
121 |
+
"{:06}.png".format(indices[i])))
|
122 |
+
|
123 |
+
|
124 |
+
def get_parser():
|
125 |
+
parser = argparse.ArgumentParser()
|
126 |
+
parser.add_argument(
|
127 |
+
"-r",
|
128 |
+
"--resume",
|
129 |
+
type=str,
|
130 |
+
nargs="?",
|
131 |
+
help="load from logdir or checkpoint in logdir",
|
132 |
+
)
|
133 |
+
parser.add_argument(
|
134 |
+
"-b",
|
135 |
+
"--base",
|
136 |
+
nargs="*",
|
137 |
+
metavar="base_config.yaml",
|
138 |
+
help="paths to base configs. Loaded from left-to-right. "
|
139 |
+
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
140 |
+
default=list(),
|
141 |
+
)
|
142 |
+
parser.add_argument(
|
143 |
+
"-c",
|
144 |
+
"--config",
|
145 |
+
nargs="?",
|
146 |
+
metavar="single_config.yaml",
|
147 |
+
help="path to single config. If specified, base configs will be ignored "
|
148 |
+
"(except for the last one if left unspecified).",
|
149 |
+
const=True,
|
150 |
+
default="",
|
151 |
+
)
|
152 |
+
parser.add_argument(
|
153 |
+
"--ignore_base_data",
|
154 |
+
action="store_true",
|
155 |
+
help="Ignore data specification from base configs. Useful if you want "
|
156 |
+
"to specify a custom datasets on the command line.",
|
157 |
+
)
|
158 |
+
parser.add_argument(
|
159 |
+
"--outdir",
|
160 |
+
required=True,
|
161 |
+
type=str,
|
162 |
+
help="Where to write outputs to.",
|
163 |
+
)
|
164 |
+
parser.add_argument(
|
165 |
+
"--top_k",
|
166 |
+
type=int,
|
167 |
+
default=100,
|
168 |
+
help="Sample from among top-k predictions.",
|
169 |
+
)
|
170 |
+
parser.add_argument(
|
171 |
+
"--temperature",
|
172 |
+
type=float,
|
173 |
+
default=1.0,
|
174 |
+
help="Sampling temperature.",
|
175 |
+
)
|
176 |
+
return parser
|
177 |
+
|
178 |
+
|
179 |
+
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
|
180 |
+
if "ckpt_path" in config.params:
|
181 |
+
print("Deleting the restore-ckpt path from the config...")
|
182 |
+
config.params.ckpt_path = None
|
183 |
+
if "downsample_cond_size" in config.params:
|
184 |
+
print("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
|
185 |
+
config.params.downsample_cond_size = -1
|
186 |
+
config.params["downsample_cond_factor"] = 0.5
|
187 |
+
try:
|
188 |
+
if "ckpt_path" in config.params.first_stage_config.params:
|
189 |
+
config.params.first_stage_config.params.ckpt_path = None
|
190 |
+
print("Deleting the first-stage restore-ckpt path from the config...")
|
191 |
+
if "ckpt_path" in config.params.cond_stage_config.params:
|
192 |
+
config.params.cond_stage_config.params.ckpt_path = None
|
193 |
+
print("Deleting the cond-stage restore-ckpt path from the config...")
|
194 |
+
except:
|
195 |
+
pass
|
196 |
+
|
197 |
+
model = instantiate_from_config(config)
|
198 |
+
if sd is not None:
|
199 |
+
missing, unexpected = model.load_state_dict(sd, strict=False)
|
200 |
+
print(f"Missing Keys in State Dict: {missing}")
|
201 |
+
print(f"Unexpected Keys in State Dict: {unexpected}")
|
202 |
+
if gpu:
|
203 |
+
model.cuda()
|
204 |
+
if eval_mode:
|
205 |
+
model.eval()
|
206 |
+
return {"model": model}
|
207 |
+
|
208 |
+
|
209 |
+
def get_data(config):
|
210 |
+
# get data
|
211 |
+
data = instantiate_from_config(config.data)
|
212 |
+
data.prepare_data()
|
213 |
+
data.setup()
|
214 |
+
return data
|
215 |
+
|
216 |
+
|
217 |
+
def load_model_and_dset(config, ckpt, gpu, eval_mode):
|
218 |
+
# get data
|
219 |
+
dsets = get_data(config) # calls data.config ...
|
220 |
+
|
221 |
+
# now load the specified checkpoint
|
222 |
+
if ckpt:
|
223 |
+
pl_sd = torch.load(ckpt, map_location="cpu")
|
224 |
+
global_step = pl_sd["global_step"]
|
225 |
+
else:
|
226 |
+
pl_sd = {"state_dict": None}
|
227 |
+
global_step = None
|
228 |
+
model = load_model_from_config(config.model,
|
229 |
+
pl_sd["state_dict"],
|
230 |
+
gpu=gpu,
|
231 |
+
eval_mode=eval_mode)["model"]
|
232 |
+
return dsets, model, global_step
|
233 |
+
|
234 |
+
|
235 |
+
if __name__ == "__main__":
|
236 |
+
sys.path.append(os.getcwd())
|
237 |
+
|
238 |
+
parser = get_parser()
|
239 |
+
|
240 |
+
opt, unknown = parser.parse_known_args()
|
241 |
+
|
242 |
+
ckpt = None
|
243 |
+
if opt.resume:
|
244 |
+
if not os.path.exists(opt.resume):
|
245 |
+
raise ValueError("Cannot find {}".format(opt.resume))
|
246 |
+
if os.path.isfile(opt.resume):
|
247 |
+
paths = opt.resume.split("/")
|
248 |
+
try:
|
249 |
+
idx = len(paths)-paths[::-1].index("logs")+1
|
250 |
+
except ValueError:
|
251 |
+
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
|
252 |
+
logdir = "/".join(paths[:idx])
|
253 |
+
ckpt = opt.resume
|
254 |
+
else:
|
255 |
+
assert os.path.isdir(opt.resume), opt.resume
|
256 |
+
logdir = opt.resume.rstrip("/")
|
257 |
+
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
258 |
+
print(f"logdir:{logdir}")
|
259 |
+
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
|
260 |
+
opt.base = base_configs+opt.base
|
261 |
+
|
262 |
+
if opt.config:
|
263 |
+
if type(opt.config) == str:
|
264 |
+
opt.base = [opt.config]
|
265 |
+
else:
|
266 |
+
opt.base = [opt.base[-1]]
|
267 |
+
|
268 |
+
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
269 |
+
cli = OmegaConf.from_dotlist(unknown)
|
270 |
+
if opt.ignore_base_data:
|
271 |
+
for config in configs:
|
272 |
+
if hasattr(config, "data"): del config["data"]
|
273 |
+
config = OmegaConf.merge(*configs, cli)
|
274 |
+
|
275 |
+
print(ckpt)
|
276 |
+
gpu = True
|
277 |
+
eval_mode = True
|
278 |
+
show_config = False
|
279 |
+
if show_config:
|
280 |
+
print(OmegaConf.to_container(config))
|
281 |
+
|
282 |
+
dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
|
283 |
+
print(f"Global step: {global_step}")
|
284 |
+
|
285 |
+
outdir = os.path.join(opt.outdir, "{:06}_{}_{}".format(global_step,
|
286 |
+
opt.top_k,
|
287 |
+
opt.temperature))
|
288 |
+
os.makedirs(outdir, exist_ok=True)
|
289 |
+
print("Writing samples to ", outdir)
|
290 |
+
for k in ["originals", "reconstructions", "samples"]:
|
291 |
+
os.makedirs(os.path.join(outdir, k), exist_ok=True)
|
292 |
+
run_conditional(model, dsets, outdir, opt.top_k, opt.temperature)
|
scripts/make_scene_samples.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from itertools import product
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Literal, List, Optional, Tuple
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
from pytorch_lightning import seed_everything
|
12 |
+
from torch import Tensor
|
13 |
+
from torchvision.utils import save_image
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from scripts.make_samples import get_parser, load_model_and_dset
|
17 |
+
from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
|
18 |
+
from taming.data.helper_types import BoundingBox, Annotation
|
19 |
+
from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
|
20 |
+
from taming.models.cond_transformer import Net2NetTransformer
|
21 |
+
|
22 |
+
seed_everything(42424242)
|
23 |
+
device: Literal['cuda', 'cpu'] = 'cuda'
|
24 |
+
first_stage_factor = 16
|
25 |
+
trained_on_res = 256
|
26 |
+
|
27 |
+
|
28 |
+
def _helper(coord: int, coord_max: int, coord_window: int) -> (int, int):
|
29 |
+
assert 0 <= coord < coord_max
|
30 |
+
coord_desired_center = (coord_window - 1) // 2
|
31 |
+
return np.clip(coord - coord_desired_center, 0, coord_max - coord_window)
|
32 |
+
|
33 |
+
|
34 |
+
def get_crop_coordinates(x: int, y: int) -> BoundingBox:
|
35 |
+
WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0]
|
36 |
+
x0 = _helper(x, WIDTH, first_stage_factor) / WIDTH
|
37 |
+
y0 = _helper(y, HEIGHT, first_stage_factor) / HEIGHT
|
38 |
+
w = first_stage_factor / WIDTH
|
39 |
+
h = first_stage_factor / HEIGHT
|
40 |
+
return x0, y0, w, h
|
41 |
+
|
42 |
+
|
43 |
+
def get_z_indices_crop_out(z_indices: Tensor, predict_x: int, predict_y: int) -> Tensor:
|
44 |
+
WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0]
|
45 |
+
x0 = _helper(predict_x, WIDTH, first_stage_factor)
|
46 |
+
y0 = _helper(predict_y, HEIGHT, first_stage_factor)
|
47 |
+
no_images = z_indices.shape[0]
|
48 |
+
cut_out_1 = z_indices[:, y0:predict_y, x0:x0+first_stage_factor].reshape((no_images, -1))
|
49 |
+
cut_out_2 = z_indices[:, predict_y, x0:predict_x]
|
50 |
+
return torch.cat((cut_out_1, cut_out_2), dim=1)
|
51 |
+
|
52 |
+
|
53 |
+
@torch.no_grad()
|
54 |
+
def sample(model: Net2NetTransformer, annotations: List[Annotation], dataset: AnnotatedObjectsDataset,
|
55 |
+
conditional_builder: ObjectsCenterPointsConditionalBuilder, no_samples: int,
|
56 |
+
temperature: float, top_k: int) -> Tensor:
|
57 |
+
x_max, y_max = desired_z_shape[1], desired_z_shape[0]
|
58 |
+
|
59 |
+
annotations = [a._replace(category_no=dataset.get_category_number(a.category_id)) for a in annotations]
|
60 |
+
|
61 |
+
recompute_conditional = any((desired_resolution[0] > trained_on_res, desired_resolution[1] > trained_on_res))
|
62 |
+
if not recompute_conditional:
|
63 |
+
crop_coordinates = get_crop_coordinates(0, 0)
|
64 |
+
conditional_indices = conditional_builder.build(annotations, crop_coordinates)
|
65 |
+
c_indices = conditional_indices.to(device).repeat(no_samples, 1)
|
66 |
+
z_indices = torch.zeros((no_samples, 0), device=device).long()
|
67 |
+
output_indices = model.sample(z_indices, c_indices, steps=x_max*y_max, temperature=temperature,
|
68 |
+
sample=True, top_k=top_k)
|
69 |
+
else:
|
70 |
+
output_indices = torch.zeros((no_samples, y_max, x_max), device=device).long()
|
71 |
+
for predict_y, predict_x in tqdm(product(range(y_max), range(x_max)), desc='sampling_image', total=x_max*y_max):
|
72 |
+
crop_coordinates = get_crop_coordinates(predict_x, predict_y)
|
73 |
+
z_indices = get_z_indices_crop_out(output_indices, predict_x, predict_y)
|
74 |
+
conditional_indices = conditional_builder.build(annotations, crop_coordinates)
|
75 |
+
c_indices = conditional_indices.to(device).repeat(no_samples, 1)
|
76 |
+
new_index = model.sample(z_indices, c_indices, steps=1, temperature=temperature, sample=True, top_k=top_k)
|
77 |
+
output_indices[:, predict_y, predict_x] = new_index[:, -1]
|
78 |
+
z_shape = (
|
79 |
+
no_samples,
|
80 |
+
model.first_stage_model.quantize.e_dim, # codebook embed_dim
|
81 |
+
desired_z_shape[0], # z_height
|
82 |
+
desired_z_shape[1] # z_width
|
83 |
+
)
|
84 |
+
x_sample = model.decode_to_img(output_indices, z_shape) * 0.5 + 0.5
|
85 |
+
x_sample = x_sample.to('cpu')
|
86 |
+
|
87 |
+
plotter = conditional_builder.plot
|
88 |
+
figure_size = (x_sample.shape[2], x_sample.shape[3])
|
89 |
+
scene_graph = conditional_builder.build(annotations, (0., 0., 1., 1.))
|
90 |
+
plot = plotter(scene_graph, dataset.get_textual_label_for_category_no, figure_size)
|
91 |
+
return torch.cat((x_sample, plot.unsqueeze(0)))
|
92 |
+
|
93 |
+
|
94 |
+
def get_resolution(resolution_str: str) -> (Tuple[int, int], Tuple[int, int]):
|
95 |
+
if not resolution_str.count(',') == 1:
|
96 |
+
raise ValueError("Give resolution as in 'height,width'")
|
97 |
+
res_h, res_w = resolution_str.split(',')
|
98 |
+
res_h = max(int(res_h), trained_on_res)
|
99 |
+
res_w = max(int(res_w), trained_on_res)
|
100 |
+
z_h = int(round(res_h/first_stage_factor))
|
101 |
+
z_w = int(round(res_w/first_stage_factor))
|
102 |
+
return (z_h, z_w), (z_h*first_stage_factor, z_w*first_stage_factor)
|
103 |
+
|
104 |
+
|
105 |
+
def add_arg_to_parser(parser):
|
106 |
+
parser.add_argument(
|
107 |
+
"-R",
|
108 |
+
"--resolution",
|
109 |
+
type=str,
|
110 |
+
default='256,256',
|
111 |
+
help=f"give resolution in multiples of {first_stage_factor}, default is '256,256'",
|
112 |
+
)
|
113 |
+
parser.add_argument(
|
114 |
+
"-C",
|
115 |
+
"--conditional",
|
116 |
+
type=str,
|
117 |
+
default='objects_bbox',
|
118 |
+
help=f"objects_bbox or objects_center_points",
|
119 |
+
)
|
120 |
+
parser.add_argument(
|
121 |
+
"-N",
|
122 |
+
"--n_samples_per_layout",
|
123 |
+
type=int,
|
124 |
+
default=4,
|
125 |
+
help=f"how many samples to generate per layout",
|
126 |
+
)
|
127 |
+
return parser
|
128 |
+
|
129 |
+
|
130 |
+
if __name__ == "__main__":
|
131 |
+
sys.path.append(os.getcwd())
|
132 |
+
|
133 |
+
parser = get_parser()
|
134 |
+
parser = add_arg_to_parser(parser)
|
135 |
+
|
136 |
+
opt, unknown = parser.parse_known_args()
|
137 |
+
|
138 |
+
ckpt = None
|
139 |
+
if opt.resume:
|
140 |
+
if not os.path.exists(opt.resume):
|
141 |
+
raise ValueError("Cannot find {}".format(opt.resume))
|
142 |
+
if os.path.isfile(opt.resume):
|
143 |
+
paths = opt.resume.split("/")
|
144 |
+
try:
|
145 |
+
idx = len(paths)-paths[::-1].index("logs")+1
|
146 |
+
except ValueError:
|
147 |
+
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
|
148 |
+
logdir = "/".join(paths[:idx])
|
149 |
+
ckpt = opt.resume
|
150 |
+
else:
|
151 |
+
assert os.path.isdir(opt.resume), opt.resume
|
152 |
+
logdir = opt.resume.rstrip("/")
|
153 |
+
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
154 |
+
print(f"logdir:{logdir}")
|
155 |
+
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
|
156 |
+
opt.base = base_configs+opt.base
|
157 |
+
|
158 |
+
if opt.config:
|
159 |
+
if type(opt.config) == str:
|
160 |
+
opt.base = [opt.config]
|
161 |
+
else:
|
162 |
+
opt.base = [opt.base[-1]]
|
163 |
+
|
164 |
+
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
165 |
+
cli = OmegaConf.from_dotlist(unknown)
|
166 |
+
if opt.ignore_base_data:
|
167 |
+
for config in configs:
|
168 |
+
if hasattr(config, "data"):
|
169 |
+
del config["data"]
|
170 |
+
config = OmegaConf.merge(*configs, cli)
|
171 |
+
desired_z_shape, desired_resolution = get_resolution(opt.resolution)
|
172 |
+
conditional = opt.conditional
|
173 |
+
|
174 |
+
print(ckpt)
|
175 |
+
gpu = True
|
176 |
+
eval_mode = True
|
177 |
+
show_config = False
|
178 |
+
if show_config:
|
179 |
+
print(OmegaConf.to_container(config))
|
180 |
+
|
181 |
+
dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
|
182 |
+
print(f"Global step: {global_step}")
|
183 |
+
|
184 |
+
data_loader = dsets.val_dataloader()
|
185 |
+
print(dsets.datasets["validation"].conditional_builders)
|
186 |
+
conditional_builder = dsets.datasets["validation"].conditional_builders[conditional]
|
187 |
+
|
188 |
+
outdir = Path(opt.outdir).joinpath(f"{global_step:06}_{opt.top_k}_{opt.temperature}")
|
189 |
+
outdir.mkdir(exist_ok=True, parents=True)
|
190 |
+
print("Writing samples to ", outdir)
|
191 |
+
|
192 |
+
p_bar_1 = tqdm(enumerate(iter(data_loader)), desc='batch', total=len(data_loader))
|
193 |
+
for batch_no, batch in p_bar_1:
|
194 |
+
save_img: Optional[Tensor] = None
|
195 |
+
for i, annotations in tqdm(enumerate(batch['annotations']), desc='within_batch', total=data_loader.batch_size):
|
196 |
+
imgs = sample(model, annotations, dsets.datasets["validation"], conditional_builder,
|
197 |
+
opt.n_samples_per_layout, opt.temperature, opt.top_k)
|
198 |
+
save_image(imgs, outdir.joinpath(f'{batch_no:04}_{i:02}.png'), n_row=opt.n_samples_per_layout+1)
|
scripts/sample_conditional.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse, os, sys, glob, math, time
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
import streamlit as st
|
6 |
+
from streamlit import caching
|
7 |
+
from PIL import Image
|
8 |
+
from main import instantiate_from_config, DataModuleFromConfig
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
from torch.utils.data.dataloader import default_collate
|
11 |
+
|
12 |
+
|
13 |
+
rescale = lambda x: (x + 1.) / 2.
|
14 |
+
|
15 |
+
|
16 |
+
def bchw_to_st(x):
|
17 |
+
return rescale(x.detach().cpu().numpy().transpose(0,2,3,1))
|
18 |
+
|
19 |
+
def save_img(xstart, fname):
|
20 |
+
I = (xstart.clip(0,1)[0]*255).astype(np.uint8)
|
21 |
+
Image.fromarray(I).save(fname)
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
def get_interactive_image(resize=False):
|
26 |
+
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
|
27 |
+
if image is not None:
|
28 |
+
image = Image.open(image)
|
29 |
+
if not image.mode == "RGB":
|
30 |
+
image = image.convert("RGB")
|
31 |
+
image = np.array(image).astype(np.uint8)
|
32 |
+
print("upload image shape: {}".format(image.shape))
|
33 |
+
img = Image.fromarray(image)
|
34 |
+
if resize:
|
35 |
+
img = img.resize((256, 256))
|
36 |
+
image = np.array(img)
|
37 |
+
return image
|
38 |
+
|
39 |
+
|
40 |
+
def single_image_to_torch(x, permute=True):
|
41 |
+
assert x is not None, "Please provide an image through the upload function"
|
42 |
+
x = np.array(x)
|
43 |
+
x = torch.FloatTensor(x/255.*2. - 1.)[None,...]
|
44 |
+
if permute:
|
45 |
+
x = x.permute(0, 3, 1, 2)
|
46 |
+
return x
|
47 |
+
|
48 |
+
|
49 |
+
def pad_to_M(x, M):
|
50 |
+
hp = math.ceil(x.shape[2]/M)*M-x.shape[2]
|
51 |
+
wp = math.ceil(x.shape[3]/M)*M-x.shape[3]
|
52 |
+
x = torch.nn.functional.pad(x, (0,wp,0,hp,0,0,0,0))
|
53 |
+
return x
|
54 |
+
|
55 |
+
@torch.no_grad()
|
56 |
+
def run_conditional(model, dsets):
|
57 |
+
if len(dsets.datasets) > 1:
|
58 |
+
split = st.sidebar.radio("Split", sorted(dsets.datasets.keys()))
|
59 |
+
dset = dsets.datasets[split]
|
60 |
+
else:
|
61 |
+
dset = next(iter(dsets.datasets.values()))
|
62 |
+
batch_size = 1
|
63 |
+
start_index = st.sidebar.number_input("Example Index (Size: {})".format(len(dset)), value=0,
|
64 |
+
min_value=0,
|
65 |
+
max_value=len(dset)-batch_size)
|
66 |
+
indices = list(range(start_index, start_index+batch_size))
|
67 |
+
|
68 |
+
example = default_collate([dset[i] for i in indices])
|
69 |
+
|
70 |
+
x = model.get_input("image", example).to(model.device)
|
71 |
+
|
72 |
+
cond_key = model.cond_stage_key
|
73 |
+
c = model.get_input(cond_key, example).to(model.device)
|
74 |
+
|
75 |
+
scale_factor = st.sidebar.slider("Scale Factor", min_value=0.5, max_value=4.0, step=0.25, value=1.00)
|
76 |
+
if scale_factor != 1.0:
|
77 |
+
x = torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="bicubic")
|
78 |
+
c = torch.nn.functional.interpolate(c, scale_factor=scale_factor, mode="bicubic")
|
79 |
+
|
80 |
+
quant_z, z_indices = model.encode_to_z(x)
|
81 |
+
quant_c, c_indices = model.encode_to_c(c)
|
82 |
+
|
83 |
+
cshape = quant_z.shape
|
84 |
+
|
85 |
+
xrec = model.first_stage_model.decode(quant_z)
|
86 |
+
st.write("image: {}".format(x.shape))
|
87 |
+
st.image(bchw_to_st(x), clamp=True, output_format="PNG")
|
88 |
+
st.write("image reconstruction: {}".format(xrec.shape))
|
89 |
+
st.image(bchw_to_st(xrec), clamp=True, output_format="PNG")
|
90 |
+
|
91 |
+
if cond_key == "segmentation":
|
92 |
+
# get image from segmentation mask
|
93 |
+
num_classes = c.shape[1]
|
94 |
+
c = torch.argmax(c, dim=1, keepdim=True)
|
95 |
+
c = torch.nn.functional.one_hot(c, num_classes=num_classes)
|
96 |
+
c = c.squeeze(1).permute(0, 3, 1, 2).float()
|
97 |
+
c = model.cond_stage_model.to_rgb(c)
|
98 |
+
|
99 |
+
st.write(f"{cond_key}: {tuple(c.shape)}")
|
100 |
+
st.image(bchw_to_st(c), clamp=True, output_format="PNG")
|
101 |
+
|
102 |
+
idx = z_indices
|
103 |
+
|
104 |
+
half_sample = st.sidebar.checkbox("Image Completion", value=False)
|
105 |
+
if half_sample:
|
106 |
+
start = idx.shape[1]//2
|
107 |
+
else:
|
108 |
+
start = 0
|
109 |
+
|
110 |
+
idx[:,start:] = 0
|
111 |
+
idx = idx.reshape(cshape[0],cshape[2],cshape[3])
|
112 |
+
start_i = start//cshape[3]
|
113 |
+
start_j = start %cshape[3]
|
114 |
+
|
115 |
+
if not half_sample and quant_z.shape == quant_c.shape:
|
116 |
+
st.info("Setting idx to c_indices")
|
117 |
+
idx = c_indices.clone().reshape(cshape[0],cshape[2],cshape[3])
|
118 |
+
|
119 |
+
cidx = c_indices
|
120 |
+
cidx = cidx.reshape(quant_c.shape[0],quant_c.shape[2],quant_c.shape[3])
|
121 |
+
|
122 |
+
xstart = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
|
123 |
+
st.image(bchw_to_st(xstart), clamp=True, output_format="PNG")
|
124 |
+
|
125 |
+
temperature = st.number_input("Temperature", value=1.0)
|
126 |
+
top_k = st.number_input("Top k", value=100)
|
127 |
+
sample = st.checkbox("Sample", value=True)
|
128 |
+
update_every = st.number_input("Update every", value=75)
|
129 |
+
|
130 |
+
st.text(f"Sampling shape ({cshape[2]},{cshape[3]})")
|
131 |
+
|
132 |
+
animate = st.checkbox("animate")
|
133 |
+
if animate:
|
134 |
+
import imageio
|
135 |
+
outvid = "sampling.mp4"
|
136 |
+
writer = imageio.get_writer(outvid, fps=25)
|
137 |
+
elapsed_t = st.empty()
|
138 |
+
info = st.empty()
|
139 |
+
st.text("Sampled")
|
140 |
+
if st.button("Sample"):
|
141 |
+
output = st.empty()
|
142 |
+
start_t = time.time()
|
143 |
+
for i in range(start_i,cshape[2]-0):
|
144 |
+
if i <= 8:
|
145 |
+
local_i = i
|
146 |
+
elif cshape[2]-i < 8:
|
147 |
+
local_i = 16-(cshape[2]-i)
|
148 |
+
else:
|
149 |
+
local_i = 8
|
150 |
+
for j in range(start_j,cshape[3]-0):
|
151 |
+
if j <= 8:
|
152 |
+
local_j = j
|
153 |
+
elif cshape[3]-j < 8:
|
154 |
+
local_j = 16-(cshape[3]-j)
|
155 |
+
else:
|
156 |
+
local_j = 8
|
157 |
+
|
158 |
+
i_start = i-local_i
|
159 |
+
i_end = i_start+16
|
160 |
+
j_start = j-local_j
|
161 |
+
j_end = j_start+16
|
162 |
+
elapsed_t.text(f"Time: {time.time() - start_t} seconds")
|
163 |
+
info.text(f"Step: ({i},{j}) | Local: ({local_i},{local_j}) | Crop: ({i_start}:{i_end},{j_start}:{j_end})")
|
164 |
+
patch = idx[:,i_start:i_end,j_start:j_end]
|
165 |
+
patch = patch.reshape(patch.shape[0],-1)
|
166 |
+
cpatch = cidx[:, i_start:i_end, j_start:j_end]
|
167 |
+
cpatch = cpatch.reshape(cpatch.shape[0], -1)
|
168 |
+
patch = torch.cat((cpatch, patch), dim=1)
|
169 |
+
logits,_ = model.transformer(patch[:,:-1])
|
170 |
+
logits = logits[:, -256:, :]
|
171 |
+
logits = logits.reshape(cshape[0],16,16,-1)
|
172 |
+
logits = logits[:,local_i,local_j,:]
|
173 |
+
|
174 |
+
logits = logits/temperature
|
175 |
+
|
176 |
+
if top_k is not None:
|
177 |
+
logits = model.top_k_logits(logits, top_k)
|
178 |
+
# apply softmax to convert to probabilities
|
179 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
180 |
+
# sample from the distribution or take the most likely
|
181 |
+
if sample:
|
182 |
+
ix = torch.multinomial(probs, num_samples=1)
|
183 |
+
else:
|
184 |
+
_, ix = torch.topk(probs, k=1, dim=-1)
|
185 |
+
idx[:,i,j] = ix
|
186 |
+
|
187 |
+
if (i*cshape[3]+j)%update_every==0:
|
188 |
+
xstart = model.decode_to_img(idx[:, :cshape[2], :cshape[3]], cshape,)
|
189 |
+
|
190 |
+
xstart = bchw_to_st(xstart)
|
191 |
+
output.image(xstart, clamp=True, output_format="PNG")
|
192 |
+
|
193 |
+
if animate:
|
194 |
+
writer.append_data((xstart[0]*255).clip(0, 255).astype(np.uint8))
|
195 |
+
|
196 |
+
xstart = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
|
197 |
+
xstart = bchw_to_st(xstart)
|
198 |
+
output.image(xstart, clamp=True, output_format="PNG")
|
199 |
+
#save_img(xstart, "full_res_sample.png")
|
200 |
+
if animate:
|
201 |
+
writer.close()
|
202 |
+
st.video(outvid)
|
203 |
+
|
204 |
+
|
205 |
+
def get_parser():
|
206 |
+
parser = argparse.ArgumentParser()
|
207 |
+
parser.add_argument(
|
208 |
+
"-r",
|
209 |
+
"--resume",
|
210 |
+
type=str,
|
211 |
+
nargs="?",
|
212 |
+
help="load from logdir or checkpoint in logdir",
|
213 |
+
)
|
214 |
+
parser.add_argument(
|
215 |
+
"-b",
|
216 |
+
"--base",
|
217 |
+
nargs="*",
|
218 |
+
metavar="base_config.yaml",
|
219 |
+
help="paths to base configs. Loaded from left-to-right. "
|
220 |
+
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
221 |
+
default=list(),
|
222 |
+
)
|
223 |
+
parser.add_argument(
|
224 |
+
"-c",
|
225 |
+
"--config",
|
226 |
+
nargs="?",
|
227 |
+
metavar="single_config.yaml",
|
228 |
+
help="path to single config. If specified, base configs will be ignored "
|
229 |
+
"(except for the last one if left unspecified).",
|
230 |
+
const=True,
|
231 |
+
default="",
|
232 |
+
)
|
233 |
+
parser.add_argument(
|
234 |
+
"--ignore_base_data",
|
235 |
+
action="store_true",
|
236 |
+
help="Ignore data specification from base configs. Useful if you want "
|
237 |
+
"to specify a custom datasets on the command line.",
|
238 |
+
)
|
239 |
+
return parser
|
240 |
+
|
241 |
+
|
242 |
+
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
|
243 |
+
if "ckpt_path" in config.params:
|
244 |
+
st.warning("Deleting the restore-ckpt path from the config...")
|
245 |
+
config.params.ckpt_path = None
|
246 |
+
if "downsample_cond_size" in config.params:
|
247 |
+
st.warning("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
|
248 |
+
config.params.downsample_cond_size = -1
|
249 |
+
config.params["downsample_cond_factor"] = 0.5
|
250 |
+
try:
|
251 |
+
if "ckpt_path" in config.params.first_stage_config.params:
|
252 |
+
config.params.first_stage_config.params.ckpt_path = None
|
253 |
+
st.warning("Deleting the first-stage restore-ckpt path from the config...")
|
254 |
+
if "ckpt_path" in config.params.cond_stage_config.params:
|
255 |
+
config.params.cond_stage_config.params.ckpt_path = None
|
256 |
+
st.warning("Deleting the cond-stage restore-ckpt path from the config...")
|
257 |
+
except:
|
258 |
+
pass
|
259 |
+
|
260 |
+
model = instantiate_from_config(config)
|
261 |
+
if sd is not None:
|
262 |
+
missing, unexpected = model.load_state_dict(sd, strict=False)
|
263 |
+
st.info(f"Missing Keys in State Dict: {missing}")
|
264 |
+
st.info(f"Unexpected Keys in State Dict: {unexpected}")
|
265 |
+
if gpu:
|
266 |
+
model.cuda()
|
267 |
+
if eval_mode:
|
268 |
+
model.eval()
|
269 |
+
return {"model": model}
|
270 |
+
|
271 |
+
|
272 |
+
def get_data(config):
|
273 |
+
# get data
|
274 |
+
data = instantiate_from_config(config.data)
|
275 |
+
data.prepare_data()
|
276 |
+
data.setup()
|
277 |
+
return data
|
278 |
+
|
279 |
+
|
280 |
+
@st.cache(allow_output_mutation=True, suppress_st_warning=True)
|
281 |
+
def load_model_and_dset(config, ckpt, gpu, eval_mode):
|
282 |
+
# get data
|
283 |
+
dsets = get_data(config) # calls data.config ...
|
284 |
+
|
285 |
+
# now load the specified checkpoint
|
286 |
+
if ckpt:
|
287 |
+
pl_sd = torch.load(ckpt, map_location="cpu")
|
288 |
+
global_step = pl_sd["global_step"]
|
289 |
+
else:
|
290 |
+
pl_sd = {"state_dict": None}
|
291 |
+
global_step = None
|
292 |
+
model = load_model_from_config(config.model,
|
293 |
+
pl_sd["state_dict"],
|
294 |
+
gpu=gpu,
|
295 |
+
eval_mode=eval_mode)["model"]
|
296 |
+
return dsets, model, global_step
|
297 |
+
|
298 |
+
|
299 |
+
if __name__ == "__main__":
|
300 |
+
sys.path.append(os.getcwd())
|
301 |
+
|
302 |
+
parser = get_parser()
|
303 |
+
|
304 |
+
opt, unknown = parser.parse_known_args()
|
305 |
+
|
306 |
+
ckpt = None
|
307 |
+
if opt.resume:
|
308 |
+
if not os.path.exists(opt.resume):
|
309 |
+
raise ValueError("Cannot find {}".format(opt.resume))
|
310 |
+
if os.path.isfile(opt.resume):
|
311 |
+
paths = opt.resume.split("/")
|
312 |
+
try:
|
313 |
+
idx = len(paths)-paths[::-1].index("logs")+1
|
314 |
+
except ValueError:
|
315 |
+
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
|
316 |
+
logdir = "/".join(paths[:idx])
|
317 |
+
ckpt = opt.resume
|
318 |
+
else:
|
319 |
+
assert os.path.isdir(opt.resume), opt.resume
|
320 |
+
logdir = opt.resume.rstrip("/")
|
321 |
+
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
322 |
+
print(f"logdir:{logdir}")
|
323 |
+
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
|
324 |
+
opt.base = base_configs+opt.base
|
325 |
+
|
326 |
+
if opt.config:
|
327 |
+
if type(opt.config) == str:
|
328 |
+
opt.base = [opt.config]
|
329 |
+
else:
|
330 |
+
opt.base = [opt.base[-1]]
|
331 |
+
|
332 |
+
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
333 |
+
cli = OmegaConf.from_dotlist(unknown)
|
334 |
+
if opt.ignore_base_data:
|
335 |
+
for config in configs:
|
336 |
+
if hasattr(config, "data"): del config["data"]
|
337 |
+
config = OmegaConf.merge(*configs, cli)
|
338 |
+
|
339 |
+
st.sidebar.text(ckpt)
|
340 |
+
gs = st.sidebar.empty()
|
341 |
+
gs.text(f"Global step: ?")
|
342 |
+
st.sidebar.text("Options")
|
343 |
+
#gpu = st.sidebar.checkbox("GPU", value=True)
|
344 |
+
gpu = True
|
345 |
+
#eval_mode = st.sidebar.checkbox("Eval Mode", value=True)
|
346 |
+
eval_mode = True
|
347 |
+
#show_config = st.sidebar.checkbox("Show Config", value=False)
|
348 |
+
show_config = False
|
349 |
+
if show_config:
|
350 |
+
st.info("Checkpoint: {}".format(ckpt))
|
351 |
+
st.json(OmegaConf.to_container(config))
|
352 |
+
|
353 |
+
dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
|
354 |
+
gs.text(f"Global step: {global_step}")
|
355 |
+
run_conditional(model, dsets)
|
scripts/sample_fast.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse, os, sys, glob
|
2 |
+
import torch
|
3 |
+
import time
|
4 |
+
import numpy as np
|
5 |
+
from omegaconf import OmegaConf
|
6 |
+
from PIL import Image
|
7 |
+
from tqdm import tqdm, trange
|
8 |
+
from einops import repeat
|
9 |
+
|
10 |
+
from main import instantiate_from_config
|
11 |
+
from taming.modules.transformer.mingpt import sample_with_past
|
12 |
+
|
13 |
+
|
14 |
+
rescale = lambda x: (x + 1.) / 2.
|
15 |
+
|
16 |
+
|
17 |
+
def chw_to_pillow(x):
|
18 |
+
return Image.fromarray((255*rescale(x.detach().cpu().numpy().transpose(1,2,0))).clip(0,255).astype(np.uint8))
|
19 |
+
|
20 |
+
|
21 |
+
@torch.no_grad()
|
22 |
+
def sample_classconditional(model, batch_size, class_label, steps=256, temperature=None, top_k=None, callback=None,
|
23 |
+
dim_z=256, h=16, w=16, verbose_time=False, top_p=None):
|
24 |
+
log = dict()
|
25 |
+
assert type(class_label) == int, f'expecting type int but type is {type(class_label)}'
|
26 |
+
qzshape = [batch_size, dim_z, h, w]
|
27 |
+
assert not model.be_unconditional, 'Expecting a class-conditional Net2NetTransformer.'
|
28 |
+
c_indices = repeat(torch.tensor([class_label]), '1 -> b 1', b=batch_size).to(model.device) # class token
|
29 |
+
t1 = time.time()
|
30 |
+
index_sample = sample_with_past(c_indices, model.transformer, steps=steps,
|
31 |
+
sample_logits=True, top_k=top_k, callback=callback,
|
32 |
+
temperature=temperature, top_p=top_p)
|
33 |
+
if verbose_time:
|
34 |
+
sampling_time = time.time() - t1
|
35 |
+
print(f"Full sampling takes about {sampling_time:.2f} seconds.")
|
36 |
+
x_sample = model.decode_to_img(index_sample, qzshape)
|
37 |
+
log["samples"] = x_sample
|
38 |
+
log["class_label"] = c_indices
|
39 |
+
return log
|
40 |
+
|
41 |
+
|
42 |
+
@torch.no_grad()
|
43 |
+
def sample_unconditional(model, batch_size, steps=256, temperature=None, top_k=None, top_p=None, callback=None,
|
44 |
+
dim_z=256, h=16, w=16, verbose_time=False):
|
45 |
+
log = dict()
|
46 |
+
qzshape = [batch_size, dim_z, h, w]
|
47 |
+
assert model.be_unconditional, 'Expecting an unconditional model.'
|
48 |
+
c_indices = repeat(torch.tensor([model.sos_token]), '1 -> b 1', b=batch_size).to(model.device) # sos token
|
49 |
+
t1 = time.time()
|
50 |
+
index_sample = sample_with_past(c_indices, model.transformer, steps=steps,
|
51 |
+
sample_logits=True, top_k=top_k, callback=callback,
|
52 |
+
temperature=temperature, top_p=top_p)
|
53 |
+
if verbose_time:
|
54 |
+
sampling_time = time.time() - t1
|
55 |
+
print(f"Full sampling takes about {sampling_time:.2f} seconds.")
|
56 |
+
x_sample = model.decode_to_img(index_sample, qzshape)
|
57 |
+
log["samples"] = x_sample
|
58 |
+
return log
|
59 |
+
|
60 |
+
|
61 |
+
@torch.no_grad()
|
62 |
+
def run(logdir, model, batch_size, temperature, top_k, unconditional=True, num_samples=50000,
|
63 |
+
given_classes=None, top_p=None):
|
64 |
+
batches = [batch_size for _ in range(num_samples//batch_size)] + [num_samples % batch_size]
|
65 |
+
if not unconditional:
|
66 |
+
assert given_classes is not None
|
67 |
+
print("Running in pure class-conditional sampling mode. I will produce "
|
68 |
+
f"{num_samples} samples for each of the {len(given_classes)} classes, "
|
69 |
+
f"i.e. {num_samples*len(given_classes)} in total.")
|
70 |
+
for class_label in tqdm(given_classes, desc="Classes"):
|
71 |
+
for n, bs in tqdm(enumerate(batches), desc="Sampling Class"):
|
72 |
+
if bs == 0: break
|
73 |
+
logs = sample_classconditional(model, batch_size=bs, class_label=class_label,
|
74 |
+
temperature=temperature, top_k=top_k, top_p=top_p)
|
75 |
+
save_from_logs(logs, logdir, base_count=n * batch_size, cond_key=logs["class_label"])
|
76 |
+
else:
|
77 |
+
print(f"Running in unconditional sampling mode, producing {num_samples} samples.")
|
78 |
+
for n, bs in tqdm(enumerate(batches), desc="Sampling"):
|
79 |
+
if bs == 0: break
|
80 |
+
logs = sample_unconditional(model, batch_size=bs, temperature=temperature, top_k=top_k, top_p=top_p)
|
81 |
+
save_from_logs(logs, logdir, base_count=n * batch_size)
|
82 |
+
|
83 |
+
|
84 |
+
def save_from_logs(logs, logdir, base_count, key="samples", cond_key=None):
|
85 |
+
xx = logs[key]
|
86 |
+
for i, x in enumerate(xx):
|
87 |
+
x = chw_to_pillow(x)
|
88 |
+
count = base_count + i
|
89 |
+
if cond_key is None:
|
90 |
+
x.save(os.path.join(logdir, f"{count:06}.png"))
|
91 |
+
else:
|
92 |
+
condlabel = cond_key[i]
|
93 |
+
if type(condlabel) == torch.Tensor: condlabel = condlabel.item()
|
94 |
+
os.makedirs(os.path.join(logdir, str(condlabel)), exist_ok=True)
|
95 |
+
x.save(os.path.join(logdir, str(condlabel), f"{count:06}.png"))
|
96 |
+
|
97 |
+
|
98 |
+
def get_parser():
|
99 |
+
def str2bool(v):
|
100 |
+
if isinstance(v, bool):
|
101 |
+
return v
|
102 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
103 |
+
return True
|
104 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
105 |
+
return False
|
106 |
+
else:
|
107 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
108 |
+
|
109 |
+
parser = argparse.ArgumentParser()
|
110 |
+
parser.add_argument(
|
111 |
+
"-r",
|
112 |
+
"--resume",
|
113 |
+
type=str,
|
114 |
+
nargs="?",
|
115 |
+
help="load from logdir or checkpoint in logdir",
|
116 |
+
)
|
117 |
+
parser.add_argument(
|
118 |
+
"-o",
|
119 |
+
"--outdir",
|
120 |
+
type=str,
|
121 |
+
nargs="?",
|
122 |
+
help="path where the samples will be logged to.",
|
123 |
+
default=""
|
124 |
+
)
|
125 |
+
parser.add_argument(
|
126 |
+
"-b",
|
127 |
+
"--base",
|
128 |
+
nargs="*",
|
129 |
+
metavar="base_config.yaml",
|
130 |
+
help="paths to base configs. Loaded from left-to-right. "
|
131 |
+
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
132 |
+
default=list(),
|
133 |
+
)
|
134 |
+
parser.add_argument(
|
135 |
+
"-n",
|
136 |
+
"--num_samples",
|
137 |
+
type=int,
|
138 |
+
nargs="?",
|
139 |
+
help="num_samples to draw",
|
140 |
+
default=50000
|
141 |
+
)
|
142 |
+
parser.add_argument(
|
143 |
+
"--batch_size",
|
144 |
+
type=int,
|
145 |
+
nargs="?",
|
146 |
+
help="the batch size",
|
147 |
+
default=25
|
148 |
+
)
|
149 |
+
parser.add_argument(
|
150 |
+
"-k",
|
151 |
+
"--top_k",
|
152 |
+
type=int,
|
153 |
+
nargs="?",
|
154 |
+
help="top-k value to sample with",
|
155 |
+
default=250,
|
156 |
+
)
|
157 |
+
parser.add_argument(
|
158 |
+
"-t",
|
159 |
+
"--temperature",
|
160 |
+
type=float,
|
161 |
+
nargs="?",
|
162 |
+
help="temperature value to sample with",
|
163 |
+
default=1.0
|
164 |
+
)
|
165 |
+
parser.add_argument(
|
166 |
+
"-p",
|
167 |
+
"--top_p",
|
168 |
+
type=float,
|
169 |
+
nargs="?",
|
170 |
+
help="top-p value to sample with",
|
171 |
+
default=1.0
|
172 |
+
)
|
173 |
+
parser.add_argument(
|
174 |
+
"--classes",
|
175 |
+
type=str,
|
176 |
+
nargs="?",
|
177 |
+
help="specify comma-separated classes to sample from. Uses 1000 classes per default.",
|
178 |
+
default="imagenet"
|
179 |
+
)
|
180 |
+
return parser
|
181 |
+
|
182 |
+
|
183 |
+
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
|
184 |
+
model = instantiate_from_config(config)
|
185 |
+
if sd is not None:
|
186 |
+
model.load_state_dict(sd)
|
187 |
+
if gpu:
|
188 |
+
model.cuda()
|
189 |
+
if eval_mode:
|
190 |
+
model.eval()
|
191 |
+
return {"model": model}
|
192 |
+
|
193 |
+
|
194 |
+
def load_model(config, ckpt, gpu, eval_mode):
|
195 |
+
# load the specified checkpoint
|
196 |
+
if ckpt:
|
197 |
+
pl_sd = torch.load(ckpt, map_location="cpu")
|
198 |
+
global_step = pl_sd["global_step"]
|
199 |
+
print(f"loaded model from global step {global_step}.")
|
200 |
+
else:
|
201 |
+
pl_sd = {"state_dict": None}
|
202 |
+
global_step = None
|
203 |
+
model = load_model_from_config(config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"]
|
204 |
+
return model, global_step
|
205 |
+
|
206 |
+
|
207 |
+
if __name__ == "__main__":
|
208 |
+
sys.path.append(os.getcwd())
|
209 |
+
parser = get_parser()
|
210 |
+
|
211 |
+
opt, unknown = parser.parse_known_args()
|
212 |
+
assert opt.resume
|
213 |
+
|
214 |
+
ckpt = None
|
215 |
+
|
216 |
+
if not os.path.exists(opt.resume):
|
217 |
+
raise ValueError("Cannot find {}".format(opt.resume))
|
218 |
+
if os.path.isfile(opt.resume):
|
219 |
+
paths = opt.resume.split("/")
|
220 |
+
try:
|
221 |
+
idx = len(paths)-paths[::-1].index("logs")+1
|
222 |
+
except ValueError:
|
223 |
+
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
|
224 |
+
logdir = "/".join(paths[:idx])
|
225 |
+
ckpt = opt.resume
|
226 |
+
else:
|
227 |
+
assert os.path.isdir(opt.resume), opt.resume
|
228 |
+
logdir = opt.resume.rstrip("/")
|
229 |
+
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
230 |
+
|
231 |
+
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
|
232 |
+
opt.base = base_configs+opt.base
|
233 |
+
|
234 |
+
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
235 |
+
cli = OmegaConf.from_dotlist(unknown)
|
236 |
+
config = OmegaConf.merge(*configs, cli)
|
237 |
+
|
238 |
+
model, global_step = load_model(config, ckpt, gpu=True, eval_mode=True)
|
239 |
+
|
240 |
+
if opt.outdir:
|
241 |
+
print(f"Switching logdir from '{logdir}' to '{opt.outdir}'")
|
242 |
+
logdir = opt.outdir
|
243 |
+
|
244 |
+
if opt.classes == "imagenet":
|
245 |
+
given_classes = [i for i in range(1000)]
|
246 |
+
else:
|
247 |
+
cls_str = opt.classes
|
248 |
+
assert not cls_str.endswith(","), 'class string should not end with a ","'
|
249 |
+
given_classes = [int(c) for c in cls_str.split(",")]
|
250 |
+
|
251 |
+
logdir = os.path.join(logdir, "samples", f"top_k_{opt.top_k}_temp_{opt.temperature:.2f}_top_p_{opt.top_p}",
|
252 |
+
f"{global_step}")
|
253 |
+
|
254 |
+
print(f"Logging to {logdir}")
|
255 |
+
os.makedirs(logdir, exist_ok=True)
|
256 |
+
|
257 |
+
run(logdir, model, opt.batch_size, opt.temperature, opt.top_k, unconditional=model.be_unconditional,
|
258 |
+
given_classes=given_classes, num_samples=opt.num_samples, top_p=opt.top_p)
|
259 |
+
|
260 |
+
print("done.")
|
scripts/taming-transformers.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
setup.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
|
3 |
+
setup(
|
4 |
+
name='taming-transformers',
|
5 |
+
version='0.0.1',
|
6 |
+
description='Taming Transformers for High-Resolution Image Synthesis',
|
7 |
+
packages=find_packages(),
|
8 |
+
install_requires=[
|
9 |
+
'torch',
|
10 |
+
'numpy',
|
11 |
+
'tqdm',
|
12 |
+
],
|
13 |
+
)
|
taming/__pycache__/lr_scheduler.cpython-312.pyc
ADDED
Binary file (2.19 kB). View file
|
|
taming/__pycache__/util.cpython-312.pyc
ADDED
Binary file (6.33 kB). View file
|
|
taming/data/.ipynb_checkpoints/utils-checkpoint.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import os
|
3 |
+
import tarfile
|
4 |
+
import urllib
|
5 |
+
import zipfile
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from taming.data.helper_types import Annotation
|
11 |
+
#from torch._six import string_classes
|
12 |
+
from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
string_classes = (str,bytes)
|
16 |
+
|
17 |
+
|
18 |
+
def unpack(path):
|
19 |
+
if path.endswith("tar.gz"):
|
20 |
+
with tarfile.open(path, "r:gz") as tar:
|
21 |
+
tar.extractall(path=os.path.split(path)[0])
|
22 |
+
elif path.endswith("tar"):
|
23 |
+
with tarfile.open(path, "r:") as tar:
|
24 |
+
tar.extractall(path=os.path.split(path)[0])
|
25 |
+
elif path.endswith("zip"):
|
26 |
+
with zipfile.ZipFile(path, "r") as f:
|
27 |
+
f.extractall(path=os.path.split(path)[0])
|
28 |
+
else:
|
29 |
+
raise NotImplementedError(
|
30 |
+
"Unknown file extension: {}".format(os.path.splitext(path)[1])
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
def reporthook(bar):
|
35 |
+
"""tqdm progress bar for downloads."""
|
36 |
+
|
37 |
+
def hook(b=1, bsize=1, tsize=None):
|
38 |
+
if tsize is not None:
|
39 |
+
bar.total = tsize
|
40 |
+
bar.update(b * bsize - bar.n)
|
41 |
+
|
42 |
+
return hook
|
43 |
+
|
44 |
+
|
45 |
+
def get_root(name):
|
46 |
+
base = "data/"
|
47 |
+
root = os.path.join(base, name)
|
48 |
+
os.makedirs(root, exist_ok=True)
|
49 |
+
return root
|
50 |
+
|
51 |
+
|
52 |
+
def is_prepared(root):
|
53 |
+
return Path(root).joinpath(".ready").exists()
|
54 |
+
|
55 |
+
|
56 |
+
def mark_prepared(root):
|
57 |
+
Path(root).joinpath(".ready").touch()
|
58 |
+
|
59 |
+
|
60 |
+
def prompt_download(file_, source, target_dir, content_dir=None):
|
61 |
+
targetpath = os.path.join(target_dir, file_)
|
62 |
+
while not os.path.exists(targetpath):
|
63 |
+
if content_dir is not None and os.path.exists(
|
64 |
+
os.path.join(target_dir, content_dir)
|
65 |
+
):
|
66 |
+
break
|
67 |
+
print(
|
68 |
+
"Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath)
|
69 |
+
)
|
70 |
+
if content_dir is not None:
|
71 |
+
print(
|
72 |
+
"Or place its content into '{}'.".format(
|
73 |
+
os.path.join(target_dir, content_dir)
|
74 |
+
)
|
75 |
+
)
|
76 |
+
input("Press Enter when done...")
|
77 |
+
return targetpath
|
78 |
+
|
79 |
+
|
80 |
+
def download_url(file_, url, target_dir):
|
81 |
+
targetpath = os.path.join(target_dir, file_)
|
82 |
+
os.makedirs(target_dir, exist_ok=True)
|
83 |
+
with tqdm(
|
84 |
+
unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_
|
85 |
+
) as bar:
|
86 |
+
urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar))
|
87 |
+
return targetpath
|
88 |
+
|
89 |
+
|
90 |
+
def download_urls(urls, target_dir):
|
91 |
+
paths = dict()
|
92 |
+
for fname, url in urls.items():
|
93 |
+
outpath = download_url(fname, url, target_dir)
|
94 |
+
paths[fname] = outpath
|
95 |
+
return paths
|
96 |
+
|
97 |
+
|
98 |
+
def quadratic_crop(x, bbox, alpha=1.0):
|
99 |
+
"""bbox is xmin, ymin, xmax, ymax"""
|
100 |
+
im_h, im_w = x.shape[:2]
|
101 |
+
bbox = np.array(bbox, dtype=np.float32)
|
102 |
+
bbox = np.clip(bbox, 0, max(im_h, im_w))
|
103 |
+
center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3])
|
104 |
+
w = bbox[2] - bbox[0]
|
105 |
+
h = bbox[3] - bbox[1]
|
106 |
+
l = int(alpha * max(w, h))
|
107 |
+
l = max(l, 2)
|
108 |
+
|
109 |
+
required_padding = -1 * min(
|
110 |
+
center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l)
|
111 |
+
)
|
112 |
+
required_padding = int(np.ceil(required_padding))
|
113 |
+
if required_padding > 0:
|
114 |
+
padding = [
|
115 |
+
[required_padding, required_padding],
|
116 |
+
[required_padding, required_padding],
|
117 |
+
]
|
118 |
+
padding += [[0, 0]] * (len(x.shape) - 2)
|
119 |
+
x = np.pad(x, padding, "reflect")
|
120 |
+
center = center[0] + required_padding, center[1] + required_padding
|
121 |
+
xmin = int(center[0] - l / 2)
|
122 |
+
ymin = int(center[1] - l / 2)
|
123 |
+
return np.array(x[ymin : ymin + l, xmin : xmin + l, ...])
|
124 |
+
|
125 |
+
|
126 |
+
def custom_collate(batch):
|
127 |
+
r"""source: pytorch 1.9.0, only one modification to original code """
|
128 |
+
|
129 |
+
elem = batch[0]
|
130 |
+
elem_type = type(elem)
|
131 |
+
if isinstance(elem, torch.Tensor):
|
132 |
+
out = None
|
133 |
+
if torch.utils.data.get_worker_info() is not None:
|
134 |
+
# If we're in a background process, concatenate directly into a
|
135 |
+
# shared memory tensor to avoid an extra copy
|
136 |
+
numel = sum([x.numel() for x in batch])
|
137 |
+
storage = elem.storage()._new_shared(numel)
|
138 |
+
out = elem.new(storage)
|
139 |
+
return torch.stack(batch, 0, out=out)
|
140 |
+
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
141 |
+
and elem_type.__name__ != 'string_':
|
142 |
+
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
|
143 |
+
# array of string classes and object
|
144 |
+
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
145 |
+
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
|
146 |
+
|
147 |
+
return custom_collate([torch.as_tensor(b) for b in batch])
|
148 |
+
elif elem.shape == (): # scalars
|
149 |
+
return torch.as_tensor(batch)
|
150 |
+
elif isinstance(elem, float):
|
151 |
+
return torch.tensor(batch, dtype=torch.float64)
|
152 |
+
elif isinstance(elem, int):
|
153 |
+
return torch.tensor(batch)
|
154 |
+
elif isinstance(elem, string_classes):
|
155 |
+
return batch
|
156 |
+
elif isinstance(elem, collections.abc.Mapping):
|
157 |
+
return {key: custom_collate([d[key] for d in batch]) for key in elem}
|
158 |
+
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
|
159 |
+
return elem_type(*(custom_collate(samples) for samples in zip(*batch)))
|
160 |
+
if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added
|
161 |
+
return batch # added
|
162 |
+
elif isinstance(elem, collections.abc.Sequence):
|
163 |
+
# check to make sure that the elements in batch have consistent size
|
164 |
+
it = iter(batch)
|
165 |
+
elem_size = len(next(it))
|
166 |
+
if not all(len(elem) == elem_size for elem in it):
|
167 |
+
raise RuntimeError('each element in list of batch should be of equal size')
|
168 |
+
transposed = zip(*batch)
|
169 |
+
return [custom_collate(samples) for samples in transposed]
|
170 |
+
|
171 |
+
raise TypeError(default_collate_err_msg_format.format(elem_type))
|
taming/data/__pycache__/helper_types.cpython-312.pyc
ADDED
Binary file (2.43 kB). View file
|
|
taming/data/__pycache__/utils.cpython-312.pyc
ADDED
Binary file (10.6 kB). View file
|
|
taming/data/ade20k.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import albumentations
|
5 |
+
from PIL import Image
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
|
8 |
+
from taming.data.sflckr import SegmentationBase # for examples included in repo
|
9 |
+
|
10 |
+
|
11 |
+
class Examples(SegmentationBase):
|
12 |
+
def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
|
13 |
+
super().__init__(data_csv="data/ade20k_examples.txt",
|
14 |
+
data_root="data/ade20k_images",
|
15 |
+
segmentation_root="data/ade20k_segmentations",
|
16 |
+
size=size, random_crop=random_crop,
|
17 |
+
interpolation=interpolation,
|
18 |
+
n_labels=151, shift_segmentation=False)
|
19 |
+
|
20 |
+
|
21 |
+
# With semantic map and scene label
|
22 |
+
class ADE20kBase(Dataset):
|
23 |
+
def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None):
|
24 |
+
self.split = self.get_split()
|
25 |
+
self.n_labels = 151 # unknown + 150
|
26 |
+
self.data_csv = {"train": "data/ade20k_train.txt",
|
27 |
+
"validation": "data/ade20k_test.txt"}[self.split]
|
28 |
+
self.data_root = "data/ade20k_root"
|
29 |
+
with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f:
|
30 |
+
self.scene_categories = f.read().splitlines()
|
31 |
+
self.scene_categories = dict(line.split() for line in self.scene_categories)
|
32 |
+
with open(self.data_csv, "r") as f:
|
33 |
+
self.image_paths = f.read().splitlines()
|
34 |
+
self._length = len(self.image_paths)
|
35 |
+
self.labels = {
|
36 |
+
"relative_file_path_": [l for l in self.image_paths],
|
37 |
+
"file_path_": [os.path.join(self.data_root, "images", l)
|
38 |
+
for l in self.image_paths],
|
39 |
+
"relative_segmentation_path_": [l.replace(".jpg", ".png")
|
40 |
+
for l in self.image_paths],
|
41 |
+
"segmentation_path_": [os.path.join(self.data_root, "annotations",
|
42 |
+
l.replace(".jpg", ".png"))
|
43 |
+
for l in self.image_paths],
|
44 |
+
"scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")]
|
45 |
+
for l in self.image_paths],
|
46 |
+
}
|
47 |
+
|
48 |
+
size = None if size is not None and size<=0 else size
|
49 |
+
self.size = size
|
50 |
+
if crop_size is None:
|
51 |
+
self.crop_size = size if size is not None else None
|
52 |
+
else:
|
53 |
+
self.crop_size = crop_size
|
54 |
+
if self.size is not None:
|
55 |
+
self.interpolation = interpolation
|
56 |
+
self.interpolation = {
|
57 |
+
"nearest": cv2.INTER_NEAREST,
|
58 |
+
"bilinear": cv2.INTER_LINEAR,
|
59 |
+
"bicubic": cv2.INTER_CUBIC,
|
60 |
+
"area": cv2.INTER_AREA,
|
61 |
+
"lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
|
62 |
+
self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
|
63 |
+
interpolation=self.interpolation)
|
64 |
+
self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
|
65 |
+
interpolation=cv2.INTER_NEAREST)
|
66 |
+
|
67 |
+
if crop_size is not None:
|
68 |
+
self.center_crop = not random_crop
|
69 |
+
if self.center_crop:
|
70 |
+
self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
|
71 |
+
else:
|
72 |
+
self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
|
73 |
+
self.preprocessor = self.cropper
|
74 |
+
|
75 |
+
def __len__(self):
|
76 |
+
return self._length
|
77 |
+
|
78 |
+
def __getitem__(self, i):
|
79 |
+
example = dict((k, self.labels[k][i]) for k in self.labels)
|
80 |
+
image = Image.open(example["file_path_"])
|
81 |
+
if not image.mode == "RGB":
|
82 |
+
image = image.convert("RGB")
|
83 |
+
image = np.array(image).astype(np.uint8)
|
84 |
+
if self.size is not None:
|
85 |
+
image = self.image_rescaler(image=image)["image"]
|
86 |
+
segmentation = Image.open(example["segmentation_path_"])
|
87 |
+
segmentation = np.array(segmentation).astype(np.uint8)
|
88 |
+
if self.size is not None:
|
89 |
+
segmentation = self.segmentation_rescaler(image=segmentation)["image"]
|
90 |
+
if self.size is not None:
|
91 |
+
processed = self.preprocessor(image=image, mask=segmentation)
|
92 |
+
else:
|
93 |
+
processed = {"image": image, "mask": segmentation}
|
94 |
+
example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
|
95 |
+
segmentation = processed["mask"]
|
96 |
+
onehot = np.eye(self.n_labels)[segmentation]
|
97 |
+
example["segmentation"] = onehot
|
98 |
+
return example
|
99 |
+
|
100 |
+
|
101 |
+
class ADE20kTrain(ADE20kBase):
|
102 |
+
# default to random_crop=True
|
103 |
+
def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None):
|
104 |
+
super().__init__(config=config, size=size, random_crop=random_crop,
|
105 |
+
interpolation=interpolation, crop_size=crop_size)
|
106 |
+
|
107 |
+
def get_split(self):
|
108 |
+
return "train"
|
109 |
+
|
110 |
+
|
111 |
+
class ADE20kValidation(ADE20kBase):
|
112 |
+
def get_split(self):
|
113 |
+
return "validation"
|
114 |
+
|
115 |
+
|
116 |
+
if __name__ == "__main__":
|
117 |
+
dset = ADE20kValidation()
|
118 |
+
ex = dset[0]
|
119 |
+
for k in ["image", "scene_category", "segmentation"]:
|
120 |
+
print(type(ex[k]))
|
121 |
+
try:
|
122 |
+
print(ex[k].shape)
|
123 |
+
except:
|
124 |
+
print(ex[k])
|
taming/data/annotated_objects_coco.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from itertools import chain
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Iterable, Dict, List, Callable, Any
|
5 |
+
from collections import defaultdict
|
6 |
+
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
|
10 |
+
from taming.data.helper_types import Annotation, ImageDescription, Category
|
11 |
+
|
12 |
+
COCO_PATH_STRUCTURE = {
|
13 |
+
'train': {
|
14 |
+
'top_level': '',
|
15 |
+
'instances_annotations': 'annotations/instances_train2017.json',
|
16 |
+
'stuff_annotations': 'annotations/stuff_train2017.json',
|
17 |
+
'files': 'train2017'
|
18 |
+
},
|
19 |
+
'validation': {
|
20 |
+
'top_level': '',
|
21 |
+
'instances_annotations': 'annotations/instances_val2017.json',
|
22 |
+
'stuff_annotations': 'annotations/stuff_val2017.json',
|
23 |
+
'files': 'val2017'
|
24 |
+
}
|
25 |
+
}
|
26 |
+
|
27 |
+
|
28 |
+
def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]:
|
29 |
+
return {
|
30 |
+
str(img['id']): ImageDescription(
|
31 |
+
id=img['id'],
|
32 |
+
license=img.get('license'),
|
33 |
+
file_name=img['file_name'],
|
34 |
+
coco_url=img['coco_url'],
|
35 |
+
original_size=(img['width'], img['height']),
|
36 |
+
date_captured=img.get('date_captured'),
|
37 |
+
flickr_url=img.get('flickr_url')
|
38 |
+
)
|
39 |
+
for img in description_json
|
40 |
+
}
|
41 |
+
|
42 |
+
|
43 |
+
def load_categories(category_json: Iterable) -> Dict[str, Category]:
|
44 |
+
return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name'])
|
45 |
+
for cat in category_json if cat['name'] != 'other'}
|
46 |
+
|
47 |
+
|
48 |
+
def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription],
|
49 |
+
category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]:
|
50 |
+
annotations = defaultdict(list)
|
51 |
+
total = sum(len(a) for a in annotations_json)
|
52 |
+
for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total):
|
53 |
+
image_id = str(ann['image_id'])
|
54 |
+
if image_id not in image_descriptions:
|
55 |
+
raise ValueError(f'image_id [{image_id}] has no image description.')
|
56 |
+
category_id = ann['category_id']
|
57 |
+
try:
|
58 |
+
category_no = category_no_for_id(str(category_id))
|
59 |
+
except KeyError:
|
60 |
+
continue
|
61 |
+
|
62 |
+
width, height = image_descriptions[image_id].original_size
|
63 |
+
bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height)
|
64 |
+
|
65 |
+
annotations[image_id].append(
|
66 |
+
Annotation(
|
67 |
+
id=ann['id'],
|
68 |
+
area=bbox[2]*bbox[3], # use bbox area
|
69 |
+
is_group_of=ann['iscrowd'],
|
70 |
+
image_id=ann['image_id'],
|
71 |
+
bbox=bbox,
|
72 |
+
category_id=str(category_id),
|
73 |
+
category_no=category_no
|
74 |
+
)
|
75 |
+
)
|
76 |
+
return dict(annotations)
|
77 |
+
|
78 |
+
|
79 |
+
class AnnotatedObjectsCoco(AnnotatedObjectsDataset):
|
80 |
+
def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs):
|
81 |
+
"""
|
82 |
+
@param data_path: is the path to the following folder structure:
|
83 |
+
coco/
|
84 |
+
├── annotations
|
85 |
+
│ ├── instances_train2017.json
|
86 |
+
│ ├── instances_val2017.json
|
87 |
+
│ ├── stuff_train2017.json
|
88 |
+
│ └── stuff_val2017.json
|
89 |
+
├── train2017
|
90 |
+
│ ├── 000000000009.jpg
|
91 |
+
│ ├── 000000000025.jpg
|
92 |
+
│ └── ...
|
93 |
+
├── val2017
|
94 |
+
│ ├── 000000000139.jpg
|
95 |
+
│ ├── 000000000285.jpg
|
96 |
+
│ └── ...
|
97 |
+
@param: split: one of 'train' or 'validation'
|
98 |
+
@param: desired image size (give square images)
|
99 |
+
"""
|
100 |
+
super().__init__(**kwargs)
|
101 |
+
self.use_things = use_things
|
102 |
+
self.use_stuff = use_stuff
|
103 |
+
|
104 |
+
with open(self.paths['instances_annotations']) as f:
|
105 |
+
inst_data_json = json.load(f)
|
106 |
+
with open(self.paths['stuff_annotations']) as f:
|
107 |
+
stuff_data_json = json.load(f)
|
108 |
+
|
109 |
+
category_jsons = []
|
110 |
+
annotation_jsons = []
|
111 |
+
if self.use_things:
|
112 |
+
category_jsons.append(inst_data_json['categories'])
|
113 |
+
annotation_jsons.append(inst_data_json['annotations'])
|
114 |
+
if self.use_stuff:
|
115 |
+
category_jsons.append(stuff_data_json['categories'])
|
116 |
+
annotation_jsons.append(stuff_data_json['annotations'])
|
117 |
+
|
118 |
+
self.categories = load_categories(chain(*category_jsons))
|
119 |
+
self.filter_categories()
|
120 |
+
self.setup_category_id_and_number()
|
121 |
+
|
122 |
+
self.image_descriptions = load_image_descriptions(inst_data_json['images'])
|
123 |
+
annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split)
|
124 |
+
self.annotations = self.filter_object_number(annotations, self.min_object_area,
|
125 |
+
self.min_objects_per_image, self.max_objects_per_image)
|
126 |
+
self.image_ids = list(self.annotations.keys())
|
127 |
+
self.clean_up_annotations_and_image_descriptions()
|
128 |
+
|
129 |
+
def get_path_structure(self) -> Dict[str, str]:
|
130 |
+
if self.split not in COCO_PATH_STRUCTURE:
|
131 |
+
raise ValueError(f'Split [{self.split} does not exist for COCO data.]')
|
132 |
+
return COCO_PATH_STRUCTURE[self.split]
|
133 |
+
|
134 |
+
def get_image_path(self, image_id: str) -> Path:
|
135 |
+
return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name)
|
136 |
+
|
137 |
+
def get_image_description(self, image_id: str) -> Dict[str, Any]:
|
138 |
+
# noinspection PyProtectedMember
|
139 |
+
return self.image_descriptions[image_id]._asdict()
|
taming/data/annotated_objects_dataset.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Optional, List, Callable, Dict, Any, Union
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
import PIL.Image as pil_image
|
6 |
+
from torch import Tensor
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from torchvision import transforms
|
9 |
+
|
10 |
+
from taming.data.conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder
|
11 |
+
from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
|
12 |
+
from taming.data.conditional_builder.utils import load_object_from_string
|
13 |
+
from taming.data.helper_types import BoundingBox, CropMethodType, Image, Annotation, SplitType
|
14 |
+
from taming.data.image_transforms import CenterCropReturnCoordinates, RandomCrop1dReturnCoordinates, \
|
15 |
+
Random2dCropReturnCoordinates, RandomHorizontalFlipReturn, convert_pil_to_tensor
|
16 |
+
|
17 |
+
|
18 |
+
class AnnotatedObjectsDataset(Dataset):
|
19 |
+
def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str], target_image_size: int,
|
20 |
+
min_object_area: float, min_objects_per_image: int, max_objects_per_image: int,
|
21 |
+
crop_method: CropMethodType, random_flip: bool, no_tokens: int, use_group_parameter: bool,
|
22 |
+
encode_crop: bool, category_allow_list_target: str = "", category_mapping_target: str = "",
|
23 |
+
no_object_classes: Optional[int] = None):
|
24 |
+
self.data_path = data_path
|
25 |
+
self.split = split
|
26 |
+
self.keys = keys
|
27 |
+
self.target_image_size = target_image_size
|
28 |
+
self.min_object_area = min_object_area
|
29 |
+
self.min_objects_per_image = min_objects_per_image
|
30 |
+
self.max_objects_per_image = max_objects_per_image
|
31 |
+
self.crop_method = crop_method
|
32 |
+
self.random_flip = random_flip
|
33 |
+
self.no_tokens = no_tokens
|
34 |
+
self.use_group_parameter = use_group_parameter
|
35 |
+
self.encode_crop = encode_crop
|
36 |
+
|
37 |
+
self.annotations = None
|
38 |
+
self.image_descriptions = None
|
39 |
+
self.categories = None
|
40 |
+
self.category_ids = None
|
41 |
+
self.category_number = None
|
42 |
+
self.image_ids = None
|
43 |
+
self.transform_functions: List[Callable] = self.setup_transform(target_image_size, crop_method, random_flip)
|
44 |
+
self.paths = self.build_paths(self.data_path)
|
45 |
+
self._conditional_builders = None
|
46 |
+
self.category_allow_list = None
|
47 |
+
if category_allow_list_target:
|
48 |
+
allow_list = load_object_from_string(category_allow_list_target)
|
49 |
+
self.category_allow_list = {name for name, _ in allow_list}
|
50 |
+
self.category_mapping = {}
|
51 |
+
if category_mapping_target:
|
52 |
+
self.category_mapping = load_object_from_string(category_mapping_target)
|
53 |
+
self.no_object_classes = no_object_classes
|
54 |
+
|
55 |
+
def build_paths(self, top_level: Union[str, Path]) -> Dict[str, Path]:
|
56 |
+
top_level = Path(top_level)
|
57 |
+
sub_paths = {name: top_level.joinpath(sub_path) for name, sub_path in self.get_path_structure().items()}
|
58 |
+
for path in sub_paths.values():
|
59 |
+
if not path.exists():
|
60 |
+
raise FileNotFoundError(f'{type(self).__name__} data structure error: [{path}] does not exist.')
|
61 |
+
return sub_paths
|
62 |
+
|
63 |
+
@staticmethod
|
64 |
+
def load_image_from_disk(path: Path) -> Image:
|
65 |
+
return pil_image.open(path).convert('RGB')
|
66 |
+
|
67 |
+
@staticmethod
|
68 |
+
def setup_transform(target_image_size: int, crop_method: CropMethodType, random_flip: bool):
|
69 |
+
transform_functions = []
|
70 |
+
if crop_method == 'none':
|
71 |
+
transform_functions.append(transforms.Resize((target_image_size, target_image_size)))
|
72 |
+
elif crop_method == 'center':
|
73 |
+
transform_functions.extend([
|
74 |
+
transforms.Resize(target_image_size),
|
75 |
+
CenterCropReturnCoordinates(target_image_size)
|
76 |
+
])
|
77 |
+
elif crop_method == 'random-1d':
|
78 |
+
transform_functions.extend([
|
79 |
+
transforms.Resize(target_image_size),
|
80 |
+
RandomCrop1dReturnCoordinates(target_image_size)
|
81 |
+
])
|
82 |
+
elif crop_method == 'random-2d':
|
83 |
+
transform_functions.extend([
|
84 |
+
Random2dCropReturnCoordinates(target_image_size),
|
85 |
+
transforms.Resize(target_image_size)
|
86 |
+
])
|
87 |
+
elif crop_method is None:
|
88 |
+
return None
|
89 |
+
else:
|
90 |
+
raise ValueError(f'Received invalid crop method [{crop_method}].')
|
91 |
+
if random_flip:
|
92 |
+
transform_functions.append(RandomHorizontalFlipReturn())
|
93 |
+
transform_functions.append(transforms.Lambda(lambda x: x / 127.5 - 1.))
|
94 |
+
return transform_functions
|
95 |
+
|
96 |
+
def image_transform(self, x: Tensor) -> (Optional[BoundingBox], Optional[bool], Tensor):
|
97 |
+
crop_bbox = None
|
98 |
+
flipped = None
|
99 |
+
for t in self.transform_functions:
|
100 |
+
if isinstance(t, (RandomCrop1dReturnCoordinates, CenterCropReturnCoordinates, Random2dCropReturnCoordinates)):
|
101 |
+
crop_bbox, x = t(x)
|
102 |
+
elif isinstance(t, RandomHorizontalFlipReturn):
|
103 |
+
flipped, x = t(x)
|
104 |
+
else:
|
105 |
+
x = t(x)
|
106 |
+
return crop_bbox, flipped, x
|
107 |
+
|
108 |
+
@property
|
109 |
+
def no_classes(self) -> int:
|
110 |
+
return self.no_object_classes if self.no_object_classes else len(self.categories)
|
111 |
+
|
112 |
+
@property
|
113 |
+
def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder:
|
114 |
+
# cannot set this up in init because no_classes is only known after loading data in init of superclass
|
115 |
+
if self._conditional_builders is None:
|
116 |
+
self._conditional_builders = {
|
117 |
+
'objects_center_points': ObjectsCenterPointsConditionalBuilder(
|
118 |
+
self.no_classes,
|
119 |
+
self.max_objects_per_image,
|
120 |
+
self.no_tokens,
|
121 |
+
self.encode_crop,
|
122 |
+
self.use_group_parameter,
|
123 |
+
getattr(self, 'use_additional_parameters', False)
|
124 |
+
),
|
125 |
+
'objects_bbox': ObjectsBoundingBoxConditionalBuilder(
|
126 |
+
self.no_classes,
|
127 |
+
self.max_objects_per_image,
|
128 |
+
self.no_tokens,
|
129 |
+
self.encode_crop,
|
130 |
+
self.use_group_parameter,
|
131 |
+
getattr(self, 'use_additional_parameters', False)
|
132 |
+
)
|
133 |
+
}
|
134 |
+
return self._conditional_builders
|
135 |
+
|
136 |
+
def filter_categories(self) -> None:
|
137 |
+
if self.category_allow_list:
|
138 |
+
self.categories = {id_: cat for id_, cat in self.categories.items() if cat.name in self.category_allow_list}
|
139 |
+
if self.category_mapping:
|
140 |
+
self.categories = {id_: cat for id_, cat in self.categories.items() if cat.id not in self.category_mapping}
|
141 |
+
|
142 |
+
def setup_category_id_and_number(self) -> None:
|
143 |
+
self.category_ids = list(self.categories.keys())
|
144 |
+
self.category_ids.sort()
|
145 |
+
if '/m/01s55n' in self.category_ids:
|
146 |
+
self.category_ids.remove('/m/01s55n')
|
147 |
+
self.category_ids.append('/m/01s55n')
|
148 |
+
self.category_number = {category_id: i for i, category_id in enumerate(self.category_ids)}
|
149 |
+
if self.category_allow_list is not None and self.category_mapping is None \
|
150 |
+
and len(self.category_ids) != len(self.category_allow_list):
|
151 |
+
warnings.warn('Unexpected number of categories: Mismatch with category_allow_list. '
|
152 |
+
'Make sure all names in category_allow_list exist.')
|
153 |
+
|
154 |
+
def clean_up_annotations_and_image_descriptions(self) -> None:
|
155 |
+
image_id_set = set(self.image_ids)
|
156 |
+
self.annotations = {k: v for k, v in self.annotations.items() if k in image_id_set}
|
157 |
+
self.image_descriptions = {k: v for k, v in self.image_descriptions.items() if k in image_id_set}
|
158 |
+
|
159 |
+
@staticmethod
|
160 |
+
def filter_object_number(all_annotations: Dict[str, List[Annotation]], min_object_area: float,
|
161 |
+
min_objects_per_image: int, max_objects_per_image: int) -> Dict[str, List[Annotation]]:
|
162 |
+
filtered = {}
|
163 |
+
for image_id, annotations in all_annotations.items():
|
164 |
+
annotations_with_min_area = [a for a in annotations if a.area > min_object_area]
|
165 |
+
if min_objects_per_image <= len(annotations_with_min_area) <= max_objects_per_image:
|
166 |
+
filtered[image_id] = annotations_with_min_area
|
167 |
+
return filtered
|
168 |
+
|
169 |
+
def __len__(self):
|
170 |
+
return len(self.image_ids)
|
171 |
+
|
172 |
+
def __getitem__(self, n: int) -> Dict[str, Any]:
|
173 |
+
image_id = self.get_image_id(n)
|
174 |
+
sample = self.get_image_description(image_id)
|
175 |
+
sample['annotations'] = self.get_annotation(image_id)
|
176 |
+
|
177 |
+
if 'image' in self.keys:
|
178 |
+
sample['image_path'] = str(self.get_image_path(image_id))
|
179 |
+
sample['image'] = self.load_image_from_disk(sample['image_path'])
|
180 |
+
sample['image'] = convert_pil_to_tensor(sample['image'])
|
181 |
+
sample['crop_bbox'], sample['flipped'], sample['image'] = self.image_transform(sample['image'])
|
182 |
+
sample['image'] = sample['image'].permute(1, 2, 0)
|
183 |
+
|
184 |
+
for conditional, builder in self.conditional_builders.items():
|
185 |
+
if conditional in self.keys:
|
186 |
+
sample[conditional] = builder.build(sample['annotations'], sample['crop_bbox'], sample['flipped'])
|
187 |
+
|
188 |
+
if self.keys:
|
189 |
+
# only return specified keys
|
190 |
+
sample = {key: sample[key] for key in self.keys}
|
191 |
+
return sample
|
192 |
+
|
193 |
+
def get_image_id(self, no: int) -> str:
|
194 |
+
return self.image_ids[no]
|
195 |
+
|
196 |
+
def get_annotation(self, image_id: str) -> str:
|
197 |
+
return self.annotations[image_id]
|
198 |
+
|
199 |
+
def get_textual_label_for_category_id(self, category_id: str) -> str:
|
200 |
+
return self.categories[category_id].name
|
201 |
+
|
202 |
+
def get_textual_label_for_category_no(self, category_no: int) -> str:
|
203 |
+
return self.categories[self.get_category_id(category_no)].name
|
204 |
+
|
205 |
+
def get_category_number(self, category_id: str) -> int:
|
206 |
+
return self.category_number[category_id]
|
207 |
+
|
208 |
+
def get_category_id(self, category_no: int) -> str:
|
209 |
+
return self.category_ids[category_no]
|
210 |
+
|
211 |
+
def get_image_description(self, image_id: str) -> Dict[str, Any]:
|
212 |
+
raise NotImplementedError()
|
213 |
+
|
214 |
+
def get_path_structure(self):
|
215 |
+
raise NotImplementedError
|
216 |
+
|
217 |
+
def get_image_path(self, image_id: str) -> Path:
|
218 |
+
raise NotImplementedError
|
taming/data/annotated_objects_open_images.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from csv import DictReader, reader as TupleReader
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Dict, List, Any
|
5 |
+
import warnings
|
6 |
+
|
7 |
+
from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
|
8 |
+
from taming.data.helper_types import Annotation, Category
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
OPEN_IMAGES_STRUCTURE = {
|
12 |
+
'train': {
|
13 |
+
'top_level': '',
|
14 |
+
'class_descriptions': 'class-descriptions-boxable.csv',
|
15 |
+
'annotations': 'oidv6-train-annotations-bbox.csv',
|
16 |
+
'file_list': 'train-images-boxable.csv',
|
17 |
+
'files': 'train'
|
18 |
+
},
|
19 |
+
'validation': {
|
20 |
+
'top_level': '',
|
21 |
+
'class_descriptions': 'class-descriptions-boxable.csv',
|
22 |
+
'annotations': 'validation-annotations-bbox.csv',
|
23 |
+
'file_list': 'validation-images.csv',
|
24 |
+
'files': 'validation'
|
25 |
+
},
|
26 |
+
'test': {
|
27 |
+
'top_level': '',
|
28 |
+
'class_descriptions': 'class-descriptions-boxable.csv',
|
29 |
+
'annotations': 'test-annotations-bbox.csv',
|
30 |
+
'file_list': 'test-images.csv',
|
31 |
+
'files': 'test'
|
32 |
+
}
|
33 |
+
}
|
34 |
+
|
35 |
+
|
36 |
+
def load_annotations(descriptor_path: Path, min_object_area: float, category_mapping: Dict[str, str],
|
37 |
+
category_no_for_id: Dict[str, int]) -> Dict[str, List[Annotation]]:
|
38 |
+
annotations: Dict[str, List[Annotation]] = defaultdict(list)
|
39 |
+
with open(descriptor_path) as file:
|
40 |
+
reader = DictReader(file)
|
41 |
+
for i, row in tqdm(enumerate(reader), total=14620000, desc='Loading OpenImages annotations'):
|
42 |
+
width = float(row['XMax']) - float(row['XMin'])
|
43 |
+
height = float(row['YMax']) - float(row['YMin'])
|
44 |
+
area = width * height
|
45 |
+
category_id = row['LabelName']
|
46 |
+
if category_id in category_mapping:
|
47 |
+
category_id = category_mapping[category_id]
|
48 |
+
if area >= min_object_area and category_id in category_no_for_id:
|
49 |
+
annotations[row['ImageID']].append(
|
50 |
+
Annotation(
|
51 |
+
id=i,
|
52 |
+
image_id=row['ImageID'],
|
53 |
+
source=row['Source'],
|
54 |
+
category_id=category_id,
|
55 |
+
category_no=category_no_for_id[category_id],
|
56 |
+
confidence=float(row['Confidence']),
|
57 |
+
bbox=(float(row['XMin']), float(row['YMin']), width, height),
|
58 |
+
area=area,
|
59 |
+
is_occluded=bool(int(row['IsOccluded'])),
|
60 |
+
is_truncated=bool(int(row['IsTruncated'])),
|
61 |
+
is_group_of=bool(int(row['IsGroupOf'])),
|
62 |
+
is_depiction=bool(int(row['IsDepiction'])),
|
63 |
+
is_inside=bool(int(row['IsInside']))
|
64 |
+
)
|
65 |
+
)
|
66 |
+
if 'train' in str(descriptor_path) and i < 14000000:
|
67 |
+
warnings.warn(f'Running with subset of Open Images. Train dataset has length [{len(annotations)}].')
|
68 |
+
return dict(annotations)
|
69 |
+
|
70 |
+
|
71 |
+
def load_image_ids(csv_path: Path) -> List[str]:
|
72 |
+
with open(csv_path) as file:
|
73 |
+
reader = DictReader(file)
|
74 |
+
return [row['image_name'] for row in reader]
|
75 |
+
|
76 |
+
|
77 |
+
def load_categories(csv_path: Path) -> Dict[str, Category]:
|
78 |
+
with open(csv_path) as file:
|
79 |
+
reader = TupleReader(file)
|
80 |
+
return {row[0]: Category(id=row[0], name=row[1], super_category=None) for row in reader}
|
81 |
+
|
82 |
+
|
83 |
+
class AnnotatedObjectsOpenImages(AnnotatedObjectsDataset):
|
84 |
+
def __init__(self, use_additional_parameters: bool, **kwargs):
|
85 |
+
"""
|
86 |
+
@param data_path: is the path to the following folder structure:
|
87 |
+
open_images/
|
88 |
+
│ oidv6-train-annotations-bbox.csv
|
89 |
+
├── class-descriptions-boxable.csv
|
90 |
+
├── oidv6-train-annotations-bbox.csv
|
91 |
+
├── test
|
92 |
+
│ ├── 000026e7ee790996.jpg
|
93 |
+
│ ├── 000062a39995e348.jpg
|
94 |
+
│ └── ...
|
95 |
+
├── test-annotations-bbox.csv
|
96 |
+
├── test-images.csv
|
97 |
+
├── train
|
98 |
+
│ ├── 000002b66c9c498e.jpg
|
99 |
+
│ ├── 000002b97e5471a0.jpg
|
100 |
+
│ └── ...
|
101 |
+
├── train-images-boxable.csv
|
102 |
+
├── validation
|
103 |
+
│ ├── 0001eeaf4aed83f9.jpg
|
104 |
+
│ ├── 0004886b7d043cfd.jpg
|
105 |
+
│ └── ...
|
106 |
+
├── validation-annotations-bbox.csv
|
107 |
+
└── validation-images.csv
|
108 |
+
@param: split: one of 'train', 'validation' or 'test'
|
109 |
+
@param: desired image size (returns square images)
|
110 |
+
"""
|
111 |
+
|
112 |
+
super().__init__(**kwargs)
|
113 |
+
self.use_additional_parameters = use_additional_parameters
|
114 |
+
|
115 |
+
self.categories = load_categories(self.paths['class_descriptions'])
|
116 |
+
self.filter_categories()
|
117 |
+
self.setup_category_id_and_number()
|
118 |
+
|
119 |
+
self.image_descriptions = {}
|
120 |
+
annotations = load_annotations(self.paths['annotations'], self.min_object_area, self.category_mapping,
|
121 |
+
self.category_number)
|
122 |
+
self.annotations = self.filter_object_number(annotations, self.min_object_area, self.min_objects_per_image,
|
123 |
+
self.max_objects_per_image)
|
124 |
+
self.image_ids = list(self.annotations.keys())
|
125 |
+
self.clean_up_annotations_and_image_descriptions()
|
126 |
+
|
127 |
+
def get_path_structure(self) -> Dict[str, str]:
|
128 |
+
if self.split not in OPEN_IMAGES_STRUCTURE:
|
129 |
+
raise ValueError(f'Split [{self.split} does not exist for Open Images data.]')
|
130 |
+
return OPEN_IMAGES_STRUCTURE[self.split]
|
131 |
+
|
132 |
+
def get_image_path(self, image_id: str) -> Path:
|
133 |
+
return self.paths['files'].joinpath(f'{image_id:0>16}.jpg')
|
134 |
+
|
135 |
+
def get_image_description(self, image_id: str) -> Dict[str, Any]:
|
136 |
+
image_path = self.get_image_path(image_id)
|
137 |
+
return {'file_path': str(image_path), 'file_name': image_path.name}
|
taming/data/base.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import bisect
|
2 |
+
import numpy as np
|
3 |
+
import albumentations
|
4 |
+
from PIL import Image
|
5 |
+
from torch.utils.data import Dataset, ConcatDataset
|
6 |
+
|
7 |
+
|
8 |
+
class ConcatDatasetWithIndex(ConcatDataset):
|
9 |
+
"""Modified from original pytorch code to return dataset idx"""
|
10 |
+
def __getitem__(self, idx):
|
11 |
+
if idx < 0:
|
12 |
+
if -idx > len(self):
|
13 |
+
raise ValueError("absolute value of index should not exceed dataset length")
|
14 |
+
idx = len(self) + idx
|
15 |
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
16 |
+
if dataset_idx == 0:
|
17 |
+
sample_idx = idx
|
18 |
+
else:
|
19 |
+
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
20 |
+
return self.datasets[dataset_idx][sample_idx], dataset_idx
|
21 |
+
|
22 |
+
|
23 |
+
class ImagePaths(Dataset):
|
24 |
+
def __init__(self, paths, size=None, random_crop=False, labels=None):
|
25 |
+
self.size = size
|
26 |
+
self.random_crop = random_crop
|
27 |
+
|
28 |
+
self.labels = dict() if labels is None else labels
|
29 |
+
self.labels["file_path_"] = paths
|
30 |
+
self._length = len(paths)
|
31 |
+
|
32 |
+
if self.size is not None and self.size > 0:
|
33 |
+
self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
|
34 |
+
if not self.random_crop:
|
35 |
+
self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
|
36 |
+
else:
|
37 |
+
self.cropper = albumentations.RandomCrop(height=self.size,width=self.size)
|
38 |
+
self.preprocessor = albumentations.Compose([self.rescaler, self.cropper])
|
39 |
+
else:
|
40 |
+
self.preprocessor = lambda **kwargs: kwargs
|
41 |
+
|
42 |
+
def __len__(self):
|
43 |
+
return self._length
|
44 |
+
|
45 |
+
def preprocess_image(self, image_path):
|
46 |
+
image = Image.open(image_path)
|
47 |
+
if not image.mode == "RGB":
|
48 |
+
image = image.convert("RGB")
|
49 |
+
image = np.array(image).astype(np.uint8)
|
50 |
+
image = self.preprocessor(image=image)["image"]
|
51 |
+
image = (image/127.5 - 1.0).astype(np.float32)
|
52 |
+
return image
|
53 |
+
|
54 |
+
def __getitem__(self, i):
|
55 |
+
example = dict()
|
56 |
+
example["image"] = self.preprocess_image(self.labels["file_path_"][i])
|
57 |
+
for k in self.labels:
|
58 |
+
example[k] = self.labels[k][i]
|
59 |
+
return example
|
60 |
+
|
61 |
+
|
62 |
+
class NumpyPaths(ImagePaths):
|
63 |
+
def preprocess_image(self, image_path):
|
64 |
+
image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024
|
65 |
+
image = np.transpose(image, (1,2,0))
|
66 |
+
image = Image.fromarray(image, mode="RGB")
|
67 |
+
image = np.array(image).astype(np.uint8)
|
68 |
+
image = self.preprocessor(image=image)["image"]
|
69 |
+
image = (image/127.5 - 1.0).astype(np.float32)
|
70 |
+
return image
|
taming/data/coco.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import albumentations
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
from tqdm import tqdm
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
|
9 |
+
from taming.data.sflckr import SegmentationBase # for examples included in repo
|
10 |
+
|
11 |
+
|
12 |
+
class Examples(SegmentationBase):
|
13 |
+
def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
|
14 |
+
super().__init__(data_csv="data/coco_examples.txt",
|
15 |
+
data_root="data/coco_images",
|
16 |
+
segmentation_root="data/coco_segmentations",
|
17 |
+
size=size, random_crop=random_crop,
|
18 |
+
interpolation=interpolation,
|
19 |
+
n_labels=183, shift_segmentation=True)
|
20 |
+
|
21 |
+
|
22 |
+
class CocoBase(Dataset):
|
23 |
+
"""needed for (image, caption, segmentation) pairs"""
|
24 |
+
def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False,
|
25 |
+
crop_size=None, force_no_crop=False, given_files=None):
|
26 |
+
self.split = self.get_split()
|
27 |
+
self.size = size
|
28 |
+
if crop_size is None:
|
29 |
+
self.crop_size = size
|
30 |
+
else:
|
31 |
+
self.crop_size = crop_size
|
32 |
+
|
33 |
+
self.onehot = onehot_segmentation # return segmentation as rgb or one hot
|
34 |
+
self.stuffthing = use_stuffthing # include thing in segmentation
|
35 |
+
if self.onehot and not self.stuffthing:
|
36 |
+
raise NotImplemented("One hot mode is only supported for the "
|
37 |
+
"stuffthings version because labels are stored "
|
38 |
+
"a bit different.")
|
39 |
+
|
40 |
+
data_json = datajson
|
41 |
+
with open(data_json) as json_file:
|
42 |
+
self.json_data = json.load(json_file)
|
43 |
+
self.img_id_to_captions = dict()
|
44 |
+
self.img_id_to_filepath = dict()
|
45 |
+
self.img_id_to_segmentation_filepath = dict()
|
46 |
+
|
47 |
+
assert data_json.split("/")[-1] in ["captions_train2017.json",
|
48 |
+
"captions_val2017.json"]
|
49 |
+
if self.stuffthing:
|
50 |
+
self.segmentation_prefix = (
|
51 |
+
"data/cocostuffthings/val2017" if
|
52 |
+
data_json.endswith("captions_val2017.json") else
|
53 |
+
"data/cocostuffthings/train2017")
|
54 |
+
else:
|
55 |
+
self.segmentation_prefix = (
|
56 |
+
"data/coco/annotations/stuff_val2017_pixelmaps" if
|
57 |
+
data_json.endswith("captions_val2017.json") else
|
58 |
+
"data/coco/annotations/stuff_train2017_pixelmaps")
|
59 |
+
|
60 |
+
imagedirs = self.json_data["images"]
|
61 |
+
self.labels = {"image_ids": list()}
|
62 |
+
for imgdir in tqdm(imagedirs, desc="ImgToPath"):
|
63 |
+
self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"])
|
64 |
+
self.img_id_to_captions[imgdir["id"]] = list()
|
65 |
+
pngfilename = imgdir["file_name"].replace("jpg", "png")
|
66 |
+
self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join(
|
67 |
+
self.segmentation_prefix, pngfilename)
|
68 |
+
if given_files is not None:
|
69 |
+
if pngfilename in given_files:
|
70 |
+
self.labels["image_ids"].append(imgdir["id"])
|
71 |
+
else:
|
72 |
+
self.labels["image_ids"].append(imgdir["id"])
|
73 |
+
|
74 |
+
capdirs = self.json_data["annotations"]
|
75 |
+
for capdir in tqdm(capdirs, desc="ImgToCaptions"):
|
76 |
+
# there are in average 5 captions per image
|
77 |
+
self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]]))
|
78 |
+
|
79 |
+
self.rescaler = albumentations.SmallestMaxSize(max_size=self.size)
|
80 |
+
if self.split=="validation":
|
81 |
+
self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
|
82 |
+
else:
|
83 |
+
self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
|
84 |
+
self.preprocessor = albumentations.Compose(
|
85 |
+
[self.rescaler, self.cropper],
|
86 |
+
additional_targets={"segmentation": "image"})
|
87 |
+
if force_no_crop:
|
88 |
+
self.rescaler = albumentations.Resize(height=self.size, width=self.size)
|
89 |
+
self.preprocessor = albumentations.Compose(
|
90 |
+
[self.rescaler],
|
91 |
+
additional_targets={"segmentation": "image"})
|
92 |
+
|
93 |
+
def __len__(self):
|
94 |
+
return len(self.labels["image_ids"])
|
95 |
+
|
96 |
+
def preprocess_image(self, image_path, segmentation_path):
|
97 |
+
image = Image.open(image_path)
|
98 |
+
if not image.mode == "RGB":
|
99 |
+
image = image.convert("RGB")
|
100 |
+
image = np.array(image).astype(np.uint8)
|
101 |
+
|
102 |
+
segmentation = Image.open(segmentation_path)
|
103 |
+
if not self.onehot and not segmentation.mode == "RGB":
|
104 |
+
segmentation = segmentation.convert("RGB")
|
105 |
+
segmentation = np.array(segmentation).astype(np.uint8)
|
106 |
+
if self.onehot:
|
107 |
+
assert self.stuffthing
|
108 |
+
# stored in caffe format: unlabeled==255. stuff and thing from
|
109 |
+
# 0-181. to be compatible with the labels in
|
110 |
+
# https://github.com/nightrome/cocostuff/blob/master/labels.txt
|
111 |
+
# we shift stuffthing one to the right and put unlabeled in zero
|
112 |
+
# as long as segmentation is uint8 shifting to right handles the
|
113 |
+
# latter too
|
114 |
+
assert segmentation.dtype == np.uint8
|
115 |
+
segmentation = segmentation + 1
|
116 |
+
|
117 |
+
processed = self.preprocessor(image=image, segmentation=segmentation)
|
118 |
+
image, segmentation = processed["image"], processed["segmentation"]
|
119 |
+
image = (image / 127.5 - 1.0).astype(np.float32)
|
120 |
+
|
121 |
+
if self.onehot:
|
122 |
+
assert segmentation.dtype == np.uint8
|
123 |
+
# make it one hot
|
124 |
+
n_labels = 183
|
125 |
+
flatseg = np.ravel(segmentation)
|
126 |
+
onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool)
|
127 |
+
onehot[np.arange(flatseg.size), flatseg] = True
|
128 |
+
onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int)
|
129 |
+
segmentation = onehot
|
130 |
+
else:
|
131 |
+
segmentation = (segmentation / 127.5 - 1.0).astype(np.float32)
|
132 |
+
return image, segmentation
|
133 |
+
|
134 |
+
def __getitem__(self, i):
|
135 |
+
img_path = self.img_id_to_filepath[self.labels["image_ids"][i]]
|
136 |
+
seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]]
|
137 |
+
image, segmentation = self.preprocess_image(img_path, seg_path)
|
138 |
+
captions = self.img_id_to_captions[self.labels["image_ids"][i]]
|
139 |
+
# randomly draw one of all available captions per image
|
140 |
+
caption = captions[np.random.randint(0, len(captions))]
|
141 |
+
example = {"image": image,
|
142 |
+
"caption": [str(caption[0])],
|
143 |
+
"segmentation": segmentation,
|
144 |
+
"img_path": img_path,
|
145 |
+
"seg_path": seg_path,
|
146 |
+
"filename_": img_path.split(os.sep)[-1]
|
147 |
+
}
|
148 |
+
return example
|
149 |
+
|
150 |
+
|
151 |
+
class CocoImagesAndCaptionsTrain(CocoBase):
|
152 |
+
"""returns a pair of (image, caption)"""
|
153 |
+
def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False):
|
154 |
+
super().__init__(size=size,
|
155 |
+
dataroot="data/coco/train2017",
|
156 |
+
datajson="data/coco/annotations/captions_train2017.json",
|
157 |
+
onehot_segmentation=onehot_segmentation,
|
158 |
+
use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop)
|
159 |
+
|
160 |
+
def get_split(self):
|
161 |
+
return "train"
|
162 |
+
|
163 |
+
|
164 |
+
class CocoImagesAndCaptionsValidation(CocoBase):
|
165 |
+
"""returns a pair of (image, caption)"""
|
166 |
+
def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,
|
167 |
+
given_files=None):
|
168 |
+
super().__init__(size=size,
|
169 |
+
dataroot="data/coco/val2017",
|
170 |
+
datajson="data/coco/annotations/captions_val2017.json",
|
171 |
+
onehot_segmentation=onehot_segmentation,
|
172 |
+
use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
|
173 |
+
given_files=given_files)
|
174 |
+
|
175 |
+
def get_split(self):
|
176 |
+
return "validation"
|
taming/data/conditional_builder/objects_bbox.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from itertools import cycle
|
2 |
+
from typing import List, Tuple, Callable, Optional
|
3 |
+
|
4 |
+
from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
|
5 |
+
from more_itertools.recipes import grouper
|
6 |
+
from taming.data.image_transforms import convert_pil_to_tensor
|
7 |
+
from torch import LongTensor, Tensor
|
8 |
+
|
9 |
+
from taming.data.helper_types import BoundingBox, Annotation
|
10 |
+
from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
|
11 |
+
from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \
|
12 |
+
pad_list, get_plot_font_size, absolute_bbox
|
13 |
+
|
14 |
+
|
15 |
+
class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder):
|
16 |
+
@property
|
17 |
+
def object_descriptor_length(self) -> int:
|
18 |
+
return 3
|
19 |
+
|
20 |
+
def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
|
21 |
+
object_triples = [
|
22 |
+
(self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox))
|
23 |
+
for ann in annotations
|
24 |
+
]
|
25 |
+
empty_triple = (self.none, self.none, self.none)
|
26 |
+
object_triples = pad_list(object_triples, empty_triple, self.no_max_objects)
|
27 |
+
return object_triples
|
28 |
+
|
29 |
+
def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]:
|
30 |
+
conditional_list = conditional.tolist()
|
31 |
+
crop_coordinates = None
|
32 |
+
if self.encode_crop:
|
33 |
+
crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
|
34 |
+
conditional_list = conditional_list[:-2]
|
35 |
+
object_triples = grouper(conditional_list, 3)
|
36 |
+
assert conditional.shape[0] == self.embedding_dim
|
37 |
+
return [
|
38 |
+
(object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2]))
|
39 |
+
for object_triple in object_triples if object_triple[0] != self.none
|
40 |
+
], crop_coordinates
|
41 |
+
|
42 |
+
def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
|
43 |
+
line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
|
44 |
+
plot = pil_image.new('RGB', figure_size, WHITE)
|
45 |
+
draw = pil_img_draw.Draw(plot)
|
46 |
+
font = ImageFont.truetype(
|
47 |
+
"/usr/share/fonts/truetype/lato/Lato-Regular.ttf",
|
48 |
+
size=get_plot_font_size(font_size, figure_size)
|
49 |
+
)
|
50 |
+
width, height = plot.size
|
51 |
+
description, crop_coordinates = self.inverse_build(conditional)
|
52 |
+
for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)):
|
53 |
+
annotation = self.representation_to_annotation(representation)
|
54 |
+
class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation)
|
55 |
+
bbox = absolute_bbox(bbox, width, height)
|
56 |
+
draw.rectangle(bbox, outline=color, width=line_width)
|
57 |
+
draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font)
|
58 |
+
if crop_coordinates is not None:
|
59 |
+
draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
|
60 |
+
return convert_pil_to_tensor(plot) / 127.5 - 1.
|
taming/data/conditional_builder/objects_center_points.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import warnings
|
4 |
+
from itertools import cycle
|
5 |
+
from typing import List, Optional, Tuple, Callable
|
6 |
+
|
7 |
+
from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
|
8 |
+
from more_itertools.recipes import grouper
|
9 |
+
from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \
|
10 |
+
additional_parameters_string, horizontally_flip_bbox, pad_list, get_circle_size, get_plot_font_size, \
|
11 |
+
absolute_bbox, rescale_annotations
|
12 |
+
from taming.data.helper_types import BoundingBox, Annotation
|
13 |
+
from taming.data.image_transforms import convert_pil_to_tensor
|
14 |
+
from torch import LongTensor, Tensor
|
15 |
+
|
16 |
+
|
17 |
+
class ObjectsCenterPointsConditionalBuilder:
|
18 |
+
def __init__(self, no_object_classes: int, no_max_objects: int, no_tokens: int, encode_crop: bool,
|
19 |
+
use_group_parameter: bool, use_additional_parameters: bool):
|
20 |
+
self.no_object_classes = no_object_classes
|
21 |
+
self.no_max_objects = no_max_objects
|
22 |
+
self.no_tokens = no_tokens
|
23 |
+
self.encode_crop = encode_crop
|
24 |
+
self.no_sections = int(math.sqrt(self.no_tokens))
|
25 |
+
self.use_group_parameter = use_group_parameter
|
26 |
+
self.use_additional_parameters = use_additional_parameters
|
27 |
+
|
28 |
+
@property
|
29 |
+
def none(self) -> int:
|
30 |
+
return self.no_tokens - 1
|
31 |
+
|
32 |
+
@property
|
33 |
+
def object_descriptor_length(self) -> int:
|
34 |
+
return 2
|
35 |
+
|
36 |
+
@property
|
37 |
+
def embedding_dim(self) -> int:
|
38 |
+
extra_length = 2 if self.encode_crop else 0
|
39 |
+
return self.no_max_objects * self.object_descriptor_length + extra_length
|
40 |
+
|
41 |
+
def tokenize_coordinates(self, x: float, y: float) -> int:
|
42 |
+
"""
|
43 |
+
Express 2d coordinates with one number.
|
44 |
+
Example: assume self.no_tokens = 16, then no_sections = 4:
|
45 |
+
0 0 0 0
|
46 |
+
0 0 # 0
|
47 |
+
0 0 0 0
|
48 |
+
0 0 0 x
|
49 |
+
Then the # position corresponds to token 6, the x position to token 15.
|
50 |
+
@param x: float in [0, 1]
|
51 |
+
@param y: float in [0, 1]
|
52 |
+
@return: discrete tokenized coordinate
|
53 |
+
"""
|
54 |
+
x_discrete = int(round(x * (self.no_sections - 1)))
|
55 |
+
y_discrete = int(round(y * (self.no_sections - 1)))
|
56 |
+
return y_discrete * self.no_sections + x_discrete
|
57 |
+
|
58 |
+
def coordinates_from_token(self, token: int) -> (float, float):
|
59 |
+
x = token % self.no_sections
|
60 |
+
y = token // self.no_sections
|
61 |
+
return x / (self.no_sections - 1), y / (self.no_sections - 1)
|
62 |
+
|
63 |
+
def bbox_from_token_pair(self, token1: int, token2: int) -> BoundingBox:
|
64 |
+
x0, y0 = self.coordinates_from_token(token1)
|
65 |
+
x1, y1 = self.coordinates_from_token(token2)
|
66 |
+
return x0, y0, x1 - x0, y1 - y0
|
67 |
+
|
68 |
+
def token_pair_from_bbox(self, bbox: BoundingBox) -> Tuple[int, int]:
|
69 |
+
return self.tokenize_coordinates(bbox[0], bbox[1]), \
|
70 |
+
self.tokenize_coordinates(bbox[0] + bbox[2], bbox[1] + bbox[3])
|
71 |
+
|
72 |
+
def inverse_build(self, conditional: LongTensor) \
|
73 |
+
-> Tuple[List[Tuple[int, Tuple[float, float]]], Optional[BoundingBox]]:
|
74 |
+
conditional_list = conditional.tolist()
|
75 |
+
crop_coordinates = None
|
76 |
+
if self.encode_crop:
|
77 |
+
crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
|
78 |
+
conditional_list = conditional_list[:-2]
|
79 |
+
table_of_content = grouper(conditional_list, self.object_descriptor_length)
|
80 |
+
assert conditional.shape[0] == self.embedding_dim
|
81 |
+
return [
|
82 |
+
(object_tuple[0], self.coordinates_from_token(object_tuple[1]))
|
83 |
+
for object_tuple in table_of_content if object_tuple[0] != self.none
|
84 |
+
], crop_coordinates
|
85 |
+
|
86 |
+
def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
|
87 |
+
line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
|
88 |
+
plot = pil_image.new('RGB', figure_size, WHITE)
|
89 |
+
draw = pil_img_draw.Draw(plot)
|
90 |
+
circle_size = get_circle_size(figure_size)
|
91 |
+
font = ImageFont.truetype('/usr/share/fonts/truetype/lato/Lato-Regular.ttf',
|
92 |
+
size=get_plot_font_size(font_size, figure_size))
|
93 |
+
width, height = plot.size
|
94 |
+
description, crop_coordinates = self.inverse_build(conditional)
|
95 |
+
for (representation, (x, y)), color in zip(description, cycle(COLOR_PALETTE)):
|
96 |
+
x_abs, y_abs = x * width, y * height
|
97 |
+
ann = self.representation_to_annotation(representation)
|
98 |
+
label = label_for_category_no(ann.category_no) + ' ' + additional_parameters_string(ann)
|
99 |
+
ellipse_bbox = [x_abs - circle_size, y_abs - circle_size, x_abs + circle_size, y_abs + circle_size]
|
100 |
+
draw.ellipse(ellipse_bbox, fill=color, width=0)
|
101 |
+
draw.text((x_abs, y_abs), label, anchor='md', fill=BLACK, font=font)
|
102 |
+
if crop_coordinates is not None:
|
103 |
+
draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
|
104 |
+
return convert_pil_to_tensor(plot) / 127.5 - 1.
|
105 |
+
|
106 |
+
def object_representation(self, annotation: Annotation) -> int:
|
107 |
+
modifier = 0
|
108 |
+
if self.use_group_parameter:
|
109 |
+
modifier |= 1 * (annotation.is_group_of is True)
|
110 |
+
if self.use_additional_parameters:
|
111 |
+
modifier |= 2 * (annotation.is_occluded is True)
|
112 |
+
modifier |= 4 * (annotation.is_depiction is True)
|
113 |
+
modifier |= 8 * (annotation.is_inside is True)
|
114 |
+
return annotation.category_no + self.no_object_classes * modifier
|
115 |
+
|
116 |
+
def representation_to_annotation(self, representation: int) -> Annotation:
|
117 |
+
category_no = representation % self.no_object_classes
|
118 |
+
modifier = representation // self.no_object_classes
|
119 |
+
# noinspection PyTypeChecker
|
120 |
+
return Annotation(
|
121 |
+
area=None, image_id=None, bbox=None, category_id=None, id=None, source=None, confidence=None,
|
122 |
+
category_no=category_no,
|
123 |
+
is_group_of=bool((modifier & 1) * self.use_group_parameter),
|
124 |
+
is_occluded=bool((modifier & 2) * self.use_additional_parameters),
|
125 |
+
is_depiction=bool((modifier & 4) * self.use_additional_parameters),
|
126 |
+
is_inside=bool((modifier & 8) * self.use_additional_parameters)
|
127 |
+
)
|
128 |
+
|
129 |
+
def _crop_encoder(self, crop_coordinates: BoundingBox) -> List[int]:
|
130 |
+
return list(self.token_pair_from_bbox(crop_coordinates))
|
131 |
+
|
132 |
+
def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
|
133 |
+
object_tuples = [
|
134 |
+
(self.object_representation(a),
|
135 |
+
self.tokenize_coordinates(a.bbox[0] + a.bbox[2] / 2, a.bbox[1] + a.bbox[3] / 2))
|
136 |
+
for a in annotations
|
137 |
+
]
|
138 |
+
empty_tuple = (self.none, self.none)
|
139 |
+
object_tuples = pad_list(object_tuples, empty_tuple, self.no_max_objects)
|
140 |
+
return object_tuples
|
141 |
+
|
142 |
+
def build(self, annotations: List, crop_coordinates: Optional[BoundingBox] = None, horizontal_flip: bool = False) \
|
143 |
+
-> LongTensor:
|
144 |
+
if len(annotations) == 0:
|
145 |
+
warnings.warn('Did not receive any annotations.')
|
146 |
+
if len(annotations) > self.no_max_objects:
|
147 |
+
warnings.warn('Received more annotations than allowed.')
|
148 |
+
annotations = annotations[:self.no_max_objects]
|
149 |
+
|
150 |
+
if not crop_coordinates:
|
151 |
+
crop_coordinates = FULL_CROP
|
152 |
+
|
153 |
+
random.shuffle(annotations)
|
154 |
+
annotations = filter_annotations(annotations, crop_coordinates)
|
155 |
+
if self.encode_crop:
|
156 |
+
annotations = rescale_annotations(annotations, FULL_CROP, horizontal_flip)
|
157 |
+
if horizontal_flip:
|
158 |
+
crop_coordinates = horizontally_flip_bbox(crop_coordinates)
|
159 |
+
extra = self._crop_encoder(crop_coordinates)
|
160 |
+
else:
|
161 |
+
annotations = rescale_annotations(annotations, crop_coordinates, horizontal_flip)
|
162 |
+
extra = []
|
163 |
+
|
164 |
+
object_tuples = self._make_object_descriptors(annotations)
|
165 |
+
flattened = [token for tuple_ in object_tuples for token in tuple_] + extra
|
166 |
+
assert len(flattened) == self.embedding_dim
|
167 |
+
assert all(0 <= value < self.no_tokens for value in flattened)
|
168 |
+
return LongTensor(flattened)
|
taming/data/conditional_builder/utils.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from typing import List, Any, Tuple, Optional
|
3 |
+
|
4 |
+
from taming.data.helper_types import BoundingBox, Annotation
|
5 |
+
|
6 |
+
# source: seaborn, color palette tab10
|
7 |
+
COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188),
|
8 |
+
(139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)]
|
9 |
+
BLACK = (0, 0, 0)
|
10 |
+
GRAY_75 = (63, 63, 63)
|
11 |
+
GRAY_50 = (127, 127, 127)
|
12 |
+
GRAY_25 = (191, 191, 191)
|
13 |
+
WHITE = (255, 255, 255)
|
14 |
+
FULL_CROP = (0., 0., 1., 1.)
|
15 |
+
|
16 |
+
|
17 |
+
def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float:
|
18 |
+
"""
|
19 |
+
Give intersection area of two rectangles.
|
20 |
+
@param rectangle1: (x0, y0, w, h) of first rectangle
|
21 |
+
@param rectangle2: (x0, y0, w, h) of second rectangle
|
22 |
+
"""
|
23 |
+
rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3]
|
24 |
+
rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3]
|
25 |
+
x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0]))
|
26 |
+
y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1]))
|
27 |
+
return x_overlap * y_overlap
|
28 |
+
|
29 |
+
|
30 |
+
def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox:
|
31 |
+
return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3]
|
32 |
+
|
33 |
+
|
34 |
+
def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]:
|
35 |
+
bbox = relative_bbox
|
36 |
+
bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height
|
37 |
+
return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
|
38 |
+
|
39 |
+
|
40 |
+
def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List:
|
41 |
+
return list_ + [pad_element for _ in range(pad_to_length - len(list_))]
|
42 |
+
|
43 |
+
|
44 |
+
def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \
|
45 |
+
List[Annotation]:
|
46 |
+
def clamp(x: float):
|
47 |
+
return max(min(x, 1.), 0.)
|
48 |
+
|
49 |
+
def rescale_bbox(bbox: BoundingBox) -> BoundingBox:
|
50 |
+
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
|
51 |
+
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
|
52 |
+
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
|
53 |
+
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
|
54 |
+
if flip:
|
55 |
+
x0 = 1 - (x0 + w)
|
56 |
+
return x0, y0, w, h
|
57 |
+
|
58 |
+
return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations]
|
59 |
+
|
60 |
+
|
61 |
+
def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List:
|
62 |
+
return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0]
|
63 |
+
|
64 |
+
|
65 |
+
def additional_parameters_string(annotation: Annotation, short: bool = True) -> str:
|
66 |
+
sl = slice(1) if short else slice(None)
|
67 |
+
string = ''
|
68 |
+
if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside):
|
69 |
+
return string
|
70 |
+
if annotation.is_group_of:
|
71 |
+
string += 'group'[sl] + ','
|
72 |
+
if annotation.is_occluded:
|
73 |
+
string += 'occluded'[sl] + ','
|
74 |
+
if annotation.is_depiction:
|
75 |
+
string += 'depiction'[sl] + ','
|
76 |
+
if annotation.is_inside:
|
77 |
+
string += 'inside'[sl]
|
78 |
+
return '(' + string.strip(",") + ')'
|
79 |
+
|
80 |
+
|
81 |
+
def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int:
|
82 |
+
if font_size is None:
|
83 |
+
font_size = 10
|
84 |
+
if max(figure_size) >= 256:
|
85 |
+
font_size = 12
|
86 |
+
if max(figure_size) >= 512:
|
87 |
+
font_size = 15
|
88 |
+
return font_size
|
89 |
+
|
90 |
+
|
91 |
+
def get_circle_size(figure_size: Tuple[int, int]) -> int:
|
92 |
+
circle_size = 2
|
93 |
+
if max(figure_size) >= 256:
|
94 |
+
circle_size = 3
|
95 |
+
if max(figure_size) >= 512:
|
96 |
+
circle_size = 4
|
97 |
+
return circle_size
|
98 |
+
|
99 |
+
|
100 |
+
def load_object_from_string(object_string: str) -> Any:
|
101 |
+
"""
|
102 |
+
Source: https://stackoverflow.com/a/10773699
|
103 |
+
"""
|
104 |
+
module_name, class_name = object_string.rsplit(".", 1)
|
105 |
+
return getattr(importlib.import_module(module_name), class_name)
|
taming/data/custom.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import albumentations
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
|
6 |
+
from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
|
7 |
+
|
8 |
+
|
9 |
+
class CustomBase(Dataset):
|
10 |
+
def __init__(self, *args, **kwargs):
|
11 |
+
super().__init__()
|
12 |
+
self.data = None
|
13 |
+
|
14 |
+
def __len__(self):
|
15 |
+
return len(self.data)
|
16 |
+
|
17 |
+
def __getitem__(self, i):
|
18 |
+
example = self.data[i]
|
19 |
+
return example
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
class CustomTrain(CustomBase):
|
24 |
+
def __init__(self, size, training_images_list_file):
|
25 |
+
super().__init__()
|
26 |
+
with open(training_images_list_file, "r") as f:
|
27 |
+
paths = f.read().splitlines()
|
28 |
+
self.data = ImagePaths(paths=paths, size=size, random_crop=False)
|
29 |
+
|
30 |
+
|
31 |
+
class CustomTest(CustomBase):
|
32 |
+
def __init__(self, size, test_images_list_file):
|
33 |
+
super().__init__()
|
34 |
+
with open(test_images_list_file, "r") as f:
|
35 |
+
paths = f.read().splitlines()
|
36 |
+
self.data = ImagePaths(paths=paths, size=size, random_crop=False)
|
37 |
+
|
38 |
+
|
taming/data/faceshq.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import albumentations
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
|
6 |
+
from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
|
7 |
+
|
8 |
+
|
9 |
+
class FacesBase(Dataset):
|
10 |
+
def __init__(self, *args, **kwargs):
|
11 |
+
super().__init__()
|
12 |
+
self.data = None
|
13 |
+
self.keys = None
|
14 |
+
|
15 |
+
def __len__(self):
|
16 |
+
return len(self.data)
|
17 |
+
|
18 |
+
def __getitem__(self, i):
|
19 |
+
example = self.data[i]
|
20 |
+
ex = {}
|
21 |
+
if self.keys is not None:
|
22 |
+
for k in self.keys:
|
23 |
+
ex[k] = example[k]
|
24 |
+
else:
|
25 |
+
ex = example
|
26 |
+
return ex
|
27 |
+
|
28 |
+
|
29 |
+
class CelebAHQTrain(FacesBase):
|
30 |
+
def __init__(self, size, keys=None):
|
31 |
+
super().__init__()
|
32 |
+
root = "data/celebahq"
|
33 |
+
with open("data/celebahqtrain.txt", "r") as f:
|
34 |
+
relpaths = f.read().splitlines()
|
35 |
+
paths = [os.path.join(root, relpath) for relpath in relpaths]
|
36 |
+
self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
|
37 |
+
self.keys = keys
|
38 |
+
|
39 |
+
|
40 |
+
class CelebAHQValidation(FacesBase):
|
41 |
+
def __init__(self, size, keys=None):
|
42 |
+
super().__init__()
|
43 |
+
root = "data/celebahq"
|
44 |
+
with open("data/celebahqvalidation.txt", "r") as f:
|
45 |
+
relpaths = f.read().splitlines()
|
46 |
+
paths = [os.path.join(root, relpath) for relpath in relpaths]
|
47 |
+
self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
|
48 |
+
self.keys = keys
|
49 |
+
|
50 |
+
|
51 |
+
class FFHQTrain(FacesBase):
|
52 |
+
def __init__(self, size, keys=None):
|
53 |
+
super().__init__()
|
54 |
+
root = "data/ffhq"
|
55 |
+
with open("data/ffhqtrain.txt", "r") as f:
|
56 |
+
relpaths = f.read().splitlines()
|
57 |
+
paths = [os.path.join(root, relpath) for relpath in relpaths]
|
58 |
+
self.data = ImagePaths(paths=paths, size=size, random_crop=False)
|
59 |
+
self.keys = keys
|
60 |
+
|
61 |
+
|
62 |
+
class FFHQValidation(FacesBase):
|
63 |
+
def __init__(self, size, keys=None):
|
64 |
+
super().__init__()
|
65 |
+
root = "data/ffhq"
|
66 |
+
with open("data/ffhqvalidation.txt", "r") as f:
|
67 |
+
relpaths = f.read().splitlines()
|
68 |
+
paths = [os.path.join(root, relpath) for relpath in relpaths]
|
69 |
+
self.data = ImagePaths(paths=paths, size=size, random_crop=False)
|
70 |
+
self.keys = keys
|
71 |
+
|
72 |
+
|
73 |
+
class FacesHQTrain(Dataset):
|
74 |
+
# CelebAHQ [0] + FFHQ [1]
|
75 |
+
def __init__(self, size, keys=None, crop_size=None, coord=False):
|
76 |
+
d1 = CelebAHQTrain(size=size, keys=keys)
|
77 |
+
d2 = FFHQTrain(size=size, keys=keys)
|
78 |
+
self.data = ConcatDatasetWithIndex([d1, d2])
|
79 |
+
self.coord = coord
|
80 |
+
if crop_size is not None:
|
81 |
+
self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
|
82 |
+
if self.coord:
|
83 |
+
self.cropper = albumentations.Compose([self.cropper],
|
84 |
+
additional_targets={"coord": "image"})
|
85 |
+
|
86 |
+
def __len__(self):
|
87 |
+
return len(self.data)
|
88 |
+
|
89 |
+
def __getitem__(self, i):
|
90 |
+
ex, y = self.data[i]
|
91 |
+
if hasattr(self, "cropper"):
|
92 |
+
if not self.coord:
|
93 |
+
out = self.cropper(image=ex["image"])
|
94 |
+
ex["image"] = out["image"]
|
95 |
+
else:
|
96 |
+
h,w,_ = ex["image"].shape
|
97 |
+
coord = np.arange(h*w).reshape(h,w,1)/(h*w)
|
98 |
+
out = self.cropper(image=ex["image"], coord=coord)
|
99 |
+
ex["image"] = out["image"]
|
100 |
+
ex["coord"] = out["coord"]
|
101 |
+
ex["class"] = y
|
102 |
+
return ex
|
103 |
+
|
104 |
+
|
105 |
+
class FacesHQValidation(Dataset):
|
106 |
+
# CelebAHQ [0] + FFHQ [1]
|
107 |
+
def __init__(self, size, keys=None, crop_size=None, coord=False):
|
108 |
+
d1 = CelebAHQValidation(size=size, keys=keys)
|
109 |
+
d2 = FFHQValidation(size=size, keys=keys)
|
110 |
+
self.data = ConcatDatasetWithIndex([d1, d2])
|
111 |
+
self.coord = coord
|
112 |
+
if crop_size is not None:
|
113 |
+
self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
|
114 |
+
if self.coord:
|
115 |
+
self.cropper = albumentations.Compose([self.cropper],
|
116 |
+
additional_targets={"coord": "image"})
|
117 |
+
|
118 |
+
def __len__(self):
|
119 |
+
return len(self.data)
|
120 |
+
|
121 |
+
def __getitem__(self, i):
|
122 |
+
ex, y = self.data[i]
|
123 |
+
if hasattr(self, "cropper"):
|
124 |
+
if not self.coord:
|
125 |
+
out = self.cropper(image=ex["image"])
|
126 |
+
ex["image"] = out["image"]
|
127 |
+
else:
|
128 |
+
h,w,_ = ex["image"].shape
|
129 |
+
coord = np.arange(h*w).reshape(h,w,1)/(h*w)
|
130 |
+
out = self.cropper(image=ex["image"], coord=coord)
|
131 |
+
ex["image"] = out["image"]
|
132 |
+
ex["coord"] = out["coord"]
|
133 |
+
ex["class"] = y
|
134 |
+
return ex
|
taming/data/helper_types.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Tuple, Optional, NamedTuple, Union
|
2 |
+
from PIL.Image import Image as pil_image
|
3 |
+
from torch import Tensor
|
4 |
+
|
5 |
+
try:
|
6 |
+
from typing import Literal
|
7 |
+
except ImportError:
|
8 |
+
from typing_extensions import Literal
|
9 |
+
|
10 |
+
Image = Union[Tensor, pil_image]
|
11 |
+
BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h
|
12 |
+
CropMethodType = Literal['none', 'random', 'center', 'random-2d']
|
13 |
+
SplitType = Literal['train', 'validation', 'test']
|
14 |
+
|
15 |
+
|
16 |
+
class ImageDescription(NamedTuple):
|
17 |
+
id: int
|
18 |
+
file_name: str
|
19 |
+
original_size: Tuple[int, int] # w, h
|
20 |
+
url: Optional[str] = None
|
21 |
+
license: Optional[int] = None
|
22 |
+
coco_url: Optional[str] = None
|
23 |
+
date_captured: Optional[str] = None
|
24 |
+
flickr_url: Optional[str] = None
|
25 |
+
flickr_id: Optional[str] = None
|
26 |
+
coco_id: Optional[str] = None
|
27 |
+
|
28 |
+
|
29 |
+
class Category(NamedTuple):
|
30 |
+
id: str
|
31 |
+
super_category: Optional[str]
|
32 |
+
name: str
|
33 |
+
|
34 |
+
|
35 |
+
class Annotation(NamedTuple):
|
36 |
+
area: float
|
37 |
+
image_id: str
|
38 |
+
bbox: BoundingBox
|
39 |
+
category_no: int
|
40 |
+
category_id: str
|
41 |
+
id: Optional[int] = None
|
42 |
+
source: Optional[str] = None
|
43 |
+
confidence: Optional[float] = None
|
44 |
+
is_group_of: Optional[bool] = None
|
45 |
+
is_truncated: Optional[bool] = None
|
46 |
+
is_occluded: Optional[bool] = None
|
47 |
+
is_depiction: Optional[bool] = None
|
48 |
+
is_inside: Optional[bool] = None
|
49 |
+
segmentation: Optional[Dict] = None
|
taming/data/image_transforms.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import warnings
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import Tensor
|
7 |
+
from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor
|
8 |
+
from torchvision.transforms.functional import _get_image_size as get_image_size
|
9 |
+
|
10 |
+
from taming.data.helper_types import BoundingBox, Image
|
11 |
+
|
12 |
+
pil_to_tensor = PILToTensor()
|
13 |
+
|
14 |
+
|
15 |
+
def convert_pil_to_tensor(image: Image) -> Tensor:
|
16 |
+
with warnings.catch_warnings():
|
17 |
+
# to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194
|
18 |
+
warnings.simplefilter("ignore")
|
19 |
+
return pil_to_tensor(image)
|
20 |
+
|
21 |
+
|
22 |
+
class RandomCrop1dReturnCoordinates(RandomCrop):
|
23 |
+
def forward(self, img: Image) -> (BoundingBox, Image):
|
24 |
+
"""
|
25 |
+
Additionally to cropping, returns the relative coordinates of the crop bounding box.
|
26 |
+
Args:
|
27 |
+
img (PIL Image or Tensor): Image to be cropped.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
Bounding box: x0, y0, w, h
|
31 |
+
PIL Image or Tensor: Cropped image.
|
32 |
+
|
33 |
+
Based on:
|
34 |
+
torchvision.transforms.RandomCrop, torchvision 1.7.0
|
35 |
+
"""
|
36 |
+
if self.padding is not None:
|
37 |
+
img = F.pad(img, self.padding, self.fill, self.padding_mode)
|
38 |
+
|
39 |
+
width, height = get_image_size(img)
|
40 |
+
# pad the width if needed
|
41 |
+
if self.pad_if_needed and width < self.size[1]:
|
42 |
+
padding = [self.size[1] - width, 0]
|
43 |
+
img = F.pad(img, padding, self.fill, self.padding_mode)
|
44 |
+
# pad the height if needed
|
45 |
+
if self.pad_if_needed and height < self.size[0]:
|
46 |
+
padding = [0, self.size[0] - height]
|
47 |
+
img = F.pad(img, padding, self.fill, self.padding_mode)
|
48 |
+
|
49 |
+
i, j, h, w = self.get_params(img, self.size)
|
50 |
+
bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h
|
51 |
+
return bbox, F.crop(img, i, j, h, w)
|
52 |
+
|
53 |
+
|
54 |
+
class Random2dCropReturnCoordinates(torch.nn.Module):
|
55 |
+
"""
|
56 |
+
Additionally to cropping, returns the relative coordinates of the crop bounding box.
|
57 |
+
Args:
|
58 |
+
img (PIL Image or Tensor): Image to be cropped.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
Bounding box: x0, y0, w, h
|
62 |
+
PIL Image or Tensor: Cropped image.
|
63 |
+
|
64 |
+
Based on:
|
65 |
+
torchvision.transforms.RandomCrop, torchvision 1.7.0
|
66 |
+
"""
|
67 |
+
|
68 |
+
def __init__(self, min_size: int):
|
69 |
+
super().__init__()
|
70 |
+
self.min_size = min_size
|
71 |
+
|
72 |
+
def forward(self, img: Image) -> (BoundingBox, Image):
|
73 |
+
width, height = get_image_size(img)
|
74 |
+
max_size = min(width, height)
|
75 |
+
if max_size <= self.min_size:
|
76 |
+
size = max_size
|
77 |
+
else:
|
78 |
+
size = random.randint(self.min_size, max_size)
|
79 |
+
top = random.randint(0, height - size)
|
80 |
+
left = random.randint(0, width - size)
|
81 |
+
bbox = left / width, top / height, size / width, size / height
|
82 |
+
return bbox, F.crop(img, top, left, size, size)
|
83 |
+
|
84 |
+
|
85 |
+
class CenterCropReturnCoordinates(CenterCrop):
|
86 |
+
@staticmethod
|
87 |
+
def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox:
|
88 |
+
if width > height:
|
89 |
+
w = height / width
|
90 |
+
h = 1.0
|
91 |
+
x0 = 0.5 - w / 2
|
92 |
+
y0 = 0.
|
93 |
+
else:
|
94 |
+
w = 1.0
|
95 |
+
h = width / height
|
96 |
+
x0 = 0.
|
97 |
+
y0 = 0.5 - h / 2
|
98 |
+
return x0, y0, w, h
|
99 |
+
|
100 |
+
def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]):
|
101 |
+
"""
|
102 |
+
Additionally to cropping, returns the relative coordinates of the crop bounding box.
|
103 |
+
Args:
|
104 |
+
img (PIL Image or Tensor): Image to be cropped.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
Bounding box: x0, y0, w, h
|
108 |
+
PIL Image or Tensor: Cropped image.
|
109 |
+
Based on:
|
110 |
+
torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
|
111 |
+
"""
|
112 |
+
width, height = get_image_size(img)
|
113 |
+
return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size)
|
114 |
+
|
115 |
+
|
116 |
+
class RandomHorizontalFlipReturn(RandomHorizontalFlip):
|
117 |
+
def forward(self, img: Image) -> (bool, Image):
|
118 |
+
"""
|
119 |
+
Additionally to flipping, returns a boolean whether it was flipped or not.
|
120 |
+
Args:
|
121 |
+
img (PIL Image or Tensor): Image to be flipped.
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
flipped: whether the image was flipped or not
|
125 |
+
PIL Image or Tensor: Randomly flipped image.
|
126 |
+
|
127 |
+
Based on:
|
128 |
+
torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
|
129 |
+
"""
|
130 |
+
if torch.rand(1) < self.p:
|
131 |
+
return True, F.hflip(img)
|
132 |
+
return False, img
|
taming/data/imagenet.py
ADDED
@@ -0,0 +1,558 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, tarfile, glob, shutil
|
2 |
+
import yaml
|
3 |
+
import numpy as np
|
4 |
+
from tqdm import tqdm
|
5 |
+
from PIL import Image
|
6 |
+
import albumentations
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
|
10 |
+
from taming.data.base import ImagePaths
|
11 |
+
from taming.util import download, retrieve
|
12 |
+
import taming.data.utils as bdu
|
13 |
+
|
14 |
+
|
15 |
+
def give_synsets_from_indices(indices, path_to_yaml="data/imagenet_idx_to_synset.yaml"):
|
16 |
+
synsets = []
|
17 |
+
with open(path_to_yaml) as f:
|
18 |
+
di2s = yaml.load(f)
|
19 |
+
for idx in indices:
|
20 |
+
synsets.append(str(di2s[idx]))
|
21 |
+
print("Using {} different synsets for construction of Restriced Imagenet.".format(len(synsets)))
|
22 |
+
return synsets
|
23 |
+
|
24 |
+
|
25 |
+
def str_to_indices(string):
|
26 |
+
"""Expects a string in the format '32-123, 256, 280-321'"""
|
27 |
+
assert not string.endswith(","), "provided string '{}' ends with a comma, pls remove it".format(string)
|
28 |
+
subs = string.split(",")
|
29 |
+
indices = []
|
30 |
+
for sub in subs:
|
31 |
+
subsubs = sub.split("-")
|
32 |
+
assert len(subsubs) > 0
|
33 |
+
if len(subsubs) == 1:
|
34 |
+
indices.append(int(subsubs[0]))
|
35 |
+
else:
|
36 |
+
rang = [j for j in range(int(subsubs[0]), int(subsubs[1]))]
|
37 |
+
indices.extend(rang)
|
38 |
+
return sorted(indices)
|
39 |
+
|
40 |
+
|
41 |
+
class ImageNetBase(Dataset):
|
42 |
+
def __init__(self, config=None):
|
43 |
+
self.config = config or OmegaConf.create()
|
44 |
+
if not type(self.config)==dict:
|
45 |
+
self.config = OmegaConf.to_container(self.config)
|
46 |
+
self._prepare()
|
47 |
+
self._prepare_synset_to_human()
|
48 |
+
self._prepare_idx_to_synset()
|
49 |
+
self._load()
|
50 |
+
|
51 |
+
def __len__(self):
|
52 |
+
return len(self.data)
|
53 |
+
|
54 |
+
def __getitem__(self, i):
|
55 |
+
return self.data[i]
|
56 |
+
|
57 |
+
def _prepare(self):
|
58 |
+
raise NotImplementedError()
|
59 |
+
|
60 |
+
def _filter_relpaths(self, relpaths):
|
61 |
+
ignore = set([
|
62 |
+
"n06596364_9591.JPEG",
|
63 |
+
])
|
64 |
+
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
|
65 |
+
if "sub_indices" in self.config:
|
66 |
+
indices = str_to_indices(self.config["sub_indices"])
|
67 |
+
synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
|
68 |
+
files = []
|
69 |
+
for rpath in relpaths:
|
70 |
+
syn = rpath.split("/")[0]
|
71 |
+
if syn in synsets:
|
72 |
+
files.append(rpath)
|
73 |
+
return files
|
74 |
+
else:
|
75 |
+
return relpaths
|
76 |
+
|
77 |
+
def _prepare_synset_to_human(self):
|
78 |
+
SIZE = 2655750
|
79 |
+
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
|
80 |
+
self.human_dict = os.path.join(self.root, "synset_human.txt")
|
81 |
+
if (not os.path.exists(self.human_dict) or
|
82 |
+
not os.path.getsize(self.human_dict)==SIZE):
|
83 |
+
download(URL, self.human_dict)
|
84 |
+
|
85 |
+
def _prepare_idx_to_synset(self):
|
86 |
+
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
|
87 |
+
self.idx2syn = os.path.join(self.root, "index_synset.yaml")
|
88 |
+
if (not os.path.exists(self.idx2syn)):
|
89 |
+
download(URL, self.idx2syn)
|
90 |
+
|
91 |
+
def _load(self):
|
92 |
+
with open(self.txt_filelist, "r") as f:
|
93 |
+
self.relpaths = f.read().splitlines()
|
94 |
+
l1 = len(self.relpaths)
|
95 |
+
self.relpaths = self._filter_relpaths(self.relpaths)
|
96 |
+
print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
|
97 |
+
|
98 |
+
self.synsets = [p.split("/")[0] for p in self.relpaths]
|
99 |
+
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
|
100 |
+
|
101 |
+
unique_synsets = np.unique(self.synsets)
|
102 |
+
class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
|
103 |
+
self.class_labels = [class_dict[s] for s in self.synsets]
|
104 |
+
|
105 |
+
with open(self.human_dict, "r") as f:
|
106 |
+
human_dict = f.read().splitlines()
|
107 |
+
human_dict = dict(line.split(maxsplit=1) for line in human_dict)
|
108 |
+
|
109 |
+
self.human_labels = [human_dict[s] for s in self.synsets]
|
110 |
+
|
111 |
+
labels = {
|
112 |
+
"relpath": np.array(self.relpaths),
|
113 |
+
"synsets": np.array(self.synsets),
|
114 |
+
"class_label": np.array(self.class_labels),
|
115 |
+
"human_label": np.array(self.human_labels),
|
116 |
+
}
|
117 |
+
self.data = ImagePaths(self.abspaths,
|
118 |
+
labels=labels,
|
119 |
+
size=retrieve(self.config, "size", default=0),
|
120 |
+
random_crop=self.random_crop)
|
121 |
+
|
122 |
+
|
123 |
+
class ImageNetTrain(ImageNetBase):
|
124 |
+
NAME = "ILSVRC2012_train"
|
125 |
+
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
|
126 |
+
AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
|
127 |
+
FILES = [
|
128 |
+
"ILSVRC2012_img_train.tar",
|
129 |
+
]
|
130 |
+
SIZES = [
|
131 |
+
147897477120,
|
132 |
+
]
|
133 |
+
|
134 |
+
def _prepare(self):
|
135 |
+
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
|
136 |
+
default=True)
|
137 |
+
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
138 |
+
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
|
139 |
+
self.datadir = os.path.join(self.root, "data")
|
140 |
+
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
141 |
+
self.expected_length = 1281167
|
142 |
+
if not bdu.is_prepared(self.root):
|
143 |
+
# prep
|
144 |
+
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
145 |
+
|
146 |
+
datadir = self.datadir
|
147 |
+
if not os.path.exists(datadir):
|
148 |
+
path = os.path.join(self.root, self.FILES[0])
|
149 |
+
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
|
150 |
+
import academictorrents as at
|
151 |
+
atpath = at.get(self.AT_HASH, datastore=self.root)
|
152 |
+
assert atpath == path
|
153 |
+
|
154 |
+
print("Extracting {} to {}".format(path, datadir))
|
155 |
+
os.makedirs(datadir, exist_ok=True)
|
156 |
+
with tarfile.open(path, "r:") as tar:
|
157 |
+
tar.extractall(path=datadir)
|
158 |
+
|
159 |
+
print("Extracting sub-tars.")
|
160 |
+
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
|
161 |
+
for subpath in tqdm(subpaths):
|
162 |
+
subdir = subpath[:-len(".tar")]
|
163 |
+
os.makedirs(subdir, exist_ok=True)
|
164 |
+
with tarfile.open(subpath, "r:") as tar:
|
165 |
+
tar.extractall(path=subdir)
|
166 |
+
|
167 |
+
|
168 |
+
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
169 |
+
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
170 |
+
filelist = sorted(filelist)
|
171 |
+
filelist = "\n".join(filelist)+"\n"
|
172 |
+
with open(self.txt_filelist, "w") as f:
|
173 |
+
f.write(filelist)
|
174 |
+
|
175 |
+
bdu.mark_prepared(self.root)
|
176 |
+
|
177 |
+
|
178 |
+
class ImageNetValidation(ImageNetBase):
|
179 |
+
NAME = "ILSVRC2012_validation"
|
180 |
+
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
|
181 |
+
AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
|
182 |
+
VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
|
183 |
+
FILES = [
|
184 |
+
"ILSVRC2012_img_val.tar",
|
185 |
+
"validation_synset.txt",
|
186 |
+
]
|
187 |
+
SIZES = [
|
188 |
+
6744924160,
|
189 |
+
1950000,
|
190 |
+
]
|
191 |
+
|
192 |
+
def _prepare(self):
|
193 |
+
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
|
194 |
+
default=False)
|
195 |
+
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
196 |
+
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
|
197 |
+
self.datadir = os.path.join(self.root, "data")
|
198 |
+
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
199 |
+
self.expected_length = 50000
|
200 |
+
if not bdu.is_prepared(self.root):
|
201 |
+
# prep
|
202 |
+
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
203 |
+
|
204 |
+
datadir = self.datadir
|
205 |
+
if not os.path.exists(datadir):
|
206 |
+
path = os.path.join(self.root, self.FILES[0])
|
207 |
+
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
|
208 |
+
import academictorrents as at
|
209 |
+
atpath = at.get(self.AT_HASH, datastore=self.root)
|
210 |
+
assert atpath == path
|
211 |
+
|
212 |
+
print("Extracting {} to {}".format(path, datadir))
|
213 |
+
os.makedirs(datadir, exist_ok=True)
|
214 |
+
with tarfile.open(path, "r:") as tar:
|
215 |
+
tar.extractall(path=datadir)
|
216 |
+
|
217 |
+
vspath = os.path.join(self.root, self.FILES[1])
|
218 |
+
if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
|
219 |
+
download(self.VS_URL, vspath)
|
220 |
+
|
221 |
+
with open(vspath, "r") as f:
|
222 |
+
synset_dict = f.read().splitlines()
|
223 |
+
synset_dict = dict(line.split() for line in synset_dict)
|
224 |
+
|
225 |
+
print("Reorganizing into synset folders")
|
226 |
+
synsets = np.unique(list(synset_dict.values()))
|
227 |
+
for s in synsets:
|
228 |
+
os.makedirs(os.path.join(datadir, s), exist_ok=True)
|
229 |
+
for k, v in synset_dict.items():
|
230 |
+
src = os.path.join(datadir, k)
|
231 |
+
dst = os.path.join(datadir, v)
|
232 |
+
shutil.move(src, dst)
|
233 |
+
|
234 |
+
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
235 |
+
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
236 |
+
filelist = sorted(filelist)
|
237 |
+
filelist = "\n".join(filelist)+"\n"
|
238 |
+
with open(self.txt_filelist, "w") as f:
|
239 |
+
f.write(filelist)
|
240 |
+
|
241 |
+
bdu.mark_prepared(self.root)
|
242 |
+
|
243 |
+
|
244 |
+
def get_preprocessor(size=None, random_crop=False, additional_targets=None,
|
245 |
+
crop_size=None):
|
246 |
+
if size is not None and size > 0:
|
247 |
+
transforms = list()
|
248 |
+
rescaler = albumentations.SmallestMaxSize(max_size = size)
|
249 |
+
transforms.append(rescaler)
|
250 |
+
if not random_crop:
|
251 |
+
cropper = albumentations.CenterCrop(height=size,width=size)
|
252 |
+
transforms.append(cropper)
|
253 |
+
else:
|
254 |
+
cropper = albumentations.RandomCrop(height=size,width=size)
|
255 |
+
transforms.append(cropper)
|
256 |
+
flipper = albumentations.HorizontalFlip()
|
257 |
+
transforms.append(flipper)
|
258 |
+
preprocessor = albumentations.Compose(transforms,
|
259 |
+
additional_targets=additional_targets)
|
260 |
+
elif crop_size is not None and crop_size > 0:
|
261 |
+
if not random_crop:
|
262 |
+
cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
|
263 |
+
else:
|
264 |
+
cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
|
265 |
+
transforms = [cropper]
|
266 |
+
preprocessor = albumentations.Compose(transforms,
|
267 |
+
additional_targets=additional_targets)
|
268 |
+
else:
|
269 |
+
preprocessor = lambda **kwargs: kwargs
|
270 |
+
return preprocessor
|
271 |
+
|
272 |
+
|
273 |
+
def rgba_to_depth(x):
|
274 |
+
assert x.dtype == np.uint8
|
275 |
+
assert len(x.shape) == 3 and x.shape[2] == 4
|
276 |
+
y = x.copy()
|
277 |
+
y.dtype = np.float32
|
278 |
+
y = y.reshape(x.shape[:2])
|
279 |
+
return np.ascontiguousarray(y)
|
280 |
+
|
281 |
+
|
282 |
+
class BaseWithDepth(Dataset):
|
283 |
+
DEFAULT_DEPTH_ROOT="data/imagenet_depth"
|
284 |
+
|
285 |
+
def __init__(self, config=None, size=None, random_crop=False,
|
286 |
+
crop_size=None, root=None):
|
287 |
+
self.config = config
|
288 |
+
self.base_dset = self.get_base_dset()
|
289 |
+
self.preprocessor = get_preprocessor(
|
290 |
+
size=size,
|
291 |
+
crop_size=crop_size,
|
292 |
+
random_crop=random_crop,
|
293 |
+
additional_targets={"depth": "image"})
|
294 |
+
self.crop_size = crop_size
|
295 |
+
if self.crop_size is not None:
|
296 |
+
self.rescaler = albumentations.Compose(
|
297 |
+
[albumentations.SmallestMaxSize(max_size = self.crop_size)],
|
298 |
+
additional_targets={"depth": "image"})
|
299 |
+
if root is not None:
|
300 |
+
self.DEFAULT_DEPTH_ROOT = root
|
301 |
+
|
302 |
+
def __len__(self):
|
303 |
+
return len(self.base_dset)
|
304 |
+
|
305 |
+
def preprocess_depth(self, path):
|
306 |
+
rgba = np.array(Image.open(path))
|
307 |
+
depth = rgba_to_depth(rgba)
|
308 |
+
depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min())
|
309 |
+
depth = 2.0*depth-1.0
|
310 |
+
return depth
|
311 |
+
|
312 |
+
def __getitem__(self, i):
|
313 |
+
e = self.base_dset[i]
|
314 |
+
e["depth"] = self.preprocess_depth(self.get_depth_path(e))
|
315 |
+
# up if necessary
|
316 |
+
h,w,c = e["image"].shape
|
317 |
+
if self.crop_size and min(h,w) < self.crop_size:
|
318 |
+
# have to upscale to be able to crop - this just uses bilinear
|
319 |
+
out = self.rescaler(image=e["image"], depth=e["depth"])
|
320 |
+
e["image"] = out["image"]
|
321 |
+
e["depth"] = out["depth"]
|
322 |
+
transformed = self.preprocessor(image=e["image"], depth=e["depth"])
|
323 |
+
e["image"] = transformed["image"]
|
324 |
+
e["depth"] = transformed["depth"]
|
325 |
+
return e
|
326 |
+
|
327 |
+
|
328 |
+
class ImageNetTrainWithDepth(BaseWithDepth):
|
329 |
+
# default to random_crop=True
|
330 |
+
def __init__(self, random_crop=True, sub_indices=None, **kwargs):
|
331 |
+
self.sub_indices = sub_indices
|
332 |
+
super().__init__(random_crop=random_crop, **kwargs)
|
333 |
+
|
334 |
+
def get_base_dset(self):
|
335 |
+
if self.sub_indices is None:
|
336 |
+
return ImageNetTrain()
|
337 |
+
else:
|
338 |
+
return ImageNetTrain({"sub_indices": self.sub_indices})
|
339 |
+
|
340 |
+
def get_depth_path(self, e):
|
341 |
+
fid = os.path.splitext(e["relpath"])[0]+".png"
|
342 |
+
fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "train", fid)
|
343 |
+
return fid
|
344 |
+
|
345 |
+
|
346 |
+
class ImageNetValidationWithDepth(BaseWithDepth):
|
347 |
+
def __init__(self, sub_indices=None, **kwargs):
|
348 |
+
self.sub_indices = sub_indices
|
349 |
+
super().__init__(**kwargs)
|
350 |
+
|
351 |
+
def get_base_dset(self):
|
352 |
+
if self.sub_indices is None:
|
353 |
+
return ImageNetValidation()
|
354 |
+
else:
|
355 |
+
return ImageNetValidation({"sub_indices": self.sub_indices})
|
356 |
+
|
357 |
+
def get_depth_path(self, e):
|
358 |
+
fid = os.path.splitext(e["relpath"])[0]+".png"
|
359 |
+
fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "val", fid)
|
360 |
+
return fid
|
361 |
+
|
362 |
+
|
363 |
+
class RINTrainWithDepth(ImageNetTrainWithDepth):
|
364 |
+
def __init__(self, config=None, size=None, random_crop=True, crop_size=None):
|
365 |
+
sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319"
|
366 |
+
super().__init__(config=config, size=size, random_crop=random_crop,
|
367 |
+
sub_indices=sub_indices, crop_size=crop_size)
|
368 |
+
|
369 |
+
|
370 |
+
class RINValidationWithDepth(ImageNetValidationWithDepth):
|
371 |
+
def __init__(self, config=None, size=None, random_crop=False, crop_size=None):
|
372 |
+
sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319"
|
373 |
+
super().__init__(config=config, size=size, random_crop=random_crop,
|
374 |
+
sub_indices=sub_indices, crop_size=crop_size)
|
375 |
+
|
376 |
+
|
377 |
+
class DRINExamples(Dataset):
|
378 |
+
def __init__(self):
|
379 |
+
self.preprocessor = get_preprocessor(size=256, additional_targets={"depth": "image"})
|
380 |
+
with open("data/drin_examples.txt", "r") as f:
|
381 |
+
relpaths = f.read().splitlines()
|
382 |
+
self.image_paths = [os.path.join("data/drin_images",
|
383 |
+
relpath) for relpath in relpaths]
|
384 |
+
self.depth_paths = [os.path.join("data/drin_depth",
|
385 |
+
relpath.replace(".JPEG", ".png")) for relpath in relpaths]
|
386 |
+
|
387 |
+
def __len__(self):
|
388 |
+
return len(self.image_paths)
|
389 |
+
|
390 |
+
def preprocess_image(self, image_path):
|
391 |
+
image = Image.open(image_path)
|
392 |
+
if not image.mode == "RGB":
|
393 |
+
image = image.convert("RGB")
|
394 |
+
image = np.array(image).astype(np.uint8)
|
395 |
+
image = self.preprocessor(image=image)["image"]
|
396 |
+
image = (image/127.5 - 1.0).astype(np.float32)
|
397 |
+
return image
|
398 |
+
|
399 |
+
def preprocess_depth(self, path):
|
400 |
+
rgba = np.array(Image.open(path))
|
401 |
+
depth = rgba_to_depth(rgba)
|
402 |
+
depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min())
|
403 |
+
depth = 2.0*depth-1.0
|
404 |
+
return depth
|
405 |
+
|
406 |
+
def __getitem__(self, i):
|
407 |
+
e = dict()
|
408 |
+
e["image"] = self.preprocess_image(self.image_paths[i])
|
409 |
+
e["depth"] = self.preprocess_depth(self.depth_paths[i])
|
410 |
+
transformed = self.preprocessor(image=e["image"], depth=e["depth"])
|
411 |
+
e["image"] = transformed["image"]
|
412 |
+
e["depth"] = transformed["depth"]
|
413 |
+
return e
|
414 |
+
|
415 |
+
|
416 |
+
def imscale(x, factor, keepshapes=False, keepmode="bicubic"):
|
417 |
+
if factor is None or factor==1:
|
418 |
+
return x
|
419 |
+
|
420 |
+
dtype = x.dtype
|
421 |
+
assert dtype in [np.float32, np.float64]
|
422 |
+
assert x.min() >= -1
|
423 |
+
assert x.max() <= 1
|
424 |
+
|
425 |
+
keepmode = {"nearest": Image.NEAREST, "bilinear": Image.BILINEAR,
|
426 |
+
"bicubic": Image.BICUBIC}[keepmode]
|
427 |
+
|
428 |
+
lr = (x+1.0)*127.5
|
429 |
+
lr = lr.clip(0,255).astype(np.uint8)
|
430 |
+
lr = Image.fromarray(lr)
|
431 |
+
|
432 |
+
h, w, _ = x.shape
|
433 |
+
nh = h//factor
|
434 |
+
nw = w//factor
|
435 |
+
assert nh > 0 and nw > 0, (nh, nw)
|
436 |
+
|
437 |
+
lr = lr.resize((nw,nh), Image.BICUBIC)
|
438 |
+
if keepshapes:
|
439 |
+
lr = lr.resize((w,h), keepmode)
|
440 |
+
lr = np.array(lr)/127.5-1.0
|
441 |
+
lr = lr.astype(dtype)
|
442 |
+
|
443 |
+
return lr
|
444 |
+
|
445 |
+
|
446 |
+
class ImageNetScale(Dataset):
|
447 |
+
def __init__(self, size=None, crop_size=None, random_crop=False,
|
448 |
+
up_factor=None, hr_factor=None, keep_mode="bicubic"):
|
449 |
+
self.base = self.get_base()
|
450 |
+
|
451 |
+
self.size = size
|
452 |
+
self.crop_size = crop_size if crop_size is not None else self.size
|
453 |
+
self.random_crop = random_crop
|
454 |
+
self.up_factor = up_factor
|
455 |
+
self.hr_factor = hr_factor
|
456 |
+
self.keep_mode = keep_mode
|
457 |
+
|
458 |
+
transforms = list()
|
459 |
+
|
460 |
+
if self.size is not None and self.size > 0:
|
461 |
+
rescaler = albumentations.SmallestMaxSize(max_size = self.size)
|
462 |
+
self.rescaler = rescaler
|
463 |
+
transforms.append(rescaler)
|
464 |
+
|
465 |
+
if self.crop_size is not None and self.crop_size > 0:
|
466 |
+
if len(transforms) == 0:
|
467 |
+
self.rescaler = albumentations.SmallestMaxSize(max_size = self.crop_size)
|
468 |
+
|
469 |
+
if not self.random_crop:
|
470 |
+
cropper = albumentations.CenterCrop(height=self.crop_size,width=self.crop_size)
|
471 |
+
else:
|
472 |
+
cropper = albumentations.RandomCrop(height=self.crop_size,width=self.crop_size)
|
473 |
+
transforms.append(cropper)
|
474 |
+
|
475 |
+
if len(transforms) > 0:
|
476 |
+
if self.up_factor is not None:
|
477 |
+
additional_targets = {"lr": "image"}
|
478 |
+
else:
|
479 |
+
additional_targets = None
|
480 |
+
self.preprocessor = albumentations.Compose(transforms,
|
481 |
+
additional_targets=additional_targets)
|
482 |
+
else:
|
483 |
+
self.preprocessor = lambda **kwargs: kwargs
|
484 |
+
|
485 |
+
def __len__(self):
|
486 |
+
return len(self.base)
|
487 |
+
|
488 |
+
def __getitem__(self, i):
|
489 |
+
example = self.base[i]
|
490 |
+
image = example["image"]
|
491 |
+
# adjust resolution
|
492 |
+
image = imscale(image, self.hr_factor, keepshapes=False)
|
493 |
+
h,w,c = image.shape
|
494 |
+
if self.crop_size and min(h,w) < self.crop_size:
|
495 |
+
# have to upscale to be able to crop - this just uses bilinear
|
496 |
+
image = self.rescaler(image=image)["image"]
|
497 |
+
if self.up_factor is None:
|
498 |
+
image = self.preprocessor(image=image)["image"]
|
499 |
+
example["image"] = image
|
500 |
+
else:
|
501 |
+
lr = imscale(image, self.up_factor, keepshapes=True,
|
502 |
+
keepmode=self.keep_mode)
|
503 |
+
|
504 |
+
out = self.preprocessor(image=image, lr=lr)
|
505 |
+
example["image"] = out["image"]
|
506 |
+
example["lr"] = out["lr"]
|
507 |
+
|
508 |
+
return example
|
509 |
+
|
510 |
+
class ImageNetScaleTrain(ImageNetScale):
|
511 |
+
def __init__(self, random_crop=True, **kwargs):
|
512 |
+
super().__init__(random_crop=random_crop, **kwargs)
|
513 |
+
|
514 |
+
def get_base(self):
|
515 |
+
return ImageNetTrain()
|
516 |
+
|
517 |
+
class ImageNetScaleValidation(ImageNetScale):
|
518 |
+
def get_base(self):
|
519 |
+
return ImageNetValidation()
|
520 |
+
|
521 |
+
|
522 |
+
from skimage.feature import canny
|
523 |
+
from skimage.color import rgb2gray
|
524 |
+
|
525 |
+
|
526 |
+
class ImageNetEdges(ImageNetScale):
|
527 |
+
def __init__(self, up_factor=1, **kwargs):
|
528 |
+
super().__init__(up_factor=1, **kwargs)
|
529 |
+
|
530 |
+
def __getitem__(self, i):
|
531 |
+
example = self.base[i]
|
532 |
+
image = example["image"]
|
533 |
+
h,w,c = image.shape
|
534 |
+
if self.crop_size and min(h,w) < self.crop_size:
|
535 |
+
# have to upscale to be able to crop - this just uses bilinear
|
536 |
+
image = self.rescaler(image=image)["image"]
|
537 |
+
|
538 |
+
lr = canny(rgb2gray(image), sigma=2)
|
539 |
+
lr = lr.astype(np.float32)
|
540 |
+
lr = lr[:,:,None][:,:,[0,0,0]]
|
541 |
+
|
542 |
+
out = self.preprocessor(image=image, lr=lr)
|
543 |
+
example["image"] = out["image"]
|
544 |
+
example["lr"] = out["lr"]
|
545 |
+
|
546 |
+
return example
|
547 |
+
|
548 |
+
|
549 |
+
class ImageNetEdgesTrain(ImageNetEdges):
|
550 |
+
def __init__(self, random_crop=True, **kwargs):
|
551 |
+
super().__init__(random_crop=random_crop, **kwargs)
|
552 |
+
|
553 |
+
def get_base(self):
|
554 |
+
return ImageNetTrain()
|
555 |
+
|
556 |
+
class ImageNetEdgesValidation(ImageNetEdges):
|
557 |
+
def get_base(self):
|
558 |
+
return ImageNetValidation()
|
taming/data/open_images_helper.py
ADDED
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
open_images_unify_categories_for_coco = {
|
2 |
+
'/m/03bt1vf': '/m/01g317',
|
3 |
+
'/m/04yx4': '/m/01g317',
|
4 |
+
'/m/05r655': '/m/01g317',
|
5 |
+
'/m/01bl7v': '/m/01g317',
|
6 |
+
'/m/0cnyhnx': '/m/01xq0k1',
|
7 |
+
'/m/01226z': '/m/018xm',
|
8 |
+
'/m/05ctyq': '/m/018xm',
|
9 |
+
'/m/058qzx': '/m/04ctx',
|
10 |
+
'/m/06pcq': '/m/0l515',
|
11 |
+
'/m/03m3pdh': '/m/02crq1',
|
12 |
+
'/m/046dlr': '/m/01x3z',
|
13 |
+
'/m/0h8mzrc': '/m/01x3z',
|
14 |
+
}
|
15 |
+
|
16 |
+
|
17 |
+
top_300_classes_plus_coco_compatibility = [
|
18 |
+
('Man', 1060962),
|
19 |
+
('Clothing', 986610),
|
20 |
+
('Tree', 748162),
|
21 |
+
('Woman', 611896),
|
22 |
+
('Person', 610294),
|
23 |
+
('Human face', 442948),
|
24 |
+
('Girl', 175399),
|
25 |
+
('Building', 162147),
|
26 |
+
('Car', 159135),
|
27 |
+
('Plant', 155704),
|
28 |
+
('Human body', 137073),
|
29 |
+
('Flower', 133128),
|
30 |
+
('Window', 127485),
|
31 |
+
('Human arm', 118380),
|
32 |
+
('House', 114365),
|
33 |
+
('Wheel', 111684),
|
34 |
+
('Suit', 99054),
|
35 |
+
('Human hair', 98089),
|
36 |
+
('Human head', 92763),
|
37 |
+
('Chair', 88624),
|
38 |
+
('Boy', 79849),
|
39 |
+
('Table', 73699),
|
40 |
+
('Jeans', 57200),
|
41 |
+
('Tire', 55725),
|
42 |
+
('Skyscraper', 53321),
|
43 |
+
('Food', 52400),
|
44 |
+
('Footwear', 50335),
|
45 |
+
('Dress', 50236),
|
46 |
+
('Human leg', 47124),
|
47 |
+
('Toy', 46636),
|
48 |
+
('Tower', 45605),
|
49 |
+
('Boat', 43486),
|
50 |
+
('Land vehicle', 40541),
|
51 |
+
('Bicycle wheel', 34646),
|
52 |
+
('Palm tree', 33729),
|
53 |
+
('Fashion accessory', 32914),
|
54 |
+
('Glasses', 31940),
|
55 |
+
('Bicycle', 31409),
|
56 |
+
('Furniture', 30656),
|
57 |
+
('Sculpture', 29643),
|
58 |
+
('Bottle', 27558),
|
59 |
+
('Dog', 26980),
|
60 |
+
('Snack', 26796),
|
61 |
+
('Human hand', 26664),
|
62 |
+
('Bird', 25791),
|
63 |
+
('Book', 25415),
|
64 |
+
('Guitar', 24386),
|
65 |
+
('Jacket', 23998),
|
66 |
+
('Poster', 22192),
|
67 |
+
('Dessert', 21284),
|
68 |
+
('Baked goods', 20657),
|
69 |
+
('Drink', 19754),
|
70 |
+
('Flag', 18588),
|
71 |
+
('Houseplant', 18205),
|
72 |
+
('Tableware', 17613),
|
73 |
+
('Airplane', 17218),
|
74 |
+
('Door', 17195),
|
75 |
+
('Sports uniform', 17068),
|
76 |
+
('Shelf', 16865),
|
77 |
+
('Drum', 16612),
|
78 |
+
('Vehicle', 16542),
|
79 |
+
('Microphone', 15269),
|
80 |
+
('Street light', 14957),
|
81 |
+
('Cat', 14879),
|
82 |
+
('Fruit', 13684),
|
83 |
+
('Fast food', 13536),
|
84 |
+
('Animal', 12932),
|
85 |
+
('Vegetable', 12534),
|
86 |
+
('Train', 12358),
|
87 |
+
('Horse', 11948),
|
88 |
+
('Flowerpot', 11728),
|
89 |
+
('Motorcycle', 11621),
|
90 |
+
('Fish', 11517),
|
91 |
+
('Desk', 11405),
|
92 |
+
('Helmet', 10996),
|
93 |
+
('Truck', 10915),
|
94 |
+
('Bus', 10695),
|
95 |
+
('Hat', 10532),
|
96 |
+
('Auto part', 10488),
|
97 |
+
('Musical instrument', 10303),
|
98 |
+
('Sunglasses', 10207),
|
99 |
+
('Picture frame', 10096),
|
100 |
+
('Sports equipment', 10015),
|
101 |
+
('Shorts', 9999),
|
102 |
+
('Wine glass', 9632),
|
103 |
+
('Duck', 9242),
|
104 |
+
('Wine', 9032),
|
105 |
+
('Rose', 8781),
|
106 |
+
('Tie', 8693),
|
107 |
+
('Butterfly', 8436),
|
108 |
+
('Beer', 7978),
|
109 |
+
('Cabinetry', 7956),
|
110 |
+
('Laptop', 7907),
|
111 |
+
('Insect', 7497),
|
112 |
+
('Goggles', 7363),
|
113 |
+
('Shirt', 7098),
|
114 |
+
('Dairy Product', 7021),
|
115 |
+
('Marine invertebrates', 7014),
|
116 |
+
('Cattle', 7006),
|
117 |
+
('Trousers', 6903),
|
118 |
+
('Van', 6843),
|
119 |
+
('Billboard', 6777),
|
120 |
+
('Balloon', 6367),
|
121 |
+
('Human nose', 6103),
|
122 |
+
('Tent', 6073),
|
123 |
+
('Camera', 6014),
|
124 |
+
('Doll', 6002),
|
125 |
+
('Coat', 5951),
|
126 |
+
('Mobile phone', 5758),
|
127 |
+
('Swimwear', 5729),
|
128 |
+
('Strawberry', 5691),
|
129 |
+
('Stairs', 5643),
|
130 |
+
('Goose', 5599),
|
131 |
+
('Umbrella', 5536),
|
132 |
+
('Cake', 5508),
|
133 |
+
('Sun hat', 5475),
|
134 |
+
('Bench', 5310),
|
135 |
+
('Bookcase', 5163),
|
136 |
+
('Bee', 5140),
|
137 |
+
('Computer monitor', 5078),
|
138 |
+
('Hiking equipment', 4983),
|
139 |
+
('Office building', 4981),
|
140 |
+
('Coffee cup', 4748),
|
141 |
+
('Curtain', 4685),
|
142 |
+
('Plate', 4651),
|
143 |
+
('Box', 4621),
|
144 |
+
('Tomato', 4595),
|
145 |
+
('Coffee table', 4529),
|
146 |
+
('Office supplies', 4473),
|
147 |
+
('Maple', 4416),
|
148 |
+
('Muffin', 4365),
|
149 |
+
('Cocktail', 4234),
|
150 |
+
('Castle', 4197),
|
151 |
+
('Couch', 4134),
|
152 |
+
('Pumpkin', 3983),
|
153 |
+
('Computer keyboard', 3960),
|
154 |
+
('Human mouth', 3926),
|
155 |
+
('Christmas tree', 3893),
|
156 |
+
('Mushroom', 3883),
|
157 |
+
('Swimming pool', 3809),
|
158 |
+
('Pastry', 3799),
|
159 |
+
('Lavender (Plant)', 3769),
|
160 |
+
('Football helmet', 3732),
|
161 |
+
('Bread', 3648),
|
162 |
+
('Traffic sign', 3628),
|
163 |
+
('Common sunflower', 3597),
|
164 |
+
('Television', 3550),
|
165 |
+
('Bed', 3525),
|
166 |
+
('Cookie', 3485),
|
167 |
+
('Fountain', 3484),
|
168 |
+
('Paddle', 3447),
|
169 |
+
('Bicycle helmet', 3429),
|
170 |
+
('Porch', 3420),
|
171 |
+
('Deer', 3387),
|
172 |
+
('Fedora', 3339),
|
173 |
+
('Canoe', 3338),
|
174 |
+
('Carnivore', 3266),
|
175 |
+
('Bowl', 3202),
|
176 |
+
('Human eye', 3166),
|
177 |
+
('Ball', 3118),
|
178 |
+
('Pillow', 3077),
|
179 |
+
('Salad', 3061),
|
180 |
+
('Beetle', 3060),
|
181 |
+
('Orange', 3050),
|
182 |
+
('Drawer', 2958),
|
183 |
+
('Platter', 2937),
|
184 |
+
('Elephant', 2921),
|
185 |
+
('Seafood', 2921),
|
186 |
+
('Monkey', 2915),
|
187 |
+
('Countertop', 2879),
|
188 |
+
('Watercraft', 2831),
|
189 |
+
('Helicopter', 2805),
|
190 |
+
('Kitchen appliance', 2797),
|
191 |
+
('Personal flotation device', 2781),
|
192 |
+
('Swan', 2739),
|
193 |
+
('Lamp', 2711),
|
194 |
+
('Boot', 2695),
|
195 |
+
('Bronze sculpture', 2693),
|
196 |
+
('Chicken', 2677),
|
197 |
+
('Taxi', 2643),
|
198 |
+
('Juice', 2615),
|
199 |
+
('Cowboy hat', 2604),
|
200 |
+
('Apple', 2600),
|
201 |
+
('Tin can', 2590),
|
202 |
+
('Necklace', 2564),
|
203 |
+
('Ice cream', 2560),
|
204 |
+
('Human beard', 2539),
|
205 |
+
('Coin', 2536),
|
206 |
+
('Candle', 2515),
|
207 |
+
('Cart', 2512),
|
208 |
+
('High heels', 2441),
|
209 |
+
('Weapon', 2433),
|
210 |
+
('Handbag', 2406),
|
211 |
+
('Penguin', 2396),
|
212 |
+
('Rifle', 2352),
|
213 |
+
('Violin', 2336),
|
214 |
+
('Skull', 2304),
|
215 |
+
('Lantern', 2285),
|
216 |
+
('Scarf', 2269),
|
217 |
+
('Saucer', 2225),
|
218 |
+
('Sheep', 2215),
|
219 |
+
('Vase', 2189),
|
220 |
+
('Lily', 2180),
|
221 |
+
('Mug', 2154),
|
222 |
+
('Parrot', 2140),
|
223 |
+
('Human ear', 2137),
|
224 |
+
('Sandal', 2115),
|
225 |
+
('Lizard', 2100),
|
226 |
+
('Kitchen & dining room table', 2063),
|
227 |
+
('Spider', 1977),
|
228 |
+
('Coffee', 1974),
|
229 |
+
('Goat', 1926),
|
230 |
+
('Squirrel', 1922),
|
231 |
+
('Cello', 1913),
|
232 |
+
('Sushi', 1881),
|
233 |
+
('Tortoise', 1876),
|
234 |
+
('Pizza', 1870),
|
235 |
+
('Studio couch', 1864),
|
236 |
+
('Barrel', 1862),
|
237 |
+
('Cosmetics', 1841),
|
238 |
+
('Moths and butterflies', 1841),
|
239 |
+
('Convenience store', 1817),
|
240 |
+
('Watch', 1792),
|
241 |
+
('Home appliance', 1786),
|
242 |
+
('Harbor seal', 1780),
|
243 |
+
('Luggage and bags', 1756),
|
244 |
+
('Vehicle registration plate', 1754),
|
245 |
+
('Shrimp', 1751),
|
246 |
+
('Jellyfish', 1730),
|
247 |
+
('French fries', 1723),
|
248 |
+
('Egg (Food)', 1698),
|
249 |
+
('Football', 1697),
|
250 |
+
('Musical keyboard', 1683),
|
251 |
+
('Falcon', 1674),
|
252 |
+
('Candy', 1660),
|
253 |
+
('Medical equipment', 1654),
|
254 |
+
('Eagle', 1651),
|
255 |
+
('Dinosaur', 1634),
|
256 |
+
('Surfboard', 1630),
|
257 |
+
('Tank', 1628),
|
258 |
+
('Grape', 1624),
|
259 |
+
('Lion', 1624),
|
260 |
+
('Owl', 1622),
|
261 |
+
('Ski', 1613),
|
262 |
+
('Waste container', 1606),
|
263 |
+
('Frog', 1591),
|
264 |
+
('Sparrow', 1585),
|
265 |
+
('Rabbit', 1581),
|
266 |
+
('Pen', 1546),
|
267 |
+
('Sea lion', 1537),
|
268 |
+
('Spoon', 1521),
|
269 |
+
('Sink', 1512),
|
270 |
+
('Teddy bear', 1507),
|
271 |
+
('Bull', 1495),
|
272 |
+
('Sofa bed', 1490),
|
273 |
+
('Dragonfly', 1479),
|
274 |
+
('Brassiere', 1478),
|
275 |
+
('Chest of drawers', 1472),
|
276 |
+
('Aircraft', 1466),
|
277 |
+
('Human foot', 1463),
|
278 |
+
('Pig', 1455),
|
279 |
+
('Fork', 1454),
|
280 |
+
('Antelope', 1438),
|
281 |
+
('Tripod', 1427),
|
282 |
+
('Tool', 1424),
|
283 |
+
('Cheese', 1422),
|
284 |
+
('Lemon', 1397),
|
285 |
+
('Hamburger', 1393),
|
286 |
+
('Dolphin', 1390),
|
287 |
+
('Mirror', 1390),
|
288 |
+
('Marine mammal', 1387),
|
289 |
+
('Giraffe', 1385),
|
290 |
+
('Snake', 1368),
|
291 |
+
('Gondola', 1364),
|
292 |
+
('Wheelchair', 1360),
|
293 |
+
('Piano', 1358),
|
294 |
+
('Cupboard', 1348),
|
295 |
+
('Banana', 1345),
|
296 |
+
('Trumpet', 1335),
|
297 |
+
('Lighthouse', 1333),
|
298 |
+
('Invertebrate', 1317),
|
299 |
+
('Carrot', 1268),
|
300 |
+
('Sock', 1260),
|
301 |
+
('Tiger', 1241),
|
302 |
+
('Camel', 1224),
|
303 |
+
('Parachute', 1224),
|
304 |
+
('Bathroom accessory', 1223),
|
305 |
+
('Earrings', 1221),
|
306 |
+
('Headphones', 1218),
|
307 |
+
('Skirt', 1198),
|
308 |
+
('Skateboard', 1190),
|
309 |
+
('Sandwich', 1148),
|
310 |
+
('Saxophone', 1141),
|
311 |
+
('Goldfish', 1136),
|
312 |
+
('Stool', 1104),
|
313 |
+
('Traffic light', 1097),
|
314 |
+
('Shellfish', 1081),
|
315 |
+
('Backpack', 1079),
|
316 |
+
('Sea turtle', 1078),
|
317 |
+
('Cucumber', 1075),
|
318 |
+
('Tea', 1051),
|
319 |
+
('Toilet', 1047),
|
320 |
+
('Roller skates', 1040),
|
321 |
+
('Mule', 1039),
|
322 |
+
('Bust', 1031),
|
323 |
+
('Broccoli', 1030),
|
324 |
+
('Crab', 1020),
|
325 |
+
('Oyster', 1019),
|
326 |
+
('Cannon', 1012),
|
327 |
+
('Zebra', 1012),
|
328 |
+
('French horn', 1008),
|
329 |
+
('Grapefruit', 998),
|
330 |
+
('Whiteboard', 997),
|
331 |
+
('Zucchini', 997),
|
332 |
+
('Crocodile', 992),
|
333 |
+
|
334 |
+
('Clock', 960),
|
335 |
+
('Wall clock', 958),
|
336 |
+
|
337 |
+
('Doughnut', 869),
|
338 |
+
('Snail', 868),
|
339 |
+
|
340 |
+
('Baseball glove', 859),
|
341 |
+
|
342 |
+
('Panda', 830),
|
343 |
+
('Tennis racket', 830),
|
344 |
+
|
345 |
+
('Pear', 652),
|
346 |
+
|
347 |
+
('Bagel', 617),
|
348 |
+
('Oven', 616),
|
349 |
+
('Ladybug', 615),
|
350 |
+
('Shark', 615),
|
351 |
+
('Polar bear', 614),
|
352 |
+
('Ostrich', 609),
|
353 |
+
|
354 |
+
('Hot dog', 473),
|
355 |
+
('Microwave oven', 467),
|
356 |
+
('Fire hydrant', 20),
|
357 |
+
('Stop sign', 20),
|
358 |
+
('Parking meter', 20),
|
359 |
+
('Bear', 20),
|
360 |
+
('Flying disc', 20),
|
361 |
+
('Snowboard', 20),
|
362 |
+
('Tennis ball', 20),
|
363 |
+
('Kite', 20),
|
364 |
+
('Baseball bat', 20),
|
365 |
+
('Kitchen knife', 20),
|
366 |
+
('Knife', 20),
|
367 |
+
('Submarine sandwich', 20),
|
368 |
+
('Computer mouse', 20),
|
369 |
+
('Remote control', 20),
|
370 |
+
('Toaster', 20),
|
371 |
+
('Sink', 20),
|
372 |
+
('Refrigerator', 20),
|
373 |
+
('Alarm clock', 20),
|
374 |
+
('Wall clock', 20),
|
375 |
+
('Scissors', 20),
|
376 |
+
('Hair dryer', 20),
|
377 |
+
('Toothbrush', 20),
|
378 |
+
('Suitcase', 20)
|
379 |
+
]
|
taming/data/sflckr.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import albumentations
|
5 |
+
from PIL import Image
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
|
8 |
+
|
9 |
+
class SegmentationBase(Dataset):
|
10 |
+
def __init__(self,
|
11 |
+
data_csv, data_root, segmentation_root,
|
12 |
+
size=None, random_crop=False, interpolation="bicubic",
|
13 |
+
n_labels=182, shift_segmentation=False,
|
14 |
+
):
|
15 |
+
self.n_labels = n_labels
|
16 |
+
self.shift_segmentation = shift_segmentation
|
17 |
+
self.data_csv = data_csv
|
18 |
+
self.data_root = data_root
|
19 |
+
self.segmentation_root = segmentation_root
|
20 |
+
with open(self.data_csv, "r") as f:
|
21 |
+
self.image_paths = f.read().splitlines()
|
22 |
+
self._length = len(self.image_paths)
|
23 |
+
self.labels = {
|
24 |
+
"relative_file_path_": [l for l in self.image_paths],
|
25 |
+
"file_path_": [os.path.join(self.data_root, l)
|
26 |
+
for l in self.image_paths],
|
27 |
+
"segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png"))
|
28 |
+
for l in self.image_paths]
|
29 |
+
}
|
30 |
+
|
31 |
+
size = None if size is not None and size<=0 else size
|
32 |
+
self.size = size
|
33 |
+
if self.size is not None:
|
34 |
+
self.interpolation = interpolation
|
35 |
+
self.interpolation = {
|
36 |
+
"nearest": cv2.INTER_NEAREST,
|
37 |
+
"bilinear": cv2.INTER_LINEAR,
|
38 |
+
"bicubic": cv2.INTER_CUBIC,
|
39 |
+
"area": cv2.INTER_AREA,
|
40 |
+
"lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
|
41 |
+
self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
|
42 |
+
interpolation=self.interpolation)
|
43 |
+
self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
|
44 |
+
interpolation=cv2.INTER_NEAREST)
|
45 |
+
self.center_crop = not random_crop
|
46 |
+
if self.center_crop:
|
47 |
+
self.cropper = albumentations.CenterCrop(height=self.size, width=self.size)
|
48 |
+
else:
|
49 |
+
self.cropper = albumentations.RandomCrop(height=self.size, width=self.size)
|
50 |
+
self.preprocessor = self.cropper
|
51 |
+
|
52 |
+
def __len__(self):
|
53 |
+
return self._length
|
54 |
+
|
55 |
+
def __getitem__(self, i):
|
56 |
+
example = dict((k, self.labels[k][i]) for k in self.labels)
|
57 |
+
image = Image.open(example["file_path_"])
|
58 |
+
if not image.mode == "RGB":
|
59 |
+
image = image.convert("RGB")
|
60 |
+
image = np.array(image).astype(np.uint8)
|
61 |
+
if self.size is not None:
|
62 |
+
image = self.image_rescaler(image=image)["image"]
|
63 |
+
segmentation = Image.open(example["segmentation_path_"])
|
64 |
+
assert segmentation.mode == "L", segmentation.mode
|
65 |
+
segmentation = np.array(segmentation).astype(np.uint8)
|
66 |
+
if self.shift_segmentation:
|
67 |
+
# used to support segmentations containing unlabeled==255 label
|
68 |
+
segmentation = segmentation+1
|
69 |
+
if self.size is not None:
|
70 |
+
segmentation = self.segmentation_rescaler(image=segmentation)["image"]
|
71 |
+
if self.size is not None:
|
72 |
+
processed = self.preprocessor(image=image,
|
73 |
+
mask=segmentation
|
74 |
+
)
|
75 |
+
else:
|
76 |
+
processed = {"image": image,
|
77 |
+
"mask": segmentation
|
78 |
+
}
|
79 |
+
example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
|
80 |
+
segmentation = processed["mask"]
|
81 |
+
onehot = np.eye(self.n_labels)[segmentation]
|
82 |
+
example["segmentation"] = onehot
|
83 |
+
return example
|
84 |
+
|
85 |
+
|
86 |
+
class Examples(SegmentationBase):
|
87 |
+
def __init__(self, size=None, random_crop=False, interpolation="bicubic"):
|
88 |
+
super().__init__(data_csv="data/sflckr_examples.txt",
|
89 |
+
data_root="data/sflckr_images",
|
90 |
+
segmentation_root="data/sflckr_segmentations",
|
91 |
+
size=size, random_crop=random_crop, interpolation=interpolation)
|
taming/data/utils.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import os
|
3 |
+
import tarfile
|
4 |
+
import urllib
|
5 |
+
import zipfile
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from taming.data.helper_types import Annotation
|
11 |
+
#from torch._six import string_classes
|
12 |
+
from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
string_classes = (str,bytes)
|
16 |
+
|
17 |
+
|
18 |
+
def unpack(path):
|
19 |
+
if path.endswith("tar.gz"):
|
20 |
+
with tarfile.open(path, "r:gz") as tar:
|
21 |
+
tar.extractall(path=os.path.split(path)[0])
|
22 |
+
elif path.endswith("tar"):
|
23 |
+
with tarfile.open(path, "r:") as tar:
|
24 |
+
tar.extractall(path=os.path.split(path)[0])
|
25 |
+
elif path.endswith("zip"):
|
26 |
+
with zipfile.ZipFile(path, "r") as f:
|
27 |
+
f.extractall(path=os.path.split(path)[0])
|
28 |
+
else:
|
29 |
+
raise NotImplementedError(
|
30 |
+
"Unknown file extension: {}".format(os.path.splitext(path)[1])
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
def reporthook(bar):
|
35 |
+
"""tqdm progress bar for downloads."""
|
36 |
+
|
37 |
+
def hook(b=1, bsize=1, tsize=None):
|
38 |
+
if tsize is not None:
|
39 |
+
bar.total = tsize
|
40 |
+
bar.update(b * bsize - bar.n)
|
41 |
+
|
42 |
+
return hook
|
43 |
+
|
44 |
+
|
45 |
+
def get_root(name):
|
46 |
+
base = "data/"
|
47 |
+
root = os.path.join(base, name)
|
48 |
+
os.makedirs(root, exist_ok=True)
|
49 |
+
return root
|
50 |
+
|
51 |
+
|
52 |
+
def is_prepared(root):
|
53 |
+
return Path(root).joinpath(".ready").exists()
|
54 |
+
|
55 |
+
|
56 |
+
def mark_prepared(root):
|
57 |
+
Path(root).joinpath(".ready").touch()
|
58 |
+
|
59 |
+
|
60 |
+
def prompt_download(file_, source, target_dir, content_dir=None):
|
61 |
+
targetpath = os.path.join(target_dir, file_)
|
62 |
+
while not os.path.exists(targetpath):
|
63 |
+
if content_dir is not None and os.path.exists(
|
64 |
+
os.path.join(target_dir, content_dir)
|
65 |
+
):
|
66 |
+
break
|
67 |
+
print(
|
68 |
+
"Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath)
|
69 |
+
)
|
70 |
+
if content_dir is not None:
|
71 |
+
print(
|
72 |
+
"Or place its content into '{}'.".format(
|
73 |
+
os.path.join(target_dir, content_dir)
|
74 |
+
)
|
75 |
+
)
|
76 |
+
input("Press Enter when done...")
|
77 |
+
return targetpath
|
78 |
+
|
79 |
+
|
80 |
+
def download_url(file_, url, target_dir):
|
81 |
+
targetpath = os.path.join(target_dir, file_)
|
82 |
+
os.makedirs(target_dir, exist_ok=True)
|
83 |
+
with tqdm(
|
84 |
+
unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_
|
85 |
+
) as bar:
|
86 |
+
urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar))
|
87 |
+
return targetpath
|
88 |
+
|
89 |
+
|
90 |
+
def download_urls(urls, target_dir):
|
91 |
+
paths = dict()
|
92 |
+
for fname, url in urls.items():
|
93 |
+
outpath = download_url(fname, url, target_dir)
|
94 |
+
paths[fname] = outpath
|
95 |
+
return paths
|
96 |
+
|
97 |
+
|
98 |
+
def quadratic_crop(x, bbox, alpha=1.0):
|
99 |
+
"""bbox is xmin, ymin, xmax, ymax"""
|
100 |
+
im_h, im_w = x.shape[:2]
|
101 |
+
bbox = np.array(bbox, dtype=np.float32)
|
102 |
+
bbox = np.clip(bbox, 0, max(im_h, im_w))
|
103 |
+
center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3])
|
104 |
+
w = bbox[2] - bbox[0]
|
105 |
+
h = bbox[3] - bbox[1]
|
106 |
+
l = int(alpha * max(w, h))
|
107 |
+
l = max(l, 2)
|
108 |
+
|
109 |
+
required_padding = -1 * min(
|
110 |
+
center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l)
|
111 |
+
)
|
112 |
+
required_padding = int(np.ceil(required_padding))
|
113 |
+
if required_padding > 0:
|
114 |
+
padding = [
|
115 |
+
[required_padding, required_padding],
|
116 |
+
[required_padding, required_padding],
|
117 |
+
]
|
118 |
+
padding += [[0, 0]] * (len(x.shape) - 2)
|
119 |
+
x = np.pad(x, padding, "reflect")
|
120 |
+
center = center[0] + required_padding, center[1] + required_padding
|
121 |
+
xmin = int(center[0] - l / 2)
|
122 |
+
ymin = int(center[1] - l / 2)
|
123 |
+
return np.array(x[ymin : ymin + l, xmin : xmin + l, ...])
|
124 |
+
|
125 |
+
|
126 |
+
def custom_collate(batch):
|
127 |
+
r"""source: pytorch 1.9.0, only one modification to original code """
|
128 |
+
|
129 |
+
elem = batch[0]
|
130 |
+
elem_type = type(elem)
|
131 |
+
if isinstance(elem, torch.Tensor):
|
132 |
+
out = None
|
133 |
+
if torch.utils.data.get_worker_info() is not None:
|
134 |
+
# If we're in a background process, concatenate directly into a
|
135 |
+
# shared memory tensor to avoid an extra copy
|
136 |
+
numel = sum([x.numel() for x in batch])
|
137 |
+
storage = elem.storage()._new_shared(numel)
|
138 |
+
out = elem.new(storage)
|
139 |
+
return torch.stack(batch, 0, out=out)
|
140 |
+
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
141 |
+
and elem_type.__name__ != 'string_':
|
142 |
+
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
|
143 |
+
# array of string classes and object
|
144 |
+
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
145 |
+
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
|
146 |
+
|
147 |
+
return custom_collate([torch.as_tensor(b) for b in batch])
|
148 |
+
elif elem.shape == (): # scalars
|
149 |
+
return torch.as_tensor(batch)
|
150 |
+
elif isinstance(elem, float):
|
151 |
+
return torch.tensor(batch, dtype=torch.float64)
|
152 |
+
elif isinstance(elem, int):
|
153 |
+
return torch.tensor(batch)
|
154 |
+
elif isinstance(elem, string_classes):
|
155 |
+
return batch
|
156 |
+
elif isinstance(elem, collections.abc.Mapping):
|
157 |
+
return {key: custom_collate([d[key] for d in batch]) for key in elem}
|
158 |
+
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
|
159 |
+
return elem_type(*(custom_collate(samples) for samples in zip(*batch)))
|
160 |
+
if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added
|
161 |
+
return batch # added
|
162 |
+
elif isinstance(elem, collections.abc.Sequence):
|
163 |
+
# check to make sure that the elements in batch have consistent size
|
164 |
+
it = iter(batch)
|
165 |
+
elem_size = len(next(it))
|
166 |
+
if not all(len(elem) == elem_size for elem in it):
|
167 |
+
raise RuntimeError('each element in list of batch should be of equal size')
|
168 |
+
transposed = zip(*batch)
|
169 |
+
return [custom_collate(samples) for samples in transposed]
|
170 |
+
|
171 |
+
raise TypeError(default_collate_err_msg_format.format(elem_type))
|
taming/lr_scheduler.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class LambdaWarmUpCosineScheduler:
|
5 |
+
"""
|
6 |
+
note: use with a base_lr of 1.0
|
7 |
+
"""
|
8 |
+
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
|
9 |
+
self.lr_warm_up_steps = warm_up_steps
|
10 |
+
self.lr_start = lr_start
|
11 |
+
self.lr_min = lr_min
|
12 |
+
self.lr_max = lr_max
|
13 |
+
self.lr_max_decay_steps = max_decay_steps
|
14 |
+
self.last_lr = 0.
|
15 |
+
self.verbosity_interval = verbosity_interval
|
16 |
+
|
17 |
+
def schedule(self, n):
|
18 |
+
if self.verbosity_interval > 0:
|
19 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
20 |
+
if n < self.lr_warm_up_steps:
|
21 |
+
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
|
22 |
+
self.last_lr = lr
|
23 |
+
return lr
|
24 |
+
else:
|
25 |
+
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
26 |
+
t = min(t, 1.0)
|
27 |
+
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
28 |
+
1 + np.cos(t * np.pi))
|
29 |
+
self.last_lr = lr
|
30 |
+
return lr
|
31 |
+
|
32 |
+
def __call__(self, n):
|
33 |
+
return self.schedule(n)
|
34 |
+
|
taming/models/__pycache__/vqgan.cpython-312.pyc
ADDED
Binary file (21.7 kB). View file
|
|
taming/models/cond_transformer.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, math
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
|
6 |
+
from main import instantiate_from_config
|
7 |
+
from taming.modules.util import SOSProvider
|
8 |
+
|
9 |
+
|
10 |
+
def disabled_train(self, mode=True):
|
11 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
12 |
+
does not change anymore."""
|
13 |
+
return self
|
14 |
+
|
15 |
+
|
16 |
+
class Net2NetTransformer(pl.LightningModule):
|
17 |
+
def __init__(self,
|
18 |
+
transformer_config,
|
19 |
+
first_stage_config,
|
20 |
+
cond_stage_config,
|
21 |
+
permuter_config=None,
|
22 |
+
ckpt_path=None,
|
23 |
+
ignore_keys=[],
|
24 |
+
first_stage_key="image",
|
25 |
+
cond_stage_key="depth",
|
26 |
+
downsample_cond_size=-1,
|
27 |
+
pkeep=1.0,
|
28 |
+
sos_token=0,
|
29 |
+
unconditional=False,
|
30 |
+
):
|
31 |
+
super().__init__()
|
32 |
+
self.be_unconditional = unconditional
|
33 |
+
self.sos_token = sos_token
|
34 |
+
self.first_stage_key = first_stage_key
|
35 |
+
self.cond_stage_key = cond_stage_key
|
36 |
+
self.init_first_stage_from_ckpt(first_stage_config)
|
37 |
+
self.init_cond_stage_from_ckpt(cond_stage_config)
|
38 |
+
if permuter_config is None:
|
39 |
+
permuter_config = {"target": "taming.modules.transformer.permuter.Identity"}
|
40 |
+
self.permuter = instantiate_from_config(config=permuter_config)
|
41 |
+
self.transformer = instantiate_from_config(config=transformer_config)
|
42 |
+
|
43 |
+
if ckpt_path is not None:
|
44 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
45 |
+
self.downsample_cond_size = downsample_cond_size
|
46 |
+
self.pkeep = pkeep
|
47 |
+
|
48 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
49 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
50 |
+
for k in sd.keys():
|
51 |
+
for ik in ignore_keys:
|
52 |
+
if k.startswith(ik):
|
53 |
+
self.print("Deleting key {} from state_dict.".format(k))
|
54 |
+
del sd[k]
|
55 |
+
self.load_state_dict(sd, strict=False)
|
56 |
+
print(f"Restored from {path}")
|
57 |
+
|
58 |
+
def init_first_stage_from_ckpt(self, config):
|
59 |
+
model = instantiate_from_config(config)
|
60 |
+
model = model.eval()
|
61 |
+
model.train = disabled_train
|
62 |
+
self.first_stage_model = model
|
63 |
+
|
64 |
+
def init_cond_stage_from_ckpt(self, config):
|
65 |
+
if config == "__is_first_stage__":
|
66 |
+
print("Using first stage also as cond stage.")
|
67 |
+
self.cond_stage_model = self.first_stage_model
|
68 |
+
elif config == "__is_unconditional__" or self.be_unconditional:
|
69 |
+
print(f"Using no cond stage. Assuming the training is intended to be unconditional. "
|
70 |
+
f"Prepending {self.sos_token} as a sos token.")
|
71 |
+
self.be_unconditional = True
|
72 |
+
self.cond_stage_key = self.first_stage_key
|
73 |
+
self.cond_stage_model = SOSProvider(self.sos_token)
|
74 |
+
else:
|
75 |
+
model = instantiate_from_config(config)
|
76 |
+
model = model.eval()
|
77 |
+
model.train = disabled_train
|
78 |
+
self.cond_stage_model = model
|
79 |
+
|
80 |
+
def forward(self, x, c):
|
81 |
+
# one step to produce the logits
|
82 |
+
_, z_indices = self.encode_to_z(x)
|
83 |
+
_, c_indices = self.encode_to_c(c)
|
84 |
+
|
85 |
+
if self.training and self.pkeep < 1.0:
|
86 |
+
mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
|
87 |
+
device=z_indices.device))
|
88 |
+
mask = mask.round().to(dtype=torch.int64)
|
89 |
+
r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
|
90 |
+
a_indices = mask*z_indices+(1-mask)*r_indices
|
91 |
+
else:
|
92 |
+
a_indices = z_indices
|
93 |
+
|
94 |
+
cz_indices = torch.cat((c_indices, a_indices), dim=1)
|
95 |
+
|
96 |
+
# target includes all sequence elements (no need to handle first one
|
97 |
+
# differently because we are conditioning)
|
98 |
+
target = z_indices
|
99 |
+
# make the prediction
|
100 |
+
logits, _ = self.transformer(cz_indices[:, :-1])
|
101 |
+
# cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
|
102 |
+
logits = logits[:, c_indices.shape[1]-1:]
|
103 |
+
|
104 |
+
return logits, target
|
105 |
+
|
106 |
+
def top_k_logits(self, logits, k):
|
107 |
+
v, ix = torch.topk(logits, k)
|
108 |
+
out = logits.clone()
|
109 |
+
out[out < v[..., [-1]]] = -float('Inf')
|
110 |
+
return out
|
111 |
+
|
112 |
+
@torch.no_grad()
|
113 |
+
def sample(self, x, c, steps, temperature=1.0, sample=False, top_k=None,
|
114 |
+
callback=lambda k: None):
|
115 |
+
x = torch.cat((c,x),dim=1)
|
116 |
+
block_size = self.transformer.get_block_size()
|
117 |
+
assert not self.transformer.training
|
118 |
+
if self.pkeep <= 0.0:
|
119 |
+
# one pass suffices since input is pure noise anyway
|
120 |
+
assert len(x.shape)==2
|
121 |
+
noise_shape = (x.shape[0], steps-1)
|
122 |
+
#noise = torch.randint(self.transformer.config.vocab_size, noise_shape).to(x)
|
123 |
+
noise = c.clone()[:,x.shape[1]-c.shape[1]:-1]
|
124 |
+
x = torch.cat((x,noise),dim=1)
|
125 |
+
logits, _ = self.transformer(x)
|
126 |
+
# take all logits for now and scale by temp
|
127 |
+
logits = logits / temperature
|
128 |
+
# optionally crop probabilities to only the top k options
|
129 |
+
if top_k is not None:
|
130 |
+
logits = self.top_k_logits(logits, top_k)
|
131 |
+
# apply softmax to convert to probabilities
|
132 |
+
probs = F.softmax(logits, dim=-1)
|
133 |
+
# sample from the distribution or take the most likely
|
134 |
+
if sample:
|
135 |
+
shape = probs.shape
|
136 |
+
probs = probs.reshape(shape[0]*shape[1],shape[2])
|
137 |
+
ix = torch.multinomial(probs, num_samples=1)
|
138 |
+
probs = probs.reshape(shape[0],shape[1],shape[2])
|
139 |
+
ix = ix.reshape(shape[0],shape[1])
|
140 |
+
else:
|
141 |
+
_, ix = torch.topk(probs, k=1, dim=-1)
|
142 |
+
# cut off conditioning
|
143 |
+
x = ix[:, c.shape[1]-1:]
|
144 |
+
else:
|
145 |
+
for k in range(steps):
|
146 |
+
callback(k)
|
147 |
+
assert x.size(1) <= block_size # make sure model can see conditioning
|
148 |
+
x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
|
149 |
+
logits, _ = self.transformer(x_cond)
|
150 |
+
# pluck the logits at the final step and scale by temperature
|
151 |
+
logits = logits[:, -1, :] / temperature
|
152 |
+
# optionally crop probabilities to only the top k options
|
153 |
+
if top_k is not None:
|
154 |
+
logits = self.top_k_logits(logits, top_k)
|
155 |
+
# apply softmax to convert to probabilities
|
156 |
+
probs = F.softmax(logits, dim=-1)
|
157 |
+
# sample from the distribution or take the most likely
|
158 |
+
if sample:
|
159 |
+
ix = torch.multinomial(probs, num_samples=1)
|
160 |
+
else:
|
161 |
+
_, ix = torch.topk(probs, k=1, dim=-1)
|
162 |
+
# append to the sequence and continue
|
163 |
+
x = torch.cat((x, ix), dim=1)
|
164 |
+
# cut off conditioning
|
165 |
+
x = x[:, c.shape[1]:]
|
166 |
+
return x
|
167 |
+
|
168 |
+
@torch.no_grad()
|
169 |
+
def encode_to_z(self, x):
|
170 |
+
quant_z, _, info = self.first_stage_model.encode(x)
|
171 |
+
indices = info[2].view(quant_z.shape[0], -1)
|
172 |
+
indices = self.permuter(indices)
|
173 |
+
return quant_z, indices
|
174 |
+
|
175 |
+
@torch.no_grad()
|
176 |
+
def encode_to_c(self, c):
|
177 |
+
if self.downsample_cond_size > -1:
|
178 |
+
c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size))
|
179 |
+
quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c)
|
180 |
+
if len(indices.shape) > 2:
|
181 |
+
indices = indices.view(c.shape[0], -1)
|
182 |
+
return quant_c, indices
|
183 |
+
|
184 |
+
@torch.no_grad()
|
185 |
+
def decode_to_img(self, index, zshape):
|
186 |
+
index = self.permuter(index, reverse=True)
|
187 |
+
bhwc = (zshape[0],zshape[2],zshape[3],zshape[1])
|
188 |
+
quant_z = self.first_stage_model.quantize.get_codebook_entry(
|
189 |
+
index.reshape(-1), shape=bhwc)
|
190 |
+
x = self.first_stage_model.decode(quant_z)
|
191 |
+
return x
|
192 |
+
|
193 |
+
@torch.no_grad()
|
194 |
+
def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs):
|
195 |
+
log = dict()
|
196 |
+
|
197 |
+
N = 4
|
198 |
+
if lr_interface:
|
199 |
+
x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8)
|
200 |
+
else:
|
201 |
+
x, c = self.get_xc(batch, N)
|
202 |
+
x = x.to(device=self.device)
|
203 |
+
c = c.to(device=self.device)
|
204 |
+
|
205 |
+
quant_z, z_indices = self.encode_to_z(x)
|
206 |
+
quant_c, c_indices = self.encode_to_c(c)
|
207 |
+
|
208 |
+
# create a "half"" sample
|
209 |
+
z_start_indices = z_indices[:,:z_indices.shape[1]//2]
|
210 |
+
index_sample = self.sample(z_start_indices, c_indices,
|
211 |
+
steps=z_indices.shape[1]-z_start_indices.shape[1],
|
212 |
+
temperature=temperature if temperature is not None else 1.0,
|
213 |
+
sample=True,
|
214 |
+
top_k=top_k if top_k is not None else 100,
|
215 |
+
callback=callback if callback is not None else lambda k: None)
|
216 |
+
x_sample = self.decode_to_img(index_sample, quant_z.shape)
|
217 |
+
|
218 |
+
# sample
|
219 |
+
z_start_indices = z_indices[:, :0]
|
220 |
+
index_sample = self.sample(z_start_indices, c_indices,
|
221 |
+
steps=z_indices.shape[1],
|
222 |
+
temperature=temperature if temperature is not None else 1.0,
|
223 |
+
sample=True,
|
224 |
+
top_k=top_k if top_k is not None else 100,
|
225 |
+
callback=callback if callback is not None else lambda k: None)
|
226 |
+
x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape)
|
227 |
+
|
228 |
+
# det sample
|
229 |
+
z_start_indices = z_indices[:, :0]
|
230 |
+
index_sample = self.sample(z_start_indices, c_indices,
|
231 |
+
steps=z_indices.shape[1],
|
232 |
+
sample=False,
|
233 |
+
callback=callback if callback is not None else lambda k: None)
|
234 |
+
x_sample_det = self.decode_to_img(index_sample, quant_z.shape)
|
235 |
+
|
236 |
+
# reconstruction
|
237 |
+
x_rec = self.decode_to_img(z_indices, quant_z.shape)
|
238 |
+
|
239 |
+
log["inputs"] = x
|
240 |
+
log["reconstructions"] = x_rec
|
241 |
+
|
242 |
+
if self.cond_stage_key in ["objects_bbox", "objects_center_points"]:
|
243 |
+
figure_size = (x_rec.shape[2], x_rec.shape[3])
|
244 |
+
dataset = kwargs["pl_module"].trainer.datamodule.datasets["validation"]
|
245 |
+
label_for_category_no = dataset.get_textual_label_for_category_no
|
246 |
+
plotter = dataset.conditional_builders[self.cond_stage_key].plot
|
247 |
+
log["conditioning"] = torch.zeros_like(log["reconstructions"])
|
248 |
+
for i in range(quant_c.shape[0]):
|
249 |
+
log["conditioning"][i] = plotter(quant_c[i], label_for_category_no, figure_size)
|
250 |
+
log["conditioning_rec"] = log["conditioning"]
|
251 |
+
elif self.cond_stage_key != "image":
|
252 |
+
cond_rec = self.cond_stage_model.decode(quant_c)
|
253 |
+
if self.cond_stage_key == "segmentation":
|
254 |
+
# get image from segmentation mask
|
255 |
+
num_classes = cond_rec.shape[1]
|
256 |
+
|
257 |
+
c = torch.argmax(c, dim=1, keepdim=True)
|
258 |
+
c = F.one_hot(c, num_classes=num_classes)
|
259 |
+
c = c.squeeze(1).permute(0, 3, 1, 2).float()
|
260 |
+
c = self.cond_stage_model.to_rgb(c)
|
261 |
+
|
262 |
+
cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True)
|
263 |
+
cond_rec = F.one_hot(cond_rec, num_classes=num_classes)
|
264 |
+
cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float()
|
265 |
+
cond_rec = self.cond_stage_model.to_rgb(cond_rec)
|
266 |
+
log["conditioning_rec"] = cond_rec
|
267 |
+
log["conditioning"] = c
|
268 |
+
|
269 |
+
log["samples_half"] = x_sample
|
270 |
+
log["samples_nopix"] = x_sample_nopix
|
271 |
+
log["samples_det"] = x_sample_det
|
272 |
+
return log
|
273 |
+
|
274 |
+
def get_input(self, key, batch):
|
275 |
+
x = batch[key]
|
276 |
+
if len(x.shape) == 3:
|
277 |
+
x = x[..., None]
|
278 |
+
if len(x.shape) == 4:
|
279 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
280 |
+
if x.dtype == torch.double:
|
281 |
+
x = x.float()
|
282 |
+
return x
|
283 |
+
|
284 |
+
def get_xc(self, batch, N=None):
|
285 |
+
x = self.get_input(self.first_stage_key, batch)
|
286 |
+
c = self.get_input(self.cond_stage_key, batch)
|
287 |
+
if N is not None:
|
288 |
+
x = x[:N]
|
289 |
+
c = c[:N]
|
290 |
+
return x, c
|
291 |
+
|
292 |
+
def shared_step(self, batch, batch_idx):
|
293 |
+
x, c = self.get_xc(batch)
|
294 |
+
logits, target = self(x, c)
|
295 |
+
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
|
296 |
+
return loss
|
297 |
+
|
298 |
+
def training_step(self, batch, batch_idx):
|
299 |
+
loss = self.shared_step(batch, batch_idx)
|
300 |
+
self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
301 |
+
return loss
|
302 |
+
|
303 |
+
def validation_step(self, batch, batch_idx):
|
304 |
+
loss = self.shared_step(batch, batch_idx)
|
305 |
+
self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
306 |
+
return loss
|
307 |
+
|
308 |
+
def configure_optimizers(self):
|
309 |
+
"""
|
310 |
+
Following minGPT:
|
311 |
+
This long function is unfortunately doing something very simple and is being very defensive:
|
312 |
+
We are separating out all parameters of the model into two buckets: those that will experience
|
313 |
+
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
|
314 |
+
We are then returning the PyTorch optimizer object.
|
315 |
+
"""
|
316 |
+
# separate out all parameters to those that will and won't experience regularizing weight decay
|
317 |
+
decay = set()
|
318 |
+
no_decay = set()
|
319 |
+
whitelist_weight_modules = (torch.nn.Linear, )
|
320 |
+
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
321 |
+
for mn, m in self.transformer.named_modules():
|
322 |
+
for pn, p in m.named_parameters():
|
323 |
+
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
|
324 |
+
|
325 |
+
if pn.endswith('bias'):
|
326 |
+
# all biases will not be decayed
|
327 |
+
no_decay.add(fpn)
|
328 |
+
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
329 |
+
# weights of whitelist modules will be weight decayed
|
330 |
+
decay.add(fpn)
|
331 |
+
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
332 |
+
# weights of blacklist modules will NOT be weight decayed
|
333 |
+
no_decay.add(fpn)
|
334 |
+
|
335 |
+
# special case the position embedding parameter in the root GPT module as not decayed
|
336 |
+
no_decay.add('pos_emb')
|
337 |
+
|
338 |
+
# validate that we considered every parameter
|
339 |
+
param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
|
340 |
+
inter_params = decay & no_decay
|
341 |
+
union_params = decay | no_decay
|
342 |
+
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
|
343 |
+
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
|
344 |
+
% (str(param_dict.keys() - union_params), )
|
345 |
+
|
346 |
+
# create the pytorch optimizer object
|
347 |
+
optim_groups = [
|
348 |
+
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
|
349 |
+
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
350 |
+
]
|
351 |
+
optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95))
|
352 |
+
return optimizer
|
taming/models/dummy_cond_stage.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import Tensor
|
2 |
+
|
3 |
+
|
4 |
+
class DummyCondStage:
|
5 |
+
def __init__(self, conditional_key):
|
6 |
+
self.conditional_key = conditional_key
|
7 |
+
self.train = None
|
8 |
+
|
9 |
+
def eval(self):
|
10 |
+
return self
|
11 |
+
|
12 |
+
@staticmethod
|
13 |
+
def encode(c: Tensor):
|
14 |
+
return c, None, (None, None, c)
|
15 |
+
|
16 |
+
@staticmethod
|
17 |
+
def decode(c: Tensor):
|
18 |
+
return c
|
19 |
+
|
20 |
+
@staticmethod
|
21 |
+
def to_rgb(c: Tensor):
|
22 |
+
return c
|
taming/models/vqgan.py
ADDED
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
|
5 |
+
from main import instantiate_from_config
|
6 |
+
|
7 |
+
from taming.modules.diffusionmodules.model import Encoder, Decoder
|
8 |
+
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
9 |
+
from taming.modules.vqvae.quantize import GumbelQuantize
|
10 |
+
from taming.modules.vqvae.quantize import EMAVectorQuantizer
|
11 |
+
|
12 |
+
class VQModel(pl.LightningModule):
|
13 |
+
def __init__(self,
|
14 |
+
ddconfig,
|
15 |
+
lossconfig,
|
16 |
+
n_embed,
|
17 |
+
embed_dim,
|
18 |
+
ckpt_path=None,
|
19 |
+
ignore_keys=[],
|
20 |
+
image_key="image",
|
21 |
+
colorize_nlabels=None,
|
22 |
+
monitor=None,
|
23 |
+
remap=None,
|
24 |
+
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
self.image_key = image_key
|
28 |
+
self.encoder = Encoder(**ddconfig)
|
29 |
+
self.decoder = Decoder(**ddconfig)
|
30 |
+
self.loss = instantiate_from_config(lossconfig)
|
31 |
+
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
32 |
+
remap=remap, sane_index_shape=sane_index_shape)
|
33 |
+
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
34 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
35 |
+
if ckpt_path is not None:
|
36 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
37 |
+
self.image_key = image_key
|
38 |
+
if colorize_nlabels is not None:
|
39 |
+
assert type(colorize_nlabels)==int
|
40 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
41 |
+
if monitor is not None:
|
42 |
+
self.monitor = monitor
|
43 |
+
|
44 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
45 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
46 |
+
keys = list(sd.keys())
|
47 |
+
for k in keys:
|
48 |
+
for ik in ignore_keys:
|
49 |
+
if k.startswith(ik):
|
50 |
+
print("Deleting key {} from state_dict.".format(k))
|
51 |
+
del sd[k]
|
52 |
+
self.load_state_dict(sd, strict=False)
|
53 |
+
print(f"Restored from {path}")
|
54 |
+
|
55 |
+
def encode(self, x):
|
56 |
+
h = self.encoder(x)
|
57 |
+
h = self.quant_conv(h)
|
58 |
+
quant, emb_loss, info = self.quantize(h)
|
59 |
+
return quant, emb_loss, info
|
60 |
+
|
61 |
+
def decode(self, quant):
|
62 |
+
quant = self.post_quant_conv(quant)
|
63 |
+
dec = self.decoder(quant)
|
64 |
+
return dec
|
65 |
+
|
66 |
+
def decode_code(self, code_b):
|
67 |
+
quant_b = self.quantize.embed_code(code_b)
|
68 |
+
dec = self.decode(quant_b)
|
69 |
+
return dec
|
70 |
+
|
71 |
+
def forward(self, input):
|
72 |
+
quant, diff, _ = self.encode(input)
|
73 |
+
dec = self.decode(quant)
|
74 |
+
return dec, diff
|
75 |
+
|
76 |
+
def get_input(self, batch, k):
|
77 |
+
x = batch[k]
|
78 |
+
if len(x.shape) == 3:
|
79 |
+
x = x[..., None]
|
80 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
81 |
+
return x.float()
|
82 |
+
|
83 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
84 |
+
x = self.get_input(batch, self.image_key)
|
85 |
+
xrec, qloss = self(x)
|
86 |
+
|
87 |
+
if optimizer_idx == 0:
|
88 |
+
# autoencode
|
89 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
90 |
+
last_layer=self.get_last_layer(), split="train")
|
91 |
+
|
92 |
+
self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
93 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
94 |
+
return aeloss
|
95 |
+
|
96 |
+
if optimizer_idx == 1:
|
97 |
+
# discriminator
|
98 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
99 |
+
last_layer=self.get_last_layer(), split="train")
|
100 |
+
self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
101 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
102 |
+
return discloss
|
103 |
+
|
104 |
+
def validation_step(self, batch, batch_idx):
|
105 |
+
x = self.get_input(batch, self.image_key)
|
106 |
+
xrec, qloss = self(x)
|
107 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
|
108 |
+
last_layer=self.get_last_layer(), split="val")
|
109 |
+
|
110 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
|
111 |
+
last_layer=self.get_last_layer(), split="val")
|
112 |
+
rec_loss = log_dict_ae["val/rec_loss"]
|
113 |
+
self.log("val/rec_loss", rec_loss,
|
114 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
|
115 |
+
self.log("val/aeloss", aeloss,
|
116 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
|
117 |
+
self.log_dict(log_dict_ae)
|
118 |
+
self.log_dict(log_dict_disc)
|
119 |
+
return self.log_dict
|
120 |
+
|
121 |
+
def configure_optimizers(self):
|
122 |
+
lr = self.learning_rate
|
123 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
124 |
+
list(self.decoder.parameters())+
|
125 |
+
list(self.quantize.parameters())+
|
126 |
+
list(self.quant_conv.parameters())+
|
127 |
+
list(self.post_quant_conv.parameters()),
|
128 |
+
lr=lr, betas=(0.5, 0.9))
|
129 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
130 |
+
lr=lr, betas=(0.5, 0.9))
|
131 |
+
return [opt_ae, opt_disc], []
|
132 |
+
|
133 |
+
def get_last_layer(self):
|
134 |
+
return self.decoder.conv_out.weight
|
135 |
+
|
136 |
+
def log_images(self, batch, **kwargs):
|
137 |
+
log = dict()
|
138 |
+
x = self.get_input(batch, self.image_key)
|
139 |
+
x = x.to(self.device)
|
140 |
+
xrec, _ = self(x)
|
141 |
+
if x.shape[1] > 3:
|
142 |
+
# colorize with random projection
|
143 |
+
assert xrec.shape[1] > 3
|
144 |
+
x = self.to_rgb(x)
|
145 |
+
xrec = self.to_rgb(xrec)
|
146 |
+
log["inputs"] = x
|
147 |
+
log["reconstructions"] = xrec
|
148 |
+
return log
|
149 |
+
|
150 |
+
def to_rgb(self, x):
|
151 |
+
assert self.image_key == "segmentation"
|
152 |
+
if not hasattr(self, "colorize"):
|
153 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
154 |
+
x = F.conv2d(x, weight=self.colorize)
|
155 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
156 |
+
return x
|
157 |
+
|
158 |
+
|
159 |
+
class VQSegmentationModel(VQModel):
|
160 |
+
def __init__(self, n_labels, *args, **kwargs):
|
161 |
+
super().__init__(*args, **kwargs)
|
162 |
+
self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1))
|
163 |
+
|
164 |
+
def configure_optimizers(self):
|
165 |
+
lr = self.learning_rate
|
166 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
167 |
+
list(self.decoder.parameters())+
|
168 |
+
list(self.quantize.parameters())+
|
169 |
+
list(self.quant_conv.parameters())+
|
170 |
+
list(self.post_quant_conv.parameters()),
|
171 |
+
lr=lr, betas=(0.5, 0.9))
|
172 |
+
return opt_ae
|
173 |
+
|
174 |
+
def training_step(self, batch, batch_idx):
|
175 |
+
x = self.get_input(batch, self.image_key)
|
176 |
+
xrec, qloss = self(x)
|
177 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
|
178 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
179 |
+
return aeloss
|
180 |
+
|
181 |
+
def validation_step(self, batch, batch_idx):
|
182 |
+
x = self.get_input(batch, self.image_key)
|
183 |
+
xrec, qloss = self(x)
|
184 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val")
|
185 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
186 |
+
total_loss = log_dict_ae["val/total_loss"]
|
187 |
+
self.log("val/total_loss", total_loss,
|
188 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
|
189 |
+
return aeloss
|
190 |
+
|
191 |
+
@torch.no_grad()
|
192 |
+
def log_images(self, batch, **kwargs):
|
193 |
+
log = dict()
|
194 |
+
x = self.get_input(batch, self.image_key)
|
195 |
+
x = x.to(self.device)
|
196 |
+
xrec, _ = self(x)
|
197 |
+
if x.shape[1] > 3:
|
198 |
+
# colorize with random projection
|
199 |
+
assert xrec.shape[1] > 3
|
200 |
+
# convert logits to indices
|
201 |
+
xrec = torch.argmax(xrec, dim=1, keepdim=True)
|
202 |
+
xrec = F.one_hot(xrec, num_classes=x.shape[1])
|
203 |
+
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
|
204 |
+
x = self.to_rgb(x)
|
205 |
+
xrec = self.to_rgb(xrec)
|
206 |
+
log["inputs"] = x
|
207 |
+
log["reconstructions"] = xrec
|
208 |
+
return log
|
209 |
+
|
210 |
+
|
211 |
+
class VQNoDiscModel(VQModel):
|
212 |
+
def __init__(self,
|
213 |
+
ddconfig,
|
214 |
+
lossconfig,
|
215 |
+
n_embed,
|
216 |
+
embed_dim,
|
217 |
+
ckpt_path=None,
|
218 |
+
ignore_keys=[],
|
219 |
+
image_key="image",
|
220 |
+
colorize_nlabels=None
|
221 |
+
):
|
222 |
+
super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim,
|
223 |
+
ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key,
|
224 |
+
colorize_nlabels=colorize_nlabels)
|
225 |
+
|
226 |
+
def training_step(self, batch, batch_idx):
|
227 |
+
x = self.get_input(batch, self.image_key)
|
228 |
+
xrec, qloss = self(x)
|
229 |
+
# autoencode
|
230 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train")
|
231 |
+
output = pl.TrainResult(minimize=aeloss)
|
232 |
+
output.log("train/aeloss", aeloss,
|
233 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
234 |
+
output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
235 |
+
return output
|
236 |
+
|
237 |
+
def validation_step(self, batch, batch_idx):
|
238 |
+
x = self.get_input(batch, self.image_key)
|
239 |
+
xrec, qloss = self(x)
|
240 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val")
|
241 |
+
rec_loss = log_dict_ae["val/rec_loss"]
|
242 |
+
output = pl.EvalResult(checkpoint_on=rec_loss)
|
243 |
+
output.log("val/rec_loss", rec_loss,
|
244 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
245 |
+
output.log("val/aeloss", aeloss,
|
246 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
247 |
+
output.log_dict(log_dict_ae)
|
248 |
+
|
249 |
+
return output
|
250 |
+
|
251 |
+
def configure_optimizers(self):
|
252 |
+
optimizer = torch.optim.Adam(list(self.encoder.parameters())+
|
253 |
+
list(self.decoder.parameters())+
|
254 |
+
list(self.quantize.parameters())+
|
255 |
+
list(self.quant_conv.parameters())+
|
256 |
+
list(self.post_quant_conv.parameters()),
|
257 |
+
lr=self.learning_rate, betas=(0.5, 0.9))
|
258 |
+
return optimizer
|
259 |
+
|
260 |
+
|
261 |
+
class GumbelVQ(VQModel):
|
262 |
+
def __init__(self,
|
263 |
+
ddconfig,
|
264 |
+
lossconfig,
|
265 |
+
n_embed,
|
266 |
+
embed_dim,
|
267 |
+
temperature_scheduler_config,
|
268 |
+
ckpt_path=None,
|
269 |
+
ignore_keys=[],
|
270 |
+
image_key="image",
|
271 |
+
colorize_nlabels=None,
|
272 |
+
monitor=None,
|
273 |
+
kl_weight=1e-8,
|
274 |
+
remap=None,
|
275 |
+
):
|
276 |
+
|
277 |
+
z_channels = ddconfig["z_channels"]
|
278 |
+
super().__init__(ddconfig,
|
279 |
+
lossconfig,
|
280 |
+
n_embed,
|
281 |
+
embed_dim,
|
282 |
+
ckpt_path=None,
|
283 |
+
ignore_keys=ignore_keys,
|
284 |
+
image_key=image_key,
|
285 |
+
colorize_nlabels=colorize_nlabels,
|
286 |
+
monitor=monitor,
|
287 |
+
)
|
288 |
+
|
289 |
+
self.loss.n_classes = n_embed
|
290 |
+
self.vocab_size = n_embed
|
291 |
+
|
292 |
+
self.quantize = GumbelQuantize(z_channels, embed_dim,
|
293 |
+
n_embed=n_embed,
|
294 |
+
kl_weight=kl_weight, temp_init=1.0,
|
295 |
+
remap=remap)
|
296 |
+
|
297 |
+
self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp
|
298 |
+
|
299 |
+
if ckpt_path is not None:
|
300 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
301 |
+
|
302 |
+
def temperature_scheduling(self):
|
303 |
+
self.quantize.temperature = self.temperature_scheduler(self.global_step)
|
304 |
+
|
305 |
+
def encode_to_prequant(self, x):
|
306 |
+
h = self.encoder(x)
|
307 |
+
h = self.quant_conv(h)
|
308 |
+
return h
|
309 |
+
|
310 |
+
def decode_code(self, code_b):
|
311 |
+
raise NotImplementedError
|
312 |
+
|
313 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
314 |
+
self.temperature_scheduling()
|
315 |
+
x = self.get_input(batch, self.image_key)
|
316 |
+
xrec, qloss = self(x)
|
317 |
+
|
318 |
+
if optimizer_idx == 0:
|
319 |
+
# autoencode
|
320 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
321 |
+
last_layer=self.get_last_layer(), split="train")
|
322 |
+
|
323 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
324 |
+
self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
325 |
+
return aeloss
|
326 |
+
|
327 |
+
if optimizer_idx == 1:
|
328 |
+
# discriminator
|
329 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
330 |
+
last_layer=self.get_last_layer(), split="train")
|
331 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
332 |
+
return discloss
|
333 |
+
|
334 |
+
def validation_step(self, batch, batch_idx):
|
335 |
+
x = self.get_input(batch, self.image_key)
|
336 |
+
xrec, qloss = self(x, return_pred_indices=True)
|
337 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
|
338 |
+
last_layer=self.get_last_layer(), split="val")
|
339 |
+
|
340 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
|
341 |
+
last_layer=self.get_last_layer(), split="val")
|
342 |
+
rec_loss = log_dict_ae["val/rec_loss"]
|
343 |
+
self.log("val/rec_loss", rec_loss,
|
344 |
+
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
345 |
+
self.log("val/aeloss", aeloss,
|
346 |
+
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
347 |
+
self.log_dict(log_dict_ae)
|
348 |
+
self.log_dict(log_dict_disc)
|
349 |
+
return self.log_dict
|
350 |
+
|
351 |
+
def log_images(self, batch, **kwargs):
|
352 |
+
log = dict()
|
353 |
+
x = self.get_input(batch, self.image_key)
|
354 |
+
x = x.to(self.device)
|
355 |
+
# encode
|
356 |
+
h = self.encoder(x)
|
357 |
+
h = self.quant_conv(h)
|
358 |
+
quant, _, _ = self.quantize(h)
|
359 |
+
# decode
|
360 |
+
x_rec = self.decode(quant)
|
361 |
+
log["inputs"] = x
|
362 |
+
log["reconstructions"] = x_rec
|
363 |
+
return log
|
364 |
+
|
365 |
+
|
366 |
+
class EMAVQ(VQModel):
|
367 |
+
def __init__(self,
|
368 |
+
ddconfig,
|
369 |
+
lossconfig,
|
370 |
+
n_embed,
|
371 |
+
embed_dim,
|
372 |
+
ckpt_path=None,
|
373 |
+
ignore_keys=[],
|
374 |
+
image_key="image",
|
375 |
+
colorize_nlabels=None,
|
376 |
+
monitor=None,
|
377 |
+
remap=None,
|
378 |
+
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
379 |
+
):
|
380 |
+
super().__init__(ddconfig,
|
381 |
+
lossconfig,
|
382 |
+
n_embed,
|
383 |
+
embed_dim,
|
384 |
+
ckpt_path=None,
|
385 |
+
ignore_keys=ignore_keys,
|
386 |
+
image_key=image_key,
|
387 |
+
colorize_nlabels=colorize_nlabels,
|
388 |
+
monitor=monitor,
|
389 |
+
)
|
390 |
+
self.quantize = EMAVectorQuantizer(n_embed=n_embed,
|
391 |
+
embedding_dim=embed_dim,
|
392 |
+
beta=0.25,
|
393 |
+
remap=remap)
|
394 |
+
def configure_optimizers(self):
|
395 |
+
lr = self.learning_rate
|
396 |
+
#Remove self.quantize from parameter list since it is updated via EMA
|
397 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
398 |
+
list(self.decoder.parameters())+
|
399 |
+
list(self.quant_conv.parameters())+
|
400 |
+
list(self.post_quant_conv.parameters()),
|
401 |
+
lr=lr, betas=(0.5, 0.9))
|
402 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
403 |
+
lr=lr, betas=(0.5, 0.9))
|
404 |
+
return [opt_ae, opt_disc], []
|
taming/modules/__pycache__/util.cpython-312.pyc
ADDED
Binary file (7.4 kB). View file
|
|
taming/modules/diffusionmodules/__pycache__/model.cpython-312.pyc
ADDED
Binary file (34.6 kB). View file
|
|
taming/modules/diffusionmodules/model.py
ADDED
@@ -0,0 +1,776 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pytorch_diffusion + derived encoder decoder
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def get_timestep_embedding(timesteps, embedding_dim):
|
9 |
+
"""
|
10 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
11 |
+
From Fairseq.
|
12 |
+
Build sinusoidal embeddings.
|
13 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
14 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
15 |
+
"""
|
16 |
+
assert len(timesteps.shape) == 1
|
17 |
+
|
18 |
+
half_dim = embedding_dim // 2
|
19 |
+
emb = math.log(10000) / (half_dim - 1)
|
20 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
21 |
+
emb = emb.to(device=timesteps.device)
|
22 |
+
emb = timesteps.float()[:, None] * emb[None, :]
|
23 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
24 |
+
if embedding_dim % 2 == 1: # zero pad
|
25 |
+
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
26 |
+
return emb
|
27 |
+
|
28 |
+
|
29 |
+
def nonlinearity(x):
|
30 |
+
# swish
|
31 |
+
return x*torch.sigmoid(x)
|
32 |
+
|
33 |
+
|
34 |
+
def Normalize(in_channels):
|
35 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
36 |
+
|
37 |
+
|
38 |
+
class Upsample(nn.Module):
|
39 |
+
def __init__(self, in_channels, with_conv):
|
40 |
+
super().__init__()
|
41 |
+
self.with_conv = with_conv
|
42 |
+
if self.with_conv:
|
43 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
44 |
+
in_channels,
|
45 |
+
kernel_size=3,
|
46 |
+
stride=1,
|
47 |
+
padding=1)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
51 |
+
if self.with_conv:
|
52 |
+
x = self.conv(x)
|
53 |
+
return x
|
54 |
+
|
55 |
+
|
56 |
+
class Downsample(nn.Module):
|
57 |
+
def __init__(self, in_channels, with_conv):
|
58 |
+
super().__init__()
|
59 |
+
self.with_conv = with_conv
|
60 |
+
if self.with_conv:
|
61 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
62 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
63 |
+
in_channels,
|
64 |
+
kernel_size=3,
|
65 |
+
stride=2,
|
66 |
+
padding=0)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
if self.with_conv:
|
70 |
+
pad = (0,1,0,1)
|
71 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
72 |
+
x = self.conv(x)
|
73 |
+
else:
|
74 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
75 |
+
return x
|
76 |
+
|
77 |
+
|
78 |
+
class ResnetBlock(nn.Module):
|
79 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
80 |
+
dropout, temb_channels=512):
|
81 |
+
super().__init__()
|
82 |
+
self.in_channels = in_channels
|
83 |
+
out_channels = in_channels if out_channels is None else out_channels
|
84 |
+
self.out_channels = out_channels
|
85 |
+
self.use_conv_shortcut = conv_shortcut
|
86 |
+
|
87 |
+
self.norm1 = Normalize(in_channels)
|
88 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
89 |
+
out_channels,
|
90 |
+
kernel_size=3,
|
91 |
+
stride=1,
|
92 |
+
padding=1)
|
93 |
+
if temb_channels > 0:
|
94 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
95 |
+
out_channels)
|
96 |
+
self.norm2 = Normalize(out_channels)
|
97 |
+
self.dropout = torch.nn.Dropout(dropout)
|
98 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
99 |
+
out_channels,
|
100 |
+
kernel_size=3,
|
101 |
+
stride=1,
|
102 |
+
padding=1)
|
103 |
+
if self.in_channels != self.out_channels:
|
104 |
+
if self.use_conv_shortcut:
|
105 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
106 |
+
out_channels,
|
107 |
+
kernel_size=3,
|
108 |
+
stride=1,
|
109 |
+
padding=1)
|
110 |
+
else:
|
111 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
112 |
+
out_channels,
|
113 |
+
kernel_size=1,
|
114 |
+
stride=1,
|
115 |
+
padding=0)
|
116 |
+
|
117 |
+
def forward(self, x, temb):
|
118 |
+
h = x
|
119 |
+
h = self.norm1(h)
|
120 |
+
h = nonlinearity(h)
|
121 |
+
h = self.conv1(h)
|
122 |
+
|
123 |
+
if temb is not None:
|
124 |
+
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
125 |
+
|
126 |
+
h = self.norm2(h)
|
127 |
+
h = nonlinearity(h)
|
128 |
+
h = self.dropout(h)
|
129 |
+
h = self.conv2(h)
|
130 |
+
|
131 |
+
if self.in_channels != self.out_channels:
|
132 |
+
if self.use_conv_shortcut:
|
133 |
+
x = self.conv_shortcut(x)
|
134 |
+
else:
|
135 |
+
x = self.nin_shortcut(x)
|
136 |
+
|
137 |
+
return x+h
|
138 |
+
|
139 |
+
|
140 |
+
class AttnBlock(nn.Module):
|
141 |
+
def __init__(self, in_channels):
|
142 |
+
super().__init__()
|
143 |
+
self.in_channels = in_channels
|
144 |
+
|
145 |
+
self.norm = Normalize(in_channels)
|
146 |
+
self.q = torch.nn.Conv2d(in_channels,
|
147 |
+
in_channels,
|
148 |
+
kernel_size=1,
|
149 |
+
stride=1,
|
150 |
+
padding=0)
|
151 |
+
self.k = torch.nn.Conv2d(in_channels,
|
152 |
+
in_channels,
|
153 |
+
kernel_size=1,
|
154 |
+
stride=1,
|
155 |
+
padding=0)
|
156 |
+
self.v = torch.nn.Conv2d(in_channels,
|
157 |
+
in_channels,
|
158 |
+
kernel_size=1,
|
159 |
+
stride=1,
|
160 |
+
padding=0)
|
161 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
162 |
+
in_channels,
|
163 |
+
kernel_size=1,
|
164 |
+
stride=1,
|
165 |
+
padding=0)
|
166 |
+
|
167 |
+
|
168 |
+
def forward(self, x):
|
169 |
+
h_ = x
|
170 |
+
h_ = self.norm(h_)
|
171 |
+
q = self.q(h_)
|
172 |
+
k = self.k(h_)
|
173 |
+
v = self.v(h_)
|
174 |
+
|
175 |
+
# compute attention
|
176 |
+
b,c,h,w = q.shape
|
177 |
+
q = q.reshape(b,c,h*w)
|
178 |
+
q = q.permute(0,2,1) # b,hw,c
|
179 |
+
k = k.reshape(b,c,h*w) # b,c,hw
|
180 |
+
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
181 |
+
w_ = w_ * (int(c)**(-0.5))
|
182 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
183 |
+
|
184 |
+
# attend to values
|
185 |
+
v = v.reshape(b,c,h*w)
|
186 |
+
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
187 |
+
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
188 |
+
h_ = h_.reshape(b,c,h,w)
|
189 |
+
|
190 |
+
h_ = self.proj_out(h_)
|
191 |
+
|
192 |
+
return x+h_
|
193 |
+
|
194 |
+
|
195 |
+
class Model(nn.Module):
|
196 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
197 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
198 |
+
resolution, use_timestep=True):
|
199 |
+
super().__init__()
|
200 |
+
self.ch = ch
|
201 |
+
self.temb_ch = self.ch*4
|
202 |
+
self.num_resolutions = len(ch_mult)
|
203 |
+
self.num_res_blocks = num_res_blocks
|
204 |
+
self.resolution = resolution
|
205 |
+
self.in_channels = in_channels
|
206 |
+
|
207 |
+
self.use_timestep = use_timestep
|
208 |
+
if self.use_timestep:
|
209 |
+
# timestep embedding
|
210 |
+
self.temb = nn.Module()
|
211 |
+
self.temb.dense = nn.ModuleList([
|
212 |
+
torch.nn.Linear(self.ch,
|
213 |
+
self.temb_ch),
|
214 |
+
torch.nn.Linear(self.temb_ch,
|
215 |
+
self.temb_ch),
|
216 |
+
])
|
217 |
+
|
218 |
+
# downsampling
|
219 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
220 |
+
self.ch,
|
221 |
+
kernel_size=3,
|
222 |
+
stride=1,
|
223 |
+
padding=1)
|
224 |
+
|
225 |
+
curr_res = resolution
|
226 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
227 |
+
self.down = nn.ModuleList()
|
228 |
+
for i_level in range(self.num_resolutions):
|
229 |
+
block = nn.ModuleList()
|
230 |
+
attn = nn.ModuleList()
|
231 |
+
block_in = ch*in_ch_mult[i_level]
|
232 |
+
block_out = ch*ch_mult[i_level]
|
233 |
+
for i_block in range(self.num_res_blocks):
|
234 |
+
block.append(ResnetBlock(in_channels=block_in,
|
235 |
+
out_channels=block_out,
|
236 |
+
temb_channels=self.temb_ch,
|
237 |
+
dropout=dropout))
|
238 |
+
block_in = block_out
|
239 |
+
if curr_res in attn_resolutions:
|
240 |
+
attn.append(AttnBlock(block_in))
|
241 |
+
down = nn.Module()
|
242 |
+
down.block = block
|
243 |
+
down.attn = attn
|
244 |
+
if i_level != self.num_resolutions-1:
|
245 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
246 |
+
curr_res = curr_res // 2
|
247 |
+
self.down.append(down)
|
248 |
+
|
249 |
+
# middle
|
250 |
+
self.mid = nn.Module()
|
251 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
252 |
+
out_channels=block_in,
|
253 |
+
temb_channels=self.temb_ch,
|
254 |
+
dropout=dropout)
|
255 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
256 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
257 |
+
out_channels=block_in,
|
258 |
+
temb_channels=self.temb_ch,
|
259 |
+
dropout=dropout)
|
260 |
+
|
261 |
+
# upsampling
|
262 |
+
self.up = nn.ModuleList()
|
263 |
+
for i_level in reversed(range(self.num_resolutions)):
|
264 |
+
block = nn.ModuleList()
|
265 |
+
attn = nn.ModuleList()
|
266 |
+
block_out = ch*ch_mult[i_level]
|
267 |
+
skip_in = ch*ch_mult[i_level]
|
268 |
+
for i_block in range(self.num_res_blocks+1):
|
269 |
+
if i_block == self.num_res_blocks:
|
270 |
+
skip_in = ch*in_ch_mult[i_level]
|
271 |
+
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
272 |
+
out_channels=block_out,
|
273 |
+
temb_channels=self.temb_ch,
|
274 |
+
dropout=dropout))
|
275 |
+
block_in = block_out
|
276 |
+
if curr_res in attn_resolutions:
|
277 |
+
attn.append(AttnBlock(block_in))
|
278 |
+
up = nn.Module()
|
279 |
+
up.block = block
|
280 |
+
up.attn = attn
|
281 |
+
if i_level != 0:
|
282 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
283 |
+
curr_res = curr_res * 2
|
284 |
+
self.up.insert(0, up) # prepend to get consistent order
|
285 |
+
|
286 |
+
# end
|
287 |
+
self.norm_out = Normalize(block_in)
|
288 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
289 |
+
out_ch,
|
290 |
+
kernel_size=3,
|
291 |
+
stride=1,
|
292 |
+
padding=1)
|
293 |
+
|
294 |
+
|
295 |
+
def forward(self, x, t=None):
|
296 |
+
#assert x.shape[2] == x.shape[3] == self.resolution
|
297 |
+
|
298 |
+
if self.use_timestep:
|
299 |
+
# timestep embedding
|
300 |
+
assert t is not None
|
301 |
+
temb = get_timestep_embedding(t, self.ch)
|
302 |
+
temb = self.temb.dense[0](temb)
|
303 |
+
temb = nonlinearity(temb)
|
304 |
+
temb = self.temb.dense[1](temb)
|
305 |
+
else:
|
306 |
+
temb = None
|
307 |
+
|
308 |
+
# downsampling
|
309 |
+
hs = [self.conv_in(x)]
|
310 |
+
for i_level in range(self.num_resolutions):
|
311 |
+
for i_block in range(self.num_res_blocks):
|
312 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
313 |
+
if len(self.down[i_level].attn) > 0:
|
314 |
+
h = self.down[i_level].attn[i_block](h)
|
315 |
+
hs.append(h)
|
316 |
+
if i_level != self.num_resolutions-1:
|
317 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
318 |
+
|
319 |
+
# middle
|
320 |
+
h = hs[-1]
|
321 |
+
h = self.mid.block_1(h, temb)
|
322 |
+
h = self.mid.attn_1(h)
|
323 |
+
h = self.mid.block_2(h, temb)
|
324 |
+
|
325 |
+
# upsampling
|
326 |
+
for i_level in reversed(range(self.num_resolutions)):
|
327 |
+
for i_block in range(self.num_res_blocks+1):
|
328 |
+
h = self.up[i_level].block[i_block](
|
329 |
+
torch.cat([h, hs.pop()], dim=1), temb)
|
330 |
+
if len(self.up[i_level].attn) > 0:
|
331 |
+
h = self.up[i_level].attn[i_block](h)
|
332 |
+
if i_level != 0:
|
333 |
+
h = self.up[i_level].upsample(h)
|
334 |
+
|
335 |
+
# end
|
336 |
+
h = self.norm_out(h)
|
337 |
+
h = nonlinearity(h)
|
338 |
+
h = self.conv_out(h)
|
339 |
+
return h
|
340 |
+
|
341 |
+
|
342 |
+
class Encoder(nn.Module):
|
343 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
344 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
345 |
+
resolution, z_channels, double_z=True, **ignore_kwargs):
|
346 |
+
super().__init__()
|
347 |
+
self.ch = ch
|
348 |
+
self.temb_ch = 0
|
349 |
+
self.num_resolutions = len(ch_mult)
|
350 |
+
self.num_res_blocks = num_res_blocks
|
351 |
+
self.resolution = resolution
|
352 |
+
self.in_channels = in_channels
|
353 |
+
|
354 |
+
# downsampling
|
355 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
356 |
+
self.ch,
|
357 |
+
kernel_size=3,
|
358 |
+
stride=1,
|
359 |
+
padding=1)
|
360 |
+
|
361 |
+
curr_res = resolution
|
362 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
363 |
+
self.down = nn.ModuleList()
|
364 |
+
for i_level in range(self.num_resolutions):
|
365 |
+
block = nn.ModuleList()
|
366 |
+
attn = nn.ModuleList()
|
367 |
+
block_in = ch*in_ch_mult[i_level]
|
368 |
+
block_out = ch*ch_mult[i_level]
|
369 |
+
for i_block in range(self.num_res_blocks):
|
370 |
+
block.append(ResnetBlock(in_channels=block_in,
|
371 |
+
out_channels=block_out,
|
372 |
+
temb_channels=self.temb_ch,
|
373 |
+
dropout=dropout))
|
374 |
+
block_in = block_out
|
375 |
+
if curr_res in attn_resolutions:
|
376 |
+
attn.append(AttnBlock(block_in))
|
377 |
+
down = nn.Module()
|
378 |
+
down.block = block
|
379 |
+
down.attn = attn
|
380 |
+
if i_level != self.num_resolutions-1:
|
381 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
382 |
+
curr_res = curr_res // 2
|
383 |
+
self.down.append(down)
|
384 |
+
|
385 |
+
# middle
|
386 |
+
self.mid = nn.Module()
|
387 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
388 |
+
out_channels=block_in,
|
389 |
+
temb_channels=self.temb_ch,
|
390 |
+
dropout=dropout)
|
391 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
392 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
393 |
+
out_channels=block_in,
|
394 |
+
temb_channels=self.temb_ch,
|
395 |
+
dropout=dropout)
|
396 |
+
|
397 |
+
# end
|
398 |
+
self.norm_out = Normalize(block_in)
|
399 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
400 |
+
2*z_channels if double_z else z_channels,
|
401 |
+
kernel_size=3,
|
402 |
+
stride=1,
|
403 |
+
padding=1)
|
404 |
+
|
405 |
+
|
406 |
+
def forward(self, x):
|
407 |
+
#assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
|
408 |
+
|
409 |
+
# timestep embedding
|
410 |
+
temb = None
|
411 |
+
|
412 |
+
# downsampling
|
413 |
+
hs = [self.conv_in(x)]
|
414 |
+
for i_level in range(self.num_resolutions):
|
415 |
+
for i_block in range(self.num_res_blocks):
|
416 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
417 |
+
if len(self.down[i_level].attn) > 0:
|
418 |
+
h = self.down[i_level].attn[i_block](h)
|
419 |
+
hs.append(h)
|
420 |
+
if i_level != self.num_resolutions-1:
|
421 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
422 |
+
|
423 |
+
# middle
|
424 |
+
h = hs[-1]
|
425 |
+
h = self.mid.block_1(h, temb)
|
426 |
+
h = self.mid.attn_1(h)
|
427 |
+
h = self.mid.block_2(h, temb)
|
428 |
+
|
429 |
+
# end
|
430 |
+
h = self.norm_out(h)
|
431 |
+
h = nonlinearity(h)
|
432 |
+
h = self.conv_out(h)
|
433 |
+
return h
|
434 |
+
|
435 |
+
|
436 |
+
class Decoder(nn.Module):
|
437 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
438 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
439 |
+
resolution, z_channels, give_pre_end=False, **ignorekwargs):
|
440 |
+
super().__init__()
|
441 |
+
self.ch = ch
|
442 |
+
self.temb_ch = 0
|
443 |
+
self.num_resolutions = len(ch_mult)
|
444 |
+
self.num_res_blocks = num_res_blocks
|
445 |
+
self.resolution = resolution
|
446 |
+
self.in_channels = in_channels
|
447 |
+
self.give_pre_end = give_pre_end
|
448 |
+
|
449 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
450 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
451 |
+
block_in = ch*ch_mult[self.num_resolutions-1]
|
452 |
+
curr_res = resolution // 2**(self.num_resolutions-1)
|
453 |
+
self.z_shape = (1,z_channels,curr_res,curr_res)
|
454 |
+
print("Working with z of shape {} = {} dimensions.".format(
|
455 |
+
self.z_shape, np.prod(self.z_shape)))
|
456 |
+
|
457 |
+
# z to block_in
|
458 |
+
self.conv_in = torch.nn.Conv2d(z_channels,
|
459 |
+
block_in,
|
460 |
+
kernel_size=3,
|
461 |
+
stride=1,
|
462 |
+
padding=1)
|
463 |
+
|
464 |
+
# middle
|
465 |
+
self.mid = nn.Module()
|
466 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
467 |
+
out_channels=block_in,
|
468 |
+
temb_channels=self.temb_ch,
|
469 |
+
dropout=dropout)
|
470 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
471 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
472 |
+
out_channels=block_in,
|
473 |
+
temb_channels=self.temb_ch,
|
474 |
+
dropout=dropout)
|
475 |
+
|
476 |
+
# upsampling
|
477 |
+
self.up = nn.ModuleList()
|
478 |
+
for i_level in reversed(range(self.num_resolutions)):
|
479 |
+
block = nn.ModuleList()
|
480 |
+
attn = nn.ModuleList()
|
481 |
+
block_out = ch*ch_mult[i_level]
|
482 |
+
for i_block in range(self.num_res_blocks+1):
|
483 |
+
block.append(ResnetBlock(in_channels=block_in,
|
484 |
+
out_channels=block_out,
|
485 |
+
temb_channels=self.temb_ch,
|
486 |
+
dropout=dropout))
|
487 |
+
block_in = block_out
|
488 |
+
if curr_res in attn_resolutions:
|
489 |
+
attn.append(AttnBlock(block_in))
|
490 |
+
up = nn.Module()
|
491 |
+
up.block = block
|
492 |
+
up.attn = attn
|
493 |
+
if i_level != 0:
|
494 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
495 |
+
curr_res = curr_res * 2
|
496 |
+
self.up.insert(0, up) # prepend to get consistent order
|
497 |
+
|
498 |
+
# end
|
499 |
+
self.norm_out = Normalize(block_in)
|
500 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
501 |
+
out_ch,
|
502 |
+
kernel_size=3,
|
503 |
+
stride=1,
|
504 |
+
padding=1)
|
505 |
+
|
506 |
+
def forward(self, z):
|
507 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
508 |
+
self.last_z_shape = z.shape
|
509 |
+
|
510 |
+
# timestep embedding
|
511 |
+
temb = None
|
512 |
+
|
513 |
+
# z to block_in
|
514 |
+
h = self.conv_in(z)
|
515 |
+
|
516 |
+
# middle
|
517 |
+
h = self.mid.block_1(h, temb)
|
518 |
+
h = self.mid.attn_1(h)
|
519 |
+
h = self.mid.block_2(h, temb)
|
520 |
+
|
521 |
+
# upsampling
|
522 |
+
for i_level in reversed(range(self.num_resolutions)):
|
523 |
+
for i_block in range(self.num_res_blocks+1):
|
524 |
+
h = self.up[i_level].block[i_block](h, temb)
|
525 |
+
if len(self.up[i_level].attn) > 0:
|
526 |
+
h = self.up[i_level].attn[i_block](h)
|
527 |
+
if i_level != 0:
|
528 |
+
h = self.up[i_level].upsample(h)
|
529 |
+
|
530 |
+
# end
|
531 |
+
if self.give_pre_end:
|
532 |
+
return h
|
533 |
+
|
534 |
+
h = self.norm_out(h)
|
535 |
+
h = nonlinearity(h)
|
536 |
+
h = self.conv_out(h)
|
537 |
+
return h
|
538 |
+
|
539 |
+
|
540 |
+
class VUNet(nn.Module):
|
541 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
542 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True,
|
543 |
+
in_channels, c_channels,
|
544 |
+
resolution, z_channels, use_timestep=False, **ignore_kwargs):
|
545 |
+
super().__init__()
|
546 |
+
self.ch = ch
|
547 |
+
self.temb_ch = self.ch*4
|
548 |
+
self.num_resolutions = len(ch_mult)
|
549 |
+
self.num_res_blocks = num_res_blocks
|
550 |
+
self.resolution = resolution
|
551 |
+
|
552 |
+
self.use_timestep = use_timestep
|
553 |
+
if self.use_timestep:
|
554 |
+
# timestep embedding
|
555 |
+
self.temb = nn.Module()
|
556 |
+
self.temb.dense = nn.ModuleList([
|
557 |
+
torch.nn.Linear(self.ch,
|
558 |
+
self.temb_ch),
|
559 |
+
torch.nn.Linear(self.temb_ch,
|
560 |
+
self.temb_ch),
|
561 |
+
])
|
562 |
+
|
563 |
+
# downsampling
|
564 |
+
self.conv_in = torch.nn.Conv2d(c_channels,
|
565 |
+
self.ch,
|
566 |
+
kernel_size=3,
|
567 |
+
stride=1,
|
568 |
+
padding=1)
|
569 |
+
|
570 |
+
curr_res = resolution
|
571 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
572 |
+
self.down = nn.ModuleList()
|
573 |
+
for i_level in range(self.num_resolutions):
|
574 |
+
block = nn.ModuleList()
|
575 |
+
attn = nn.ModuleList()
|
576 |
+
block_in = ch*in_ch_mult[i_level]
|
577 |
+
block_out = ch*ch_mult[i_level]
|
578 |
+
for i_block in range(self.num_res_blocks):
|
579 |
+
block.append(ResnetBlock(in_channels=block_in,
|
580 |
+
out_channels=block_out,
|
581 |
+
temb_channels=self.temb_ch,
|
582 |
+
dropout=dropout))
|
583 |
+
block_in = block_out
|
584 |
+
if curr_res in attn_resolutions:
|
585 |
+
attn.append(AttnBlock(block_in))
|
586 |
+
down = nn.Module()
|
587 |
+
down.block = block
|
588 |
+
down.attn = attn
|
589 |
+
if i_level != self.num_resolutions-1:
|
590 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
591 |
+
curr_res = curr_res // 2
|
592 |
+
self.down.append(down)
|
593 |
+
|
594 |
+
self.z_in = torch.nn.Conv2d(z_channels,
|
595 |
+
block_in,
|
596 |
+
kernel_size=1,
|
597 |
+
stride=1,
|
598 |
+
padding=0)
|
599 |
+
# middle
|
600 |
+
self.mid = nn.Module()
|
601 |
+
self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
|
602 |
+
out_channels=block_in,
|
603 |
+
temb_channels=self.temb_ch,
|
604 |
+
dropout=dropout)
|
605 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
606 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
607 |
+
out_channels=block_in,
|
608 |
+
temb_channels=self.temb_ch,
|
609 |
+
dropout=dropout)
|
610 |
+
|
611 |
+
# upsampling
|
612 |
+
self.up = nn.ModuleList()
|
613 |
+
for i_level in reversed(range(self.num_resolutions)):
|
614 |
+
block = nn.ModuleList()
|
615 |
+
attn = nn.ModuleList()
|
616 |
+
block_out = ch*ch_mult[i_level]
|
617 |
+
skip_in = ch*ch_mult[i_level]
|
618 |
+
for i_block in range(self.num_res_blocks+1):
|
619 |
+
if i_block == self.num_res_blocks:
|
620 |
+
skip_in = ch*in_ch_mult[i_level]
|
621 |
+
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
622 |
+
out_channels=block_out,
|
623 |
+
temb_channels=self.temb_ch,
|
624 |
+
dropout=dropout))
|
625 |
+
block_in = block_out
|
626 |
+
if curr_res in attn_resolutions:
|
627 |
+
attn.append(AttnBlock(block_in))
|
628 |
+
up = nn.Module()
|
629 |
+
up.block = block
|
630 |
+
up.attn = attn
|
631 |
+
if i_level != 0:
|
632 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
633 |
+
curr_res = curr_res * 2
|
634 |
+
self.up.insert(0, up) # prepend to get consistent order
|
635 |
+
|
636 |
+
# end
|
637 |
+
self.norm_out = Normalize(block_in)
|
638 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
639 |
+
out_ch,
|
640 |
+
kernel_size=3,
|
641 |
+
stride=1,
|
642 |
+
padding=1)
|
643 |
+
|
644 |
+
|
645 |
+
def forward(self, x, z):
|
646 |
+
#assert x.shape[2] == x.shape[3] == self.resolution
|
647 |
+
|
648 |
+
if self.use_timestep:
|
649 |
+
# timestep embedding
|
650 |
+
assert t is not None
|
651 |
+
temb = get_timestep_embedding(t, self.ch)
|
652 |
+
temb = self.temb.dense[0](temb)
|
653 |
+
temb = nonlinearity(temb)
|
654 |
+
temb = self.temb.dense[1](temb)
|
655 |
+
else:
|
656 |
+
temb = None
|
657 |
+
|
658 |
+
# downsampling
|
659 |
+
hs = [self.conv_in(x)]
|
660 |
+
for i_level in range(self.num_resolutions):
|
661 |
+
for i_block in range(self.num_res_blocks):
|
662 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
663 |
+
if len(self.down[i_level].attn) > 0:
|
664 |
+
h = self.down[i_level].attn[i_block](h)
|
665 |
+
hs.append(h)
|
666 |
+
if i_level != self.num_resolutions-1:
|
667 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
668 |
+
|
669 |
+
# middle
|
670 |
+
h = hs[-1]
|
671 |
+
z = self.z_in(z)
|
672 |
+
h = torch.cat((h,z),dim=1)
|
673 |
+
h = self.mid.block_1(h, temb)
|
674 |
+
h = self.mid.attn_1(h)
|
675 |
+
h = self.mid.block_2(h, temb)
|
676 |
+
|
677 |
+
# upsampling
|
678 |
+
for i_level in reversed(range(self.num_resolutions)):
|
679 |
+
for i_block in range(self.num_res_blocks+1):
|
680 |
+
h = self.up[i_level].block[i_block](
|
681 |
+
torch.cat([h, hs.pop()], dim=1), temb)
|
682 |
+
if len(self.up[i_level].attn) > 0:
|
683 |
+
h = self.up[i_level].attn[i_block](h)
|
684 |
+
if i_level != 0:
|
685 |
+
h = self.up[i_level].upsample(h)
|
686 |
+
|
687 |
+
# end
|
688 |
+
h = self.norm_out(h)
|
689 |
+
h = nonlinearity(h)
|
690 |
+
h = self.conv_out(h)
|
691 |
+
return h
|
692 |
+
|
693 |
+
|
694 |
+
class SimpleDecoder(nn.Module):
|
695 |
+
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
696 |
+
super().__init__()
|
697 |
+
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
|
698 |
+
ResnetBlock(in_channels=in_channels,
|
699 |
+
out_channels=2 * in_channels,
|
700 |
+
temb_channels=0, dropout=0.0),
|
701 |
+
ResnetBlock(in_channels=2 * in_channels,
|
702 |
+
out_channels=4 * in_channels,
|
703 |
+
temb_channels=0, dropout=0.0),
|
704 |
+
ResnetBlock(in_channels=4 * in_channels,
|
705 |
+
out_channels=2 * in_channels,
|
706 |
+
temb_channels=0, dropout=0.0),
|
707 |
+
nn.Conv2d(2*in_channels, in_channels, 1),
|
708 |
+
Upsample(in_channels, with_conv=True)])
|
709 |
+
# end
|
710 |
+
self.norm_out = Normalize(in_channels)
|
711 |
+
self.conv_out = torch.nn.Conv2d(in_channels,
|
712 |
+
out_channels,
|
713 |
+
kernel_size=3,
|
714 |
+
stride=1,
|
715 |
+
padding=1)
|
716 |
+
|
717 |
+
def forward(self, x):
|
718 |
+
for i, layer in enumerate(self.model):
|
719 |
+
if i in [1,2,3]:
|
720 |
+
x = layer(x, None)
|
721 |
+
else:
|
722 |
+
x = layer(x)
|
723 |
+
|
724 |
+
h = self.norm_out(x)
|
725 |
+
h = nonlinearity(h)
|
726 |
+
x = self.conv_out(h)
|
727 |
+
return x
|
728 |
+
|
729 |
+
|
730 |
+
class UpsampleDecoder(nn.Module):
|
731 |
+
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
|
732 |
+
ch_mult=(2,2), dropout=0.0):
|
733 |
+
super().__init__()
|
734 |
+
# upsampling
|
735 |
+
self.temb_ch = 0
|
736 |
+
self.num_resolutions = len(ch_mult)
|
737 |
+
self.num_res_blocks = num_res_blocks
|
738 |
+
block_in = in_channels
|
739 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
740 |
+
self.res_blocks = nn.ModuleList()
|
741 |
+
self.upsample_blocks = nn.ModuleList()
|
742 |
+
for i_level in range(self.num_resolutions):
|
743 |
+
res_block = []
|
744 |
+
block_out = ch * ch_mult[i_level]
|
745 |
+
for i_block in range(self.num_res_blocks + 1):
|
746 |
+
res_block.append(ResnetBlock(in_channels=block_in,
|
747 |
+
out_channels=block_out,
|
748 |
+
temb_channels=self.temb_ch,
|
749 |
+
dropout=dropout))
|
750 |
+
block_in = block_out
|
751 |
+
self.res_blocks.append(nn.ModuleList(res_block))
|
752 |
+
if i_level != self.num_resolutions - 1:
|
753 |
+
self.upsample_blocks.append(Upsample(block_in, True))
|
754 |
+
curr_res = curr_res * 2
|
755 |
+
|
756 |
+
# end
|
757 |
+
self.norm_out = Normalize(block_in)
|
758 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
759 |
+
out_channels,
|
760 |
+
kernel_size=3,
|
761 |
+
stride=1,
|
762 |
+
padding=1)
|
763 |
+
|
764 |
+
def forward(self, x):
|
765 |
+
# upsampling
|
766 |
+
h = x
|
767 |
+
for k, i_level in enumerate(range(self.num_resolutions)):
|
768 |
+
for i_block in range(self.num_res_blocks + 1):
|
769 |
+
h = self.res_blocks[i_level][i_block](h, None)
|
770 |
+
if i_level != self.num_resolutions - 1:
|
771 |
+
h = self.upsample_blocks[k](h)
|
772 |
+
h = self.norm_out(h)
|
773 |
+
h = nonlinearity(h)
|
774 |
+
h = self.conv_out(h)
|
775 |
+
return h
|
776 |
+
|
taming/modules/discriminator/__pycache__/model.cpython-312.pyc
ADDED
Binary file (3.81 kB). View file
|
|
taming/modules/discriminator/model.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
from taming.modules.util import ActNorm
|
6 |
+
|
7 |
+
|
8 |
+
def weights_init(m):
|
9 |
+
classname = m.__class__.__name__
|
10 |
+
if classname.find('Conv') != -1:
|
11 |
+
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
12 |
+
elif classname.find('BatchNorm') != -1:
|
13 |
+
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
14 |
+
nn.init.constant_(m.bias.data, 0)
|
15 |
+
|
16 |
+
|
17 |
+
class NLayerDiscriminator(nn.Module):
|
18 |
+
"""Defines a PatchGAN discriminator as in Pix2Pix
|
19 |
+
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
20 |
+
"""
|
21 |
+
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
22 |
+
"""Construct a PatchGAN discriminator
|
23 |
+
Parameters:
|
24 |
+
input_nc (int) -- the number of channels in input images
|
25 |
+
ndf (int) -- the number of filters in the last conv layer
|
26 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
27 |
+
norm_layer -- normalization layer
|
28 |
+
"""
|
29 |
+
super(NLayerDiscriminator, self).__init__()
|
30 |
+
if not use_actnorm:
|
31 |
+
norm_layer = nn.BatchNorm2d
|
32 |
+
else:
|
33 |
+
norm_layer = ActNorm
|
34 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
35 |
+
use_bias = norm_layer.func != nn.BatchNorm2d
|
36 |
+
else:
|
37 |
+
use_bias = norm_layer != nn.BatchNorm2d
|
38 |
+
|
39 |
+
kw = 4
|
40 |
+
padw = 1
|
41 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
42 |
+
nf_mult = 1
|
43 |
+
nf_mult_prev = 1
|
44 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
45 |
+
nf_mult_prev = nf_mult
|
46 |
+
nf_mult = min(2 ** n, 8)
|
47 |
+
sequence += [
|
48 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
49 |
+
norm_layer(ndf * nf_mult),
|
50 |
+
nn.LeakyReLU(0.2, True)
|
51 |
+
]
|
52 |
+
|
53 |
+
nf_mult_prev = nf_mult
|
54 |
+
nf_mult = min(2 ** n_layers, 8)
|
55 |
+
sequence += [
|
56 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
57 |
+
norm_layer(ndf * nf_mult),
|
58 |
+
nn.LeakyReLU(0.2, True)
|
59 |
+
]
|
60 |
+
|
61 |
+
sequence += [
|
62 |
+
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
63 |
+
self.main = nn.Sequential(*sequence)
|
64 |
+
|
65 |
+
def forward(self, input):
|
66 |
+
"""Standard forward."""
|
67 |
+
return self.main(input)
|
taming/modules/losses/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from taming.modules.losses.vqperceptual import DummyLoss
|
2 |
+
|
taming/modules/losses/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (253 Bytes). View file
|
|