|
""" |
|
Data Loader Utilities |
|
|
|
This module provides data loading utilities for different domains |
|
(satellite, fashion, robotics) with support for few-shot and zero-shot learning. |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
from PIL import Image |
|
import os |
|
import json |
|
from typing import List, Dict, Tuple, Optional |
|
import random |
|
from torch.utils.data import Dataset, DataLoader |
|
import torchvision.transforms as transforms |
|
from torchvision.transforms import functional as F |
|
import cv2 |
|
|
|
|
|
class BaseDataLoader: |
|
"""Base class for domain-specific data loaders.""" |
|
|
|
def __init__(self, data_dir: str, image_size: Tuple[int, int] = (512, 512)): |
|
self.data_dir = data_dir |
|
self.image_size = image_size |
|
|
|
|
|
self.transform = transforms.Compose([ |
|
transforms.Resize(image_size), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
self.mask_transform = transforms.Compose([ |
|
transforms.Resize(image_size, interpolation=transforms.InterpolationMode.NEAREST), |
|
transforms.ToTensor() |
|
]) |
|
|
|
def load_image(self, image_path: str) -> torch.Tensor: |
|
"""Load and preprocess image.""" |
|
image = Image.open(image_path).convert('RGB') |
|
return self.transform(image) |
|
|
|
def load_mask(self, mask_path: str) -> torch.Tensor: |
|
"""Load and preprocess mask.""" |
|
mask = Image.open(mask_path).convert('L') |
|
return self.mask_transform(mask) |
|
|
|
def get_random_sample(self) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
|
"""Get a random sample from the dataset.""" |
|
raise NotImplementedError |
|
|
|
def get_class_examples(self, class_name: str, num_examples: int) -> List[Tuple[torch.Tensor, torch.Tensor]]: |
|
"""Get examples for a specific class.""" |
|
raise NotImplementedError |
|
|
|
|
|
class SatelliteDataLoader(BaseDataLoader): |
|
"""Data loader for satellite imagery segmentation.""" |
|
|
|
def __init__(self, data_dir: str, image_size: Tuple[int, int] = (512, 512)): |
|
super().__init__(data_dir, image_size) |
|
|
|
|
|
self.classes = ["building", "road", "vegetation", "water"] |
|
self.class_to_id = {cls: i for i, cls in enumerate(self.classes)} |
|
|
|
|
|
self.load_dataset_structure() |
|
|
|
def load_dataset_structure(self): |
|
"""Load dataset structure and file paths.""" |
|
self.images = [] |
|
self.masks = [] |
|
self.class_samples = {cls: [] for cls in self.classes} |
|
|
|
|
|
images_dir = os.path.join(self.data_dir, "images") |
|
masks_dir = os.path.join(self.data_dir, "masks") |
|
|
|
if not os.path.exists(images_dir) or not os.path.exists(masks_dir): |
|
|
|
self.create_dummy_data() |
|
return |
|
|
|
|
|
for filename in os.listdir(images_dir): |
|
if filename.endswith(('.jpg', '.png', '.tif')): |
|
image_path = os.path.join(images_dir, filename) |
|
mask_path = os.path.join(masks_dir, filename.replace('.jpg', '_mask.png')) |
|
|
|
if os.path.exists(mask_path): |
|
self.images.append(image_path) |
|
self.masks.append(mask_path) |
|
|
|
|
|
self.categorize_sample(image_path, mask_path) |
|
|
|
def create_dummy_data(self): |
|
"""Create dummy satellite data for demonstration.""" |
|
print("Creating dummy satellite data...") |
|
|
|
|
|
os.makedirs(os.path.join(self.data_dir, "images"), exist_ok=True) |
|
os.makedirs(os.path.join(self.data_dir, "masks"), exist_ok=True) |
|
|
|
|
|
for i in range(100): |
|
|
|
image = np.random.randint(50, 200, (512, 512, 3), dtype=np.uint8) |
|
|
|
|
|
|
|
for _ in range(5): |
|
x, y = np.random.randint(0, 400), np.random.randint(0, 400) |
|
w, h = np.random.randint(20, 80), np.random.randint(20, 80) |
|
image[y:y+h, x:x+w] = np.random.randint(100, 150, 3) |
|
|
|
|
|
for _ in range(3): |
|
x, y = np.random.randint(0, 512), np.random.randint(0, 512) |
|
length = np.random.randint(50, 150) |
|
angle = np.random.uniform(0, 2*np.pi) |
|
for j in range(length): |
|
px = int(x + j * np.cos(angle)) |
|
py = int(y + j * np.sin(angle)) |
|
if 0 <= px < 512 and 0 <= py < 512: |
|
image[py, px] = [80, 80, 80] |
|
|
|
|
|
image_path = os.path.join(self.data_dir, "images", f"satellite_{i:03d}.jpg") |
|
Image.fromarray(image).save(image_path) |
|
|
|
|
|
mask = np.zeros((512, 512), dtype=np.uint8) |
|
|
|
|
|
for _ in range(3): |
|
x, y = np.random.randint(0, 400), np.random.randint(0, 400) |
|
w, h = np.random.randint(20, 80), np.random.randint(20, 80) |
|
mask[y:y+h, x:x+w] = 1 |
|
|
|
|
|
for _ in range(2): |
|
x, y = np.random.randint(0, 512), np.random.randint(0, 512) |
|
length = np.random.randint(50, 150) |
|
angle = np.random.uniform(0, 2*np.pi) |
|
for j in range(length): |
|
px = int(x + j * np.cos(angle)) |
|
py = int(y + j * np.sin(angle)) |
|
if 0 <= px < 512 and 0 <= py < 512: |
|
mask[py, px] = 2 |
|
|
|
|
|
mask_path = os.path.join(self.data_dir, "masks", f"satellite_{i:03d}_mask.png") |
|
Image.fromarray(mask * 85).save(mask_path) |
|
|
|
|
|
self.images.append(image_path) |
|
self.masks.append(mask_path) |
|
|
|
|
|
self.categorize_sample(image_path, mask_path) |
|
|
|
def categorize_sample(self, image_path: str, mask_path: str): |
|
"""Categorize sample by dominant class.""" |
|
mask = np.array(Image.open(mask_path)) |
|
|
|
|
|
class_counts = {} |
|
for i, class_name in enumerate(self.classes): |
|
class_counts[class_name] = np.sum(mask == i) |
|
|
|
|
|
dominant_class = max(class_counts.items(), key=lambda x: x[1])[0] |
|
self.class_samples[dominant_class].append((image_path, mask_path)) |
|
|
|
def get_random_query(self, class_name: str) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Get a random query image and mask for a specific class.""" |
|
if class_name not in self.class_samples or not self.class_samples[class_name]: |
|
|
|
idx = random.randint(0, len(self.images) - 1) |
|
image = self.load_image(self.images[idx]) |
|
mask = self.load_mask(self.masks[idx]) |
|
return image, mask |
|
|
|
|
|
image_path, mask_path = random.choice(self.class_samples[class_name]) |
|
image = self.load_image(image_path) |
|
mask = self.load_mask(mask_path) |
|
|
|
return image, mask |
|
|
|
def get_class_examples(self, class_name: str, num_examples: int) -> List[Tuple[torch.Tensor, torch.Tensor]]: |
|
"""Get examples for a specific class.""" |
|
examples = [] |
|
|
|
if class_name in self.class_samples: |
|
available_samples = self.class_samples[class_name] |
|
selected_samples = random.sample(available_samples, min(num_examples, len(available_samples))) |
|
|
|
for image_path, mask_path in selected_samples: |
|
image = self.load_image(image_path) |
|
mask = self.load_mask(mask_path) |
|
examples.append((image, mask)) |
|
|
|
return examples |
|
|
|
|
|
class FashionDataLoader(BaseDataLoader): |
|
"""Data loader for fashion segmentation.""" |
|
|
|
def __init__(self, data_dir: str, image_size: Tuple[int, int] = (512, 512)): |
|
super().__init__(data_dir, image_size) |
|
|
|
|
|
self.classes = ["shirt", "pants", "dress", "shoes"] |
|
self.class_to_id = {cls: i for i, cls in enumerate(self.classes)} |
|
|
|
|
|
self.load_dataset_structure() |
|
|
|
def load_dataset_structure(self): |
|
"""Load dataset structure and file paths.""" |
|
self.images = [] |
|
self.masks = [] |
|
self.class_samples = {cls: [] for cls in self.classes} |
|
|
|
|
|
images_dir = os.path.join(self.data_dir, "images") |
|
masks_dir = os.path.join(self.data_dir, "masks") |
|
|
|
if not os.path.exists(images_dir) or not os.path.exists(masks_dir): |
|
|
|
self.create_dummy_data() |
|
return |
|
|
|
|
|
for filename in os.listdir(images_dir): |
|
if filename.endswith(('.jpg', '.png')): |
|
image_path = os.path.join(images_dir, filename) |
|
mask_path = os.path.join(masks_dir, filename.replace('.jpg', '_mask.png')) |
|
|
|
if os.path.exists(mask_path): |
|
self.images.append(image_path) |
|
self.masks.append(mask_path) |
|
|
|
|
|
self.categorize_sample(image_path, mask_path) |
|
|
|
def create_dummy_data(self): |
|
"""Create dummy fashion data for demonstration.""" |
|
print("Creating dummy fashion data...") |
|
|
|
|
|
os.makedirs(os.path.join(self.data_dir, "images"), exist_ok=True) |
|
os.makedirs(os.path.join(self.data_dir, "masks"), exist_ok=True) |
|
|
|
|
|
for i in range(100): |
|
|
|
image = np.random.randint(200, 255, (512, 512, 3), dtype=np.uint8) |
|
|
|
|
|
class_id = i % len(self.classes) |
|
|
|
if class_id == 0: |
|
|
|
center_x, center_y = 256, 256 |
|
width, height = 150, 200 |
|
image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [100, 150, 200] |
|
|
|
elif class_id == 1: |
|
|
|
center_x, center_y = 256, 300 |
|
width, height = 120, 180 |
|
image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [50, 100, 150] |
|
|
|
elif class_id == 2: |
|
|
|
center_x, center_y = 256, 250 |
|
width, height = 140, 220 |
|
image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [200, 100, 150] |
|
|
|
else: |
|
|
|
center_x, center_y = 256, 400 |
|
width, height = 100, 60 |
|
image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [80, 80, 80] |
|
|
|
|
|
image_path = os.path.join(self.data_dir, "images", f"fashion_{i:03d}.jpg") |
|
Image.fromarray(image).save(image_path) |
|
|
|
|
|
mask = np.zeros((512, 512), dtype=np.uint8) |
|
|
|
|
|
if class_id == 0: |
|
center_x, center_y = 256, 256 |
|
width, height = 150, 200 |
|
mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 1 |
|
|
|
elif class_id == 1: |
|
center_x, center_y = 256, 300 |
|
width, height = 120, 180 |
|
mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 2 |
|
|
|
elif class_id == 2: |
|
center_x, center_y = 256, 250 |
|
width, height = 140, 220 |
|
mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 3 |
|
|
|
else: |
|
center_x, center_y = 256, 400 |
|
width, height = 100, 60 |
|
mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 4 |
|
|
|
|
|
mask_path = os.path.join(self.data_dir, "masks", f"fashion_{i:03d}_mask.png") |
|
Image.fromarray(mask * 51).save(mask_path) |
|
|
|
|
|
self.images.append(image_path) |
|
self.masks.append(mask_path) |
|
|
|
|
|
self.categorize_sample(image_path, mask_path) |
|
|
|
def categorize_sample(self, image_path: str, mask_path: str): |
|
"""Categorize sample by dominant class.""" |
|
mask = np.array(Image.open(mask_path)) |
|
|
|
|
|
class_counts = {} |
|
for i, class_name in enumerate(self.classes): |
|
class_counts[class_name] = np.sum(mask == (i + 1)) |
|
|
|
|
|
dominant_class = max(class_counts.items(), key=lambda x: x[1])[0] |
|
self.class_samples[dominant_class].append((image_path, mask_path)) |
|
|
|
def get_test_sample(self) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
|
"""Get a random test sample with ground truth masks.""" |
|
idx = random.randint(0, len(self.images) - 1) |
|
image = self.load_image(self.images[idx]) |
|
mask = self.load_mask(self.masks[idx]) |
|
|
|
|
|
ground_truth = {} |
|
for i, class_name in enumerate(self.classes): |
|
class_mask = (mask == (i + 1)).float() |
|
ground_truth[class_name] = class_mask |
|
|
|
return image, ground_truth |
|
|
|
|
|
class RoboticsDataLoader(BaseDataLoader): |
|
"""Data loader for robotics segmentation.""" |
|
|
|
def __init__(self, data_dir: str, image_size: Tuple[int, int] = (512, 512)): |
|
super().__init__(data_dir, image_size) |
|
|
|
|
|
self.classes = ["robot", "tool", "safety"] |
|
self.class_to_id = {cls: i for i, cls in enumerate(self.classes)} |
|
|
|
|
|
self.load_dataset_structure() |
|
|
|
def load_dataset_structure(self): |
|
"""Load dataset structure and file paths.""" |
|
self.images = [] |
|
self.masks = [] |
|
self.class_samples = {cls: [] for cls in self.classes} |
|
|
|
|
|
images_dir = os.path.join(self.data_dir, "images") |
|
masks_dir = os.path.join(self.data_dir, "masks") |
|
|
|
if not os.path.exists(images_dir) or not os.path.exists(masks_dir): |
|
|
|
self.create_dummy_data() |
|
return |
|
|
|
|
|
for filename in os.listdir(images_dir): |
|
if filename.endswith(('.jpg', '.png')): |
|
image_path = os.path.join(images_dir, filename) |
|
mask_path = os.path.join(masks_dir, filename.replace('.jpg', '_mask.png')) |
|
|
|
if os.path.exists(mask_path): |
|
self.images.append(image_path) |
|
self.masks.append(mask_path) |
|
|
|
|
|
self.categorize_sample(image_path, mask_path) |
|
|
|
def create_dummy_data(self): |
|
"""Create dummy robotics data for demonstration.""" |
|
print("Creating dummy robotics data...") |
|
|
|
|
|
os.makedirs(os.path.join(self.data_dir, "images"), exist_ok=True) |
|
os.makedirs(os.path.join(self.data_dir, "masks"), exist_ok=True) |
|
|
|
|
|
for i in range(100): |
|
|
|
image = np.random.randint(50, 150, (512, 512, 3), dtype=np.uint8) |
|
|
|
|
|
class_id = i % len(self.classes) |
|
|
|
if class_id == 0: |
|
|
|
center_x, center_y = 256, 256 |
|
width, height = 120, 160 |
|
image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [100, 100, 100] |
|
|
|
elif class_id == 1: |
|
|
|
center_x, center_y = 256, 256 |
|
width, height = 80, 120 |
|
image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [150, 100, 50] |
|
|
|
else: |
|
|
|
center_x, center_y = 256, 256 |
|
width, height = 100, 100 |
|
image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [200, 200, 50] |
|
|
|
|
|
image_path = os.path.join(self.data_dir, "images", f"robotics_{i:03d}.jpg") |
|
Image.fromarray(image).save(image_path) |
|
|
|
|
|
mask = np.zeros((512, 512), dtype=np.uint8) |
|
|
|
|
|
if class_id == 0: |
|
center_x, center_y = 256, 256 |
|
width, height = 120, 160 |
|
mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 1 |
|
|
|
elif class_id == 1: |
|
center_x, center_y = 256, 256 |
|
width, height = 80, 120 |
|
mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 2 |
|
|
|
else: |
|
center_x, center_y = 256, 256 |
|
width, height = 100, 100 |
|
mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 3 |
|
|
|
|
|
mask_path = os.path.join(self.data_dir, "masks", f"robotics_{i:03d}_mask.png") |
|
Image.fromarray(mask * 85).save(mask_path) |
|
|
|
|
|
self.images.append(image_path) |
|
self.masks.append(mask_path) |
|
|
|
|
|
self.categorize_sample(image_path, mask_path) |
|
|
|
def categorize_sample(self, image_path: str, mask_path: str): |
|
"""Categorize sample by dominant class.""" |
|
mask = np.array(Image.open(mask_path)) |
|
|
|
|
|
class_counts = {} |
|
for i, class_name in enumerate(self.classes): |
|
class_counts[class_name] = np.sum(mask == (i + 1)) |
|
|
|
|
|
dominant_class = max(class_counts.items(), key=lambda x: x[1])[0] |
|
self.class_samples[dominant_class].append((image_path, mask_path)) |
|
|
|
def get_test_sample(self) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
|
"""Get a random test sample with ground truth masks.""" |
|
idx = random.randint(0, len(self.images) - 1) |
|
image = self.load_image(self.images[idx]) |
|
mask = self.load_mask(self.masks[idx]) |
|
|
|
|
|
ground_truth = {} |
|
for i, class_name in enumerate(self.classes): |
|
class_mask = (mask == (i + 1)).float() |
|
ground_truth[class_name] = class_mask |
|
|
|
return image, ground_truth |