File size: 7,912 Bytes
d72b2c3 e70ad00 d72b2c3 0230db1 bc7f42e 0230db1 4eabff6 0230db1 bc7f42e 0230db1 4eabff6 0230db1 4eabff6 0230db1 bc7f42e 0230db1 d72b2c3 4eabff6 a0ce150 999bd69 e366cd5 b399825 bc7f42e b399825 d72b2c3 0230db1 bc7f42e b399825 d72b2c3 3eec6d2 e366cd5 bc7f42e 6cb7713 376e6a0 dcfe0d4 bc7f42e 999bd69 6cb7713 bc7f42e 376e6a0 4eabff6 bc7f42e d72b2c3 86b9ce4 0230db1 bc7f42e 0230db1 bc7f42e 999bd69 376e6a0 999bd69 0230db1 bc7f42e 999bd69 376e6a0 0230db1 bc7f42e 999bd69 bc7f42e bc08da5 999bd69 bc7f42e bc08da5 999bd69 bc08da5 999bd69 bc08da5 999bd69 bc08da5 999bd69 bc7f42e 999bd69 bc7f42e 999bd69 bc7f42e dcfe0d4 bc7f42e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import torch
from audiocraft.transformer import StreamingTransformer
from torch import nn
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
class T5(nn.Module):
def __init__(self):
super().__init__()
self.output_proj = nn.Linear(1024, # t5-large
1536) # lm hidden
self.t5_tokenizer = T5Tokenizer.from_pretrained('t5-large', legacy=True)
t5 = T5EncoderModel.from_pretrained('t5-large').train(mode=False)
# this makes sure that the t5 is not part
# of the saved checkpoint
self.__dict__['t5'] = t5.to('cuda:0')
def forward(self, prompt):
with torch.set_grad_enabled(False), torch.autocast(device_type='cuda', dtype=torch.float32):
bs = len(prompt) // 2
d = self.t5_tokenizer(prompt,
return_tensors='pt',
padding=True).to(self.output_proj.bias.device)
d['attention_mask'][bs:, :] = 0 # null condition t5 attn_mask should be zero
x = self.t5(input_ids=d['input_ids'],
attention_mask=d['attention_mask']).last_hidden_state # no kv
# Float 16
# > self.output_proj() is outside of autocast of t5 - however inside the autocast of lm thus computed in torch.float16
x = self.output_proj(x) # nn.Linear() - produces different result if there is no duplicate txt condition here
x[bs:, :, :] = 0 # venv/../site-packages/audiocraft/modules/conditioners.py -> tokenize()
return x
class LMModel(nn.Module):
def __init__(self,
n_q = 4,
card = 2048,
dim = 1536
):
super().__init__()
self.t5 = T5()
self.card = card # 2048
self.n_draw = 1 # draw > 1 tokens of different CFG scale
# batch size > 1 is slower from n_draw as calls transformer on larger batch
self.emb = nn.ModuleList([nn.Embedding(self.card + 1, dim) for _ in range(n_q)]) # EMBEDDING HAS 2049
self.transformer = StreamingTransformer()
self.out_norm = nn.LayerNorm(dim, eps=1e-5)
self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=False) for _ in range(n_q)]) # LINEAR DOESNT HAVE 2049
def forward(self,
sequence,
condition_tensors=None,
cache_position=None):
bs, n_q, time_frames = sequence.shape # [bs, 4, time]
input_ = sum([self.emb[k](sequence[:, k]) for k in range(n_q)])
out = self.transformer(torch.cat([input_, input_], 0), # duplicate null condition (bs x 2) for ClassifierFreeGuidance
cross_attention_src=condition_tensors,
cache_position=cache_position
)
logits = torch.stack([self.linears[k](self.out_norm(out)) for k in range(n_q)], dim=1) # [2*bs, 4, 1, 2048]
logits = 3 * logits[:bs, :, :, :] - self._scale * logits[bs:, :, :, :] # [ bs, 4, n_draw, 2048]
k = 24
logits = torch.softmax(logits / 1.0, dim=3) # [bs, 4, 1, 2048]
p, ix = torch.topk(logits, k, dim=3) # p = [bs, 4, 1, 24], ix = [bs, 4, 1, 2048]
# Exponential Distribution
deflation = torch.empty_like(p).exponential_(lambd=1)
p = p / deflation
# divide large probs with exp(prob) If prob=.001 then 1/exp(1*.001) -> almost by 0 --> exp doesnt really produce (0, Inf)
p = p.argmax(dim=3, keepdim=True) # [bs, 4, n_draw, 24]
tok = ix.gather(dim=3, index=p).to(torch.int64) # [bs, 4, n_draw, 1]
return tok[:, :, :, 0].transpose(1, 2) # [bs, n_draw, 4]
@torch.no_grad()
def generate(self,
max_tokens=None,
text_condition=None):
x = self.t5(text_condition)
bs = x.shape[0] // 2 # has null conditions - bs*2*N_REPEAT applys in builders.py
self._scale = .3 * torch.rand(1, 1, self.n_draw, 1, device=x.device) + 1.94
cache_position = 0
out_codes = torch.full((bs,
self.n_draw,
4,
4 + 3 + max_tokens), # 4 + max_tokens + 4-1 to have sufficient to index the 1st antidiagonal of 4x4 + 4 xtra tokens
self.card,
dtype=torch.long,
device=x.device) # [bs, n_draw, 4, dur]
# A/R
for offset in range(0, max_tokens + 4 - 1): # max_tokens + n_q - 1
# extract diagonal via indexing out_codes[ [0, 1, 2, 3], [0, 1, 2, 3] ]
next_token = self.forward(out_codes[:, 0, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset][:, :, None], # index diagonal & exapnd to [bs, n_q, dur=1]
#gen_sequence[:, 0, :, offset-1:offset], # DIAGINDEXING for setting prediction of lm into gen_sequence THE GENSEQUENCE has to be un-delayed in the end [Because it has to be de-delayed for the vocoder then is actually only the lm input that requires to see the delay thus we could just feed by diaggather] so it matches gen_codes -1 a[[0, 1, 2, 3], torch.tensor([0, 1, 2, 3]) + 5] the gen_sequence is indexed by vertical column and fed to lm however the prediction of lm is place diagonally with delay to the gen_sequence
condition_tensors=x, # utilisation of the attention mask of txt condition ?
cache_position=cache_position) # [bs, n_draw, 4]
# Fill of next_token should be also placed on antidiagonal [not column]
# Do Not Overwrite 2048 of TRIU/TRIL = START/END => Do Not Fill them by Predicted Tokens
# 0-th antidiagonal should be full of card = [2048, 2048, 2048, 2048]
#
# [2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048, 2048, 2048],
# [2048, 2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048, 2048],
# [2048, 2048, 2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048],
# [2048, 2048, 2048, 2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6]]
# NO OVerWriting
if offset == 0:
next_token[:, :, 1:4] = 2048 # self.card - bottom 3 entries of the antidiagonal should remain 2048
elif offset == 1:
next_token[:, :, 2:4] = 2048 # bottom 2 entries of the antidiagonal should remain 2048
elif offset == 2:
next_token[:, :, 3:4] = 2048
elif offset == max_tokens:
next_token[:, :, 0:1] = 2048 # top 1 entry of the antidiagonal should stay to 2048
elif offset == (max_tokens + 1):
next_token[:, :, 0:2] = 2048
elif offset == (max_tokens + 2):
next_token[:, :, 0:3] = 2048
else: # offset 3,4,5,6,7...... max_tokens-1 # FILL Complete n_q = 4 ANTIDIAGONAL ENTRIES
pass #print('No delete anti-diag')
out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token
# Sink Attn
if (offset > 0) and (offset % 71) == 0:
n_preserve = 4
self.transformer._flush(n_preserve=n_preserve)
cache_position = n_preserve
else:
cache_position += 1
# [bs, n_draw, 4, time+xtra] -> [bs, 4, n_draw, time] -> [bs, 4, time * n_draw]
out_codes = out_codes[:, :, :, 4:max_tokens+4].transpose(1, 2).reshape(bs, 4, self.n_draw * max_tokens)
# flush for next API call
self.transformer._flush()
return out_codes # SKIP THE 4 fill 2048
|