Tanusree88 commited on
Commit
7ec0f50
·
verified ·
1 Parent(s): 0c1799d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -97,7 +97,7 @@ class CustomImageDataset(Dataset):
97
  def fine_tune_classification_model(train_loader):
98
  # Load the ResNet model with ignore_mismatched_sizes
99
  model = ResNetForImageClassification.from_pretrained('microsoft/resnet-50', num_labels=3, ignore_mismatched_sizes=True)
100
- model.classifier = torch.nn.Linear(model.config.hidden_size, 3) # Update classifier for 3 labels
101
  model.train()
102
 
103
  optimizer = AdamW(model.parameters(), lr=1e-4)
 
97
  def fine_tune_classification_model(train_loader):
98
  # Load the ResNet model with ignore_mismatched_sizes
99
  model = ResNetForImageClassification.from_pretrained('microsoft/resnet-50', num_labels=3, ignore_mismatched_sizes=True)
100
+ model.classifier = torch.nn.Linear(model.config.num_features, 3) # Update classifier for 3 labels
101
  model.train()
102
 
103
  optimizer = AdamW(model.parameters(), lr=1e-4)