Gabor Cselle commited on
Commit
ea56d2d
·
1 Parent(s): 99f802a

U gotta normalize. (+ cleanup)

Browse files
Files changed (1) hide show
  1. train_font_identifier.py +58 -90
train_font_identifier.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import time
4
  import torch
5
  import torch.optim as optim
 
6
  from torch.optim import lr_scheduler
7
  from torchvision import datasets, models, transforms
8
  from tqdm import tqdm
@@ -10,113 +11,80 @@ from tqdm import tqdm
10
  # Directory with organized font images
11
  data_dir = './train_test_images'
12
 
13
- # Define transformations for the image data
14
- data_transforms = {
15
- 'train': transforms.Compose([
16
- transforms.Resize((224, 224)), # Resize to the input size expected by the model
17
- transforms.ToTensor(),
18
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet standards
19
- ]),
20
- 'test': transforms.Compose([
21
- transforms.Resize((224, 224)), # Resize to the input size expected by the model
22
- transforms.ToTensor(),
23
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
24
- ]),
25
- }
26
-
27
 
28
  # Create datasets
29
  image_datasets = {
30
- x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
31
  for x in ['train', 'test']
32
  }
33
 
34
  # Create dataloaders
35
  dataloaders = {
36
- 'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=4),
37
- 'test': torch.utils.data.DataLoader(image_datasets['test'], batch_size=4)
38
  }
39
 
40
  # Define the model
41
  model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
42
 
 
 
 
 
43
  # Define the loss function
44
  criterion = torch.nn.CrossEntropyLoss()
45
 
46
- # Optimizer (you can replace 'model.parameters()' with specific parameters to optimize if needed)
47
- optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
48
-
49
- # Decay LR by a factor of 0.1 every 7 epochs
50
- exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
51
 
52
  # Number of epochs to train for
53
  num_epochs = 25
54
 
55
- def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
56
- since = time.time()
57
-
58
- best_model_wts = copy.deepcopy(model.state_dict())
59
- best_acc = 0.0
60
-
61
- for epoch in range(num_epochs):
62
- print('Epoch {}/{}'.format(epoch, num_epochs - 1))
63
- print('-' * 10)
64
-
65
- # Each epoch has a training and validation phase
66
- for phase in ['train', 'test']:
67
- if phase == 'train':
68
- model.train() # Set model to training mode
69
- else:
70
- model.eval() # Set model to evaluate mode
71
-
72
- running_loss = 0.0
73
- running_corrects = 0
74
-
75
- # Iterate over data.
76
- # Here we wrap the dataloader with tqdm for a progress bar
77
- for inputs, labels in tqdm(dataloaders[phase], desc=f"Epoch {epoch} - {phase}"):
78
- # Zero the parameter gradients
79
- optimizer.zero_grad()
80
-
81
- # Forward
82
- # Track history if only in train
83
- with torch.set_grad_enabled(phase == 'train'):
84
- outputs = model(inputs)
85
- _, preds = torch.max(outputs, 1)
86
- loss = criterion(outputs, labels)
87
-
88
- # Backward + optimize only if in training phase
89
- if phase == 'train':
90
- loss.backward()
91
- optimizer.step()
92
-
93
- # Statistics
94
- running_loss += loss.item() * inputs.size(0)
95
- running_corrects += torch.sum(preds == labels.data)
96
- if phase == 'train':
97
- scheduler.step()
98
-
99
- epoch_loss = running_loss / len(image_datasets[phase])
100
- epoch_acc = running_corrects.double() / len(image_datasets[phase])
101
-
102
- print('{} Loss: {:.4f} Acc: {:.4f}'.format(
103
- phase, epoch_loss, epoch_acc))
104
-
105
- # Deep copy the model
106
- if phase == 'test' and epoch_acc > best_acc:
107
- best_acc = epoch_acc
108
- best_model_wts = copy.deepcopy(model.state_dict())
109
-
110
- print()
111
-
112
- time_elapsed = time.time() - since
113
- print('Training complete in {:.0f}m {:.0f}s'.format(
114
- time_elapsed // 60, time_elapsed % 60))
115
- print('Best test Acc: {:4f}'.format(best_acc))
116
-
117
- # Load best model weights
118
- model.load_state_dict(best_model_wts)
119
- return model
120
-
121
- # Train the model
122
- model = train_model(model, criterion, optimizer, exp_lr_scheduler, num_epochs=num_epochs)
 
3
  import time
4
  import torch
5
  import torch.optim as optim
6
+ import torch.nn as nn
7
  from torch.optim import lr_scheduler
8
  from torchvision import datasets, models, transforms
9
  from tqdm import tqdm
 
11
  # Directory with organized font images
12
  data_dir = './train_test_images'
13
 
14
+ # Transformations for the image data
15
+ data_transforms = transforms.Compose([
16
+ s transforms.Grayscale(num_output_channels=3), # Convert images to grayscale with 3 channels
17
+ transforms.Resize((224, 224)), # Resize images to the expected input size of the model
18
+ transforms.ToTensor(), # Convert images to PyTorch tensors
19
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize with ImageNet stats
20
+ ])
 
 
 
 
 
 
 
21
 
22
  # Create datasets
23
  image_datasets = {
24
+ x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms)
25
  for x in ['train', 'test']
26
  }
27
 
28
  # Create dataloaders
29
  dataloaders = {
30
+ 'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=4, shuffle=True),
31
+ 'test': torch.utils.data.DataLoader(image_datasets['test'], batch_size=4, shuffle=True)
32
  }
33
 
34
  # Define the model
35
  model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
36
 
37
+ # Modify the last fully connected layer to match the number of font classes you have
38
+ num_classes = len(image_datasets['train'].classes)
39
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
40
+
41
  # Define the loss function
42
  criterion = torch.nn.CrossEntropyLoss()
43
 
44
+ # Define loss function and optimizer
45
+ criterion = nn.CrossEntropyLoss()
46
+ optimizer = optim.Adam(model.parameters())
 
 
47
 
48
  # Number of epochs to train for
49
  num_epochs = 25
50
 
51
+ # Function to perform a training step with progress bar
52
+ def train_step(model, data_loader, criterion, optimizer):
53
+ model.train()
54
+ total_loss = 0
55
+ progress_bar = tqdm(data_loader, desc='Training', leave=True)
56
+ for inputs, targets in progress_bar:
57
+ outputs = model(inputs)
58
+ loss = criterion(outputs, targets)
59
+ optimizer.zero_grad()
60
+ loss.backward()
61
+ optimizer.step()
62
+ total_loss += loss.item()
63
+ progress_bar.set_postfix(loss=loss.item())
64
+ progress_bar.close()
65
+ return total_loss / len(data_loader)
66
+
67
+ # Function to perform a validation step with progress bar
68
+ def validate(model, data_loader, criterion):
69
+ model.eval()
70
+ total_loss = 0
71
+ correct = 0
72
+ progress_bar = tqdm(data_loader, desc='Validation', leave=False)
73
+ with torch.no_grad():
74
+ for inputs, targets in progress_bar:
75
+ outputs = model(inputs)
76
+ loss = criterion(outputs, targets)
77
+ total_loss += loss.item()
78
+ _, predicted = torch.max(outputs, 1)
79
+ correct += (predicted == targets).sum().item()
80
+ progress_bar.set_postfix(loss=loss.item())
81
+ progress_bar.close()
82
+ return total_loss / len(data_loader), correct / len(data_loader.dataset)
83
+
84
+ # Training loop with progress bar for epochs
85
+ num_epochs = 25 # Replace with the number of epochs you'd like to train for
86
+ for epoch in range(num_epochs):
87
+ print(f"Epoch {epoch+1}/{num_epochs}")
88
+ train_loss = train_step(model, dataloaders["train"], criterion, optimizer)
89
+ val_loss, val_accuracy = validate(model, dataloaders["test"], criterion)
90
+ print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")