zs38 commited on
Commit
a0e3aec
·
1 Parent(s): caf3d0c
Files changed (4) hide show
  1. app.py +87 -146
  2. config.json +15 -0
  3. projection.py +46 -0
  4. transformer_flux_custom.py +890 -0
app.py CHANGED
@@ -1,154 +1,95 @@
1
- import gradio as gr
2
- import numpy as np
3
- import random
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
 
 
 
 
 
 
 
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
-
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
-
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
- result = gr.Image(label="Result", show_label=False)
83
-
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
-
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
- )
99
-
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
 
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
 
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
 
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  if __name__ == "__main__":
154
- demo.launch()
 
1
+ import os
2
+ import sys
3
+ sys.path.append('app/')
4
 
 
 
5
  import torch
6
+ import spaces
7
+ import safetensors
8
+ import gradio as gr
9
+ from PIL import Image
10
+ from loguru import logger
11
+ from torchvision import transforms
12
+ from huggingface_hub import hf_hub_download, login
13
+ from diffusers import FluxPipeline, FluxTransformer2DModel
14
 
15
+ from projection import ImageEncoder
16
+ from transformer_flux_custom import FluxTransformer2DModel as FluxTransformer2DModelWithIP
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
 
 
 
 
 
 
 
 
18
 
19
+ model_config = './config.json'
20
+ pretrained_model_name = 'black-forest-labs/FLUX.1-dev'
21
+ adapter_path = 'model.safetensors'
22
+ adapter_repo_id = "ashen0209/Flux-Character-Consitancy"
 
 
 
23
 
24
+ conditioner_base_model = 'eva02_large_patch14_448.mim_in22k_ft_in1k'
25
+ conditioner_layer_num = 12
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ output_dim = 4096
28
+
29
+ logger.info("init model")
30
+ model = FluxTransformer2DModelWithIP.from_config(model_config, torch_dtype=torch.bfloat16) # type: ignore
31
+ logger.info("load model")
32
+ copy = FluxTransformer2DModel.from_pretrained(pretrained_model_name, subfolder='transformer', torch_dtype=torch.bfloat16)
33
+ model.load_state_dict(copy.state_dict(), strict=False)
34
+ del copy
35
+
36
+ logger.info("load proj")
37
+ extra_embedder = ImageEncoder(output_dim, layer_num=conditioner_layer_num, seq_len=2, device=device, base_model=conditioner_base_model).to(device=device, dtype=torch.bfloat16)
38
+
39
+ logger.info("load pipe")
40
+ pipe = FluxPipeline.from_pretrained(pretrained_model_name, transformer=model, torch_dtype=torch.bfloat16)
41
+ pipe.to(dtype=torch.bfloat16, device=device)
42
+
43
+ logger.info("download adapter")
44
+ login(token=os.environ['HF_TOKEN'])
45
+ file_path = hf_hub_download(repo_id=adapter_repo_id, filename=adapter_path)
46
+
47
+ logger.info("load adapter")
48
+ state_dict = safetensors.torch.load_file(adapter_path)
49
+ state_dict = {'.'.join(k.split('.')[1:]): state_dict[k] for k in state_dict.keys()}
50
+ diff = model.load_state_dict(state_dict, strict=False)
51
+ diff = extra_embedder.load_state_dict(state_dict, strict=False)
52
+
53
+
54
+ IMAGE_PROCESS_TRANSFORM = transforms.Compose([
55
+ transforms.Resize((448, 448)),
56
+ transforms.ToTensor(),
57
+ transforms.Normalize(mean=[0.4815, 0.4578, 0.4082], std=[0.2686, 0.2613, 0.276])
58
+ ])
59
+
60
+ @spaces.GPU
61
+ def generate_image(ref_image, prompt, height=512, width=512, num_steps=25, guidance_scale=3.5, ip_scale=1.0):
62
+ nonlocal pipe
63
+ with torch.no_grad():
64
+ image_refs = map(torch.stack, [
65
+ [IMAGE_PROCESS_TRANSFORM(i) for i in [ref_image, ]]
66
+ ])
67
+ image_refs = [i.to(dtype=torch.bfloat16, device='cuda') for i in image_refs]
68
+ prompt_embeds, pooled_prompt_embeds, txt_ids = pipe.encode_prompt(prompt, prompt)
69
+ visual_prompt_embeds = extra_embedder(image_refs)
70
+ prompt_embeds_with_ref = torch.cat([prompt_embeds, visual_prompt_embeds], dim=1)
71
+ pipe.transformer.ip_scale = ip_scale
72
+ image = pipe(
73
+ prompt_embeds=prompt_embeds_with_ref,
74
+ pooled_prompt_embeds=pooled_prompt_embeds,
75
+ # negative_prompt_embeds=negative_prompt_embeds,
76
+ # negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
77
+ height=height,
78
+ width=width,
79
+ num_inference_steps=num_steps,
80
+ guidance_scale=guidance_scale,
81
+ ).images[0]
82
+ return image
83
+
84
+ iface = gr.Interface(
85
+ fn=generate_image,
86
+ inputs=[
87
+ gr.Image(type="pil", label="Upload Reference Subject Image"),
88
+ gr.Textbox(lines=2, placeholder="Describe the desired contents", label="Description Text"),
89
+ ],
90
+ outputs=gr.Image(type="pil", label="Generated Image"),
91
+ live=True
92
+ )
93
 
94
  if __name__ == "__main__":
95
+ iface.launch()
config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FluxTransformer2DModel",
3
+ "_diffusers_version": "0.30.0.dev0",
4
+ "_name_or_path": "../checkpoints/flux-dev/transformer",
5
+ "attention_head_dim": 128,
6
+ "guidance_embeds": true,
7
+ "in_channels": 64,
8
+ "joint_attention_dim": 4096,
9
+ "num_attention_heads": 24,
10
+ "num_layers": 19,
11
+ "num_single_layers": 38,
12
+ "patch_size": 1,
13
+ "pooled_projection_dim": 768
14
+ }
15
+
projection.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+ import torch
3
+ from torch import nn
4
+ from loguru import logger
5
+ from torch.utils.checkpoint import checkpoint
6
+ # from sbp.nn.model_paths import MODEL_PATHS
7
+
8
+
9
+
10
+ class ImageEncoder(nn.Module):
11
+
12
+ def __init__(self, output_dim, base_model='eva02_base_patch14_224.mim_in22k', layer_num=6, seq_len=3, device='cpu'):
13
+ super().__init__()
14
+ self.output_dim = output_dim
15
+ if base_model == 'eva02_base_patch14_224.mim_in22k':
16
+ self.img_seq = 257
17
+ elif base_model == 'eva02_large_patch14_448.mim_in22k_ft_in1k':
18
+ self.img_seq = 1025
19
+ else:
20
+ raise ValueError(f" unknown {base_model}, supported: {list(paths.keys())}")
21
+ self.base_model = timm.create_model(base_model, pretrained=False)
22
+ del self.base_model.norm, self.base_model.fc_norm, self.base_model.head, self.base_model.head_drop
23
+ del self.base_model.blocks[layer_num:]
24
+ self.project = nn.Linear(self.base_model.num_features, output_dim)
25
+ self.final_norm = nn.LayerNorm(output_dim)
26
+ self.seq_len = seq_len
27
+ self.device = device
28
+
29
+ def forward(self, image_list):
30
+ splits = [len(lst) for lst in image_list]
31
+ if sum(splits) == 0:
32
+ return torch.zeros([len(splits), self.seq_len * self.img_seq, self.output_dim], device=self.device, dtype=torch.bfloat16)
33
+ x = torch.concat(image_list, dim=0).to(device=self.device, dtype=torch.bfloat16)
34
+ x = self.base_model.patch_embed(x)
35
+ x, rot_pos_embed = self.base_model._pos_embed(x)
36
+ for blk in self.base_model.blocks:
37
+ x = blk(x, rope=rot_pos_embed)
38
+ x = self.project(x)
39
+ x = self.final_norm(x)
40
+ b, seq_len, c= x.shape
41
+ split_patches = torch.split(x, splits, dim=0)
42
+ split_patches = [nn.functional.pad(sample, (0, 0, 0, 0, 0, self.seq_len - len(sample))) for sample in split_patches]
43
+ x = torch.stack(split_patches, dim=0)
44
+ x = x.reshape((len(splits), self.seq_len * seq_len, c))
45
+ return x
46
+
transformer_flux_custom.py ADDED
@@ -0,0 +1,890 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ from einops import rearrange
23
+
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
26
+ from diffusers.models.attention import FeedForward
27
+ from diffusers.models.attention_processor import (
28
+ Attention,
29
+ AttentionProcessor,
30
+ FluxAttnProcessor2_0,
31
+ FluxAttnProcessor2_0_NPU,
32
+ FusedFluxAttnProcessor2_0,
33
+ )
34
+ from diffusers.models.modeling_utils import ModelMixin
35
+ from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
36
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
37
+ from diffusers.utils.import_utils import is_torch_npu_available
38
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
39
+ from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
40
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
41
+
42
+
43
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
44
+
45
+
46
+ class FluxIPAttnProcessor2_0:
47
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
48
+
49
+ def __init__(self):
50
+ if not hasattr(F, "scaled_dot_product_attention"):
51
+ raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
52
+
53
+ def __call__(
54
+ self,
55
+ attn: Attention,
56
+ hidden_states: torch.FloatTensor,
57
+ encoder_hidden_states: torch.FloatTensor = None,
58
+ attention_mask: Optional[torch.FloatTensor] = None,
59
+ image_rotary_emb: Optional[torch.Tensor] = None,
60
+ ) -> torch.FloatTensor:
61
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
62
+
63
+ # `sample` projections.
64
+ query = attn.to_q(hidden_states)
65
+ key = attn.to_k(hidden_states)
66
+ value = attn.to_v(hidden_states)
67
+
68
+ inner_dim = key.shape[-1]
69
+ head_dim = inner_dim // attn.heads
70
+
71
+ query = img_q = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
72
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
73
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
74
+
75
+ if attn.norm_q is not None:
76
+ query = attn.norm_q(query)
77
+ if attn.norm_k is not None:
78
+ key = attn.norm_k(key)
79
+
80
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
81
+ if encoder_hidden_states is not None:
82
+ # `context` projections.
83
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
84
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
85
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
86
+
87
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
88
+ batch_size, -1, attn.heads, head_dim
89
+ ).transpose(1, 2)
90
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
91
+ batch_size, -1, attn.heads, head_dim
92
+ ).transpose(1, 2)
93
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
94
+ batch_size, -1, attn.heads, head_dim
95
+ ).transpose(1, 2)
96
+
97
+ if attn.norm_added_q is not None:
98
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
99
+ if attn.norm_added_k is not None:
100
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
101
+
102
+ # attention
103
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
104
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
105
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
106
+
107
+ if image_rotary_emb is not None:
108
+ from diffusers.models.embeddings import apply_rotary_emb
109
+ query = apply_rotary_emb(query, image_rotary_emb)
110
+ key = apply_rotary_emb(key, image_rotary_emb)
111
+
112
+ hidden_states = F.scaled_dot_product_attention(
113
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
114
+ )
115
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
116
+ hidden_states = hidden_states.to(query.dtype)
117
+
118
+ if encoder_hidden_states is not None:
119
+ encoder_hidden_states, hidden_states = (
120
+ hidden_states[:, : encoder_hidden_states.shape[1]],
121
+ hidden_states[:, encoder_hidden_states.shape[1] :],
122
+ )
123
+
124
+ # linear proj
125
+ hidden_states = attn.to_out[0](hidden_states)
126
+ # dropout
127
+ hidden_states = attn.to_out[1](hidden_states)
128
+
129
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
130
+
131
+ return hidden_states, encoder_hidden_states, img_q
132
+ else:
133
+ return hidden_states, img_q
134
+
135
+
136
+ @maybe_allow_in_graph
137
+ class FluxSingleTransformerBlock(nn.Module):
138
+ r"""
139
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
140
+
141
+ Reference: https://arxiv.org/abs/2403.03206
142
+
143
+ Parameters:
144
+ dim (`int`): The number of channels in the input and output.
145
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
146
+ attention_head_dim (`int`): The number of channels in each head.
147
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
148
+ processing of `context` conditions.
149
+ """
150
+
151
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
152
+ super().__init__()
153
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
154
+
155
+ self.norm = AdaLayerNormZeroSingle(dim)
156
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
157
+ self.act_mlp = nn.GELU(approximate="tanh")
158
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
159
+
160
+ if is_torch_npu_available():
161
+ processor = FluxAttnProcessor2_0_NPU()
162
+ else:
163
+ processor = FluxAttnProcessor2_0()
164
+ self.attn = Attention(
165
+ query_dim=dim,
166
+ cross_attention_dim=None,
167
+ dim_head=attention_head_dim,
168
+ heads=num_attention_heads,
169
+ out_dim=dim,
170
+ bias=True,
171
+ processor=processor,
172
+ qk_norm="rms_norm",
173
+ eps=1e-6,
174
+ pre_only=True,
175
+ )
176
+
177
+ def forward(
178
+ self,
179
+ hidden_states: torch.FloatTensor,
180
+ temb: torch.FloatTensor,
181
+ image_rotary_emb=None,
182
+ joint_attention_kwargs=None,
183
+ ):
184
+ residual = hidden_states
185
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
186
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
187
+ joint_attention_kwargs = joint_attention_kwargs or {}
188
+ attn_output = self.attn(
189
+ hidden_states=norm_hidden_states,
190
+ image_rotary_emb=image_rotary_emb,
191
+ **joint_attention_kwargs,
192
+ )
193
+
194
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
195
+ gate = gate.unsqueeze(1)
196
+ hidden_states = gate * self.proj_out(hidden_states)
197
+ hidden_states = residual + hidden_states
198
+ if hidden_states.dtype == torch.float16:
199
+ hidden_states = hidden_states.clip(-65504, 65504)
200
+
201
+ return hidden_states
202
+
203
+
204
+ @maybe_allow_in_graph
205
+ class FluxTransformerBlock(nn.Module):
206
+ r"""
207
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
208
+
209
+ Reference: https://arxiv.org/abs/2403.03206
210
+
211
+ Parameters:
212
+ dim (`int`): The number of channels in the input and output.
213
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
214
+ attention_head_dim (`int`): The number of channels in each head.
215
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
216
+ processing of `context` conditions.
217
+ """
218
+
219
+ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
220
+ super().__init__()
221
+
222
+ self.norm1 = AdaLayerNormZero(dim)
223
+
224
+ self.norm1_context = AdaLayerNormZero(dim)
225
+
226
+ if hasattr(F, "scaled_dot_product_attention"):
227
+ processor = FluxAttnProcessor2_0()
228
+ else:
229
+ raise ValueError(
230
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
231
+ )
232
+ self.attn = Attention(
233
+ query_dim=dim,
234
+ cross_attention_dim=None,
235
+ added_kv_proj_dim=dim,
236
+ dim_head=attention_head_dim,
237
+ heads=num_attention_heads,
238
+ out_dim=dim,
239
+ context_pre_only=False,
240
+ bias=True,
241
+ processor=processor,
242
+ qk_norm=qk_norm,
243
+ eps=eps,
244
+ )
245
+
246
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
247
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
248
+
249
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
250
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
251
+
252
+ # let chunk size default to None
253
+ self._chunk_size = None
254
+ self._chunk_dim = 0
255
+
256
+ def forward(
257
+ self,
258
+ hidden_states: torch.FloatTensor,
259
+ encoder_hidden_states: torch.FloatTensor,
260
+ temb: torch.FloatTensor,
261
+ image_rotary_emb=None,
262
+ joint_attention_kwargs=None,
263
+ ):
264
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
265
+
266
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
267
+ encoder_hidden_states, emb=temb
268
+ )
269
+ joint_attention_kwargs = joint_attention_kwargs or {}
270
+ # Attention.
271
+ attn_output, context_attn_output = self.attn(
272
+ hidden_states=norm_hidden_states,
273
+ encoder_hidden_states=norm_encoder_hidden_states,
274
+ image_rotary_emb=image_rotary_emb,
275
+ **joint_attention_kwargs,
276
+ )
277
+
278
+ # Process attention outputs for the `hidden_states`.
279
+ attn_output = gate_msa.unsqueeze(1) * attn_output
280
+ hidden_states = hidden_states + attn_output
281
+
282
+ norm_hidden_states = self.norm2(hidden_states)
283
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
284
+
285
+ ff_output = self.ff(norm_hidden_states)
286
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
287
+
288
+ hidden_states = hidden_states + ff_output
289
+
290
+ # Process attention outputs for the `encoder_hidden_states`.
291
+
292
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
293
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
294
+
295
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
296
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
297
+
298
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
299
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
300
+ if encoder_hidden_states.dtype == torch.float16:
301
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
302
+
303
+ return encoder_hidden_states, hidden_states
304
+
305
+
306
+ @maybe_allow_in_graph
307
+ class FluxTransformerIPBlock(nn.Module):
308
+ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6, ip_dim=3072):
309
+ super().__init__()
310
+
311
+ self.norm1 = AdaLayerNormZero(dim)
312
+
313
+ self.norm1_context = AdaLayerNormZero(dim)
314
+
315
+ if hasattr(F, "scaled_dot_product_attention"):
316
+ processor = FluxIPAttnProcessor2_0()
317
+ else:
318
+ raise ValueError(
319
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
320
+ )
321
+ self.attn = Attention(
322
+ query_dim=dim,
323
+ cross_attention_dim=None,
324
+ added_kv_proj_dim=dim,
325
+ dim_head=attention_head_dim,
326
+ heads=num_attention_heads,
327
+ out_dim=dim,
328
+ context_pre_only=False,
329
+ bias=True,
330
+ processor=processor,
331
+ qk_norm=qk_norm,
332
+ eps=eps,
333
+ )
334
+ self.ip_k_proj = nn.Linear(ip_dim, num_attention_heads * attention_head_dim, bias=True)
335
+ self.ip_v_proj = nn.Linear(ip_dim, num_attention_heads * attention_head_dim, bias=True)
336
+ self.ip_dim = ip_dim
337
+ self.num_heads = num_attention_heads
338
+ self.head_dim = attention_head_dim
339
+ nn.init.zeros_(self.ip_v_proj.weight)
340
+ nn.init.zeros_(self.ip_v_proj.bias)
341
+
342
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
343
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
344
+
345
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
346
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
347
+
348
+ # let chunk size default to None
349
+ self._chunk_size = None
350
+ self._chunk_dim = 0
351
+
352
+ def forward(
353
+ self,
354
+ hidden_states: torch.FloatTensor,
355
+ encoder_hidden_states: torch.FloatTensor,
356
+ temb: torch.FloatTensor,
357
+ image_rotary_emb=None,
358
+ joint_attention_kwargs=None,
359
+ image_proj=None,
360
+ ip_scale = 1.0
361
+ ):
362
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
363
+
364
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
365
+ encoder_hidden_states, emb=temb
366
+ )
367
+ joint_attention_kwargs = joint_attention_kwargs or {}
368
+ # Attention.
369
+ attn_output, context_attn_output, img_q = self.attn(
370
+ hidden_states=norm_hidden_states,
371
+ encoder_hidden_states=norm_encoder_hidden_states,
372
+ image_rotary_emb=image_rotary_emb,
373
+ **joint_attention_kwargs,
374
+ )
375
+
376
+ # Process attention outputs for the `hidden_states`.
377
+ attn_output = gate_msa.unsqueeze(1) * attn_output
378
+ hidden_states = hidden_states + attn_output
379
+
380
+ norm_hidden_states = self.norm2(hidden_states)
381
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
382
+
383
+ ff_output = self.ff(norm_hidden_states)
384
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
385
+
386
+ hidden_states = hidden_states + ff_output
387
+
388
+ # Process attention outputs for the `encoder_hidden_states`.
389
+
390
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
391
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
392
+
393
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
394
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
395
+
396
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
397
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
398
+ if encoder_hidden_states.dtype == torch.float16:
399
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
400
+
401
+ ip_q = img_q
402
+ # image_proj = encoder_hidden_states[:, -512:, :]
403
+ # print("image_proj:", image_proj.shape, "encoder_hidden_states:", encoder_hidden_states.shape)
404
+ ip_k = self.ip_k_proj(image_proj)
405
+ ip_v = self.ip_v_proj(image_proj)
406
+ ip_k = rearrange(ip_k, 'B L (H D) -> B H L D', H=self.num_heads, D=self.head_dim)
407
+ ip_v = rearrange(ip_v, 'B L (H D) -> B H L D', H=self.num_heads, D=self.head_dim)
408
+ # print("qkv shape:", ip_q.shape, ip_k.shape, ip_v.shape)
409
+ ip_attention = F.scaled_dot_product_attention(ip_q, ip_k, ip_v, attn_mask=None, dropout_p=0.0, is_causal=False)
410
+ ip_attention = rearrange(ip_attention, 'B H L D -> B L (H D)', H=self.num_heads, D=self.head_dim)
411
+ hidden_states = hidden_states + ip_scale * ip_attention
412
+ return encoder_hidden_states, hidden_states
413
+
414
+
415
+
416
+ @maybe_allow_in_graph
417
+ class FluxSingleTransformerIPBlock(nn.Module):
418
+ r"""
419
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
420
+
421
+ Reference: https://arxiv.org/abs/2403.03206
422
+
423
+ Parameters:
424
+ dim (`int`): The number of channels in the input and output.
425
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
426
+ attention_head_dim (`int`): The number of channels in each head.
427
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
428
+ processing of `context` conditions.
429
+ """
430
+
431
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0, ip_dim=4096):
432
+ super().__init__()
433
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
434
+
435
+ self.norm = AdaLayerNormZeroSingle(dim)
436
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
437
+ self.act_mlp = nn.GELU(approximate="tanh")
438
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
439
+
440
+ if is_torch_npu_available():
441
+ processor = FluxAttnProcessor2_0_NPU()
442
+ else:
443
+ processor = FluxIPAttnProcessor2_0()
444
+ self.attn = Attention(
445
+ query_dim=dim,
446
+ cross_attention_dim=None,
447
+ dim_head=attention_head_dim,
448
+ heads=num_attention_heads,
449
+ out_dim=dim,
450
+ bias=True,
451
+ processor=processor,
452
+ qk_norm="rms_norm",
453
+ eps=1e-6,
454
+ pre_only=True,
455
+ )
456
+ self.ip_k_proj = nn.Linear(ip_dim, num_attention_heads * attention_head_dim, bias=True)
457
+ self.ip_v_proj = nn.Linear(ip_dim, num_attention_heads * attention_head_dim, bias=True)
458
+ nn.init.zeros_(self.ip_v_proj.weight)
459
+ nn.init.zeros_(self.ip_v_proj.bias)
460
+ self.ip_dim = ip_dim
461
+ self.num_heads = num_attention_heads
462
+ self.head_dim = attention_head_dim
463
+
464
+ def forward(
465
+ self,
466
+ hidden_states: torch.FloatTensor,
467
+ temb: torch.FloatTensor,
468
+ image_rotary_emb=None,
469
+ joint_attention_kwargs=None,
470
+ image_proj=None,
471
+ ip_scale=1.0
472
+ ):
473
+ residual = hidden_states
474
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
475
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
476
+ joint_attention_kwargs = joint_attention_kwargs or {}
477
+ attn_output, img_q = self.attn(
478
+ hidden_states=norm_hidden_states,
479
+ image_rotary_emb=image_rotary_emb,
480
+ **joint_attention_kwargs,
481
+ )
482
+
483
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
484
+ gate = gate.unsqueeze(1)
485
+ hidden_states = gate * self.proj_out(hidden_states)
486
+ hidden_states = residual + hidden_states
487
+ if hidden_states.dtype == torch.float16:
488
+ hidden_states = hidden_states.clip(-65504, 65504)
489
+
490
+ ip_q = img_q
491
+ # image_proj = encoder_hidden_states[:, -512:, :]
492
+ ip_k = self.ip_k_proj(image_proj)
493
+ ip_v = self.ip_v_proj(image_proj)
494
+ ip_k = rearrange(ip_k, 'B L (H D) -> B H L D', H=self.num_heads, D=self.head_dim)
495
+ ip_v = rearrange(ip_v, 'B L (H D) -> B H L D', H=self.num_heads, D=self.head_dim)
496
+ # print("qkv shape:", ip_q.shape, ip_k.shape, ip_v.shape)
497
+ ip_attention = F.scaled_dot_product_attention(ip_q, ip_k, ip_v, attn_mask=None, dropout_p=0.0, is_causal=False)
498
+ ip_attention = rearrange(ip_attention, 'B H L D -> B L (H D)', H=self.num_heads, D=self.head_dim)
499
+ hidden_states = hidden_states + ip_scale * ip_attention
500
+
501
+ return hidden_states
502
+
503
+
504
+ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
505
+ """
506
+ The Transformer model introduced in Flux.
507
+
508
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
509
+
510
+ Parameters:
511
+ patch_size (`int`): Patch size to turn the input data into small patches.
512
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
513
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
514
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
515
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
516
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
517
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
518
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
519
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
520
+ """
521
+
522
+ _supports_gradient_checkpointing = True
523
+ _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
524
+
525
+ @register_to_config
526
+ def __init__(
527
+ self,
528
+ patch_size: int = 1,
529
+ in_channels: int = 64,
530
+ out_channels: Optional[int] = None,
531
+ num_layers: int = 19,
532
+ num_single_layers: int = 38,
533
+ attention_head_dim: int = 128,
534
+ num_attention_heads: int = 24,
535
+ joint_attention_dim: int = 4096,
536
+ pooled_projection_dim: int = 768,
537
+ guidance_embeds: bool = False,
538
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
539
+ ):
540
+ super().__init__()
541
+ self.out_channels = out_channels or in_channels
542
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
543
+ self.ip_scale = 1.0
544
+
545
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
546
+
547
+ text_time_guidance_cls = (
548
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
549
+ )
550
+ self.time_text_embed = text_time_guidance_cls(
551
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
552
+ )
553
+
554
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
555
+ self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim)
556
+
557
+ self.transformer_blocks = nn.ModuleList(
558
+ [
559
+ FluxTransformerIPBlock(
560
+ dim=self.inner_dim,
561
+ num_attention_heads=self.config.num_attention_heads,
562
+ attention_head_dim=self.config.attention_head_dim,
563
+ ip_dim=4096
564
+ )
565
+ for i in range(self.config.num_layers)
566
+ ]
567
+ )
568
+
569
+ self.single_transformer_blocks = nn.ModuleList(
570
+ [
571
+ FluxSingleTransformerIPBlock(
572
+ dim=self.inner_dim,
573
+ num_attention_heads=self.config.num_attention_heads,
574
+ attention_head_dim=self.config.attention_head_dim,
575
+ ip_dim=4096
576
+ )
577
+ for i in range(self.config.num_single_layers)
578
+ ]
579
+ )
580
+
581
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
582
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
583
+
584
+ self.gradient_checkpointing = False
585
+
586
+ @property
587
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
588
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
589
+ r"""
590
+ Returns:
591
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
592
+ indexed by its weight name.
593
+ """
594
+ # set recursively
595
+ processors = {}
596
+
597
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
598
+ if hasattr(module, "get_processor"):
599
+ processors[f"{name}.processor"] = module.get_processor()
600
+
601
+ for sub_name, child in module.named_children():
602
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
603
+
604
+ return processors
605
+
606
+ for name, module in self.named_children():
607
+ fn_recursive_add_processors(name, module, processors)
608
+
609
+ return processors
610
+
611
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
612
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
613
+ r"""
614
+ Sets the attention processor to use to compute attention.
615
+
616
+ Parameters:
617
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
618
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
619
+ for **all** `Attention` layers.
620
+
621
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
622
+ processor. This is strongly recommended when setting trainable attention processors.
623
+
624
+ """
625
+ count = len(self.attn_processors.keys())
626
+
627
+ if isinstance(processor, dict) and len(processor) != count:
628
+ raise ValueError(
629
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
630
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
631
+ )
632
+
633
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
634
+ if hasattr(module, "set_processor"):
635
+ if not isinstance(processor, dict):
636
+ module.set_processor(processor)
637
+ else:
638
+ module.set_processor(processor.pop(f"{name}.processor"))
639
+
640
+ for sub_name, child in module.named_children():
641
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
642
+
643
+ for name, module in self.named_children():
644
+ fn_recursive_attn_processor(name, module, processor)
645
+
646
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
647
+ def fuse_qkv_projections(self):
648
+ """
649
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
650
+ are fused. For cross-attention modules, key and value projection matrices are fused.
651
+
652
+ <Tip warning={true}>
653
+
654
+ This API is 🧪 experimental.
655
+
656
+ </Tip>
657
+ """
658
+ self.original_attn_processors = None
659
+
660
+ for _, attn_processor in self.attn_processors.items():
661
+ if "Added" in str(attn_processor.__class__.__name__):
662
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
663
+
664
+ self.original_attn_processors = self.attn_processors
665
+
666
+ for module in self.modules():
667
+ if isinstance(module, Attention):
668
+ module.fuse_projections(fuse=True)
669
+
670
+ self.set_attn_processor(FusedFluxAttnProcessor2_0())
671
+
672
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
673
+ def unfuse_qkv_projections(self):
674
+ """Disables the fused QKV projection if enabled.
675
+
676
+ <Tip warning={true}>
677
+
678
+ This API is 🧪 experimental.
679
+
680
+ </Tip>
681
+
682
+ """
683
+ if self.original_attn_processors is not None:
684
+ self.set_attn_processor(self.original_attn_processors)
685
+
686
+ def _set_gradient_checkpointing(self, module, value=False):
687
+ if hasattr(module, "gradient_checkpointing"):
688
+ module.gradient_checkpointing = value
689
+
690
+ def forward(
691
+ self,
692
+ hidden_states: torch.Tensor,
693
+ encoder_hidden_states: torch.Tensor = None,
694
+ pooled_projections: torch.Tensor = None,
695
+ timestep: torch.LongTensor = None,
696
+ img_ids: torch.Tensor = None,
697
+ txt_ids: torch.Tensor = None,
698
+ guidance: torch.Tensor = None,
699
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
700
+ controlnet_block_samples=None,
701
+ controlnet_single_block_samples=None,
702
+ return_dict: bool = True,
703
+ controlnet_blocks_repeat: bool = False,
704
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
705
+ """
706
+ The [`FluxTransformer2DModel`] forward method.
707
+
708
+ Args:
709
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
710
+ Input `hidden_states`.
711
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
712
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
713
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
714
+ from the embeddings of input conditions.
715
+ timestep ( `torch.LongTensor`):
716
+ Used to indicate denoising step.
717
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
718
+ A list of tensors that if specified are added to the residuals of transformer blocks.
719
+ joint_attention_kwargs (`dict`, *optional*):
720
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
721
+ `self.processor` in
722
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
723
+ return_dict (`bool`, *optional*, defaults to `True`):
724
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
725
+ tuple.
726
+
727
+ Returns:
728
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
729
+ `tuple` where the first element is the sample tensor.
730
+ """
731
+ if joint_attention_kwargs is not None:
732
+ joint_attention_kwargs = joint_attention_kwargs.copy()
733
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
734
+ else:
735
+ lora_scale = 1.0
736
+
737
+ if USE_PEFT_BACKEND:
738
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
739
+ scale_lora_layers(self, lora_scale)
740
+ else:
741
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
742
+ logger.warning(
743
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
744
+ )
745
+
746
+ hidden_states = self.x_embedder(hidden_states)
747
+
748
+ timestep = timestep.to(hidden_states.dtype) * 1000
749
+ if guidance is not None:
750
+ guidance = guidance.to(hidden_states.dtype) * 1000
751
+ else:
752
+ guidance = None
753
+
754
+ temb = (
755
+ self.time_text_embed(timestep, pooled_projections)
756
+ if guidance is None
757
+ else self.time_text_embed(timestep, guidance, pooled_projections)
758
+ )
759
+ _, _s, _ = encoder_hidden_states.shape
760
+ if _s > 2048:
761
+ _im_len = -2050
762
+ elif _s > 512:
763
+ _im_len = -514
764
+ else:
765
+ _im_len = -1
766
+ image_proj = encoder_hidden_states[:, _im_len:, :]
767
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states[:, :_im_len, :])
768
+ txt_ids = txt_ids[:_im_len, :]
769
+ if txt_ids.ndim == 3:
770
+ logger.warning(
771
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
772
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
773
+ )
774
+ txt_ids = txt_ids[0]
775
+ if img_ids.ndim == 3:
776
+ logger.warning(
777
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
778
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
779
+ )
780
+ img_ids = img_ids[0]
781
+
782
+ ids = torch.cat((txt_ids, img_ids), dim=0)
783
+ image_rotary_emb = self.pos_embed(ids)
784
+
785
+ for index_block, block in enumerate(self.transformer_blocks):
786
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
787
+
788
+ def create_custom_forward(module, return_dict=None):
789
+ def custom_forward(*inputs):
790
+ if return_dict is not None:
791
+ return module(*inputs, return_dict=return_dict)
792
+ else:
793
+ return module(*inputs)
794
+
795
+ return custom_forward
796
+
797
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
798
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
799
+ create_custom_forward(block),
800
+ hidden_states,
801
+ encoder_hidden_states,
802
+ temb,
803
+ image_rotary_emb,
804
+ joint_attention_kwargs,
805
+ image_proj,
806
+ self.ip_scale,
807
+ **ckpt_kwargs
808
+ )
809
+
810
+ else:
811
+ encoder_hidden_states, hidden_states = block(
812
+ hidden_states=hidden_states,
813
+ encoder_hidden_states=encoder_hidden_states,
814
+ temb=temb,
815
+ image_rotary_emb=image_rotary_emb,
816
+ joint_attention_kwargs=joint_attention_kwargs,
817
+ image_proj=image_proj,
818
+ ip_scale=self.ip_scale
819
+ )
820
+
821
+ # controlnet residual
822
+ if controlnet_block_samples is not None:
823
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
824
+ interval_control = int(np.ceil(interval_control))
825
+ # For Xlabs ControlNet.
826
+ if controlnet_blocks_repeat:
827
+ hidden_states = (
828
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
829
+ )
830
+ else:
831
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
832
+
833
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
834
+
835
+ for index_block, block in enumerate(self.single_transformer_blocks):
836
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
837
+
838
+ def create_custom_forward(module, return_dict=None):
839
+ def custom_forward(*inputs):
840
+ if return_dict is not None:
841
+ return module(*inputs, return_dict=return_dict)
842
+ else:
843
+ return module(*inputs)
844
+
845
+ return custom_forward
846
+
847
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
848
+ hidden_states = torch.utils.checkpoint.checkpoint(
849
+ create_custom_forward(block),
850
+ hidden_states,
851
+ temb,
852
+ image_rotary_emb,
853
+ joint_attention_kwargs,
854
+ image_proj,
855
+ self.ip_scale,
856
+ **ckpt_kwargs,
857
+ )
858
+
859
+ else:
860
+ hidden_states = block(
861
+ hidden_states=hidden_states,
862
+ temb=temb,
863
+ image_rotary_emb=image_rotary_emb,
864
+ joint_attention_kwargs=joint_attention_kwargs,
865
+ image_proj=image_proj,
866
+ ip_scale=self.ip_scale
867
+ )
868
+
869
+ # controlnet residual
870
+ if controlnet_single_block_samples is not None:
871
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
872
+ interval_control = int(np.ceil(interval_control))
873
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
874
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
875
+ + controlnet_single_block_samples[index_block // interval_control]
876
+ )
877
+
878
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
879
+
880
+ hidden_states = self.norm_out(hidden_states, temb)
881
+ output = self.proj_out(hidden_states)
882
+
883
+ if USE_PEFT_BACKEND:
884
+ # remove `lora_scale` from each PEFT layer
885
+ unscale_lora_layers(self, lora_scale)
886
+
887
+ if not return_dict:
888
+ return (output,)
889
+
890
+ return Transformer2DModelOutput(sample=output)