DonImages commited on
Commit
ce9a2bf
·
verified ·
1 Parent(s): 6e1e5e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -44,14 +44,16 @@ class ImageDescriptionDataset(Dataset):
44
  class LoRAModel(nn.Module):
45
  def __init__(self):
46
  super(LoRAModel, self).__init__()
47
- self.backbone = models.resnet18(pretrained=True) # Using a pre-trained ResNet18
 
48
  self.fc = nn.Linear(self.backbone.fc.in_features, 100) # Placeholder output layer
49
 
50
  def forward(self, x):
51
- x = self.backbone(x)
52
- x = self.fc(x)
53
  return x
54
 
 
55
  # Function to train LoRA
56
  def train_lora(image_folder, metadata):
57
  print("Starting training process...")
 
44
  class LoRAModel(nn.Module):
45
  def __init__(self):
46
  super(LoRAModel, self).__init__()
47
+ self.backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) # Use the updated weights argument
48
+ # Adjust the final fully connected layer to match the output of the backbone
49
  self.fc = nn.Linear(self.backbone.fc.in_features, 100) # Placeholder output layer
50
 
51
  def forward(self, x):
52
+ x = self.backbone(x) # Pass through the ResNet18 backbone
53
+ x = self.fc(x) # Apply the final fully connected layer
54
  return x
55
 
56
+
57
  # Function to train LoRA
58
  def train_lora(image_folder, metadata):
59
  print("Starting training process...")