|
import torch |
|
from torch import nn, optim |
|
from torch.utils.data import DataLoader, Dataset |
|
from torchvision import transforms, datasets, models |
|
from PIL import Image |
|
import json |
|
import os |
|
import gradio as gr |
|
|
|
|
|
image_folder = "Images/" |
|
metadata_file = "descriptions.json" |
|
|
|
|
|
def load_metadata(metadata_file): |
|
with open(metadata_file, 'r') as f: |
|
metadata = json.load(f) |
|
return metadata |
|
|
|
|
|
class ImageDescriptionDataset(Dataset): |
|
def __init__(self, image_folder, metadata): |
|
self.image_folder = image_folder |
|
self.metadata = metadata |
|
self.image_names = list(metadata.keys()) |
|
self.transform = transforms.Compose([ |
|
transforms.Resize((512, 512)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
|
]) |
|
|
|
def __len__(self): |
|
return len(self.image_names) |
|
|
|
def __getitem__(self, idx): |
|
image_name = self.image_names[idx] |
|
image_path = os.path.join(self.image_folder, image_name) |
|
image = Image.open(image_path).convert("RGB") |
|
description = self.metadata[image_name] |
|
image = self.transform(image) |
|
return image, description |
|
|
|
|
|
class LoRALayer(nn.Module): |
|
def __init__(self, original_layer, rank=4): |
|
super(LoRALayer, self).__init__() |
|
self.original_layer = original_layer |
|
self.rank = rank |
|
self.lora_up = nn.Linear(original_layer.in_features, rank, bias=False) |
|
self.lora_down = nn.Linear(rank, original_layer.out_features, bias=False) |
|
|
|
def forward(self, x): |
|
return self.original_layer(x) + self.lora_down(self.lora_up(x)) |
|
|
|
|
|
class LoRAModel(nn.Module): |
|
def __init__(self): |
|
super(LoRAModel, self).__init__() |
|
self.backbone = models.resnet18(pretrained=True) |
|
self.backbone.fc = LoRALayer(self.backbone.fc) |
|
|
|
def forward(self, x): |
|
return self.backbone(x) |
|
|
|
|
|
def train_lora(image_folder, metadata): |
|
print("Starting LoRA training process...") |
|
|
|
|
|
dataset = ImageDescriptionDataset(image_folder, metadata) |
|
dataloader = DataLoader(dataset, batch_size=8, shuffle=True) |
|
|
|
|
|
model = LoRAModel() |
|
criterion = nn.CrossEntropyLoss() |
|
optimizer = optim.Adam(model.parameters(), lr=0.001) |
|
|
|
|
|
num_epochs = 5 |
|
for epoch in range(num_epochs): |
|
print(f"Epoch {epoch + 1}/{num_epochs}") |
|
for batch_idx, (images, descriptions) in enumerate(dataloader): |
|
|
|
labels = torch.randint(0, 100, (images.size(0),)) |
|
|
|
|
|
outputs = model(images) |
|
loss = criterion(outputs, labels) |
|
|
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
if batch_idx % 10 == 0: |
|
print(f"Batch {batch_idx}, Loss: {loss.item()}") |
|
|
|
|
|
model_path = "lora_model.pth" |
|
torch.save(model.state_dict(), model_path) |
|
print(f"Model saved as {model_path}") |
|
|
|
print("Training completed.") |
|
return model_path |
|
|
|
|
|
def start_training_gradio(): |
|
print("Loading metadata and preparing dataset...") |
|
metadata = load_metadata(metadata_file) |
|
model_path = train_lora(image_folder, metadata) |
|
return model_path |
|
|
|
|
|
demo = gr.Interface( |
|
fn=start_training_gradio, |
|
inputs=None, |
|
outputs=gr.File(), |
|
title="Train LoRA Model", |
|
description="Fine-tune a model using LoRA for consistent image generation." |
|
) |
|
|
|
demo.launch() |
|
|