call-audio-8 / examples /vm_sound_classification8 /step_4_train_country_model.py
HoneyTian's picture
update
69ad385
raw
history blame
11.7 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
只训练 cls_head 部分的参数, 模型的准确率会更低.
"""
import argparse
from collections import defaultdict
import json
import logging
from logging.handlers import TimedRotatingFileHandler
import os
import platform
from pathlib import Path
import sys
import shutil
from typing import List
pwd = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(pwd, "../../"))
import pandas as pd
import torch
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from toolbox.torch.modules.loss import FocalLoss, HingeLoss, HingeLinear
from toolbox.torch.training.metrics.categorical_accuracy import CategoricalAccuracy
from toolbox.torch.utils.data.vocabulary import Vocabulary
from toolbox.torch.utils.data.dataset.wave_classifier_excel_dataset import WaveClassifierExcelDataset
from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveEncoder, ClsHead, WaveClassifier
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
parser.add_argument("--country", default="en-US", type=str)
parser.add_argument("--shared_encoder", default="file_dir/global_model/best.bin", type=str)
parser.add_argument("--max_epochs", default=100, type=int)
parser.add_argument("--batch_size", default=64, type=int)
parser.add_argument("--learning_rate", default=1e-3, type=float)
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
parser.add_argument("--patience", default=5, type=int)
parser.add_argument("--serialization_dir", default="country_models", type=str)
parser.add_argument("--seed", default=0, type=int)
args = parser.parse_args()
return args
def logging_config(file_dir: str):
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
logging.basicConfig(format=fmt,
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.DEBUG)
file_handler = TimedRotatingFileHandler(
filename=os.path.join(file_dir, "main.log"),
encoding="utf-8",
when="D",
interval=1,
backupCount=7
)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(logging.Formatter(fmt))
logger = logging.getLogger(__name__)
logger.addHandler(file_handler)
return logger
class CollateFunction(object):
def __init__(self):
pass
def __call__(self, batch: List[dict]):
array_list = list()
label_list = list()
for sample in batch:
array = sample['waveform']
label = sample['label']
array_list.append(array)
label_list.append(label)
array_list = torch.stack(array_list)
label_list = torch.stack(label_list)
return array_list, label_list
collate_fn = CollateFunction()
def main():
args = get_args()
serialization_dir = Path(args.serialization_dir)
serialization_dir.mkdir(parents=True, exist_ok=True)
logger = logging_config(args.serialization_dir)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
logger.info("GPU available: {}; device: {}".format(n_gpu, device))
vocabulary = Vocabulary.from_files(args.vocabulary_dir)
# datasets
logger.info("prepare datasets")
train_dataset = WaveClassifierExcelDataset(
vocab=vocabulary,
excel_file=args.train_dataset,
category=args.country,
category_field="category",
label_field="country_labels",
expected_sample_rate=8000,
max_wave_value=32768.0,
)
valid_dataset = WaveClassifierExcelDataset(
vocab=vocabulary,
excel_file=args.valid_dataset,
category=args.country,
category_field="category",
label_field="country_labels",
expected_sample_rate=8000,
max_wave_value=32768.0,
)
train_data_loader = DataLoader(
dataset=train_dataset,
batch_size=args.batch_size,
shuffle=True,
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
num_workers=0 if platform.system() == "Windows" else os.cpu_count(),
collate_fn=collate_fn,
pin_memory=False,
# prefetch_factor=64,
)
valid_data_loader = DataLoader(
dataset=valid_dataset,
batch_size=args.batch_size,
shuffle=True,
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
num_workers=0 if platform.system() == "Windows" else os.cpu_count(),
collate_fn=collate_fn,
pin_memory=False,
# prefetch_factor=64,
)
# models - classifier
wave_encoder = WaveEncoder(
conv1d_block_param_list=[
{
'batch_norm': True,
'in_channels': 80,
'out_channels': 16,
'kernel_size': 3,
'stride': 3,
# 'padding': 'same',
'activation': 'relu',
'dropout': 0.1,
},
{
# 'batch_norm': True,
'in_channels': 16,
'out_channels': 16,
'kernel_size': 3,
'stride': 3,
# 'padding': 'same',
'activation': 'relu',
'dropout': 0.1,
},
{
# 'batch_norm': True,
'in_channels': 16,
'out_channels': 16,
'kernel_size': 3,
'stride': 3,
# 'padding': 'same',
'activation': 'relu',
'dropout': 0.1,
},
],
mel_spectrogram_param={
"sample_rate": 8000,
"n_fft": 512,
"win_length": 200,
"hop_length": 80,
"f_min": 10,
"f_max": 3800,
"window_fn": "hamming",
"n_mels": 80,
}
)
with open(args.shared_encoder, "rb") as f:
state_dict = torch.load(f, map_location=device)
processed_state_dict = dict()
prefix = "wave_encoder."
for k, v in state_dict.items():
if not str(k).startswith(prefix):
continue
k = k[len(prefix):]
processed_state_dict[k] = v
wave_encoder.load_state_dict(
state_dict=processed_state_dict,
strict=True,
)
cls_head = ClsHead(
input_dim=16,
num_layers=2,
hidden_dims=[32, 16],
activations="relu",
dropout=0.1,
num_labels=vocabulary.get_vocab_size(namespace="global_labels")
)
model = WaveClassifier(
wave_encoder=wave_encoder,
cls_head=cls_head,
)
model.wave_encoder.requires_grad_(requires_grad=False)
model.cls_head.requires_grad_(requires_grad=True)
model.to(device)
# optimizer
logger.info("prepare optimizer")
optimizer = torch.optim.Adam(
model.cls_head.parameters(),
lr=args.learning_rate,
)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=2000
)
focal_loss = FocalLoss(
num_classes=vocabulary.get_vocab_size(namespace=args.country),
reduction="mean",
)
categorical_accuracy = CategoricalAccuracy()
# training loop
best_idx_epoch: int = None
best_accuracy: float = None
patience_count = 0
global_step = 0
model_filename_list = list()
for idx_epoch in range(args.max_epochs):
# training
model.train()
total_loss = 0
total_examples = 0
for step, batch in enumerate(tqdm(train_data_loader, desc="Epoch={} (training)".format(idx_epoch))):
input_ids, label_ids = batch
input_ids = input_ids.to(device)
label_ids: torch.LongTensor = label_ids.to(device).long()
logits = model.forward(input_ids)
loss = focal_loss.forward(logits, label_ids.view(-1))
categorical_accuracy(logits, label_ids)
total_loss += loss.item()
total_examples += input_ids.size(0)
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
global_step += 1
training_loss = total_loss / total_examples
training_loss = round(training_loss, 4)
training_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
training_accuracy = round(training_accuracy, 4)
logger.info("Epoch: {}; training_loss: {}; training_accuracy: {}".format(
idx_epoch, training_loss, training_accuracy
))
# evaluation
model.eval()
total_loss = 0
total_examples = 0
for step, batch in enumerate(tqdm(valid_data_loader, desc="Epoch={} (evaluation)".format(idx_epoch))):
input_ids, label_ids = batch
input_ids = input_ids.to(device)
label_ids: torch.LongTensor = label_ids.to(device).long()
with torch.no_grad():
logits = model.forward(input_ids)
loss = focal_loss.forward(logits, label_ids.view(-1))
categorical_accuracy(logits, label_ids)
total_loss += loss.item()
total_examples += input_ids.size(0)
evaluation_loss = total_loss / total_examples
evaluation_loss = round(evaluation_loss, 4)
evaluation_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
evaluation_accuracy = round(evaluation_accuracy, 4)
logger.info("Epoch: {}; evaluation_loss: {}; evaluation_accuracy: {}".format(
idx_epoch, evaluation_loss, evaluation_accuracy
))
# save metric
metrics = {
"training_loss": training_loss,
"training_accuracy": training_accuracy,
"evaluation_loss": evaluation_loss,
"evaluation_accuracy": evaluation_accuracy,
"best_idx_epoch": best_idx_epoch,
"best_accuracy": best_accuracy,
}
metrics_filename = os.path.join(args.serialization_dir, "metrics_epoch_{}.json".format(idx_epoch))
with open(metrics_filename, "w", encoding="utf-8") as f:
json.dump(metrics, f, indent=4, ensure_ascii=False)
# save model
model_filename = os.path.join(args.serialization_dir, "model_epoch_{}.bin".format(idx_epoch))
model_filename_list.append(model_filename)
if len(model_filename_list) >= args.num_serialized_models_to_keep:
model_filename_to_delete = model_filename_list.pop(0)
os.remove(model_filename_to_delete)
torch.save(model.state_dict(), model_filename)
# early stop
best_model_filename = os.path.join(args.serialization_dir, "best.bin")
if best_accuracy is None:
best_idx_epoch = idx_epoch
best_accuracy = evaluation_accuracy
torch.save(model.state_dict(), best_model_filename)
elif evaluation_accuracy > best_accuracy:
best_idx_epoch = idx_epoch
best_accuracy = evaluation_accuracy
torch.save(model.state_dict(), best_model_filename)
patience_count = 0
elif patience_count >= args.patience:
break
else:
patience_count += 1
return
if __name__ == "__main__":
main()