JB-Bai commited on
Commit
7217432
·
1 Parent(s): e8fe280
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ .vscode
2
+ __pycache__/
3
+ *.pyc
4
+
5
+ data
6
+ checkpoint
7
+ eval
8
+
9
+ outputs
10
+ wandb
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from src.transformer import SymmetricTransformer2DModel
4
+ from src.pipeline import UnifiedPipeline
5
+ from src.scheduler import Scheduler
6
+ from torchvision import transforms
7
+ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
8
+ from diffusers import VQModel
9
+ import os
10
+ from PIL import Image
11
+ import numpy as np
12
+
13
+
14
+ def load_models(model_path="MeissonFlow/Meissonic",
15
+ transformer_path="MeissonFlow/Muddit/1024",
16
+ device="cuda"):
17
+ model = SymmetricTransformer2DModel.from_pretrained(
18
+ transformer_path or model_path,
19
+ subfolder="transformer",
20
+ )
21
+ vq_model = VQModel.from_pretrained(model_path, subfolder="vqvae")
22
+ text_encoder = CLIPTextModelWithProjection.from_pretrained(model_path, subfolder="text_encoder")
23
+ tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")
24
+ scheduler = Scheduler.from_pretrained(model_path, subfolder="scheduler")
25
+
26
+ pipe = UnifiedPipeline(
27
+ vqvae=vq_model,
28
+ tokenizer=tokenizer,
29
+ text_encoder=text_encoder,
30
+ transformer=model,
31
+ scheduler=scheduler,
32
+ )
33
+ pipe.to(device)
34
+ return pipe
35
+
36
+ # Load models (global variable to avoid reloading)
37
+ pipe = load_models()
38
+
39
+ # Common transform
40
+ def get_transform(resolution):
41
+ return transforms.Compose([
42
+ transforms.Resize((resolution, resolution)),
43
+ transforms.ToTensor(),
44
+ ])
45
+
46
+ # Image-to-Text Function
47
+ def image_to_text(image, prompt, resolution=1024, steps=64, cfg=9.0):
48
+ try:
49
+ transform = get_transform(resolution)
50
+
51
+ if image is not None:
52
+ pil_image = Image.fromarray(image.astype('uint8'), 'RGB') if isinstance(image, np.ndarray) else image
53
+ images = torch.stack([transform(pil_image)])
54
+ questions = [prompt] if prompt else ["Please describe this image."]
55
+ else:
56
+ images = None
57
+ questions = [prompt] if prompt else ["Please generate an image description."]
58
+
59
+ output = pipe(
60
+ prompt=questions,
61
+ image=images,
62
+ height=resolution,
63
+ width=resolution,
64
+ guidance_scale=cfg,
65
+ num_inference_steps=steps,
66
+ mask_token_embedding="MeissonFlow/Muddit",
67
+ generator=torch.manual_seed(42),
68
+ )
69
+
70
+ return output.prompts[0]
71
+
72
+ except Exception as e:
73
+ return f"Error: {str(e)}"
74
+
75
+ # Text-to-Image Function
76
+ def text_to_image(prompt, negative_prompt, num_images=1, resolution=1024, steps=64, cfg=9.0):
77
+ try:
78
+ negative_prompt = negative_prompt or "worst quality, low quality, low res, blurry, distortion, watermark, logo, signature, text, jpeg artifacts, signature, sketch, duplicate, ugly, identifying mark"
79
+
80
+ output = pipe(
81
+ prompt=[prompt]*num_images,
82
+ negative_prompt=[negative_prompt]*num_images,
83
+ height=resolution,
84
+ width=resolution,
85
+ guidance_scale=cfg,
86
+ num_inference_steps=steps,
87
+ mask_token_embedding="MeissonFlow/Muddit",
88
+ generator=torch.manual_seed(42),
89
+ )
90
+
91
+ return output.images
92
+
93
+ except Exception as e:
94
+ print(f"Error: {str(e)}")
95
+ return None
96
+
97
+ # Create Gradio interface with Soft theme
98
+ with gr.Blocks(theme=gr.themes.Soft(), title="Muddit Unifined Model") as demo:
99
+ gr.Markdown("# 🌌 Muddit: Liberating Generation Beyond Text-to-Image with a Unified Discrete Diffusion Model.")
100
+ gr.Markdown(" Muddit is a unified discrete diffusion transformer that enables fast and parallel generation across both text and image modalities.")
101
+
102
+ with gr.Tab("Image to Text"):
103
+ with gr.Row():
104
+ with gr.Column():
105
+ i2t_image_input = gr.Image(label="Upload Image", type="pil")
106
+ i2t_prompt_input = gr.Textbox(label="Prompt", value="Please describe this image.", placeholder="Enter your prompt here...")
107
+
108
+ with gr.Accordion("Advanced Settings", open=False):
109
+ i2t_resolution = gr.Slider(label="Resolution", minimum=256, maximum=1024, value=1024, step=64)
110
+ i2t_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, value=64, step=1)
111
+ i2t_cfg = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=20.0, value=9.0, step=0.5)
112
+
113
+ i2t_submit_btn = gr.Button("Generate Description", variant="primary")
114
+
115
+ with gr.Column():
116
+ i2t_output_text = gr.Textbox(label="Generated Description", interactive=False)
117
+ i2t_examples = gr.Examples(
118
+ examples=[
119
+ ["assets/man.jpg"],
120
+ ["assets/tennis.jpg"],
121
+ ["assets/pizza2.jpg"],
122
+ ["assets/plane.jpg"],
123
+ ["assets/zebra.jpg"],
124
+ ["assets/building.jpg"],
125
+ ["assets/flower.jpg"],
126
+ ],
127
+ inputs=[i2t_image_input],
128
+ label="Example Inputs"
129
+ )
130
+
131
+ with gr.Tab("VQA"):
132
+ with gr.Row():
133
+ with gr.Column():
134
+ vqa_image_input = gr.Image(label="Upload Image", type="pil")
135
+ vqa_prompt_input = gr.Textbox(label="Prompt", placeholder="Enter your question here...")
136
+
137
+ with gr.Accordion("Advanced Settings", open=False):
138
+ vqa_resolution = gr.Slider(label="Resolution", minimum=256, maximum=1024, value=1024, step=64)
139
+ vqa_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, value=64, step=1)
140
+ vqa_cfg = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=20.0, value=9.0, step=0.5)
141
+
142
+ vqa_submit_btn = gr.Button("Generate Answer", variant="primary")
143
+
144
+ with gr.Column():
145
+ vqa_output_text = gr.Textbox(label="Generated Answer", interactive=False)
146
+ vqa_examples = gr.Examples(
147
+ examples=[
148
+ ["assets/kid.jpg", "What color is the kid's hair?"],
149
+ ["assets/street.jpg", "Can someone legally walk across the street right now?"],
150
+ ["assets/dog.jpg", "Where is the dog laying?"],
151
+ ["assets/dog2.jpg", "What color is the toy the dog is holding?"],
152
+ ["assets/pizza.jpg", "What food item is shown?"],
153
+ ["assets/sheep.jpg", "How many sheep are pictured?"],
154
+ ["assets/car.jpg", "Where are the cars?"],
155
+ ],
156
+ inputs=[vqa_image_input, vqa_prompt_input],
157
+ label="Example Inputs"
158
+ )
159
+
160
+ with gr.Tab("Text to Image"):
161
+ with gr.Row():
162
+ with gr.Column():
163
+ t2i_prompt_input = gr.Textbox(label="Prompt", placeholder="Describe the image you want to generate...")
164
+ t2i_negative_prompt = gr.Textbox(label="Negative Prompt",
165
+ value="worst quality, low quality, low res, blurry, distortion, watermark, logo, signature, text, jpeg artifacts, signature, sketch, duplicate, ugly, identifying mark",
166
+ placeholder="What you don't want in the image...",
167
+ lines=5)
168
+ t2i_num_images = gr.Slider(label="Number of Images", minimum=1, maximum=4, value=1, step=1)
169
+
170
+ with gr.Accordion("Advanced Settings", open=False):
171
+ t2i_resolution = gr.Slider(label="Resolution", minimum=256, maximum=1024, value=1024, step=64)
172
+ t2i_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, value=64, step=1)
173
+ t2i_cfg = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=20.0, value=9.0, step=0.5)
174
+
175
+ t2i_submit_btn = gr.Button("Generate Images", variant="primary")
176
+
177
+ with gr.Column():
178
+ t2i_gallery = gr.Gallery(label="Generated Images")
179
+ t2i_examples = gr.Examples(
180
+ examples=[
181
+ ["A line art portrait showcasing a human figure with flowing, textured strokes"],
182
+ ["A hyper realistic image of a chimpanzee with a glass-enclosed brain on his head, standing amidst lush, bioluminescent foliage in a vibrant futuristic forest"],
183
+ ["A samurai in a stylized cyberpunk outfit adorned with intricate steampunk gear and floral accents, his Mandalorian armor gleaming under the backlighting"],
184
+ ["A translucent, minimalist Porsche 911 GT3RS built from sleek carbon fiber, its aerodynamic body designed in the spirit of '60s Braun and modern Apple minimalism"],
185
+ ["A realistic photograph of a ramadan tent shaped like a crescent moon under a velvety back sky studded with the milky way"],
186
+ ["A portrait of John Lennon, captured in the gritty detail of line art"],
187
+ ["In a world plunged into an unending darkness, remnants of fading starlight pierce through a heavy, smog-filled sky"]
188
+ ],
189
+ inputs=[t2i_prompt_input],
190
+ label="Example Prompts"
191
+ )
192
+
193
+ # Event handlers
194
+ i2t_submit_btn.click(
195
+ fn=image_to_text,
196
+ inputs=[i2t_image_input, i2t_prompt_input, i2t_resolution, i2t_steps, i2t_cfg],
197
+ outputs=i2t_output_text
198
+ )
199
+
200
+ vqa_submit_btn.click(
201
+ fn=image_to_text,
202
+ inputs=[vqa_image_input, vqa_prompt_input, vqa_resolution, vqa_steps, vqa_cfg],
203
+ outputs=vqa_output_text
204
+ )
205
+
206
+ t2i_submit_btn.click(
207
+ fn=text_to_image,
208
+ inputs=[t2i_prompt_input, t2i_negative_prompt, t2i_num_images, t2i_resolution, t2i_steps, t2i_cfg],
209
+ outputs=t2i_gallery
210
+ )
211
+
212
+ demo.launch()
assets/building.jpg ADDED

Git LFS Details

  • SHA256: a2758e8579477d301e72c10893fb3ca62b2f8e2026fb4362a3abcfb4713c59c9
  • Pointer size: 131 Bytes
  • Size of remote file: 136 kB
assets/car.jpg ADDED

Git LFS Details

  • SHA256: 71d585a24471298336ad9b921b57c7fcc80a876ddb5ba0b7a9187e59144f26b5
  • Pointer size: 131 Bytes
  • Size of remote file: 159 kB
assets/dog.jpg ADDED

Git LFS Details

  • SHA256: b9da7cf6d47602c875a370296f593485c639aef5f03c264dbb62adce3db30541
  • Pointer size: 130 Bytes
  • Size of remote file: 47.2 kB
assets/dog2.jpg ADDED

Git LFS Details

  • SHA256: 399d5ebf4a8fdf2817b9d6afe207cbb41b6abb749a8663bf04d3b48d0878beb1
  • Pointer size: 131 Bytes
  • Size of remote file: 152 kB
assets/flower.jpg ADDED

Git LFS Details

  • SHA256: 2d13c04321526882a1b30e4ed38b2dae140febd8e8a2578e82bc8e278736dab2
  • Pointer size: 130 Bytes
  • Size of remote file: 76.5 kB
assets/giraffe2.jpg ADDED

Git LFS Details

  • SHA256: be0f2c417eca5516d06e67d6cf6c6ed031520e2a5ac019fbd35c19bd45d82e7d
  • Pointer size: 131 Bytes
  • Size of remote file: 271 kB
assets/girl.jpg ADDED

Git LFS Details

  • SHA256: 9a1822f2c99cfe373a3fc7df9bef5732297f19eb4b4e0cd38f3760925f65d589
  • Pointer size: 131 Bytes
  • Size of remote file: 129 kB
assets/kid.jpg ADDED

Git LFS Details

  • SHA256: 90ce8bd7589b0ad0ce295858e9bc0206cdc1d601cc599eb4c4660637d1d29b08
  • Pointer size: 130 Bytes
  • Size of remote file: 30.8 kB
assets/man.jpg ADDED

Git LFS Details

  • SHA256: ca09f5e2d747e2bcbed917fe688793d9996f457ac1e4150d7726259e299bbb3a
  • Pointer size: 131 Bytes
  • Size of remote file: 247 kB
assets/pizza.jpg ADDED

Git LFS Details

  • SHA256: 6740d86839f2651a997e93d10c13bc18ce25418d44bc8813b2a8dc53ea21c7a8
  • Pointer size: 130 Bytes
  • Size of remote file: 56.4 kB
assets/pizza2.jpg ADDED

Git LFS Details

  • SHA256: 155c476a191f672beff20670ea20af63e647def8fa1808f1382906e8ca0a46dc
  • Pointer size: 131 Bytes
  • Size of remote file: 181 kB
assets/plane.jpg ADDED

Git LFS Details

  • SHA256: dc3fd7cc41633da42e39652bf5516dee95d8a611b988ee320f6cf4c777358733
  • Pointer size: 130 Bytes
  • Size of remote file: 59.1 kB
assets/sheep.jpg ADDED

Git LFS Details

  • SHA256: 0a0486e0332df04695b56c1c463be1b7b21f0d5d9da5f0d85c93931e4dc4f8da
  • Pointer size: 131 Bytes
  • Size of remote file: 334 kB
assets/street.jpg ADDED

Git LFS Details

  • SHA256: 2570f9e21d7650d43e0b38705a7c44f4d006c3f8ed71dd11abcd717d33212a1b
  • Pointer size: 130 Bytes
  • Size of remote file: 50.6 kB
assets/tennis.jpg ADDED

Git LFS Details

  • SHA256: d239391bb63121bc457f049f8d1c89329ff1910aafd7e6b42b59a8e36ec54f7a
  • Pointer size: 131 Bytes
  • Size of remote file: 110 kB
assets/train.jpg ADDED

Git LFS Details

  • SHA256: 282ea96b499a76b5e93ebfb9bcc5c2788821515ee4f71488f8867fde8e743a7e
  • Pointer size: 131 Bytes
  • Size of remote file: 158 kB
assets/zebra.jpg ADDED

Git LFS Details

  • SHA256: 65062fba50c78acdc7f478227349059e19ae8388ca01453206e719cbe492acd9
  • Pointer size: 131 Bytes
  • Size of remote file: 172 kB
requirements.txt ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ accelerate==0.34.2
3
+ addict==2.4.0
4
+ aiofiles==23.2.1
5
+ aiohappyeyeballs==2.4.0
6
+ aiohttp==3.10.5
7
+ aiosignal==1.3.1
8
+ albucore==0.0.19
9
+ albumentations==1.4.20
10
+ aliyun-python-sdk-core==2.16.0
11
+ aliyun-python-sdk-kms==2.16.5
12
+ all-clip==1.2.0
13
+ aniso8601==9.0.1
14
+ annotated-types==0.7.0
15
+ antlr4-python3-runtime==4.9.3
16
+ anyio==4.9.0
17
+ appdirs==1.4.4
18
+ argon2-cffi==23.1.0
19
+ argon2-cffi-bindings==21.2.0
20
+ arrow==1.3.0
21
+ asttokens==2.4.1
22
+ async-lru==2.0.4
23
+ async-timeout==4.0.3
24
+ attrdict==2.0.1
25
+ attrs==24.2.0
26
+ av==14.0.1
27
+ babel==2.16.0
28
+ beartype==0.19.0
29
+ beautifulsoup4==4.12.3
30
+ bitsandbytes==0.43.3
31
+ bleach==6.2.0
32
+ blessed==1.20.0
33
+ blinker==1.8.2
34
+ boto3==1.35.92
35
+ botocore==1.35.92
36
+ braceexpand==0.1.7
37
+ build==1.2.2.post1
38
+ cachetools==5.5.2
39
+ certifi==2024.8.30
40
+ cffi==1.17.1
41
+ charset-normalizer==2.0.12
42
+ chumpy==0.70
43
+ click==8.1.7
44
+ clip-anytorch==2.6.0
45
+ cloudpickle==3.1.0
46
+ colorama==0.4.6
47
+ comm==0.2.2
48
+ contourpy==1.3.0
49
+ cpm-kernels==1.0.11
50
+ crcmod==1.7
51
+ cryptography==44.0.3
52
+ cycler==0.12.1
53
+ Cython==3.0.12
54
+ dashscope==1.22.2
55
+ dataclasses==0.6
56
+ datasets==3.6.0
57
+ debugpy==1.8.8
58
+ decorator==4.4.2
59
+ decord==0.6.0
60
+ deepspeed==0.16.2
61
+ defusedxml==0.7.1
62
+ diffusers==0.33.1
63
+ dill==0.3.8
64
+ distro==1.9.0
65
+ docker-pycreds==0.4.0
66
+ easydict==1.13
67
+ einops==0.8.1
68
+ eval_type_backport==0.2.0
69
+ exceptiongroup==1.2.2
70
+ executing==2.1.0
71
+ ExifRead-nocycle==3.0.1
72
+ fairscale==0.4.13
73
+ fastapi==0.115.0
74
+ fastdtw==0.3.4
75
+ fastjsonschema==2.20.0
76
+ ffmpy==0.4.0
77
+ filelock==3.14.0
78
+ filterpy==1.4.5
79
+ fire==0.5.0
80
+ flash_attn==2.7.4.post1
81
+ Flask==3.0.3
82
+ Flask-Cors==4.0.2
83
+ Flask-RESTful==0.3.10
84
+ flow-vis==0.1
85
+ fonttools==4.53.1
86
+ fqdn==1.5.1
87
+ freetype-py==2.5.1
88
+ frozenlist==1.4.1
89
+ fsspec==2024.2.0
90
+ ftfy==6.3.1
91
+ func_timeout==4.3.5
92
+ fvcore==0.1.5.post20221221
93
+ gdown==5.2.0
94
+ gitdb==4.0.11
95
+ GitPython==3.1.43
96
+ google-auth==2.38.0
97
+ gpustat==1.1.1
98
+ gradio_client==1.3.0
99
+ grpcio==1.69.0
100
+ h11==0.14.0
101
+ h5py==3.12.1
102
+ hjson==3.1.0
103
+ httpcore==1.0.5
104
+ httpx==0.28.1
105
+ huggingface-hub==0.29.1
106
+ hydra-core==1.3.2
107
+ hydra-submitit-launcher==1.2.0
108
+ idna==3.10
109
+ imageio==2.35.1
110
+ imageio-ffmpeg==0.5.1
111
+ importlib_metadata==8.5.0
112
+ importlib_resources==6.4.5
113
+ imutils==0.5.4
114
+ iopath==0.1.10
115
+ ipykernel==6.29.5
116
+ ipympl==0.9.4
117
+ ipython==8.18.1
118
+ ipython-genutils==0.2.0
119
+ ipywidgets==8.1.5
120
+ isoduration==20.11.0
121
+ itsdangerous==2.2.0
122
+ jaxtyping==0.2.36
123
+ jedi==0.19.2
124
+ Jinja2==3.1.3
125
+ jiter==0.5.0
126
+ jmespath==0.10.0
127
+ joblib==1.4.2
128
+ json-tricks==3.17.3
129
+ json5==0.9.28
130
+ jsonpointer==3.0.0
131
+ jsonschema==4.23.0
132
+ jsonschema-specifications==2024.10.1
133
+ jupyter==1.1.1
134
+ jupyter-console==6.6.3
135
+ jupyter-events==0.10.0
136
+ jupyter-lsp==2.2.5
137
+ jupyter_client==8.6.3
138
+ jupyter_core==5.7.2
139
+ jupyter_server==2.14.2
140
+ jupyter_server_terminals==0.5.3
141
+ jupyterlab==4.2.6
142
+ jupyterlab_pygments==0.3.0
143
+ jupyterlab_server==2.27.3
144
+ jupyterlab_widgets==3.0.13
145
+ kiwisolver==1.4.7
146
+ kornia==0.7.3
147
+ kornia_rs==0.1.8
148
+ lazy_loader==0.4
149
+ lightning-utilities==0.11.9
150
+ llvmlite==0.43.0
151
+ Markdown==3.7
152
+ markdown-it-py==3.0.0
153
+ MarkupSafe==2.1.5
154
+ matplotlib==3.7.0
155
+ matplotlib-inline==0.1.7
156
+ mdurl==0.1.2
157
+ mediapy==1.2.2
158
+ mistune==3.0.2
159
+ mmcv==2.2.0
160
+ mmengine==0.10.7
161
+ mmpose==0.28.0
162
+ model-index==0.1.11
163
+ moviepy==1.0.3
164
+ mpmath==1.3.0
165
+ msgpack==1.1.0
166
+ multidict==6.1.0
167
+ multilingual-clip==1.0.10
168
+ multiprocess==0.70.16
169
+ munkres==1.1.4
170
+ nbclient==0.10.0
171
+ nbconvert==7.16.4
172
+ nbformat==5.10.4
173
+ nest-asyncio==1.6.0
174
+ networkx==3.2.1
175
+ ninja==1.11.1.3
176
+ nltk==3.9.1
177
+ notebook==7.2.2
178
+ notebook_shim==0.2.4
179
+ numba==0.60.0
180
+ numpy==1.24.4
181
+ nvidia-cublas-cu12==12.4.5.8
182
+ nvidia-cuda-cupti-cu12==12.4.127
183
+ nvidia-cuda-nvrtc-cu12==12.4.127
184
+ nvidia-cuda-runtime-cu12==12.4.127
185
+ nvidia-cudnn-cu12==9.1.0.70
186
+ nvidia-cufft-cu12==11.2.1.3
187
+ nvidia-curand-cu12==10.3.5.147
188
+ nvidia-cusolver-cu12==11.6.1.9
189
+ nvidia-cusparse-cu12==12.3.1.170
190
+ nvidia-ml-py==12.560.30
191
+ nvidia-nccl-cu12==2.21.5
192
+ nvidia-nvjitlink-cu12==12.4.127
193
+ nvidia-nvtx-cu12==12.4.127
194
+ omegaconf==2.3.0
195
+ open_clip_torch==2.29.0
196
+ openai==1.47.0
197
+ opencv-python==4.7.0.72
198
+ opencv-python-headless==4.10.0.84
199
+ opendatalab==0.0.10
200
+ openmim==0.3.9
201
+ ordered-set==4.1.0
202
+ orjson==3.10.7
203
+ oss2==2.17.0
204
+ overrides==7.7.0
205
+ packaging==24.1
206
+ pandas==2.2.3
207
+ pandocfilters==1.5.1
208
+ parso==0.8.4
209
+ peft==0.14.0
210
+ pexpect==4.9.0
211
+ Pillow==9.5.0
212
+ pip-tools==7.4.1
213
+ platformdirs==4.3.6
214
+ plotly==5.24.1
215
+ plyfile==1.1
216
+ portalocker==2.10.1
217
+ prodigyopt==1.0
218
+ proglog==0.1.10
219
+ prometheus_client==0.21.0
220
+ prompt_toolkit==3.0.48
221
+ protobuf==3.20.3
222
+ psutil==6.0.0
223
+ ptyprocess==0.7.0
224
+ pure_eval==0.2.3
225
+ py-cpuinfo==9.0.0
226
+ pyarrow==20.0.0
227
+ pyasn1==0.6.1
228
+ pyasn1_modules==0.4.1
229
+ pycocoevalcap==1.2
230
+ pycocotools==2.0.8
231
+ pycparser==2.22
232
+ pycryptodome==3.22.0
233
+ pydantic==2.9.2
234
+ pydantic_core==2.23.4
235
+ pydub==0.25.1
236
+ pyglet==1.5.27
237
+ Pygments==2.18.0
238
+ PyOpenGL==3.1.0
239
+ pyparsing==3.1.4
240
+ pyproject_hooks==1.2.0
241
+ pyrender==0.1.45
242
+ PySocks==1.7.1
243
+ python-dateutil==2.9.0.post0
244
+ python-json-logger==2.0.7
245
+ python-multipart==0.0.10
246
+ pytorch-lightning==2.5.0.post0
247
+ pytz==2023.4
248
+ PyWavelets==1.6.0
249
+ PyYAML==6.0.2
250
+ pyzmq==26.2.0
251
+ qwen-vl-utils==0.0.10
252
+ referencing==0.35.1
253
+ regex==2024.9.11
254
+ requests==2.32.3
255
+ rfc3339-validator==0.1.4
256
+ rfc3986-validator==0.1.1
257
+ rich==13.4.2
258
+ rpds-py==0.21.0
259
+ rsa==4.9
260
+ ruff==0.6.7
261
+ s3transfer==0.10.4
262
+ safetensors==0.4.5
263
+ scikit-image==0.24.0
264
+ scikit-learn==1.5.2
265
+ scipy==1.10.1
266
+ seaborn==0.13.2
267
+ semantic-version==2.10.0
268
+ Send2Trash==1.8.3
269
+ sentence-transformers==2.7.0
270
+ sentencepiece==0.1.99
271
+ sentry-sdk==2.14.0
272
+ setproctitle==1.3.3
273
+ shellingham==1.5.4
274
+ shortuuid==1.0.13
275
+ simple-aesthetics-predictor==0.1.2
276
+ six==1.16.0
277
+ smmap==5.0.1
278
+ smplx==0.1.28
279
+ sniffio==1.3.1
280
+ soupsieve==2.6
281
+ stack-data==0.6.3
282
+ starlette==0.38.6
283
+ stringzilla==3.10.6
284
+ submitit==1.5.2
285
+ supervision==0.25.1
286
+ SwissArmyTransformer==0.4.12
287
+ sympy==1.13.1
288
+ tabulate==0.9.0
289
+ tenacity==9.0.0
290
+ tensorboard==2.18.0
291
+ tensorboard-data-server==0.7.2
292
+ tensorboardX==2.6.2.2
293
+ termcolor==2.5.0
294
+ terminado==0.18.1
295
+ threadpoolctl==3.5.0
296
+ tifffile==2024.8.30
297
+ timm==1.0.11
298
+ tinycss2==1.4.0
299
+ tokenizers==0.21.0
300
+ tomesd==0.1.3
301
+ tomli==2.1.0
302
+ tomlkit==0.12.0
303
+ torch==2.5.1
304
+ torch-fidelity==0.3.0
305
+ torchdata==0.11.0
306
+ torchdiffeq==0.2.5
307
+ torchgeometry==0.1.2
308
+ torchmetrics==1.6.1
309
+ torchsde==0.2.6
310
+ torchvision==0.20.1
311
+ tornado==6.4.1
312
+ tqdm==4.67.1
313
+ traitlets==5.14.3
314
+ trampoline==0.1.2
315
+ transformers==4.49.0
316
+ trimesh==4.6.8
317
+ triton==3.1.0
318
+ tslearn==0.6.3
319
+ typeguard==4.4.2
320
+ typer==0.12.5
321
+ types-python-dateutil==2.9.0.20241003
322
+ typing_extensions==4.12.2
323
+ tzdata==2024.1
324
+ uri-template==1.3.0
325
+ urllib3==1.26.20
326
+ uvicorn==0.30.6
327
+ wandb==0.17.5
328
+ wcwidth==0.2.13
329
+ webcolors==24.11.1
330
+ webdataset==0.2.100
331
+ webencodings==0.5.1
332
+ websocket-client==1.8.0
333
+ websockets==12.0
334
+ Werkzeug==3.0.6
335
+ widgetsnbextension==4.0.13
336
+ wordcloud==1.9.4
337
+ xtcocotools==1.14.3
338
+ xxhash==3.5.0
339
+ yacs==0.1.8
340
+ yapf==0.43.0
341
+ yarl==1.11.1
342
+ zipp==3.20.2
src/pipeline.py ADDED
@@ -0,0 +1,684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team and The MeissonFlow Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import sys
16
+ from dataclasses import dataclass
17
+
18
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19
+ import PIL.Image
20
+ import torch
21
+ import PIL
22
+ import numpy as np
23
+
24
+ from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5Tokenizer, T5EncoderModel
25
+ from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
26
+ from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast
27
+
28
+ from diffusers.image_processor import VaeImageProcessor
29
+ from diffusers.models import VQModel
30
+ from diffusers.utils import replace_example_docstring
31
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
32
+ from diffusers.utils import BaseOutput
33
+
34
+ from src.scheduler import Scheduler
35
+ from src.transformer import SymmetricTransformer2DModel
36
+
37
+
38
+ EXAMPLE_DOC_STRING = """
39
+ Examples:
40
+ ```py
41
+ >>> image = pipe(prompt).images[0]
42
+ ```
43
+ """
44
+
45
+
46
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
47
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
48
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
49
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
50
+
51
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
52
+
53
+ latent_image_ids = latent_image_ids.reshape(
54
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
55
+ )
56
+
57
+ return latent_image_ids.to(device=device, dtype=dtype)
58
+
59
+ def dedup_consecutive_words(text: str) -> str:
60
+ """
61
+ >>> dedup_consecutive_words("hello hello world world world")
62
+ 'hello world'
63
+ """
64
+ words = text.split()
65
+ if not words:
66
+ return text
67
+
68
+ out = [words[0]]
69
+ for w in words[1:]:
70
+ if w != out[-1]:
71
+ out.append(w)
72
+ return " ".join(out)
73
+
74
+ def keep_upto_last_period(text: str) -> str:
75
+ """
76
+ Return the substring up to (and including) the last period-mark.
77
+
78
+ The function searches first for the Chinese full stop “。”;
79
+ if none is found, it falls back to the ASCII dot “.”.
80
+
81
+ Parameters
82
+ ----------
83
+ text : str
84
+ Input string.
85
+
86
+ Returns
87
+ -------
88
+ str
89
+ Substring ending at the final period-mark. If no period is present,
90
+ the original string is returned unchanged.
91
+ """
92
+ # Weired problem
93
+ text = text.replace("such is such", "").replace("is such is", "").replace("such is", "").replace("is such", "")
94
+ # Fallback to the ASCII period
95
+ idx = -1
96
+ if idx == -1:
97
+ idx = text.rfind(".")
98
+ # If still not found, return original text
99
+ if idx == -1:
100
+ return text
101
+ # Keep everything up to (and including) the last period
102
+ return text[:idx + 1]
103
+
104
+ @dataclass
105
+ class UnifiedPipelineOutput(BaseOutput):
106
+ """
107
+ Output class for image pipelines.
108
+
109
+ Args:
110
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
111
+ List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
112
+ num_channels)`.
113
+ """
114
+
115
+ images: Union[List[PIL.Image.Image], np.ndarray]
116
+ prompts: List[str]
117
+
118
+
119
+ class UnifiedPipeline(DiffusionPipeline):
120
+ image_processor: VaeImageProcessor
121
+ vqvae: VQModel
122
+ tokenizer: CLIPTokenizer
123
+ tokenizer_2: T5Tokenizer
124
+ text_encoder: CLIPTextModelWithProjection
125
+ text_encoder_2: T5EncoderModel
126
+ transformer: SymmetricTransformer2DModel
127
+ scheduler: Scheduler
128
+ model_cpu_offload_seq = "text_encoder->transformer->vqvae"
129
+
130
+ def __init__(
131
+ self,
132
+ vqvae: VQModel,
133
+ tokenizer: CLIPTokenizer,
134
+ text_encoder: CLIPTextModelWithProjection,
135
+ transformer: SymmetricTransformer2DModel,
136
+ scheduler: Scheduler,
137
+ tokenizer_2: T5Tokenizer = None,
138
+ text_encoder_2: T5EncoderModel = None,
139
+ ):
140
+ super().__init__()
141
+
142
+ self.register_modules(
143
+ vqvae=vqvae,
144
+ tokenizer=tokenizer,
145
+ tokenizer_2=tokenizer_2,
146
+ text_encoder=text_encoder,
147
+ text_encoder_2=text_encoder_2,
148
+ transformer=transformer,
149
+ scheduler=scheduler,
150
+ )
151
+ self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
152
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
153
+
154
+ @torch.no_grad()
155
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
156
+ def __call__(
157
+ self,
158
+ prompt: Optional[Union[List[str], str]] = None,
159
+ height: Optional[int] = 1024,
160
+ width: Optional[int] = 1024,
161
+ image: Optional[Union[torch.Tensor, PIL.Image.Image]] = None,
162
+ num_inference_steps: int = 48,
163
+ guidance_scale: float = 9.0,
164
+ negative_prompt: Optional[Union[str, List[str]]] = None,
165
+ num_images_per_prompt: Optional[int] = 1,
166
+ generator: Optional[torch.Generator] = None,
167
+ latents: Optional[torch.IntTensor] = None,
168
+ prompt_embeds: Optional[torch.Tensor] = None,
169
+ encoder_hidden_states: Optional[torch.Tensor] = None,
170
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
171
+ negative_encoder_hidden_states: Optional[torch.Tensor] = None,
172
+ output_type = "pil",
173
+ return_dict: bool = True,
174
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
175
+ callback_steps: int = 1,
176
+ micro_conditioning_aesthetic_score: int = 6,
177
+ micro_conditioning_crop_coord: Tuple[int, int] = (0, 0),
178
+ temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
179
+ mask_token_embedding: Optional[str] = None,
180
+ ):
181
+ """
182
+ The call function to the pipeline for generation.
183
+
184
+ Args:
185
+ prompt (`str` or `List[str]`, *optional*):
186
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
187
+ height (`int`, *optional*, defaults to `self.transformer.config.sample_size * self.vae_scale_factor`):
188
+ The height in pixels of the generated image.
189
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
190
+ The width in pixels of the generated image.
191
+ num_inference_steps (`int`, *optional*, defaults to 16):
192
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
193
+ expense of slower inference.
194
+ guidance_scale (`float`, *optional*, defaults to 10.0):
195
+ A higher guidance scale value encourages the model to generate images closely linked to the text
196
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
197
+ negative_prompt (`str` or `List[str]`, *optional*):
198
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
199
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
200
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
201
+ The number of images to generate per prompt.
202
+ generator (`torch.Generator`, *optional*):
203
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
204
+ generation deterministic.
205
+ latents (`torch.IntTensor`, *optional*):
206
+ Pre-generated tokens representing latent vectors in `self.vqvae`, to be used as inputs for image
207
+ gneration. If not provided, the starting latents will be completely masked.
208
+ prompt_embeds (`torch.Tensor`, *optional*):
209
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
210
+ provided, text embeddings are generated from the `prompt` input argument. A single vector from the
211
+ pooled and projected final hidden states.
212
+ encoder_hidden_states (`torch.Tensor`, *optional*):
213
+ Pre-generated penultimate hidden states from the text encoder providing additional text conditioning.
214
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
215
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
216
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
217
+ negative_encoder_hidden_states (`torch.Tensor`, *optional*):
218
+ Analogous to `encoder_hidden_states` for the positive prompt.
219
+ output_type (`str`, *optional*, defaults to `"pil"`):
220
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
221
+ return_dict (`bool`, *optional*, defaults to `True`):
222
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
223
+ plain tuple.
224
+ callback (`Callable`, *optional*):
225
+ A function that calls every `callback_steps` steps during inference. The function is called with the
226
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
227
+ callback_steps (`int`, *optional*, defaults to 1):
228
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
229
+ every step.
230
+ cross_attention_kwargs (`dict`, *optional*):
231
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
232
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
233
+ micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6):
234
+ The targeted aesthetic score according to the laion aesthetic classifier. See
235
+ https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of
236
+ https://arxiv.org/abs/2307.01952.
237
+ micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)):
238
+ The targeted height, width crop coordinates. See the micro-conditioning section of
239
+ https://arxiv.org/abs/2307.01952.
240
+ temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)):
241
+ Configures the temperature scheduler on `self.scheduler` see `Scheduler#set_timesteps`.
242
+
243
+ Examples:
244
+
245
+ Returns:
246
+ [`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`:
247
+ If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a
248
+ `tuple` is returned where the first element is a list with the generated images.
249
+ """
250
+ if (prompt_embeds is not None and encoder_hidden_states is None) or (
251
+ prompt_embeds is None and encoder_hidden_states is not None
252
+ ):
253
+ raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither")
254
+
255
+ if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or (
256
+ negative_prompt_embeds is None and negative_encoder_hidden_states is not None
257
+ ):
258
+ raise ValueError(
259
+ "pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither"
260
+ )
261
+
262
+ if self.text_encoder_2 is not None:
263
+ self.text_encoder_2.to(self._execution_device)
264
+
265
+ text2image = image is None
266
+ image2text = image is not None
267
+
268
+ if image2text:
269
+ if self.text_encoder_2 is not None:
270
+ prompt = "<extra_id_0>" * 256
271
+ prompt = [prompt] * len(image)
272
+
273
+ t5_mask_id = self.tokenizer_2.convert_tokens_to_ids("<extra_id_0>")
274
+ self.scheduler.config.mask_token_id = t5_mask_id
275
+ else:
276
+ mask_token = "<mask>"
277
+ self.tokenizer.add_tokens(mask_token, special_tokens=False)
278
+ clip_mask_id = self.tokenizer.convert_tokens_to_ids(mask_token)
279
+ self.text_encoder.resize_token_embeddings(len(self.tokenizer))
280
+
281
+ if mask_token_embedding is not None:
282
+ if mask_token_embedding.endswith(".pth"):
283
+ mask_token_embedding = torch.load(mask_token_embedding)
284
+ else:
285
+ mask_token_embedding = os.path.dirname(mask_token_embedding)
286
+ mask_token_embedding_path = os.path.join(mask_token_embedding, "mask_token_embedding.pth")
287
+ assert os.path.exists(mask_token_embedding_path), f"{mask_token_embedding_path} doesn't exists!"
288
+ mask_token_embedding = torch.load(mask_token_embedding_path)
289
+
290
+ mask_token_embedding = mask_token_embedding.to(self._execution_device, dtype=self.text_encoder.dtype)
291
+ self.text_encoder.get_input_embeddings().weight.data[clip_mask_id].copy_(mask_token_embedding)
292
+
293
+ self.scheduler.config.mask_token_id = clip_mask_id
294
+
295
+ input_ids = torch.ones(
296
+ size=(len(image), self.tokenizer.model_max_length),
297
+ dtype=torch.int64,
298
+ device=self._execution_device
299
+ )
300
+ input_ids = input_ids * clip_mask_id
301
+
302
+ question_len = []
303
+ if prompt is None:
304
+ question_len = [0] * len(image)
305
+ elif isinstance(prompt, str):
306
+ question_ids = torch.LongTensor([self.tokenizer.encode(prompt)])
307
+ question_ids = question_ids.repeat(len(image), 1)
308
+
309
+ q_len = len(question_ids[0]) - 1 # remove <eos> token
310
+ question_len = [q_len] * len(image)
311
+
312
+ input_ids[:, :q_len] = question_ids[:, :-1]
313
+ else:
314
+ assert isinstance(prompt, list), f"prompt must be None or str or list!"
315
+ assert len(prompt) == len(image), f"VQA require equal num of images and prompts!"
316
+ for i, p in enumerate(prompt):
317
+ question_ids = torch.LongTensor([self.tokenizer.encode(p)])
318
+
319
+ q_len = len(question_ids[0]) - 1
320
+ question_len.append(q_len)
321
+
322
+ input_ids[i, :q_len] = question_ids[0, :-1]
323
+ else:
324
+ self.scheduler.config.mask_token_id = self.transformer.config.vocab_size - 1
325
+
326
+ if isinstance(prompt, str):
327
+ prompt = [prompt]
328
+
329
+ if image is not None:
330
+ batch_size = len(image)
331
+ else:
332
+ batch_size = len(prompt)
333
+
334
+ if height is None:
335
+ height = self.transformer.config.sample_size * self.vae_scale_factor
336
+
337
+ if width is None:
338
+ width = self.transformer.config.sample_size * self.vae_scale_factor
339
+
340
+ if isinstance(self.text_encoder, CLIPTextModelWithProjection):
341
+ text_encoder_type = "open_clip"
342
+ if isinstance(self.text_encoder_2, Gemma2Model):
343
+ text_encoder_type = "gemma"
344
+
345
+ if prompt_embeds is None:
346
+ if text_encoder_type == "t5_clip":
347
+ if text2image:
348
+ input_ids_clip = self.tokenizer(
349
+ prompt,
350
+ return_tensors="pt",
351
+ padding="max_length",
352
+ truncation=True,
353
+ add_special_tokens=True,
354
+ max_length=77,
355
+ ).input_ids.to(self._execution_device)
356
+ outputs = self.text_encoder(input_ids_clip, return_dict=True, output_hidden_states=True)
357
+ prompt_embeds = outputs.text_embeds
358
+
359
+ input_ids_t5 = self.tokenizer_2(
360
+ prompt,
361
+ return_tensors="pt",
362
+ padding="max_length",
363
+ truncation=True,
364
+ add_special_tokens=True,
365
+ max_length=256,
366
+ ).input_ids.to(self._execution_device)
367
+
368
+ outputs_2 = self.text_encoder_2(input_ids_t5, return_dict=True, output_hidden_states=True)
369
+ encoder_hidden_states = outputs_2.last_hidden_state
370
+ elif text_encoder_type == "open_clip":
371
+ if text2image:
372
+ input_ids = self.tokenizer(
373
+ prompt,
374
+ return_tensors="pt",
375
+ padding="max_length",
376
+ truncation=True,
377
+ add_special_tokens=True,
378
+ max_length=77,
379
+ ).input_ids.to(self._execution_device)
380
+
381
+ outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
382
+ prompt_embeds = outputs.text_embeds
383
+ encoder_hidden_states = outputs.hidden_states[-2]
384
+ elif text_encoder_type == "gemma":
385
+ if text2image:
386
+ input_ids_clip = self.tokenizer(
387
+ prompt,
388
+ return_tensors="pt",
389
+ padding="max_length",
390
+ truncation=True,
391
+ add_special_tokens=True,
392
+ max_length=77,
393
+ ).input_ids.to(self._execution_device)
394
+ outputs = self.text_encoder(input_ids_clip, return_dict=True, output_hidden_states=True)
395
+ prompt_embeds = outputs.text_embeds
396
+
397
+ input_ids_2 = self.tokenizer_2(
398
+ prompt,
399
+ truncation=True,
400
+ padding="max_length",
401
+ max_length=256,
402
+ return_tensors="pt",
403
+ ).input_ids.to(self._execution_device)
404
+
405
+ outputs_2 = self.text_encoder_2(input_ids_2, return_dict=True, output_hidden_states=True)
406
+ encoder_hidden_states = outputs_2.last_hidden_state
407
+
408
+ prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1)
409
+ encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
410
+
411
+ if guidance_scale > 1.0 and text2image:
412
+ if negative_prompt_embeds is None:
413
+ if negative_prompt is None:
414
+ negative_prompt = [""] * len(prompt)
415
+
416
+ if isinstance(negative_prompt, str):
417
+ negative_prompt = [negative_prompt] * len(prompt)
418
+
419
+ if text_encoder_type == "t5_clip":
420
+ input_ids = self.tokenizer(
421
+ negative_prompt,
422
+ return_tensors="pt",
423
+ padding="max_length",
424
+ truncation=True,
425
+ add_special_tokens=True,
426
+ max_length=77,
427
+ ).input_ids.to(self._execution_device)
428
+ outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
429
+ negative_prompt_embeds = outputs.text_embeds
430
+
431
+ input_ids_2 = self.tokenizer_2(
432
+ negative_prompt,
433
+ return_tensors="pt",
434
+ padding="max_length",
435
+ truncation=True,
436
+ add_special_tokens=True,
437
+ max_length=256,
438
+ ).input_ids.to(self._execution_device)
439
+ outputs_2 = self.text_encoder_2(input_ids_2, return_dict=True, output_hidden_states=True)
440
+ negative_encoder_hidden_states = outputs_2.last_hidden_state
441
+
442
+ elif text_encoder_type == "open_clip":
443
+ input_ids = self.tokenizer(
444
+ negative_prompt,
445
+ return_tensors="pt",
446
+ padding="max_length",
447
+ truncation=True,
448
+ add_special_tokens=True,
449
+ max_length=77,
450
+ ).input_ids.to(self._execution_device)
451
+
452
+ outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
453
+
454
+ negative_prompt_embeds = outputs.text_embeds
455
+ negative_encoder_hidden_states = outputs.hidden_states[-2]
456
+
457
+ elif text_encoder_type == "gemma":
458
+ input_ids = self.tokenizer(
459
+ negative_prompt,
460
+ return_tensors="pt",
461
+ padding="max_length",
462
+ truncation=True,
463
+ add_special_tokens=True,
464
+ max_length=77,
465
+ ).input_ids.to(self._execution_device)
466
+ outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
467
+ negative_prompt_embeds = outputs.text_embeds
468
+
469
+ input_ids_2 = self.tokenizer_2(
470
+ negative_prompt,
471
+ truncation=True,
472
+ padding="max_length",
473
+ max_length=256,
474
+ return_tensors="pt",
475
+ ).input_ids.to(self._execution_device)
476
+ outputs_2 = self.text_encoder_2(input_ids_2, return_dict=True, output_hidden_states=True)
477
+ negative_encoder_hidden_states = outputs_2.last_hidden_state
478
+
479
+ negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1)
480
+ negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
481
+
482
+ prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds])
483
+ encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states])
484
+
485
+ # Note that the micro conditionings _do_ flip the order of width, height for the original size
486
+ # and the crop coordinates. This is how it was done in the original code base
487
+ micro_conds = torch.tensor(
488
+ [
489
+ width,
490
+ height,
491
+ micro_conditioning_crop_coord[0],
492
+ micro_conditioning_crop_coord[1],
493
+ micro_conditioning_aesthetic_score,
494
+ ],
495
+ device=self._execution_device,
496
+ dtype=encoder_hidden_states.dtype,
497
+ )
498
+ micro_conds = micro_conds.unsqueeze(0)
499
+ micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 and text2image else batch_size, -1)
500
+
501
+ shape = (batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor)
502
+
503
+ if latents is None and text2image:
504
+ latents = torch.full(
505
+ shape, self.scheduler.config.mask_token_id, dtype=torch.long, device=self._execution_device
506
+ )
507
+ elif image2text:
508
+ if text_encoder_type in ("t5_clip", "gemma"):
509
+ latents = input_ids_2 # [b, l]
510
+ else:
511
+ latents = input_ids
512
+
513
+ model_input = None
514
+
515
+ step_by_step = []
516
+
517
+ self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device)
518
+ num_warmup_steps = len(self.scheduler.timesteps) - num_inference_steps * self.scheduler.order
519
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
520
+ for i, timestep in enumerate(self.scheduler.timesteps):
521
+ if guidance_scale > 1.0 and text2image:
522
+ model_input = torch.cat([latents] * 2)
523
+ encoder_hidden_states = encoder_hidden_states
524
+ elif image2text:
525
+ if model_input is None:
526
+ model_input = self.vqvae.quantize(
527
+ self.vqvae.encode(image.to(self._execution_device, dtype=self.vqvae.dtype)).latents
528
+ )[2][2].reshape(batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor)
529
+
530
+ if text_encoder_type in ("t5_clip", "gemma"):
531
+ outputs_t5 = self.text_encoder_2(latents, return_dict=True)
532
+ encoder_hidden_states = outputs_t5.last_hidden_state
533
+
534
+ batch_prompt = []
535
+ for i in range(latents.size(0)):
536
+ masked_prompt_input_id = latents[i].tolist()
537
+ prompt = self.tokenizer_2.decode(masked_prompt_input_id, skip_special_tokens=True)
538
+ batch_prompt.append(prompt)
539
+
540
+ masked_prompt_input_ids_clip = self.tokenizer(
541
+ batch_prompt,
542
+ truncation=True,
543
+ padding="max_length",
544
+ max_length=77,
545
+ return_tensors="pt"
546
+ ).input_ids
547
+ masked_prompt_input_ids_clip = masked_prompt_input_ids_clip.to(self._execution_device)
548
+ outputs_clip = self.text_encoder(input_ids=masked_prompt_input_ids_clip, return_dict=True)
549
+ prompt_embeds = outputs_clip.text_embeds
550
+
551
+ else:
552
+ outputs = self.text_encoder(latents, return_dict=True, output_hidden_states=True)
553
+ prompt_embeds = outputs.text_embeds
554
+ encoder_hidden_states = outputs.hidden_states[-2]
555
+ else:
556
+ model_input = latents
557
+ encoder_hidden_states = encoder_hidden_states
558
+
559
+ if height == 1024: #args.resolution == 1024:
560
+ img_ids = _prepare_latent_image_ids(
561
+ model_input.shape[0],
562
+ model_input.shape[-2],
563
+ model_input.shape[-1],
564
+ model_input.device,
565
+ model_input.dtype
566
+ )
567
+ else:
568
+ img_ids = _prepare_latent_image_ids(
569
+ model_input.shape[0],
570
+ model_input.shape[-2],
571
+ model_input.shape[-1],
572
+ model_input.device,
573
+ model_input.dtype
574
+ )
575
+ txt_ids = torch.zeros(encoder_hidden_states.shape[1], 3).to(
576
+ device=encoder_hidden_states.device,
577
+ dtype=encoder_hidden_states.dtype
578
+ )
579
+
580
+ # timestep_ = int(timestep / num_inference_steps * 1000)
581
+ model_output, encoder_hidden_states_tmp = self.transformer(
582
+ hidden_states=model_input,
583
+ micro_conds=micro_conds,
584
+ pooled_projections=prompt_embeds,
585
+ encoder_hidden_states=encoder_hidden_states,
586
+ img_ids=img_ids,
587
+ txt_ids=txt_ids,
588
+ timestep=torch.tensor([timestep / num_inference_steps], device=model_input.device),
589
+ )
590
+
591
+ if image2text:
592
+ encoder_hidden_states = encoder_hidden_states_tmp.clone()
593
+
594
+ if guidance_scale > 1.0 and text2image:
595
+ uncond_logits, cond_logits = model_output.chunk(2)
596
+ to_scheduler = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
597
+ elif image2text:
598
+ to_scheduler = encoder_hidden_states
599
+ else:
600
+ to_scheduler = model_output
601
+
602
+ latents = self.scheduler.step(
603
+ model_output=to_scheduler,
604
+ timestep=timestep,
605
+ sample=latents,
606
+ generator=generator,
607
+ ).prev_sample
608
+
609
+ # this line will print the intermediate results of the image-to-text generation
610
+ # step_by_step.append(self.tokenizer.decode(latents[0].tolist(), skip_special_tokens=True))
611
+
612
+ # this line will print the intermediate results of the text-to-image generation
613
+ # output = self.vqvae.decode(
614
+ # latents,
615
+ # force_not_quantize=True,
616
+ # shape=(
617
+ # batch_size,
618
+ # height // self.vae_scale_factor,
619
+ # width // self.vae_scale_factor,
620
+ # self.vqvae.config.latent_channels,
621
+ # ),
622
+ # ).sample.clip(0, 1)
623
+ # output = self.image_processor.postprocess(output, output_type) # output is a list of PIL.Image, you need to save it.
624
+
625
+ if i == len(self.scheduler.timesteps) - 1 or (
626
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
627
+ ):
628
+ progress_bar.update()
629
+ if callback is not None and i % callback_steps == 0:
630
+ step_idx = i // getattr(self.scheduler, "order", 1)
631
+ callback(step_idx, timestep, latents)
632
+
633
+ # with open("step_by_step.txt", "w") as file:
634
+ # for prompt in step_by_step:
635
+ # file.write(prompt + "\n")
636
+
637
+ if guidance_scale > 1.0 and text2image:
638
+ decoded_input_ids = encoder_hidden_states[encoder_hidden_states.shape[0] // 2:].argmax(-1)
639
+ else:
640
+ decoded_input_ids = encoder_hidden_states.argmax(-1)
641
+
642
+ prompts = []
643
+ for i, prompt in enumerate(decoded_input_ids):
644
+ if image2text:
645
+ q_len = question_len[i]
646
+ prompt = self.tokenizer.decode(prompt.tolist()[q_len:], skip_special_tokens=True)
647
+ prompts.append(keep_upto_last_period(dedup_consecutive_words(prompt)))
648
+ else:
649
+ prompts.append("Placeholder")
650
+
651
+ if output_type == "latent":
652
+ output = latents
653
+ else:
654
+ needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast
655
+
656
+ if needs_upcasting:
657
+ self.vqvae.float()
658
+
659
+ if text2image:
660
+ to_vqvae = latents
661
+ else:
662
+ to_vqvae = model_input
663
+
664
+ output = self.vqvae.decode(
665
+ to_vqvae,
666
+ force_not_quantize=True,
667
+ shape=(
668
+ batch_size,
669
+ height // self.vae_scale_factor,
670
+ width // self.vae_scale_factor,
671
+ self.vqvae.config.latent_channels,
672
+ ),
673
+ ).sample.clip(0, 1)
674
+ output = self.image_processor.postprocess(output, output_type)
675
+
676
+ if needs_upcasting:
677
+ self.vqvae.half()
678
+
679
+ self.maybe_free_model_hooks()
680
+
681
+ if not return_dict:
682
+ return (output,)
683
+
684
+ return UnifiedPipelineOutput(images=output, prompts=prompts)
src/scheduler.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team and The MeissonFlow Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from dataclasses import dataclass
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+ from diffusers.utils import BaseOutput
22
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
23
+
24
+
25
+ def gumbel_noise(t, generator=None):
26
+ device = generator.device if generator is not None else t.device
27
+ noise = torch.zeros_like(t, device=device).uniform_(0, 1, generator=generator).to(t.device)
28
+ return -torch.log((-torch.log(noise.clamp(1e-20))).clamp(1e-20))
29
+
30
+
31
+ def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None):
32
+ confidence = torch.log(probs.clamp(1e-20)) + temperature * gumbel_noise(probs, generator=generator)
33
+ sorted_confidence = torch.sort(confidence, dim=-1).values
34
+ cut_off = torch.gather(sorted_confidence, 1, mask_len.long())
35
+ masking = confidence < cut_off
36
+ return masking
37
+
38
+
39
+ @dataclass
40
+ class SchedulerOutput(BaseOutput):
41
+ """
42
+ Output class for the scheduler's `step` function output.
43
+
44
+ Args:
45
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
46
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
47
+ denoising loop.
48
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
49
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
50
+ `pred_original_sample` can be used to preview progress or for guidance.
51
+ """
52
+
53
+ prev_sample: torch.Tensor
54
+ pred_original_sample: torch.Tensor = None
55
+
56
+
57
+ class Scheduler(SchedulerMixin, ConfigMixin):
58
+ order = 1
59
+
60
+ temperatures: torch.Tensor
61
+
62
+ @register_to_config
63
+ def __init__(
64
+ self,
65
+ mask_token_id: int,
66
+ masking_schedule: str = "cosine",
67
+ ):
68
+ self.temperatures = None
69
+ self.timesteps = None
70
+
71
+ def set_timesteps(
72
+ self,
73
+ num_inference_steps: int,
74
+ temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
75
+ device: Union[str, torch.device] = None,
76
+ ):
77
+ self.timesteps = torch.arange(num_inference_steps, device=device).flip(0)
78
+
79
+ if isinstance(temperature, (tuple, list)):
80
+ self.temperatures = torch.linspace(temperature[0], temperature[1], num_inference_steps, device=device)
81
+ else:
82
+ self.temperatures = torch.linspace(temperature, 0.01, num_inference_steps, device=device)
83
+
84
+ def step(
85
+ self,
86
+ model_output: torch.Tensor,
87
+ timestep: torch.long,
88
+ sample: torch.LongTensor,
89
+ starting_mask_ratio: int = 1,
90
+ generator: Optional[torch.Generator] = None,
91
+ return_dict: bool = True,
92
+ ) -> Union[SchedulerOutput, Tuple]:
93
+ two_dim_input = sample.ndim == 3 and model_output.ndim == 4
94
+
95
+ if two_dim_input:
96
+ batch_size, codebook_size, height, width = model_output.shape
97
+ sample = sample.reshape(batch_size, height * width)
98
+ model_output = model_output.reshape(batch_size, codebook_size, height * width).permute(0, 2, 1)
99
+
100
+ unknown_map = sample == self.config.mask_token_id
101
+
102
+ probs = model_output.softmax(dim=-1)
103
+
104
+ device = probs.device
105
+ probs_ = probs.to(generator.device) if generator is not None else probs # handles when generator is on CPU
106
+ if probs_.device.type == "cpu" and probs_.dtype != torch.float32:
107
+ probs_ = probs_.float() # multinomial is not implemented for cpu half precision
108
+ probs_ = probs_.reshape(-1, probs.size(-1))
109
+ pred_original_sample = torch.multinomial(probs_, 1, generator=generator).to(device=device)
110
+ pred_original_sample = pred_original_sample[:, 0].view(*probs.shape[:-1])
111
+ pred_original_sample = torch.where(unknown_map, pred_original_sample, sample)
112
+
113
+ if timestep == 0:
114
+ prev_sample = pred_original_sample
115
+ else:
116
+ seq_len = sample.shape[1]
117
+ step_idx = (self.timesteps == timestep).nonzero()
118
+ ratio = (step_idx + 1) / len(self.timesteps)
119
+
120
+ if self.config.masking_schedule == "cosine":
121
+ mask_ratio = torch.cos(ratio * math.pi / 2)
122
+ elif self.config.masking_schedule == "linear":
123
+ mask_ratio = 1 - ratio
124
+ else:
125
+ raise ValueError(f"unknown masking schedule {self.config.masking_schedule}")
126
+
127
+ mask_ratio = starting_mask_ratio * mask_ratio
128
+
129
+ mask_len = (seq_len * mask_ratio).floor()
130
+ # do not mask more than amount previously masked
131
+ mask_len = torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
132
+ # mask at least one
133
+ mask_len = torch.max(torch.tensor([1], device=model_output.device), mask_len)
134
+
135
+ selected_probs = torch.gather(probs, -1, pred_original_sample[:, :, None])[:, :, 0]
136
+ # Ignores the tokens given in the input by overwriting their confidence.
137
+ selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max)
138
+
139
+ masking = mask_by_random_topk(mask_len, selected_probs, self.temperatures[step_idx], generator)
140
+
141
+ # Masks tokens with lower confidence.
142
+ prev_sample = torch.where(masking, self.config.mask_token_id, pred_original_sample)
143
+
144
+ if two_dim_input:
145
+ prev_sample = prev_sample.reshape(batch_size, height, width)
146
+ pred_original_sample = pred_original_sample.reshape(batch_size, height, width)
147
+
148
+ if not return_dict:
149
+ return (prev_sample, pred_original_sample)
150
+
151
+ return SchedulerOutput(prev_sample, pred_original_sample)
152
+
153
+ def add_noise(self, sample, timesteps, generator=None):
154
+ step_idx = (self.timesteps == timesteps).nonzero()
155
+ ratio = (step_idx + 1) / len(self.timesteps)
156
+
157
+ if self.config.masking_schedule == "cosine":
158
+ mask_ratio = torch.cos(ratio * math.pi / 2)
159
+ elif self.config.masking_schedule == "linear":
160
+ mask_ratio = 1 - ratio
161
+ else:
162
+ raise ValueError(f"unknown masking schedule {self.config.masking_schedule}")
163
+
164
+ mask_indices = (
165
+ torch.rand(
166
+ sample.shape, device=generator.device if generator is not None else sample.device, generator=generator
167
+ ).to(sample.device)
168
+ < mask_ratio
169
+ )
170
+
171
+ masked_sample = sample.clone()
172
+
173
+ masked_sample[mask_indices] = self.config.mask_token_id
174
+
175
+ return masked_sample
src/transformer.py ADDED
@@ -0,0 +1,1459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team, The InstantX Team and The MeissonFlow Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, Optional, Tuple, Union, List
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
25
+ from diffusers.models.attention import FeedForward, BasicTransformerBlock, SkipFFTransformerBlock
26
+ from diffusers.models.attention_processor import (
27
+ Attention,
28
+ AttentionProcessor,
29
+ FluxAttnProcessor2_0,
30
+ )
31
+ from diffusers.models.modeling_utils import ModelMixin
32
+ from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, GlobalResponseNorm, RMSNorm
33
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
34
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
35
+ from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings,TimestepEmbedding, get_timestep_embedding #,FluxPosEmbed
36
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
37
+ from diffusers.models.resnet import Downsample2D, Upsample2D
38
+
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+
43
+
44
+ def get_3d_rotary_pos_embed(
45
+ embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
46
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
47
+ """
48
+ RoPE for video tokens with 3D structure.
49
+
50
+ Args:
51
+ embed_dim: (`int`):
52
+ The embedding dimension size, corresponding to hidden_size_head.
53
+ crops_coords (`Tuple[int]`):
54
+ The top-left and bottom-right coordinates of the crop.
55
+ grid_size (`Tuple[int]`):
56
+ The grid size of the spatial positional embedding (height, width).
57
+ temporal_size (`int`):
58
+ The size of the temporal dimension.
59
+ theta (`float`):
60
+ Scaling factor for frequency computation.
61
+ use_real (`bool`):
62
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
63
+
64
+ Returns:
65
+ `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
66
+ """
67
+ start, stop = crops_coords
68
+ grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
69
+ grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
70
+ grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
71
+
72
+ # Compute dimensions for each axis
73
+ dim_t = embed_dim // 4
74
+ dim_h = embed_dim // 8 * 3
75
+ dim_w = embed_dim // 8 * 3
76
+
77
+ # Temporal frequencies
78
+ freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
79
+ grid_t = torch.from_numpy(grid_t).float()
80
+ freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
81
+ freqs_t = freqs_t.repeat_interleave(2, dim=-1)
82
+
83
+ # Spatial frequencies for height and width
84
+ freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
85
+ freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
86
+ grid_h = torch.from_numpy(grid_h).float()
87
+ grid_w = torch.from_numpy(grid_w).float()
88
+ freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
89
+ freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
90
+ freqs_h = freqs_h.repeat_interleave(2, dim=-1)
91
+ freqs_w = freqs_w.repeat_interleave(2, dim=-1)
92
+
93
+ # Broadcast and concatenate tensors along specified dimension
94
+ def broadcast(tensors, dim=-1):
95
+ num_tensors = len(tensors)
96
+ shape_lens = {len(t.shape) for t in tensors}
97
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
98
+ shape_len = list(shape_lens)[0]
99
+ dim = (dim + shape_len) if dim < 0 else dim
100
+ dims = list(zip(*(list(t.shape) for t in tensors)))
101
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
102
+ assert all(
103
+ [*(len(set(t[1])) <= 2 for t in expandable_dims)]
104
+ ), "invalid dimensions for broadcastable concatenation"
105
+ max_dims = [(t[0], max(t[1])) for t in expandable_dims]
106
+ expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
107
+ expanded_dims.insert(dim, (dim, dims[dim]))
108
+ expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
109
+ tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
110
+ return torch.cat(tensors, dim=dim)
111
+
112
+ freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
113
+
114
+ t, h, w, d = freqs.shape
115
+ freqs = freqs.view(t * h * w, d)
116
+
117
+ # Generate sine and cosine components
118
+ sin = freqs.sin()
119
+ cos = freqs.cos()
120
+
121
+ if use_real:
122
+ return cos, sin
123
+ else:
124
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
125
+ return freqs_cis
126
+
127
+
128
+ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
129
+ """
130
+ RoPE for image tokens with 2d structure.
131
+
132
+ Args:
133
+ embed_dim: (`int`):
134
+ The embedding dimension size
135
+ crops_coords (`Tuple[int]`)
136
+ The top-left and bottom-right coordinates of the crop.
137
+ grid_size (`Tuple[int]`):
138
+ The grid size of the positional embedding.
139
+ use_real (`bool`):
140
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
141
+
142
+ Returns:
143
+ `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
144
+ """
145
+ start, stop = crops_coords
146
+ grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
147
+ grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
148
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
149
+ grid = np.stack(grid, axis=0) # [2, W, H]
150
+
151
+ grid = grid.reshape([2, 1, *grid.shape[1:]])
152
+ pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
153
+ return pos_embed
154
+
155
+
156
+ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
157
+ assert embed_dim % 4 == 0
158
+
159
+ # use half of dimensions to encode grid_h
160
+ emb_h = get_1d_rotary_pos_embed(
161
+ embed_dim // 2, grid[0].reshape(-1), use_real=use_real
162
+ ) # (H*W, D/2) if use_real else (H*W, D/4)
163
+ emb_w = get_1d_rotary_pos_embed(
164
+ embed_dim // 2, grid[1].reshape(-1), use_real=use_real
165
+ ) # (H*W, D/2) if use_real else (H*W, D/4)
166
+
167
+ if use_real:
168
+ cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D)
169
+ sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D)
170
+ return cos, sin
171
+ else:
172
+ emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
173
+ return emb
174
+
175
+
176
+ def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
177
+ assert embed_dim % 4 == 0
178
+
179
+ emb_h = get_1d_rotary_pos_embed(
180
+ embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor
181
+ ) # (H, D/4)
182
+ emb_w = get_1d_rotary_pos_embed(
183
+ embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor
184
+ ) # (W, D/4)
185
+ emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1)
186
+ emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1)
187
+
188
+ emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2)
189
+ return emb
190
+
191
+
192
+ def get_1d_rotary_pos_embed(
193
+ dim: int,
194
+ pos: Union[np.ndarray, int],
195
+ theta: float = 10000.0,
196
+ use_real=False,
197
+ linear_factor=1.0,
198
+ ntk_factor=1.0,
199
+ repeat_interleave_real=True,
200
+ freqs_dtype=torch.float32, # torch.float32 (hunyuan, stable audio), torch.float64 (flux)
201
+ ):
202
+ """
203
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
204
+
205
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
206
+ index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
207
+ data type.
208
+
209
+ Args:
210
+ dim (`int`): Dimension of the frequency tensor.
211
+ pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
212
+ theta (`float`, *optional*, defaults to 10000.0):
213
+ Scaling factor for frequency computation. Defaults to 10000.0.
214
+ use_real (`bool`, *optional*):
215
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
216
+ linear_factor (`float`, *optional*, defaults to 1.0):
217
+ Scaling factor for the context extrapolation. Defaults to 1.0.
218
+ ntk_factor (`float`, *optional*, defaults to 1.0):
219
+ Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
220
+ repeat_interleave_real (`bool`, *optional*, defaults to `True`):
221
+ If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
222
+ Otherwise, they are concateanted with themselves.
223
+ freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
224
+ the dtype of the frequency tensor.
225
+ Returns:
226
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
227
+ """
228
+ assert dim % 2 == 0
229
+
230
+ if isinstance(pos, int):
231
+ pos = np.arange(pos)
232
+ theta = theta * ntk_factor
233
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
234
+ t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
235
+ freqs = torch.outer(t, freqs) # type: ignore # [S, D/2]
236
+ if use_real and repeat_interleave_real:
237
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
238
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
239
+ return freqs_cos, freqs_sin
240
+ elif use_real:
241
+ freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
242
+ freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
243
+ return freqs_cos, freqs_sin
244
+ else:
245
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs).float() # complex64 # [S, D/2]
246
+ return freqs_cis
247
+
248
+
249
+ class FluxPosEmbed(nn.Module):
250
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
251
+ def __init__(self, theta: int, axes_dim: List[int]):
252
+ super().__init__()
253
+ self.theta = theta
254
+ self.axes_dim = axes_dim
255
+
256
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
257
+ n_axes = ids.shape[-1]
258
+ cos_out = []
259
+ sin_out = []
260
+ pos = ids.squeeze().float().cpu().numpy()
261
+ is_mps = ids.device.type == "mps"
262
+ freqs_dtype = torch.float32 if is_mps else torch.float64
263
+ for i in range(n_axes):
264
+ cos, sin = get_1d_rotary_pos_embed(
265
+ self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
266
+ )
267
+ cos_out.append(cos)
268
+ sin_out.append(sin)
269
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
270
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
271
+ return freqs_cos, freqs_sin
272
+
273
+
274
+
275
+ class FusedFluxAttnProcessor2_0:
276
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
277
+
278
+ def __init__(self):
279
+ if not hasattr(F, "scaled_dot_product_attention"):
280
+ raise ImportError(
281
+ "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
282
+ )
283
+
284
+ def __call__(
285
+ self,
286
+ attn: Attention,
287
+ hidden_states: torch.FloatTensor,
288
+ encoder_hidden_states: torch.FloatTensor = None,
289
+ attention_mask: Optional[torch.FloatTensor] = None,
290
+ image_rotary_emb: Optional[torch.Tensor] = None,
291
+ ) -> torch.FloatTensor:
292
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
293
+
294
+ # `sample` projections.
295
+ qkv = attn.to_qkv(hidden_states)
296
+ split_size = qkv.shape[-1] // 3
297
+ query, key, value = torch.split(qkv, split_size, dim=-1)
298
+
299
+ inner_dim = key.shape[-1]
300
+ head_dim = inner_dim // attn.heads
301
+
302
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
303
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
304
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
305
+
306
+ if attn.norm_q is not None:
307
+ query = attn.norm_q(query)
308
+ if attn.norm_k is not None:
309
+ key = attn.norm_k(key)
310
+
311
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
312
+ # `context` projections.
313
+ if encoder_hidden_states is not None:
314
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
315
+ split_size = encoder_qkv.shape[-1] // 3
316
+ (
317
+ encoder_hidden_states_query_proj,
318
+ encoder_hidden_states_key_proj,
319
+ encoder_hidden_states_value_proj,
320
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
321
+
322
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
323
+ batch_size, -1, attn.heads, head_dim
324
+ ).transpose(1, 2)
325
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
326
+ batch_size, -1, attn.heads, head_dim
327
+ ).transpose(1, 2)
328
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
329
+ batch_size, -1, attn.heads, head_dim
330
+ ).transpose(1, 2)
331
+
332
+ if attn.norm_added_q is not None:
333
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
334
+ if attn.norm_added_k is not None:
335
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
336
+
337
+ # attention
338
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
339
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
340
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
341
+
342
+ if image_rotary_emb is not None:
343
+ from diffusers.models.embeddings import apply_rotary_emb
344
+
345
+ query = apply_rotary_emb(query, image_rotary_emb)
346
+ key = apply_rotary_emb(key, image_rotary_emb)
347
+
348
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
349
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
350
+ hidden_states = hidden_states.to(query.dtype)
351
+
352
+ if encoder_hidden_states is not None:
353
+ encoder_hidden_states, hidden_states = (
354
+ hidden_states[:, : encoder_hidden_states.shape[1]],
355
+ hidden_states[:, encoder_hidden_states.shape[1] :],
356
+ )
357
+
358
+ # linear proj
359
+ hidden_states = attn.to_out[0](hidden_states)
360
+ # dropout
361
+ hidden_states = attn.to_out[1](hidden_states)
362
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
363
+
364
+ return hidden_states, encoder_hidden_states
365
+ else:
366
+ return hidden_states
367
+
368
+
369
+
370
+ @maybe_allow_in_graph
371
+ class SingleTransformerBlock(nn.Module):
372
+ r"""
373
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
374
+
375
+ Reference: https://arxiv.org/abs/2403.03206
376
+
377
+ Parameters:
378
+ dim (`int`): The number of channels in the input and output.
379
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
380
+ attention_head_dim (`int`): The number of channels in each head.
381
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
382
+ processing of `context` conditions.
383
+ """
384
+
385
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
386
+ super().__init__()
387
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
388
+
389
+ self.norm = AdaLayerNormZeroSingle(dim)
390
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
391
+ self.act_mlp = nn.GELU(approximate="tanh")
392
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
393
+
394
+ processor = FluxAttnProcessor2_0()
395
+ self.attn = Attention(
396
+ query_dim=dim,
397
+ cross_attention_dim=None,
398
+ dim_head=attention_head_dim,
399
+ heads=num_attention_heads,
400
+ out_dim=dim,
401
+ bias=True,
402
+ processor=processor,
403
+ qk_norm="rms_norm",
404
+ eps=1e-6,
405
+ pre_only=True,
406
+ )
407
+
408
+ def forward(
409
+ self,
410
+ hidden_states: torch.FloatTensor,
411
+ temb: torch.FloatTensor,
412
+ image_rotary_emb=None,
413
+ ):
414
+ residual = hidden_states
415
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
416
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
417
+
418
+ attn_output = self.attn(
419
+ hidden_states=norm_hidden_states,
420
+ image_rotary_emb=image_rotary_emb,
421
+ )
422
+
423
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
424
+ gate = gate.unsqueeze(1)
425
+ hidden_states = gate * self.proj_out(hidden_states)
426
+ hidden_states = residual + hidden_states
427
+ if hidden_states.dtype == torch.float16:
428
+ hidden_states = hidden_states.clip(-65504, 65504)
429
+
430
+ return hidden_states
431
+
432
+ @maybe_allow_in_graph
433
+ class TransformerBlock(nn.Module):
434
+ r"""
435
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
436
+
437
+ Reference: https://arxiv.org/abs/2403.03206
438
+
439
+ Parameters:
440
+ dim (`int`): The number of channels in the input and output.
441
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
442
+ attention_head_dim (`int`): The number of channels in each head.
443
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
444
+ processing of `context` conditions.
445
+ """
446
+
447
+ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
448
+ super().__init__()
449
+
450
+ self.norm1 = AdaLayerNormZero(dim)
451
+
452
+ self.norm1_context = AdaLayerNormZero(dim)
453
+
454
+ if hasattr(F, "scaled_dot_product_attention"):
455
+ processor = FluxAttnProcessor2_0()
456
+ else:
457
+ raise ValueError(
458
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
459
+ )
460
+ self.attn = Attention(
461
+ query_dim=dim,
462
+ cross_attention_dim=None,
463
+ added_kv_proj_dim=dim,
464
+ dim_head=attention_head_dim,
465
+ heads=num_attention_heads,
466
+ out_dim=dim,
467
+ context_pre_only=False,
468
+ bias=True,
469
+ processor=processor,
470
+ qk_norm=qk_norm,
471
+ eps=eps,
472
+ )
473
+
474
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
475
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
476
+ # self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="swiglu")
477
+
478
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
479
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
480
+ # self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="swiglu")
481
+
482
+ # let chunk size default to None
483
+ self._chunk_size = None
484
+ self._chunk_dim = 0
485
+
486
+ def forward(
487
+ self,
488
+ hidden_states: torch.FloatTensor,
489
+ encoder_hidden_states: torch.FloatTensor,
490
+ temb: torch.FloatTensor,
491
+ image_rotary_emb=None,
492
+ ):
493
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
494
+
495
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
496
+ encoder_hidden_states, emb=temb
497
+ )
498
+ # Attention.
499
+ attn_output, context_attn_output = self.attn(
500
+ hidden_states=norm_hidden_states,
501
+ encoder_hidden_states=norm_encoder_hidden_states,
502
+ image_rotary_emb=image_rotary_emb,
503
+ )
504
+
505
+ # Process attention outputs for the `hidden_states`.
506
+ attn_output = gate_msa.unsqueeze(1) * attn_output
507
+ hidden_states = hidden_states + attn_output
508
+
509
+ norm_hidden_states = self.norm2(hidden_states)
510
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
511
+
512
+ ff_output = self.ff(norm_hidden_states)
513
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
514
+
515
+ hidden_states = hidden_states + ff_output
516
+
517
+ # Process attention outputs for the `encoder_hidden_states`.
518
+
519
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
520
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
521
+
522
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
523
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
524
+
525
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
526
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
527
+ if encoder_hidden_states.dtype == torch.float16:
528
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
529
+
530
+ return encoder_hidden_states, hidden_states
531
+
532
+
533
+ class UVit2DConvEmbed(nn.Module):
534
+ def __init__(self, in_channels, block_out_channels, vocab_size, elementwise_affine, eps, bias):
535
+ super().__init__()
536
+ self.embeddings = nn.Embedding(vocab_size, in_channels)
537
+ self.layer_norm = RMSNorm(in_channels, eps, elementwise_affine)
538
+ self.conv = nn.Conv2d(in_channels, block_out_channels, kernel_size=1, bias=bias)
539
+
540
+ def forward(self, input_ids):
541
+ embeddings = self.embeddings(input_ids)
542
+ embeddings = self.layer_norm(embeddings)
543
+ embeddings = embeddings.permute(0, 3, 1, 2)
544
+ embeddings = self.conv(embeddings)
545
+ return embeddings
546
+
547
+ class ConvMlmLayer(nn.Module):
548
+ def __init__(
549
+ self,
550
+ block_out_channels: int,
551
+ in_channels: int,
552
+ use_bias: bool,
553
+ ln_elementwise_affine: bool,
554
+ layer_norm_eps: float,
555
+ codebook_size: int,
556
+ ):
557
+ super().__init__()
558
+ self.conv1 = nn.Conv2d(block_out_channels, in_channels, kernel_size=1, bias=use_bias)
559
+ self.layer_norm = RMSNorm(in_channels, layer_norm_eps, ln_elementwise_affine)
560
+ self.conv2 = nn.Conv2d(in_channels, codebook_size, kernel_size=1, bias=use_bias)
561
+
562
+ def forward(self, hidden_states):
563
+ hidden_states = self.conv1(hidden_states)
564
+ hidden_states = self.layer_norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
565
+ logits = self.conv2(hidden_states)
566
+ return logits
567
+
568
+ class SwiGLU(nn.Module):
569
+ r"""
570
+ A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. It's similar to `GEGLU`
571
+ but uses SiLU / Swish instead of GeLU.
572
+
573
+ Parameters:
574
+ dim_in (`int`): The number of channels in the input.
575
+ dim_out (`int`): The number of channels in the output.
576
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
577
+ """
578
+
579
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
580
+ super().__init__()
581
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
582
+ self.activation = nn.SiLU()
583
+
584
+ def forward(self, hidden_states):
585
+ hidden_states = self.proj(hidden_states)
586
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
587
+ return hidden_states * self.activation(gate)
588
+
589
+ class ConvNextBlock(nn.Module):
590
+ def __init__(
591
+ self, channels, layer_norm_eps, ln_elementwise_affine, use_bias, hidden_dropout, hidden_size, res_ffn_factor=4
592
+ ):
593
+ super().__init__()
594
+ self.depthwise = nn.Conv2d(
595
+ channels,
596
+ channels,
597
+ kernel_size=3,
598
+ padding=1,
599
+ groups=channels,
600
+ bias=use_bias,
601
+ )
602
+ self.norm = RMSNorm(channels, layer_norm_eps, ln_elementwise_affine)
603
+ self.channelwise_linear_1 = nn.Linear(channels, int(channels * res_ffn_factor), bias=use_bias)
604
+ self.channelwise_act = nn.GELU()
605
+ self.channelwise_norm = GlobalResponseNorm(int(channels * res_ffn_factor))
606
+ self.channelwise_linear_2 = nn.Linear(int(channels * res_ffn_factor), channels, bias=use_bias)
607
+ self.channelwise_dropout = nn.Dropout(hidden_dropout)
608
+ self.cond_embeds_mapper = nn.Linear(hidden_size, channels * 2, use_bias)
609
+
610
+ def forward(self, x, cond_embeds):
611
+ x_res = x
612
+
613
+ x = self.depthwise(x)
614
+
615
+ x = x.permute(0, 2, 3, 1)
616
+ x = self.norm(x)
617
+
618
+ x = self.channelwise_linear_1(x)
619
+ x = self.channelwise_act(x)
620
+ x = self.channelwise_norm(x)
621
+ x = self.channelwise_linear_2(x)
622
+ x = self.channelwise_dropout(x)
623
+
624
+ x = x.permute(0, 3, 1, 2)
625
+
626
+ x = x + x_res
627
+
628
+ scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1)
629
+ x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
630
+
631
+ return x
632
+
633
+ class Simple_UVitBlock(nn.Module):
634
+ def __init__(
635
+ self,
636
+ channels,
637
+ ln_elementwise_affine,
638
+ layer_norm_eps,
639
+ use_bias,
640
+ downsample: bool,
641
+ upsample: bool,
642
+ ):
643
+ super().__init__()
644
+
645
+ if downsample:
646
+ self.downsample = Downsample2D(
647
+ channels,
648
+ use_conv=True,
649
+ padding=0,
650
+ name="Conv2d_0",
651
+ kernel_size=2,
652
+ norm_type="rms_norm",
653
+ eps=layer_norm_eps,
654
+ elementwise_affine=ln_elementwise_affine,
655
+ bias=use_bias,
656
+ )
657
+ else:
658
+ self.downsample = None
659
+
660
+ if upsample:
661
+ self.upsample = Upsample2D(
662
+ channels,
663
+ use_conv_transpose=True,
664
+ kernel_size=2,
665
+ padding=0,
666
+ name="conv",
667
+ norm_type="rms_norm",
668
+ eps=layer_norm_eps,
669
+ elementwise_affine=ln_elementwise_affine,
670
+ bias=use_bias,
671
+ interpolate=False,
672
+ )
673
+ else:
674
+ self.upsample = None
675
+
676
+ def forward(self, x):
677
+ # print("before,", x.shape)
678
+ if self.downsample is not None:
679
+ # print('downsample')
680
+ x = self.downsample(x)
681
+
682
+ if self.upsample is not None:
683
+ # print('upsample')
684
+ x = self.upsample(x)
685
+ # print("after,", x.shape)
686
+ return x
687
+
688
+ class Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
689
+ """
690
+ The Transformer model introduced in Flux.
691
+
692
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
693
+
694
+ Parameters:
695
+ patch_size (`int`): Patch size to turn the input data into small patches.
696
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
697
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
698
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
699
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
700
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
701
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
702
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
703
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
704
+ """
705
+
706
+ _supports_gradient_checkpointing = False #True
707
+ # Due to NotImplementedError: DDPOptimizer backend: Found a higher order op in the graph. This is not supported. Please turn off DDP optimizer using torch._dynamo.config.optimize_ddp=False. Note that this can cause performance degradation because there will be one bucket for the entire Dynamo graph.
708
+ # Please refer to this issue - https://github.com/pytorch/pytorch/issues/104674.
709
+ _no_split_modules = ["TransformerBlock", "SingleTransformerBlock"]
710
+
711
+ @register_to_config
712
+ def __init__(
713
+ self,
714
+ patch_size: int = 1,
715
+ in_channels: int = 64,
716
+ num_layers: int = 19,
717
+ num_single_layers: int = 38,
718
+ attention_head_dim: int = 128,
719
+ num_attention_heads: int = 24,
720
+ joint_attention_dim: int = 4096,
721
+ pooled_projection_dim: int = 768,
722
+ guidance_embeds: bool = False, # unused in our implementation
723
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
724
+ vocab_size: int = 8256,
725
+ codebook_size: int = 8192,
726
+ downsample: bool = False,
727
+ upsample: bool = False,
728
+ ):
729
+ super().__init__()
730
+ self.out_channels = in_channels
731
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
732
+
733
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
734
+ text_time_guidance_cls = (
735
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
736
+ )
737
+ self.time_text_embed = text_time_guidance_cls(
738
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
739
+ )
740
+
741
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
742
+
743
+ self.transformer_blocks = nn.ModuleList(
744
+ [
745
+ TransformerBlock(
746
+ dim=self.inner_dim,
747
+ num_attention_heads=self.config.num_attention_heads,
748
+ attention_head_dim=self.config.attention_head_dim,
749
+ )
750
+ for i in range(self.config.num_layers)
751
+ ]
752
+ )
753
+
754
+ self.single_transformer_blocks = nn.ModuleList(
755
+ [
756
+ SingleTransformerBlock(
757
+ dim=self.inner_dim,
758
+ num_attention_heads=self.config.num_attention_heads,
759
+ attention_head_dim=self.config.attention_head_dim,
760
+ )
761
+ for i in range(self.config.num_single_layers)
762
+ ]
763
+ )
764
+
765
+
766
+ self.gradient_checkpointing = False
767
+
768
+ in_channels_embed = self.inner_dim
769
+ ln_elementwise_affine = True
770
+ layer_norm_eps = 1e-06
771
+ use_bias = False
772
+ micro_cond_embed_dim = 1280
773
+ self.embed = UVit2DConvEmbed(
774
+ in_channels_embed, self.inner_dim, self.config.vocab_size, ln_elementwise_affine, layer_norm_eps, use_bias
775
+ )
776
+ self.mlm_layer = ConvMlmLayer(
777
+ self.inner_dim, in_channels_embed, use_bias, ln_elementwise_affine, layer_norm_eps, self.config.codebook_size
778
+ )
779
+ self.cond_embed = TimestepEmbedding(
780
+ micro_cond_embed_dim + self.config.pooled_projection_dim, self.inner_dim, sample_proj_bias=use_bias
781
+ )
782
+ self.encoder_proj_layer_norm = RMSNorm(self.inner_dim, layer_norm_eps, ln_elementwise_affine)
783
+ self.project_to_hidden_norm = RMSNorm(in_channels_embed, layer_norm_eps, ln_elementwise_affine)
784
+ self.project_to_hidden = nn.Linear(in_channels_embed, self.inner_dim, bias=use_bias)
785
+ self.project_from_hidden_norm = RMSNorm(self.inner_dim, layer_norm_eps, ln_elementwise_affine)
786
+ self.project_from_hidden = nn.Linear(self.inner_dim, in_channels_embed, bias=use_bias)
787
+
788
+ self.down_block = Simple_UVitBlock(
789
+ self.inner_dim,
790
+ ln_elementwise_affine,
791
+ layer_norm_eps,
792
+ use_bias,
793
+ downsample,
794
+ False,
795
+ )
796
+ self.up_block = Simple_UVitBlock(
797
+ self.inner_dim, #block_out_channels,
798
+ ln_elementwise_affine,
799
+ layer_norm_eps,
800
+ use_bias,
801
+ False,
802
+ upsample=upsample,
803
+ )
804
+
805
+ # self.fuse_qkv_projections()
806
+
807
+ @property
808
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
809
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
810
+ r"""
811
+ Returns:
812
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
813
+ indexed by its weight name.
814
+ """
815
+ # set recursively
816
+ processors = {}
817
+
818
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
819
+ if hasattr(module, "get_processor"):
820
+ processors[f"{name}.processor"] = module.get_processor()
821
+
822
+ for sub_name, child in module.named_children():
823
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
824
+
825
+ return processors
826
+
827
+ for name, module in self.named_children():
828
+ fn_recursive_add_processors(name, module, processors)
829
+
830
+ return processors
831
+
832
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
833
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
834
+ r"""
835
+ Sets the attention processor to use to compute attention.
836
+
837
+ Parameters:
838
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
839
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
840
+ for **all** `Attention` layers.
841
+
842
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
843
+ processor. This is strongly recommended when setting trainable attention processors.
844
+
845
+ """
846
+ count = len(self.attn_processors.keys())
847
+
848
+ if isinstance(processor, dict) and len(processor) != count:
849
+ raise ValueError(
850
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
851
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
852
+ )
853
+
854
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
855
+ if hasattr(module, "set_processor"):
856
+ if not isinstance(processor, dict):
857
+ module.set_processor(processor)
858
+ else:
859
+ module.set_processor(processor.pop(f"{name}.processor"))
860
+
861
+ for sub_name, child in module.named_children():
862
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
863
+
864
+ for name, module in self.named_children():
865
+ fn_recursive_attn_processor(name, module, processor)
866
+
867
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
868
+ def fuse_qkv_projections(self):
869
+ """
870
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
871
+ are fused. For cross-attention modules, key and value projection matrices are fused.
872
+
873
+ <Tip warning={true}>
874
+
875
+ This API is 🧪 experimental.
876
+
877
+ </Tip>
878
+ """
879
+ self.original_attn_processors = None
880
+
881
+ for _, attn_processor in self.attn_processors.items():
882
+ if "Added" in str(attn_processor.__class__.__name__):
883
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
884
+
885
+ self.original_attn_processors = self.attn_processors
886
+
887
+ for module in self.modules():
888
+ if isinstance(module, Attention):
889
+ module.fuse_projections(fuse=True)
890
+
891
+ self.set_attn_processor(FusedFluxAttnProcessor2_0())
892
+
893
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
894
+ def unfuse_qkv_projections(self):
895
+ """Disables the fused QKV projection if enabled.
896
+
897
+ <Tip warning={true}>
898
+
899
+ This API is 🧪 experimental.
900
+
901
+ </Tip>
902
+
903
+ """
904
+ if self.original_attn_processors is not None:
905
+ self.set_attn_processor(self.original_attn_processors)
906
+
907
+ def _set_gradient_checkpointing(self, module, value=False):
908
+ if hasattr(module, "gradient_checkpointing"):
909
+ module.gradient_checkpointing = value
910
+
911
+ def forward(
912
+ self,
913
+ hidden_states: torch.Tensor,
914
+ encoder_hidden_states: torch.Tensor = None,
915
+ pooled_projections: torch.Tensor = None,
916
+ timestep: torch.LongTensor = None,
917
+ img_ids: torch.Tensor = None,
918
+ txt_ids: torch.Tensor = None,
919
+ guidance: torch.Tensor = None,
920
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
921
+ controlnet_block_samples= None,
922
+ controlnet_single_block_samples=None,
923
+ return_dict: bool = True,
924
+ micro_conds: torch.Tensor = None,
925
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
926
+ """
927
+ The [`FluxTransformer2DModel`] forward method.
928
+
929
+ Args:
930
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
931
+ Input `hidden_states`.
932
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
933
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
934
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
935
+ from the embeddings of input conditions.
936
+ timestep ( `torch.LongTensor`):
937
+ Used to indicate denoising step.
938
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
939
+ A list of tensors that if specified are added to the residuals of transformer blocks.
940
+ joint_attention_kwargs (`dict`, *optional*):
941
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
942
+ `self.processor` in
943
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
944
+ return_dict (`bool`, *optional*, defaults to `True`):
945
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
946
+ tuple.
947
+
948
+ Returns:
949
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
950
+ `tuple` where the first element is the sample tensor.
951
+ """
952
+ micro_cond_encode_dim = 256 # same as self.config.micro_cond_encode_dim = 256 from amused
953
+ micro_cond_embeds = get_timestep_embedding(
954
+ micro_conds.flatten(), micro_cond_encode_dim, flip_sin_to_cos=True, downscale_freq_shift=0
955
+ )
956
+ micro_cond_embeds = micro_cond_embeds.reshape((hidden_states.shape[0], -1))
957
+
958
+ pooled_projections = torch.cat([pooled_projections, micro_cond_embeds], dim=1)
959
+ pooled_projections = pooled_projections.to(dtype=self.dtype)
960
+ pooled_projections = self.cond_embed(pooled_projections).to(encoder_hidden_states.dtype)
961
+
962
+
963
+ hidden_states = self.embed(hidden_states)
964
+
965
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
966
+ encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states)
967
+ hidden_states = self.down_block(hidden_states)
968
+
969
+ batch_size, channels, height, width = hidden_states.shape
970
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels)
971
+ hidden_states = self.project_to_hidden_norm(hidden_states)
972
+ hidden_states = self.project_to_hidden(hidden_states)
973
+
974
+
975
+ if joint_attention_kwargs is not None:
976
+ joint_attention_kwargs = joint_attention_kwargs.copy()
977
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
978
+ else:
979
+ lora_scale = 1.0
980
+
981
+ if USE_PEFT_BACKEND:
982
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
983
+ scale_lora_layers(self, lora_scale)
984
+ else:
985
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
986
+ logger.warning(
987
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
988
+ )
989
+
990
+ timestep = timestep.to(hidden_states.dtype) * 1000
991
+ if guidance is not None:
992
+ guidance = guidance.to(hidden_states.dtype) * 1000
993
+ else:
994
+ guidance = None
995
+ temb = (
996
+ self.time_text_embed(timestep, pooled_projections)
997
+ if guidance is None
998
+ else self.time_text_embed(timestep, guidance, pooled_projections)
999
+ )
1000
+
1001
+ if txt_ids.ndim == 3:
1002
+ logger.warning(
1003
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
1004
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
1005
+ )
1006
+ txt_ids = txt_ids[0]
1007
+ if img_ids.ndim == 3:
1008
+ logger.warning(
1009
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
1010
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
1011
+ )
1012
+ img_ids = img_ids[0]
1013
+ ids = torch.cat((txt_ids, img_ids), dim=0)
1014
+
1015
+ image_rotary_emb = self.pos_embed(ids)
1016
+
1017
+ for index_block, block in enumerate(self.transformer_blocks):
1018
+ if self.training and self.gradient_checkpointing:
1019
+
1020
+ def create_custom_forward(module, return_dict=None):
1021
+ def custom_forward(*inputs):
1022
+ if return_dict is not None:
1023
+ return module(*inputs, return_dict=return_dict)
1024
+ else:
1025
+ return module(*inputs)
1026
+
1027
+ return custom_forward
1028
+
1029
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1030
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
1031
+ create_custom_forward(block),
1032
+ hidden_states,
1033
+ encoder_hidden_states,
1034
+ temb,
1035
+ image_rotary_emb,
1036
+ **ckpt_kwargs,
1037
+ )
1038
+
1039
+ else:
1040
+ encoder_hidden_states, hidden_states = block(
1041
+ hidden_states=hidden_states,
1042
+ encoder_hidden_states=encoder_hidden_states,
1043
+ temb=temb,
1044
+ image_rotary_emb=image_rotary_emb,
1045
+ )
1046
+
1047
+
1048
+ # controlnet residual
1049
+ if controlnet_block_samples is not None:
1050
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
1051
+ interval_control = int(np.ceil(interval_control))
1052
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
1053
+
1054
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
1055
+
1056
+ for index_block, block in enumerate(self.single_transformer_blocks):
1057
+ if self.training and self.gradient_checkpointing:
1058
+
1059
+ def create_custom_forward(module, return_dict=None):
1060
+ def custom_forward(*inputs):
1061
+ if return_dict is not None:
1062
+ return module(*inputs, return_dict=return_dict)
1063
+ else:
1064
+ return module(*inputs)
1065
+
1066
+ return custom_forward
1067
+
1068
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1069
+ hidden_states = torch.utils.checkpoint.checkpoint(
1070
+ create_custom_forward(block),
1071
+ hidden_states,
1072
+ temb,
1073
+ image_rotary_emb,
1074
+ **ckpt_kwargs,
1075
+ )
1076
+
1077
+ else:
1078
+ hidden_states = block(
1079
+ hidden_states=hidden_states,
1080
+ temb=temb,
1081
+ image_rotary_emb=image_rotary_emb,
1082
+ )
1083
+
1084
+ # controlnet residual
1085
+ if controlnet_single_block_samples is not None:
1086
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
1087
+ interval_control = int(np.ceil(interval_control))
1088
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
1089
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
1090
+ + controlnet_single_block_samples[index_block // interval_control]
1091
+ )
1092
+
1093
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
1094
+
1095
+
1096
+ hidden_states = self.project_from_hidden_norm(hidden_states)
1097
+ hidden_states = self.project_from_hidden(hidden_states)
1098
+
1099
+
1100
+ hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
1101
+
1102
+ hidden_states = self.up_block(hidden_states)
1103
+
1104
+ if USE_PEFT_BACKEND:
1105
+ # remove `lora_scale` from each PEFT layer
1106
+ unscale_lora_layers(self, lora_scale)
1107
+
1108
+ output = self.mlm_layer(hidden_states)
1109
+ # self.unfuse_qkv_projections()
1110
+ if not return_dict:
1111
+ return (output,)
1112
+
1113
+
1114
+ return output
1115
+
1116
+
1117
+ class SymmetricTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
1118
+ """
1119
+ The Transformer model introduced in Flux.
1120
+
1121
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
1122
+
1123
+ Parameters:
1124
+ patch_size (`int`): Patch size to turn the input data into small patches.
1125
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
1126
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
1127
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
1128
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
1129
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
1130
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
1131
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
1132
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
1133
+ """
1134
+
1135
+ _supports_gradient_checkpointing = False #True
1136
+ # Due to NotImplementedError: DDPOptimizer backend: Found a higher order op in the graph. This is not supported. Please turn off DDP optimizer using torch._dynamo.config.optimize_ddp=False. Note that this can cause performance degradation because there will be one bucket for the entire Dynamo graph.
1137
+ # Please refer to this issue - https://github.com/pytorch/pytorch/issues/104674.
1138
+ _no_split_modules = ["TransformerBlock", "SingleTransformerBlock"]
1139
+
1140
+ @register_to_config
1141
+ def __init__(
1142
+ self,
1143
+ patch_size: int = 1,
1144
+ in_channels: int = 64,
1145
+ num_layers: int = 19,
1146
+ num_single_layers: int = 38,
1147
+ attention_head_dim: int = 128,
1148
+ num_attention_heads: int = 24,
1149
+ joint_attention_dim: int = 4096,
1150
+ pooled_projection_dim: int = 768,
1151
+ guidance_embeds: bool = False, # unused in our implementation
1152
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
1153
+ vocab_size: int = 8256,
1154
+ codebook_size: int = 8192,
1155
+ tokenizer_vocab_size: Optional[int] = None,
1156
+ t5_dim: Optional[int] = None,
1157
+ downsample: bool = False,
1158
+ upsample: bool = False,
1159
+ ):
1160
+ super().__init__()
1161
+ self.out_channels = in_channels
1162
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
1163
+
1164
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
1165
+ text_time_guidance_cls = (
1166
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
1167
+ )
1168
+ self.time_text_embed = text_time_guidance_cls(
1169
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
1170
+ )
1171
+
1172
+ if t5_dim is not None:
1173
+ self.adapter = nn.Sequential(
1174
+ nn.LayerNorm(t5_dim, elementwise_affine=False, eps=1e-6),
1175
+ nn.Linear(t5_dim, self.config.joint_attention_dim, bias=False)
1176
+ )
1177
+ else:
1178
+ self.adapter = None
1179
+
1180
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
1181
+
1182
+ self.transformer_blocks = nn.ModuleList(
1183
+ [
1184
+ TransformerBlock(
1185
+ dim=self.inner_dim,
1186
+ num_attention_heads=self.config.num_attention_heads,
1187
+ attention_head_dim=self.config.attention_head_dim,
1188
+ )
1189
+ for i in range(self.config.num_layers)
1190
+ ]
1191
+ )
1192
+
1193
+ self.single_transformer_blocks = nn.ModuleList(
1194
+ [
1195
+ SingleTransformerBlock(
1196
+ dim=self.inner_dim,
1197
+ num_attention_heads=self.config.num_attention_heads,
1198
+ attention_head_dim=self.config.attention_head_dim,
1199
+ )
1200
+ for i in range(self.config.num_single_layers)
1201
+ ]
1202
+ )
1203
+
1204
+ self.gradient_checkpointing = False
1205
+
1206
+ in_channels_embed = self.inner_dim
1207
+ ln_elementwise_affine = True
1208
+ layer_norm_eps = 1e-06
1209
+ use_bias = False
1210
+ micro_cond_embed_dim = 1280
1211
+ self.embed = UVit2DConvEmbed(
1212
+ in_channels_embed, self.inner_dim, self.config.vocab_size, ln_elementwise_affine, layer_norm_eps, use_bias
1213
+ )
1214
+ self.mlm_layer = ConvMlmLayer(
1215
+ self.inner_dim, in_channels_embed, use_bias, ln_elementwise_affine, layer_norm_eps, self.config.codebook_size
1216
+ )
1217
+ self.cond_embed = TimestepEmbedding(
1218
+ micro_cond_embed_dim + self.config.pooled_projection_dim, self.inner_dim, sample_proj_bias=use_bias
1219
+ )
1220
+ self.encoder_proj_layer_norm = RMSNorm(self.inner_dim, layer_norm_eps, ln_elementwise_affine)
1221
+ self.project_to_hidden_norm = RMSNorm(in_channels_embed, layer_norm_eps, ln_elementwise_affine)
1222
+ self.project_to_hidden = nn.Linear(in_channels_embed, self.inner_dim, bias=use_bias)
1223
+ self.project_from_hidden_norm = RMSNorm(self.inner_dim, layer_norm_eps, ln_elementwise_affine)
1224
+ self.project_from_hidden = nn.Linear(self.inner_dim, in_channels_embed, bias=use_bias)
1225
+
1226
+ self.down_block = Simple_UVitBlock(
1227
+ self.inner_dim,
1228
+ ln_elementwise_affine,
1229
+ layer_norm_eps,
1230
+ use_bias,
1231
+ downsample,
1232
+ False,
1233
+ )
1234
+ self.up_block = Simple_UVitBlock(
1235
+ self.inner_dim,
1236
+ ln_elementwise_affine,
1237
+ layer_norm_eps,
1238
+ use_bias,
1239
+ False,
1240
+ upsample=upsample,
1241
+ )
1242
+
1243
+ if tokenizer_vocab_size is not None:
1244
+ self.text_decoder = nn.Sequential(
1245
+ nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6),
1246
+ nn.Linear(self.inner_dim, tokenizer_vocab_size, bias=use_bias)
1247
+ )
1248
+ else:
1249
+ self.text_decoder = None
1250
+
1251
+
1252
+ def forward(
1253
+ self,
1254
+ hidden_states: torch.Tensor,
1255
+ encoder_hidden_states: torch.Tensor = None,
1256
+ pooled_projections: torch.Tensor = None,
1257
+ timestep: torch.LongTensor = None,
1258
+ img_ids: torch.Tensor = None,
1259
+ txt_ids: torch.Tensor = None,
1260
+ guidance: torch.Tensor = None,
1261
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
1262
+ controlnet_block_samples= None,
1263
+ controlnet_single_block_samples=None,
1264
+ return_dict: bool = True,
1265
+ micro_conds: torch.Tensor = None,
1266
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
1267
+ """
1268
+ The [`FluxTransformer2DModel`] forward method.
1269
+
1270
+ Args:
1271
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
1272
+ Input `hidden_states`.
1273
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
1274
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
1275
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
1276
+ from the embeddings of input conditions.
1277
+ timestep ( `torch.LongTensor`):
1278
+ Used to indicate denoising step.
1279
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
1280
+ A list of tensors that if specified are added to the residuals of transformer blocks.
1281
+ joint_attention_kwargs (`dict`, *optional*):
1282
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1283
+ `self.processor` in
1284
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1285
+ return_dict (`bool`, *optional*, defaults to `True`):
1286
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
1287
+ tuple.
1288
+
1289
+ Returns:
1290
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
1291
+ `tuple` where the first element is the sample tensor.
1292
+ """
1293
+ micro_cond_encode_dim = 256 # same as self.config.micro_cond_encode_dim = 256 from amused
1294
+ micro_cond_embeds = get_timestep_embedding(
1295
+ micro_conds.flatten(), micro_cond_encode_dim, flip_sin_to_cos=True, downscale_freq_shift=0
1296
+ )
1297
+ micro_cond_embeds = micro_cond_embeds.reshape((hidden_states.shape[0], -1))
1298
+
1299
+ if self.adapter is not None:
1300
+ encoder_hidden_states = self.adapter(encoder_hidden_states)
1301
+
1302
+ pooled_projections = torch.cat([pooled_projections, micro_cond_embeds], dim=1)
1303
+ pooled_projections = pooled_projections.to(dtype=self.dtype)
1304
+ pooled_projections = self.cond_embed(pooled_projections).to(encoder_hidden_states.dtype)
1305
+
1306
+ hidden_states = self.embed(hidden_states)
1307
+
1308
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
1309
+ encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states)
1310
+ hidden_states = self.down_block(hidden_states)
1311
+
1312
+ batch_size, channels, height, width = hidden_states.shape
1313
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels)
1314
+ hidden_states = self.project_to_hidden_norm(hidden_states)
1315
+ hidden_states = self.project_to_hidden(hidden_states)
1316
+
1317
+
1318
+ if joint_attention_kwargs is not None:
1319
+ joint_attention_kwargs = joint_attention_kwargs.copy()
1320
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
1321
+ else:
1322
+ lora_scale = 1.0
1323
+
1324
+ if USE_PEFT_BACKEND:
1325
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1326
+ scale_lora_layers(self, lora_scale)
1327
+ else:
1328
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
1329
+ logger.warning(
1330
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
1331
+ )
1332
+
1333
+ timestep = timestep.to(hidden_states.dtype) * 1000
1334
+ if guidance is not None:
1335
+ guidance = guidance.to(hidden_states.dtype) * 1000
1336
+ else:
1337
+ guidance = None
1338
+ temb = (
1339
+ self.time_text_embed(timestep, pooled_projections)
1340
+ if guidance is None
1341
+ else self.time_text_embed(timestep, guidance, pooled_projections)
1342
+ )
1343
+
1344
+ if txt_ids.ndim == 3:
1345
+ logger.warning(
1346
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
1347
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
1348
+ )
1349
+ txt_ids = txt_ids[0]
1350
+ if img_ids.ndim == 3:
1351
+ logger.warning(
1352
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
1353
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
1354
+ )
1355
+ img_ids = img_ids[0]
1356
+ ids = torch.cat((txt_ids, img_ids), dim=0)
1357
+
1358
+ image_rotary_emb = self.pos_embed(ids)
1359
+
1360
+ for index_block, block in enumerate(self.transformer_blocks):
1361
+ if self.training and self.gradient_checkpointing:
1362
+
1363
+ def create_custom_forward(module, return_dict=None):
1364
+ def custom_forward(*inputs):
1365
+ if return_dict is not None:
1366
+ return module(*inputs, return_dict=return_dict)
1367
+ else:
1368
+ return module(*inputs)
1369
+
1370
+ return custom_forward
1371
+
1372
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1373
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
1374
+ create_custom_forward(block),
1375
+ hidden_states,
1376
+ encoder_hidden_states,
1377
+ temb,
1378
+ image_rotary_emb,
1379
+ **ckpt_kwargs,
1380
+ )
1381
+
1382
+ else:
1383
+ encoder_hidden_states, hidden_states = block(
1384
+ hidden_states=hidden_states,
1385
+ encoder_hidden_states=encoder_hidden_states,
1386
+ temb=temb,
1387
+ image_rotary_emb=image_rotary_emb,
1388
+ )
1389
+
1390
+
1391
+ # controlnet residual
1392
+ if controlnet_block_samples is not None:
1393
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
1394
+ interval_control = int(np.ceil(interval_control))
1395
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
1396
+
1397
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
1398
+
1399
+ for index_block, block in enumerate(self.single_transformer_blocks):
1400
+ if self.training and self.gradient_checkpointing:
1401
+
1402
+ def create_custom_forward(module, return_dict=None):
1403
+ def custom_forward(*inputs):
1404
+ if return_dict is not None:
1405
+ return module(*inputs, return_dict=return_dict)
1406
+ else:
1407
+ return module(*inputs)
1408
+
1409
+ return custom_forward
1410
+
1411
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1412
+ hidden_states = torch.utils.checkpoint.checkpoint(
1413
+ create_custom_forward(block),
1414
+ hidden_states,
1415
+ temb,
1416
+ image_rotary_emb,
1417
+ **ckpt_kwargs,
1418
+ )
1419
+
1420
+ else:
1421
+ hidden_states = block(
1422
+ hidden_states=hidden_states,
1423
+ temb=temb,
1424
+ image_rotary_emb=image_rotary_emb,
1425
+ )
1426
+
1427
+ # controlnet residual
1428
+ if controlnet_single_block_samples is not None:
1429
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
1430
+ interval_control = int(np.ceil(interval_control))
1431
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
1432
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
1433
+ + controlnet_single_block_samples[index_block // interval_control]
1434
+ )
1435
+
1436
+ encoder_hidden_states = hidden_states[:, :encoder_hidden_states.shape[1], ...]
1437
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1]:, ...]
1438
+
1439
+ if self.text_decoder is not None:
1440
+ encoder_hidden_states = self.text_decoder(encoder_hidden_states)
1441
+
1442
+ hidden_states = self.project_from_hidden_norm(hidden_states)
1443
+ hidden_states = self.project_from_hidden(hidden_states)
1444
+
1445
+ hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
1446
+
1447
+ hidden_states = self.up_block(hidden_states)
1448
+
1449
+ if USE_PEFT_BACKEND:
1450
+ # remove `lora_scale` from each PEFT layer
1451
+ unscale_lora_layers(self, lora_scale)
1452
+
1453
+ output = self.mlm_layer(hidden_states)
1454
+ # self.unfuse_qkv_projections()
1455
+ if not return_dict:
1456
+ return (output, encoder_hidden_states)
1457
+
1458
+
1459
+ return output, encoder_hidden_states # [b, l, tokenizer_vocab_size]