KingNish commited on
Commit
d2c70a4
·
verified ·
1 Parent(s): 2d64f5c

Upload ./vocos/helpers.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vocos/helpers.py +71 -0
vocos/helpers.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ import numpy as np
3
+ import torch
4
+ from matplotlib import pyplot as plt
5
+ from pytorch_lightning import Callback
6
+
7
+ matplotlib.use("Agg")
8
+
9
+
10
+ def save_figure_to_numpy(fig: plt.Figure) -> np.ndarray:
11
+ """
12
+ Save a matplotlib figure to a numpy array.
13
+
14
+ Args:
15
+ fig (Figure): Matplotlib figure object.
16
+
17
+ Returns:
18
+ ndarray: Numpy array representing the figure.
19
+ """
20
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
21
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
22
+ return data
23
+
24
+
25
+ def plot_spectrogram_to_numpy(spectrogram: np.ndarray) -> np.ndarray:
26
+ """
27
+ Plot a spectrogram and convert it to a numpy array.
28
+
29
+ Args:
30
+ spectrogram (ndarray): Spectrogram data.
31
+
32
+ Returns:
33
+ ndarray: Numpy array representing the plotted spectrogram.
34
+ """
35
+ spectrogram = spectrogram.astype(np.float32)
36
+ fig, ax = plt.subplots(figsize=(12, 3))
37
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
38
+ plt.colorbar(im, ax=ax)
39
+ plt.xlabel("Frames")
40
+ plt.ylabel("Channels")
41
+ plt.tight_layout()
42
+
43
+ fig.canvas.draw()
44
+ data = save_figure_to_numpy(fig)
45
+ plt.close()
46
+ return data
47
+
48
+
49
+ class GradNormCallback(Callback):
50
+ """
51
+ Callback to log the gradient norm.
52
+ """
53
+
54
+ def on_after_backward(self, trainer, model):
55
+ model.log("grad_norm", gradient_norm(model))
56
+
57
+
58
+ def gradient_norm(model: torch.nn.Module, norm_type: float = 2.0) -> torch.Tensor:
59
+ """
60
+ Compute the gradient norm.
61
+
62
+ Args:
63
+ model (Module): PyTorch model.
64
+ norm_type (float, optional): Type of the norm. Defaults to 2.0.
65
+
66
+ Returns:
67
+ Tensor: Gradient norm.
68
+ """
69
+ grads = [p.grad for p in model.parameters() if p.grad is not None]
70
+ total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type) for g in grads]), norm_type)
71
+ return total_norm