Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
import argparse | |
from collections import defaultdict | |
import json | |
import logging | |
from logging.handlers import TimedRotatingFileHandler | |
import os | |
import platform | |
from pathlib import Path | |
import sys | |
import shutil | |
from typing import List | |
pwd = os.path.abspath(os.path.dirname(__file__)) | |
sys.path.append(os.path.join(pwd, "../../")) | |
import pandas as pd | |
import torch | |
from torch.utils.data.dataloader import DataLoader | |
from tqdm import tqdm | |
from toolbox.torch.modules.loss import FocalLoss, HingeLoss, HingeLinear | |
from toolbox.torch.training.metrics.categorical_accuracy import CategoricalAccuracy | |
from toolbox.torch.utils.data.vocabulary import Vocabulary | |
from toolbox.torch.utils.data.dataset.wave_classifier_excel_dataset import WaveClassifierExcelDataset | |
from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveEncoder, ClsHead, WaveClassifier | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--vocabulary_dir", default="vocabulary", type=str) | |
parser.add_argument("--train_dataset", default="train.xlsx", type=str) | |
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) | |
parser.add_argument("--max_steps", default=100000, type=int) | |
parser.add_argument("--save_steps", default=30, type=int) | |
parser.add_argument("--batch_size", default=1, type=int) | |
parser.add_argument("--learning_rate", default=1e-3, type=float) | |
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) | |
parser.add_argument("--patience", default=5, type=int) | |
parser.add_argument("--serialization_dir", default="union", type=str) | |
parser.add_argument("--seed", default=0, type=int) | |
parser.add_argument("--num_workers", default=0, type=int) | |
args = parser.parse_args() | |
return args | |
def logging_config(file_dir: str): | |
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" | |
logging.basicConfig(format=fmt, | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.DEBUG) | |
file_handler = TimedRotatingFileHandler( | |
filename=os.path.join(file_dir, "main.log"), | |
encoding="utf-8", | |
when="D", | |
interval=1, | |
backupCount=7 | |
) | |
file_handler.setLevel(logging.INFO) | |
file_handler.setFormatter(logging.Formatter(fmt)) | |
logger = logging.getLogger(__name__) | |
logger.addHandler(file_handler) | |
return logger | |
class CollateFunction(object): | |
def __init__(self): | |
pass | |
def __call__(self, batch: List[dict]): | |
array_list = list() | |
label_list = list() | |
for sample in batch: | |
array = sample['waveform'] | |
label = sample['label'] | |
array_list.append(array) | |
label_list.append(label) | |
array_list = torch.stack(array_list) | |
label_list = torch.stack(label_list) | |
return array_list, label_list | |
collate_fn = CollateFunction() | |
class DatasetIterator(object): | |
def __init__(self, data_loader: DataLoader): | |
self.data_loader = data_loader | |
self.data_loader_iter = iter(self.data_loader) | |
def next(self): | |
try: | |
result = self.data_loader_iter.__next__() | |
except StopIteration: | |
self.data_loader_iter = iter(self.data_loader) | |
result = self.data_loader_iter.__next__() | |
return result | |
def main(): | |
args = get_args() | |
serialization_dir = Path(args.serialization_dir) | |
serialization_dir.mkdir(parents=True, exist_ok=True) | |
logger = logging_config(args.serialization_dir) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
n_gpu = torch.cuda.device_count() | |
logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) | |
vocabulary = Vocabulary.from_files(args.vocabulary_dir) | |
namespaces = vocabulary._token_to_index.keys() | |
# namespace_to_ratio | |
max_radio = (len(namespaces) - 1) * 3 | |
namespace_to_ratio = {n: 1 for n in namespaces} | |
namespace_to_ratio["global_labels"] = max_radio | |
# datasets | |
logger.info("prepare datasets") | |
namespace_to_datasets = dict() | |
for namespace in namespaces: | |
logger.info("prepare datasets - {}".format(namespace)) | |
if namespace == "global_labels": | |
train_dataset = WaveClassifierExcelDataset( | |
vocab=vocabulary, | |
excel_file=args.train_dataset, | |
category=None, | |
category_field="category", | |
label_field="global_labels", | |
expected_sample_rate=8000, | |
max_wave_value=32768.0, | |
) | |
valid_dataset = WaveClassifierExcelDataset( | |
vocab=vocabulary, | |
excel_file=args.valid_dataset, | |
category=None, | |
category_field="category", | |
label_field="global_labels", | |
expected_sample_rate=8000, | |
max_wave_value=32768.0, | |
) | |
else: | |
train_dataset = WaveClassifierExcelDataset( | |
vocab=vocabulary, | |
excel_file=args.train_dataset, | |
category=namespace, | |
category_field="category", | |
label_field="country_labels", | |
expected_sample_rate=8000, | |
max_wave_value=32768.0, | |
) | |
valid_dataset = WaveClassifierExcelDataset( | |
vocab=vocabulary, | |
excel_file=args.valid_dataset, | |
category=namespace, | |
category_field="category", | |
label_field="country_labels", | |
expected_sample_rate=8000, | |
max_wave_value=32768.0, | |
) | |
train_data_loader = DataLoader( | |
dataset=train_dataset, | |
batch_size=args.batch_size, | |
shuffle=True, | |
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. | |
# num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, | |
num_workers=args.num_workers, | |
collate_fn=collate_fn, | |
pin_memory=False, | |
# prefetch_factor=64, | |
) | |
valid_data_loader = DataLoader( | |
dataset=valid_dataset, | |
batch_size=args.batch_size, | |
shuffle=True, | |
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. | |
# num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, | |
num_workers=args.num_workers, | |
collate_fn=collate_fn, | |
pin_memory=False, | |
# prefetch_factor=64, | |
) | |
namespace_to_datasets[namespace] = { | |
"train_data_loader": train_data_loader, | |
"valid_data_loader": valid_data_loader, | |
} | |
# datasets iterator | |
logger.info("prepare datasets iterator") | |
namespace_to_datasets_iter = dict() | |
for namespace in namespaces: | |
logger.info("prepare datasets iterator - {}".format(namespace)) | |
train_data_loader = namespace_to_datasets[namespace]["train_data_loader"] | |
valid_data_loader = namespace_to_datasets[namespace]["valid_data_loader"] | |
namespace_to_datasets_iter[namespace] = { | |
"train_data_loader_iter": DatasetIterator(train_data_loader), | |
"valid_data_loader_iter": DatasetIterator(valid_data_loader), | |
} | |
# models - encoder | |
logger.info("prepare models - encoder") | |
wave_encoder = WaveEncoder( | |
conv2d_block_param_list=[ | |
{ | |
"batch_norm": True, | |
"in_channels": 1, | |
"out_channels": 4, | |
"kernel_size": 3, | |
"stride": 1, | |
# "padding": "same", | |
"dilation": 3, | |
"activation": "relu", | |
"dropout": 0.1, | |
}, | |
{ | |
# "batch_norm": True, | |
"in_channels": 4, | |
"out_channels": 4, | |
"kernel_size": 5, | |
"stride": 2, | |
# "padding": "same", | |
"dilation": 3, | |
"activation": "relu", | |
"dropout": 0.1, | |
}, | |
{ | |
# "batch_norm": True, | |
"in_channels": 4, | |
"out_channels": 4, | |
"kernel_size": 3, | |
"stride": 1, | |
# "padding": "same", | |
"dilation": 2, | |
"activation": "relu", | |
"dropout": 0.1, | |
}, | |
], | |
mel_spectrogram_param={ | |
'sample_rate': 8000, | |
'n_fft': 512, | |
'win_length': 200, | |
'hop_length': 80, | |
'f_min': 10, | |
'f_max': 3800, | |
'window_fn': 'hamming', | |
'n_mels': 80, | |
} | |
) | |
# models - cls_head | |
logger.info("prepare models - cls_head") | |
namespace_to_cls_heads = dict() | |
for namespace in namespaces: | |
logger.info("prepare models - cls_head - {}".format(namespace)) | |
cls_head = ClsHead( | |
input_dim=352, | |
num_layers=2, | |
hidden_dims=[128, 32], | |
activations="relu", | |
dropout=0.1, | |
num_labels=vocabulary.get_vocab_size(namespace=namespace) | |
) | |
namespace_to_cls_heads[namespace] = cls_head | |
# models - classifier | |
logger.info("prepare models - classifier") | |
namespace_to_classifier = dict() | |
for namespace in namespaces: | |
logger.info("prepare models - classifier - {}".format(namespace)) | |
cls_head = namespace_to_cls_heads[namespace] | |
wave_classifier = WaveClassifier( | |
wave_encoder=wave_encoder, | |
cls_head=cls_head, | |
) | |
wave_classifier.to(device) | |
namespace_to_classifier[namespace] = wave_classifier | |
# optimizer | |
logger.info("prepare optimizer") | |
param_optimizer = list() | |
param_optimizer.extend(wave_encoder.parameters()) | |
for _, cls_head in namespace_to_cls_heads.items(): | |
param_optimizer.extend(cls_head.parameters()) | |
optimizer = torch.optim.Adam( | |
param_optimizer, | |
lr=args.learning_rate, | |
) | |
lr_scheduler = torch.optim.lr_scheduler.StepLR( | |
optimizer, | |
step_size=10000 | |
) | |
focal_loss = FocalLoss( | |
num_classes=vocabulary.get_vocab_size(namespace="global_labels"), | |
reduction="mean", | |
) | |
# categorical_accuracy | |
logger.info("prepare categorical_accuracy") | |
namespace_to_categorical_accuracy = dict() | |
for namespace in namespaces: | |
categorical_accuracy = CategoricalAccuracy() | |
namespace_to_categorical_accuracy[namespace] = categorical_accuracy | |
# training loop | |
logger.info("prepare training loop") | |
model_list = list() | |
best_idx_step = None | |
best_accuracy = None | |
patience_count = 0 | |
namespace_to_total_loss = defaultdict(float) | |
namespace_to_total_examples = defaultdict(int) | |
for idx_step in tqdm(range(args.max_steps)): | |
# training one step | |
loss: torch.Tensor = None | |
for namespace in namespaces: | |
train_data_loader_iter = namespace_to_datasets_iter[namespace]["train_data_loader_iter"] | |
ratio = namespace_to_ratio[namespace] | |
model = namespace_to_classifier[namespace] | |
categorical_accuracy = namespace_to_categorical_accuracy[namespace] | |
model.train() | |
for _ in range(ratio): | |
batch = train_data_loader_iter.next() | |
input_ids, label_ids = batch | |
input_ids = input_ids.to(device) | |
label_ids: torch.LongTensor = label_ids.to(device).long() | |
logits = model.forward(input_ids) | |
task_loss = focal_loss.forward(logits, label_ids.view(-1)) | |
categorical_accuracy(logits, label_ids) | |
if loss is None: | |
loss = task_loss | |
else: | |
loss += task_loss | |
namespace_to_total_loss[namespace] += task_loss.item() | |
namespace_to_total_examples[namespace] += input_ids.size(0) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
lr_scheduler.step() | |
# logging | |
if (idx_step + 1) % args.save_steps == 0: | |
metrics = dict() | |
# training | |
for namespace in namespaces: | |
total_loss = namespace_to_total_loss[namespace] | |
total_examples = namespace_to_total_examples[namespace] | |
training_loss = total_loss / total_examples | |
training_loss = round(training_loss, 4) | |
categorical_accuracy = namespace_to_categorical_accuracy[namespace] | |
training_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"] | |
training_accuracy = round(training_accuracy, 4) | |
logger.info("Step: {}; namespace: {}; training_loss: {}; training_accuracy: {}".format( | |
idx_step, namespace, training_loss, training_accuracy | |
)) | |
metrics[namespace] = { | |
"training_loss": training_loss, | |
"training_accuracy": training_accuracy, | |
} | |
namespace_to_total_loss = defaultdict(float) | |
namespace_to_total_examples = defaultdict(int) | |
# evaluation | |
for namespace in namespaces: | |
valid_data_loader = namespace_to_datasets[namespace]["valid_data_loader"] | |
model = namespace_to_classifier[namespace] | |
categorical_accuracy = namespace_to_categorical_accuracy[namespace] | |
model.eval() | |
total_loss = 0 | |
total_examples = 0 | |
for step, batch in enumerate(valid_data_loader): | |
input_ids, label_ids = batch | |
input_ids = input_ids.to(device) | |
label_ids: torch.LongTensor = label_ids.to(device).long() | |
with torch.no_grad(): | |
logits = model.forward(input_ids) | |
loss = focal_loss.forward(logits, label_ids.view(-1)) | |
categorical_accuracy(logits, label_ids) | |
total_loss += loss.item() | |
total_examples += input_ids.size(0) | |
evaluation_loss = total_loss / total_examples | |
evaluation_loss = round(evaluation_loss, 4) | |
evaluation_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"] | |
evaluation_accuracy = round(evaluation_accuracy, 4) | |
logger.info("Step: {}; namespace: {}; evaluation_loss: {}; evaluation_accuracy: {}".format( | |
idx_step, namespace, evaluation_loss, evaluation_accuracy | |
)) | |
metrics[namespace] = { | |
"evaluation_loss": evaluation_loss, | |
"evaluation_accuracy": evaluation_accuracy, | |
} | |
# update ratio | |
min_accuracy = min([m["evaluation_accuracy"] for m in metrics.values()]) | |
max_accuracy = max([m["evaluation_accuracy"] for m in metrics.values()]) | |
width = max_accuracy - min_accuracy | |
for namespace, metric in metrics.items(): | |
evaluation_accuracy = metric["evaluation_accuracy"] | |
radio = (max_accuracy - evaluation_accuracy) / width * max_radio | |
radio = int(radio) | |
namespace_to_ratio[namespace] = radio | |
msg = "".join(["{}: {}; ".format(k, v) for k, v in namespace_to_ratio.items()]) | |
logger.info("namespace to ratio: {}".format(msg)) | |
# save path | |
step_dir = serialization_dir / "step-{}".format(idx_step) | |
step_dir.mkdir(parents=True, exist_ok=False) | |
# save models | |
wave_encoder_filename = step_dir / "wave_encoder.pt" | |
torch.save(wave_encoder.state_dict(), wave_encoder_filename) | |
for namespace in namespaces: | |
cls_head_filename = step_dir / "{}.pt".format(namespace) | |
cls_head = namespace_to_cls_heads[namespace] | |
torch.save(cls_head.state_dict(), cls_head_filename) | |
model_list.append(step_dir) | |
if len(model_list) >= args.num_serialized_models_to_keep: | |
model_to_delete: Path = model_list.pop(0) | |
shutil.rmtree(model_to_delete.as_posix()) | |
# save metric | |
this_accuracy = metrics["global_labels"]["evaluation_accuracy"] | |
if best_accuracy is None: | |
best_idx_step = idx_step | |
best_accuracy = this_accuracy | |
elif metrics["global_labels"]["evaluation_accuracy"] > best_accuracy: | |
best_idx_step = idx_step | |
best_accuracy = this_accuracy | |
else: | |
pass | |
metrics_filename = step_dir / "metrics_epoch.json" | |
metrics.update({ | |
"idx_step": idx_step, | |
"best_idx_step": best_idx_step, | |
}) | |
with open(metrics_filename, "w", encoding="utf-8") as f: | |
json.dump(metrics, f, indent=4, ensure_ascii=False) | |
# save best | |
best_dir = serialization_dir / "best" | |
if best_idx_step == idx_step: | |
if best_dir.exists(): | |
shutil.rmtree(best_dir) | |
shutil.copytree(step_dir, best_dir) | |
# early stop | |
early_stop_flag = False | |
if best_idx_step == idx_step: | |
patience_count = 0 | |
else: | |
patience_count += 1 | |
if patience_count >= args.patience: | |
early_stop_flag = True | |
# early stop | |
if early_stop_flag: | |
break | |
return | |
if __name__ == "__main__": | |
main() | |