File size: 6,610 Bytes
d72b2c3 b399825 e83a997 4eabff6 bc7f42e 4eabff6 bc7f42e d72b2c3 4898f81 bc7f42e d72b2c3 bc7f42e d72b2c3 a0ce150 4898f81 b399825 bc7f42e cf02fb0 4898f81 d72b2c3 4898f81 b399825 bc7f42e d72b2c3 b399825 cf02fb0 bc7f42e 731cb10 b399825 731cb10 b399825 4898f81 bc7f42e b399825 4898f81 b399825 4898f81 bc7f42e 4eabff6 b399825 cf02fb0 4898f81 b399825 4898f81 bc7f42e 4eabff6 b399825 bc7f42e b399825 bc7f42e 4898f81 b399825 4898f81 bc7f42e b399825 731cb10 bc7f42e d72b2c3 b399825 bc7f42e 4eabff6 b399825 731cb10 2e6c69d d72b2c3 bc7f42e d72b2c3 4898f81 b399825 bc7f42e 4898f81 b399825 4898f81 e70ad00 731cb10 0230db1 bc7f42e 731cb10 bc7f42e d72b2c3 a0ce150 731cb10 4898f81 bc7f42e d72b2c3 bc7f42e d72b2c3 4898f81 bc7f42e b399825 4eabff6 bc7f42e 4898f81 bc7f42e 731cb10 bc7f42e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
torch.backends.cuda.enable_mem_efficient_sdp(True)
def create_sin_embedding(positions,
dim,
max_period=10000
):
# assert dim % 2 == 0
half_dim = dim // 2
positions = positions.to(torch.float)
adim = torch.arange(half_dim, device=positions.device,
dtype=torch.float).view(1, 1, -1)
max_period_tensor = torch.full([],
max_period,
device=positions.device,
dtype=torch.float) # avoid sync point
phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
# OFFICIAL is torch.float32 HOWEVER self_attn.in_prod_weight = torch.float16
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
class StreamingMultiheadAttention(nn.Module):
def __init__(self,
embed_dim,
num_heads,
cross_attention=False,
):
super().__init__()
self.cross_attention = cross_attention
# if not self.cross_attention then it has kvcachingn
self.k_history = None
# cleanup history through LM inside GENERATION - Each 0,..,47 mha has different kv history
self.v_history = None
self.num_heads = num_heads
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.register_buffer('in_proj_weight', torch.ones((3 * embed_dim, embed_dim),
dtype=torch.float))
def forward(self,
query,
key=None,
value=None):
layout = "b h t d"
if self.cross_attention:
# Different queries, keys, values > split in_proj_weight
dim = self.in_proj_weight.shape[0] // 3
q = nn.functional.linear(query, self.in_proj_weight[:dim])
k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim])
v = nn.functional.linear(value, self.in_proj_weight[2 * dim:])
q, k, v = [
rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
else:
# 1st projected makes k,v (instantaneous)
# Here else is self_attention for audio with itself (above is cross attention txt)
# HISTORY - DIFFERENT FOR EACH TRANSF LAYER
# here we have different floating values from official
projected = nn.functional.linear(query, self.in_proj_weight, None)
# print(query.sum(), projected.sum() , self.in_proj_weight.sum(), 'Lc') # verified official AudioGen values
bound_layout = "b h p t d"
packed = rearrange(
projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
q, k, v = packed.unbind(dim=2)
if self.k_history is not None:
# IF ctrl^c during live_demo the assigning of each of kv is non-atomic k!=v
# thus it will try to continue with incompatible k/v dims!
self.k_history = torch.cat([self.k_history, k], 2)
self.v_history = torch.cat([self.v_history, v], 2)
else:
self.k_history = k
self.v_history = v
# Assign Completed k / v to k / v
k = self.k_history
v = self.v_history
# -> kv CACHE ONLY APPLIES if not self.cross_attention
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, is_causal=False, dropout_p=0.0)
x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
x = self.out_proj(x)
return x
class StreamingTransformerLayer(nn.Module):
def __init__(self,
d_model,
num_heads,
dim_feedforward):
super().__init__()
self.self_attn = StreamingMultiheadAttention(embed_dim=d_model,
num_heads=num_heads)
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False)
self.cross_attention = StreamingMultiheadAttention(embed_dim=d_model,
num_heads=num_heads,
cross_attention=True)
self.norm_cross = nn.LayerNorm(d_model, eps=1e-5)
self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
self.norm2 = nn.LayerNorm(d_model, eps=1e-5)
def forward(self,
x,
cross_attention_src=None):
x = x + self.self_attn(self.norm1(x))
x = x + self.cross_attention(query=self.norm_cross(x),
key=cross_attention_src,
value=cross_attention_src) # txtcondition
x = x + self.linear2(F.gelu(self.linear1(self.norm2(x))))
return x
class StreamingTransformer(nn.Module):
def __init__(self,
d_model=1536,
num_heads=24,
num_layers=48,
dim_feedforward=6144):
super().__init__()
self.layers = nn.ModuleList(
[
StreamingTransformerLayer(d_model=d_model,
num_heads=num_heads,
dim_feedforward=dim_feedforward) for _ in range(num_layers)
]
)
def forward(self,
x,
cache_position=None,
cross_attention_src=None):
x = x + create_sin_embedding(
torch.zeros(x.shape[0], 1, 1, device=x.device) + cache_position, 1536)
for lay in self.layers:
x = lay(x,
cross_attention_src=cross_attention_src)
return x
def _flush(self,
n_preserve=None):
for lay in self.layers:
if n_preserve is not None:
# cache position is difficult to choose to also preserve kv from end
lay.self_attn.k_history = lay.self_attn.k_history[:, :, :n_preserve, :]
lay.self_attn.v_history = lay.self_attn.v_history[:, :, :n_preserve, :]
else:
lay.self_attn.k_history = None
lay.self_attn.v_history = None
|