File size: 3,779 Bytes
f71d177
 
 
 
 
cb92b08
 
ce2de6c
cb92b08
 
f71d177
 
cb92b08
6e1e5e8
 
 
 
 
ce2de6c
f71d177
 
 
 
 
 
 
 
 
 
 
cb92b08
f71d177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417c33d
f71d177
cb92b08
f71d177
 
 
417c33d
f71d177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417c33d
6e1e5e8
 
f71d177
6e1e5e8
417c33d
cb92b08
417c33d
cb92b08
6e1e5e8
cb92b08
 
417c33d
 
cb92b08
 
417c33d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
from torch import nn, optim
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import json
import os
import gradio as gr

# Paths
image_folder = "Images/"
metadata_file = "descriptions.json"

# Define the function to load metadata
def load_metadata(metadata_file):
    with open(metadata_file, 'r') as f:
        metadata = json.load(f)
    return metadata

# Custom Dataset Class
class ImageDescriptionDataset(Dataset):
    def __init__(self, image_folder, metadata):
        self.image_folder = image_folder
        self.metadata = metadata
        self.image_names = list(metadata.keys())  # List of image filenames
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_name = self.image_names[idx]
        image_path = os.path.join(self.image_folder, image_name)
        image = Image.open(image_path).convert("RGB")  # Open the image and convert to RGB
        description = self.metadata[image_name]  # Get description for the image
        image = self.transform(image)  # Apply transformations
        return image, description

# LoRA Model Class (This is a placeholder, you'll need to implement the actual LoRA model)
class LoRAModel(nn.Module):
    def __init__(self):
        super(LoRAModel, self).__init__()
        self.backbone = models.resnet18(pretrained=True)  # Using a pre-trained ResNet18
        self.fc = nn.Linear(self.backbone.fc.in_features, 100)  # Placeholder output layer

    def forward(self, x):
        x = self.backbone(x)
        x = self.fc(x)
        return x

# Function to train LoRA
def train_lora(image_folder, metadata):
    print("Starting training process...")
    
    # Create dataset and dataloaders
    dataset = ImageDescriptionDataset(image_folder, metadata)
    dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
    
    # Initialize model, loss, and optimizer
    model = LoRAModel()
    criterion = nn.CrossEntropyLoss()  # Placeholder loss function, can be adjusted
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # Training loop
    num_epochs = 5  # Adjust the number of epochs based on your needs
    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        for batch_idx, (images, descriptions) in enumerate(dataloader):
            # Here we would convert descriptions to a numerical format
            # Since it's a placeholder, we use random labels for descriptions
            labels = torch.randint(0, 100, (images.size(0),))  # Random labels as a placeholder

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if batch_idx % 10 == 0:  # Log every 10 batches
                print(f"Batch {batch_idx}, Loss: {loss.item()}")

    print("Training completed.")

# Gradio app function to load metadata and start training
def start_training_gradio():
    print("Preparing dataset...")
    metadata = load_metadata(metadata_file)  # Load metadata
    return train_lora(image_folder, metadata)

# Gradio interface
demo = gr.Interface(
    fn=start_training_gradio,  # Use the new function name here
    inputs=None,
    outputs="text",
    title="Train LoRA on Your Dataset",
    description="Click below to start training with the uploaded images and metadata."
)

demo.launch()