Spaces:
Sleeping
Sleeping
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
import json | |
from random import shuffle, choice, sample | |
from moviepy.editor import VideoFileClip | |
import librosa | |
from scipy import signal | |
from scipy.io import wavfile | |
import torchaudio | |
torchaudio.set_audio_backend("sox_io") | |
INTERVAL = 1000 | |
# discard | |
stft = torchaudio.transforms.MelSpectrogram( | |
sample_rate=16000, hop_length=161, n_mels=64).cuda() | |
def log10(x): return torch.log(x)/torch.log(torch.tensor(10.)) | |
def norm_range(x, min_val, max_val): | |
return 2.*(x - min_val)/float(max_val - min_val) - 1. | |
def normalize_spec(spec, spec_min, spec_max): | |
return norm_range(spec, spec_min, spec_max) | |
def db_from_amp(x, cuda=False): | |
# rescale the audio | |
if cuda: | |
return 20. * log10(torch.max(torch.tensor(1e-5).to('cuda'), x.float())) | |
else: | |
return 20. * log10(torch.max(torch.tensor(1e-5), x.float())) | |
def audio_stft(audio, stft=stft): | |
# We'll apply stft to the audio samples to convert it to a HxW matrix | |
N, C, A = audio.size() | |
audio = audio.view(N * C, A) | |
spec = stft(audio) | |
spec = spec.transpose(-1, -2) | |
spec = db_from_amp(spec, cuda=True) | |
spec = normalize_spec(spec, -100., 100.) | |
_, T, F = spec.size() | |
spec = spec.view(N, C, T, F) | |
return spec | |
# discard | |
# def get_spec( | |
# wavs, | |
# sample_rate=16000, | |
# use_volume_jittering=False, | |
# center=False, | |
# ): | |
# # Volume jittering - scale volume by factor in range (0.9, 1.1) | |
# if use_volume_jittering: | |
# wavs = [wav * np.random.uniform(0.9, 1.1) for wav in wavs] | |
# if center: | |
# wavs = [center_only(wav) for wav in wavs] | |
# # Convert to log filterbank | |
# specs = [logfbank( | |
# wav, | |
# sample_rate, | |
# winlen=0.009, | |
# winstep=0.005, # if num_sec==1 else 0.01, | |
# nfilt=256, | |
# nfft=1024 | |
# ).astype('float32').T for wav in wavs] | |
# # Convert to 32-bit float and expand dim | |
# specs = np.stack(specs, axis=0) | |
# specs = np.expand_dims(specs, 1) | |
# specs = torch.as_tensor(specs) # Nx1xFxT | |
# return specs | |
def center_only(audio, sr=16000, L=1.0): | |
# center_wav = np.arange(0, L, L/(0.5*sr)) ** 2 | |
# center_wav = np.concatenate([center_wav, center_wav[::-1]]) | |
# center_wav[L*sr//2:3*L*sr//4] = 1 | |
# only take 0.3 sec audio | |
center_wav = np.zeros(int(L * sr)) | |
center_wav[int(0.4*L*sr):int(0.7*L*sr)] = 1 | |
return audio * center_wav | |
def get_spec_librosa( | |
wavs, | |
sample_rate=16000, | |
use_volume_jittering=False, | |
center=False, | |
): | |
# Volume jittering - scale volume by factor in range (0.9, 1.1) | |
if use_volume_jittering: | |
wavs = [wav * np.random.uniform(0.9, 1.1) for wav in wavs] | |
if center: | |
wavs = [center_only(wav) for wav in wavs] | |
# Convert to log filterbank | |
specs = [librosa.feature.melspectrogram( | |
y=wav, | |
sr=sample_rate, | |
n_fft=400, | |
hop_length=126, | |
n_mels=128, | |
).astype('float32') for wav in wavs] | |
# Convert to 32-bit float and expand dim | |
specs = [librosa.power_to_db(spec) for spec in specs] | |
specs = np.stack(specs, axis=0) | |
specs = np.expand_dims(specs, 1) | |
specs = torch.as_tensor(specs) # Nx1xFxT | |
return specs | |
def calcEuclideanDistance_Mat(X, Y): | |
""" | |
Inputs: | |
- X: A numpy array of shape (N, F) | |
- Y: A numpy array of shape (M, F) | |
Returns: | |
A numpy array D of shape (N, M) where D[i, j] is the Euclidean distance | |
between X[i] and Y[j]. | |
""" | |
return ((torch.sum(X ** 2, axis=1, keepdims=True)) + (torch.sum(Y ** 2, axis=1, keepdims=True)).T - 2 * X @ Y.T) ** 0.5 | |
def calcEuclideanDistance(x1, x2): | |
return torch.sum((x1 - x2)**2, dim=1)**0.5 | |
def split_data(in_list, portion=(0.9, 0.95), is_shuffle=True): | |
if is_shuffle: | |
shuffle(in_list) | |
if type(in_list) == str: | |
with open(in_list) as l: | |
fw_list = json.load(l) | |
elif type(in_list) == list: | |
fw_list = in_list | |
else: | |
print(type(in_list)) | |
raise TypeError('Invalid input list type') | |
c1, c2 = int(len(fw_list) * portion[0]), int(len(fw_list) * portion[1]) | |
tr_list, va_list, te_list = fw_list[:c1], fw_list[c1:c2], fw_list[c2:] | |
print( | |
f'==> train set: {len(tr_list)}, validation set: {len(va_list)}, test set: {len(te_list)}') | |
return tr_list, va_list, te_list | |
def load_one_clip(video_path): | |
v = VideoFileClip(video_path) | |
fps = int(v.fps) | |
frames = [f for f in v.iter_frames()][:-1] | |
frame_cnt = len(frames) | |
frame_length = 1000./fps | |
total_length = int(1000 * (frame_cnt / fps)) | |
a = v.audio | |
sr = a.fps | |
a = np.array([fa for fa in a.iter_frames()]) | |
a = librosa.resample(a, sr, 48000) | |
if len(a.shape) > 1: | |
a = np.mean(a, axis=1) | |
while True: | |
idx = np.random.choice(np.arange(frame_cnt - 1), 1)[0] | |
frame_clip = frames[idx] | |
start_time = int(idx * frame_length + 0.5 * frame_length - 500) | |
end_time = start_time + INTERVAL | |
if start_time < 0 or end_time > total_length: | |
continue | |
wave_clip = a[48 * start_time: 48 * end_time] | |
if wave_clip.shape[0] != 48000: | |
continue | |
break | |
return frame_clip, wave_clip | |
def resize_frame(frame): | |
H, W = frame.size | |
short_edge = min(H, W) | |
scale = 256 / short_edge | |
H_tar, W_tar = int(np.round(H * scale)), int(np.round(W * scale)) | |
return frame.resize((H_tar, W_tar)) | |
def get_spectrogram(wave, amp_jitter, amp_jitter_range, log_scale=True, sr=48000): | |
# random clip-level amplitude jittering | |
if amp_jitter: | |
amplified = wave * np.random.uniform(*amp_jitter_range) | |
if wave.dtype == np.int16: | |
amplified[amplified >= 32767] = 32767 | |
amplified[amplified <= -32768] = -32768 | |
wave = amplified.astype('int16') | |
elif wave.dtype == np.float32 or wave.dtype == np.float64: | |
amplified[amplified >= 1] = 1 | |
amplified[amplified <= -1] = -1 | |
# fr, ts, spectrogram = signal.spectrogram(wave[:48000], fs=sr, nperseg=480, noverlap=240, nfft=512) | |
# spectrogram = librosa.feature.melspectrogram(S=spectrogram, n_mels=257) # Try log-mel spectrogram? | |
spectrogram = librosa.feature.melspectrogram( | |
y=wave[:48000], sr=sr, hop_length=240, win_length=480, n_mels=257) | |
if log_scale: | |
spectrogram = librosa.power_to_db(spectrogram, ref=np.max) | |
assert spectrogram.shape[0] == 257 | |
return spectrogram | |
def cropAudio(audio, sr, f_idx, fps=10, length=1., left_shift=0): | |
time_per_frame = 1./fps | |
assert audio.shape[0] > sr * length | |
start_time = f_idx * time_per_frame - left_shift | |
start_time = 0 if start_time < 0 else start_time | |
start_idx = int(np.round(sr * start_time)) | |
end_idx = int(np.round(start_idx + (sr * length))) | |
if end_idx > audio.shape[0]: | |
end_idx = audio.shape[0] | |
start_idx = int(end_idx - (sr * length)) | |
try: | |
assert audio[start_idx:end_idx].shape[0] == sr * length | |
except: | |
print(audio.shape, start_idx, end_idx, end_idx - start_idx) | |
exit(1) | |
return audio[start_idx:end_idx] | |
def pick_async_frame_idx(idx, total_frames, fps=10, gap=2.0, length=1.0, cnt=1): | |
assert idx < total_frames - fps * length | |
lower_bound = idx - int((length + gap) * fps) | |
upper_bound = idx + int((length + gap) * fps) | |
proposal = list(range(0, lower_bound)) + \ | |
list(range(upper_bound, int(total_frames - fps * length))) | |
# assert len(proposal) >= cnt | |
avail_cnt = len(proposal) | |
try: | |
for i in range(cnt - avail_cnt): | |
proposal.append(proposal[i % avail_cnt]) | |
except Exception as e: | |
print(idx, total_frames, proposal) | |
raise e | |
return sample(proposal, k=cnt) | |
def adjust_learning_rate(optimizer, epoch, args): | |
"""Decay the learning rate based on schedule""" | |
lr = args.lr | |
if args.cos: # cosine lr schedule | |
lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epoch)) | |
else: # stepwise lr schedule | |
for milestone in args.schedule: | |
lr *= 0.1 if epoch >= milestone else 1. | |
for param_group in optimizer.param_groups: | |
param_group['lr'] = lr | |