ford442 commited on
Commit
58d3e52
·
verified ·
1 Parent(s): bacb38f

Delete free_lunch_utils.py

Browse files
Files changed (1) hide show
  1. free_lunch_utils.py +0 -340
free_lunch_utils.py DELETED
@@ -1,340 +0,0 @@
1
- import torch
2
- import torch.fft as fft
3
- from diffusers.models.unets.unet_2d_condition import logger
4
- from diffusers.utils import is_torch_version
5
- from typing import Any, Dict, List, Optional, Tuple, Union
6
-
7
-
8
- def isinstance_str(x: object, cls_name: str):
9
- """
10
- Checks whether x has any class *named* cls_name in its ancestry.
11
- Doesn't require access to the class's implementation.
12
-
13
- Useful for patching!
14
- """
15
-
16
- for _cls in x.__class__.__mro__:
17
- if _cls.__name__ == cls_name:
18
- return True
19
-
20
- return False
21
-
22
-
23
- def Fourier_filter(x, threshold, scale):
24
- dtype = x.dtype
25
- x = x.type(torch.float32)
26
- # FFT
27
- x_freq = fft.fftn(x, dim=(-2, -1))
28
- x_freq = fft.fftshift(x_freq, dim=(-2, -1))
29
-
30
- B, C, H, W = x_freq.shape
31
- mask = torch.ones((B, C, H, W)).cuda()
32
-
33
- crow, ccol = H // 2, W //2
34
- mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
35
- x_freq = x_freq * mask
36
-
37
- # IFFT
38
- x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
39
- x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
40
-
41
- x_filtered = x_filtered.type(dtype)
42
- return x_filtered
43
-
44
-
45
- def register_upblock2d(model):
46
- def up_forward(self):
47
- def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
48
- for resnet in self.resnets:
49
- # pop res hidden states
50
- res_hidden_states = res_hidden_states_tuple[-1]
51
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
52
- #print(f"in upblock2d, hidden states shape: {hidden_states.shape}")
53
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
54
-
55
- if self.training and self.gradient_checkpointing:
56
-
57
- def create_custom_forward(module):
58
- def custom_forward(*inputs):
59
- return module(*inputs)
60
-
61
- return custom_forward
62
-
63
- if is_torch_version(">=", "1.11.0"):
64
- hidden_states = torch.utils.checkpoint.checkpoint(
65
- create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
66
- )
67
- else:
68
- hidden_states = torch.utils.checkpoint.checkpoint(
69
- create_custom_forward(resnet), hidden_states, temb
70
- )
71
- else:
72
- hidden_states = resnet(hidden_states, temb)
73
-
74
- if self.upsamplers is not None:
75
- for upsampler in self.upsamplers:
76
- hidden_states = upsampler(hidden_states, upsample_size)
77
-
78
- return hidden_states
79
-
80
- return forward
81
-
82
- for i, upsample_block in enumerate(model.unet.up_blocks):
83
- if isinstance_str(upsample_block, "UpBlock2D"):
84
- upsample_block.forward = up_forward(upsample_block)
85
-
86
-
87
- def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
88
- def up_forward(self):
89
- def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
90
- for resnet in self.resnets:
91
- # pop res hidden states
92
- res_hidden_states = res_hidden_states_tuple[-1]
93
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
94
- #print(f"in free upblock2d, hidden states shape: {hidden_states.shape}")
95
-
96
- # # --------------- FreeU code -----------------------
97
- # # Only operate on the first two stages
98
- # if hidden_states.shape[1] == 1280:
99
- # hidden_states[:,:640] = hidden_states[:,:640] * self.b1
100
- # res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
101
- # if hidden_states.shape[1] == 640:
102
- # hidden_states[:,:320] = hidden_states[:,:320] * self.b2
103
- # res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
104
- # # ---------------------------------------------------------
105
-
106
- # --------------- FreeU code -----------------------
107
- # Only operate on the first two stages
108
- if hidden_states.shape[1] == 1280:
109
- hidden_mean = hidden_states.mean(1).unsqueeze(1)
110
- B = hidden_mean.shape[0]
111
- hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
112
- hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
113
-
114
- hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
115
-
116
- hidden_states[:,:640] = hidden_states[:,:640] * ((self.b1 - 1 ) * hidden_mean + 1)
117
- res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
118
- if hidden_states.shape[1] == 640:
119
- hidden_mean = hidden_states.mean(1).unsqueeze(1)
120
- B = hidden_mean.shape[0]
121
- hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
122
- hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
123
- hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
124
-
125
- hidden_states[:,:320] = hidden_states[:,:320] * ((self.b2 - 1 ) * hidden_mean + 1)
126
- res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
127
- # ---------------------------------------------------------
128
-
129
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
130
-
131
- if self.training and self.gradient_checkpointing:
132
-
133
- def create_custom_forward(module):
134
- def custom_forward(*inputs):
135
- return module(*inputs)
136
-
137
- return custom_forward
138
-
139
- if is_torch_version(">=", "1.11.0"):
140
- hidden_states = torch.utils.checkpoint.checkpoint(
141
- create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
142
- )
143
- else:
144
- hidden_states = torch.utils.checkpoint.checkpoint(
145
- create_custom_forward(resnet), hidden_states, temb
146
- )
147
- else:
148
- hidden_states = resnet(hidden_states, temb)
149
-
150
- if self.upsamplers is not None:
151
- for upsampler in self.upsamplers:
152
- hidden_states = upsampler(hidden_states, upsample_size)
153
-
154
- return hidden_states
155
-
156
- return forward
157
-
158
- for i, upsample_block in enumerate(model.unet.up_blocks):
159
- if isinstance_str(upsample_block, "UpBlock2D"):
160
- upsample_block.forward = up_forward(upsample_block)
161
- setattr(upsample_block, 'b1', b1)
162
- setattr(upsample_block, 'b2', b2)
163
- setattr(upsample_block, 's1', s1)
164
- setattr(upsample_block, 's2', s2)
165
-
166
-
167
- def register_crossattn_upblock2d(model):
168
- def up_forward(self):
169
- def forward(
170
- hidden_states: torch.FloatTensor,
171
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
172
- temb: Optional[torch.FloatTensor] = None,
173
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
174
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
175
- upsample_size: Optional[int] = None,
176
- attention_mask: Optional[torch.FloatTensor] = None,
177
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
178
- ):
179
- for resnet, attn in zip(self.resnets, self.attentions):
180
- # pop res hidden states
181
- #print(f"in crossatten upblock2d, hidden states shape: {hidden_states.shape}")
182
- res_hidden_states = res_hidden_states_tuple[-1]
183
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
184
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
185
-
186
- if self.training and self.gradient_checkpointing:
187
-
188
- def create_custom_forward(module, return_dict=None):
189
- def custom_forward(*inputs):
190
- if return_dict is not None:
191
- return module(*inputs, return_dict=return_dict)
192
- else:
193
- return module(*inputs)
194
-
195
- return custom_forward
196
-
197
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
198
- hidden_states = torch.utils.checkpoint.checkpoint(
199
- create_custom_forward(resnet),
200
- hidden_states,
201
- temb,
202
- **ckpt_kwargs,
203
- )
204
- hidden_states = torch.utils.checkpoint.checkpoint(
205
- create_custom_forward(attn, return_dict=False),
206
- hidden_states,
207
- encoder_hidden_states,
208
- None, # timestep
209
- None, # class_labels
210
- cross_attention_kwargs,
211
- attention_mask,
212
- encoder_attention_mask,
213
- **ckpt_kwargs,
214
- )[0]
215
- else:
216
- hidden_states = resnet(hidden_states, temb)
217
- hidden_states = attn(
218
- hidden_states,
219
- encoder_hidden_states=encoder_hidden_states,
220
- cross_attention_kwargs=cross_attention_kwargs,
221
- attention_mask=attention_mask,
222
- encoder_attention_mask=encoder_attention_mask,
223
- return_dict=False,
224
- )[0]
225
-
226
- if self.upsamplers is not None:
227
- for upsampler in self.upsamplers:
228
- hidden_states = upsampler(hidden_states, upsample_size)
229
-
230
- return hidden_states
231
-
232
- return forward
233
-
234
- for i, upsample_block in enumerate(model.unet.up_blocks):
235
- if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
236
- upsample_block.forward = up_forward(upsample_block)
237
-
238
-
239
- def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
240
- def up_forward(self):
241
- def forward(
242
- hidden_states: torch.FloatTensor,
243
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
244
- temb: Optional[torch.FloatTensor] = None,
245
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
246
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
247
- upsample_size: Optional[int] = None,
248
- attention_mask: Optional[torch.FloatTensor] = None,
249
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
250
- ):
251
- for resnet, attn in zip(self.resnets, self.attentions):
252
- # pop res hidden states
253
- #print(f"in free crossatten upblock2d, hidden states shape: {hidden_states.shape}")
254
- res_hidden_states = res_hidden_states_tuple[-1]
255
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
256
-
257
- # --------------- FreeU code -----------------------
258
- # Only operate on the first two stages
259
- if hidden_states.shape[1] == 1280:
260
- hidden_mean = hidden_states.mean(1).unsqueeze(1)
261
- B = hidden_mean.shape[0]
262
- hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
263
- hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
264
-
265
- hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
266
-
267
- hidden_states[:,:640] = hidden_states[:,:640] * ((self.b1 - 1 ) * hidden_mean + 1)
268
- res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
269
- if hidden_states.shape[1] == 640:
270
- hidden_mean = hidden_states.mean(1).unsqueeze(1)
271
- B = hidden_mean.shape[0]
272
- hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
273
- hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
274
- hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
275
-
276
- hidden_states[:,:320] = hidden_states[:,:320] * ((self.b2 - 1 ) * hidden_mean + 1)
277
- res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
278
- # ---------------------------------------------------------
279
-
280
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
281
-
282
- if self.training and self.gradient_checkpointing:
283
-
284
- def create_custom_forward(module, return_dict=None):
285
- def custom_forward(*inputs):
286
- if return_dict is not None:
287
- return module(*inputs, return_dict=return_dict)
288
- else:
289
- return module(*inputs)
290
-
291
- return custom_forward
292
-
293
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
294
- hidden_states = torch.utils.checkpoint.checkpoint(
295
- create_custom_forward(resnet),
296
- hidden_states,
297
- temb,
298
- **ckpt_kwargs,
299
- )
300
- hidden_states = torch.utils.checkpoint.checkpoint(
301
- create_custom_forward(attn, return_dict=False),
302
- hidden_states,
303
- encoder_hidden_states,
304
- None, # timestep
305
- None, # class_labels
306
- cross_attention_kwargs,
307
- attention_mask,
308
- encoder_attention_mask,
309
- **ckpt_kwargs,
310
- )[0]
311
- else:
312
- hidden_states = resnet(hidden_states, temb)
313
- # hidden_states = attn(
314
- # hidden_states,
315
- # encoder_hidden_states=encoder_hidden_states,
316
- # cross_attention_kwargs=cross_attention_kwargs,
317
- # encoder_attention_mask=encoder_attention_mask,
318
- # return_dict=False,
319
- # )[0]
320
- hidden_states = attn(
321
- hidden_states,
322
- encoder_hidden_states=encoder_hidden_states,
323
- cross_attention_kwargs=cross_attention_kwargs,
324
- )[0]
325
-
326
- if self.upsamplers is not None:
327
- for upsampler in self.upsamplers:
328
- hidden_states = upsampler(hidden_states, upsample_size)
329
-
330
- return hidden_states
331
-
332
- return forward
333
-
334
- for i, upsample_block in enumerate(model.unet.up_blocks):
335
- if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
336
- upsample_block.forward = up_forward(upsample_block)
337
- setattr(upsample_block, 'b1', b1)
338
- setattr(upsample_block, 'b2', b2)
339
- setattr(upsample_block, 's1', s1)
340
- setattr(upsample_block, 's2', s2)