pytholic commited on
Commit
fb340a7
·
1 Parent(s): 95ce0a4

model fixed

Browse files
Files changed (2) hide show
  1. model.py +39 -29
  2. models/checkpoint.ckpt +2 -2
model.py CHANGED
@@ -1,6 +1,7 @@
1
  import pytorch_lightning as pl
2
  import torch
3
  import torchmetrics
 
4
  from simple_parsing import ArgumentParser
5
  from torch import nn
6
  from torch.nn import functional as F
@@ -14,46 +15,48 @@ args = args_namespace.options
14
 
15
  # Model class
16
  class Model(nn.Module):
17
- def __init__(self):
18
  super().__init__()
19
 
20
- self.conv1 = nn.Conv2d(3, 32, 5)
21
- self.conv2 = nn.Conv2d(32, 64, 5)
22
- self.conv3 = nn.Conv2d(64, 128, 3)
23
- self.dropout1 = nn.Dropout2d(0.25)
24
- self.dropout2 = nn.Dropout2d(0.5)
 
 
 
 
 
 
 
25
 
26
- x = torch.randn(3, 224, 224).view(-1, 3, 224, 224)
27
- self._to_linear = None
28
- self.convs(x)
 
29
 
30
- self.fc1 = nn.Linear(self._to_linear, 128)
31
- self.fc2 = nn.Linear(128, args.num_classes)
 
32
 
33
  def convs(self, x):
34
- x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
35
- x = self.dropout1(x)
36
- x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
37
- x = self.dropout2(x)
38
- x = F.max_pool2d(F.relu(self.conv3(x)), (2, 2))
39
-
40
- if self._to_linear is None:
41
- self._to_linear = x[0].shape[0] * x[0].shape[1] * x[0].shape[2]
42
  return x
43
 
44
  def forward(self, x):
 
45
  x = self.convs(x)
46
- x = x.view(-1, self._to_linear)
47
- x = F.relu(self.fc1(x))
48
- x = self.fc2(x)
49
- return F.log_softmax(x, dim=1)
50
 
51
 
52
  class Classifier(pl.LightningModule):
53
  def __init__(self):
54
  super().__init__()
55
 
56
- self.model = Model()
57
  self.accuracy = torchmetrics.Accuracy(
58
  task="multiclass", num_classes=args.num_classes
59
  )
@@ -62,13 +65,13 @@ class Classifier(pl.LightningModule):
62
  x = self.model(x)
63
  return x
64
 
65
- def nll_loss(self, logits, labels):
66
- return F.nll_loss(logits, labels)
67
 
68
  def training_step(self, train_batch, batch_idx):
69
  x, y = train_batch
70
  logits = self.model(x)
71
- loss = self.nll_loss(logits, y)
72
  acc = self.accuracy(logits, y)
73
  self.log("accuracy/train_accuracy", acc)
74
  self.log("loss/train_loss", loss)
@@ -77,11 +80,18 @@ class Classifier(pl.LightningModule):
77
  def validation_step(self, val_batch, batch_idx):
78
  x, y = val_batch
79
  logits = self.model(x)
80
- loss = self.nll_loss(logits, y)
81
  acc = self.accuracy(logits, y)
82
  self.log("accuracy/val_accuracy", acc)
83
  self.log("loss/val_loss", loss)
84
 
85
  def configure_optimizers(self):
86
  optimizer = torch.optim.Adam(self.parameters(), lr=args.learning_rate)
87
- return optimizer
 
 
 
 
 
 
 
 
1
  import pytorch_lightning as pl
2
  import torch
3
  import torchmetrics
4
+ import torchvision.models as models
5
  from simple_parsing import ArgumentParser
6
  from torch import nn
7
  from torch.nn import functional as F
 
15
 
16
  # Model class
17
  class Model(nn.Module):
18
+ def __init__(self, input_shape, weights=args.weights):
19
  super().__init__()
20
 
21
+ self.feature_extractor = models.resnet18(weights=weights)
22
+
23
+ if weights:
24
+ # layers are frozen by using eval()
25
+ self.feature_extractor.eval()
26
+ # freeze params
27
+ for param in self.feature_extractor.parameters():
28
+ param.requires_grad = False
29
+
30
+ n_size = self._get_conv_output(input_shape)
31
+
32
+ self.classifier = nn.Linear(n_size, args.num_classes)
33
 
34
+ # returns the size of the output tensor going into the Linear layer from the conv block.
35
+ def _get_conv_output(self, shape):
36
+ batch_size = 1
37
+ tmp_input = torch.autograd.Variable(torch.rand(batch_size, *shape))
38
 
39
+ output_feat = self.convs(tmp_input)
40
+ n_size = output_feat.data.view(batch_size, -1).size(1)
41
+ return n_size
42
 
43
  def convs(self, x):
44
+ x = self.feature_extractor(x)
 
 
 
 
 
 
 
45
  return x
46
 
47
  def forward(self, x):
48
+
49
  x = self.convs(x)
50
+ x = x.view(x.size(0), -1)
51
+ x = self.classifier(x)
52
+ return x
 
53
 
54
 
55
  class Classifier(pl.LightningModule):
56
  def __init__(self):
57
  super().__init__()
58
 
59
+ self.model = Model(input_shape=args.input_shape)
60
  self.accuracy = torchmetrics.Accuracy(
61
  task="multiclass", num_classes=args.num_classes
62
  )
 
65
  x = self.model(x)
66
  return x
67
 
68
+ def ce_loss(self, logits, labels):
69
+ return F.cross_entropy(logits, labels)
70
 
71
  def training_step(self, train_batch, batch_idx):
72
  x, y = train_batch
73
  logits = self.model(x)
74
+ loss = self.ce_loss(logits, y)
75
  acc = self.accuracy(logits, y)
76
  self.log("accuracy/train_accuracy", acc)
77
  self.log("loss/train_loss", loss)
 
80
  def validation_step(self, val_batch, batch_idx):
81
  x, y = val_batch
82
  logits = self.model(x)
83
+ loss = self.ce_loss(logits, y)
84
  acc = self.accuracy(logits, y)
85
  self.log("accuracy/val_accuracy", acc)
86
  self.log("loss/val_loss", loss)
87
 
88
  def configure_optimizers(self):
89
  optimizer = torch.optim.Adam(self.parameters(), lr=args.learning_rate)
90
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
91
+ optimizer, mode="min", patience=7
92
+ )
93
+ return {
94
+ "optimizer": optimizer,
95
+ "lr_scheduler": scheduler,
96
+ "monitor": "loss/val_loss",
97
+ }
models/checkpoint.ckpt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:85f9ff02a03ded56ff20903f0227f017ec351a9946b7fa0a1b7c33b7107427d6
3
- size 124442154
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:293e1c26f4f6ec48d58824658c41e3a4509d1313b2bdf15177020893e0ed1df5
3
+ size 140535623