File size: 3,899 Bytes
f71d177
 
 
e12b6d4
f71d177
cb92b08
 
ce2de6c
cb92b08
 
f71d177
 
cb92b08
6e1e5e8
 
 
 
 
ce2de6c
f71d177
 
 
 
 
 
 
e12b6d4
f71d177
e12b6d4
f71d177
cb92b08
f71d177
 
 
 
 
 
e12b6d4
 
 
f71d177
 
e12b6d4
 
 
 
 
 
 
 
 
 
 
 
 
f71d177
 
 
e12b6d4
 
f71d177
 
e12b6d4
f71d177
e12b6d4
417c33d
e12b6d4
 
 
f71d177
 
e12b6d4
 
f71d177
e12b6d4
f71d177
e12b6d4
b940d83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f71d177
33b1db2
 
 
b940d83
 
417c33d
e12b6d4
6e1e5e8
e12b6d4
 
 
 
cb92b08
 
e12b6d4
cb92b08
 
e12b6d4
 
cb92b08
 
417c33d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets, models
from PIL import Image
import json
import os
import gradio as gr

# Paths
image_folder = "Images/"
metadata_file = "descriptions.json"

# Define the function to load metadata
def load_metadata(metadata_file):
    with open(metadata_file, 'r') as f:
        metadata = json.load(f)
    return metadata

# Custom Dataset Class
class ImageDescriptionDataset(Dataset):
    def __init__(self, image_folder, metadata):
        self.image_folder = image_folder
        self.metadata = metadata
        self.image_names = list(metadata.keys())  # List of image filenames
        self.transform = transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ])

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_name = self.image_names[idx]
        image_path = os.path.join(self.image_folder, image_name)
        image = Image.open(image_path).convert("RGB")
        description = self.metadata[image_name]
        image = self.transform(image)
        return image, description

# LoRA Layer Implementation
class LoRALayer(nn.Module):
    def __init__(self, original_layer, rank=4):
        super(LoRALayer, self).__init__()
        self.original_layer = original_layer
        self.rank = rank
        self.lora_up = nn.Linear(original_layer.in_features, rank, bias=False)
        self.lora_down = nn.Linear(rank, original_layer.out_features, bias=False)

    def forward(self, x):
        return self.original_layer(x) + self.lora_down(self.lora_up(x))

# LoRA Model Class
class LoRAModel(nn.Module):
    def __init__(self):
        super(LoRAModel, self).__init__()
        self.backbone = models.resnet18(pretrained=True)  # Base model
        self.backbone.fc = LoRALayer(self.backbone.fc)  # Replace the final layer with LoRA

    def forward(self, x):
        return self.backbone(x)

# Training Function
def train_lora(image_folder, metadata):
    print("Starting LoRA training process...")

    # Create dataset and dataloader
    dataset = ImageDescriptionDataset(image_folder, metadata)
    dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

    # Initialize model, loss function, and optimizer
    model = LoRAModel()
    criterion = nn.CrossEntropyLoss()  # Update this if your task changes
    optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 5  # Adjust the number of epochs based on your needs
for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    for batch_idx, (images, descriptions) in enumerate(dataloader):
        # Convert descriptions to a numerical format (if applicable)
        labels = torch.randint(0, 100, (images.size(0),))  # Placeholder labels

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch_idx % 10 == 0:  # Log every 10 batches
            print(f"Batch {batch_idx}, Loss: {loss.item()}")

# Save the trained model
torch.save(model.state_dict(), "lora_model.pth")
print("Model saved as lora_model.pth")

print("Training completed.")

# Gradio App
def start_training_gradio():
    print("Loading metadata and preparing dataset...")
    metadata = load_metadata(metadata_file)
    train_lora(image_folder, metadata)
    return "Training completed. Check the model outputs!"

demo = gr.Interface(
    fn=start_training_gradio,
    inputs=None,
    outputs="text",
    title="Train LoRA Model",
    description="Fine-tune a model using LoRA for consistent image generation."
)

demo.launch()