Spaces:
Runtime error
Runtime error
Commit
·
c4e6a63
1
Parent(s):
04d6ff1
add code
Browse files- README.md +1 -1
- gradio_app.py +204 -0
- main.py +150 -0
- requirements.txt +121 -0
- src/attention_based_segmentation.py +67 -0
- src/attention_utils.py +99 -0
- src/diffusion_model_wrapper.py +252 -0
- src/null_text_inversion.py +201 -0
- src/prompt_mixing.py +86 -0
- src/prompt_to_prompt_controllers.py +205 -0
- src/prompt_utils.py +64 -0
- src/seq_aligner.py +195 -0
- style.css +3 -0
- vocab.json +0 -0
README.md
CHANGED
@@ -5,7 +5,7 @@ colorFrom: indigo
|
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.23.0
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
---
|
|
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.23.0
|
8 |
+
app_file: gradio_app.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
---
|
gradio_app.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
import nltk
|
9 |
+
nltk.download('punkt')
|
10 |
+
nltk.download('averaged_perceptron_tagger')
|
11 |
+
|
12 |
+
from main import LPMConfig, main
|
13 |
+
|
14 |
+
DESCRIPTION = '''# Localizing Object-level Shape Variations with Text-to-Image Diffusion Models
|
15 |
+
This is a demo for our ''Localizing Object-level Shape Variations with Text-to-Image Diffusion Models'' [paper](https://arxiv.org/abs/2303.11306).
|
16 |
+
We introduce a method that generates object-level shape variation for a given image.
|
17 |
+
This demo allows using a real image as well as a generated image. For a real image, a matching prompt is required.
|
18 |
+
'''
|
19 |
+
|
20 |
+
def main_pipeline(
|
21 |
+
prompt: str,
|
22 |
+
object_of_interest: str,
|
23 |
+
proxy_words: str,
|
24 |
+
number_of_variations: int,
|
25 |
+
start_prompt_range: int,
|
26 |
+
end_prompt_range: int,
|
27 |
+
objects_to_preserve: str,
|
28 |
+
background_nouns: str,
|
29 |
+
seed: int,
|
30 |
+
input_image: str):
|
31 |
+
prompt = prompt.replace(object_of_interest, '{word}')
|
32 |
+
print(number_of_variations)
|
33 |
+
print(proxy_words)
|
34 |
+
proxy_words = proxy_words.split(',') if proxy_words != '' else []
|
35 |
+
objects_to_preserve = objects_to_preserve.split(',') if objects_to_preserve != '' else []
|
36 |
+
background_nouns = background_nouns.split(',') if background_nouns != '' else []
|
37 |
+
args = LPMConfig(
|
38 |
+
seed=seed,
|
39 |
+
prompt=prompt,
|
40 |
+
object_of_interest=object_of_interest,
|
41 |
+
proxy_words=proxy_words,
|
42 |
+
number_of_variations=number_of_variations,
|
43 |
+
start_prompt_range=start_prompt_range,
|
44 |
+
end_prompt_range=end_prompt_range,
|
45 |
+
objects_to_preserve=objects_to_preserve,
|
46 |
+
background_nouns=background_nouns,
|
47 |
+
real_image_path="" if input_image is None else input_image
|
48 |
+
)
|
49 |
+
|
50 |
+
result_images, result_proxy_words = main(args)
|
51 |
+
result_images = [im.permute(1, 2, 0).cpu().numpy() for im in result_images]
|
52 |
+
result_images = [(im * 255).astype(np.uint8) for im in result_images]
|
53 |
+
result_images = [Image.fromarray(im) for im in result_images]
|
54 |
+
|
55 |
+
return result_images, ",".join(result_proxy_words)
|
56 |
+
|
57 |
+
|
58 |
+
with gr.Blocks(css='style.css') as demo:
|
59 |
+
gr.Markdown(DESCRIPTION)
|
60 |
+
|
61 |
+
with gr.Row():
|
62 |
+
with gr.Column():
|
63 |
+
input_image = gr.Image(
|
64 |
+
label="Input image (optional)",
|
65 |
+
type="filepath"
|
66 |
+
)
|
67 |
+
prompt = gr.Text(
|
68 |
+
label='Prompt',
|
69 |
+
max_lines=1,
|
70 |
+
placeholder='A table below a lamp',
|
71 |
+
)
|
72 |
+
object_of_interest = gr.Text(
|
73 |
+
label='Object of interest',
|
74 |
+
max_lines=1,
|
75 |
+
placeholder='lamp',
|
76 |
+
)
|
77 |
+
proxy_words = gr.Text(
|
78 |
+
label='Proxy words - words used to obtain variations (a comma-separated list of words, can leave empty)',
|
79 |
+
max_lines=1,
|
80 |
+
placeholder=''
|
81 |
+
)
|
82 |
+
number_of_variations = gr.Slider(
|
83 |
+
label='Number of variations (used only for automatic proxy-words)',
|
84 |
+
minimum=2,
|
85 |
+
maximum=30,
|
86 |
+
value=20,
|
87 |
+
step=1
|
88 |
+
)
|
89 |
+
start_prompt_range = gr.Slider(
|
90 |
+
label='Number of steps before starting shape interval',
|
91 |
+
minimum=0,
|
92 |
+
maximum=50,
|
93 |
+
value=7,
|
94 |
+
step=1
|
95 |
+
)
|
96 |
+
end_prompt_range = gr.Slider(
|
97 |
+
label='Number of steps before ending shape interval',
|
98 |
+
minimum=1,
|
99 |
+
maximum=50,
|
100 |
+
value=17,
|
101 |
+
step=1
|
102 |
+
)
|
103 |
+
objects_to_preserve = gr.Text(
|
104 |
+
label='Words corresponding to objects to preserve (a comma-separated list of words, can leave empty)',
|
105 |
+
max_lines=1,
|
106 |
+
placeholder='table',
|
107 |
+
)
|
108 |
+
background_nouns = gr.Text(
|
109 |
+
label='Words corresponding to objects that should be copied from original image (a comma-separated list of words, can leave empty)',
|
110 |
+
max_lines=1,
|
111 |
+
placeholder='',
|
112 |
+
)
|
113 |
+
seed = gr.Slider(
|
114 |
+
label='Seed',
|
115 |
+
minimum=1,
|
116 |
+
maximum=100000,
|
117 |
+
value=0,
|
118 |
+
step=1
|
119 |
+
)
|
120 |
+
|
121 |
+
run_button = gr.Button('Generate')
|
122 |
+
with gr.Column():
|
123 |
+
result = gr.Gallery(label='Result').style(grid=4)
|
124 |
+
proxy_words_result = gr.Text(label='Used proxy words')
|
125 |
+
|
126 |
+
examples = [
|
127 |
+
[
|
128 |
+
"hamster eating watermelon on the beach",
|
129 |
+
"watermelon",
|
130 |
+
"",
|
131 |
+
20,
|
132 |
+
6,
|
133 |
+
16,
|
134 |
+
"",
|
135 |
+
"hamster,beach",
|
136 |
+
48,
|
137 |
+
None
|
138 |
+
],
|
139 |
+
[
|
140 |
+
"A decorated lamp in the livingroom",
|
141 |
+
"lamp",
|
142 |
+
"",
|
143 |
+
20,
|
144 |
+
4,
|
145 |
+
14,
|
146 |
+
"livingroom",
|
147 |
+
"",
|
148 |
+
42,
|
149 |
+
None
|
150 |
+
],
|
151 |
+
[
|
152 |
+
"a snake in the field eats an apple",
|
153 |
+
"snake",
|
154 |
+
"",
|
155 |
+
20,
|
156 |
+
7,
|
157 |
+
17,
|
158 |
+
"apple",
|
159 |
+
"apple,field",
|
160 |
+
10,
|
161 |
+
None
|
162 |
+
]
|
163 |
+
]
|
164 |
+
|
165 |
+
gr.Examples(examples=examples,
|
166 |
+
inputs=[
|
167 |
+
prompt,
|
168 |
+
object_of_interest,
|
169 |
+
proxy_words,
|
170 |
+
number_of_variations,
|
171 |
+
start_prompt_range,
|
172 |
+
end_prompt_range,
|
173 |
+
objects_to_preserve,
|
174 |
+
background_nouns,
|
175 |
+
seed,
|
176 |
+
input_image
|
177 |
+
],
|
178 |
+
outputs=[
|
179 |
+
result,
|
180 |
+
proxy_words_result
|
181 |
+
],
|
182 |
+
fn=main_pipeline,
|
183 |
+
cache_examples=False)
|
184 |
+
|
185 |
+
|
186 |
+
inputs = [
|
187 |
+
prompt,
|
188 |
+
object_of_interest,
|
189 |
+
proxy_words,
|
190 |
+
number_of_variations,
|
191 |
+
start_prompt_range,
|
192 |
+
end_prompt_range,
|
193 |
+
objects_to_preserve,
|
194 |
+
background_nouns,
|
195 |
+
seed,
|
196 |
+
input_image
|
197 |
+
]
|
198 |
+
outputs = [
|
199 |
+
result,
|
200 |
+
proxy_words_result
|
201 |
+
]
|
202 |
+
run_button.click(fn=main_pipeline, inputs=inputs, outputs=outputs)
|
203 |
+
|
204 |
+
demo.queue(max_size=50).launch(share=False)
|
main.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
import pyrallis
|
7 |
+
import torch
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
from torchvision.utils import save_image
|
10 |
+
from torchvision.transforms import ToTensor
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
from src.prompt_to_prompt_controllers import AttentionStore, AttentionReplace
|
14 |
+
from src.null_text_inversion import invert_image
|
15 |
+
from src.prompt_utils import get_proxy_prompts
|
16 |
+
from src.prompt_mixing import PromptMixing
|
17 |
+
from src.diffusion_model_wrapper import DiffusionModelWrapper, get_stable_diffusion_model, get_stable_diffusion_config, \
|
18 |
+
generate_original_image
|
19 |
+
|
20 |
+
|
21 |
+
def save_args_dict(args, similar_words):
|
22 |
+
exp_path = os.path.join(args.exp_dir, args.prompt.replace(' ', '-'), f"seed={args.seed}_{args.exp_name}")
|
23 |
+
os.makedirs(exp_path, exist_ok=True)
|
24 |
+
|
25 |
+
args_dict = vars(args)
|
26 |
+
args_dict['similar_words'] = similar_words
|
27 |
+
with open(os.path.join(exp_path, "opt.json"), 'w') as fp:
|
28 |
+
json.dump(args_dict, fp, sort_keys=True, indent=4)
|
29 |
+
|
30 |
+
return exp_path
|
31 |
+
|
32 |
+
|
33 |
+
def main(args):
|
34 |
+
ldm_stable = get_stable_diffusion_model(args)
|
35 |
+
ldm_stable_config = get_stable_diffusion_config(args)
|
36 |
+
|
37 |
+
similar_words, prompts, another_prompts = get_proxy_prompts(args, ldm_stable)
|
38 |
+
exp_path = save_args_dict(args, similar_words)
|
39 |
+
|
40 |
+
images = []
|
41 |
+
x_t = None
|
42 |
+
uncond_embeddings = None
|
43 |
+
|
44 |
+
if args.real_image_path != "":
|
45 |
+
x_t, uncond_embeddings = invert_image(args, ldm_stable, ldm_stable_config, prompts, exp_path)
|
46 |
+
|
47 |
+
image, x_t, orig_all_latents, orig_mask, average_attention = generate_original_image(args, ldm_stable, ldm_stable_config, prompts, x_t, uncond_embeddings)
|
48 |
+
save_image(ToTensor()(image[0]), f"{exp_path}/{similar_words[0]}.jpg")
|
49 |
+
save_image(torch.from_numpy(orig_mask).float(), f"{exp_path}/{similar_words[0]}_mask.jpg")
|
50 |
+
images.append(image[0])
|
51 |
+
|
52 |
+
object_of_interest_index = args.prompt.split().index('{word}') + 1
|
53 |
+
pm = PromptMixing(args, object_of_interest_index, average_attention)
|
54 |
+
|
55 |
+
do_other_obj_self_attn_masking = len(args.objects_to_preserve) > 0 and args.end_preserved_obj_self_attn_masking > 0
|
56 |
+
do_self_or_cross_attn_inject = args.cross_attn_inject_steps != 0.0 or args.self_attn_inject_steps != 0.0
|
57 |
+
if do_other_obj_self_attn_masking:
|
58 |
+
print("Do self attn other obj masking")
|
59 |
+
if do_self_or_cross_attn_inject:
|
60 |
+
print(f'Do self attn inject for {args.self_attn_inject_steps} steps')
|
61 |
+
print(f'Do cross attn inject for {args.cross_attn_inject_steps} steps')
|
62 |
+
|
63 |
+
another_prompts_dataloader = DataLoader(another_prompts[1:], batch_size=args.batch_size, shuffle=False)
|
64 |
+
|
65 |
+
for another_prompt_batch in tqdm(another_prompts_dataloader):
|
66 |
+
batch_size = len(another_prompt_batch["word"])
|
67 |
+
batch_prompts = prompts * batch_size
|
68 |
+
batch_another_prompt = another_prompt_batch["prompt"]
|
69 |
+
if do_self_or_cross_attn_inject or do_other_obj_self_attn_masking:
|
70 |
+
batch_prompts.append(prompts[0])
|
71 |
+
batch_another_prompt.insert(0, prompts[0])
|
72 |
+
|
73 |
+
if do_self_or_cross_attn_inject:
|
74 |
+
controller = AttentionReplace(batch_another_prompt, ldm_stable.tokenizer, ldm_stable.device,
|
75 |
+
ldm_stable_config["low_resource"], ldm_stable_config["num_diffusion_steps"],
|
76 |
+
cross_replace_steps=args.cross_attn_inject_steps,
|
77 |
+
self_replace_steps=args.self_attn_inject_steps)
|
78 |
+
else:
|
79 |
+
controller = AttentionStore(ldm_stable_config["low_resource"])
|
80 |
+
|
81 |
+
diffusion_model_wrapper = DiffusionModelWrapper(args, ldm_stable, ldm_stable_config, controller, prompt_mixing=pm)
|
82 |
+
with torch.no_grad():
|
83 |
+
image, x_t, _, mask = diffusion_model_wrapper.forward(batch_prompts, latent=x_t, other_prompt=batch_another_prompt,
|
84 |
+
post_background=args.background_post_process, orig_all_latents=orig_all_latents,
|
85 |
+
orig_mask=orig_mask, uncond_embeddings=uncond_embeddings)
|
86 |
+
|
87 |
+
for i in range(batch_size):
|
88 |
+
image_index = i + 1 if do_self_or_cross_attn_inject or do_other_obj_self_attn_masking else i
|
89 |
+
save_image(ToTensor()(image[image_index]), f"{exp_path}/{another_prompt_batch['word'][i]}.jpg")
|
90 |
+
if mask is not None:
|
91 |
+
save_image(torch.from_numpy(mask).float(), f"{exp_path}/{another_prompt_batch['word'][i]}_mask.jpg")
|
92 |
+
images.append(image[image_index])
|
93 |
+
|
94 |
+
images = [ToTensor()(image) for image in images]
|
95 |
+
save_image(images, f"{exp_path}/grid.jpg", nrow=min(max([i for i in range(2, 8) if len(images) % i == 0]), 8))
|
96 |
+
return images, similar_words
|
97 |
+
|
98 |
+
|
99 |
+
@dataclass
|
100 |
+
class LPMConfig:
|
101 |
+
|
102 |
+
# general config
|
103 |
+
seed: int = 10
|
104 |
+
batch_size: int = 1
|
105 |
+
exp_dir: str = "results"
|
106 |
+
exp_name: str = ""
|
107 |
+
display_images: bool = False
|
108 |
+
gpu_id: int = 0
|
109 |
+
|
110 |
+
# Stable Diffusion config
|
111 |
+
auth_token: str = ""
|
112 |
+
low_resource: bool = True
|
113 |
+
num_diffusion_steps: int = 50
|
114 |
+
guidance_scale: float = 7.5
|
115 |
+
max_num_words: int = 77
|
116 |
+
|
117 |
+
# prompt-mixing
|
118 |
+
prompt: str = "a {word} in the field eats an apple"
|
119 |
+
object_of_interest: str = "snake" # The object for which we generate variations
|
120 |
+
proxy_words: List[str] = field(default_factory=lambda :[]) # Leave empty for automatic proxy words
|
121 |
+
number_of_variations: int = 20
|
122 |
+
start_prompt_range: int = 7 # Number of steps to begin prompt-mixing
|
123 |
+
end_prompt_range: int = 17 # Number of steps to finish prompt-mixing
|
124 |
+
|
125 |
+
# attention based shape localization
|
126 |
+
objects_to_preserve: List[str] = field(default_factory=lambda :[]) # Objects for which apply attention based shape localization
|
127 |
+
remove_obj_from_self_mask: bool = True # If set to True, removes the object of interest from the self-attention mask
|
128 |
+
obj_pixels_injection_threshold: float = 0.05
|
129 |
+
end_preserved_obj_self_attn_masking: int = 40
|
130 |
+
|
131 |
+
# real image
|
132 |
+
real_image_path: str = ""
|
133 |
+
|
134 |
+
# controllable background preservation
|
135 |
+
background_post_process: bool = True
|
136 |
+
background_nouns: List[str] = field(default_factory=lambda :[]) # Objects to take from the original image in addition to the background
|
137 |
+
num_segments: int = 5 # Number of clusters for the segmentation
|
138 |
+
background_segment_threshold: float = 0.3 # Threshold for the segments labeling
|
139 |
+
background_blend_timestep: int = 35 # Number of steps before background blending
|
140 |
+
|
141 |
+
# other
|
142 |
+
cross_attn_inject_steps: float = 0.0
|
143 |
+
self_attn_inject_steps: float = 0.0
|
144 |
+
|
145 |
+
|
146 |
+
if __name__ == '__main__':
|
147 |
+
args = pyrallis.parse(config_class=LPMConfig)
|
148 |
+
|
149 |
+
print(args)
|
150 |
+
main(args)
|
requirements.txt
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.18.0
|
2 |
+
anyio==3.6.2
|
3 |
+
argon2-cffi==21.3.0
|
4 |
+
argon2-cffi-bindings==21.2.0
|
5 |
+
asttokens==2.2.1
|
6 |
+
attrs==22.2.0
|
7 |
+
backcall==0.2.0
|
8 |
+
backports.functools-lru-cache==1.6.4
|
9 |
+
beautifulsoup4==4.11.2
|
10 |
+
bleach==6.0.0
|
11 |
+
brotlipy==0.7.0
|
12 |
+
certifi==2022.12.7
|
13 |
+
cffi==1.15.1
|
14 |
+
chardet==4.0.0
|
15 |
+
charset-normalizer==2.0.4
|
16 |
+
click==8.1.3
|
17 |
+
comm==0.1.2
|
18 |
+
contourpy==1.0.5
|
19 |
+
cryptography==38.0.4
|
20 |
+
cycler==0.11.0
|
21 |
+
debugpy==1.5.1
|
22 |
+
decorator==5.1.1
|
23 |
+
defusedxml==0.7.1
|
24 |
+
diffusers==0.10.2
|
25 |
+
entrypoints==0.4
|
26 |
+
executing==1.2.0
|
27 |
+
fastjsonschema==2.16.2
|
28 |
+
filelock==3.10.4
|
29 |
+
flit_core==3.6.0
|
30 |
+
fonttools==4.25.0
|
31 |
+
huggingface-hub==0.13.3
|
32 |
+
idna==3.4
|
33 |
+
importlib-metadata==6.0.0
|
34 |
+
importlib-resources==5.10.2
|
35 |
+
ipykernel==6.19.2
|
36 |
+
ipython==8.8.0
|
37 |
+
ipython-genutils==0.2.0
|
38 |
+
jedi==0.18.2
|
39 |
+
Jinja2==3.1.2
|
40 |
+
joblib==1.2.0
|
41 |
+
jsonschema==4.17.3
|
42 |
+
jupyter-client==7.3.4
|
43 |
+
jupyter_core==4.12.0
|
44 |
+
jupyter-server==1.23.5
|
45 |
+
jupyterlab-pygments==0.2.2
|
46 |
+
kiwisolver==1.4.4
|
47 |
+
MarkupSafe==2.1.2
|
48 |
+
matplotlib==3.6.2
|
49 |
+
matplotlib-inline==0.1.6
|
50 |
+
mistune==2.0.5
|
51 |
+
mkl-fft==1.3.1
|
52 |
+
mkl-random==1.2.2
|
53 |
+
mkl-service==2.4.0
|
54 |
+
munkres==1.1.4
|
55 |
+
mypy-extensions==1.0.0
|
56 |
+
nbclassic==0.5.1
|
57 |
+
nbclient==0.7.2
|
58 |
+
nbconvert==7.2.9
|
59 |
+
nbformat==5.7.3
|
60 |
+
nest-asyncio==1.5.6
|
61 |
+
nltk==3.8.1
|
62 |
+
notebook==6.5.2
|
63 |
+
notebook_shim==0.2.2
|
64 |
+
numpy==1.23.5
|
65 |
+
opencv-python==4.7.0.72
|
66 |
+
packaging==23.0
|
67 |
+
pandocfilters==1.5.0
|
68 |
+
parso==0.8.3
|
69 |
+
pexpect==4.8.0
|
70 |
+
pickleshare==0.7.5
|
71 |
+
Pillow==9.3.0
|
72 |
+
pip==23.0.1
|
73 |
+
pkgutil_resolve_name==1.3.10
|
74 |
+
ply==3.11
|
75 |
+
prometheus-client==0.16.0
|
76 |
+
prompt-toolkit==3.0.36
|
77 |
+
psutil==5.9.4
|
78 |
+
ptyprocess==0.7.0
|
79 |
+
pure-eval==0.2.2
|
80 |
+
pycparser==2.21
|
81 |
+
Pygments==2.14.0
|
82 |
+
pyOpenSSL==22.0.0
|
83 |
+
pyparsing==3.0.9
|
84 |
+
PyQt5-sip==12.11.0
|
85 |
+
pyrallis==0.3.1
|
86 |
+
pyrsistent==0.19.3
|
87 |
+
PySocks==1.7.1
|
88 |
+
python-dateutil==2.8.2
|
89 |
+
PyYAML==6.0
|
90 |
+
pyzmq==25.0.0
|
91 |
+
regex==2023.3.23
|
92 |
+
requests==2.28.1
|
93 |
+
scikit-learn==1.2.2
|
94 |
+
scipy==1.10.1
|
95 |
+
Send2Trash==1.8.0
|
96 |
+
setuptools==65.6.3
|
97 |
+
sip==6.6.2
|
98 |
+
six==1.16.0
|
99 |
+
sniffio==1.3.0
|
100 |
+
soupsieve==2.3.2.post1
|
101 |
+
stack-data==0.6.2
|
102 |
+
terminado==0.17.1
|
103 |
+
threadpoolctl==3.1.0
|
104 |
+
tinycss2==1.2.1
|
105 |
+
tokenizers==0.13.2
|
106 |
+
toml==0.10.2
|
107 |
+
torch==1.13.1
|
108 |
+
torchaudio==0.13.1
|
109 |
+
torchvision==0.14.1
|
110 |
+
tornado==6.2
|
111 |
+
tqdm==4.65.0
|
112 |
+
traitlets==5.7.1
|
113 |
+
transformers==4.25.1
|
114 |
+
typing_extensions==4.4.0
|
115 |
+
typing-inspect==0.8.0
|
116 |
+
urllib3==1.26.14
|
117 |
+
wcwidth==0.2.6
|
118 |
+
webencodings==0.5.1
|
119 |
+
websocket-client==1.5.1
|
120 |
+
wheel==0.37.1
|
121 |
+
zipp==3.11.0
|
src/attention_based_segmentation.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nltk
|
2 |
+
from sklearn.cluster import KMeans
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from src.attention_utils import aggregate_attention
|
6 |
+
|
7 |
+
|
8 |
+
class Segmentor:
|
9 |
+
|
10 |
+
def __init__(self, controller, prompts, num_segments, background_segment_threshold, res=32, background_nouns=[]):
|
11 |
+
self.controller = controller
|
12 |
+
self.prompts = prompts
|
13 |
+
self.num_segments = num_segments
|
14 |
+
self.background_segment_threshold = background_segment_threshold
|
15 |
+
self.resolution = res
|
16 |
+
self.background_nouns = background_nouns
|
17 |
+
|
18 |
+
self.self_attention = aggregate_attention(controller, res=32, from_where=("up", "down"), prompts=prompts,
|
19 |
+
is_cross=False, select=len(prompts) - 1)
|
20 |
+
self.cross_attention = aggregate_attention(controller, res=16, from_where=("up", "down"), prompts=prompts,
|
21 |
+
is_cross=True, select=len(prompts) - 1)
|
22 |
+
tokenized_prompt = nltk.word_tokenize(prompts[-1])
|
23 |
+
self.nouns = [(i, word) for (i, (word, pos)) in enumerate(nltk.pos_tag(tokenized_prompt)) if pos[:2] == 'NN']
|
24 |
+
|
25 |
+
def __call__(self, *args, **kwargs):
|
26 |
+
clusters = self.cluster()
|
27 |
+
cluster2noun = self.cluster2noun(clusters)
|
28 |
+
return cluster2noun
|
29 |
+
|
30 |
+
def cluster(self):
|
31 |
+
np.random.seed(1)
|
32 |
+
resolution = self.self_attention.shape[0]
|
33 |
+
attn = self.self_attention.cpu().numpy().reshape(resolution ** 2, resolution ** 2)
|
34 |
+
kmeans = KMeans(n_clusters=self.num_segments, n_init=10).fit(attn)
|
35 |
+
clusters = kmeans.labels_
|
36 |
+
clusters = clusters.reshape(resolution, resolution)
|
37 |
+
return clusters
|
38 |
+
|
39 |
+
def cluster2noun(self, clusters):
|
40 |
+
result = {}
|
41 |
+
nouns_indices = [index for (index, word) in self.nouns]
|
42 |
+
nouns_maps = self.cross_attention.cpu().numpy()[:, :, [i + 1 for i in nouns_indices]]
|
43 |
+
normalized_nouns_maps = np.zeros_like(nouns_maps).repeat(2, axis=0).repeat(2, axis=1)
|
44 |
+
for i in range(nouns_maps.shape[-1]):
|
45 |
+
curr_noun_map = nouns_maps[:, :, i].repeat(2, axis=0).repeat(2, axis=1)
|
46 |
+
normalized_nouns_maps[:, :, i] = (curr_noun_map - np.abs(curr_noun_map.min())) / curr_noun_map.max()
|
47 |
+
for c in range(self.num_segments):
|
48 |
+
cluster_mask = np.zeros_like(clusters)
|
49 |
+
cluster_mask[clusters == c] = 1
|
50 |
+
score_maps = [cluster_mask * normalized_nouns_maps[:, :, i] for i in range(len(nouns_indices))]
|
51 |
+
scores = [score_map.sum() / cluster_mask.sum() for score_map in score_maps]
|
52 |
+
result[c] = self.nouns[np.argmax(np.array(scores))] if max(scores) > self.background_segment_threshold else "BG"
|
53 |
+
return result
|
54 |
+
|
55 |
+
def get_background_mask(self, obj_token_index):
|
56 |
+
clusters = self.cluster()
|
57 |
+
cluster2noun = self.cluster2noun(clusters)
|
58 |
+
mask = clusters.copy()
|
59 |
+
obj_segments = [c for c in cluster2noun if cluster2noun[c][0] == obj_token_index - 1]
|
60 |
+
background_segments = [c for c in cluster2noun if cluster2noun[c] == "BG" or cluster2noun[c][1] in self.background_nouns]
|
61 |
+
for c in range(self.num_segments):
|
62 |
+
if c in background_segments and c not in obj_segments:
|
63 |
+
mask[clusters == c] = 0
|
64 |
+
else:
|
65 |
+
mask[clusters == c] = 1
|
66 |
+
return mask
|
67 |
+
|
src/attention_utils.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from typing import Tuple, List
|
4 |
+
from cv2 import putText, getTextSize, FONT_HERSHEY_SIMPLEX
|
5 |
+
# import matplotlib.pyplot as plt
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
from src.prompt_to_prompt_controllers import AttentionStore
|
9 |
+
|
10 |
+
def aggregate_attention(attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int, prompts):
|
11 |
+
out = []
|
12 |
+
attention_maps = attention_store.get_average_attention()
|
13 |
+
num_pixels = res ** 2
|
14 |
+
for location in from_where:
|
15 |
+
for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
|
16 |
+
if item.shape[1] == num_pixels:
|
17 |
+
cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
|
18 |
+
out.append(cross_maps)
|
19 |
+
out = torch.cat(out, dim=0)
|
20 |
+
out = out.sum(0) / out.shape[0]
|
21 |
+
return out.cpu()
|
22 |
+
|
23 |
+
|
24 |
+
def show_cross_attention(attention_store: AttentionStore, res: int, from_where: List[str], prompts, tokenizer, select: int = 0):
|
25 |
+
tokens = tokenizer.encode(prompts[select])
|
26 |
+
decoder = tokenizer.decode
|
27 |
+
attention_maps = aggregate_attention(attention_store, res, from_where, True, select, prompts)
|
28 |
+
images = []
|
29 |
+
for i in range(len(tokens)):
|
30 |
+
image = attention_maps[:, :, i]
|
31 |
+
image = 255 * image / image.max()
|
32 |
+
image = image.unsqueeze(-1).expand(*image.shape, 3)
|
33 |
+
image = image.numpy().astype(np.uint8)
|
34 |
+
image = np.array(Image.fromarray(image).resize((256, 256)))
|
35 |
+
image = text_under_image(image, decoder(int(tokens[i])))
|
36 |
+
images.append(image)
|
37 |
+
view_images(np.stack(images, axis=0))
|
38 |
+
|
39 |
+
|
40 |
+
def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str],
|
41 |
+
max_com=10, select: int = 0):
|
42 |
+
attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape(
|
43 |
+
(res ** 2, res ** 2))
|
44 |
+
u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
|
45 |
+
images = []
|
46 |
+
for i in range(max_com):
|
47 |
+
image = vh[i].reshape(res, res)
|
48 |
+
image = image - image.min()
|
49 |
+
image = 255 * image / image.max()
|
50 |
+
image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
|
51 |
+
image = Image.fromarray(image).resize((256, 256))
|
52 |
+
image = np.array(image)
|
53 |
+
images.append(image)
|
54 |
+
view_images(np.concatenate(images, axis=1))
|
55 |
+
|
56 |
+
|
57 |
+
def view_images(images, num_rows=1, offset_ratio=0.02):
|
58 |
+
if type(images) is list:
|
59 |
+
num_empty = len(images) % num_rows
|
60 |
+
elif images.ndim == 4:
|
61 |
+
num_empty = images.shape[0] % num_rows
|
62 |
+
else:
|
63 |
+
images = [images]
|
64 |
+
num_empty = 0
|
65 |
+
|
66 |
+
empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
|
67 |
+
images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
|
68 |
+
num_items = len(images)
|
69 |
+
|
70 |
+
h, w, c = images[0].shape
|
71 |
+
offset = int(h * offset_ratio)
|
72 |
+
num_cols = num_items // num_rows
|
73 |
+
image_ = np.ones((h * num_rows + offset * (num_rows - 1),
|
74 |
+
w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
|
75 |
+
for i in range(num_rows):
|
76 |
+
for j in range(num_cols):
|
77 |
+
image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
|
78 |
+
i * num_cols + j]
|
79 |
+
|
80 |
+
pil_img = Image.fromarray(image_)
|
81 |
+
display(pil_img)
|
82 |
+
|
83 |
+
|
84 |
+
def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
|
85 |
+
h, w, c = image.shape
|
86 |
+
offset = int(h * .2)
|
87 |
+
img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
|
88 |
+
font = FONT_HERSHEY_SIMPLEX
|
89 |
+
img[:h] = image
|
90 |
+
textsize = getTextSize(text, font, 1, 2)[0]
|
91 |
+
text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
|
92 |
+
putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
|
93 |
+
return img
|
94 |
+
|
95 |
+
|
96 |
+
def display(image):
|
97 |
+
global display_index
|
98 |
+
plt.imshow(image)
|
99 |
+
plt.show()
|
src/diffusion_model_wrapper.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from typing import Optional, List
|
4 |
+
|
5 |
+
from diffusers import DDIMScheduler, StableDiffusionPipeline
|
6 |
+
from tqdm import tqdm
|
7 |
+
from cv2 import dilate
|
8 |
+
|
9 |
+
from src.attention_utils import show_cross_attention
|
10 |
+
from src.attention_based_segmentation import Segmentor
|
11 |
+
from src.prompt_to_prompt_controllers import DummyController, AttentionStore
|
12 |
+
|
13 |
+
|
14 |
+
def get_stable_diffusion_model(args):
|
15 |
+
device = torch.device(f'cuda:{args.gpu_id}') if torch.cuda.is_available() else torch.device('cpu')
|
16 |
+
if args.real_image_path != "":
|
17 |
+
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
|
18 |
+
ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=args.auth_token, scheduler=scheduler).to(device)
|
19 |
+
else:
|
20 |
+
ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=args.auth_token).to(device)
|
21 |
+
|
22 |
+
return ldm_stable
|
23 |
+
|
24 |
+
def get_stable_diffusion_config(args):
|
25 |
+
return {
|
26 |
+
"low_resource": args.low_resource,
|
27 |
+
"num_diffusion_steps": args.num_diffusion_steps,
|
28 |
+
"guidance_scale": args.guidance_scale,
|
29 |
+
"max_num_words": args.max_num_words
|
30 |
+
}
|
31 |
+
|
32 |
+
|
33 |
+
def generate_original_image(args, ldm_stable, ldm_stable_config, prompts, latent, uncond_embeddings):
|
34 |
+
g_cpu = torch.Generator(device=ldm_stable.device).manual_seed(args.seed)
|
35 |
+
controller = AttentionStore(ldm_stable_config["low_resource"])
|
36 |
+
diffusion_model_wrapper = DiffusionModelWrapper(args, ldm_stable, ldm_stable_config, controller, generator=g_cpu)
|
37 |
+
image, x_t, orig_all_latents, _ = diffusion_model_wrapper.forward(prompts,
|
38 |
+
latent=latent,
|
39 |
+
uncond_embeddings=uncond_embeddings)
|
40 |
+
orig_mask = Segmentor(controller, prompts, args.num_segments, args.background_segment_threshold, background_nouns=args.background_nouns)\
|
41 |
+
.get_background_mask(args.prompt.split(' ').index("{word}") + 1)
|
42 |
+
average_attention = controller.get_average_attention()
|
43 |
+
return image, x_t, orig_all_latents, orig_mask, average_attention
|
44 |
+
|
45 |
+
|
46 |
+
class DiffusionModelWrapper:
|
47 |
+
def __init__(self, args, model, model_config, controller=None, prompt_mixing=None, generator=None):
|
48 |
+
self.args = args
|
49 |
+
self.model = model
|
50 |
+
self.model_config = model_config
|
51 |
+
self.controller = controller
|
52 |
+
if self.controller is None:
|
53 |
+
self.controller = DummyController()
|
54 |
+
self.prompt_mixing = prompt_mixing
|
55 |
+
self.device = model.device
|
56 |
+
self.generator = generator
|
57 |
+
|
58 |
+
self.height = 512
|
59 |
+
self.width = 512
|
60 |
+
|
61 |
+
self.diff_step = 0
|
62 |
+
self.register_attention_control()
|
63 |
+
|
64 |
+
|
65 |
+
def diffusion_step(self, latents, context, t, other_context=None):
|
66 |
+
if self.model_config["low_resource"]:
|
67 |
+
self.uncond_pred = True
|
68 |
+
noise_pred_uncond = self.model.unet(latents, t, encoder_hidden_states=(context[0], None))["sample"]
|
69 |
+
self.uncond_pred = False
|
70 |
+
noise_prediction_text = self.model.unet(latents, t, encoder_hidden_states=(context[1], other_context))["sample"]
|
71 |
+
else:
|
72 |
+
latents_input = torch.cat([latents] * 2)
|
73 |
+
noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=(context, other_context))["sample"]
|
74 |
+
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
|
75 |
+
noise_pred = noise_pred_uncond + self.model_config["guidance_scale"] * (noise_prediction_text - noise_pred_uncond)
|
76 |
+
latents = self.model.scheduler.step(noise_pred, t, latents)["prev_sample"]
|
77 |
+
latents = self.controller.step_callback(latents)
|
78 |
+
return latents
|
79 |
+
|
80 |
+
|
81 |
+
def latent2image(self, latents):
|
82 |
+
latents = 1 / 0.18215 * latents
|
83 |
+
image = self.model.vae.decode(latents)['sample']
|
84 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
85 |
+
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
86 |
+
image = (image * 255).astype(np.uint8)
|
87 |
+
return image
|
88 |
+
|
89 |
+
|
90 |
+
def init_latent(self, latent, batch_size):
|
91 |
+
if latent is None:
|
92 |
+
latent = torch.randn(
|
93 |
+
(1, self.model.unet.in_channels, self.height // 8, self.width // 8),
|
94 |
+
generator=self.generator, device=self.model.device
|
95 |
+
)
|
96 |
+
latents = latent.expand(batch_size, self.model.unet.in_channels, self.height // 8, self.width // 8).to(self.device)
|
97 |
+
return latent, latents
|
98 |
+
|
99 |
+
|
100 |
+
def register_attention_control(self):
|
101 |
+
def ca_forward(model_self, place_in_unet):
|
102 |
+
to_out = model_self.to_out
|
103 |
+
if type(to_out) is torch.nn.modules.container.ModuleList:
|
104 |
+
to_out = model_self.to_out[0]
|
105 |
+
else:
|
106 |
+
to_out = model_self.to_out
|
107 |
+
|
108 |
+
def forward(x, context=None, mask=None):
|
109 |
+
batch_size, sequence_length, dim = x.shape
|
110 |
+
h = model_self.heads
|
111 |
+
q = model_self.to_q(x)
|
112 |
+
is_cross = context is not None
|
113 |
+
context = context if is_cross else (x, None)
|
114 |
+
|
115 |
+
k = model_self.to_k(context[0])
|
116 |
+
if is_cross and self.prompt_mixing is not None:
|
117 |
+
v_context = self.prompt_mixing.get_context_for_v(self.diff_step, context[0], context[1])
|
118 |
+
v = model_self.to_v(v_context)
|
119 |
+
else:
|
120 |
+
v = model_self.to_v(context[0])
|
121 |
+
|
122 |
+
q = model_self.reshape_heads_to_batch_dim(q)
|
123 |
+
k = model_self.reshape_heads_to_batch_dim(k)
|
124 |
+
v = model_self.reshape_heads_to_batch_dim(v)
|
125 |
+
|
126 |
+
sim = torch.einsum("b i d, b j d -> b i j", q, k) * model_self.scale
|
127 |
+
|
128 |
+
if mask is not None:
|
129 |
+
mask = mask.reshape(batch_size, -1)
|
130 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
131 |
+
mask = mask[:, None, :].repeat(h, 1, 1)
|
132 |
+
sim.masked_fill_(~mask, max_neg_value)
|
133 |
+
|
134 |
+
# attention, what we cannot get enough of
|
135 |
+
attn = sim.softmax(dim=-1)
|
136 |
+
if self.enbale_attn_controller_changes:
|
137 |
+
attn = self.controller(attn, is_cross, place_in_unet)
|
138 |
+
|
139 |
+
if is_cross and context[1] is not None and self.prompt_mixing is not None:
|
140 |
+
attn = self.prompt_mixing.get_cross_attn(self, self.diff_step, attn, place_in_unet, batch_size)
|
141 |
+
|
142 |
+
if not is_cross and (not self.model_config["low_resource"] or not self.uncond_pred) and self.prompt_mixing is not None:
|
143 |
+
attn = self.prompt_mixing.get_self_attn(self, self.diff_step, attn, place_in_unet, batch_size)
|
144 |
+
|
145 |
+
out = torch.einsum("b i j, b j d -> b i d", attn, v)
|
146 |
+
out = model_self.reshape_batch_dim_to_heads(out)
|
147 |
+
return to_out(out)
|
148 |
+
|
149 |
+
return forward
|
150 |
+
|
151 |
+
def register_recr(net_, count, place_in_unet):
|
152 |
+
if net_.__class__.__name__ == 'CrossAttention':
|
153 |
+
net_.forward = ca_forward(net_, place_in_unet)
|
154 |
+
return count + 1
|
155 |
+
elif hasattr(net_, 'children'):
|
156 |
+
for net__ in net_.children():
|
157 |
+
count = register_recr(net__, count, place_in_unet)
|
158 |
+
return count
|
159 |
+
|
160 |
+
cross_att_count = 0
|
161 |
+
sub_nets = self.model.unet.named_children()
|
162 |
+
for net in sub_nets:
|
163 |
+
if "down" in net[0]:
|
164 |
+
cross_att_count += register_recr(net[1], 0, "down")
|
165 |
+
elif "up" in net[0]:
|
166 |
+
cross_att_count += register_recr(net[1], 0, "up")
|
167 |
+
elif "mid" in net[0]:
|
168 |
+
cross_att_count += register_recr(net[1], 0, "mid")
|
169 |
+
self.controller.num_att_layers = cross_att_count
|
170 |
+
|
171 |
+
|
172 |
+
def get_text_embedding(self, prompt: List[str], max_length=None, truncation=True):
|
173 |
+
text_input = self.model.tokenizer(
|
174 |
+
prompt,
|
175 |
+
padding="max_length",
|
176 |
+
max_length=self.model.tokenizer.model_max_length if max_length is None else max_length,
|
177 |
+
truncation=truncation,
|
178 |
+
return_tensors="pt",
|
179 |
+
)
|
180 |
+
text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.device))[0]
|
181 |
+
max_length = text_input.input_ids.shape[-1]
|
182 |
+
return text_embeddings, max_length
|
183 |
+
|
184 |
+
|
185 |
+
@torch.no_grad()
|
186 |
+
def forward(self, prompt: List[str], latent: Optional[torch.FloatTensor] = None,
|
187 |
+
other_prompt: List[str] = None, post_background = False, orig_all_latents = None, orig_mask = None,
|
188 |
+
uncond_embeddings=None, start_time=51, return_type='image'):
|
189 |
+
self.enbale_attn_controller_changes = True
|
190 |
+
batch_size = len(prompt)
|
191 |
+
|
192 |
+
text_embeddings, max_length = self.get_text_embedding(prompt)
|
193 |
+
if uncond_embeddings is None:
|
194 |
+
uncond_embeddings_, _ = self.get_text_embedding([""] * batch_size, max_length=max_length, truncation=False)
|
195 |
+
else:
|
196 |
+
uncond_embeddings_ = None
|
197 |
+
|
198 |
+
other_context = None
|
199 |
+
if other_prompt is not None:
|
200 |
+
other_text_embeddings, _ = self.get_text_embedding(other_prompt)
|
201 |
+
other_context = other_text_embeddings
|
202 |
+
|
203 |
+
latent, latents = self.init_latent(latent, batch_size)
|
204 |
+
|
205 |
+
# set timesteps
|
206 |
+
self.model.scheduler.set_timesteps(self.model_config["num_diffusion_steps"])
|
207 |
+
all_latents = []
|
208 |
+
|
209 |
+
object_mask = None
|
210 |
+
self.diff_step = 0
|
211 |
+
for i, t in enumerate(tqdm(self.model.scheduler.timesteps[-start_time:])):
|
212 |
+
if uncond_embeddings_ is None:
|
213 |
+
context = [uncond_embeddings[i].expand(*text_embeddings.shape), text_embeddings]
|
214 |
+
else:
|
215 |
+
context = [uncond_embeddings_, text_embeddings]
|
216 |
+
if not self.model_config["low_resource"]:
|
217 |
+
context = torch.cat(context)
|
218 |
+
|
219 |
+
self.down_cross_index = 0
|
220 |
+
self.mid_cross_index = 0
|
221 |
+
self.up_cross_index = 0
|
222 |
+
latents = self.diffusion_step(latents, context, t, other_context)
|
223 |
+
|
224 |
+
if post_background and self.diff_step == self.args.background_blend_timestep:
|
225 |
+
object_mask = Segmentor(self.controller,
|
226 |
+
prompt,
|
227 |
+
self.args.num_segments,
|
228 |
+
self.args.background_segment_threshold,
|
229 |
+
background_nouns=self.args.background_nouns)\
|
230 |
+
.get_background_mask(self.args.prompt.split(' ').index("{word}") + 1)
|
231 |
+
self.enbale_attn_controller_changes = False
|
232 |
+
mask = object_mask.astype(np.bool8) + orig_mask.astype(np.bool8)
|
233 |
+
mask = torch.from_numpy(mask).float().cuda()
|
234 |
+
shape = (1, 1, mask.shape[0], mask.shape[1])
|
235 |
+
mask = torch.nn.Upsample(size=(64, 64), mode='nearest')(mask.view(shape))
|
236 |
+
mask_eroded = dilate(mask.cpu().numpy()[0, 0], np.ones((3, 3), np.uint8), iterations=1)
|
237 |
+
mask = torch.from_numpy(mask_eroded).float().cuda().view(1, 1, 64, 64)
|
238 |
+
latents = mask * latents + (1 - mask) * orig_all_latents[self.diff_step]
|
239 |
+
|
240 |
+
all_latents.append(latents)
|
241 |
+
self.diff_step += 1
|
242 |
+
|
243 |
+
if return_type == 'image':
|
244 |
+
image = self.latent2image(latents)
|
245 |
+
else:
|
246 |
+
image = latents
|
247 |
+
|
248 |
+
return image, latent, all_latents, object_mask
|
249 |
+
|
250 |
+
|
251 |
+
def show_last_cross_attention(self, res: int, from_where: List[str], prompts, select: int = 0):
|
252 |
+
show_cross_attention(self.controller, res, from_where, prompts, tokenizer=self.model.tokenizer, select=select)
|
src/null_text_inversion.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
from torchvision.transforms import ToTensor
|
4 |
+
from torchvision.utils import save_image
|
5 |
+
from tqdm import tqdm
|
6 |
+
import torch
|
7 |
+
from torch.optim.adam import Adam
|
8 |
+
import torch.nn.functional as nnf
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
|
13 |
+
def load_512(image_path, left=0, right=0, top=0, bottom=0):
|
14 |
+
if type(image_path) is str:
|
15 |
+
image = np.array(Image.open(image_path))[:, :, :3]
|
16 |
+
else:
|
17 |
+
image = image_path
|
18 |
+
h, w, c = image.shape
|
19 |
+
left = min(left, w-1)
|
20 |
+
right = min(right, w - left - 1)
|
21 |
+
top = min(top, h - left - 1)
|
22 |
+
bottom = min(bottom, h - top - 1)
|
23 |
+
image = image[top:h-bottom, left:w-right]
|
24 |
+
h, w, c = image.shape
|
25 |
+
if h < w:
|
26 |
+
offset = (w - h) // 2
|
27 |
+
image = image[:, offset:offset + h]
|
28 |
+
elif w < h:
|
29 |
+
offset = (h - w) // 2
|
30 |
+
image = image[offset:offset + w]
|
31 |
+
image = np.array(Image.fromarray(image).resize((512, 512)))
|
32 |
+
return image
|
33 |
+
|
34 |
+
|
35 |
+
def invert_image(args, ldm_stable, ldm_stable_config, prompts, exp_path):
|
36 |
+
print("Start null text inversion")
|
37 |
+
null_inversion = NullInversion(ldm_stable, ldm_stable_config)
|
38 |
+
(image_gt, image_enc), x_t, uncond_embeddings = null_inversion.invert(args.real_image_path, prompts[0], offsets=(0,0,0,0), verbose=True)
|
39 |
+
save_image(ToTensor()(image_gt), f"{exp_path}/real_image.jpg")
|
40 |
+
save_image(ToTensor()(image_enc), f"{exp_path}/image_enc.jpg")
|
41 |
+
print("End null text inversion")
|
42 |
+
return x_t, uncond_embeddings
|
43 |
+
|
44 |
+
|
45 |
+
class NullInversion:
|
46 |
+
|
47 |
+
def __init__(self, model, model_config):
|
48 |
+
self.model = model
|
49 |
+
self.model_config = model_config
|
50 |
+
self.tokenizer = self.model.tokenizer
|
51 |
+
self.model.scheduler.set_timesteps(self.model_config["num_diffusion_steps"])
|
52 |
+
self.prompt = None
|
53 |
+
self.context = None
|
54 |
+
|
55 |
+
|
56 |
+
def prev_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
|
57 |
+
prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
|
58 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
|
59 |
+
alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
|
60 |
+
beta_prod_t = 1 - alpha_prod_t
|
61 |
+
pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
|
62 |
+
pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output
|
63 |
+
prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction
|
64 |
+
return prev_sample
|
65 |
+
|
66 |
+
def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
|
67 |
+
timestep, next_timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep
|
68 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
|
69 |
+
alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep]
|
70 |
+
beta_prod_t = 1 - alpha_prod_t
|
71 |
+
next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
|
72 |
+
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
|
73 |
+
next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
|
74 |
+
return next_sample
|
75 |
+
|
76 |
+
def get_noise_pred_single(self, latents, t, context):
|
77 |
+
noise_pred = self.model.unet(latents, t, encoder_hidden_states=context)["sample"]
|
78 |
+
return noise_pred
|
79 |
+
|
80 |
+
def get_noise_pred(self, latents, t, is_forward=True, context=None):
|
81 |
+
latents_input = torch.cat([latents] * 2)
|
82 |
+
if context is None:
|
83 |
+
context = self.context
|
84 |
+
guidance_scale = 1 if is_forward else self.model_config["guidance_scale"]
|
85 |
+
noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
|
86 |
+
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
|
87 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
|
88 |
+
if is_forward:
|
89 |
+
latents = self.next_step(noise_pred, t, latents)
|
90 |
+
else:
|
91 |
+
latents = self.prev_step(noise_pred, t, latents)
|
92 |
+
return latents
|
93 |
+
|
94 |
+
@torch.no_grad()
|
95 |
+
def latent2image(self, latents, return_type='np'):
|
96 |
+
latents = 1 / 0.18215 * latents.detach()
|
97 |
+
image = self.model.vae.decode(latents)['sample']
|
98 |
+
if return_type == 'np':
|
99 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
100 |
+
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
|
101 |
+
image = (image * 255).astype(np.uint8)
|
102 |
+
return image
|
103 |
+
|
104 |
+
@torch.no_grad()
|
105 |
+
def image2latent(self, image):
|
106 |
+
with torch.no_grad():
|
107 |
+
if type(image) is Image:
|
108 |
+
image = np.array(image)
|
109 |
+
if type(image) is torch.Tensor and image.dim() == 4:
|
110 |
+
latents = image
|
111 |
+
else:
|
112 |
+
image = torch.from_numpy(image).float() / 127.5 - 1
|
113 |
+
image = image.permute(2, 0, 1).unsqueeze(0).to(self.model.device)
|
114 |
+
latents = self.model.vae.encode(image)['latent_dist'].mean
|
115 |
+
latents = latents * 0.18215
|
116 |
+
return latents
|
117 |
+
|
118 |
+
@torch.no_grad()
|
119 |
+
def init_prompt(self, prompt: str):
|
120 |
+
uncond_input = self.model.tokenizer(
|
121 |
+
[""], padding="max_length", max_length=self.model.tokenizer.model_max_length,
|
122 |
+
return_tensors="pt"
|
123 |
+
)
|
124 |
+
uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0]
|
125 |
+
text_input = self.model.tokenizer(
|
126 |
+
[prompt],
|
127 |
+
padding="max_length",
|
128 |
+
max_length=self.model.tokenizer.model_max_length,
|
129 |
+
truncation=True,
|
130 |
+
return_tensors="pt",
|
131 |
+
)
|
132 |
+
text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0]
|
133 |
+
self.context = torch.cat([uncond_embeddings, text_embeddings])
|
134 |
+
self.prompt = prompt
|
135 |
+
|
136 |
+
@torch.no_grad()
|
137 |
+
def ddim_loop(self, latent):
|
138 |
+
uncond_embeddings, cond_embeddings = self.context.chunk(2)
|
139 |
+
all_latent = [latent]
|
140 |
+
latent = latent.clone().detach()
|
141 |
+
for i in tqdm(range(self.model_config["num_diffusion_steps"])):
|
142 |
+
t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1]
|
143 |
+
noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings)
|
144 |
+
latent = self.next_step(noise_pred, t, latent)
|
145 |
+
all_latent.append(latent)
|
146 |
+
return all_latent
|
147 |
+
|
148 |
+
@property
|
149 |
+
def scheduler(self):
|
150 |
+
return self.model.scheduler
|
151 |
+
|
152 |
+
@torch.no_grad()
|
153 |
+
def ddim_inversion(self, image):
|
154 |
+
latent = self.image2latent(image)
|
155 |
+
image_rec = self.latent2image(latent)
|
156 |
+
ddim_latents = self.ddim_loop(latent)
|
157 |
+
return image_rec, ddim_latents
|
158 |
+
|
159 |
+
def null_optimization(self, latents, num_inner_steps, epsilon):
|
160 |
+
uncond_embeddings, cond_embeddings = self.context.chunk(2)
|
161 |
+
uncond_embeddings_list = []
|
162 |
+
latent_cur = latents[-1]
|
163 |
+
with tqdm(total=num_inner_steps * (self.model_config["num_diffusion_steps"])) as bar:
|
164 |
+
for i in range(self.model_config["num_diffusion_steps"]):
|
165 |
+
uncond_embeddings = uncond_embeddings.clone().detach()
|
166 |
+
uncond_embeddings.requires_grad = True
|
167 |
+
optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.))
|
168 |
+
latent_prev = latents[len(latents) - i - 2]
|
169 |
+
t = self.model.scheduler.timesteps[i]
|
170 |
+
with torch.no_grad():
|
171 |
+
noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings)
|
172 |
+
for j in range(num_inner_steps):
|
173 |
+
noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings)
|
174 |
+
noise_pred = noise_pred_uncond + self.model_config["guidance_scale"] * (noise_pred_cond - noise_pred_uncond)
|
175 |
+
latents_prev_rec = self.prev_step(noise_pred, t, latent_cur)
|
176 |
+
loss = nnf.mse_loss(latents_prev_rec, latent_prev)
|
177 |
+
optimizer.zero_grad()
|
178 |
+
loss.backward()
|
179 |
+
optimizer.step()
|
180 |
+
loss_item = loss.item()
|
181 |
+
bar.update()
|
182 |
+
if loss_item < epsilon + i * 2e-5:
|
183 |
+
break
|
184 |
+
bar.update(num_inner_steps - j - 1)
|
185 |
+
uncond_embeddings_list.append(uncond_embeddings[:1].detach())
|
186 |
+
with torch.no_grad():
|
187 |
+
context = torch.cat([uncond_embeddings, cond_embeddings])
|
188 |
+
latent_cur = self.get_noise_pred(latent_cur, t, False, context)
|
189 |
+
# bar.close()
|
190 |
+
return uncond_embeddings_list
|
191 |
+
|
192 |
+
def invert(self, image_path: str, prompt: str, offsets=(0,0,0,0), num_inner_steps=10, early_stop_epsilon=1e-5, verbose=False):
|
193 |
+
self.init_prompt(prompt)
|
194 |
+
image_gt = load_512(image_path, *offsets)
|
195 |
+
if verbose:
|
196 |
+
print("DDIM inversion...")
|
197 |
+
image_rec, ddim_latents = self.ddim_inversion(image_gt)
|
198 |
+
if verbose:
|
199 |
+
print("Null-text optimization...")
|
200 |
+
uncond_embeddings = self.null_optimization(ddim_latents, num_inner_steps, early_stop_epsilon)
|
201 |
+
return (image_gt, image_rec), ddim_latents[-1], uncond_embeddings
|
src/prompt_mixing.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from scipy.signal import medfilt2d
|
3 |
+
|
4 |
+
class PromptMixing:
|
5 |
+
def __init__(self, args, object_of_interest_index, avg_cross_attn=None):
|
6 |
+
self.object_of_interest_index = object_of_interest_index
|
7 |
+
self.objects_to_preserve = [args.prompt.split().index(o) + 1 for o in args.objects_to_preserve]
|
8 |
+
self.obj_pixels_injection_threshold = args.obj_pixels_injection_threshold
|
9 |
+
|
10 |
+
self.start_other_prompt_range = args.start_prompt_range
|
11 |
+
self.end_other_prompt_range = args.end_prompt_range
|
12 |
+
|
13 |
+
self.start_cross_attn_replace_range = args.num_diffusion_steps
|
14 |
+
self.end_cross_attn_replace_range = args.num_diffusion_steps
|
15 |
+
|
16 |
+
self.start_self_attn_replace_range = 0
|
17 |
+
self.end_self_attn_replace_range = args.end_preserved_obj_self_attn_masking
|
18 |
+
self.remove_obj_from_self_mask = args.remove_obj_from_self_mask
|
19 |
+
self.avg_cross_attn = avg_cross_attn
|
20 |
+
|
21 |
+
self.low_resource = args.low_resource
|
22 |
+
|
23 |
+
def get_context_for_v(self, t, context, other_context):
|
24 |
+
if other_context is not None and \
|
25 |
+
self.start_other_prompt_range <= t < self.end_other_prompt_range:
|
26 |
+
if self.low_resource:
|
27 |
+
return other_context
|
28 |
+
else:
|
29 |
+
v_context = context.clone()
|
30 |
+
# first half of context is for the uncoditioned image
|
31 |
+
v_context[v_context.shape[0]//2:] = other_context
|
32 |
+
return v_context
|
33 |
+
else:
|
34 |
+
return context
|
35 |
+
|
36 |
+
def get_cross_attn(self, diffusion_model_wrapper, t, attn, place_in_unet, batch_size):
|
37 |
+
if self.start_cross_attn_replace_range <= t < self.end_cross_attn_replace_range:
|
38 |
+
if self.low_resource:
|
39 |
+
attn[:,:,self.object_of_interest_index] = 0.2 * torch.from_numpy(medfilt2d(attn[:, :, self.object_of_interest_index].cpu().numpy(), kernel_size=3)).to(attn.device) + \
|
40 |
+
0.8 * attn[:, :, self.object_of_interest_index]
|
41 |
+
else:
|
42 |
+
# first half of attn maps is for the uncoditioned image
|
43 |
+
min_h = attn.shape[0] // 2
|
44 |
+
attn[min_h:, :, self.object_of_interest_index] = 0.2 * torch.from_numpy(medfilt2d(attn[min_h:, :, self.object_of_interest_index].cpu().numpy(), kernel_size=3)).to(attn.device) + \
|
45 |
+
0.8 * attn[min_h:, :, self.object_of_interest_index]
|
46 |
+
return attn
|
47 |
+
|
48 |
+
def get_self_attn(self, diffusion_model_wrapper, t, attn, place_in_unet, batch_size):
|
49 |
+
if attn.shape[1] <= 32 ** 2 and \
|
50 |
+
self.avg_cross_attn is not None and \
|
51 |
+
self.start_self_attn_replace_range <= t < self.end_self_attn_replace_range:
|
52 |
+
|
53 |
+
key = f"{place_in_unet}_cross"
|
54 |
+
attn_index = getattr(diffusion_model_wrapper, f'{key}_index')
|
55 |
+
cr = self.avg_cross_attn[key][attn_index]
|
56 |
+
setattr(diffusion_model_wrapper, f'{key}_index', attn_index+1)
|
57 |
+
|
58 |
+
if self.low_resource:
|
59 |
+
attn = self.mask_self_attn_patches(attn, cr, batch_size)
|
60 |
+
else:
|
61 |
+
# first half of attn maps is for the uncoditioned image
|
62 |
+
attn[attn.shape[0]//2:] = self.mask_self_attn_patches(attn[attn.shape[0]//2:], cr, batch_size//2)
|
63 |
+
|
64 |
+
return attn
|
65 |
+
|
66 |
+
def mask_self_attn_patches(self, self_attn, cross_attn, batch_size):
|
67 |
+
h = self_attn.shape[0] // batch_size
|
68 |
+
tokens = self.objects_to_preserve
|
69 |
+
obj_token = self.object_of_interest_index
|
70 |
+
|
71 |
+
normalized_cross_attn = cross_attn - cross_attn.min()
|
72 |
+
normalized_cross_attn /= normalized_cross_attn.max()
|
73 |
+
|
74 |
+
mask = torch.zeros_like(self_attn[0])
|
75 |
+
for tk in tokens:
|
76 |
+
mask_tk_in = torch.unique((normalized_cross_attn[:,:,tk] > self.obj_pixels_injection_threshold).nonzero(as_tuple=True)[1])
|
77 |
+
mask[mask_tk_in, :] = 1
|
78 |
+
mask[:, mask_tk_in] = 1
|
79 |
+
|
80 |
+
if self.remove_obj_from_self_mask:
|
81 |
+
obj_patches = torch.unique((normalized_cross_attn[:,:,obj_token] > self.obj_pixels_injection_threshold).nonzero(as_tuple=True)[1])
|
82 |
+
mask[obj_patches, :] = 0
|
83 |
+
mask[:, obj_patches] = 0
|
84 |
+
|
85 |
+
self_attn[h:] = self_attn[h:] * (1 - mask) + self_attn[:h].repeat(batch_size - 1, 1, 1) * mask
|
86 |
+
return self_attn
|
src/prompt_to_prompt_controllers.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import abc
|
4 |
+
from typing import Optional, Union, Tuple, Dict
|
5 |
+
import src.seq_aligner as seq_aligner
|
6 |
+
|
7 |
+
|
8 |
+
class AttentionControl(abc.ABC):
|
9 |
+
|
10 |
+
def step_callback(self, x_t):
|
11 |
+
return x_t
|
12 |
+
|
13 |
+
def between_steps(self):
|
14 |
+
return
|
15 |
+
|
16 |
+
@property
|
17 |
+
def num_uncond_att_layers(self):
|
18 |
+
return self.num_att_layers if self.low_resource else 0
|
19 |
+
|
20 |
+
@abc.abstractmethod
|
21 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
22 |
+
raise NotImplementedError
|
23 |
+
|
24 |
+
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
25 |
+
if self.cur_att_layer >= self.num_uncond_att_layers:
|
26 |
+
if self.low_resource:
|
27 |
+
attn = self.forward(attn, is_cross, place_in_unet)
|
28 |
+
else:
|
29 |
+
h = attn.shape[0]
|
30 |
+
attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
|
31 |
+
self.cur_att_layer += 1
|
32 |
+
if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
|
33 |
+
self.cur_att_layer = 0
|
34 |
+
self.cur_step += 1
|
35 |
+
self.between_steps()
|
36 |
+
return attn
|
37 |
+
|
38 |
+
def reset(self):
|
39 |
+
self.cur_step = 0
|
40 |
+
self.cur_att_layer = 0
|
41 |
+
|
42 |
+
def __init__(self, low_resource):
|
43 |
+
self.cur_step = 0
|
44 |
+
self.num_att_layers = -1
|
45 |
+
self.cur_att_layer = 0
|
46 |
+
self.low_resource = low_resource
|
47 |
+
|
48 |
+
|
49 |
+
class EmptyControl(AttentionControl):
|
50 |
+
|
51 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
52 |
+
return attn
|
53 |
+
|
54 |
+
|
55 |
+
class DummyController:
|
56 |
+
def __call__(self, *args):
|
57 |
+
return args[0]
|
58 |
+
|
59 |
+
def __init__(self):
|
60 |
+
self.num_att_layers = 0
|
61 |
+
|
62 |
+
|
63 |
+
class AttentionStore(AttentionControl):
|
64 |
+
|
65 |
+
@staticmethod
|
66 |
+
def get_empty_store():
|
67 |
+
return {"down_cross": [], "mid_cross": [], "up_cross": [],
|
68 |
+
"down_self": [], "mid_self": [], "up_self": []}
|
69 |
+
|
70 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
71 |
+
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
|
72 |
+
if attn.shape[1] <= 32 ** 2: # avoid memory overhead
|
73 |
+
self.step_store[key].append(attn)
|
74 |
+
return attn
|
75 |
+
|
76 |
+
def between_steps(self):
|
77 |
+
if len(self.attention_store) == 0:
|
78 |
+
self.attention_store = self.step_store
|
79 |
+
else:
|
80 |
+
for key in self.attention_store:
|
81 |
+
for i in range(len(self.attention_store[key])):
|
82 |
+
self.attention_store[key][i] += self.step_store[key][i]
|
83 |
+
self.step_store = self.get_empty_store()
|
84 |
+
|
85 |
+
def get_average_attention(self):
|
86 |
+
average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in
|
87 |
+
self.attention_store}
|
88 |
+
return average_attention
|
89 |
+
|
90 |
+
def reset(self):
|
91 |
+
super(AttentionStore, self).reset()
|
92 |
+
self.step_store = self.get_empty_store()
|
93 |
+
self.attention_store = {}
|
94 |
+
|
95 |
+
def __init__(self, low_resource):
|
96 |
+
super(AttentionStore, self).__init__(low_resource)
|
97 |
+
self.step_store = self.get_empty_store()
|
98 |
+
self.attention_store = {}
|
99 |
+
|
100 |
+
|
101 |
+
class AttentionControlEdit(AttentionStore, abc.ABC):
|
102 |
+
|
103 |
+
def step_callback(self, x_t):
|
104 |
+
return x_t
|
105 |
+
|
106 |
+
def replace_self_attention(self, attn_base, att_replace):
|
107 |
+
if att_replace.shape[2] <= 16 ** 2:
|
108 |
+
return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
|
109 |
+
else:
|
110 |
+
return att_replace
|
111 |
+
|
112 |
+
@abc.abstractmethod
|
113 |
+
def replace_cross_attention(self, attn_base, att_replace):
|
114 |
+
raise NotImplementedError
|
115 |
+
|
116 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
117 |
+
super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
|
118 |
+
if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
|
119 |
+
h = attn.shape[0] // (self.batch_size)
|
120 |
+
attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
|
121 |
+
attn_base, attn_repalce = attn[0], attn[1:]
|
122 |
+
if is_cross:
|
123 |
+
alpha_words = self.cross_replace_alpha[self.cur_step]
|
124 |
+
attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (
|
125 |
+
1 - alpha_words) * attn_repalce
|
126 |
+
attn[1:] = attn_repalce_new
|
127 |
+
else:
|
128 |
+
attn[1:] = self.replace_self_attention(attn_base, attn_repalce)
|
129 |
+
attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
|
130 |
+
return attn
|
131 |
+
|
132 |
+
def __init__(self, prompts, tokenizer, device, low_resource, num_steps: int,
|
133 |
+
cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
|
134 |
+
self_replace_steps: Union[float, Tuple[float, float]]):
|
135 |
+
super(AttentionControlEdit, self).__init__(low_resource)
|
136 |
+
self.batch_size = len(prompts)
|
137 |
+
self.tokenizer = tokenizer
|
138 |
+
self.cross_replace_alpha = get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps,
|
139 |
+
self.tokenizer).to(device)
|
140 |
+
if type(self_replace_steps) is float:
|
141 |
+
self_replace_steps = 0, self_replace_steps
|
142 |
+
self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
|
143 |
+
|
144 |
+
|
145 |
+
class AttentionReplace(AttentionControlEdit):
|
146 |
+
|
147 |
+
def replace_cross_attention(self, attn_base, att_replace):
|
148 |
+
return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper.to(attn_base.dtype))
|
149 |
+
|
150 |
+
def __init__(self, prompts, tokenizer, device, low_resource, num_steps: int, cross_replace_steps: float, self_replace_steps: float):
|
151 |
+
super(AttentionReplace, self).__init__(prompts, tokenizer, device, low_resource, num_steps, cross_replace_steps, self_replace_steps)
|
152 |
+
self.mapper = seq_aligner.get_replacement_mapper(prompts, self.tokenizer).to(device)
|
153 |
+
|
154 |
+
|
155 |
+
def get_word_inds(text: str, word_place: int, tokenizer):
|
156 |
+
split_text = text.split(" ")
|
157 |
+
if type(word_place) is str:
|
158 |
+
word_place = [i for i, word in enumerate(split_text) if word_place == word]
|
159 |
+
elif type(word_place) is int:
|
160 |
+
word_place = [word_place]
|
161 |
+
out = []
|
162 |
+
if len(word_place) > 0:
|
163 |
+
words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
|
164 |
+
cur_len, ptr = 0, 0
|
165 |
+
|
166 |
+
for i in range(len(words_encode)):
|
167 |
+
cur_len += len(words_encode[i])
|
168 |
+
if ptr in word_place:
|
169 |
+
out.append(i + 1)
|
170 |
+
if cur_len >= len(split_text[ptr]):
|
171 |
+
ptr += 1
|
172 |
+
cur_len = 0
|
173 |
+
return np.array(out)
|
174 |
+
|
175 |
+
|
176 |
+
def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, word_inds: Optional[torch.Tensor]=None):
|
177 |
+
if type(bounds) is float:
|
178 |
+
bounds = 0, bounds
|
179 |
+
start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
|
180 |
+
if word_inds is None:
|
181 |
+
word_inds = torch.arange(alpha.shape[2])
|
182 |
+
alpha[: start, prompt_ind, word_inds] = 0
|
183 |
+
alpha[start: end, prompt_ind, word_inds] = 1
|
184 |
+
alpha[end:, prompt_ind, word_inds] = 0
|
185 |
+
return alpha
|
186 |
+
|
187 |
+
|
188 |
+
def get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
|
189 |
+
tokenizer, max_num_words=77):
|
190 |
+
if type(cross_replace_steps) is not dict:
|
191 |
+
cross_replace_steps = {"default_": cross_replace_steps}
|
192 |
+
if "default_" not in cross_replace_steps:
|
193 |
+
cross_replace_steps["default_"] = (0., 1.)
|
194 |
+
alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
|
195 |
+
for i in range(len(prompts) - 1):
|
196 |
+
alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
|
197 |
+
i)
|
198 |
+
for key, item in cross_replace_steps.items():
|
199 |
+
if key != "default_":
|
200 |
+
inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
|
201 |
+
for i, ind in enumerate(inds):
|
202 |
+
if len(ind) > 0:
|
203 |
+
alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
|
204 |
+
alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) # time, batch, heads, pixels, words
|
205 |
+
return alpha_time_words
|
src/prompt_utils.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
|
7 |
+
def get_topk_similar_words(model, prompt, base_word, vocab, k=30):
|
8 |
+
text_input = model.tokenizer(
|
9 |
+
[prompt.format(word=base_word)],
|
10 |
+
padding="max_length",
|
11 |
+
max_length=model.tokenizer.model_max_length,
|
12 |
+
truncation=True,
|
13 |
+
return_tensors="pt",
|
14 |
+
)
|
15 |
+
with torch.no_grad():
|
16 |
+
encoder_output = model.text_encoder(text_input.input_ids.to(model.device))
|
17 |
+
full_prompt_embedding = encoder_output.pooler_output
|
18 |
+
full_prompt_embedding = full_prompt_embedding / full_prompt_embedding.norm(p=2, dim=-1, keepdim=True)
|
19 |
+
|
20 |
+
prompts = [prompt.format(word=word) for word in vocab]
|
21 |
+
batch_size = 1000
|
22 |
+
all_prompts_embeddings = []
|
23 |
+
for i in tqdm(range(0, len(prompts), batch_size)):
|
24 |
+
curr_prompts = prompts[i:i + batch_size]
|
25 |
+
with torch.no_grad():
|
26 |
+
text_input = model.tokenizer(
|
27 |
+
curr_prompts,
|
28 |
+
padding="max_length",
|
29 |
+
max_length=model.tokenizer.model_max_length,
|
30 |
+
truncation=True,
|
31 |
+
return_tensors="pt",
|
32 |
+
)
|
33 |
+
curr_embeddings = model.text_encoder(text_input.input_ids.to(model.device)).pooler_output
|
34 |
+
all_prompts_embeddings.append(curr_embeddings)
|
35 |
+
|
36 |
+
all_prompts_embeddings = torch.cat(all_prompts_embeddings)
|
37 |
+
all_prompts_embeddings = all_prompts_embeddings / all_prompts_embeddings.norm(p=2, dim=-1, keepdim=True)
|
38 |
+
prompts_similarities = all_prompts_embeddings.matmul(full_prompt_embedding.view(-1, 1))
|
39 |
+
sorted_prompts_similarities = np.flip(prompts_similarities.cpu().numpy().reshape(-1).argsort())
|
40 |
+
|
41 |
+
print(f"prompt: {prompt}")
|
42 |
+
print(f"initial word: {base_word}")
|
43 |
+
print(f"TOP {k} SIMILAR WORDS:")
|
44 |
+
similar_words = [vocab[index] for index in sorted_prompts_similarities[:k]]
|
45 |
+
print(similar_words)
|
46 |
+
return similar_words
|
47 |
+
|
48 |
+
def get_proxy_words(args, ldm_stable):
|
49 |
+
if len(args.proxy_words) > 0:
|
50 |
+
return [args.object_of_interest] + args.proxy_words
|
51 |
+
vocab = list(json.load(open("vocab.json")).keys())
|
52 |
+
vocab = [word for word in vocab if word.isalpha() and len(word) > 1]
|
53 |
+
filtered_vocab = get_topk_similar_words(ldm_stable, "a photo of a {word}", args.object_of_interest, vocab, k=50)
|
54 |
+
proxy_words = get_topk_similar_words(ldm_stable, args.prompt, args.object_of_interest, filtered_vocab, k=args.number_of_variations)
|
55 |
+
if proxy_words[0] != args.object_of_interest:
|
56 |
+
proxy_words = [args.object_of_interest] + proxy_words
|
57 |
+
|
58 |
+
return proxy_words
|
59 |
+
|
60 |
+
def get_proxy_prompts(args, ldm_stable):
|
61 |
+
proxy_words = get_proxy_words(args, ldm_stable)
|
62 |
+
prompts = [args.prompt.format(word=args.object_of_interest)]
|
63 |
+
proxy_prompts = [{"word": word, "prompt": args.prompt.format(word=word)} for word in proxy_words]
|
64 |
+
return proxy_words, prompts, proxy_prompts
|
src/seq_aligner.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
|
18 |
+
class ScoreParams:
|
19 |
+
|
20 |
+
def __init__(self, gap, match, mismatch):
|
21 |
+
self.gap = gap
|
22 |
+
self.match = match
|
23 |
+
self.mismatch = mismatch
|
24 |
+
|
25 |
+
def mis_match_char(self, x, y):
|
26 |
+
if x != y:
|
27 |
+
return self.mismatch
|
28 |
+
else:
|
29 |
+
return self.match
|
30 |
+
|
31 |
+
|
32 |
+
def get_matrix(size_x, size_y, gap):
|
33 |
+
matrix = []
|
34 |
+
for i in range(len(size_x) + 1):
|
35 |
+
sub_matrix = []
|
36 |
+
for j in range(len(size_y) + 1):
|
37 |
+
sub_matrix.append(0)
|
38 |
+
matrix.append(sub_matrix)
|
39 |
+
for j in range(1, len(size_y) + 1):
|
40 |
+
matrix[0][j] = j*gap
|
41 |
+
for i in range(1, len(size_x) + 1):
|
42 |
+
matrix[i][0] = i*gap
|
43 |
+
return matrix
|
44 |
+
|
45 |
+
|
46 |
+
def get_matrix(size_x, size_y, gap):
|
47 |
+
matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
|
48 |
+
matrix[0, 1:] = (np.arange(size_y) + 1) * gap
|
49 |
+
matrix[1:, 0] = (np.arange(size_x) + 1) * gap
|
50 |
+
return matrix
|
51 |
+
|
52 |
+
|
53 |
+
def get_traceback_matrix(size_x, size_y):
|
54 |
+
matrix = np.zeros((size_x + 1, size_y +1), dtype=np.int32)
|
55 |
+
matrix[0, 1:] = 1
|
56 |
+
matrix[1:, 0] = 2
|
57 |
+
matrix[0, 0] = 4
|
58 |
+
return matrix
|
59 |
+
|
60 |
+
|
61 |
+
def global_align(x, y, score):
|
62 |
+
matrix = get_matrix(len(x), len(y), score.gap)
|
63 |
+
trace_back = get_traceback_matrix(len(x), len(y))
|
64 |
+
for i in range(1, len(x) + 1):
|
65 |
+
for j in range(1, len(y) + 1):
|
66 |
+
left = matrix[i, j - 1] + score.gap
|
67 |
+
up = matrix[i - 1, j] + score.gap
|
68 |
+
diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1])
|
69 |
+
matrix[i, j] = max(left, up, diag)
|
70 |
+
if matrix[i, j] == left:
|
71 |
+
trace_back[i, j] = 1
|
72 |
+
elif matrix[i, j] == up:
|
73 |
+
trace_back[i, j] = 2
|
74 |
+
else:
|
75 |
+
trace_back[i, j] = 3
|
76 |
+
return matrix, trace_back
|
77 |
+
|
78 |
+
|
79 |
+
def get_aligned_sequences(x, y, trace_back):
|
80 |
+
x_seq = []
|
81 |
+
y_seq = []
|
82 |
+
i = len(x)
|
83 |
+
j = len(y)
|
84 |
+
mapper_y_to_x = []
|
85 |
+
while i > 0 or j > 0:
|
86 |
+
if trace_back[i, j] == 3:
|
87 |
+
x_seq.append(x[i-1])
|
88 |
+
y_seq.append(y[j-1])
|
89 |
+
i = i-1
|
90 |
+
j = j-1
|
91 |
+
mapper_y_to_x.append((j, i))
|
92 |
+
elif trace_back[i][j] == 1:
|
93 |
+
x_seq.append('-')
|
94 |
+
y_seq.append(y[j-1])
|
95 |
+
j = j-1
|
96 |
+
mapper_y_to_x.append((j, -1))
|
97 |
+
elif trace_back[i][j] == 2:
|
98 |
+
x_seq.append(x[i-1])
|
99 |
+
y_seq.append('-')
|
100 |
+
i = i-1
|
101 |
+
elif trace_back[i][j] == 4:
|
102 |
+
break
|
103 |
+
mapper_y_to_x.reverse()
|
104 |
+
return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64)
|
105 |
+
|
106 |
+
|
107 |
+
def get_mapper(x: str, y: str, tokenizer, max_len=77):
|
108 |
+
x_seq = tokenizer.encode(x)
|
109 |
+
y_seq = tokenizer.encode(y)
|
110 |
+
score = ScoreParams(0, 1, -1)
|
111 |
+
matrix, trace_back = global_align(x_seq, y_seq, score)
|
112 |
+
mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1]
|
113 |
+
alphas = torch.ones(max_len)
|
114 |
+
alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float()
|
115 |
+
mapper = torch.zeros(max_len, dtype=torch.int64)
|
116 |
+
mapper[:mapper_base.shape[0]] = mapper_base[:, 1]
|
117 |
+
mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq))
|
118 |
+
return mapper, alphas
|
119 |
+
|
120 |
+
|
121 |
+
def get_refinement_mapper(prompts, tokenizer, max_len=77):
|
122 |
+
x_seq = prompts[0]
|
123 |
+
mappers, alphas = [], []
|
124 |
+
for i in range(1, len(prompts)):
|
125 |
+
mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len)
|
126 |
+
mappers.append(mapper)
|
127 |
+
alphas.append(alpha)
|
128 |
+
return torch.stack(mappers), torch.stack(alphas)
|
129 |
+
|
130 |
+
|
131 |
+
def get_word_inds(text: str, word_place: int, tokenizer):
|
132 |
+
split_text = text.split(" ")
|
133 |
+
if type(word_place) is str:
|
134 |
+
word_place = [i for i, word in enumerate(split_text) if word_place == word]
|
135 |
+
elif type(word_place) is int:
|
136 |
+
word_place = [word_place]
|
137 |
+
out = []
|
138 |
+
if len(word_place) > 0:
|
139 |
+
words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
|
140 |
+
cur_len, ptr = 0, 0
|
141 |
+
|
142 |
+
for i in range(len(words_encode)):
|
143 |
+
cur_len += len(words_encode[i])
|
144 |
+
if ptr in word_place:
|
145 |
+
out.append(i + 1)
|
146 |
+
if cur_len >= len(split_text[ptr]):
|
147 |
+
ptr += 1
|
148 |
+
cur_len = 0
|
149 |
+
return np.array(out)
|
150 |
+
|
151 |
+
|
152 |
+
def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77):
|
153 |
+
words_x = x.split(' ')
|
154 |
+
words_y = y.split(' ')
|
155 |
+
if len(words_x) != len(words_y):
|
156 |
+
raise ValueError(f"attention replacement edit can only be applied on prompts with the same length"
|
157 |
+
f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.")
|
158 |
+
inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]]
|
159 |
+
inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace]
|
160 |
+
inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace]
|
161 |
+
mapper = np.zeros((max_len, max_len))
|
162 |
+
i = j = 0
|
163 |
+
cur_inds = 0
|
164 |
+
while i < max_len and j < max_len:
|
165 |
+
if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i:
|
166 |
+
inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds]
|
167 |
+
if len(inds_source_) == len(inds_target_):
|
168 |
+
mapper[inds_source_, inds_target_] = 1
|
169 |
+
else:
|
170 |
+
ratio = 1 / len(inds_target_)
|
171 |
+
for i_t in inds_target_:
|
172 |
+
mapper[inds_source_, i_t] = ratio
|
173 |
+
cur_inds += 1
|
174 |
+
i += len(inds_source_)
|
175 |
+
j += len(inds_target_)
|
176 |
+
elif cur_inds < len(inds_source):
|
177 |
+
mapper[i, j] = 1
|
178 |
+
i += 1
|
179 |
+
j += 1
|
180 |
+
else:
|
181 |
+
mapper[j, j] = 1
|
182 |
+
i += 1
|
183 |
+
j += 1
|
184 |
+
|
185 |
+
return torch.from_numpy(mapper).float()
|
186 |
+
|
187 |
+
|
188 |
+
def get_replacement_mapper(prompts, tokenizer, max_len=77):
|
189 |
+
x_seq = prompts[0]
|
190 |
+
mappers = []
|
191 |
+
for i in range(1, len(prompts)):
|
192 |
+
mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len)
|
193 |
+
mappers.append(mapper)
|
194 |
+
return torch.stack(mappers)
|
195 |
+
|
style.css
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
h1 {
|
2 |
+
text-align: center;
|
3 |
+
}
|
vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|