File size: 3,446 Bytes
f71d177 cb92b08 f71d177 cb92b08 f71d177 cb92b08 f71d177 417c33d f71d177 cb92b08 f71d177 417c33d f71d177 417c33d f71d177 417c33d cb92b08 417c33d cb92b08 417c33d cb92b08 417c33d cb92b08 417c33d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import torch
from torch import nn, optim
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import json
import os
# Paths
image_folder = "Images/"
metadata_file = "descriptions.json"
# Custom Dataset Class
class ImageDescriptionDataset(Dataset):
def __init__(self, image_folder, metadata):
self.image_folder = image_folder
self.metadata = metadata
self.image_names = list(metadata.keys()) # List of image filenames
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def __len__(self):
return len(self.image_names)
def __getitem__(self, idx):
image_name = self.image_names[idx]
image_path = os.path.join(self.image_folder, image_name)
image = Image.open(image_path).convert("RGB") # Open the image and convert to RGB
description = self.metadata[image_name] # Get description for the image
image = self.transform(image) # Apply transformations
return image, description
# LoRA Model Class (This is a placeholder, you'll need to implement the actual LoRA model)
class LoRAModel(nn.Module):
def __init__(self):
super(LoRAModel, self).__init__()
self.backbone = models.resnet18(pretrained=True) # Using a pre-trained ResNet18
self.fc = nn.Linear(self.backbone.fc.in_features, 100) # Placeholder output layer
def forward(self, x):
x = self.backbone(x)
x = self.fc(x)
return x
# Function to train LoRA
def train_lora(image_folder, metadata):
print("Starting training process...")
# Create dataset and dataloaders
dataset = ImageDescriptionDataset(image_folder, metadata)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
# Initialize model, loss, and optimizer
model = LoRAModel()
criterion = nn.CrossEntropyLoss() # Placeholder loss function, can be adjusted
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training loop
num_epochs = 5 # Adjust the number of epochs based on your needs
for epoch in range(num_epochs):
print(f"Epoch {epoch + 1}/{num_epochs}")
for batch_idx, (images, descriptions) in enumerate(dataloader):
# Here we would convert descriptions to a numerical format
# Since it's a placeholder, we use random labels for descriptions
labels = torch.randint(0, 100, (images.size(0),)) # Random labels as a placeholder
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_idx % 10 == 0: # Log every 10 batches
print(f"Batch {batch_idx}, Loss: {loss.item()}")
print("Training completed.")
# Define Gradio app
def start_training():
print("Preparing dataset...")
return train_lora(image_folder, metadata)
# Gradio interface
demo = gr.Interface(
fn=start_training,
inputs=None,
outputs="text",
title="Train LoRA on Your Dataset",
description="Click below to start training with the uploaded images and metadata."
)
demo.launch()
|