Dionyssos commited on
Commit
376e6a0
·
1 Parent(s): b399825

antidiagonal fast delay pattern

Browse files
audiocraft/builders.py CHANGED
@@ -64,14 +64,14 @@ class AudioGen(nn.Module):
64
  with torch.no_grad():
65
  gen_tokens = self.lm.generate(
66
  descriptions=[descriptions]*3,
67
- max_gen_len=int(self.duration * self.frame_rate)) # [bs, 4, 37 * self.lm.n_draw]
68
  x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
69
 
70
- x = x[:, 0, :-250] # last samples have splash sounds DISCARD 25000 last samples
71
 
72
  # AudioGen 16KHZ / StyleTTS2 24 KHz / MMSTTS 24 KHz
73
 
74
- # x = self.resample_fn(x)
75
 
76
  # batch size = different sounds for same txt
77
 
 
64
  with torch.no_grad():
65
  gen_tokens = self.lm.generate(
66
  descriptions=[descriptions]*3,
67
+ max_tokens=int(self.duration * self.frame_rate)) # [bs, 4, 37 * self.lm.n_draw]
68
  x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
69
 
70
+ x = x[:, 0, :] # last samples have splash sounds DISCARD 25000 last samples
71
 
72
  # AudioGen 16KHZ / StyleTTS2 24 KHz / MMSTTS 24 KHz
73
 
74
+ x = self.resample_fn(x)
75
 
76
  # batch size = different sounds for same txt
77
 
audiocraft/codebooks_patterns.py DELETED
@@ -1,285 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- from collections import namedtuple
8
- from dataclasses import dataclass
9
-
10
- import logging
11
- import typing as tp
12
- import torch
13
-
14
- LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
15
- PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
16
- logger = logging.getLogger(__name__)
17
-
18
-
19
- @dataclass
20
- class Pattern:
21
- """Base implementation of a pattern over a sequence with multiple codebooks.
22
-
23
- The codebook pattern consists in a layout, defining for each sequence step
24
- the list of coordinates of each codebook timestep in the resulting interleaved sequence.
25
- The first item of the pattern is always an empty list in order to properly insert a special token
26
- to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
27
- and ``timesteps`` the number of timesteps corresponding to the original sequence.
28
-
29
- The pattern provides convenient methods to build and revert interleaved sequences from it:
30
- ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
31
- to the interleaved sequence of shape [B, K, S] applying the pattern, with B being the batch size,
32
- K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
33
- for the output sequence. The unfilled positions are replaced with a special token and the built sequence
34
- is returned along with a mask indicating valid tokens.
35
- ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
36
- of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
37
- to fill and specify invalid positions if needed.
38
- See the dedicated methods for more details.
39
- """
40
- # Pattern layout, for each sequence step, we have a list of coordinates
41
- # corresponding to the original codebook timestep and position.
42
- # The first list is always an empty list in order to properly insert
43
- # a special token to start with.
44
- layout: PatternLayout
45
- timesteps: int
46
- n_q: int
47
-
48
- def __post_init__(self):
49
- # assert len(self.layout) > 0
50
- # self._validate_layout() #
51
- self._build_reverted_sequence_scatter_indexes = self._build_reverted_sequence_scatter_indexes
52
- self._build_pattern_sequence_scatter_indexes = self._build_pattern_sequence_scatter_indexes
53
- print("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
54
-
55
- @property
56
- def max_delay(self):
57
- max_t_in_seq_coords = 0
58
- for seq_coords in self.layout[1:]:
59
- for coords in seq_coords:
60
- max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
61
- return max_t_in_seq_coords - self.timesteps
62
-
63
- @property
64
- def valid_layout(self):
65
- valid_step = len(self.layout) - self.max_delay
66
- return self.layout[:valid_step]
67
-
68
- def starts_with_special_token(self):
69
- return self.layout[0] == []
70
-
71
- def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
72
- """Get codebook coordinates in the layout that corresponds to the specified timestep t
73
- and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
74
- and the actual codebook coordinates.
75
- """
76
- assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
77
- if q is not None:
78
- assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
79
- coords = []
80
- for s, seq_codes in enumerate(self.layout):
81
- for code in seq_codes:
82
- if code.t == t and (q is None or code.q == q):
83
- coords.append((s, code))
84
- return coords
85
-
86
- def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
87
- return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
88
-
89
- def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
90
- steps_with_timesteps = self.get_steps_with_timestep(t, q)
91
- return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
92
-
93
- def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
94
- device: tp.Union[torch.device, str] = 'cpu'):
95
- """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
96
-
97
- Args:
98
- timesteps (int): Maximum number of timesteps steps to consider.
99
- keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
100
- device (torch.device or str): Device for created tensors.
101
- Returns:
102
- indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
103
- mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
104
- """
105
- assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
106
- assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
107
- # use the proper layout based on whether we limit ourselves to valid steps only or not,
108
- # note that using the valid_layout will result in a truncated sequence up to the valid steps
109
- ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
110
- # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
111
- indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
112
- mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
113
- # fill indexes with last sequence step value that will correspond to our special token
114
- # the last value is n_q * timesteps as we have flattened z and append special token as the last token
115
- # which will correspond to the index: n_q * timesteps
116
- indexes[:] = n_q * timesteps
117
- # iterate over the pattern and fill scattered indexes and mask
118
- for s, sequence_coords in enumerate(ref_layout):
119
- for coords in sequence_coords:
120
- if coords.t < timesteps:
121
- indexes[coords.q, s] = coords.t + coords.q * timesteps
122
- mask[coords.q, s] = 1
123
- indexes = torch.from_numpy(indexes).to(device)
124
- mask = torch.from_numpy(mask).to(device)
125
- return indexes, mask
126
-
127
- def build_pattern_sequence(self,
128
- z,
129
- special_token,
130
- keep_only_valid_steps=False):
131
- B, K, T = z.shape
132
- indexes, mask = self._build_pattern_sequence_scatter_indexes(
133
- T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
134
- )
135
- z = z.view(B, -1)
136
- # we append the special token as the last index of our flattened z tensor
137
- z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
138
- values = z[:, indexes.view(-1)]
139
- values = values.view(B, K, indexes.shape[-1])
140
-
141
- # print(values.shape, indexes.shape, mask.shape, 'BUILD PATTERN')
142
- # --
143
- # torch.Size([1, 4, 39]) torch.Size([4, 39]) torch.Size([4, 39]) BUILD PATTERN
144
-
145
- return values, indexes, mask
146
-
147
- def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
148
- keep_only_valid_steps: bool = False,
149
- is_model_output: bool = False,
150
- device: tp.Union[torch.device, str] = 'cpu'):
151
- """Builds scatter indexes required to retrieve the original multi-codebook sequence
152
- from interleaving pattern.
153
-
154
- Args:
155
- sequence_steps (int): Sequence steps.
156
- n_q (int): Number of codebooks.
157
- keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
158
- Steps that are beyond valid steps will be replaced by the special_token in that case.
159
- is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
160
- device (torch.device or str): Device for created tensors.
161
- Returns:
162
- indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T].
163
- mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
164
- """
165
- ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
166
- # TODO(jade): Do we want to further truncate to only valid timesteps here as well?
167
- timesteps = self.timesteps
168
- assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
169
- assert sequence_steps <= len(ref_layout), \
170
- f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
171
-
172
- # ensure we take the appropriate indexes to keep the model output from the first special token as well
173
- if is_model_output and self.starts_with_special_token():
174
- ref_layout = ref_layout[1:]
175
-
176
- # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
177
- indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
178
- mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
179
- # fill indexes with last sequence step value that will correspond to our special token
180
- indexes[:] = n_q * sequence_steps
181
- for s, sequence_codes in enumerate(ref_layout):
182
- if s < sequence_steps:
183
- for code in sequence_codes:
184
- if code.t < timesteps:
185
- indexes[code.q, code.t] = s + code.q * sequence_steps # oh the jump - so are the codes linearised
186
- mask[code.q, code.t] = 1
187
- indexes = torch.from_numpy(indexes).to(device)
188
- mask = torch.from_numpy(mask).to(device)
189
- return indexes, mask
190
-
191
- def revert_pattern_sequence(self,
192
- s,
193
- special_token,
194
- keep_only_valid_steps=False):
195
- """SPECIAL TOKEN NOT DELETED HERE !!!!
196
-
197
- Args:
198
- s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
199
- special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
200
- Returns:
201
- values (torch.Tensor) : Interleaved sequence matching the pattern, of shape [B, K, T] with T
202
- indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
203
- mask (torch.Tensor) : Mask corresponding to indexes that matches valid indexes of shape [K, T].
204
- shall this mask delete special token id;
205
- """
206
- B, K, S = s.shape
207
- indexes, mask = self._build_reverted_sequence_scatter_indexes(
208
- S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
209
- )
210
- s = s.view(B, -1)
211
- # we append the special token as the last index of our flattened z tensor
212
- s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
213
- values = s[:, indexes.view(-1)]
214
- values = values.view(B, K, indexes.shape[-1])
215
-
216
- return values, indexes, mask
217
-
218
-
219
-
220
-
221
-
222
- class DelayedPatternProvider():
223
- """Provider for delayed pattern across delayed codebooks.
224
- Codebooks are delayed in the sequence and sequence steps will contain codebooks
225
- from different timesteps.
226
-
227
- Example:
228
- Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
229
- [[1, 2, 3, 4],
230
- [1, 2, 3, 4],
231
- [1, 2, 3, 4]]
232
- The resulting sequence obtained from the returned pattern is:
233
- [[S, 1, 2, 3, 4],
234
- [S, S, 1, 2, 3],
235
- [S, S, S, 1, 2]]
236
- (with S being a special token)
237
-
238
- Args:
239
- n_q (int): Number of codebooks.
240
- delays (list of int, optional): Delay for each of the codebooks.
241
- If delays not defined, each codebook is delayed by 1 compared to the previous one.
242
- flatten_first (int): Flatten the first N timesteps.
243
- empty_initial (int): Prepend with N empty list of coordinates.
244
- """
245
- def __init__(self,
246
- n_q,
247
- delays,
248
- flatten_first=0,
249
- empty_initial=0):
250
- self.n_q = n_q
251
- if delays is None:
252
- delays = list(range(n_q))
253
- print(f'{delays=} PATTERN __ini')
254
- self.delays = delays
255
- self.flatten_first = flatten_first
256
- self.empty_initial = empty_initial
257
- assert len(self.delays) == self.n_q
258
- assert sorted(self.delays) == self.delays
259
-
260
- def get_pattern(self, timesteps):
261
- # get_pattern for desired length?
262
- # print(f'{timesteps=} GET_PATTERn') # 35
263
- # print(f'{self.empty_initial=}')
264
- omit_special_token = self.empty_initial < 0 # False as initial = 0 unset
265
-
266
- out: PatternLayout = [] if omit_special_token else [[]]
267
- max_delay = max(self.delays)
268
- if self.empty_initial:
269
- out += [[] for _ in range(self.empty_initial)]
270
- if self.flatten_first:
271
- for t in range(min(timesteps, self.flatten_first)):
272
- for q in range(self.n_q):
273
- out.append([LayoutCoord(t, q)])
274
- for t in range(self.flatten_first, timesteps + max_delay):
275
- v = []
276
- for q, delay in enumerate(self.delays):
277
- t_for_q = t - delay
278
- if t_for_q >= self.flatten_first:
279
- v.append(LayoutCoord(t_for_q, q))
280
- out.append(v)
281
- # print(self.n_q, 'N_Q in PATTERN') # 4 N_Q in PATTERN
282
- return Pattern(out, n_q=self.n_q, timesteps=timesteps)
283
-
284
-
285
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/lm.py CHANGED
@@ -2,7 +2,6 @@ import torch
2
  import torch.nn.functional as F
3
  from audiocraft.transformer import StreamingTransformer
4
  from torch import nn
5
- from audiocraft.codebooks_patterns import DelayedPatternProvider
6
  from audiocraft.conditioners import T5Conditioner
7
  import numpy as np
8
 
@@ -26,7 +25,6 @@ class LMModel(nn.Module):
26
  embed_dim = self.card + 1
27
  self.n_q = n_q
28
  self.dim = dim
29
- self.pattern_provider = DelayedPatternProvider()
30
  self.emb = nn.ModuleList([nn.Embedding(embed_dim, dim) for _ in range(n_q)]) # EMBEDDING HAS 2049
31
  self.transformer = StreamingTransformer(
32
  d_model=dim,
@@ -45,14 +43,18 @@ class LMModel(nn.Module):
45
  sequence,
46
  condition_tensors=None,
47
  token_count=None):
48
- # takes bs=3 duplicates null condition to bs=6 splits logits to cfg returns bs=3
49
 
50
- bs, _, _ = sequence.shape # sequence [bs, n_draw,4]
51
 
52
  input_ = sum([self.emb[k](sequence[:, k]) for k in range(self.n_q)])
 
 
 
53
  out = self.transformer(torch.cat([input_, input_], 0),
54
  cross_attention_src=condition_tensors,
55
- token_count=token_count)
 
 
56
  if self.out_norm:
57
  out = self.out_norm(out)
58
 
@@ -79,7 +81,7 @@ class LMModel(nn.Module):
79
  @torch.no_grad()
80
  def generate(self,
81
  descriptions = ['windy day', 'rain storm'],
82
- max_gen_len = 256):
83
 
84
  text_condition = self.condition_provider(descriptions)
85
 
@@ -95,64 +97,66 @@ class LMModel(nn.Module):
95
 
96
 
97
 
98
- pattern = self.pattern_provider.get_pattern(max_gen_len)
99
- gen_codes = torch.full((bs,
100
- self.n_q,
101
- max_gen_len), -1, dtype=torch.long,
102
- device=text_condition.device)
103
-
104
- gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.card)
105
- _, _, audiodur = gen_sequence.shape # bs, 4, 7=audiodur
106
-
107
- # print(gen_sequence.shape, mask.shape, 'F') # mask has no batch = [4,audio_duration]
108
- # print(f'{mask=}')
109
- #
110
- # torch.Size([3, 4, 7]) torch.Size([4, 7]) F
111
- # mask=tensor([[False, True, True, True, False, False, False],
112
- # [False, False, True, True, True, False, False],
113
- # [False, False, False, True, True, True, False],
114
- # [False, False, False, False, True, True, True]], device='cuda:0')
115
-
116
- mask = mask[None, None, :, :].repeat(bs, self.n_draw, 1, 1) # [bs, n_draw, 4, audio duration]
117
- gen_sequence = gen_sequence[:, None, :, :].repeat(1, self.n_draw, 1, 1) # bs,n_draw,4,dur
118
-
119
 
120
-
121
- for offset in range(1, audiodur):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
- # forward duplicates the query to nullcond - then cfg & returns deduplicate token
124
- next_token = self.forward(gen_sequence[:, 0, :, offset-1:offset], # DIAGINDEXING for setting prediction of lm into gen_sequence THE GENSEQUENCE has to be un-delayed in the end [Because it has to be de-delayed for the vocoder then is actually only the lm input that requires to see the delay thus we could just feed by diaggather] so it matches gen_codes -1 a[[0, 1, 2, 3], torch.tensor([0, 1, 2, 3]) + 5] the gen_sequence is indexed by vertical column and fed to lm however the prediction of lm is place diagonally with delay to the gen_sequence
 
125
  condition_tensors=text_condition, # utilisation of the attention mask of txt condition ?
126
  token_count=offset-1) # [bs, 4, 1, 2048]
127
 
128
 
129
-
 
 
 
130
 
131
- # MASK is not full 1---- HAS 4 x audioduration PATTERN
132
- m = mask[:, :, :, offset]
133
- next_token[~m] = self.card
134
- gen_sequence[:, :, :, offset] = torch.where(
135
- gen_sequence[:, :, :, offset] == -1, #unknown_token,
136
- next_token,
137
- gen_sequence[:, :, :, offset]
138
- )
139
 
140
-
141
- # 1. reshape n_draw as bs * n_draw
142
- # 2. invert all short-sequences
143
- # 3. reshape bs * n_draw -> bs, n_draw * audiodur ELONGATION
144
- out_codes = pattern.revert_pattern_sequence(
145
- gen_sequence.reshape(bs * self.n_draw, 4, audiodur), # [3,8,4,7]
146
- special_token=-1)
147
- # print(f'{gen_sequence.shape=} {out_codes.shape=} Ha') # REVERT PATTERN REDUCES DURATION?
148
- _, _, new_len = out_codes.shape # 4 IS PRESERVED AFTER REVERT!
149
- out_codes = out_codes.reshape(bs, self.n_draw, 4, new_len)
150
- out_codes = out_codes.transpose(1, 2).reshape(bs, 4, self.n_draw * new_len)
151
- print(out_codes.shape, 'o')
152
 
153
  # Clear k/v cache (Different kv is saved by every 48x selfattn)
154
  for lay in self.transformer.layers:
155
  lay.self_attn.k_history = None
156
  lay.self_attn.v_history = None
157
 
158
- return out_codes # bs*n_draw, duration -> repeat/shift in api.py
 
2
  import torch.nn.functional as F
3
  from audiocraft.transformer import StreamingTransformer
4
  from torch import nn
 
5
  from audiocraft.conditioners import T5Conditioner
6
  import numpy as np
7
 
 
25
  embed_dim = self.card + 1
26
  self.n_q = n_q
27
  self.dim = dim
 
28
  self.emb = nn.ModuleList([nn.Embedding(embed_dim, dim) for _ in range(n_q)]) # EMBEDDING HAS 2049
29
  self.transformer = StreamingTransformer(
30
  d_model=dim,
 
43
  sequence,
44
  condition_tensors=None,
45
  token_count=None):
 
46
 
47
+ bs, n_q, time_frames = sequence.shape # [bs, 4, time]
48
 
49
  input_ = sum([self.emb[k](sequence[:, k]) for k in range(self.n_q)])
50
+
51
+ # duplicate null condition (bs x 2)
52
+
53
  out = self.transformer(torch.cat([input_, input_], 0),
54
  cross_attention_src=condition_tensors,
55
+ token_count=token_count
56
+ )
57
+
58
  if self.out_norm:
59
  out = self.out_norm(out)
60
 
 
81
  @torch.no_grad()
82
  def generate(self,
83
  descriptions = ['windy day', 'rain storm'],
84
+ max_tokens = 256):
85
 
86
  text_condition = self.condition_provider(descriptions)
87
 
 
97
 
98
 
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ out_codes = torch.full((bs, self.n_draw, 4, 4 + max_tokens), # 4 + max_tokens to have sufficient to index the 1st antidiagonal of 4x4
102
+ self.card,
103
+ dtype=torch.long,device=text_condition.device) # bs,n_draw,4,dur
104
+ for offset in range(0, max_tokens):
105
+
106
+ # GEN_SEQUENCE has fillers start & end = 2048
107
+ # [6,4,74] = gen_sequence = torch.tensor([[[
108
+ # [2048, 0, 1, 2, 3, 4, 5, 6, 2048, 2048, 2048],
109
+ # [2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048, 2048],
110
+ # [2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048],
111
+ # [2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6]],
112
+ #
113
+ # out codes = un-delayed
114
+ #
115
+ # tensor([[0, 1, 2, 3, 4, 5, 6],
116
+ # [0, 1, 2, 3, 4, 5, 6],
117
+ # [0, 1, 2, 3, 4, 5, 6],
118
+ # [0, 1, 2, 3, 4, 5, 6]])
119
+ #
120
+ # LM "sees" 4 delayed tokens (diagonal extract)
121
+ #
122
+ # SO THE FIRST pack of 4 tokens fed TO LM is [2048, 2048, 2048, 2048]
123
+ #
124
+ # IF WE START WITH
125
+ # 2048 2048 2048 2048
126
+ # 2048 2048 2048 2048
127
+ # 2048 2048 2048 2048
128
+ # 2048 2048 2048 2048
129
+ #
130
+ # THE 2nd token pack of 4 fed to LM is [10, 20, 50, 7]
131
+ #
132
+ # 2048 2048 2048 2048 10
133
+ # 2048 2048 2048 2048 20
134
+ # 2048 2048 2048 2048 50
135
+ # 2048 2048 2048 2048 7
136
+ #
137
+ # extract diagonal via indexing out_codes[ [0, 1, 2, 3], [0, 1, 2, 3] ]
138
+ #
139
+ # forward duplicates the query to nullcond - then cfg & returns deduplicate token
140
+ # only 0 (1st token of n_draw is continued by LM call - rest is supersampled in torch.multinomial)
141
 
142
+ # feeds the antidiagonal to LM
143
+ next_token = self.forward(out_codes[:, 0, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset][:, :, None], # index diagonal & exapnd to [bs, n_q, dur=1]
144
+ #gen_sequence[:, 0, :, offset-1:offset], # DIAGINDEXING for setting prediction of lm into gen_sequence THE GENSEQUENCE has to be un-delayed in the end [Because it has to be de-delayed for the vocoder then is actually only the lm input that requires to see the delay thus we could just feed by diaggather] so it matches gen_codes -1 a[[0, 1, 2, 3], torch.tensor([0, 1, 2, 3]) + 5] the gen_sequence is indexed by vertical column and fed to lm however the prediction of lm is place diagonally with delay to the gen_sequence
145
  condition_tensors=text_condition, # utilisation of the attention mask of txt condition ?
146
  token_count=offset-1) # [bs, 4, 1, 2048]
147
 
148
 
149
+
150
+ out_codes[:, :, :, offset + 4] = next_token # [bs, n_draw, 4, duration]
151
+
152
+ # DISCARD FILL
153
 
154
+ out_codes = out_codes[:, :, :, 4:].transpose(1, 2).reshape(bs, 4, self.n_draw * max_tokens) # [bs, 4, duration*n_draw] DISCARD FILL 2048
 
 
 
 
 
 
 
155
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  # Clear k/v cache (Different kv is saved by every 48x selfattn)
158
  for lay in self.transformer.layers:
159
  lay.self_attn.k_history = None
160
  lay.self_attn.v_history = None
161
 
162
+ return out_codes # SKIP THE 4 fill 2048 bs*n_draw, duration -> repeat/shift in api.py