Testing / app.py
DonImages's picture
Update app.py
6e1e5e8 verified
raw
history blame
3.78 kB
import torch
from torch import nn, optim
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import json
import os
import gradio as gr
# Paths
image_folder = "Images/"
metadata_file = "descriptions.json"
# Define the function to load metadata
def load_metadata(metadata_file):
with open(metadata_file, 'r') as f:
metadata = json.load(f)
return metadata
# Custom Dataset Class
class ImageDescriptionDataset(Dataset):
def __init__(self, image_folder, metadata):
self.image_folder = image_folder
self.metadata = metadata
self.image_names = list(metadata.keys()) # List of image filenames
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def __len__(self):
return len(self.image_names)
def __getitem__(self, idx):
image_name = self.image_names[idx]
image_path = os.path.join(self.image_folder, image_name)
image = Image.open(image_path).convert("RGB") # Open the image and convert to RGB
description = self.metadata[image_name] # Get description for the image
image = self.transform(image) # Apply transformations
return image, description
# LoRA Model Class (This is a placeholder, you'll need to implement the actual LoRA model)
class LoRAModel(nn.Module):
def __init__(self):
super(LoRAModel, self).__init__()
self.backbone = models.resnet18(pretrained=True) # Using a pre-trained ResNet18
self.fc = nn.Linear(self.backbone.fc.in_features, 100) # Placeholder output layer
def forward(self, x):
x = self.backbone(x)
x = self.fc(x)
return x
# Function to train LoRA
def train_lora(image_folder, metadata):
print("Starting training process...")
# Create dataset and dataloaders
dataset = ImageDescriptionDataset(image_folder, metadata)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
# Initialize model, loss, and optimizer
model = LoRAModel()
criterion = nn.CrossEntropyLoss() # Placeholder loss function, can be adjusted
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training loop
num_epochs = 5 # Adjust the number of epochs based on your needs
for epoch in range(num_epochs):
print(f"Epoch {epoch + 1}/{num_epochs}")
for batch_idx, (images, descriptions) in enumerate(dataloader):
# Here we would convert descriptions to a numerical format
# Since it's a placeholder, we use random labels for descriptions
labels = torch.randint(0, 100, (images.size(0),)) # Random labels as a placeholder
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_idx % 10 == 0: # Log every 10 batches
print(f"Batch {batch_idx}, Loss: {loss.item()}")
print("Training completed.")
# Gradio app function to load metadata and start training
def start_training_gradio():
print("Preparing dataset...")
metadata = load_metadata(metadata_file) # Load metadata
return train_lora(image_folder, metadata)
# Gradio interface
demo = gr.Interface(
fn=start_training_gradio, # Use the new function name here
inputs=None,
outputs="text",
title="Train LoRA on Your Dataset",
description="Click below to start training with the uploaded images and metadata."
)
demo.launch()