Dionyssos's picture
determinis
bc7f42e
import torch
from audiocraft.transformer import StreamingTransformer
from torch import nn
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
class T5(nn.Module):
def __init__(self):
super().__init__()
self.output_proj = nn.Linear(1024, # t5-large
1536) # lm hidden
self.t5_tokenizer = T5Tokenizer.from_pretrained('t5-large', legacy=True)
t5 = T5EncoderModel.from_pretrained('t5-large').train(mode=False)
# this makes sure that the t5 is not part
# of the saved checkpoint
self.__dict__['t5'] = t5.to('cuda:0')
def forward(self, prompt):
with torch.set_grad_enabled(False), torch.autocast(device_type='cuda', dtype=torch.float32):
bs = len(prompt) // 2
d = self.t5_tokenizer(prompt,
return_tensors='pt',
padding=True).to(self.output_proj.bias.device)
d['attention_mask'][bs:, :] = 0 # null condition t5 attn_mask should be zero
x = self.t5(input_ids=d['input_ids'],
attention_mask=d['attention_mask']).last_hidden_state # no kv
# Float 16
# > self.output_proj() is outside of autocast of t5 - however inside the autocast of lm thus computed in torch.float16
x = self.output_proj(x) # nn.Linear() - produces different result if there is no duplicate txt condition here
x[bs:, :, :] = 0 # venv/../site-packages/audiocraft/modules/conditioners.py -> tokenize()
return x
class LMModel(nn.Module):
def __init__(self,
n_q = 4,
card = 2048,
dim = 1536
):
super().__init__()
self.t5 = T5()
self.card = card # 2048
self.n_draw = 1 # draw > 1 tokens of different CFG scale
# batch size > 1 is slower from n_draw as calls transformer on larger batch
self.emb = nn.ModuleList([nn.Embedding(self.card + 1, dim) for _ in range(n_q)]) # EMBEDDING HAS 2049
self.transformer = StreamingTransformer()
self.out_norm = nn.LayerNorm(dim, eps=1e-5)
self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=False) for _ in range(n_q)]) # LINEAR DOESNT HAVE 2049
def forward(self,
sequence,
condition_tensors=None,
cache_position=None):
bs, n_q, time_frames = sequence.shape # [bs, 4, time]
input_ = sum([self.emb[k](sequence[:, k]) for k in range(n_q)])
out = self.transformer(torch.cat([input_, input_], 0), # duplicate null condition (bs x 2) for ClassifierFreeGuidance
cross_attention_src=condition_tensors,
cache_position=cache_position
)
logits = torch.stack([self.linears[k](self.out_norm(out)) for k in range(n_q)], dim=1) # [2*bs, 4, 1, 2048]
logits = 3 * logits[:bs, :, :, :] - self._scale * logits[bs:, :, :, :] # [ bs, 4, n_draw, 2048]
k = 24
logits = torch.softmax(logits / 1.0, dim=3) # [bs, 4, 1, 2048]
p, ix = torch.topk(logits, k, dim=3) # p = [bs, 4, 1, 24], ix = [bs, 4, 1, 2048]
# Exponential Distribution
deflation = torch.empty_like(p).exponential_(lambd=1)
p = p / deflation
# divide large probs with exp(prob) If prob=.001 then 1/exp(1*.001) -> almost by 0 --> exp doesnt really produce (0, Inf)
p = p.argmax(dim=3, keepdim=True) # [bs, 4, n_draw, 24]
tok = ix.gather(dim=3, index=p).to(torch.int64) # [bs, 4, n_draw, 1]
return tok[:, :, :, 0].transpose(1, 2) # [bs, n_draw, 4]
@torch.no_grad()
def generate(self,
max_tokens=None,
text_condition=None):
x = self.t5(text_condition)
bs = x.shape[0] // 2 # has null conditions - bs*2*N_REPEAT applys in builders.py
self._scale = .3 * torch.rand(1, 1, self.n_draw, 1, device=x.device) + 1.94
cache_position = 0
out_codes = torch.full((bs,
self.n_draw,
4,
4 + 3 + max_tokens), # 4 + max_tokens + 4-1 to have sufficient to index the 1st antidiagonal of 4x4 + 4 xtra tokens
self.card,
dtype=torch.long,
device=x.device) # [bs, n_draw, 4, dur]
# A/R
for offset in range(0, max_tokens + 4 - 1): # max_tokens + n_q - 1
# extract diagonal via indexing out_codes[ [0, 1, 2, 3], [0, 1, 2, 3] ]
next_token = self.forward(out_codes[:, 0, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset][:, :, None], # index diagonal & exapnd to [bs, n_q, dur=1]
#gen_sequence[:, 0, :, offset-1:offset], # DIAGINDEXING for setting prediction of lm into gen_sequence THE GENSEQUENCE has to be un-delayed in the end [Because it has to be de-delayed for the vocoder then is actually only the lm input that requires to see the delay thus we could just feed by diaggather] so it matches gen_codes -1 a[[0, 1, 2, 3], torch.tensor([0, 1, 2, 3]) + 5] the gen_sequence is indexed by vertical column and fed to lm however the prediction of lm is place diagonally with delay to the gen_sequence
condition_tensors=x, # utilisation of the attention mask of txt condition ?
cache_position=cache_position) # [bs, n_draw, 4]
# Fill of next_token should be also placed on antidiagonal [not column]
# Do Not Overwrite 2048 of TRIU/TRIL = START/END => Do Not Fill them by Predicted Tokens
# 0-th antidiagonal should be full of card = [2048, 2048, 2048, 2048]
#
# [2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048, 2048, 2048],
# [2048, 2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048, 2048],
# [2048, 2048, 2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048],
# [2048, 2048, 2048, 2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6]]
# NO OVerWriting
if offset == 0:
next_token[:, :, 1:4] = 2048 # self.card - bottom 3 entries of the antidiagonal should remain 2048
elif offset == 1:
next_token[:, :, 2:4] = 2048 # bottom 2 entries of the antidiagonal should remain 2048
elif offset == 2:
next_token[:, :, 3:4] = 2048
elif offset == max_tokens:
next_token[:, :, 0:1] = 2048 # top 1 entry of the antidiagonal should stay to 2048
elif offset == (max_tokens + 1):
next_token[:, :, 0:2] = 2048
elif offset == (max_tokens + 2):
next_token[:, :, 0:3] = 2048
else: # offset 3,4,5,6,7...... max_tokens-1 # FILL Complete n_q = 4 ANTIDIAGONAL ENTRIES
pass #print('No delete anti-diag')
out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token
# Sink Attn
if (offset > 0) and (offset % 71) == 0:
n_preserve = 4
self.transformer._flush(n_preserve=n_preserve)
cache_position = n_preserve
else:
cache_position += 1
# [bs, n_draw, 4, time+xtra] -> [bs, 4, n_draw, time] -> [bs, 4, time * n_draw]
out_codes = out_codes[:, :, :, 4:max_tokens+4].transpose(1, 2).reshape(bs, 4, self.n_draw * max_tokens)
# flush for next API call
self.transformer._flush()
return out_codes # SKIP THE 4 fill 2048