Spaces:
Runtime error
Runtime error
| import datetime | |
| import os | |
| import torch | |
| from torch.utils.data import DataLoader | |
| import torch.nn as nn | |
| from tqdm import tqdm | |
| import pandas as pd | |
| import numpy as np | |
| from torch.utils.data import random_split, SubsetRandomSampler | |
| import json | |
| from sklearn.model_selection import KFold | |
| from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score | |
| from preprocessing.dataset import SongDataset | |
| from preprocessing.preprocess import get_examples | |
| from dancer_net.dancer_net import ShortChunkCNN | |
| DEVICE = "mps" | |
| SEED = 42 | |
| def get_timestamp() -> str: | |
| return datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S") | |
| class EarlyStopping: | |
| def __init__(self, patience=0): | |
| self.patience = patience | |
| self.last_measure = np.inf | |
| self.consecutive_increase = 0 | |
| def step(self, val) -> bool: | |
| if self.last_measure <= val: | |
| self.consecutive_increase +=1 | |
| else: | |
| self.consecutive_increase = 0 | |
| self.last_measure = val | |
| return self.patience < self.consecutive_increase | |
| def calculate_metrics(pred, target, threshold=0.5, prefix=""): | |
| target = target.detach().cpu().numpy() | |
| pred = pred.detach().cpu().numpy() | |
| pred = np.array(pred > threshold, dtype=float) | |
| metrics= { | |
| 'precision': precision_score(y_true=target, y_pred=pred, average='macro', zero_division=0), | |
| 'recall': recall_score(y_true=target, y_pred=pred, average='macro', zero_division=0), | |
| 'f1': f1_score(y_true=target, y_pred=pred, average='macro', zero_division=0), | |
| 'accuracy': accuracy_score(y_true=target, y_pred=pred), | |
| } | |
| if prefix != "": | |
| metrics = {prefix + k : v for k, v in metrics.items()} | |
| return metrics | |
| def evaluate(model:nn.Module, data_loader:DataLoader, criterion, device="mps") -> pd.Series: | |
| val_metrics = [] | |
| for features, labels in (prog_bar := tqdm(data_loader)): | |
| features = features.to(device) | |
| labels = labels.to(device) | |
| with torch.no_grad(): | |
| outputs = model(features) | |
| loss = criterion(outputs, labels) | |
| batch_metrics = calculate_metrics(outputs, labels, prefix="val_") | |
| batch_metrics["val_loss"] = loss.item() | |
| prog_bar.set_description(f'Validation - Loss: {batch_metrics["val_loss"]:.2f}, Accuracy: {batch_metrics["val_accuracy"]:.2f}') | |
| val_metrics.append(batch_metrics) | |
| return pd.DataFrame(val_metrics).mean() | |
| def train( | |
| model: nn.Module, | |
| data_loader: DataLoader, | |
| val_loader=None, | |
| epochs=3, | |
| lr=1e-3, | |
| device="mps"): | |
| criterion = nn.BCELoss() | |
| optimizer = torch.optim.Adam(model.parameters(),lr=lr) | |
| early_stop = EarlyStopping(1) | |
| scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, | |
| steps_per_epoch=int(len(data_loader)), | |
| epochs=epochs, | |
| anneal_strategy='linear') | |
| metrics = [] | |
| for epoch in range(1,epochs+1): | |
| train_metrics = [] | |
| prog_bar = tqdm(data_loader) | |
| for features, labels in prog_bar: | |
| features = features.to(device) | |
| labels = labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(features) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| scheduler.step() | |
| batch_metrics = calculate_metrics(outputs, labels) | |
| batch_metrics["loss"] = loss.item() | |
| train_metrics.append(batch_metrics) | |
| prog_bar.set_description(f'Training - Epoch: {epoch}/{epochs}, Loss: {batch_metrics["loss"]:.2f}, Accuracy: {batch_metrics["accuracy"]:.2f}') | |
| train_metrics = pd.DataFrame(train_metrics).mean() | |
| if val_loader is not None: | |
| val_metrics = evaluate(model, val_loader, criterion) | |
| if early_stop.step(val_metrics["val_f1"]): | |
| break | |
| epoch_metrics = pd.concat([train_metrics, val_metrics], axis=0) | |
| else: | |
| epoch_metrics = train_metrics | |
| metrics.append(dict(epoch_metrics)) | |
| return model, metrics | |
| def cross_validation(seed=42, batch_size=64, k=5, device="mps"): | |
| target_classes = ['ATN', | |
| 'BBA', | |
| 'BCH', | |
| 'BLU', | |
| 'CHA', | |
| 'CMB', | |
| 'CSG', | |
| 'ECS', | |
| 'HST', | |
| 'JIV', | |
| 'LHP', | |
| 'QST', | |
| 'RMB', | |
| 'SFT', | |
| 'SLS', | |
| 'SMB', | |
| 'SWZ', | |
| 'TGO', | |
| 'VWZ', | |
| 'WCS'] | |
| df = pd.read_csv("data/songs.csv") | |
| x,y = get_examples(df, "data/samples",class_list=target_classes) | |
| dataset = SongDataset(x,y) | |
| splits=KFold(n_splits=k,shuffle=True,random_state=seed) | |
| metrics = [] | |
| for fold, (train_idx,val_idx) in enumerate(splits.split(x,y)): | |
| print(f"Fold {fold+1}") | |
| train_sampler = SubsetRandomSampler(train_idx) | |
| test_sampler = SubsetRandomSampler(val_idx) | |
| train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler) | |
| test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler) | |
| n_classes = len(y[0]) | |
| model = ShortChunkCNN(n_class=n_classes).to(device) | |
| model, _ = train(model,train_loader, epochs=2, device=device) | |
| val_metrics = evaluate(model, test_loader, nn.BCELoss()) | |
| metrics.append(val_metrics) | |
| metrics = pd.DataFrame(metrics) | |
| log_dir = os.path.join( | |
| "logs", get_timestamp() | |
| ) | |
| os.makedirs(log_dir, exist_ok=True) | |
| metrics.to_csv(model.state_dict(), os.path.join(log_dir, "cross_val.csv")) | |
| def train_model(): | |
| target_classes = ['ATN', | |
| 'BBA', | |
| 'BCH', | |
| 'BLU', | |
| 'CHA', | |
| 'CMB', | |
| 'CSG', | |
| 'ECS', | |
| 'HST', | |
| 'JIV', | |
| 'LHP', | |
| 'QST', | |
| 'RMB', | |
| 'SFT', | |
| 'SLS', | |
| 'SMB', | |
| 'SWZ', | |
| 'TGO', | |
| 'VWZ', | |
| 'WCS'] | |
| df = pd.read_csv("data/songs.csv") | |
| x,y = get_examples(df, "data/samples",class_list=target_classes) | |
| dataset = SongDataset(x,y) | |
| train_count = int(len(dataset) * 0.9) | |
| datasets = random_split(dataset, [train_count, len(dataset) - train_count], torch.Generator().manual_seed(SEED)) | |
| data_loaders = [DataLoader(data, batch_size=64, shuffle=True) for data in datasets] | |
| train_data, val_data = data_loaders | |
| example_spec, example_label = dataset[0] | |
| n_classes = len(example_label) | |
| model = ShortChunkCNN(n_class=n_classes).to(DEVICE) | |
| model, metrics = train(model,train_data, val_data, epochs=3, device=DEVICE) | |
| log_dir = os.path.join( | |
| "logs", get_timestamp() | |
| ) | |
| os.makedirs(log_dir, exist_ok=True) | |
| torch.save(model.state_dict(), os.path.join(log_dir, "dancer_net.pt")) | |
| metrics = pd.DataFrame(metrics) | |
| metrics.to_csv(os.path.join(log_dir, "metrics.csv")) | |
| config = { | |
| "classes": target_classes | |
| } | |
| with open(os.path.join(log_dir, "config.json")) as f: | |
| json.dump(config, f) | |
| print("Training information saved!") | |
| if __name__ == "__main__": | |
| cross_validation() |