YinuoGuo27 commited on
Commit
c39612f
·
verified ·
1 Parent(s): 689c6a3

Upload 5 files

Browse files
difpoint/dataset_process/.DS_Store ADDED
Binary file (6.15 kB). View file
 
difpoint/dataset_process/__pycache__/audio.cpython-310.pyc ADDED
Binary file (4.61 kB). View file
 
difpoint/dataset_process/__pycache__/audio.cpython-38.pyc ADDED
Binary file (4.65 kB). View file
 
difpoint/dataset_process/audio.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import librosa.filters
3
+ import numpy as np
4
+ # import tensorflow as tf
5
+ from scipy import signal
6
+ from scipy.io import wavfile
7
+ from difpoint.src.utils.hparams import hparams as hp
8
+
9
+
10
+ def load_wav(path, sr):
11
+ return librosa.core.load(path, sr=sr)[0]
12
+
13
+
14
+ def save_wav(wav, path, sr):
15
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
16
+ # proposed by @dsmiller
17
+ wavfile.write(path, sr, wav.astype(np.int16))
18
+
19
+
20
+ def save_wavenet_wav(wav, path, sr):
21
+ librosa.output.write_wav(path, wav, sr=sr)
22
+
23
+
24
+ def preemphasis(wav, k, preemphasize=True):
25
+ if preemphasize:
26
+ return signal.lfilter([1, -k], [1], wav)
27
+ return wav
28
+
29
+
30
+ def inv_preemphasis(wav, k, inv_preemphasize=True):
31
+ if inv_preemphasize:
32
+ return signal.lfilter([1], [1, -k], wav)
33
+ return wav
34
+
35
+
36
+ def get_hop_size():
37
+ hop_size = hp.hop_size
38
+ if hop_size is None:
39
+ assert hp.frame_shift_ms is not None
40
+ hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
41
+ return hop_size
42
+
43
+
44
+ def linearspectrogram(wav):
45
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
46
+ S = _amp_to_db(np.abs(D)) - hp.ref_level_db
47
+
48
+ if hp.signal_normalization:
49
+ return _normalize(S)
50
+ return S
51
+
52
+
53
+ def melspectrogram(wav):
54
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
55
+ S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
56
+
57
+ if hp.signal_normalization:
58
+ return _normalize(S)
59
+ return S
60
+
61
+
62
+ def _lws_processor():
63
+ import lws
64
+ return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
65
+
66
+
67
+ def _stft(y):
68
+ if hp.use_lws:
69
+ return _lws_processor(hp).stft(y).T
70
+ else:
71
+ return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
72
+
73
+
74
+ ##########################################################
75
+ # Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
76
+ def num_frames(length, fsize, fshift):
77
+ """Compute number of time frames of spectrogram
78
+ """
79
+ pad = (fsize - fshift)
80
+ if length % fshift == 0:
81
+ M = (length + pad * 2 - fsize) // fshift + 1
82
+ else:
83
+ M = (length + pad * 2 - fsize) // fshift + 2
84
+ return M
85
+
86
+
87
+ def pad_lr(x, fsize, fshift):
88
+ """Compute left and right padding
89
+ """
90
+ M = num_frames(len(x), fsize, fshift)
91
+ pad = (fsize - fshift)
92
+ T = len(x) + 2 * pad
93
+ r = (M - 1) * fshift + fsize - T
94
+ return pad, pad + r
95
+
96
+
97
+ ##########################################################
98
+ # Librosa correct padding
99
+ def librosa_pad_lr(x, fsize, fshift):
100
+ return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
101
+
102
+
103
+ # Conversions
104
+ _mel_basis = None
105
+
106
+
107
+ def _linear_to_mel(spectogram):
108
+ global _mel_basis
109
+ if _mel_basis is None:
110
+ _mel_basis = _build_mel_basis()
111
+ return np.dot(_mel_basis, spectogram)
112
+
113
+
114
+ def _build_mel_basis():
115
+ assert hp.fmax <= hp.sample_rate // 2
116
+ return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels,
117
+ fmin=hp.fmin, fmax=hp.fmax)
118
+
119
+
120
+ def _amp_to_db(x):
121
+ min_level = np.exp(hp.min_level_db / 20 * np.log(10))
122
+ return 20 * np.log10(np.maximum(min_level, x))
123
+
124
+
125
+ def _db_to_amp(x):
126
+ return np.power(10.0, (x) * 0.05)
127
+
128
+
129
+ def _normalize(S):
130
+ if hp.allow_clipping_in_normalization:
131
+ if hp.symmetric_mels:
132
+ return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
133
+ -hp.max_abs_value, hp.max_abs_value)
134
+ else:
135
+ return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
136
+
137
+ assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
138
+ if hp.symmetric_mels:
139
+ return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
140
+ else:
141
+ return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
142
+
143
+
144
+ def _denormalize(D):
145
+ if hp.allow_clipping_in_normalization:
146
+ if hp.symmetric_mels:
147
+ return (((np.clip(D, -hp.max_abs_value,
148
+ hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
149
+ + hp.min_level_db)
150
+ else:
151
+ return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
152
+
153
+ if hp.symmetric_mels:
154
+ return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
155
+ else:
156
+ return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
difpoint/dataset_process/wav2lip.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+
6
+ class Conv2d(nn.Module):
7
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act=True, *args, **kwargs):
8
+ super().__init__(*args, **kwargs)
9
+ self.conv_block = nn.Sequential(
10
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
11
+ nn.BatchNorm2d(cout)
12
+ )
13
+ self.act = nn.ReLU()
14
+ self.residual = residual
15
+ self.use_act = use_act
16
+
17
+ def forward(self, x):
18
+ out = self.conv_block(x)
19
+ if self.residual:
20
+ out += x
21
+
22
+ if self.use_act:
23
+ return self.act(out)
24
+ else:
25
+ return out
26
+
27
+ class AudioEncoder(nn.Module):
28
+ def __init__(self, wav2lip_checkpoint, device):
29
+ super(AudioEncoder, self).__init__()
30
+
31
+ self.audio_encoder = nn.Sequential(
32
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
33
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
34
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
35
+
36
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
37
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
38
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
39
+
40
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
41
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
42
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
43
+
44
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
45
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
46
+
47
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
48
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
49
+
50
+ #### load the pre-trained audio_encoder
51
+ wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
52
+ state_dict = self.audio_encoder.state_dict()
53
+
54
+ for k,v in wav2lip_state_dict.items():
55
+ if 'audio_encoder' in k:
56
+ state_dict[k.replace('module.audio_encoder.', '')] = v
57
+ self.audio_encoder.load_state_dict(state_dict)
58
+
59
+
60
+ def forward(self, audio_sequences):
61
+ # audio_sequences = (B, T, 1, 80, 16)
62
+ B = audio_sequences.size(0)
63
+
64
+ audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
65
+
66
+ audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
67
+ dim = audio_embedding.shape[1]
68
+ audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
69
+
70
+ return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512
71
+
72
+ wav2lip_checkpoint='ckpts/wav2lip.pth'
73
+ wav2lip_model = AudioEncoder(wav2lip_checkpoint, 'cuda')
74
+ wav2lip_model.cuda()
75
+ wav2lip_model.eval()