XiangpengYang's picture
first commit
5602c9a
raw
history blame
5.43 kB
from typing import Callable, Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
# from diffusers.utils
from diffusers.utils import deprecate, logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.models.attention import FeedForward, AdaLayerNorm
from diffusers.models.cross_attention import CrossAttention
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, stride=1):
super().__init__()
if rank > min(in_features, out_features):
Warning(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}, reset to {min(in_features, out_features)//2}")
rank = min(in_features, out_features)//2
self.down = nn.Conv1d(in_features, rank, bias=False,
kernel_size=3,
stride = stride,
padding=1,)
self.up = nn.Conv1d(rank, out_features, bias=False,
kernel_size=3,
padding=1,)
nn.init.normal_(self.down.weight, std=1 / rank)
# nn.init.zeros_(self.down.bias.data)
nn.init.zeros_(self.up.weight)
# nn.init.zeros_(self.up.bias.data)
if stride > 1:
self.skip = nn.AvgPool1d(kernel_size=3, stride=2, padding=1)
def forward(self, hidden_states):
orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype
down_hidden_states = self.down(hidden_states.to(dtype))
up_hidden_states = self.up(down_hidden_states)
if hasattr(self, 'skip'):
hidden_states=self.skip(hidden_states)
return up_hidden_states.to(orig_dtype)+hidden_states
class LoRACrossAttnProcessor(nn.Module):
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
super().__init__()
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
def __call__(
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class LoRAXFormersCrossAttnProcessor(nn.Module):
def __init__(self, hidden_size, cross_attention_dim, rank=4):
super().__init__()
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
def __call__(
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query).contiguous()
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states