hugo flores garcia commited on
Commit
12dc48a
·
1 Parent(s): 2d0bc4e

for use with sound objects

Browse files
Files changed (2) hide show
  1. app.py +106 -33
  2. vampnet/newmask.py +365 -0
app.py CHANGED
@@ -57,6 +57,47 @@ def shift_pitch(signal, interval: int):
57
  return signal
58
 
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  @spaces.GPU
61
  def _vamp(
62
  seed, input_audio, model_choice,
@@ -78,7 +119,7 @@ def _vamp(
78
  sr, input_audio = input_audio
79
  input_audio = input_audio / np.iinfo(input_audio.dtype).max
80
 
81
- sig = at.AudioSignal(input_audio, sr)
82
 
83
  # reload the model if necessary
84
  interface.load_finetuned(model_choice)
@@ -88,18 +129,15 @@ def _vamp(
88
 
89
  codes = interface.encode(sig)
90
 
91
- mask = interface.build_mask(
92
- codes, sig,
93
- rand_mask_intensity=1.0,
94
- prefix_s=0.0,
95
- suffix_s=0.0,
96
- periodic_prompt=int(periodic_p),
97
- periodic_prompt_width=periodic_w,
98
- onset_mask_width=onset_mask_width,
99
- _dropout=dropout,
100
- upper_codebook_mask=int(n_mask_codebooks),
101
- )
102
-
103
 
104
  # save the mask as a txt file
105
  interface.set_chunk_size(10.0)
@@ -145,24 +183,45 @@ def vamp(data):
145
  api=False,
146
  )
147
 
148
- def api_vamp(data):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  return _vamp(
150
- seed=data[seed],
151
- input_audio=data[input_audio],
152
- model_choice=data[model_choice],
153
- pitch_shift_amt=data[pitch_shift_amt],
154
- periodic_p=data[periodic_p],
155
- n_mask_codebooks=data[n_mask_codebooks],
156
- periodic_w=data[periodic_w],
157
- onset_mask_width=data[onset_mask_width],
158
- dropout=data[dropout],
159
- sampletemp=data[sampletemp],
160
- typical_filtering=data[typical_filtering],
161
- typical_mass=data[typical_mass],
162
- typical_min_tokens=data[typical_min_tokens],
163
- top_p=data[top_p],
164
- sample_cutoff=data[sample_cutoff],
165
- stretch_factor=data[stretch_factor],
166
  api=True,
167
  )
168
 
@@ -258,7 +317,7 @@ with gr.Blocks() as demo:
258
  minimum=0,
259
  maximum=100,
260
  step=1,
261
- value=0, visible=False
262
  )
263
 
264
  n_mask_codebooks = gr.Slider(
@@ -419,8 +478,22 @@ with gr.Blocks() as demo:
419
  api_vamp_button = gr.Button("api vamp", visible=True)
420
  api_vamp_button.click(
421
  fn=api_vamp,
422
- inputs=_inputs,
423
- outputs=[audio_outs[0]],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
  api_name="vamp"
425
  )
426
 
 
57
  return signal
58
 
59
 
60
+ def onsets(sig: at.AudioSignal, hop_length: int):
61
+ assert sig.batch_size == 1, "batch size must be 1"
62
+ assert sig.num_channels == 1, "mono signals only"
63
+ import librosa
64
+ onset_frame_idxs = librosa.onset.onset_detect(
65
+ y=sig.samples[0][0].detach().cpu().numpy(), sr=sig.sample_rate,
66
+ hop_length=hop_length,
67
+ backtrack=True,
68
+ )
69
+ return onset_frame_idxs
70
+
71
+
72
+ def new_vampnet_mask(self,
73
+ codes,
74
+ onset_idxs,
75
+ width: int = 5,
76
+ periodic_prompt=2,
77
+ upper_codebook_mask=1,
78
+ drop_amt: float = 0.1
79
+ ):
80
+ from vampnet.newmask import mask_and, mask_or, onset_mask, periodic_mask, drop_ones, codebook_mask
81
+ mask = mask_and(
82
+ periodic_mask(codes, periodic_prompt, 1, random_roll=False),
83
+ mask_or( # this re-masks the onsets, according to a periodic schedule
84
+ onset_mask(onset_idxs, codes, width=width),
85
+ periodic_mask(codes, periodic_prompt, 1, random_roll=False),
86
+ )
87
+ ).int()
88
+ # make sure the onset idxs themselves are unmasked
89
+ # mask = 1 - mask
90
+ mask[:, :, onset_idxs] = 0
91
+ mask = mask.cpu() # debug
92
+ mask = 1-drop_ones(1-mask, drop_amt)
93
+ mask = codebook_mask(mask, upper_codebook_mask)
94
+
95
+
96
+ # save mask as txt (ints)
97
+ np.savetxt("scratch/rms_mask.txt", mask[0].cpu().numpy(), fmt='%d')
98
+ mask = mask.to(self.device)
99
+ return mask[:, :, :]
100
+
101
  @spaces.GPU
102
  def _vamp(
103
  seed, input_audio, model_choice,
 
119
  sr, input_audio = input_audio
120
  input_audio = input_audio / np.iinfo(input_audio.dtype).max
121
 
122
+ sig = at.AudioSignal(input_audio, sr).to_mono()
123
 
124
  # reload the model if necessary
125
  interface.load_finetuned(model_choice)
 
129
 
130
  codes = interface.encode(sig)
131
 
132
+ mask = new_vampnet_mask(
133
+ interface,
134
+ codes,
135
+ onset_idxs=onsets(sig, hop_length=interface.codec.hop_length),
136
+ width=onset_mask_width,
137
+ periodic_prompt=periodic_p,
138
+ upper_codebook_mask=n_mask_codebooks,
139
+ drop_amt=dropout
140
+ ).long()
 
 
 
141
 
142
  # save the mask as a txt file
143
  interface.set_chunk_size(10.0)
 
183
  api=False,
184
  )
185
 
186
+ # def api_vamp(data):
187
+ # return _vamp(
188
+ # seed=data[seed],
189
+ # input_audio=data[input_audio],
190
+ # model_choice=data[model_choice],
191
+ # pitch_shift_amt=data[pitch_shift_amt],
192
+ # periodic_p=data[periodic_p],
193
+ # n_mask_codebooks=data[n_mask_codebooks],
194
+ # periodic_w=data[periodic_w],
195
+ # onset_mask_width=data[onset_mask_width],
196
+ # dropout=data[dropout],
197
+ # sampletemp=data[sampletemp],
198
+ # typical_filtering=data[typical_filtering],
199
+ # typical_mass=data[typical_mass],
200
+ # typical_min_tokens=data[typical_min_tokens],
201
+ # top_p=data[top_p],
202
+ # sample_cutoff=data[sample_cutoff],
203
+ # stretch_factor=data[stretch_factor],
204
+ # api=True,
205
+ # )
206
+
207
+ def api_vamp(input_audio, sampletemp, top_p, periodic_p, periodic_w, dropout, stretch_factor, onset_mask_width, typical_filtering, typical_mass, typical_min_tokens, seed, model_choice, n_mask_codebooks, pitch_shift_amt, sample_cutoff):
208
  return _vamp(
209
+ seed=seed,
210
+ input_audio=input_audio,
211
+ model_choice=model_choice,
212
+ pitch_shift_amt=pitch_shift_amt,
213
+ periodic_p=periodic_p,
214
+ n_mask_codebooks=n_mask_codebooks,
215
+ periodic_w=periodic_w,
216
+ onset_mask_width=onset_mask_width,
217
+ dropout=dropout,
218
+ sampletemp=sampletemp,
219
+ typical_filtering=typical_filtering,
220
+ typical_mass=typical_mass,
221
+ typical_min_tokens=typical_min_tokens,
222
+ top_p=top_p,
223
+ sample_cutoff=sample_cutoff,
224
+ stretch_factor=stretch_factor,
225
  api=True,
226
  )
227
 
 
317
  minimum=0,
318
  maximum=100,
319
  step=1,
320
+ value=0, visible=True
321
  )
322
 
323
  n_mask_codebooks = gr.Slider(
 
478
  api_vamp_button = gr.Button("api vamp", visible=True)
479
  api_vamp_button.click(
480
  fn=api_vamp,
481
+ inputs=[input_audio,
482
+ sampletemp, top_p,
483
+ periodic_p, periodic_w,
484
+ dropout,
485
+ stretch_factor,
486
+ onset_mask_width,
487
+ typical_filtering,
488
+ typical_mass,
489
+ typical_min_tokens,
490
+ seed,
491
+ model_choice,
492
+ n_mask_codebooks,
493
+ pitch_shift_amt,
494
+ sample_cutoff
495
+ ],
496
+ outputs=[audio_outs[0]],
497
  api_name="vamp"
498
  )
499
 
vampnet/newmask.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from .util import scalar_to_batch_tensor
6
+
7
+ def _gamma(r):
8
+ return (r * torch.pi / 2).cos().clamp(1e-10, 1.0)
9
+
10
+ def _invgamma(y):
11
+ if not torch.is_tensor(y):
12
+ y = torch.tensor(y)[None]
13
+ return 2 * y.acos() / torch.pi
14
+
15
+ def full_mask(x: torch.Tensor):
16
+ assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
17
+ return torch.ones_like(x).int()
18
+
19
+ def empty_mask(x: torch.Tensor):
20
+ assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
21
+ return torch.zeros_like(x).int()
22
+
23
+ def apply_mask(
24
+ x: torch.Tensor,
25
+ mask: torch.Tensor,
26
+ mask_token: int
27
+ ):
28
+ assert mask.ndim == 3, f"mask must be (batch, n_codebooks, seq), but got {mask.ndim}"
29
+ assert mask.shape == x.shape, f"mask must be same shape as x, but got {mask.shape} and {x.shape}"
30
+ assert mask.dtype == torch.int, f"mask must be int dtype, but got {mask.dtype}"
31
+ assert ~torch.any(mask > 1), "mask must be binary"
32
+ assert ~torch.any(mask < 0), "mask must be binary"
33
+ mask = mask.int()
34
+
35
+ fill_x = torch.full_like(x, mask_token)
36
+ x = x * (1 - mask) + fill_x * mask
37
+
38
+ return x
39
+
40
+ def random(
41
+ x: torch.Tensor,
42
+ r: torch.Tensor
43
+ ):
44
+ assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
45
+ if not isinstance(r, torch.Tensor):
46
+ r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device)
47
+
48
+ r = _gamma(r)[:, None, None]
49
+ probs = torch.ones_like(x) * r
50
+
51
+ mask = torch.bernoulli(probs)
52
+ mask = mask.round().int()
53
+
54
+ return mask, torch.zeros_like(mask).bool()
55
+
56
+ def random_along_time(x: torch.Tensor, r: torch.Tensor):
57
+ assert x.ndim == 3, "x must be (batch, channel, seq)"
58
+ if not isinstance(r, torch.Tensor):
59
+ r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device)
60
+
61
+ x = x[:, 0, :]
62
+ r = _gamma(r)[:, None]
63
+ probs = torch.ones_like(x) * r
64
+
65
+ mask = torch.bernoulli(probs)
66
+ mask = mask.round().int()
67
+
68
+ return mask
69
+
70
+
71
+ def stemgen_random(x: torch.Tensor, r: torch.Tensor):
72
+ assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
73
+ if not isinstance(r, torch.Tensor):
74
+ r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device)
75
+
76
+ # Assuming x is your input tensor and r is the probability for the Bernoulli distribution
77
+ nb, nc, nt = x.shape
78
+
79
+ # Randomly sample a codebook level to infer for each item in the batch
80
+ c = torch.randint(0, nc, (nb,)).to(x.device)
81
+
82
+ # Create a mask tensor of the same shape as x, initially filled with ones
83
+ mask = torch.ones_like(x).long().to(x.device)
84
+ ignore_indices_mask = torch.zeros_like(x).long().to(x.device)
85
+
86
+ # Iterate over each item in the batch
87
+ for i in range(nb):
88
+ # Create the Bernoulli mask for the sampled level
89
+ level_mask = torch.bernoulli(torch.ones(nt).to(x.device) * r[i]).long()
90
+
91
+ # Apply the mask to the sampled level
92
+ mask[i, c[i]] = level_mask
93
+
94
+ # All levels below the sampled level are unmasked (zeros)
95
+ mask[i, :c[i]] = 0
96
+ ignore_indices_mask[i, :c[i]] = 1
97
+
98
+ # All levels above the sampled level are masked (ones)
99
+ mask[i, c[i]+1:] = 1
100
+ ignore_indices_mask[i, c[i]+1:] = 1
101
+
102
+ # save a debug mask to np txt
103
+ # import numpy as np
104
+ # np.savetxt("mask.txt", mask[0].cpu().numpy(), fmt="%d")
105
+ # np.savetxt("ignore_indices_mask.txt", ignore_indices_mask[0].cpu().numpy(), fmt="%d")
106
+
107
+ return mask.int(), ignore_indices_mask.bool()
108
+
109
+
110
+ def hugo_random(x: torch.Tensor, r:torch.Tensor):
111
+ assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
112
+ if not isinstance(r, torch.Tensor):
113
+ r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float()
114
+
115
+ r = _gamma(r)[:, None, None]
116
+
117
+ nb, nc, nt = x.shape
118
+
119
+ probs = torch.ones_like(x) * r
120
+ mask = torch.bernoulli(probs)
121
+ # alternatively, the mask level could be the cumsum of the mask
122
+ mask = mask.round().long().to(x.device)
123
+ mask_levels = nc - mask.sum(dim=1) - 1
124
+
125
+ # create a new mask, where all levels below the mask level are masked
126
+ # shape (nb, nc, nt) where new_mask[i, CB:, t] = 1, CB = mask_level[i, t]
127
+ # mask = mask_levels[:, :, None] > torch.arange(nc)[None, None, :]
128
+ mask = (mask_levels[:, None, :] < torch.arange(nc, device=x.device)[None, :, None]).long()
129
+
130
+ ignore_levels = mask_levels + 1
131
+ ignore_indices_mask = (ignore_levels[:, None, :] < torch.arange(nc, device=x.device)[None, :, None]).long()
132
+
133
+ # for _b in range(nb):
134
+ # for _t in range(nt):
135
+ # for _c in range(nc):
136
+ # if mask[_b, _c, _t] == 1:
137
+ # mask[_b, _c:, _t] = 1
138
+ # ignore_indices_mask[_b, _c + 1:, _t] = 1
139
+ # break
140
+
141
+ return mask.long(), ignore_indices_mask.bool()
142
+
143
+
144
+ def better_cond_random_but_not_working(x: torch.Tensor, r:torch.Tensor):
145
+ assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
146
+ if not isinstance(r, torch.Tensor):
147
+ r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float()
148
+
149
+ r = _gamma(r)[:, None, None]
150
+
151
+ nb, nc, nt = x.shape
152
+
153
+ probs = torch.ones_like(x) * r
154
+ mask = torch.bernoulli(probs)
155
+
156
+ mask = mask.round().long().to(x.device)
157
+
158
+ # there cannot be anything unmasked if there's an masked token
159
+ # in the same timestep and below it
160
+ for i in range(nb):
161
+ for j in range(nc):
162
+ for t in range(nt):
163
+ if mask[i, j, t] == 1:
164
+ mask[i, j:, t] = 1
165
+ break
166
+
167
+ # an ignore indices mask, since we can truly only predict one token
168
+ # per timestep
169
+ ignore_indices = torch.zeros_like(x)
170
+ for i in range(nb):
171
+ for j in range(nc):
172
+ for t in range(nt):
173
+ if mask[i, j, t] == 1:
174
+ ignore_indices[i, j, t+1:] = 1
175
+ break
176
+ return mask.int(), ignore_indices
177
+
178
+
179
+ @torch.jit.script_if_tracing
180
+ def linear_random(
181
+ x: torch.Tensor,
182
+ r: torch.Tensor,
183
+ ):
184
+ assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
185
+ if not isinstance(r, torch.Tensor):
186
+ r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float()
187
+ r = r[:, None, None]
188
+
189
+ probs = torch.ones_like(x).to(x.device).float()
190
+ # expand to batch and codebook dims
191
+ probs = probs.expand(x.shape[0], x.shape[1], -1)
192
+ probs = probs * r
193
+
194
+ mask = torch.bernoulli(probs)
195
+ mask = mask.round().int()
196
+
197
+ return mask
198
+
199
+ @torch.jit.script_if_tracing
200
+ def inpaint(x: torch.Tensor, n_prefix: int, n_suffix: int,):
201
+ assert n_prefix is not None
202
+ assert n_suffix is not None
203
+
204
+ mask = full_mask(x)
205
+
206
+ # if we have a prefix or suffix, set their mask prob to 0
207
+ if n_prefix > 0:
208
+ if not isinstance(n_prefix, torch.Tensor):
209
+ n_prefix = scalar_to_batch_tensor(n_prefix, x.shape[0]).to(x.device)
210
+ for i, n in enumerate(n_prefix):
211
+ if n > 0:
212
+ mask[i, :, :n] = 0.0
213
+ if n_suffix > 0:
214
+ if not isinstance(n_suffix, torch.Tensor):
215
+ n_suffix = scalar_to_batch_tensor(n_suffix, x.shape[0]).to(x.device)
216
+ for i, n in enumerate(n_suffix):
217
+ if n > 0:
218
+ mask[i, :, -n:] = 0.0
219
+ return mask
220
+
221
+ @torch.jit.script_if_tracing
222
+ def periodic_mask(x: torch.Tensor, period: int,
223
+ width: int = 1, random_roll: bool = False,):
224
+ mask = full_mask(x)
225
+ if period == 0:
226
+ return full_mask(x)
227
+
228
+ if not isinstance(period, torch.Tensor):
229
+ period = scalar_to_batch_tensor(period, x.shape[0])
230
+ if period.ndim == 0:
231
+ period = period[None]
232
+
233
+ for i, factor in enumerate(period):
234
+ if factor == 0:
235
+ continue
236
+ for j in range(mask.shape[-1]):
237
+ if j % factor == 0:
238
+ # figure out how wide the mask should be
239
+ j_start = max(0, j - width // 2 )
240
+ j_end = min(mask.shape[-1] - 1, j + width // 2 ) + 1
241
+ # flip a coin for each position in the mask
242
+ j_mask = torch.bernoulli(torch.ones(j_end - j_start))
243
+ assert torch.all(j_mask == 1)
244
+ j_fill = torch.ones_like(j_mask) * (1 - j_mask)
245
+ assert torch.all(j_fill == 0)
246
+ # fill
247
+ mask[i, :, j_start:j_end] = j_fill
248
+
249
+ return mask
250
+
251
+ def codebook_unmask(
252
+ mask: torch.Tensor,
253
+ n_conditioning_codebooks: int
254
+ ):
255
+ if n_conditioning_codebooks == None:
256
+ return mask
257
+ # if we have any conditioning codebooks, set their mask to 0
258
+ mask = mask.clone()
259
+ mask[:, :n_conditioning_codebooks, :] = 0
260
+ return mask
261
+
262
+ def codebook_mask(mask: torch.Tensor, val1: int, val2: int = None):
263
+ mask = mask.clone()
264
+ mask[:, val1:, :] = 1
265
+ # val2 = val2 or val1
266
+ # vs = torch.linspace(val1, val2, mask.shape[1])
267
+ # for t, v in enumerate(vs):
268
+ # v = int(v)
269
+ # mask[:, v:, t] = 1
270
+
271
+ return mask
272
+
273
+ @torch.jit.script_if_tracing
274
+ def mask_and(
275
+ mask1: torch.Tensor,
276
+ mask2: torch.Tensor
277
+ ):
278
+ assert mask1.shape == mask2.shape, "masks must be same shape"
279
+ return torch.min(mask1, mask2)
280
+
281
+ def drop_ones(mask: torch.Tensor, p: float):
282
+ oldshp = mask.shape
283
+ mask = mask.view(-1)
284
+
285
+ # find ones idxs
286
+ ones_idxs = torch.where(mask == 1)[0]
287
+ # shuffle idxs
288
+ ones_idxs_idxs = torch.randperm(len(ones_idxs))
289
+ ones_idxs = ones_idxs[ones_idxs_idxs]
290
+ # drop p% of ones
291
+ ones_idxs = ones_idxs[:int(len(ones_idxs) * p)]
292
+ # set those idxs to 0
293
+ mask[ones_idxs] = 0
294
+
295
+ mask = mask.view(oldshp)
296
+ return mask
297
+
298
+
299
+ def mask_or(
300
+ mask1: torch.Tensor,
301
+ mask2: torch.Tensor
302
+ ):
303
+ assert mask1.shape == mask2.shape, f"masks must be same shape, but got {mask1.shape} and {mask2.shape}"
304
+ assert mask1.max() <= 1, "mask1 must be binary"
305
+ assert mask2.max() <= 1, "mask2 must be binary"
306
+ assert mask1.min() >= 0, "mask1 must be binary"
307
+ assert mask2.min() >= 0, "mask2 must be binary"
308
+ return (mask1 + mask2).clamp(0, 1)
309
+
310
+ def time_stretch_mask(
311
+ x: torch.Tensor,
312
+ stretch_factor: int,
313
+ ):
314
+ assert stretch_factor >= 1, "stretch factor must be >= 1"
315
+ c_seq_len = x.shape[-1]
316
+ x = x.repeat_interleave(stretch_factor, dim=-1)
317
+
318
+ # trim cz to the original length
319
+ x = x[:, :, :c_seq_len]
320
+
321
+ mask = periodic_mask(x, stretch_factor, width=1)
322
+ return mask
323
+
324
+ def onset_mask(
325
+ onset_frame_idxs: torch.Tensor,
326
+ z: torch.Tensor,
327
+ width: int = 1,
328
+ ):
329
+ if len(onset_frame_idxs) == 0:
330
+ print("no onsets detected")
331
+ # print("onset_frame_idxs", onset_frame_idxs)
332
+ # print("mask shape", z.shape)
333
+
334
+ mask = torch.ones_like(z).int()
335
+ for idx in onset_frame_idxs:
336
+ mask[:, :, idx-width:idx+width] = 0
337
+
338
+ return mask.int()
339
+
340
+ def tria_mask(
341
+ codes: torch.Tensor,
342
+ min_amt: float = 0.1,
343
+ max_amt: float = 0.4,
344
+ ):
345
+ """
346
+ unmasks a prefix of the codes tensor,
347
+ in the range provided
348
+ """
349
+
350
+ mask = full_mask(codes)
351
+ nb, nc, nt = codes.shape
352
+ for i in range(nb):
353
+ amt = torch.rand(1) * (max_amt - min_amt) + min_amt
354
+ amt = int(amt * nt)
355
+ mask[i, :, :amt] = 0
356
+
357
+ return mask
358
+
359
+
360
+
361
+
362
+
363
+
364
+ if __name__ == "__main__":
365
+ sig = AudioSignal("assets/example.wav")