#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse from collections import defaultdict import json import logging from logging.handlers import TimedRotatingFileHandler import os import platform from pathlib import Path import sys import shutil from typing import List pwd = os.path.abspath(os.path.dirname(__file__)) sys.path.append(os.path.join(pwd, "../../")) import pandas as pd import torch from torch.utils.data.dataloader import DataLoader from tqdm import tqdm from toolbox.torch.modules.loss import FocalLoss, HingeLoss, HingeLinear from toolbox.torch.training.metrics.categorical_accuracy import CategoricalAccuracy from toolbox.torch.utils.data.vocabulary import Vocabulary from toolbox.torch.utils.data.dataset.wave_classifier_excel_dataset import WaveClassifierExcelDataset from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveEncoder, ClsHead, WaveClassifier def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--vocabulary_dir", default="vocabulary", type=str) parser.add_argument("--train_dataset", default="train.xlsx", type=str) parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) parser.add_argument("--max_steps", default=100000, type=int) parser.add_argument("--save_steps", default=30, type=int) parser.add_argument("--batch_size", default=1, type=int) parser.add_argument("--learning_rate", default=1e-3, type=float) parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) parser.add_argument("--patience", default=5, type=int) parser.add_argument("--serialization_dir", default="union", type=str) parser.add_argument("--seed", default=0, type=int) parser.add_argument("--num_workers", default=0, type=int) args = parser.parse_args() return args def logging_config(file_dir: str): fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" logging.basicConfig(format=fmt, datefmt="%m/%d/%Y %H:%M:%S", level=logging.DEBUG) file_handler = TimedRotatingFileHandler( filename=os.path.join(file_dir, "main.log"), encoding="utf-8", when="D", interval=1, backupCount=7 ) file_handler.setLevel(logging.INFO) file_handler.setFormatter(logging.Formatter(fmt)) logger = logging.getLogger(__name__) logger.addHandler(file_handler) return logger class CollateFunction(object): def __init__(self): pass def __call__(self, batch: List[dict]): array_list = list() label_list = list() for sample in batch: array = sample['waveform'] label = sample['label'] array_list.append(array) label_list.append(label) array_list = torch.stack(array_list) label_list = torch.stack(label_list) return array_list, label_list collate_fn = CollateFunction() class DatasetIterator(object): def __init__(self, data_loader: DataLoader): self.data_loader = data_loader self.data_loader_iter = iter(self.data_loader) def next(self): try: result = self.data_loader_iter.__next__() except StopIteration: self.data_loader_iter = iter(self.data_loader) result = self.data_loader_iter.__next__() return result def main(): args = get_args() serialization_dir = Path(args.serialization_dir) serialization_dir.mkdir(parents=True, exist_ok=True) logger = logging_config(args.serialization_dir) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) vocabulary = Vocabulary.from_files(args.vocabulary_dir) namespaces = vocabulary._token_to_index.keys() # namespace_to_ratio max_radio = (len(namespaces) - 1) * 3 namespace_to_ratio = {n: 1 for n in namespaces} namespace_to_ratio["global_labels"] = max_radio # datasets logger.info("prepare datasets") namespace_to_datasets = dict() for namespace in namespaces: logger.info("prepare datasets - {}".format(namespace)) if namespace == "global_labels": train_dataset = WaveClassifierExcelDataset( vocab=vocabulary, excel_file=args.train_dataset, category=None, category_field="category", label_field="global_labels", expected_sample_rate=8000, max_wave_value=32768.0, ) valid_dataset = WaveClassifierExcelDataset( vocab=vocabulary, excel_file=args.valid_dataset, category=None, category_field="category", label_field="global_labels", expected_sample_rate=8000, max_wave_value=32768.0, ) else: train_dataset = WaveClassifierExcelDataset( vocab=vocabulary, excel_file=args.train_dataset, category=namespace, category_field="category", label_field="country_labels", expected_sample_rate=8000, max_wave_value=32768.0, ) valid_dataset = WaveClassifierExcelDataset( vocab=vocabulary, excel_file=args.valid_dataset, category=namespace, category_field="category", label_field="country_labels", expected_sample_rate=8000, max_wave_value=32768.0, ) train_data_loader = DataLoader( dataset=train_dataset, batch_size=args.batch_size, shuffle=True, # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. # num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, num_workers=args.num_workers, collate_fn=collate_fn, pin_memory=False, # prefetch_factor=64, ) valid_data_loader = DataLoader( dataset=valid_dataset, batch_size=args.batch_size, shuffle=True, # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. # num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, num_workers=args.num_workers, collate_fn=collate_fn, pin_memory=False, # prefetch_factor=64, ) namespace_to_datasets[namespace] = { "train_data_loader": train_data_loader, "valid_data_loader": valid_data_loader, } # datasets iterator logger.info("prepare datasets iterator") namespace_to_datasets_iter = dict() for namespace in namespaces: logger.info("prepare datasets iterator - {}".format(namespace)) train_data_loader = namespace_to_datasets[namespace]["train_data_loader"] valid_data_loader = namespace_to_datasets[namespace]["valid_data_loader"] namespace_to_datasets_iter[namespace] = { "train_data_loader_iter": DatasetIterator(train_data_loader), "valid_data_loader_iter": DatasetIterator(valid_data_loader), } # models - encoder logger.info("prepare models - encoder") wave_encoder = WaveEncoder( conv2d_block_param_list=[ { "batch_norm": True, "in_channels": 1, "out_channels": 4, "kernel_size": 3, "stride": 1, # "padding": "same", "dilation": 3, "activation": "relu", "dropout": 0.1, }, { # "batch_norm": True, "in_channels": 4, "out_channels": 4, "kernel_size": 5, "stride": 2, # "padding": "same", "dilation": 3, "activation": "relu", "dropout": 0.1, }, { # "batch_norm": True, "in_channels": 4, "out_channels": 4, "kernel_size": 3, "stride": 1, # "padding": "same", "dilation": 2, "activation": "relu", "dropout": 0.1, }, ], mel_spectrogram_param={ 'sample_rate': 8000, 'n_fft': 512, 'win_length': 200, 'hop_length': 80, 'f_min': 10, 'f_max': 3800, 'window_fn': 'hamming', 'n_mels': 80, } ) # models - cls_head logger.info("prepare models - cls_head") namespace_to_cls_heads = dict() for namespace in namespaces: logger.info("prepare models - cls_head - {}".format(namespace)) cls_head = ClsHead( input_dim=352, num_layers=2, hidden_dims=[128, 32], activations="relu", dropout=0.1, num_labels=vocabulary.get_vocab_size(namespace=namespace) ) namespace_to_cls_heads[namespace] = cls_head # models - classifier logger.info("prepare models - classifier") namespace_to_classifier = dict() for namespace in namespaces: logger.info("prepare models - classifier - {}".format(namespace)) cls_head = namespace_to_cls_heads[namespace] wave_classifier = WaveClassifier( wave_encoder=wave_encoder, cls_head=cls_head, ) wave_classifier.to(device) namespace_to_classifier[namespace] = wave_classifier # optimizer logger.info("prepare optimizer") param_optimizer = list() param_optimizer.extend(wave_encoder.parameters()) for _, cls_head in namespace_to_cls_heads.items(): param_optimizer.extend(cls_head.parameters()) optimizer = torch.optim.Adam( param_optimizer, lr=args.learning_rate, ) lr_scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=10000 ) focal_loss = FocalLoss( num_classes=vocabulary.get_vocab_size(namespace="global_labels"), reduction="mean", ) # categorical_accuracy logger.info("prepare categorical_accuracy") namespace_to_categorical_accuracy = dict() for namespace in namespaces: categorical_accuracy = CategoricalAccuracy() namespace_to_categorical_accuracy[namespace] = categorical_accuracy # training loop logger.info("prepare training loop") model_list = list() best_idx_step = None best_accuracy = None patience_count = 0 namespace_to_total_loss = defaultdict(float) namespace_to_total_examples = defaultdict(int) for idx_step in tqdm(range(args.max_steps)): # training one step loss: torch.Tensor = None for namespace in namespaces: train_data_loader_iter = namespace_to_datasets_iter[namespace]["train_data_loader_iter"] ratio = namespace_to_ratio[namespace] model = namespace_to_classifier[namespace] categorical_accuracy = namespace_to_categorical_accuracy[namespace] model.train() for _ in range(ratio): batch = train_data_loader_iter.next() input_ids, label_ids = batch input_ids = input_ids.to(device) label_ids: torch.LongTensor = label_ids.to(device).long() logits = model.forward(input_ids) task_loss = focal_loss.forward(logits, label_ids.view(-1)) categorical_accuracy(logits, label_ids) if loss is None: loss = task_loss else: loss += task_loss namespace_to_total_loss[namespace] += task_loss.item() namespace_to_total_examples[namespace] += input_ids.size(0) optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # logging if (idx_step + 1) % args.save_steps == 0: metrics = dict() # training for namespace in namespaces: total_loss = namespace_to_total_loss[namespace] total_examples = namespace_to_total_examples[namespace] training_loss = total_loss / total_examples training_loss = round(training_loss, 4) categorical_accuracy = namespace_to_categorical_accuracy[namespace] training_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"] training_accuracy = round(training_accuracy, 4) logger.info("Step: {}; namespace: {}; training_loss: {}; training_accuracy: {}".format( idx_step, namespace, training_loss, training_accuracy )) metrics[namespace] = { "training_loss": training_loss, "training_accuracy": training_accuracy, } namespace_to_total_loss = defaultdict(float) namespace_to_total_examples = defaultdict(int) # evaluation for namespace in namespaces: valid_data_loader = namespace_to_datasets[namespace]["valid_data_loader"] model = namespace_to_classifier[namespace] categorical_accuracy = namespace_to_categorical_accuracy[namespace] model.eval() total_loss = 0 total_examples = 0 for step, batch in enumerate(valid_data_loader): input_ids, label_ids = batch input_ids = input_ids.to(device) label_ids: torch.LongTensor = label_ids.to(device).long() with torch.no_grad(): logits = model.forward(input_ids) loss = focal_loss.forward(logits, label_ids.view(-1)) categorical_accuracy(logits, label_ids) total_loss += loss.item() total_examples += input_ids.size(0) evaluation_loss = total_loss / total_examples evaluation_loss = round(evaluation_loss, 4) evaluation_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"] evaluation_accuracy = round(evaluation_accuracy, 4) logger.info("Step: {}; namespace: {}; evaluation_loss: {}; evaluation_accuracy: {}".format( idx_step, namespace, evaluation_loss, evaluation_accuracy )) metrics[namespace] = { "evaluation_loss": evaluation_loss, "evaluation_accuracy": evaluation_accuracy, } # update ratio min_accuracy = min([m["evaluation_accuracy"] for m in metrics.values()]) max_accuracy = max([m["evaluation_accuracy"] for m in metrics.values()]) width = max_accuracy - min_accuracy for namespace, metric in metrics.items(): evaluation_accuracy = metric["evaluation_accuracy"] radio = (max_accuracy - evaluation_accuracy) / width * max_radio radio = int(radio) namespace_to_ratio[namespace] = radio msg = "".join(["{}: {}; ".format(k, v) for k, v in namespace_to_ratio.items()]) logger.info("namespace to ratio: {}".format(msg)) # save path step_dir = serialization_dir / "step-{}".format(idx_step) step_dir.mkdir(parents=True, exist_ok=False) # save models wave_encoder_filename = step_dir / "wave_encoder.pt" torch.save(wave_encoder.state_dict(), wave_encoder_filename) for namespace in namespaces: cls_head_filename = step_dir / "{}.pt".format(namespace) cls_head = namespace_to_cls_heads[namespace] torch.save(cls_head.state_dict(), cls_head_filename) model_list.append(step_dir) if len(model_list) >= args.num_serialized_models_to_keep: model_to_delete: Path = model_list.pop(0) shutil.rmtree(model_to_delete.as_posix()) # save metric this_accuracy = metrics["global_labels"]["evaluation_accuracy"] if best_accuracy is None: best_idx_step = idx_step best_accuracy = this_accuracy elif metrics["global_labels"]["evaluation_accuracy"] > best_accuracy: best_idx_step = idx_step best_accuracy = this_accuracy else: pass metrics_filename = step_dir / "metrics_epoch.json" metrics.update({ "idx_step": idx_step, "best_idx_step": best_idx_step, }) with open(metrics_filename, "w", encoding="utf-8") as f: json.dump(metrics, f, indent=4, ensure_ascii=False) # save best best_dir = serialization_dir / "best" if best_idx_step == idx_step: if best_dir.exists(): shutil.rmtree(best_dir) shutil.copytree(step_dir, best_dir) # early stop early_stop_flag = False if best_idx_step == idx_step: patience_count = 0 else: patience_count += 1 if patience_count >= args.patience: early_stop_flag = True # early stop if early_stop_flag: break return if __name__ == "__main__": main()