Update app.py
Browse files
app.py
CHANGED
@@ -97,7 +97,7 @@ class CustomImageDataset(Dataset):
|
|
97 |
def fine_tune_classification_model(train_loader):
|
98 |
# Load the ResNet model with ignore_mismatched_sizes
|
99 |
model = ResNetForImageClassification.from_pretrained('microsoft/resnet-50', num_labels=3, ignore_mismatched_sizes=True)
|
100 |
-
model.classifier = torch.nn.Linear(model.config.
|
101 |
model.train()
|
102 |
|
103 |
optimizer = AdamW(model.parameters(), lr=1e-4)
|
|
|
97 |
def fine_tune_classification_model(train_loader):
|
98 |
# Load the ResNet model with ignore_mismatched_sizes
|
99 |
model = ResNetForImageClassification.from_pretrained('microsoft/resnet-50', num_labels=3, ignore_mismatched_sizes=True)
|
100 |
+
model.classifier = torch.nn.Linear(model.config.num_features, 3) # Update classifier for 3 labels
|
101 |
model.train()
|
102 |
|
103 |
optimizer = AdamW(model.parameters(), lr=1e-4)
|