xilluill commited on
Commit
9b891da
·
1 Parent(s): ca71d6b

update attn scale

Browse files
Files changed (3) hide show
  1. app.py +8 -4
  2. flux/modules/layers.py +4 -2
  3. models/kv_edit.py +38 -0
app.py CHANGED
@@ -30,6 +30,7 @@ class SamplingOptions:
30
  seed: int = 42
31
  re_init: bool = False
32
  attn_mask: bool = False
 
33
 
34
  def resize_image(image_array, max_width=512, max_height=512):
35
  # 将numpy数组转换为PIL图像
@@ -96,7 +97,7 @@ def edit(brush_canvas,
96
  inversion_num_steps, denoise_num_steps,
97
  skip_step,
98
  inversion_guidance, denoise_guidance,seed,
99
- re_init,attn_mask
100
  ):
101
  device = "cuda" if torch.cuda.is_available() else "cpu"
102
  torch.cuda.empty_cache()
@@ -136,7 +137,8 @@ def edit(brush_canvas,
136
  denoise_guidance=denoise_guidance,
137
  seed=seed,
138
  re_init=re_init,
139
- attn_mask=attn_mask
 
140
  )
141
 
142
 
@@ -215,7 +217,8 @@ def create_demo(model_name: str):
215
  3️⃣ Fill in your target prompt, then adjust the hyperparameters. <br>
216
  4️⃣ Click the "Edit" button to generate your edited image! <br>
217
 
218
- 🔔🔔 [<b>Important</b>] We suggest trying less skip steps, "re_init" and "attn_mask" only when the result is too similar to the original content (e.g. removing objects or changing color).<br>
 
219
  """
220
  article = r"""
221
  If our work is helpful, please help to ⭐ the <a href='https://github.com/Xilluill/KV-Edit' target='_blank'>Github Repo</a>. Thanks!
@@ -252,6 +255,7 @@ def create_demo(model_name: str):
252
  skip_step = gr.Slider(0, 30, 0, step=1, label="Number of skip steps")
253
  inversion_guidance = gr.Slider(1.0, 10.0, 1.5, step=0.1, label="inversion Guidance", interactive=not is_schnell)
254
  denoise_guidance = gr.Slider(1.0, 10.0, 5.5, step=0.1, label="denoise Guidance", interactive=not is_schnell)
 
255
  seed = gr.Textbox('0', label="Seed (-1 for random)", visible=True)
256
  with gr.Row():
257
  re_init = gr.Checkbox(label="re_init", value=False)
@@ -268,7 +272,7 @@ def create_demo(model_name: str):
268
  skip_step,
269
  inversion_guidance,
270
  denoise_guidance,seed,
271
- re_init,attn_mask
272
  ],
273
  outputs=[output_image]
274
  )
 
30
  seed: int = 42
31
  re_init: bool = False
32
  attn_mask: bool = False
33
+ attn_scale_value: float = 0.0
34
 
35
  def resize_image(image_array, max_width=512, max_height=512):
36
  # 将numpy数组转换为PIL图像
 
97
  inversion_num_steps, denoise_num_steps,
98
  skip_step,
99
  inversion_guidance, denoise_guidance,seed,
100
+ re_init,attn_mask,attn_scale_value
101
  ):
102
  device = "cuda" if torch.cuda.is_available() else "cpu"
103
  torch.cuda.empty_cache()
 
137
  denoise_guidance=denoise_guidance,
138
  seed=seed,
139
  re_init=re_init,
140
+ attn_mask=attn_mask,
141
+ attn_scale_value = attn_scale_value
142
  )
143
 
144
 
 
217
  3️⃣ Fill in your target prompt, then adjust the hyperparameters. <br>
218
  4️⃣ Click the "Edit" button to generate your edited image! <br>
219
 
220
+ 🔔🔔 [<b>Important</b>] Less skip steps, "re_init" and "attn_mask" will enhance the editing performance, making the results aligned with your text but may lead to discontinuous images. <br>
221
+ We recommend trying to increase "attn_scale" to increase attention between mask and background.<br>
222
  """
223
  article = r"""
224
  If our work is helpful, please help to ⭐ the <a href='https://github.com/Xilluill/KV-Edit' target='_blank'>Github Repo</a>. Thanks!
 
255
  skip_step = gr.Slider(0, 30, 0, step=1, label="Number of skip steps")
256
  inversion_guidance = gr.Slider(1.0, 10.0, 1.5, step=0.1, label="inversion Guidance", interactive=not is_schnell)
257
  denoise_guidance = gr.Slider(1.0, 10.0, 5.5, step=0.1, label="denoise Guidance", interactive=not is_schnell)
258
+ attn_scale_value = gr.Slider(0.0, 5.0, 1, step=0.1, label="attn_scale")
259
  seed = gr.Textbox('0', label="Seed (-1 for random)", visible=True)
260
  with gr.Row():
261
  re_init = gr.Checkbox(label="re_init", value=False)
 
272
  skip_step,
273
  inversion_guidance,
274
  denoise_guidance,seed,
275
+ re_init,attn_mask,attn_scale_value
276
  ],
277
  outputs=[output_image]
278
  )
flux/modules/layers.py CHANGED
@@ -300,7 +300,8 @@ class DoubleStreamBlock_kv(DoubleStreamBlock):
300
  if 'attention_mask' in info:
301
  attn = attention(q, k, v, pe=pe,attention_mask=info['attention_mask'])
302
  else:
303
- attn = attention(q, k, v, pe=pe)
 
304
 
305
  # elif feature_k_name in info['feature']:
306
  else:
@@ -377,7 +378,8 @@ class SingleStreamBlock_kv(SingleStreamBlock):
377
 
378
  k = torch.cat((txt_k, source_img_k), dim=2)
379
  v = torch.cat((txt_v, source_img_v), dim=2)
380
- attn = attention(q, k, v, pe=pe, pe_q = info['pe_mask'])
 
381
 
382
  # compute attention
383
  # attn = attention(q, k, v, pe=pe)
 
300
  if 'attention_mask' in info:
301
  attn = attention(q, k, v, pe=pe,attention_mask=info['attention_mask'])
302
  else:
303
+ # attn = attention(q, k, v, pe=pe)
304
+ attn = attention(q, k, v, pe=pe, pe_q = info['pe_mask'],attention_mask=info['attention_scale'])
305
 
306
  # elif feature_k_name in info['feature']:
307
  else:
 
378
 
379
  k = torch.cat((txt_k, source_img_k), dim=2)
380
  v = torch.cat((txt_v, source_img_v), dim=2)
381
+ # attn = attention(q, k, v, pe=pe, pe_q = info['pe_mask'])
382
+ attn = attention(q, k, v, pe=pe, pe_q = info['pe_mask'],attention_mask=info['attention_scale'])
383
 
384
  # compute attention
385
  # attn = attention(q, k, v, pe=pe)
models/kv_edit.py CHANGED
@@ -76,6 +76,37 @@ class only_Flux(torch.nn.Module): # 仅包括初始化函数
76
  attention_mask[background_token_indices.unsqueeze(1), background_token_indices] = True # 关注背景区域
77
 
78
  return attention_mask.unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  class Flux_kv_edit_inf(only_Flux):
81
  def __init__(self, device,name):
@@ -200,6 +231,13 @@ class Flux_kv_edit(only_Flux):
200
  inp_target["img"] = zt_noise[:, mask_indices,...]
201
  else:
202
  inp_target["img"] = zt[:, mask_indices,...]
 
 
 
 
 
 
 
203
 
204
  info['inverse'] = False
205
  x, _ = denoise_kv(self.model, **inp_target, timesteps=denoise_timesteps, guidance=opts.denoise_guidance, inverse=False, info=info)
 
76
  attention_mask[background_token_indices.unsqueeze(1), background_token_indices] = True # 关注背景区域
77
 
78
  return attention_mask.unsqueeze(0)
79
+
80
+ def create_attention_scale(self,seq_len, mask_indices, text_len=512, device='cuda',scale = 0):
81
+ """
82
+ 创建注意力局部缩放
83
+
84
+ Args:
85
+ seq_len (int): 序列长度。
86
+ mask_indices (List[int]): 图像令牌中掩码区域的索引。
87
+ text_len (int): 文本令牌的长度,默认 512。
88
+ device (str): 设备类型,如 'cuda' 或 'cpu'。
89
+
90
+ Returns:
91
+ torch.Tensor: 形状为 (seq_len, seq_len) 的注意力缩放
92
+ """
93
+ # 初始缩放为全 1
94
+ attention_scale = torch.zeros(1, seq_len, dtype=torch.bfloat16, device=device) # 相加时广播
95
+
96
+ # 文本令牌索引
97
+ text_indices = torch.arange(0, text_len, device=device)
98
+
99
+ # 掩码区域令牌索引
100
+ mask_token_indices = torch.tensor([idx + text_len for idx in mask_indices], device=device)
101
+
102
+ # 背景区域令牌索引
103
+ all_indices = torch.arange(text_len, seq_len, device=device)
104
+ background_token_indices = torch.tensor([idx for idx in all_indices if idx not in mask_token_indices])
105
+
106
+ attention_scale[0, background_token_indices] = scale #
107
+ print(f"attention_scale:{scale}")
108
+
109
+ return attention_scale.unsqueeze(0)
110
 
111
  class Flux_kv_edit_inf(only_Flux):
112
  def __init__(self, device,name):
 
231
  inp_target["img"] = zt_noise[:, mask_indices,...]
232
  else:
233
  inp_target["img"] = zt[:, mask_indices,...]
234
+
235
+ if opts.attn_scale_value != 0:
236
+ attention_scale = self.create_attention_scale(L+512, mask_indices, device=mask.device,scale = opts.attn_scale_value)
237
+ info['attention_scale'] = attention_scale
238
+ else:
239
+ info['attention_scale'] = None
240
+
241
 
242
  info['inverse'] = False
243
  x, _ = denoise_kv(self.model, **inp_target, timesteps=denoise_timesteps, guidance=opts.denoise_guidance, inverse=False, info=info)