|
import torch
|
|
import torch.nn.functional as F
|
|
from diffusers.models.attention_processor import (
|
|
Attention,
|
|
AttnProcessor2_0,
|
|
SlicedAttnProcessor,
|
|
XFormersAttnProcessor
|
|
)
|
|
|
|
try:
|
|
import xformers.ops
|
|
except:
|
|
xformers = None
|
|
|
|
|
|
loaded_networks = []
|
|
|
|
|
|
def apply_single_hypernetwork(
|
|
hypernetwork, hidden_states, encoder_hidden_states
|
|
):
|
|
context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states)
|
|
return context_k, context_v
|
|
|
|
|
|
def apply_hypernetworks(context_k, context_v, layer=None):
|
|
if len(loaded_networks) == 0:
|
|
return context_v, context_v
|
|
for hypernetwork in loaded_networks:
|
|
context_k, context_v = hypernetwork.forward(context_k, context_v)
|
|
|
|
context_k = context_k.to(dtype=context_k.dtype)
|
|
context_v = context_v.to(dtype=context_k.dtype)
|
|
|
|
return context_k, context_v
|
|
|
|
|
|
|
|
def xformers_forward(
|
|
self: XFormersAttnProcessor,
|
|
attn: Attention,
|
|
hidden_states: torch.Tensor,
|
|
encoder_hidden_states: torch.Tensor = None,
|
|
attention_mask: torch.Tensor = None,
|
|
):
|
|
batch_size, sequence_length, _ = (
|
|
hidden_states.shape
|
|
if encoder_hidden_states is None
|
|
else encoder_hidden_states.shape
|
|
)
|
|
|
|
attention_mask = attn.prepare_attention_mask(
|
|
attention_mask, sequence_length, batch_size
|
|
)
|
|
|
|
query = attn.to_q(hidden_states)
|
|
|
|
if encoder_hidden_states is None:
|
|
encoder_hidden_states = hidden_states
|
|
elif attn.norm_cross:
|
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
|
|
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
|
|
|
|
key = attn.to_k(context_k)
|
|
value = attn.to_v(context_v)
|
|
|
|
query = attn.head_to_batch_dim(query).contiguous()
|
|
key = attn.head_to_batch_dim(key).contiguous()
|
|
value = attn.head_to_batch_dim(value).contiguous()
|
|
|
|
hidden_states = xformers.ops.memory_efficient_attention(
|
|
query,
|
|
key,
|
|
value,
|
|
attn_bias=attention_mask,
|
|
op=self.attention_op,
|
|
scale=attn.scale,
|
|
)
|
|
hidden_states = hidden_states.to(query.dtype)
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
def sliced_attn_forward(
|
|
self: SlicedAttnProcessor,
|
|
attn: Attention,
|
|
hidden_states: torch.Tensor,
|
|
encoder_hidden_states: torch.Tensor = None,
|
|
attention_mask: torch.Tensor = None,
|
|
):
|
|
batch_size, sequence_length, _ = (
|
|
hidden_states.shape
|
|
if encoder_hidden_states is None
|
|
else encoder_hidden_states.shape
|
|
)
|
|
attention_mask = attn.prepare_attention_mask(
|
|
attention_mask, sequence_length, batch_size
|
|
)
|
|
|
|
query = attn.to_q(hidden_states)
|
|
dim = query.shape[-1]
|
|
query = attn.head_to_batch_dim(query)
|
|
|
|
if encoder_hidden_states is None:
|
|
encoder_hidden_states = hidden_states
|
|
elif attn.norm_cross:
|
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
|
|
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
|
|
|
|
key = attn.to_k(context_k)
|
|
value = attn.to_v(context_v)
|
|
key = attn.head_to_batch_dim(key)
|
|
value = attn.head_to_batch_dim(value)
|
|
|
|
batch_size_attention, query_tokens, _ = query.shape
|
|
hidden_states = torch.zeros(
|
|
(batch_size_attention, query_tokens, dim // attn.heads),
|
|
device=query.device,
|
|
dtype=query.dtype,
|
|
)
|
|
|
|
for i in range(batch_size_attention // self.slice_size):
|
|
start_idx = i * self.slice_size
|
|
end_idx = (i + 1) * self.slice_size
|
|
|
|
query_slice = query[start_idx:end_idx]
|
|
key_slice = key[start_idx:end_idx]
|
|
attn_mask_slice = (
|
|
attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
|
)
|
|
|
|
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
|
|
|
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
|
|
|
hidden_states[start_idx:end_idx] = attn_slice
|
|
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
def v2_0_forward(
|
|
self: AttnProcessor2_0,
|
|
attn: Attention,
|
|
hidden_states,
|
|
encoder_hidden_states=None,
|
|
attention_mask=None,
|
|
):
|
|
batch_size, sequence_length, _ = (
|
|
hidden_states.shape
|
|
if encoder_hidden_states is None
|
|
else encoder_hidden_states.shape
|
|
)
|
|
inner_dim = hidden_states.shape[-1]
|
|
|
|
if attention_mask is not None:
|
|
attention_mask = attn.prepare_attention_mask(
|
|
attention_mask, sequence_length, batch_size
|
|
)
|
|
|
|
|
|
attention_mask = attention_mask.view(
|
|
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
|
)
|
|
|
|
query = attn.to_q(hidden_states)
|
|
|
|
if encoder_hidden_states is None:
|
|
encoder_hidden_states = hidden_states
|
|
elif attn.norm_cross:
|
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
|
|
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
|
|
|
|
key = attn.to_k(context_k)
|
|
value = attn.to_v(context_v)
|
|
|
|
head_dim = inner_dim // attn.heads
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention(
|
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
|
)
|
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(
|
|
batch_size, -1, attn.heads * head_dim
|
|
)
|
|
hidden_states = hidden_states.to(query.dtype)
|
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
def replace_attentions_for_hypernetwork():
|
|
import diffusers.models.attention_processor
|
|
|
|
diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = (
|
|
xformers_forward
|
|
)
|
|
diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = (
|
|
sliced_attn_forward
|
|
)
|
|
diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward
|
|
|