Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.utils.checkpoint | |
| from typing import Any, Optional, Tuple, Union | |
| class Attention(nn.Module): | |
| """Multi-headed attention from 'Attention Is All You Need' paper""" | |
| def __init__(self, hidden_size, num_attention_heads, attention_head_dim, attention_dropout=0.0): | |
| super().__init__() | |
| self.embed_dim = hidden_size | |
| self.num_heads = num_attention_heads | |
| self.head_dim = attention_head_dim | |
| self.scale = self.head_dim**-0.5 | |
| self.dropout = attention_dropout | |
| self.inner_dim = self.head_dim * self.num_heads | |
| self.k_proj = nn.Linear(self.embed_dim, self.inner_dim) | |
| self.v_proj = nn.Linear(self.embed_dim, self.inner_dim) | |
| self.q_proj = nn.Linear(self.embed_dim, self.inner_dim) | |
| self.out_proj = nn.Linear(self.inner_dim, self.embed_dim) | |
| def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): | |
| return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| causal_attention_mask: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = False, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | |
| """Input shape: Batch x Time x Channel""" | |
| bsz, tgt_len, embed_dim = hidden_states.size() | |
| # get query proj | |
| query_states = self.q_proj(hidden_states) * self.scale | |
| key_states = self._shape(self.k_proj(hidden_states), -1, bsz) | |
| value_states = self._shape(self.v_proj(hidden_states), -1, bsz) | |
| proj_shape = (bsz * self.num_heads, -1, self.head_dim) | |
| query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) | |
| key_states = key_states.view(*proj_shape) | |
| value_states = value_states.view(*proj_shape) | |
| src_len = key_states.size(1) | |
| attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) | |
| if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): | |
| raise ValueError( | |
| f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" | |
| f" {attn_weights.size()}" | |
| ) | |
| # apply the causal_attention_mask first | |
| if causal_attention_mask is not None: | |
| if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): | |
| raise ValueError( | |
| f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" | |
| f" {causal_attention_mask.size()}" | |
| ) | |
| attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask | |
| attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) | |
| if attention_mask is not None: | |
| if attention_mask.size() != (bsz, 1, tgt_len, src_len): | |
| raise ValueError( | |
| f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" | |
| ) | |
| attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask | |
| attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) | |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1) | |
| if output_attentions: | |
| # this operation is a bit akward, but it's required to | |
| # make sure that attn_weights keeps its gradient. | |
| # In order to do so, attn_weights have to reshaped | |
| # twice and have to be reused in the following | |
| attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) | |
| attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) | |
| else: | |
| attn_weights_reshaped = None | |
| attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) | |
| attn_output = torch.bmm(attn_probs, value_states) | |
| if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): | |
| raise ValueError( | |
| f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" | |
| f" {attn_output.size()}" | |
| ) | |
| attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) | |
| attn_output = attn_output.transpose(1, 2) | |
| attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim) | |
| attn_output = self.out_proj(attn_output) | |
| return attn_output, attn_weights_reshaped | |
| class MLP(nn.Module): | |
| def __init__(self, hidden_size, intermediate_size, mult=4): | |
| super().__init__() | |
| self.activation_fn = nn.SiLU() | |
| self.fc1 = nn.Linear(hidden_size, intermediate_size * mult) | |
| self.fc2 = nn.Linear(intermediate_size * mult, hidden_size) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| hidden_states = self.fc1(hidden_states) | |
| hidden_states = self.activation_fn(hidden_states) | |
| hidden_states = self.fc2(hidden_states) | |
| return hidden_states | |
| class Transformer(nn.Module): | |
| def __init__(self, depth=12): | |
| super().__init__() | |
| self.layers = nn.ModuleList([TransformerBlock() for _ in range(depth)]) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: torch.Tensor=None, | |
| causal_attention_mask: torch.Tensor=None, | |
| output_attentions: Optional[bool] = False, | |
| ) -> Tuple[torch.FloatTensor]: | |
| """ | |
| Args: | |
| hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` | |
| attention_mask (`torch.FloatTensor`): attention mask of size | |
| `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. | |
| `(config.encoder_attention_heads,)`. | |
| output_attentions (`bool`, *optional*): | |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under | |
| returned tensors for more detail. | |
| """ | |
| for layer in self.layers: | |
| hidden_states = layer( | |
| hidden_states=hidden_states, | |
| attention_mask=attention_mask, | |
| causal_attention_mask=causal_attention_mask, | |
| output_attentions=output_attentions, | |
| ) | |
| return hidden_states | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, hidden_size=512, num_attention_heads=12, attention_head_dim=64, attention_dropout=0.0, dropout=0.0, eps=1e-5): | |
| super().__init__() | |
| self.embed_dim = hidden_size | |
| self.self_attn = Attention(hidden_size=hidden_size, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim) | |
| self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=eps) | |
| self.mlp = MLP(hidden_size=hidden_size, intermediate_size=hidden_size) | |
| self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=eps) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: torch.Tensor=None, | |
| causal_attention_mask: torch.Tensor=None, | |
| output_attentions: Optional[bool] = False, | |
| ) -> Tuple[torch.FloatTensor]: | |
| """ | |
| Args: | |
| hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` | |
| attention_mask (`torch.FloatTensor`): attention mask of size | |
| `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. | |
| `(config.encoder_attention_heads,)`. | |
| output_attentions (`bool`, *optional*): | |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under | |
| returned tensors for more detail. | |
| """ | |
| residual = hidden_states | |
| hidden_states = self.layer_norm1(hidden_states) | |
| hidden_states, attn_weights = self.self_attn( | |
| hidden_states=hidden_states, | |
| attention_mask=attention_mask, | |
| causal_attention_mask=causal_attention_mask, | |
| output_attentions=output_attentions, | |
| ) | |
| hidden_states = residual + hidden_states | |
| residual = hidden_states | |
| hidden_states = self.layer_norm2(hidden_states) | |
| hidden_states = self.mlp(hidden_states) | |
| hidden_states = residual + hidden_states | |
| outputs = (hidden_states,) | |
| if output_attentions: | |
| outputs += (attn_weights,) | |
| return outputs[0] | |
| class DiffusionTransformerBlock(nn.Module): | |
| def __init__(self, hidden_size=512, num_attention_heads=12, attention_head_dim=64, attention_dropout=0.0, dropout=0.0, eps=1e-5): | |
| super().__init__() | |
| self.embed_dim = hidden_size | |
| self.self_attn = Attention(hidden_size=hidden_size, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim) | |
| self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=eps) | |
| self.mlp = MLP(hidden_size=hidden_size, intermediate_size=hidden_size) | |
| self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=eps) | |
| self.output_token = nn.Parameter(torch.randn(1, hidden_size)) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: torch.Tensor=None, | |
| causal_attention_mask: torch.Tensor=None, | |
| output_attentions: Optional[bool] = False, | |
| ) -> Tuple[torch.FloatTensor]: | |
| """ | |
| Args: | |
| hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` | |
| attention_mask (`torch.FloatTensor`): attention mask of size | |
| `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. | |
| `(config.encoder_attention_heads,)`. | |
| output_attentions (`bool`, *optional*): | |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under | |
| returned tensors for more detail. | |
| """ | |
| output_token = self.output_token.unsqueeze(0).repeat(hidden_states.shape[0], 1, 1) | |
| hidden_states = torch.cat([output_token, hidden_states], dim=1) | |
| residual = hidden_states | |
| hidden_states = self.layer_norm1(hidden_states) | |
| hidden_states, attn_weights = self.self_attn( | |
| hidden_states=hidden_states, | |
| attention_mask=attention_mask, | |
| causal_attention_mask=causal_attention_mask, | |
| output_attentions=output_attentions, | |
| ) | |
| hidden_states = residual + hidden_states | |
| residual = hidden_states | |
| hidden_states = self.layer_norm2(hidden_states) | |
| hidden_states = self.mlp(hidden_states) | |
| hidden_states = residual + hidden_states | |
| outputs = (hidden_states,) | |
| if output_attentions: | |
| outputs += (attn_weights,) | |
| return outputs[0][:,0:1,...] | |
| class V2AMapperMLP(nn.Module): | |
| def __init__(self, input_dim=512, output_dim=512, expansion_rate=4): | |
| super().__init__() | |
| self.linear = nn.Linear(input_dim, input_dim * expansion_rate) | |
| self.silu = nn.SiLU() | |
| self.layer_norm = nn.LayerNorm(input_dim * expansion_rate) | |
| self.linear2 = nn.Linear(input_dim * expansion_rate, output_dim) | |
| def forward(self, x): | |
| x = self.linear(x) | |
| x = self.silu(x) | |
| x = self.layer_norm(x) | |
| x = self.linear2(x) | |
| return x | |
| class ImageProjModel(torch.nn.Module): | |
| """Projection Model""" | |
| def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): | |
| super().__init__() | |
| self.cross_attention_dim = cross_attention_dim | |
| self.clip_extra_context_tokens = clip_extra_context_tokens | |
| self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) | |
| self.norm = torch.nn.LayerNorm(cross_attention_dim) | |
| self.zero_initialize_last_layer() | |
| def zero_initialize_last_layer(module): | |
| last_layer = None | |
| for module_name, layer in module.named_modules(): | |
| if isinstance(layer, torch.nn.Linear): | |
| last_layer = layer | |
| if last_layer is not None: | |
| last_layer.weight.data.zero_() | |
| last_layer.bias.data.zero_() | |
| def forward(self, image_embeds): | |
| embeds = image_embeds | |
| clip_extra_context_tokens = self.proj(embeds).reshape( | |
| -1, self.clip_extra_context_tokens, self.cross_attention_dim | |
| ) | |
| clip_extra_context_tokens = self.norm(clip_extra_context_tokens) | |
| return clip_extra_context_tokens | |
| class VisionAudioAdapter(torch.nn.Module): | |
| def __init__( | |
| self, | |
| embedding_size=768, | |
| expand_dim=4, | |
| token_num=4, | |
| ): | |
| super().__init__() | |
| self.mapper = V2AMapperMLP( | |
| embedding_size, | |
| embedding_size, | |
| expansion_rate=expand_dim, | |
| ) | |
| self.proj = ImageProjModel( | |
| cross_attention_dim=embedding_size, | |
| clip_embeddings_dim=embedding_size, | |
| clip_extra_context_tokens=token_num, | |
| ) | |
| def forward(self, image_embeds): | |
| image_embeds = self.mapper(image_embeds) | |
| image_embeds = self.proj(image_embeds) | |
| return image_embeds | |