Spaces:
Running
Running
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| import argparse | |
| from collections import defaultdict | |
| import json | |
| import logging | |
| from logging.handlers import TimedRotatingFileHandler | |
| import os | |
| import platform | |
| from pathlib import Path | |
| import sys | |
| import shutil | |
| from typing import List | |
| pwd = os.path.abspath(os.path.dirname(__file__)) | |
| sys.path.append(os.path.join(pwd, "../../")) | |
| import pandas as pd | |
| from scipy.io import wavfile | |
| import torch | |
| from tqdm import tqdm | |
| from toolbox.torch.utils.data.vocabulary import Vocabulary | |
| from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel | |
| def get_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--dataset", default="dataset.xlsx", type=str) | |
| parser.add_argument("--vocabulary_dir", default="vocabulary", type=str) | |
| parser.add_argument("--model_dir", default="best", type=str) | |
| parser.add_argument("--output_file", default="evaluation.xlsx", type=str) | |
| args = parser.parse_args() | |
| return args | |
| def logging_config(): | |
| fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" | |
| logging.basicConfig(format=fmt, | |
| datefmt="%m/%d/%Y %H:%M:%S", | |
| level=logging.DEBUG) | |
| stream_handler = logging.StreamHandler() | |
| stream_handler.setLevel(logging.INFO) | |
| stream_handler.setFormatter(logging.Formatter(fmt)) | |
| logger = logging.getLogger(__name__) | |
| return logger | |
| def main(): | |
| args = get_args() | |
| logger = logging_config() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| n_gpu = torch.cuda.device_count() | |
| logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) | |
| logger.info("prepare vocabulary, model") | |
| vocabulary = Vocabulary.from_files(args.vocabulary_dir) | |
| model = WaveClassifierPretrainedModel.from_pretrained( | |
| pretrained_model_name_or_path=args.model_dir, | |
| ) | |
| model.to(device) | |
| model.eval() | |
| logger.info("read excel") | |
| df = pd.read_excel(args.dataset) | |
| result = list() | |
| total_correct = 0 | |
| total_examples = 0 | |
| progress_bar = tqdm(total=len(df), desc="Evaluation") | |
| for i, row in df.iterrows(): | |
| filename = row["filename"] | |
| ground_true = row["labels"] | |
| sample_rate, waveform = wavfile.read(filename) | |
| waveform = waveform / (1 << 15) | |
| waveform = torch.tensor(waveform, dtype=torch.float32) | |
| waveform = torch.unsqueeze(waveform, dim=0) | |
| waveform = waveform.to(device) | |
| with torch.no_grad(): | |
| logits = model.forward(waveform) | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| label_idx = torch.argmax(probs, dim=-1) | |
| label_idx = label_idx.cpu() | |
| probs = probs.cpu() | |
| label_idx = label_idx.numpy()[0] | |
| label_str = vocabulary.get_token_from_index(label_idx, namespace="labels") | |
| prob = probs[0][label_idx].numpy() | |
| correct = 1 if label_str == ground_true else 0 | |
| row_ = dict(row) | |
| row_["predict"] = label_str | |
| row_["prob"] = prob | |
| row_["correct"] = correct | |
| result.append(row_) | |
| total_examples += 1 | |
| total_correct += correct | |
| accuracy = total_correct / total_examples | |
| progress_bar.update(1) | |
| progress_bar.set_postfix({ | |
| "accuracy": accuracy, | |
| }) | |
| result = pd.DataFrame(result) | |
| result.to_excel( | |
| args.output_file, | |
| index=False | |
| ) | |
| return | |
| if __name__ == '__main__': | |
| main() | |