Spaces:
Running
Running
hugo flores garcia
commited on
Commit
·
12dc48a
1
Parent(s):
2d0bc4e
for use with sound objects
Browse files- app.py +106 -33
- vampnet/newmask.py +365 -0
app.py
CHANGED
@@ -57,6 +57,47 @@ def shift_pitch(signal, interval: int):
|
|
57 |
return signal
|
58 |
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
@spaces.GPU
|
61 |
def _vamp(
|
62 |
seed, input_audio, model_choice,
|
@@ -78,7 +119,7 @@ def _vamp(
|
|
78 |
sr, input_audio = input_audio
|
79 |
input_audio = input_audio / np.iinfo(input_audio.dtype).max
|
80 |
|
81 |
-
sig = at.AudioSignal(input_audio, sr)
|
82 |
|
83 |
# reload the model if necessary
|
84 |
interface.load_finetuned(model_choice)
|
@@ -88,18 +129,15 @@ def _vamp(
|
|
88 |
|
89 |
codes = interface.encode(sig)
|
90 |
|
91 |
-
mask =
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
periodic_prompt=
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
upper_codebook_mask=int(n_mask_codebooks),
|
101 |
-
)
|
102 |
-
|
103 |
|
104 |
# save the mask as a txt file
|
105 |
interface.set_chunk_size(10.0)
|
@@ -145,24 +183,45 @@ def vamp(data):
|
|
145 |
api=False,
|
146 |
)
|
147 |
|
148 |
-
def api_vamp(data):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
return _vamp(
|
150 |
-
seed=
|
151 |
-
input_audio=
|
152 |
-
model_choice=
|
153 |
-
pitch_shift_amt=
|
154 |
-
periodic_p=
|
155 |
-
n_mask_codebooks=
|
156 |
-
periodic_w=
|
157 |
-
onset_mask_width=
|
158 |
-
dropout=
|
159 |
-
sampletemp=
|
160 |
-
typical_filtering=
|
161 |
-
typical_mass=
|
162 |
-
typical_min_tokens=
|
163 |
-
top_p=
|
164 |
-
sample_cutoff=
|
165 |
-
stretch_factor=
|
166 |
api=True,
|
167 |
)
|
168 |
|
@@ -258,7 +317,7 @@ with gr.Blocks() as demo:
|
|
258 |
minimum=0,
|
259 |
maximum=100,
|
260 |
step=1,
|
261 |
-
value=0, visible=
|
262 |
)
|
263 |
|
264 |
n_mask_codebooks = gr.Slider(
|
@@ -419,8 +478,22 @@ with gr.Blocks() as demo:
|
|
419 |
api_vamp_button = gr.Button("api vamp", visible=True)
|
420 |
api_vamp_button.click(
|
421 |
fn=api_vamp,
|
422 |
-
inputs=
|
423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
api_name="vamp"
|
425 |
)
|
426 |
|
|
|
57 |
return signal
|
58 |
|
59 |
|
60 |
+
def onsets(sig: at.AudioSignal, hop_length: int):
|
61 |
+
assert sig.batch_size == 1, "batch size must be 1"
|
62 |
+
assert sig.num_channels == 1, "mono signals only"
|
63 |
+
import librosa
|
64 |
+
onset_frame_idxs = librosa.onset.onset_detect(
|
65 |
+
y=sig.samples[0][0].detach().cpu().numpy(), sr=sig.sample_rate,
|
66 |
+
hop_length=hop_length,
|
67 |
+
backtrack=True,
|
68 |
+
)
|
69 |
+
return onset_frame_idxs
|
70 |
+
|
71 |
+
|
72 |
+
def new_vampnet_mask(self,
|
73 |
+
codes,
|
74 |
+
onset_idxs,
|
75 |
+
width: int = 5,
|
76 |
+
periodic_prompt=2,
|
77 |
+
upper_codebook_mask=1,
|
78 |
+
drop_amt: float = 0.1
|
79 |
+
):
|
80 |
+
from vampnet.newmask import mask_and, mask_or, onset_mask, periodic_mask, drop_ones, codebook_mask
|
81 |
+
mask = mask_and(
|
82 |
+
periodic_mask(codes, periodic_prompt, 1, random_roll=False),
|
83 |
+
mask_or( # this re-masks the onsets, according to a periodic schedule
|
84 |
+
onset_mask(onset_idxs, codes, width=width),
|
85 |
+
periodic_mask(codes, periodic_prompt, 1, random_roll=False),
|
86 |
+
)
|
87 |
+
).int()
|
88 |
+
# make sure the onset idxs themselves are unmasked
|
89 |
+
# mask = 1 - mask
|
90 |
+
mask[:, :, onset_idxs] = 0
|
91 |
+
mask = mask.cpu() # debug
|
92 |
+
mask = 1-drop_ones(1-mask, drop_amt)
|
93 |
+
mask = codebook_mask(mask, upper_codebook_mask)
|
94 |
+
|
95 |
+
|
96 |
+
# save mask as txt (ints)
|
97 |
+
np.savetxt("scratch/rms_mask.txt", mask[0].cpu().numpy(), fmt='%d')
|
98 |
+
mask = mask.to(self.device)
|
99 |
+
return mask[:, :, :]
|
100 |
+
|
101 |
@spaces.GPU
|
102 |
def _vamp(
|
103 |
seed, input_audio, model_choice,
|
|
|
119 |
sr, input_audio = input_audio
|
120 |
input_audio = input_audio / np.iinfo(input_audio.dtype).max
|
121 |
|
122 |
+
sig = at.AudioSignal(input_audio, sr).to_mono()
|
123 |
|
124 |
# reload the model if necessary
|
125 |
interface.load_finetuned(model_choice)
|
|
|
129 |
|
130 |
codes = interface.encode(sig)
|
131 |
|
132 |
+
mask = new_vampnet_mask(
|
133 |
+
interface,
|
134 |
+
codes,
|
135 |
+
onset_idxs=onsets(sig, hop_length=interface.codec.hop_length),
|
136 |
+
width=onset_mask_width,
|
137 |
+
periodic_prompt=periodic_p,
|
138 |
+
upper_codebook_mask=n_mask_codebooks,
|
139 |
+
drop_amt=dropout
|
140 |
+
).long()
|
|
|
|
|
|
|
141 |
|
142 |
# save the mask as a txt file
|
143 |
interface.set_chunk_size(10.0)
|
|
|
183 |
api=False,
|
184 |
)
|
185 |
|
186 |
+
# def api_vamp(data):
|
187 |
+
# return _vamp(
|
188 |
+
# seed=data[seed],
|
189 |
+
# input_audio=data[input_audio],
|
190 |
+
# model_choice=data[model_choice],
|
191 |
+
# pitch_shift_amt=data[pitch_shift_amt],
|
192 |
+
# periodic_p=data[periodic_p],
|
193 |
+
# n_mask_codebooks=data[n_mask_codebooks],
|
194 |
+
# periodic_w=data[periodic_w],
|
195 |
+
# onset_mask_width=data[onset_mask_width],
|
196 |
+
# dropout=data[dropout],
|
197 |
+
# sampletemp=data[sampletemp],
|
198 |
+
# typical_filtering=data[typical_filtering],
|
199 |
+
# typical_mass=data[typical_mass],
|
200 |
+
# typical_min_tokens=data[typical_min_tokens],
|
201 |
+
# top_p=data[top_p],
|
202 |
+
# sample_cutoff=data[sample_cutoff],
|
203 |
+
# stretch_factor=data[stretch_factor],
|
204 |
+
# api=True,
|
205 |
+
# )
|
206 |
+
|
207 |
+
def api_vamp(input_audio, sampletemp, top_p, periodic_p, periodic_w, dropout, stretch_factor, onset_mask_width, typical_filtering, typical_mass, typical_min_tokens, seed, model_choice, n_mask_codebooks, pitch_shift_amt, sample_cutoff):
|
208 |
return _vamp(
|
209 |
+
seed=seed,
|
210 |
+
input_audio=input_audio,
|
211 |
+
model_choice=model_choice,
|
212 |
+
pitch_shift_amt=pitch_shift_amt,
|
213 |
+
periodic_p=periodic_p,
|
214 |
+
n_mask_codebooks=n_mask_codebooks,
|
215 |
+
periodic_w=periodic_w,
|
216 |
+
onset_mask_width=onset_mask_width,
|
217 |
+
dropout=dropout,
|
218 |
+
sampletemp=sampletemp,
|
219 |
+
typical_filtering=typical_filtering,
|
220 |
+
typical_mass=typical_mass,
|
221 |
+
typical_min_tokens=typical_min_tokens,
|
222 |
+
top_p=top_p,
|
223 |
+
sample_cutoff=sample_cutoff,
|
224 |
+
stretch_factor=stretch_factor,
|
225 |
api=True,
|
226 |
)
|
227 |
|
|
|
317 |
minimum=0,
|
318 |
maximum=100,
|
319 |
step=1,
|
320 |
+
value=0, visible=True
|
321 |
)
|
322 |
|
323 |
n_mask_codebooks = gr.Slider(
|
|
|
478 |
api_vamp_button = gr.Button("api vamp", visible=True)
|
479 |
api_vamp_button.click(
|
480 |
fn=api_vamp,
|
481 |
+
inputs=[input_audio,
|
482 |
+
sampletemp, top_p,
|
483 |
+
periodic_p, periodic_w,
|
484 |
+
dropout,
|
485 |
+
stretch_factor,
|
486 |
+
onset_mask_width,
|
487 |
+
typical_filtering,
|
488 |
+
typical_mass,
|
489 |
+
typical_min_tokens,
|
490 |
+
seed,
|
491 |
+
model_choice,
|
492 |
+
n_mask_codebooks,
|
493 |
+
pitch_shift_amt,
|
494 |
+
sample_cutoff
|
495 |
+
],
|
496 |
+
outputs=[audio_outs[0]],
|
497 |
api_name="vamp"
|
498 |
)
|
499 |
|
vampnet/newmask.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from .util import scalar_to_batch_tensor
|
6 |
+
|
7 |
+
def _gamma(r):
|
8 |
+
return (r * torch.pi / 2).cos().clamp(1e-10, 1.0)
|
9 |
+
|
10 |
+
def _invgamma(y):
|
11 |
+
if not torch.is_tensor(y):
|
12 |
+
y = torch.tensor(y)[None]
|
13 |
+
return 2 * y.acos() / torch.pi
|
14 |
+
|
15 |
+
def full_mask(x: torch.Tensor):
|
16 |
+
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
17 |
+
return torch.ones_like(x).int()
|
18 |
+
|
19 |
+
def empty_mask(x: torch.Tensor):
|
20 |
+
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
21 |
+
return torch.zeros_like(x).int()
|
22 |
+
|
23 |
+
def apply_mask(
|
24 |
+
x: torch.Tensor,
|
25 |
+
mask: torch.Tensor,
|
26 |
+
mask_token: int
|
27 |
+
):
|
28 |
+
assert mask.ndim == 3, f"mask must be (batch, n_codebooks, seq), but got {mask.ndim}"
|
29 |
+
assert mask.shape == x.shape, f"mask must be same shape as x, but got {mask.shape} and {x.shape}"
|
30 |
+
assert mask.dtype == torch.int, f"mask must be int dtype, but got {mask.dtype}"
|
31 |
+
assert ~torch.any(mask > 1), "mask must be binary"
|
32 |
+
assert ~torch.any(mask < 0), "mask must be binary"
|
33 |
+
mask = mask.int()
|
34 |
+
|
35 |
+
fill_x = torch.full_like(x, mask_token)
|
36 |
+
x = x * (1 - mask) + fill_x * mask
|
37 |
+
|
38 |
+
return x
|
39 |
+
|
40 |
+
def random(
|
41 |
+
x: torch.Tensor,
|
42 |
+
r: torch.Tensor
|
43 |
+
):
|
44 |
+
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
45 |
+
if not isinstance(r, torch.Tensor):
|
46 |
+
r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device)
|
47 |
+
|
48 |
+
r = _gamma(r)[:, None, None]
|
49 |
+
probs = torch.ones_like(x) * r
|
50 |
+
|
51 |
+
mask = torch.bernoulli(probs)
|
52 |
+
mask = mask.round().int()
|
53 |
+
|
54 |
+
return mask, torch.zeros_like(mask).bool()
|
55 |
+
|
56 |
+
def random_along_time(x: torch.Tensor, r: torch.Tensor):
|
57 |
+
assert x.ndim == 3, "x must be (batch, channel, seq)"
|
58 |
+
if not isinstance(r, torch.Tensor):
|
59 |
+
r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device)
|
60 |
+
|
61 |
+
x = x[:, 0, :]
|
62 |
+
r = _gamma(r)[:, None]
|
63 |
+
probs = torch.ones_like(x) * r
|
64 |
+
|
65 |
+
mask = torch.bernoulli(probs)
|
66 |
+
mask = mask.round().int()
|
67 |
+
|
68 |
+
return mask
|
69 |
+
|
70 |
+
|
71 |
+
def stemgen_random(x: torch.Tensor, r: torch.Tensor):
|
72 |
+
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
73 |
+
if not isinstance(r, torch.Tensor):
|
74 |
+
r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device)
|
75 |
+
|
76 |
+
# Assuming x is your input tensor and r is the probability for the Bernoulli distribution
|
77 |
+
nb, nc, nt = x.shape
|
78 |
+
|
79 |
+
# Randomly sample a codebook level to infer for each item in the batch
|
80 |
+
c = torch.randint(0, nc, (nb,)).to(x.device)
|
81 |
+
|
82 |
+
# Create a mask tensor of the same shape as x, initially filled with ones
|
83 |
+
mask = torch.ones_like(x).long().to(x.device)
|
84 |
+
ignore_indices_mask = torch.zeros_like(x).long().to(x.device)
|
85 |
+
|
86 |
+
# Iterate over each item in the batch
|
87 |
+
for i in range(nb):
|
88 |
+
# Create the Bernoulli mask for the sampled level
|
89 |
+
level_mask = torch.bernoulli(torch.ones(nt).to(x.device) * r[i]).long()
|
90 |
+
|
91 |
+
# Apply the mask to the sampled level
|
92 |
+
mask[i, c[i]] = level_mask
|
93 |
+
|
94 |
+
# All levels below the sampled level are unmasked (zeros)
|
95 |
+
mask[i, :c[i]] = 0
|
96 |
+
ignore_indices_mask[i, :c[i]] = 1
|
97 |
+
|
98 |
+
# All levels above the sampled level are masked (ones)
|
99 |
+
mask[i, c[i]+1:] = 1
|
100 |
+
ignore_indices_mask[i, c[i]+1:] = 1
|
101 |
+
|
102 |
+
# save a debug mask to np txt
|
103 |
+
# import numpy as np
|
104 |
+
# np.savetxt("mask.txt", mask[0].cpu().numpy(), fmt="%d")
|
105 |
+
# np.savetxt("ignore_indices_mask.txt", ignore_indices_mask[0].cpu().numpy(), fmt="%d")
|
106 |
+
|
107 |
+
return mask.int(), ignore_indices_mask.bool()
|
108 |
+
|
109 |
+
|
110 |
+
def hugo_random(x: torch.Tensor, r:torch.Tensor):
|
111 |
+
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
112 |
+
if not isinstance(r, torch.Tensor):
|
113 |
+
r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float()
|
114 |
+
|
115 |
+
r = _gamma(r)[:, None, None]
|
116 |
+
|
117 |
+
nb, nc, nt = x.shape
|
118 |
+
|
119 |
+
probs = torch.ones_like(x) * r
|
120 |
+
mask = torch.bernoulli(probs)
|
121 |
+
# alternatively, the mask level could be the cumsum of the mask
|
122 |
+
mask = mask.round().long().to(x.device)
|
123 |
+
mask_levels = nc - mask.sum(dim=1) - 1
|
124 |
+
|
125 |
+
# create a new mask, where all levels below the mask level are masked
|
126 |
+
# shape (nb, nc, nt) where new_mask[i, CB:, t] = 1, CB = mask_level[i, t]
|
127 |
+
# mask = mask_levels[:, :, None] > torch.arange(nc)[None, None, :]
|
128 |
+
mask = (mask_levels[:, None, :] < torch.arange(nc, device=x.device)[None, :, None]).long()
|
129 |
+
|
130 |
+
ignore_levels = mask_levels + 1
|
131 |
+
ignore_indices_mask = (ignore_levels[:, None, :] < torch.arange(nc, device=x.device)[None, :, None]).long()
|
132 |
+
|
133 |
+
# for _b in range(nb):
|
134 |
+
# for _t in range(nt):
|
135 |
+
# for _c in range(nc):
|
136 |
+
# if mask[_b, _c, _t] == 1:
|
137 |
+
# mask[_b, _c:, _t] = 1
|
138 |
+
# ignore_indices_mask[_b, _c + 1:, _t] = 1
|
139 |
+
# break
|
140 |
+
|
141 |
+
return mask.long(), ignore_indices_mask.bool()
|
142 |
+
|
143 |
+
|
144 |
+
def better_cond_random_but_not_working(x: torch.Tensor, r:torch.Tensor):
|
145 |
+
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
146 |
+
if not isinstance(r, torch.Tensor):
|
147 |
+
r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float()
|
148 |
+
|
149 |
+
r = _gamma(r)[:, None, None]
|
150 |
+
|
151 |
+
nb, nc, nt = x.shape
|
152 |
+
|
153 |
+
probs = torch.ones_like(x) * r
|
154 |
+
mask = torch.bernoulli(probs)
|
155 |
+
|
156 |
+
mask = mask.round().long().to(x.device)
|
157 |
+
|
158 |
+
# there cannot be anything unmasked if there's an masked token
|
159 |
+
# in the same timestep and below it
|
160 |
+
for i in range(nb):
|
161 |
+
for j in range(nc):
|
162 |
+
for t in range(nt):
|
163 |
+
if mask[i, j, t] == 1:
|
164 |
+
mask[i, j:, t] = 1
|
165 |
+
break
|
166 |
+
|
167 |
+
# an ignore indices mask, since we can truly only predict one token
|
168 |
+
# per timestep
|
169 |
+
ignore_indices = torch.zeros_like(x)
|
170 |
+
for i in range(nb):
|
171 |
+
for j in range(nc):
|
172 |
+
for t in range(nt):
|
173 |
+
if mask[i, j, t] == 1:
|
174 |
+
ignore_indices[i, j, t+1:] = 1
|
175 |
+
break
|
176 |
+
return mask.int(), ignore_indices
|
177 |
+
|
178 |
+
|
179 |
+
@torch.jit.script_if_tracing
|
180 |
+
def linear_random(
|
181 |
+
x: torch.Tensor,
|
182 |
+
r: torch.Tensor,
|
183 |
+
):
|
184 |
+
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
185 |
+
if not isinstance(r, torch.Tensor):
|
186 |
+
r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float()
|
187 |
+
r = r[:, None, None]
|
188 |
+
|
189 |
+
probs = torch.ones_like(x).to(x.device).float()
|
190 |
+
# expand to batch and codebook dims
|
191 |
+
probs = probs.expand(x.shape[0], x.shape[1], -1)
|
192 |
+
probs = probs * r
|
193 |
+
|
194 |
+
mask = torch.bernoulli(probs)
|
195 |
+
mask = mask.round().int()
|
196 |
+
|
197 |
+
return mask
|
198 |
+
|
199 |
+
@torch.jit.script_if_tracing
|
200 |
+
def inpaint(x: torch.Tensor, n_prefix: int, n_suffix: int,):
|
201 |
+
assert n_prefix is not None
|
202 |
+
assert n_suffix is not None
|
203 |
+
|
204 |
+
mask = full_mask(x)
|
205 |
+
|
206 |
+
# if we have a prefix or suffix, set their mask prob to 0
|
207 |
+
if n_prefix > 0:
|
208 |
+
if not isinstance(n_prefix, torch.Tensor):
|
209 |
+
n_prefix = scalar_to_batch_tensor(n_prefix, x.shape[0]).to(x.device)
|
210 |
+
for i, n in enumerate(n_prefix):
|
211 |
+
if n > 0:
|
212 |
+
mask[i, :, :n] = 0.0
|
213 |
+
if n_suffix > 0:
|
214 |
+
if not isinstance(n_suffix, torch.Tensor):
|
215 |
+
n_suffix = scalar_to_batch_tensor(n_suffix, x.shape[0]).to(x.device)
|
216 |
+
for i, n in enumerate(n_suffix):
|
217 |
+
if n > 0:
|
218 |
+
mask[i, :, -n:] = 0.0
|
219 |
+
return mask
|
220 |
+
|
221 |
+
@torch.jit.script_if_tracing
|
222 |
+
def periodic_mask(x: torch.Tensor, period: int,
|
223 |
+
width: int = 1, random_roll: bool = False,):
|
224 |
+
mask = full_mask(x)
|
225 |
+
if period == 0:
|
226 |
+
return full_mask(x)
|
227 |
+
|
228 |
+
if not isinstance(period, torch.Tensor):
|
229 |
+
period = scalar_to_batch_tensor(period, x.shape[0])
|
230 |
+
if period.ndim == 0:
|
231 |
+
period = period[None]
|
232 |
+
|
233 |
+
for i, factor in enumerate(period):
|
234 |
+
if factor == 0:
|
235 |
+
continue
|
236 |
+
for j in range(mask.shape[-1]):
|
237 |
+
if j % factor == 0:
|
238 |
+
# figure out how wide the mask should be
|
239 |
+
j_start = max(0, j - width // 2 )
|
240 |
+
j_end = min(mask.shape[-1] - 1, j + width // 2 ) + 1
|
241 |
+
# flip a coin for each position in the mask
|
242 |
+
j_mask = torch.bernoulli(torch.ones(j_end - j_start))
|
243 |
+
assert torch.all(j_mask == 1)
|
244 |
+
j_fill = torch.ones_like(j_mask) * (1 - j_mask)
|
245 |
+
assert torch.all(j_fill == 0)
|
246 |
+
# fill
|
247 |
+
mask[i, :, j_start:j_end] = j_fill
|
248 |
+
|
249 |
+
return mask
|
250 |
+
|
251 |
+
def codebook_unmask(
|
252 |
+
mask: torch.Tensor,
|
253 |
+
n_conditioning_codebooks: int
|
254 |
+
):
|
255 |
+
if n_conditioning_codebooks == None:
|
256 |
+
return mask
|
257 |
+
# if we have any conditioning codebooks, set their mask to 0
|
258 |
+
mask = mask.clone()
|
259 |
+
mask[:, :n_conditioning_codebooks, :] = 0
|
260 |
+
return mask
|
261 |
+
|
262 |
+
def codebook_mask(mask: torch.Tensor, val1: int, val2: int = None):
|
263 |
+
mask = mask.clone()
|
264 |
+
mask[:, val1:, :] = 1
|
265 |
+
# val2 = val2 or val1
|
266 |
+
# vs = torch.linspace(val1, val2, mask.shape[1])
|
267 |
+
# for t, v in enumerate(vs):
|
268 |
+
# v = int(v)
|
269 |
+
# mask[:, v:, t] = 1
|
270 |
+
|
271 |
+
return mask
|
272 |
+
|
273 |
+
@torch.jit.script_if_tracing
|
274 |
+
def mask_and(
|
275 |
+
mask1: torch.Tensor,
|
276 |
+
mask2: torch.Tensor
|
277 |
+
):
|
278 |
+
assert mask1.shape == mask2.shape, "masks must be same shape"
|
279 |
+
return torch.min(mask1, mask2)
|
280 |
+
|
281 |
+
def drop_ones(mask: torch.Tensor, p: float):
|
282 |
+
oldshp = mask.shape
|
283 |
+
mask = mask.view(-1)
|
284 |
+
|
285 |
+
# find ones idxs
|
286 |
+
ones_idxs = torch.where(mask == 1)[0]
|
287 |
+
# shuffle idxs
|
288 |
+
ones_idxs_idxs = torch.randperm(len(ones_idxs))
|
289 |
+
ones_idxs = ones_idxs[ones_idxs_idxs]
|
290 |
+
# drop p% of ones
|
291 |
+
ones_idxs = ones_idxs[:int(len(ones_idxs) * p)]
|
292 |
+
# set those idxs to 0
|
293 |
+
mask[ones_idxs] = 0
|
294 |
+
|
295 |
+
mask = mask.view(oldshp)
|
296 |
+
return mask
|
297 |
+
|
298 |
+
|
299 |
+
def mask_or(
|
300 |
+
mask1: torch.Tensor,
|
301 |
+
mask2: torch.Tensor
|
302 |
+
):
|
303 |
+
assert mask1.shape == mask2.shape, f"masks must be same shape, but got {mask1.shape} and {mask2.shape}"
|
304 |
+
assert mask1.max() <= 1, "mask1 must be binary"
|
305 |
+
assert mask2.max() <= 1, "mask2 must be binary"
|
306 |
+
assert mask1.min() >= 0, "mask1 must be binary"
|
307 |
+
assert mask2.min() >= 0, "mask2 must be binary"
|
308 |
+
return (mask1 + mask2).clamp(0, 1)
|
309 |
+
|
310 |
+
def time_stretch_mask(
|
311 |
+
x: torch.Tensor,
|
312 |
+
stretch_factor: int,
|
313 |
+
):
|
314 |
+
assert stretch_factor >= 1, "stretch factor must be >= 1"
|
315 |
+
c_seq_len = x.shape[-1]
|
316 |
+
x = x.repeat_interleave(stretch_factor, dim=-1)
|
317 |
+
|
318 |
+
# trim cz to the original length
|
319 |
+
x = x[:, :, :c_seq_len]
|
320 |
+
|
321 |
+
mask = periodic_mask(x, stretch_factor, width=1)
|
322 |
+
return mask
|
323 |
+
|
324 |
+
def onset_mask(
|
325 |
+
onset_frame_idxs: torch.Tensor,
|
326 |
+
z: torch.Tensor,
|
327 |
+
width: int = 1,
|
328 |
+
):
|
329 |
+
if len(onset_frame_idxs) == 0:
|
330 |
+
print("no onsets detected")
|
331 |
+
# print("onset_frame_idxs", onset_frame_idxs)
|
332 |
+
# print("mask shape", z.shape)
|
333 |
+
|
334 |
+
mask = torch.ones_like(z).int()
|
335 |
+
for idx in onset_frame_idxs:
|
336 |
+
mask[:, :, idx-width:idx+width] = 0
|
337 |
+
|
338 |
+
return mask.int()
|
339 |
+
|
340 |
+
def tria_mask(
|
341 |
+
codes: torch.Tensor,
|
342 |
+
min_amt: float = 0.1,
|
343 |
+
max_amt: float = 0.4,
|
344 |
+
):
|
345 |
+
"""
|
346 |
+
unmasks a prefix of the codes tensor,
|
347 |
+
in the range provided
|
348 |
+
"""
|
349 |
+
|
350 |
+
mask = full_mask(codes)
|
351 |
+
nb, nc, nt = codes.shape
|
352 |
+
for i in range(nb):
|
353 |
+
amt = torch.rand(1) * (max_amt - min_amt) + min_amt
|
354 |
+
amt = int(amt * nt)
|
355 |
+
mask[i, :, :amt] = 0
|
356 |
+
|
357 |
+
return mask
|
358 |
+
|
359 |
+
|
360 |
+
|
361 |
+
|
362 |
+
|
363 |
+
|
364 |
+
if __name__ == "__main__":
|
365 |
+
sig = AudioSignal("assets/example.wav")
|