File size: 637 Bytes
09b47fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from dataclasses import dataclass, asdict
import torch
from torch import Tensor
import torch.nn as nn
import torchaudio
import torchaudio.transforms

from config import MelConfig

class LogMelSpectrogram(nn.Module):
    def __init__(self, config: MelConfig):
        super().__init__()
        self.spec = torchaudio.transforms.MelSpectrogram(**asdict(config))
        
    def forward(self, x: Tensor) -> Tensor:
        return self.compress(self.spec(x))
        
    def compress(self, x: Tensor) -> Tensor:
        return torch.log(torch.clamp(x, min=1e-5))

    def decompress(self, x: Tensor) -> Tensor:
        return torch.exp(x)