|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision import transforms |
|
import os |
|
from PIL import Image |
|
import numpy as np |
|
from tqdm import tqdm |
|
from sklearn.metrics import classification_report |
|
import matplotlib.pyplot as plt |
|
|
|
class ChordDataset(Dataset): |
|
def __init__(self, root_dir, transform=None): |
|
self.root_dir = root_dir |
|
self.transform = transform |
|
self.images = [] |
|
self.labels = [] |
|
self.class_to_idx = {} |
|
|
|
|
|
for img_name in os.listdir(root_dir): |
|
if img_name.endswith(('.jpg', '.jpeg', '.png')): |
|
chord = img_name.split('_')[0] |
|
if chord not in self.class_to_idx: |
|
self.class_to_idx[chord] = len(self.class_to_idx) |
|
|
|
self.images.append(os.path.join(root_dir, img_name)) |
|
self.labels.append(self.class_to_idx[chord]) |
|
|
|
def __len__(self): |
|
return len(self.images) |
|
|
|
def __getitem__(self, idx): |
|
img_path = self.images[idx] |
|
image = Image.open(img_path).convert('RGB') |
|
label = self.labels[idx] |
|
|
|
if self.transform: |
|
image = self.transform(image) |
|
|
|
return image, label |
|
|
|
class ChordCNN(nn.Module): |
|
def __init__(self, num_classes): |
|
super(ChordCNN, self).__init__() |
|
|
|
|
|
self.conv_layers = nn.Sequential( |
|
|
|
nn.Conv2d(3, 32, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(32), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2), |
|
|
|
|
|
nn.Conv2d(32, 64, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2), |
|
|
|
|
|
nn.Conv2d(64, 128, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2), |
|
|
|
|
|
nn.Conv2d(128, 256, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(256), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2), |
|
|
|
|
|
nn.Conv2d(256, 512, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(512), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2), |
|
) |
|
|
|
|
|
self.fc_layers = nn.Sequential( |
|
nn.Dropout(0.5), |
|
nn.Linear(512 * 7 * 7, 1024), |
|
nn.ReLU(), |
|
nn.Dropout(0.5), |
|
nn.Linear(1024, num_classes) |
|
) |
|
|
|
def forward(self, x): |
|
x = self.conv_layers(x) |
|
x = x.view(x.size(0), -1) |
|
x = self.fc_layers(x) |
|
return x |
|
|
|
def train_epoch(model, train_loader, criterion, optimizer, device): |
|
model.train() |
|
running_loss = 0.0 |
|
correct = 0 |
|
total = 0 |
|
|
|
for images, labels in tqdm(train_loader, desc="Training"): |
|
images, labels = images.to(device), labels.to(device) |
|
|
|
optimizer.zero_grad() |
|
outputs = model(images) |
|
loss = criterion(outputs, labels) |
|
|
|
loss.backward() |
|
optimizer.step() |
|
|
|
running_loss += loss.item() |
|
_, predicted = outputs.max(1) |
|
total += labels.size(0) |
|
correct += predicted.eq(labels).sum().item() |
|
|
|
epoch_loss = running_loss / len(train_loader) |
|
accuracy = 100. * correct / total |
|
return epoch_loss, accuracy |
|
|
|
def evaluate(model, data_loader, criterion, device): |
|
model.eval() |
|
running_loss = 0.0 |
|
correct = 0 |
|
total = 0 |
|
all_predictions = [] |
|
all_labels = [] |
|
|
|
with torch.no_grad(): |
|
for images, labels in tqdm(data_loader, desc="Evaluating"): |
|
images, labels = images.to(device), labels.to(device) |
|
outputs = model(images) |
|
loss = criterion(outputs, labels) |
|
|
|
running_loss += loss.item() |
|
_, predicted = outputs.max(1) |
|
total += labels.size(0) |
|
correct += predicted.eq(labels).sum().item() |
|
|
|
all_predictions.extend(predicted.cpu().numpy()) |
|
all_labels.extend(labels.cpu().numpy()) |
|
|
|
epoch_loss = running_loss / len(data_loader) |
|
accuracy = 100. * correct / total |
|
return epoch_loss, accuracy, all_predictions, all_labels |
|
|
|
def train_and_evaluate(): |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
|
|
train_dataset = ChordDataset(root_dir='ds/train', transform=transform) |
|
valid_dataset = ChordDataset(root_dir='ds/valid', transform=transform) |
|
test_dataset = ChordDataset(root_dir='ds/test', transform=transform) |
|
|
|
|
|
batch_size = 32 |
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) |
|
valid_loader = DataLoader(valid_dataset, batch_size=batch_size) |
|
test_loader = DataLoader(test_dataset, batch_size=batch_size) |
|
|
|
|
|
num_classes = len(train_dataset.class_to_idx) |
|
model = ChordCNN(num_classes).to(device) |
|
|
|
|
|
criterion = nn.CrossEntropyLoss() |
|
optimizer = optim.Adam(model.parameters(), lr=0.001) |
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5) |
|
|
|
|
|
num_epochs = 30 |
|
best_valid_loss = float('inf') |
|
train_losses = [] |
|
valid_losses = [] |
|
train_accuracies = [] |
|
valid_accuracies = [] |
|
|
|
|
|
for epoch in range(num_epochs): |
|
print(f"\nEpoch {epoch+1}/{num_epochs}") |
|
|
|
|
|
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device) |
|
train_losses.append(train_loss) |
|
train_accuracies.append(train_acc) |
|
|
|
|
|
valid_loss, valid_acc, _, _ = evaluate(model, valid_loader, criterion, device) |
|
valid_losses.append(valid_loss) |
|
valid_accuracies.append(valid_acc) |
|
|
|
|
|
print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%") |
|
print(f"Valid Loss: {valid_loss:.4f} | Valid Acc: {valid_acc:.2f}%") |
|
|
|
|
|
scheduler.step(valid_loss) |
|
|
|
|
|
if valid_loss < best_valid_loss: |
|
best_valid_loss = valid_loss |
|
torch.save(model.state_dict(), 'best_chord_cnn.pth') |
|
|
|
|
|
model.load_state_dict(torch.load('best_chord_cnn.pth')) |
|
test_loss, test_acc, test_predictions, test_labels = evaluate(model, test_loader, criterion, device) |
|
print("\nTest Set Performance:") |
|
print(classification_report(test_labels, test_predictions)) |
|
|
|
|
|
plt.figure(figsize=(12, 4)) |
|
|
|
plt.subplot(1, 2, 1) |
|
plt.plot(train_losses, label='Train Loss') |
|
plt.plot(valid_losses, label='Valid Loss') |
|
plt.xlabel('Epoch') |
|
plt.ylabel('Loss') |
|
plt.legend() |
|
|
|
plt.subplot(1, 2, 2) |
|
plt.plot(train_accuracies, label='Train Accuracy') |
|
plt.plot(valid_accuracies, label='Valid Accuracy') |
|
plt.xlabel('Epoch') |
|
plt.ylabel('Accuracy (%)') |
|
plt.legend() |
|
|
|
plt.tight_layout() |
|
plt.savefig('training_history.png') |
|
plt.close() |
|
|
|
return model, train_dataset.class_to_idx |
|
|
|
if __name__ == "__main__": |
|
model, class_mapping = train_and_evaluate() |