HoneyTian's picture
update
69ad385
raw
history blame
18 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
from collections import defaultdict
import json
import logging
from logging.handlers import TimedRotatingFileHandler
import os
import platform
from pathlib import Path
import sys
import shutil
from typing import List
pwd = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(pwd, "../../"))
import pandas as pd
import torch
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from toolbox.torch.modules.loss import FocalLoss, HingeLoss, HingeLinear
from toolbox.torch.training.metrics.categorical_accuracy import CategoricalAccuracy
from toolbox.torch.utils.data.vocabulary import Vocabulary
from toolbox.torch.utils.data.dataset.wave_classifier_excel_dataset import WaveClassifierExcelDataset
from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveEncoder, ClsHead, WaveClassifier
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
parser.add_argument("--max_steps", default=100000, type=int)
parser.add_argument("--save_steps", default=30, type=int)
parser.add_argument("--batch_size", default=1, type=int)
parser.add_argument("--learning_rate", default=1e-3, type=float)
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
parser.add_argument("--patience", default=5, type=int)
parser.add_argument("--serialization_dir", default="union", type=str)
parser.add_argument("--seed", default=0, type=int)
parser.add_argument("--num_workers", default=0, type=int)
args = parser.parse_args()
return args
def logging_config(file_dir: str):
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
logging.basicConfig(format=fmt,
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.DEBUG)
file_handler = TimedRotatingFileHandler(
filename=os.path.join(file_dir, "main.log"),
encoding="utf-8",
when="D",
interval=1,
backupCount=7
)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(logging.Formatter(fmt))
logger = logging.getLogger(__name__)
logger.addHandler(file_handler)
return logger
class CollateFunction(object):
def __init__(self):
pass
def __call__(self, batch: List[dict]):
array_list = list()
label_list = list()
for sample in batch:
array = sample['waveform']
label = sample['label']
array_list.append(array)
label_list.append(label)
array_list = torch.stack(array_list)
label_list = torch.stack(label_list)
return array_list, label_list
collate_fn = CollateFunction()
class DatasetIterator(object):
def __init__(self, data_loader: DataLoader):
self.data_loader = data_loader
self.data_loader_iter = iter(self.data_loader)
def next(self):
try:
result = self.data_loader_iter.__next__()
except StopIteration:
self.data_loader_iter = iter(self.data_loader)
result = self.data_loader_iter.__next__()
return result
def main():
args = get_args()
serialization_dir = Path(args.serialization_dir)
serialization_dir.mkdir(parents=True, exist_ok=True)
logger = logging_config(args.serialization_dir)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
vocabulary = Vocabulary.from_files(args.vocabulary_dir)
namespaces = vocabulary._token_to_index.keys()
# namespace_to_ratio
max_radio = (len(namespaces) - 1) * 3
namespace_to_ratio = {n: 1 for n in namespaces}
namespace_to_ratio["global_labels"] = max_radio
# datasets
logger.info("prepare datasets")
namespace_to_datasets = dict()
for namespace in namespaces:
logger.info("prepare datasets - {}".format(namespace))
if namespace == "global_labels":
train_dataset = WaveClassifierExcelDataset(
vocab=vocabulary,
excel_file=args.train_dataset,
category=None,
category_field="category",
label_field="global_labels",
expected_sample_rate=8000,
max_wave_value=32768.0,
)
valid_dataset = WaveClassifierExcelDataset(
vocab=vocabulary,
excel_file=args.valid_dataset,
category=None,
category_field="category",
label_field="global_labels",
expected_sample_rate=8000,
max_wave_value=32768.0,
)
else:
train_dataset = WaveClassifierExcelDataset(
vocab=vocabulary,
excel_file=args.train_dataset,
category=namespace,
category_field="category",
label_field="country_labels",
expected_sample_rate=8000,
max_wave_value=32768.0,
)
valid_dataset = WaveClassifierExcelDataset(
vocab=vocabulary,
excel_file=args.valid_dataset,
category=namespace,
category_field="category",
label_field="country_labels",
expected_sample_rate=8000,
max_wave_value=32768.0,
)
train_data_loader = DataLoader(
dataset=train_dataset,
batch_size=args.batch_size,
shuffle=True,
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
# num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
num_workers=args.num_workers,
collate_fn=collate_fn,
pin_memory=False,
# prefetch_factor=64,
)
valid_data_loader = DataLoader(
dataset=valid_dataset,
batch_size=args.batch_size,
shuffle=True,
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
# num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
num_workers=args.num_workers,
collate_fn=collate_fn,
pin_memory=False,
# prefetch_factor=64,
)
namespace_to_datasets[namespace] = {
"train_data_loader": train_data_loader,
"valid_data_loader": valid_data_loader,
}
# datasets iterator
logger.info("prepare datasets iterator")
namespace_to_datasets_iter = dict()
for namespace in namespaces:
logger.info("prepare datasets iterator - {}".format(namespace))
train_data_loader = namespace_to_datasets[namespace]["train_data_loader"]
valid_data_loader = namespace_to_datasets[namespace]["valid_data_loader"]
namespace_to_datasets_iter[namespace] = {
"train_data_loader_iter": DatasetIterator(train_data_loader),
"valid_data_loader_iter": DatasetIterator(valid_data_loader),
}
# models - encoder
logger.info("prepare models - encoder")
wave_encoder = WaveEncoder(
conv2d_block_param_list=[
{
"batch_norm": True,
"in_channels": 1,
"out_channels": 4,
"kernel_size": 3,
"stride": 1,
# "padding": "same",
"dilation": 3,
"activation": "relu",
"dropout": 0.1,
},
{
# "batch_norm": True,
"in_channels": 4,
"out_channels": 4,
"kernel_size": 5,
"stride": 2,
# "padding": "same",
"dilation": 3,
"activation": "relu",
"dropout": 0.1,
},
{
# "batch_norm": True,
"in_channels": 4,
"out_channels": 4,
"kernel_size": 3,
"stride": 1,
# "padding": "same",
"dilation": 2,
"activation": "relu",
"dropout": 0.1,
},
],
mel_spectrogram_param={
'sample_rate': 8000,
'n_fft': 512,
'win_length': 200,
'hop_length': 80,
'f_min': 10,
'f_max': 3800,
'window_fn': 'hamming',
'n_mels': 80,
}
)
# models - cls_head
logger.info("prepare models - cls_head")
namespace_to_cls_heads = dict()
for namespace in namespaces:
logger.info("prepare models - cls_head - {}".format(namespace))
cls_head = ClsHead(
input_dim=352,
num_layers=2,
hidden_dims=[128, 32],
activations="relu",
dropout=0.1,
num_labels=vocabulary.get_vocab_size(namespace=namespace)
)
namespace_to_cls_heads[namespace] = cls_head
# models - classifier
logger.info("prepare models - classifier")
namespace_to_classifier = dict()
for namespace in namespaces:
logger.info("prepare models - classifier - {}".format(namespace))
cls_head = namespace_to_cls_heads[namespace]
wave_classifier = WaveClassifier(
wave_encoder=wave_encoder,
cls_head=cls_head,
)
wave_classifier.to(device)
namespace_to_classifier[namespace] = wave_classifier
# optimizer
logger.info("prepare optimizer")
param_optimizer = list()
param_optimizer.extend(wave_encoder.parameters())
for _, cls_head in namespace_to_cls_heads.items():
param_optimizer.extend(cls_head.parameters())
optimizer = torch.optim.Adam(
param_optimizer,
lr=args.learning_rate,
)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=10000
)
focal_loss = FocalLoss(
num_classes=vocabulary.get_vocab_size(namespace="global_labels"),
reduction="mean",
)
# categorical_accuracy
logger.info("prepare categorical_accuracy")
namespace_to_categorical_accuracy = dict()
for namespace in namespaces:
categorical_accuracy = CategoricalAccuracy()
namespace_to_categorical_accuracy[namespace] = categorical_accuracy
# training loop
logger.info("prepare training loop")
model_list = list()
best_idx_step = None
best_accuracy = None
patience_count = 0
namespace_to_total_loss = defaultdict(float)
namespace_to_total_examples = defaultdict(int)
for idx_step in tqdm(range(args.max_steps)):
# training one step
loss: torch.Tensor = None
for namespace in namespaces:
train_data_loader_iter = namespace_to_datasets_iter[namespace]["train_data_loader_iter"]
ratio = namespace_to_ratio[namespace]
model = namespace_to_classifier[namespace]
categorical_accuracy = namespace_to_categorical_accuracy[namespace]
model.train()
for _ in range(ratio):
batch = train_data_loader_iter.next()
input_ids, label_ids = batch
input_ids = input_ids.to(device)
label_ids: torch.LongTensor = label_ids.to(device).long()
logits = model.forward(input_ids)
task_loss = focal_loss.forward(logits, label_ids.view(-1))
categorical_accuracy(logits, label_ids)
if loss is None:
loss = task_loss
else:
loss += task_loss
namespace_to_total_loss[namespace] += task_loss.item()
namespace_to_total_examples[namespace] += input_ids.size(0)
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
# logging
if (idx_step + 1) % args.save_steps == 0:
metrics = dict()
# training
for namespace in namespaces:
total_loss = namespace_to_total_loss[namespace]
total_examples = namespace_to_total_examples[namespace]
training_loss = total_loss / total_examples
training_loss = round(training_loss, 4)
categorical_accuracy = namespace_to_categorical_accuracy[namespace]
training_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
training_accuracy = round(training_accuracy, 4)
logger.info("Step: {}; namespace: {}; training_loss: {}; training_accuracy: {}".format(
idx_step, namespace, training_loss, training_accuracy
))
metrics[namespace] = {
"training_loss": training_loss,
"training_accuracy": training_accuracy,
}
namespace_to_total_loss = defaultdict(float)
namespace_to_total_examples = defaultdict(int)
# evaluation
for namespace in namespaces:
valid_data_loader = namespace_to_datasets[namespace]["valid_data_loader"]
model = namespace_to_classifier[namespace]
categorical_accuracy = namespace_to_categorical_accuracy[namespace]
model.eval()
total_loss = 0
total_examples = 0
for step, batch in enumerate(valid_data_loader):
input_ids, label_ids = batch
input_ids = input_ids.to(device)
label_ids: torch.LongTensor = label_ids.to(device).long()
with torch.no_grad():
logits = model.forward(input_ids)
loss = focal_loss.forward(logits, label_ids.view(-1))
categorical_accuracy(logits, label_ids)
total_loss += loss.item()
total_examples += input_ids.size(0)
evaluation_loss = total_loss / total_examples
evaluation_loss = round(evaluation_loss, 4)
evaluation_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
evaluation_accuracy = round(evaluation_accuracy, 4)
logger.info("Step: {}; namespace: {}; evaluation_loss: {}; evaluation_accuracy: {}".format(
idx_step, namespace, evaluation_loss, evaluation_accuracy
))
metrics[namespace] = {
"evaluation_loss": evaluation_loss,
"evaluation_accuracy": evaluation_accuracy,
}
# update ratio
min_accuracy = min([m["evaluation_accuracy"] for m in metrics.values()])
max_accuracy = max([m["evaluation_accuracy"] for m in metrics.values()])
width = max_accuracy - min_accuracy
for namespace, metric in metrics.items():
evaluation_accuracy = metric["evaluation_accuracy"]
radio = (max_accuracy - evaluation_accuracy) / width * max_radio
radio = int(radio)
namespace_to_ratio[namespace] = radio
msg = "".join(["{}: {}; ".format(k, v) for k, v in namespace_to_ratio.items()])
logger.info("namespace to ratio: {}".format(msg))
# save path
step_dir = serialization_dir / "step-{}".format(idx_step)
step_dir.mkdir(parents=True, exist_ok=False)
# save models
wave_encoder_filename = step_dir / "wave_encoder.pt"
torch.save(wave_encoder.state_dict(), wave_encoder_filename)
for namespace in namespaces:
cls_head_filename = step_dir / "{}.pt".format(namespace)
cls_head = namespace_to_cls_heads[namespace]
torch.save(cls_head.state_dict(), cls_head_filename)
model_list.append(step_dir)
if len(model_list) >= args.num_serialized_models_to_keep:
model_to_delete: Path = model_list.pop(0)
shutil.rmtree(model_to_delete.as_posix())
# save metric
this_accuracy = metrics["global_labels"]["evaluation_accuracy"]
if best_accuracy is None:
best_idx_step = idx_step
best_accuracy = this_accuracy
elif metrics["global_labels"]["evaluation_accuracy"] > best_accuracy:
best_idx_step = idx_step
best_accuracy = this_accuracy
else:
pass
metrics_filename = step_dir / "metrics_epoch.json"
metrics.update({
"idx_step": idx_step,
"best_idx_step": best_idx_step,
})
with open(metrics_filename, "w", encoding="utf-8") as f:
json.dump(metrics, f, indent=4, ensure_ascii=False)
# save best
best_dir = serialization_dir / "best"
if best_idx_step == idx_step:
if best_dir.exists():
shutil.rmtree(best_dir)
shutil.copytree(step_dir, best_dir)
# early stop
early_stop_flag = False
if best_idx_step == idx_step:
patience_count = 0
else:
patience_count += 1
if patience_count >= args.patience:
early_stop_flag = True
# early stop
if early_stop_flag:
break
return
if __name__ == "__main__":
main()