cai-qi commited on
Commit
7bc69cd
·
verified ·
1 Parent(s): a53dbd0

Delete hi_diffusers

Browse files
hi_diffusers/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from .models.transformers.transformer_hidream_image import HiDreamImageTransformer2DModel
2
- from .pipelines.hidream_image.pipeline_hidream_image import HiDreamImagePipeline
 
 
 
hi_diffusers/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (361 Bytes)
 
hi_diffusers/models/__pycache__/attention.cpython-310.pyc DELETED
Binary file (3.39 kB)
 
hi_diffusers/models/__pycache__/attention_processor.cpython-310.pyc DELETED
Binary file (2.98 kB)
 
hi_diffusers/models/__pycache__/embeddings.cpython-310.pyc DELETED
Binary file (5.09 kB)
 
hi_diffusers/models/__pycache__/moe.cpython-310.pyc DELETED
Binary file (5.52 kB)
 
hi_diffusers/models/attention.py DELETED
@@ -1,106 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from typing import Optional
4
- from diffusers.models.attention_processor import Attention
5
- from diffusers.utils.torch_utils import maybe_allow_in_graph
6
-
7
- @maybe_allow_in_graph
8
- class HiDreamAttention(Attention):
9
- def __init__(
10
- self,
11
- query_dim: int,
12
- heads: int = 8,
13
- dim_head: int = 64,
14
- upcast_attention: bool = False,
15
- upcast_softmax: bool = False,
16
- scale_qk: bool = True,
17
- eps: float = 1e-5,
18
- processor = None,
19
- out_dim: int = None,
20
- single: bool = False
21
- ):
22
- super(Attention, self).__init__()
23
- self.inner_dim = out_dim if out_dim is not None else dim_head * heads
24
- self.query_dim = query_dim
25
- self.upcast_attention = upcast_attention
26
- self.upcast_softmax = upcast_softmax
27
- self.out_dim = out_dim if out_dim is not None else query_dim
28
-
29
- self.scale_qk = scale_qk
30
- self.scale = dim_head**-0.5 if self.scale_qk else 1.0
31
-
32
- self.heads = out_dim // dim_head if out_dim is not None else heads
33
- self.sliceable_head_dim = heads
34
- self.single = single
35
-
36
- linear_cls = nn.Linear
37
- self.linear_cls = linear_cls
38
- self.to_q = linear_cls(query_dim, self.inner_dim)
39
- self.to_k = linear_cls(self.inner_dim, self.inner_dim)
40
- self.to_v = linear_cls(self.inner_dim, self.inner_dim)
41
- self.to_out = linear_cls(self.inner_dim, self.out_dim)
42
- self.q_rms_norm = nn.RMSNorm(self.inner_dim, eps)
43
- self.k_rms_norm = nn.RMSNorm(self.inner_dim, eps)
44
-
45
- if not single:
46
- self.to_q_t = linear_cls(query_dim, self.inner_dim)
47
- self.to_k_t = linear_cls(self.inner_dim, self.inner_dim)
48
- self.to_v_t = linear_cls(self.inner_dim, self.inner_dim)
49
- self.to_out_t = linear_cls(self.inner_dim, self.out_dim)
50
- self.q_rms_norm_t = nn.RMSNorm(self.inner_dim, eps)
51
- self.k_rms_norm_t = nn.RMSNorm(self.inner_dim, eps)
52
-
53
- self.set_processor(processor)
54
- self.apply(self._init_weights)
55
-
56
- def _init_weights(self, m):
57
- if isinstance(m, nn.Linear):
58
- nn.init.xavier_uniform_(m.weight)
59
- if m.bias is not None:
60
- nn.init.constant_(m.bias, 0)
61
-
62
- def forward(
63
- self,
64
- norm_image_tokens: torch.FloatTensor,
65
- image_tokens_masks: torch.FloatTensor = None,
66
- norm_text_tokens: torch.FloatTensor = None,
67
- rope: torch.FloatTensor = None,
68
- ) -> torch.Tensor:
69
- return self.processor(
70
- self,
71
- image_tokens = norm_image_tokens,
72
- image_tokens_masks = image_tokens_masks,
73
- text_tokens = norm_text_tokens,
74
- rope = rope,
75
- )
76
-
77
- class FeedForwardSwiGLU(nn.Module):
78
- def __init__(
79
- self,
80
- dim: int,
81
- hidden_dim: int,
82
- multiple_of: int = 256,
83
- ffn_dim_multiplier: Optional[float] = None,
84
- ):
85
- super().__init__()
86
- hidden_dim = int(2 * hidden_dim / 3)
87
- # custom dim factor multiplier
88
- if ffn_dim_multiplier is not None:
89
- hidden_dim = int(ffn_dim_multiplier * hidden_dim)
90
- hidden_dim = multiple_of * (
91
- (hidden_dim + multiple_of - 1) // multiple_of
92
- )
93
-
94
- self.w1 = nn.Linear(dim, hidden_dim, bias=False)
95
- self.w2 = nn.Linear(hidden_dim, dim, bias=False)
96
- self.w3 = nn.Linear(dim, hidden_dim, bias=False)
97
- self.apply(self._init_weights)
98
-
99
- def _init_weights(self, m):
100
- if isinstance(m, nn.Linear):
101
- nn.init.xavier_uniform_(m.weight)
102
- if m.bias is not None:
103
- nn.init.constant_(m.bias, 0)
104
-
105
- def forward(self, x):
106
- return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hi_diffusers/models/attention_processor.py DELETED
@@ -1,95 +0,0 @@
1
- from typing import Optional
2
- import torch
3
- from .attention import HiDreamAttention
4
-
5
- try:
6
- from flash_attn_interface import flash_attn_func
7
- USE_FLASH_ATTN3 = True
8
- except:
9
- from flash_attn import flash_attn_func
10
- USE_FLASH_ATTN3 = False
11
-
12
- # Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
13
- def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
14
- xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
15
- xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
16
- xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
17
- xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
18
- return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
19
-
20
- def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
21
- if USE_FLASH_ATTN3:
22
- hidden_states = flash_attn_func(query, key, value, causal=False, deterministic=False)[0]
23
- else:
24
- hidden_states = flash_attn_func(query, key, value, dropout_p=0., causal=False)
25
- hidden_states = hidden_states.flatten(-2)
26
- hidden_states = hidden_states.to(query.dtype)
27
- return hidden_states
28
-
29
- class HiDreamAttnProcessor_flashattn:
30
- """Attention processor used typically in processing the SD3-like self-attention projections."""
31
-
32
- def __call__(
33
- self,
34
- attn: HiDreamAttention,
35
- image_tokens: torch.FloatTensor,
36
- image_tokens_masks: Optional[torch.FloatTensor] = None,
37
- text_tokens: Optional[torch.FloatTensor] = None,
38
- rope: torch.FloatTensor = None,
39
- *args,
40
- **kwargs,
41
- ) -> torch.FloatTensor:
42
- dtype = image_tokens.dtype
43
- batch_size = image_tokens.shape[0]
44
-
45
- query_i = attn.q_rms_norm(attn.to_q(image_tokens)).to(dtype=dtype)
46
- key_i = attn.k_rms_norm(attn.to_k(image_tokens)).to(dtype=dtype)
47
- value_i = attn.to_v(image_tokens)
48
-
49
- inner_dim = key_i.shape[-1]
50
- head_dim = inner_dim // attn.heads
51
-
52
- query_i = query_i.view(batch_size, -1, attn.heads, head_dim)
53
- key_i = key_i.view(batch_size, -1, attn.heads, head_dim)
54
- value_i = value_i.view(batch_size, -1, attn.heads, head_dim)
55
- if image_tokens_masks is not None:
56
- key_i = key_i * image_tokens_masks.view(batch_size, -1, 1, 1)
57
-
58
- if not attn.single:
59
- query_t = attn.q_rms_norm_t(attn.to_q_t(text_tokens)).to(dtype=dtype)
60
- key_t = attn.k_rms_norm_t(attn.to_k_t(text_tokens)).to(dtype=dtype)
61
- value_t = attn.to_v_t(text_tokens)
62
-
63
- query_t = query_t.view(batch_size, -1, attn.heads, head_dim)
64
- key_t = key_t.view(batch_size, -1, attn.heads, head_dim)
65
- value_t = value_t.view(batch_size, -1, attn.heads, head_dim)
66
-
67
- num_image_tokens = query_i.shape[1]
68
- num_text_tokens = query_t.shape[1]
69
- query = torch.cat([query_i, query_t], dim=1)
70
- key = torch.cat([key_i, key_t], dim=1)
71
- value = torch.cat([value_i, value_t], dim=1)
72
- else:
73
- query = query_i
74
- key = key_i
75
- value = value_i
76
-
77
- if query.shape[-1] == rope.shape[-3] * 2:
78
- query, key = apply_rope(query, key, rope)
79
- else:
80
- query_1, query_2 = query.chunk(2, dim=-1)
81
- key_1, key_2 = key.chunk(2, dim=-1)
82
- query_1, key_1 = apply_rope(query_1, key_1, rope)
83
- query = torch.cat([query_1, query_2], dim=-1)
84
- key = torch.cat([key_1, key_2], dim=-1)
85
-
86
- hidden_states = attention(query, key, value)
87
-
88
- if not attn.single:
89
- hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
90
- hidden_states_i = attn.to_out(hidden_states_i)
91
- hidden_states_t = attn.to_out_t(hidden_states_t)
92
- return hidden_states_i, hidden_states_t
93
- else:
94
- hidden_states = attn.to_out(hidden_states)
95
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hi_diffusers/models/embeddings.py DELETED
@@ -1,114 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from typing import List
4
- from diffusers.models.embeddings import Timesteps, TimestepEmbedding
5
-
6
- # Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
7
- def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
8
- assert dim % 2 == 0, "The dimension must be even."
9
-
10
- scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
11
- omega = 1.0 / (theta**scale)
12
-
13
- batch_size, seq_length = pos.shape
14
- out = torch.einsum("...n,d->...nd", pos, omega)
15
- cos_out = torch.cos(out)
16
- sin_out = torch.sin(out)
17
-
18
- stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
19
- out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
20
- return out.float()
21
-
22
- # Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
23
- class EmbedND(nn.Module):
24
- def __init__(self, theta: int, axes_dim: List[int]):
25
- super().__init__()
26
- self.theta = theta
27
- self.axes_dim = axes_dim
28
-
29
- def forward(self, ids: torch.Tensor) -> torch.Tensor:
30
- n_axes = ids.shape[-1]
31
- emb = torch.cat(
32
- [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
33
- dim=-3,
34
- )
35
- return emb.unsqueeze(2)
36
-
37
- class PatchEmbed(nn.Module):
38
- def __init__(
39
- self,
40
- patch_size=2,
41
- in_channels=4,
42
- out_channels=1024,
43
- ):
44
- super().__init__()
45
- self.patch_size = patch_size
46
- self.out_channels = out_channels
47
- self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True)
48
- self.apply(self._init_weights)
49
-
50
- def _init_weights(self, m):
51
- if isinstance(m, nn.Linear):
52
- nn.init.xavier_uniform_(m.weight)
53
- if m.bias is not None:
54
- nn.init.constant_(m.bias, 0)
55
-
56
- def forward(self, latent):
57
- latent = self.proj(latent)
58
- return latent
59
-
60
- class PooledEmbed(nn.Module):
61
- def __init__(self, text_emb_dim, hidden_size):
62
- super().__init__()
63
- self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size)
64
- self.apply(self._init_weights)
65
-
66
- def _init_weights(self, m):
67
- if isinstance(m, nn.Linear):
68
- nn.init.normal_(m.weight, std=0.02)
69
- if m.bias is not None:
70
- nn.init.constant_(m.bias, 0)
71
-
72
- def forward(self, pooled_embed):
73
- return self.pooled_embedder(pooled_embed)
74
-
75
- class TimestepEmbed(nn.Module):
76
- def __init__(self, hidden_size, frequency_embedding_size=256):
77
- super().__init__()
78
- self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
79
- self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
80
- self.apply(self._init_weights)
81
-
82
- def _init_weights(self, m):
83
- if isinstance(m, nn.Linear):
84
- nn.init.normal_(m.weight, std=0.02)
85
- if m.bias is not None:
86
- nn.init.constant_(m.bias, 0)
87
-
88
- def forward(self, timesteps, wdtype):
89
- t_emb = self.time_proj(timesteps).to(dtype=wdtype)
90
- t_emb = self.timestep_embedder(t_emb)
91
- return t_emb
92
-
93
- class OutEmbed(nn.Module):
94
- def __init__(self, hidden_size, patch_size, out_channels):
95
- super().__init__()
96
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
97
- self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
98
- self.adaLN_modulation = nn.Sequential(
99
- nn.SiLU(),
100
- nn.Linear(hidden_size, 2 * hidden_size, bias=True)
101
- )
102
- self.apply(self._init_weights)
103
-
104
- def _init_weights(self, m):
105
- if isinstance(m, nn.Linear):
106
- nn.init.zeros_(m.weight)
107
- if m.bias is not None:
108
- nn.init.constant_(m.bias, 0)
109
-
110
- def forward(self, x, adaln_input):
111
- shift, scale = self.adaLN_modulation(adaln_input).chunk(2, dim=1)
112
- x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
113
- x = self.linear(x)
114
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hi_diffusers/models/moe.py DELETED
@@ -1,154 +0,0 @@
1
- import math
2
- import torch
3
- from torch import nn
4
- import torch.nn.functional as F
5
- from .attention import FeedForwardSwiGLU
6
- from torch.distributed.nn.functional import all_gather
7
-
8
- _LOAD_BALANCING_LOSS = []
9
- def save_load_balancing_loss(loss):
10
- global _LOAD_BALANCING_LOSS
11
- _LOAD_BALANCING_LOSS.append(loss)
12
-
13
- def clear_load_balancing_loss():
14
- global _LOAD_BALANCING_LOSS
15
- _LOAD_BALANCING_LOSS.clear()
16
-
17
- def get_load_balancing_loss():
18
- global _LOAD_BALANCING_LOSS
19
- return _LOAD_BALANCING_LOSS
20
-
21
- def batched_load_balancing_loss():
22
- aux_losses_arr = get_load_balancing_loss()
23
- alpha = aux_losses_arr[0][-1]
24
- Pi = torch.stack([ent[1] for ent in aux_losses_arr], dim=0)
25
- fi = torch.stack([ent[2] for ent in aux_losses_arr], dim=0)
26
-
27
- fi_list = all_gather(fi)
28
- fi = torch.stack(fi_list, 0).mean(0)
29
-
30
- aux_loss = (Pi * fi).sum(-1).mean() * alpha
31
- return aux_loss
32
-
33
- # Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
34
- class MoEGate(nn.Module):
35
- def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01):
36
- super().__init__()
37
- self.top_k = num_activated_experts
38
- self.n_routed_experts = num_routed_experts
39
-
40
- self.scoring_func = 'softmax'
41
- self.alpha = aux_loss_alpha
42
- self.seq_aux = False
43
-
44
- # topk selection algorithm
45
- self.norm_topk_prob = False
46
- self.gating_dim = embed_dim
47
- self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
48
- self.reset_parameters()
49
-
50
- def reset_parameters(self) -> None:
51
- import torch.nn.init as init
52
- init.kaiming_uniform_(self.weight, a=math.sqrt(5))
53
-
54
- def forward(self, hidden_states):
55
- bsz, seq_len, h = hidden_states.shape
56
- # print(bsz, seq_len, h)
57
- ### compute gating score
58
- hidden_states = hidden_states.view(-1, h)
59
- logits = F.linear(hidden_states, self.weight, None)
60
- if self.scoring_func == 'softmax':
61
- scores = logits.softmax(dim=-1)
62
- else:
63
- raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
64
-
65
- ### select top-k experts
66
- topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
67
-
68
- ### norm gate to sum 1
69
- if self.top_k > 1 and self.norm_topk_prob:
70
- denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
71
- topk_weight = topk_weight / denominator
72
-
73
- ### expert-level computation auxiliary loss
74
- if self.training and self.alpha > 0.0:
75
- scores_for_aux = scores
76
- aux_topk = self.top_k
77
- # always compute aux loss based on the naive greedy topk method
78
- topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
79
- if self.seq_aux:
80
- scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
81
- ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
82
- ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(seq_len * aux_topk / self.n_routed_experts)
83
- aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * self.alpha
84
- else:
85
- mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
86
- ce = mask_ce.float().mean(0)
87
-
88
- Pi = scores_for_aux.mean(0)
89
- fi = ce * self.n_routed_experts
90
- aux_loss = (Pi * fi).sum() * self.alpha
91
- save_load_balancing_loss((aux_loss, Pi, fi, self.alpha))
92
- else:
93
- aux_loss = None
94
- return topk_idx, topk_weight, aux_loss
95
-
96
- # Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
97
- class MOEFeedForwardSwiGLU(nn.Module):
98
- def __init__(
99
- self,
100
- dim: int,
101
- hidden_dim: int,
102
- num_routed_experts: int,
103
- num_activated_experts: int,
104
- ):
105
- super().__init__()
106
- self.shared_experts = FeedForwardSwiGLU(dim, hidden_dim // 2)
107
- self.experts = nn.ModuleList([FeedForwardSwiGLU(dim, hidden_dim) for i in range(num_routed_experts)])
108
- self.gate = MoEGate(
109
- embed_dim = dim,
110
- num_routed_experts = num_routed_experts,
111
- num_activated_experts = num_activated_experts
112
- )
113
- self.num_activated_experts = num_activated_experts
114
-
115
- def forward(self, x):
116
- wtype = x.dtype
117
- identity = x
118
- orig_shape = x.shape
119
- topk_idx, topk_weight, aux_loss = self.gate(x)
120
- x = x.view(-1, x.shape[-1])
121
- flat_topk_idx = topk_idx.view(-1)
122
- if self.training:
123
- x = x.repeat_interleave(self.num_activated_experts, dim=0)
124
- y = torch.empty_like(x, dtype=wtype)
125
- for i, expert in enumerate(self.experts):
126
- y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
127
- y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
128
- y = y.view(*orig_shape).to(dtype=wtype)
129
- #y = AddAuxiliaryLoss.apply(y, aux_loss)
130
- else:
131
- y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
132
- y = y + self.shared_experts(identity)
133
- return y
134
-
135
- @torch.no_grad()
136
- def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
137
- expert_cache = torch.zeros_like(x)
138
- idxs = flat_expert_indices.argsort()
139
- tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
140
- token_idxs = idxs // self.num_activated_experts
141
- for i, end_idx in enumerate(tokens_per_expert):
142
- start_idx = 0 if i == 0 else tokens_per_expert[i-1]
143
- if start_idx == end_idx:
144
- continue
145
- expert = self.experts[i]
146
- exp_token_idx = token_idxs[start_idx:end_idx]
147
- expert_tokens = x[exp_token_idx]
148
- expert_out = expert(expert_tokens)
149
- expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
150
-
151
- # for fp16 and other dtype
152
- expert_cache = expert_cache.to(expert_out.dtype)
153
- expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum')
154
- return expert_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hi_diffusers/models/transformers/__pycache__/transformer_hidream_image.cpython-310.pyc DELETED
Binary file (14.1 kB)
 
hi_diffusers/models/transformers/transformer_hidream_image.py DELETED
@@ -1,526 +0,0 @@
1
- from typing import Any, Dict, Optional, Tuple, List
2
-
3
- import torch
4
- import torch.nn as nn
5
- import einops
6
- from einops import repeat
7
-
8
- from diffusers.configuration_utils import ConfigMixin, register_to_config
9
- from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
10
- from diffusers.models.modeling_utils import ModelMixin
11
- from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
12
- from diffusers.utils.torch_utils import maybe_allow_in_graph
13
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
14
- from ..embeddings import PatchEmbed, PooledEmbed, TimestepEmbed, EmbedND, OutEmbed
15
- from ..attention import HiDreamAttention, FeedForwardSwiGLU
16
- from ..attention_processor import HiDreamAttnProcessor_flashattn
17
- from ..moe import MOEFeedForwardSwiGLU
18
-
19
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20
-
21
- class TextProjection(nn.Module):
22
- def __init__(self, in_features, hidden_size):
23
- super().__init__()
24
- self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False)
25
-
26
- def forward(self, caption):
27
- hidden_states = self.linear(caption)
28
- return hidden_states
29
-
30
- class BlockType:
31
- TransformerBlock = 1
32
- SingleTransformerBlock = 2
33
-
34
- @maybe_allow_in_graph
35
- class HiDreamImageSingleTransformerBlock(nn.Module):
36
- def __init__(
37
- self,
38
- dim: int,
39
- num_attention_heads: int,
40
- attention_head_dim: int,
41
- num_routed_experts: int = 4,
42
- num_activated_experts: int = 2
43
- ):
44
- super().__init__()
45
- self.num_attention_heads = num_attention_heads
46
- self.adaLN_modulation = nn.Sequential(
47
- nn.SiLU(),
48
- nn.Linear(dim, 6 * dim, bias=True)
49
- )
50
- nn.init.zeros_(self.adaLN_modulation[1].weight)
51
- nn.init.zeros_(self.adaLN_modulation[1].bias)
52
-
53
- # 1. Attention
54
- self.norm1_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
55
- self.attn1 = HiDreamAttention(
56
- query_dim=dim,
57
- heads=num_attention_heads,
58
- dim_head=attention_head_dim,
59
- processor = HiDreamAttnProcessor_flashattn(),
60
- single = True
61
- )
62
-
63
- # 3. Feed-forward
64
- self.norm3_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
65
- if num_routed_experts > 0:
66
- self.ff_i = MOEFeedForwardSwiGLU(
67
- dim = dim,
68
- hidden_dim = 4 * dim,
69
- num_routed_experts = num_routed_experts,
70
- num_activated_experts = num_activated_experts,
71
- )
72
- else:
73
- self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim)
74
-
75
- def forward(
76
- self,
77
- image_tokens: torch.FloatTensor,
78
- image_tokens_masks: Optional[torch.FloatTensor] = None,
79
- text_tokens: Optional[torch.FloatTensor] = None,
80
- adaln_input: Optional[torch.FloatTensor] = None,
81
- rope: torch.FloatTensor = None,
82
-
83
- ) -> torch.FloatTensor:
84
- wtype = image_tokens.dtype
85
- shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
86
- self.adaLN_modulation(adaln_input)[:,None].chunk(6, dim=-1)
87
-
88
- # 1. MM-Attention
89
- norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
90
- norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
91
- attn_output_i = self.attn1(
92
- norm_image_tokens,
93
- image_tokens_masks,
94
- rope = rope,
95
- )
96
- image_tokens = gate_msa_i * attn_output_i + image_tokens
97
-
98
- # 2. Feed-forward
99
- norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
100
- norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
101
- ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens.to(dtype=wtype))
102
- image_tokens = ff_output_i + image_tokens
103
- return image_tokens
104
-
105
- @maybe_allow_in_graph
106
- class HiDreamImageTransformerBlock(nn.Module):
107
- def __init__(
108
- self,
109
- dim: int,
110
- num_attention_heads: int,
111
- attention_head_dim: int,
112
- num_routed_experts: int = 4,
113
- num_activated_experts: int = 2
114
- ):
115
- super().__init__()
116
- self.num_attention_heads = num_attention_heads
117
- self.adaLN_modulation = nn.Sequential(
118
- nn.SiLU(),
119
- nn.Linear(dim, 12 * dim, bias=True)
120
- )
121
- nn.init.zeros_(self.adaLN_modulation[1].weight)
122
- nn.init.zeros_(self.adaLN_modulation[1].bias)
123
-
124
- # 1. Attention
125
- self.norm1_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
126
- self.norm1_t = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
127
- self.attn1 = HiDreamAttention(
128
- query_dim=dim,
129
- heads=num_attention_heads,
130
- dim_head=attention_head_dim,
131
- processor = HiDreamAttnProcessor_flashattn(),
132
- single = False
133
- )
134
-
135
- # 3. Feed-forward
136
- self.norm3_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
137
- if num_routed_experts > 0:
138
- self.ff_i = MOEFeedForwardSwiGLU(
139
- dim = dim,
140
- hidden_dim = 4 * dim,
141
- num_routed_experts = num_routed_experts,
142
- num_activated_experts = num_activated_experts,
143
- )
144
- else:
145
- self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim)
146
- self.norm3_t = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
147
- self.ff_t = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim)
148
-
149
- def forward(
150
- self,
151
- image_tokens: torch.FloatTensor,
152
- image_tokens_masks: Optional[torch.FloatTensor] = None,
153
- text_tokens: Optional[torch.FloatTensor] = None,
154
- adaln_input: Optional[torch.FloatTensor] = None,
155
- rope: torch.FloatTensor = None,
156
- ) -> torch.FloatTensor:
157
- wtype = image_tokens.dtype
158
- shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
159
- shift_msa_t, scale_msa_t, gate_msa_t, shift_mlp_t, scale_mlp_t, gate_mlp_t = \
160
- self.adaLN_modulation(adaln_input)[:,None].chunk(12, dim=-1)
161
-
162
- # 1. MM-Attention
163
- norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
164
- norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
165
- norm_text_tokens = self.norm1_t(text_tokens).to(dtype=wtype)
166
- norm_text_tokens = norm_text_tokens * (1 + scale_msa_t) + shift_msa_t
167
-
168
- attn_output_i, attn_output_t = self.attn1(
169
- norm_image_tokens,
170
- image_tokens_masks,
171
- norm_text_tokens,
172
- rope = rope,
173
- )
174
-
175
- image_tokens = gate_msa_i * attn_output_i + image_tokens
176
- text_tokens = gate_msa_t * attn_output_t + text_tokens
177
-
178
- # 2. Feed-forward
179
- norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
180
- norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
181
- norm_text_tokens = self.norm3_t(text_tokens).to(dtype=wtype)
182
- norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t) + shift_mlp_t
183
-
184
- ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens)
185
- ff_output_t = gate_mlp_t * self.ff_t(norm_text_tokens)
186
- image_tokens = ff_output_i + image_tokens
187
- text_tokens = ff_output_t + text_tokens
188
- return image_tokens, text_tokens
189
-
190
- @maybe_allow_in_graph
191
- class HiDreamImageBlock(nn.Module):
192
- def __init__(
193
- self,
194
- dim: int,
195
- num_attention_heads: int,
196
- attention_head_dim: int,
197
- num_routed_experts: int = 4,
198
- num_activated_experts: int = 2,
199
- block_type: BlockType = BlockType.TransformerBlock,
200
- ):
201
- super().__init__()
202
- block_classes = {
203
- BlockType.TransformerBlock: HiDreamImageTransformerBlock,
204
- BlockType.SingleTransformerBlock: HiDreamImageSingleTransformerBlock,
205
- }
206
- self.block = block_classes[block_type](
207
- dim,
208
- num_attention_heads,
209
- attention_head_dim,
210
- num_routed_experts,
211
- num_activated_experts
212
- )
213
-
214
- def forward(
215
- self,
216
- image_tokens: torch.FloatTensor,
217
- image_tokens_masks: Optional[torch.FloatTensor] = None,
218
- text_tokens: Optional[torch.FloatTensor] = None,
219
- adaln_input: torch.FloatTensor = None,
220
- rope: torch.FloatTensor = None,
221
- ) -> torch.FloatTensor:
222
- return self.block(
223
- image_tokens,
224
- image_tokens_masks,
225
- text_tokens,
226
- adaln_input,
227
- rope,
228
- )
229
-
230
- class HiDreamImageTransformer2DModel(
231
- ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin
232
- ):
233
- _supports_gradient_checkpointing = True
234
- _no_split_modules = ["HiDreamImageBlock"]
235
-
236
- @register_to_config
237
- def __init__(
238
- self,
239
- patch_size: Optional[int] = None,
240
- in_channels: int = 64,
241
- out_channels: Optional[int] = None,
242
- num_layers: int = 16,
243
- num_single_layers: int = 32,
244
- attention_head_dim: int = 128,
245
- num_attention_heads: int = 20,
246
- caption_channels: List[int] = None,
247
- text_emb_dim: int = 2048,
248
- num_routed_experts: int = 4,
249
- num_activated_experts: int = 2,
250
- axes_dims_rope: Tuple[int, int] = (32, 32),
251
- max_resolution: Tuple[int, int] = (128, 128),
252
- llama_layers: List[int] = None,
253
- ):
254
- super().__init__()
255
- self.out_channels = out_channels or in_channels
256
- self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
257
- self.llama_layers = llama_layers
258
-
259
- self.t_embedder = TimestepEmbed(self.inner_dim)
260
- self.p_embedder = PooledEmbed(text_emb_dim, self.inner_dim)
261
- self.x_embedder = PatchEmbed(
262
- patch_size = patch_size,
263
- in_channels = in_channels,
264
- out_channels = self.inner_dim,
265
- )
266
- self.pe_embedder = EmbedND(theta=10000, axes_dim=axes_dims_rope)
267
-
268
- self.double_stream_blocks = nn.ModuleList(
269
- [
270
- HiDreamImageBlock(
271
- dim = self.inner_dim,
272
- num_attention_heads = self.config.num_attention_heads,
273
- attention_head_dim = self.config.attention_head_dim,
274
- num_routed_experts = num_routed_experts,
275
- num_activated_experts = num_activated_experts,
276
- block_type = BlockType.TransformerBlock
277
- )
278
- for i in range(self.config.num_layers)
279
- ]
280
- )
281
-
282
- self.single_stream_blocks = nn.ModuleList(
283
- [
284
- HiDreamImageBlock(
285
- dim = self.inner_dim,
286
- num_attention_heads = self.config.num_attention_heads,
287
- attention_head_dim = self.config.attention_head_dim,
288
- num_routed_experts = num_routed_experts,
289
- num_activated_experts = num_activated_experts,
290
- block_type = BlockType.SingleTransformerBlock
291
- )
292
- for i in range(self.config.num_single_layers)
293
- ]
294
- )
295
-
296
- self.final_layer = OutEmbed(self.inner_dim, patch_size, self.out_channels)
297
-
298
- caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ]
299
- caption_projection = []
300
- for caption_channel in caption_channels:
301
- caption_projection.append(TextProjection(in_features = caption_channel, hidden_size = self.inner_dim))
302
- self.caption_projection = nn.ModuleList(caption_projection)
303
- self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size)
304
-
305
- self.gradient_checkpointing = False
306
-
307
- def _set_gradient_checkpointing(self, module, value=False):
308
- if hasattr(module, "gradient_checkpointing"):
309
- module.gradient_checkpointing = value
310
-
311
- def expand_timesteps(self, timesteps, batch_size, device):
312
- if not torch.is_tensor(timesteps):
313
- is_mps = device.type == "mps"
314
- if isinstance(timesteps, float):
315
- dtype = torch.float32 if is_mps else torch.float64
316
- else:
317
- dtype = torch.int32 if is_mps else torch.int64
318
- timesteps = torch.tensor([timesteps], dtype=dtype, device=device)
319
- elif len(timesteps.shape) == 0:
320
- timesteps = timesteps[None].to(device)
321
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
322
- timesteps = timesteps.expand(batch_size)
323
- return timesteps
324
-
325
- def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
326
- if is_training:
327
- x = einops.rearrange(x, 'B S (p1 p2 C) -> B C S (p1 p2)', p1=self.config.patch_size, p2=self.config.patch_size)
328
- else:
329
- x_arr = []
330
- for i, img_size in enumerate(img_sizes):
331
- pH, pW = img_size
332
- x_arr.append(
333
- einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)',
334
- p1=self.config.patch_size, p2=self.config.patch_size)
335
- )
336
- x = torch.cat(x_arr, dim=0)
337
- return x
338
-
339
- def patchify(self, x, max_seq, img_sizes=None):
340
- pz2 = self.config.patch_size * self.config.patch_size
341
- if isinstance(x, torch.Tensor):
342
- B, C = x.shape[0], x.shape[1]
343
- device = x.device
344
- dtype = x.dtype
345
- else:
346
- B, C = len(x), x[0].shape[0]
347
- device = x[0].device
348
- dtype = x[0].dtype
349
- x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device)
350
-
351
- if img_sizes is not None:
352
- for i, img_size in enumerate(img_sizes):
353
- x_masks[i, 0:img_size[0] * img_size[1]] = 1
354
- x = einops.rearrange(x, 'B C S p -> B S (p C)', p=pz2)
355
- elif isinstance(x, torch.Tensor):
356
- pH, pW = x.shape[-2] // self.config.patch_size, x.shape[-1] // self.config.patch_size
357
- x = einops.rearrange(x, 'B C (H p1) (W p2) -> B (H W) (p1 p2 C)', p1=self.config.patch_size, p2=self.config.patch_size)
358
- img_sizes = [[pH, pW]] * B
359
- x_masks = None
360
- else:
361
- raise NotImplementedError
362
- return x, x_masks, img_sizes
363
-
364
- def forward(
365
- self,
366
- hidden_states: torch.Tensor,
367
- timesteps: torch.LongTensor = None,
368
- encoder_hidden_states: torch.Tensor = None,
369
- pooled_embeds: torch.Tensor = None,
370
- img_sizes: Optional[List[Tuple[int, int]]] = None,
371
- img_ids: Optional[torch.Tensor] = None,
372
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
373
- return_dict: bool = True,
374
- ):
375
- if joint_attention_kwargs is not None:
376
- joint_attention_kwargs = joint_attention_kwargs.copy()
377
- lora_scale = joint_attention_kwargs.pop("scale", 1.0)
378
- else:
379
- lora_scale = 1.0
380
-
381
- if USE_PEFT_BACKEND:
382
- # weight the lora layers by setting `lora_scale` for each PEFT layer
383
- scale_lora_layers(self, lora_scale)
384
- else:
385
- if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
386
- logger.warning(
387
- "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
388
- )
389
-
390
- # spatial forward
391
- batch_size = hidden_states.shape[0]
392
- hidden_states_type = hidden_states.dtype
393
-
394
- # 0. time
395
- timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device)
396
- timesteps = self.t_embedder(timesteps, hidden_states_type)
397
- p_embedder = self.p_embedder(pooled_embeds)
398
- adaln_input = timesteps + p_embedder
399
-
400
- hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes)
401
- if image_tokens_masks is None:
402
- pH, pW = img_sizes[0]
403
- img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
404
- img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
405
- img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]
406
- img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
407
- hidden_states = self.x_embedder(hidden_states)
408
-
409
- T5_encoder_hidden_states = encoder_hidden_states[0]
410
- encoder_hidden_states = encoder_hidden_states[-1]
411
- encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers]
412
-
413
- if self.caption_projection is not None:
414
- new_encoder_hidden_states = []
415
- for i, enc_hidden_state in enumerate(encoder_hidden_states):
416
- enc_hidden_state = self.caption_projection[i](enc_hidden_state)
417
- enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
418
- new_encoder_hidden_states.append(enc_hidden_state)
419
- encoder_hidden_states = new_encoder_hidden_states
420
- T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states)
421
- T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
422
- encoder_hidden_states.append(T5_encoder_hidden_states)
423
-
424
- txt_ids = torch.zeros(
425
- batch_size,
426
- encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1],
427
- 3,
428
- device=img_ids.device, dtype=img_ids.dtype
429
- )
430
- ids = torch.cat((img_ids, txt_ids), dim=1)
431
- rope = self.pe_embedder(ids)
432
-
433
- # 2. Blocks
434
- block_id = 0
435
- initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
436
- initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
437
- for bid, block in enumerate(self.double_stream_blocks):
438
- cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
439
- cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1)
440
- if self.training and self.gradient_checkpointing:
441
- def create_custom_forward(module, return_dict=None):
442
- def custom_forward(*inputs):
443
- if return_dict is not None:
444
- return module(*inputs, return_dict=return_dict)
445
- else:
446
- return module(*inputs)
447
- return custom_forward
448
-
449
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
450
- hidden_states, initial_encoder_hidden_states = torch.utils.checkpoint.checkpoint(
451
- create_custom_forward(block),
452
- hidden_states,
453
- image_tokens_masks,
454
- cur_encoder_hidden_states,
455
- adaln_input,
456
- rope,
457
- **ckpt_kwargs,
458
- )
459
- else:
460
- hidden_states, initial_encoder_hidden_states = block(
461
- image_tokens = hidden_states,
462
- image_tokens_masks = image_tokens_masks,
463
- text_tokens = cur_encoder_hidden_states,
464
- adaln_input = adaln_input,
465
- rope = rope,
466
- )
467
- initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
468
- block_id += 1
469
-
470
- image_tokens_seq_len = hidden_states.shape[1]
471
- hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)
472
- hidden_states_seq_len = hidden_states.shape[1]
473
- if image_tokens_masks is not None:
474
- encoder_attention_mask_ones = torch.ones(
475
- (batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]),
476
- device=image_tokens_masks.device, dtype=image_tokens_masks.dtype
477
- )
478
- image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1)
479
-
480
- for bid, block in enumerate(self.single_stream_blocks):
481
- cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
482
- hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
483
- if self.training and self.gradient_checkpointing:
484
- def create_custom_forward(module, return_dict=None):
485
- def custom_forward(*inputs):
486
- if return_dict is not None:
487
- return module(*inputs, return_dict=return_dict)
488
- else:
489
- return module(*inputs)
490
- return custom_forward
491
-
492
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
493
- hidden_states = torch.utils.checkpoint.checkpoint(
494
- create_custom_forward(block),
495
- hidden_states,
496
- image_tokens_masks,
497
- None,
498
- adaln_input,
499
- rope,
500
- **ckpt_kwargs,
501
- )
502
- else:
503
- hidden_states = block(
504
- image_tokens = hidden_states,
505
- image_tokens_masks = image_tokens_masks,
506
- text_tokens = None,
507
- adaln_input = adaln_input,
508
- rope = rope,
509
- )
510
- hidden_states = hidden_states[:, :hidden_states_seq_len]
511
- block_id += 1
512
-
513
- hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
514
- output = self.final_layer(hidden_states, adaln_input)
515
- output = self.unpatchify(output, img_sizes, self.training)
516
- if image_tokens_masks is not None:
517
- image_tokens_masks = image_tokens_masks[:, :image_tokens_seq_len]
518
-
519
- if USE_PEFT_BACKEND:
520
- # remove `lora_scale` from each PEFT layer
521
- unscale_lora_layers(self, lora_scale)
522
-
523
- if not return_dict:
524
- return (output, image_tokens_masks)
525
- return Transformer2DModelOutput(sample=output, mask=image_tokens_masks)
526
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hi_diffusers/pipelines/hidream_image/__pycache__/pipeline_hidream_image.cpython-310.pyc DELETED
Binary file (18.2 kB)
 
hi_diffusers/pipelines/hidream_image/__pycache__/pipeline_output.cpython-310.pyc DELETED
Binary file (1.03 kB)
 
hi_diffusers/pipelines/hidream_image/pipeline_hidream_image.py DELETED
@@ -1,733 +0,0 @@
1
- import inspect
2
- from typing import Any, Callable, Dict, List, Optional, Union
3
- import math
4
- import einops
5
- import torch
6
- from transformers import (
7
- CLIPTextModelWithProjection,
8
- CLIPTokenizer,
9
- T5EncoderModel,
10
- T5Tokenizer,
11
- LlamaForCausalLM,
12
- PreTrainedTokenizerFast
13
- )
14
-
15
- from diffusers.image_processor import VaeImageProcessor
16
- from diffusers.loaders import FromSingleFileMixin
17
- from diffusers.models.autoencoders import AutoencoderKL
18
- from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
19
- from diffusers.utils import (
20
- USE_PEFT_BACKEND,
21
- is_torch_xla_available,
22
- logging,
23
- )
24
- from diffusers.utils.torch_utils import randn_tensor
25
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
26
- from .pipeline_output import HiDreamImagePipelineOutput
27
- from ...models.transformers.transformer_hidream_image import HiDreamImageTransformer2DModel
28
- from ...schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
29
-
30
- if is_torch_xla_available():
31
- import torch_xla.core.xla_model as xm
32
-
33
- XLA_AVAILABLE = True
34
- else:
35
- XLA_AVAILABLE = False
36
-
37
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
-
39
- # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
40
- def calculate_shift(
41
- image_seq_len,
42
- base_seq_len: int = 256,
43
- max_seq_len: int = 4096,
44
- base_shift: float = 0.5,
45
- max_shift: float = 1.15,
46
- ):
47
- m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
48
- b = base_shift - m * base_seq_len
49
- mu = image_seq_len * m + b
50
- return mu
51
-
52
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
53
- def retrieve_timesteps(
54
- scheduler,
55
- num_inference_steps: Optional[int] = None,
56
- device: Optional[Union[str, torch.device]] = None,
57
- timesteps: Optional[List[int]] = None,
58
- sigmas: Optional[List[float]] = None,
59
- **kwargs,
60
- ):
61
- r"""
62
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
63
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
64
-
65
- Args:
66
- scheduler (`SchedulerMixin`):
67
- The scheduler to get timesteps from.
68
- num_inference_steps (`int`):
69
- The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
70
- must be `None`.
71
- device (`str` or `torch.device`, *optional*):
72
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
73
- timesteps (`List[int]`, *optional*):
74
- Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
75
- `num_inference_steps` and `sigmas` must be `None`.
76
- sigmas (`List[float]`, *optional*):
77
- Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
78
- `num_inference_steps` and `timesteps` must be `None`.
79
-
80
- Returns:
81
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
82
- second element is the number of inference steps.
83
- """
84
- if timesteps is not None and sigmas is not None:
85
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
86
- if timesteps is not None:
87
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
88
- if not accepts_timesteps:
89
- raise ValueError(
90
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
91
- f" timestep schedules. Please check whether you are using the correct scheduler."
92
- )
93
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
94
- timesteps = scheduler.timesteps
95
- num_inference_steps = len(timesteps)
96
- elif sigmas is not None:
97
- accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
98
- if not accept_sigmas:
99
- raise ValueError(
100
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
101
- f" sigmas schedules. Please check whether you are using the correct scheduler."
102
- )
103
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
104
- timesteps = scheduler.timesteps
105
- num_inference_steps = len(timesteps)
106
- else:
107
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
108
- timesteps = scheduler.timesteps
109
- return timesteps, num_inference_steps
110
-
111
- class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
112
- model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->image_encoder->transformer->vae"
113
- _optional_components = ["image_encoder", "feature_extractor"]
114
- _callback_tensor_inputs = ["latents", "prompt_embeds"]
115
-
116
- def __init__(
117
- self,
118
- scheduler: FlowMatchEulerDiscreteScheduler,
119
- vae: AutoencoderKL,
120
- text_encoder: CLIPTextModelWithProjection,
121
- tokenizer: CLIPTokenizer,
122
- text_encoder_2: CLIPTextModelWithProjection,
123
- tokenizer_2: CLIPTokenizer,
124
- text_encoder_3: T5EncoderModel,
125
- tokenizer_3: T5Tokenizer,
126
- text_encoder_4: LlamaForCausalLM,
127
- tokenizer_4: PreTrainedTokenizerFast,
128
- ):
129
- super().__init__()
130
-
131
- self.register_modules(
132
- vae=vae,
133
- text_encoder=text_encoder,
134
- text_encoder_2=text_encoder_2,
135
- text_encoder_3=text_encoder_3,
136
- text_encoder_4=text_encoder_4,
137
- tokenizer=tokenizer,
138
- tokenizer_2=tokenizer_2,
139
- tokenizer_3=tokenizer_3,
140
- tokenizer_4=tokenizer_4,
141
- scheduler=scheduler,
142
- )
143
- self.vae_scale_factor = (
144
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
145
- )
146
- # HiDreamImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
147
- # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
148
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
149
- self.default_sample_size = 128
150
- self.tokenizer_4.pad_token = self.tokenizer_4.eos_token
151
-
152
- def _get_t5_prompt_embeds(
153
- self,
154
- prompt: Union[str, List[str]] = None,
155
- num_images_per_prompt: int = 1,
156
- max_sequence_length: int = 128,
157
- device: Optional[torch.device] = None,
158
- dtype: Optional[torch.dtype] = None,
159
- ):
160
- device = device or self._execution_device
161
- dtype = dtype or self.text_encoder_3.dtype
162
-
163
- prompt = [prompt] if isinstance(prompt, str) else prompt
164
- batch_size = len(prompt)
165
-
166
- text_inputs = self.tokenizer_3(
167
- prompt,
168
- padding="max_length",
169
- max_length=min(max_sequence_length, self.tokenizer_3.model_max_length),
170
- truncation=True,
171
- add_special_tokens=True,
172
- return_tensors="pt",
173
- )
174
- text_input_ids = text_inputs.input_ids
175
- attention_mask = text_inputs.attention_mask
176
- untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
177
-
178
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
179
- removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.text_encoder_3.model_max_length - 1 : -1])
180
- logger.warning(
181
- "The following part of your input was truncated because `max_sequence_length` is set to "
182
- f" {self.text_encoder_3.model_max_length} tokens: {removed_text}"
183
- )
184
-
185
- prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=attention_mask.to(device))[0]
186
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
187
- _, seq_len, _ = prompt_embeds.shape
188
-
189
- # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
190
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
191
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
192
- return prompt_embeds
193
-
194
- def _get_clip_prompt_embeds(
195
- self,
196
- tokenizer,
197
- text_encoder,
198
- prompt: Union[str, List[str]],
199
- num_images_per_prompt: int = 1,
200
- max_sequence_length: int = 128,
201
- device: Optional[torch.device] = None,
202
- dtype: Optional[torch.dtype] = None,
203
- ):
204
- device = device or self._execution_device
205
- dtype = dtype or text_encoder.dtype
206
-
207
- prompt = [prompt] if isinstance(prompt, str) else prompt
208
- batch_size = len(prompt)
209
-
210
- text_inputs = tokenizer(
211
- prompt,
212
- padding="max_length",
213
- max_length=min(max_sequence_length, 218),
214
- truncation=True,
215
- return_tensors="pt",
216
- )
217
-
218
- text_input_ids = text_inputs.input_ids
219
- untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
220
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
221
- removed_text = tokenizer.batch_decode(untruncated_ids[:, 218 - 1 : -1])
222
- logger.warning(
223
- "The following part of your input was truncated because CLIP can only handle sequences up to"
224
- f" {218} tokens: {removed_text}"
225
- )
226
- prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
227
-
228
- # Use pooled output of CLIPTextModel
229
- prompt_embeds = prompt_embeds[0]
230
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
231
-
232
- # duplicate text embeddings for each generation per prompt, using mps friendly method
233
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
234
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
235
-
236
- return prompt_embeds
237
-
238
- def _get_llama3_prompt_embeds(
239
- self,
240
- prompt: Union[str, List[str]] = None,
241
- num_images_per_prompt: int = 1,
242
- max_sequence_length: int = 128,
243
- device: Optional[torch.device] = None,
244
- dtype: Optional[torch.dtype] = None,
245
- ):
246
- device = device or self._execution_device
247
- dtype = dtype or self.text_encoder_4.dtype
248
-
249
- prompt = [prompt] if isinstance(prompt, str) else prompt
250
- batch_size = len(prompt)
251
-
252
- text_inputs = self.tokenizer_4(
253
- prompt,
254
- padding="max_length",
255
- max_length=min(max_sequence_length, self.tokenizer_4.model_max_length),
256
- truncation=True,
257
- add_special_tokens=True,
258
- return_tensors="pt",
259
- )
260
- text_input_ids = text_inputs.input_ids
261
- attention_mask = text_inputs.attention_mask
262
- untruncated_ids = self.tokenizer_4(prompt, padding="longest", return_tensors="pt").input_ids
263
-
264
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
265
- removed_text = self.tokenizer_4.batch_decode(untruncated_ids[:, self.text_encoder_4.model_max_length - 1 : -1])
266
- logger.warning(
267
- "The following part of your input was truncated because `max_sequence_length` is set to "
268
- f" {self.text_encoder_4.model_max_length} tokens: {removed_text}"
269
- )
270
-
271
- outputs = self.text_encoder_4(
272
- text_input_ids.to(device),
273
- attention_mask=attention_mask.to(device),
274
- output_hidden_states=True,
275
- output_attentions=True
276
- )
277
-
278
- prompt_embeds = outputs.hidden_states[1:]
279
- prompt_embeds = torch.stack(prompt_embeds, dim=0)
280
- _, _, seq_len, dim = prompt_embeds.shape
281
-
282
- # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
283
- prompt_embeds = prompt_embeds.repeat(1, 1, num_images_per_prompt, 1)
284
- prompt_embeds = prompt_embeds.view(-1, batch_size * num_images_per_prompt, seq_len, dim)
285
- return prompt_embeds
286
-
287
- def encode_prompt(
288
- self,
289
- prompt: Union[str, List[str]],
290
- prompt_2: Union[str, List[str]],
291
- prompt_3: Union[str, List[str]],
292
- prompt_4: Union[str, List[str]],
293
- device: Optional[torch.device] = None,
294
- dtype: Optional[torch.dtype] = None,
295
- num_images_per_prompt: int = 1,
296
- do_classifier_free_guidance: bool = True,
297
- negative_prompt: Optional[Union[str, List[str]]] = None,
298
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
299
- negative_prompt_3: Optional[Union[str, List[str]]] = None,
300
- negative_prompt_4: Optional[Union[str, List[str]]] = None,
301
- prompt_embeds: Optional[List[torch.FloatTensor]] = None,
302
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
303
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
304
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
305
- max_sequence_length: int = 128,
306
- lora_scale: Optional[float] = None,
307
- ):
308
- prompt = [prompt] if isinstance(prompt, str) else prompt
309
- if prompt is not None:
310
- batch_size = len(prompt)
311
- else:
312
- batch_size = prompt_embeds.shape[0]
313
-
314
- prompt_embeds, pooled_prompt_embeds = self._encode_prompt(
315
- prompt = prompt,
316
- prompt_2 = prompt_2,
317
- prompt_3 = prompt_3,
318
- prompt_4 = prompt_4,
319
- device = device,
320
- dtype = dtype,
321
- num_images_per_prompt = num_images_per_prompt,
322
- prompt_embeds = prompt_embeds,
323
- pooled_prompt_embeds = pooled_prompt_embeds,
324
- max_sequence_length = max_sequence_length,
325
- )
326
-
327
- if do_classifier_free_guidance and negative_prompt_embeds is None:
328
- negative_prompt = negative_prompt or ""
329
- negative_prompt_2 = negative_prompt_2 or negative_prompt
330
- negative_prompt_3 = negative_prompt_3 or negative_prompt
331
- negative_prompt_4 = negative_prompt_4 or negative_prompt
332
-
333
- # normalize str to list
334
- negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
335
- negative_prompt_2 = (
336
- batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
337
- )
338
- negative_prompt_3 = (
339
- batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
340
- )
341
- negative_prompt_4 = (
342
- batch_size * [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4
343
- )
344
-
345
- if prompt is not None and type(prompt) is not type(negative_prompt):
346
- raise TypeError(
347
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
348
- f" {type(prompt)}."
349
- )
350
- elif batch_size != len(negative_prompt):
351
- raise ValueError(
352
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
353
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
354
- " the batch size of `prompt`."
355
- )
356
-
357
- negative_prompt_embeds, negative_pooled_prompt_embeds = self._encode_prompt(
358
- prompt = negative_prompt,
359
- prompt_2 = negative_prompt_2,
360
- prompt_3 = negative_prompt_3,
361
- prompt_4 = negative_prompt_4,
362
- device = device,
363
- dtype = dtype,
364
- num_images_per_prompt = num_images_per_prompt,
365
- prompt_embeds = negative_prompt_embeds,
366
- pooled_prompt_embeds = negative_pooled_prompt_embeds,
367
- max_sequence_length = max_sequence_length,
368
- )
369
- return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
370
-
371
- def _encode_prompt(
372
- self,
373
- prompt: Union[str, List[str]],
374
- prompt_2: Union[str, List[str]],
375
- prompt_3: Union[str, List[str]],
376
- prompt_4: Union[str, List[str]],
377
- device: Optional[torch.device] = None,
378
- dtype: Optional[torch.dtype] = None,
379
- num_images_per_prompt: int = 1,
380
- prompt_embeds: Optional[List[torch.FloatTensor]] = None,
381
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
382
- max_sequence_length: int = 128,
383
- ):
384
- device = device or self._execution_device
385
-
386
- if prompt_embeds is None:
387
- prompt_2 = prompt_2 or prompt
388
- prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
389
-
390
- prompt_3 = prompt_3 or prompt
391
- prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
392
-
393
- prompt_4 = prompt_4 or prompt
394
- prompt_4 = [prompt_4] if isinstance(prompt_4, str) else prompt_4
395
-
396
- pooled_prompt_embeds_1 = self._get_clip_prompt_embeds(
397
- self.tokenizer,
398
- self.text_encoder,
399
- prompt = prompt,
400
- num_images_per_prompt = num_images_per_prompt,
401
- max_sequence_length = max_sequence_length,
402
- device = device,
403
- dtype = dtype,
404
- )
405
-
406
- pooled_prompt_embeds_2 = self._get_clip_prompt_embeds(
407
- self.tokenizer_2,
408
- self.text_encoder_2,
409
- prompt = prompt_2,
410
- num_images_per_prompt = num_images_per_prompt,
411
- max_sequence_length = max_sequence_length,
412
- device = device,
413
- dtype = dtype,
414
- )
415
-
416
- pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1)
417
-
418
- t5_prompt_embeds = self._get_t5_prompt_embeds(
419
- prompt = prompt_3,
420
- num_images_per_prompt = num_images_per_prompt,
421
- max_sequence_length = max_sequence_length,
422
- device = device,
423
- dtype = dtype
424
- )
425
- llama3_prompt_embeds = self._get_llama3_prompt_embeds(
426
- prompt = prompt_4,
427
- num_images_per_prompt = num_images_per_prompt,
428
- max_sequence_length = max_sequence_length,
429
- device = device,
430
- dtype = dtype
431
- )
432
- prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds]
433
-
434
- return prompt_embeds, pooled_prompt_embeds
435
-
436
- def enable_vae_slicing(self):
437
- r"""
438
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
439
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
440
- """
441
- self.vae.enable_slicing()
442
-
443
- def disable_vae_slicing(self):
444
- r"""
445
- Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
446
- computing decoding in one step.
447
- """
448
- self.vae.disable_slicing()
449
-
450
- def enable_vae_tiling(self):
451
- r"""
452
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
453
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
454
- processing larger images.
455
- """
456
- self.vae.enable_tiling()
457
-
458
- def disable_vae_tiling(self):
459
- r"""
460
- Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
461
- computing decoding in one step.
462
- """
463
- self.vae.disable_tiling()
464
-
465
- def prepare_latents(
466
- self,
467
- batch_size,
468
- num_channels_latents,
469
- height,
470
- width,
471
- dtype,
472
- device,
473
- generator,
474
- latents=None,
475
- ):
476
- # VAE applies 8x compression on images but we must also account for packing which requires
477
- # latent height and width to be divisible by 2.
478
- height = 2 * (int(height) // (self.vae_scale_factor * 2))
479
- width = 2 * (int(width) // (self.vae_scale_factor * 2))
480
-
481
- shape = (batch_size, num_channels_latents, height, width)
482
-
483
- if latents is None:
484
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
485
- else:
486
- if latents.shape != shape:
487
- raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
488
- latents = latents.to(device)
489
- return latents
490
-
491
- @property
492
- def guidance_scale(self):
493
- return self._guidance_scale
494
-
495
- @property
496
- def do_classifier_free_guidance(self):
497
- return self._guidance_scale > 1
498
-
499
- @property
500
- def joint_attention_kwargs(self):
501
- return self._joint_attention_kwargs
502
-
503
- @property
504
- def num_timesteps(self):
505
- return self._num_timesteps
506
-
507
- @property
508
- def interrupt(self):
509
- return self._interrupt
510
-
511
- @torch.no_grad()
512
- def __call__(
513
- self,
514
- prompt: Union[str, List[str]] = None,
515
- prompt_2: Optional[Union[str, List[str]]] = None,
516
- prompt_3: Optional[Union[str, List[str]]] = None,
517
- prompt_4: Optional[Union[str, List[str]]] = None,
518
- height: Optional[int] = None,
519
- width: Optional[int] = None,
520
- num_inference_steps: int = 50,
521
- sigmas: Optional[List[float]] = None,
522
- guidance_scale: float = 5.0,
523
- negative_prompt: Optional[Union[str, List[str]]] = None,
524
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
525
- negative_prompt_3: Optional[Union[str, List[str]]] = None,
526
- negative_prompt_4: Optional[Union[str, List[str]]] = None,
527
- num_images_per_prompt: Optional[int] = 1,
528
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
529
- latents: Optional[torch.FloatTensor] = None,
530
- prompt_embeds: Optional[torch.FloatTensor] = None,
531
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
532
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
533
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
534
- output_type: Optional[str] = "pil",
535
- return_dict: bool = True,
536
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
537
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
538
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
539
- max_sequence_length: int = 128,
540
- ):
541
- height = height or self.default_sample_size * self.vae_scale_factor
542
- width = width or self.default_sample_size * self.vae_scale_factor
543
-
544
- division = self.vae_scale_factor * 2
545
- S_max = (self.default_sample_size * self.vae_scale_factor) ** 2
546
- scale = S_max / (width * height)
547
- scale = math.sqrt(scale)
548
- width, height = int(width * scale // division * division), int(height * scale // division * division)
549
-
550
- self._guidance_scale = guidance_scale
551
- self._joint_attention_kwargs = joint_attention_kwargs
552
- self._interrupt = False
553
-
554
- # 2. Define call parameters
555
- if prompt is not None and isinstance(prompt, str):
556
- batch_size = 1
557
- elif prompt is not None and isinstance(prompt, list):
558
- batch_size = len(prompt)
559
- else:
560
- batch_size = prompt_embeds.shape[0]
561
-
562
- device = self._execution_device
563
-
564
- lora_scale = (
565
- self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
566
- )
567
- (
568
- prompt_embeds,
569
- negative_prompt_embeds,
570
- pooled_prompt_embeds,
571
- negative_pooled_prompt_embeds,
572
- ) = self.encode_prompt(
573
- prompt=prompt,
574
- prompt_2=prompt_2,
575
- prompt_3=prompt_3,
576
- prompt_4=prompt_4,
577
- negative_prompt=negative_prompt,
578
- negative_prompt_2=negative_prompt_2,
579
- negative_prompt_3=negative_prompt_3,
580
- negative_prompt_4=negative_prompt_4,
581
- do_classifier_free_guidance=self.do_classifier_free_guidance,
582
- prompt_embeds=prompt_embeds,
583
- negative_prompt_embeds=negative_prompt_embeds,
584
- pooled_prompt_embeds=pooled_prompt_embeds,
585
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
586
- device=device,
587
- num_images_per_prompt=num_images_per_prompt,
588
- max_sequence_length=max_sequence_length,
589
- lora_scale=lora_scale,
590
- )
591
-
592
- if self.do_classifier_free_guidance:
593
- prompt_embeds_arr = []
594
- for n, p in zip(negative_prompt_embeds, prompt_embeds):
595
- if len(n.shape) == 3:
596
- prompt_embeds_arr.append(torch.cat([n, p], dim=0))
597
- else:
598
- prompt_embeds_arr.append(torch.cat([n, p], dim=1))
599
- prompt_embeds = prompt_embeds_arr
600
- pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
601
-
602
- # 4. Prepare latent variables
603
- num_channels_latents = self.transformer.config.in_channels
604
- latents = self.prepare_latents(
605
- batch_size * num_images_per_prompt,
606
- num_channels_latents,
607
- height,
608
- width,
609
- pooled_prompt_embeds.dtype,
610
- device,
611
- generator,
612
- latents,
613
- )
614
-
615
- if latents.shape[-2] != latents.shape[-1]:
616
- B, C, H, W = latents.shape
617
- pH, pW = H // self.transformer.config.patch_size, W // self.transformer.config.patch_size
618
-
619
- img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1)
620
- img_ids = torch.zeros(pH, pW, 3)
621
- img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None]
622
- img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :]
623
- img_ids = img_ids.reshape(pH * pW, -1)
624
- img_ids_pad = torch.zeros(self.transformer.max_seq, 3)
625
- img_ids_pad[:pH*pW, :] = img_ids
626
-
627
- img_sizes = img_sizes.unsqueeze(0).to(latents.device)
628
- img_ids = img_ids_pad.unsqueeze(0).to(latents.device)
629
- if self.do_classifier_free_guidance:
630
- img_sizes = img_sizes.repeat(2 * B, 1)
631
- img_ids = img_ids.repeat(2 * B, 1, 1)
632
- else:
633
- img_sizes = img_ids = None
634
-
635
- # 5. Prepare timesteps
636
- mu = calculate_shift(self.transformer.max_seq)
637
- scheduler_kwargs = {"mu": mu}
638
- if isinstance(self.scheduler, FlowUniPCMultistepScheduler):
639
- self.scheduler.set_timesteps(num_inference_steps, device=device, shift=math.exp(mu))
640
- timesteps = self.scheduler.timesteps
641
- else:
642
- timesteps, num_inference_steps = retrieve_timesteps(
643
- self.scheduler,
644
- num_inference_steps,
645
- device,
646
- sigmas=sigmas,
647
- **scheduler_kwargs,
648
- )
649
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
650
- self._num_timesteps = len(timesteps)
651
-
652
- # 6. Denoising loop
653
- with self.progress_bar(total=num_inference_steps) as progress_bar:
654
- for i, t in enumerate(timesteps):
655
- if self.interrupt:
656
- continue
657
-
658
- # expand the latents if we are doing classifier free guidance
659
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
660
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
661
- timestep = t.expand(latent_model_input.shape[0])
662
-
663
- if latent_model_input.shape[-2] != latent_model_input.shape[-1]:
664
- B, C, H, W = latent_model_input.shape
665
- patch_size = self.transformer.config.patch_size
666
- pH, pW = H // patch_size, W // patch_size
667
- out = torch.zeros(
668
- (B, C, self.transformer.max_seq, patch_size * patch_size),
669
- dtype=latent_model_input.dtype,
670
- device=latent_model_input.device
671
- )
672
- latent_model_input = einops.rearrange(latent_model_input, 'B C (H p1) (W p2) -> B C (H W) (p1 p2)', p1=patch_size, p2=patch_size)
673
- out[:, :, 0:pH*pW] = latent_model_input
674
- latent_model_input = out
675
-
676
- noise_pred = self.transformer(
677
- hidden_states = latent_model_input,
678
- timesteps = timestep,
679
- encoder_hidden_states = prompt_embeds,
680
- pooled_embeds = pooled_prompt_embeds,
681
- img_sizes = img_sizes,
682
- img_ids = img_ids,
683
- return_dict = False,
684
- )[0]
685
- noise_pred = -noise_pred
686
-
687
- # perform guidance
688
- if self.do_classifier_free_guidance:
689
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
690
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
691
-
692
- # compute the previous noisy sample x_t -> x_t-1
693
- latents_dtype = latents.dtype
694
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
695
-
696
- if latents.dtype != latents_dtype:
697
- if torch.backends.mps.is_available():
698
- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
699
- latents = latents.to(latents_dtype)
700
-
701
- if callback_on_step_end is not None:
702
- callback_kwargs = {}
703
- for k in callback_on_step_end_tensor_inputs:
704
- callback_kwargs[k] = locals()[k]
705
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
706
-
707
- latents = callback_outputs.pop("latents", latents)
708
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
709
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
710
-
711
- # call the callback, if provided
712
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
713
- progress_bar.update()
714
-
715
- if XLA_AVAILABLE:
716
- xm.mark_step()
717
-
718
- if output_type == "latent":
719
- image = latents
720
-
721
- else:
722
- latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
723
-
724
- image = self.vae.decode(latents, return_dict=False)[0]
725
- image = self.image_processor.postprocess(image, output_type=output_type)
726
-
727
- # Offload all models
728
- self.maybe_free_model_hooks()
729
-
730
- if not return_dict:
731
- return (image,)
732
-
733
- return HiDreamImagePipelineOutput(images=image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hi_diffusers/pipelines/hidream_image/pipeline_output.py DELETED
@@ -1,21 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import List, Union
3
-
4
- import numpy as np
5
- import PIL.Image
6
-
7
- from diffusers.utils import BaseOutput
8
-
9
-
10
- @dataclass
11
- class HiDreamImagePipelineOutput(BaseOutput):
12
- """
13
- Output class for HiDreamImage pipelines.
14
-
15
- Args:
16
- images (`List[PIL.Image.Image]` or `np.ndarray`)
17
- List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
18
- num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
19
- """
20
-
21
- images: Union[List[PIL.Image.Image], np.ndarray]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hi_diffusers/schedulers/__pycache__/flash_flow_match.cpython-310.pyc DELETED
Binary file (12.9 kB)
 
hi_diffusers/schedulers/__pycache__/fm_solvers_unipc.cpython-310.pyc DELETED
Binary file (22.2 kB)
 
hi_diffusers/schedulers/flash_flow_match.py DELETED
@@ -1,428 +0,0 @@
1
- # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import math
16
- from dataclasses import dataclass
17
- from typing import List, Optional, Tuple, Union
18
-
19
- import numpy as np
20
- import torch
21
- from diffusers.configuration_utils import ConfigMixin, register_to_config
22
- from diffusers.schedulers.scheduling_utils import SchedulerMixin
23
- from diffusers.utils import BaseOutput, is_scipy_available, logging
24
- from diffusers.utils.torch_utils import randn_tensor
25
-
26
- if is_scipy_available():
27
- import scipy.stats
28
-
29
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
-
31
-
32
- @dataclass
33
- class FlashFlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
34
- """
35
- Output class for the scheduler's `step` function output.
36
-
37
- Args:
38
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
39
- Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
40
- denoising loop.
41
- """
42
-
43
- prev_sample: torch.FloatTensor
44
-
45
-
46
- class FlashFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
47
- """
48
- Euler scheduler.
49
-
50
- This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
51
- methods the library implements for all schedulers such as loading and saving.
52
-
53
- Args:
54
- num_train_timesteps (`int`, defaults to 1000):
55
- The number of diffusion steps to train the model.
56
- timestep_spacing (`str`, defaults to `"linspace"`):
57
- The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
58
- Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
59
- shift (`float`, defaults to 1.0):
60
- The shift value for the timestep schedule.
61
- """
62
-
63
- _compatibles = []
64
- order = 1
65
-
66
- @register_to_config
67
- def __init__(
68
- self,
69
- num_train_timesteps: int = 1000,
70
- shift: float = 1.0,
71
- use_dynamic_shifting=False,
72
- base_shift: Optional[float] = 0.5,
73
- max_shift: Optional[float] = 1.15,
74
- base_image_seq_len: Optional[int] = 256,
75
- max_image_seq_len: Optional[int] = 4096,
76
- invert_sigmas: bool = False,
77
- use_karras_sigmas: Optional[bool] = False,
78
- use_exponential_sigmas: Optional[bool] = False,
79
- use_beta_sigmas: Optional[bool] = False,
80
- ):
81
- if self.config.use_beta_sigmas and not is_scipy_available():
82
- raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
83
- if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
84
- raise ValueError(
85
- "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
86
- )
87
- timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
88
- timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
89
-
90
- sigmas = timesteps / num_train_timesteps
91
- if not use_dynamic_shifting:
92
- # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
93
- sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
94
-
95
- self.timesteps = sigmas * num_train_timesteps
96
-
97
- self._step_index = None
98
- self._begin_index = None
99
-
100
- self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
101
- self.sigma_min = self.sigmas[-1].item()
102
- self.sigma_max = self.sigmas[0].item()
103
-
104
- @property
105
- def step_index(self):
106
- """
107
- The index counter for current timestep. It will increase 1 after each scheduler step.
108
- """
109
- return self._step_index
110
-
111
- @property
112
- def begin_index(self):
113
- """
114
- The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
115
- """
116
- return self._begin_index
117
-
118
- # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
119
- def set_begin_index(self, begin_index: int = 0):
120
- """
121
- Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
122
-
123
- Args:
124
- begin_index (`int`):
125
- The begin index for the scheduler.
126
- """
127
- self._begin_index = begin_index
128
-
129
- def scale_noise(
130
- self,
131
- sample: torch.FloatTensor,
132
- timestep: Union[float, torch.FloatTensor],
133
- noise: Optional[torch.FloatTensor] = None,
134
- ) -> torch.FloatTensor:
135
- """
136
- Forward process in flow-matching
137
-
138
- Args:
139
- sample (`torch.FloatTensor`):
140
- The input sample.
141
- timestep (`int`, *optional*):
142
- The current timestep in the diffusion chain.
143
-
144
- Returns:
145
- `torch.FloatTensor`:
146
- A scaled input sample.
147
- """
148
- # Make sure sigmas and timesteps have the same device and dtype as original_samples
149
- sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
150
-
151
- if sample.device.type == "mps" and torch.is_floating_point(timestep):
152
- # mps does not support float64
153
- schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
154
- timestep = timestep.to(sample.device, dtype=torch.float32)
155
- else:
156
- schedule_timesteps = self.timesteps.to(sample.device)
157
- timestep = timestep.to(sample.device)
158
-
159
- # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
160
- if self.begin_index is None:
161
- step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
162
- elif self.step_index is not None:
163
- # add_noise is called after first denoising step (for inpainting)
164
- step_indices = [self.step_index] * timestep.shape[0]
165
- else:
166
- # add noise is called before first denoising step to create initial latent(img2img)
167
- step_indices = [self.begin_index] * timestep.shape[0]
168
-
169
- sigma = sigmas[step_indices].flatten()
170
- while len(sigma.shape) < len(sample.shape):
171
- sigma = sigma.unsqueeze(-1)
172
-
173
- sample = sigma * noise + (1.0 - sigma) * sample
174
-
175
- return sample
176
-
177
- def _sigma_to_t(self, sigma):
178
- return sigma * self.config.num_train_timesteps
179
-
180
- def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
181
- return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
182
-
183
- def set_timesteps(
184
- self,
185
- num_inference_steps: int = None,
186
- device: Union[str, torch.device] = None,
187
- sigmas: Optional[List[float]] = None,
188
- mu: Optional[float] = None,
189
- ):
190
- """
191
- Sets the discrete timesteps used for the diffusion chain (to be run before inference).
192
-
193
- Args:
194
- num_inference_steps (`int`):
195
- The number of diffusion steps used when generating samples with a pre-trained model.
196
- device (`str` or `torch.device`, *optional*):
197
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
198
- """
199
- if self.config.use_dynamic_shifting and mu is None:
200
- raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
201
-
202
- if sigmas is None:
203
- timesteps = np.linspace(
204
- self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
205
- )
206
-
207
- sigmas = timesteps / self.config.num_train_timesteps
208
- else:
209
- sigmas = np.array(sigmas).astype(np.float32)
210
- num_inference_steps = len(sigmas)
211
- self.num_inference_steps = num_inference_steps
212
-
213
- if self.config.use_dynamic_shifting:
214
- sigmas = self.time_shift(mu, 1.0, sigmas)
215
- else:
216
- sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
217
-
218
- if self.config.use_karras_sigmas:
219
- sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
220
-
221
- elif self.config.use_exponential_sigmas:
222
- sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
223
-
224
- elif self.config.use_beta_sigmas:
225
- sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
226
-
227
- sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
228
- timesteps = sigmas * self.config.num_train_timesteps
229
-
230
- if self.config.invert_sigmas:
231
- sigmas = 1.0 - sigmas
232
- timesteps = sigmas * self.config.num_train_timesteps
233
- sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])
234
- else:
235
- sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
236
-
237
- self.timesteps = timesteps.to(device=device)
238
- self.sigmas = sigmas
239
- self._step_index = None
240
- self._begin_index = None
241
-
242
- def index_for_timestep(self, timestep, schedule_timesteps=None):
243
- if schedule_timesteps is None:
244
- schedule_timesteps = self.timesteps
245
-
246
- indices = (schedule_timesteps == timestep).nonzero()
247
-
248
- # The sigma index that is taken for the **very** first `step`
249
- # is always the second index (or the last index if there is only 1)
250
- # This way we can ensure we don't accidentally skip a sigma in
251
- # case we start in the middle of the denoising schedule (e.g. for image-to-image)
252
- pos = 1 if len(indices) > 1 else 0
253
-
254
- return indices[pos].item()
255
-
256
- def _init_step_index(self, timestep):
257
- if self.begin_index is None:
258
- if isinstance(timestep, torch.Tensor):
259
- timestep = timestep.to(self.timesteps.device)
260
- self._step_index = self.index_for_timestep(timestep)
261
- else:
262
- self._step_index = self._begin_index
263
-
264
- def step(
265
- self,
266
- model_output: torch.FloatTensor,
267
- timestep: Union[float, torch.FloatTensor],
268
- sample: torch.FloatTensor,
269
- s_churn: float = 0.0,
270
- s_tmin: float = 0.0,
271
- s_tmax: float = float("inf"),
272
- s_noise: float = 1.0,
273
- generator: Optional[torch.Generator] = None,
274
- return_dict: bool = True,
275
- ) -> Union[FlashFlowMatchEulerDiscreteSchedulerOutput, Tuple]:
276
- """
277
- Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
278
- process from the learned model outputs (most often the predicted noise).
279
-
280
- Args:
281
- model_output (`torch.FloatTensor`):
282
- The direct output from learned diffusion model.
283
- timestep (`float`):
284
- The current discrete timestep in the diffusion chain.
285
- sample (`torch.FloatTensor`):
286
- A current instance of a sample created by the diffusion process.
287
- s_churn (`float`):
288
- s_tmin (`float`):
289
- s_tmax (`float`):
290
- s_noise (`float`, defaults to 1.0):
291
- Scaling factor for noise added to the sample.
292
- generator (`torch.Generator`, *optional*):
293
- A random number generator.
294
- return_dict (`bool`):
295
- Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
296
- tuple.
297
-
298
- Returns:
299
- [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
300
- If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
301
- returned, otherwise a tuple is returned where the first element is the sample tensor.
302
- """
303
-
304
- if (
305
- isinstance(timestep, int)
306
- or isinstance(timestep, torch.IntTensor)
307
- or isinstance(timestep, torch.LongTensor)
308
- ):
309
- raise ValueError(
310
- (
311
- "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
312
- " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
313
- " one of the `scheduler.timesteps` as a timestep."
314
- ),
315
- )
316
-
317
- if self.step_index is None:
318
- self._init_step_index(timestep)
319
-
320
- # Upcast to avoid precision issues when computing prev_sample
321
-
322
- sigma = self.sigmas[self.step_index]
323
-
324
- # Upcast to avoid precision issues when computing prev_sample
325
- sample = sample.to(torch.float32)
326
-
327
- denoised = sample - model_output * sigma
328
-
329
- if self.step_index < self.num_inference_steps - 1:
330
- sigma_next = self.sigmas[self.step_index + 1]
331
- noise = randn_tensor(
332
- model_output.shape,
333
- generator=generator,
334
- device=model_output.device,
335
- dtype=denoised.dtype,
336
- )
337
- sample = sigma_next * noise + (1.0 - sigma_next) * denoised
338
-
339
- self._step_index += 1
340
- sample = sample.to(model_output.dtype)
341
-
342
- if not return_dict:
343
- return (sample,)
344
-
345
- return FlashFlowMatchEulerDiscreteSchedulerOutput(prev_sample=sample)
346
-
347
- # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
348
- def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
349
- """Constructs the noise schedule of Karras et al. (2022)."""
350
-
351
- # Hack to make sure that other schedulers which copy this function don't break
352
- # TODO: Add this logic to the other schedulers
353
- if hasattr(self.config, "sigma_min"):
354
- sigma_min = self.config.sigma_min
355
- else:
356
- sigma_min = None
357
-
358
- if hasattr(self.config, "sigma_max"):
359
- sigma_max = self.config.sigma_max
360
- else:
361
- sigma_max = None
362
-
363
- sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
364
- sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
365
-
366
- rho = 7.0 # 7.0 is the value used in the paper
367
- ramp = np.linspace(0, 1, num_inference_steps)
368
- min_inv_rho = sigma_min ** (1 / rho)
369
- max_inv_rho = sigma_max ** (1 / rho)
370
- sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
371
- return sigmas
372
-
373
- # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
374
- def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
375
- """Constructs an exponential noise schedule."""
376
-
377
- # Hack to make sure that other schedulers which copy this function don't break
378
- # TODO: Add this logic to the other schedulers
379
- if hasattr(self.config, "sigma_min"):
380
- sigma_min = self.config.sigma_min
381
- else:
382
- sigma_min = None
383
-
384
- if hasattr(self.config, "sigma_max"):
385
- sigma_max = self.config.sigma_max
386
- else:
387
- sigma_max = None
388
-
389
- sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
390
- sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
391
-
392
- sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
393
- return sigmas
394
-
395
- # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
396
- def _convert_to_beta(
397
- self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
398
- ) -> torch.Tensor:
399
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
400
-
401
- # Hack to make sure that other schedulers which copy this function don't break
402
- # TODO: Add this logic to the other schedulers
403
- if hasattr(self.config, "sigma_min"):
404
- sigma_min = self.config.sigma_min
405
- else:
406
- sigma_min = None
407
-
408
- if hasattr(self.config, "sigma_max"):
409
- sigma_max = self.config.sigma_max
410
- else:
411
- sigma_max = None
412
-
413
- sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
414
- sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
415
-
416
- sigmas = np.array(
417
- [
418
- sigma_min + (ppf * (sigma_max - sigma_min))
419
- for ppf in [
420
- scipy.stats.beta.ppf(timestep, alpha, beta)
421
- for timestep in 1 - np.linspace(0, 1, num_inference_steps)
422
- ]
423
- ]
424
- )
425
- return sigmas
426
-
427
- def __len__(self):
428
- return self.config.num_train_timesteps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hi_diffusers/schedulers/fm_solvers_unipc.py DELETED
@@ -1,800 +0,0 @@
1
- # Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py
2
- # Convert unipc for flow matching
3
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
4
-
5
- import math
6
- from typing import List, Optional, Tuple, Union
7
-
8
- import numpy as np
9
- import torch
10
- from diffusers.configuration_utils import ConfigMixin, register_to_config
11
- from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers,
12
- SchedulerMixin,
13
- SchedulerOutput)
14
- from diffusers.utils import deprecate, is_scipy_available
15
-
16
- if is_scipy_available():
17
- import scipy.stats
18
-
19
-
20
- class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
21
- """
22
- `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
23
-
24
- This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
25
- methods the library implements for all schedulers such as loading and saving.
26
-
27
- Args:
28
- num_train_timesteps (`int`, defaults to 1000):
29
- The number of diffusion steps to train the model.
30
- solver_order (`int`, default `2`):
31
- The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
32
- due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
33
- unconditional sampling.
34
- prediction_type (`str`, defaults to "flow_prediction"):
35
- Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
36
- the flow of the diffusion process.
37
- thresholding (`bool`, defaults to `False`):
38
- Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
39
- as Stable Diffusion.
40
- dynamic_thresholding_ratio (`float`, defaults to 0.995):
41
- The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
42
- sample_max_value (`float`, defaults to 1.0):
43
- The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
44
- predict_x0 (`bool`, defaults to `True`):
45
- Whether to use the updating algorithm on the predicted x0.
46
- solver_type (`str`, default `bh2`):
47
- Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
48
- otherwise.
49
- lower_order_final (`bool`, default `True`):
50
- Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
51
- stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
52
- disable_corrector (`list`, default `[]`):
53
- Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
54
- and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
55
- usually disabled during the first few steps.
56
- solver_p (`SchedulerMixin`, default `None`):
57
- Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
58
- use_karras_sigmas (`bool`, *optional*, defaults to `False`):
59
- Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
60
- the sigmas are determined according to a sequence of noise levels {σi}.
61
- use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
62
- Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
63
- timestep_spacing (`str`, defaults to `"linspace"`):
64
- The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
65
- Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
66
- steps_offset (`int`, defaults to 0):
67
- An offset added to the inference steps, as required by some model families.
68
- final_sigmas_type (`str`, defaults to `"zero"`):
69
- The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
70
- sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
71
- """
72
-
73
- _compatibles = [e.name for e in KarrasDiffusionSchedulers]
74
- order = 1
75
-
76
- @register_to_config
77
- def __init__(
78
- self,
79
- num_train_timesteps: int = 1000,
80
- solver_order: int = 2,
81
- prediction_type: str = "flow_prediction",
82
- shift: Optional[float] = 1.0,
83
- use_dynamic_shifting=False,
84
- thresholding: bool = False,
85
- dynamic_thresholding_ratio: float = 0.995,
86
- sample_max_value: float = 1.0,
87
- predict_x0: bool = True,
88
- solver_type: str = "bh2",
89
- lower_order_final: bool = True,
90
- disable_corrector: List[int] = [],
91
- solver_p: SchedulerMixin = None,
92
- timestep_spacing: str = "linspace",
93
- steps_offset: int = 0,
94
- final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
95
- ):
96
-
97
- if solver_type not in ["bh1", "bh2"]:
98
- if solver_type in ["midpoint", "heun", "logrho"]:
99
- self.register_to_config(solver_type="bh2")
100
- else:
101
- raise NotImplementedError(
102
- f"{solver_type} is not implemented for {self.__class__}")
103
-
104
- self.predict_x0 = predict_x0
105
- # setable values
106
- self.num_inference_steps = None
107
- alphas = np.linspace(1, 1 / num_train_timesteps,
108
- num_train_timesteps)[::-1].copy()
109
- sigmas = 1.0 - alphas
110
- sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
111
-
112
- if not use_dynamic_shifting:
113
- # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
114
- sigmas = shift * sigmas / (1 +
115
- (shift - 1) * sigmas) # pyright: ignore
116
-
117
- self.sigmas = sigmas
118
- self.timesteps = sigmas * num_train_timesteps
119
-
120
- self.model_outputs = [None] * solver_order
121
- self.timestep_list = [None] * solver_order
122
- self.lower_order_nums = 0
123
- self.disable_corrector = disable_corrector
124
- self.solver_p = solver_p
125
- self.last_sample = None
126
- self._step_index = None
127
- self._begin_index = None
128
-
129
- self.sigmas = self.sigmas.to(
130
- "cpu") # to avoid too much CPU/GPU communication
131
- self.sigma_min = self.sigmas[-1].item()
132
- self.sigma_max = self.sigmas[0].item()
133
-
134
- @property
135
- def step_index(self):
136
- """
137
- The index counter for current timestep. It will increase 1 after each scheduler step.
138
- """
139
- return self._step_index
140
-
141
- @property
142
- def begin_index(self):
143
- """
144
- The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
145
- """
146
- return self._begin_index
147
-
148
- # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
149
- def set_begin_index(self, begin_index: int = 0):
150
- """
151
- Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
152
-
153
- Args:
154
- begin_index (`int`):
155
- The begin index for the scheduler.
156
- """
157
- self._begin_index = begin_index
158
-
159
- # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
160
- def set_timesteps(
161
- self,
162
- num_inference_steps: Union[int, None] = None,
163
- device: Union[str, torch.device] = None,
164
- sigmas: Optional[List[float]] = None,
165
- mu: Optional[Union[float, None]] = None,
166
- shift: Optional[Union[float, None]] = None,
167
- ):
168
- """
169
- Sets the discrete timesteps used for the diffusion chain (to be run before inference).
170
- Args:
171
- num_inference_steps (`int`):
172
- Total number of the spacing of the time steps.
173
- device (`str` or `torch.device`, *optional*):
174
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
175
- """
176
-
177
- if self.config.use_dynamic_shifting and mu is None:
178
- raise ValueError(
179
- " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
180
- )
181
-
182
- if sigmas is None:
183
- sigmas = np.linspace(self.sigma_max, self.sigma_min,
184
- num_inference_steps +
185
- 1).copy()[:-1] # pyright: ignore
186
-
187
- if self.config.use_dynamic_shifting:
188
- sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore
189
- else:
190
- if shift is None:
191
- shift = self.config.shift
192
- sigmas = shift * sigmas / (1 +
193
- (shift - 1) * sigmas) # pyright: ignore
194
-
195
- if self.config.final_sigmas_type == "sigma_min":
196
- sigma_last = ((1 - self.alphas_cumprod[0]) /
197
- self.alphas_cumprod[0])**0.5
198
- elif self.config.final_sigmas_type == "zero":
199
- sigma_last = 0
200
- else:
201
- raise ValueError(
202
- f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
203
- )
204
-
205
- timesteps = sigmas * self.config.num_train_timesteps
206
- sigmas = np.concatenate([sigmas, [sigma_last]
207
- ]).astype(np.float32) # pyright: ignore
208
-
209
- self.sigmas = torch.from_numpy(sigmas)
210
- self.timesteps = torch.from_numpy(timesteps).to(
211
- device=device, dtype=torch.int64)
212
-
213
- self.num_inference_steps = len(timesteps)
214
-
215
- self.model_outputs = [
216
- None,
217
- ] * self.config.solver_order
218
- self.lower_order_nums = 0
219
- self.last_sample = None
220
- if self.solver_p:
221
- self.solver_p.set_timesteps(self.num_inference_steps, device=device)
222
-
223
- # add an index counter for schedulers that allow duplicated timesteps
224
- self._step_index = None
225
- self._begin_index = None
226
- self.sigmas = self.sigmas.to(
227
- "cpu") # to avoid too much CPU/GPU communication
228
-
229
- # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
230
- def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
231
- """
232
- "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
233
- prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
234
- s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
235
- pixels from saturation at each step. We find that dynamic thresholding results in significantly better
236
- photorealism as well as better image-text alignment, especially when using very large guidance weights."
237
-
238
- https://arxiv.org/abs/2205.11487
239
- """
240
- dtype = sample.dtype
241
- batch_size, channels, *remaining_dims = sample.shape
242
-
243
- if dtype not in (torch.float32, torch.float64):
244
- sample = sample.float(
245
- ) # upcast for quantile calculation, and clamp not implemented for cpu half
246
-
247
- # Flatten sample for doing quantile calculation along each image
248
- sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
249
-
250
- abs_sample = sample.abs() # "a certain percentile absolute pixel value"
251
-
252
- s = torch.quantile(
253
- abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
254
- s = torch.clamp(
255
- s, min=1, max=self.config.sample_max_value
256
- ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
257
- s = s.unsqueeze(
258
- 1) # (batch_size, 1) because clamp will broadcast along dim=0
259
- sample = torch.clamp(
260
- sample, -s, s
261
- ) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
262
-
263
- sample = sample.reshape(batch_size, channels, *remaining_dims)
264
- sample = sample.to(dtype)
265
-
266
- return sample
267
-
268
- # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
269
- def _sigma_to_t(self, sigma):
270
- return sigma * self.config.num_train_timesteps
271
-
272
- def _sigma_to_alpha_sigma_t(self, sigma):
273
- return 1 - sigma, sigma
274
-
275
- # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
276
- def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
277
- return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
278
-
279
- def convert_model_output(
280
- self,
281
- model_output: torch.Tensor,
282
- *args,
283
- sample: torch.Tensor = None,
284
- **kwargs,
285
- ) -> torch.Tensor:
286
- r"""
287
- Convert the model output to the corresponding type the UniPC algorithm needs.
288
-
289
- Args:
290
- model_output (`torch.Tensor`):
291
- The direct output from the learned diffusion model.
292
- timestep (`int`):
293
- The current discrete timestep in the diffusion chain.
294
- sample (`torch.Tensor`):
295
- A current instance of a sample created by the diffusion process.
296
-
297
- Returns:
298
- `torch.Tensor`:
299
- The converted model output.
300
- """
301
- timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
302
- if sample is None:
303
- if len(args) > 1:
304
- sample = args[1]
305
- else:
306
- raise ValueError(
307
- "missing `sample` as a required keyward argument")
308
- if timestep is not None:
309
- deprecate(
310
- "timesteps",
311
- "1.0.0",
312
- "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
313
- )
314
-
315
- sigma = self.sigmas[self.step_index]
316
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
317
-
318
- if self.predict_x0:
319
- if self.config.prediction_type == "flow_prediction":
320
- sigma_t = self.sigmas[self.step_index]
321
- x0_pred = sample - sigma_t * model_output
322
- else:
323
- raise ValueError(
324
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
325
- " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
326
- )
327
-
328
- if self.config.thresholding:
329
- x0_pred = self._threshold_sample(x0_pred)
330
-
331
- return x0_pred
332
- else:
333
- if self.config.prediction_type == "flow_prediction":
334
- sigma_t = self.sigmas[self.step_index]
335
- epsilon = sample - (1 - sigma_t) * model_output
336
- else:
337
- raise ValueError(
338
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
339
- " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
340
- )
341
-
342
- if self.config.thresholding:
343
- sigma_t = self.sigmas[self.step_index]
344
- x0_pred = sample - sigma_t * model_output
345
- x0_pred = self._threshold_sample(x0_pred)
346
- epsilon = model_output + x0_pred
347
-
348
- return epsilon
349
-
350
- def multistep_uni_p_bh_update(
351
- self,
352
- model_output: torch.Tensor,
353
- *args,
354
- sample: torch.Tensor = None,
355
- order: int = None, # pyright: ignore
356
- **kwargs,
357
- ) -> torch.Tensor:
358
- """
359
- One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
360
-
361
- Args:
362
- model_output (`torch.Tensor`):
363
- The direct output from the learned diffusion model at the current timestep.
364
- prev_timestep (`int`):
365
- The previous discrete timestep in the diffusion chain.
366
- sample (`torch.Tensor`):
367
- A current instance of a sample created by the diffusion process.
368
- order (`int`):
369
- The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
370
-
371
- Returns:
372
- `torch.Tensor`:
373
- The sample tensor at the previous timestep.
374
- """
375
- prev_timestep = args[0] if len(args) > 0 else kwargs.pop(
376
- "prev_timestep", None)
377
- if sample is None:
378
- if len(args) > 1:
379
- sample = args[1]
380
- else:
381
- raise ValueError(
382
- " missing `sample` as a required keyward argument")
383
- if order is None:
384
- if len(args) > 2:
385
- order = args[2]
386
- else:
387
- raise ValueError(
388
- " missing `order` as a required keyward argument")
389
- if prev_timestep is not None:
390
- deprecate(
391
- "prev_timestep",
392
- "1.0.0",
393
- "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
394
- )
395
- model_output_list = self.model_outputs
396
-
397
- s0 = self.timestep_list[-1]
398
- m0 = model_output_list[-1]
399
- x = sample
400
-
401
- if self.solver_p:
402
- x_t = self.solver_p.step(model_output, s0, x).prev_sample
403
- return x_t
404
-
405
- sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[
406
- self.step_index] # pyright: ignore
407
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
408
- alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
409
-
410
- lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
411
- lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
412
-
413
- h = lambda_t - lambda_s0
414
- device = sample.device
415
-
416
- rks = []
417
- D1s = []
418
- for i in range(1, order):
419
- si = self.step_index - i # pyright: ignore
420
- mi = model_output_list[-(i + 1)]
421
- alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
422
- lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
423
- rk = (lambda_si - lambda_s0) / h
424
- rks.append(rk)
425
- D1s.append((mi - m0) / rk) # pyright: ignore
426
-
427
- rks.append(1.0)
428
- rks = torch.tensor(rks, device=device)
429
-
430
- R = []
431
- b = []
432
-
433
- hh = -h if self.predict_x0 else h
434
- h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
435
- h_phi_k = h_phi_1 / hh - 1
436
-
437
- factorial_i = 1
438
-
439
- if self.config.solver_type == "bh1":
440
- B_h = hh
441
- elif self.config.solver_type == "bh2":
442
- B_h = torch.expm1(hh)
443
- else:
444
- raise NotImplementedError()
445
-
446
- for i in range(1, order + 1):
447
- R.append(torch.pow(rks, i - 1))
448
- b.append(h_phi_k * factorial_i / B_h)
449
- factorial_i *= i + 1
450
- h_phi_k = h_phi_k / hh - 1 / factorial_i
451
-
452
- R = torch.stack(R)
453
- b = torch.tensor(b, device=device)
454
-
455
- if len(D1s) > 0:
456
- D1s = torch.stack(D1s, dim=1) # (B, K)
457
- # for order 2, we use a simplified version
458
- if order == 2:
459
- rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
460
- else:
461
- rhos_p = torch.linalg.solve(R[:-1, :-1],
462
- b[:-1]).to(device).to(x.dtype)
463
- else:
464
- D1s = None
465
-
466
- if self.predict_x0:
467
- x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
468
- if D1s is not None:
469
- pred_res = torch.einsum("k,bkc...->bc...", rhos_p,
470
- D1s) # pyright: ignore
471
- else:
472
- pred_res = 0
473
- x_t = x_t_ - alpha_t * B_h * pred_res
474
- else:
475
- x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
476
- if D1s is not None:
477
- pred_res = torch.einsum("k,bkc...->bc...", rhos_p,
478
- D1s) # pyright: ignore
479
- else:
480
- pred_res = 0
481
- x_t = x_t_ - sigma_t * B_h * pred_res
482
-
483
- x_t = x_t.to(x.dtype)
484
- return x_t
485
-
486
- def multistep_uni_c_bh_update(
487
- self,
488
- this_model_output: torch.Tensor,
489
- *args,
490
- last_sample: torch.Tensor = None,
491
- this_sample: torch.Tensor = None,
492
- order: int = None, # pyright: ignore
493
- **kwargs,
494
- ) -> torch.Tensor:
495
- """
496
- One step for the UniC (B(h) version).
497
-
498
- Args:
499
- this_model_output (`torch.Tensor`):
500
- The model outputs at `x_t`.
501
- this_timestep (`int`):
502
- The current timestep `t`.
503
- last_sample (`torch.Tensor`):
504
- The generated sample before the last predictor `x_{t-1}`.
505
- this_sample (`torch.Tensor`):
506
- The generated sample after the last predictor `x_{t}`.
507
- order (`int`):
508
- The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
509
-
510
- Returns:
511
- `torch.Tensor`:
512
- The corrected sample tensor at the current timestep.
513
- """
514
- this_timestep = args[0] if len(args) > 0 else kwargs.pop(
515
- "this_timestep", None)
516
- if last_sample is None:
517
- if len(args) > 1:
518
- last_sample = args[1]
519
- else:
520
- raise ValueError(
521
- " missing`last_sample` as a required keyward argument")
522
- if this_sample is None:
523
- if len(args) > 2:
524
- this_sample = args[2]
525
- else:
526
- raise ValueError(
527
- " missing`this_sample` as a required keyward argument")
528
- if order is None:
529
- if len(args) > 3:
530
- order = args[3]
531
- else:
532
- raise ValueError(
533
- " missing`order` as a required keyward argument")
534
- if this_timestep is not None:
535
- deprecate(
536
- "this_timestep",
537
- "1.0.0",
538
- "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
539
- )
540
-
541
- model_output_list = self.model_outputs
542
-
543
- m0 = model_output_list[-1]
544
- x = last_sample
545
- x_t = this_sample
546
- model_t = this_model_output
547
-
548
- sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[
549
- self.step_index - 1] # pyright: ignore
550
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
551
- alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
552
-
553
- lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
554
- lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
555
-
556
- h = lambda_t - lambda_s0
557
- device = this_sample.device
558
-
559
- rks = []
560
- D1s = []
561
- for i in range(1, order):
562
- si = self.step_index - (i + 1) # pyright: ignore
563
- mi = model_output_list[-(i + 1)]
564
- alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
565
- lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
566
- rk = (lambda_si - lambda_s0) / h
567
- rks.append(rk)
568
- D1s.append((mi - m0) / rk) # pyright: ignore
569
-
570
- rks.append(1.0)
571
- rks = torch.tensor(rks, device=device)
572
-
573
- R = []
574
- b = []
575
-
576
- hh = -h if self.predict_x0 else h
577
- h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
578
- h_phi_k = h_phi_1 / hh - 1
579
-
580
- factorial_i = 1
581
-
582
- if self.config.solver_type == "bh1":
583
- B_h = hh
584
- elif self.config.solver_type == "bh2":
585
- B_h = torch.expm1(hh)
586
- else:
587
- raise NotImplementedError()
588
-
589
- for i in range(1, order + 1):
590
- R.append(torch.pow(rks, i - 1))
591
- b.append(h_phi_k * factorial_i / B_h)
592
- factorial_i *= i + 1
593
- h_phi_k = h_phi_k / hh - 1 / factorial_i
594
-
595
- R = torch.stack(R)
596
- b = torch.tensor(b, device=device)
597
-
598
- if len(D1s) > 0:
599
- D1s = torch.stack(D1s, dim=1)
600
- else:
601
- D1s = None
602
-
603
- # for order 1, we use a simplified version
604
- if order == 1:
605
- rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
606
- else:
607
- rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
608
-
609
- if self.predict_x0:
610
- x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
611
- if D1s is not None:
612
- corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
613
- else:
614
- corr_res = 0
615
- D1_t = model_t - m0
616
- x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
617
- else:
618
- x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
619
- if D1s is not None:
620
- corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
621
- else:
622
- corr_res = 0
623
- D1_t = model_t - m0
624
- x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
625
- x_t = x_t.to(x.dtype)
626
- return x_t
627
-
628
- def index_for_timestep(self, timestep, schedule_timesteps=None):
629
- if schedule_timesteps is None:
630
- schedule_timesteps = self.timesteps
631
-
632
- indices = (schedule_timesteps == timestep).nonzero()
633
-
634
- # The sigma index that is taken for the **very** first `step`
635
- # is always the second index (or the last index if there is only 1)
636
- # This way we can ensure we don't accidentally skip a sigma in
637
- # case we start in the middle of the denoising schedule (e.g. for image-to-image)
638
- pos = 1 if len(indices) > 1 else 0
639
-
640
- return indices[pos].item()
641
-
642
- # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
643
- def _init_step_index(self, timestep):
644
- """
645
- Initialize the step_index counter for the scheduler.
646
- """
647
-
648
- if self.begin_index is None:
649
- if isinstance(timestep, torch.Tensor):
650
- timestep = timestep.to(self.timesteps.device)
651
- self._step_index = self.index_for_timestep(timestep)
652
- else:
653
- self._step_index = self._begin_index
654
-
655
- def step(self,
656
- model_output: torch.Tensor,
657
- timestep: Union[int, torch.Tensor],
658
- sample: torch.Tensor,
659
- return_dict: bool = True,
660
- generator=None) -> Union[SchedulerOutput, Tuple]:
661
- """
662
- Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
663
- the multistep UniPC.
664
-
665
- Args:
666
- model_output (`torch.Tensor`):
667
- The direct output from learned diffusion model.
668
- timestep (`int`):
669
- The current discrete timestep in the diffusion chain.
670
- sample (`torch.Tensor`):
671
- A current instance of a sample created by the diffusion process.
672
- return_dict (`bool`):
673
- Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
674
-
675
- Returns:
676
- [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
677
- If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
678
- tuple is returned where the first element is the sample tensor.
679
-
680
- """
681
- if self.num_inference_steps is None:
682
- raise ValueError(
683
- "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
684
- )
685
-
686
- if self.step_index is None:
687
- self._init_step_index(timestep)
688
-
689
- use_corrector = (
690
- self.step_index > 0 and
691
- self.step_index - 1 not in self.disable_corrector and
692
- self.last_sample is not None # pyright: ignore
693
- )
694
-
695
- model_output_convert = self.convert_model_output(
696
- model_output, sample=sample)
697
- if use_corrector:
698
- sample = self.multistep_uni_c_bh_update(
699
- this_model_output=model_output_convert,
700
- last_sample=self.last_sample,
701
- this_sample=sample,
702
- order=self.this_order,
703
- )
704
-
705
- for i in range(self.config.solver_order - 1):
706
- self.model_outputs[i] = self.model_outputs[i + 1]
707
- self.timestep_list[i] = self.timestep_list[i + 1]
708
-
709
- self.model_outputs[-1] = model_output_convert
710
- self.timestep_list[-1] = timestep # pyright: ignore
711
-
712
- if self.config.lower_order_final:
713
- this_order = min(self.config.solver_order,
714
- len(self.timesteps) -
715
- self.step_index) # pyright: ignore
716
- else:
717
- this_order = self.config.solver_order
718
-
719
- self.this_order = min(this_order,
720
- self.lower_order_nums + 1) # warmup for multistep
721
- assert self.this_order > 0
722
-
723
- self.last_sample = sample
724
- prev_sample = self.multistep_uni_p_bh_update(
725
- model_output=model_output, # pass the original non-converted model output, in case solver-p is used
726
- sample=sample,
727
- order=self.this_order,
728
- )
729
-
730
- if self.lower_order_nums < self.config.solver_order:
731
- self.lower_order_nums += 1
732
-
733
- # upon completion increase step index by one
734
- self._step_index += 1 # pyright: ignore
735
-
736
- if not return_dict:
737
- return (prev_sample,)
738
-
739
- return SchedulerOutput(prev_sample=prev_sample)
740
-
741
- def scale_model_input(self, sample: torch.Tensor, *args,
742
- **kwargs) -> torch.Tensor:
743
- """
744
- Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
745
- current timestep.
746
-
747
- Args:
748
- sample (`torch.Tensor`):
749
- The input sample.
750
-
751
- Returns:
752
- `torch.Tensor`:
753
- A scaled input sample.
754
- """
755
- return sample
756
-
757
- # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
758
- def add_noise(
759
- self,
760
- original_samples: torch.Tensor,
761
- noise: torch.Tensor,
762
- timesteps: torch.IntTensor,
763
- ) -> torch.Tensor:
764
- # Make sure sigmas and timesteps have the same device and dtype as original_samples
765
- sigmas = self.sigmas.to(
766
- device=original_samples.device, dtype=original_samples.dtype)
767
- if original_samples.device.type == "mps" and torch.is_floating_point(
768
- timesteps):
769
- # mps does not support float64
770
- schedule_timesteps = self.timesteps.to(
771
- original_samples.device, dtype=torch.float32)
772
- timesteps = timesteps.to(
773
- original_samples.device, dtype=torch.float32)
774
- else:
775
- schedule_timesteps = self.timesteps.to(original_samples.device)
776
- timesteps = timesteps.to(original_samples.device)
777
-
778
- # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
779
- if self.begin_index is None:
780
- step_indices = [
781
- self.index_for_timestep(t, schedule_timesteps)
782
- for t in timesteps
783
- ]
784
- elif self.step_index is not None:
785
- # add_noise is called after first denoising step (for inpainting)
786
- step_indices = [self.step_index] * timesteps.shape[0]
787
- else:
788
- # add noise is called before first denoising step to create initial latent(img2img)
789
- step_indices = [self.begin_index] * timesteps.shape[0]
790
-
791
- sigma = sigmas[step_indices].flatten()
792
- while len(sigma.shape) < len(original_samples.shape):
793
- sigma = sigma.unsqueeze(-1)
794
-
795
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
796
- noisy_samples = alpha_t * original_samples + sigma_t * noise
797
- return noisy_samples
798
-
799
- def __len__(self):
800
- return self.config.num_train_timesteps