update attn scale
Browse files- app.py +8 -4
- flux/modules/layers.py +4 -2
- models/kv_edit.py +38 -0
app.py
CHANGED
@@ -30,6 +30,7 @@ class SamplingOptions:
|
|
30 |
seed: int = 42
|
31 |
re_init: bool = False
|
32 |
attn_mask: bool = False
|
|
|
33 |
|
34 |
def resize_image(image_array, max_width=512, max_height=512):
|
35 |
# 将numpy数组转换为PIL图像
|
@@ -96,7 +97,7 @@ def edit(brush_canvas,
|
|
96 |
inversion_num_steps, denoise_num_steps,
|
97 |
skip_step,
|
98 |
inversion_guidance, denoise_guidance,seed,
|
99 |
-
re_init,attn_mask
|
100 |
):
|
101 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
102 |
torch.cuda.empty_cache()
|
@@ -136,7 +137,8 @@ def edit(brush_canvas,
|
|
136 |
denoise_guidance=denoise_guidance,
|
137 |
seed=seed,
|
138 |
re_init=re_init,
|
139 |
-
attn_mask=attn_mask
|
|
|
140 |
)
|
141 |
|
142 |
|
@@ -215,7 +217,8 @@ def create_demo(model_name: str):
|
|
215 |
3️⃣ Fill in your target prompt, then adjust the hyperparameters. <br>
|
216 |
4️⃣ Click the "Edit" button to generate your edited image! <br>
|
217 |
|
218 |
-
🔔🔔 [<b>Important</b>]
|
|
|
219 |
"""
|
220 |
article = r"""
|
221 |
If our work is helpful, please help to ⭐ the <a href='https://github.com/Xilluill/KV-Edit' target='_blank'>Github Repo</a>. Thanks!
|
@@ -252,6 +255,7 @@ def create_demo(model_name: str):
|
|
252 |
skip_step = gr.Slider(0, 30, 0, step=1, label="Number of skip steps")
|
253 |
inversion_guidance = gr.Slider(1.0, 10.0, 1.5, step=0.1, label="inversion Guidance", interactive=not is_schnell)
|
254 |
denoise_guidance = gr.Slider(1.0, 10.0, 5.5, step=0.1, label="denoise Guidance", interactive=not is_schnell)
|
|
|
255 |
seed = gr.Textbox('0', label="Seed (-1 for random)", visible=True)
|
256 |
with gr.Row():
|
257 |
re_init = gr.Checkbox(label="re_init", value=False)
|
@@ -268,7 +272,7 @@ def create_demo(model_name: str):
|
|
268 |
skip_step,
|
269 |
inversion_guidance,
|
270 |
denoise_guidance,seed,
|
271 |
-
re_init,attn_mask
|
272 |
],
|
273 |
outputs=[output_image]
|
274 |
)
|
|
|
30 |
seed: int = 42
|
31 |
re_init: bool = False
|
32 |
attn_mask: bool = False
|
33 |
+
attn_scale_value: float = 0.0
|
34 |
|
35 |
def resize_image(image_array, max_width=512, max_height=512):
|
36 |
# 将numpy数组转换为PIL图像
|
|
|
97 |
inversion_num_steps, denoise_num_steps,
|
98 |
skip_step,
|
99 |
inversion_guidance, denoise_guidance,seed,
|
100 |
+
re_init,attn_mask,attn_scale_value
|
101 |
):
|
102 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
103 |
torch.cuda.empty_cache()
|
|
|
137 |
denoise_guidance=denoise_guidance,
|
138 |
seed=seed,
|
139 |
re_init=re_init,
|
140 |
+
attn_mask=attn_mask,
|
141 |
+
attn_scale_value = attn_scale_value
|
142 |
)
|
143 |
|
144 |
|
|
|
217 |
3️⃣ Fill in your target prompt, then adjust the hyperparameters. <br>
|
218 |
4️⃣ Click the "Edit" button to generate your edited image! <br>
|
219 |
|
220 |
+
🔔🔔 [<b>Important</b>] Less skip steps, "re_init" and "attn_mask" will enhance the editing performance, making the results aligned with your text but may lead to discontinuous images. <br>
|
221 |
+
We recommend trying to increase "attn_scale" to increase attention between mask and background.<br>
|
222 |
"""
|
223 |
article = r"""
|
224 |
If our work is helpful, please help to ⭐ the <a href='https://github.com/Xilluill/KV-Edit' target='_blank'>Github Repo</a>. Thanks!
|
|
|
255 |
skip_step = gr.Slider(0, 30, 0, step=1, label="Number of skip steps")
|
256 |
inversion_guidance = gr.Slider(1.0, 10.0, 1.5, step=0.1, label="inversion Guidance", interactive=not is_schnell)
|
257 |
denoise_guidance = gr.Slider(1.0, 10.0, 5.5, step=0.1, label="denoise Guidance", interactive=not is_schnell)
|
258 |
+
attn_scale_value = gr.Slider(0.0, 5.0, 1, step=0.1, label="attn_scale")
|
259 |
seed = gr.Textbox('0', label="Seed (-1 for random)", visible=True)
|
260 |
with gr.Row():
|
261 |
re_init = gr.Checkbox(label="re_init", value=False)
|
|
|
272 |
skip_step,
|
273 |
inversion_guidance,
|
274 |
denoise_guidance,seed,
|
275 |
+
re_init,attn_mask,attn_scale_value
|
276 |
],
|
277 |
outputs=[output_image]
|
278 |
)
|
flux/modules/layers.py
CHANGED
@@ -300,7 +300,8 @@ class DoubleStreamBlock_kv(DoubleStreamBlock):
|
|
300 |
if 'attention_mask' in info:
|
301 |
attn = attention(q, k, v, pe=pe,attention_mask=info['attention_mask'])
|
302 |
else:
|
303 |
-
attn = attention(q, k, v, pe=pe)
|
|
|
304 |
|
305 |
# elif feature_k_name in info['feature']:
|
306 |
else:
|
@@ -377,7 +378,8 @@ class SingleStreamBlock_kv(SingleStreamBlock):
|
|
377 |
|
378 |
k = torch.cat((txt_k, source_img_k), dim=2)
|
379 |
v = torch.cat((txt_v, source_img_v), dim=2)
|
380 |
-
attn = attention(q, k, v, pe=pe, pe_q = info['pe_mask'])
|
|
|
381 |
|
382 |
# compute attention
|
383 |
# attn = attention(q, k, v, pe=pe)
|
|
|
300 |
if 'attention_mask' in info:
|
301 |
attn = attention(q, k, v, pe=pe,attention_mask=info['attention_mask'])
|
302 |
else:
|
303 |
+
# attn = attention(q, k, v, pe=pe)
|
304 |
+
attn = attention(q, k, v, pe=pe, pe_q = info['pe_mask'],attention_mask=info['attention_scale'])
|
305 |
|
306 |
# elif feature_k_name in info['feature']:
|
307 |
else:
|
|
|
378 |
|
379 |
k = torch.cat((txt_k, source_img_k), dim=2)
|
380 |
v = torch.cat((txt_v, source_img_v), dim=2)
|
381 |
+
# attn = attention(q, k, v, pe=pe, pe_q = info['pe_mask'])
|
382 |
+
attn = attention(q, k, v, pe=pe, pe_q = info['pe_mask'],attention_mask=info['attention_scale'])
|
383 |
|
384 |
# compute attention
|
385 |
# attn = attention(q, k, v, pe=pe)
|
models/kv_edit.py
CHANGED
@@ -76,6 +76,37 @@ class only_Flux(torch.nn.Module): # 仅包括初始化函数
|
|
76 |
attention_mask[background_token_indices.unsqueeze(1), background_token_indices] = True # 关注背景区域
|
77 |
|
78 |
return attention_mask.unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
class Flux_kv_edit_inf(only_Flux):
|
81 |
def __init__(self, device,name):
|
@@ -200,6 +231,13 @@ class Flux_kv_edit(only_Flux):
|
|
200 |
inp_target["img"] = zt_noise[:, mask_indices,...]
|
201 |
else:
|
202 |
inp_target["img"] = zt[:, mask_indices,...]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
|
204 |
info['inverse'] = False
|
205 |
x, _ = denoise_kv(self.model, **inp_target, timesteps=denoise_timesteps, guidance=opts.denoise_guidance, inverse=False, info=info)
|
|
|
76 |
attention_mask[background_token_indices.unsqueeze(1), background_token_indices] = True # 关注背景区域
|
77 |
|
78 |
return attention_mask.unsqueeze(0)
|
79 |
+
|
80 |
+
def create_attention_scale(self,seq_len, mask_indices, text_len=512, device='cuda',scale = 0):
|
81 |
+
"""
|
82 |
+
创建注意力局部缩放
|
83 |
+
|
84 |
+
Args:
|
85 |
+
seq_len (int): 序列长度。
|
86 |
+
mask_indices (List[int]): 图像令牌中掩码区域的索引。
|
87 |
+
text_len (int): 文本令牌的长度,默认 512。
|
88 |
+
device (str): 设备类型,如 'cuda' 或 'cpu'。
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
torch.Tensor: 形状为 (seq_len, seq_len) 的注意力缩放
|
92 |
+
"""
|
93 |
+
# 初始缩放为全 1
|
94 |
+
attention_scale = torch.zeros(1, seq_len, dtype=torch.bfloat16, device=device) # 相加时广播
|
95 |
+
|
96 |
+
# 文本令牌索引
|
97 |
+
text_indices = torch.arange(0, text_len, device=device)
|
98 |
+
|
99 |
+
# 掩码区域令牌索引
|
100 |
+
mask_token_indices = torch.tensor([idx + text_len for idx in mask_indices], device=device)
|
101 |
+
|
102 |
+
# 背景区域令牌索引
|
103 |
+
all_indices = torch.arange(text_len, seq_len, device=device)
|
104 |
+
background_token_indices = torch.tensor([idx for idx in all_indices if idx not in mask_token_indices])
|
105 |
+
|
106 |
+
attention_scale[0, background_token_indices] = scale #
|
107 |
+
print(f"attention_scale:{scale}")
|
108 |
+
|
109 |
+
return attention_scale.unsqueeze(0)
|
110 |
|
111 |
class Flux_kv_edit_inf(only_Flux):
|
112 |
def __init__(self, device,name):
|
|
|
231 |
inp_target["img"] = zt_noise[:, mask_indices,...]
|
232 |
else:
|
233 |
inp_target["img"] = zt[:, mask_indices,...]
|
234 |
+
|
235 |
+
if opts.attn_scale_value != 0:
|
236 |
+
attention_scale = self.create_attention_scale(L+512, mask_indices, device=mask.device,scale = opts.attn_scale_value)
|
237 |
+
info['attention_scale'] = attention_scale
|
238 |
+
else:
|
239 |
+
info['attention_scale'] = None
|
240 |
+
|
241 |
|
242 |
info['inverse'] = False
|
243 |
x, _ = denoise_kv(self.model, **inp_target, timesteps=denoise_timesteps, guidance=opts.denoise_guidance, inverse=False, info=info)
|