File size: 4,247 Bytes
f71d177
 
 
e12b6d4
f71d177
cb92b08
 
ce2de6c
cb92b08
 
f71d177
 
cb92b08
6e1e5e8
 
 
 
 
ce2de6c
f71d177
 
 
 
 
 
 
e12b6d4
f71d177
e12b6d4
f71d177
cb92b08
f71d177
 
 
 
 
 
e12b6d4
 
 
f71d177
 
e12b6d4
 
 
 
 
 
 
 
 
 
 
 
 
f71d177
 
 
e12b6d4
 
f71d177
 
e12b6d4
f71d177
e12b6d4
417c33d
e12b6d4
 
 
f71d177
 
e12b6d4
 
f71d177
e12b6d4
f71d177
e12b6d4
8894709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8ec44f
 
 
 
 
 
 
 
8894709
 
417c33d
e12b6d4
6e1e5e8
e12b6d4
 
 
 
cb92b08
 
e12b6d4
cb92b08
 
e12b6d4
 
cb92b08
 
417c33d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets, models
from PIL import Image
import json
import os
import gradio as gr

# Paths
image_folder = "Images/"
metadata_file = "descriptions.json"

# Define the function to load metadata
def load_metadata(metadata_file):
    with open(metadata_file, 'r') as f:
        metadata = json.load(f)
    return metadata

# Custom Dataset Class
class ImageDescriptionDataset(Dataset):
    def __init__(self, image_folder, metadata):
        self.image_folder = image_folder
        self.metadata = metadata
        self.image_names = list(metadata.keys())  # List of image filenames
        self.transform = transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ])

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_name = self.image_names[idx]
        image_path = os.path.join(self.image_folder, image_name)
        image = Image.open(image_path).convert("RGB")
        description = self.metadata[image_name]
        image = self.transform(image)
        return image, description

# LoRA Layer Implementation
class LoRALayer(nn.Module):
    def __init__(self, original_layer, rank=4):
        super(LoRALayer, self).__init__()
        self.original_layer = original_layer
        self.rank = rank
        self.lora_up = nn.Linear(original_layer.in_features, rank, bias=False)
        self.lora_down = nn.Linear(rank, original_layer.out_features, bias=False)

    def forward(self, x):
        return self.original_layer(x) + self.lora_down(self.lora_up(x))

# LoRA Model Class
class LoRAModel(nn.Module):
    def __init__(self):
        super(LoRAModel, self).__init__()
        self.backbone = models.resnet18(pretrained=True)  # Base model
        self.backbone.fc = LoRALayer(self.backbone.fc)  # Replace the final layer with LoRA

    def forward(self, x):
        return self.backbone(x)

# Training Function
def train_lora(image_folder, metadata):
    print("Starting LoRA training process...")

    # Create dataset and dataloader
    dataset = ImageDescriptionDataset(image_folder, metadata)
    dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

    # Initialize model, loss function, and optimizer
    model = LoRAModel()
    criterion = nn.CrossEntropyLoss()  # Update this if your task changes
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Training loop
    num_epochs = 5  # Adjust the number of epochs based on your needs
    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        for batch_idx, (images, descriptions) in enumerate(dataloader):
            # Convert descriptions to a numerical format (if applicable)
            labels = torch.randint(0, 100, (images.size(0),))  # Placeholder labels

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if batch_idx % 10 == 0:  # Log every 10 batches
                print(f"Batch {batch_idx}, Loss: {loss.item()}")

    # Define the folder where the model will be saved
    save_folder = "models"
    os.makedirs(save_folder, exist_ok=True)  # Create the folder if it doesn't exist

    # Save the trained model in the specified folder
    model_save_path = os.path.join(save_folder, "lora_model.pth")
    torch.save(model.state_dict(), model_save_path)
    print(f"Model saved at {model_save_path}")

    print("Training completed.")

# Gradio App
def start_training_gradio():
    print("Loading metadata and preparing dataset...")
    metadata = load_metadata(metadata_file)
    train_lora(image_folder, metadata)
    return "Training completed. Check the model outputs!"

demo = gr.Interface(
    fn=start_training_gradio,
    inputs=None,
    outputs="text",
    title="Train LoRA Model",
    description="Fine-tune a model using LoRA for consistent image generation."
)

demo.launch()