File size: 4,247 Bytes
f71d177 e12b6d4 f71d177 cb92b08 ce2de6c cb92b08 f71d177 cb92b08 6e1e5e8 ce2de6c f71d177 e12b6d4 f71d177 e12b6d4 f71d177 cb92b08 f71d177 e12b6d4 f71d177 e12b6d4 f71d177 e12b6d4 f71d177 e12b6d4 f71d177 e12b6d4 417c33d e12b6d4 f71d177 e12b6d4 f71d177 e12b6d4 f71d177 e12b6d4 8894709 d8ec44f 8894709 417c33d e12b6d4 6e1e5e8 e12b6d4 cb92b08 e12b6d4 cb92b08 e12b6d4 cb92b08 417c33d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets, models
from PIL import Image
import json
import os
import gradio as gr
# Paths
image_folder = "Images/"
metadata_file = "descriptions.json"
# Define the function to load metadata
def load_metadata(metadata_file):
with open(metadata_file, 'r') as f:
metadata = json.load(f)
return metadata
# Custom Dataset Class
class ImageDescriptionDataset(Dataset):
def __init__(self, image_folder, metadata):
self.image_folder = image_folder
self.metadata = metadata
self.image_names = list(metadata.keys()) # List of image filenames
self.transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
def __len__(self):
return len(self.image_names)
def __getitem__(self, idx):
image_name = self.image_names[idx]
image_path = os.path.join(self.image_folder, image_name)
image = Image.open(image_path).convert("RGB")
description = self.metadata[image_name]
image = self.transform(image)
return image, description
# LoRA Layer Implementation
class LoRALayer(nn.Module):
def __init__(self, original_layer, rank=4):
super(LoRALayer, self).__init__()
self.original_layer = original_layer
self.rank = rank
self.lora_up = nn.Linear(original_layer.in_features, rank, bias=False)
self.lora_down = nn.Linear(rank, original_layer.out_features, bias=False)
def forward(self, x):
return self.original_layer(x) + self.lora_down(self.lora_up(x))
# LoRA Model Class
class LoRAModel(nn.Module):
def __init__(self):
super(LoRAModel, self).__init__()
self.backbone = models.resnet18(pretrained=True) # Base model
self.backbone.fc = LoRALayer(self.backbone.fc) # Replace the final layer with LoRA
def forward(self, x):
return self.backbone(x)
# Training Function
def train_lora(image_folder, metadata):
print("Starting LoRA training process...")
# Create dataset and dataloader
dataset = ImageDescriptionDataset(image_folder, metadata)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
# Initialize model, loss function, and optimizer
model = LoRAModel()
criterion = nn.CrossEntropyLoss() # Update this if your task changes
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training loop
num_epochs = 5 # Adjust the number of epochs based on your needs
for epoch in range(num_epochs):
print(f"Epoch {epoch + 1}/{num_epochs}")
for batch_idx, (images, descriptions) in enumerate(dataloader):
# Convert descriptions to a numerical format (if applicable)
labels = torch.randint(0, 100, (images.size(0),)) # Placeholder labels
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_idx % 10 == 0: # Log every 10 batches
print(f"Batch {batch_idx}, Loss: {loss.item()}")
# Define the folder where the model will be saved
save_folder = "models"
os.makedirs(save_folder, exist_ok=True) # Create the folder if it doesn't exist
# Save the trained model in the specified folder
model_save_path = os.path.join(save_folder, "lora_model.pth")
torch.save(model.state_dict(), model_save_path)
print(f"Model saved at {model_save_path}")
print("Training completed.")
# Gradio App
def start_training_gradio():
print("Loading metadata and preparing dataset...")
metadata = load_metadata(metadata_file)
train_lora(image_folder, metadata)
return "Training completed. Check the model outputs!"
demo = gr.Interface(
fn=start_training_gradio,
inputs=None,
outputs="text",
title="Train LoRA Model",
description="Fine-tune a model using LoRA for consistent image generation."
)
demo.launch()
|