DonImages commited on
Commit
35ff6e4
·
verified ·
1 Parent(s): ce9a2bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -44,16 +44,18 @@ class ImageDescriptionDataset(Dataset):
44
  class LoRAModel(nn.Module):
45
  def __init__(self):
46
  super(LoRAModel, self).__init__()
47
- self.backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) # Use the updated weights argument
48
- # Adjust the final fully connected layer to match the output of the backbone
49
- self.fc = nn.Linear(self.backbone.fc.in_features, 100) # Placeholder output layer
 
 
 
50
 
51
  def forward(self, x):
52
- x = self.backbone(x) # Pass through the ResNet18 backbone
53
  x = self.fc(x) # Apply the final fully connected layer
54
  return x
55
 
56
-
57
  # Function to train LoRA
58
  def train_lora(image_folder, metadata):
59
  print("Starting training process...")
 
44
  class LoRAModel(nn.Module):
45
  def __init__(self):
46
  super(LoRAModel, self).__init__()
47
+ self.backbone = models.resnet18(pretrained=True) # Using a pre-trained ResNet18
48
+
49
+ # Fixing the shape mismatch: Input size to the fc layer should match ResNet output
50
+ self.fc = nn.Linear(self.backbone.fc.in_features, 100) # 100 is a placeholder for your output
51
+
52
+ # If you want to use LoRA, you will implement the low-rank adaptation mechanism here
53
 
54
  def forward(self, x):
55
+ x = self.backbone(x) # Extract features using the ResNet18 backbone
56
  x = self.fc(x) # Apply the final fully connected layer
57
  return x
58
 
 
59
  # Function to train LoRA
60
  def train_lora(image_folder, metadata):
61
  print("Starting training process...")