Tanusree88 commited on
Commit
6bd55a4
·
verified ·
1 Parent(s): 7ec0f50

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -97,7 +97,8 @@ class CustomImageDataset(Dataset):
97
  def fine_tune_classification_model(train_loader):
98
  # Load the ResNet model with ignore_mismatched_sizes
99
  model = ResNetForImageClassification.from_pretrained('microsoft/resnet-50', num_labels=3, ignore_mismatched_sizes=True)
100
- model.classifier = torch.nn.Linear(model.config.num_features, 3) # Update classifier for 3 labels
 
101
  model.train()
102
 
103
  optimizer = AdamW(model.parameters(), lr=1e-4)
 
97
  def fine_tune_classification_model(train_loader):
98
  # Load the ResNet model with ignore_mismatched_sizes
99
  model = ResNetForImageClassification.from_pretrained('microsoft/resnet-50', num_labels=3, ignore_mismatched_sizes=True)
100
+ model.fc = torch.nn.Linear(model.fc.in_features, 3) # Assuming 3 output classes
101
+
102
  model.train()
103
 
104
  optimizer = AdamW(model.parameters(), lr=1e-4)