Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
""" | |
之前的代码达到准确率0.8423 | |
此代码达到准确率0.8379 | |
此代码可行. | |
""" | |
import argparse | |
import copy | |
import json | |
import logging | |
from logging.handlers import TimedRotatingFileHandler | |
import os | |
from pathlib import Path | |
import platform | |
import sys | |
from typing import List | |
pwd = os.path.abspath(os.path.dirname(__file__)) | |
sys.path.append(os.path.join(pwd, "../../")) | |
import torch | |
from torch.utils.data.dataloader import DataLoader | |
from tqdm import tqdm | |
from toolbox.torch.modules.loss import FocalLoss, HingeLoss, HingeLinear | |
from toolbox.torch.training.metrics.categorical_accuracy import CategoricalAccuracy | |
from toolbox.torch.utils.data.vocabulary import Vocabulary | |
from toolbox.torch.utils.data.dataset.wave_classifier_excel_dataset import WaveClassifierExcelDataset | |
from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveEncoder, ClsHead, WaveClassifier | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--vocabulary_dir", default="vocabulary", type=str) | |
parser.add_argument("--train_dataset", default="train.xlsx", type=str) | |
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) | |
parser.add_argument("--max_epochs", default=100, type=int) | |
parser.add_argument("--batch_size", default=64, type=int) | |
parser.add_argument("--learning_rate", default=1e-3, type=float) | |
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) | |
parser.add_argument("--patience", default=5, type=int) | |
parser.add_argument("--serialization_dir", default="global_classifier", type=str) | |
parser.add_argument("--seed", default=0, type=int) | |
args = parser.parse_args() | |
return args | |
def logging_config(file_dir: str): | |
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" | |
logging.basicConfig(format=fmt, | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.DEBUG) | |
file_handler = TimedRotatingFileHandler( | |
filename=os.path.join(file_dir, "main.log"), | |
encoding="utf-8", | |
when="D", | |
interval=1, | |
backupCount=7 | |
) | |
file_handler.setLevel(logging.INFO) | |
file_handler.setFormatter(logging.Formatter(fmt)) | |
logger = logging.getLogger(__name__) | |
logger.addHandler(file_handler) | |
return logger | |
class CollateFunction(object): | |
def __init__(self): | |
pass | |
def __call__(self, batch: List[dict]): | |
array_list = list() | |
label_list = list() | |
for sample in batch: | |
array = sample["waveform"] | |
label = sample["label"] | |
array_list.append(array) | |
label_list.append(label) | |
array_list = torch.stack(array_list) | |
label_list = torch.stack(label_list) | |
return array_list, label_list | |
collate_fn = CollateFunction() | |
def main(): | |
args = get_args() | |
serialization_dir = Path(args.serialization_dir) | |
serialization_dir.mkdir(parents=True, exist_ok=True) | |
logger = logging_config(args.serialization_dir) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
n_gpu = torch.cuda.device_count() | |
logger.info("GPU available: {}; device: {}".format(n_gpu, device)) | |
vocabulary = Vocabulary.from_files(args.vocabulary_dir) | |
# datasets | |
train_dataset = WaveClassifierExcelDataset( | |
vocab=vocabulary, | |
excel_file=args.train_dataset, | |
category=None, | |
category_field="category", | |
label_field="global_labels", | |
expected_sample_rate=8000, | |
max_wave_value=32768.0, | |
) | |
valid_dataset = WaveClassifierExcelDataset( | |
vocab=vocabulary, | |
excel_file=args.valid_dataset, | |
category=None, | |
category_field="category", | |
label_field="global_labels", | |
expected_sample_rate=8000, | |
max_wave_value=32768.0, | |
) | |
train_data_loader = DataLoader( | |
dataset=train_dataset, | |
batch_size=args.batch_size, | |
shuffle=True, | |
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. | |
num_workers=0 if platform.system() == "Windows" else os.cpu_count(), | |
collate_fn=collate_fn, | |
pin_memory=False, | |
# prefetch_factor=64, | |
) | |
valid_data_loader = DataLoader( | |
dataset=valid_dataset, | |
batch_size=args.batch_size, | |
shuffle=True, | |
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. | |
num_workers=0 if platform.system() == "Windows" else os.cpu_count(), | |
collate_fn=collate_fn, | |
pin_memory=False, | |
# prefetch_factor=64, | |
) | |
# models - classifier | |
wave_encoder = WaveEncoder( | |
conv1d_block_param_list=[ | |
{ | |
'batch_norm': True, | |
'in_channels': 80, | |
'out_channels': 16, | |
'kernel_size': 3, | |
'stride': 3, | |
# 'padding': 'same', | |
'activation': 'relu', | |
'dropout': 0.1, | |
}, | |
{ | |
# 'batch_norm': True, | |
'in_channels': 16, | |
'out_channels': 16, | |
'kernel_size': 3, | |
'stride': 3, | |
# 'padding': 'same', | |
'activation': 'relu', | |
'dropout': 0.1, | |
}, | |
{ | |
# 'batch_norm': True, | |
'in_channels': 16, | |
'out_channels': 16, | |
'kernel_size': 3, | |
'stride': 3, | |
# 'padding': 'same', | |
'activation': 'relu', | |
'dropout': 0.1, | |
}, | |
], | |
mel_spectrogram_param={ | |
"sample_rate": 8000, | |
"n_fft": 512, | |
"win_length": 200, | |
"hop_length": 80, | |
"f_min": 10, | |
"f_max": 3800, | |
"window_fn": "hamming", | |
"n_mels": 80, | |
} | |
) | |
cls_head = ClsHead( | |
input_dim=16, | |
num_layers=2, | |
hidden_dims=[32, 16], | |
activations="relu", | |
dropout=0.1, | |
num_labels=vocabulary.get_vocab_size(namespace="global_labels") | |
) | |
model = WaveClassifier( | |
wave_encoder=wave_encoder, | |
cls_head=cls_head, | |
) | |
model.to(device) | |
# optimizer | |
optimizer = torch.optim.Adam( | |
model.parameters(), | |
lr=args.learning_rate | |
) | |
lr_scheduler = torch.optim.lr_scheduler.StepLR( | |
optimizer, | |
step_size=30000 | |
) | |
focal_loss = FocalLoss( | |
num_classes=vocabulary.get_vocab_size(namespace="global_labels"), | |
reduction="mean", | |
) | |
categorical_accuracy = CategoricalAccuracy() | |
# training | |
best_idx_epoch: int = None | |
best_accuracy: float = None | |
patience_count = 0 | |
global_step = 0 | |
model_filename_list = list() | |
for idx_epoch in range(args.max_epochs): | |
# training | |
model.train() | |
total_loss = 0 | |
total_examples = 0 | |
for step, batch in enumerate(tqdm(train_data_loader, desc="Epoch={} (training)".format(idx_epoch))): | |
input_ids, label_ids = batch | |
input_ids = input_ids.to(device) | |
label_ids: torch.LongTensor = label_ids.to(device).long() | |
logits = model.forward(input_ids) | |
loss = focal_loss.forward(logits, label_ids.view(-1)) | |
categorical_accuracy(logits, label_ids) | |
total_loss += loss.item() | |
total_examples += input_ids.size(0) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
lr_scheduler.step() | |
global_step += 1 | |
training_loss = total_loss / total_examples | |
training_loss = round(training_loss, 4) | |
training_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"] | |
training_accuracy = round(training_accuracy, 4) | |
logger.info("Epoch: {}; training_loss: {}; training_accuracy: {}".format( | |
idx_epoch, training_loss, training_accuracy | |
)) | |
# evaluation | |
model.eval() | |
total_loss = 0 | |
total_examples = 0 | |
for step, batch in enumerate(tqdm(valid_data_loader, desc="Epoch={} (evaluation)".format(idx_epoch))): | |
input_ids, label_ids = batch | |
input_ids = input_ids.to(device) | |
label_ids: torch.LongTensor = label_ids.to(device).long() | |
with torch.no_grad(): | |
logits = model.forward(input_ids) | |
loss = focal_loss.forward(logits, label_ids.view(-1)) | |
categorical_accuracy(logits, label_ids) | |
total_loss += loss.item() | |
total_examples += input_ids.size(0) | |
evaluation_loss = total_loss / total_examples | |
evaluation_loss = round(evaluation_loss, 4) | |
evaluation_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"] | |
evaluation_accuracy = round(evaluation_accuracy, 4) | |
logger.info("Epoch: {}; evaluation_loss: {}; evaluation_accuracy: {}".format( | |
idx_epoch, evaluation_loss, evaluation_accuracy | |
)) | |
# save metric | |
metrics = { | |
"training_loss": training_loss, | |
"training_accuracy": training_accuracy, | |
"evaluation_loss": evaluation_loss, | |
"evaluation_accuracy": evaluation_accuracy, | |
"best_idx_epoch": best_idx_epoch, | |
"best_accuracy": best_accuracy, | |
} | |
metrics_filename = os.path.join(args.serialization_dir, "metrics_epoch_{}.json".format(idx_epoch)) | |
with open(metrics_filename, "w", encoding="utf-8") as f: | |
json.dump(metrics, f, indent=4, ensure_ascii=False) | |
# save model | |
model_filename = os.path.join(args.serialization_dir, "model_epoch_{}.bin".format(idx_epoch)) | |
model_filename_list.append(model_filename) | |
if len(model_filename_list) >= args.num_serialized_models_to_keep: | |
model_filename_to_delete = model_filename_list.pop(0) | |
os.remove(model_filename_to_delete) | |
torch.save(model.state_dict(), model_filename) | |
# early stop | |
best_model_filename = os.path.join(args.serialization_dir, "best.bin") | |
if best_accuracy is None: | |
best_idx_epoch = idx_epoch | |
best_accuracy = evaluation_accuracy | |
torch.save(model.state_dict(), best_model_filename) | |
elif evaluation_accuracy > best_accuracy: | |
best_idx_epoch = idx_epoch | |
best_accuracy = evaluation_accuracy | |
torch.save(model.state_dict(), best_model_filename) | |
patience_count = 0 | |
elif patience_count >= args.patience: | |
break | |
else: | |
patience_count += 1 | |
return | |
if __name__ == "__main__": | |
main() | |