Datasculptor's picture
Duplicate from AIGC-Audio/AudioGPT
98f685a
import os
import logging
import h5py
import soundfile
import librosa
import numpy as np
import pandas as pd
from scipy import stats
import datetime
import pickle
def create_folder(fd):
if not os.path.exists(fd):
os.makedirs(fd)
def get_filename(path):
path = os.path.realpath(path)
na_ext = path.split('/')[-1]
na = os.path.splitext(na_ext)[0]
return na
def get_sub_filepaths(folder):
paths = []
for root, dirs, files in os.walk(folder):
for name in files:
path = os.path.join(root, name)
paths.append(path)
return paths
def create_logging(log_dir, filemode):
create_folder(log_dir)
i1 = 0
while os.path.isfile(os.path.join(log_dir, '{:04d}.log'.format(i1))):
i1 += 1
log_path = os.path.join(log_dir, '{:04d}.log'.format(i1))
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
datefmt='%a, %d %b %Y %H:%M:%S',
filename=log_path,
filemode=filemode)
# Print to console
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s')
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)
return logging
def read_metadata(csv_path, classes_num, id_to_ix):
"""Read metadata of AudioSet from a csv file.
Args:
csv_path: str
Returns:
meta_dict: {'audio_name': (audios_num,), 'target': (audios_num, classes_num)}
"""
with open(csv_path, 'r') as fr:
lines = fr.readlines()
lines = lines[3:] # Remove heads
audios_num = len(lines)
targets = np.zeros((audios_num, classes_num), dtype=np.bool)
audio_names = []
for n, line in enumerate(lines):
items = line.split(', ')
"""items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']"""
audio_name = 'Y{}.wav'.format(items[0]) # Audios are started with an extra 'Y' when downloading
label_ids = items[3].split('"')[1].split(',')
audio_names.append(audio_name)
# Target
for id in label_ids:
ix = id_to_ix[id]
targets[n, ix] = 1
meta_dict = {'audio_name': np.array(audio_names), 'target': targets}
return meta_dict
def float32_to_int16(x):
assert np.max(np.abs(x)) <= 1.2
x = np.clip(x, -1, 1)
return (x * 32767.).astype(np.int16)
def int16_to_float32(x):
return (x / 32767.).astype(np.float32)
def pad_or_truncate(x, audio_length):
"""Pad all audio to specific length."""
if len(x) <= audio_length:
return np.concatenate((x, np.zeros(audio_length - len(x))), axis=0)
else:
return x[0 : audio_length]
def d_prime(auc):
d_prime = stats.norm().ppf(auc) * np.sqrt(2.0)
return d_prime
class Mixup(object):
def __init__(self, mixup_alpha, random_seed=1234):
"""Mixup coefficient generator.
"""
self.mixup_alpha = mixup_alpha
self.random_state = np.random.RandomState(random_seed)
def get_lambda(self, batch_size):
"""Get mixup random coefficients.
Args:
batch_size: int
Returns:
mixup_lambdas: (batch_size,)
"""
mixup_lambdas = []
for n in range(0, batch_size, 2):
lam = self.random_state.beta(self.mixup_alpha, self.mixup_alpha, 1)[0]
mixup_lambdas.append(lam)
mixup_lambdas.append(1. - lam)
return np.array(mixup_lambdas)
class StatisticsContainer(object):
def __init__(self, statistics_path):
"""Contain statistics of different training iterations.
"""
self.statistics_path = statistics_path
self.backup_statistics_path = '{}_{}.pkl'.format(
os.path.splitext(self.statistics_path)[0],
datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
self.statistics_dict = {'bal': [], 'test': []}
def append(self, iteration, statistics, data_type):
statistics['iteration'] = iteration
self.statistics_dict[data_type].append(statistics)
def dump(self):
pickle.dump(self.statistics_dict, open(self.statistics_path, 'wb'))
pickle.dump(self.statistics_dict, open(self.backup_statistics_path, 'wb'))
logging.info(' Dump statistics to {}'.format(self.statistics_path))
logging.info(' Dump statistics to {}'.format(self.backup_statistics_path))
def load_state_dict(self, resume_iteration):
self.statistics_dict = pickle.load(open(self.statistics_path, 'rb'))
resume_statistics_dict = {'bal': [], 'test': []}
for key in self.statistics_dict.keys():
for statistics in self.statistics_dict[key]:
if statistics['iteration'] <= resume_iteration:
resume_statistics_dict[key].append(statistics)
self.statistics_dict = resume_statistics_dict