File size: 1,132 Bytes
d358e26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from dataclasses import dataclass, asdict
import torch
from torch import Tensor
import torch.nn as nn
import torchaudio
import torchaudio.transforms

from config import MelConfig

class LogMelSpectrogram(nn.Module):
    def __init__(self, config: MelConfig):
        super().__init__()
        self.spec = torchaudio.transforms.MelSpectrogram(**asdict(config))
        
    def forward(self, x: Tensor) -> Tensor:
        return self.compress(self.spec(x))
        
    def compress(self, x: Tensor) -> Tensor:
        return torch.log(torch.clamp(x, min=1e-5))

    def decompress(self, x: Tensor) -> Tensor:
        return torch.exp(x)

    
def load_and_resample_audio(audio_path, target_sr, device='cpu') -> Tensor:
    try:
        y, sr = torchaudio.load(audio_path)
    except Exception as e:
        print(str(e))
        return None
    
    y.to(device)
    # Convert to mono
    if y.size(0) > 1:
        y = y[0, :].unsqueeze(0) # shape: [2, time] -> [time] -> [1, time]
        
    # resample audio to target sample_rate
    if sr != target_sr:
        y = torchaudio.functional.resample(y, sr, target_sr)
    return y