|
import torch |
|
from audiocraft.transformer import StreamingTransformer |
|
from torch import nn |
|
from transformers import T5EncoderModel, T5Tokenizer |
|
|
|
class T5(nn.Module): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
self.output_proj = nn.Linear(1024, |
|
1536) |
|
self.t5_tokenizer = T5Tokenizer.from_pretrained('t5-large', legacy=True) |
|
t5 = T5EncoderModel.from_pretrained('t5-large').train(mode=False) |
|
|
|
|
|
|
|
self.__dict__['t5'] = t5.to('cuda:0') |
|
|
|
def forward(self, prompt): |
|
with torch.set_grad_enabled(False), torch.autocast(device_type='cuda', dtype=torch.float32): |
|
|
|
bs = len(prompt) // 2 |
|
d = self.t5_tokenizer(prompt, |
|
return_tensors='pt', |
|
padding=True).to(self.output_proj.bias.device) |
|
d['attention_mask'][bs:, :] = 0 |
|
|
|
x = self.t5(input_ids=d['input_ids'], |
|
attention_mask=d['attention_mask']).last_hidden_state |
|
|
|
|
|
x = self.output_proj(x) |
|
x[bs:, :, :] = 0 |
|
return x |
|
|
|
|
|
class LMModel(nn.Module): |
|
|
|
def __init__(self, |
|
n_q = 4, |
|
card = 2048, |
|
dim = 1536 |
|
): |
|
super().__init__() |
|
self.t5 = T5() |
|
self.card = card |
|
self.n_draw = 1 |
|
|
|
self.emb = nn.ModuleList([nn.Embedding(self.card + 1, dim) for _ in range(n_q)]) |
|
self.transformer = StreamingTransformer() |
|
self.out_norm = nn.LayerNorm(dim, eps=1e-5) |
|
self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=False) for _ in range(n_q)]) |
|
|
|
def forward(self, |
|
sequence, |
|
condition_tensors=None, |
|
cache_position=None): |
|
|
|
bs, n_q, time_frames = sequence.shape |
|
|
|
input_ = sum([self.emb[k](sequence[:, k]) for k in range(n_q)]) |
|
|
|
out = self.transformer(torch.cat([input_, input_], 0), |
|
cross_attention_src=condition_tensors, |
|
cache_position=cache_position |
|
) |
|
|
|
logits = torch.stack([self.linears[k](self.out_norm(out)) for k in range(n_q)], dim=1) |
|
logits = 3 * logits[:bs, :, :, :] - self._scale * logits[bs:, :, :, :] |
|
|
|
k = 24 |
|
logits = torch.softmax(logits / 1.0, dim=3) |
|
p, ix = torch.topk(logits, k, dim=3) |
|
|
|
deflation = torch.empty_like(p).exponential_(lambd=1) |
|
p = p / deflation |
|
|
|
p = p.argmax(dim=3, keepdim=True) |
|
tok = ix.gather(dim=3, index=p).to(torch.int64) |
|
return tok[:, :, :, 0].transpose(1, 2) |
|
|
|
@torch.no_grad() |
|
def generate(self, |
|
max_tokens=None, |
|
text_condition=None): |
|
x = self.t5(text_condition) |
|
bs = x.shape[0] // 2 |
|
self._scale = .3 * torch.rand(1, 1, self.n_draw, 1, device=x.device) + 1.94 |
|
cache_position = 0 |
|
|
|
out_codes = torch.full((bs, |
|
self.n_draw, |
|
4, |
|
4 + 3 + max_tokens), |
|
self.card, |
|
dtype=torch.long, |
|
device=x.device) |
|
|
|
|
|
for offset in range(0, max_tokens + 4 - 1): |
|
|
|
|
|
next_token = self.forward(out_codes[:, 0, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset][:, :, None], |
|
|
|
condition_tensors=x, |
|
cache_position=cache_position) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if offset == 0: |
|
|
|
next_token[:, :, 1:4] = 2048 |
|
|
|
elif offset == 1: |
|
|
|
next_token[:, :, 2:4] = 2048 |
|
|
|
elif offset == 2: |
|
|
|
next_token[:, :, 3:4] = 2048 |
|
|
|
elif offset == max_tokens: |
|
|
|
next_token[:, :, 0:1] = 2048 |
|
|
|
elif offset == (max_tokens + 1): |
|
|
|
next_token[:, :, 0:2] = 2048 |
|
|
|
elif offset == (max_tokens + 2): |
|
|
|
next_token[:, :, 0:3] = 2048 |
|
|
|
else: |
|
|
|
pass |
|
|
|
out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token |
|
|
|
if (offset > 0) and (offset % 71) == 0: |
|
n_preserve = 4 |
|
self.transformer._flush(n_preserve=n_preserve) |
|
cache_position = n_preserve |
|
else: |
|
cache_position += 1 |
|
|
|
|
|
out_codes = out_codes[:, :, :, 4:max_tokens+4].transpose(1, 2).reshape(bs, 4, self.n_draw * max_tokens) |
|
|
|
|
|
self.transformer._flush() |
|
|
|
return out_codes |
|
|