File size: 4,126 Bytes
f71d177
 
 
962febd
f71d177
cb92b08
 
ce2de6c
cb92b08
 
f71d177
 
cb92b08
6e1e5e8
 
 
 
 
ce2de6c
f71d177
 
 
 
 
 
 
e12b6d4
f71d177
e12b6d4
f71d177
cb92b08
f71d177
 
 
 
 
 
e12b6d4
 
 
f71d177
 
e12b6d4
 
 
 
 
 
 
 
 
 
 
 
 
f71d177
 
 
e12b6d4
 
f71d177
 
e12b6d4
f71d177
e12b6d4
417c33d
e12b6d4
 
 
f71d177
 
e12b6d4
 
f71d177
e12b6d4
f71d177
e12b6d4
8894709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375075f
 
 
 
d8ec44f
375075f
 
417c33d
e12b6d4
6e1e5e8
e12b6d4
 
375075f
 
cb92b08
375075f
cb92b08
e12b6d4
cb92b08
375075f
e12b6d4
 
cb92b08
 
417c33d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets, models
from PIL import Image
import json
import os
import gradio as gr

# Paths
image_folder = "Images/"
metadata_file = "descriptions.json"

# Define the function to load metadata
def load_metadata(metadata_file):
    with open(metadata_file, 'r') as f:
        metadata = json.load(f)
    return metadata

# Custom Dataset Class
class ImageDescriptionDataset(Dataset):
    def __init__(self, image_folder, metadata):
        self.image_folder = image_folder
        self.metadata = metadata
        self.image_names = list(metadata.keys())  # List of image filenames
        self.transform = transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ])

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_name = self.image_names[idx]
        image_path = os.path.join(self.image_folder, image_name)
        image = Image.open(image_path).convert("RGB")
        description = self.metadata[image_name]
        image = self.transform(image)
        return image, description

# LoRA Layer Implementation
class LoRALayer(nn.Module):
    def __init__(self, original_layer, rank=4):
        super(LoRALayer, self).__init__()
        self.original_layer = original_layer
        self.rank = rank
        self.lora_up = nn.Linear(original_layer.in_features, rank, bias=False)
        self.lora_down = nn.Linear(rank, original_layer.out_features, bias=False)

    def forward(self, x):
        return self.original_layer(x) + self.lora_down(self.lora_up(x))

# LoRA Model Class
class LoRAModel(nn.Module):
    def __init__(self):
        super(LoRAModel, self).__init__()
        self.backbone = models.resnet18(pretrained=True)  # Base model
        self.backbone.fc = LoRALayer(self.backbone.fc)  # Replace the final layer with LoRA

    def forward(self, x):
        return self.backbone(x)

# Training Function
def train_lora(image_folder, metadata):
    print("Starting LoRA training process...")

    # Create dataset and dataloader
    dataset = ImageDescriptionDataset(image_folder, metadata)
    dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

    # Initialize model, loss function, and optimizer
    model = LoRAModel()
    criterion = nn.CrossEntropyLoss()  # Update this if your task changes
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Training loop
    num_epochs = 5  # Adjust the number of epochs based on your needs
    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        for batch_idx, (images, descriptions) in enumerate(dataloader):
            # Convert descriptions to a numerical format (if applicable)
            labels = torch.randint(0, 100, (images.size(0),))  # Placeholder labels

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if batch_idx % 10 == 0:  # Log every 10 batches
                print(f"Batch {batch_idx}, Loss: {loss.item()}")

    # Save the trained model
    model_path = "lora_model.pth"
    torch.save(model.state_dict(), model_path)
    print(f"Model saved as {model_path}")

    print("Training completed.")
    return model_path  # Return the path of the saved model

# Gradio App
def start_training_gradio():
    print("Loading metadata and preparing dataset...")
    metadata = load_metadata(metadata_file)
    model_path = train_lora(image_folder, metadata)
    return model_path  # This will return the model file path for download

# Gradio interface
demo = gr.Interface(
    fn=start_training_gradio,
    inputs=None,
    outputs=gr.File(),
    title="Train LoRA Model",
    description="Fine-tune a model using LoRA for consistent image generation."
)

demo.launch()