English
music
emotion
File size: 3,814 Bytes
2dfd92b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import logging

import torch
torch.set_float32_matmul_precision("medium")

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from data_loader import DataModule
from trainer import MusicClassifier
import yaml
from omegaconf import DictConfig
import hydra
from hydra.utils import to_absolute_path
from hydra.core.hydra_config import HydraConfig
from pytorch_lightning.utilities.combined_loader import CombinedLoader


log = logging.getLogger(__name__)

def get_latest_version(log_dir):
    version_dirs = [d for d in os.listdir(log_dir) if d.startswith('version_')]
    version_dirs.sort(key=lambda x: int(x.split('_')[-1]))  # Sort by version number
    return version_dirs[-1] if version_dirs else None

def save_metrics_and_checkpoint(metrics, checkpoint, output_file):
    data = {
        'checkpoint': checkpoint,
        'metrics': metrics
    }
    with open(output_file, 'w') as f:
        yaml.dump(data, f)

def read_best_checkpoint_info(file_path, dataset_type=None):
    """Read the best checkpoint file."""
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"Checkpoint info file not found: {file_path}")
    
    with open(file_path, 'r') as f:
        lines = f.readlines()
    
    if dataset_type == "mood":
        checkpoint_line = next((line for line in lines if line.startswith("Best checkpoint (mood):")), None)
    elif dataset_type == "va":
        checkpoint_line = next((line for line in lines if line.startswith("Best checkpoint (va):")), None)
    else:
        checkpoint_line = next((line for line in lines if line.startswith("Best checkpoint:")), None)

    if not checkpoint_line:
        raise ValueError(f"No checkpoint found for dataset type '{dataset_type}' in the file.")
    
    return checkpoint_line.split(": ")[-1].strip()

@hydra.main(version_base=None, config_path="config", config_name="test_config")
def main(config: DictConfig):
    log.info("Testing starts")
    log_base_dir = 'tb_logs/train_audio_classification'
    # log_base_dir = to_absolute_path('tb_logs/train_audio_classification')

    latest_version = get_latest_version(log_base_dir)
    if not latest_version:
        raise FileNotFoundError("No version directories found in log base directory.")
    version_log_dir = os.path.join(log_base_dir, latest_version)
    output_file = os.path.join(version_log_dir, 'test_metrics.txt')

    if config.checkpoint_latest:
        if config.multitask:
            dataset_type = config.dataset_type  # Expecting 'mood' or 'va'
            best_checkpoint_file = os.path.join(version_log_dir, 'best_checkpoint.txt')
            ckpt = read_best_checkpoint_info(best_checkpoint_file, dataset_type)
        else:
            best_checkpoint_file = os.path.join(version_log_dir, 'best_checkpoint.txt')
            ckpt = read_best_checkpoint_info(best_checkpoint_file)
    else:
        ckpt = config.checkpoint
    if not os.path.exists(ckpt):
        raise FileNotFoundError(f"Checkpoint file not found: {ckpt}")
    
    log.info(f"Using checkpoint: {ckpt}")
    data_module = DataModule( config )
    data_module.setup()

    testloaders = {dataset_name: loader for dataset_name, loader in zip(config.datasets, data_module.test_dataloader())}
    combined_test_loader = CombinedLoader(testloaders, mode="max_size")

    model = MusicClassifier.load_from_checkpoint(ckpt, cfg=config, output_file=output_file)
    logger = TensorBoardLogger(save_dir=log_base_dir,  
                                name="",  
                                version=latest_version)
    trainer = pl.Trainer(**config.trainer,
                        logger=logger)
    
    
    trainer.test(model, combined_test_loader)

if __name__ == '__main__':
    main()