import torch from torch import nn, optim from torchvision import transforms, datasets, models from torch.utils.data import DataLoader, Dataset from PIL import Image import json import os import gradio as gr # Paths image_folder = "Images/" metadata_file = "descriptions.json" # Define the function to load metadata def load_metadata(metadata_file): with open(metadata_file, 'r') as f: metadata = json.load(f) return metadata # Custom Dataset Class class ImageDescriptionDataset(Dataset): def __init__(self, image_folder, metadata): self.image_folder = image_folder self.metadata = metadata self.image_names = list(metadata.keys()) # List of image filenames self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def __len__(self): return len(self.image_names) def __getitem__(self, idx): image_name = self.image_names[idx] image_path = os.path.join(self.image_folder, image_name) image = Image.open(image_path).convert("RGB") # Open the image and convert to RGB description = self.metadata[image_name] # Get description for the image image = self.transform(image) # Apply transformations return image, description # LoRA Model Class (This is a placeholder, you'll need to implement the actual LoRA model) class LoRAModel(nn.Module): def __init__(self): super(LoRAModel, self).__init__() self.backbone = models.resnet18(pretrained=True) # Using a pre-trained ResNet18 self.fc = nn.Linear(self.backbone.fc.in_features, 100) # Placeholder output layer def forward(self, x): x = self.backbone(x) x = self.fc(x) return x # Function to train LoRA def train_lora(image_folder, metadata): print("Starting training process...") # Create dataset and dataloaders dataset = ImageDescriptionDataset(image_folder, metadata) dataloader = DataLoader(dataset, batch_size=8, shuffle=True) # Initialize model, loss, and optimizer model = LoRAModel() criterion = nn.CrossEntropyLoss() # Placeholder loss function, can be adjusted optimizer = optim.Adam(model.parameters(), lr=0.001) # Training loop num_epochs = 5 # Adjust the number of epochs based on your needs for epoch in range(num_epochs): print(f"Epoch {epoch + 1}/{num_epochs}") for batch_idx, (images, descriptions) in enumerate(dataloader): # Here we would convert descriptions to a numerical format # Since it's a placeholder, we use random labels for descriptions labels = torch.randint(0, 100, (images.size(0),)) # Random labels as a placeholder # Forward pass outputs = model(images) loss = criterion(outputs, labels) # Backward pass optimizer.zero_grad() loss.backward() optimizer.step() if batch_idx % 10 == 0: # Log every 10 batches print(f"Batch {batch_idx}, Loss: {loss.item()}") print("Training completed.") # Gradio app function to load metadata and start training def start_training_gradio(): print("Preparing dataset...") metadata = load_metadata(metadata_file) # Load metadata return train_lora(image_folder, metadata) # Gradio interface demo = gr.Interface( fn=start_training_gradio, # Use the new function name here inputs=None, outputs="text", title="Train LoRA on Your Dataset", description="Click below to start training with the uploaded images and metadata." ) demo.launch()