Update app.py
Browse files
app.py
CHANGED
@@ -105,6 +105,7 @@ def fine_tune_classification_model(train_loader):
|
|
105 |
model.classifier = torch.nn.Linear(model.classifier.in_features, 3) # Assuming 3 output classes
|
106 |
else:
|
107 |
# Access the linear layer differently if 'classifier' does not exist
|
|
|
108 |
|
109 |
|
110 |
model.train()
|
|
|
105 |
model.classifier = torch.nn.Linear(model.classifier.in_features, 3) # Assuming 3 output classes
|
106 |
else:
|
107 |
# Access the linear layer differently if 'classifier' does not exist
|
108 |
+
model.classifier = torch.nn.Linear(model.config.num_labels, 3) # Update according to available layers
|
109 |
|
110 |
|
111 |
model.train()
|