ford442 commited on
Commit
72e088a
·
verified ·
1 Parent(s): 3089ad7

Create free_lunch_utils.py

Browse files
Files changed (1) hide show
  1. free_lunch_utils.py +340 -0
free_lunch_utils.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.fft as fft
3
+ from diffusers.models.unet_2d_condition import logger
4
+ from diffusers.utils import is_torch_version
5
+ from typing import Any, Dict, List, Optional, Tuple, Union
6
+
7
+
8
+ def isinstance_str(x: object, cls_name: str):
9
+ """
10
+ Checks whether x has any class *named* cls_name in its ancestry.
11
+ Doesn't require access to the class's implementation.
12
+
13
+ Useful for patching!
14
+ """
15
+
16
+ for _cls in x.__class__.__mro__:
17
+ if _cls.__name__ == cls_name:
18
+ return True
19
+
20
+ return False
21
+
22
+
23
+ def Fourier_filter(x, threshold, scale):
24
+ dtype = x.dtype
25
+ x = x.type(torch.float32)
26
+ # FFT
27
+ x_freq = fft.fftn(x, dim=(-2, -1))
28
+ x_freq = fft.fftshift(x_freq, dim=(-2, -1))
29
+
30
+ B, C, H, W = x_freq.shape
31
+ mask = torch.ones((B, C, H, W)).cuda()
32
+
33
+ crow, ccol = H // 2, W //2
34
+ mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
35
+ x_freq = x_freq * mask
36
+
37
+ # IFFT
38
+ x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
39
+ x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
40
+
41
+ x_filtered = x_filtered.type(dtype)
42
+ return x_filtered
43
+
44
+
45
+ def register_upblock2d(model):
46
+ def up_forward(self):
47
+ def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
48
+ for resnet in self.resnets:
49
+ # pop res hidden states
50
+ res_hidden_states = res_hidden_states_tuple[-1]
51
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
52
+ #print(f"in upblock2d, hidden states shape: {hidden_states.shape}")
53
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
54
+
55
+ if self.training and self.gradient_checkpointing:
56
+
57
+ def create_custom_forward(module):
58
+ def custom_forward(*inputs):
59
+ return module(*inputs)
60
+
61
+ return custom_forward
62
+
63
+ if is_torch_version(">=", "1.11.0"):
64
+ hidden_states = torch.utils.checkpoint.checkpoint(
65
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
66
+ )
67
+ else:
68
+ hidden_states = torch.utils.checkpoint.checkpoint(
69
+ create_custom_forward(resnet), hidden_states, temb
70
+ )
71
+ else:
72
+ hidden_states = resnet(hidden_states, temb)
73
+
74
+ if self.upsamplers is not None:
75
+ for upsampler in self.upsamplers:
76
+ hidden_states = upsampler(hidden_states, upsample_size)
77
+
78
+ return hidden_states
79
+
80
+ return forward
81
+
82
+ for i, upsample_block in enumerate(model.unet.up_blocks):
83
+ if isinstance_str(upsample_block, "UpBlock2D"):
84
+ upsample_block.forward = up_forward(upsample_block)
85
+
86
+
87
+ def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
88
+ def up_forward(self):
89
+ def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
90
+ for resnet in self.resnets:
91
+ # pop res hidden states
92
+ res_hidden_states = res_hidden_states_tuple[-1]
93
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
94
+ #print(f"in free upblock2d, hidden states shape: {hidden_states.shape}")
95
+
96
+ # # --------------- FreeU code -----------------------
97
+ # # Only operate on the first two stages
98
+ # if hidden_states.shape[1] == 1280:
99
+ # hidden_states[:,:640] = hidden_states[:,:640] * self.b1
100
+ # res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
101
+ # if hidden_states.shape[1] == 640:
102
+ # hidden_states[:,:320] = hidden_states[:,:320] * self.b2
103
+ # res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
104
+ # # ---------------------------------------------------------
105
+
106
+ # --------------- FreeU code -----------------------
107
+ # Only operate on the first two stages
108
+ if hidden_states.shape[1] == 1280:
109
+ hidden_mean = hidden_states.mean(1).unsqueeze(1)
110
+ B = hidden_mean.shape[0]
111
+ hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
112
+ hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
113
+
114
+ hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
115
+
116
+ hidden_states[:,:640] = hidden_states[:,:640] * ((self.b1 - 1 ) * hidden_mean + 1)
117
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
118
+ if hidden_states.shape[1] == 640:
119
+ hidden_mean = hidden_states.mean(1).unsqueeze(1)
120
+ B = hidden_mean.shape[0]
121
+ hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
122
+ hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
123
+ hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
124
+
125
+ hidden_states[:,:320] = hidden_states[:,:320] * ((self.b2 - 1 ) * hidden_mean + 1)
126
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
127
+ # ---------------------------------------------------------
128
+
129
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
130
+
131
+ if self.training and self.gradient_checkpointing:
132
+
133
+ def create_custom_forward(module):
134
+ def custom_forward(*inputs):
135
+ return module(*inputs)
136
+
137
+ return custom_forward
138
+
139
+ if is_torch_version(">=", "1.11.0"):
140
+ hidden_states = torch.utils.checkpoint.checkpoint(
141
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
142
+ )
143
+ else:
144
+ hidden_states = torch.utils.checkpoint.checkpoint(
145
+ create_custom_forward(resnet), hidden_states, temb
146
+ )
147
+ else:
148
+ hidden_states = resnet(hidden_states, temb)
149
+
150
+ if self.upsamplers is not None:
151
+ for upsampler in self.upsamplers:
152
+ hidden_states = upsampler(hidden_states, upsample_size)
153
+
154
+ return hidden_states
155
+
156
+ return forward
157
+
158
+ for i, upsample_block in enumerate(model.unet.up_blocks):
159
+ if isinstance_str(upsample_block, "UpBlock2D"):
160
+ upsample_block.forward = up_forward(upsample_block)
161
+ setattr(upsample_block, 'b1', b1)
162
+ setattr(upsample_block, 'b2', b2)
163
+ setattr(upsample_block, 's1', s1)
164
+ setattr(upsample_block, 's2', s2)
165
+
166
+
167
+ def register_crossattn_upblock2d(model):
168
+ def up_forward(self):
169
+ def forward(
170
+ hidden_states: torch.FloatTensor,
171
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
172
+ temb: Optional[torch.FloatTensor] = None,
173
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
174
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
175
+ upsample_size: Optional[int] = None,
176
+ attention_mask: Optional[torch.FloatTensor] = None,
177
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
178
+ ):
179
+ for resnet, attn in zip(self.resnets, self.attentions):
180
+ # pop res hidden states
181
+ #print(f"in crossatten upblock2d, hidden states shape: {hidden_states.shape}")
182
+ res_hidden_states = res_hidden_states_tuple[-1]
183
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
184
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
185
+
186
+ if self.training and self.gradient_checkpointing:
187
+
188
+ def create_custom_forward(module, return_dict=None):
189
+ def custom_forward(*inputs):
190
+ if return_dict is not None:
191
+ return module(*inputs, return_dict=return_dict)
192
+ else:
193
+ return module(*inputs)
194
+
195
+ return custom_forward
196
+
197
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
198
+ hidden_states = torch.utils.checkpoint.checkpoint(
199
+ create_custom_forward(resnet),
200
+ hidden_states,
201
+ temb,
202
+ **ckpt_kwargs,
203
+ )
204
+ hidden_states = torch.utils.checkpoint.checkpoint(
205
+ create_custom_forward(attn, return_dict=False),
206
+ hidden_states,
207
+ encoder_hidden_states,
208
+ None, # timestep
209
+ None, # class_labels
210
+ cross_attention_kwargs,
211
+ attention_mask,
212
+ encoder_attention_mask,
213
+ **ckpt_kwargs,
214
+ )[0]
215
+ else:
216
+ hidden_states = resnet(hidden_states, temb)
217
+ hidden_states = attn(
218
+ hidden_states,
219
+ encoder_hidden_states=encoder_hidden_states,
220
+ cross_attention_kwargs=cross_attention_kwargs,
221
+ attention_mask=attention_mask,
222
+ encoder_attention_mask=encoder_attention_mask,
223
+ return_dict=False,
224
+ )[0]
225
+
226
+ if self.upsamplers is not None:
227
+ for upsampler in self.upsamplers:
228
+ hidden_states = upsampler(hidden_states, upsample_size)
229
+
230
+ return hidden_states
231
+
232
+ return forward
233
+
234
+ for i, upsample_block in enumerate(model.unet.up_blocks):
235
+ if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
236
+ upsample_block.forward = up_forward(upsample_block)
237
+
238
+
239
+ def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
240
+ def up_forward(self):
241
+ def forward(
242
+ hidden_states: torch.FloatTensor,
243
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
244
+ temb: Optional[torch.FloatTensor] = None,
245
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
246
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
247
+ upsample_size: Optional[int] = None,
248
+ attention_mask: Optional[torch.FloatTensor] = None,
249
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
250
+ ):
251
+ for resnet, attn in zip(self.resnets, self.attentions):
252
+ # pop res hidden states
253
+ #print(f"in free crossatten upblock2d, hidden states shape: {hidden_states.shape}")
254
+ res_hidden_states = res_hidden_states_tuple[-1]
255
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
256
+
257
+ # --------------- FreeU code -----------------------
258
+ # Only operate on the first two stages
259
+ if hidden_states.shape[1] == 1280:
260
+ hidden_mean = hidden_states.mean(1).unsqueeze(1)
261
+ B = hidden_mean.shape[0]
262
+ hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
263
+ hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
264
+
265
+ hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
266
+
267
+ hidden_states[:,:640] = hidden_states[:,:640] * ((self.b1 - 1 ) * hidden_mean + 1)
268
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
269
+ if hidden_states.shape[1] == 640:
270
+ hidden_mean = hidden_states.mean(1).unsqueeze(1)
271
+ B = hidden_mean.shape[0]
272
+ hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
273
+ hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
274
+ hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
275
+
276
+ hidden_states[:,:320] = hidden_states[:,:320] * ((self.b2 - 1 ) * hidden_mean + 1)
277
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
278
+ # ---------------------------------------------------------
279
+
280
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
281
+
282
+ if self.training and self.gradient_checkpointing:
283
+
284
+ def create_custom_forward(module, return_dict=None):
285
+ def custom_forward(*inputs):
286
+ if return_dict is not None:
287
+ return module(*inputs, return_dict=return_dict)
288
+ else:
289
+ return module(*inputs)
290
+
291
+ return custom_forward
292
+
293
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
294
+ hidden_states = torch.utils.checkpoint.checkpoint(
295
+ create_custom_forward(resnet),
296
+ hidden_states,
297
+ temb,
298
+ **ckpt_kwargs,
299
+ )
300
+ hidden_states = torch.utils.checkpoint.checkpoint(
301
+ create_custom_forward(attn, return_dict=False),
302
+ hidden_states,
303
+ encoder_hidden_states,
304
+ None, # timestep
305
+ None, # class_labels
306
+ cross_attention_kwargs,
307
+ attention_mask,
308
+ encoder_attention_mask,
309
+ **ckpt_kwargs,
310
+ )[0]
311
+ else:
312
+ hidden_states = resnet(hidden_states, temb)
313
+ # hidden_states = attn(
314
+ # hidden_states,
315
+ # encoder_hidden_states=encoder_hidden_states,
316
+ # cross_attention_kwargs=cross_attention_kwargs,
317
+ # encoder_attention_mask=encoder_attention_mask,
318
+ # return_dict=False,
319
+ # )[0]
320
+ hidden_states = attn(
321
+ hidden_states,
322
+ encoder_hidden_states=encoder_hidden_states,
323
+ cross_attention_kwargs=cross_attention_kwargs,
324
+ )[0]
325
+
326
+ if self.upsamplers is not None:
327
+ for upsampler in self.upsamplers:
328
+ hidden_states = upsampler(hidden_states, upsample_size)
329
+
330
+ return hidden_states
331
+
332
+ return forward
333
+
334
+ for i, upsample_block in enumerate(model.unet.up_blocks):
335
+ if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
336
+ upsample_block.forward = up_forward(upsample_block)
337
+ setattr(upsample_block, 'b1', b1)
338
+ setattr(upsample_block, 'b2', b2)
339
+ setattr(upsample_block, 's1', s1)
340
+ setattr(upsample_block, 's2', s2)