wangqixun commited on
Commit
d4c1bb7
·
verified ·
1 Parent(s): c86fe71

Upload 5 files

Browse files
Files changed (5) hide show
  1. models/attn_processor.py +164 -0
  2. models/norm_layer.py +46 -0
  3. models/resampler.py +365 -0
  4. models/utils.py +139 -0
  5. pipeline.py +550 -0
models/attn_processor.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch.nn as nn
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from diffusers.models.embeddings import apply_rotary_emb
7
+ from einops import rearrange
8
+
9
+ from .norm_layer import RMSNorm
10
+
11
+
12
+ class FluxIPAttnProcessor(nn.Module):
13
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
14
+
15
+ def __init__(
16
+ self,
17
+ hidden_size=None,
18
+ ip_hidden_states_dim=None,
19
+ ):
20
+ super().__init__()
21
+ self.norm_ip_q = RMSNorm(128, eps=1e-6)
22
+ self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size)
23
+ self.norm_ip_k = RMSNorm(128, eps=1e-6)
24
+ self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size)
25
+
26
+
27
+ def __call__(
28
+ self,
29
+ attn,
30
+ hidden_states: torch.FloatTensor,
31
+ encoder_hidden_states: torch.FloatTensor = None,
32
+ attention_mask: Optional[torch.FloatTensor] = None,
33
+ image_rotary_emb: Optional[torch.Tensor] = None,
34
+ emb_dict={},
35
+ subject_emb_dict={},
36
+ *args,
37
+ **kwargs,
38
+ ) -> torch.FloatTensor:
39
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
40
+
41
+ # `sample` projections.
42
+ query = attn.to_q(hidden_states)
43
+ key = attn.to_k(hidden_states)
44
+ value = attn.to_v(hidden_states)
45
+
46
+ # IPadapter
47
+ ip_hidden_states = self._get_ip_hidden_states(
48
+ attn,
49
+ query if encoder_hidden_states is not None else query[:, emb_dict['length_encoder_hidden_states']:],
50
+ subject_emb_dict.get('ip_hidden_states', None)
51
+ )
52
+
53
+ inner_dim = key.shape[-1]
54
+ head_dim = inner_dim // attn.heads
55
+
56
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
57
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
58
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
59
+
60
+ if attn.norm_q is not None:
61
+ query = attn.norm_q(query)
62
+ if attn.norm_k is not None:
63
+ key = attn.norm_k(key)
64
+
65
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
66
+ if encoder_hidden_states is not None:
67
+ # `context` projections.
68
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
69
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
70
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
71
+
72
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
73
+ batch_size, -1, attn.heads, head_dim
74
+ ).transpose(1, 2)
75
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
76
+ batch_size, -1, attn.heads, head_dim
77
+ ).transpose(1, 2)
78
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
79
+ batch_size, -1, attn.heads, head_dim
80
+ ).transpose(1, 2)
81
+
82
+ if attn.norm_added_q is not None:
83
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
84
+ if attn.norm_added_k is not None:
85
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
86
+
87
+ # attention
88
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
89
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
90
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
91
+
92
+ if image_rotary_emb is not None:
93
+ query = apply_rotary_emb(query, image_rotary_emb)
94
+ key = apply_rotary_emb(key, image_rotary_emb)
95
+
96
+ hidden_states = F.scaled_dot_product_attention(
97
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
98
+ )
99
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
100
+ hidden_states = hidden_states.to(query.dtype)
101
+
102
+
103
+ if encoder_hidden_states is not None:
104
+ encoder_hidden_states, hidden_states = (
105
+ hidden_states[:, : encoder_hidden_states.shape[1]],
106
+ hidden_states[:, encoder_hidden_states.shape[1] :],
107
+ )
108
+
109
+ if ip_hidden_states is not None:
110
+ hidden_states = hidden_states + ip_hidden_states * subject_emb_dict.get('scale', 1.0)
111
+
112
+ # linear proj
113
+ hidden_states = attn.to_out[0](hidden_states)
114
+ # dropout
115
+ hidden_states = attn.to_out[1](hidden_states)
116
+
117
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
118
+
119
+ return hidden_states, encoder_hidden_states
120
+ else:
121
+
122
+ if ip_hidden_states is not None:
123
+ hidden_states[:, emb_dict['length_encoder_hidden_states']:] = \
124
+ hidden_states[:, emb_dict['length_encoder_hidden_states']:] + \
125
+ ip_hidden_states * subject_emb_dict.get('scale', 1.0)
126
+
127
+ return hidden_states
128
+
129
+
130
+ def _scaled_dot_product_attention(self, query, key, value, attention_mask=None, heads=None):
131
+ query = rearrange(query, '(b h) l c -> b h l c', h=heads)
132
+ key = rearrange(key, '(b h) l c -> b h l c', h=heads)
133
+ value = rearrange(value, '(b h) l c -> b h l c', h=heads)
134
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=None)
135
+ hidden_states = rearrange(hidden_states, 'b h l c -> (b h) l c', h=heads)
136
+ hidden_states = hidden_states.to(query)
137
+ return hidden_states
138
+
139
+
140
+ def _get_ip_hidden_states(
141
+ self,
142
+ attn,
143
+ img_query,
144
+ ip_hidden_states,
145
+ ):
146
+ if ip_hidden_states is None:
147
+ return None
148
+
149
+ if not hasattr(self, 'to_k_ip') or not hasattr(self, 'to_v_ip'):
150
+ return None
151
+
152
+ ip_query = self.norm_ip_q(rearrange(img_query, 'b l (h d) -> b h l d', h=attn.heads))
153
+ ip_query = rearrange(ip_query, 'b h l d -> (b h) l d')
154
+ ip_key = self.to_k_ip(ip_hidden_states)
155
+ ip_key = self.norm_ip_k(rearrange(ip_key, 'b l (h d) -> b h l d', h=attn.heads))
156
+ ip_key = rearrange(ip_key, 'b h l d -> (b h) l d')
157
+ ip_value = self.to_v_ip(ip_hidden_states)
158
+ ip_value = attn.head_to_batch_dim(ip_value)
159
+ ip_hidden_states = self._scaled_dot_product_attention(
160
+ ip_query.to(ip_value.dtype), ip_key.to(ip_value.dtype), ip_value, None, attn.heads)
161
+ ip_hidden_states = ip_hidden_states.to(img_query.dtype)
162
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
163
+ return ip_hidden_states
164
+
models/norm_layer.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ class RMSNorm(nn.Module):
5
+ def __init__(self, d, p=-1., eps=1e-8, bias=False):
6
+ """
7
+ Root Mean Square Layer Normalization
8
+ :param d: model size
9
+ :param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled)
10
+ :param eps: epsilon value, default 1e-8
11
+ :param bias: whether use bias term for RMSNorm, disabled by
12
+ default because RMSNorm doesn't enforce re-centering invariance.
13
+ """
14
+ super(RMSNorm, self).__init__()
15
+
16
+ self.eps = eps
17
+ self.d = d
18
+ self.p = p
19
+ self.bias = bias
20
+
21
+ self.scale = nn.Parameter(torch.ones(d))
22
+ self.register_parameter("scale", self.scale)
23
+
24
+ if self.bias:
25
+ self.offset = nn.Parameter(torch.zeros(d))
26
+ self.register_parameter("offset", self.offset)
27
+
28
+ def forward(self, x):
29
+ if self.p < 0. or self.p > 1.:
30
+ norm_x = x.norm(2, dim=-1, keepdim=True)
31
+ d_x = self.d
32
+ else:
33
+ partial_size = int(self.d * self.p)
34
+ partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1)
35
+
36
+ norm_x = partial_x.norm(2, dim=-1, keepdim=True)
37
+ d_x = partial_size
38
+
39
+ rms_x = norm_x * d_x ** (-1. / 2)
40
+ x_normed = x / (rms_x + self.eps)
41
+
42
+ if self.bias:
43
+ return self.scale * x_normed + self.offset
44
+
45
+ return self.scale * x_normed
46
+
models/resampler.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import math
4
+
5
+ from diffusers.models.transformers.transformer_2d import BasicTransformerBlock
6
+ from diffusers.models.embeddings import Timesteps, TimestepEmbedding
7
+ from timm.models.vision_transformer import Mlp
8
+
9
+ from .norm_layer import RMSNorm
10
+
11
+
12
+ # FFN
13
+ def FeedForward(dim, mult=4):
14
+ inner_dim = int(dim * mult)
15
+ return nn.Sequential(
16
+ nn.LayerNorm(dim),
17
+ nn.Linear(dim, inner_dim, bias=False),
18
+ nn.GELU(),
19
+ nn.Linear(inner_dim, dim, bias=False),
20
+ )
21
+
22
+
23
+ def reshape_tensor(x, heads):
24
+ bs, length, width = x.shape
25
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
26
+ x = x.view(bs, length, heads, -1)
27
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
28
+ x = x.transpose(1, 2)
29
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
30
+ x = x.reshape(bs, heads, length, -1)
31
+ return x
32
+
33
+
34
+ class PerceiverAttention(nn.Module):
35
+ def __init__(self, *, dim, dim_head=64, heads=8):
36
+ super().__init__()
37
+ self.scale = dim_head**-0.5
38
+ self.dim_head = dim_head
39
+ self.heads = heads
40
+ inner_dim = dim_head * heads
41
+
42
+ self.norm1 = nn.LayerNorm(dim)
43
+ self.norm2 = nn.LayerNorm(dim)
44
+
45
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
46
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
47
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
48
+
49
+
50
+ def forward(self, x, latents, shift=None, scale=None):
51
+ """
52
+ Args:
53
+ x (torch.Tensor): image features
54
+ shape (b, n1, D)
55
+ latent (torch.Tensor): latent features
56
+ shape (b, n2, D)
57
+ """
58
+ x = self.norm1(x)
59
+ latents = self.norm2(latents)
60
+
61
+ if shift is not None and scale is not None:
62
+ latents = latents * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
63
+
64
+ b, l, _ = latents.shape
65
+
66
+ q = self.to_q(latents)
67
+ kv_input = torch.cat((x, latents), dim=-2)
68
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
69
+
70
+ q = reshape_tensor(q, self.heads)
71
+ k = reshape_tensor(k, self.heads)
72
+ v = reshape_tensor(v, self.heads)
73
+
74
+ # attention
75
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
76
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
77
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
78
+ out = weight @ v
79
+
80
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
81
+
82
+ return self.to_out(out)
83
+
84
+
85
+ class ReshapeExpandToken(nn.Module):
86
+ def __init__(self, expand_token, token_dim):
87
+ super().__init__()
88
+ self.expand_token = expand_token
89
+ self.token_dim = token_dim
90
+
91
+ def forward(self, x):
92
+ x = x.reshape(-1, self.expand_token, self.token_dim)
93
+ return x
94
+
95
+
96
+ class TimeResampler(nn.Module):
97
+ def __init__(
98
+ self,
99
+ dim=1024,
100
+ depth=8,
101
+ dim_head=64,
102
+ heads=16,
103
+ num_queries=8,
104
+ embedding_dim=768,
105
+ output_dim=1024,
106
+ ff_mult=4,
107
+ timestep_in_dim=320,
108
+ timestep_flip_sin_to_cos=True,
109
+ timestep_freq_shift=0,
110
+ expand_token=None,
111
+ extra_dim=None,
112
+ ):
113
+ super().__init__()
114
+
115
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
116
+
117
+ self.expand_token = expand_token is not None
118
+ if expand_token:
119
+ self.expand_proj = torch.nn.Sequential(
120
+ torch.nn.Linear(embedding_dim, embedding_dim * 2),
121
+ torch.nn.GELU(),
122
+ torch.nn.Linear(embedding_dim * 2, embedding_dim * expand_token),
123
+ ReshapeExpandToken(expand_token, embedding_dim),
124
+ RMSNorm(embedding_dim, eps=1e-8),
125
+ )
126
+
127
+ self.proj_in = nn.Linear(embedding_dim, dim)
128
+
129
+ self.extra_feature = extra_dim is not None
130
+ if self.extra_feature:
131
+ self.proj_in_norm = RMSNorm(dim, eps=1e-8)
132
+ self.extra_proj_in = torch.nn.Sequential(
133
+ nn.Linear(extra_dim, dim),
134
+ RMSNorm(dim, eps=1e-8),
135
+ )
136
+
137
+ self.proj_out = nn.Linear(dim, output_dim)
138
+ self.norm_out = nn.LayerNorm(output_dim)
139
+
140
+ self.layers = nn.ModuleList([])
141
+ for _ in range(depth):
142
+ self.layers.append(
143
+ nn.ModuleList(
144
+ [
145
+ # msa
146
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
147
+ # ff
148
+ FeedForward(dim=dim, mult=ff_mult),
149
+ # adaLN
150
+ nn.Sequential(nn.SiLU(), nn.Linear(dim, 4 * dim, bias=True))
151
+ ]
152
+ )
153
+ )
154
+
155
+ # time
156
+ self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift)
157
+ self.time_embedding = TimestepEmbedding(timestep_in_dim, dim, act_fn="silu")
158
+
159
+
160
+ def forward(self, x, timestep, need_temb=False, extra_feature=None):
161
+ timestep_emb = self.embedding_time(x, timestep) # bs, dim
162
+
163
+ latents = self.latents.repeat(x.size(0), 1, 1)
164
+
165
+ if self.expand_token:
166
+ x = self.expand_proj(x)
167
+
168
+ x = self.proj_in(x)
169
+
170
+ if self.extra_feature:
171
+ extra_feature = self.extra_proj_in(extra_feature)
172
+ x = self.proj_in_norm(x)
173
+ x = torch.cat([x, extra_feature], dim=1)
174
+
175
+ x = x + timestep_emb[:, None]
176
+
177
+ for attn, ff, adaLN_modulation in self.layers:
178
+ shift_msa, scale_msa, shift_mlp, scale_mlp = adaLN_modulation(timestep_emb).chunk(4, dim=1)
179
+ latents = attn(x, latents, shift_msa, scale_msa) + latents
180
+
181
+ res = latents
182
+ for idx_ff in range(len(ff)):
183
+ layer_ff = ff[idx_ff]
184
+ latents = layer_ff(latents)
185
+ if idx_ff == 0 and isinstance(layer_ff, nn.LayerNorm): # adaLN
186
+ latents = latents * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
187
+ latents = latents + res
188
+
189
+ # latents = ff(latents) + latents
190
+
191
+ latents = self.proj_out(latents)
192
+ latents = self.norm_out(latents)
193
+
194
+ if need_temb:
195
+ return latents, timestep_emb
196
+ else:
197
+ return latents
198
+
199
+
200
+ def embedding_time(self, sample, timestep):
201
+
202
+ # 1. time
203
+ timesteps = timestep
204
+ if not torch.is_tensor(timesteps):
205
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
206
+ # This would be a good case for the `match` statement (Python 3.10+)
207
+ is_mps = sample.device.type == "mps"
208
+ if isinstance(timestep, float):
209
+ dtype = torch.float32 if is_mps else torch.float64
210
+ else:
211
+ dtype = torch.int32 if is_mps else torch.int64
212
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
213
+ elif len(timesteps.shape) == 0:
214
+ timesteps = timesteps[None].to(sample.device)
215
+
216
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
217
+ timesteps = timesteps.expand(sample.shape[0])
218
+
219
+ t_emb = self.time_proj(timesteps)
220
+
221
+ # timesteps does not contain any weights and will always return f32 tensors
222
+ # but time_embedding might actually be running in fp16. so we need to cast here.
223
+ # there might be better ways to encapsulate this.
224
+ t_emb = t_emb.to(dtype=sample.dtype)
225
+
226
+ emb = self.time_embedding(t_emb, None)
227
+ return emb
228
+
229
+
230
+ class CrossLayerCrossScaleProjector(nn.Module):
231
+ def __init__(
232
+ self,
233
+ inner_dim=2688,
234
+ num_attention_heads=42,
235
+ attention_head_dim=64,
236
+ cross_attention_dim=2688,
237
+ num_layers=4,
238
+
239
+ # resampler
240
+ dim=1280,
241
+ depth=4,
242
+ dim_head=64,
243
+ heads=20,
244
+ num_queries=1024,
245
+ embedding_dim=1152 + 1536,
246
+ output_dim=4096,
247
+ ff_mult=4,
248
+ timestep_in_dim=320,
249
+ timestep_flip_sin_to_cos=True,
250
+ timestep_freq_shift=0,
251
+ ):
252
+ super().__init__()
253
+
254
+ self.cross_layer_blocks = nn.ModuleList(
255
+ [
256
+ BasicTransformerBlock(
257
+ inner_dim,
258
+ num_attention_heads,
259
+ attention_head_dim,
260
+ dropout=0,
261
+ cross_attention_dim=cross_attention_dim,
262
+ activation_fn="geglu",
263
+ num_embeds_ada_norm=None,
264
+ attention_bias=False,
265
+ only_cross_attention=False,
266
+ double_self_attention=False,
267
+ upcast_attention=False,
268
+ norm_type='layer_norm',
269
+ norm_elementwise_affine=True,
270
+ norm_eps=1e-6,
271
+ attention_type="default",
272
+ )
273
+ for _ in range(num_layers)
274
+ ]
275
+ )
276
+
277
+ self.cross_scale_blocks = nn.ModuleList(
278
+ [
279
+ BasicTransformerBlock(
280
+ inner_dim,
281
+ num_attention_heads,
282
+ attention_head_dim,
283
+ dropout=0,
284
+ cross_attention_dim=cross_attention_dim,
285
+ activation_fn="geglu",
286
+ num_embeds_ada_norm=None,
287
+ attention_bias=False,
288
+ only_cross_attention=False,
289
+ double_self_attention=False,
290
+ upcast_attention=False,
291
+ norm_type='layer_norm',
292
+ norm_elementwise_affine=True,
293
+ norm_eps=1e-6,
294
+ attention_type="default",
295
+ )
296
+ for _ in range(num_layers)
297
+ ]
298
+ )
299
+
300
+ self.proj = Mlp(
301
+ in_features=inner_dim,
302
+ hidden_features=int(inner_dim*2),
303
+ act_layer=lambda: nn.GELU(approximate="tanh"),
304
+ drop=0
305
+ )
306
+
307
+ self.proj_cross_layer = Mlp(
308
+ in_features=inner_dim,
309
+ hidden_features=int(inner_dim*2),
310
+ act_layer=lambda: nn.GELU(approximate="tanh"),
311
+ drop=0
312
+ )
313
+
314
+ self.proj_cross_scale = Mlp(
315
+ in_features=inner_dim,
316
+ hidden_features=int(inner_dim*2),
317
+ act_layer=lambda: nn.GELU(approximate="tanh"),
318
+ drop=0
319
+ )
320
+
321
+ self.resampler = TimeResampler(
322
+ dim=dim,
323
+ depth=depth,
324
+ dim_head=dim_head,
325
+ heads=heads,
326
+ num_queries=num_queries,
327
+ embedding_dim=embedding_dim,
328
+ output_dim=output_dim,
329
+ ff_mult=ff_mult,
330
+ timestep_in_dim=timestep_in_dim,
331
+ timestep_flip_sin_to_cos=timestep_flip_sin_to_cos,
332
+ timestep_freq_shift=timestep_freq_shift,
333
+ )
334
+
335
+ def forward(self, low_res_shallow, low_res_deep, high_res_deep, timesteps, cross_attention_kwargs=None, need_temb=True):
336
+ '''
337
+ low_res_shallow [bs, 729*l, c]
338
+ low_res_deep [bs, 729, c]
339
+ high_res_deep [bs, 729*4, c]
340
+ '''
341
+
342
+ cross_layer_hidden_states = low_res_deep
343
+ for block in self.cross_layer_blocks:
344
+ cross_layer_hidden_states = block(
345
+ cross_layer_hidden_states,
346
+ encoder_hidden_states=low_res_shallow,
347
+ cross_attention_kwargs=cross_attention_kwargs,
348
+ )
349
+ cross_layer_hidden_states = self.proj_cross_layer(cross_layer_hidden_states)
350
+
351
+ cross_scale_hidden_states = low_res_deep
352
+ for block in self.cross_scale_blocks:
353
+ cross_scale_hidden_states = block(
354
+ cross_scale_hidden_states,
355
+ encoder_hidden_states=high_res_deep,
356
+ cross_attention_kwargs=cross_attention_kwargs,
357
+ )
358
+ cross_scale_hidden_states = self.proj_cross_scale(cross_scale_hidden_states)
359
+
360
+ hidden_states = self.proj(low_res_deep) + cross_scale_hidden_states
361
+ hidden_states = torch.cat([hidden_states, cross_layer_hidden_states], dim=1)
362
+
363
+ hidden_states, timestep_emb = self.resampler(hidden_states, timesteps, need_temb=True)
364
+ return hidden_states, timestep_emb
365
+
models/utils.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from safetensors.torch import load_file
2
+ import torch
3
+ from tqdm import tqdm
4
+
5
+ __all__ = [
6
+ 'flux_load_lora'
7
+ ]
8
+
9
+
10
+ def is_int(d):
11
+ try:
12
+ d = int(d)
13
+ return True
14
+ except Exception as e:
15
+ return False
16
+
17
+
18
+ def flux_load_lora(self, lora_file, lora_weight=1.0):
19
+ device = self.transformer.device
20
+
21
+ # DiT 部分
22
+ state_dict, network_alphas = self.lora_state_dict(lora_file, return_alphas=True)
23
+ state_dict = {k:v.to(device) for k,v in state_dict.items()}
24
+
25
+ model = self.transformer
26
+ keys = list(state_dict.keys())
27
+ keys = [k for k in keys if k.startswith('transformer.')]
28
+
29
+ for k_lora in tqdm(keys, total=len(keys), desc=f"loading lora in transformer ..."):
30
+ v_lora = state_dict[k_lora]
31
+
32
+ # 非 up 的都跳过
33
+ if '.lora_A.weight' in k_lora:
34
+ continue
35
+ if '.alpha' in k_lora:
36
+ continue
37
+
38
+ k_lora_name = k_lora.replace("transformer.", "")
39
+ k_lora_name = k_lora_name.replace(".lora_B.weight", "")
40
+ attr_name_list = k_lora_name.split('.')
41
+
42
+ cur_attr = model
43
+ latest_attr_name = ''
44
+ for idx in range(0, len(attr_name_list)):
45
+ attr_name = attr_name_list[idx]
46
+ if is_int(attr_name):
47
+ cur_attr = cur_attr[int(attr_name)]
48
+ latest_attr_name = ''
49
+ else:
50
+ try:
51
+ if latest_attr_name != '':
52
+ cur_attr = cur_attr.__getattr__(f"{latest_attr_name}.{attr_name}")
53
+ else:
54
+ cur_attr = cur_attr.__getattr__(attr_name)
55
+ latest_attr_name = ''
56
+ except Exception as e:
57
+ if latest_attr_name != '':
58
+ latest_attr_name = f"{latest_attr_name}.{attr_name}"
59
+ else:
60
+ latest_attr_name = attr_name
61
+
62
+ up_w = v_lora
63
+ down_w = state_dict[k_lora.replace('.lora_B.weight', '.lora_A.weight')]
64
+
65
+ # 赋值
66
+ einsum_a = f"ijabcdefg"
67
+ einsum_b = f"jkabcdefg"
68
+ einsum_res = f"ikabcdefg"
69
+ length_shape = len(up_w.shape)
70
+ einsum_str = f"{einsum_a[:length_shape]},{einsum_b[:length_shape]}->{einsum_res[:length_shape]}"
71
+ dtype = cur_attr.weight.data.dtype
72
+ d_w = torch.einsum(einsum_str, up_w.to(torch.float32), down_w.to(torch.float32)).to(dtype)
73
+ cur_attr.weight.data = cur_attr.weight.data + d_w * lora_weight
74
+
75
+
76
+
77
+ # text encoder 部分
78
+ raw_state_dict = load_file(lora_file)
79
+ raw_state_dict = {k:v.to(device) for k,v in raw_state_dict.items()}
80
+
81
+ # text encoder
82
+ state_dict = {k:v for k,v in raw_state_dict.items() if 'lora_te1_' in k}
83
+ model = self.text_encoder
84
+ keys = list(state_dict.keys())
85
+ keys = [k for k in keys if k.startswith('lora_te1_')]
86
+
87
+ for k_lora in tqdm(keys, total=len(keys), desc=f"loading lora in text_encoder ..."):
88
+ v_lora = state_dict[k_lora]
89
+
90
+ # 非 up 的都跳过
91
+ if '.lora_down.weight' in k_lora:
92
+ continue
93
+ if '.alpha' in k_lora:
94
+ continue
95
+
96
+ k_lora_name = k_lora.replace("lora_te1_", "")
97
+ k_lora_name = k_lora_name.replace(".lora_up.weight", "")
98
+ attr_name_list = k_lora_name.split('_')
99
+
100
+ cur_attr = model
101
+ latest_attr_name = ''
102
+ for idx in range(0, len(attr_name_list)):
103
+ attr_name = attr_name_list[idx]
104
+ if is_int(attr_name):
105
+ cur_attr = cur_attr[int(attr_name)]
106
+ latest_attr_name = ''
107
+ else:
108
+ try:
109
+ if latest_attr_name != '':
110
+ cur_attr = cur_attr.__getattr__(f"{latest_attr_name}_{attr_name}")
111
+ else:
112
+ cur_attr = cur_attr.__getattr__(attr_name)
113
+ latest_attr_name = ''
114
+ except Exception as e:
115
+ if latest_attr_name != '':
116
+ latest_attr_name = f"{latest_attr_name}_{attr_name}"
117
+ else:
118
+ latest_attr_name = attr_name
119
+
120
+ up_w = v_lora
121
+ down_w = state_dict[k_lora.replace('.lora_up.weight', '.lora_down.weight')]
122
+
123
+ alpha = state_dict.get(k_lora.replace('.lora_up.weight', '.alpha'), None)
124
+ if alpha is None:
125
+ lora_scale = 1
126
+ else:
127
+ rank = up_w.shape[1]
128
+ lora_scale = alpha / rank
129
+
130
+ # 赋值
131
+ einsum_a = f"ijabcdefg"
132
+ einsum_b = f"jkabcdefg"
133
+ einsum_res = f"ikabcdefg"
134
+ length_shape = len(up_w.shape)
135
+ einsum_str = f"{einsum_a[:length_shape]},{einsum_b[:length_shape]}->{einsum_res[:length_shape]}"
136
+ dtype = cur_attr.weight.data.dtype
137
+ d_w = torch.einsum(einsum_str, up_w.to(torch.float32), down_w.to(torch.float32)).to(dtype)
138
+ cur_attr.weight.data = cur_attr.weight.data + d_w * lora_scale * lora_weight
139
+
pipeline.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Tencent InstantX Team. All rights reserved.
2
+ #
3
+
4
+ from PIL import Image
5
+ from einops import rearrange
6
+ import torch
7
+ from diffusers.pipelines.flux.pipeline_flux import *
8
+ from transformers import SiglipVisionModel, SiglipImageProcessor, AutoModel, AutoImageProcessor
9
+
10
+ from models.attn_processor import FluxIPAttnProcessor
11
+ from models.resampler import CrossLayerCrossScaleProjector
12
+ from models.utils import flux_load_lora
13
+
14
+
15
+ # TODO
16
+ EXAMPLE_DOC_STRING = """
17
+ Examples:
18
+ ```py
19
+ >>> import torch
20
+ >>> from diffusers import FluxPipeline
21
+
22
+ >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
23
+ >>> pipe.to("cuda")
24
+ >>> prompt = "A cat holding a sign that says hello world"
25
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
26
+ >>> # Refer to the pipeline documentation for more details.
27
+ >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
28
+ >>> image.save("flux.png")
29
+ ```
30
+ """
31
+
32
+
33
+ class InstantCharacterFluxPipeline(FluxPipeline):
34
+
35
+
36
+ @torch.inference_mode()
37
+ def encode_siglip_image_emb(self, siglip_image, device, dtype):
38
+ siglip_image = siglip_image.to(device, dtype=dtype)
39
+ res = self.siglip_image_encoder(siglip_image, output_hidden_states=True)
40
+
41
+ siglip_image_embeds = res.last_hidden_state
42
+
43
+ siglip_image_shallow_embeds = torch.cat([res.hidden_states[i] for i in [7, 13, 26]], dim=1)
44
+
45
+ return siglip_image_embeds, siglip_image_shallow_embeds
46
+
47
+
48
+ @torch.inference_mode()
49
+ def encode_dinov2_image_emb(self, dinov2_image, device, dtype):
50
+ dinov2_image = dinov2_image.to(device, dtype=dtype)
51
+ res = self.dino_image_encoder_2(dinov2_image, output_hidden_states=True)
52
+
53
+ dinov2_image_embeds = res.last_hidden_state[:, 1:]
54
+
55
+ dinov2_image_shallow_embeds = torch.cat([res.hidden_states[i][:, 1:] for i in [9, 19, 29]], dim=1)
56
+
57
+ return dinov2_image_embeds, dinov2_image_shallow_embeds
58
+
59
+
60
+ @torch.inference_mode()
61
+ def encode_image_emb(self, siglip_image, device, dtype):
62
+ object_image_pil = siglip_image
63
+ object_image_pil_low_res = [object_image_pil.resize((384, 384))]
64
+ object_image_pil_high_res = object_image_pil.resize((768, 768))
65
+ object_image_pil_high_res = [
66
+ object_image_pil_high_res.crop((0, 0, 384, 384)),
67
+ object_image_pil_high_res.crop((384, 0, 768, 384)),
68
+ object_image_pil_high_res.crop((0, 384, 384, 768)),
69
+ object_image_pil_high_res.crop((384, 384, 768, 768)),
70
+ ]
71
+ nb_split_image = len(object_image_pil_high_res)
72
+
73
+ siglip_image_embeds = self.encode_siglip_image_emb(
74
+ self.siglip_image_processor(images=object_image_pil_low_res, return_tensors="pt").pixel_values,
75
+ device,
76
+ dtype
77
+ )
78
+ dinov2_image_embeds = self.encode_dinov2_image_emb(
79
+ self.dino_image_processor_2(images=object_image_pil_low_res, return_tensors="pt").pixel_values,
80
+ device,
81
+ dtype
82
+ )
83
+
84
+ image_embeds_low_res_deep = torch.cat([siglip_image_embeds[0], dinov2_image_embeds[0]], dim=2)
85
+ image_embeds_low_res_shallow = torch.cat([siglip_image_embeds[1], dinov2_image_embeds[1]], dim=2)
86
+
87
+ siglip_image_high_res = self.siglip_image_processor(images=object_image_pil_high_res, return_tensors="pt").pixel_values
88
+ siglip_image_high_res = siglip_image_high_res[None]
89
+ siglip_image_high_res = rearrange(siglip_image_high_res, 'b n c h w -> (b n) c h w')
90
+ siglip_image_high_res_embeds = self.encode_siglip_image_emb(siglip_image_high_res, device, dtype)
91
+ siglip_image_high_res_deep = rearrange(siglip_image_high_res_embeds[0], '(b n) l c -> b (n l) c', n=nb_split_image)
92
+ dinov2_image_high_res = self.dino_image_processor_2(images=object_image_pil_high_res, return_tensors="pt").pixel_values
93
+ dinov2_image_high_res = dinov2_image_high_res[None]
94
+ dinov2_image_high_res = rearrange(dinov2_image_high_res, 'b n c h w -> (b n) c h w')
95
+ dinov2_image_high_res_embeds = self.encode_dinov2_image_emb(dinov2_image_high_res, device, dtype)
96
+ dinov2_image_high_res_deep = rearrange(dinov2_image_high_res_embeds[0], '(b n) l c -> b (n l) c', n=nb_split_image)
97
+ image_embeds_high_res_deep = torch.cat([siglip_image_high_res_deep, dinov2_image_high_res_deep], dim=2)
98
+
99
+ image_embeds_dict = dict(
100
+ image_embeds_low_res_shallow=image_embeds_low_res_shallow,
101
+ image_embeds_low_res_deep=image_embeds_low_res_deep,
102
+ image_embeds_high_res_deep=image_embeds_high_res_deep,
103
+ )
104
+ return image_embeds_dict
105
+
106
+
107
+ @torch.inference_mode()
108
+ def init_ccp_and_attn_processor(self, *args, **kwargs):
109
+ subject_ip_adapter_path = kwargs['subject_ip_adapter_path']
110
+ nb_token = kwargs['nb_token']
111
+ state_dict = torch.load(subject_ip_adapter_path, map_location="cpu")
112
+ device, dtype = self.transformer.device, self.transformer.dtype
113
+
114
+ print(f"=> init attn processor")
115
+ attn_procs = {}
116
+ for idx_attn, (name, v) in enumerate(self.transformer.attn_processors.items()):
117
+ attn_procs[name] = FluxIPAttnProcessor(
118
+ hidden_size=self.transformer.config.attention_head_dim * self.transformer.config.num_attention_heads,
119
+ ip_hidden_states_dim=self.text_encoder_2.config.d_model,
120
+ ).to(device, dtype=dtype)
121
+ self.transformer.set_attn_processor(attn_procs)
122
+ tmp_ip_layers = torch.nn.ModuleList(self.transformer.attn_processors.values())
123
+ key_name = tmp_ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
124
+ print(f"=> load attn processor: {key_name}")
125
+
126
+ print(f"=> init project")
127
+ image_proj_model = CrossLayerCrossScaleProjector(
128
+ inner_dim=1152 + 1536,
129
+ num_attention_heads=42,
130
+ attention_head_dim=64,
131
+ cross_attention_dim=1152 + 1536,
132
+ num_layers=4,
133
+ dim=1280,
134
+ depth=4,
135
+ dim_head=64,
136
+ heads=20,
137
+ num_queries=nb_token,
138
+ embedding_dim=1152 + 1536,
139
+ output_dim=4096,
140
+ ff_mult=4,
141
+ timestep_in_dim=320,
142
+ timestep_flip_sin_to_cos=True,
143
+ timestep_freq_shift=0,
144
+ )
145
+ image_proj_model.eval()
146
+ image_proj_model.to(device, dtype=dtype)
147
+
148
+ key_name = image_proj_model.load_state_dict(state_dict["image_proj"], strict=False)
149
+ print(f"=> load project: {key_name}")
150
+ self.subject_image_proj_model = image_proj_model
151
+
152
+
153
+ @torch.inference_mode()
154
+ def init_adapter(
155
+ self,
156
+ image_encoder_path=None,
157
+ image_encoder_2_path=None,
158
+ subject_ipadapter_cfg=None,
159
+ ):
160
+ device, dtype = self.transformer.device, self.transformer.dtype
161
+
162
+ # image encoder
163
+ print(f"=> loading image_encoder_1: {image_encoder_path}")
164
+ image_encoder = SiglipVisionModel.from_pretrained(image_encoder_path)
165
+ image_processor = SiglipImageProcessor.from_pretrained(image_encoder_path)
166
+ image_encoder.eval()
167
+ image_encoder.to(device, dtype=dtype)
168
+ self.siglip_image_encoder = image_encoder
169
+ self.siglip_image_processor = image_processor
170
+
171
+ # image encoder 2
172
+ print(f"=> loading image_encoder_2: {image_encoder_2_path}")
173
+ image_encoder_2 = AutoModel.from_pretrained(image_encoder_2_path)
174
+ image_processor_2 = AutoImageProcessor.from_pretrained(image_encoder_2_path)
175
+ image_encoder_2.eval()
176
+ image_encoder_2.to(device, dtype=dtype)
177
+ image_processor_2.crop_size = dict(height=384, width=384)
178
+ image_processor_2.size = dict(shortest_edge=384)
179
+ self.dino_image_encoder_2 = image_encoder_2
180
+ self.dino_image_processor_2 = image_processor_2
181
+
182
+ # ccp and adapter
183
+ self.init_ccp_and_attn_processor(**subject_ipadapter_cfg)
184
+
185
+
186
+ @torch.no_grad()
187
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
188
+ def __call__(
189
+ self,
190
+ prompt: Union[str, List[str]] = None,
191
+ prompt_2: Optional[Union[str, List[str]]] = None,
192
+ negative_prompt: Union[str, List[str]] = None,
193
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
194
+ true_cfg_scale: float = 1.0,
195
+ height: Optional[int] = None,
196
+ width: Optional[int] = None,
197
+ num_inference_steps: int = 28,
198
+ sigmas: Optional[List[float]] = None,
199
+ guidance_scale: float = 3.5,
200
+ num_images_per_prompt: Optional[int] = 1,
201
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
202
+ latents: Optional[torch.FloatTensor] = None,
203
+ prompt_embeds: Optional[torch.FloatTensor] = None,
204
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
205
+ ip_adapter_image: Optional[PipelineImageInput] = None,
206
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
207
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
208
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
209
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
210
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
211
+ output_type: Optional[str] = "pil",
212
+ return_dict: bool = True,
213
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
214
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
215
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
216
+ max_sequence_length: int = 512,
217
+ subject_image: Image.Image = None,
218
+ subject_scale: float = 0.8,
219
+
220
+ ):
221
+ r"""
222
+ Function invoked when calling the pipeline for generation.
223
+
224
+ Args:
225
+ prompt (`str` or `List[str]`, *optional*):
226
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
227
+ instead.
228
+ prompt_2 (`str` or `List[str]`, *optional*):
229
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
230
+ will be used instead
231
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
232
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
233
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
234
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
235
+ num_inference_steps (`int`, *optional*, defaults to 50):
236
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
237
+ expense of slower inference.
238
+ sigmas (`List[float]`, *optional*):
239
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
240
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
241
+ will be used.
242
+ guidance_scale (`float`, *optional*, defaults to 7.0):
243
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
244
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
245
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
246
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
247
+ usually at the expense of lower image quality.
248
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
249
+ The number of images to generate per prompt.
250
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
251
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
252
+ to make generation deterministic.
253
+ latents (`torch.FloatTensor`, *optional*):
254
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
255
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
256
+ tensor will ge generated by sampling using the supplied random `generator`.
257
+ prompt_embeds (`torch.FloatTensor`, *optional*):
258
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
259
+ provided, text embeddings will be generated from `prompt` input argument.
260
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
261
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
262
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
263
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
264
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
265
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
266
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
267
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
268
+ negative_ip_adapter_image:
269
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
270
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
271
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
272
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
273
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
274
+ output_type (`str`, *optional*, defaults to `"pil"`):
275
+ The output format of the generate image. Choose between
276
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
277
+ return_dict (`bool`, *optional*, defaults to `True`):
278
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
279
+ joint_attention_kwargs (`dict`, *optional*):
280
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
281
+ `self.processor` in
282
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
283
+ callback_on_step_end (`Callable`, *optional*):
284
+ A function that calls at the end of each denoising steps during the inference. The function is called
285
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
286
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
287
+ `callback_on_step_end_tensor_inputs`.
288
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
289
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
290
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
291
+ `._callback_tensor_inputs` attribute of your pipeline class.
292
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
293
+
294
+ Examples:
295
+
296
+ Returns:
297
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
298
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
299
+ images.
300
+ """
301
+
302
+ height = height or self.default_sample_size * self.vae_scale_factor
303
+ width = width or self.default_sample_size * self.vae_scale_factor
304
+
305
+ # 1. Check inputs. Raise error if not correct
306
+ self.check_inputs(
307
+ prompt,
308
+ prompt_2,
309
+ height,
310
+ width,
311
+ negative_prompt=negative_prompt,
312
+ negative_prompt_2=negative_prompt_2,
313
+ prompt_embeds=prompt_embeds,
314
+ negative_prompt_embeds=negative_prompt_embeds,
315
+ pooled_prompt_embeds=pooled_prompt_embeds,
316
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
317
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
318
+ max_sequence_length=max_sequence_length,
319
+ )
320
+
321
+ self._guidance_scale = guidance_scale
322
+ self._joint_attention_kwargs = joint_attention_kwargs
323
+ self._interrupt = False
324
+
325
+ # 2. Define call parameters
326
+ if prompt is not None and isinstance(prompt, str):
327
+ batch_size = 1
328
+ elif prompt is not None and isinstance(prompt, list):
329
+ batch_size = len(prompt)
330
+ else:
331
+ batch_size = prompt_embeds.shape[0]
332
+
333
+ device = self._execution_device
334
+ dtype = self.transformer.dtype
335
+
336
+ lora_scale = (
337
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
338
+ )
339
+ do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
340
+ (
341
+ prompt_embeds,
342
+ pooled_prompt_embeds,
343
+ text_ids,
344
+ ) = self.encode_prompt(
345
+ prompt=prompt,
346
+ prompt_2=prompt_2,
347
+ prompt_embeds=prompt_embeds,
348
+ pooled_prompt_embeds=pooled_prompt_embeds,
349
+ device=device,
350
+ num_images_per_prompt=num_images_per_prompt,
351
+ max_sequence_length=max_sequence_length,
352
+ lora_scale=lora_scale,
353
+ )
354
+ if do_true_cfg:
355
+ (
356
+ negative_prompt_embeds,
357
+ negative_pooled_prompt_embeds,
358
+ _,
359
+ ) = self.encode_prompt(
360
+ prompt=negative_prompt,
361
+ prompt_2=negative_prompt_2,
362
+ prompt_embeds=negative_prompt_embeds,
363
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
364
+ device=device,
365
+ num_images_per_prompt=num_images_per_prompt,
366
+ max_sequence_length=max_sequence_length,
367
+ lora_scale=lora_scale,
368
+ )
369
+
370
+ # 3.1 Prepare subject emb
371
+ if subject_image is not None:
372
+ subject_image = subject_image.resize((max(subject_image.size), max(subject_image.size)))
373
+ subject_image_embeds_dict = self.encode_image_emb(subject_image, device, dtype)
374
+
375
+ # 4. Prepare latent variables
376
+ num_channels_latents = self.transformer.config.in_channels // 4
377
+ latents, latent_image_ids = self.prepare_latents(
378
+ batch_size * num_images_per_prompt,
379
+ num_channels_latents,
380
+ height,
381
+ width,
382
+ prompt_embeds.dtype,
383
+ device,
384
+ generator,
385
+ latents,
386
+ )
387
+
388
+ # 5. Prepare timesteps
389
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
390
+ image_seq_len = latents.shape[1]
391
+ mu = calculate_shift(
392
+ image_seq_len,
393
+ self.scheduler.config.base_image_seq_len,
394
+ self.scheduler.config.max_image_seq_len,
395
+ self.scheduler.config.base_shift,
396
+ self.scheduler.config.max_shift,
397
+ )
398
+ timesteps, num_inference_steps = retrieve_timesteps(
399
+ self.scheduler,
400
+ num_inference_steps,
401
+ device,
402
+ sigmas=sigmas,
403
+ mu=mu,
404
+ )
405
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
406
+ self._num_timesteps = len(timesteps)
407
+
408
+ # handle guidance
409
+ if self.transformer.config.guidance_embeds:
410
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
411
+ guidance = guidance.expand(latents.shape[0])
412
+ else:
413
+ guidance = None
414
+
415
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
416
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
417
+ ):
418
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
419
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
420
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
421
+ ):
422
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
423
+
424
+ if self.joint_attention_kwargs is None:
425
+ self._joint_attention_kwargs = {}
426
+
427
+ image_embeds = None
428
+ negative_image_embeds = None
429
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
430
+ image_embeds = self.prepare_ip_adapter_image_embeds(
431
+ ip_adapter_image,
432
+ ip_adapter_image_embeds,
433
+ device,
434
+ batch_size * num_images_per_prompt,
435
+ )
436
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
437
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
438
+ negative_ip_adapter_image,
439
+ negative_ip_adapter_image_embeds,
440
+ device,
441
+ batch_size * num_images_per_prompt,
442
+ )
443
+
444
+ # 6. Denoising loop
445
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
446
+ for i, t in enumerate(timesteps):
447
+ if self.interrupt:
448
+ continue
449
+
450
+ if image_embeds is not None:
451
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
452
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
453
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
454
+
455
+
456
+ # subject adapter
457
+ if subject_image is not None:
458
+ subject_image_prompt_embeds = self.subject_image_proj_model(
459
+ low_res_shallow=subject_image_embeds_dict['image_embeds_low_res_shallow'],
460
+ low_res_deep=subject_image_embeds_dict['image_embeds_low_res_deep'],
461
+ high_res_deep=subject_image_embeds_dict['image_embeds_high_res_deep'],
462
+ timesteps=timestep.to(dtype=latents.dtype),
463
+ need_temb=True
464
+ )[0]
465
+ self._joint_attention_kwargs['emb_dict'] = dict(
466
+ length_encoder_hidden_states=prompt_embeds.shape[1]
467
+ )
468
+ self._joint_attention_kwargs['subject_emb_dict'] = dict(
469
+ ip_hidden_states=subject_image_prompt_embeds,
470
+ scale=subject_scale,
471
+ )
472
+
473
+ noise_pred = self.transformer(
474
+ hidden_states=latents,
475
+ timestep=timestep / 1000,
476
+ guidance=guidance,
477
+ pooled_projections=pooled_prompt_embeds,
478
+ encoder_hidden_states=prompt_embeds,
479
+ txt_ids=text_ids,
480
+ img_ids=latent_image_ids,
481
+ joint_attention_kwargs=self.joint_attention_kwargs,
482
+ return_dict=False,
483
+ )[0]
484
+
485
+ if do_true_cfg:
486
+ if negative_image_embeds is not None:
487
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
488
+ neg_noise_pred = self.transformer(
489
+ hidden_states=latents,
490
+ timestep=timestep / 1000,
491
+ guidance=guidance,
492
+ pooled_projections=negative_pooled_prompt_embeds,
493
+ encoder_hidden_states=negative_prompt_embeds,
494
+ txt_ids=text_ids,
495
+ img_ids=latent_image_ids,
496
+ joint_attention_kwargs=self.joint_attention_kwargs,
497
+ return_dict=False,
498
+ )[0]
499
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
500
+
501
+ # compute the previous noisy sample x_t -> x_t-1
502
+ latents_dtype = latents.dtype
503
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
504
+
505
+ if latents.dtype != latents_dtype:
506
+ if torch.backends.mps.is_available():
507
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
508
+ latents = latents.to(latents_dtype)
509
+
510
+ if callback_on_step_end is not None:
511
+ callback_kwargs = {}
512
+ for k in callback_on_step_end_tensor_inputs:
513
+ callback_kwargs[k] = locals()[k]
514
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
515
+
516
+ latents = callback_outputs.pop("latents", latents)
517
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
518
+
519
+ # call the callback, if provided
520
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
521
+ progress_bar.update()
522
+
523
+ if XLA_AVAILABLE:
524
+ xm.mark_step()
525
+
526
+ if output_type == "latent":
527
+ image = latents
528
+
529
+ else:
530
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
531
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
532
+ image = self.vae.decode(latents, return_dict=False)[0]
533
+ image = self.image_processor.postprocess(image, output_type=output_type)
534
+
535
+ # Offload all models
536
+ self.maybe_free_model_hooks()
537
+
538
+ if not return_dict:
539
+ return (image,)
540
+
541
+ return FluxPipelineOutput(images=image)
542
+
543
+
544
+ def with_style_lora(self, lora_file_path, lora_weight=1.0, trigger='', *args, **kwargs):
545
+ flux_load_lora(self, lora_file_path, lora_weight)
546
+ kwargs['prompt'] = f"{trigger}, {kwargs['prompt']}"
547
+ res = self.__call__(*args, **kwargs)
548
+ flux_load_lora(self, lora_file_path, -lora_weight)
549
+ return res
550
+