#!/usr/bin/python3 # -*- coding: utf-8 -*- """ 之前的代码达到准确率0.8423 此代码达到准确率0.8379 此代码可行. """ import argparse import copy import json import logging from logging.handlers import TimedRotatingFileHandler import os from pathlib import Path import platform import sys from typing import List pwd = os.path.abspath(os.path.dirname(__file__)) sys.path.append(os.path.join(pwd, "../../")) import torch from torch.utils.data.dataloader import DataLoader from tqdm import tqdm from toolbox.torch.modules.loss import FocalLoss, HingeLoss, HingeLinear from toolbox.torch.training.metrics.categorical_accuracy import CategoricalAccuracy from toolbox.torch.utils.data.vocabulary import Vocabulary from toolbox.torch.utils.data.dataset.wave_classifier_excel_dataset import WaveClassifierExcelDataset from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveEncoder, ClsHead, WaveClassifier def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--vocabulary_dir", default="vocabulary", type=str) parser.add_argument("--train_dataset", default="train.xlsx", type=str) parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) parser.add_argument("--max_epochs", default=100, type=int) parser.add_argument("--batch_size", default=64, type=int) parser.add_argument("--learning_rate", default=1e-3, type=float) parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) parser.add_argument("--patience", default=5, type=int) parser.add_argument("--serialization_dir", default="global_classifier", type=str) parser.add_argument("--seed", default=0, type=int) args = parser.parse_args() return args def logging_config(file_dir: str): fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" logging.basicConfig(format=fmt, datefmt="%m/%d/%Y %H:%M:%S", level=logging.DEBUG) file_handler = TimedRotatingFileHandler( filename=os.path.join(file_dir, "main.log"), encoding="utf-8", when="D", interval=1, backupCount=7 ) file_handler.setLevel(logging.INFO) file_handler.setFormatter(logging.Formatter(fmt)) logger = logging.getLogger(__name__) logger.addHandler(file_handler) return logger class CollateFunction(object): def __init__(self): pass def __call__(self, batch: List[dict]): array_list = list() label_list = list() for sample in batch: array = sample["waveform"] label = sample["label"] array_list.append(array) label_list.append(label) array_list = torch.stack(array_list) label_list = torch.stack(label_list) return array_list, label_list collate_fn = CollateFunction() def main(): args = get_args() serialization_dir = Path(args.serialization_dir) serialization_dir.mkdir(parents=True, exist_ok=True) logger = logging_config(args.serialization_dir) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() logger.info("GPU available: {}; device: {}".format(n_gpu, device)) vocabulary = Vocabulary.from_files(args.vocabulary_dir) # datasets train_dataset = WaveClassifierExcelDataset( vocab=vocabulary, excel_file=args.train_dataset, category=None, category_field="category", label_field="global_labels", expected_sample_rate=8000, max_wave_value=32768.0, ) valid_dataset = WaveClassifierExcelDataset( vocab=vocabulary, excel_file=args.valid_dataset, category=None, category_field="category", label_field="global_labels", expected_sample_rate=8000, max_wave_value=32768.0, ) train_data_loader = DataLoader( dataset=train_dataset, batch_size=args.batch_size, shuffle=True, # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. num_workers=0 if platform.system() == "Windows" else os.cpu_count(), collate_fn=collate_fn, pin_memory=False, # prefetch_factor=64, ) valid_data_loader = DataLoader( dataset=valid_dataset, batch_size=args.batch_size, shuffle=True, # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. num_workers=0 if platform.system() == "Windows" else os.cpu_count(), collate_fn=collate_fn, pin_memory=False, # prefetch_factor=64, ) # models - classifier wave_encoder = WaveEncoder( conv1d_block_param_list=[ { 'batch_norm': True, 'in_channels': 80, 'out_channels': 16, 'kernel_size': 3, 'stride': 3, # 'padding': 'same', 'activation': 'relu', 'dropout': 0.1, }, { # 'batch_norm': True, 'in_channels': 16, 'out_channels': 16, 'kernel_size': 3, 'stride': 3, # 'padding': 'same', 'activation': 'relu', 'dropout': 0.1, }, { # 'batch_norm': True, 'in_channels': 16, 'out_channels': 16, 'kernel_size': 3, 'stride': 3, # 'padding': 'same', 'activation': 'relu', 'dropout': 0.1, }, ], mel_spectrogram_param={ "sample_rate": 8000, "n_fft": 512, "win_length": 200, "hop_length": 80, "f_min": 10, "f_max": 3800, "window_fn": "hamming", "n_mels": 80, } ) cls_head = ClsHead( input_dim=16, num_layers=2, hidden_dims=[32, 16], activations="relu", dropout=0.1, num_labels=vocabulary.get_vocab_size(namespace="global_labels") ) model = WaveClassifier( wave_encoder=wave_encoder, cls_head=cls_head, ) model.to(device) # optimizer optimizer = torch.optim.Adam( model.parameters(), lr=args.learning_rate ) lr_scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=30000 ) focal_loss = FocalLoss( num_classes=vocabulary.get_vocab_size(namespace="global_labels"), reduction="mean", ) categorical_accuracy = CategoricalAccuracy() # training best_idx_epoch: int = None best_accuracy: float = None patience_count = 0 global_step = 0 model_filename_list = list() for idx_epoch in range(args.max_epochs): # training model.train() total_loss = 0 total_examples = 0 for step, batch in enumerate(tqdm(train_data_loader, desc="Epoch={} (training)".format(idx_epoch))): input_ids, label_ids = batch input_ids = input_ids.to(device) label_ids: torch.LongTensor = label_ids.to(device).long() logits = model.forward(input_ids) loss = focal_loss.forward(logits, label_ids.view(-1)) categorical_accuracy(logits, label_ids) total_loss += loss.item() total_examples += input_ids.size(0) optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() global_step += 1 training_loss = total_loss / total_examples training_loss = round(training_loss, 4) training_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"] training_accuracy = round(training_accuracy, 4) logger.info("Epoch: {}; training_loss: {}; training_accuracy: {}".format( idx_epoch, training_loss, training_accuracy )) # evaluation model.eval() total_loss = 0 total_examples = 0 for step, batch in enumerate(tqdm(valid_data_loader, desc="Epoch={} (evaluation)".format(idx_epoch))): input_ids, label_ids = batch input_ids = input_ids.to(device) label_ids: torch.LongTensor = label_ids.to(device).long() with torch.no_grad(): logits = model.forward(input_ids) loss = focal_loss.forward(logits, label_ids.view(-1)) categorical_accuracy(logits, label_ids) total_loss += loss.item() total_examples += input_ids.size(0) evaluation_loss = total_loss / total_examples evaluation_loss = round(evaluation_loss, 4) evaluation_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"] evaluation_accuracy = round(evaluation_accuracy, 4) logger.info("Epoch: {}; evaluation_loss: {}; evaluation_accuracy: {}".format( idx_epoch, evaluation_loss, evaluation_accuracy )) # save metric metrics = { "training_loss": training_loss, "training_accuracy": training_accuracy, "evaluation_loss": evaluation_loss, "evaluation_accuracy": evaluation_accuracy, "best_idx_epoch": best_idx_epoch, "best_accuracy": best_accuracy, } metrics_filename = os.path.join(args.serialization_dir, "metrics_epoch_{}.json".format(idx_epoch)) with open(metrics_filename, "w", encoding="utf-8") as f: json.dump(metrics, f, indent=4, ensure_ascii=False) # save model model_filename = os.path.join(args.serialization_dir, "model_epoch_{}.bin".format(idx_epoch)) model_filename_list.append(model_filename) if len(model_filename_list) >= args.num_serialized_models_to_keep: model_filename_to_delete = model_filename_list.pop(0) os.remove(model_filename_to_delete) torch.save(model.state_dict(), model_filename) # early stop best_model_filename = os.path.join(args.serialization_dir, "best.bin") if best_accuracy is None: best_idx_epoch = idx_epoch best_accuracy = evaluation_accuracy torch.save(model.state_dict(), best_model_filename) elif evaluation_accuracy > best_accuracy: best_idx_epoch = idx_epoch best_accuracy = evaluation_accuracy torch.save(model.state_dict(), best_model_filename) patience_count = 0 elif patience_count >= args.patience: break else: patience_count += 1 return if __name__ == "__main__": main()