ginipick commited on
Commit
37f1d99
ยท
verified ยท
1 Parent(s): 97f2938

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -689,12 +689,13 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
689
  seed: Optional[int] = None,
690
  c : Optional[float] = 0.3,
691
  ):
692
-
693
  """
694
  Stable Diffusion XL ๊ธฐ๋ฐ˜ AccDiffusion ํŒŒ์ดํ”„๋ผ์ธ์„ ํ†ตํ•ด ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๋Š” ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค.
695
 
696
  ์ด ํ•จ์ˆ˜๋Š” ์ฃผ์–ด์ง„ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ธ์ฝ”๋”ฉํ•˜๊ณ , ๋””๋…ธ์ด์ง• ๋ฐ progressive upscaling ๊ณผ์ •์„ ๊ฑฐ์ณ ์ตœ์ข… ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
697
  ์ž์„ธํ•œ ์‚ฌ์šฉ๋ฒ•์€ ๋ฌธ์„œ๋ฅผ ์ฐธ๊ณ ํ•˜์„ธ์š”.
 
 
698
  """
699
 
700
  if debug:
@@ -1093,8 +1094,8 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
1093
  count_local[:, :, h_start:h_end, w_start:w_end] += 1
1094
 
1095
  if random_jitter:
1096
- value_local = value_local[: ,:, jitter_range: jitter_range + current_height // self.vae_scale_factor, jitter_range: jitter_range + current_width // self.vae_scale_factor]
1097
- count_local = count_local[: ,:, jitter_range: jitter_range + current_height // self.vae_scale_factor, jitter_range: jitter_range + current_width // self.vae_scale_factor]
1098
 
1099
  noise_index = i + 1 if i != (len(timesteps) - 1) else i
1100
 
@@ -1145,12 +1146,12 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
1145
 
1146
  if shuffle:
1147
  shape = latents_for_view.shape
 
1148
  shuffle_index = torch.stack([torch.randperm(shape[0]) for _ in range(latents_for_view.reshape(-1).shape[0]//shape[0])])
1149
-
1150
- shuffle_index = shuffle_index.view(shape[1],shape[2],shape[3],shape[0])
1151
  original_index = torch.zeros_like(shuffle_index).scatter_(3, shuffle_index, torch.arange(shape[0]).repeat(shape[1], shape[2], shape[3], 1))
1152
- shuffle_index = shuffle_index.permute(3,0,1,2).to(device)
1153
- original_index = original_index.permute(3,0,1,2).to(device)
1154
  latents_for_view_gaussian = latents_for_view_gaussian.gather(0, shuffle_index)
1155
 
1156
  vb_size = latents_for_view.size(0)
@@ -1195,7 +1196,7 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
1195
  value_global[:, :, h::current_scale_num, w::current_scale_num] += latents_view_denoised
1196
  count_global[:, :, h::current_scale_num, w::current_scale_num] += 1
1197
 
1198
- value_global = value_global[: ,:, h_pad:, w_pad:]
1199
 
1200
  if use_multidiffusion:
1201
  c2 = cosine_factor ** cosine_scale_2
 
689
  seed: Optional[int] = None,
690
  c : Optional[float] = 0.3,
691
  ):
 
692
  """
693
  Stable Diffusion XL ๊ธฐ๋ฐ˜ AccDiffusion ํŒŒ์ดํ”„๋ผ์ธ์„ ํ†ตํ•ด ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๋Š” ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค.
694
 
695
  ์ด ํ•จ์ˆ˜๋Š” ์ฃผ์–ด์ง„ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ธ์ฝ”๋”ฉํ•˜๊ณ , ๋””๋…ธ์ด์ง• ๋ฐ progressive upscaling ๊ณผ์ •์„ ๊ฑฐ์ณ ์ตœ์ข… ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
696
  ์ž์„ธํ•œ ์‚ฌ์šฉ๋ฒ•์€ ๋ฌธ์„œ๋ฅผ ์ฐธ๊ณ ํ•˜์„ธ์š”.
697
+
698
+ Examples:
699
  """
700
 
701
  if debug:
 
1094
  count_local[:, :, h_start:h_end, w_start:w_end] += 1
1095
 
1096
  if random_jitter:
1097
+ value_local = value_local[:, :, jitter_range:jitter_range + current_height // self.vae_scale_factor, jitter_range:jitter_range + current_width // self.vae_scale_factor]
1098
+ count_local = count_local[:, :, jitter_range:jitter_range + current_height // self.vae_scale_factor, jitter_range:jitter_range + current_width // self.vae_scale_factor]
1099
 
1100
  noise_index = i + 1 if i != (len(timesteps) - 1) else i
1101
 
 
1146
 
1147
  if shuffle:
1148
  shape = latents_for_view.shape
1149
+ # ์ˆ˜์ •: range(...)์˜ ๊ด„ํ˜ธ๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
1150
  shuffle_index = torch.stack([torch.randperm(shape[0]) for _ in range(latents_for_view.reshape(-1).shape[0]//shape[0])])
1151
+ shuffle_index = shuffle_index.view(shape[1], shape[2], shape[3], shape[0])
 
1152
  original_index = torch.zeros_like(shuffle_index).scatter_(3, shuffle_index, torch.arange(shape[0]).repeat(shape[1], shape[2], shape[3], 1))
1153
+ shuffle_index = shuffle_index.permute(3, 0, 1, 2).to(device)
1154
+ original_index = original_index.permute(3, 0, 1, 2).to(device)
1155
  latents_for_view_gaussian = latents_for_view_gaussian.gather(0, shuffle_index)
1156
 
1157
  vb_size = latents_for_view.size(0)
 
1196
  value_global[:, :, h::current_scale_num, w::current_scale_num] += latents_view_denoised
1197
  count_global[:, :, h::current_scale_num, w::current_scale_num] += 1
1198
 
1199
+ value_global = value_global[:, :, h_pad:, w_pad:]
1200
 
1201
  if use_multidiffusion:
1202
  c2 = cosine_factor ** cosine_scale_2