pattern no overwrite 2048
Browse files- audiocraft/builders.py +3 -3
- audiocraft/lm.py +79 -84
- audiocraft/transformer.py +5 -12
- models.py +1 -2
audiocraft/builders.py
CHANGED
@@ -11,7 +11,7 @@ from .lm import LMModel
|
|
11 |
from .seanet import SEANetDecoder
|
12 |
from .vq import ResidualVectorQuantizer
|
13 |
|
14 |
-
N_REPEAT =
|
15 |
|
16 |
def _shift(x):
|
17 |
n = x.shape[0]
|
@@ -57,7 +57,7 @@ class AudioGen(nn.Module):
|
|
57 |
):
|
58 |
|
59 |
with torch.no_grad():
|
60 |
-
print(duration / N_REPEAT * self.compression_model.frame_rate, 'DURATION TOKENS AudioGen')
|
61 |
|
62 |
gen_tokens = self.lm.generate(
|
63 |
descriptions=[descriptions] * N_REPEAT,
|
@@ -166,7 +166,7 @@ class AudioGen(nn.Module):
|
|
166 |
_best = pkg['best_state']
|
167 |
_best['condition_provider.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')
|
168 |
_best['condition_provider.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')
|
169 |
-
model.load_state_dict(pkg['best_state'])
|
170 |
# model.cfg = cfg
|
171 |
self.lm = model.to(torch.float)
|
172 |
|
|
|
11 |
from .seanet import SEANetDecoder
|
12 |
from .vq import ResidualVectorQuantizer
|
13 |
|
14 |
+
N_REPEAT = 7 # num (virtual batch_size) clones of audio sounds
|
15 |
|
16 |
def _shift(x):
|
17 |
n = x.shape[0]
|
|
|
57 |
):
|
58 |
|
59 |
with torch.no_grad():
|
60 |
+
print('\nCUSTOM\n',int(duration / N_REPEAT * self.compression_model.frame_rate), 'DURATION TOKENS AudioGen')
|
61 |
|
62 |
gen_tokens = self.lm.generate(
|
63 |
descriptions=[descriptions] * N_REPEAT,
|
|
|
166 |
_best = pkg['best_state']
|
167 |
_best['condition_provider.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')
|
168 |
_best['condition_provider.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')
|
169 |
+
model.load_state_dict(pkg['best_state'], strict=True)
|
170 |
# model.cfg = cfg
|
171 |
self.lm = model.to(torch.float)
|
172 |
|
audiocraft/lm.py
CHANGED
@@ -7,7 +7,7 @@ import numpy as np
|
|
7 |
|
8 |
|
9 |
class LMModel(nn.Module):
|
10 |
-
|
11 |
def __init__(self,
|
12 |
n_q = 4,
|
13 |
card = 2048,
|
@@ -16,28 +16,25 @@ class LMModel(nn.Module):
|
|
16 |
hidden_scale = 4, # FFN of Transformer
|
17 |
):
|
18 |
super().__init__()
|
19 |
-
self.condition_provider = T5Conditioner(name='t5-large',
|
20 |
output_dim=dim)
|
21 |
self.card = card # 2048 ?
|
22 |
-
self.n_draw =
|
23 |
# the batch is more expensive than n_draw as it re-runs the model bs times
|
24 |
# n_draw just draws more phonemes from the multinomial - after running the lm
|
25 |
embed_dim = self.card + 1
|
26 |
self.n_q = n_q
|
27 |
self.dim = dim
|
28 |
-
self.emb = nn.ModuleList([nn.Embedding(embed_dim, dim) for _ in range(n_q)]) # EMBEDDING HAS 2049
|
29 |
self.transformer = StreamingTransformer(
|
30 |
-
d_model=dim,
|
31 |
-
num_heads=num_heads,
|
32 |
dim_feedforward=int(hidden_scale * dim),
|
33 |
num_layers=48,
|
34 |
positional_embedding='sin',
|
35 |
)
|
36 |
self.out_norm = nn.LayerNorm(dim, eps=1e-5)
|
37 |
self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=False) for _ in range(n_q)]) # LINEAR DOESNT HAVE 2049
|
38 |
-
# self._init_weights(weight_init, depthwise_init, zero_bias_init)
|
39 |
-
# self.__dict__['_fsdp'] = None
|
40 |
-
|
41 |
|
42 |
def forward(self,
|
43 |
sequence,
|
@@ -47,116 +44,114 @@ class LMModel(nn.Module):
|
|
47 |
bs, n_q, time_frames = sequence.shape # [bs, 4, time]
|
48 |
|
49 |
input_ = sum([self.emb[k](sequence[:, k]) for k in range(self.n_q)])
|
50 |
-
|
51 |
-
# duplicate null condition (bs x 2)
|
52 |
-
|
53 |
-
out = self.transformer(torch.cat([input_, input_], 0),
|
54 |
cross_attention_src=condition_tensors,
|
55 |
token_count=token_count
|
56 |
)
|
57 |
-
|
58 |
-
if self.out_norm:
|
59 |
-
out = self.out_norm(out)
|
60 |
|
61 |
-
logits = torch.stack([self.linears[k](out) for k in range(self.n_q)], dim=1)#[2*bs,4,1,2048]
|
62 |
-
|
63 |
logits = 3 * logits[:bs, :, :, :] - 2 * logits[bs:, :, :, :] # [3, 4, 1, 2048]
|
64 |
-
|
|
|
65 |
# SAMPLE TOP K
|
66 |
k = 400 # 450 is nice sound still train honk is clear!
|
67 |
p = torch.softmax(logits, dim=3)
|
68 |
top_k_value, _ = torch.topk(p, k, dim=3) # [3, 4, 1, k]
|
69 |
-
min_value_top_k = top_k_value[:, :, :, -1:]
|
70 |
p *= (p >= min_value_top_k).float() # zero low probs
|
71 |
p.div_(p.sum(dim=-1, keepdim=True)) # renormalise on non-zero probs
|
72 |
|
73 |
-
|
74 |
# BRING THE nq = 4 IN BATCH
|
75 |
p = p.reshape(bs * self.n_q, 2048)
|
76 |
out = torch.multinomial(p, # p=[bs,2048], out=[bs, num_samples]
|
77 |
num_samples=self.n_draw,
|
78 |
replacement=True) # [bs*4, self.n_draw]
|
|
|
79 |
return out.reshape(bs, self.n_q, self.n_draw).transpose(1,2) # [bs=3not6, self.n_draw, 4]
|
80 |
|
81 |
@torch.no_grad()
|
82 |
def generate(self,
|
83 |
descriptions = ['windy day', 'rain storm'],
|
84 |
-
max_tokens =
|
85 |
-
|
86 |
text_condition = self.condition_provider(descriptions)
|
87 |
-
|
88 |
-
# NULL CONDITION
|
89 |
-
# text_condition = cfg_conditions['description'][0]
|
90 |
bs, _, _ = text_condition.shape
|
91 |
text_condition = torch.cat(
|
92 |
[
|
93 |
text_condition,
|
94 |
torch.zeros_like(text_condition)
|
95 |
], 0)
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
out_codes = torch.full((bs, self.n_draw, 4, 4 + max_tokens), # 4 + max_tokens to have sufficient to index the 1st antidiagonal of 4x4
|
102 |
self.card,
|
103 |
-
dtype=torch.long,
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
# [2048, 0, 1, 2, 3, 4, 5, 6, 2048, 2048, 2048],
|
109 |
-
# [2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048, 2048],
|
110 |
-
# [2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048],
|
111 |
-
# [2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6]],
|
112 |
-
#
|
113 |
-
# out codes = un-delayed
|
114 |
-
#
|
115 |
-
# tensor([[0, 1, 2, 3, 4, 5, 6],
|
116 |
-
# [0, 1, 2, 3, 4, 5, 6],
|
117 |
-
# [0, 1, 2, 3, 4, 5, 6],
|
118 |
-
# [0, 1, 2, 3, 4, 5, 6]])
|
119 |
-
#
|
120 |
-
# LM "sees" 4 delayed tokens (diagonal extract)
|
121 |
-
#
|
122 |
-
# SO THE FIRST pack of 4 tokens fed TO LM is [2048, 2048, 2048, 2048]
|
123 |
-
#
|
124 |
-
# IF WE START WITH
|
125 |
-
# 2048 2048 2048 2048
|
126 |
-
# 2048 2048 2048 2048
|
127 |
-
# 2048 2048 2048 2048
|
128 |
-
# 2048 2048 2048 2048
|
129 |
-
#
|
130 |
-
# THE 2nd token pack of 4 fed to LM is [10, 20, 50, 7]
|
131 |
-
#
|
132 |
-
# 2048 2048 2048 2048 10
|
133 |
-
# 2048 2048 2048 2048 20
|
134 |
-
# 2048 2048 2048 2048 50
|
135 |
-
# 2048 2048 2048 2048 7
|
136 |
-
#
|
137 |
# extract diagonal via indexing out_codes[ [0, 1, 2, 3], [0, 1, 2, 3] ]
|
138 |
-
#
|
139 |
-
# forward duplicates the query to nullcond - then cfg & returns deduplicate token
|
140 |
-
# only 0 (1st token of n_draw is continued by LM call - rest is supersampled in torch.multinomial)
|
141 |
-
|
142 |
-
# feeds the antidiagonal to LM
|
143 |
next_token = self.forward(out_codes[:, 0, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset][:, :, None], # index diagonal & exapnd to [bs, n_q, dur=1]
|
144 |
#gen_sequence[:, 0, :, offset-1:offset], # DIAGINDEXING for setting prediction of lm into gen_sequence THE GENSEQUENCE has to be un-delayed in the end [Because it has to be de-delayed for the vocoder then is actually only the lm input that requires to see the delay thus we could just feed by diaggather] so it matches gen_codes -1 a[[0, 1, 2, 3], torch.tensor([0, 1, 2, 3]) + 5] the gen_sequence is indexed by vertical column and fed to lm however the prediction of lm is place diagonally with delay to the gen_sequence
|
145 |
condition_tensors=text_condition, # utilisation of the attention mask of txt condition ?
|
146 |
-
token_count=offset) # [bs,
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
for lay in self.transformer.layers:
|
159 |
lay.self_attn.k_history = None
|
160 |
-
lay.self_attn.v_history = None
|
161 |
|
162 |
return out_codes # SKIP THE 4 fill 2048 bs*n_draw, duration -> repeat/shift in api.py
|
|
|
7 |
|
8 |
|
9 |
class LMModel(nn.Module):
|
10 |
+
|
11 |
def __init__(self,
|
12 |
n_q = 4,
|
13 |
card = 2048,
|
|
|
16 |
hidden_scale = 4, # FFN of Transformer
|
17 |
):
|
18 |
super().__init__()
|
19 |
+
self.condition_provider = T5Conditioner(name='t5-large',
|
20 |
output_dim=dim)
|
21 |
self.card = card # 2048 ?
|
22 |
+
self.n_draw = 6 # replicate so many times the generation of each text in batch
|
23 |
# the batch is more expensive than n_draw as it re-runs the model bs times
|
24 |
# n_draw just draws more phonemes from the multinomial - after running the lm
|
25 |
embed_dim = self.card + 1
|
26 |
self.n_q = n_q
|
27 |
self.dim = dim
|
28 |
+
self.emb = nn.ModuleList([nn.Embedding(embed_dim, dim) for _ in range(n_q)]) # EMBEDDING HAS 2049
|
29 |
self.transformer = StreamingTransformer(
|
30 |
+
d_model=dim,
|
31 |
+
num_heads=num_heads,
|
32 |
dim_feedforward=int(hidden_scale * dim),
|
33 |
num_layers=48,
|
34 |
positional_embedding='sin',
|
35 |
)
|
36 |
self.out_norm = nn.LayerNorm(dim, eps=1e-5)
|
37 |
self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=False) for _ in range(n_q)]) # LINEAR DOESNT HAVE 2049
|
|
|
|
|
|
|
38 |
|
39 |
def forward(self,
|
40 |
sequence,
|
|
|
44 |
bs, n_q, time_frames = sequence.shape # [bs, 4, time]
|
45 |
|
46 |
input_ = sum([self.emb[k](sequence[:, k]) for k in range(self.n_q)])
|
47 |
+
|
48 |
+
out = self.transformer(torch.cat([input_, input_], 0), # duplicate null condition (bs x 2) for ClassifierFreeGuidance
|
|
|
|
|
49 |
cross_attention_src=condition_tensors,
|
50 |
token_count=token_count
|
51 |
)
|
|
|
|
|
|
|
52 |
|
53 |
+
logits = torch.stack([self.linears[k](self.out_norm(out)) for k in range(self.n_q)], dim=1)#[2*bs,4,1,2048]
|
54 |
+
|
55 |
logits = 3 * logits[:bs, :, :, :] - 2 * logits[bs:, :, :, :] # [3, 4, 1, 2048]
|
56 |
+
|
57 |
+
|
58 |
# SAMPLE TOP K
|
59 |
k = 400 # 450 is nice sound still train honk is clear!
|
60 |
p = torch.softmax(logits, dim=3)
|
61 |
top_k_value, _ = torch.topk(p, k, dim=3) # [3, 4, 1, k]
|
62 |
+
min_value_top_k = top_k_value[:, :, :, -1:]
|
63 |
p *= (p >= min_value_top_k).float() # zero low probs
|
64 |
p.div_(p.sum(dim=-1, keepdim=True)) # renormalise on non-zero probs
|
65 |
|
|
|
66 |
# BRING THE nq = 4 IN BATCH
|
67 |
p = p.reshape(bs * self.n_q, 2048)
|
68 |
out = torch.multinomial(p, # p=[bs,2048], out=[bs, num_samples]
|
69 |
num_samples=self.n_draw,
|
70 |
replacement=True) # [bs*4, self.n_draw]
|
71 |
+
# print('DRAW','c', out)
|
72 |
return out.reshape(bs, self.n_q, self.n_draw).transpose(1,2) # [bs=3not6, self.n_draw, 4]
|
73 |
|
74 |
@torch.no_grad()
|
75 |
def generate(self,
|
76 |
descriptions = ['windy day', 'rain storm'],
|
77 |
+
max_tokens = None):
|
78 |
+
|
79 |
text_condition = self.condition_provider(descriptions)
|
|
|
|
|
|
|
80 |
bs, _, _ = text_condition.shape
|
81 |
text_condition = torch.cat(
|
82 |
[
|
83 |
text_condition,
|
84 |
torch.zeros_like(text_condition)
|
85 |
], 0)
|
86 |
+
out_codes = torch.full((bs,
|
87 |
+
self.n_draw,
|
88 |
+
4,
|
89 |
+
4 + 3 + max_tokens), # 4 + max_tokens + 4-1 to have sufficient to index the 1st antidiagonal of 4x4 + 4 xtra tokens
|
|
|
|
|
90 |
self.card,
|
91 |
+
dtype=torch.long,
|
92 |
+
device=text_condition.device) # [bs, n_draw, 4, dur]
|
93 |
+
# =========================================
|
94 |
+
for offset in range(0, max_tokens + 4 - 1): # max_tokens + n_q - 1
|
95 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
# extract diagonal via indexing out_codes[ [0, 1, 2, 3], [0, 1, 2, 3] ]
|
|
|
|
|
|
|
|
|
|
|
97 |
next_token = self.forward(out_codes[:, 0, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset][:, :, None], # index diagonal & exapnd to [bs, n_q, dur=1]
|
98 |
#gen_sequence[:, 0, :, offset-1:offset], # DIAGINDEXING for setting prediction of lm into gen_sequence THE GENSEQUENCE has to be un-delayed in the end [Because it has to be de-delayed for the vocoder then is actually only the lm input that requires to see the delay thus we could just feed by diaggather] so it matches gen_codes -1 a[[0, 1, 2, 3], torch.tensor([0, 1, 2, 3]) + 5] the gen_sequence is indexed by vertical column and fed to lm however the prediction of lm is place diagonally with delay to the gen_sequence
|
99 |
condition_tensors=text_condition, # utilisation of the attention mask of txt condition ?
|
100 |
+
token_count=offset) # [bs, n_draw, 4]
|
101 |
+
|
102 |
+
# Fill of next_token should be also placed on antidiagonal [not column]
|
103 |
+
|
104 |
+
# Do Not Overwrite 2048 of TRIU/TRIL = START/END => Do Not Fill them by Predicted Tokens
|
105 |
+
# 0-th antidiagonal should be full of card = [2048, 2048, 2048, 2048]
|
106 |
+
#
|
107 |
+
# [2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048, 2048, 2048],
|
108 |
+
# [2048, 2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048, 2048],
|
109 |
+
# [2048, 2048, 2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048],
|
110 |
+
# [2048, 2048, 2048, 2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6]]
|
111 |
+
# NO OVerWriting
|
112 |
+
if offset == 0:
|
113 |
+
|
114 |
+
next_token[:, :, 1:4] = 2048 # self.card
|
115 |
+
|
116 |
+
elif offset == 1:
|
117 |
+
|
118 |
+
next_token[:, :, 2:4] = 2048
|
119 |
+
|
120 |
+
elif offset == 2:
|
121 |
+
|
122 |
+
next_token[:, :, 3:4] = 2048
|
123 |
+
|
124 |
+
elif offset == max_tokens:
|
125 |
+
|
126 |
+
next_token[:, :, 0:1] = 2048
|
127 |
+
|
128 |
+
elif offset == (max_tokens + 1):
|
129 |
+
|
130 |
+
next_token[:, :, 0:1] = 2048
|
131 |
+
|
132 |
+
elif offset == (max_tokens + 2):
|
133 |
+
|
134 |
+
next_token[:, :, 0:2] = 2048
|
135 |
+
|
136 |
+
else: # offset 3,4,5,6,7...... max_tokens-1 # FILL Complete n_q = 4 ANTIDIAGONAL ENTRIES
|
137 |
+
|
138 |
+
pass #print('No delete anti-diag')
|
139 |
+
|
140 |
+
out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token
|
141 |
+
|
142 |
+
# print(out_codes.shape, f'{offset=} \n', out_codes[0:1, 0:1, :, :],'\n______________L_____________________\n')
|
143 |
+
# align 4-rows (shift by 1)
|
144 |
+
# print(out_codes[0, 0, :, :]) # do we pass 2048 to Seanet - There wil result in AtenIndexingError as it has no 2048
|
145 |
+
# out_codes = torch.cat([out_codes[:, :, 0:1, 4:max_tokens+4], # first row starts to be filled at offset = 4
|
146 |
+
# out_codes[:, :, 1:2, 3:max_tokens+3],
|
147 |
+
# out_codes[:, :, 2:3, 2:max_tokens+2],
|
148 |
+
# out_codes[:, :, 3:4, 1:max_tokens+1]], 2)
|
149 |
+
print('\n_____ALIGN____\n', out_codes[0, 0, :, 4:max_tokens+4]) # do we pass 2048 to Seanet - There wil result in AtenIndexingError as it has no 2048
|
150 |
+
|
151 |
+
out_codes = out_codes[:, :, :, 4:max_tokens+4].transpose(1, 2).reshape(bs, 4, self.n_draw * max_tokens) # [bs, 4, duration*n_draw] DISCARD FILL 2048
|
152 |
+
|
153 |
for lay in self.transformer.layers:
|
154 |
lay.self_attn.k_history = None
|
155 |
+
lay.self_attn.v_history = None
|
156 |
|
157 |
return out_codes # SKIP THE 4 fill 2048 bs*n_draw, duration -> repeat/shift in api.py
|
audiocraft/transformer.py
CHANGED
@@ -180,21 +180,14 @@ class StreamingTransformer(nn.Module):
|
|
180 |
x,
|
181 |
token_count=None,
|
182 |
cross_attention_src=None):
|
183 |
-
|
184 |
B, T, C = x.shape
|
185 |
-
|
186 |
-
|
187 |
if self.positional_embedding in ['sin', 'sin_rope']:
|
188 |
-
|
189 |
positions = torch.arange(T, device=x.device).view(1, -1, 1)
|
190 |
-
|
191 |
-
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
|
192 |
x = x + pos_emb
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
for j, lay in enumerate(self.layers):
|
197 |
-
# print(f'Transf Layer{j} {pos_emb.sum()=} {pos_emb.shape=}{x.
|
198 |
-
x = lay(x, cross_attention_src=cross_attention_src) # cross_attention_src = txt-cond
|
199 |
-
|
|
|
200 |
return x
|
|
|
180 |
x,
|
181 |
token_count=None,
|
182 |
cross_attention_src=None):
|
|
|
183 |
B, T, C = x.shape
|
|
|
|
|
184 |
if self.positional_embedding in ['sin', 'sin_rope']:
|
|
|
185 |
positions = torch.arange(T, device=x.device).view(1, -1, 1)
|
186 |
+
pos_emb = create_sin_embedding(positions + token_count, C, max_period=self.max_period, dtype=x.dtype)
|
|
|
187 |
x = x + pos_emb
|
|
|
|
|
|
|
188 |
for j, lay in enumerate(self.layers):
|
189 |
+
# print(f'Transf Layer c{j} {pos_emb.sum()=} {pos_emb.shape=}{x.sum()=}___________________')
|
190 |
+
x = lay(x, cross_attention_src=cross_attention_src) # cross_attention_src = txt-cond x audio
|
191 |
+
# self attn = audio x audio
|
192 |
+
# Every layer (mha) keeps itsw own kv cachE
|
193 |
return x
|
models.py
CHANGED
@@ -357,8 +357,7 @@ class DurationEncoder(nn.Module):
|
|
357 |
|
358 |
for block in self.lstms:
|
359 |
if isinstance(block, AdaLayerNorm):
|
360 |
-
|
361 |
-
print(f'\n=========ENTER ADALAYNORM L479 models.py {x.shape=}, {style.shape=}')
|
362 |
x = block(x, style) # [bs, 75, 512]
|
363 |
x = torch.cat([x.transpose(1, 2), style], axis=1) # [bs, 512, 75]
|
364 |
|
|
|
357 |
|
358 |
for block in self.lstms:
|
359 |
if isinstance(block, AdaLayerNorm):
|
360 |
+
# not LST enters here
|
|
|
361 |
x = block(x, style) # [bs, 75, 512]
|
362 |
x = torch.cat([x.transpose(1, 2), style], axis=1) # [bs, 512, 75]
|
363 |
|