import torch import torch.nn as nn from typing import Any, Callable, Dict, List, Optional, Tuple, Union import warnings import logging from torch import Tensor from diffusers import ModelMixin from transformers.models.t5.modeling_t5 import T5LayerSelfAttention, T5LayerFF, T5LayerNorm logger = logging.getLogger(__name__) class T5EncoderBlock(nn.Module): def __init__(self, config, has_relative_attention_bias=False): super().__init__() self.layer = nn.ModuleList() self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) self.layer.append(T5LayerFF(config)) def forward( self, hidden_states, attention_mask=None, position_bias=None, layer_head_mask=None, output_attentions=False, ): self_attn_past_key_value, cross_attn_past_key_value = None, None self_attention_outputs = self.layer[0]( hidden_states, attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, past_key_value=self_attn_past_key_value, use_cache=False, output_attentions=output_attentions, ) hidden_states, present_key_value_state = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16: clamp_value = torch.where( torch.isinf(hidden_states).any(), torch.finfo(hidden_states.dtype).max - 1000, torch.finfo(hidden_states.dtype).max, ) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) # Apply Feed Forward layer hidden_states = self.layer[-1](hidden_states) # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16: clamp_value = torch.where( torch.isinf(hidden_states).any(), torch.finfo(hidden_states.dtype).max - 1000, torch.finfo(hidden_states.dtype).max, ) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) outputs = (hidden_states,) + attention_outputs return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) class T5EncoderBlockByT5Mapper(ModelMixin): def __init__(self, byt5_config, num_layers, sdxl_channels=None): super().__init__() if num_layers > 0: self.blocks = nn.ModuleList( [ T5EncoderBlock( byt5_config, has_relative_attention_bias=bool(i == 0)) for i in range(num_layers) ] ) else: self.blocks = None self.layer_norm = T5LayerNorm(byt5_config.d_model, eps=byt5_config.layer_norm_epsilon) if sdxl_channels is not None: self.channel_mapper = nn.Linear(byt5_config.d_model, sdxl_channels) self.final_layer_norm = T5LayerNorm(sdxl_channels, eps=byt5_config.layer_norm_epsilon) else: self.channel_mapper = None self.final_layer_norm = None def get_extended_attention_mask( self, attention_mask: Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None ) -> Tensor: """ Makes broadcastable attention and causal masks so that future and masked tokens are ignored. Arguments: attention_mask (`torch.Tensor`): Mask with ones indicating tokens to attend to, zeros for tokens to ignore. input_shape (`Tuple[int]`): The shape of the input to the model. Returns: `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. """ if dtype is None: dtype = self.dtype # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. if attention_mask.dim() == 3: extended_attention_mask = attention_mask[:, None, :, :] elif attention_mask.dim() == 2: # Provided a padding mask of dimensions [batch_size, seq_length] # - if the model is a decoder, apply a causal mask in addition to the padding mask # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] extended_attention_mask = attention_mask[:, None, None, :] else: raise ValueError( f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" ) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and the dtype's smallest value for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min return extended_attention_mask def forward(self, inputs_embeds, attention_mask): input_shape = inputs_embeds.size()[:-1] extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) hidden_states = inputs_embeds position_bias = None if self.blocks is not None: for layer_module in self.blocks: layer_outputs = layer_module( hidden_states, attention_mask=extended_attention_mask, position_bias=position_bias, ) hidden_states, position_bias = layer_outputs hidden_states = self.layer_norm(hidden_states) if self.channel_mapper is not None: hidden_states = self.channel_mapper(hidden_states) hidden_states = self.final_layer_norm(hidden_states) return hidden_states