#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse from collections import defaultdict import json import logging from logging.handlers import TimedRotatingFileHandler import os import platform from pathlib import Path import sys import shutil from typing import List pwd = os.path.abspath(os.path.dirname(__file__)) sys.path.append(os.path.join(pwd, "../../")) import pandas as pd from scipy.io import wavfile import torch from tqdm import tqdm from toolbox.torch.utils.data.vocabulary import Vocabulary from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--dataset", default="dataset.xlsx", type=str) parser.add_argument("--vocabulary_dir", default="vocabulary", type=str) parser.add_argument("--model_dir", default="best", type=str) parser.add_argument("--output_file", default="evaluation.xlsx", type=str) args = parser.parse_args() return args def logging_config(): fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" logging.basicConfig(format=fmt, datefmt="%m/%d/%Y %H:%M:%S", level=logging.DEBUG) stream_handler = logging.StreamHandler() stream_handler.setLevel(logging.INFO) stream_handler.setFormatter(logging.Formatter(fmt)) logger = logging.getLogger(__name__) return logger def main(): args = get_args() logger = logging_config() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) logger.info("prepare vocabulary, model") vocabulary = Vocabulary.from_files(args.vocabulary_dir) model = WaveClassifierPretrainedModel.from_pretrained( pretrained_model_name_or_path=args.model_dir, ) model.to(device) model.eval() logger.info("read excel") df = pd.read_excel(args.dataset) result = list() total_correct = 0 total_examples = 0 progress_bar = tqdm(total=len(df), desc="Evaluation") for i, row in df.iterrows(): filename = row["filename"] ground_true = row["labels"] sample_rate, waveform = wavfile.read(filename) waveform = waveform / (1 << 15) waveform = torch.tensor(waveform, dtype=torch.float32) waveform = torch.unsqueeze(waveform, dim=0) waveform = waveform.to(device) with torch.no_grad(): logits = model.forward(waveform) probs = torch.nn.functional.softmax(logits, dim=-1) label_idx = torch.argmax(probs, dim=-1) label_idx = label_idx.cpu() probs = probs.cpu() label_idx = label_idx.numpy()[0] label_str = vocabulary.get_token_from_index(label_idx, namespace="labels") prob = probs[0][label_idx].numpy() correct = 1 if label_str == ground_true else 0 row_ = dict(row) row_["predict"] = label_str row_["prob"] = prob row_["correct"] = correct result.append(row_) total_examples += 1 total_correct += correct accuracy = total_correct / total_examples progress_bar.update(1) progress_bar.set_postfix({ "accuracy": accuracy, }) result = pd.DataFrame(result) result.to_excel( args.output_file, index=False ) return if __name__ == '__main__': main()