|
import torch |
|
from torch import nn, optim |
|
from torchvision import transforms, datasets, models |
|
from torch.utils.data import DataLoader, Dataset |
|
from PIL import Image |
|
import json |
|
import os |
|
|
|
|
|
image_folder = "Images/" |
|
metadata_file = "descriptions.json" |
|
|
|
|
|
class ImageDescriptionDataset(Dataset): |
|
def __init__(self, image_folder, metadata): |
|
self.image_folder = image_folder |
|
self.metadata = metadata |
|
self.image_names = list(metadata.keys()) |
|
self.transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
def __len__(self): |
|
return len(self.image_names) |
|
|
|
def __getitem__(self, idx): |
|
image_name = self.image_names[idx] |
|
image_path = os.path.join(self.image_folder, image_name) |
|
image = Image.open(image_path).convert("RGB") |
|
description = self.metadata[image_name] |
|
image = self.transform(image) |
|
return image, description |
|
|
|
|
|
class LoRAModel(nn.Module): |
|
def __init__(self): |
|
super(LoRAModel, self).__init__() |
|
self.backbone = models.resnet18(pretrained=True) |
|
self.fc = nn.Linear(self.backbone.fc.in_features, 100) |
|
|
|
def forward(self, x): |
|
x = self.backbone(x) |
|
x = self.fc(x) |
|
return x |
|
|
|
|
|
def train_lora(image_folder, metadata): |
|
print("Starting training process...") |
|
|
|
|
|
dataset = ImageDescriptionDataset(image_folder, metadata) |
|
dataloader = DataLoader(dataset, batch_size=8, shuffle=True) |
|
|
|
|
|
model = LoRAModel() |
|
criterion = nn.CrossEntropyLoss() |
|
optimizer = optim.Adam(model.parameters(), lr=0.001) |
|
|
|
|
|
num_epochs = 5 |
|
for epoch in range(num_epochs): |
|
print(f"Epoch {epoch + 1}/{num_epochs}") |
|
for batch_idx, (images, descriptions) in enumerate(dataloader): |
|
|
|
|
|
labels = torch.randint(0, 100, (images.size(0),)) |
|
|
|
|
|
outputs = model(images) |
|
loss = criterion(outputs, labels) |
|
|
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
if batch_idx % 10 == 0: |
|
print(f"Batch {batch_idx}, Loss: {loss.item()}") |
|
|
|
print("Training completed.") |
|
|
|
|
|
def start_training(): |
|
print("Preparing dataset...") |
|
return train_lora(image_folder, metadata) |
|
|
|
|
|
demo = gr.Interface( |
|
fn=start_training, |
|
inputs=None, |
|
outputs="text", |
|
title="Train LoRA on Your Dataset", |
|
description="Click below to start training with the uploaded images and metadata." |
|
) |
|
|
|
demo.launch() |
|
|