DonImages commited on
Commit
f71d177
·
verified ·
1 Parent(s): 417c33d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -28
app.py CHANGED
@@ -1,42 +1,89 @@
1
- import gradio as gr
 
 
 
 
2
  import json
3
  import os
4
 
5
  # Paths
6
- image_folder = "Images/" # Folder containing the images
7
- metadata_file = "descriptions.json" # JSON file with image descriptions
8
 
9
- # Load metadata
10
- with open(metadata_file, "r") as f:
11
- metadata = json.load(f)
12
- print(f"Loaded metadata: {len(metadata)} items") # Print the number of descriptions
 
 
 
 
 
 
 
13
 
14
- # Placeholder function for training LoRA
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def train_lora(image_folder, metadata):
16
- print("Starting training process...") # Log the start of the training
17
- # Prepare a dataset of image paths and descriptions
18
- dataset = []
19
- for image_name, description in metadata.items():
20
- image_path = os.path.join(image_folder, image_name)
21
- if os.path.exists(image_path): # Ensure the image file exists
22
- dataset.append({"image": image_path, "description": description})
23
- print(f"Added {image_name} to dataset") # Log each added image
24
- else:
25
- print(f"Warning: {image_name} not found in {image_folder}") # Log missing images
26
-
27
- # Log how many images were successfully added
28
- num_images = len(dataset)
29
- print(f"Dataset prepared with {num_images} images.")
30
 
31
- # Placeholder for training logic
32
- # Replace this with your actual training code
33
- print("Training LoRA with the prepared dataset...")
34
 
35
- # For now, just return a message
36
- return f"Training LoRA with {num_images} images and their descriptions."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # Define Gradio app
39
  def start_training():
 
40
  return train_lora(image_folder, metadata)
41
 
42
  # Gradio interface
@@ -48,5 +95,4 @@ demo = gr.Interface(
48
  description="Click below to start training with the uploaded images and metadata."
49
  )
50
 
51
- # Launch the Gradio interface
52
  demo.launch()
 
1
+ import torch
2
+ from torch import nn, optim
3
+ from torchvision import transforms, datasets, models
4
+ from torch.utils.data import DataLoader, Dataset
5
+ from PIL import Image
6
  import json
7
  import os
8
 
9
  # Paths
10
+ image_folder = "Images/"
11
+ metadata_file = "descriptions.json"
12
 
13
+ # Custom Dataset Class
14
+ class ImageDescriptionDataset(Dataset):
15
+ def __init__(self, image_folder, metadata):
16
+ self.image_folder = image_folder
17
+ self.metadata = metadata
18
+ self.image_names = list(metadata.keys()) # List of image filenames
19
+ self.transform = transforms.Compose([
20
+ transforms.Resize((224, 224)),
21
+ transforms.ToTensor(),
22
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
23
+ ])
24
 
25
+ def __len__(self):
26
+ return len(self.image_names)
27
+
28
+ def __getitem__(self, idx):
29
+ image_name = self.image_names[idx]
30
+ image_path = os.path.join(self.image_folder, image_name)
31
+ image = Image.open(image_path).convert("RGB") # Open the image and convert to RGB
32
+ description = self.metadata[image_name] # Get description for the image
33
+ image = self.transform(image) # Apply transformations
34
+ return image, description
35
+
36
+ # LoRA Model Class (This is a placeholder, you'll need to implement the actual LoRA model)
37
+ class LoRAModel(nn.Module):
38
+ def __init__(self):
39
+ super(LoRAModel, self).__init__()
40
+ self.backbone = models.resnet18(pretrained=True) # Using a pre-trained ResNet18
41
+ self.fc = nn.Linear(self.backbone.fc.in_features, 100) # Placeholder output layer
42
+
43
+ def forward(self, x):
44
+ x = self.backbone(x)
45
+ x = self.fc(x)
46
+ return x
47
+
48
+ # Function to train LoRA
49
  def train_lora(image_folder, metadata):
50
+ print("Starting training process...")
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ # Create dataset and dataloaders
53
+ dataset = ImageDescriptionDataset(image_folder, metadata)
54
+ dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
55
 
56
+ # Initialize model, loss, and optimizer
57
+ model = LoRAModel()
58
+ criterion = nn.CrossEntropyLoss() # Placeholder loss function, can be adjusted
59
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
60
+
61
+ # Training loop
62
+ num_epochs = 5 # Adjust the number of epochs based on your needs
63
+ for epoch in range(num_epochs):
64
+ print(f"Epoch {epoch + 1}/{num_epochs}")
65
+ for batch_idx, (images, descriptions) in enumerate(dataloader):
66
+ # Here we would convert descriptions to a numerical format
67
+ # Since it's a placeholder, we use random labels for descriptions
68
+ labels = torch.randint(0, 100, (images.size(0),)) # Random labels as a placeholder
69
+
70
+ # Forward pass
71
+ outputs = model(images)
72
+ loss = criterion(outputs, labels)
73
+
74
+ # Backward pass
75
+ optimizer.zero_grad()
76
+ loss.backward()
77
+ optimizer.step()
78
+
79
+ if batch_idx % 10 == 0: # Log every 10 batches
80
+ print(f"Batch {batch_idx}, Loss: {loss.item()}")
81
+
82
+ print("Training completed.")
83
 
84
  # Define Gradio app
85
  def start_training():
86
+ print("Preparing dataset...")
87
  return train_lora(image_folder, metadata)
88
 
89
  # Gradio interface
 
95
  description="Click below to start training with the uploaded images and metadata."
96
  )
97
 
 
98
  demo.launch()