antidiagonal fast delay pattern
Browse files- audiocraft/builders.py +3 -3
- audiocraft/codebooks_patterns.py +0 -285
- audiocraft/lm.py +57 -53
audiocraft/builders.py
CHANGED
@@ -64,14 +64,14 @@ class AudioGen(nn.Module):
|
|
64 |
with torch.no_grad():
|
65 |
gen_tokens = self.lm.generate(
|
66 |
descriptions=[descriptions]*3,
|
67 |
-
|
68 |
x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
|
69 |
|
70 |
-
x = x[:, 0,
|
71 |
|
72 |
# AudioGen 16KHZ / StyleTTS2 24 KHz / MMSTTS 24 KHz
|
73 |
|
74 |
-
|
75 |
|
76 |
# batch size = different sounds for same txt
|
77 |
|
|
|
64 |
with torch.no_grad():
|
65 |
gen_tokens = self.lm.generate(
|
66 |
descriptions=[descriptions]*3,
|
67 |
+
max_tokens=int(self.duration * self.frame_rate)) # [bs, 4, 37 * self.lm.n_draw]
|
68 |
x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
|
69 |
|
70 |
+
x = x[:, 0, :] # last samples have splash sounds DISCARD 25000 last samples
|
71 |
|
72 |
# AudioGen 16KHZ / StyleTTS2 24 KHz / MMSTTS 24 KHz
|
73 |
|
74 |
+
x = self.resample_fn(x)
|
75 |
|
76 |
# batch size = different sounds for same txt
|
77 |
|
audiocraft/codebooks_patterns.py
DELETED
@@ -1,285 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
from collections import namedtuple
|
8 |
-
from dataclasses import dataclass
|
9 |
-
|
10 |
-
import logging
|
11 |
-
import typing as tp
|
12 |
-
import torch
|
13 |
-
|
14 |
-
LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
|
15 |
-
PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
|
16 |
-
logger = logging.getLogger(__name__)
|
17 |
-
|
18 |
-
|
19 |
-
@dataclass
|
20 |
-
class Pattern:
|
21 |
-
"""Base implementation of a pattern over a sequence with multiple codebooks.
|
22 |
-
|
23 |
-
The codebook pattern consists in a layout, defining for each sequence step
|
24 |
-
the list of coordinates of each codebook timestep in the resulting interleaved sequence.
|
25 |
-
The first item of the pattern is always an empty list in order to properly insert a special token
|
26 |
-
to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
|
27 |
-
and ``timesteps`` the number of timesteps corresponding to the original sequence.
|
28 |
-
|
29 |
-
The pattern provides convenient methods to build and revert interleaved sequences from it:
|
30 |
-
``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
|
31 |
-
to the interleaved sequence of shape [B, K, S] applying the pattern, with B being the batch size,
|
32 |
-
K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
|
33 |
-
for the output sequence. The unfilled positions are replaced with a special token and the built sequence
|
34 |
-
is returned along with a mask indicating valid tokens.
|
35 |
-
``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
|
36 |
-
of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
|
37 |
-
to fill and specify invalid positions if needed.
|
38 |
-
See the dedicated methods for more details.
|
39 |
-
"""
|
40 |
-
# Pattern layout, for each sequence step, we have a list of coordinates
|
41 |
-
# corresponding to the original codebook timestep and position.
|
42 |
-
# The first list is always an empty list in order to properly insert
|
43 |
-
# a special token to start with.
|
44 |
-
layout: PatternLayout
|
45 |
-
timesteps: int
|
46 |
-
n_q: int
|
47 |
-
|
48 |
-
def __post_init__(self):
|
49 |
-
# assert len(self.layout) > 0
|
50 |
-
# self._validate_layout() #
|
51 |
-
self._build_reverted_sequence_scatter_indexes = self._build_reverted_sequence_scatter_indexes
|
52 |
-
self._build_pattern_sequence_scatter_indexes = self._build_pattern_sequence_scatter_indexes
|
53 |
-
print("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
|
54 |
-
|
55 |
-
@property
|
56 |
-
def max_delay(self):
|
57 |
-
max_t_in_seq_coords = 0
|
58 |
-
for seq_coords in self.layout[1:]:
|
59 |
-
for coords in seq_coords:
|
60 |
-
max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
|
61 |
-
return max_t_in_seq_coords - self.timesteps
|
62 |
-
|
63 |
-
@property
|
64 |
-
def valid_layout(self):
|
65 |
-
valid_step = len(self.layout) - self.max_delay
|
66 |
-
return self.layout[:valid_step]
|
67 |
-
|
68 |
-
def starts_with_special_token(self):
|
69 |
-
return self.layout[0] == []
|
70 |
-
|
71 |
-
def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
|
72 |
-
"""Get codebook coordinates in the layout that corresponds to the specified timestep t
|
73 |
-
and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
|
74 |
-
and the actual codebook coordinates.
|
75 |
-
"""
|
76 |
-
assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
|
77 |
-
if q is not None:
|
78 |
-
assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
|
79 |
-
coords = []
|
80 |
-
for s, seq_codes in enumerate(self.layout):
|
81 |
-
for code in seq_codes:
|
82 |
-
if code.t == t and (q is None or code.q == q):
|
83 |
-
coords.append((s, code))
|
84 |
-
return coords
|
85 |
-
|
86 |
-
def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
|
87 |
-
return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
|
88 |
-
|
89 |
-
def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
|
90 |
-
steps_with_timesteps = self.get_steps_with_timestep(t, q)
|
91 |
-
return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
|
92 |
-
|
93 |
-
def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
|
94 |
-
device: tp.Union[torch.device, str] = 'cpu'):
|
95 |
-
"""Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
|
96 |
-
|
97 |
-
Args:
|
98 |
-
timesteps (int): Maximum number of timesteps steps to consider.
|
99 |
-
keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
|
100 |
-
device (torch.device or str): Device for created tensors.
|
101 |
-
Returns:
|
102 |
-
indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
|
103 |
-
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
|
104 |
-
"""
|
105 |
-
assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
|
106 |
-
assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
|
107 |
-
# use the proper layout based on whether we limit ourselves to valid steps only or not,
|
108 |
-
# note that using the valid_layout will result in a truncated sequence up to the valid steps
|
109 |
-
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
|
110 |
-
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
|
111 |
-
indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
|
112 |
-
mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
|
113 |
-
# fill indexes with last sequence step value that will correspond to our special token
|
114 |
-
# the last value is n_q * timesteps as we have flattened z and append special token as the last token
|
115 |
-
# which will correspond to the index: n_q * timesteps
|
116 |
-
indexes[:] = n_q * timesteps
|
117 |
-
# iterate over the pattern and fill scattered indexes and mask
|
118 |
-
for s, sequence_coords in enumerate(ref_layout):
|
119 |
-
for coords in sequence_coords:
|
120 |
-
if coords.t < timesteps:
|
121 |
-
indexes[coords.q, s] = coords.t + coords.q * timesteps
|
122 |
-
mask[coords.q, s] = 1
|
123 |
-
indexes = torch.from_numpy(indexes).to(device)
|
124 |
-
mask = torch.from_numpy(mask).to(device)
|
125 |
-
return indexes, mask
|
126 |
-
|
127 |
-
def build_pattern_sequence(self,
|
128 |
-
z,
|
129 |
-
special_token,
|
130 |
-
keep_only_valid_steps=False):
|
131 |
-
B, K, T = z.shape
|
132 |
-
indexes, mask = self._build_pattern_sequence_scatter_indexes(
|
133 |
-
T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
|
134 |
-
)
|
135 |
-
z = z.view(B, -1)
|
136 |
-
# we append the special token as the last index of our flattened z tensor
|
137 |
-
z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
|
138 |
-
values = z[:, indexes.view(-1)]
|
139 |
-
values = values.view(B, K, indexes.shape[-1])
|
140 |
-
|
141 |
-
# print(values.shape, indexes.shape, mask.shape, 'BUILD PATTERN')
|
142 |
-
# --
|
143 |
-
# torch.Size([1, 4, 39]) torch.Size([4, 39]) torch.Size([4, 39]) BUILD PATTERN
|
144 |
-
|
145 |
-
return values, indexes, mask
|
146 |
-
|
147 |
-
def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
|
148 |
-
keep_only_valid_steps: bool = False,
|
149 |
-
is_model_output: bool = False,
|
150 |
-
device: tp.Union[torch.device, str] = 'cpu'):
|
151 |
-
"""Builds scatter indexes required to retrieve the original multi-codebook sequence
|
152 |
-
from interleaving pattern.
|
153 |
-
|
154 |
-
Args:
|
155 |
-
sequence_steps (int): Sequence steps.
|
156 |
-
n_q (int): Number of codebooks.
|
157 |
-
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
|
158 |
-
Steps that are beyond valid steps will be replaced by the special_token in that case.
|
159 |
-
is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
|
160 |
-
device (torch.device or str): Device for created tensors.
|
161 |
-
Returns:
|
162 |
-
indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T].
|
163 |
-
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
|
164 |
-
"""
|
165 |
-
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
|
166 |
-
# TODO(jade): Do we want to further truncate to only valid timesteps here as well?
|
167 |
-
timesteps = self.timesteps
|
168 |
-
assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
|
169 |
-
assert sequence_steps <= len(ref_layout), \
|
170 |
-
f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
|
171 |
-
|
172 |
-
# ensure we take the appropriate indexes to keep the model output from the first special token as well
|
173 |
-
if is_model_output and self.starts_with_special_token():
|
174 |
-
ref_layout = ref_layout[1:]
|
175 |
-
|
176 |
-
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
|
177 |
-
indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
|
178 |
-
mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
|
179 |
-
# fill indexes with last sequence step value that will correspond to our special token
|
180 |
-
indexes[:] = n_q * sequence_steps
|
181 |
-
for s, sequence_codes in enumerate(ref_layout):
|
182 |
-
if s < sequence_steps:
|
183 |
-
for code in sequence_codes:
|
184 |
-
if code.t < timesteps:
|
185 |
-
indexes[code.q, code.t] = s + code.q * sequence_steps # oh the jump - so are the codes linearised
|
186 |
-
mask[code.q, code.t] = 1
|
187 |
-
indexes = torch.from_numpy(indexes).to(device)
|
188 |
-
mask = torch.from_numpy(mask).to(device)
|
189 |
-
return indexes, mask
|
190 |
-
|
191 |
-
def revert_pattern_sequence(self,
|
192 |
-
s,
|
193 |
-
special_token,
|
194 |
-
keep_only_valid_steps=False):
|
195 |
-
"""SPECIAL TOKEN NOT DELETED HERE !!!!
|
196 |
-
|
197 |
-
Args:
|
198 |
-
s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
|
199 |
-
special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
|
200 |
-
Returns:
|
201 |
-
values (torch.Tensor) : Interleaved sequence matching the pattern, of shape [B, K, T] with T
|
202 |
-
indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
|
203 |
-
mask (torch.Tensor) : Mask corresponding to indexes that matches valid indexes of shape [K, T].
|
204 |
-
shall this mask delete special token id;
|
205 |
-
"""
|
206 |
-
B, K, S = s.shape
|
207 |
-
indexes, mask = self._build_reverted_sequence_scatter_indexes(
|
208 |
-
S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
|
209 |
-
)
|
210 |
-
s = s.view(B, -1)
|
211 |
-
# we append the special token as the last index of our flattened z tensor
|
212 |
-
s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
|
213 |
-
values = s[:, indexes.view(-1)]
|
214 |
-
values = values.view(B, K, indexes.shape[-1])
|
215 |
-
|
216 |
-
return values, indexes, mask
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
class DelayedPatternProvider():
|
223 |
-
"""Provider for delayed pattern across delayed codebooks.
|
224 |
-
Codebooks are delayed in the sequence and sequence steps will contain codebooks
|
225 |
-
from different timesteps.
|
226 |
-
|
227 |
-
Example:
|
228 |
-
Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
|
229 |
-
[[1, 2, 3, 4],
|
230 |
-
[1, 2, 3, 4],
|
231 |
-
[1, 2, 3, 4]]
|
232 |
-
The resulting sequence obtained from the returned pattern is:
|
233 |
-
[[S, 1, 2, 3, 4],
|
234 |
-
[S, S, 1, 2, 3],
|
235 |
-
[S, S, S, 1, 2]]
|
236 |
-
(with S being a special token)
|
237 |
-
|
238 |
-
Args:
|
239 |
-
n_q (int): Number of codebooks.
|
240 |
-
delays (list of int, optional): Delay for each of the codebooks.
|
241 |
-
If delays not defined, each codebook is delayed by 1 compared to the previous one.
|
242 |
-
flatten_first (int): Flatten the first N timesteps.
|
243 |
-
empty_initial (int): Prepend with N empty list of coordinates.
|
244 |
-
"""
|
245 |
-
def __init__(self,
|
246 |
-
n_q,
|
247 |
-
delays,
|
248 |
-
flatten_first=0,
|
249 |
-
empty_initial=0):
|
250 |
-
self.n_q = n_q
|
251 |
-
if delays is None:
|
252 |
-
delays = list(range(n_q))
|
253 |
-
print(f'{delays=} PATTERN __ini')
|
254 |
-
self.delays = delays
|
255 |
-
self.flatten_first = flatten_first
|
256 |
-
self.empty_initial = empty_initial
|
257 |
-
assert len(self.delays) == self.n_q
|
258 |
-
assert sorted(self.delays) == self.delays
|
259 |
-
|
260 |
-
def get_pattern(self, timesteps):
|
261 |
-
# get_pattern for desired length?
|
262 |
-
# print(f'{timesteps=} GET_PATTERn') # 35
|
263 |
-
# print(f'{self.empty_initial=}')
|
264 |
-
omit_special_token = self.empty_initial < 0 # False as initial = 0 unset
|
265 |
-
|
266 |
-
out: PatternLayout = [] if omit_special_token else [[]]
|
267 |
-
max_delay = max(self.delays)
|
268 |
-
if self.empty_initial:
|
269 |
-
out += [[] for _ in range(self.empty_initial)]
|
270 |
-
if self.flatten_first:
|
271 |
-
for t in range(min(timesteps, self.flatten_first)):
|
272 |
-
for q in range(self.n_q):
|
273 |
-
out.append([LayoutCoord(t, q)])
|
274 |
-
for t in range(self.flatten_first, timesteps + max_delay):
|
275 |
-
v = []
|
276 |
-
for q, delay in enumerate(self.delays):
|
277 |
-
t_for_q = t - delay
|
278 |
-
if t_for_q >= self.flatten_first:
|
279 |
-
v.append(LayoutCoord(t_for_q, q))
|
280 |
-
out.append(v)
|
281 |
-
# print(self.n_q, 'N_Q in PATTERN') # 4 N_Q in PATTERN
|
282 |
-
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
283 |
-
|
284 |
-
|
285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/lm.py
CHANGED
@@ -2,7 +2,6 @@ import torch
|
|
2 |
import torch.nn.functional as F
|
3 |
from audiocraft.transformer import StreamingTransformer
|
4 |
from torch import nn
|
5 |
-
from audiocraft.codebooks_patterns import DelayedPatternProvider
|
6 |
from audiocraft.conditioners import T5Conditioner
|
7 |
import numpy as np
|
8 |
|
@@ -26,7 +25,6 @@ class LMModel(nn.Module):
|
|
26 |
embed_dim = self.card + 1
|
27 |
self.n_q = n_q
|
28 |
self.dim = dim
|
29 |
-
self.pattern_provider = DelayedPatternProvider()
|
30 |
self.emb = nn.ModuleList([nn.Embedding(embed_dim, dim) for _ in range(n_q)]) # EMBEDDING HAS 2049
|
31 |
self.transformer = StreamingTransformer(
|
32 |
d_model=dim,
|
@@ -45,14 +43,18 @@ class LMModel(nn.Module):
|
|
45 |
sequence,
|
46 |
condition_tensors=None,
|
47 |
token_count=None):
|
48 |
-
# takes bs=3 duplicates null condition to bs=6 splits logits to cfg returns bs=3
|
49 |
|
50 |
-
bs,
|
51 |
|
52 |
input_ = sum([self.emb[k](sequence[:, k]) for k in range(self.n_q)])
|
|
|
|
|
|
|
53 |
out = self.transformer(torch.cat([input_, input_], 0),
|
54 |
cross_attention_src=condition_tensors,
|
55 |
-
token_count=token_count
|
|
|
|
|
56 |
if self.out_norm:
|
57 |
out = self.out_norm(out)
|
58 |
|
@@ -79,7 +81,7 @@ class LMModel(nn.Module):
|
|
79 |
@torch.no_grad()
|
80 |
def generate(self,
|
81 |
descriptions = ['windy day', 'rain storm'],
|
82 |
-
|
83 |
|
84 |
text_condition = self.condition_provider(descriptions)
|
85 |
|
@@ -95,64 +97,66 @@ class LMModel(nn.Module):
|
|
95 |
|
96 |
|
97 |
|
98 |
-
pattern = self.pattern_provider.get_pattern(max_gen_len)
|
99 |
-
gen_codes = torch.full((bs,
|
100 |
-
self.n_q,
|
101 |
-
max_gen_len), -1, dtype=torch.long,
|
102 |
-
device=text_condition.device)
|
103 |
-
|
104 |
-
gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.card)
|
105 |
-
_, _, audiodur = gen_sequence.shape # bs, 4, 7=audiodur
|
106 |
-
|
107 |
-
# print(gen_sequence.shape, mask.shape, 'F') # mask has no batch = [4,audio_duration]
|
108 |
-
# print(f'{mask=}')
|
109 |
-
#
|
110 |
-
# torch.Size([3, 4, 7]) torch.Size([4, 7]) F
|
111 |
-
# mask=tensor([[False, True, True, True, False, False, False],
|
112 |
-
# [False, False, True, True, True, False, False],
|
113 |
-
# [False, False, False, True, True, True, False],
|
114 |
-
# [False, False, False, False, True, True, True]], device='cuda:0')
|
115 |
-
|
116 |
-
mask = mask[None, None, :, :].repeat(bs, self.n_draw, 1, 1) # [bs, n_draw, 4, audio duration]
|
117 |
-
gen_sequence = gen_sequence[:, None, :, :].repeat(1, self.n_draw, 1, 1) # bs,n_draw,4,dur
|
118 |
-
|
119 |
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
-
#
|
124 |
-
next_token = self.forward(
|
|
|
125 |
condition_tensors=text_condition, # utilisation of the attention mask of txt condition ?
|
126 |
token_count=offset-1) # [bs, 4, 1, 2048]
|
127 |
|
128 |
|
129 |
-
|
|
|
|
|
|
|
130 |
|
131 |
-
|
132 |
-
m = mask[:, :, :, offset]
|
133 |
-
next_token[~m] = self.card
|
134 |
-
gen_sequence[:, :, :, offset] = torch.where(
|
135 |
-
gen_sequence[:, :, :, offset] == -1, #unknown_token,
|
136 |
-
next_token,
|
137 |
-
gen_sequence[:, :, :, offset]
|
138 |
-
)
|
139 |
|
140 |
-
|
141 |
-
# 1. reshape n_draw as bs * n_draw
|
142 |
-
# 2. invert all short-sequences
|
143 |
-
# 3. reshape bs * n_draw -> bs, n_draw * audiodur ELONGATION
|
144 |
-
out_codes = pattern.revert_pattern_sequence(
|
145 |
-
gen_sequence.reshape(bs * self.n_draw, 4, audiodur), # [3,8,4,7]
|
146 |
-
special_token=-1)
|
147 |
-
# print(f'{gen_sequence.shape=} {out_codes.shape=} Ha') # REVERT PATTERN REDUCES DURATION?
|
148 |
-
_, _, new_len = out_codes.shape # 4 IS PRESERVED AFTER REVERT!
|
149 |
-
out_codes = out_codes.reshape(bs, self.n_draw, 4, new_len)
|
150 |
-
out_codes = out_codes.transpose(1, 2).reshape(bs, 4, self.n_draw * new_len)
|
151 |
-
print(out_codes.shape, 'o')
|
152 |
|
153 |
# Clear k/v cache (Different kv is saved by every 48x selfattn)
|
154 |
for lay in self.transformer.layers:
|
155 |
lay.self_attn.k_history = None
|
156 |
lay.self_attn.v_history = None
|
157 |
|
158 |
-
return out_codes # bs*n_draw, duration -> repeat/shift in api.py
|
|
|
2 |
import torch.nn.functional as F
|
3 |
from audiocraft.transformer import StreamingTransformer
|
4 |
from torch import nn
|
|
|
5 |
from audiocraft.conditioners import T5Conditioner
|
6 |
import numpy as np
|
7 |
|
|
|
25 |
embed_dim = self.card + 1
|
26 |
self.n_q = n_q
|
27 |
self.dim = dim
|
|
|
28 |
self.emb = nn.ModuleList([nn.Embedding(embed_dim, dim) for _ in range(n_q)]) # EMBEDDING HAS 2049
|
29 |
self.transformer = StreamingTransformer(
|
30 |
d_model=dim,
|
|
|
43 |
sequence,
|
44 |
condition_tensors=None,
|
45 |
token_count=None):
|
|
|
46 |
|
47 |
+
bs, n_q, time_frames = sequence.shape # [bs, 4, time]
|
48 |
|
49 |
input_ = sum([self.emb[k](sequence[:, k]) for k in range(self.n_q)])
|
50 |
+
|
51 |
+
# duplicate null condition (bs x 2)
|
52 |
+
|
53 |
out = self.transformer(torch.cat([input_, input_], 0),
|
54 |
cross_attention_src=condition_tensors,
|
55 |
+
token_count=token_count
|
56 |
+
)
|
57 |
+
|
58 |
if self.out_norm:
|
59 |
out = self.out_norm(out)
|
60 |
|
|
|
81 |
@torch.no_grad()
|
82 |
def generate(self,
|
83 |
descriptions = ['windy day', 'rain storm'],
|
84 |
+
max_tokens = 256):
|
85 |
|
86 |
text_condition = self.condition_provider(descriptions)
|
87 |
|
|
|
97 |
|
98 |
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
+
out_codes = torch.full((bs, self.n_draw, 4, 4 + max_tokens), # 4 + max_tokens to have sufficient to index the 1st antidiagonal of 4x4
|
102 |
+
self.card,
|
103 |
+
dtype=torch.long,device=text_condition.device) # bs,n_draw,4,dur
|
104 |
+
for offset in range(0, max_tokens):
|
105 |
+
|
106 |
+
# GEN_SEQUENCE has fillers start & end = 2048
|
107 |
+
# [6,4,74] = gen_sequence = torch.tensor([[[
|
108 |
+
# [2048, 0, 1, 2, 3, 4, 5, 6, 2048, 2048, 2048],
|
109 |
+
# [2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048, 2048],
|
110 |
+
# [2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048],
|
111 |
+
# [2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6]],
|
112 |
+
#
|
113 |
+
# out codes = un-delayed
|
114 |
+
#
|
115 |
+
# tensor([[0, 1, 2, 3, 4, 5, 6],
|
116 |
+
# [0, 1, 2, 3, 4, 5, 6],
|
117 |
+
# [0, 1, 2, 3, 4, 5, 6],
|
118 |
+
# [0, 1, 2, 3, 4, 5, 6]])
|
119 |
+
#
|
120 |
+
# LM "sees" 4 delayed tokens (diagonal extract)
|
121 |
+
#
|
122 |
+
# SO THE FIRST pack of 4 tokens fed TO LM is [2048, 2048, 2048, 2048]
|
123 |
+
#
|
124 |
+
# IF WE START WITH
|
125 |
+
# 2048 2048 2048 2048
|
126 |
+
# 2048 2048 2048 2048
|
127 |
+
# 2048 2048 2048 2048
|
128 |
+
# 2048 2048 2048 2048
|
129 |
+
#
|
130 |
+
# THE 2nd token pack of 4 fed to LM is [10, 20, 50, 7]
|
131 |
+
#
|
132 |
+
# 2048 2048 2048 2048 10
|
133 |
+
# 2048 2048 2048 2048 20
|
134 |
+
# 2048 2048 2048 2048 50
|
135 |
+
# 2048 2048 2048 2048 7
|
136 |
+
#
|
137 |
+
# extract diagonal via indexing out_codes[ [0, 1, 2, 3], [0, 1, 2, 3] ]
|
138 |
+
#
|
139 |
+
# forward duplicates the query to nullcond - then cfg & returns deduplicate token
|
140 |
+
# only 0 (1st token of n_draw is continued by LM call - rest is supersampled in torch.multinomial)
|
141 |
|
142 |
+
# feeds the antidiagonal to LM
|
143 |
+
next_token = self.forward(out_codes[:, 0, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset][:, :, None], # index diagonal & exapnd to [bs, n_q, dur=1]
|
144 |
+
#gen_sequence[:, 0, :, offset-1:offset], # DIAGINDEXING for setting prediction of lm into gen_sequence THE GENSEQUENCE has to be un-delayed in the end [Because it has to be de-delayed for the vocoder then is actually only the lm input that requires to see the delay thus we could just feed by diaggather] so it matches gen_codes -1 a[[0, 1, 2, 3], torch.tensor([0, 1, 2, 3]) + 5] the gen_sequence is indexed by vertical column and fed to lm however the prediction of lm is place diagonally with delay to the gen_sequence
|
145 |
condition_tensors=text_condition, # utilisation of the attention mask of txt condition ?
|
146 |
token_count=offset-1) # [bs, 4, 1, 2048]
|
147 |
|
148 |
|
149 |
+
|
150 |
+
out_codes[:, :, :, offset + 4] = next_token # [bs, n_draw, 4, duration]
|
151 |
+
|
152 |
+
# DISCARD FILL
|
153 |
|
154 |
+
out_codes = out_codes[:, :, :, 4:].transpose(1, 2).reshape(bs, 4, self.n_draw * max_tokens) # [bs, 4, duration*n_draw] DISCARD FILL 2048
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
# Clear k/v cache (Different kv is saved by every 48x selfattn)
|
158 |
for lay in self.transformer.layers:
|
159 |
lay.self_attn.k_history = None
|
160 |
lay.self_attn.v_history = None
|
161 |
|
162 |
+
return out_codes # SKIP THE 4 fill 2048 bs*n_draw, duration -> repeat/shift in api.py
|