DDCM commited on
Commit
b273838
·
0 Parent(s):

initial commit

Browse files
Files changed (49) hide show
  1. .gitattributes +35 -0
  2. DDCM_blind_face_image_restoration.py +247 -0
  3. README.md +20 -0
  4. app.py +293 -0
  5. examples/bfr/00000055.png +0 -0
  6. examples/bfr/00000085.png +0 -0
  7. examples/bfr/00000113.png +0 -0
  8. examples/bfr/00000137.png +0 -0
  9. examples/bfr/01.png +0 -0
  10. examples/bfr/03.jpg +0 -0
  11. examples/bfr/lfw/Ana_Palacio_0001_00.jpg +0 -0
  12. examples/bfr/webphoto/00042_00.jpg +0 -0
  13. examples/bfr/wider/0005.jpg +0 -0
  14. examples/bfr/wider/0022.jpg +0 -0
  15. examples/bfr/wider/0034.jpg +0 -0
  16. examples/compression/1.jpg +0 -0
  17. examples/compression/13.jpg +0 -0
  18. examples/compression/15.jpg +0 -0
  19. examples/compression/17.jpg +0 -0
  20. examples/compression/18.jpg +0 -0
  21. examples/compression/19.jpg +0 -0
  22. examples/compression/2.jpg +0 -0
  23. examples/compression/20.jpg +0 -0
  24. examples/compression/21.jpg +0 -0
  25. examples/compression/22.jpg +0 -0
  26. examples/compression/23.jpg +0 -0
  27. examples/compression/4.jpg +0 -0
  28. examples/compression/7.jpg +0 -0
  29. examples/compression/8.jpg +0 -0
  30. guided_diffusion/__init__.py +3 -0
  31. guided_diffusion/condition_methods.py +106 -0
  32. guided_diffusion/diffusion_config.yaml +9 -0
  33. guided_diffusion/ffhq512_model_config.yaml +24 -0
  34. guided_diffusion/fp16_util.py +234 -0
  35. guided_diffusion/gaussian_diffusion.py +864 -0
  36. guided_diffusion/measurements.py +314 -0
  37. guided_diffusion/nn.py +170 -0
  38. guided_diffusion/posterior_mean_variance.py +264 -0
  39. guided_diffusion/swinir.py +904 -0
  40. guided_diffusion/unet.py +1148 -0
  41. latent_DDCM_CCFG.py +45 -0
  42. latent_DDCM_compression.py +47 -0
  43. latent_models.py +278 -0
  44. latent_utils.py +322 -0
  45. requirements.txt +15 -0
  46. util/__init__.py +0 -0
  47. util/basicsr_img_util.py +172 -0
  48. util/file.py +55 -0
  49. util/img_utils.py +423 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
DDCM_blind_face_image_restoration.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import partial
3
+
4
+ import cv2
5
+ import gradio as gr
6
+ import spaces
7
+ from util.file import generate_binary_file, load_numpy_from_binary_bitwise
8
+ import torch
9
+ import yaml
10
+ from util.basicsr_img_util import img2tensor, tensor2img
11
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
12
+ from torchvision.transforms.functional import resize
13
+
14
+ from guided_diffusion.gaussian_diffusion import create_sampler
15
+ from guided_diffusion.swinir import SwinIR
16
+ from guided_diffusion.unet import create_model
17
+
18
+
19
+ def create_swinir_model(ckpt_path):
20
+ cfg = {
21
+ 'in_channels': 3,
22
+ 'out_channels': 3,
23
+ 'embed_dim': 180,
24
+ 'depths': [6, 6, 6, 6, 6, 6, 6, 6],
25
+ 'num_heads': [6, 6, 6, 6, 6, 6, 6, 6],
26
+ 'resi_connection': '1conv',
27
+ 'sf': 8
28
+ }
29
+ mmse_model = SwinIR(
30
+ img_size=64,
31
+ patch_size=1,
32
+ in_chans=cfg['in_channels'],
33
+ num_out_ch=cfg['out_channels'],
34
+ embed_dim=cfg['embed_dim'],
35
+ depths=cfg['depths'],
36
+ num_heads=cfg['num_heads'],
37
+ window_size=8,
38
+ mlp_ratio=2,
39
+ sf=cfg['sf'],
40
+ img_range=1.0,
41
+ upsampler="nearest+conv",
42
+ resi_connection=cfg['resi_connection'],
43
+ unshuffle=True,
44
+ unshuffle_scale=8
45
+ )
46
+ ckpt = torch.load(ckpt_path, map_location="cpu")
47
+
48
+ if 'params_ema' in ckpt:
49
+ mmse_model.load_state_dict(ckpt['params_ema'])
50
+ else:
51
+ state_dict = ckpt['state_dict']
52
+ state_dict = {layer_name.replace('model.', ''): weights for layer_name, weights in
53
+ state_dict.items()}
54
+ state_dict = {layer_name.replace('module.', ''): weights for layer_name, weights in
55
+ state_dict.items()}
56
+ mmse_model.load_state_dict(state_dict)
57
+ for param in mmse_model.parameters():
58
+ param.requires_grad = False
59
+ return mmse_model
60
+
61
+
62
+ ffhq_diffusion_model = "./guided_diffusion/iddpm_ffhq512_ema500000.pth"
63
+ mmse_model_ckpt = "./guided_diffusion/swinir_restoration512_L1.pth"
64
+
65
+ if not os.path.exists(ffhq_diffusion_model):
66
+ os.system(
67
+ "wget https://github.com/zsyOAOA/DifFace/releases/download/V1.0/iddpm_ffhq512_ema500000.pth -O ./guided_diffusion/iddpm_ffhq512_ema500000.pth"
68
+ )
69
+ if not os.path.exists(mmse_model_ckpt):
70
+ os.system(
71
+ "wget https://github.com/zsyOAOA/DifFace/releases/download/V1.0/swinir_restoration512_L1.pth -O ./guided_diffusion/swinir_restoration512_L1.pth"
72
+ )
73
+
74
+
75
+ def load_yaml(file_path: str) -> dict:
76
+ with open(file_path) as f:
77
+ config = yaml.load(f, Loader=yaml.FullLoader)
78
+ return config
79
+
80
+
81
+ model_config = './guided_diffusion/ffhq512_model_config.yaml'
82
+ diffusion_config = './guided_diffusion/diffusion_config.yaml'
83
+ model_config = load_yaml(model_config)
84
+ diffusion_config = load_yaml(diffusion_config)
85
+
86
+ models = {
87
+ 'main_model': create_model(**model_config),
88
+ 'mmse_model': create_swinir_model('./guided_diffusion/swinir_restoration512_L1.pth')
89
+ }
90
+ models['main_model'].eval()
91
+ models['mmse_model'].eval()
92
+
93
+
94
+ @torch.no_grad()
95
+ @spaces.GPU(duration=80)
96
+ def generate_reconstruction(degraded_face_img, K, T, iqa_metric, iqa_coef, loaded_indices):
97
+ assert iqa_metric in ['niqe', 'clipiqa+', 'topiq_nr-face']
98
+ diffusion_config['timestep_respacing'] = T
99
+ sampler = create_sampler(**diffusion_config)
100
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
101
+ model = models['main_model'].to(device)
102
+ mmse_model = models['mmse_model'].to(device)
103
+
104
+ sample_fn = partial(sampler.p_sample_loop_blind_restoration, model=model, num_opt_noises=K,
105
+ eta=1.0, iqa_metric=iqa_metric, iqa_coef=iqa_coef)
106
+
107
+ if degraded_face_img is not None:
108
+ mmse_img = mmse_model(degraded_face_img).clip(0, 1) * 2 - 1
109
+ x_start = torch.randn(mmse_img.shape, device=device)
110
+ else:
111
+ mmse_img = None
112
+ x_start = torch.randn(1, 3, 512, 512, device=device)
113
+ restored_face, indices = sample_fn(x_start=x_start, mmse_img=mmse_img, loaded_indices=loaded_indices)
114
+
115
+ return restored_face, indices
116
+
117
+
118
+ def resize(img, size):
119
+ # From https://github.com/sczhou/CodeFormer/blob/master/facelib/utils/face_restoration_helper.py
120
+ h, w = img.shape[0:2]
121
+ scale = size / min(h, w)
122
+ h, w = int(h * scale), int(w * scale)
123
+ interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
124
+ return cv2.resize(img, (w, h), interpolation=interp)
125
+
126
+
127
+ @torch.no_grad()
128
+ @spaces.GPU(duration=80)
129
+ def enhance_faces(img, face_helper, has_aligned, K, T, iqa_metric, iqa_coef, loaded_indices):
130
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
131
+ face_helper.clean_all()
132
+ if has_aligned: # The inputs are already aligned
133
+ img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
134
+ face_helper.cropped_faces = [img]
135
+ else:
136
+ face_helper.read_image(img)
137
+ face_helper.input_img = resize(face_helper.input_img, 640)
138
+ face_helper.get_face_landmarks_5(only_center_face=False, eye_dist_threshold=5)
139
+ face_helper.align_warp_face()
140
+ if len(face_helper.cropped_faces) == 0:
141
+ raise gr.Error("Could not identify any face in the image.")
142
+ if has_aligned and len(face_helper.cropped_faces) > 1:
143
+ raise gr.Error(
144
+ "You marked that the input image is aligned, but multiple faces were detected."
145
+ )
146
+ restored_faces = []
147
+ generated_indices = []
148
+ for i, cropped_face in enumerate(face_helper.cropped_faces):
149
+ cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True)
150
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
151
+ cur_loaded_indices = loaded_indices[i] if loaded_indices is not None else None
152
+
153
+ output, indices = generate_reconstruction(
154
+ cropped_face_t,
155
+ K,
156
+ T,
157
+ iqa_metric,
158
+ iqa_coef,
159
+ cur_loaded_indices
160
+ )
161
+
162
+ restored_face = tensor2img(
163
+ output.to(torch.float32).squeeze(0), rgb2bgr=False, min_max=(-1, 1)
164
+ )
165
+
166
+ restored_face = restored_face.astype("uint8")
167
+ restored_faces.append(restored_face),
168
+ generated_indices.append(indices)
169
+ return restored_faces, generated_indices
170
+
171
+
172
+ @torch.no_grad()
173
+ @spaces.GPU()
174
+ def decompress_face(K, T, iqa_metric, iqa_coef, loaded_indices):
175
+ assert loaded_indices is not None
176
+
177
+ output, indices = generate_reconstruction(
178
+ None,
179
+ K,
180
+ T,
181
+ iqa_metric,
182
+ iqa_coef,
183
+ loaded_indices
184
+ )
185
+
186
+ restored_face = tensor2img(
187
+ output.to(torch.float32).squeeze(0), rgb2bgr=False, min_max=(-1, 1)
188
+ ).astype("uint8")
189
+
190
+ return restored_face, loaded_indices
191
+
192
+ @torch.no_grad()
193
+ @spaces.GPU(duration=80)
194
+ def inference(
195
+ img,
196
+ T,
197
+ K,
198
+ iqa_metric,
199
+ iqa_coef,
200
+ aligned,
201
+ bitstream=None,
202
+ progress=gr.Progress(track_tqdm=True),
203
+ ):
204
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
205
+
206
+ iqa_metric_to_pyiqa_name = {
207
+ 'NIQE': 'niqe',
208
+ 'TOPIQ': 'topiq_nr-face',
209
+ 'CLIP-IQA': 'clipiqa+'
210
+ }
211
+ iqa_metric = iqa_metric_to_pyiqa_name[iqa_metric]
212
+ indices = load_numpy_from_binary_bitwise(bitstream, K, T, 'ffhq', T)
213
+ if indices is not None:
214
+ indices = indices.to(device)
215
+
216
+ if img is not None:
217
+ img = cv2.imread(img, cv2.IMREAD_COLOR)
218
+ h, w = img.shape[0:2]
219
+ if h > 4500 or w > 4500:
220
+ raise gr.Error("Image size too large.")
221
+
222
+ face_helper = FaceRestoreHelper(
223
+ 1,
224
+ face_size=512,
225
+ crop_ratio=(1, 1),
226
+ det_model="retinaface_resnet50",
227
+ save_ext="png",
228
+ use_parse=True,
229
+ device=device,
230
+ model_rootpath=None,
231
+ )
232
+
233
+ x, indices = enhance_faces(
234
+ img, face_helper, aligned, K=K, T=T, iqa_metric=iqa_metric, iqa_coef=iqa_coef,
235
+ loaded_indices=indices,
236
+ )
237
+ else:
238
+ x, indices = decompress_face(
239
+ K=K, T=T, iqa_metric=iqa_metric, iqa_coef=iqa_coef, loaded_indices=indices,
240
+ )
241
+
242
+ torch.cuda.empty_cache()
243
+
244
+ if bitstream is None:
245
+ indices = [generate_binary_file(index.numpy(), K, T, 'ffhq') for index in indices]
246
+ return x, indices
247
+ return x
README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Compressed Image Generation with Denoising Diffusion Codebook Models
3
+ emoji: 📖
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 5.14.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ tags:
12
+ - image-generation
13
+ - blind-face-image-restoration
14
+ - image-compression
15
+ - text-to-image-generation
16
+ - compressed-image-generation
17
+ short_description: Generate compressed images given different input conditions
18
+ ---
19
+
20
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from functools import partial
3
+ import torch
4
+ import spaces
5
+
6
+ import DDCM_blind_face_image_restoration
7
+ import latent_DDCM_CCFG
8
+ import latent_DDCM_compression
9
+ from latent_models import load_model
10
+ import os
11
+ # import transformers
12
+ # transformers.utils.move_cache()
13
+
14
+
15
+ if os.getenv("SPACES_ZERO_GPU") == "true":
16
+ os.environ["SPACES_ZERO_GPU"] = "1"
17
+
18
+
19
+ avail_models = {'512x512': load_model('stabilityai/stable-diffusion-2-1-base', 1000, float16=True, device=torch.device("cpu"), compile=False)[0],
20
+ '768x768': load_model('stabilityai/stable-diffusion-2-1', 1000, float16=True, device=torch.device("cpu"), compile=False)[0]
21
+ }
22
+
23
+ compression_func = partial(latent_DDCM_compression.main, avail_models=avail_models)
24
+
25
+
26
+ def get_t_and_k_from_file_name(file_name):
27
+ T = int(file_name.split('T')[1].split('-')[0])
28
+ K = int(file_name.split('K')[1].split('-')[0])
29
+ model_type = file_name.split('M')[1].split('-')[0]
30
+ return T, K, model_type
31
+
32
+
33
+ def ccfg(text_input, T, K, ccfg_scale, model_type, compressed_file_in=None):
34
+ return latent_DDCM_CCFG.main(text_input, T, K, min(ccfg_scale, K), model_type, compressed_file_in,
35
+ avail_models=avail_models)
36
+ # return latent_DDCM_CCFG.main(text_input, T, K, min(ccfg_scale, K), compressed_file_in)
37
+
38
+
39
+ @spaces.GPU
40
+ def decompress_given_bitstream(bitstream, method):
41
+ if bitstream is None:
42
+ gr.Error("Please provide a bit-stream file when performing decompression")
43
+ file_name = bitstream.name
44
+ T, K, model_type = get_t_and_k_from_file_name(file_name)
45
+ if method == 'compression':
46
+ return compression_func(None, T, K, model_type, bitstream)
47
+ elif method == 'blind':
48
+ return DDCM_blind_face_image_restoration.inference(None, T, K, 'NIQE', 1, True, bitstream)
49
+ elif method == 'ccfg':
50
+ return ccfg(None, T, K, -1, model_type, bitstream)
51
+ else:
52
+ raise NotImplementedError()
53
+
54
+
55
+ def validate_K(K):
56
+ if (K & (K - 1)) != 0:
57
+ gr.Warning("For efficient bit usage, K should be a power of 2.")
58
+
59
+
60
+ method_to_func = {
61
+ 'compression': partial(decompress_given_bitstream, method='compression'),
62
+ 'blind': partial(decompress_given_bitstream, method='blind'),
63
+ 'ccfg': partial(decompress_given_bitstream, method='ccfg'),
64
+ }
65
+
66
+ title = "<div style='text-align: center; font-size: 36px; font-weight: bold;'>Compressed Image Generation with Denoising Diffusion Codebook Models</div>"
67
+ intro = """
68
+ <h3 style="margin-bottom: 10px; text-align: center;">
69
+ <a href="https://ohayonguy.github.io/">Guy Ohayon*</a>&nbsp;,&nbsp;
70
+ <a href="https://hilamanor.github.io/">Hila Manor*</a>&nbsp;,&nbsp;
71
+ <a href="https://tomer.net.technion.ac.il/">Tomer Michaeli</a>&nbsp;,&nbsp;
72
+ <a href="https://elad.cs.technion.ac.il/">Michael Elad</a>
73
+ </h3>
74
+ <p style="font-size: 12px; text-align: center; margin-bottom: 10px;">
75
+ * Equal contribution
76
+ </p>
77
+ <h4 style="margin-bottom: 10px; text-align: center;">
78
+ Technion - Israel Institute of Technology
79
+ </h5>
80
+ <h3 style="margin-bottom: 10px; text-align: center;">
81
+ <a href="https://www.arxiv.org/abs/2502.01189/">[Paper]</a>&nbsp;|&nbsp;
82
+ <a href="https://ddcm-2025.github.io/">[Project Page]</a>&nbsp;|&nbsp;
83
+ <a href="https://github.com/DDCM-2025/ddcm-compressed-image-generation/">[Code]</a>
84
+ </h3>
85
+ </br></br>
86
+ Denoising Diffusion Codebook Models (DDCM) is a novel (and simple) generative approach based on any Denoising Diffusion Model (DDM), that is able to produce high-quality image samples along with their losslessly compressed bit-stream representations.
87
+ DDCM can easily be utilized for perceptual image compression, as well as for solving a variety of compressed conditional generation tasks such as text-conditional image generation and image restoration, where each generated sample is accompanied by a compressed bit-stream.
88
+ </br></br>
89
+ The tabs below correspond to demos of different practical applications. Open each tab to see the application's specific instructions.
90
+ </br></br>
91
+ <b>Note: The demos below rely on relatively old pre-trained diffusion models such as Stable Diffusion 2.1, simply for the purpose of demonstrating the capabilities of DDCM. Feel free to implement our DDCM-based methods using newer diffusion models to further improve performance.</b>
92
+ """
93
+
94
+ article = r"""
95
+ If you find our work useful, please ⭐ our <a href='https://github.com/DDCM-2025/ddcm-compressed-image-generation' target='_blank'>GitHub repository</a>. Thanks!
96
+
97
+ 📝 **Citation**
98
+ ```bibtex
99
+ @article{ohayon2025compressedimagegenerationdenoising,
100
+ title={Compressed Image Generation with Denoising Diffusion Codebook Models},
101
+ author={Guy Ohayon and Hila Manor and Tomer Michaeli and Michael Elad},
102
+ year={2025},
103
+ eprint={2502.01189},
104
+ journal={arXiv},
105
+ primaryClass={eess.IV},
106
+ url={https://arxiv.org/abs/2502.01189},
107
+ }
108
+ ```
109
+
110
+ 📋 **License**
111
+ This project is released under the <a rel="license" href="https://github.com/DDCM-2025/ddcm-compressed-image-generation/blob/master/LICENSE">MIT license</a>.
112
+
113
+ 📧 **Contact**
114
+ If you have any questions, please feel free to contact us at <b>[email protected]</b> (Guy Ohayon) and <b>[email protected]</b> (Hila Manor).
115
+ """
116
+
117
+ custom_css = """
118
+ .tabs button {
119
+ font-size: 21px !important;
120
+ font-weight: bold !important;
121
+ }
122
+ """
123
+
124
+ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
125
+ gr.HTML(title)
126
+ gr.HTML(intro)
127
+ # gr.Markdown("# Compressed Image Generation with Denoising Diffusion Codebook Models")
128
+
129
+ with gr.Tab("Image Compression"):
130
+ gr.Markdown(
131
+ "- To change the bit rate, modify the number of diffusion timesteps (T) and/or the codebook sizes (K).")
132
+ gr.Markdown("- The input image will be center-cropped and resized to the specified size (512x512 or 768x768).")
133
+ # gr.Markdown("#### Notes:")
134
+ # gr.Markdown('* Since our methods relies on Stable Diffusion, we resize the input image to 512512 pixels')
135
+
136
+ with gr.Row():
137
+ with gr.Column(scale=2):
138
+ input_image = gr.Image(label="Input image", scale=2, image_mode='RGB', type='pil')
139
+ with gr.Group():
140
+ with gr.Row():
141
+ T = gr.Number(label="Diffusion timesteps (T)", minimum=50, maximum=1000, value=1000, scale=2)
142
+ K = gr.Number(label="Size of each codebook (K)", minimum=2, maximum=8192, value=2048, scale=3)
143
+ with gr.Row():
144
+ model_type = gr.Radio(["768x768", "512x512"], label="Image size", value="512x512")
145
+ compress = gr.Button("Compress image")
146
+
147
+ with gr.Column(scale=3):
148
+ decompressed_image = gr.Image(label="Decompressed image", scale=2)
149
+ compressed_file_out = gr.File(label="Compressed bit-stream (output)", scale=0)
150
+
151
+ compress.click(validate_K, inputs=[K]).then(compression_func, inputs=[input_image, T, K, model_type],
152
+ outputs=[decompressed_image, compressed_file_out])
153
+
154
+ gr.Examples([
155
+ ["examples/compression/1.jpg", 1000, 256, '512x512'],
156
+ ["examples/compression/2.jpg", 1000, 256, '512x512'],
157
+ ["examples/compression/4.jpg", 1000, 256, '512x512'],
158
+ ["examples/compression/7.jpg", 1000, 256, '512x512'],
159
+ ["examples/compression/8.jpg", 1000, 256, '512x512'],
160
+ ["examples/compression/13.jpg", 1000, 256, '512x512'],
161
+ ["examples/compression/15.jpg", 1000, 256, '512x512'],
162
+ ["examples/compression/17.jpg", 1000, 256, '512x512'],
163
+ ["examples/compression/18.jpg", 1000, 256, '512x512'],
164
+ ["examples/compression/19.jpg", 1000, 256, '512x512'],
165
+ ["examples/compression/21.jpg", 1000, 256, '512x512'],
166
+ ["examples/compression/22.jpg", 1000, 256, '512x512'],
167
+ ["examples/compression/23.jpg", 1000, 256, '512x512'],
168
+ ],
169
+ inputs=[input_image, T, K, model_type],
170
+ outputs=[decompressed_image, compressed_file_out],
171
+ fn=compression_func,
172
+ cache_examples='lazy')
173
+
174
+ gr.Markdown("### Decompress a previously generated bit-stream")
175
+ with gr.Row():
176
+ with gr.Column(scale=2):
177
+ bitstream = gr.File(label="Compressed bit-stream (input)", scale=0)
178
+ decompress = gr.Button("Decompress image")
179
+
180
+ with gr.Column(scale=3):
181
+ decompressed_image = gr.Image(label="Decompressed image (from uploaded bit-stream)", scale=2)
182
+
183
+ decompress.click(method_to_func['compression'], inputs=bitstream, outputs=decompressed_image)
184
+
185
+ with gr.Tab("Real-World Face Image Restoration"):
186
+ gr.Markdown( # "Restore any degraded face image. "
187
+ "Please mark if your input face image is already aligned. "
188
+ "If not, we will try to automatically detect, crop and align the faces, and raise an error if no faces are found. Expect better results if your input image is already aligned.")
189
+
190
+ with gr.Row():
191
+ with gr.Column(scale=2):
192
+ with gr.Group():
193
+ input_image = gr.Image(label="Input image", scale=2, type='filepath')
194
+ aligned = gr.Checkbox(label='Input face image is aligned')
195
+ with gr.Group():
196
+ with gr.Row():
197
+ T = gr.Number(label="Diffusion timesteps (T)", minimum=50, maximum=1000, value=1000)
198
+ K = gr.Number(label="Size of each codebook (K)", minimum=2, maximum=8192, value=2048)
199
+ iqa_metric = gr.Radio(['NIQE', 'TOPIQ', 'CLIP-IQA'], label='Perceptual quality measure to optimize',
200
+ value='NIQE')
201
+ iqa_coef = gr.Number(
202
+ label="Perception-distortion tradeoff coefficient (λ)",
203
+ info="Higher -> better perceptual quality",
204
+ # label="Coefficient controlling the perception-distortion tradeoff (higher means better perceptual quality)",
205
+ minimum=0, maximum=1, value=1)
206
+ restore = gr.Button("Restore and compress")
207
+
208
+ with gr.Column(scale=3):
209
+ decompressed_image = gr.Gallery(label="Restored faces gallery", type="numpy", show_label=True,
210
+ format="png")
211
+ compressed_file_out = gr.File(label="Compressed bit-stream (output)", scale=0, file_count='multiple')
212
+
213
+ restore.click(validate_K, inputs=[K]).then(DDCM_blind_face_image_restoration.inference,
214
+ inputs=[input_image, T, K, iqa_metric, iqa_coef, aligned],
215
+ outputs=[decompressed_image, compressed_file_out])
216
+ gr.Examples([
217
+ ["examples/bfr/00000055.png", 1000, 4096, 'TOPIQ', 0.1, True],
218
+ ["examples/bfr/00000085.png", 1000, 4096, 'TOPIQ', 0.1, True],
219
+ ["examples/bfr/00000113.png", 1000, 4096, 'TOPIQ', 0.1, True],
220
+ ["examples/bfr/00000137.png", 1000, 4096, 'TOPIQ', 0.1, True],
221
+ ["examples/bfr/wider/0034.jpg", 1000, 4096, 'NIQE', 1, True],
222
+ ["examples/bfr/webphoto/00042_00.jpg", 1000, 4096, 'TOPIQ', 0.1, True],
223
+ ["examples/bfr/lfw/Ana_Palacio_0001_00.jpg", 1000, 4096, 'TOPIQ', 0.1, True],
224
+ ["examples/bfr/01.png", 1000, 4096, 'NIQE', 0.1, False],
225
+ ["examples/bfr/03.jpg", 1000, 4096, 'TOPIQ', 0.1, False],
226
+ ],
227
+ inputs=[input_image, T, K, iqa_metric, iqa_coef, aligned],
228
+ outputs=[decompressed_image, compressed_file_out],
229
+ fn=DDCM_blind_face_image_restoration.inference,
230
+ cache_examples='lazy')
231
+
232
+ gr.Markdown("### Decompress a previously generated bit-stream")
233
+ with gr.Row():
234
+ with gr.Column(scale=2):
235
+ bitstream = gr.File(label="Compressed bit-stream (input)", scale=0)
236
+ decompress = gr.Button("Decompress image")
237
+
238
+ with gr.Column(scale=3):
239
+ decompressed_image = gr.Image(label="Decompressed image (from uploaded bit-stream)", scale=2)
240
+
241
+ decompress.click(method_to_func['blind'], inputs=bitstream, outputs=decompressed_image)
242
+
243
+ with gr.Tab("Compressed Text-to-Image Generation"):
244
+ gr.Markdown(
245
+ "This application demonstrates the capabilities of our new *compressed* classifier-free guidance method, which *does not require the input condition for decompression*."
246
+ " \n" # newline
247
+ "Each image is generated along with its compressed bit-stream representation, and the input condition is implicitly encoded in the bit-stream.")
248
+ # gr.Markdown("### Generate an image and its compressed bit-stream given an input text prompt")
249
+ # gr.Markdown("#### Notes:")
250
+ # gr.Markdown("* The size of the generated image is 512x512")
251
+
252
+ with gr.Row():
253
+ with gr.Column(scale=2):
254
+ with gr.Group():
255
+ text_input = gr.Textbox(label="Input text prompt", scale=1, value="An image of a dog")
256
+ with gr.Row():
257
+ T = gr.Number(label="Diffusion timesteps (T)", minimum=50, maximum=1000, value=1000, scale=1)
258
+ K = gr.Number(label="Size of each codebook (K)", minimum=2, maximum=256, value=128, scale=1)
259
+ K_tilde = gr.Number(label=r"Sub-sampled codebooks' sizes (K̃)", scale=1,
260
+ info="Behaves like a guidance scale", minimum=2, maximum=256, value=32)
261
+ model_type = gr.Radio(["768x768", "512x512"], label="Image size", value="512x512")
262
+ button = gr.Button("Generate and compress")
263
+
264
+ with gr.Column(scale=3):
265
+ decompressed_image = gr.Image(label="Generated image", scale=2)
266
+ compressed_file_out = gr.File(label="Compressed bit-stream (output)", scale=0)
267
+
268
+ button.click(validate_K, inputs=[K]).then(ccfg, inputs=[text_input, T, K, K_tilde, model_type],
269
+ outputs=[decompressed_image, compressed_file_out])
270
+
271
+ gr.Examples([
272
+ ["An image of a dog", 1000, 64, 4, '512x512'],
273
+ ["Rainbow over the mountains", 1000, 64, 4, '512x512'],
274
+ ["A cat playing soccer", 1000, 64, 4, '512x512'],
275
+ ],
276
+ inputs=[text_input, T, K, K_tilde, model_type],
277
+ outputs=[decompressed_image, compressed_file_out],
278
+ fn=ccfg,
279
+ cache_examples='lazy')
280
+ gr.Markdown("### Decompress a previously generated bit-stream")
281
+ with gr.Row():
282
+ with gr.Column(scale=2):
283
+ bitstream = gr.File(label="Compressed bit-stream (input)", scale=0)
284
+ button = gr.Button("Decompress")
285
+ with gr.Column(scale=3):
286
+ decompressed_image = gr.Image(label="Decompressed image (from uploaded bit-stream)", scale=2)
287
+ button.click(method_to_func['ccfg'], inputs=bitstream, outputs=decompressed_image)
288
+
289
+ gr.Markdown(article)
290
+
291
+ demo.queue()
292
+ demo.launch(state_session_capacity=500)
293
+
examples/bfr/00000055.png ADDED
examples/bfr/00000085.png ADDED
examples/bfr/00000113.png ADDED
examples/bfr/00000137.png ADDED
examples/bfr/01.png ADDED
examples/bfr/03.jpg ADDED
examples/bfr/lfw/Ana_Palacio_0001_00.jpg ADDED
examples/bfr/webphoto/00042_00.jpg ADDED
examples/bfr/wider/0005.jpg ADDED
examples/bfr/wider/0022.jpg ADDED
examples/bfr/wider/0034.jpg ADDED
examples/compression/1.jpg ADDED
examples/compression/13.jpg ADDED
examples/compression/15.jpg ADDED
examples/compression/17.jpg ADDED
examples/compression/18.jpg ADDED
examples/compression/19.jpg ADDED
examples/compression/2.jpg ADDED
examples/compression/20.jpg ADDED
examples/compression/21.jpg ADDED
examples/compression/22.jpg ADDED
examples/compression/23.jpg ADDED
examples/compression/4.jpg ADDED
examples/compression/7.jpg ADDED
examples/compression/8.jpg ADDED
guided_diffusion/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ Codebase for "Improved Denoising Diffusion Probabilistic Models".
3
+ """
guided_diffusion/condition_methods.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ import torch
3
+
4
+ __CONDITIONING_METHOD__ = {}
5
+
6
+ def register_conditioning_method(name: str):
7
+ def wrapper(cls):
8
+ if __CONDITIONING_METHOD__.get(name, None):
9
+ raise NameError(f"Name {name} is already registered!")
10
+ __CONDITIONING_METHOD__[name] = cls
11
+ return cls
12
+ return wrapper
13
+
14
+ def get_conditioning_method(name: str, operator, noiser, **kwargs):
15
+ if __CONDITIONING_METHOD__.get(name, None) is None:
16
+ raise NameError(f"Name {name} is not defined!")
17
+ return __CONDITIONING_METHOD__[name](operator=operator, noiser=noiser, **kwargs)
18
+
19
+
20
+ class ConditioningMethod(ABC):
21
+ def __init__(self, operator, noiser, **kwargs):
22
+ self.operator = operator
23
+ self.noiser = noiser
24
+
25
+ def project(self, data, noisy_measurement, **kwargs):
26
+ return self.operator.project(data=data, measurement=noisy_measurement, **kwargs)
27
+
28
+ def grad_and_value(self, x_prev, x_0_hat, measurement, **kwargs):
29
+ if self.noiser.__name__ == 'gaussian':
30
+ difference = measurement - self.operator.forward(x_0_hat, **kwargs)
31
+ norm = torch.linalg.norm(difference)
32
+ norm_grad = torch.autograd.grad(outputs=norm, inputs=x_prev)[0]
33
+
34
+ elif self.noiser.__name__ == 'poisson':
35
+ Ax = self.operator.forward(x_0_hat, **kwargs)
36
+ difference = measurement-Ax
37
+ norm = torch.linalg.norm(difference) / measurement.abs()
38
+ norm = norm.mean()
39
+ norm_grad = torch.autograd.grad(outputs=norm, inputs=x_prev)[0]
40
+
41
+ else:
42
+ raise NotImplementedError
43
+
44
+ return norm_grad, norm
45
+
46
+ @abstractmethod
47
+ def conditioning(self, x_t, measurement, noisy_measurement=None, **kwargs):
48
+ pass
49
+
50
+ @register_conditioning_method(name='vanilla')
51
+ class Identity(ConditioningMethod):
52
+ # just pass the input without conditioning
53
+ def conditioning(self, x_t):
54
+ return x_t
55
+
56
+ @register_conditioning_method(name='projection')
57
+ class Projection(ConditioningMethod):
58
+ def conditioning(self, x_t, noisy_measurement, **kwargs):
59
+ x_t = self.project(data=x_t, noisy_measurement=noisy_measurement)
60
+ return x_t
61
+
62
+
63
+ @register_conditioning_method(name='mcg')
64
+ class ManifoldConstraintGradient(ConditioningMethod):
65
+ def __init__(self, operator, noiser, **kwargs):
66
+ super().__init__(operator, noiser)
67
+ self.scale = kwargs.get('scale', 1.0)
68
+
69
+ def conditioning(self, x_prev, x_t, x_0_hat, measurement, noisy_measurement, **kwargs):
70
+ # posterior sampling
71
+ norm_grad, norm = self.grad_and_value(x_prev=x_prev, x_0_hat=x_0_hat, measurement=measurement, **kwargs)
72
+ x_t -= norm_grad * self.scale
73
+
74
+ # projection
75
+ x_t = self.project(data=x_t, noisy_measurement=noisy_measurement, **kwargs)
76
+ return x_t, norm
77
+
78
+ @register_conditioning_method(name='ps')
79
+ class PosteriorSampling(ConditioningMethod):
80
+ def __init__(self, operator, noiser, **kwargs):
81
+ super().__init__(operator, noiser)
82
+ self.scale = kwargs.get('scale', 1.0)
83
+
84
+ def conditioning(self, x_prev, x_t, x_0_hat, measurement, **kwargs):
85
+ norm_grad, norm = self.grad_and_value(x_prev=x_prev, x_0_hat=x_0_hat, measurement=measurement, **kwargs)
86
+ x_t -= norm_grad * self.scale
87
+ return x_t, norm
88
+
89
+ @register_conditioning_method(name='ps+')
90
+ class PosteriorSamplingPlus(ConditioningMethod):
91
+ def __init__(self, operator, noiser, **kwargs):
92
+ super().__init__(operator, noiser)
93
+ self.num_sampling = kwargs.get('num_sampling', 5)
94
+ self.scale = kwargs.get('scale', 1.0)
95
+
96
+ def conditioning(self, x_prev, x_t, x_0_hat, measurement, **kwargs):
97
+ norm = 0
98
+ for _ in range(self.num_sampling):
99
+ # TODO: use noiser?
100
+ x_0_hat_noise = x_0_hat + 0.05 * torch.rand_like(x_0_hat)
101
+ difference = measurement - self.operator.forward(x_0_hat_noise)
102
+ norm += torch.linalg.norm(difference) / self.num_sampling
103
+
104
+ norm_grad = torch.autograd.grad(outputs=norm, inputs=x_prev)[0]
105
+ x_t -= norm_grad * self.scale
106
+ return x_t, norm
guided_diffusion/diffusion_config.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ sampler: ddim
2
+ steps: 1000
3
+ noise_schedule: linear
4
+ model_mean_type: epsilon
5
+ model_var_type: learned_range
6
+ dynamic_threshold: False
7
+ clip_denoised: True
8
+ rescale_timesteps: False
9
+ timestep_respacing: 1000
guided_diffusion/ffhq512_model_config.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Defaults for image training.
2
+
3
+ image_size: 512
4
+ num_channels: 32
5
+ num_res_blocks: "1,2,2,2,2,3,4"
6
+ learn_sigma: True
7
+ class_cond: False
8
+ conv_resample: True
9
+ attention_resolutions: "32,16,8"
10
+ num_head_channels: 64
11
+ use_scale_shift_norm: True
12
+ resblock_updown: False
13
+ use_fp16: False
14
+ use_checkpoint: False
15
+ channel_mult: "1,2,4,8,8,16,16"
16
+ num_heads: 1
17
+ num_heads_upsample: -1
18
+ dropout: 0.0
19
+ dims: 2
20
+ use_new_attention_order: False
21
+
22
+ model_path: ./guided_diffusion/iddpm_ffhq512_ema500000.pth
23
+
24
+
guided_diffusion/fp16_util.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers to train with 16-bit precision.
3
+ """
4
+
5
+ import numpy as np
6
+ import torch as th
7
+ import torch.nn as nn
8
+ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
9
+
10
+ INITIAL_LOG_LOSS_SCALE = 20.0
11
+
12
+
13
+ def convert_module_to_f16(l):
14
+ """
15
+ Convert primitive modules to float16.
16
+ """
17
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
18
+ l.weight.data = l.weight.data.half()
19
+ if l.bias is not None:
20
+ l.bias.data = l.bias.data.half()
21
+
22
+
23
+ def convert_module_to_f32(l):
24
+ """
25
+ Convert primitive modules to float32, undoing convert_module_to_f16().
26
+ """
27
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
28
+ l.weight.data = l.weight.data.float()
29
+ if l.bias is not None:
30
+ l.bias.data = l.bias.data.float()
31
+
32
+
33
+ def make_master_params(param_groups_and_shapes):
34
+ """
35
+ Copy model parameters into a (differently-shaped) list of full-precision
36
+ parameters.
37
+ """
38
+ master_params = []
39
+ for param_group, shape in param_groups_and_shapes:
40
+ master_param = nn.Parameter(
41
+ _flatten_dense_tensors(
42
+ [param.detach().float() for (_, param) in param_group]
43
+ ).view(shape)
44
+ )
45
+ master_param.requires_grad = True
46
+ master_params.append(master_param)
47
+ return master_params
48
+
49
+
50
+ def model_grads_to_master_grads(param_groups_and_shapes, master_params):
51
+ """
52
+ Copy the gradients from the model parameters into the master parameters
53
+ from make_master_params().
54
+ """
55
+ for master_param, (param_group, shape) in zip(
56
+ master_params, param_groups_and_shapes
57
+ ):
58
+ master_param.grad = _flatten_dense_tensors(
59
+ [param_grad_or_zeros(param) for (_, param) in param_group]
60
+ ).view(shape)
61
+
62
+
63
+ def master_params_to_model_params(param_groups_and_shapes, master_params):
64
+ """
65
+ Copy the master parameter data back into the model parameters.
66
+ """
67
+ # Without copying to a list, if a generator is passed, this will
68
+ # silently not copy any parameters.
69
+ for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
70
+ for (_, param), unflat_master_param in zip(
71
+ param_group, unflatten_master_params(param_group, master_param.view(-1))
72
+ ):
73
+ param.detach().copy_(unflat_master_param)
74
+
75
+
76
+ def unflatten_master_params(param_group, master_param):
77
+ return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
78
+
79
+
80
+ def get_param_groups_and_shapes(named_model_params):
81
+ named_model_params = list(named_model_params)
82
+ scalar_vector_named_params = (
83
+ [(n, p) for (n, p) in named_model_params if p.ndim <= 1],
84
+ (-1),
85
+ )
86
+ matrix_named_params = (
87
+ [(n, p) for (n, p) in named_model_params if p.ndim > 1],
88
+ (1, -1),
89
+ )
90
+ return [scalar_vector_named_params, matrix_named_params]
91
+
92
+
93
+ def master_params_to_state_dict(
94
+ model, param_groups_and_shapes, master_params, use_fp16
95
+ ):
96
+ if use_fp16:
97
+ state_dict = model.state_dict()
98
+ for master_param, (param_group, _) in zip(
99
+ master_params, param_groups_and_shapes
100
+ ):
101
+ for (name, _), unflat_master_param in zip(
102
+ param_group, unflatten_master_params(param_group, master_param.view(-1))
103
+ ):
104
+ assert name in state_dict
105
+ state_dict[name] = unflat_master_param
106
+ else:
107
+ state_dict = model.state_dict()
108
+ for i, (name, _value) in enumerate(model.named_parameters()):
109
+ assert name in state_dict
110
+ state_dict[name] = master_params[i]
111
+ return state_dict
112
+
113
+
114
+ def state_dict_to_master_params(model, state_dict, use_fp16):
115
+ if use_fp16:
116
+ named_model_params = [
117
+ (name, state_dict[name]) for name, _ in model.named_parameters()
118
+ ]
119
+ param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
120
+ master_params = make_master_params(param_groups_and_shapes)
121
+ else:
122
+ master_params = [state_dict[name] for name, _ in model.named_parameters()]
123
+ return master_params
124
+
125
+
126
+ def zero_master_grads(master_params):
127
+ for param in master_params:
128
+ param.grad = None
129
+
130
+
131
+ def zero_grad(model_params):
132
+ for param in model_params:
133
+ # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
134
+ if param.grad is not None:
135
+ param.grad.detach_()
136
+ param.grad.zero_()
137
+
138
+
139
+ def param_grad_or_zeros(param):
140
+ if param.grad is not None:
141
+ return param.grad.data.detach()
142
+ else:
143
+ return th.zeros_like(param)
144
+
145
+
146
+ class MixedPrecisionTrainer:
147
+ def __init__(
148
+ self,
149
+ *,
150
+ model,
151
+ use_fp16=False,
152
+ fp16_scale_growth=1e-3,
153
+ initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
154
+ ):
155
+ self.model = model
156
+ self.use_fp16 = use_fp16
157
+ self.fp16_scale_growth = fp16_scale_growth
158
+
159
+ self.model_params = list(self.model.parameters())
160
+ self.master_params = self.model_params
161
+ self.param_groups_and_shapes = None
162
+ self.lg_loss_scale = initial_lg_loss_scale
163
+
164
+ if self.use_fp16:
165
+ self.param_groups_and_shapes = get_param_groups_and_shapes(
166
+ self.model.named_parameters()
167
+ )
168
+ self.master_params = make_master_params(self.param_groups_and_shapes)
169
+ self.model.convert_to_fp16()
170
+
171
+ def zero_grad(self):
172
+ zero_grad(self.model_params)
173
+
174
+ def backward(self, loss: th.Tensor):
175
+ if self.use_fp16:
176
+ loss_scale = 2 ** self.lg_loss_scale
177
+ (loss * loss_scale).backward()
178
+ else:
179
+ loss.backward()
180
+
181
+ def optimize(self, opt: th.optim.Optimizer):
182
+ if self.use_fp16:
183
+ return self._optimize_fp16(opt)
184
+ else:
185
+ return self._optimize_normal(opt)
186
+
187
+ def _optimize_fp16(self, opt: th.optim.Optimizer):
188
+ logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
189
+ model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
190
+ grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
191
+ if check_overflow(grad_norm):
192
+ self.lg_loss_scale -= 1
193
+ logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
194
+ zero_master_grads(self.master_params)
195
+ return False
196
+
197
+ logger.logkv_mean("grad_norm", grad_norm)
198
+ logger.logkv_mean("param_norm", param_norm)
199
+
200
+ self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
201
+ opt.step()
202
+ zero_master_grads(self.master_params)
203
+ master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
204
+ self.lg_loss_scale += self.fp16_scale_growth
205
+ return True
206
+
207
+ def _optimize_normal(self, opt: th.optim.Optimizer):
208
+ grad_norm, param_norm = self._compute_norms()
209
+ logger.logkv_mean("grad_norm", grad_norm)
210
+ logger.logkv_mean("param_norm", param_norm)
211
+ opt.step()
212
+ return True
213
+
214
+ def _compute_norms(self, grad_scale=1.0):
215
+ grad_norm = 0.0
216
+ param_norm = 0.0
217
+ for p in self.master_params:
218
+ with th.no_grad():
219
+ param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
220
+ if p.grad is not None:
221
+ grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
222
+ return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
223
+
224
+ def master_params_to_state_dict(self, master_params):
225
+ return master_params_to_state_dict(
226
+ self.model, self.param_groups_and_shapes, master_params, self.use_fp16
227
+ )
228
+
229
+ def state_dict_to_master_params(self, state_dict):
230
+ return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
231
+
232
+
233
+ def check_overflow(value):
234
+ return (value == float("inf")) or (value == -float("inf")) or (value != value)
guided_diffusion/gaussian_diffusion.py ADDED
@@ -0,0 +1,864 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ # from functools import partial
4
+ # from clip_fiqa.inference import get_model, compute_quality
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import torch
8
+ from tqdm.auto import tqdm
9
+ # from torchmetrics.multimodal import CLIPImageQualityAssessment
10
+ import random
11
+ # from torch.nn.functional import cosine_similarity
12
+ import pyiqa
13
+
14
+ from util.img_utils import clear_color
15
+ from .posterior_mean_variance import get_mean_processor, get_var_processor
16
+
17
+
18
+ def set_seed(seed):
19
+ torch.manual_seed(seed)
20
+ np.random.seed(seed)
21
+ random.seed(seed)
22
+ torch.cuda.manual_seed_all(seed)
23
+ # torch.backends.cudnn.deterministic = True
24
+ # torch.backends.cudnn.benchmark = False
25
+
26
+ __SAMPLER__ = {}
27
+
28
+ def register_sampler(name: str):
29
+ def wrapper(cls):
30
+ if __SAMPLER__.get(name, None):
31
+ raise NameError(f"Name {name} is already registered!")
32
+ __SAMPLER__[name] = cls
33
+ return cls
34
+ return wrapper
35
+
36
+
37
+ def get_sampler(name: str):
38
+ if __SAMPLER__.get(name, None) is None:
39
+ raise NameError(f"Name {name} is not defined!")
40
+ return __SAMPLER__[name]
41
+
42
+
43
+ def create_sampler(sampler,
44
+ steps,
45
+ noise_schedule,
46
+ model_mean_type,
47
+ model_var_type,
48
+ dynamic_threshold,
49
+ clip_denoised,
50
+ rescale_timesteps,
51
+ timestep_respacing=""):
52
+
53
+ sampler = get_sampler(name=sampler)
54
+
55
+ betas = get_named_beta_schedule(noise_schedule, steps)
56
+ if not timestep_respacing:
57
+ timestep_respacing = [steps]
58
+
59
+ return sampler(use_timesteps=space_timesteps(steps, timestep_respacing),
60
+ betas=betas,
61
+ model_mean_type=model_mean_type,
62
+ model_var_type=model_var_type,
63
+ dynamic_threshold=dynamic_threshold,
64
+ clip_denoised=clip_denoised,
65
+ rescale_timesteps=rescale_timesteps)
66
+
67
+ def compute_psnr(img1, img2):
68
+ """
69
+ Computes the Peak Signal-to-Noise Ratio (PSNR) between two images.
70
+ The images should have pixel values in the range [-1, 1].
71
+
72
+ Args:
73
+ img1 (torch.Tensor): The first image tensor (e.g., reference image).
74
+ Shape: (N, C, H, W) or (C, H, W).
75
+ img2 (torch.Tensor): The second image tensor (e.g., generated image).
76
+ Shape: same as img1.
77
+
78
+ Returns:
79
+ psnr (float): The computed PSNR value in decibels (dB).
80
+ """
81
+ # Ensure the input tensors are in the same shape
82
+ assert img1.shape == img2.shape, "Input images must have the same shape"
83
+
84
+ # Compute Mean Squared Error (MSE)
85
+ mse = torch.mean((img1 - img2) ** 2)
86
+
87
+ # Avoid division by zero in case of identical images
88
+ if mse == 0:
89
+ return float('inf')
90
+
91
+ # Maximum possible pixel value difference in the range [-1, 1] is 2
92
+ max_pixel_value = 2.0
93
+
94
+ # Compute PSNR
95
+ psnr = 20 * torch.log10(max_pixel_value / torch.sqrt(mse))
96
+
97
+ return psnr.item()
98
+
99
+ class GaussianDiffusion:
100
+ def __init__(self,
101
+ betas,
102
+ model_mean_type,
103
+ model_var_type,
104
+ dynamic_threshold,
105
+ clip_denoised,
106
+ rescale_timesteps
107
+ ):
108
+
109
+ # use float64 for accuracy.
110
+ betas = np.array(betas, dtype=np.float64)
111
+ self.betas = betas
112
+ assert self.betas.ndim == 1, "betas must be 1-D"
113
+ assert (0 < self.betas).all() and (self.betas <=1).all(), "betas must be in (0..1]"
114
+
115
+ self.num_timesteps = int(self.betas.shape[0])
116
+ self.rescale_timesteps = rescale_timesteps
117
+
118
+ alphas = 1.0 - self.betas
119
+ self.alphas = alphas
120
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
121
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
122
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
123
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
124
+
125
+ # calculations for diffusion q(x_t | x_{t-1}) and others
126
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
127
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
128
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
129
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
130
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
131
+
132
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
133
+ self.posterior_variance = (
134
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
135
+ )
136
+ # log calculation clipped because the posterior variance is 0 at the
137
+ # beginning of the diffusion chain.
138
+ self.posterior_log_variance_clipped = np.log(
139
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
140
+ )
141
+ self.posterior_mean_coef1 = (
142
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
143
+ )
144
+ self.posterior_mean_coef2 = (
145
+ (1.0 - self.alphas_cumprod_prev)
146
+ * np.sqrt(alphas)
147
+ / (1.0 - self.alphas_cumprod)
148
+ )
149
+
150
+ self.mean_processor = get_mean_processor(model_mean_type,
151
+ betas=betas,
152
+ dynamic_threshold=dynamic_threshold,
153
+ clip_denoised=clip_denoised)
154
+
155
+ self.var_processor = get_var_processor(model_var_type,
156
+ betas=betas)
157
+
158
+ def q_mean_variance(self, x_start, t):
159
+ """
160
+ Get the distribution q(x_t | x_0).
161
+
162
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
163
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
164
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
165
+ """
166
+
167
+ mean = extract_and_expand(self.sqrt_alphas_cumprod, t, x_start) * x_start
168
+ variance = extract_and_expand(1.0 - self.alphas_cumprod, t, x_start)
169
+ log_variance = extract_and_expand(self.log_one_minus_alphas_cumprod, t, x_start)
170
+
171
+ return mean, variance, log_variance
172
+
173
+ def q_sample(self, x_start, t):
174
+ """
175
+ Diffuse the data for a given number of diffusion steps.
176
+
177
+ In other words, sample from q(x_t | x_0).
178
+
179
+ :param x_start: the initial data batch.
180
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
181
+ :param noise: if specified, the split-out normal noise.
182
+ :return: A noisy version of x_start.
183
+ """
184
+ noise = torch.randn_like(x_start)
185
+ assert noise.shape == x_start.shape
186
+
187
+ coef1 = extract_and_expand(self.sqrt_alphas_cumprod, t, x_start)
188
+ coef2 = extract_and_expand(self.sqrt_one_minus_alphas_cumprod, t, x_start)
189
+
190
+ return coef1 * x_start + coef2 * noise
191
+
192
+ def q_posterior_mean_variance(self, x_start, x_t, t):
193
+ """
194
+ Compute the mean and variance of the diffusion posterior:
195
+
196
+ q(x_{t-1} | x_t, x_0)
197
+
198
+ """
199
+ assert x_start.shape == x_t.shape
200
+ coef1 = extract_and_expand(self.posterior_mean_coef1, t, x_start)
201
+ coef2 = extract_and_expand(self.posterior_mean_coef2, t, x_t)
202
+ posterior_mean = coef1 * x_start + coef2 * x_t
203
+ posterior_variance = extract_and_expand(self.posterior_variance, t, x_t)
204
+ posterior_log_variance_clipped = extract_and_expand(self.posterior_log_variance_clipped, t, x_t)
205
+
206
+ assert (
207
+ posterior_mean.shape[0]
208
+ == posterior_variance.shape[0]
209
+ == posterior_log_variance_clipped.shape[0]
210
+ == x_start.shape[0]
211
+ )
212
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
213
+
214
+ torch.no_grad()
215
+ def p_sample_loop_compression(self,
216
+ model,
217
+ x_start,
218
+ ref_img,
219
+ record,
220
+ save_root,
221
+ num_opt_noises,
222
+ num_random_noises,
223
+ loss_type,
224
+ decode_residual_gap,
225
+ fname,
226
+ eta,
227
+ num_best_opt_noises,
228
+ num_pursuit_noises,
229
+ num_pursuit_coef_bits,
230
+ random_opt_mse_noises):
231
+ """
232
+ The function used for sampling from noise.
233
+ """
234
+ assert num_best_opt_noises + num_random_noises > 0
235
+ # loss_fn_vgg = lpips.LPIPS(net='vgg').cuda()
236
+ # loss_fn_alex = lpips.LPIPS(net='alex').cuda()
237
+
238
+ set_seed(100000)
239
+ device = x_start.device
240
+ img = torch.randn(1 + random_opt_mse_noises, *x_start.shape[1:], device=device)
241
+
242
+ plt.imsave(os.path.join(save_root, f"progress/img_to_compress.png"), clear_color(ref_img))
243
+ best_indices_list = []
244
+ x_hat_0_list = []
245
+
246
+ pbar = tqdm(list(range(self.num_timesteps))[::-1])
247
+ num_noises_total = 0
248
+ num_steps_total = 0
249
+ for idx in pbar:
250
+ set_seed(idx)
251
+ time = torch.tensor([idx] * img.shape[0], device=device)
252
+ if len(x_hat_0_list) >= 2:
253
+ x_hat_0_list = x_hat_0_list[-decode_residual_gap:]
254
+ x_hat_0_list_tensor = torch.stack(x_hat_0_list, dim=0)
255
+
256
+ # TODO: think about different probs schedulings
257
+ probs = torch.linspace(0, 1, len(x_hat_0_list) - 1, device=device)
258
+ probs /= torch.sum(probs)
259
+
260
+ residual = torch.sum(probs.view(-1, 1) * (x_hat_0_list_tensor[1:] - x_hat_0_list_tensor[:-1]).view(len(x_hat_0_list) - 1, -1), dim=0)
261
+
262
+ new_noise = torch.randn(num_opt_noises, *img.shape[1:], device=device)
263
+ similarity = torch.matmul(new_noise.view(num_opt_noises, -1),
264
+ residual.view(-1, 1)).squeeze(1)
265
+ sorted_similarity, sorted_indices = torch.sort(similarity, descending=False)
266
+
267
+ noise = new_noise[sorted_indices][:num_best_opt_noises]
268
+ if num_random_noises > 0:
269
+ noise = torch.cat((noise, torch.randn(num_random_noises, *img.shape[1:], device=device)), dim=0)
270
+
271
+ else:
272
+ noise = torch.randn(num_best_opt_noises + num_random_noises, *img.shape[1:], device=device)
273
+ num_noises_total += noise.shape[0]
274
+ num_steps_total += 1
275
+ # perceptual_loss_weight = (1 - (idx / len(pbar))) * lpips_loss_mult
276
+ out = self.p_sample(x=img,
277
+ t=time,
278
+ model=model,
279
+ noise=noise,
280
+ ref=ref_img,
281
+ loss_type=loss_type,
282
+ random_opt_mse_noises=random_opt_mse_noises,
283
+ eta=eta,
284
+ num_pursuit_noises=num_pursuit_noises,
285
+ num_pursuit_coef_bits=num_pursuit_coef_bits)
286
+ best_idx = out['best_idx']
287
+ best_indices_list.append(best_idx.cpu().numpy())
288
+ # print(best_indices_list, '\n\n', flush=True)
289
+
290
+ img = out['sample']
291
+ x_0_hat = out['pred_xstart']
292
+ x_hat_0_list.append(x_0_hat[0].unsqueeze(0))
293
+ # chosen_noises_list.append(noise[best_idx])
294
+
295
+ # pbar.set_postfix({'distance': out['mse']}, refresh=False)
296
+ if record:
297
+ if idx % 50 == 0:
298
+ plt.imsave(os.path.join(save_root, f"progress/x_0_hat_{str(idx).zfill(4)}.png"), clear_color(x_0_hat[0].unsqueeze(0).clip(-1, 1)))
299
+ plt.imsave(os.path.join(save_root, f"progress/x_t_{str(idx).zfill(4)}.png"), clear_color(img[0].unsqueeze(0).clip(-1, 1)))
300
+ plt.imsave(os.path.join(save_root, f"progress/noise_t_{str(idx).zfill(4)}.png"), clear_color(noise[0].unsqueeze(0).clip(-1, 1)))
301
+ plt.imsave(os.path.join(save_root, f"progress/err_t_{str(idx).zfill(4)}.png"), clear_color((ref_img - x_0_hat)[0].unsqueeze(0)))
302
+ del noise
303
+
304
+ # lpips_vgg = loss_fn_vgg(img, ref_img).squeeze().item()
305
+ # lpips_alex = loss_fn_alex(img, ref_img).squeeze().item()
306
+ plt.imsave(os.path.join(save_root,
307
+ f"progress/x_0_hat_final_psnr={compute_psnr(img[0].unsqueeze(0), ref_img)}_bpp={np.log2(num_noises_total / num_steps_total)}.png"),
308
+ clear_color(img[0].unsqueeze(0)))
309
+ indices_save_folder = os.path.join(save_root, 'best_indices')
310
+ os.makedirs(indices_save_folder, exist_ok=True)
311
+ np.save(os.path.join(indices_save_folder, os.path.splitext(os.path.basename(fname))[0] + '.bestindices'), np.array(best_indices_list))
312
+
313
+ return img
314
+
315
+ @torch.no_grad()
316
+ def p_sample_loop_blind_restoration(self,
317
+ model,
318
+ x_start,
319
+ mmse_img,
320
+ num_opt_noises,
321
+ iqa_metric,
322
+ iqa_coef,
323
+ eta,
324
+ loaded_indices):
325
+
326
+ assert iqa_metric == 'niqe' or iqa_metric == 'clipiqa+' or iqa_metric == 'topiq_nr-face'
327
+ iqa = pyiqa.create_metric(iqa_metric, device=x_start.device)
328
+ device = x_start.device
329
+
330
+ set_seed(100000)
331
+ img = torch.randn(2, *x_start.shape[1:], device=device)
332
+
333
+ pbar = tqdm(list(range(self.num_timesteps))[::-1])
334
+ next_idx = np.array([0, 1])
335
+ if loaded_indices is not None:
336
+ indices = loaded_indices
337
+ loaded_indices = torch.cat((loaded_indices, torch.tensor([0], device=device, dtype=loaded_indices.dtype)), dim=0)
338
+ else:
339
+ indices = []
340
+ for i, idx in enumerate(pbar):
341
+ set_seed(idx)
342
+
343
+
344
+ noise = torch.randn(num_opt_noises, *img.shape[1:], device=device)
345
+ if loaded_indices is None:
346
+ time = torch.tensor([idx] * img.shape[0], device=device)
347
+ out = self.p_sample(x=img,
348
+ t=time,
349
+ model=model,
350
+ noise=noise,
351
+ ref=mmse_img,
352
+ loss_type='dot_prod',
353
+ optimize_iqa=True,
354
+ eta=eta,
355
+ iqa=iqa,
356
+ iqa_coef=iqa_coef)
357
+ img = out['sample']
358
+ best_perceptual_idx_cur = out['best_perceptual_idx']
359
+ indices.append(next_idx[best_perceptual_idx_cur])
360
+ next_idx = out['best_idx']
361
+ else:
362
+ time = torch.tensor([idx], device=device)
363
+ if i == 0:
364
+ img = img[loaded_indices[0]].unsqueeze(0)
365
+ out = self.p_sample(x=img,
366
+ t=time,
367
+ model=model,
368
+ noise=noise[loaded_indices[i+1]].unsqueeze(0),
369
+ ref=img,
370
+ loss_type='dot_prod',
371
+ optimize_iqa=False,
372
+ eta=eta,
373
+ iqa='niqe',
374
+ iqa_coef=0.0)
375
+ img = out['sample']
376
+
377
+
378
+ if type(indices) is list:
379
+ indices = torch.tensor(indices).flatten()
380
+ return img[0].unsqueeze(0), indices
381
+
382
+
383
+ @torch.no_grad()
384
+ def p_sample_loop_linear_restoration(self,
385
+ model,
386
+ x_start,
387
+ ref_img,
388
+ linear_operator,
389
+ y_n,
390
+ num_pursuit_noises,
391
+ num_pursuit_coef_bits,
392
+ record,
393
+ save_root,
394
+ num_opt_noises,
395
+ fname,
396
+ eta):
397
+ """
398
+ The function used for sampling from noise.
399
+ """
400
+
401
+ set_seed(100000)
402
+ device = x_start.device
403
+ img = torch.randn(1, *x_start.shape[1:], device=device)
404
+
405
+
406
+ pbar = tqdm(list(range(self.num_timesteps))[::-1])
407
+ for idx in pbar:
408
+ set_seed(idx)
409
+ time = torch.tensor([idx] * img.shape[0], device=device)
410
+
411
+ noise = torch.randn(num_opt_noises, *img.shape[1:], device=device)
412
+ # perceptual_loss_weight = (1 - (idx / len(pbar))) * lpips_loss_mult
413
+ out = self.p_sample(x=img,
414
+ t=time,
415
+ model=model,
416
+ noise=noise,
417
+ ref=ref_img,
418
+ loss_type='mse',
419
+ eta=eta,
420
+ y_n=y_n,
421
+ linear_operator=linear_operator,
422
+ num_pursuit_noises=num_pursuit_noises,
423
+ num_pursuit_coef_bits=num_pursuit_coef_bits,
424
+ optimize_iqa=False,
425
+ iqa=None,
426
+ iqa_coef=None)
427
+ x_0_hat = out['pred_xstart']
428
+ img = out['sample']
429
+ # loss = (((x_0_hat - mmse_img) ** 2).mean()
430
+ # - perceptual_quality_coef * clip_iqa((x_0_hat * 0.5 + 0.5).clip(0, 1)))
431
+
432
+ # pbar.set_postfix({'perceptual_quality': loss[best_perceptual_idx].item()}, refresh=False)
433
+ if record:
434
+ if idx % 50 == 0:
435
+ plt.imsave(os.path.join(save_root, f"progress/x_0_hat_{str(idx).zfill(4)}.png"), clear_color(x_0_hat[0].unsqueeze(0).clip(-1, 1)))
436
+ plt.imsave(os.path.join(save_root, f"progress/x_t_{str(idx).zfill(4)}.png"), clear_color(img[0].unsqueeze(0).clip(-1, 1)))
437
+
438
+
439
+ # plt.imsave(os.path.join(save_root,
440
+ # f"progress/x_0_hat_final_lpips-vgg={lpips_vgg:.4f}_lpips-alex"
441
+ # f"={lpips_alex:.4f}_psnr={compute_psnr(img[0].unsqueeze(0), ref_img)}_bpp={np.log2(num_noises_total / num_steps_total)}.png"),
442
+ # clear_color(img[0].unsqueeze(0)))
443
+ # indices_save_folder = os.path.join(save_root, 'best_indices')
444
+ # os.makedirs(indices_save_folder, exist_ok=True)
445
+ # np.save(os.path.join(indices_save_folder, os.path.splitext(os.path.basename(fname))[0] + '.bestindices'), np.array(best_indices_list))
446
+
447
+ return img
448
+ def p_sample(self, model, x, t, noise, ref, loss_type, eta=None):
449
+ raise NotImplementedError
450
+
451
+ def p_mean_variance(self, model, x, t):
452
+ model_output = model(x, self._scale_timesteps(t))
453
+
454
+ # In the case of "learned" variance, model will give twice channels.
455
+ if model_output.shape[1] == 2 * x.shape[1]:
456
+ model_output, model_var_values = torch.split(model_output, x.shape[1], dim=1)
457
+ else:
458
+ # The name of variable is wrong.
459
+ # This will just provide shape information, and
460
+ # will not be used for calculating something important in variance.
461
+ model_var_values = model_output
462
+
463
+ model_mean, pred_xstart = self.mean_processor.get_mean_and_xstart(x, t, model_output)
464
+ model_variance, model_log_variance = self.var_processor.get_variance(model_var_values, t)
465
+
466
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
467
+
468
+ return {'mean': model_mean,
469
+ 'variance': model_variance,
470
+ 'log_variance': model_log_variance,
471
+ 'pred_xstart': pred_xstart}
472
+
473
+
474
+ def _scale_timesteps(self, t):
475
+ if self.rescale_timesteps:
476
+ return t.float() * (1000.0 / self.num_timesteps)
477
+ return t
478
+
479
+ def space_timesteps(num_timesteps, section_counts):
480
+ """
481
+ Create a list of timesteps to use from an original diffusion process,
482
+ given the number of timesteps we want to take from equally-sized portions
483
+ of the original process.
484
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
485
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
486
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
487
+ If the stride is a string starting with "ddim", then the fixed striding
488
+ from the DDIM paper is used, and only one section is allowed.
489
+ :param num_timesteps: the number of diffusion steps in the original
490
+ process to divide up.
491
+ :param section_counts: either a list of numbers, or a string containing
492
+ comma-separated numbers, indicating the step count
493
+ per section. As a special case, use "ddimN" where N
494
+ is a number of steps to use the striding from the
495
+ DDIM paper.
496
+ :return: a set of diffusion steps from the original process to use.
497
+ """
498
+ if isinstance(section_counts, str):
499
+ if section_counts.startswith("ddim"):
500
+ desired_count = int(section_counts[len("ddim") :])
501
+ for i in range(1, num_timesteps):
502
+ if len(range(0, num_timesteps, i)) == desired_count:
503
+ return set(range(0, num_timesteps, i))
504
+ raise ValueError(
505
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
506
+ )
507
+ section_counts = [int(x) for x in section_counts.split(",")]
508
+ elif isinstance(section_counts, int):
509
+ section_counts = [section_counts]
510
+
511
+ size_per = num_timesteps // len(section_counts)
512
+ extra = num_timesteps % len(section_counts)
513
+ start_idx = 0
514
+ all_steps = []
515
+ for i, section_count in enumerate(section_counts):
516
+ size = size_per + (1 if i < extra else 0)
517
+ if size < section_count:
518
+ raise ValueError(
519
+ f"cannot divide section of {size} steps into {section_count}"
520
+ )
521
+ if section_count <= 1:
522
+ frac_stride = 1
523
+ else:
524
+ frac_stride = (size - 1) / (section_count - 1)
525
+ cur_idx = 0.0
526
+ taken_steps = []
527
+ for _ in range(section_count):
528
+ taken_steps.append(start_idx + round(cur_idx))
529
+ cur_idx += frac_stride
530
+ all_steps += taken_steps
531
+ start_idx += size
532
+ return set(all_steps)
533
+
534
+
535
+ class SpacedDiffusion(GaussianDiffusion):
536
+ """
537
+ A diffusion process which can skip steps in a base diffusion process.
538
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
539
+ original diffusion process to retain.
540
+ :param kwargs: the kwargs to create the base diffusion process.
541
+ """
542
+
543
+ def __init__(self, use_timesteps, **kwargs):
544
+ self.use_timesteps = set(use_timesteps)
545
+ self.timestep_map = []
546
+ self.original_num_steps = len(kwargs["betas"])
547
+
548
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
549
+ last_alpha_cumprod = 1.0
550
+ new_betas = []
551
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
552
+ if i in self.use_timesteps:
553
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
554
+ last_alpha_cumprod = alpha_cumprod
555
+ self.timestep_map.append(i)
556
+ kwargs["betas"] = np.array(new_betas)
557
+ super().__init__(**kwargs)
558
+
559
+ def p_mean_variance(
560
+ self, model, *args, **kwargs
561
+ ): # pylint: disable=signature-differs
562
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
563
+
564
+ def training_losses(
565
+ self, model, *args, **kwargs
566
+ ): # pylint: disable=signature-differs
567
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
568
+
569
+ def condition_mean(self, cond_fn, *args, **kwargs):
570
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
571
+
572
+ def condition_score(self, cond_fn, *args, **kwargs):
573
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
574
+
575
+ def _wrap_model(self, model):
576
+ if isinstance(model, _WrappedModel):
577
+ return model
578
+ return _WrappedModel(
579
+ model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
580
+ )
581
+
582
+ def _scale_timesteps(self, t):
583
+ # Scaling is done by the wrapped model.
584
+ return t
585
+
586
+
587
+ class _WrappedModel:
588
+ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
589
+ self.model = model
590
+ self.timestep_map = timestep_map
591
+ self.rescale_timesteps = rescale_timesteps
592
+ self.original_num_steps = original_num_steps
593
+
594
+ def __call__(self, x, ts, **kwargs):
595
+ map_tensor = torch.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
596
+ new_ts = map_tensor[ts]
597
+ if self.rescale_timesteps:
598
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
599
+ return self.model(x, new_ts, **kwargs)
600
+
601
+
602
+ @register_sampler(name='ddpm')
603
+ class DDPM(SpacedDiffusion):
604
+ def __init__(self, *args, **kwargs):
605
+ super().__init__(*args, **kwargs)
606
+
607
+ def p_sample(self, model, x, t, noise, ref, perceptual_loss_weight, loss_type='mse', eta=None):
608
+ out = self.p_mean_variance(model, x, t)
609
+ pred_xstart = out['pred_xstart']
610
+
611
+ # if loss_type == 'mse':
612
+ # loss = - ((pred_xstart + noise - ref).view(noise.shape[0], -1) ** 2).mean(1)
613
+ # elif loss_type == 'mse_alpha':
614
+ # loss = - ((pred_xstart + torch.exp(0.5 * out['log_variance']) * noise - ref).view(noise.shape[0], -1) ** 2).mean(1)
615
+ if loss_type == 'dot_prod':
616
+ loss = torch.matmul(noise.view(noise.shape[0], -1), (ref - pred_xstart).view(pred_xstart.shape[0], -1).transpose(0, 1))
617
+ elif loss_type == 'mse':
618
+ #TODO: this is what we are doing! the dot product is an approximation of it!
619
+ sqrt_recip_alphas_cumprod = extract_and_expand(self.sqrt_recip_alphas_cumprod, t-1 if t[0] > 0 else torch.zeros_like(t), noise)
620
+ loss = - ((pred_xstart + sqrt_recip_alphas_cumprod * torch.exp(0.5 * out['log_variance']) * noise - ref).view(noise.shape[0], -1) ** 2).mean(1)
621
+ elif loss_type == 'l1':
622
+ sqrt_recip_alphas_cumprod = extract_and_expand(self.sqrt_recip_alphas_cumprod, t-1 if t[0] > 0 else torch.zeros_like(t), noise)
623
+ loss = - torch.abs(pred_xstart + sqrt_recip_alphas_cumprod * torch.exp(0.5 * out['log_variance']) * noise - ref).view(noise.shape[0], -1).mean(1)
624
+
625
+ # elif loss_type == 'ddpm_inversion':
626
+ # sqrt_alphas_cumprod = extract_and_expand(self.sqrt_alphas_cumprod, t-1 if t[0] > 0 else torch.zeros_like(t), ref)
627
+ # sqrt_one_minus_alphas_cumprod = extract_and_expand(self.sqrt_one_minus_alphas_cumprod, t-1 if t[0] > 0 else torch.zeros_like(t), ref)
628
+ #
629
+ # forward_noise = torch.randn_like(ref)
630
+ # loss = torch.matmul(noise.view(noise.shape[0], -1),
631
+ # (sqrt_alphas_cumprod * ref + sqrt_one_minus_alphas_cumprod * forward_noise - out['mean']).view(pred_xstart.shape[0], -1).transpose(0, 1))
632
+ #
633
+ #
634
+
635
+ else:
636
+ raise NotImplementedError()
637
+
638
+ best_idx = torch.argmax(loss)
639
+ samples = out['mean'] + torch.exp(0.5 * out['log_variance']) * noise[best_idx].unsqueeze(0)
640
+
641
+ return {'sample': samples if t[0] > 0 else pred_xstart,
642
+ 'pred_xstart': pred_xstart,
643
+ 'mse': loss[best_idx].item(),
644
+ 'best_idx': best_idx}
645
+
646
+
647
+ @register_sampler(name='ddim')
648
+ class DDIM(SpacedDiffusion):
649
+ @torch.no_grad()
650
+ def p_sample(self, model, x, t, noise, ref, loss_type='mse', eta=0.0, iqa=None, iqa_coef=1.0,
651
+ optimize_iqa=False, linear_operator=None, y_n=None, random_opt_mse_noises=0,
652
+ num_pursuit_noises=1, num_pursuit_coef_bits=1,
653
+ cond_fn=None,
654
+ cls=None
655
+ ):
656
+
657
+ out = self.p_mean_variance(model, x, t)
658
+ pred_xstart = out['pred_xstart']
659
+ best_perceptual_idx = None
660
+ if optimize_iqa:
661
+ assert not random_opt_mse_noises
662
+ coef_sign = 1 if iqa.lower_better else -1
663
+ if iqa.metric_name == 'topiq_nr-face':
664
+ assert not iqa.lower_better
665
+ # topiq_nr-face doesn't support a batch size larger than 1.
666
+ scores = []
667
+ for elem in pred_xstart:
668
+ try:
669
+ scores.append(iqa((elem.unsqueeze(0) * 0.5 + 0.5).clip(0, 1)).squeeze().view(1))
670
+ except AssertionError:
671
+ # no face detected...
672
+ scores.append(torch.zeros(1, device=x.device))
673
+ scores = torch.stack(scores, dim=0).squeeze()
674
+ loss = (((ref - pred_xstart) ** 2).view(pred_xstart.shape[0], -1).mean(1) + coef_sign * iqa_coef * scores)
675
+ else:
676
+ loss = (((ref - pred_xstart) ** 2).view(pred_xstart.shape[0], -1).mean(1) + coef_sign * iqa_coef * iqa((pred_xstart * 0.5 + 0.5).clip(0, 1)).squeeze())
677
+ best_perceptual_idx = torch.argmin(loss)
678
+ out['pred_xstart'] = out['pred_xstart'][best_perceptual_idx].unsqueeze(0)
679
+ pred_xstart = pred_xstart[best_perceptual_idx].unsqueeze(0)
680
+ t = t[best_perceptual_idx]
681
+ x = x[best_perceptual_idx].unsqueeze(0)
682
+ elif random_opt_mse_noises > 0:
683
+ loss = (((ref - pred_xstart) ** 2).view(pred_xstart.shape[0], -1).mean(1))
684
+ best_mse_idx = torch.argmin(loss)
685
+ out['pred_xstart'] = out['pred_xstart'][best_mse_idx].unsqueeze(0)
686
+ pred_xstart = pred_xstart[best_mse_idx].unsqueeze(0)
687
+ t = t[best_mse_idx]
688
+ x = x[best_mse_idx].unsqueeze(0)
689
+
690
+ eps = self.predict_eps_from_x_start(x, t, out['pred_xstart'])
691
+ alpha_bar = extract_and_expand(self.alphas_cumprod, t, x)
692
+ alpha_bar_prev = extract_and_expand(self.alphas_cumprod_prev, t, x)
693
+ sigma = (
694
+ eta
695
+ * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
696
+ * torch.sqrt(1 - alpha_bar / alpha_bar_prev)
697
+ )
698
+ mean_pred = (
699
+ out["pred_xstart"] * torch.sqrt(alpha_bar_prev)
700
+ + torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
701
+ )
702
+ sample = mean_pred
703
+
704
+ if y_n is not None:
705
+ assert linear_operator is not None
706
+ y_n = ref if y_n is None else y_n
707
+
708
+ if not optimize_iqa and random_opt_mse_noises <= 0 and cond_fn is None:
709
+ if loss_type == 'dot_prod':
710
+ if linear_operator is None:
711
+ compute_loss = lambda noise_cur: torch.matmul(noise_cur.view(noise_cur.shape[0], -1), (ref - pred_xstart).view(pred_xstart.shape[0], -1).transpose(0, 1))
712
+ else:
713
+ compute_loss = lambda noise_cur: torch.matmul(linear_operator.forward(noise_cur).reshape(noise_cur.shape[0], -1), (y_n - linear_operator.forward(pred_xstart)).reshape(pred_xstart.shape[0], -1).transpose(0, 1))
714
+ elif loss_type == 'mse':
715
+ if linear_operator is None:
716
+ compute_loss = lambda noise_cur: - (((sigma / torch.sqrt(alpha_bar_prev)) * noise_cur + pred_xstart - y_n) ** 2).mean((1, 2, 3))
717
+ else:
718
+ compute_loss = lambda noise_cur: - (((sigma / torch.sqrt(alpha_bar_prev))[:, :, :y_n.shape[2], :y_n.shape[3]] * linear_operator.forward(noise_cur) + linear_operator.forward(pred_xstart) - y_n) ** 2).mean((1, 2, 3))
719
+ else:
720
+ raise NotImplementedError()
721
+ # print("getting loss")
722
+ loss = compute_loss(noise)
723
+ best_idx = torch.argmax(loss)
724
+ best_noise = noise[best_idx]
725
+ best_loss = loss[best_idx]
726
+
727
+ if num_pursuit_noises > 1:
728
+ pursuit_coefs = np.linspace(0, 1, 2 ** num_pursuit_coef_bits + 1)[1:]
729
+
730
+ for _ in range(num_pursuit_noises - 1):
731
+ next_best_noise = best_noise
732
+ for pursuit_coef in pursuit_coefs:
733
+ new_noise = best_noise.unsqueeze(0) * np.sqrt(pursuit_coef) + noise * np.sqrt(1 - pursuit_coef)
734
+ new_noise /= new_noise.view(noise.shape[0], -1).std(1).view(noise.shape[0], 1, 1, 1)
735
+ cur_loss = compute_loss(new_noise)
736
+ cur_best_idx = torch.argmax(cur_loss)
737
+ cur_best_loss = cur_loss[cur_best_idx]
738
+
739
+ if cur_best_loss > best_loss:
740
+ next_best_noise = new_noise[cur_best_idx]
741
+ best_loss = cur_best_loss
742
+
743
+ best_noise = next_best_noise
744
+
745
+ if t != 0:
746
+ sample += sigma * best_noise.unsqueeze(0)
747
+
748
+ return {'sample': sample if t[0] > 0 else pred_xstart,
749
+ 'pred_xstart': pred_xstart,
750
+ 'mse': loss[best_idx].item(),
751
+ 'best_idx': best_idx}
752
+ else:
753
+ if random_opt_mse_noises > 0 and not optimize_iqa:
754
+ num_rand_indices = random_opt_mse_noises
755
+ elif optimize_iqa and random_opt_mse_noises <= 0:
756
+ num_rand_indices = 1
757
+ elif cond_fn is not None:
758
+ num_rand_indices = 2
759
+ else:
760
+ raise NotImplementedError()
761
+ loss = torch.matmul(noise.view(noise.shape[0], -1),
762
+ (ref - pred_xstart).view(pred_xstart.shape[0], -1).transpose(0, 1)).squeeze()
763
+ best_idx = torch.argmax(loss).reshape(1)
764
+ rand_idx = torch.randint(0, noise.shape[0], size=(num_rand_indices, ), device=best_idx.device).reshape(num_rand_indices)
765
+ best_and_rand_idx = torch.cat((best_idx, rand_idx), dim=0).flatten()
766
+ if t != 0:
767
+ sample = sample + sigma * noise[best_and_rand_idx]
768
+ return {'sample': sample,
769
+ 'pred_xstart': pred_xstart,
770
+ 'best_idx': best_and_rand_idx,
771
+ 'best_perceptual_idx': best_perceptual_idx}
772
+
773
+ def predict_eps_from_x_start(self, x_t, t, pred_xstart):
774
+ coef1 = extract_and_expand(self.sqrt_recip_alphas_cumprod, t, x_t)
775
+ coef2 = extract_and_expand(self.sqrt_recipm1_alphas_cumprod, t, x_t)
776
+ return (coef1 * x_t - pred_xstart) / coef2
777
+
778
+ # =================
779
+ # Helper functions
780
+ # =================
781
+
782
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
783
+ """
784
+ Get a pre-defined beta schedule for the given name.
785
+
786
+ The beta schedule library consists of beta schedules which remain similar
787
+ in the limit of num_diffusion_timesteps.
788
+ Beta schedules may be added, but should not be removed or changed once
789
+ they are committed to maintain backwards compatibility.
790
+ """
791
+ if schedule_name == "linear":
792
+ # Linear schedule from Ho et al, extended to work for any number of
793
+ # diffusion steps.
794
+ scale = 1000 / num_diffusion_timesteps
795
+ beta_start = scale * 0.0001
796
+ beta_end = scale * 0.02
797
+ return np.linspace(
798
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
799
+ )
800
+ elif schedule_name == "cosine":
801
+ return betas_for_alpha_bar(
802
+ num_diffusion_timesteps,
803
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
804
+ )
805
+ else:
806
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
807
+
808
+
809
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
810
+ """
811
+ Create a beta schedule that discretizes the given alpha_t_bar function,
812
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
813
+
814
+ :param num_diffusion_timesteps: the number of betas to produce.
815
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
816
+ produces the cumulative product of (1-beta) up to that
817
+ part of the diffusion process.
818
+ :param max_beta: the maximum beta to use; use values lower than 1 to
819
+ prevent singularities.
820
+ """
821
+ betas = []
822
+ for i in range(num_diffusion_timesteps):
823
+ t1 = i / num_diffusion_timesteps
824
+ t2 = (i + 1) / num_diffusion_timesteps
825
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
826
+ return np.array(betas)
827
+
828
+ # ================
829
+ # Helper function
830
+ # ================
831
+
832
+ def extract_and_expand(array, time, target):
833
+ array = torch.from_numpy(array).to(target.device)[time].float()
834
+ while array.ndim < target.ndim:
835
+ array = array.unsqueeze(-1)
836
+ return array.expand_as(target)
837
+
838
+
839
+ def expand_as(array, target):
840
+ if isinstance(array, np.ndarray):
841
+ array = torch.from_numpy(array)
842
+ elif isinstance(array, np.float):
843
+ array = torch.tensor([array])
844
+
845
+ while array.ndim < target.ndim:
846
+ array = array.unsqueeze(-1)
847
+
848
+ return array.expand_as(target).to(target.device)
849
+
850
+
851
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
852
+ """
853
+ Extract values from a 1-D numpy array for a batch of indices.
854
+
855
+ :param arr: the 1-D numpy array.
856
+ :param timesteps: a tensor of indices into the array to extract.
857
+ :param broadcast_shape: a larger shape of K dimensions with the batch
858
+ dimension equal to the length of timesteps.
859
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
860
+ """
861
+ res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
862
+ while len(res.shape) < len(broadcast_shape):
863
+ res = res[..., None]
864
+ return res.expand(broadcast_shape)
guided_diffusion/measurements.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''This module handles task-dependent operations (A) and noises (n) to simulate a measurement y=Ax+n.'''
2
+
3
+ from abc import ABC, abstractmethod
4
+ from functools import partial
5
+ import yaml
6
+ from torch.nn import functional as F
7
+ from torchvision import torch
8
+
9
+ from util.resizer import Resizer
10
+ from util.img_utils import Blurkernel, fft2_m
11
+
12
+
13
+ # =================
14
+ # Operation classes
15
+ # =================
16
+
17
+ __OPERATOR__ = {}
18
+
19
+ def register_operator(name: str):
20
+ def wrapper(cls):
21
+ if __OPERATOR__.get(name, None):
22
+ raise NameError(f"Name {name} is already registered!")
23
+ __OPERATOR__[name] = cls
24
+ return cls
25
+ return wrapper
26
+
27
+
28
+ def get_operator(name: str, **kwargs):
29
+ if __OPERATOR__.get(name, None) is None:
30
+ raise NameError(f"Name {name} is not defined.")
31
+ return __OPERATOR__[name](**kwargs)
32
+
33
+
34
+ class LinearOperator(ABC):
35
+ @abstractmethod
36
+ def forward(self, data, **kwargs):
37
+ # calculate A * X
38
+ pass
39
+
40
+ @abstractmethod
41
+ def transpose(self, data, **kwargs):
42
+ # calculate A^T * X
43
+ pass
44
+
45
+ def ortho_project(self, data, **kwargs):
46
+ # calculate (I - A^T * A)X
47
+ return data - self.transpose(self.forward(data, **kwargs), **kwargs)
48
+
49
+ def project(self, data, measurement, **kwargs):
50
+ # calculate (I - A^T * A)Y - AX
51
+ return self.ortho_project(measurement, **kwargs) - self.forward(data, **kwargs)
52
+
53
+
54
+ @register_operator(name='noise')
55
+ class DenoiseOperator(LinearOperator):
56
+ def __init__(self, device):
57
+ self.device = device
58
+
59
+ def forward(self, data):
60
+ return data
61
+
62
+ def transpose(self, data):
63
+ return data
64
+
65
+ def ortho_project(self, data):
66
+ return data
67
+
68
+ def project(self, data):
69
+ return data
70
+
71
+
72
+ @register_operator(name='super_resolution')
73
+ class SuperResolutionOperator(LinearOperator):
74
+ def __init__(self, in_shape, scale_factor, device):
75
+ self.device = device
76
+ self.up_sample = partial(F.interpolate, scale_factor=scale_factor)
77
+ self.down_sample = Resizer(in_shape, 1/scale_factor).to(device)
78
+
79
+ def forward(self, data, **kwargs):
80
+ return self.down_sample(data)
81
+
82
+ def transpose(self, data, **kwargs):
83
+ return self.up_sample(data)
84
+
85
+ def project(self, data, measurement, **kwargs):
86
+ return data - self.transpose(self.forward(data)) + self.transpose(measurement)
87
+
88
+
89
+
90
+ @register_operator(name='motion_blur')
91
+ class MotionBlurOperator(LinearOperator):
92
+ def __init__(self, kernel_size, intensity, device):
93
+ self.device = device
94
+ self.kernel_size = kernel_size
95
+ self.conv = Blurkernel(blur_type='motion',
96
+ kernel_size=kernel_size,
97
+ std=intensity,
98
+ device=device).to(device) # should we keep this device term?
99
+
100
+ self.kernel = Kernel(size=(kernel_size, kernel_size), intensity=intensity)
101
+ kernel = torch.tensor(self.kernel.kernelMatrix, dtype=torch.float32)
102
+ self.conv.update_weights(kernel)
103
+
104
+ def forward(self, data, **kwargs):
105
+ # A^T * A
106
+ return self.conv(data)
107
+
108
+ def transpose(self, data, **kwargs):
109
+ return data
110
+
111
+ def get_kernel(self):
112
+ kernel = self.kernel.kernelMatrix.type(torch.float32).to(self.device)
113
+ return kernel.view(1, 1, self.kernel_size, self.kernel_size)
114
+
115
+
116
+ @register_operator(name='colorization')
117
+ class ColorizationOperator(LinearOperator):
118
+ def __init__(self, device):
119
+ self.device = device
120
+
121
+ def forward(self, data, **kwargs):
122
+ return (1/3) * torch.sum(data, dim=1, keepdim=True)
123
+
124
+ def transpose(self, data, **kwargs):
125
+ return data
126
+
127
+
128
+
129
+ @register_operator(name='gaussian_blur')
130
+ class GaussialBlurOperator(LinearOperator):
131
+ def __init__(self, kernel_size, intensity, device):
132
+ self.device = device
133
+ self.kernel_size = kernel_size
134
+ self.conv = Blurkernel(blur_type='gaussian',
135
+ kernel_size=kernel_size,
136
+ std=intensity,
137
+ device=device).to(device)
138
+ self.kernel = self.conv.get_kernel()
139
+ self.conv.update_weights(self.kernel.type(torch.float32))
140
+
141
+ def forward(self, data, **kwargs):
142
+ return self.conv(data)
143
+
144
+ def transpose(self, data, **kwargs):
145
+ return data
146
+
147
+ def get_kernel(self):
148
+ return self.kernel.view(1, 1, self.kernel_size, self.kernel_size)
149
+
150
+ def project(self, data, measurement, **kwargs):
151
+ # calculate (I - A^T * A)Y - AX
152
+ return data - self.forward(data, **kwargs) + measurement
153
+
154
+ @register_operator(name='inpainting')
155
+ class InpaintingOperator(LinearOperator):
156
+ '''This operator get pre-defined mask and return masked image.'''
157
+ def __init__(self, device):
158
+ self.device = device
159
+
160
+ def set_mask(self, mask):
161
+ self.mask = mask
162
+
163
+ def forward(self, data, **kwargs):
164
+ try:
165
+ return data * self.mask.to(self.device)
166
+ except:
167
+ raise ValueError("Require mask")
168
+
169
+ def transpose(self, data, **kwargs):
170
+ return data
171
+
172
+ def ortho_project(self, data, **kwargs):
173
+ return data - self.forward(data, **kwargs)
174
+
175
+ def project(self, data, measurement, **kwargs):
176
+ return data - self.forward(data, **kwargs) + measurement
177
+
178
+
179
+ class NonLinearOperator(ABC):
180
+ @abstractmethod
181
+ def forward(self, data, **kwargs):
182
+ pass
183
+
184
+ def project(self, data, measurement, **kwargs):
185
+ return data + measurement - self.forward(data)
186
+
187
+ @register_operator(name='phase_retrieval')
188
+ class PhaseRetrievalOperator(NonLinearOperator):
189
+ def __init__(self, oversample, device):
190
+ self.pad = int((oversample / 8.0) * 256)
191
+ self.device = device
192
+
193
+ def forward(self, data, **kwargs):
194
+ padded = F.pad(data, (self.pad, self.pad, self.pad, self.pad))
195
+ amplitude = fft2_m(padded).abs()
196
+ return amplitude
197
+
198
+ @register_operator(name='nonlinear_blur')
199
+ class NonlinearBlurOperator(NonLinearOperator):
200
+ def __init__(self, opt_yml_path, device):
201
+ self.device = device
202
+ self.blur_model = self.prepare_nonlinear_blur_model(opt_yml_path)
203
+
204
+ def prepare_nonlinear_blur_model(self, opt_yml_path):
205
+ '''
206
+ Nonlinear deblur requires external codes (bkse).
207
+ '''
208
+ from bkse.models.kernel_encoding.kernel_wizard import KernelWizard
209
+
210
+ with open(opt_yml_path, "r") as f:
211
+ opt = yaml.safe_load(f)["KernelWizard"]
212
+ model_path = opt["pretrained"]
213
+ blur_model = KernelWizard(opt)
214
+ blur_model.eval()
215
+ blur_model.load_state_dict(torch.load(model_path))
216
+ blur_model = blur_model.to(self.device)
217
+ return blur_model
218
+
219
+ def forward(self, data, **kwargs):
220
+ random_kernel = torch.randn(1, 512, 2, 2).to(self.device) * 1.2
221
+ data = (data + 1.0) / 2.0 #[-1, 1] -> [0, 1]
222
+ blurred = self.blur_model.adaptKernel(data, kernel=random_kernel)
223
+ blurred = (blurred * 2.0 - 1.0).clamp(-1, 1) #[0, 1] -> [-1, 1]
224
+ return blurred
225
+
226
+ # =============
227
+ # Noise classes
228
+ # =============
229
+
230
+
231
+ __NOISE__ = {}
232
+
233
+ def register_noise(name: str):
234
+ def wrapper(cls):
235
+ if __NOISE__.get(name, None):
236
+ raise NameError(f"Name {name} is already defined!")
237
+ __NOISE__[name] = cls
238
+ return cls
239
+ return wrapper
240
+
241
+ def get_noise(name: str, **kwargs):
242
+ if __NOISE__.get(name, None) is None:
243
+ raise NameError(f"Name {name} is not defined.")
244
+ noiser = __NOISE__[name](**kwargs)
245
+ noiser.__name__ = name
246
+ return noiser
247
+
248
+ class Noise(ABC):
249
+ def __call__(self, data):
250
+ return self.forward(data)
251
+
252
+ @abstractmethod
253
+ def forward(self, data):
254
+ pass
255
+
256
+ @register_noise(name='clean')
257
+ class Clean(Noise):
258
+ def forward(self, data):
259
+ return data
260
+
261
+ @register_noise(name='gaussian')
262
+ class GaussianNoise(Noise):
263
+ def __init__(self, sigma):
264
+ self.sigma = sigma
265
+
266
+ def forward(self, data):
267
+ return data + torch.randn_like(data, device=data.device) * self.sigma * 2
268
+
269
+
270
+ @register_noise(name='poisson')
271
+ class PoissonNoise(Noise):
272
+ def __init__(self, rate):
273
+ self.rate = rate
274
+
275
+ def forward(self, data):
276
+ '''
277
+ Follow skimage.util.random_noise.
278
+ '''
279
+
280
+ # TODO: set one version of poisson
281
+
282
+ # version 3 (stack-overflow)
283
+ import numpy as np
284
+ data = (data + 1.0) / 2.0
285
+ data = data.clamp(0, 1)
286
+ device = data.device
287
+ data = data.detach().cpu()
288
+ data = torch.from_numpy(np.random.poisson(data * 255.0 * self.rate) / 255.0 / self.rate)
289
+ data = data * 2.0 - 1.0
290
+ data = data.clamp(-1, 1)
291
+ return data.to(device)
292
+
293
+ # version 2 (skimage)
294
+ # if data.min() < 0:
295
+ # low_clip = -1
296
+ # else:
297
+ # low_clip = 0
298
+
299
+
300
+ # # Determine unique values in iamge & calculate the next power of two
301
+ # vals = torch.Tensor([len(torch.unique(data))])
302
+ # vals = 2 ** torch.ceil(torch.log2(vals))
303
+ # vals = vals.to(data.device)
304
+
305
+ # if low_clip == -1:
306
+ # old_max = data.max()
307
+ # data = (data + 1.0) / (old_max + 1.0)
308
+
309
+ # data = torch.poisson(data * vals) / float(vals)
310
+
311
+ # if low_clip == -1:
312
+ # data = data * (old_max + 1.0) - 1.0
313
+
314
+ # return data.clamp(low_clip, 1.0)
guided_diffusion/nn.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Various utilities for neural networks.
3
+ """
4
+
5
+ import math
6
+
7
+ import torch as th
8
+ import torch.nn as nn
9
+
10
+
11
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12
+ class SiLU(nn.Module):
13
+ def forward(self, x):
14
+ return x * th.sigmoid(x)
15
+
16
+
17
+ class GroupNorm32(nn.GroupNorm):
18
+ def forward(self, x):
19
+ return super().forward(x.float()).type(x.dtype)
20
+
21
+
22
+ def conv_nd(dims, *args, **kwargs):
23
+ """
24
+ Create a 1D, 2D, or 3D convolution module.
25
+ """
26
+ if dims == 1:
27
+ return nn.Conv1d(*args, **kwargs)
28
+ elif dims == 2:
29
+ return nn.Conv2d(*args, **kwargs)
30
+ elif dims == 3:
31
+ return nn.Conv3d(*args, **kwargs)
32
+ raise ValueError(f"unsupported dimensions: {dims}")
33
+
34
+
35
+ def linear(*args, **kwargs):
36
+ """
37
+ Create a linear module.
38
+ """
39
+ return nn.Linear(*args, **kwargs)
40
+
41
+
42
+ def avg_pool_nd(dims, *args, **kwargs):
43
+ """
44
+ Create a 1D, 2D, or 3D average pooling module.
45
+ """
46
+ if dims == 1:
47
+ return nn.AvgPool1d(*args, **kwargs)
48
+ elif dims == 2:
49
+ return nn.AvgPool2d(*args, **kwargs)
50
+ elif dims == 3:
51
+ return nn.AvgPool3d(*args, **kwargs)
52
+ raise ValueError(f"unsupported dimensions: {dims}")
53
+
54
+
55
+ def update_ema(target_params, source_params, rate=0.99):
56
+ """
57
+ Update target parameters to be closer to those of source parameters using
58
+ an exponential moving average.
59
+
60
+ :param target_params: the target parameter sequence.
61
+ :param source_params: the source parameter sequence.
62
+ :param rate: the EMA rate (closer to 1 means slower).
63
+ """
64
+ for targ, src in zip(target_params, source_params):
65
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
66
+
67
+
68
+ def zero_module(module):
69
+ """
70
+ Zero out the parameters of a module and return it.
71
+ """
72
+ for p in module.parameters():
73
+ p.detach().zero_()
74
+ return module
75
+
76
+
77
+ def scale_module(module, scale):
78
+ """
79
+ Scale the parameters of a module and return it.
80
+ """
81
+ for p in module.parameters():
82
+ p.detach().mul_(scale)
83
+ return module
84
+
85
+
86
+ def mean_flat(tensor):
87
+ """
88
+ Take the mean over all non-batch dimensions.
89
+ """
90
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
91
+
92
+
93
+ def normalization(channels):
94
+ """
95
+ Make a standard normalization layer.
96
+
97
+ :param channels: number of input channels.
98
+ :return: an nn.Module for normalization.
99
+ """
100
+ return GroupNorm32(32, channels)
101
+
102
+
103
+ def timestep_embedding(timesteps, dim, max_period=10000):
104
+ """
105
+ Create sinusoidal timestep embeddings.
106
+
107
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
108
+ These may be fractional.
109
+ :param dim: the dimension of the output.
110
+ :param max_period: controls the minimum frequency of the embeddings.
111
+ :return: an [N x dim] Tensor of positional embeddings.
112
+ """
113
+ half = dim // 2
114
+ freqs = th.exp(
115
+ -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
116
+ ).to(device=timesteps.device)
117
+ args = timesteps[:, None].float() * freqs[None]
118
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
119
+ if dim % 2:
120
+ embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
121
+ return embedding
122
+
123
+
124
+ def checkpoint(func, inputs, params, flag):
125
+ """
126
+ Evaluate a function without caching intermediate activations, allowing for
127
+ reduced memory at the expense of extra compute in the backward pass.
128
+
129
+ :param func: the function to evaluate.
130
+ :param inputs: the argument sequence to pass to `func`.
131
+ :param params: a sequence of parameters `func` depends on but does not
132
+ explicitly take as arguments.
133
+ :param flag: if False, disable gradient checkpointing.
134
+ """
135
+ if flag:
136
+ args = tuple(inputs) + tuple(params)
137
+ return CheckpointFunction.apply(func, len(inputs), *args)
138
+ else:
139
+ return func(*inputs)
140
+
141
+
142
+ class CheckpointFunction(th.autograd.Function):
143
+ @staticmethod
144
+ def forward(ctx, run_function, length, *args):
145
+ ctx.run_function = run_function
146
+ ctx.input_tensors = list(args[:length])
147
+ ctx.input_params = list(args[length:])
148
+ with th.no_grad():
149
+ output_tensors = ctx.run_function(*ctx.input_tensors)
150
+ return output_tensors
151
+
152
+ @staticmethod
153
+ def backward(ctx, *output_grads):
154
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
155
+ with th.enable_grad():
156
+ # Fixes a bug where the first op in run_function modifies the
157
+ # Tensor storage in place, which is not allowed for detach()'d
158
+ # Tensors.
159
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
160
+ output_tensors = ctx.run_function(*shallow_copies)
161
+ input_grads = th.autograd.grad(
162
+ output_tensors,
163
+ ctx.input_tensors + ctx.input_params,
164
+ output_grads,
165
+ allow_unused=True,
166
+ )
167
+ del ctx.input_tensors
168
+ del ctx.input_params
169
+ del output_tensors
170
+ return (None, None) + input_grads
guided_diffusion/posterior_mean_variance.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from util.img_utils import dynamic_thresholding
7
+
8
+
9
+
10
+ # ====================
11
+ # Model Mean Processor
12
+ # ====================
13
+
14
+ __MODEL_MEAN_PROCESSOR__ = {}
15
+
16
+ def register_mean_processor(name: str):
17
+ def wrapper(cls):
18
+ if __MODEL_MEAN_PROCESSOR__.get(name, None):
19
+ raise NameError(f"Name {name} is already registerd.")
20
+ __MODEL_MEAN_PROCESSOR__[name] = cls
21
+ return cls
22
+ return wrapper
23
+
24
+ def get_mean_processor(name: str, **kwargs):
25
+ if __MODEL_MEAN_PROCESSOR__.get(name, None) is None:
26
+ raise NameError(f"Name {name} is not defined.")
27
+ return __MODEL_MEAN_PROCESSOR__[name](**kwargs)
28
+
29
+ class MeanProcessor(ABC):
30
+ """Predict x_start and calculate mean value"""
31
+ @abstractmethod
32
+ def __init__(self, betas, dynamic_threshold, clip_denoised):
33
+ self.dynamic_threshold = dynamic_threshold
34
+ self.clip_denoised = clip_denoised
35
+
36
+ @abstractmethod
37
+ def get_mean_and_xstart(self, x, t, model_output):
38
+ pass
39
+
40
+ def process_xstart(self, x):
41
+ if self.dynamic_threshold:
42
+ x = dynamic_thresholding(x, s=0.95)
43
+ if self.clip_denoised:
44
+ x = x.clamp(-1, 1)
45
+ return x
46
+
47
+ @register_mean_processor(name='previous_x')
48
+ class PreviousXMeanProcessor(MeanProcessor):
49
+ def __init__(self, betas, dynamic_threshold, clip_denoised):
50
+ super().__init__(betas, dynamic_threshold, clip_denoised)
51
+ alphas = 1.0 - betas
52
+ alphas_cumprod = np.cumprod(alphas, axis=0)
53
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
54
+
55
+ self.posterior_mean_coef1 = betas * np.sqrt(alphas_cumprod_prev) / (1.0-alphas_cumprod)
56
+ self.posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
57
+
58
+ def predict_xstart(self, x_t, t, x_prev):
59
+ coef1 = extract_and_expand(1.0/self.posterior_mean_coef1, t, x_t)
60
+ coef2 = extract_and_expand(self.posterior_mean_coef2/self.posterior_mean_coef1, t, x_t)
61
+ return coef1 * x_prev - coef2 * x_t
62
+
63
+ def get_mean_and_xstart(self, x, t, model_output):
64
+ mean = model_output
65
+ pred_xstart = self.process_xstart(self.predict_xstart(x, t, model_output))
66
+ return mean, pred_xstart
67
+
68
+ @register_mean_processor(name='start_x')
69
+ class StartXMeanProcessor(MeanProcessor):
70
+ def __init__(self, betas, dynamic_threshold, clip_denoised):
71
+ super().__init__(betas, dynamic_threshold, clip_denoised)
72
+ alphas = 1.0 - betas
73
+ alphas_cumprod = np.cumprod(alphas, axis=0)
74
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
75
+
76
+ self.posterior_mean_coef1 = betas * np.sqrt(alphas_cumprod_prev) / (1.0-alphas_cumprod)
77
+ self.posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
78
+
79
+ def q_posterior_mean(self, x_start, x_t, t):
80
+ """
81
+ Compute the mean of the diffusion posteriro:
82
+ q(x_{t-1} | x_t, x_0)
83
+ """
84
+ assert x_start.shape == x_t.shape
85
+ coef1 = extract_and_expand(self.posterior_mean_coef1, t, x_start)
86
+ coef2 = extract_and_expand(self.posterior_mean_coef2, t, x_t)
87
+
88
+ return coef1 * x_start + coef2 * x_t
89
+
90
+ def get_mean_and_xstart(self, x, t, model_output):
91
+ pred_xstart = self.process_xstart(model_output)
92
+ mean = self.q_posterior_mean(x_start=pred_xstart, x_t=x, t=t)
93
+
94
+ return mean, pred_xstart
95
+
96
+ @register_mean_processor(name='epsilon')
97
+ class EpsilonXMeanProcessor(MeanProcessor):
98
+ def __init__(self, betas, dynamic_threshold, clip_denoised):
99
+ super().__init__(betas, dynamic_threshold, clip_denoised)
100
+ alphas = 1.0 - betas
101
+ alphas_cumprod = np.cumprod(alphas, axis=0)
102
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
103
+
104
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod)
105
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1)
106
+ self.posterior_mean_coef1 = betas * np.sqrt(alphas_cumprod_prev) / (1.0-alphas_cumprod)
107
+ self.posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
108
+
109
+
110
+ def q_posterior_mean(self, x_start, x_t, t):
111
+ """
112
+ Compute the mean of the diffusion posteriro:
113
+ q(x_{t-1} | x_t, x_0)
114
+ """
115
+ assert x_start.shape == x_t.shape
116
+ coef1 = extract_and_expand(self.posterior_mean_coef1, t, x_start)
117
+ coef2 = extract_and_expand(self.posterior_mean_coef2, t, x_t)
118
+ return coef1 * x_start + coef2 * x_t
119
+
120
+ def predict_xstart(self, x_t, t, eps):
121
+ coef1 = extract_and_expand(self.sqrt_recip_alphas_cumprod, t, x_t)
122
+ coef2 = extract_and_expand(self.sqrt_recipm1_alphas_cumprod, t, eps)
123
+ return coef1 * x_t - coef2 * eps
124
+
125
+ def get_mean_and_xstart(self, x, t, model_output):
126
+ pred_xstart = self.process_xstart(self.predict_xstart(x, t, model_output))
127
+ mean = self.q_posterior_mean(pred_xstart, x, t)
128
+
129
+ return mean, pred_xstart
130
+
131
+ # =========================
132
+ # Model Variance Processor
133
+ # =========================
134
+
135
+ __MODEL_VAR_PROCESSOR__ = {}
136
+
137
+ def register_var_processor(name: str):
138
+ def wrapper(cls):
139
+ if __MODEL_VAR_PROCESSOR__.get(name, None):
140
+ raise NameError(f"Name {name} is already registerd.")
141
+ __MODEL_VAR_PROCESSOR__[name] = cls
142
+ return cls
143
+ return wrapper
144
+
145
+ def get_var_processor(name: str, **kwargs):
146
+ if __MODEL_VAR_PROCESSOR__.get(name, None) is None:
147
+ raise NameError(f"Name {name} is not defined.")
148
+ return __MODEL_VAR_PROCESSOR__[name](**kwargs)
149
+
150
+ class VarianceProcessor(ABC):
151
+ @abstractmethod
152
+ def __init__(self, betas):
153
+ pass
154
+
155
+ @abstractmethod
156
+ def get_variance(self, x, t):
157
+ pass
158
+
159
+ @register_var_processor(name='fixed_small')
160
+ class FixedSmallVarianceProcessor(VarianceProcessor):
161
+ def __init__(self, betas):
162
+ alphas = 1.0 - betas
163
+ alphas_cumprod = np.cumprod(alphas, axis=0)
164
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
165
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
166
+ self.posterior_variance = (
167
+ betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
168
+ )
169
+
170
+ def get_variance(self, x, t):
171
+ model_variance = self.posterior_variance
172
+ model_log_variance = np.log(model_variance)
173
+
174
+ model_variance = extract_and_expand(model_variance, t, x)
175
+ model_log_variance = extract_and_expand(model_log_variance, t, x)
176
+
177
+ return model_variance, model_log_variance
178
+
179
+ @register_var_processor(name='fixed_large')
180
+ class FixedLargeVarianceProcessor(VarianceProcessor):
181
+ def __init__(self, betas):
182
+ self.betas = betas
183
+
184
+ alphas = 1.0 - betas
185
+ alphas_cumprod = np.cumprod(alphas, axis=0)
186
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
187
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
188
+ self.posterior_variance = (
189
+ betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
190
+ )
191
+
192
+ def get_variance(self, x, t):
193
+ model_variance = np.append(self.posterior_variance[1], self.betas[1:])
194
+ model_log_variance = np.log(model_variance)
195
+
196
+ model_variance = extract_and_expand(model_variance, t, x)
197
+ model_log_variance = extract_and_expand(model_log_variance, t, x)
198
+
199
+ return model_variance, model_log_variance
200
+
201
+ @register_var_processor(name='learned')
202
+ class LearnedVarianceProcessor(VarianceProcessor):
203
+ def __init__(self, betas):
204
+ pass
205
+
206
+ def get_variance(self, x, t):
207
+ model_log_variance = x
208
+ model_variance = torch.exp(model_log_variance)
209
+ return model_variance, model_log_variance
210
+
211
+ @register_var_processor(name='learned_range')
212
+ class LearnedRangeVarianceProcessor(VarianceProcessor):
213
+ def __init__(self, betas):
214
+ self.betas = betas
215
+
216
+ alphas = 1.0 - betas
217
+ alphas_cumprod = np.cumprod(alphas, axis=0)
218
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
219
+
220
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
221
+ posterior_variance = (
222
+ betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
223
+ )
224
+ # log calculation clipped because the posterior variance is 0 at the
225
+ # beginning of the diffusion chain.
226
+ self.posterior_log_variance_clipped = np.log(
227
+ np.append(posterior_variance[1], posterior_variance[1:])
228
+ )
229
+
230
+ def get_variance(self, x, t):
231
+ model_var_values = x
232
+ min_log = self.posterior_log_variance_clipped
233
+ max_log = np.log(self.betas)
234
+
235
+ min_log = extract_and_expand(min_log, t, x)
236
+ max_log = extract_and_expand(max_log, t, x)
237
+
238
+ # The model_var_values is [-1, 1] for [min_var, max_var]
239
+ frac = (model_var_values + 1.0) / 2.0
240
+ model_log_variance = frac * max_log + (1-frac) * min_log
241
+ model_variance = torch.exp(model_log_variance)
242
+ return model_variance, model_log_variance
243
+
244
+ # ================
245
+ # Helper function
246
+ # ================
247
+
248
+ def extract_and_expand(array, time, target):
249
+ array = torch.from_numpy(array).to(target.device)[time].float()
250
+ while array.ndim < target.ndim:
251
+ array = array.unsqueeze(-1)
252
+ return array.expand_as(target)
253
+
254
+
255
+ def expand_as(array, target):
256
+ if isinstance(array, np.ndarray):
257
+ array = torch.from_numpy(array)
258
+ elif isinstance(array, np.float):
259
+ array = torch.tensor([array])
260
+
261
+ while array.ndim < target.ndim:
262
+ array = array.unsqueeze(-1)
263
+
264
+ return array.expand_as(target).to(target.device)
guided_diffusion/swinir.py ADDED
@@ -0,0 +1,904 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------------
2
+ # SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
3
+ # Originally Written by Ze Liu, Modified by Jingyun Liang.
4
+ # -----------------------------------------------------------------------------------
5
+ # Borrowed from DifFace (https://github.com/zsyOAOA/DifFace/blob/master/models/swinir.py)
6
+
7
+ import math
8
+ from typing import Set
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.utils.checkpoint as checkpoint
14
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
15
+
16
+
17
+ class Mlp(nn.Module):
18
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
19
+ super().__init__()
20
+ out_features = out_features or in_features
21
+ hidden_features = hidden_features or in_features
22
+ self.fc1 = nn.Linear(in_features, hidden_features)
23
+ self.act = act_layer()
24
+ self.fc2 = nn.Linear(hidden_features, out_features)
25
+ self.drop = nn.Dropout(drop)
26
+
27
+ def forward(self, x):
28
+ x = self.fc1(x)
29
+ x = self.act(x)
30
+ x = self.drop(x)
31
+ x = self.fc2(x)
32
+ x = self.drop(x)
33
+ return x
34
+
35
+
36
+ def window_partition(x, window_size):
37
+ """
38
+ Args:
39
+ x: (B, H, W, C)
40
+ window_size (int): window size
41
+
42
+ Returns:
43
+ windows: (num_windows*B, window_size, window_size, C)
44
+ """
45
+ B, H, W, C = x.shape
46
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
47
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
48
+ return windows
49
+
50
+
51
+ def window_reverse(windows, window_size, H, W):
52
+ """
53
+ Args:
54
+ windows: (num_windows*B, window_size, window_size, C)
55
+ window_size (int): Window size
56
+ H (int): Height of image
57
+ W (int): Width of image
58
+
59
+ Returns:
60
+ x: (B, H, W, C)
61
+ """
62
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
63
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
64
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
65
+ return x
66
+
67
+
68
+ class WindowAttention(nn.Module):
69
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
70
+ It supports both of shifted and non-shifted window.
71
+
72
+ Args:
73
+ dim (int): Number of input channels.
74
+ window_size (tuple[int]): The height and width of the window.
75
+ num_heads (int): Number of attention heads.
76
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
77
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
78
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
79
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
80
+ """
81
+
82
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
83
+
84
+ super().__init__()
85
+ self.dim = dim
86
+ self.window_size = window_size # Wh, Ww
87
+ self.num_heads = num_heads
88
+ head_dim = dim // num_heads
89
+ self.scale = qk_scale or head_dim ** -0.5
90
+
91
+ # define a parameter table of relative position bias
92
+ self.relative_position_bias_table = nn.Parameter(
93
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
94
+
95
+ # get pair-wise relative position index for each token inside the window
96
+ coords_h = torch.arange(self.window_size[0])
97
+ coords_w = torch.arange(self.window_size[1])
98
+ # coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
99
+ # Fix: Pass indexing="ij" to avoid warning
100
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww
101
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
102
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
103
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
104
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
105
+ relative_coords[:, :, 1] += self.window_size[1] - 1
106
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
107
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
108
+ self.register_buffer("relative_position_index", relative_position_index)
109
+
110
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
111
+ self.attn_drop = nn.Dropout(attn_drop)
112
+ self.proj = nn.Linear(dim, dim)
113
+
114
+ self.proj_drop = nn.Dropout(proj_drop)
115
+
116
+ trunc_normal_(self.relative_position_bias_table, std=.02)
117
+ self.softmax = nn.Softmax(dim=-1)
118
+
119
+ def forward(self, x, mask=None):
120
+ """
121
+ Args:
122
+ x: input features with shape of (num_windows*B, N, C)
123
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
124
+ """
125
+ B_, N, C = x.shape
126
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
127
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
128
+
129
+ q = q * self.scale
130
+ attn = (q @ k.transpose(-2, -1))
131
+
132
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
133
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
134
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
135
+ attn = attn + relative_position_bias.unsqueeze(0)
136
+
137
+ if mask is not None:
138
+ nW = mask.shape[0]
139
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
140
+ attn = attn.view(-1, self.num_heads, N, N)
141
+ attn = self.softmax(attn)
142
+ else:
143
+ attn = self.softmax(attn)
144
+
145
+ attn = self.attn_drop(attn)
146
+
147
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
148
+ x = self.proj(x)
149
+ x = self.proj_drop(x)
150
+ return x
151
+
152
+ def extra_repr(self) -> str:
153
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
154
+
155
+ def flops(self, N):
156
+ # calculate flops for 1 window with token length of N
157
+ flops = 0
158
+ # qkv = self.qkv(x)
159
+ flops += N * self.dim * 3 * self.dim
160
+ # attn = (q @ k.transpose(-2, -1))
161
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
162
+ # x = (attn @ v)
163
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
164
+ # x = self.proj(x)
165
+ flops += N * self.dim * self.dim
166
+ return flops
167
+
168
+
169
+ class SwinTransformerBlock(nn.Module):
170
+ r""" Swin Transformer Block.
171
+
172
+ Args:
173
+ dim (int): Number of input channels.
174
+ input_resolution (tuple[int]): Input resulotion.
175
+ num_heads (int): Number of attention heads.
176
+ window_size (int): Window size.
177
+ shift_size (int): Shift size for SW-MSA.
178
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
179
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
180
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
181
+ drop (float, optional): Dropout rate. Default: 0.0
182
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
183
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
184
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
185
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
186
+ """
187
+
188
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
189
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
190
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
191
+ super().__init__()
192
+ self.dim = dim
193
+ self.input_resolution = input_resolution
194
+ self.num_heads = num_heads
195
+ self.window_size = window_size
196
+ self.shift_size = shift_size
197
+ self.mlp_ratio = mlp_ratio
198
+ if min(self.input_resolution) <= self.window_size:
199
+ # if window size is larger than input resolution, we don't partition windows
200
+ self.shift_size = 0
201
+ self.window_size = min(self.input_resolution)
202
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
203
+
204
+ self.norm1 = norm_layer(dim)
205
+ self.attn = WindowAttention(
206
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
207
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
208
+
209
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
210
+ self.norm2 = norm_layer(dim)
211
+ mlp_hidden_dim = int(dim * mlp_ratio)
212
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
213
+
214
+ if self.shift_size > 0:
215
+ attn_mask = self.calculate_mask(self.input_resolution)
216
+ else:
217
+ attn_mask = None
218
+
219
+ self.register_buffer("attn_mask", attn_mask)
220
+
221
+ def calculate_mask(self, x_size):
222
+ # calculate attention mask for SW-MSA
223
+ H, W = x_size
224
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
225
+ h_slices = (slice(0, -self.window_size),
226
+ slice(-self.window_size, -self.shift_size),
227
+ slice(-self.shift_size, None))
228
+ w_slices = (slice(0, -self.window_size),
229
+ slice(-self.window_size, -self.shift_size),
230
+ slice(-self.shift_size, None))
231
+ cnt = 0
232
+ for h in h_slices:
233
+ for w in w_slices:
234
+ img_mask[:, h, w, :] = cnt
235
+ cnt += 1
236
+
237
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
238
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
239
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
240
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
241
+
242
+ return attn_mask
243
+
244
+ def forward(self, x, x_size):
245
+ H, W = x_size
246
+ B, L, C = x.shape
247
+ # assert L == H * W, "input feature has wrong size"
248
+
249
+ shortcut = x
250
+ x = self.norm1(x)
251
+ x = x.view(B, H, W, C)
252
+
253
+ # cyclic shift
254
+ if self.shift_size > 0:
255
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
256
+ else:
257
+ shifted_x = x
258
+
259
+ # partition windows
260
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
261
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
262
+
263
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
264
+ if self.input_resolution == x_size:
265
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
266
+ else:
267
+ attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
268
+
269
+ # merge windows
270
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
271
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
272
+
273
+ # reverse cyclic shift
274
+ if self.shift_size > 0:
275
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
276
+ else:
277
+ x = shifted_x
278
+ x = x.view(B, H * W, C)
279
+
280
+ # FFN
281
+ x = shortcut + self.drop_path(x)
282
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
283
+
284
+ return x
285
+
286
+ def extra_repr(self) -> str:
287
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
288
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
289
+
290
+ def flops(self):
291
+ flops = 0
292
+ H, W = self.input_resolution
293
+ # norm1
294
+ flops += self.dim * H * W
295
+ # W-MSA/SW-MSA
296
+ nW = H * W / self.window_size / self.window_size
297
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
298
+ # mlp
299
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
300
+ # norm2
301
+ flops += self.dim * H * W
302
+ return flops
303
+
304
+
305
+ class PatchMerging(nn.Module):
306
+ r""" Patch Merging Layer.
307
+
308
+ Args:
309
+ input_resolution (tuple[int]): Resolution of input feature.
310
+ dim (int): Number of input channels.
311
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
312
+ """
313
+
314
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
315
+ super().__init__()
316
+ self.input_resolution = input_resolution
317
+ self.dim = dim
318
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
319
+ self.norm = norm_layer(4 * dim)
320
+
321
+ def forward(self, x):
322
+ """
323
+ x: B, H*W, C
324
+ """
325
+ H, W = self.input_resolution
326
+ B, L, C = x.shape
327
+ assert L == H * W, "input feature has wrong size"
328
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
329
+
330
+ x = x.view(B, H, W, C)
331
+
332
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
333
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
334
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
335
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
336
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
337
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
338
+
339
+ x = self.norm(x)
340
+ x = self.reduction(x)
341
+
342
+ return x
343
+
344
+ def extra_repr(self) -> str:
345
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
346
+
347
+ def flops(self):
348
+ H, W = self.input_resolution
349
+ flops = H * W * self.dim
350
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
351
+ return flops
352
+
353
+
354
+ class BasicLayer(nn.Module):
355
+ """ A basic Swin Transformer layer for one stage.
356
+
357
+ Args:
358
+ dim (int): Number of input channels.
359
+ input_resolution (tuple[int]): Input resolution.
360
+ depth (int): Number of blocks.
361
+ num_heads (int): Number of attention heads.
362
+ window_size (int): Local window size.
363
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
364
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
365
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
366
+ drop (float, optional): Dropout rate. Default: 0.0
367
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
368
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
369
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
370
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
371
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
372
+ """
373
+
374
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
375
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
376
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
377
+
378
+ super().__init__()
379
+ self.dim = dim
380
+ self.input_resolution = input_resolution
381
+ self.depth = depth
382
+ self.use_checkpoint = use_checkpoint
383
+
384
+ # build blocks
385
+ self.blocks = nn.ModuleList([
386
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
387
+ num_heads=num_heads, window_size=window_size,
388
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
389
+ mlp_ratio=mlp_ratio,
390
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
391
+ drop=drop, attn_drop=attn_drop,
392
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
393
+ norm_layer=norm_layer)
394
+ for i in range(depth)])
395
+
396
+ # patch merging layer
397
+ if downsample is not None:
398
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
399
+ else:
400
+ self.downsample = None
401
+
402
+ def forward(self, x, x_size):
403
+ for blk in self.blocks:
404
+ if self.use_checkpoint:
405
+ x = checkpoint.checkpoint(blk, x, x_size)
406
+ else:
407
+ x = blk(x, x_size)
408
+ if self.downsample is not None:
409
+ x = self.downsample(x)
410
+ return x
411
+
412
+ def extra_repr(self) -> str:
413
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
414
+
415
+ def flops(self):
416
+ flops = 0
417
+ for blk in self.blocks:
418
+ flops += blk.flops()
419
+ if self.downsample is not None:
420
+ flops += self.downsample.flops()
421
+ return flops
422
+
423
+
424
+ class RSTB(nn.Module):
425
+ """Residual Swin Transformer Block (RSTB).
426
+
427
+ Args:
428
+ dim (int): Number of input channels.
429
+ input_resolution (tuple[int]): Input resolution.
430
+ depth (int): Number of blocks.
431
+ num_heads (int): Number of attention heads.
432
+ window_size (int): Local window size.
433
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
434
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
435
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
436
+ drop (float, optional): Dropout rate. Default: 0.0
437
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
438
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
439
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
440
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
441
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
442
+ img_size: Input image size.
443
+ patch_size: Patch size.
444
+ resi_connection: The convolutional block before residual connection.
445
+ """
446
+
447
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
448
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
449
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
450
+ img_size=224, patch_size=4, resi_connection='1conv'):
451
+ super(RSTB, self).__init__()
452
+
453
+ self.dim = dim
454
+ self.input_resolution = input_resolution
455
+
456
+ self.residual_group = BasicLayer(dim=dim,
457
+ input_resolution=input_resolution,
458
+ depth=depth,
459
+ num_heads=num_heads,
460
+ window_size=window_size,
461
+ mlp_ratio=mlp_ratio,
462
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
463
+ drop=drop, attn_drop=attn_drop,
464
+ drop_path=drop_path,
465
+ norm_layer=norm_layer,
466
+ downsample=downsample,
467
+ use_checkpoint=use_checkpoint)
468
+
469
+ if resi_connection == '1conv':
470
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
471
+ elif resi_connection == '3conv':
472
+ # to save parameters and memory
473
+ self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
474
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
475
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
476
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
477
+
478
+ self.patch_embed = PatchEmbed(
479
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
480
+ norm_layer=None)
481
+
482
+ self.patch_unembed = PatchUnEmbed(
483
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
484
+ norm_layer=None)
485
+
486
+ def forward(self, x, x_size):
487
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
488
+
489
+ def flops(self):
490
+ flops = 0
491
+ flops += self.residual_group.flops()
492
+ H, W = self.input_resolution
493
+ flops += H * W * self.dim * self.dim * 9
494
+ flops += self.patch_embed.flops()
495
+ flops += self.patch_unembed.flops()
496
+
497
+ return flops
498
+
499
+
500
+ class PatchEmbed(nn.Module):
501
+ r""" Image to Patch Embedding
502
+
503
+ Args:
504
+ img_size (int): Image size. Default: 224.
505
+ patch_size (int): Patch token size. Default: 4.
506
+ in_chans (int): Number of input image channels. Default: 3.
507
+ embed_dim (int): Number of linear projection output channels. Default: 96.
508
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
509
+ """
510
+
511
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
512
+ super().__init__()
513
+ img_size = to_2tuple(img_size)
514
+ patch_size = to_2tuple(patch_size)
515
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
516
+ self.img_size = img_size
517
+ self.patch_size = patch_size
518
+ self.patches_resolution = patches_resolution
519
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
520
+
521
+ self.in_chans = in_chans
522
+ self.embed_dim = embed_dim
523
+
524
+ if norm_layer is not None:
525
+ self.norm = norm_layer(embed_dim)
526
+ else:
527
+ self.norm = None
528
+
529
+ def forward(self, x):
530
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
531
+ if self.norm is not None:
532
+ x = self.norm(x)
533
+ return x
534
+
535
+ def flops(self):
536
+ flops = 0
537
+ H, W = self.img_size
538
+ if self.norm is not None:
539
+ flops += H * W * self.embed_dim
540
+ return flops
541
+
542
+
543
+ class PatchUnEmbed(nn.Module):
544
+ r""" Image to Patch Unembedding
545
+
546
+ Args:
547
+ img_size (int): Image size. Default: 224.
548
+ patch_size (int): Patch token size. Default: 4.
549
+ in_chans (int): Number of input image channels. Default: 3.
550
+ embed_dim (int): Number of linear projection output channels. Default: 96.
551
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
552
+ """
553
+
554
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
555
+ super().__init__()
556
+ img_size = to_2tuple(img_size)
557
+ patch_size = to_2tuple(patch_size)
558
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
559
+ self.img_size = img_size
560
+ self.patch_size = patch_size
561
+ self.patches_resolution = patches_resolution
562
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
563
+
564
+ self.in_chans = in_chans
565
+ self.embed_dim = embed_dim
566
+
567
+ def forward(self, x, x_size):
568
+ B, HW, C = x.shape
569
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
570
+ return x
571
+
572
+ def flops(self):
573
+ flops = 0
574
+ return flops
575
+
576
+
577
+ class Upsample(nn.Sequential):
578
+ """Upsample module.
579
+
580
+ Args:
581
+ scale (int): Scale factor. Supported scales: 2^n and 3.
582
+ num_feat (int): Channel number of intermediate features.
583
+ """
584
+
585
+ def __init__(self, scale, num_feat):
586
+ m = []
587
+ if (scale & (scale - 1)) == 0: # scale = 2^n
588
+ for _ in range(int(math.log(scale, 2))):
589
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
590
+ m.append(nn.PixelShuffle(2))
591
+ elif scale == 3:
592
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
593
+ m.append(nn.PixelShuffle(3))
594
+ else:
595
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
596
+ super(Upsample, self).__init__(*m)
597
+
598
+
599
+ class UpsampleOneStep(nn.Sequential):
600
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
601
+ Used in lightweight SR to save parameters.
602
+
603
+ Args:
604
+ scale (int): Scale factor. Supported scales: 2^n and 3.
605
+ num_feat (int): Channel number of intermediate features.
606
+
607
+ """
608
+
609
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
610
+ self.num_feat = num_feat
611
+ self.input_resolution = input_resolution
612
+ m = []
613
+ m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
614
+ m.append(nn.PixelShuffle(scale))
615
+ super(UpsampleOneStep, self).__init__(*m)
616
+
617
+ def flops(self):
618
+ H, W = self.input_resolution
619
+ flops = H * W * self.num_feat * 3 * 9
620
+ return flops
621
+
622
+
623
+ class SwinIR(nn.Module):
624
+ r""" SwinIR
625
+ A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
626
+
627
+ Args:
628
+ img_size (int | tuple(int)): Input image size. Default 64
629
+ patch_size (int | tuple(int)): Patch size. Default: 1
630
+ in_chans (int): Number of input image channels. Default: 3
631
+ embed_dim (int): Patch embedding dimension. Default: 96
632
+ depths (tuple(int)): Depth of each Swin Transformer layer.
633
+ num_heads (tuple(int)): Number of attention heads in different layers.
634
+ window_size (int): Window size. Default: 7
635
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
636
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
637
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
638
+ drop_rate (float): Dropout rate. Default: 0
639
+ attn_drop_rate (float): Attention dropout rate. Default: 0
640
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
641
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
642
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
643
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
644
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
645
+ sf: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
646
+ img_range: Image range. 1. or 255.
647
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
648
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
649
+ """
650
+
651
+ def __init__(
652
+ self,
653
+ img_size=64,
654
+ patch_size=1,
655
+ in_chans=3,
656
+ num_out_ch=3,
657
+ embed_dim=96,
658
+ depths=[6, 6, 6, 6],
659
+ num_heads=[6, 6, 6, 6],
660
+ window_size=7,
661
+ mlp_ratio=4.,
662
+ qkv_bias=True,
663
+ qk_scale=None,
664
+ drop_rate=0.,
665
+ attn_drop_rate=0.,
666
+ drop_path_rate=0.1,
667
+ norm_layer=nn.LayerNorm,
668
+ ape=False,
669
+ patch_norm=True,
670
+ use_checkpoint=False,
671
+ sf=4,
672
+ img_range=1.,
673
+ upsampler='',
674
+ resi_connection='1conv',
675
+ unshuffle=False,
676
+ unshuffle_scale=None,
677
+ hq_key: str = "jpg",
678
+ lq_key: str = "hint",
679
+ learning_rate: float = None,
680
+ weight_decay: float = None
681
+ ) -> "SwinIR":
682
+ super(SwinIR, self).__init__()
683
+ num_in_ch = in_chans * (unshuffle_scale ** 2) if unshuffle else in_chans
684
+ num_feat = 64
685
+ self.img_range = img_range
686
+ if in_chans == 3:
687
+ rgb_mean = (0.4488, 0.4371, 0.4040)
688
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
689
+ else:
690
+ self.mean = torch.zeros(1, 1, 1, 1)
691
+ self.upscale = sf
692
+ self.upsampler = upsampler
693
+ self.window_size = window_size
694
+ self.unshuffle_scale = unshuffle_scale
695
+ self.unshuffle = unshuffle
696
+
697
+ #####################################################################################################
698
+ ################################### 1, shallow feature extraction ###################################
699
+ if unshuffle:
700
+ assert unshuffle_scale is not None
701
+ self.conv_first = nn.Sequential(
702
+ nn.PixelUnshuffle(sf),
703
+ nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1),
704
+ )
705
+ else:
706
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
707
+
708
+ #####################################################################################################
709
+ ################################### 2, deep feature extraction ######################################
710
+ self.num_layers = len(depths)
711
+ self.embed_dim = embed_dim
712
+ self.ape = ape
713
+ self.patch_norm = patch_norm
714
+ self.num_features = embed_dim
715
+ self.mlp_ratio = mlp_ratio
716
+
717
+ # split image into non-overlapping patches
718
+ self.patch_embed = PatchEmbed(
719
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
720
+ norm_layer=norm_layer if self.patch_norm else None
721
+ )
722
+ num_patches = self.patch_embed.num_patches
723
+ patches_resolution = self.patch_embed.patches_resolution
724
+ self.patches_resolution = patches_resolution
725
+
726
+ # merge non-overlapping patches into image
727
+ self.patch_unembed = PatchUnEmbed(
728
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
729
+ norm_layer=norm_layer if self.patch_norm else None
730
+ )
731
+
732
+ # absolute position embedding
733
+ if self.ape:
734
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
735
+ trunc_normal_(self.absolute_pos_embed, std=.02)
736
+
737
+ self.pos_drop = nn.Dropout(p=drop_rate)
738
+
739
+ # stochastic depth
740
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
741
+
742
+ # build Residual Swin Transformer blocks (RSTB)
743
+ self.layers = nn.ModuleList()
744
+ for i_layer in range(self.num_layers):
745
+ layer = RSTB(
746
+ dim=embed_dim,
747
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
748
+ depth=depths[i_layer],
749
+ num_heads=num_heads[i_layer],
750
+ window_size=window_size,
751
+ mlp_ratio=self.mlp_ratio,
752
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
753
+ drop=drop_rate, attn_drop=attn_drop_rate,
754
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
755
+ norm_layer=norm_layer,
756
+ downsample=None,
757
+ use_checkpoint=use_checkpoint,
758
+ img_size=img_size,
759
+ patch_size=patch_size,
760
+ resi_connection=resi_connection
761
+ )
762
+ self.layers.append(layer)
763
+ self.norm = norm_layer(self.num_features)
764
+
765
+ # build the last conv layer in deep feature extraction
766
+ if resi_connection == '1conv':
767
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
768
+ elif resi_connection == '3conv':
769
+ # to save parameters and memory
770
+ self.conv_after_body = nn.Sequential(
771
+ nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
772
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
773
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
774
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
775
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)
776
+ )
777
+
778
+ #####################################################################################################
779
+ ################################ 3, high quality image reconstruction ################################
780
+ if self.upsampler == 'pixelshuffle':
781
+ # for classical SR
782
+ self.conv_before_upsample = nn.Sequential(
783
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
784
+ nn.LeakyReLU(inplace=True)
785
+ )
786
+ self.upsample = Upsample(sf, num_feat)
787
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
788
+ elif self.upsampler == 'pixelshuffledirect':
789
+ # for lightweight SR (to save parameters)
790
+ self.upsample = UpsampleOneStep(
791
+ sf, embed_dim, num_out_ch,
792
+ (patches_resolution[0], patches_resolution[1])
793
+ )
794
+ elif self.upsampler == 'nearest+conv':
795
+ # for real-world SR (less artifacts)
796
+ self.conv_before_upsample = nn.Sequential(
797
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
798
+ nn.LeakyReLU(inplace=True)
799
+ )
800
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
801
+ if self.upscale == 4:
802
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
803
+ elif self.upscale == 8:
804
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
805
+ self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
806
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
807
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
808
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
809
+ else:
810
+ # for image denoising and JPEG compression artifact reduction
811
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
812
+
813
+ self.apply(self._init_weights)
814
+
815
+ def _init_weights(self, m: nn.Module) -> None:
816
+ if isinstance(m, nn.Linear):
817
+ trunc_normal_(m.weight, std=.02)
818
+ if isinstance(m, nn.Linear) and m.bias is not None:
819
+ nn.init.constant_(m.bias, 0)
820
+ elif isinstance(m, nn.LayerNorm):
821
+ nn.init.constant_(m.bias, 0)
822
+ nn.init.constant_(m.weight, 1.0)
823
+
824
+ # TODO: What's this ?
825
+ @torch.jit.ignore
826
+ def no_weight_decay(self) -> Set[str]:
827
+ return {'absolute_pos_embed'}
828
+
829
+ @torch.jit.ignore
830
+ def no_weight_decay_keywords(self) -> Set[str]:
831
+ return {'relative_position_bias_table'}
832
+
833
+ def check_image_size(self, x: torch.Tensor) -> torch.Tensor:
834
+ _, _, h, w = x.size()
835
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
836
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
837
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
838
+ return x
839
+
840
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
841
+ x_size = (x.shape[2], x.shape[3])
842
+ x = self.patch_embed(x)
843
+ if self.ape:
844
+ x = x + self.absolute_pos_embed
845
+ x = self.pos_drop(x)
846
+
847
+ for layer in self.layers:
848
+ x = layer(x, x_size)
849
+
850
+ x = self.norm(x) # B L C
851
+ x = self.patch_unembed(x, x_size)
852
+
853
+ return x
854
+
855
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
856
+ H, W = x.shape[2:]
857
+ x = self.check_image_size(x)
858
+
859
+ self.mean = self.mean.type_as(x)
860
+ x = (x - self.mean) * self.img_range
861
+
862
+ if self.upsampler == 'pixelshuffle':
863
+ # for classical SR
864
+ x = self.conv_first(x)
865
+ x = self.conv_after_body(self.forward_features(x)) + x
866
+ x = self.conv_before_upsample(x)
867
+ x = self.conv_last(self.upsample(x))
868
+ elif self.upsampler == 'pixelshuffledirect':
869
+ # for lightweight SR
870
+ x = self.conv_first(x)
871
+ x = self.conv_after_body(self.forward_features(x)) + x
872
+ x = self.upsample(x)
873
+ elif self.upsampler == 'nearest+conv':
874
+ # for real-world SR
875
+ x = self.conv_first(x)
876
+ x = self.conv_after_body(self.forward_features(x)) + x
877
+ x = self.conv_before_upsample(x)
878
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
879
+ if self.upscale == 4:
880
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
881
+ elif self.upscale == 8:
882
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
883
+ x = self.lrelu(self.conv_up3(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
884
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
885
+ else:
886
+ # for image denoising and JPEG compression artifact reduction
887
+ x_first = self.conv_first(x)
888
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
889
+ x = x + self.conv_last(res)
890
+
891
+ x = x / self.img_range + self.mean
892
+
893
+ return x[:, :, :H * self.upscale, :W * self.upscale]
894
+
895
+ def flops(self) -> int:
896
+ flops = 0
897
+ H, W = self.patches_resolution
898
+ flops += H * W * 3 * self.embed_dim * 9
899
+ flops += self.patch_embed.flops()
900
+ for i, layer in enumerate(self.layers):
901
+ flops += layer.flops()
902
+ flops += H * W * 3 * self.embed_dim * self.embed_dim
903
+ flops += self.upsample.flops()
904
+ return flops
guided_diffusion/unet.py ADDED
@@ -0,0 +1,1148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+
3
+ import math
4
+
5
+ import numpy as np
6
+ import torch as th
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import functools
10
+ from collections import OrderedDict
11
+
12
+
13
+ from .fp16_util import convert_module_to_f16, convert_module_to_f32
14
+ from .nn import (
15
+ checkpoint,
16
+ conv_nd,
17
+ linear,
18
+ avg_pool_nd,
19
+ zero_module,
20
+ normalization,
21
+ timestep_embedding,
22
+ )
23
+
24
+
25
+ NUM_CLASSES = 1000
26
+
27
+ def create_model(
28
+ image_size,
29
+ num_channels,
30
+ num_res_blocks,
31
+ channel_mult="",
32
+ learn_sigma=False,
33
+ class_cond=False,
34
+ conv_resample=True,
35
+ use_checkpoint=False,
36
+ attention_resolutions="16",
37
+ num_heads=1,
38
+ num_head_channels=-1,
39
+ num_heads_upsample=-1,
40
+ use_scale_shift_norm=False,
41
+ dropout=0,
42
+ resblock_updown=False,
43
+ use_fp16=False,
44
+ use_new_attention_order=False,
45
+ dims=2,
46
+ model_path='',
47
+ ):
48
+ if channel_mult == "":
49
+ if image_size == 512:
50
+ channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
51
+ elif image_size == 256:
52
+ channel_mult = (1, 1, 2, 2, 4, 4)
53
+ elif image_size == 128:
54
+ channel_mult = (1, 1, 2, 3, 4)
55
+ elif image_size == 64:
56
+ channel_mult = (1, 2, 3, 4)
57
+ else:
58
+ raise ValueError(f"unsupported image size: {image_size}")
59
+ else:
60
+ channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
61
+ print(channel_mult)
62
+ attention_ds = []
63
+ if isinstance(attention_resolutions, int):
64
+ attention_ds.append(image_size // attention_resolutions)
65
+ elif isinstance(attention_resolutions, str):
66
+ for res in attention_resolutions.split(","):
67
+ attention_ds.append(image_size // int(res))
68
+ else:
69
+ raise NotImplementedError
70
+
71
+ if isinstance(num_res_blocks, str):
72
+ num_res_blocks_res = []
73
+ for res in num_res_blocks.split(","):
74
+ num_res_blocks_res.append(int(res))
75
+ else:
76
+ assert isinstance(num_res_blocks, int)
77
+ num_res_blocks_res = num_res_blocks
78
+
79
+ model= UNetModel(
80
+ image_size=image_size,
81
+ in_channels=3,
82
+ model_channels=num_channels,
83
+ out_channels=(3 if not learn_sigma else 6),
84
+ num_res_blocks=num_res_blocks_res,
85
+ attention_resolutions=tuple(attention_ds),
86
+ dropout=dropout,
87
+ channel_mult=channel_mult,
88
+ num_classes=(NUM_CLASSES if class_cond else None),
89
+ use_checkpoint=use_checkpoint,
90
+ use_fp16=use_fp16,
91
+ num_heads=num_heads,
92
+ dims=dims,
93
+ num_head_channels=num_head_channels,
94
+ num_heads_upsample=num_heads_upsample,
95
+ use_scale_shift_norm=use_scale_shift_norm,
96
+ resblock_updown=resblock_updown,
97
+ use_new_attention_order=use_new_attention_order,
98
+ conv_resample=conv_resample
99
+ )
100
+
101
+ try:
102
+ ckpt = th.load(model_path, map_location='cpu')
103
+ if list(model.state_dict().keys())[0].startswith('module.'):
104
+ if list(ckpt.keys())[0].startswith('module.'):
105
+ ckpt = ckpt
106
+ else:
107
+ ckpt = OrderedDict({f'module.{key}': value for key, value in ckpt.items()})
108
+ else:
109
+ if list(ckpt.keys())[0].startswith('module.'):
110
+ ckpt = OrderedDict({key[7:]: value for key, value in ckpt.items()})
111
+ else:
112
+ ckpt = ckpt
113
+
114
+ model.load_state_dict(ckpt)
115
+ except Exception as e:
116
+ print(f"Got exception: {e} / Randomly initialize")
117
+ return model
118
+
119
+ class AttentionPool2d(nn.Module):
120
+ """
121
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ spacial_dim: int,
127
+ embed_dim: int,
128
+ num_heads_channels: int,
129
+ output_dim: int = None,
130
+ ):
131
+ super().__init__()
132
+ self.positional_embedding = nn.Parameter(
133
+ th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
134
+ )
135
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
136
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
137
+ self.num_heads = embed_dim // num_heads_channels
138
+ self.attention = QKVAttention(self.num_heads)
139
+
140
+ def forward(self, x):
141
+ b, c, *_spatial = x.shape
142
+ x = x.reshape(b, c, -1) # NC(HW)
143
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
144
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
145
+ x = self.qkv_proj(x)
146
+ x = self.attention(x)
147
+ x = self.c_proj(x)
148
+ return x[:, :, 0]
149
+
150
+
151
+ class TimestepBlock(nn.Module):
152
+ """
153
+ Any module where forward() takes timestep embeddings as a second argument.
154
+ """
155
+
156
+ @abstractmethod
157
+ def forward(self, x, emb):
158
+ """
159
+ Apply the module to `x` given `emb` timestep embeddings.
160
+ """
161
+
162
+
163
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
164
+ """
165
+ A sequential module that passes timestep embeddings to the children that
166
+ support it as an extra input.
167
+ """
168
+
169
+ def forward(self, x, emb):
170
+ for layer in self:
171
+ if isinstance(layer, TimestepBlock):
172
+ x = layer(x, emb)
173
+ else:
174
+ x = layer(x)
175
+ return x
176
+
177
+
178
+ class Upsample(nn.Module):
179
+ """
180
+ An upsampling layer with an optional convolution.
181
+
182
+ :param channels: channels in the inputs and outputs.
183
+ :param use_conv: a bool determining if a convolution is applied.
184
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
185
+ upsampling occurs in the inner-two dimensions.
186
+ """
187
+
188
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
189
+ super().__init__()
190
+ self.channels = channels
191
+ self.out_channels = out_channels or channels
192
+ self.use_conv = use_conv
193
+ self.dims = dims
194
+ if use_conv:
195
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
196
+
197
+ def forward(self, x):
198
+ assert x.shape[1] == self.channels
199
+ if self.dims == 3:
200
+ x = F.interpolate(
201
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
202
+ )
203
+ else:
204
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
205
+ if self.use_conv:
206
+ x = self.conv(x)
207
+ return x
208
+
209
+
210
+ class Downsample(nn.Module):
211
+ """
212
+ A downsampling layer with an optional convolution.
213
+
214
+ :param channels: channels in the inputs and outputs.
215
+ :param use_conv: a bool determining if a convolution is applied.
216
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
217
+ downsampling occurs in the inner-two dimensions.
218
+ """
219
+
220
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
221
+ super().__init__()
222
+ self.channels = channels
223
+ self.out_channels = out_channels or channels
224
+ self.use_conv = use_conv
225
+ self.dims = dims
226
+ stride = 2 if dims != 3 else (1, 2, 2)
227
+ if use_conv:
228
+ self.op = conv_nd(
229
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=1
230
+ )
231
+ else:
232
+ assert self.channels == self.out_channels
233
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
234
+
235
+ def forward(self, x):
236
+ assert x.shape[1] == self.channels
237
+ return self.op(x)
238
+
239
+
240
+ class ResBlock(TimestepBlock):
241
+ """
242
+ A residual block that can optionally change the number of channels.
243
+
244
+ :param channels: the number of input channels.
245
+ :param emb_channels: the number of timestep embedding channels.
246
+ :param dropout: the rate of dropout.
247
+ :param out_channels: if specified, the number of out channels.
248
+ :param use_conv: if True and out_channels is specified, use a spatial
249
+ convolution instead of a smaller 1x1 convolution to change the
250
+ channels in the skip connection.
251
+ :param dims: determines if the signal is 1D, 2D, or 3D.
252
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
253
+ :param up: if True, use this block for upsampling.
254
+ :param down: if True, use this block for downsampling.
255
+ """
256
+
257
+ def __init__(
258
+ self,
259
+ channels,
260
+ emb_channels,
261
+ dropout,
262
+ out_channels=None,
263
+ use_conv=False,
264
+ use_scale_shift_norm=False,
265
+ dims=2,
266
+ use_checkpoint=False,
267
+ up=False,
268
+ down=False,
269
+ ):
270
+ super().__init__()
271
+ self.channels = channels
272
+ self.emb_channels = emb_channels
273
+ self.dropout = dropout
274
+ self.out_channels = out_channels or channels
275
+ self.use_conv = use_conv
276
+ self.use_checkpoint = use_checkpoint
277
+ self.use_scale_shift_norm = use_scale_shift_norm
278
+
279
+ self.in_layers = nn.Sequential(
280
+ normalization(channels),
281
+ nn.SiLU(),
282
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
283
+ )
284
+
285
+ self.updown = up or down
286
+
287
+ if up:
288
+ self.h_upd = Upsample(channels, False, dims)
289
+ self.x_upd = Upsample(channels, False, dims)
290
+ elif down:
291
+ self.h_upd = Downsample(channels, False, dims)
292
+ self.x_upd = Downsample(channels, False, dims)
293
+ else:
294
+ self.h_upd = self.x_upd = nn.Identity()
295
+
296
+ self.emb_layers = nn.Sequential(
297
+ nn.SiLU(),
298
+ linear(
299
+ emb_channels,
300
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
301
+ ),
302
+ )
303
+ self.out_layers = nn.Sequential(
304
+ normalization(self.out_channels),
305
+ nn.SiLU(),
306
+ nn.Dropout(p=dropout),
307
+ zero_module(
308
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
309
+ ),
310
+ )
311
+
312
+ if self.out_channels == channels:
313
+ self.skip_connection = nn.Identity()
314
+ elif use_conv:
315
+ self.skip_connection = conv_nd(
316
+ dims, channels, self.out_channels, 3, padding=1
317
+ )
318
+ else:
319
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
320
+
321
+ def forward(self, x, emb):
322
+ """
323
+ Apply the block to a Tensor, conditioned on a timestep embedding.
324
+
325
+ :param x: an [N x C x ...] Tensor of features.
326
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
327
+ :return: an [N x C x ...] Tensor of outputs.
328
+ """
329
+ return checkpoint(
330
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
331
+ )
332
+
333
+ def _forward(self, x, emb):
334
+ if self.updown:
335
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
336
+ h = in_rest(x)
337
+ h = self.h_upd(h)
338
+ x = self.x_upd(x)
339
+ h = in_conv(h)
340
+ else:
341
+ h = self.in_layers(x)
342
+ emb_out = self.emb_layers(emb).type(h.dtype)
343
+ while len(emb_out.shape) < len(h.shape):
344
+ emb_out = emb_out[..., None]
345
+ if self.use_scale_shift_norm:
346
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
347
+ scale, shift = th.chunk(emb_out, 2, dim=1)
348
+ h = out_norm(h) * (1 + scale) + shift
349
+ h = out_rest(h)
350
+ else:
351
+ h = h + emb_out
352
+ h = self.out_layers(h)
353
+ return self.skip_connection(x) + h
354
+
355
+
356
+ class AttentionBlock(nn.Module):
357
+ """
358
+ An attention block that allows spatial positions to attend to each other.
359
+
360
+ Originally ported from here, but adapted to the N-d case.
361
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
362
+ """
363
+
364
+ def __init__(
365
+ self,
366
+ channels,
367
+ num_heads=1,
368
+ num_head_channels=-1,
369
+ use_checkpoint=False,
370
+ use_new_attention_order=False,
371
+ ):
372
+ super().__init__()
373
+ self.channels = channels
374
+ if num_head_channels == -1:
375
+ self.num_heads = num_heads
376
+ else:
377
+ assert (
378
+ channels % num_head_channels == 0
379
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
380
+ self.num_heads = channels // num_head_channels
381
+ self.use_checkpoint = use_checkpoint
382
+ self.norm = normalization(channels)
383
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
384
+ if use_new_attention_order:
385
+ # split qkv before split heads
386
+ self.attention = QKVAttention(self.num_heads)
387
+ else:
388
+ # split heads before split qkv
389
+ self.attention = QKVAttentionLegacy(self.num_heads)
390
+
391
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
392
+
393
+ def forward(self, x):
394
+ return checkpoint(self._forward, (x,), self.parameters(), True)
395
+
396
+ def _forward(self, x):
397
+ b, c, *spatial = x.shape
398
+ x = x.reshape(b, c, -1)
399
+ qkv = self.qkv(self.norm(x))
400
+ h = self.attention(qkv)
401
+ h = self.proj_out(h)
402
+ return (x + h).reshape(b, c, *spatial)
403
+
404
+
405
+ def count_flops_attn(model, _x, y):
406
+ """
407
+ A counter for the `thop` package to count the operations in an
408
+ attention operation.
409
+ Meant to be used like:
410
+ macs, params = thop.profile(
411
+ model,
412
+ inputs=(inputs, timestamps),
413
+ custom_ops={QKVAttention: QKVAttention.count_flops},
414
+ )
415
+ """
416
+ b, c, *spatial = y[0].shape
417
+ num_spatial = int(np.prod(spatial))
418
+ # We perform two matmuls with the same number of ops.
419
+ # The first computes the weight matrix, the second computes
420
+ # the combination of the value vectors.
421
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
422
+ model.total_ops += th.DoubleTensor([matmul_ops])
423
+
424
+
425
+ class QKVAttentionLegacy(nn.Module):
426
+ """
427
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
428
+ """
429
+
430
+ def __init__(self, n_heads):
431
+ super().__init__()
432
+ self.n_heads = n_heads
433
+
434
+ def forward(self, qkv):
435
+ """
436
+ Apply QKV attention.
437
+
438
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
439
+ :return: an [N x (H * C) x T] tensor after attention.
440
+ """
441
+ bs, width, length = qkv.shape
442
+ assert width % (3 * self.n_heads) == 0
443
+ ch = width // (3 * self.n_heads)
444
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
445
+ scale = 1 / math.sqrt(math.sqrt(ch))
446
+ weight = th.einsum(
447
+ "bct,bcs->bts", q * scale, k * scale
448
+ ) # More stable with f16 than dividing afterwards
449
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
450
+ a = th.einsum("bts,bcs->bct", weight, v)
451
+ return a.reshape(bs, -1, length)
452
+
453
+ @staticmethod
454
+ def count_flops(model, _x, y):
455
+ return count_flops_attn(model, _x, y)
456
+
457
+
458
+ class QKVAttention(nn.Module):
459
+ """
460
+ A module which performs QKV attention and splits in a different order.
461
+ """
462
+
463
+ def __init__(self, n_heads):
464
+ super().__init__()
465
+ self.n_heads = n_heads
466
+
467
+ def forward(self, qkv):
468
+ """
469
+ Apply QKV attention.
470
+
471
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
472
+ :return: an [N x (H * C) x T] tensor after attention.
473
+ """
474
+ bs, width, length = qkv.shape
475
+ assert width % (3 * self.n_heads) == 0
476
+ ch = width // (3 * self.n_heads)
477
+ q, k, v = qkv.chunk(3, dim=1)
478
+ scale = 1 / math.sqrt(math.sqrt(ch))
479
+ weight = th.einsum(
480
+ "bct,bcs->bts",
481
+ (q * scale).view(bs * self.n_heads, ch, length),
482
+ (k * scale).view(bs * self.n_heads, ch, length),
483
+ ) # More stable with f16 than dividing afterwards
484
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
485
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
486
+ return a.reshape(bs, -1, length)
487
+
488
+ @staticmethod
489
+ def count_flops(model, _x, y):
490
+ return count_flops_attn(model, _x, y)
491
+
492
+
493
+ class UNetModel(nn.Module):
494
+ """
495
+ The full UNet model with attention and timestep embedding.
496
+
497
+ :param in_channels: channels in the input Tensor.
498
+ :param model_channels: base channel count for the model.
499
+ :param out_channels: channels in the output Tensor.
500
+ :param num_res_blocks: number of residual blocks per downsample.
501
+ :param attention_resolutions: a collection of downsample rates at which
502
+ attention will take place. May be a set, list, or tuple.
503
+ For example, if this contains 4, then at 4x downsampling, attention
504
+ will be used.
505
+ :param dropout: the dropout probability.
506
+ :param channel_mult: channel multiplier for each level of the UNet.
507
+ :param conv_resample: if True, use learned convolutions for upsampling and
508
+ downsampling.
509
+ :param dims: determines if the signal is 1D, 2D, or 3D.
510
+ :param num_classes: if specified (as an int), then this model will be
511
+ class-conditional with `num_classes` classes.
512
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
513
+ :param num_heads: the number of attention heads in each attention layer.
514
+ :param num_heads_channels: if specified, ignore num_heads and instead use
515
+ a fixed channel width per attention head.
516
+ :param num_heads_upsample: works with num_heads to set a different number
517
+ of heads for upsampling. Deprecated.
518
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
519
+ :param resblock_updown: use residual blocks for up/downsampling.
520
+ :param use_new_attention_order: use a different attention pattern for potentially
521
+ increased efficiency.
522
+ """
523
+
524
+ def __init__(
525
+ self,
526
+ image_size,
527
+ in_channels,
528
+ model_channels,
529
+ out_channels,
530
+ num_res_blocks,
531
+ attention_resolutions,
532
+ dropout=0,
533
+ channel_mult=(1, 2, 4, 8),
534
+ conv_resample=True,
535
+ dims=2,
536
+ num_classes=None,
537
+ use_checkpoint=False,
538
+ use_fp16=False,
539
+ num_heads=1,
540
+ num_head_channels=-1,
541
+ num_heads_upsample=-1,
542
+ use_scale_shift_norm=False,
543
+ resblock_updown=False,
544
+ use_new_attention_order=False,
545
+ ):
546
+ super().__init__()
547
+ if isinstance(num_res_blocks, int):
548
+ num_res_blocks = [num_res_blocks, ] * len(channel_mult)
549
+ else:
550
+ assert len(num_res_blocks) == len(channel_mult)
551
+ self.num_res_blocks = num_res_blocks
552
+
553
+ if num_heads_upsample == -1:
554
+ num_heads_upsample = num_heads
555
+
556
+ self.image_size = image_size
557
+ self.in_channels = in_channels
558
+ self.model_channels = model_channels
559
+ self.out_channels = out_channels
560
+ self.num_res_blocks = num_res_blocks
561
+ self.attention_resolutions = attention_resolutions
562
+ self.dropout = dropout
563
+ self.channel_mult = channel_mult
564
+ self.conv_resample = conv_resample
565
+ self.num_classes = num_classes
566
+ self.use_checkpoint = use_checkpoint
567
+ self.dtype = th.float16 if use_fp16 else th.float32
568
+ self.num_heads = num_heads
569
+ self.num_head_channels = num_head_channels
570
+ self.num_heads_upsample = num_heads_upsample
571
+
572
+ time_embed_dim = model_channels * 4
573
+ self.time_embed = nn.Sequential(
574
+ linear(model_channels, time_embed_dim),
575
+ nn.SiLU(),
576
+ linear(time_embed_dim, time_embed_dim),
577
+ )
578
+
579
+ if self.num_classes is not None:
580
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
581
+
582
+ ch = input_ch = int(channel_mult[0] * model_channels)
583
+ self.input_blocks = nn.ModuleList(
584
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
585
+ )
586
+ self._feature_size = ch
587
+ input_block_chans = [ch]
588
+ ds = 1
589
+ for level, mult in enumerate(channel_mult):
590
+ for _ in range(num_res_blocks[level]):
591
+ layers = [
592
+ ResBlock(
593
+ ch,
594
+ time_embed_dim,
595
+ dropout,
596
+ out_channels=int(mult * model_channels),
597
+ dims=dims,
598
+ use_checkpoint=use_checkpoint,
599
+ use_scale_shift_norm=use_scale_shift_norm,
600
+ )
601
+ ]
602
+ ch = int(mult * model_channels)
603
+ if ds in attention_resolutions:
604
+ layers.append(
605
+ AttentionBlock(
606
+ ch,
607
+ use_checkpoint=use_checkpoint,
608
+ num_heads=num_heads,
609
+ num_head_channels=num_head_channels,
610
+ use_new_attention_order=use_new_attention_order,
611
+ )
612
+ )
613
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
614
+ self._feature_size += ch
615
+ input_block_chans.append(ch)
616
+ if level != len(channel_mult) - 1:
617
+ out_ch = ch
618
+ self.input_blocks.append(
619
+ TimestepEmbedSequential(
620
+ ResBlock(
621
+ ch,
622
+ time_embed_dim,
623
+ dropout,
624
+ out_channels=out_ch,
625
+ dims=dims,
626
+ use_checkpoint=use_checkpoint,
627
+ use_scale_shift_norm=use_scale_shift_norm,
628
+ down=True,
629
+ )
630
+ if resblock_updown
631
+ else Downsample(
632
+ ch, conv_resample, dims=dims, out_channels=out_ch
633
+ )
634
+ )
635
+ )
636
+ ch = out_ch
637
+ input_block_chans.append(ch)
638
+ ds *= 2
639
+ self._feature_size += ch
640
+
641
+ self.middle_block = TimestepEmbedSequential(
642
+ ResBlock(
643
+ ch,
644
+ time_embed_dim,
645
+ dropout,
646
+ dims=dims,
647
+ use_checkpoint=use_checkpoint,
648
+ use_scale_shift_norm=use_scale_shift_norm,
649
+ ),
650
+ AttentionBlock(
651
+ ch,
652
+ use_checkpoint=use_checkpoint,
653
+ num_heads=num_heads,
654
+ num_head_channels=num_head_channels,
655
+ use_new_attention_order=use_new_attention_order,
656
+ ),
657
+ ResBlock(
658
+ ch,
659
+ time_embed_dim,
660
+ dropout,
661
+ dims=dims,
662
+ use_checkpoint=use_checkpoint,
663
+ use_scale_shift_norm=use_scale_shift_norm,
664
+ ),
665
+ )
666
+ self._feature_size += ch
667
+
668
+ self.output_blocks = nn.ModuleList([])
669
+ for level, mult in list(enumerate(channel_mult))[::-1]:
670
+ for i in range(num_res_blocks[level] + 1):
671
+ ich = input_block_chans.pop()
672
+ layers = [
673
+ ResBlock(
674
+ ch + ich,
675
+ time_embed_dim,
676
+ dropout,
677
+ out_channels=int(model_channels * mult),
678
+ dims=dims,
679
+ use_checkpoint=use_checkpoint,
680
+ use_scale_shift_norm=use_scale_shift_norm,
681
+ )
682
+ ]
683
+ ch = int(model_channels * mult)
684
+ if ds in attention_resolutions:
685
+ layers.append(
686
+ AttentionBlock(
687
+ ch,
688
+ use_checkpoint=use_checkpoint,
689
+ num_heads=num_heads_upsample,
690
+ num_head_channels=num_head_channels,
691
+ use_new_attention_order=use_new_attention_order,
692
+ )
693
+ )
694
+ if level and i == num_res_blocks[level]:
695
+ out_ch = ch
696
+ layers.append(
697
+ ResBlock(
698
+ ch,
699
+ time_embed_dim,
700
+ dropout,
701
+ out_channels=out_ch,
702
+ dims=dims,
703
+ use_checkpoint=use_checkpoint,
704
+ use_scale_shift_norm=use_scale_shift_norm,
705
+ up=True,
706
+ )
707
+ if resblock_updown
708
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
709
+ )
710
+ ds //= 2
711
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
712
+ self._feature_size += ch
713
+
714
+ self.out = nn.Sequential(
715
+ normalization(ch),
716
+ nn.SiLU(),
717
+ zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
718
+ )
719
+
720
+ def convert_to_fp16(self):
721
+ """
722
+ Convert the torso of the model to float16.
723
+ """
724
+ self.input_blocks.apply(convert_module_to_f16)
725
+ self.middle_block.apply(convert_module_to_f16)
726
+ self.output_blocks.apply(convert_module_to_f16)
727
+
728
+ def convert_to_fp32(self):
729
+ """
730
+ Convert the torso of the model to float32.
731
+ """
732
+ self.input_blocks.apply(convert_module_to_f32)
733
+ self.middle_block.apply(convert_module_to_f32)
734
+ self.output_blocks.apply(convert_module_to_f32)
735
+
736
+ def forward(self, x, timesteps, y=None):
737
+ """
738
+ Apply the model to an input batch.
739
+
740
+ :param x: an [N x C x ...] Tensor of inputs.
741
+ :param timesteps: a 1-D batch of timesteps.
742
+ :param y: an [N] Tensor of labels, if class-conditional.
743
+ :return: an [N x C x ...] Tensor of outputs.
744
+ """
745
+ assert (y is not None) == (
746
+ self.num_classes is not None
747
+ ), "must specify y if and only if the model is class-conditional"
748
+
749
+ hs = []
750
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
751
+
752
+ if self.num_classes is not None:
753
+ assert y.shape == (x.shape[0],)
754
+ emb = emb + self.label_emb(y)
755
+
756
+ h = x.type(self.dtype)
757
+ for module in self.input_blocks:
758
+ h = module(h, emb)
759
+ hs.append(h)
760
+ h = self.middle_block(h, emb)
761
+ for module in self.output_blocks:
762
+ h = th.cat([h, hs.pop()], dim=1)
763
+ h = module(h, emb)
764
+ h = h.type(x.dtype)
765
+ return self.out(h)
766
+
767
+
768
+ class SuperResModel(UNetModel):
769
+ """
770
+ A UNetModel that performs super-resolution.
771
+
772
+ Expects an extra kwarg `low_res` to condition on a low-resolution image.
773
+ """
774
+
775
+ def __init__(self, image_size, in_channels, *args, **kwargs):
776
+ super().__init__(image_size, in_channels * 2, *args, **kwargs)
777
+
778
+ def forward(self, x, timesteps, low_res=None, **kwargs):
779
+ _, _, new_height, new_width = x.shape
780
+ upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
781
+ x = th.cat([x, upsampled], dim=1)
782
+ return super().forward(x, timesteps, **kwargs)
783
+
784
+
785
+ class EncoderUNetModel(nn.Module):
786
+ """
787
+ The half UNet model with attention and timestep embedding.
788
+
789
+ For usage, see UNet.
790
+ """
791
+
792
+ def __init__(
793
+ self,
794
+ image_size,
795
+ in_channels,
796
+ model_channels,
797
+ out_channels,
798
+ num_res_blocks,
799
+ attention_resolutions,
800
+ dropout=0,
801
+ channel_mult=(1, 2, 4, 8),
802
+ conv_resample=True,
803
+ dims=2,
804
+ use_checkpoint=False,
805
+ use_fp16=False,
806
+ num_heads=1,
807
+ num_head_channels=-1,
808
+ num_heads_upsample=-1,
809
+ use_scale_shift_norm=False,
810
+ resblock_updown=False,
811
+ use_new_attention_order=False,
812
+ pool="adaptive",
813
+ ):
814
+ super().__init__()
815
+
816
+ if num_heads_upsample == -1:
817
+ num_heads_upsample = num_heads
818
+
819
+ self.in_channels = in_channels
820
+ self.model_channels = model_channels
821
+ self.out_channels = out_channels
822
+ self.num_res_blocks = num_res_blocks
823
+ self.attention_resolutions = attention_resolutions
824
+ self.dropout = dropout
825
+ self.channel_mult = channel_mult
826
+ self.conv_resample = conv_resample
827
+ self.use_checkpoint = use_checkpoint
828
+ self.dtype = th.float16 if use_fp16 else th.float32
829
+ self.num_heads = num_heads
830
+ self.num_head_channels = num_head_channels
831
+ self.num_heads_upsample = num_heads_upsample
832
+
833
+ time_embed_dim = model_channels * 4
834
+ self.time_embed = nn.Sequential(
835
+ linear(model_channels, time_embed_dim),
836
+ nn.SiLU(),
837
+ linear(time_embed_dim, time_embed_dim),
838
+ )
839
+
840
+ ch = int(channel_mult[0] * model_channels)
841
+ self.input_blocks = nn.ModuleList(
842
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
843
+ )
844
+ self._feature_size = ch
845
+ input_block_chans = [ch]
846
+ ds = 1
847
+ for level, mult in enumerate(channel_mult):
848
+ for _ in range(num_res_blocks[level]):
849
+ layers = [
850
+ ResBlock(
851
+ ch,
852
+ time_embed_dim,
853
+ dropout,
854
+ out_channels=int(mult * model_channels),
855
+ dims=dims,
856
+ use_checkpoint=use_checkpoint,
857
+ use_scale_shift_norm=use_scale_shift_norm,
858
+ )
859
+ ]
860
+ ch = int(mult * model_channels)
861
+ if ds in attention_resolutions:
862
+ layers.append(
863
+ AttentionBlock(
864
+ ch,
865
+ use_checkpoint=use_checkpoint,
866
+ num_heads=num_heads,
867
+ num_head_channels=num_head_channels,
868
+ use_new_attention_order=use_new_attention_order,
869
+ )
870
+ )
871
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
872
+ self._feature_size += ch
873
+ input_block_chans.append(ch)
874
+ if level != len(channel_mult) - 1:
875
+ out_ch = ch
876
+ self.input_blocks.append(
877
+ TimestepEmbedSequential(
878
+ ResBlock(
879
+ ch,
880
+ time_embed_dim,
881
+ dropout,
882
+ out_channels=out_ch,
883
+ dims=dims,
884
+ use_checkpoint=use_checkpoint,
885
+ use_scale_shift_norm=use_scale_shift_norm,
886
+ down=True,
887
+ )
888
+ if resblock_updown
889
+ else Downsample(
890
+ ch, conv_resample, dims=dims, out_channels=out_ch
891
+ )
892
+ )
893
+ )
894
+ ch = out_ch
895
+ input_block_chans.append(ch)
896
+ ds *= 2
897
+ self._feature_size += ch
898
+
899
+ self.middle_block = TimestepEmbedSequential(
900
+ ResBlock(
901
+ ch,
902
+ time_embed_dim,
903
+ dropout,
904
+ dims=dims,
905
+ use_checkpoint=use_checkpoint,
906
+ use_scale_shift_norm=use_scale_shift_norm,
907
+ ),
908
+ AttentionBlock(
909
+ ch,
910
+ use_checkpoint=use_checkpoint,
911
+ num_heads=num_heads,
912
+ num_head_channels=num_head_channels,
913
+ use_new_attention_order=use_new_attention_order,
914
+ ),
915
+ ResBlock(
916
+ ch,
917
+ time_embed_dim,
918
+ dropout,
919
+ dims=dims,
920
+ use_checkpoint=use_checkpoint,
921
+ use_scale_shift_norm=use_scale_shift_norm,
922
+ ),
923
+ )
924
+ self._feature_size += ch
925
+ self.pool = pool
926
+ if pool == "adaptive":
927
+ self.out = nn.Sequential(
928
+ normalization(ch),
929
+ nn.SiLU(),
930
+ nn.AdaptiveAvgPool2d((1, 1)),
931
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
932
+ nn.Flatten(),
933
+ )
934
+ elif pool == "attention":
935
+ assert num_head_channels != -1
936
+ self.out = nn.Sequential(
937
+ normalization(ch),
938
+ nn.SiLU(),
939
+ AttentionPool2d(
940
+ (image_size // ds), ch, num_head_channels, out_channels
941
+ ),
942
+ )
943
+ elif pool == "spatial":
944
+ self.out = nn.Sequential(
945
+ nn.Linear(self._feature_size, 2048),
946
+ nn.ReLU(),
947
+ nn.Linear(2048, self.out_channels),
948
+ )
949
+ elif pool == "spatial_v2":
950
+ self.out = nn.Sequential(
951
+ nn.Linear(self._feature_size, 2048),
952
+ normalization(2048),
953
+ nn.SiLU(),
954
+ nn.Linear(2048, self.out_channels),
955
+ )
956
+ else:
957
+ raise NotImplementedError(f"Unexpected {pool} pooling")
958
+
959
+ def convert_to_fp16(self):
960
+ """
961
+ Convert the torso of the model to float16.
962
+ """
963
+ self.input_blocks.apply(convert_module_to_f16)
964
+ self.middle_block.apply(convert_module_to_f16)
965
+
966
+ def convert_to_fp32(self):
967
+ """
968
+ Convert the torso of the model to float32.
969
+ """
970
+ self.input_blocks.apply(convert_module_to_f32)
971
+ self.middle_block.apply(convert_module_to_f32)
972
+
973
+ def forward(self, x, timesteps):
974
+ """
975
+ Apply the model to an input batch.
976
+
977
+ :param x: an [N x C x ...] Tensor of inputs.
978
+ :param timesteps: a 1-D batch of timesteps.
979
+ :return: an [N x K] Tensor of outputs.
980
+ """
981
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
982
+
983
+ results = []
984
+ h = x.type(self.dtype)
985
+ for module in self.input_blocks:
986
+ h = module(h, emb)
987
+ if self.pool.startswith("spatial"):
988
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
989
+ h = self.middle_block(h, emb)
990
+ if self.pool.startswith("spatial"):
991
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
992
+ h = th.cat(results, axis=-1)
993
+ return self.out(h)
994
+ else:
995
+ h = h.type(x.dtype)
996
+ return self.out(h)
997
+
998
+
999
+ class NLayerDiscriminator(nn.Module):
1000
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
1001
+ super(NLayerDiscriminator, self).__init__()
1002
+ if type(norm_layer) == functools.partial:
1003
+ use_bias = norm_layer.func == nn.InstanceNorm2d
1004
+ else:
1005
+ use_bias = norm_layer == nn.InstanceNorm2d
1006
+
1007
+ kw = 4
1008
+ padw = 1
1009
+ sequence = [
1010
+ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
1011
+ nn.LeakyReLU(0.2, True)
1012
+ ]
1013
+
1014
+ nf_mult = 1
1015
+ nf_mult_prev = 1
1016
+ for n in range(1, n_layers):
1017
+ nf_mult_prev = nf_mult
1018
+ nf_mult = min(2**n, 8)
1019
+ sequence += [
1020
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
1021
+ kernel_size=kw, stride=2, padding=padw, bias=use_bias),
1022
+ norm_layer(ndf * nf_mult),
1023
+ nn.LeakyReLU(0.2, True)
1024
+ ]
1025
+
1026
+ nf_mult_prev = nf_mult
1027
+ nf_mult = min(2**n_layers, 8)
1028
+ sequence += [
1029
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
1030
+ kernel_size=kw, stride=2, padding=padw, bias=use_bias),
1031
+ norm_layer(ndf * nf_mult),
1032
+ nn.LeakyReLU(0.2, True)
1033
+ ]
1034
+
1035
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=2, padding=padw)] + [nn.Dropout(0.5)]
1036
+ if use_sigmoid:
1037
+ sequence += [nn.Sigmoid()]
1038
+
1039
+ self.model = nn.Sequential(*sequence)
1040
+
1041
+ def forward(self, input):
1042
+ return self.model(input)
1043
+
1044
+
1045
+ class GANLoss(nn.Module):
1046
+ """Define different GAN objectives.
1047
+
1048
+ The GANLoss class abstracts away the need to create the target label tensor
1049
+ that has the same size as the input.
1050
+ """
1051
+
1052
+ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
1053
+ """ Initialize the GANLoss class.
1054
+
1055
+ Parameters:
1056
+ gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
1057
+ target_real_label (bool) - - label for a real image
1058
+ target_fake_label (bool) - - label of a fake image
1059
+
1060
+ Note: Do not use sigmoid as the last layer of Discriminator.
1061
+ LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
1062
+ """
1063
+ super(GANLoss, self).__init__()
1064
+ self.register_buffer('real_label', th.tensor(target_real_label))
1065
+ self.register_buffer('fake_label', th.tensor(target_fake_label))
1066
+ self.gan_mode = gan_mode
1067
+ if gan_mode == 'lsgan':
1068
+ self.loss = nn.MSELoss()
1069
+ elif gan_mode == 'vanilla':
1070
+ self.loss = nn.BCEWithLogitsLoss()
1071
+ elif gan_mode in ['wgangp']:
1072
+ self.loss = None
1073
+ else:
1074
+ raise NotImplementedError('gan mode %s not implemented' % gan_mode)
1075
+
1076
+ def get_target_tensor(self, prediction, target_is_real):
1077
+ """Create label tensors with the same size as the input.
1078
+
1079
+ Parameters:
1080
+ prediction (tensor) - - tpyically the prediction from a discriminator
1081
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
1082
+
1083
+ Returns:
1084
+ A label tensor filled with ground truth label, and with the size of the input
1085
+ """
1086
+
1087
+ if target_is_real:
1088
+ target_tensor = self.real_label
1089
+ else:
1090
+ target_tensor = self.fake_label
1091
+ return target_tensor.expand_as(prediction)
1092
+
1093
+ def __call__(self, prediction, target_is_real):
1094
+ """Calculate loss given Discriminator's output and grount truth labels.
1095
+
1096
+ Parameters:
1097
+ prediction (tensor) - - tpyically the prediction output from a discriminator
1098
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
1099
+
1100
+ Returns:
1101
+ the calculated loss.
1102
+ """
1103
+ if self.gan_mode in ['lsgan', 'vanilla']:
1104
+ target_tensor = self.get_target_tensor(prediction, target_is_real)
1105
+ loss = self.loss(prediction, target_tensor)
1106
+ elif self.gan_mode == 'wgangp':
1107
+ if target_is_real:
1108
+ loss = -prediction.mean()
1109
+ else:
1110
+ loss = prediction.mean()
1111
+ return loss
1112
+
1113
+
1114
+ def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
1115
+ """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
1116
+
1117
+ Arguments:
1118
+ netD (network) -- discriminator network
1119
+ real_data (tensor array) -- real images
1120
+ fake_data (tensor array) -- generated images from the generator
1121
+ device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
1122
+ type (str) -- if we mix real and fake data or not [real | fake | mixed].
1123
+ constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
1124
+ lambda_gp (float) -- weight for this loss
1125
+
1126
+ Returns the gradient penalty loss
1127
+ """
1128
+ if lambda_gp > 0.0:
1129
+ if type == 'real': # either use real images, fake images, or a linear interpolation of two.
1130
+ interpolatesv = real_data
1131
+ elif type == 'fake':
1132
+ interpolatesv = fake_data
1133
+ elif type == 'mixed':
1134
+ alpha = th.rand(real_data.shape[0], 1, device=device)
1135
+ alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
1136
+ interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
1137
+ else:
1138
+ raise NotImplementedError('{} not implemented'.format(type))
1139
+ interpolatesv.requires_grad_(True)
1140
+ disc_interpolates = netD(interpolatesv)
1141
+ gradients = th.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
1142
+ grad_outputs=th.ones(disc_interpolates.size()).to(device),
1143
+ create_graph=True, retain_graph=True, only_inputs=True)
1144
+ gradients = gradients[0].view(real_data.size(0), -1) # flat the data
1145
+ gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
1146
+ return gradient_penalty, gradients
1147
+ else:
1148
+ return 0.0, None
latent_DDCM_CCFG.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import torch
3
+ import numpy as np
4
+ import gradio as gr
5
+
6
+ from util.file import generate_binary_file, load_numpy_from_binary_bitwise
7
+ from latent_utils import generate_ours
8
+
9
+
10
+ @torch.no_grad()
11
+ @spaces.GPU(duration=80)
12
+ def main(prompt, T, K, K_tilde, model_type='512x512', bitstream=None, avail_models=None,
13
+ progress=gr.Progress(track_tqdm=True)):
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ indices = load_numpy_from_binary_bitwise(bitstream, K, T, model_type, T - 1)
17
+ if indices is not None:
18
+ indices = indices.to(device)
19
+
20
+ # model, _ = load_model(img_size_to_id[img_size], T, device, float16=True, compile=False)
21
+ model = avail_models[model_type].to(device)
22
+
23
+ model.device = device
24
+ model.model.to(device=device)
25
+
26
+ model.model.scheduler.device = device
27
+
28
+ model.set_timesteps(T, device=device)
29
+
30
+ with torch.no_grad():
31
+ x, indices = generate_ours(model,
32
+ num_noises=K,
33
+ num_noises_to_optimize=K_tilde,
34
+ prompt=prompt,
35
+ negative_prompt=None,
36
+ indices=indices)
37
+ x = (x / 2 + 0.5).clamp(0, 1)
38
+ x = x.detach().cpu().squeeze().numpy()
39
+ x = np.transpose(x, (1, 2, 0))
40
+ torch.cuda.empty_cache()
41
+
42
+ if bitstream is None:
43
+ indices = generate_binary_file(indices.numpy(), K, T, model_type)
44
+ return x, indices
45
+ return x
latent_DDCM_compression.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import spaces
4
+ import torch
5
+ import torchvision
6
+
7
+ from latent_utils import compress
8
+ from util.file import generate_binary_file, load_numpy_from_binary_bitwise
9
+ from util.img_utils import resize_and_crop
10
+
11
+
12
+ @torch.no_grad()
13
+ @spaces.GPU(duration=80)
14
+ def main(img_to_compress, T, K, model_type='512x512', bitstream=None, avail_models=None,
15
+ progress=gr.Progress(track_tqdm=True)):
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ indices = load_numpy_from_binary_bitwise(bitstream, K, T, model_type, T - 1)
18
+ if indices is not None:
19
+ indices = indices.to(device)
20
+ if indices is None:
21
+ img_to_compress = resize_and_crop(img_to_compress, int(model_type.split('x')[0]))
22
+ img_to_compress = (torchvision.transforms.ToTensor()(img_to_compress) * 2) - 1
23
+ img_to_compress = img_to_compress.unsqueeze(0).to(device)
24
+ else:
25
+ img_to_compress = None
26
+ print(T, K, model_type)
27
+ # model, _ = load_model(img_size_to_id[img_size], T, device, float16=True, compile=False)
28
+ model = avail_models[model_type].to(device)
29
+
30
+ model.device = device
31
+ model.model.to(device=device)
32
+
33
+ model.model.scheduler.device = device
34
+ # model.model.scheduler.scheduler = model.model.scheduler.scheduler.to(device)
35
+
36
+ model.set_timesteps(T, device=device)
37
+ model.num_timesteps = T
38
+ with torch.no_grad():
39
+ x, indices = compress(model, img_to_compress, K, indices, device=device)
40
+ x = (x / 2 + 0.5).clamp(0, 1)
41
+ x = x.detach().cpu().squeeze().numpy()
42
+ x = np.transpose(x, (1, 2, 0))
43
+ torch.cuda.empty_cache()
44
+ indices = generate_binary_file(indices.numpy(), K, T, model_type)
45
+ if bitstream is None:
46
+ return x, indices
47
+ return x
latent_models.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import DDIMScheduler, StableDiffusionPipeline
3
+ from typing import Optional, Tuple, Union
4
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
5
+
6
+
7
+ class PipelineWrapper(torch.nn.Module):
8
+ def __init__(self, model_id: str,
9
+ timesteps: int,
10
+ device: torch.device,
11
+ float16: bool = False,
12
+ compile: bool = True,
13
+ token: Optional[str] = None, *args, **kwargs) -> None:
14
+ super().__init__(*args, **kwargs)
15
+ self.model_id = model_id
16
+ self.num_timesteps = timesteps
17
+ self.device = device
18
+ self.float16 = float16
19
+ self.token = token
20
+ self.compile = compile
21
+ self.model = None
22
+
23
+ # def get_sigma(self, timestep: int) -> float:
24
+ # sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.model.scheduler.alphas_cumprod - 1)
25
+ # return sqrt_recipm1_alphas_cumprod[timestep]
26
+
27
+ @property
28
+ def timesteps(self) -> torch.Tensor:
29
+ return self.model.scheduler.timesteps
30
+
31
+ @property
32
+ def dtype(self) -> torch.dtype:
33
+ if self.model is None:
34
+ raise AttributeError("Model is not initialized.")
35
+ return self.model.unet.dtype
36
+
37
+ def get_x_0_hat(self, xt: torch.Tensor, epst: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
38
+ return self.model.scheduler.get_x_0_hat(xt, epst, timestep)
39
+
40
+ def finish_step(self, xt: torch.Tensor, pred_x0: torch.Tensor, epst: torch.Tensor,
41
+ timestep: torch.Tensor, variance_noise: torch.Tensor,
42
+ **kwargs) -> torch.Tensor:
43
+ return self.model.scheduler.finish_step(xt, pred_x0, epst, timestep, variance_noise, **kwargs)
44
+
45
+ def get_variance(self, timestep: torch.Tensor) -> torch.Tensor:
46
+ return self.model.scheduler.get_variance(timestep)
47
+
48
+ def set_timesteps(self, timesteps: int, device: torch.device) -> None:
49
+ self.model.scheduler.set_timesteps(timesteps, device=device)
50
+
51
+ def encode_image(self, x: torch.Tensor) -> torch.Tensor:
52
+ pass
53
+
54
+ def decode_image(self, x: torch.Tensor) -> torch.Tensor:
55
+ pass
56
+
57
+ def encode_prompt(self, prompt: torch.Tensor, negative_prompt=None) -> Tuple[torch.Tensor, torch.Tensor]:
58
+ pass
59
+
60
+ def get_epst(self, xt: torch.Tensor, t: torch.Tensor, prompt_embeds: torch.Tensor,
61
+ guidance_scale: Optional[float] = None, **kwargs) -> torch.Tensor:
62
+ pass
63
+
64
+ def get_image_size(self) -> Tuple[int, int]:
65
+ return self.model.unet.config.sample_size * self.model.vae_scale_factor
66
+
67
+ def get_noise_shape(self, imsize: Union[int, Tuple[int]], batch_size: int) -> Tuple[int, ...]:
68
+ if isinstance(imsize, int):
69
+ imsize = (imsize, imsize)
70
+ variance_noise_shape = (batch_size,
71
+ self.model.unet.config.in_channels,
72
+ imsize[-2],
73
+ imsize[-1])
74
+ return variance_noise_shape
75
+
76
+ def get_latent_shape(self, orig_image_shape: Union[int, Tuple[int, int]]) -> Tuple[int, ...]:
77
+ if isinstance(orig_image_shape, int):
78
+ orig_image_shape = (orig_image_shape, orig_image_shape)
79
+ return (self.model.unet.config.in_channels,
80
+ orig_image_shape[0] // self.model.vae_scale_factor,
81
+ orig_image_shape[1] // self.model.vae_scale_factor)
82
+
83
+ def get_pre_kwargs(self, **kwargs) -> dict:
84
+ return {}
85
+
86
+
87
+ class StableDiffWrapper(PipelineWrapper):
88
+ def __init__(self, scheduler='ddpm', *args, **kwargs) -> None:
89
+ super().__init__(*args, **kwargs)
90
+ self.scheduler_type = scheduler
91
+ try:
92
+ self.model = StableDiffusionPipeline.from_pretrained(
93
+ self.model_id,
94
+ torch_dtype=torch.float16 if self.float16 else torch.float32,
95
+ token=self.token).to(self.device)
96
+ except OSError:
97
+ self.model = StableDiffusionPipeline.from_pretrained(
98
+ self.model_id,
99
+ torch_dtype=torch.float16 if self.float16 else torch.float32,
100
+ token=self.token, force_download=True
101
+ ).to(self.device)
102
+
103
+ if scheduler == 'ddpm' or 'ddim' in scheduler:
104
+ eta = 1.0 if 'ddpm' in scheduler else float(scheduler.split('-')[1])
105
+ self.model.scheduler = DDIMWrapper(model_id=self.model_id, device=self.device,
106
+ eta=eta,
107
+ float16=self.float16, token=self.token)
108
+
109
+ self.model.scheduler.set_timesteps(self.num_timesteps, device=self.device)
110
+ if self.compile:
111
+ try:
112
+ self.model.unet = torch.compile(self.model.unet, mode="reduce-overhead", fullgraph=True)
113
+ except Exception as e:
114
+ print(f"Error compiling model: {e}")
115
+
116
+ def encode_image(self, x: torch.Tensor) -> torch.Tensor:
117
+ return (self.model.vae.encode(x).latent_dist.mode() * self.model.vae.config.scaling_factor) # .float()
118
+
119
+ def decode_image(self, x: torch.Tensor) -> torch.Tensor:
120
+ if x.device != self.device:
121
+ orig_device = self.model.vae.device
122
+ self.model.vae.to(x.device)
123
+ ret = self.model.vae.decode(x / self.model.vae.config.scaling_factor).sample.clamp(-1, 1)
124
+ self.model.vae.to(orig_device)
125
+ return ret
126
+ return self.model.vae.decode(x / self.model.vae.config.scaling_factor).sample.clamp(-1, 1)
127
+
128
+ def encode_prompt(self, prompt: torch.Tensor, negative_prompt=None) -> Tuple[torch.Tensor, torch.Tensor]:
129
+ do_cfg = (negative_prompt is not None) or prompt != ""
130
+
131
+ prompt_embeds, negative_prompt_embeds = self.model.encode_prompt(
132
+ prompt, self.device, 1,
133
+ do_cfg,
134
+ negative_prompt,
135
+ )
136
+
137
+ if do_cfg:
138
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
139
+ return prompt_embeds
140
+
141
+ def get_epst(self, xt: torch.Tensor, t: torch.Tensor, prompt_embeds: torch.Tensor,
142
+ guidance_scale: Optional[float] = None, return_everything=False, **kwargs):
143
+ do_cfg = prompt_embeds.shape[0] > 1
144
+ xt = torch.cat([xt] * 2) if do_cfg else xt
145
+
146
+ # predict the noise residual
147
+ noise_pred = self.model.unet(xt, t, encoder_hidden_states=prompt_embeds, return_dict=False)[0]
148
+
149
+ # perform guidance
150
+ if do_cfg:
151
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
152
+ return None, noise_pred_uncond, noise_pred_text
153
+ return None, noise_pred, None
154
+
155
+
156
+ class SchedulerWrapper(object):
157
+ def __init__(self, model_id: str, device: torch.device,
158
+ float16: bool = False, token: Optional[str] = None, *args, **kwargs) -> None:
159
+ super().__init__(*args, **kwargs)
160
+ self.model_id = model_id
161
+ self.device = device
162
+ self.float16 = float16
163
+ self.token = token
164
+ self.scheduler = None
165
+
166
+ @property
167
+ def timesteps(self) -> torch.Tensor:
168
+ return self.scheduler.timesteps
169
+
170
+ def set_timesteps(self, timesteps: int, device: torch.device) -> None:
171
+ self.scheduler.set_timesteps(timesteps, device=device)
172
+ if self.scheduler.timesteps[0] == 1000:
173
+ self.scheduler.timesteps -= 1
174
+
175
+ def get_x_0_hat(self, xt: torch.Tensor, epst: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
176
+ pass
177
+
178
+ def finish_step(self, xt: torch.Tensor, pred_x0: torch.Tensor, epst: torch.Tensor,
179
+ timestep: torch.Tensor, variance_noise: torch.Tensor,
180
+ **kwargs) -> torch.Tensor:
181
+ pass
182
+
183
+ def get_variance(self, timestep: torch.Tensor) -> torch.Tensor:
184
+ pass
185
+
186
+
187
+ class DDIMWrapper(SchedulerWrapper):
188
+ def __init__(self, eta, *args, **kwargs) -> None:
189
+ super().__init__(*args, **kwargs)
190
+ self.scheduler = DDIMScheduler.from_pretrained(
191
+ self.model_id, subfolder="scheduler",
192
+ torch_dtype=torch.float16 if self.float16 else torch.float32,
193
+ token=self.token,
194
+ device=self.device, timestep_spacing='linspace')
195
+ self.eta = eta
196
+
197
+ def get_x_0_hat(self, xt: torch.Tensor, epst: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
198
+ # compute alphas, betas
199
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
200
+ beta_prod_t = 1 - alpha_prod_t
201
+ # compute predicted original sample from predicted noise also called
202
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
203
+ if self.scheduler.config.prediction_type == 'epsilon':
204
+ pred_original_sample = (xt - beta_prod_t ** (0.5) * epst) / alpha_prod_t ** (0.5)
205
+ elif self.scheduler.config.prediction_type == 'v_prediction':
206
+ pred_original_sample = (alpha_prod_t ** 0.5) * xt - (beta_prod_t ** 0.5) * epst
207
+
208
+ return pred_original_sample
209
+
210
+ def finish_step(self, xt: torch.Tensor, pred_x0: torch.Tensor, epst: torch.Tensor,
211
+ timestep: torch.Tensor, variance_noise: torch.Tensor,
212
+ eta=None) -> torch.Tensor:
213
+ if eta is None:
214
+ eta = self.eta
215
+ prev_timestep = timestep - self.scheduler.config.num_train_timesteps // \
216
+ self.scheduler.num_inference_steps
217
+ # 2. compute alphas, betas
218
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
219
+ alpha_prod_t_prev = self._get_alpha_prod_t_prev(prev_timestep)
220
+ beta_prod_t = 1 - alpha_prod_t
221
+
222
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
223
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
224
+ variance = self.get_variance(timestep)
225
+ std_dev_t = eta * variance ** (0.5)
226
+
227
+ # std_dev_t = eta * variance ** (0.5)
228
+ # Take care of asymetric reverse process (asyrp)
229
+ if self.scheduler.config.prediction_type == 'epsilon':
230
+ model_output_direction = epst
231
+ elif self.scheduler.config.prediction_type == 'v_prediction':
232
+ model_output_direction = (alpha_prod_t**0.5) * epst + (beta_prod_t**0.5) * xt
233
+
234
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
235
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction
236
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
237
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_x0 + pred_sample_direction
238
+
239
+ # 8. Add noice if eta > 0
240
+ if eta > 0:
241
+ sigma_z = std_dev_t * variance_noise
242
+ prev_sample = prev_sample + sigma_z
243
+
244
+ return prev_sample
245
+
246
+ def get_variance(self, timestep: torch.Tensor) -> torch.Tensor:
247
+ prev_timestep = timestep - self.scheduler.config.num_train_timesteps // \
248
+ self.scheduler.num_inference_steps
249
+ variance = self.scheduler._get_variance(timestep, prev_timestep)
250
+ return variance
251
+
252
+ def _get_alpha_prod_t_prev(self, prev_timestep: torch.Tensor) -> torch.Tensor:
253
+ return self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 \
254
+ else self.scheduler.final_alpha_cumprod
255
+
256
+ def load_model(model_id: str, timesteps: int,
257
+ device: torch.device, blip: bool = False,
258
+ float16: bool = False, token: Optional[str] = None,
259
+ compile: bool = True,
260
+ blip_model="Salesforce/blip2-opt-2.7b-coco", scheduler: str = 'ddpm') -> PipelineWrapper:
261
+ pipeline = StableDiffWrapper(model_id=model_id, timesteps=timesteps, device=device,
262
+ scheduler=scheduler,
263
+ float16=float16, token=token, compile=compile)
264
+
265
+ pipeline = pipeline.to(device)
266
+ if blip:
267
+ pipeline.blip_processor = Blip2Processor.from_pretrained(blip_model)
268
+ try:
269
+ print(device if torch.cuda.get_device_properties(0).total_memory/(1024**3) > 18 else 'cpu')
270
+ pipeline.blip_model = Blip2ForConditionalGeneration.from_pretrained(
271
+ blip_model,).to(device if torch.cuda.get_device_properties(0).total_memory/(1024**3) > 18 else 'cpu')
272
+ except OSError:
273
+ pipeline.blip_model = Blip2ForConditionalGeneration.from_pretrained(
274
+ blip_model, force_download=True).to(device if torch.cuda.get_device_properties(0).total_memory/(1024**3) > 18 else 'cpu')
275
+ pipeline.blip_max_words = 32
276
+
277
+ image_size = pipeline.get_image_size()
278
+ return pipeline, image_size
latent_utils.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import spaces
3
+ import time
4
+ from glob import glob
5
+ from typing import Callable, Optional, Tuple, Union, Dict
6
+ import random
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import torch
10
+ import torchvision.transforms as transforms
11
+ from PIL import Image
12
+ from torch.utils.data import DataLoader
13
+ from torchvision.datasets import VisionDataset
14
+ from tqdm import tqdm
15
+ from util.img_utils import clear_color
16
+
17
+ from latent_models import PipelineWrapper
18
+
19
+
20
+ def set_seed(seed: int) -> None:
21
+ torch.manual_seed(seed)
22
+ np.random.seed(seed)
23
+ random.seed(seed)
24
+ torch.cuda.manual_seed_all(seed)
25
+ # torch.backends.cudnn.deterministic = True
26
+ # torch.backends.cudnn.benchmark = False
27
+
28
+
29
+ class MinusOneToOne(torch.nn.Module):
30
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
31
+ return tensor * 2 - 1
32
+
33
+
34
+ class ResizePIL(torch.nn.Module):
35
+ def __init__(self, image_size: Optional[Union[int, Tuple[int, int]]] = None):
36
+ super().__init__()
37
+ if isinstance(image_size, int):
38
+ image_size = (image_size, image_size)
39
+ self.image_size = image_size
40
+
41
+ def forward(self, pil_image: Image.Image) -> Image.Image:
42
+ if self.image_size is not None:
43
+ pil_image = pil_image.resize(self.image_size)
44
+ return pil_image
45
+
46
+
47
+ def get_loader(datadir: str, batch_size: int = 1,
48
+ crop_to: Optional[Union[int, Tuple[int, int]]] = None,
49
+ include_path: bool = False) -> DataLoader:
50
+ transform = transforms.Compose([
51
+ ResizePIL(crop_to),
52
+ transforms.ToTensor(),
53
+ MinusOneToOne(),
54
+ ])
55
+ loader = DataLoader(FoldersDataset(datadir, transform, include_path=include_path),
56
+ batch_size=batch_size,
57
+ shuffle=True, num_workers=0, drop_last=False)
58
+ return loader
59
+
60
+
61
+ class FoldersDataset(VisionDataset):
62
+ def __init__(self, root: str, transforms: Optional[Callable] = None,
63
+ include_path: bool = False) -> None:
64
+ super().__init__(root, transforms)
65
+ self.include_path = include_path
66
+ self.root = root
67
+
68
+ if os.path.isdir(root):
69
+ self.fpaths = glob(os.path.join(root, '**', '*.png'), recursive=True)
70
+ self.fpaths += glob(os.path.join(root, '**', '*.JPEG'), recursive=True)
71
+ self.fpaths += glob(os.path.join(root, '**', '*.jpg'), recursive=True)
72
+ self.fpaths = sorted(self.fpaths)
73
+ assert len(self.fpaths) > 0, "File list is empty. Check the root."
74
+ elif os.path.exists(root):
75
+ self.fpaths = [root]
76
+ else:
77
+ raise FileNotFoundError(f"File not found: {root}")
78
+
79
+ def __len__(self):
80
+ return len(self.fpaths)
81
+
82
+ def __getitem__(self, index: int) -> Tuple[torch.Tensor, str]:
83
+ fpath = self.fpaths[index]
84
+ img = Image.open(fpath).convert('RGB')
85
+
86
+ if self.transforms is not None:
87
+ img = self.transforms(img)
88
+
89
+ path = ""
90
+ if self.include_path:
91
+ dirname = os.path.dirname(fpath)
92
+ # remove root from dirname
93
+ path = dirname[len(self.root) + 1:]
94
+ return img, os.path.basename(fpath).split(os.extsep)[0], path
95
+
96
+
97
+ @spaces.GPU
98
+ def compress(model: PipelineWrapper,
99
+ img_to_compress: torch.Tensor,
100
+ num_noises: int,
101
+ loaded_indices,
102
+ device,
103
+ ):
104
+ # model.set_timesteps(model.num_timesteps, device=device)
105
+ dtype = model.dtype
106
+
107
+ prompt_embeds = model.encode_prompt("", None)
108
+
109
+ set_seed(88888888)
110
+ if img_to_compress is None:
111
+ img_to_compress = torch.zeros(1, 3, model.get_image_size(), model.get_image_size(), device=device)
112
+ enc_im = model.encode_image(img_to_compress.to(dtype))
113
+ kwargs = model.get_pre_kwargs(height=img_to_compress.shape[-2], width=img_to_compress.shape[-1],
114
+ prompt_embeds=prompt_embeds)
115
+
116
+ set_seed(100000)
117
+ xt = torch.randn(1, *enc_im.shape[1:], device=device, dtype=dtype)
118
+
119
+ result_noise_indices = []
120
+
121
+ pbar = tqdm(model.timesteps)
122
+ for idx, t in enumerate(pbar):
123
+ set_seed(idx)
124
+ noise = torch.randn(num_noises, *xt.shape[1:], device=device, dtype=dtype)
125
+
126
+ _, epst, _ = model.get_epst(xt, t, prompt_embeds, 0.0, **kwargs)
127
+ x_0_hat = model.get_x_0_hat(xt, epst, t)
128
+ if loaded_indices is None:
129
+
130
+ if t >= 1:
131
+ dot_prod = torch.matmul(noise.view(noise.shape[0], -1),
132
+ (enc_im - x_0_hat).view(enc_im.shape[0], -1).transpose(0, 1))
133
+ best_idx = torch.argmax(dot_prod)
134
+ best_noise = noise[best_idx]
135
+ else:
136
+ best_noise = noise[0]
137
+ else:
138
+ if t >= 1:
139
+ best_idx = loaded_indices[idx]
140
+ best_noise = noise[best_idx]
141
+ else:
142
+ best_noise = noise[0]
143
+ if t >= 1:
144
+ result_noise_indices.append(best_idx)
145
+
146
+ xt = model.finish_step(xt, x_0_hat, epst, t, best_noise.unsqueeze(0), eta=None)
147
+
148
+ try:
149
+ img = model.decode_image(xt)
150
+ except torch.OutOfMemoryError:
151
+ img = model.decode_image(xt.to('cpu'))
152
+
153
+ return img, torch.tensor(result_noise_indices).squeeze().cpu()
154
+
155
+
156
+ @spaces.GPU
157
+ def generate_ours(model: PipelineWrapper,
158
+ num_noises: int,
159
+ num_noises_to_optimize: int,
160
+ prompt: str = "",
161
+ negative_prompt: Optional[str] = None,
162
+ indices = None,
163
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
164
+ device = model.device
165
+ dtype = model.dtype
166
+ # print(num_noises, num_noises_to_optimize, flush=True)
167
+ # model.set_timesteps(model.num_timesteps, device=device)
168
+
169
+ set_seed(88888888)
170
+ if prompt is None:
171
+ prompt = ""
172
+ prompt_embeds = model.encode_prompt(prompt, negative_prompt)
173
+
174
+ kwargs = model.get_pre_kwargs(height=model.get_image_size(),
175
+ width=model.get_image_size(),
176
+ prompt_embeds=prompt_embeds)
177
+
178
+ set_seed(100000)
179
+ xt = torch.randn(1, *model.get_latent_shape(model.get_image_size()), device=device, dtype=dtype)
180
+
181
+ result_noise_indices = []
182
+ pbar = tqdm(model.timesteps)
183
+ for idx, t in enumerate(pbar):
184
+ set_seed(idx)
185
+ noise = torch.randn(num_noises, *xt.shape[1:], device=device, dtype=dtype) # Codebook
186
+
187
+ _, epst_uncond, epst_cond = model.get_epst(xt, t, prompt_embeds, 1.0, return_everything=True, **kwargs)
188
+
189
+ x_0_hat = model.get_x_0_hat(xt, epst_uncond, t)
190
+ if t >= 1:
191
+ if indices is None:
192
+ prev_classif_score = epst_uncond - epst_cond
193
+ set_seed(int(time.time_ns() & 0xFFFFFFFF))
194
+ noise_indices = torch.randint(0, num_noises, size=(num_noises_to_optimize,), device=device)
195
+ loss = torch.matmul(noise[noise_indices].view(num_noises_to_optimize, -1),
196
+ prev_classif_score.view(prev_classif_score.shape[0], -1).transpose(0, 1))
197
+ best_idx = noise_indices[torch.argmax(loss)]
198
+ else:
199
+ best_idx = indices[idx]
200
+ best_noise = noise[best_idx]
201
+ result_noise_indices.append(best_idx)
202
+
203
+ else:
204
+ best_noise = torch.zeros_like(noise[0])
205
+ xt = model.finish_step(xt, x_0_hat, epst_uncond, t, best_noise)
206
+
207
+ try:
208
+ img = model.decode_image(xt)
209
+ except torch.OutOfMemoryError:
210
+ img = model.decode_image(xt.to('cpu'))
211
+ return img, torch.stack(result_noise_indices).squeeze().cpu()
212
+
213
+
214
+ def decompress(model: PipelineWrapper,
215
+ image_size: Tuple[int, int],
216
+ indices: Dict[str, torch.Tensor],
217
+ num_noises: int,
218
+ prompt: str = "",
219
+ negative_prompt: Optional[str] = None,
220
+ tedit: int = 0,
221
+ new_prompt: str = "",
222
+ new_negative_prompt: Optional[str] = None,
223
+ guidance_scale: float = 3.0,
224
+ num_pursuit_noises: Optional[int] = 1,
225
+ num_pursuit_coef_bits: Optional[int] = 3,
226
+ t_range: Tuple[int, int] = (999, 0),
227
+ robust_randn: bool = False
228
+ ) -> torch.Tensor:
229
+ noise_indices = indices['noise_indices']
230
+ coeffs_indices = indices['coeff_indices']
231
+ num_pursuit_noises = num_pursuit_noises if num_pursuit_noises is not None else 1
232
+ num_pursuit_coef_bits = num_pursuit_coef_bits if num_pursuit_coef_bits is not None else 1
233
+
234
+ device = model.device
235
+ dtype = model.dtype
236
+ # model.set_timesteps(model.num_timesteps, device=device)
237
+
238
+ set_seed(88888888)
239
+ orig_prompt_embeds = model.encode_prompt(prompt, negative_prompt)
240
+ kwargs_orig = model.get_pre_kwargs(height=image_size[-2], width=image_size[-1],
241
+ prompt_embeds=orig_prompt_embeds)
242
+ if new_prompt != prompt or new_negative_prompt != negative_prompt:
243
+ new_prompt_embeds = model.encode_prompt(new_prompt, new_negative_prompt)
244
+ kwargs_new = model.get_pre_kwargs(height=image_size[-2], width=image_size[-1],
245
+ prompt_embeds=new_prompt_embeds)
246
+ else:
247
+ new_prompt_embeds = orig_prompt_embeds
248
+ kwargs_new = kwargs_orig
249
+
250
+ set_seed(100000)
251
+ xt = torch.randn(1, *model.get_latent_shape(image_size), device=device, dtype=dtype)
252
+
253
+ pbar = tqdm(model.timesteps)
254
+ for idx, t in enumerate(pbar):
255
+ set_seed(idx)
256
+
257
+ dont_optimize_t = not (t_range[0] >= t >= t_range[1])
258
+ # No intermittent support
259
+
260
+ if robust_randn:
261
+ noise = get_robust_randn(num_noises if not dont_optimize_t else 1, xt.shape[1:], device, dtype)
262
+ else:
263
+ noise = torch.randn(num_noises if not dont_optimize_t else 1, *xt.shape[1:], device=device, dtype=dtype)
264
+
265
+ curr_embs = orig_prompt_embeds if idx < tedit else new_prompt_embeds
266
+ curr_kwargs = kwargs_orig if idx < tedit else kwargs_new
267
+ epst = model.get_epst(xt, t, curr_embs, guidance_scale, **curr_kwargs)
268
+ x_0_hat = model.get_x_0_hat(xt, epst, t)
269
+
270
+ curr_t_noise_indices = noise_indices[idx]
271
+ best_noise = noise[curr_t_noise_indices[0]]
272
+ pursuit_coefs = torch.linspace(0, 1, 2 ** num_pursuit_coef_bits + 1)[1:]
273
+ if num_pursuit_noises > 1:
274
+ curr_t_coeffs_indices = coeffs_indices[idx]
275
+ if curr_t_coeffs_indices[0] == -1:
276
+ continue
277
+ for pursuit_idx in range(1, num_pursuit_noises):
278
+ pursuit_coef = pursuit_coefs[curr_t_coeffs_indices[pursuit_idx]]
279
+ best_noise = best_noise * torch.sqrt(pursuit_coef) + noise[
280
+ curr_t_noise_indices[pursuit_idx]] * torch.sqrt(1 - pursuit_coef)
281
+ best_noise /= best_noise.std()
282
+ best_noise = best_noise.unsqueeze(0)
283
+ xt = model.finish_step(xt, x_0_hat, epst, t, best_noise)
284
+ img = model.decode_image(xt)
285
+ return img
286
+
287
+
288
+ def inf_generate(model: PipelineWrapper,
289
+ prompt: str = "",
290
+ negative_prompt: Optional[str] = None,
291
+ guidance_scale: float = 7.0,
292
+ record: int = 0,
293
+ save_root: str = "") -> Tuple[torch.Tensor, torch.Tensor]:
294
+ device = model.device
295
+ dtype = model.dtype
296
+
297
+ model.set_timesteps(model.num_timesteps, device=device)
298
+
299
+ prompt_embeds = model.encode_prompt(prompt, negative_prompt)
300
+ kwargs = model.get_pre_kwargs(height=model.get_image_size(),
301
+ width=model.get_image_size(),
302
+ prompt_embeds=prompt_embeds)
303
+
304
+ xt = torch.randn(1, *model.get_latent_shape(model.get_image_size()), device=device, dtype=dtype)
305
+ pbar = tqdm(model.timesteps)
306
+ for idx, t in enumerate(pbar):
307
+ noise = torch.randn(1, *xt.shape[1:], device=device, dtype=dtype)
308
+
309
+ epst = model.get_epst(xt, t, prompt_embeds, guidance_scale, **kwargs)
310
+ x_0_hat = model.get_x_0_hat(xt, epst, t)
311
+ xt = model.finish_step(xt, x_0_hat, epst, t, noise)
312
+
313
+ if record and not idx % record:
314
+ img = model.decode_image(x_0_hat)
315
+ plt.imsave(os.path.join(save_root, f"progress/x_0_hat_{str(t.item()).zfill(4)}.png"),
316
+ clear_color(img[0].unsqueeze(0), normalize=False))
317
+ try:
318
+ img = model.decode_image(xt)
319
+ except torch.OutOfMemoryError:
320
+ img = model.decode_image(xt.to('cpu'))
321
+
322
+ return img
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ opencv-python
3
+ scipy
4
+ tqdm
5
+ lmdb
6
+ pyyaml
7
+ yapf
8
+ dctorch
9
+ einops
10
+ timm
11
+ diffusers
12
+ facexlib
13
+ pyiqa
14
+ torch==2.4.0
15
+ torchvision==0.19.0
util/__init__.py ADDED
File without changes
util/basicsr_img_util.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import os
5
+ import torch
6
+ from torchvision.utils import make_grid
7
+
8
+
9
+ def img2tensor(imgs, bgr2rgb=True, float32=True):
10
+ """Numpy array to tensor.
11
+
12
+ Args:
13
+ imgs (list[ndarray] | ndarray): Input images.
14
+ bgr2rgb (bool): Whether to change bgr to rgb.
15
+ float32 (bool): Whether to change to float32.
16
+
17
+ Returns:
18
+ list[tensor] | tensor: Tensor images. If returned results only have
19
+ one element, just return tensor.
20
+ """
21
+
22
+ def _totensor(img, bgr2rgb, float32):
23
+ if img.shape[2] == 3 and bgr2rgb:
24
+ if img.dtype == 'float64':
25
+ img = img.astype('float32')
26
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
27
+ img = torch.from_numpy(img.transpose(2, 0, 1))
28
+ if float32:
29
+ img = img.float()
30
+ return img
31
+
32
+ if isinstance(imgs, list):
33
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
34
+ else:
35
+ return _totensor(imgs, bgr2rgb, float32)
36
+
37
+
38
+ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
39
+ """Convert torch Tensors into image numpy arrays.
40
+
41
+ After clamping to [min, max], values will be normalized to [0, 1].
42
+
43
+ Args:
44
+ tensor (Tensor or list[Tensor]): Accept shapes:
45
+ 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
46
+ 2) 3D Tensor of shape (3/1 x H x W);
47
+ 3) 2D Tensor of shape (H x W).
48
+ Tensor channel should be in RGB order.
49
+ rgb2bgr (bool): Whether to change rgb to bgr.
50
+ out_type (numpy type): output types. If ``np.uint8``, transform outputs
51
+ to uint8 type with range [0, 255]; otherwise, float type with
52
+ range [0, 1]. Default: ``np.uint8``.
53
+ min_max (tuple[int]): min and max values for clamp.
54
+
55
+ Returns:
56
+ (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
57
+ shape (H x W). The channel order is BGR.
58
+ """
59
+ if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
60
+ raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
61
+
62
+ if torch.is_tensor(tensor):
63
+ tensor = [tensor]
64
+ result = []
65
+ for _tensor in tensor:
66
+ _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
67
+ _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
68
+
69
+ n_dim = _tensor.dim()
70
+ if n_dim == 4:
71
+ img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
72
+ img_np = img_np.transpose(1, 2, 0)
73
+ if rgb2bgr:
74
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
75
+ elif n_dim == 3:
76
+ img_np = _tensor.numpy()
77
+ img_np = img_np.transpose(1, 2, 0)
78
+ if img_np.shape[2] == 1: # gray image
79
+ img_np = np.squeeze(img_np, axis=2)
80
+ else:
81
+ if rgb2bgr:
82
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
83
+ elif n_dim == 2:
84
+ img_np = _tensor.numpy()
85
+ else:
86
+ raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
87
+ if out_type == np.uint8:
88
+ # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
89
+ img_np = (img_np * 255.0).round()
90
+ img_np = img_np.astype(out_type)
91
+ result.append(img_np)
92
+ if len(result) == 1:
93
+ result = result[0]
94
+ return result
95
+
96
+
97
+ def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
98
+ """This implementation is slightly faster than tensor2img.
99
+ It now only supports torch tensor with shape (1, c, h, w).
100
+
101
+ Args:
102
+ tensor (Tensor): Now only support torch tensor with (1, c, h, w).
103
+ rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
104
+ min_max (tuple[int]): min and max values for clamp.
105
+ """
106
+ output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
107
+ output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
108
+ output = output.type(torch.uint8).cpu().numpy()
109
+ if rgb2bgr:
110
+ output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
111
+ return output
112
+
113
+
114
+ def imfrombytes(content, flag='color', float32=False):
115
+ """Read an image from bytes.
116
+
117
+ Args:
118
+ content (bytes): Image bytes got from files or other streams.
119
+ flag (str): Flags specifying the color type of a loaded image,
120
+ candidates are `color`, `grayscale` and `unchanged`.
121
+ float32 (bool): Whether to change to float32., If True, will also norm
122
+ to [0, 1]. Default: False.
123
+
124
+ Returns:
125
+ ndarray: Loaded image array.
126
+ """
127
+ img_np = np.frombuffer(content, np.uint8)
128
+ imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
129
+ img = cv2.imdecode(img_np, imread_flags[flag])
130
+ if float32:
131
+ img = img.astype(np.float32) / 255.
132
+ return img
133
+
134
+
135
+ def imwrite(img, file_path, params=None, auto_mkdir=True):
136
+ """Write image to file.
137
+
138
+ Args:
139
+ img (ndarray): Image array to be written.
140
+ file_path (str): Image file path.
141
+ params (None or list): Same as opencv's :func:`imwrite` interface.
142
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
143
+ whether to create it automatically.
144
+
145
+ Returns:
146
+ bool: Successful or not.
147
+ """
148
+ if auto_mkdir:
149
+ dir_name = os.path.abspath(os.path.dirname(file_path))
150
+ os.makedirs(dir_name, exist_ok=True)
151
+ ok = cv2.imwrite(file_path, img, params)
152
+ if not ok:
153
+ raise IOError('Failed in writing images.')
154
+
155
+
156
+ def crop_border(imgs, crop_border):
157
+ """Crop borders of images.
158
+
159
+ Args:
160
+ imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
161
+ crop_border (int): Crop border for each end of height and weight.
162
+
163
+ Returns:
164
+ list[ndarray]: Cropped images.
165
+ """
166
+ if crop_border == 0:
167
+ return imgs
168
+ else:
169
+ if isinstance(imgs, list):
170
+ return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
171
+ else:
172
+ return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
util/file.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+
3
+ import numpy as np
4
+ import torch
5
+ import gradio as gr
6
+
7
+
8
+ def save_numpy_as_binary_bitwise(array, K, filename):
9
+ """Save a NumPy array as a binary file with bitwise storage."""
10
+ bits_per_value = int(np.ceil(np.log2(K))) # Number of bits required per value
11
+ bitstring = ''.join(format(val, f'0{bits_per_value}b') for val in array) # Convert each number to binary
12
+
13
+ # Convert bitstring to bytes
14
+ byte_array = int(bitstring, 2).to_bytes((len(bitstring) + 7) // 8, byteorder='big')
15
+
16
+ # Write to binary file
17
+ with open(filename, 'wb') as f:
18
+ f.write(byte_array)
19
+
20
+
21
+ def load_numpy_from_binary_bitwise(filename, K, T, model_type, effective_num_values):
22
+ if filename is None:
23
+ return None
24
+ """Load a NumPy array from a binary file stored in bitwise format."""
25
+ bits_per_value = int(np.ceil(np.log2(K))) # Number of bits required per value
26
+
27
+ if f'-K{K}-' not in filename:
28
+ raise gr.Error("Please set the codebook size to match the bitstream file you provided")
29
+
30
+ if f'-T{T}-' not in filename:
31
+ raise gr.Error("Please set the number of diffusion timesteps to match the bitstream file you provided")
32
+
33
+ if f'-M{model_type}-' not in filename:
34
+ raise gr.Error("Please set the image size to match the bitstream file you provided")
35
+
36
+ # Read the binary file as bytes
37
+ with open(filename, 'rb') as f:
38
+ byte_data = f.read()
39
+ # Convert bytes to a binary string
40
+ bitstring = bin(int.from_bytes(byte_data, byteorder='big'))[2:] # Remove '0b' prefix
41
+
42
+ # Pad with leading zeros if needed
43
+ bitstring = bitstring.zfill(effective_num_values * bits_per_value)
44
+
45
+ # Extract values from bitstring
46
+ values = [int(bitstring[i:i + bits_per_value], 2) for i in range(0, len(bitstring), bits_per_value)]
47
+
48
+ return torch.from_numpy(np.array(values, dtype=np.int32)).squeeze()
49
+
50
+
51
+ def generate_binary_file(np_arr, num_noises, timesteps, model_type):
52
+ temp_file = tempfile.NamedTemporaryFile(delete=False,
53
+ suffix=f".bitstream-T{timesteps}-K{num_noises}-M{model_type}-")
54
+ save_numpy_as_binary_bitwise(np_arr, num_noises, temp_file.name)
55
+ return temp_file.name
util/img_utils.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import scipy
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from torch.autograd import Variable
7
+ from PIL import Image
8
+ import matplotlib.pyplot as plt
9
+
10
+ """
11
+ Helper functions for new types of inverse problems
12
+ """
13
+
14
+
15
+ def fft2(x):
16
+ """ FFT with shifting DC to the center of the image"""
17
+ return torch.fft.fftshift(torch.fft.fft2(x), dim=[-1, -2])
18
+
19
+
20
+ def ifft2(x):
21
+ """ IFFT with shifting DC to the corner of the image prior to transform"""
22
+ return torch.fft.ifft2(torch.fft.ifftshift(x, dim=[-1, -2]))
23
+
24
+
25
+ def fft2_m(x):
26
+ """ FFT for multi-coil """
27
+ if not torch.is_complex(x):
28
+ x = x.type(torch.complex64)
29
+ return torch.view_as_complex(fft2c_new(torch.view_as_real(x)))
30
+
31
+
32
+ def ifft2_m(x):
33
+ """ IFFT for multi-coil """
34
+ if not torch.is_complex(x):
35
+ x = x.type(torch.complex64)
36
+ return torch.view_as_complex(ifft2c_new(torch.view_as_real(x)))
37
+
38
+
39
+ def clear(x):
40
+ x = x.detach().cpu().squeeze().numpy()
41
+ return normalize_np(x)
42
+
43
+
44
+ def resize_and_crop(image, imsize=512):
45
+ width, height = image.size
46
+
47
+ if width < height:
48
+ new_width = imsize
49
+ new_height = int((imsize / width) * height)
50
+ else:
51
+ new_height = imsize
52
+ new_width = int((imsize / height) * width)
53
+
54
+ image_resized = image.resize((new_width, new_height))
55
+
56
+ left = (new_width - imsize) / 2
57
+ top = (new_height - imsize) / 2
58
+ right = (new_width + imsize) / 2
59
+ bottom = (new_height + imsize) / 2
60
+
61
+ image_cropped = image_resized.crop((left, top, right, bottom))
62
+
63
+ return image_cropped
64
+
65
+
66
+ def clear_color(x, normalize=True):
67
+ if torch.is_complex(x):
68
+ x = torch.abs(x)
69
+ if normalize:
70
+ x = x.detach().cpu().squeeze().numpy()
71
+ if x.ndim == 3:
72
+ return normalize_np(np.transpose(x, (1, 2, 0)))
73
+ else:
74
+ return normalize_np(x)
75
+ else:
76
+ x = (x / 2 + 0.5).clamp(0, 1)
77
+ x = x.detach().cpu().squeeze().numpy()
78
+ if x.ndim == 3:
79
+ return np.transpose(x, (1, 2, 0))
80
+ else:
81
+ return x
82
+
83
+
84
+ def normalize_np(img):
85
+ """ Normalize img in arbitrary range to [0, 1] """
86
+ img -= np.min(img)
87
+ img /= np.max(img)
88
+ return img
89
+
90
+
91
+ def prepare_im(load_dir, image_size, device):
92
+ ref_img = torch.from_numpy(normalize_np(plt.imread(load_dir)[:, :, :3].astype(np.float32))).to(device)
93
+ ref_img = ref_img.permute(2, 0, 1)
94
+ ref_img = ref_img.view(1, 3, image_size, image_size)
95
+ ref_img = ref_img * 2 - 1
96
+ return ref_img
97
+
98
+
99
+ def fold_unfold(img_t, kernel, stride):
100
+ img_shape = img_t.shape
101
+ B, C, H, W = img_shape
102
+ print("\n----- input shape: ", img_shape)
103
+
104
+ patches = img_t.unfold(3, kernel, stride).unfold(2, kernel, stride).permute(0, 1, 2, 3, 5, 4)
105
+
106
+ print("\n----- patches shape:", patches.shape)
107
+ # reshape output to match F.fold input
108
+ patches = patches.contiguous().view(B, C, -1, kernel * kernel)
109
+ print("\n", patches.shape) # [B, C, nb_patches_all, kernel_size*kernel_size]
110
+ patches = patches.permute(0, 1, 3, 2)
111
+ print("\n", patches.shape) # [B, C, kernel_size*kernel_size, nb_patches_all]
112
+ patches = patches.contiguous().view(B, C * kernel * kernel, -1)
113
+ print("\n", patches.shape) # [B, C*prod(kernel_size), L] as expected by Fold
114
+
115
+ output = F.fold(patches, output_size=(H, W),
116
+ kernel_size=kernel, stride=stride)
117
+ # mask that mimics the original folding:
118
+ recovery_mask = F.fold(torch.ones_like(patches), output_size=(
119
+ H, W), kernel_size=kernel, stride=stride)
120
+ output = output / recovery_mask
121
+
122
+ return patches, output
123
+
124
+
125
+ def reshape_patch(x, crop_size=128, dim_size=3):
126
+ x = x.transpose(0, 2).squeeze() # [9, 3*(128**2)]
127
+ x = x.view(dim_size ** 2, 3, crop_size, crop_size)
128
+ return x
129
+
130
+
131
+ def reshape_patch_back(x, crop_size=128, dim_size=3):
132
+ x = x.view(dim_size ** 2, 3 * (crop_size ** 2)).unsqueeze(dim=-1)
133
+ x = x.transpose(0, 2)
134
+ return x
135
+
136
+
137
+ class Unfolder:
138
+ def __init__(self, img_size=256, crop_size=128, stride=64):
139
+ self.img_size = img_size
140
+ self.crop_size = crop_size
141
+ self.stride = stride
142
+
143
+ self.unfold = nn.Unfold(crop_size, stride=stride)
144
+ self.dim_size = (img_size - crop_size) // stride + 1
145
+
146
+ def __call__(self, x):
147
+ patch1D = self.unfold(x)
148
+ patch2D = reshape_patch(patch1D, crop_size=self.crop_size, dim_size=self.dim_size)
149
+ return patch2D
150
+
151
+
152
+ def center_crop(img, new_width=None, new_height=None):
153
+ width = img.shape[1]
154
+ height = img.shape[0]
155
+
156
+ if new_width is None:
157
+ new_width = min(width, height)
158
+
159
+ if new_height is None:
160
+ new_height = min(width, height)
161
+
162
+ left = int(np.ceil((width - new_width) / 2))
163
+ right = width - int(np.floor((width - new_width) / 2))
164
+
165
+ top = int(np.ceil((height - new_height) / 2))
166
+ bottom = height - int(np.floor((height - new_height) / 2))
167
+
168
+ if len(img.shape) == 2:
169
+ center_cropped_img = img[top:bottom, left:right]
170
+ else:
171
+ center_cropped_img = img[top:bottom, left:right, ...]
172
+
173
+ return center_cropped_img
174
+
175
+
176
+ class Folder:
177
+ def __init__(self, img_size=256, crop_size=128, stride=64):
178
+ self.img_size = img_size
179
+ self.crop_size = crop_size
180
+ self.stride = stride
181
+
182
+ self.fold = nn.Fold(img_size, crop_size, stride=stride)
183
+ self.dim_size = (img_size - crop_size) // stride + 1
184
+
185
+ def __call__(self, patch2D):
186
+ patch1D = reshape_patch_back(patch2D, crop_size=self.crop_size, dim_size=self.dim_size)
187
+ return self.fold(patch1D)
188
+
189
+
190
+ def random_sq_bbox(img, mask_shape, image_size=256, margin=(16, 16)):
191
+ """Generate a random sqaure mask for inpainting
192
+ """
193
+ B, C, H, W = img.shape
194
+ h, w = mask_shape
195
+ margin_height, margin_width = margin
196
+ maxt = image_size - margin_height - h
197
+ maxl = image_size - margin_width - w
198
+
199
+ # bb
200
+ t = np.random.randint(margin_height, maxt)
201
+ l = np.random.randint(margin_width, maxl)
202
+
203
+ # make mask
204
+ mask = torch.ones([B, C, H, W], device=img.device)
205
+ mask[..., t:t + h, l:l + w] = 0
206
+
207
+ return mask, t, t + h, l, l + w
208
+
209
+
210
+ class mask_generator:
211
+ def __init__(self, mask_type, mask_len_range=None, mask_prob_range=None,
212
+ image_size=256, margin=(16, 16)):
213
+ """
214
+ (mask_len_range): given in (min, max) tuple.
215
+ Specifies the range of box size in each dimension
216
+ (mask_prob_range): for the case of random masking,
217
+ specify the probability of individual pixels being masked
218
+ """
219
+ assert mask_type in ['box', 'random', 'both', 'extreme']
220
+ self.mask_type = mask_type
221
+ self.mask_len_range = mask_len_range
222
+ self.mask_prob_range = mask_prob_range
223
+ self.image_size = image_size
224
+ self.margin = margin
225
+
226
+ def _retrieve_box(self, img):
227
+ l, h = self.mask_len_range
228
+ l, h = int(l), int(h)
229
+ mask_h = np.random.randint(l, h)
230
+ mask_w = np.random.randint(l, h)
231
+ mask, t, tl, w, wh = random_sq_bbox(img,
232
+ mask_shape=(mask_h, mask_w),
233
+ image_size=self.image_size,
234
+ margin=self.margin)
235
+ return mask, t, tl, w, wh
236
+
237
+ def _retrieve_random(self, img):
238
+ total = self.image_size ** 2
239
+ # random pixel sampling
240
+ l, h = self.mask_prob_range
241
+ prob = np.random.uniform(l, h)
242
+ mask_vec = torch.ones([1, self.image_size * self.image_size])
243
+ samples = np.random.choice(self.image_size * self.image_size, int(total * prob), replace=False)
244
+ mask_vec[:, samples] = 0
245
+ mask_b = mask_vec.view(1, self.image_size, self.image_size)
246
+ mask_b = mask_b.repeat(3, 1, 1)
247
+ mask = torch.ones_like(img, device=img.device)
248
+ mask[:, ...] = mask_b
249
+ return mask
250
+
251
+ def __call__(self, img):
252
+ if self.mask_type == 'random':
253
+ mask = self._retrieve_random(img)
254
+ return mask
255
+ elif self.mask_type == 'box':
256
+ mask, t, th, w, wl = self._retrieve_box(img)
257
+ return mask
258
+ elif self.mask_type == 'extreme':
259
+ mask, t, th, w, wl = self._retrieve_box(img)
260
+ mask = 1. - mask
261
+ return mask
262
+
263
+
264
+ def unnormalize(img, s=0.95):
265
+ scaling = torch.quantile(img.abs(), s)
266
+ return img / scaling
267
+
268
+
269
+ def normalize(img, s=0.95):
270
+ scaling = torch.quantile(img.abs(), s)
271
+ return img * scaling
272
+
273
+
274
+ def dynamic_thresholding(img, s=0.95):
275
+ img = normalize(img, s=s)
276
+ return torch.clip(img, -1., 1.)
277
+
278
+
279
+ def get_gaussian_kernel(kernel_size=31, std=0.5):
280
+ n = np.zeros([kernel_size, kernel_size])
281
+ n[kernel_size // 2, kernel_size // 2] = 1
282
+ k = scipy.ndimage.gaussian_filter(n, sigma=std)
283
+ k = k.astype(np.float32)
284
+ return k
285
+
286
+
287
+ def init_kernel_torch(kernel, device="cuda:0"):
288
+ h, w = kernel.shape
289
+ kernel = Variable(torch.from_numpy(kernel).to(device), requires_grad=True)
290
+ kernel = kernel.view(1, 1, h, w)
291
+ kernel = kernel.repeat(1, 3, 1, 1)
292
+ return kernel
293
+
294
+
295
+ class Blurkernel(nn.Module):
296
+ def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0, device=None):
297
+ super().__init__()
298
+ self.blur_type = blur_type
299
+ self.kernel_size = kernel_size
300
+ self.std = std
301
+ self.device = device
302
+ self.seq = nn.Sequential(
303
+ nn.ReflectionPad2d(self.kernel_size // 2),
304
+ nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3)
305
+ )
306
+
307
+ self.weights_init()
308
+
309
+ def forward(self, x):
310
+ return self.seq(x)
311
+
312
+ def weights_init(self):
313
+ if self.blur_type == "gaussian":
314
+ n = np.zeros((self.kernel_size, self.kernel_size))
315
+ n[self.kernel_size // 2, self.kernel_size // 2] = 1
316
+ k = scipy.ndimage.gaussian_filter(n, sigma=self.std)
317
+ k = torch.from_numpy(k)
318
+ self.k = k
319
+ for name, f in self.named_parameters():
320
+ f.data.copy_(k)
321
+ elif self.blur_type == "motion":
322
+ k = Kernel(size=(self.kernel_size, self.kernel_size), intensity=self.std).kernelMatrix
323
+ k = torch.from_numpy(k)
324
+ self.k = k
325
+ for name, f in self.named_parameters():
326
+ f.data.copy_(k)
327
+
328
+ def update_weights(self, k):
329
+ if not torch.is_tensor(k):
330
+ k = torch.from_numpy(k).to(self.device)
331
+ for name, f in self.named_parameters():
332
+ f.data.copy_(k)
333
+
334
+ def get_kernel(self):
335
+ return self.k
336
+
337
+
338
+ class exact_posterior():
339
+ def __init__(self, betas, sigma_0, label_dim, input_dim):
340
+ self.betas = betas
341
+ self.sigma_0 = sigma_0
342
+ self.label_dim = label_dim
343
+ self.input_dim = input_dim
344
+
345
+ def py_given_x0(self, x0, y, A, verbose=False):
346
+ norm_const = 1 / ((2 * np.pi) ** self.input_dim * self.sigma_0 ** 2)
347
+ exp_in = -1 / (2 * self.sigma_0 ** 2) * torch.linalg.norm(y - A(x0)) ** 2
348
+ if not verbose:
349
+ return norm_const * torch.exp(exp_in)
350
+ else:
351
+ return norm_const * torch.exp(exp_in), norm_const, exp_in
352
+
353
+ def pxt_given_x0(self, x0, xt, t, verbose=False):
354
+ beta_t = self.betas[t]
355
+ norm_const = 1 / ((2 * np.pi) ** self.label_dim * beta_t)
356
+ exp_in = -1 / (2 * beta_t) * torch.linalg.norm(xt - np.sqrt(1 - beta_t) * x0) ** 2
357
+ if not verbose:
358
+ return norm_const * torch.exp(exp_in)
359
+ else:
360
+ return norm_const * torch.exp(exp_in), norm_const, exp_in
361
+
362
+ def prod_logsumexp(self, x0, xt, y, A, t):
363
+ py_given_x0_density, pyx0_nc, pyx0_ei = self.py_given_x0(x0, y, A, verbose=True)
364
+ pxt_given_x0_density, pxtx0_nc, pxtx0_ei = self.pxt_given_x0(x0, xt, t, verbose=True)
365
+ summand = (pyx0_nc * pxtx0_nc) * torch.exp(-pxtx0_ei - pxtx0_ei)
366
+ return torch.logsumexp(summand, dim=0)
367
+
368
+
369
+ def map2tensor(gray_map):
370
+ """Move gray maps to GPU, no normalization is done"""
371
+ return torch.FloatTensor(gray_map).unsqueeze(0).unsqueeze(0).cuda()
372
+
373
+
374
+ def create_penalty_mask(k_size, penalty_scale):
375
+ """Generate a mask of weights penalizing values close to the boundaries"""
376
+ center_size = k_size // 2 + k_size % 2
377
+ mask = create_gaussian(size=k_size, sigma1=k_size, is_tensor=False)
378
+ mask = 1 - mask / np.max(mask)
379
+ margin = (k_size - center_size) // 2 - 1
380
+ mask[margin:-margin, margin:-margin] = 0
381
+ return penalty_scale * mask
382
+
383
+
384
+ def create_gaussian(size, sigma1, sigma2=-1, is_tensor=False):
385
+ """Return a Gaussian"""
386
+ func1 = [np.exp(-z ** 2 / (2 * sigma1 ** 2)) / np.sqrt(2 * np.pi * sigma1 ** 2) for z in
387
+ range(-size // 2 + 1, size // 2 + 1)]
388
+ func2 = func1 if sigma2 == -1 else [np.exp(-z ** 2 / (2 * sigma2 ** 2)) / np.sqrt(2 * np.pi * sigma2 ** 2) for z in
389
+ range(-size // 2 + 1, size // 2 + 1)]
390
+ return torch.FloatTensor(np.outer(func1, func2)).cuda() if is_tensor else np.outer(func1, func2)
391
+
392
+
393
+ def total_variation_loss(img, weight):
394
+ tv_h = ((img[:, :, 1:, :] - img[:, :, :-1, :]).pow(2)).mean()
395
+ tv_w = ((img[:, :, :, 1:] - img[:, :, :, :-1]).pow(2)).mean()
396
+ return weight * (tv_h + tv_w)
397
+
398
+
399
+ if __name__ == '__main__':
400
+ import numpy as np
401
+ from torch import nn
402
+ import matplotlib.pyplot as plt
403
+
404
+ device = 'cuda:0'
405
+ load_path = '/media/harry/tomo/FFHQ/256/test/00000.png'
406
+ img = torch.tensor(plt.imread(load_path)[:, :, :3]) # rgb
407
+ img = torch.permute(img, (2, 0, 1)).view(1, 3, 256, 256).to(device)
408
+
409
+ mask_len_range = (32, 128)
410
+ mask_prob_range = (0.3, 0.7)
411
+ image_size = 256
412
+ # mask
413
+ mask_gen = mask_generator(
414
+ mask_len_range=mask_len_range,
415
+ mask_prob_range=mask_prob_range,
416
+ image_size=image_size
417
+ )
418
+ mask = mask_gen(img)
419
+
420
+ mask = np.transpose(mask.squeeze().cpu().detach().numpy(), (1, 2, 0))
421
+
422
+ plt.imshow(mask)
423
+ plt.show()