yigitkucuk commited on
Commit
47f412e
·
verified ·
1 Parent(s): 3ea6c3c

Upload cnn.py

Browse files
Files changed (1) hide show
  1. cnn.py +244 -0
cnn.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from torchvision import transforms
6
+ import os
7
+ from PIL import Image
8
+ import numpy as np
9
+ from tqdm import tqdm
10
+ from sklearn.metrics import classification_report
11
+ import matplotlib.pyplot as plt
12
+
13
+ class ChordDataset(Dataset):
14
+ def __init__(self, root_dir, transform=None):
15
+ self.root_dir = root_dir
16
+ self.transform = transform
17
+ self.images = []
18
+ self.labels = []
19
+ self.class_to_idx = {}
20
+
21
+ # Get all image files and their corresponding labels
22
+ for img_name in os.listdir(root_dir):
23
+ if img_name.endswith(('.jpg', '.jpeg', '.png')):
24
+ chord = img_name.split('_')[0]
25
+ if chord not in self.class_to_idx:
26
+ self.class_to_idx[chord] = len(self.class_to_idx)
27
+
28
+ self.images.append(os.path.join(root_dir, img_name))
29
+ self.labels.append(self.class_to_idx[chord])
30
+
31
+ def __len__(self):
32
+ return len(self.images)
33
+
34
+ def __getitem__(self, idx):
35
+ img_path = self.images[idx]
36
+ image = Image.open(img_path).convert('RGB')
37
+ label = self.labels[idx]
38
+
39
+ if self.transform:
40
+ image = self.transform(image)
41
+
42
+ return image, label
43
+
44
+ class ChordCNN(nn.Module):
45
+ def __init__(self, num_classes):
46
+ super(ChordCNN, self).__init__()
47
+
48
+ # Convolutional layers
49
+ self.conv_layers = nn.Sequential(
50
+ # First conv block
51
+ nn.Conv2d(3, 32, kernel_size=3, padding=1),
52
+ nn.BatchNorm2d(32),
53
+ nn.ReLU(),
54
+ nn.MaxPool2d(2),
55
+
56
+ # Second conv block
57
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
58
+ nn.BatchNorm2d(64),
59
+ nn.ReLU(),
60
+ nn.MaxPool2d(2),
61
+
62
+ # Third conv block
63
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
64
+ nn.BatchNorm2d(128),
65
+ nn.ReLU(),
66
+ nn.MaxPool2d(2),
67
+
68
+ # Fourth conv block
69
+ nn.Conv2d(128, 256, kernel_size=3, padding=1),
70
+ nn.BatchNorm2d(256),
71
+ nn.ReLU(),
72
+ nn.MaxPool2d(2),
73
+
74
+ # Fifth conv block
75
+ nn.Conv2d(256, 512, kernel_size=3, padding=1),
76
+ nn.BatchNorm2d(512),
77
+ nn.ReLU(),
78
+ nn.MaxPool2d(2),
79
+ )
80
+
81
+ # Fully connected layers
82
+ self.fc_layers = nn.Sequential(
83
+ nn.Dropout(0.5),
84
+ nn.Linear(512 * 7 * 7, 1024),
85
+ nn.ReLU(),
86
+ nn.Dropout(0.5),
87
+ nn.Linear(1024, num_classes)
88
+ )
89
+
90
+ def forward(self, x):
91
+ x = self.conv_layers(x)
92
+ x = x.view(x.size(0), -1)
93
+ x = self.fc_layers(x)
94
+ return x
95
+
96
+ def train_epoch(model, train_loader, criterion, optimizer, device):
97
+ model.train()
98
+ running_loss = 0.0
99
+ correct = 0
100
+ total = 0
101
+
102
+ for images, labels in tqdm(train_loader, desc="Training"):
103
+ images, labels = images.to(device), labels.to(device)
104
+
105
+ optimizer.zero_grad()
106
+ outputs = model(images)
107
+ loss = criterion(outputs, labels)
108
+
109
+ loss.backward()
110
+ optimizer.step()
111
+
112
+ running_loss += loss.item()
113
+ _, predicted = outputs.max(1)
114
+ total += labels.size(0)
115
+ correct += predicted.eq(labels).sum().item()
116
+
117
+ epoch_loss = running_loss / len(train_loader)
118
+ accuracy = 100. * correct / total
119
+ return epoch_loss, accuracy
120
+
121
+ def evaluate(model, data_loader, criterion, device):
122
+ model.eval()
123
+ running_loss = 0.0
124
+ correct = 0
125
+ total = 0
126
+ all_predictions = []
127
+ all_labels = []
128
+
129
+ with torch.no_grad():
130
+ for images, labels in tqdm(data_loader, desc="Evaluating"):
131
+ images, labels = images.to(device), labels.to(device)
132
+ outputs = model(images)
133
+ loss = criterion(outputs, labels)
134
+
135
+ running_loss += loss.item()
136
+ _, predicted = outputs.max(1)
137
+ total += labels.size(0)
138
+ correct += predicted.eq(labels).sum().item()
139
+
140
+ all_predictions.extend(predicted.cpu().numpy())
141
+ all_labels.extend(labels.cpu().numpy())
142
+
143
+ epoch_loss = running_loss / len(data_loader)
144
+ accuracy = 100. * correct / total
145
+ return epoch_loss, accuracy, all_predictions, all_labels
146
+
147
+ def train_and_evaluate():
148
+ # Set device
149
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
150
+ print(f"Using device: {device}")
151
+
152
+ # Define transformations
153
+ transform = transforms.Compose([
154
+ transforms.Resize((224, 224)),
155
+ transforms.ToTensor(),
156
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
157
+ std=[0.229, 0.224, 0.225])
158
+ ])
159
+
160
+ # Create datasets
161
+ train_dataset = ChordDataset(root_dir='ds/train', transform=transform)
162
+ valid_dataset = ChordDataset(root_dir='ds/valid', transform=transform)
163
+ test_dataset = ChordDataset(root_dir='ds/test', transform=transform)
164
+
165
+ # Create dataloaders
166
+ batch_size = 32
167
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
168
+ valid_loader = DataLoader(valid_dataset, batch_size=batch_size)
169
+ test_loader = DataLoader(test_dataset, batch_size=batch_size)
170
+
171
+ # Initialize model
172
+ num_classes = len(train_dataset.class_to_idx)
173
+ model = ChordCNN(num_classes).to(device)
174
+
175
+ # Define loss function and optimizer
176
+ criterion = nn.CrossEntropyLoss()
177
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
178
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)
179
+
180
+ # Training parameters
181
+ num_epochs = 30
182
+ best_valid_loss = float('inf')
183
+ train_losses = []
184
+ valid_losses = []
185
+ train_accuracies = []
186
+ valid_accuracies = []
187
+
188
+ # Training loop
189
+ for epoch in range(num_epochs):
190
+ print(f"\nEpoch {epoch+1}/{num_epochs}")
191
+
192
+ # Train
193
+ train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
194
+ train_losses.append(train_loss)
195
+ train_accuracies.append(train_acc)
196
+
197
+ # Validate
198
+ valid_loss, valid_acc, _, _ = evaluate(model, valid_loader, criterion, device)
199
+ valid_losses.append(valid_loss)
200
+ valid_accuracies.append(valid_acc)
201
+
202
+ # Print epoch statistics
203
+ print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
204
+ print(f"Valid Loss: {valid_loss:.4f} | Valid Acc: {valid_acc:.2f}%")
205
+
206
+ # Learning rate scheduling
207
+ scheduler.step(valid_loss)
208
+
209
+ # Save best model
210
+ if valid_loss < best_valid_loss:
211
+ best_valid_loss = valid_loss
212
+ torch.save(model.state_dict(), 'best_chord_cnn.pth')
213
+
214
+ # Load best model and evaluate on test set
215
+ model.load_state_dict(torch.load('best_chord_cnn.pth'))
216
+ test_loss, test_acc, test_predictions, test_labels = evaluate(model, test_loader, criterion, device)
217
+ print("\nTest Set Performance:")
218
+ print(classification_report(test_labels, test_predictions))
219
+
220
+ # Plot training history
221
+ plt.figure(figsize=(12, 4))
222
+
223
+ plt.subplot(1, 2, 1)
224
+ plt.plot(train_losses, label='Train Loss')
225
+ plt.plot(valid_losses, label='Valid Loss')
226
+ plt.xlabel('Epoch')
227
+ plt.ylabel('Loss')
228
+ plt.legend()
229
+
230
+ plt.subplot(1, 2, 2)
231
+ plt.plot(train_accuracies, label='Train Accuracy')
232
+ plt.plot(valid_accuracies, label='Valid Accuracy')
233
+ plt.xlabel('Epoch')
234
+ plt.ylabel('Accuracy (%)')
235
+ plt.legend()
236
+
237
+ plt.tight_layout()
238
+ plt.savefig('training_history.png')
239
+ plt.close()
240
+
241
+ return model, train_dataset.class_to_idx
242
+
243
+ if __name__ == "__main__":
244
+ model, class_mapping = train_and_evaluate()