import torch from torch import nn, optim from torch.utils.data import DataLoader, Dataset from torchvision import transforms, datasets, models from PIL import Image import json import os import gradio as gr # Paths image_folder = "Images/" metadata_file = "descriptions.json" # Define the function to load metadata def load_metadata(metadata_file): with open(metadata_file, 'r') as f: metadata = json.load(f) return metadata # Custom Dataset Class class ImageDescriptionDataset(Dataset): def __init__(self, image_folder, metadata): self.image_folder = image_folder self.metadata = metadata self.image_names = list(metadata.keys()) # List of image filenames self.transform = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) def __len__(self): return len(self.image_names) def __getitem__(self, idx): image_name = self.image_names[idx] image_path = os.path.join(self.image_folder, image_name) image = Image.open(image_path).convert("RGB") description = self.metadata[image_name] image = self.transform(image) return image, description # LoRA Layer Implementation class LoRALayer(nn.Module): def __init__(self, original_layer, rank=4): super(LoRALayer, self).__init__() self.original_layer = original_layer self.rank = rank self.lora_up = nn.Linear(original_layer.in_features, rank, bias=False) self.lora_down = nn.Linear(rank, original_layer.out_features, bias=False) def forward(self, x): return self.original_layer(x) + self.lora_down(self.lora_up(x)) # LoRA Model Class class LoRAModel(nn.Module): def __init__(self): super(LoRAModel, self).__init__() self.backbone = models.resnet18(pretrained=True) # Base model self.backbone.fc = LoRALayer(self.backbone.fc) # Replace the final layer with LoRA def forward(self, x): return self.backbone(x) # Training Function def train_lora(image_folder, metadata): print("Starting LoRA training process...") # Create dataset and dataloader dataset = ImageDescriptionDataset(image_folder, metadata) dataloader = DataLoader(dataset, batch_size=8, shuffle=True) # Initialize model, loss function, and optimizer model = LoRAModel() criterion = nn.CrossEntropyLoss() # Update this if your task changes optimizer = optim.Adam(model.parameters(), lr=0.001) # Training loop num_epochs = 5 # Adjust the number of epochs based on your needs for epoch in range(num_epochs): print(f"Epoch {epoch + 1}/{num_epochs}") for batch_idx, (images, descriptions) in enumerate(dataloader): # Convert descriptions to a numerical format (if applicable) labels = torch.randint(0, 100, (images.size(0),)) # Placeholder labels # Forward pass outputs = model(images) loss = criterion(outputs, labels) # Backward pass optimizer.zero_grad() loss.backward() optimizer.step() if batch_idx % 10 == 0: # Log every 10 batches print(f"Batch {batch_idx}, Loss: {loss.item()}") # Save the trained model torch.save(model.state_dict(), "lora_model.pth") print("Model saved as lora_model.pth") print("Training completed.") # Gradio App def start_training_gradio(): print("Loading metadata and preparing dataset...") metadata = load_metadata(metadata_file) train_lora(image_folder, metadata) return "Training completed. Check the model outputs!" demo = gr.Interface( fn=start_training_gradio, inputs=None, outputs="text", title="Train LoRA Model", description="Fine-tune a model using LoRA for consistent image generation." ) demo.launch()