#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse from collections import defaultdict import json import logging from logging.handlers import TimedRotatingFileHandler import os import platform from pathlib import Path import random import sys import shutil from typing import List pwd = os.path.abspath(os.path.dirname(__file__)) sys.path.append(os.path.join(pwd, "../../")) import numpy as np import torch from torch.utils.data.dataloader import DataLoader from tqdm import tqdm from toolbox.torch.modules.loss import FocalLoss, HingeLoss, HingeLinear from toolbox.torch.training.metrics.categorical_accuracy import CategoricalAccuracy from toolbox.torch.utils.data.vocabulary import Vocabulary from toolbox.torch.utils.data.dataset.wave_classifier_excel_dataset import WaveClassifierExcelDataset from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel from toolbox.torchaudio.models.cnn_audio_classifier.configuration_cnn_audio_classifier import CnnAudioClassifierConfig def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--vocabulary_dir", default="vocabulary", type=str) parser.add_argument("--train_dataset", default="train.xlsx", type=str) parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) parser.add_argument("--max_epochs", default=100, type=int) parser.add_argument("--batch_size", default=64, type=int) parser.add_argument("--learning_rate", default=1e-3, type=float) parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) parser.add_argument("--patience", default=5, type=int) parser.add_argument("--serialization_dir", default="serialization_dir", type=str) parser.add_argument("--seed", default=0, type=int) parser.add_argument("--config_file", default="conv2d_classifier.yaml", type=str) args = parser.parse_args() return args def logging_config(file_dir: str): fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" logging.basicConfig(format=fmt, datefmt="%m/%d/%Y %H:%M:%S", level=logging.DEBUG) file_handler = TimedRotatingFileHandler( filename=os.path.join(file_dir, "main.log"), encoding="utf-8", when="D", interval=1, backupCount=7 ) file_handler.setLevel(logging.INFO) file_handler.setFormatter(logging.Formatter(fmt)) logger = logging.getLogger(__name__) logger.addHandler(file_handler) return logger class CollateFunction(object): def __init__(self): pass def __call__(self, batch: List[dict]): array_list = list() label_list = list() for sample in batch: array = sample["waveform"] label = sample["label"] l = len(array) if l < 16000: delta = int(16000 - l) array = np.concatenate([array, np.zeros(shape=(delta,), dtype=np.float32)], axis=-1) if l > 16000: array = array[:16000] array_list.append(array) label_list.append(label) array_list = torch.stack(array_list) label_list = torch.stack(label_list) return array_list, label_list collate_fn = CollateFunction() def main(): args = get_args() serialization_dir = Path(args.serialization_dir) serialization_dir.mkdir(parents=True, exist_ok=True) logger = logging_config(serialization_dir) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) logger.info("set seed: {}".format(args.seed)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) vocabulary = Vocabulary.from_files(args.vocabulary_dir) # datasets logger.info("prepare datasets") train_dataset = WaveClassifierExcelDataset( vocab=vocabulary, excel_file=args.train_dataset, category=None, category_field="category", label_field="labels", expected_sample_rate=8000, max_wave_value=32768.0, ) valid_dataset = WaveClassifierExcelDataset( vocab=vocabulary, excel_file=args.valid_dataset, category=None, category_field="category", label_field="labels", expected_sample_rate=8000, max_wave_value=32768.0, ) train_data_loader = DataLoader( dataset=train_dataset, batch_size=args.batch_size, shuffle=True, # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, collate_fn=collate_fn, pin_memory=False, # prefetch_factor=64, ) valid_data_loader = DataLoader( dataset=valid_dataset, batch_size=args.batch_size, shuffle=True, # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, collate_fn=collate_fn, pin_memory=False, # prefetch_factor=64, ) # models logger.info("prepare models") config = CnnAudioClassifierConfig.from_pretrained( pretrained_model_name_or_path=args.config_file, # num_labels=vocabulary.get_vocab_size(namespace="labels") ) if not config.cls_head_param["num_labels"] == vocabulary.get_vocab_size(namespace="labels"): raise AssertionError("expected num labels: {} instead of {}.".format( vocabulary.get_vocab_size(namespace="labels"), config.cls_head_param["num_labels"], )) model = WaveClassifierPretrainedModel( config=config, ) model.to(device) model.train() # optimizer logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy") param_optimizer = model.parameters() optimizer = torch.optim.Adam( param_optimizer, lr=args.learning_rate, ) # lr_scheduler = torch.optim.lr_scheduler.StepLR( # optimizer, # step_size=2000 # ) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5 ) focal_loss = FocalLoss( num_classes=vocabulary.get_vocab_size(namespace="labels"), reduction="mean", ) categorical_accuracy = CategoricalAccuracy() # training loop logger.info("training") training_loss = 10000000000 training_accuracy = 0. evaluation_loss = 10000000000 evaluation_accuracy = 0. model_list = list() best_idx_epoch = None best_accuracy = None patience_count = 0 for idx_epoch in range(args.max_epochs): categorical_accuracy.reset() total_loss = 0. total_examples = 0. progress_bar = tqdm( total=len(train_data_loader), desc="Training; epoch: {}".format(idx_epoch), ) for batch in train_data_loader: input_ids, label_ids = batch input_ids = input_ids.to(device) label_ids: torch.LongTensor = label_ids.to(device).long() logits = model.forward(input_ids) loss = focal_loss.forward(logits, label_ids.view(-1)) categorical_accuracy(logits, label_ids) total_loss += loss.item() total_examples += input_ids.size(0) optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() training_loss = total_loss / total_examples training_loss = round(training_loss, 4) training_accuracy = categorical_accuracy.get_metric()["accuracy"] training_accuracy = round(training_accuracy, 4) progress_bar.update(1) progress_bar.set_postfix({ "training_loss": training_loss, "training_accuracy": training_accuracy, }) categorical_accuracy.reset() total_loss = 0. total_examples = 0. progress_bar = tqdm( total=len(valid_data_loader), desc="Evaluation; epoch: {}".format(idx_epoch), ) for batch in valid_data_loader: input_ids, label_ids = batch input_ids = input_ids.to(device) label_ids: torch.LongTensor = label_ids.to(device).long() with torch.no_grad(): logits = model.forward(input_ids) loss = focal_loss.forward(logits, label_ids.view(-1)) categorical_accuracy(logits, label_ids) total_loss += loss.item() total_examples += input_ids.size(0) evaluation_loss = total_loss / total_examples evaluation_loss = round(evaluation_loss, 4) evaluation_accuracy = categorical_accuracy.get_metric()["accuracy"] evaluation_accuracy = round(evaluation_accuracy, 4) progress_bar.update(1) progress_bar.set_postfix({ "evaluation_loss": evaluation_loss, "evaluation_accuracy": evaluation_accuracy, }) # save path epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch) epoch_dir.mkdir(parents=True, exist_ok=False) # save models model.save_pretrained(epoch_dir.as_posix()) model_list.append(epoch_dir) if len(model_list) >= args.num_serialized_models_to_keep: model_to_delete: Path = model_list.pop(0) shutil.rmtree(model_to_delete.as_posix()) # save metric if best_accuracy is None: best_idx_epoch = idx_epoch best_accuracy = evaluation_accuracy elif evaluation_accuracy > best_accuracy: best_idx_epoch = idx_epoch best_accuracy = evaluation_accuracy else: pass metrics = { "idx_epoch": idx_epoch, "best_idx_epoch": best_idx_epoch, "best_accuracy": best_accuracy, "training_loss": training_loss, "training_accuracy": training_accuracy, "evaluation_loss": evaluation_loss, "evaluation_accuracy": evaluation_accuracy, "learning_rate": optimizer.param_groups[0]['lr'], } metrics_filename = epoch_dir / "metrics_epoch.json" with open(metrics_filename, "w", encoding="utf-8") as f: json.dump(metrics, f, indent=4, ensure_ascii=False) # save best best_dir = serialization_dir / "best" if best_idx_epoch == idx_epoch: if best_dir.exists(): shutil.rmtree(best_dir) shutil.copytree(epoch_dir, best_dir) # early stop early_stop_flag = False if best_idx_epoch == idx_epoch: patience_count = 0 else: patience_count += 1 if patience_count >= args.patience: early_stop_flag = True # early stop if early_stop_flag: break return if __name__ == "__main__": main()