Update app.py
Browse files
app.py
CHANGED
@@ -44,14 +44,16 @@ class ImageDescriptionDataset(Dataset):
|
|
44 |
class LoRAModel(nn.Module):
|
45 |
def __init__(self):
|
46 |
super(LoRAModel, self).__init__()
|
47 |
-
self.backbone = models.resnet18(
|
|
|
48 |
self.fc = nn.Linear(self.backbone.fc.in_features, 100) # Placeholder output layer
|
49 |
|
50 |
def forward(self, x):
|
51 |
-
x = self.backbone(x)
|
52 |
-
x = self.fc(x)
|
53 |
return x
|
54 |
|
|
|
55 |
# Function to train LoRA
|
56 |
def train_lora(image_folder, metadata):
|
57 |
print("Starting training process...")
|
|
|
44 |
class LoRAModel(nn.Module):
|
45 |
def __init__(self):
|
46 |
super(LoRAModel, self).__init__()
|
47 |
+
self.backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) # Use the updated weights argument
|
48 |
+
# Adjust the final fully connected layer to match the output of the backbone
|
49 |
self.fc = nn.Linear(self.backbone.fc.in_features, 100) # Placeholder output layer
|
50 |
|
51 |
def forward(self, x):
|
52 |
+
x = self.backbone(x) # Pass through the ResNet18 backbone
|
53 |
+
x = self.fc(x) # Apply the final fully connected layer
|
54 |
return x
|
55 |
|
56 |
+
|
57 |
# Function to train LoRA
|
58 |
def train_lora(image_folder, metadata):
|
59 |
print("Starting training process...")
|