English
music
emotion
File size: 4,399 Bytes
2dfd92b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import logging

import torch
torch.set_float32_matmul_precision("medium") 


import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from data_loader import DataModule
from trainer import MusicClassifier
from omegaconf import DictConfig
import hydra
from hydra.utils import to_absolute_path
from hydra.core.hydra_config import HydraConfig
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.combined_loader import CombinedLoader
from pytorch_lightning.strategies import DDPStrategy

#from utilities.custom_early_stopping import MultiMetricEarlyStopping
def get_latest_version(log_dir):
    version_dirs = [d for d in os.listdir(log_dir) if d.startswith('version_')]
    version_dirs.sort(key=lambda x: int(x.split('_')[-1]))  # Sort by version number
    return version_dirs[-1] if version_dirs else None

log = logging.getLogger(__name__)
@hydra.main(version_base=None, config_path="config", config_name="train_config")
def main(config: DictConfig):

    log_base_dir = 'tb_logs/train_audio_classification'
    # log_base_dir = to_absolute_path('tb_logs/train_audio_classification')
    is_mt = False
    if "mt" in config.model.classifier:
        is_mt = True

    logger = TensorBoardLogger("tb_logs", name="train_audio_classification")
    logger.log_hyperparams(config)
    train_log_dir = logger.log_dir
    print(f"Logging to {train_log_dir}")
    log.info("Training starts")

    data_module = DataModule( config )
    data_module.setup()

    # Get the list of dataloaders for both train and validation, with dataset names
    trainloaders = {dataset_name: loader for dataset_name, loader in zip(config.datasets, data_module.train_dataloader())}
    vallowers = {dataset_name: loader for dataset_name, loader in zip(config.datasets, data_module.val_dataloader())}
    
    # Combine multiple loaders using CombinedLoader, now with dataset names
    combined_train_loader = CombinedLoader(trainloaders, mode="max_size")
    combined_val_loader = CombinedLoader(vallowers, mode="max_size")

    latest_version = get_latest_version(log_base_dir)
    next_version = int(latest_version.split('_')[-1]) + 1 if latest_version else 0
    next_version = f"version_{next_version}"

    val_epoch_file = os.path.join(log_base_dir,  latest_version, 'val_epoch.txt')

    model = MusicClassifier( config, output_file = val_epoch_file)

    if is_mt:
        checkpoint_callback_mood = ModelCheckpoint(**config.checkpoint_mood)
        checkpoint_callback_va = ModelCheckpoint(**config.checkpoint_va)
        early_stop_callback = EarlyStopping(**config.earlystopping)

        if config.model.kd == True:
            trainer = pl.Trainer(
                **config.trainer,
                strategy=DDPStrategy(find_unused_parameters=True),
                callbacks=[checkpoint_callback_mood, checkpoint_callback_va, early_stop_callback],
                logger=logger,
                num_sanity_val_steps=0
            )
        else:
            trainer = pl.Trainer(
                **config.trainer,
                strategy=DDPStrategy(find_unused_parameters=False),
                callbacks=[checkpoint_callback_mood, checkpoint_callback_va, early_stop_callback],
                logger=logger,
                num_sanity_val_steps=0
            )

    else:
        checkpoint_callback = ModelCheckpoint(**config.checkpoint)
        # early_stop_callback = EarlyStopping(**config.earlystopping)
        trainer = pl.Trainer(
            **config.trainer,
            callbacks=[checkpoint_callback, early_stop_callback],
            logger=logger,
            num_sanity_val_steps = 0
        )
    
    trainer.fit(model, combined_train_loader, combined_val_loader)

    if trainer.global_rank == 0:
        best_checkpoint_file = os.path.join(train_log_dir, 'best_checkpoint.txt')
        with open(best_checkpoint_file, 'w') as f:
            if is_mt:
                f.write(f"Best checkpoint (mood): {checkpoint_callback_mood.best_model_path}\n")
                f.write(f"Best checkpoint (va): {checkpoint_callback_va.best_model_path}\n")
            else:
                f.write(f"Best checkpoint: {checkpoint_callback.best_model_path}\n")
            f.write(f"Version: {logger.version}\n")

if __name__ == '__main__':
    main()