Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import torch | |
from torch import nn | |
from typing import Optional | |
from diffusers.models.attention_processor import Attention | |
from diffusers.utils.torch_utils import maybe_allow_in_graph | |
class HiDreamAttention(Attention): | |
def __init__( | |
self, | |
query_dim: int, | |
heads: int = 8, | |
dim_head: int = 64, | |
upcast_attention: bool = False, | |
upcast_softmax: bool = False, | |
scale_qk: bool = True, | |
eps: float = 1e-5, | |
processor = None, | |
out_dim: int = None, | |
single: bool = False | |
): | |
super(Attention, self).__init__() | |
self.inner_dim = out_dim if out_dim is not None else dim_head * heads | |
self.query_dim = query_dim | |
self.upcast_attention = upcast_attention | |
self.upcast_softmax = upcast_softmax | |
self.out_dim = out_dim if out_dim is not None else query_dim | |
self.scale_qk = scale_qk | |
self.scale = dim_head**-0.5 if self.scale_qk else 1.0 | |
self.heads = out_dim // dim_head if out_dim is not None else heads | |
self.sliceable_head_dim = heads | |
self.single = single | |
linear_cls = nn.Linear | |
self.linear_cls = linear_cls | |
self.to_q = linear_cls(query_dim, self.inner_dim) | |
self.to_k = linear_cls(self.inner_dim, self.inner_dim) | |
self.to_v = linear_cls(self.inner_dim, self.inner_dim) | |
self.to_out = linear_cls(self.inner_dim, self.out_dim) | |
self.q_rms_norm = nn.RMSNorm(self.inner_dim, eps) | |
self.k_rms_norm = nn.RMSNorm(self.inner_dim, eps) | |
if not single: | |
self.to_q_t = linear_cls(query_dim, self.inner_dim) | |
self.to_k_t = linear_cls(self.inner_dim, self.inner_dim) | |
self.to_v_t = linear_cls(self.inner_dim, self.inner_dim) | |
self.to_out_t = linear_cls(self.inner_dim, self.out_dim) | |
self.q_rms_norm_t = nn.RMSNorm(self.inner_dim, eps) | |
self.k_rms_norm_t = nn.RMSNorm(self.inner_dim, eps) | |
self.set_processor(processor) | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
nn.init.xavier_uniform_(m.weight) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
def forward( | |
self, | |
norm_image_tokens: torch.FloatTensor, | |
image_tokens_masks: torch.FloatTensor = None, | |
norm_text_tokens: torch.FloatTensor = None, | |
rope: torch.FloatTensor = None, | |
) -> torch.Tensor: | |
return self.processor( | |
self, | |
image_tokens = norm_image_tokens, | |
image_tokens_masks = image_tokens_masks, | |
text_tokens = norm_text_tokens, | |
rope = rope, | |
) | |
class FeedForwardSwiGLU(nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
hidden_dim: int, | |
multiple_of: int = 256, | |
ffn_dim_multiplier: Optional[float] = None, | |
): | |
super().__init__() | |
hidden_dim = int(2 * hidden_dim / 3) | |
# custom dim factor multiplier | |
if ffn_dim_multiplier is not None: | |
hidden_dim = int(ffn_dim_multiplier * hidden_dim) | |
hidden_dim = multiple_of * ( | |
(hidden_dim + multiple_of - 1) // multiple_of | |
) | |
self.w1 = nn.Linear(dim, hidden_dim, bias=False) | |
self.w2 = nn.Linear(hidden_dim, dim, bias=False) | |
self.w3 = nn.Linear(dim, hidden_dim, bias=False) | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
nn.init.xavier_uniform_(m.weight) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
def forward(self, x): | |
return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) |