hugo flores garcia commited on
Commit
39bff10
·
1 Parent(s): 12dc48a

plant debug stuff

Browse files
Files changed (2) hide show
  1. hello.py +7 -2
  2. vampnet/modules/transformer.py +63 -0
hello.py CHANGED
@@ -13,6 +13,10 @@ print(f"available finetuned models: {finetuned_model_choices}")
13
  model_choice = random.choice(finetuned_model_choices)
14
  print(f"choosing model: {model_choice}")
15
 
 
 
 
 
16
  # load a finetuned model
17
  interface.load_finetuned(model_choice)
18
 
@@ -25,7 +29,7 @@ codes = interface.encode(signal)
25
  # build a mask for the audio
26
  mask = interface.build_mask(
27
  codes, signal,
28
- periodic_prompt=7,
29
  upper_codebook_mask=3,
30
  )
31
 
@@ -33,7 +37,8 @@ mask = interface.build_mask(
33
  output_tokens = interface.vamp(
34
  codes, mask, return_mask=False,
35
  temperature=1.0,
36
- typical_filtering=True,
 
37
  )
38
 
39
  # convert them to a signal
 
13
  model_choice = random.choice(finetuned_model_choices)
14
  print(f"choosing model: {model_choice}")
15
 
16
+ # or pick a specific finetuned model
17
+ print(f"actually, forcing model: default")
18
+ model_choice = "default"
19
+
20
  # load a finetuned model
21
  interface.load_finetuned(model_choice)
22
 
 
29
  # build a mask for the audio
30
  mask = interface.build_mask(
31
  codes, signal,
32
+ periodic_prompt=13,
33
  upper_codebook_mask=3,
34
  )
35
 
 
37
  output_tokens = interface.vamp(
38
  codes, mask, return_mask=False,
39
  temperature=1.0,
40
+ typical_filtering=False,
41
+ debug=True
42
  )
43
 
44
  # convert them to a signal
vampnet/modules/transformer.py CHANGED
@@ -669,11 +669,50 @@ class VampNet(at.ml.BaseModel):
669
  (mask, torch.full_like(mask, 1)), dim=0
670
  )
671
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
672
  #################
673
  # begin sampling #
674
  #################
675
  from tqdm import tqdm
676
  for i in range(sampling_steps):
 
 
 
 
 
677
 
678
  # our current schedule step
679
  r = scalar_to_batch_tensor(
@@ -706,6 +745,19 @@ class VampNet(at.ml.BaseModel):
706
  top_k=None, top_p=top_p, return_probs=True,
707
  )
708
 
 
 
 
 
 
 
 
 
 
 
 
 
 
709
 
710
  # flatten z_masked and mask, so we can deal with the sampling logic
711
  # we'll unflatten them at the end of the loop for the next forward pass
@@ -713,6 +765,17 @@ class VampNet(at.ml.BaseModel):
713
  z_masked = codebook_flatten(z_masked[:, self.n_conditioning_codebooks:, :])
714
 
715
  mask = (z_masked == self.mask_token).int()
 
 
 
 
 
 
 
 
 
 
 
716
 
717
  # update the mask, remove conditioning codebooks from the mask
718
  # add z back into sampled z where the mask was false
 
669
  (mask, torch.full_like(mask, 1)), dim=0
670
  )
671
 
672
+ if debug:
673
+ DEBUG_FOLDER = "vampnet-debug"
674
+ import matplotlib.pyplot as plt
675
+ from pathlib import Path
676
+
677
+ Path(DEBUG_FOLDER).mkdir(exist_ok=True)
678
+ plt.rcParams['figure.dpi'] = 100 # Default DPI for figures
679
+ plt.rcParams['figure.figsize'] = (20, 0.4) # Default size for a 4x2000 grid (2000/100, 4/100)
680
+ plt.rcParams['image.interpolation'] = 'nearest' # Ensures no smoothing for imshow
681
+ plt.rcParams['image.aspect'] = 'auto' # Maintains proper aspect ratio
682
+ plt.rcParams['axes.axisbelow'] = True # Ensures axis is beneath the data (for clarity)
683
+ plt.rcParams['axes.xmargin'] = 0 # No extra space around data
684
+ plt.rcParams['axes.ymargin'] = 0 # Same for Y-axis
685
+
686
+ from functools import partial
687
+ plt.imshow = partial(plt.imshow, origin='lower')
688
+
689
+
690
+ # save the initial mask
691
+ plt.clf()
692
+ plt.imshow(mask[0].cpu().numpy())
693
+ plt.savefig(f"{DEBUG_FOLDER}/mask.png")
694
+
695
+ # save the initial z_masked
696
+ plt.clf()
697
+ plt.imshow(z_masked[0].cpu().numpy())
698
+ plt.savefig(f"{DEBUG_FOLDER}/z_masked.png")
699
+
700
+ # save the initial z
701
+ plt.clf()
702
+ plt.imshow(z[0].cpu().numpy())
703
+ plt.savefig(f"{DEBUG_FOLDER}/z.png")
704
+
705
+
706
  #################
707
  # begin sampling #
708
  #################
709
  from tqdm import tqdm
710
  for i in range(sampling_steps):
711
+ if debug:
712
+ # save the mask at step i
713
+ # make a folder called step i
714
+ STEP_FOLDER = (f"{DEBUG_FOLDER}/step_{i}")
715
+ Path(STEP_FOLDER).mkdir(exist_ok=True)
716
 
717
  # our current schedule step
718
  r = scalar_to_batch_tensor(
 
745
  top_k=None, top_p=top_p, return_probs=True,
746
  )
747
 
748
+ if debug:
749
+ # log the selected probs and sampled
750
+ plt.clf()
751
+ _selected_probs = codebook_unflatten(selected_probs, n_infer_codebooks)
752
+ plt.imshow(_selected_probs[0].cpu().numpy(), )
753
+ plt.colorbar()
754
+ plt.savefig(f"{STEP_FOLDER}/selected_probs.png")
755
+
756
+ plt.clf()
757
+ plt.imshow(sampled_z.cpu().numpy())
758
+ plt.savefig(f"{STEP_FOLDER}/sampled_z.png")
759
+
760
+
761
 
762
  # flatten z_masked and mask, so we can deal with the sampling logic
763
  # we'll unflatten them at the end of the loop for the next forward pass
 
765
  z_masked = codebook_flatten(z_masked[:, self.n_conditioning_codebooks:, :])
766
 
767
  mask = (z_masked == self.mask_token).int()
768
+
769
+ if debug:
770
+
771
+ plt.clf()
772
+ # plt.imshow(mask.cpu().numpy())
773
+ _mask = codebook_unflatten(mask, n_infer_codebooks)
774
+ plt.imshow(_mask[0].cpu().numpy())
775
+ plt.savefig(f"{STEP_FOLDER}/mask.png")
776
+
777
+
778
+
779
 
780
  # update the mask, remove conditioning codebooks from the mask
781
  # add z back into sampled z where the mask was false