Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
import argparse | |
from collections import defaultdict | |
import json | |
import logging | |
from logging.handlers import TimedRotatingFileHandler | |
import os | |
import platform | |
from pathlib import Path | |
import random | |
import sys | |
import shutil | |
import tempfile | |
from typing import List | |
import zipfile | |
pwd = os.path.abspath(os.path.dirname(__file__)) | |
sys.path.append(os.path.join(pwd, "../../")) | |
import numpy as np | |
import torch | |
from torch.utils.data.dataloader import DataLoader | |
from tqdm import tqdm | |
from toolbox.torch.modules.loss import FocalLoss, HingeLoss, HingeLinear | |
from toolbox.torch.training.metrics.categorical_accuracy import CategoricalAccuracy | |
from toolbox.torch.utils.data.vocabulary import Vocabulary | |
from toolbox.torch.utils.data.dataset.wave_classifier_excel_dataset import WaveClassifierExcelDataset | |
from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel | |
from toolbox.torchaudio.models.cnn_audio_classifier.configuration_cnn_audio_classifier import CnnAudioClassifierConfig | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--vocabulary_dir", default="vocabulary", type=str) | |
parser.add_argument("--train_dataset", default="train.xlsx", type=str) | |
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) | |
parser.add_argument("--max_epochs", default=100, type=int) | |
parser.add_argument("--batch_size", default=64, type=int) | |
parser.add_argument("--learning_rate", default=1e-3, type=float) | |
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) | |
parser.add_argument("--patience", default=5, type=int) | |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str) | |
parser.add_argument("--seed", default=0, type=int) | |
parser.add_argument("--config_file", default="conv2d_classifier.yaml", type=str) | |
parser.add_argument( | |
"--pretrained_model", | |
# default=(project_path / "trained_models/voicemail-en-sg-2-ch4.zip").as_posix(), | |
default="null", | |
type=str | |
) | |
args = parser.parse_args() | |
return args | |
def logging_config(file_dir: str): | |
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" | |
logging.basicConfig(format=fmt, | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.DEBUG) | |
file_handler = TimedRotatingFileHandler( | |
filename=os.path.join(file_dir, "main.log"), | |
encoding="utf-8", | |
when="D", | |
interval=1, | |
backupCount=7 | |
) | |
file_handler.setLevel(logging.INFO) | |
file_handler.setFormatter(logging.Formatter(fmt)) | |
logger = logging.getLogger(__name__) | |
logger.addHandler(file_handler) | |
return logger | |
class CollateFunction(object): | |
def __init__(self): | |
pass | |
def __call__(self, batch: List[dict]): | |
array_list = list() | |
label_list = list() | |
for sample in batch: | |
array = sample["waveform"] | |
label = sample["label"] | |
l = len(array) | |
if l < 16000: | |
delta = int(16000 - l) | |
array = np.concatenate([array, np.zeros(shape=(delta,), dtype=np.float32)], axis=-1) | |
if l > 16000: | |
array = array[:16000] | |
array_list.append(array) | |
label_list.append(label) | |
array_list = torch.stack(array_list) | |
label_list = torch.stack(label_list) | |
return array_list, label_list | |
collate_fn = CollateFunction() | |
def main(): | |
args = get_args() | |
serialization_dir = Path(args.serialization_dir) | |
serialization_dir.mkdir(parents=True, exist_ok=True) | |
logger = logging_config(serialization_dir) | |
random.seed(args.seed) | |
np.random.seed(args.seed) | |
torch.manual_seed(args.seed) | |
logger.info("set seed: {}".format(args.seed)) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
n_gpu = torch.cuda.device_count() | |
logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) | |
vocabulary = Vocabulary.from_files(args.vocabulary_dir) | |
# datasets | |
logger.info("prepare datasets") | |
train_dataset = WaveClassifierExcelDataset( | |
vocab=vocabulary, | |
excel_file=args.train_dataset, | |
category=None, | |
category_field="category", | |
label_field="labels", | |
expected_sample_rate=8000, | |
max_wave_value=32768.0, | |
) | |
valid_dataset = WaveClassifierExcelDataset( | |
vocab=vocabulary, | |
excel_file=args.valid_dataset, | |
category=None, | |
category_field="category", | |
label_field="labels", | |
expected_sample_rate=8000, | |
max_wave_value=32768.0, | |
) | |
train_data_loader = DataLoader( | |
dataset=train_dataset, | |
batch_size=args.batch_size, | |
shuffle=True, | |
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. | |
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, | |
collate_fn=collate_fn, | |
pin_memory=False, | |
# prefetch_factor=64, | |
) | |
valid_data_loader = DataLoader( | |
dataset=valid_dataset, | |
batch_size=args.batch_size, | |
shuffle=True, | |
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. | |
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, | |
collate_fn=collate_fn, | |
pin_memory=False, | |
# prefetch_factor=64, | |
) | |
# models | |
logger.info(f"prepare models. config_file: {args.config_file}") | |
config = CnnAudioClassifierConfig.from_pretrained( | |
pretrained_model_name_or_path=args.config_file, | |
# num_labels=vocabulary.get_vocab_size(namespace="labels") | |
) | |
if not config.cls_head_param["num_labels"] == vocabulary.get_vocab_size(namespace="labels"): | |
raise AssertionError("expected num labels: {} instead of {}.".format( | |
vocabulary.get_vocab_size(namespace="labels"), | |
config.cls_head_param["num_labels"], | |
)) | |
model = WaveClassifierPretrainedModel( | |
config=config, | |
) | |
if args.pretrained_model is not None and os.path.exists(args.pretrained_model): | |
logger.info(f"load pretrained model state dict from: {args.pretrained_model}") | |
pretrained_model = Path(args.pretrained_model) | |
with zipfile.ZipFile(pretrained_model.as_posix(), "r") as f_zip: | |
out_root = Path(tempfile.gettempdir()) / "vm_sound_classification" | |
# print(out_root.as_posix()) | |
if out_root.exists(): | |
shutil.rmtree(out_root.as_posix()) | |
out_root.mkdir(parents=True, exist_ok=True) | |
f_zip.extractall(path=out_root) | |
tgt_path = out_root / pretrained_model.stem | |
model_pt_file = tgt_path / "model.pt" | |
with open(model_pt_file, "rb") as f: | |
state_dict = torch.load(f, map_location="cpu") | |
model.load_state_dict(state_dict=state_dict) | |
model.to(device) | |
model.train() | |
# optimizer | |
logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy") | |
param_optimizer = model.parameters() | |
optimizer = torch.optim.Adam( | |
param_optimizer, | |
lr=args.learning_rate, | |
) | |
# lr_scheduler = torch.optim.lr_scheduler.StepLR( | |
# optimizer, | |
# step_size=2000 | |
# ) | |
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | |
optimizer, | |
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5 | |
) | |
focal_loss = FocalLoss( | |
num_classes=vocabulary.get_vocab_size(namespace="labels"), | |
reduction="mean", | |
) | |
categorical_accuracy = CategoricalAccuracy() | |
# training loop | |
logger.info("training") | |
training_loss = 10000000000 | |
training_accuracy = 0. | |
evaluation_loss = 10000000000 | |
evaluation_accuracy = 0. | |
model_list = list() | |
best_idx_epoch = None | |
best_accuracy = None | |
patience_count = 0 | |
for idx_epoch in range(args.max_epochs): | |
categorical_accuracy.reset() | |
total_loss = 0. | |
total_examples = 0. | |
progress_bar = tqdm( | |
total=len(train_data_loader), | |
desc="Training; epoch: {}".format(idx_epoch), | |
) | |
for batch in train_data_loader: | |
input_ids, label_ids = batch | |
input_ids = input_ids.to(device) | |
label_ids: torch.LongTensor = label_ids.to(device).long() | |
logits = model.forward(input_ids) | |
loss = focal_loss.forward(logits, label_ids.view(-1)) | |
categorical_accuracy(logits, label_ids) | |
total_loss += loss.item() | |
total_examples += input_ids.size(0) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
lr_scheduler.step() | |
training_loss = total_loss / total_examples | |
training_loss = round(training_loss, 4) | |
training_accuracy = categorical_accuracy.get_metric()["accuracy"] | |
training_accuracy = round(training_accuracy, 4) | |
progress_bar.update(1) | |
progress_bar.set_postfix({ | |
"training_loss": training_loss, | |
"training_accuracy": training_accuracy, | |
}) | |
categorical_accuracy.reset() | |
total_loss = 0. | |
total_examples = 0. | |
progress_bar = tqdm( | |
total=len(valid_data_loader), | |
desc="Evaluation; epoch: {}".format(idx_epoch), | |
) | |
for batch in valid_data_loader: | |
input_ids, label_ids = batch | |
input_ids = input_ids.to(device) | |
label_ids: torch.LongTensor = label_ids.to(device).long() | |
with torch.no_grad(): | |
logits = model.forward(input_ids) | |
loss = focal_loss.forward(logits, label_ids.view(-1)) | |
categorical_accuracy(logits, label_ids) | |
total_loss += loss.item() | |
total_examples += input_ids.size(0) | |
evaluation_loss = total_loss / total_examples | |
evaluation_loss = round(evaluation_loss, 4) | |
evaluation_accuracy = categorical_accuracy.get_metric()["accuracy"] | |
evaluation_accuracy = round(evaluation_accuracy, 4) | |
progress_bar.update(1) | |
progress_bar.set_postfix({ | |
"evaluation_loss": evaluation_loss, | |
"evaluation_accuracy": evaluation_accuracy, | |
}) | |
# save path | |
epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch) | |
epoch_dir.mkdir(parents=True, exist_ok=False) | |
# save models | |
model.save_pretrained(epoch_dir.as_posix()) | |
model_list.append(epoch_dir) | |
if len(model_list) >= args.num_serialized_models_to_keep: | |
model_to_delete: Path = model_list.pop(0) | |
shutil.rmtree(model_to_delete.as_posix()) | |
# save metric | |
if best_accuracy is None: | |
best_idx_epoch = idx_epoch | |
best_accuracy = evaluation_accuracy | |
elif evaluation_accuracy > best_accuracy: | |
best_idx_epoch = idx_epoch | |
best_accuracy = evaluation_accuracy | |
else: | |
pass | |
metrics = { | |
"idx_epoch": idx_epoch, | |
"best_idx_epoch": best_idx_epoch, | |
"best_accuracy": best_accuracy, | |
"training_loss": training_loss, | |
"training_accuracy": training_accuracy, | |
"evaluation_loss": evaluation_loss, | |
"evaluation_accuracy": evaluation_accuracy, | |
"learning_rate": optimizer.param_groups[0]['lr'], | |
} | |
metrics_filename = epoch_dir / "metrics_epoch.json" | |
with open(metrics_filename, "w", encoding="utf-8") as f: | |
json.dump(metrics, f, indent=4, ensure_ascii=False) | |
# save best | |
best_dir = serialization_dir / "best" | |
if best_idx_epoch == idx_epoch: | |
if best_dir.exists(): | |
shutil.rmtree(best_dir) | |
shutil.copytree(epoch_dir, best_dir) | |
# early stop | |
early_stop_flag = False | |
if best_idx_epoch == idx_epoch: | |
patience_count = 0 | |
else: | |
patience_count += 1 | |
if patience_count >= args.patience: | |
early_stop_flag = True | |
# early stop | |
if early_stop_flag: | |
break | |
return | |
if __name__ == "__main__": | |
main() | |