DonImages commited on
Commit
e12b6d4
·
verified ·
1 Parent(s): 35ff6e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -40
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
  from torch import nn, optim
3
- from torchvision import transforms, datasets, models
4
  from torch.utils.data import DataLoader, Dataset
 
5
  from PIL import Image
6
  import json
7
  import os
@@ -24,9 +24,9 @@ class ImageDescriptionDataset(Dataset):
24
  self.metadata = metadata
25
  self.image_names = list(metadata.keys()) # List of image filenames
26
  self.transform = transforms.Compose([
27
- transforms.Resize((224, 224)),
28
  transforms.ToTensor(),
29
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
30
  ])
31
 
32
  def __len__(self):
@@ -35,76 +35,81 @@ class ImageDescriptionDataset(Dataset):
35
  def __getitem__(self, idx):
36
  image_name = self.image_names[idx]
37
  image_path = os.path.join(self.image_folder, image_name)
38
- image = Image.open(image_path).convert("RGB") # Open the image and convert to RGB
39
- description = self.metadata[image_name] # Get description for the image
40
- image = self.transform(image) # Apply transformations
41
  return image, description
42
 
43
- # LoRA Model Class (This is a placeholder, you'll need to implement the actual LoRA model)
 
 
 
 
 
 
 
 
 
 
 
 
44
  class LoRAModel(nn.Module):
45
  def __init__(self):
46
  super(LoRAModel, self).__init__()
47
- self.backbone = models.resnet18(pretrained=True) # Using a pre-trained ResNet18
48
-
49
- # Fixing the shape mismatch: Input size to the fc layer should match ResNet output
50
- self.fc = nn.Linear(self.backbone.fc.in_features, 100) # 100 is a placeholder for your output
51
-
52
- # If you want to use LoRA, you will implement the low-rank adaptation mechanism here
53
 
54
  def forward(self, x):
55
- x = self.backbone(x) # Extract features using the ResNet18 backbone
56
- x = self.fc(x) # Apply the final fully connected layer
57
- return x
58
 
59
- # Function to train LoRA
60
  def train_lora(image_folder, metadata):
61
- print("Starting training process...")
62
-
63
- # Create dataset and dataloaders
64
  dataset = ImageDescriptionDataset(image_folder, metadata)
65
  dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
66
-
67
- # Initialize model, loss, and optimizer
68
  model = LoRAModel()
69
- criterion = nn.CrossEntropyLoss() # Placeholder loss function, can be adjusted
70
  optimizer = optim.Adam(model.parameters(), lr=0.001)
71
-
72
  # Training loop
73
- num_epochs = 5 # Adjust the number of epochs based on your needs
74
  for epoch in range(num_epochs):
75
  print(f"Epoch {epoch + 1}/{num_epochs}")
76
  for batch_idx, (images, descriptions) in enumerate(dataloader):
77
- # Here we would convert descriptions to a numerical format
78
- # Since it's a placeholder, we use random labels for descriptions
79
- labels = torch.randint(0, 100, (images.size(0),)) # Random labels as a placeholder
80
 
81
  # Forward pass
82
  outputs = model(images)
83
  loss = criterion(outputs, labels)
84
-
85
  # Backward pass
86
  optimizer.zero_grad()
87
  loss.backward()
88
  optimizer.step()
89
-
90
- if batch_idx % 10 == 0: # Log every 10 batches
91
  print(f"Batch {batch_idx}, Loss: {loss.item()}")
92
 
93
- print("Training completed.")
94
 
95
- # Gradio app function to load metadata and start training
96
  def start_training_gradio():
97
- print("Preparing dataset...")
98
- metadata = load_metadata(metadata_file) # Load metadata
99
- return train_lora(image_folder, metadata)
 
100
 
101
- # Gradio interface
102
  demo = gr.Interface(
103
- fn=start_training_gradio, # Use the new function name here
104
  inputs=None,
105
  outputs="text",
106
- title="Train LoRA on Your Dataset",
107
- description="Click below to start training with the uploaded images and metadata."
108
  )
109
 
110
  demo.launch()
 
1
  import torch
2
  from torch import nn, optim
 
3
  from torch.utils.data import DataLoader, Dataset
4
+ from torchvision import transforms, datasets, models
5
  from PIL import Image
6
  import json
7
  import os
 
24
  self.metadata = metadata
25
  self.image_names = list(metadata.keys()) # List of image filenames
26
  self.transform = transforms.Compose([
27
+ transforms.Resize((512, 512)),
28
  transforms.ToTensor(),
29
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
30
  ])
31
 
32
  def __len__(self):
 
35
  def __getitem__(self, idx):
36
  image_name = self.image_names[idx]
37
  image_path = os.path.join(self.image_folder, image_name)
38
+ image = Image.open(image_path).convert("RGB")
39
+ description = self.metadata[image_name]
40
+ image = self.transform(image)
41
  return image, description
42
 
43
+ # LoRA Layer Implementation
44
+ class LoRALayer(nn.Module):
45
+ def __init__(self, original_layer, rank=4):
46
+ super(LoRALayer, self).__init__()
47
+ self.original_layer = original_layer
48
+ self.rank = rank
49
+ self.lora_up = nn.Linear(original_layer.in_features, rank, bias=False)
50
+ self.lora_down = nn.Linear(rank, original_layer.out_features, bias=False)
51
+
52
+ def forward(self, x):
53
+ return self.original_layer(x) + self.lora_down(self.lora_up(x))
54
+
55
+ # LoRA Model Class
56
  class LoRAModel(nn.Module):
57
  def __init__(self):
58
  super(LoRAModel, self).__init__()
59
+ self.backbone = models.resnet18(pretrained=True) # Base model
60
+ self.backbone.fc = LoRALayer(self.backbone.fc) # Replace the final layer with LoRA
 
 
 
 
61
 
62
  def forward(self, x):
63
+ return self.backbone(x)
 
 
64
 
65
+ # Training Function
66
  def train_lora(image_folder, metadata):
67
+ print("Starting LoRA training process...")
68
+
69
+ # Create dataset and dataloader
70
  dataset = ImageDescriptionDataset(image_folder, metadata)
71
  dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
72
+
73
+ # Initialize model, loss function, and optimizer
74
  model = LoRAModel()
75
+ criterion = nn.CrossEntropyLoss() # Update this if your task changes
76
  optimizer = optim.Adam(model.parameters(), lr=0.001)
77
+
78
  # Training loop
79
+ num_epochs = 5
80
  for epoch in range(num_epochs):
81
  print(f"Epoch {epoch + 1}/{num_epochs}")
82
  for batch_idx, (images, descriptions) in enumerate(dataloader):
83
+ # Placeholder: Convert descriptions to labels
84
+ labels = torch.randint(0, 100, (images.size(0),))
 
85
 
86
  # Forward pass
87
  outputs = model(images)
88
  loss = criterion(outputs, labels)
89
+
90
  # Backward pass
91
  optimizer.zero_grad()
92
  loss.backward()
93
  optimizer.step()
94
+
95
+ if batch_idx % 10 == 0:
96
  print(f"Batch {batch_idx}, Loss: {loss.item()}")
97
 
98
+ print("LoRA training completed.")
99
 
100
+ # Gradio App
101
  def start_training_gradio():
102
+ print("Loading metadata and preparing dataset...")
103
+ metadata = load_metadata(metadata_file)
104
+ train_lora(image_folder, metadata)
105
+ return "Training completed. Check the model outputs!"
106
 
 
107
  demo = gr.Interface(
108
+ fn=start_training_gradio,
109
  inputs=None,
110
  outputs="text",
111
+ title="Train LoRA Model",
112
+ description="Fine-tune a model using LoRA for consistent image generation."
113
  )
114
 
115
  demo.launch()