Spaces:
Build error
Build error
File size: 8,633 Bytes
357c94c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
"""
This module provides the implementation of an Audio Projection Model, which is designed for
audio processing tasks. The model takes audio embeddings as input and outputs context tokens
that can be used for various downstream applications, such as audio analysis or synthesis.
The AudioProjModel class is based on the ModelMixin class from the diffusers library, which
provides a foundation for building custom models. This implementation includes multiple linear
layers with ReLU activation functions and a LayerNorm for normalization.
Key Features:
- Audio embedding input with flexible sequence length and block structure.
- Multiple linear layers for feature transformation.
- ReLU activation for non-linear transformation.
- LayerNorm for stabilizing and speeding up training.
- Rearrangement of input embeddings to match the model's expected input shape.
- Customizable number of blocks, channels, and context tokens for adaptability.
The module is structured to be easily integrated into larger systems or used as a standalone
component for audio feature extraction and processing.
Classes:
- AudioProjModel: A class representing the audio projection model with configurable parameters.
Functions:
- (none)
Dependencies:
- torch: For tensor operations and neural network components.
- diffusers: For the ModelMixin base class.
- einops: For tensor rearrangement operations.
"""
import torch
from diffusers import ModelMixin
from einops import rearrange
import math
import torch.nn as nn
from .parallel_states import (
initialize_sequence_parallel_state,
nccl_info,
get_sequence_parallel_state,
parallel_attention,
all_gather,
all_to_all_4D,
)
class AudioProjNet2(ModelMixin):
"""Audio Projection Model
This class defines an audio projection model that takes audio embeddings as input
and produces context tokens as output. The model is based on the ModelMixin class
and consists of multiple linear layers and activation functions. It can be used
for various audio processing tasks.
Attributes:
seq_len (int): The length of the audio sequence.
blocks (int): The number of blocks in the audio projection model.
channels (int): The number of channels in the audio projection model.
intermediate_dim (int): The intermediate dimension of the model.
context_tokens (int): The number of context tokens in the output.
output_dim (int): The output dimension of the context tokens.
Methods:
__init__(self, seq_len=5, blocks=12, channels=768, intermediate_dim=512, context_tokens=32, output_dim=768):
Initializes the AudioProjModel with the given parameters.
forward(self, audio_embeds):
Defines the forward pass for the AudioProjModel.
Parameters:
audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels).
Returns:
context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim).
"""
def __init__(
self,
seq_len=5,
blocks=12, # add a new parameter blocks
channels=768, # add a new parameter channels
intermediate_dim=512,
output_dim=768,
context_tokens=4,
):
super().__init__()
self.seq_len = seq_len
self.blocks = blocks
self.channels = channels
self.input_dim = (
seq_len * blocks * channels
)
self.intermediate_dim = intermediate_dim
self.context_tokens = context_tokens
self.output_dim = output_dim
# define multiple linear layers
self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
self.norm = nn.LayerNorm(output_dim)
def forward(self, audio_embeds):
video_length = audio_embeds.shape[1]
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
batch_size, window_size, blocks, channels = audio_embeds.shape
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
audio_embeds = torch.relu(self.proj1(audio_embeds))
audio_embeds = torch.relu(self.proj2(audio_embeds))
context_tokens = self.proj3(audio_embeds).reshape(
batch_size, self.context_tokens, self.output_dim
)
context_tokens = self.norm(context_tokens)
out_all = rearrange(
context_tokens, "(bz f) m c -> bz f m c", f=video_length
)
return out_all
def reshape_tensor(x, heads):
bs, length, width = x.shape
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttentionCA(nn.Module):
def __init__(self, *, dim=3072, dim_head=1024, heads=33):
super().__init__()
self.scale = dim_head ** -0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head #* heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
import torch.nn.init as init
init.zeros_(self.to_out.weight)
if self.to_out.bias is not None:
init.zeros_(self.to_out.bias)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, t, aa, D)
latent (torch.Tensor): latent features
shape (b, t, hw, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
# print("latents shape: ", latents.shape)
# print("x shape: ", x.shape)
q = self.to_q(latents)
k, v = self.to_kv(x).chunk(2, dim=-1)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
# out = out.permute(0, 2, 1, 3)
return self.to_out(out)
#def forward(self, x, latents):
# """
# Args:
# x (torch.Tensor): image features
# shape (b, t, aa, D)
# latent (torch.Tensor): latent features
# shape (b, t, hw, D)
# """
# if get_sequence_parallel_state():
# sp_size = nccl_info.sp_size
# sp_rank = nccl_info.rank_within_group
# print("rank:", latents.shape, sp_size, sp_rank)
# latents = torch.chunk(latents, sp_size, dim=1)[sp_rank]
# x = self.norm1(x)
# latents = self.norm2(latents)
# # print("latents shape: ", latents.shape)
# # print("x shape: ", x.shape)
# q = self.to_q(latents)
# k, v = self.to_kv(x).chunk(2, dim=-1)
# # print("q, k, v: ", q.shape, k.shape, v.shape)
# # attention
# #scale = 1 / math.sqrt(math.sqrt(self.dim_head))
# #weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
# #weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
# #out = weight @ v
# def shrink_head(encoder_state, dim):
# local_heads = encoder_state.shape[dim] // nccl_info.sp_size
# return encoder_state.narrow(dim, nccl_info.rank_within_group * local_heads, local_heads)
# if get_sequence_parallel_state():
# # batch_size, seq_len, attn_heads, head_dim
# q = all_to_all_4D(q, scatter_dim=2, gather_dim=1) # [2, 32256, 24, 128]
# k = shrink_head(k ,dim=2)
# v = shrink_head(v ,dim=2)
# qkv = torch.stack([query, key, value], dim=2)
# attn = flash_attn_no_pad(qkv, causal=False, dropout_p=0.0, softmax_scale=None)
# # out = out.permute(0, 2, 1, 3)
# #b, s, a, d = attn.shape
# #attn = attn.reshape(b, s, -1)
#
# out = self.to_out(attn)
# if get_sequence_parallel_state():
# out = all_gather(out, dim=1)
# return out
|