Update app.py
Browse files
app.py
CHANGED
@@ -97,7 +97,15 @@ class CustomImageDataset(Dataset):
|
|
97 |
def fine_tune_classification_model(train_loader):
|
98 |
# Load the ResNet model with ignore_mismatched_sizes
|
99 |
model = ResNetForImageClassification.from_pretrained('microsoft/resnet-50', num_labels=3, ignore_mismatched_sizes=True)
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
model.train()
|
103 |
|
|
|
97 |
def fine_tune_classification_model(train_loader):
|
98 |
# Load the ResNet model with ignore_mismatched_sizes
|
99 |
model = ResNetForImageClassification.from_pretrained('microsoft/resnet-50', num_labels=3, ignore_mismatched_sizes=True)
|
100 |
+
# Print model architecture to identify the classifier layer
|
101 |
+
print(model) # Inspect the model structure
|
102 |
+
|
103 |
+
# Update the classifier layer to match the number of labels
|
104 |
+
if hasattr(model, 'classifier'):
|
105 |
+
model.classifier = torch.nn.Linear(model.classifier.in_features, 3) # Assuming 3 output classes
|
106 |
+
else:
|
107 |
+
# Access the linear layer differently if 'classifier' does not exist
|
108 |
+
|
109 |
|
110 |
model.train()
|
111 |
|