Dionyssos commited on
Commit
999bd69
·
1 Parent(s): 17a68db

pattern no overwrite 2048

Browse files
audiocraft/builders.py CHANGED
@@ -11,7 +11,7 @@ from .lm import LMModel
11
  from .seanet import SEANetDecoder
12
  from .vq import ResidualVectorQuantizer
13
 
14
- N_REPEAT = 3 # num (virtual batch_size) clones of audio sounds
15
 
16
  def _shift(x):
17
  n = x.shape[0]
@@ -57,7 +57,7 @@ class AudioGen(nn.Module):
57
  ):
58
 
59
  with torch.no_grad():
60
- print(duration / N_REPEAT * self.compression_model.frame_rate, 'DURATION TOKENS AudioGen')
61
 
62
  gen_tokens = self.lm.generate(
63
  descriptions=[descriptions] * N_REPEAT,
@@ -166,7 +166,7 @@ class AudioGen(nn.Module):
166
  _best = pkg['best_state']
167
  _best['condition_provider.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')
168
  _best['condition_provider.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')
169
- model.load_state_dict(pkg['best_state'])
170
  # model.cfg = cfg
171
  self.lm = model.to(torch.float)
172
 
 
11
  from .seanet import SEANetDecoder
12
  from .vq import ResidualVectorQuantizer
13
 
14
+ N_REPEAT = 7 # num (virtual batch_size) clones of audio sounds
15
 
16
  def _shift(x):
17
  n = x.shape[0]
 
57
  ):
58
 
59
  with torch.no_grad():
60
+ print('\nCUSTOM\n',int(duration / N_REPEAT * self.compression_model.frame_rate), 'DURATION TOKENS AudioGen')
61
 
62
  gen_tokens = self.lm.generate(
63
  descriptions=[descriptions] * N_REPEAT,
 
166
  _best = pkg['best_state']
167
  _best['condition_provider.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')
168
  _best['condition_provider.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')
169
+ model.load_state_dict(pkg['best_state'], strict=True)
170
  # model.cfg = cfg
171
  self.lm = model.to(torch.float)
172
 
audiocraft/lm.py CHANGED
@@ -7,7 +7,7 @@ import numpy as np
7
 
8
 
9
  class LMModel(nn.Module):
10
-
11
  def __init__(self,
12
  n_q = 4,
13
  card = 2048,
@@ -16,28 +16,25 @@ class LMModel(nn.Module):
16
  hidden_scale = 4, # FFN of Transformer
17
  ):
18
  super().__init__()
19
- self.condition_provider = T5Conditioner(name='t5-large',
20
  output_dim=dim)
21
  self.card = card # 2048 ?
22
- self.n_draw = 1 # replicate so many times the generation of each text in batch
23
  # the batch is more expensive than n_draw as it re-runs the model bs times
24
  # n_draw just draws more phonemes from the multinomial - after running the lm
25
  embed_dim = self.card + 1
26
  self.n_q = n_q
27
  self.dim = dim
28
- self.emb = nn.ModuleList([nn.Embedding(embed_dim, dim) for _ in range(n_q)]) # EMBEDDING HAS 2049
29
  self.transformer = StreamingTransformer(
30
- d_model=dim,
31
- num_heads=num_heads,
32
  dim_feedforward=int(hidden_scale * dim),
33
  num_layers=48,
34
  positional_embedding='sin',
35
  )
36
  self.out_norm = nn.LayerNorm(dim, eps=1e-5)
37
  self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=False) for _ in range(n_q)]) # LINEAR DOESNT HAVE 2049
38
- # self._init_weights(weight_init, depthwise_init, zero_bias_init)
39
- # self.__dict__['_fsdp'] = None
40
-
41
 
42
  def forward(self,
43
  sequence,
@@ -47,116 +44,114 @@ class LMModel(nn.Module):
47
  bs, n_q, time_frames = sequence.shape # [bs, 4, time]
48
 
49
  input_ = sum([self.emb[k](sequence[:, k]) for k in range(self.n_q)])
50
-
51
- # duplicate null condition (bs x 2)
52
-
53
- out = self.transformer(torch.cat([input_, input_], 0),
54
  cross_attention_src=condition_tensors,
55
  token_count=token_count
56
  )
57
-
58
- if self.out_norm:
59
- out = self.out_norm(out)
60
 
61
- logits = torch.stack([self.linears[k](out) for k in range(self.n_q)], dim=1)#[2*bs,4,1,2048]
62
-
63
  logits = 3 * logits[:bs, :, :, :] - 2 * logits[bs:, :, :, :] # [3, 4, 1, 2048]
64
-
 
65
  # SAMPLE TOP K
66
  k = 400 # 450 is nice sound still train honk is clear!
67
  p = torch.softmax(logits, dim=3)
68
  top_k_value, _ = torch.topk(p, k, dim=3) # [3, 4, 1, k]
69
- min_value_top_k = top_k_value[:, :, :, -1:]
70
  p *= (p >= min_value_top_k).float() # zero low probs
71
  p.div_(p.sum(dim=-1, keepdim=True)) # renormalise on non-zero probs
72
 
73
-
74
  # BRING THE nq = 4 IN BATCH
75
  p = p.reshape(bs * self.n_q, 2048)
76
  out = torch.multinomial(p, # p=[bs,2048], out=[bs, num_samples]
77
  num_samples=self.n_draw,
78
  replacement=True) # [bs*4, self.n_draw]
 
79
  return out.reshape(bs, self.n_q, self.n_draw).transpose(1,2) # [bs=3not6, self.n_draw, 4]
80
 
81
  @torch.no_grad()
82
  def generate(self,
83
  descriptions = ['windy day', 'rain storm'],
84
- max_tokens = 256):
85
-
86
  text_condition = self.condition_provider(descriptions)
87
-
88
- # NULL CONDITION
89
- # text_condition = cfg_conditions['description'][0]
90
  bs, _, _ = text_condition.shape
91
  text_condition = torch.cat(
92
  [
93
  text_condition,
94
  torch.zeros_like(text_condition)
95
  ], 0)
96
-
97
-
98
-
99
-
100
-
101
- out_codes = torch.full((bs, self.n_draw, 4, 4 + max_tokens), # 4 + max_tokens to have sufficient to index the 1st antidiagonal of 4x4
102
  self.card,
103
- dtype=torch.long,device=text_condition.device) # bs,n_draw,4,dur
104
- for offset in range(0, max_tokens):
105
-
106
- # GEN_SEQUENCE has fillers start & end = 2048
107
- # [6,4,74] = gen_sequence = torch.tensor([[[
108
- # [2048, 0, 1, 2, 3, 4, 5, 6, 2048, 2048, 2048],
109
- # [2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048, 2048],
110
- # [2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048],
111
- # [2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6]],
112
- #
113
- # out codes = un-delayed
114
- #
115
- # tensor([[0, 1, 2, 3, 4, 5, 6],
116
- # [0, 1, 2, 3, 4, 5, 6],
117
- # [0, 1, 2, 3, 4, 5, 6],
118
- # [0, 1, 2, 3, 4, 5, 6]])
119
- #
120
- # LM "sees" 4 delayed tokens (diagonal extract)
121
- #
122
- # SO THE FIRST pack of 4 tokens fed TO LM is [2048, 2048, 2048, 2048]
123
- #
124
- # IF WE START WITH
125
- # 2048 2048 2048 2048
126
- # 2048 2048 2048 2048
127
- # 2048 2048 2048 2048
128
- # 2048 2048 2048 2048
129
- #
130
- # THE 2nd token pack of 4 fed to LM is [10, 20, 50, 7]
131
- #
132
- # 2048 2048 2048 2048 10
133
- # 2048 2048 2048 2048 20
134
- # 2048 2048 2048 2048 50
135
- # 2048 2048 2048 2048 7
136
- #
137
  # extract diagonal via indexing out_codes[ [0, 1, 2, 3], [0, 1, 2, 3] ]
138
- #
139
- # forward duplicates the query to nullcond - then cfg & returns deduplicate token
140
- # only 0 (1st token of n_draw is continued by LM call - rest is supersampled in torch.multinomial)
141
-
142
- # feeds the antidiagonal to LM
143
  next_token = self.forward(out_codes[:, 0, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset][:, :, None], # index diagonal & exapnd to [bs, n_q, dur=1]
144
  #gen_sequence[:, 0, :, offset-1:offset], # DIAGINDEXING for setting prediction of lm into gen_sequence THE GENSEQUENCE has to be un-delayed in the end [Because it has to be de-delayed for the vocoder then is actually only the lm input that requires to see the delay thus we could just feed by diaggather] so it matches gen_codes -1 a[[0, 1, 2, 3], torch.tensor([0, 1, 2, 3]) + 5] the gen_sequence is indexed by vertical column and fed to lm however the prediction of lm is place diagonally with delay to the gen_sequence
145
  condition_tensors=text_condition, # utilisation of the attention mask of txt condition ?
146
- token_count=offset) # [bs, 4, 1, 2048]
147
-
148
-
149
-
150
- out_codes[:, :, :, offset + 4] = next_token # [bs, n_draw, 4, duration]
151
-
152
- # DISCARD FILL
153
-
154
- out_codes = out_codes[:, :, :, 4:].transpose(1, 2).reshape(bs, 4, self.n_draw * max_tokens) # [bs, 4, duration*n_draw] DISCARD FILL 2048
155
-
156
-
157
- # Clear k/v cache (Different kv is saved by every 48x selfattn)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  for lay in self.transformer.layers:
159
  lay.self_attn.k_history = None
160
- lay.self_attn.v_history = None
161
 
162
  return out_codes # SKIP THE 4 fill 2048 bs*n_draw, duration -> repeat/shift in api.py
 
7
 
8
 
9
  class LMModel(nn.Module):
10
+
11
  def __init__(self,
12
  n_q = 4,
13
  card = 2048,
 
16
  hidden_scale = 4, # FFN of Transformer
17
  ):
18
  super().__init__()
19
+ self.condition_provider = T5Conditioner(name='t5-large',
20
  output_dim=dim)
21
  self.card = card # 2048 ?
22
+ self.n_draw = 6 # replicate so many times the generation of each text in batch
23
  # the batch is more expensive than n_draw as it re-runs the model bs times
24
  # n_draw just draws more phonemes from the multinomial - after running the lm
25
  embed_dim = self.card + 1
26
  self.n_q = n_q
27
  self.dim = dim
28
+ self.emb = nn.ModuleList([nn.Embedding(embed_dim, dim) for _ in range(n_q)]) # EMBEDDING HAS 2049
29
  self.transformer = StreamingTransformer(
30
+ d_model=dim,
31
+ num_heads=num_heads,
32
  dim_feedforward=int(hidden_scale * dim),
33
  num_layers=48,
34
  positional_embedding='sin',
35
  )
36
  self.out_norm = nn.LayerNorm(dim, eps=1e-5)
37
  self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=False) for _ in range(n_q)]) # LINEAR DOESNT HAVE 2049
 
 
 
38
 
39
  def forward(self,
40
  sequence,
 
44
  bs, n_q, time_frames = sequence.shape # [bs, 4, time]
45
 
46
  input_ = sum([self.emb[k](sequence[:, k]) for k in range(self.n_q)])
47
+
48
+ out = self.transformer(torch.cat([input_, input_], 0), # duplicate null condition (bs x 2) for ClassifierFreeGuidance
 
 
49
  cross_attention_src=condition_tensors,
50
  token_count=token_count
51
  )
 
 
 
52
 
53
+ logits = torch.stack([self.linears[k](self.out_norm(out)) for k in range(self.n_q)], dim=1)#[2*bs,4,1,2048]
54
+
55
  logits = 3 * logits[:bs, :, :, :] - 2 * logits[bs:, :, :, :] # [3, 4, 1, 2048]
56
+
57
+
58
  # SAMPLE TOP K
59
  k = 400 # 450 is nice sound still train honk is clear!
60
  p = torch.softmax(logits, dim=3)
61
  top_k_value, _ = torch.topk(p, k, dim=3) # [3, 4, 1, k]
62
+ min_value_top_k = top_k_value[:, :, :, -1:]
63
  p *= (p >= min_value_top_k).float() # zero low probs
64
  p.div_(p.sum(dim=-1, keepdim=True)) # renormalise on non-zero probs
65
 
 
66
  # BRING THE nq = 4 IN BATCH
67
  p = p.reshape(bs * self.n_q, 2048)
68
  out = torch.multinomial(p, # p=[bs,2048], out=[bs, num_samples]
69
  num_samples=self.n_draw,
70
  replacement=True) # [bs*4, self.n_draw]
71
+ # print('DRAW','c', out)
72
  return out.reshape(bs, self.n_q, self.n_draw).transpose(1,2) # [bs=3not6, self.n_draw, 4]
73
 
74
  @torch.no_grad()
75
  def generate(self,
76
  descriptions = ['windy day', 'rain storm'],
77
+ max_tokens = None):
78
+
79
  text_condition = self.condition_provider(descriptions)
 
 
 
80
  bs, _, _ = text_condition.shape
81
  text_condition = torch.cat(
82
  [
83
  text_condition,
84
  torch.zeros_like(text_condition)
85
  ], 0)
86
+ out_codes = torch.full((bs,
87
+ self.n_draw,
88
+ 4,
89
+ 4 + 3 + max_tokens), # 4 + max_tokens + 4-1 to have sufficient to index the 1st antidiagonal of 4x4 + 4 xtra tokens
 
 
90
  self.card,
91
+ dtype=torch.long,
92
+ device=text_condition.device) # [bs, n_draw, 4, dur]
93
+ # =========================================
94
+ for offset in range(0, max_tokens + 4 - 1): # max_tokens + n_q - 1
95
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  # extract diagonal via indexing out_codes[ [0, 1, 2, 3], [0, 1, 2, 3] ]
 
 
 
 
 
97
  next_token = self.forward(out_codes[:, 0, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset][:, :, None], # index diagonal & exapnd to [bs, n_q, dur=1]
98
  #gen_sequence[:, 0, :, offset-1:offset], # DIAGINDEXING for setting prediction of lm into gen_sequence THE GENSEQUENCE has to be un-delayed in the end [Because it has to be de-delayed for the vocoder then is actually only the lm input that requires to see the delay thus we could just feed by diaggather] so it matches gen_codes -1 a[[0, 1, 2, 3], torch.tensor([0, 1, 2, 3]) + 5] the gen_sequence is indexed by vertical column and fed to lm however the prediction of lm is place diagonally with delay to the gen_sequence
99
  condition_tensors=text_condition, # utilisation of the attention mask of txt condition ?
100
+ token_count=offset) # [bs, n_draw, 4]
101
+
102
+ # Fill of next_token should be also placed on antidiagonal [not column]
103
+
104
+ # Do Not Overwrite 2048 of TRIU/TRIL = START/END => Do Not Fill them by Predicted Tokens
105
+ # 0-th antidiagonal should be full of card = [2048, 2048, 2048, 2048]
106
+ #
107
+ # [2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048, 2048, 2048],
108
+ # [2048, 2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048, 2048],
109
+ # [2048, 2048, 2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048],
110
+ # [2048, 2048, 2048, 2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6]]
111
+ # NO OVerWriting
112
+ if offset == 0:
113
+
114
+ next_token[:, :, 1:4] = 2048 # self.card
115
+
116
+ elif offset == 1:
117
+
118
+ next_token[:, :, 2:4] = 2048
119
+
120
+ elif offset == 2:
121
+
122
+ next_token[:, :, 3:4] = 2048
123
+
124
+ elif offset == max_tokens:
125
+
126
+ next_token[:, :, 0:1] = 2048
127
+
128
+ elif offset == (max_tokens + 1):
129
+
130
+ next_token[:, :, 0:1] = 2048
131
+
132
+ elif offset == (max_tokens + 2):
133
+
134
+ next_token[:, :, 0:2] = 2048
135
+
136
+ else: # offset 3,4,5,6,7...... max_tokens-1 # FILL Complete n_q = 4 ANTIDIAGONAL ENTRIES
137
+
138
+ pass #print('No delete anti-diag')
139
+
140
+ out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token
141
+
142
+ # print(out_codes.shape, f'{offset=} \n', out_codes[0:1, 0:1, :, :],'\n______________L_____________________\n')
143
+ # align 4-rows (shift by 1)
144
+ # print(out_codes[0, 0, :, :]) # do we pass 2048 to Seanet - There wil result in AtenIndexingError as it has no 2048
145
+ # out_codes = torch.cat([out_codes[:, :, 0:1, 4:max_tokens+4], # first row starts to be filled at offset = 4
146
+ # out_codes[:, :, 1:2, 3:max_tokens+3],
147
+ # out_codes[:, :, 2:3, 2:max_tokens+2],
148
+ # out_codes[:, :, 3:4, 1:max_tokens+1]], 2)
149
+ print('\n_____ALIGN____\n', out_codes[0, 0, :, 4:max_tokens+4]) # do we pass 2048 to Seanet - There wil result in AtenIndexingError as it has no 2048
150
+
151
+ out_codes = out_codes[:, :, :, 4:max_tokens+4].transpose(1, 2).reshape(bs, 4, self.n_draw * max_tokens) # [bs, 4, duration*n_draw] DISCARD FILL 2048
152
+
153
  for lay in self.transformer.layers:
154
  lay.self_attn.k_history = None
155
+ lay.self_attn.v_history = None
156
 
157
  return out_codes # SKIP THE 4 fill 2048 bs*n_draw, duration -> repeat/shift in api.py
audiocraft/transformer.py CHANGED
@@ -180,21 +180,14 @@ class StreamingTransformer(nn.Module):
180
  x,
181
  token_count=None,
182
  cross_attention_src=None):
183
-
184
  B, T, C = x.shape
185
-
186
-
187
  if self.positional_embedding in ['sin', 'sin_rope']:
188
-
189
  positions = torch.arange(T, device=x.device).view(1, -1, 1)
190
- positions = positions + token_count #offsets.view(-1, 1, 1)
191
- pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
192
  x = x + pos_emb
193
-
194
-
195
-
196
  for j, lay in enumerate(self.layers):
197
- # print(f'Transf Layer{j} {pos_emb.sum()=} {pos_emb.shape=}{x.shape=}___________________')
198
- x = lay(x, cross_attention_src=cross_attention_src) # cross_attention_src = txt-cond
199
- # each layer (mha) keeps history of its own k,v for all tokens
 
200
  return x
 
180
  x,
181
  token_count=None,
182
  cross_attention_src=None):
 
183
  B, T, C = x.shape
 
 
184
  if self.positional_embedding in ['sin', 'sin_rope']:
 
185
  positions = torch.arange(T, device=x.device).view(1, -1, 1)
186
+ pos_emb = create_sin_embedding(positions + token_count, C, max_period=self.max_period, dtype=x.dtype)
 
187
  x = x + pos_emb
 
 
 
188
  for j, lay in enumerate(self.layers):
189
+ # print(f'Transf Layer c{j} {pos_emb.sum()=} {pos_emb.shape=}{x.sum()=}___________________')
190
+ x = lay(x, cross_attention_src=cross_attention_src) # cross_attention_src = txt-cond x audio
191
+ # self attn = audio x audio
192
+ # Every layer (mha) keeps itsw own kv cachE
193
  return x
models.py CHANGED
@@ -357,8 +357,7 @@ class DurationEncoder(nn.Module):
357
 
358
  for block in self.lstms:
359
  if isinstance(block, AdaLayerNorm):
360
-
361
- print(f'\n=========ENTER ADALAYNORM L479 models.py {x.shape=}, {style.shape=}')
362
  x = block(x, style) # [bs, 75, 512]
363
  x = torch.cat([x.transpose(1, 2), style], axis=1) # [bs, 512, 75]
364
 
 
357
 
358
  for block in self.lstms:
359
  if isinstance(block, AdaLayerNorm):
360
+ # not LST enters here
 
361
  x = block(x, style) # [bs, 75, 512]
362
  x = torch.cat([x.transpose(1, 2), style], axis=1) # [bs, 512, 75]
363