Tanusree88 commited on
Commit
0c1799d
·
verified ·
1 Parent(s): dd5fc17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -2
app.py CHANGED
@@ -95,8 +95,11 @@ class CustomImageDataset(Dataset):
95
 
96
  # Training function for classification
97
  def fine_tune_classification_model(train_loader):
98
- model = ResNetForImageClassification.from_pretrained('microsoft/resnet-50', num_labels=3)
 
 
99
  model.train()
 
100
  optimizer = AdamW(model.parameters(), lr=1e-4)
101
  criterion = torch.nn.CrossEntropyLoss()
102
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -112,6 +115,7 @@ def fine_tune_classification_model(train_loader):
112
  loss.backward()
113
  optimizer.step()
114
  running_loss += loss.item()
 
115
  return running_loss / len(train_loader)
116
 
117
  # Streamlit UI for Fine-tuning
@@ -140,8 +144,10 @@ if st.button("Start Training"):
140
 
141
  # Segmentation function (using SegFormer)
142
  def fine_tune_segmentation_model(train_loader):
143
- model = SegformerForImageSegmentation.from_pretrained('nvidia/segformer-b0', num_labels=3)
 
144
  model.train()
 
145
  optimizer = AdamW(model.parameters(), lr=1e-4)
146
  criterion = torch.nn.CrossEntropyLoss()
147
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -157,6 +163,7 @@ def fine_tune_segmentation_model(train_loader):
157
  loss.backward()
158
  optimizer.step()
159
  running_loss += loss.item()
 
160
  return running_loss / len(train_loader)
161
 
162
  # Add a button for segmentation training
 
95
 
96
  # Training function for classification
97
  def fine_tune_classification_model(train_loader):
98
+ # Load the ResNet model with ignore_mismatched_sizes
99
+ model = ResNetForImageClassification.from_pretrained('microsoft/resnet-50', num_labels=3, ignore_mismatched_sizes=True)
100
+ model.classifier = torch.nn.Linear(model.config.hidden_size, 3) # Update classifier for 3 labels
101
  model.train()
102
+
103
  optimizer = AdamW(model.parameters(), lr=1e-4)
104
  criterion = torch.nn.CrossEntropyLoss()
105
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
115
  loss.backward()
116
  optimizer.step()
117
  running_loss += loss.item()
118
+
119
  return running_loss / len(train_loader)
120
 
121
  # Streamlit UI for Fine-tuning
 
144
 
145
  # Segmentation function (using SegFormer)
146
  def fine_tune_segmentation_model(train_loader):
147
+ # Load the Segformer model with ignore_mismatched_sizes
148
+ model = SegformerForSemanticSegmentation.from_pretrained('nvidia/segformer-b0', num_labels=3, ignore_mismatched_sizes=True)
149
  model.train()
150
+
151
  optimizer = AdamW(model.parameters(), lr=1e-4)
152
  criterion = torch.nn.CrossEntropyLoss()
153
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
163
  loss.backward()
164
  optimizer.step()
165
  running_loss += loss.item()
166
+
167
  return running_loss / len(train_loader)
168
 
169
  # Add a button for segmentation training