Segmentation / utils /data_loader.py
Edwin Salguero
Initial commit: SAM 2 Few-Shot/Zero-Shot Segmentation Research Framework
12fa055
"""
Data Loader Utilities
This module provides data loading utilities for different domains
(satellite, fashion, robotics) with support for few-shot and zero-shot learning.
"""
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import os
import json
from typing import List, Dict, Tuple, Optional
import random
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.transforms import functional as F
import cv2
class BaseDataLoader:
"""Base class for domain-specific data loaders."""
def __init__(self, data_dir: str, image_size: Tuple[int, int] = (512, 512)):
self.data_dir = data_dir
self.image_size = image_size
# Standard transforms
self.transform = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
self.mask_transform = transforms.Compose([
transforms.Resize(image_size, interpolation=transforms.InterpolationMode.NEAREST),
transforms.ToTensor()
])
def load_image(self, image_path: str) -> torch.Tensor:
"""Load and preprocess image."""
image = Image.open(image_path).convert('RGB')
return self.transform(image)
def load_mask(self, mask_path: str) -> torch.Tensor:
"""Load and preprocess mask."""
mask = Image.open(mask_path).convert('L')
return self.mask_transform(mask)
def get_random_sample(self) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Get a random sample from the dataset."""
raise NotImplementedError
def get_class_examples(self, class_name: str, num_examples: int) -> List[Tuple[torch.Tensor, torch.Tensor]]:
"""Get examples for a specific class."""
raise NotImplementedError
class SatelliteDataLoader(BaseDataLoader):
"""Data loader for satellite imagery segmentation."""
def __init__(self, data_dir: str, image_size: Tuple[int, int] = (512, 512)):
super().__init__(data_dir, image_size)
# Satellite-specific classes
self.classes = ["building", "road", "vegetation", "water"]
self.class_to_id = {cls: i for i, cls in enumerate(self.classes)}
# Load dataset structure
self.load_dataset_structure()
def load_dataset_structure(self):
"""Load dataset structure and file paths."""
self.images = []
self.masks = []
self.class_samples = {cls: [] for cls in self.classes}
# Assuming structure: data_dir/images/ and data_dir/masks/
images_dir = os.path.join(self.data_dir, "images")
masks_dir = os.path.join(self.data_dir, "masks")
if not os.path.exists(images_dir) or not os.path.exists(masks_dir):
# Create dummy data for demonstration
self.create_dummy_data()
return
# Load real data
for filename in os.listdir(images_dir):
if filename.endswith(('.jpg', '.png', '.tif')):
image_path = os.path.join(images_dir, filename)
mask_path = os.path.join(masks_dir, filename.replace('.jpg', '_mask.png'))
if os.path.exists(mask_path):
self.images.append(image_path)
self.masks.append(mask_path)
# Categorize by class (simplified)
self.categorize_sample(image_path, mask_path)
def create_dummy_data(self):
"""Create dummy satellite data for demonstration."""
print("Creating dummy satellite data...")
# Create dummy directory structure
os.makedirs(os.path.join(self.data_dir, "images"), exist_ok=True)
os.makedirs(os.path.join(self.data_dir, "masks"), exist_ok=True)
# Generate dummy images and masks
for i in range(100):
# Create dummy image (satellite-like)
image = np.random.randint(50, 200, (512, 512, 3), dtype=np.uint8)
# Add some structure to make it look like satellite imagery
# Buildings (rectangular shapes)
for _ in range(5):
x, y = np.random.randint(0, 400), np.random.randint(0, 400)
w, h = np.random.randint(20, 80), np.random.randint(20, 80)
image[y:y+h, x:x+w] = np.random.randint(100, 150, 3)
# Roads (linear structures)
for _ in range(3):
x, y = np.random.randint(0, 512), np.random.randint(0, 512)
length = np.random.randint(50, 150)
angle = np.random.uniform(0, 2*np.pi)
for j in range(length):
px = int(x + j * np.cos(angle))
py = int(y + j * np.sin(angle))
if 0 <= px < 512 and 0 <= py < 512:
image[py, px] = [80, 80, 80]
# Save image
image_path = os.path.join(self.data_dir, "images", f"satellite_{i:03d}.jpg")
Image.fromarray(image).save(image_path)
# Create corresponding mask
mask = np.zeros((512, 512), dtype=np.uint8)
# Add building masks
for _ in range(3):
x, y = np.random.randint(0, 400), np.random.randint(0, 400)
w, h = np.random.randint(20, 80), np.random.randint(20, 80)
mask[y:y+h, x:x+w] = 1 # Building class
# Add road masks
for _ in range(2):
x, y = np.random.randint(0, 512), np.random.randint(0, 512)
length = np.random.randint(50, 150)
angle = np.random.uniform(0, 2*np.pi)
for j in range(length):
px = int(x + j * np.cos(angle))
py = int(y + j * np.sin(angle))
if 0 <= px < 512 and 0 <= py < 512:
mask[py, px] = 2 # Road class
# Save mask
mask_path = os.path.join(self.data_dir, "masks", f"satellite_{i:03d}_mask.png")
Image.fromarray(mask * 85).save(mask_path) # Scale for visibility
# Add to lists
self.images.append(image_path)
self.masks.append(mask_path)
# Categorize
self.categorize_sample(image_path, mask_path)
def categorize_sample(self, image_path: str, mask_path: str):
"""Categorize sample by dominant class."""
mask = np.array(Image.open(mask_path))
# Count pixels for each class
class_counts = {}
for i, class_name in enumerate(self.classes):
class_counts[class_name] = np.sum(mask == i)
# Find dominant class
dominant_class = max(class_counts.items(), key=lambda x: x[1])[0]
self.class_samples[dominant_class].append((image_path, mask_path))
def get_random_query(self, class_name: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get a random query image and mask for a specific class."""
if class_name not in self.class_samples or not self.class_samples[class_name]:
# Fallback to any available sample
idx = random.randint(0, len(self.images) - 1)
image = self.load_image(self.images[idx])
mask = self.load_mask(self.masks[idx])
return image, mask
# Get random sample from specified class
image_path, mask_path = random.choice(self.class_samples[class_name])
image = self.load_image(image_path)
mask = self.load_mask(mask_path)
return image, mask
def get_class_examples(self, class_name: str, num_examples: int) -> List[Tuple[torch.Tensor, torch.Tensor]]:
"""Get examples for a specific class."""
examples = []
if class_name in self.class_samples:
available_samples = self.class_samples[class_name]
selected_samples = random.sample(available_samples, min(num_examples, len(available_samples)))
for image_path, mask_path in selected_samples:
image = self.load_image(image_path)
mask = self.load_mask(mask_path)
examples.append((image, mask))
return examples
class FashionDataLoader(BaseDataLoader):
"""Data loader for fashion segmentation."""
def __init__(self, data_dir: str, image_size: Tuple[int, int] = (512, 512)):
super().__init__(data_dir, image_size)
# Fashion-specific classes
self.classes = ["shirt", "pants", "dress", "shoes"]
self.class_to_id = {cls: i for i, cls in enumerate(self.classes)}
# Load dataset structure
self.load_dataset_structure()
def load_dataset_structure(self):
"""Load dataset structure and file paths."""
self.images = []
self.masks = []
self.class_samples = {cls: [] for cls in self.classes}
# Assuming structure: data_dir/images/ and data_dir/masks/
images_dir = os.path.join(self.data_dir, "images")
masks_dir = os.path.join(self.data_dir, "masks")
if not os.path.exists(images_dir) or not os.path.exists(masks_dir):
# Create dummy data for demonstration
self.create_dummy_data()
return
# Load real data
for filename in os.listdir(images_dir):
if filename.endswith(('.jpg', '.png')):
image_path = os.path.join(images_dir, filename)
mask_path = os.path.join(masks_dir, filename.replace('.jpg', '_mask.png'))
if os.path.exists(mask_path):
self.images.append(image_path)
self.masks.append(mask_path)
# Categorize by class
self.categorize_sample(image_path, mask_path)
def create_dummy_data(self):
"""Create dummy fashion data for demonstration."""
print("Creating dummy fashion data...")
# Create dummy directory structure
os.makedirs(os.path.join(self.data_dir, "images"), exist_ok=True)
os.makedirs(os.path.join(self.data_dir, "masks"), exist_ok=True)
# Generate dummy images and masks
for i in range(100):
# Create dummy image (fashion-like)
image = np.random.randint(200, 255, (512, 512, 3), dtype=np.uint8)
# Add fashion items
class_id = i % len(self.classes)
if class_id == 0: # Shirt
# Create shirt-like shape
center_x, center_y = 256, 256
width, height = 150, 200
image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [100, 150, 200]
elif class_id == 1: # Pants
# Create pants-like shape
center_x, center_y = 256, 300
width, height = 120, 180
image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [50, 100, 150]
elif class_id == 2: # Dress
# Create dress-like shape
center_x, center_y = 256, 250
width, height = 140, 220
image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [200, 100, 150]
else: # Shoes
# Create shoes-like shape
center_x, center_y = 256, 400
width, height = 100, 60
image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [80, 80, 80]
# Save image
image_path = os.path.join(self.data_dir, "images", f"fashion_{i:03d}.jpg")
Image.fromarray(image).save(image_path)
# Create corresponding mask
mask = np.zeros((512, 512), dtype=np.uint8)
# Add mask for the fashion item
if class_id == 0: # Shirt
center_x, center_y = 256, 256
width, height = 150, 200
mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 1
elif class_id == 1: # Pants
center_x, center_y = 256, 300
width, height = 120, 180
mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 2
elif class_id == 2: # Dress
center_x, center_y = 256, 250
width, height = 140, 220
mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 3
else: # Shoes
center_x, center_y = 256, 400
width, height = 100, 60
mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 4
# Save mask
mask_path = os.path.join(self.data_dir, "masks", f"fashion_{i:03d}_mask.png")
Image.fromarray(mask * 51).save(mask_path) # Scale for visibility
# Add to lists
self.images.append(image_path)
self.masks.append(mask_path)
# Categorize
self.categorize_sample(image_path, mask_path)
def categorize_sample(self, image_path: str, mask_path: str):
"""Categorize sample by dominant class."""
mask = np.array(Image.open(mask_path))
# Count pixels for each class
class_counts = {}
for i, class_name in enumerate(self.classes):
class_counts[class_name] = np.sum(mask == (i + 1)) # +1 because 0 is background
# Find dominant class
dominant_class = max(class_counts.items(), key=lambda x: x[1])[0]
self.class_samples[dominant_class].append((image_path, mask_path))
def get_test_sample(self) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Get a random test sample with ground truth masks."""
idx = random.randint(0, len(self.images) - 1)
image = self.load_image(self.images[idx])
mask = self.load_mask(self.masks[idx])
# Convert single mask to multi-class dictionary
ground_truth = {}
for i, class_name in enumerate(self.classes):
class_mask = (mask == (i + 1)).float() # +1 because 0 is background
ground_truth[class_name] = class_mask
return image, ground_truth
class RoboticsDataLoader(BaseDataLoader):
"""Data loader for robotics segmentation."""
def __init__(self, data_dir: str, image_size: Tuple[int, int] = (512, 512)):
super().__init__(data_dir, image_size)
# Robotics-specific classes
self.classes = ["robot", "tool", "safety"]
self.class_to_id = {cls: i for i, cls in enumerate(self.classes)}
# Load dataset structure
self.load_dataset_structure()
def load_dataset_structure(self):
"""Load dataset structure and file paths."""
self.images = []
self.masks = []
self.class_samples = {cls: [] for cls in self.classes}
# Assuming structure: data_dir/images/ and data_dir/masks/
images_dir = os.path.join(self.data_dir, "images")
masks_dir = os.path.join(self.data_dir, "masks")
if not os.path.exists(images_dir) or not os.path.exists(masks_dir):
# Create dummy data for demonstration
self.create_dummy_data()
return
# Load real data
for filename in os.listdir(images_dir):
if filename.endswith(('.jpg', '.png')):
image_path = os.path.join(images_dir, filename)
mask_path = os.path.join(masks_dir, filename.replace('.jpg', '_mask.png'))
if os.path.exists(mask_path):
self.images.append(image_path)
self.masks.append(mask_path)
# Categorize by class
self.categorize_sample(image_path, mask_path)
def create_dummy_data(self):
"""Create dummy robotics data for demonstration."""
print("Creating dummy robotics data...")
# Create dummy directory structure
os.makedirs(os.path.join(self.data_dir, "images"), exist_ok=True)
os.makedirs(os.path.join(self.data_dir, "masks"), exist_ok=True)
# Generate dummy images and masks
for i in range(100):
# Create dummy image (robotics-like)
image = np.random.randint(50, 150, (512, 512, 3), dtype=np.uint8)
# Add robotics elements
class_id = i % len(self.classes)
if class_id == 0: # Robot
# Create robot-like shape
center_x, center_y = 256, 256
width, height = 120, 160
image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [100, 100, 100]
elif class_id == 1: # Tool
# Create tool-like shape
center_x, center_y = 256, 256
width, height = 80, 120
image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [150, 100, 50]
else: # Safety equipment
# Create safety equipment-like shape
center_x, center_y = 256, 256
width, height = 100, 100
image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [200, 200, 50]
# Save image
image_path = os.path.join(self.data_dir, "images", f"robotics_{i:03d}.jpg")
Image.fromarray(image).save(image_path)
# Create corresponding mask
mask = np.zeros((512, 512), dtype=np.uint8)
# Add mask for the robotics element
if class_id == 0: # Robot
center_x, center_y = 256, 256
width, height = 120, 160
mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 1
elif class_id == 1: # Tool
center_x, center_y = 256, 256
width, height = 80, 120
mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 2
else: # Safety equipment
center_x, center_y = 256, 256
width, height = 100, 100
mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 3
# Save mask
mask_path = os.path.join(self.data_dir, "masks", f"robotics_{i:03d}_mask.png")
Image.fromarray(mask * 85).save(mask_path) # Scale for visibility
# Add to lists
self.images.append(image_path)
self.masks.append(mask_path)
# Categorize
self.categorize_sample(image_path, mask_path)
def categorize_sample(self, image_path: str, mask_path: str):
"""Categorize sample by dominant class."""
mask = np.array(Image.open(mask_path))
# Count pixels for each class
class_counts = {}
for i, class_name in enumerate(self.classes):
class_counts[class_name] = np.sum(mask == (i + 1)) # +1 because 0 is background
# Find dominant class
dominant_class = max(class_counts.items(), key=lambda x: x[1])[0]
self.class_samples[dominant_class].append((image_path, mask_path))
def get_test_sample(self) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Get a random test sample with ground truth masks."""
idx = random.randint(0, len(self.images) - 1)
image = self.load_image(self.images[idx])
mask = self.load_mask(self.masks[idx])
# Convert single mask to multi-class dictionary
ground_truth = {}
for i, class_name in enumerate(self.classes):
class_mask = (mask == (i + 1)).float() # +1 because 0 is background
ground_truth[class_name] = class_mask
return image, ground_truth