Upload folder using huggingface_hub
Browse files- .gitattributes +8 -0
- README.md +131 -12
- app.py +98 -0
- assets/0.jpg +0 -0
- assets/2.jpg +0 -0
- assets/3.jpg +0 -0
- assets/comparison.png +3 -0
- assets/example1.png +3 -0
- assets/example2.png +3 -0
- assets/example3.png +3 -0
- assets/pipe.png +0 -0
- assets/subtraction.png +3 -0
- assets/tree.png +3 -0
- flagged/Style Image/4f12bf3724d50ac7ab9b87ce0e3fd4e327ed3ba0/tmp50v2kwjw.png +3 -0
- flagged/log.csv +2 -0
- ip_adapter/__init__.py +9 -0
- ip_adapter/__pycache__/__init__.cpython-310.pyc +0 -0
- ip_adapter/__pycache__/attention_processor.cpython-310.pyc +0 -0
- ip_adapter/__pycache__/ip_adapter.cpython-310.pyc +0 -0
- ip_adapter/__pycache__/resampler.cpython-310.pyc +0 -0
- ip_adapter/__pycache__/utils.cpython-310.pyc +0 -0
- ip_adapter/attention_processor.py +558 -0
- ip_adapter/ip_adapter.py +471 -0
- ip_adapter/resampler.py +158 -0
- ip_adapter/utils.py +93 -0
- models/image_encoder/config.json +81 -0
- models/image_encoder/model.safetensors +3 -0
- models/image_encoder/pytorch_model.bin +3 -0
- models/ip-adapter_sdxl.bin +3 -0
- result.png +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/comparison.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/example1.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/example2.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/example3.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/subtraction.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/tree.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
flagged/Style[[:space:]]Image/4f12bf3724d50ac7ab9b87ce0e3fd4e327ed3ba0/tmp50v2kwjw.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
result.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,12 +1,131 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
<h1>InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation</h1>
|
3 |
+
|
4 |
+
[**Haofan Wang**](https://haofanwang.github.io/)<sup>*</sup> · [Matteo Spinelli](https://github.com/cubiq) · [**Qixun Wang**](https://github.com/wangqixun) · [**Xu Bai**](https://huggingface.co/baymin0220) · [**Zekui Qin**](https://github.com/ZekuiQin) · [**Anthony Chen**](https://antonioo-c.github.io/)
|
5 |
+
|
6 |
+
InstantX Team
|
7 |
+
|
8 |
+
<sup>*</sup>corresponding authors
|
9 |
+
|
10 |
+
<a href='[https://instantid.github.io/](https://instantstyle.github.io/)'><img src='https://img.shields.io/badge/Project-Page-green'></a>
|
11 |
+
<a href='https://arxiv.org/abs/2404.02733'><img src='https://img.shields.io/badge/Technique-Report-red'></a>
|
12 |
+
[](https://github.com/InstantStyle/InstantStyle)
|
13 |
+
|
14 |
+
</div>
|
15 |
+
|
16 |
+
InstantStyle is a general framework that employs two straightforward yet potent techniques for achieving an effective disentanglement of style and content from reference images.
|
17 |
+
|
18 |
+
<img src='assets/pipe.png'>
|
19 |
+
|
20 |
+
## Principle
|
21 |
+
|
22 |
+
Separating Content from Image. Benefit from the good characterization of CLIP global features, after subtracting the content text fea- tures from the image features, the style and content can be explicitly decoupled. Although simple, this strategy is quite effective in mitigating content leakage.
|
23 |
+
<p align="center">
|
24 |
+
<img src="assets/subtraction.png">
|
25 |
+
</p>
|
26 |
+
|
27 |
+
Injecting into Style Blocks Only. Empirically, each layer of a deep network captures different semantic information the key observation in our work is that there exists two specific attention layers handling style. Specifically, we find up blocks.0.attentions.1 and down blocks.2.attentions.1 capture style (color, material, atmosphere) and spatial layout (structure, composition) respectively.
|
28 |
+
<p align="center">
|
29 |
+
<img src="assets/tree.png">
|
30 |
+
</p>
|
31 |
+
|
32 |
+
## Release
|
33 |
+
- [2024/04/03] 🔥 We release the [technical report](https://arxiv.org/abs/2404.02733).
|
34 |
+
|
35 |
+
## Download
|
36 |
+
Follow [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter?tab=readme-ov-file#download-models) to download pre-trained checkpoints.
|
37 |
+
|
38 |
+
## Demos
|
39 |
+
|
40 |
+
### Stylized Synthesis
|
41 |
+
|
42 |
+
<p align="center">
|
43 |
+
<img src="assets/example1.png">
|
44 |
+
<img src="assets/example2.png">
|
45 |
+
</p>
|
46 |
+
|
47 |
+
### Image-based Stylized Synthesis
|
48 |
+
|
49 |
+
<p align="center">
|
50 |
+
<img src="assets/example3.png">
|
51 |
+
</p>
|
52 |
+
|
53 |
+
### Comparison with Previous Works
|
54 |
+
|
55 |
+
<p align="center">
|
56 |
+
<img src="assets/comparison.png">
|
57 |
+
</p>
|
58 |
+
|
59 |
+
## Usage
|
60 |
+
|
61 |
+
Our method is fully compatible with [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter). But for feature subtraction, it only works with IP-Adapter using global embeddings.
|
62 |
+
|
63 |
+
```python
|
64 |
+
import torch
|
65 |
+
from diffusers import StableDiffusionXLPipeline
|
66 |
+
from PIL import Image
|
67 |
+
|
68 |
+
from ip_adapter import IPAdapterXL
|
69 |
+
|
70 |
+
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
71 |
+
image_encoder_path = "sdxl_models/image_encoder"
|
72 |
+
ip_ckpt = "sdxl_models/ip-adapter_sdxl.bin"
|
73 |
+
device = "cuda"
|
74 |
+
|
75 |
+
# load SDXL pipeline
|
76 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(
|
77 |
+
base_model_path,
|
78 |
+
torch_dtype=torch.float16,
|
79 |
+
add_watermarker=False,
|
80 |
+
)
|
81 |
+
|
82 |
+
# load ip-adapter
|
83 |
+
# target_blocks=["blocks"] for original IP-Adapter
|
84 |
+
# target_blocks=["up_blocks.0.attentions.1"] for style blocks only
|
85 |
+
# target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"] # for style+layout blocks
|
86 |
+
ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=["up_blocks.0.attentions.1"])
|
87 |
+
|
88 |
+
image = "./assets/0.jpg"
|
89 |
+
image = Image.open(image)
|
90 |
+
image.resize((512, 512))
|
91 |
+
|
92 |
+
# generate image variations with only image prompt
|
93 |
+
images = ip_model.generate(pil_image=image,
|
94 |
+
prompt="a cat, masterpiece, best quality, high quality",
|
95 |
+
negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
|
96 |
+
scale=1.0,
|
97 |
+
guidance_scale=5,
|
98 |
+
num_samples=1,
|
99 |
+
num_inference_steps=30,
|
100 |
+
seed=42,
|
101 |
+
#neg_content_prompt="a rabbit",
|
102 |
+
#neg_content_scale=0.5,
|
103 |
+
)
|
104 |
+
|
105 |
+
images[0].save("result.png")
|
106 |
+
```
|
107 |
+
|
108 |
+
We will support diffusers API soon.
|
109 |
+
|
110 |
+
## TODO
|
111 |
+
- Support in diffusers API.
|
112 |
+
- Support InstantID.
|
113 |
+
|
114 |
+
## Sponsor Us
|
115 |
+
If you find this project useful, you can buy us a coffee via Github Sponsor! We support [Paypal](https://ko-fi.com/instantx) and [WeChat Pay](https://tinyurl.com/instantx-pay).
|
116 |
+
|
117 |
+
## Cite
|
118 |
+
If you find InstantStyle useful for your research and applications, please cite us using this BibTeX:
|
119 |
+
|
120 |
+
```bibtex
|
121 |
+
@misc{wang2024instantstyle,
|
122 |
+
title={InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation},
|
123 |
+
author={Haofan Wang and Qixun Wang and Xu Bai and Zekui Qin and Anthony Chen},
|
124 |
+
year={2024},
|
125 |
+
eprint={2404.02733},
|
126 |
+
archivePrefix={arXiv},
|
127 |
+
primaryClass={cs.CV}
|
128 |
+
}
|
129 |
+
```
|
130 |
+
|
131 |
+
For any question, please feel free to contact us via [email protected].
|
app.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
donwload_repo_loc= "./models/image_encoder/"
|
4 |
+
os.system("pip install -U peft")
|
5 |
+
# os.system(f"wget -O {donwload_repo_loc}config.json https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/image_encoder/config.json?download=true")
|
6 |
+
# os.system(f"wget -O {donwload_repo_loc}model.safetensors https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/image_encoder/model.safetensors?download=true")
|
7 |
+
# os.system(f"wget -O {donwload_repo_loc}pytorch_model.bin https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/image_encoder/pytorch_model.bin?download=true")
|
8 |
+
|
9 |
+
import space
|
10 |
+
import gradio as gr
|
11 |
+
import torch
|
12 |
+
from diffusers import StableDiffusionXLPipeline
|
13 |
+
from PIL import Image
|
14 |
+
from ip_adapter import IPAdapterXL
|
15 |
+
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
16 |
+
device = "cuda"
|
17 |
+
|
18 |
+
image_encoder_path = donwload_repo_loc #"sdxl_models/image_encoder"
|
19 |
+
ip_ckpt = "./models/ip-adapter_sdxl.bin"
|
20 |
+
# load SDXL pipeline
|
21 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(
|
22 |
+
base_model_path,
|
23 |
+
torch_dtype=torch.float16,
|
24 |
+
add_watermarker=False,
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
# generate image variations with only image prompt
|
29 |
+
@spaces.GPU(enable_queue=True)
|
30 |
+
def create_image(image_pil,target,prompt,n_prompt,scale, guidance_scale,num_samples,num_inference_steps,seed):
|
31 |
+
# load ip-adapter
|
32 |
+
if target =="Load original IP-Adapter":
|
33 |
+
# target_blocks=["blocks"] for original IP-Adapter
|
34 |
+
ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=["blocks"])
|
35 |
+
elif target=="Load only style blocks":
|
36 |
+
# target_blocks=["up_blocks.0.attentions.1"] for style blocks only
|
37 |
+
ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=["up_blocks.0.attentions.1"])
|
38 |
+
elif target == "Load style+layout block":
|
39 |
+
# target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"] # for style+layout blocks
|
40 |
+
ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"])
|
41 |
+
|
42 |
+
|
43 |
+
image_pil=image_pil.resize((512, 512))
|
44 |
+
images = ip_model.generate(pil_image=image_pil,
|
45 |
+
prompt=prompt,
|
46 |
+
negative_prompt=n_prompt,
|
47 |
+
scale=scale,
|
48 |
+
guidance_scale=guidance_scale,
|
49 |
+
num_samples=num_samples,
|
50 |
+
num_inference_steps=num_inference_steps,
|
51 |
+
seed=seed,
|
52 |
+
#neg_content_prompt="a rabbit",
|
53 |
+
#neg_content_scale=0.5,
|
54 |
+
)
|
55 |
+
|
56 |
+
# images[0].save("result.png")
|
57 |
+
del ip_model
|
58 |
+
|
59 |
+
return images
|
60 |
+
|
61 |
+
|
62 |
+
DESCRIPTION = """
|
63 |
+
# Res-Adapter :Domain Consistent Resolution Adapter for Diffusion Models
|
64 |
+
**Demo by [ameer azam] - [Twitter](https://twitter.com/Ameerazam18) - [GitHub](https://github.com/AMEERAZAM08)) - [Hugging Face](https://huggingface.co/ameerazam08)**
|
65 |
+
This is a demo of https://huggingface.co/jiaxiangc/res-adapter ResAdapter by ByteDance.
|
66 |
+
ByteDance provide a demo of [ResAdapter](https://huggingface.co/jiaxiangc/res-adapter) with [SDXL-Lightning-Step4](https://huggingface.co/ByteDance/SDXL-Lightning) to expand resolution range from 1024-only to 256~1024.
|
67 |
+
"""
|
68 |
+
|
69 |
+
block = gr.Blocks(css="footer {visibility: hidden}").queue()
|
70 |
+
with block:
|
71 |
+
with gr.Row():
|
72 |
+
|
73 |
+
with gr.Column():
|
74 |
+
gr.Markdown("## <h1 align='center'>InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation </h1>")
|
75 |
+
gr.Markdown(DESCRIPTION)
|
76 |
+
with gr.Tabs():
|
77 |
+
with gr.Row():
|
78 |
+
with gr.Column():
|
79 |
+
image_pil = gr.Image(label="Style Image", type='pil')
|
80 |
+
target = gr.Dropdown(["Load original IP-Adapter","Load only style blocks","Load style+layout block"], label="LORA Model", info="Which finetuned model to use?")
|
81 |
+
prompt = gr.Textbox(label="Prompt",value="a cat, masterpiece, best quality, high quality")
|
82 |
+
n_prompt = gr.Textbox(label="Neg Prompt",value="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry")
|
83 |
+
scale = gr.Slider(minimum=0,maximum=2.0, step=0.01,value=1.0, label="scale")
|
84 |
+
guidance_scale = gr.Slider(minimum=1,maximum=15.0, step=0.01,value=5.0, label="guidance_scale")
|
85 |
+
num_samples= gr.Slider(minimum=1,maximum=3.0, step=1.0,value=1.0, label="num_samples")
|
86 |
+
num_inference_steps = gr.Slider(minimum=5,maximum=50.0, step=1.0,value=30, label="num_inference_steps")
|
87 |
+
seed = gr.Slider(minimum=-1000000,maximum=1000000,value=1, step=1, label="Seed Value")
|
88 |
+
generate_button = gr.Button("Generate Image")
|
89 |
+
with gr.Column():
|
90 |
+
generated_image = gr.Gallery(label="Generated Image")
|
91 |
+
|
92 |
+
generate_button.click(fn=create_image, inputs=[image_pil,target,prompt,n_prompt,scale, guidance_scale,num_samples,num_inference_steps,seed],
|
93 |
+
outputs=[generated_image])
|
94 |
+
|
95 |
+
block.launch(max_threads=10)
|
96 |
+
|
97 |
+
|
98 |
+
|
assets/0.jpg
ADDED
![]() |
assets/2.jpg
ADDED
![]() |
assets/3.jpg
ADDED
![]() |
assets/comparison.png
ADDED
![]() |
Git LFS Details
|
assets/example1.png
ADDED
![]() |
Git LFS Details
|
assets/example2.png
ADDED
![]() |
Git LFS Details
|
assets/example3.png
ADDED
![]() |
Git LFS Details
|
assets/pipe.png
ADDED
![]() |
assets/subtraction.png
ADDED
![]() |
Git LFS Details
|
assets/tree.png
ADDED
![]() |
Git LFS Details
|
flagged/Style Image/4f12bf3724d50ac7ab9b87ce0e3fd4e327ed3ba0/tmp50v2kwjw.png
ADDED
![]() |
Git LFS Details
|
flagged/log.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
Style Image,Prompt,Negative Prompt,Scale,guidance_scale,num_samples,num_inference_steps,Seed Value,Processed Image,flag,username,timestamp
|
2 |
+
/home/rnd/Documents/Ameer/InstantStyle/flagged/Style Image/4f12bf3724d50ac7ab9b87ce0e3fd4e327ed3ba0/tmp50v2kwjw.png,dfgdfgdf,"text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",1,5,1,30,1,,,,2024-04-05 00:34:42.130755
|
ip_adapter/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull
|
2 |
+
|
3 |
+
__all__ = [
|
4 |
+
"IPAdapter",
|
5 |
+
"IPAdapterPlus",
|
6 |
+
"IPAdapterPlusXL",
|
7 |
+
"IPAdapterXL",
|
8 |
+
"IPAdapterFull",
|
9 |
+
]
|
ip_adapter/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (313 Bytes). View file
|
|
ip_adapter/__pycache__/attention_processor.cpython-310.pyc
ADDED
Binary file (9.93 kB). View file
|
|
ip_adapter/__pycache__/ip_adapter.cpython-310.pyc
ADDED
Binary file (11.4 kB). View file
|
|
ip_adapter/__pycache__/resampler.cpython-310.pyc
ADDED
Binary file (4.26 kB). View file
|
|
ip_adapter/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (2.83 kB). View file
|
|
ip_adapter/attention_processor.py
ADDED
@@ -0,0 +1,558 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class AttnProcessor(nn.Module):
|
8 |
+
r"""
|
9 |
+
Default processor for performing attention-related computations.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
hidden_size=None,
|
15 |
+
cross_attention_dim=None,
|
16 |
+
):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
def __call__(
|
20 |
+
self,
|
21 |
+
attn,
|
22 |
+
hidden_states,
|
23 |
+
encoder_hidden_states=None,
|
24 |
+
attention_mask=None,
|
25 |
+
temb=None,
|
26 |
+
):
|
27 |
+
residual = hidden_states
|
28 |
+
|
29 |
+
if attn.spatial_norm is not None:
|
30 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
31 |
+
|
32 |
+
input_ndim = hidden_states.ndim
|
33 |
+
|
34 |
+
if input_ndim == 4:
|
35 |
+
batch_size, channel, height, width = hidden_states.shape
|
36 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
37 |
+
|
38 |
+
batch_size, sequence_length, _ = (
|
39 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
40 |
+
)
|
41 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
42 |
+
|
43 |
+
if attn.group_norm is not None:
|
44 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
45 |
+
|
46 |
+
query = attn.to_q(hidden_states)
|
47 |
+
|
48 |
+
if encoder_hidden_states is None:
|
49 |
+
encoder_hidden_states = hidden_states
|
50 |
+
elif attn.norm_cross:
|
51 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
52 |
+
|
53 |
+
key = attn.to_k(encoder_hidden_states)
|
54 |
+
value = attn.to_v(encoder_hidden_states)
|
55 |
+
|
56 |
+
query = attn.head_to_batch_dim(query)
|
57 |
+
key = attn.head_to_batch_dim(key)
|
58 |
+
value = attn.head_to_batch_dim(value)
|
59 |
+
|
60 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
61 |
+
hidden_states = torch.bmm(attention_probs, value)
|
62 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
63 |
+
|
64 |
+
# linear proj
|
65 |
+
hidden_states = attn.to_out[0](hidden_states)
|
66 |
+
# dropout
|
67 |
+
hidden_states = attn.to_out[1](hidden_states)
|
68 |
+
|
69 |
+
if input_ndim == 4:
|
70 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
71 |
+
|
72 |
+
if attn.residual_connection:
|
73 |
+
hidden_states = hidden_states + residual
|
74 |
+
|
75 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
76 |
+
|
77 |
+
return hidden_states
|
78 |
+
|
79 |
+
|
80 |
+
class IPAttnProcessor(nn.Module):
|
81 |
+
r"""
|
82 |
+
Attention processor for IP-Adapater.
|
83 |
+
Args:
|
84 |
+
hidden_size (`int`):
|
85 |
+
The hidden size of the attention layer.
|
86 |
+
cross_attention_dim (`int`):
|
87 |
+
The number of channels in the `encoder_hidden_states`.
|
88 |
+
scale (`float`, defaults to 1.0):
|
89 |
+
the weight scale of image prompt.
|
90 |
+
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
91 |
+
The context length of the image features.
|
92 |
+
"""
|
93 |
+
|
94 |
+
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
|
95 |
+
super().__init__()
|
96 |
+
|
97 |
+
self.hidden_size = hidden_size
|
98 |
+
self.cross_attention_dim = cross_attention_dim
|
99 |
+
self.scale = scale
|
100 |
+
self.num_tokens = num_tokens
|
101 |
+
|
102 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
103 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
104 |
+
|
105 |
+
def __call__(
|
106 |
+
self,
|
107 |
+
attn,
|
108 |
+
hidden_states,
|
109 |
+
encoder_hidden_states=None,
|
110 |
+
attention_mask=None,
|
111 |
+
temb=None,
|
112 |
+
):
|
113 |
+
residual = hidden_states
|
114 |
+
|
115 |
+
if attn.spatial_norm is not None:
|
116 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
117 |
+
|
118 |
+
input_ndim = hidden_states.ndim
|
119 |
+
|
120 |
+
if input_ndim == 4:
|
121 |
+
batch_size, channel, height, width = hidden_states.shape
|
122 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
123 |
+
|
124 |
+
batch_size, sequence_length, _ = (
|
125 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
126 |
+
)
|
127 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
128 |
+
|
129 |
+
if attn.group_norm is not None:
|
130 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
131 |
+
|
132 |
+
query = attn.to_q(hidden_states)
|
133 |
+
|
134 |
+
if encoder_hidden_states is None:
|
135 |
+
encoder_hidden_states = hidden_states
|
136 |
+
else:
|
137 |
+
# get encoder_hidden_states, ip_hidden_states
|
138 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
139 |
+
encoder_hidden_states, ip_hidden_states = (
|
140 |
+
encoder_hidden_states[:, :end_pos, :],
|
141 |
+
encoder_hidden_states[:, end_pos:, :],
|
142 |
+
)
|
143 |
+
if attn.norm_cross:
|
144 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
145 |
+
|
146 |
+
key = attn.to_k(encoder_hidden_states)
|
147 |
+
value = attn.to_v(encoder_hidden_states)
|
148 |
+
|
149 |
+
query = attn.head_to_batch_dim(query)
|
150 |
+
key = attn.head_to_batch_dim(key)
|
151 |
+
value = attn.head_to_batch_dim(value)
|
152 |
+
|
153 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
154 |
+
hidden_states = torch.bmm(attention_probs, value)
|
155 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
156 |
+
|
157 |
+
# for ip-adapter
|
158 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
159 |
+
ip_value = self.to_v_ip(ip_hidden_states)
|
160 |
+
|
161 |
+
ip_key = attn.head_to_batch_dim(ip_key)
|
162 |
+
ip_value = attn.head_to_batch_dim(ip_value)
|
163 |
+
|
164 |
+
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
165 |
+
self.attn_map = ip_attention_probs
|
166 |
+
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
167 |
+
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
168 |
+
|
169 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
170 |
+
|
171 |
+
# linear proj
|
172 |
+
hidden_states = attn.to_out[0](hidden_states)
|
173 |
+
# dropout
|
174 |
+
hidden_states = attn.to_out[1](hidden_states)
|
175 |
+
|
176 |
+
if input_ndim == 4:
|
177 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
178 |
+
|
179 |
+
if attn.residual_connection:
|
180 |
+
hidden_states = hidden_states + residual
|
181 |
+
|
182 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
183 |
+
|
184 |
+
return hidden_states
|
185 |
+
|
186 |
+
|
187 |
+
class AttnProcessor2_0(torch.nn.Module):
|
188 |
+
r"""
|
189 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
190 |
+
"""
|
191 |
+
|
192 |
+
def __init__(
|
193 |
+
self,
|
194 |
+
hidden_size=None,
|
195 |
+
cross_attention_dim=None,
|
196 |
+
):
|
197 |
+
super().__init__()
|
198 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
199 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
200 |
+
|
201 |
+
def __call__(
|
202 |
+
self,
|
203 |
+
attn,
|
204 |
+
hidden_states,
|
205 |
+
encoder_hidden_states=None,
|
206 |
+
attention_mask=None,
|
207 |
+
temb=None,
|
208 |
+
):
|
209 |
+
residual = hidden_states
|
210 |
+
|
211 |
+
if attn.spatial_norm is not None:
|
212 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
213 |
+
|
214 |
+
input_ndim = hidden_states.ndim
|
215 |
+
|
216 |
+
if input_ndim == 4:
|
217 |
+
batch_size, channel, height, width = hidden_states.shape
|
218 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
219 |
+
|
220 |
+
batch_size, sequence_length, _ = (
|
221 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
222 |
+
)
|
223 |
+
|
224 |
+
if attention_mask is not None:
|
225 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
226 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
227 |
+
# (batch, heads, source_length, target_length)
|
228 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
229 |
+
|
230 |
+
if attn.group_norm is not None:
|
231 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
232 |
+
|
233 |
+
query = attn.to_q(hidden_states)
|
234 |
+
|
235 |
+
if encoder_hidden_states is None:
|
236 |
+
encoder_hidden_states = hidden_states
|
237 |
+
elif attn.norm_cross:
|
238 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
239 |
+
|
240 |
+
key = attn.to_k(encoder_hidden_states)
|
241 |
+
value = attn.to_v(encoder_hidden_states)
|
242 |
+
|
243 |
+
inner_dim = key.shape[-1]
|
244 |
+
head_dim = inner_dim // attn.heads
|
245 |
+
|
246 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
247 |
+
|
248 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
249 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
250 |
+
|
251 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
252 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
253 |
+
hidden_states = F.scaled_dot_product_attention(
|
254 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
255 |
+
)
|
256 |
+
|
257 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
258 |
+
hidden_states = hidden_states.to(query.dtype)
|
259 |
+
|
260 |
+
# linear proj
|
261 |
+
hidden_states = attn.to_out[0](hidden_states)
|
262 |
+
# dropout
|
263 |
+
hidden_states = attn.to_out[1](hidden_states)
|
264 |
+
|
265 |
+
if input_ndim == 4:
|
266 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
267 |
+
|
268 |
+
if attn.residual_connection:
|
269 |
+
hidden_states = hidden_states + residual
|
270 |
+
|
271 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
272 |
+
|
273 |
+
return hidden_states
|
274 |
+
|
275 |
+
|
276 |
+
class IPAttnProcessor2_0(torch.nn.Module):
|
277 |
+
r"""
|
278 |
+
Attention processor for IP-Adapater for PyTorch 2.0.
|
279 |
+
Args:
|
280 |
+
hidden_size (`int`):
|
281 |
+
The hidden size of the attention layer.
|
282 |
+
cross_attention_dim (`int`):
|
283 |
+
The number of channels in the `encoder_hidden_states`.
|
284 |
+
scale (`float`, defaults to 1.0):
|
285 |
+
the weight scale of image prompt.
|
286 |
+
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
287 |
+
The context length of the image features.
|
288 |
+
"""
|
289 |
+
|
290 |
+
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
|
291 |
+
super().__init__()
|
292 |
+
|
293 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
294 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
295 |
+
|
296 |
+
self.hidden_size = hidden_size
|
297 |
+
self.cross_attention_dim = cross_attention_dim
|
298 |
+
self.scale = scale
|
299 |
+
self.num_tokens = num_tokens
|
300 |
+
|
301 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
302 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
303 |
+
|
304 |
+
def __call__(
|
305 |
+
self,
|
306 |
+
attn,
|
307 |
+
hidden_states,
|
308 |
+
encoder_hidden_states=None,
|
309 |
+
attention_mask=None,
|
310 |
+
temb=None,
|
311 |
+
):
|
312 |
+
residual = hidden_states
|
313 |
+
|
314 |
+
if attn.spatial_norm is not None:
|
315 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
316 |
+
|
317 |
+
input_ndim = hidden_states.ndim
|
318 |
+
|
319 |
+
if input_ndim == 4:
|
320 |
+
batch_size, channel, height, width = hidden_states.shape
|
321 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
322 |
+
|
323 |
+
batch_size, sequence_length, _ = (
|
324 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
325 |
+
)
|
326 |
+
|
327 |
+
if attention_mask is not None:
|
328 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
329 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
330 |
+
# (batch, heads, source_length, target_length)
|
331 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
332 |
+
|
333 |
+
if attn.group_norm is not None:
|
334 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
335 |
+
|
336 |
+
query = attn.to_q(hidden_states)
|
337 |
+
|
338 |
+
if encoder_hidden_states is None:
|
339 |
+
encoder_hidden_states = hidden_states
|
340 |
+
else:
|
341 |
+
# get encoder_hidden_states, ip_hidden_states
|
342 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
343 |
+
encoder_hidden_states, ip_hidden_states = (
|
344 |
+
encoder_hidden_states[:, :end_pos, :],
|
345 |
+
encoder_hidden_states[:, end_pos:, :],
|
346 |
+
)
|
347 |
+
if attn.norm_cross:
|
348 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
349 |
+
|
350 |
+
key = attn.to_k(encoder_hidden_states)
|
351 |
+
value = attn.to_v(encoder_hidden_states)
|
352 |
+
|
353 |
+
inner_dim = key.shape[-1]
|
354 |
+
head_dim = inner_dim // attn.heads
|
355 |
+
|
356 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
357 |
+
|
358 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
359 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
360 |
+
|
361 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
362 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
363 |
+
hidden_states = F.scaled_dot_product_attention(
|
364 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
365 |
+
)
|
366 |
+
|
367 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
368 |
+
hidden_states = hidden_states.to(query.dtype)
|
369 |
+
|
370 |
+
# for ip-adapter
|
371 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
372 |
+
ip_value = self.to_v_ip(ip_hidden_states)
|
373 |
+
|
374 |
+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
375 |
+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
376 |
+
|
377 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
378 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
379 |
+
ip_hidden_states = F.scaled_dot_product_attention(
|
380 |
+
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
381 |
+
)
|
382 |
+
with torch.no_grad():
|
383 |
+
self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
|
384 |
+
#print(self.attn_map.shape)
|
385 |
+
|
386 |
+
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
387 |
+
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
388 |
+
|
389 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
390 |
+
|
391 |
+
# linear proj
|
392 |
+
hidden_states = attn.to_out[0](hidden_states)
|
393 |
+
# dropout
|
394 |
+
hidden_states = attn.to_out[1](hidden_states)
|
395 |
+
|
396 |
+
if input_ndim == 4:
|
397 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
398 |
+
|
399 |
+
if attn.residual_connection:
|
400 |
+
hidden_states = hidden_states + residual
|
401 |
+
|
402 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
403 |
+
|
404 |
+
return hidden_states
|
405 |
+
|
406 |
+
|
407 |
+
## for controlnet
|
408 |
+
class CNAttnProcessor:
|
409 |
+
r"""
|
410 |
+
Default processor for performing attention-related computations.
|
411 |
+
"""
|
412 |
+
|
413 |
+
def __init__(self, num_tokens=4):
|
414 |
+
self.num_tokens = num_tokens
|
415 |
+
|
416 |
+
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
|
417 |
+
residual = hidden_states
|
418 |
+
|
419 |
+
if attn.spatial_norm is not None:
|
420 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
421 |
+
|
422 |
+
input_ndim = hidden_states.ndim
|
423 |
+
|
424 |
+
if input_ndim == 4:
|
425 |
+
batch_size, channel, height, width = hidden_states.shape
|
426 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
427 |
+
|
428 |
+
batch_size, sequence_length, _ = (
|
429 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
430 |
+
)
|
431 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
432 |
+
|
433 |
+
if attn.group_norm is not None:
|
434 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
435 |
+
|
436 |
+
query = attn.to_q(hidden_states)
|
437 |
+
|
438 |
+
if encoder_hidden_states is None:
|
439 |
+
encoder_hidden_states = hidden_states
|
440 |
+
else:
|
441 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
442 |
+
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
|
443 |
+
if attn.norm_cross:
|
444 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
445 |
+
|
446 |
+
key = attn.to_k(encoder_hidden_states)
|
447 |
+
value = attn.to_v(encoder_hidden_states)
|
448 |
+
|
449 |
+
query = attn.head_to_batch_dim(query)
|
450 |
+
key = attn.head_to_batch_dim(key)
|
451 |
+
value = attn.head_to_batch_dim(value)
|
452 |
+
|
453 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
454 |
+
hidden_states = torch.bmm(attention_probs, value)
|
455 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
456 |
+
|
457 |
+
# linear proj
|
458 |
+
hidden_states = attn.to_out[0](hidden_states)
|
459 |
+
# dropout
|
460 |
+
hidden_states = attn.to_out[1](hidden_states)
|
461 |
+
|
462 |
+
if input_ndim == 4:
|
463 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
464 |
+
|
465 |
+
if attn.residual_connection:
|
466 |
+
hidden_states = hidden_states + residual
|
467 |
+
|
468 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
469 |
+
|
470 |
+
return hidden_states
|
471 |
+
|
472 |
+
|
473 |
+
class CNAttnProcessor2_0:
|
474 |
+
r"""
|
475 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
476 |
+
"""
|
477 |
+
|
478 |
+
def __init__(self, num_tokens=4):
|
479 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
480 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
481 |
+
self.num_tokens = num_tokens
|
482 |
+
|
483 |
+
def __call__(
|
484 |
+
self,
|
485 |
+
attn,
|
486 |
+
hidden_states,
|
487 |
+
encoder_hidden_states=None,
|
488 |
+
attention_mask=None,
|
489 |
+
temb=None,
|
490 |
+
):
|
491 |
+
residual = hidden_states
|
492 |
+
|
493 |
+
if attn.spatial_norm is not None:
|
494 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
495 |
+
|
496 |
+
input_ndim = hidden_states.ndim
|
497 |
+
|
498 |
+
if input_ndim == 4:
|
499 |
+
batch_size, channel, height, width = hidden_states.shape
|
500 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
501 |
+
|
502 |
+
batch_size, sequence_length, _ = (
|
503 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
504 |
+
)
|
505 |
+
|
506 |
+
if attention_mask is not None:
|
507 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
508 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
509 |
+
# (batch, heads, source_length, target_length)
|
510 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
511 |
+
|
512 |
+
if attn.group_norm is not None:
|
513 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
514 |
+
|
515 |
+
query = attn.to_q(hidden_states)
|
516 |
+
|
517 |
+
if encoder_hidden_states is None:
|
518 |
+
encoder_hidden_states = hidden_states
|
519 |
+
else:
|
520 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
521 |
+
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
|
522 |
+
if attn.norm_cross:
|
523 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
524 |
+
|
525 |
+
key = attn.to_k(encoder_hidden_states)
|
526 |
+
value = attn.to_v(encoder_hidden_states)
|
527 |
+
|
528 |
+
inner_dim = key.shape[-1]
|
529 |
+
head_dim = inner_dim // attn.heads
|
530 |
+
|
531 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
532 |
+
|
533 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
534 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
535 |
+
|
536 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
537 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
538 |
+
hidden_states = F.scaled_dot_product_attention(
|
539 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
540 |
+
)
|
541 |
+
|
542 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
543 |
+
hidden_states = hidden_states.to(query.dtype)
|
544 |
+
|
545 |
+
# linear proj
|
546 |
+
hidden_states = attn.to_out[0](hidden_states)
|
547 |
+
# dropout
|
548 |
+
hidden_states = attn.to_out[1](hidden_states)
|
549 |
+
|
550 |
+
if input_ndim == 4:
|
551 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
552 |
+
|
553 |
+
if attn.residual_connection:
|
554 |
+
hidden_states = hidden_states + residual
|
555 |
+
|
556 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
557 |
+
|
558 |
+
return hidden_states
|
ip_adapter/ip_adapter.py
ADDED
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from diffusers import StableDiffusionPipeline
|
6 |
+
from diffusers.pipelines.controlnet import MultiControlNetModel
|
7 |
+
from PIL import Image
|
8 |
+
from safetensors import safe_open
|
9 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
10 |
+
|
11 |
+
from .utils import is_torch2_available, get_generator
|
12 |
+
|
13 |
+
if is_torch2_available():
|
14 |
+
from .attention_processor import (
|
15 |
+
AttnProcessor2_0 as AttnProcessor,
|
16 |
+
)
|
17 |
+
from .attention_processor import (
|
18 |
+
CNAttnProcessor2_0 as CNAttnProcessor,
|
19 |
+
)
|
20 |
+
from .attention_processor import (
|
21 |
+
IPAttnProcessor2_0 as IPAttnProcessor,
|
22 |
+
)
|
23 |
+
else:
|
24 |
+
from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
|
25 |
+
from .resampler import Resampler
|
26 |
+
|
27 |
+
|
28 |
+
class ImageProjModel(torch.nn.Module):
|
29 |
+
"""Projection Model"""
|
30 |
+
|
31 |
+
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
32 |
+
super().__init__()
|
33 |
+
|
34 |
+
self.generator = None
|
35 |
+
self.cross_attention_dim = cross_attention_dim
|
36 |
+
self.clip_extra_context_tokens = clip_extra_context_tokens
|
37 |
+
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
38 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
39 |
+
|
40 |
+
def forward(self, image_embeds):
|
41 |
+
embeds = image_embeds
|
42 |
+
clip_extra_context_tokens = self.proj(embeds).reshape(
|
43 |
+
-1, self.clip_extra_context_tokens, self.cross_attention_dim
|
44 |
+
)
|
45 |
+
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
46 |
+
return clip_extra_context_tokens
|
47 |
+
|
48 |
+
|
49 |
+
class MLPProjModel(torch.nn.Module):
|
50 |
+
"""SD model with image prompt"""
|
51 |
+
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
|
52 |
+
super().__init__()
|
53 |
+
|
54 |
+
self.proj = torch.nn.Sequential(
|
55 |
+
torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
|
56 |
+
torch.nn.GELU(),
|
57 |
+
torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
|
58 |
+
torch.nn.LayerNorm(cross_attention_dim)
|
59 |
+
)
|
60 |
+
|
61 |
+
def forward(self, image_embeds):
|
62 |
+
clip_extra_context_tokens = self.proj(image_embeds)
|
63 |
+
return clip_extra_context_tokens
|
64 |
+
|
65 |
+
|
66 |
+
class IPAdapter:
|
67 |
+
def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, target_blocks=["blocks"]):
|
68 |
+
self.device = device
|
69 |
+
self.image_encoder_path = image_encoder_path
|
70 |
+
self.ip_ckpt = ip_ckpt
|
71 |
+
self.num_tokens = num_tokens
|
72 |
+
self.target_blocks = target_blocks
|
73 |
+
|
74 |
+
self.pipe = sd_pipe.to(self.device)
|
75 |
+
self.set_ip_adapter()
|
76 |
+
|
77 |
+
# load image encoder
|
78 |
+
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
|
79 |
+
self.device, dtype=torch.float16
|
80 |
+
)
|
81 |
+
self.clip_image_processor = CLIPImageProcessor()
|
82 |
+
# image proj model
|
83 |
+
self.image_proj_model = self.init_proj()
|
84 |
+
|
85 |
+
self.load_ip_adapter()
|
86 |
+
|
87 |
+
def init_proj(self):
|
88 |
+
image_proj_model = ImageProjModel(
|
89 |
+
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
|
90 |
+
clip_embeddings_dim=self.image_encoder.config.projection_dim,
|
91 |
+
clip_extra_context_tokens=self.num_tokens,
|
92 |
+
).to(self.device, dtype=torch.float16)
|
93 |
+
return image_proj_model
|
94 |
+
|
95 |
+
def set_ip_adapter(self):
|
96 |
+
unet = self.pipe.unet
|
97 |
+
attn_procs = {}
|
98 |
+
for name in unet.attn_processors.keys():
|
99 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
100 |
+
if name.startswith("mid_block"):
|
101 |
+
hidden_size = unet.config.block_out_channels[-1]
|
102 |
+
elif name.startswith("up_blocks"):
|
103 |
+
block_id = int(name[len("up_blocks.")])
|
104 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
105 |
+
elif name.startswith("down_blocks"):
|
106 |
+
block_id = int(name[len("down_blocks.")])
|
107 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
108 |
+
if cross_attention_dim is None:
|
109 |
+
attn_procs[name] = AttnProcessor()
|
110 |
+
else:
|
111 |
+
selected = False
|
112 |
+
for block_name in self.target_blocks:
|
113 |
+
if block_name in name:
|
114 |
+
selected = True
|
115 |
+
break
|
116 |
+
if selected:
|
117 |
+
attn_procs[name] = IPAttnProcessor(
|
118 |
+
hidden_size=hidden_size,
|
119 |
+
cross_attention_dim=cross_attention_dim,
|
120 |
+
scale=1.0,
|
121 |
+
num_tokens=self.num_tokens,
|
122 |
+
).to(self.device, dtype=torch.float16)
|
123 |
+
else:
|
124 |
+
attn_procs[name] = AttnProcessor(
|
125 |
+
hidden_size=hidden_size,
|
126 |
+
cross_attention_dim=cross_attention_dim,
|
127 |
+
).to(self.device, dtype=torch.float16)
|
128 |
+
unet.set_attn_processor(attn_procs)
|
129 |
+
if hasattr(self.pipe, "controlnet"):
|
130 |
+
if isinstance(self.pipe.controlnet, MultiControlNetModel):
|
131 |
+
for controlnet in self.pipe.controlnet.nets:
|
132 |
+
controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
|
133 |
+
else:
|
134 |
+
self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
|
135 |
+
|
136 |
+
def load_ip_adapter(self):
|
137 |
+
if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
|
138 |
+
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
139 |
+
with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
|
140 |
+
for key in f.keys():
|
141 |
+
if key.startswith("image_proj."):
|
142 |
+
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
143 |
+
elif key.startswith("ip_adapter."):
|
144 |
+
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
145 |
+
else:
|
146 |
+
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
|
147 |
+
self.image_proj_model.load_state_dict(state_dict["image_proj"])
|
148 |
+
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
|
149 |
+
ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
|
150 |
+
|
151 |
+
@torch.inference_mode()
|
152 |
+
def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None):
|
153 |
+
if pil_image is not None:
|
154 |
+
if isinstance(pil_image, Image.Image):
|
155 |
+
pil_image = [pil_image]
|
156 |
+
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
157 |
+
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
|
158 |
+
else:
|
159 |
+
clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
|
160 |
+
|
161 |
+
if content_prompt_embeds is not None:
|
162 |
+
clip_image_embeds = clip_image_embeds - content_prompt_embeds
|
163 |
+
|
164 |
+
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
165 |
+
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
|
166 |
+
return image_prompt_embeds, uncond_image_prompt_embeds
|
167 |
+
|
168 |
+
def set_scale(self, scale):
|
169 |
+
for attn_processor in self.pipe.unet.attn_processors.values():
|
170 |
+
if isinstance(attn_processor, IPAttnProcessor):
|
171 |
+
attn_processor.scale = scale
|
172 |
+
|
173 |
+
def generate(
|
174 |
+
self,
|
175 |
+
pil_image=None,
|
176 |
+
clip_image_embeds=None,
|
177 |
+
prompt=None,
|
178 |
+
negative_prompt=None,
|
179 |
+
scale=1.0,
|
180 |
+
num_samples=4,
|
181 |
+
seed=None,
|
182 |
+
guidance_scale=7.5,
|
183 |
+
num_inference_steps=30,
|
184 |
+
neg_content_prompt=None,
|
185 |
+
neg_content_scale=1.0,
|
186 |
+
**kwargs,
|
187 |
+
):
|
188 |
+
self.set_scale(scale)
|
189 |
+
|
190 |
+
if pil_image is not None:
|
191 |
+
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
|
192 |
+
else:
|
193 |
+
num_prompts = clip_image_embeds.size(0)
|
194 |
+
|
195 |
+
if prompt is None:
|
196 |
+
prompt = "best quality, high quality"
|
197 |
+
if negative_prompt is None:
|
198 |
+
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
199 |
+
|
200 |
+
if not isinstance(prompt, List):
|
201 |
+
prompt = [prompt] * num_prompts
|
202 |
+
if not isinstance(negative_prompt, List):
|
203 |
+
negative_prompt = [negative_prompt] * num_prompts
|
204 |
+
|
205 |
+
if neg_content_prompt is not None:
|
206 |
+
with torch.inference_mode():
|
207 |
+
(
|
208 |
+
prompt_embeds_, # torch.Size([1, 77, 2048])
|
209 |
+
negative_prompt_embeds_,
|
210 |
+
pooled_prompt_embeds_, # torch.Size([1, 1280])
|
211 |
+
negative_pooled_prompt_embeds_,
|
212 |
+
) = self.pipe.encode_prompt(
|
213 |
+
neg_content_prompt,
|
214 |
+
num_images_per_prompt=num_samples,
|
215 |
+
do_classifier_free_guidance=True,
|
216 |
+
negative_prompt=negative_prompt,
|
217 |
+
)
|
218 |
+
pooled_prompt_embeds_ *= neg_content_scale
|
219 |
+
else:
|
220 |
+
pooled_prompt_embeds_ = None
|
221 |
+
|
222 |
+
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
|
223 |
+
pil_image=pil_image, clip_image_embeds=clip_image_embeds, content_prompt_embeds=pooled_prompt_embeds_
|
224 |
+
)
|
225 |
+
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
226 |
+
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
227 |
+
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
228 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
229 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
230 |
+
|
231 |
+
with torch.inference_mode():
|
232 |
+
prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
|
233 |
+
prompt,
|
234 |
+
device=self.device,
|
235 |
+
num_images_per_prompt=num_samples,
|
236 |
+
do_classifier_free_guidance=True,
|
237 |
+
negative_prompt=negative_prompt,
|
238 |
+
)
|
239 |
+
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
240 |
+
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
241 |
+
|
242 |
+
generator = get_generator(seed, self.device)
|
243 |
+
|
244 |
+
images = self.pipe(
|
245 |
+
prompt_embeds=prompt_embeds,
|
246 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
247 |
+
guidance_scale=guidance_scale,
|
248 |
+
num_inference_steps=num_inference_steps,
|
249 |
+
generator=generator,
|
250 |
+
**kwargs,
|
251 |
+
).images
|
252 |
+
|
253 |
+
return images
|
254 |
+
|
255 |
+
|
256 |
+
class IPAdapterXL(IPAdapter):
|
257 |
+
"""SDXL"""
|
258 |
+
|
259 |
+
def generate(
|
260 |
+
self,
|
261 |
+
pil_image,
|
262 |
+
prompt=None,
|
263 |
+
negative_prompt=None,
|
264 |
+
scale=1.0,
|
265 |
+
num_samples=4,
|
266 |
+
seed=None,
|
267 |
+
num_inference_steps=30,
|
268 |
+
neg_content_prompt=None,
|
269 |
+
neg_content_scale=1.0,
|
270 |
+
**kwargs,
|
271 |
+
):
|
272 |
+
self.set_scale(scale)
|
273 |
+
|
274 |
+
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
|
275 |
+
|
276 |
+
if prompt is None:
|
277 |
+
prompt = "best quality, high quality"
|
278 |
+
if negative_prompt is None:
|
279 |
+
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
280 |
+
|
281 |
+
if not isinstance(prompt, List):
|
282 |
+
prompt = [prompt] * num_prompts
|
283 |
+
if not isinstance(negative_prompt, List):
|
284 |
+
negative_prompt = [negative_prompt] * num_prompts
|
285 |
+
|
286 |
+
if neg_content_prompt is not None:
|
287 |
+
with torch.inference_mode():
|
288 |
+
(
|
289 |
+
prompt_embeds_, # torch.Size([1, 77, 2048])
|
290 |
+
negative_prompt_embeds_,
|
291 |
+
pooled_prompt_embeds_, # torch.Size([1, 1280])
|
292 |
+
negative_pooled_prompt_embeds_,
|
293 |
+
) = self.pipe.encode_prompt(
|
294 |
+
neg_content_prompt,
|
295 |
+
num_images_per_prompt=num_samples,
|
296 |
+
do_classifier_free_guidance=True,
|
297 |
+
negative_prompt=negative_prompt,
|
298 |
+
)
|
299 |
+
pooled_prompt_embeds_ *= neg_content_scale
|
300 |
+
else:
|
301 |
+
pooled_prompt_embeds_ = None
|
302 |
+
|
303 |
+
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image, content_prompt_embeds=pooled_prompt_embeds_)
|
304 |
+
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
305 |
+
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
306 |
+
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
307 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
308 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
309 |
+
|
310 |
+
with torch.inference_mode():
|
311 |
+
(
|
312 |
+
prompt_embeds,
|
313 |
+
negative_prompt_embeds,
|
314 |
+
pooled_prompt_embeds,
|
315 |
+
negative_pooled_prompt_embeds,
|
316 |
+
) = self.pipe.encode_prompt(
|
317 |
+
prompt,
|
318 |
+
num_images_per_prompt=num_samples,
|
319 |
+
do_classifier_free_guidance=True,
|
320 |
+
negative_prompt=negative_prompt,
|
321 |
+
)
|
322 |
+
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
323 |
+
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
324 |
+
|
325 |
+
self.generator = get_generator(seed, self.device)
|
326 |
+
|
327 |
+
images = self.pipe(
|
328 |
+
prompt_embeds=prompt_embeds,
|
329 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
330 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
331 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
332 |
+
num_inference_steps=num_inference_steps,
|
333 |
+
generator=self.generator,
|
334 |
+
**kwargs,
|
335 |
+
).images
|
336 |
+
|
337 |
+
return images
|
338 |
+
|
339 |
+
|
340 |
+
class IPAdapterPlus(IPAdapter):
|
341 |
+
"""IP-Adapter with fine-grained features"""
|
342 |
+
|
343 |
+
def init_proj(self):
|
344 |
+
image_proj_model = Resampler(
|
345 |
+
dim=self.pipe.unet.config.cross_attention_dim,
|
346 |
+
depth=4,
|
347 |
+
dim_head=64,
|
348 |
+
heads=12,
|
349 |
+
num_queries=self.num_tokens,
|
350 |
+
embedding_dim=self.image_encoder.config.hidden_size,
|
351 |
+
output_dim=self.pipe.unet.config.cross_attention_dim,
|
352 |
+
ff_mult=4,
|
353 |
+
).to(self.device, dtype=torch.float16)
|
354 |
+
return image_proj_model
|
355 |
+
|
356 |
+
@torch.inference_mode()
|
357 |
+
def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
|
358 |
+
if isinstance(pil_image, Image.Image):
|
359 |
+
pil_image = [pil_image]
|
360 |
+
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
361 |
+
clip_image = clip_image.to(self.device, dtype=torch.float16)
|
362 |
+
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
363 |
+
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
364 |
+
uncond_clip_image_embeds = self.image_encoder(
|
365 |
+
torch.zeros_like(clip_image), output_hidden_states=True
|
366 |
+
).hidden_states[-2]
|
367 |
+
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
|
368 |
+
return image_prompt_embeds, uncond_image_prompt_embeds
|
369 |
+
|
370 |
+
|
371 |
+
class IPAdapterFull(IPAdapterPlus):
|
372 |
+
"""IP-Adapter with full features"""
|
373 |
+
|
374 |
+
def init_proj(self):
|
375 |
+
image_proj_model = MLPProjModel(
|
376 |
+
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
|
377 |
+
clip_embeddings_dim=self.image_encoder.config.hidden_size,
|
378 |
+
).to(self.device, dtype=torch.float16)
|
379 |
+
return image_proj_model
|
380 |
+
|
381 |
+
|
382 |
+
class IPAdapterPlusXL(IPAdapter):
|
383 |
+
"""SDXL"""
|
384 |
+
|
385 |
+
def init_proj(self):
|
386 |
+
image_proj_model = Resampler(
|
387 |
+
dim=1280,
|
388 |
+
depth=4,
|
389 |
+
dim_head=64,
|
390 |
+
heads=20,
|
391 |
+
num_queries=self.num_tokens,
|
392 |
+
embedding_dim=self.image_encoder.config.hidden_size,
|
393 |
+
output_dim=self.pipe.unet.config.cross_attention_dim,
|
394 |
+
ff_mult=4,
|
395 |
+
).to(self.device, dtype=torch.float16)
|
396 |
+
return image_proj_model
|
397 |
+
|
398 |
+
@torch.inference_mode()
|
399 |
+
def get_image_embeds(self, pil_image):
|
400 |
+
if isinstance(pil_image, Image.Image):
|
401 |
+
pil_image = [pil_image]
|
402 |
+
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
403 |
+
clip_image = clip_image.to(self.device, dtype=torch.float16)
|
404 |
+
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
405 |
+
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
406 |
+
uncond_clip_image_embeds = self.image_encoder(
|
407 |
+
torch.zeros_like(clip_image), output_hidden_states=True
|
408 |
+
).hidden_states[-2]
|
409 |
+
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
|
410 |
+
return image_prompt_embeds, uncond_image_prompt_embeds
|
411 |
+
|
412 |
+
def generate(
|
413 |
+
self,
|
414 |
+
pil_image,
|
415 |
+
prompt=None,
|
416 |
+
negative_prompt=None,
|
417 |
+
scale=1.0,
|
418 |
+
num_samples=4,
|
419 |
+
seed=None,
|
420 |
+
num_inference_steps=30,
|
421 |
+
**kwargs,
|
422 |
+
):
|
423 |
+
self.set_scale(scale)
|
424 |
+
|
425 |
+
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
|
426 |
+
|
427 |
+
if prompt is None:
|
428 |
+
prompt = "best quality, high quality"
|
429 |
+
if negative_prompt is None:
|
430 |
+
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
431 |
+
|
432 |
+
if not isinstance(prompt, List):
|
433 |
+
prompt = [prompt] * num_prompts
|
434 |
+
if not isinstance(negative_prompt, List):
|
435 |
+
negative_prompt = [negative_prompt] * num_prompts
|
436 |
+
|
437 |
+
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
|
438 |
+
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
439 |
+
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
440 |
+
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
441 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
442 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
443 |
+
|
444 |
+
with torch.inference_mode():
|
445 |
+
(
|
446 |
+
prompt_embeds,
|
447 |
+
negative_prompt_embeds,
|
448 |
+
pooled_prompt_embeds,
|
449 |
+
negative_pooled_prompt_embeds,
|
450 |
+
) = self.pipe.encode_prompt(
|
451 |
+
prompt,
|
452 |
+
num_images_per_prompt=num_samples,
|
453 |
+
do_classifier_free_guidance=True,
|
454 |
+
negative_prompt=negative_prompt,
|
455 |
+
)
|
456 |
+
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
457 |
+
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
458 |
+
|
459 |
+
generator = get_generator(seed, self.device)
|
460 |
+
|
461 |
+
images = self.pipe(
|
462 |
+
prompt_embeds=prompt_embeds,
|
463 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
464 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
465 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
466 |
+
num_inference_steps=num_inference_steps,
|
467 |
+
generator=generator,
|
468 |
+
**kwargs,
|
469 |
+
).images
|
470 |
+
|
471 |
+
return images
|
ip_adapter/resampler.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
2 |
+
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
|
3 |
+
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from einops import rearrange
|
9 |
+
from einops.layers.torch import Rearrange
|
10 |
+
|
11 |
+
|
12 |
+
# FFN
|
13 |
+
def FeedForward(dim, mult=4):
|
14 |
+
inner_dim = int(dim * mult)
|
15 |
+
return nn.Sequential(
|
16 |
+
nn.LayerNorm(dim),
|
17 |
+
nn.Linear(dim, inner_dim, bias=False),
|
18 |
+
nn.GELU(),
|
19 |
+
nn.Linear(inner_dim, dim, bias=False),
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
def reshape_tensor(x, heads):
|
24 |
+
bs, length, width = x.shape
|
25 |
+
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
26 |
+
x = x.view(bs, length, heads, -1)
|
27 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
28 |
+
x = x.transpose(1, 2)
|
29 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
30 |
+
x = x.reshape(bs, heads, length, -1)
|
31 |
+
return x
|
32 |
+
|
33 |
+
|
34 |
+
class PerceiverAttention(nn.Module):
|
35 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
36 |
+
super().__init__()
|
37 |
+
self.scale = dim_head**-0.5
|
38 |
+
self.dim_head = dim_head
|
39 |
+
self.heads = heads
|
40 |
+
inner_dim = dim_head * heads
|
41 |
+
|
42 |
+
self.norm1 = nn.LayerNorm(dim)
|
43 |
+
self.norm2 = nn.LayerNorm(dim)
|
44 |
+
|
45 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
46 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
47 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
48 |
+
|
49 |
+
def forward(self, x, latents):
|
50 |
+
"""
|
51 |
+
Args:
|
52 |
+
x (torch.Tensor): image features
|
53 |
+
shape (b, n1, D)
|
54 |
+
latent (torch.Tensor): latent features
|
55 |
+
shape (b, n2, D)
|
56 |
+
"""
|
57 |
+
x = self.norm1(x)
|
58 |
+
latents = self.norm2(latents)
|
59 |
+
|
60 |
+
b, l, _ = latents.shape
|
61 |
+
|
62 |
+
q = self.to_q(latents)
|
63 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
64 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
65 |
+
|
66 |
+
q = reshape_tensor(q, self.heads)
|
67 |
+
k = reshape_tensor(k, self.heads)
|
68 |
+
v = reshape_tensor(v, self.heads)
|
69 |
+
|
70 |
+
# attention
|
71 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
72 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
73 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
74 |
+
out = weight @ v
|
75 |
+
|
76 |
+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
77 |
+
|
78 |
+
return self.to_out(out)
|
79 |
+
|
80 |
+
|
81 |
+
class Resampler(nn.Module):
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
dim=1024,
|
85 |
+
depth=8,
|
86 |
+
dim_head=64,
|
87 |
+
heads=16,
|
88 |
+
num_queries=8,
|
89 |
+
embedding_dim=768,
|
90 |
+
output_dim=1024,
|
91 |
+
ff_mult=4,
|
92 |
+
max_seq_len: int = 257, # CLIP tokens + CLS token
|
93 |
+
apply_pos_emb: bool = False,
|
94 |
+
num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
|
95 |
+
):
|
96 |
+
super().__init__()
|
97 |
+
self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
|
98 |
+
|
99 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
100 |
+
|
101 |
+
self.proj_in = nn.Linear(embedding_dim, dim)
|
102 |
+
|
103 |
+
self.proj_out = nn.Linear(dim, output_dim)
|
104 |
+
self.norm_out = nn.LayerNorm(output_dim)
|
105 |
+
|
106 |
+
self.to_latents_from_mean_pooled_seq = (
|
107 |
+
nn.Sequential(
|
108 |
+
nn.LayerNorm(dim),
|
109 |
+
nn.Linear(dim, dim * num_latents_mean_pooled),
|
110 |
+
Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
|
111 |
+
)
|
112 |
+
if num_latents_mean_pooled > 0
|
113 |
+
else None
|
114 |
+
)
|
115 |
+
|
116 |
+
self.layers = nn.ModuleList([])
|
117 |
+
for _ in range(depth):
|
118 |
+
self.layers.append(
|
119 |
+
nn.ModuleList(
|
120 |
+
[
|
121 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
122 |
+
FeedForward(dim=dim, mult=ff_mult),
|
123 |
+
]
|
124 |
+
)
|
125 |
+
)
|
126 |
+
|
127 |
+
def forward(self, x):
|
128 |
+
if self.pos_emb is not None:
|
129 |
+
n, device = x.shape[1], x.device
|
130 |
+
pos_emb = self.pos_emb(torch.arange(n, device=device))
|
131 |
+
x = x + pos_emb
|
132 |
+
|
133 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
134 |
+
|
135 |
+
x = self.proj_in(x)
|
136 |
+
|
137 |
+
if self.to_latents_from_mean_pooled_seq:
|
138 |
+
meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
|
139 |
+
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
|
140 |
+
latents = torch.cat((meanpooled_latents, latents), dim=-2)
|
141 |
+
|
142 |
+
for attn, ff in self.layers:
|
143 |
+
latents = attn(x, latents) + latents
|
144 |
+
latents = ff(latents) + latents
|
145 |
+
|
146 |
+
latents = self.proj_out(latents)
|
147 |
+
return self.norm_out(latents)
|
148 |
+
|
149 |
+
|
150 |
+
def masked_mean(t, *, dim, mask=None):
|
151 |
+
if mask is None:
|
152 |
+
return t.mean(dim=dim)
|
153 |
+
|
154 |
+
denom = mask.sum(dim=dim, keepdim=True)
|
155 |
+
mask = rearrange(mask, "b n -> b n 1")
|
156 |
+
masked_t = t.masked_fill(~mask, 0.0)
|
157 |
+
|
158 |
+
return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
|
ip_adapter/utils.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
attn_maps = {}
|
7 |
+
def hook_fn(name):
|
8 |
+
def forward_hook(module, input, output):
|
9 |
+
if hasattr(module.processor, "attn_map"):
|
10 |
+
attn_maps[name] = module.processor.attn_map
|
11 |
+
del module.processor.attn_map
|
12 |
+
|
13 |
+
return forward_hook
|
14 |
+
|
15 |
+
def register_cross_attention_hook(unet):
|
16 |
+
for name, module in unet.named_modules():
|
17 |
+
if name.split('.')[-1].startswith('attn2'):
|
18 |
+
module.register_forward_hook(hook_fn(name))
|
19 |
+
|
20 |
+
return unet
|
21 |
+
|
22 |
+
def upscale(attn_map, target_size):
|
23 |
+
attn_map = torch.mean(attn_map, dim=0)
|
24 |
+
attn_map = attn_map.permute(1,0)
|
25 |
+
temp_size = None
|
26 |
+
|
27 |
+
for i in range(0,5):
|
28 |
+
scale = 2 ** i
|
29 |
+
if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
|
30 |
+
temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
|
31 |
+
break
|
32 |
+
|
33 |
+
assert temp_size is not None, "temp_size cannot is None"
|
34 |
+
|
35 |
+
attn_map = attn_map.view(attn_map.shape[0], *temp_size)
|
36 |
+
|
37 |
+
attn_map = F.interpolate(
|
38 |
+
attn_map.unsqueeze(0).to(dtype=torch.float32),
|
39 |
+
size=target_size,
|
40 |
+
mode='bilinear',
|
41 |
+
align_corners=False
|
42 |
+
)[0]
|
43 |
+
|
44 |
+
attn_map = torch.softmax(attn_map, dim=0)
|
45 |
+
return attn_map
|
46 |
+
def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
|
47 |
+
|
48 |
+
idx = 0 if instance_or_negative else 1
|
49 |
+
net_attn_maps = []
|
50 |
+
|
51 |
+
for name, attn_map in attn_maps.items():
|
52 |
+
attn_map = attn_map.cpu() if detach else attn_map
|
53 |
+
attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
|
54 |
+
attn_map = upscale(attn_map, image_size)
|
55 |
+
net_attn_maps.append(attn_map)
|
56 |
+
|
57 |
+
net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
|
58 |
+
|
59 |
+
return net_attn_maps
|
60 |
+
|
61 |
+
def attnmaps2images(net_attn_maps):
|
62 |
+
|
63 |
+
#total_attn_scores = 0
|
64 |
+
images = []
|
65 |
+
|
66 |
+
for attn_map in net_attn_maps:
|
67 |
+
attn_map = attn_map.cpu().numpy()
|
68 |
+
#total_attn_scores += attn_map.mean().item()
|
69 |
+
|
70 |
+
normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
|
71 |
+
normalized_attn_map = normalized_attn_map.astype(np.uint8)
|
72 |
+
#print("norm: ", normalized_attn_map.shape)
|
73 |
+
image = Image.fromarray(normalized_attn_map)
|
74 |
+
|
75 |
+
#image = fix_save_attn_map(attn_map)
|
76 |
+
images.append(image)
|
77 |
+
|
78 |
+
#print(total_attn_scores)
|
79 |
+
return images
|
80 |
+
def is_torch2_available():
|
81 |
+
return hasattr(F, "scaled_dot_product_attention")
|
82 |
+
|
83 |
+
def get_generator(seed, device):
|
84 |
+
|
85 |
+
if seed is not None:
|
86 |
+
if isinstance(seed, list):
|
87 |
+
generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
|
88 |
+
else:
|
89 |
+
generator = torch.Generator(device).manual_seed(seed)
|
90 |
+
else:
|
91 |
+
generator = None
|
92 |
+
|
93 |
+
return generator
|
models/image_encoder/config.json
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"CLIPVisionModelWithProjection"
|
4 |
+
],
|
5 |
+
"_name_or_path": "",
|
6 |
+
"add_cross_attention": false,
|
7 |
+
"architectures": null,
|
8 |
+
"attention_dropout": 0.0,
|
9 |
+
"bad_words_ids": null,
|
10 |
+
"begin_suppress_tokens": null,
|
11 |
+
"bos_token_id": null,
|
12 |
+
"chunk_size_feed_forward": 0,
|
13 |
+
"cross_attention_hidden_size": null,
|
14 |
+
"decoder_start_token_id": null,
|
15 |
+
"diversity_penalty": 0.0,
|
16 |
+
"do_sample": false,
|
17 |
+
"dropout": 0.0,
|
18 |
+
"early_stopping": false,
|
19 |
+
"encoder_no_repeat_ngram_size": 0,
|
20 |
+
"eos_token_id": null,
|
21 |
+
"exponential_decay_length_penalty": null,
|
22 |
+
"finetuning_task": null,
|
23 |
+
"forced_bos_token_id": null,
|
24 |
+
"forced_eos_token_id": null,
|
25 |
+
"hidden_act": "gelu",
|
26 |
+
"hidden_size": 1664,
|
27 |
+
"id2label": {
|
28 |
+
"0": "LABEL_0",
|
29 |
+
"1": "LABEL_1"
|
30 |
+
},
|
31 |
+
"image_size": 224,
|
32 |
+
"initializer_factor": 1.0,
|
33 |
+
"initializer_range": 0.02,
|
34 |
+
"intermediate_size": 8192,
|
35 |
+
"is_decoder": false,
|
36 |
+
"is_encoder_decoder": false,
|
37 |
+
"label2id": {
|
38 |
+
"LABEL_0": 0,
|
39 |
+
"LABEL_1": 1
|
40 |
+
},
|
41 |
+
"layer_norm_eps": 1e-05,
|
42 |
+
"length_penalty": 1.0,
|
43 |
+
"max_length": 20,
|
44 |
+
"min_length": 0,
|
45 |
+
"model_type": "clip_vision_model",
|
46 |
+
"no_repeat_ngram_size": 0,
|
47 |
+
"num_attention_heads": 16,
|
48 |
+
"num_beam_groups": 1,
|
49 |
+
"num_beams": 1,
|
50 |
+
"num_channels": 3,
|
51 |
+
"num_hidden_layers": 48,
|
52 |
+
"num_return_sequences": 1,
|
53 |
+
"output_attentions": false,
|
54 |
+
"output_hidden_states": false,
|
55 |
+
"output_scores": false,
|
56 |
+
"pad_token_id": null,
|
57 |
+
"patch_size": 14,
|
58 |
+
"prefix": null,
|
59 |
+
"problem_type": null,
|
60 |
+
"pruned_heads": {},
|
61 |
+
"remove_invalid_values": false,
|
62 |
+
"repetition_penalty": 1.0,
|
63 |
+
"return_dict": true,
|
64 |
+
"return_dict_in_generate": false,
|
65 |
+
"sep_token_id": null,
|
66 |
+
"suppress_tokens": null,
|
67 |
+
"task_specific_params": null,
|
68 |
+
"temperature": 1.0,
|
69 |
+
"tf_legacy_loss": false,
|
70 |
+
"tie_encoder_decoder": false,
|
71 |
+
"tie_word_embeddings": true,
|
72 |
+
"tokenizer_class": null,
|
73 |
+
"top_k": 50,
|
74 |
+
"top_p": 1.0,
|
75 |
+
"torch_dtype": null,
|
76 |
+
"torchscript": false,
|
77 |
+
"transformers_version": "4.24.0",
|
78 |
+
"typical_p": 1.0,
|
79 |
+
"use_bfloat16": false,
|
80 |
+
"projection_dim": 1280
|
81 |
+
}
|
models/image_encoder/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:657723e09f46a7c3957df651601029f66b1748afb12b419816330f16ed45d64d
|
3 |
+
size 3689912664
|
models/image_encoder/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2999562fbc02f9dc0d9c0acb7cf0970ec3a9b2a578d7d05afe82191d606d2d80
|
3 |
+
size 3690112753
|
models/ip-adapter_sdxl.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7525f2731e9e86d1368e0b68467615d55dda459691965bdd7d37fa3d7fd84c12
|
3 |
+
size 702585097
|
result.png
ADDED
![]() |
Git LFS Details
|