|
import torch |
|
import numpy as np |
|
import os |
|
import torch.nn.functional as F |
|
from PIL import Image |
|
import time |
|
import random |
|
|
|
|
|
def generate_random_mask( |
|
batch_size, |
|
height=256, |
|
width=256, |
|
device='cuda', |
|
min_coverage=0.2, |
|
max_coverage=0.8, |
|
num_blobs_range=(1, 3) |
|
): |
|
""" |
|
Generate random blob masks for a batch of images. |
|
Fast GPU version with smooth, non-circular blob shapes. |
|
|
|
Args: |
|
batch_size (int): Number of masks to generate |
|
height (int): Height of the mask |
|
width (int): Width of the mask |
|
device (str): Device to run the computation on ('cuda' or 'cpu') |
|
min_coverage (float): Minimum percentage of the image to be covered (0-1) |
|
max_coverage (float): Maximum percentage of the image to be covered (0-1) |
|
num_blobs_range (tuple): Range of number of blobs (min, max) |
|
|
|
Returns: |
|
torch.Tensor: Binary masks with shape (batch_size, 1, height, width) |
|
""" |
|
|
|
masks = torch.zeros((batch_size, 1, height, width), device=device) |
|
|
|
|
|
y_indices = torch.arange(height, device=device).view( |
|
height, 1).expand(height, width) |
|
x_indices = torch.arange(width, device=device).view( |
|
1, width).expand(height, width) |
|
|
|
|
|
small_kernel = get_gaussian_kernel(7, 1.0).to(device) |
|
small_kernel = small_kernel.view(1, 1, 7, 7) |
|
|
|
large_kernel = get_gaussian_kernel(15, 2.5).to(device) |
|
large_kernel = large_kernel.view(1, 1, 15, 15) |
|
|
|
|
|
max_radius = min(height, width) // 3 |
|
min_radius = min(height, width) // 8 |
|
|
|
|
|
for b in range(batch_size): |
|
|
|
num_blobs = np.random.randint( |
|
num_blobs_range[0], num_blobs_range[1] + 1) |
|
|
|
|
|
target_coverage = np.random.uniform(min_coverage, max_coverage) |
|
|
|
|
|
mask = torch.zeros(1, 1, height, width, device=device) |
|
|
|
|
|
for _ in range(num_blobs): |
|
|
|
noise_field = torch.zeros(height, width, device=device) |
|
|
|
|
|
|
|
num_waves = np.random.randint(2, 5) |
|
for i in range(num_waves): |
|
freq_x = np.random.uniform(1.0, 3.0) * np.pi / width |
|
freq_y = np.random.uniform(1.0, 3.0) * np.pi / height |
|
phase_x = np.random.uniform(0, 2 * np.pi) |
|
phase_y = np.random.uniform(0, 2 * np.pi) |
|
amp = np.random.uniform(0.5, 1.0) * max_radius / (i+1.5) |
|
|
|
|
|
wave = torch.sin(x_indices * freq_x + phase_x) * \ |
|
torch.sin(y_indices * freq_y + phase_y) * amp |
|
noise_field += wave |
|
|
|
|
|
center_y = np.random.randint(height//4, 3*height//4) |
|
center_x = np.random.randint(width//4, 3*width//4) |
|
radius = np.random.randint(min_radius, max_radius) |
|
|
|
|
|
scale_y = np.random.uniform(0.6, 1.4) |
|
scale_x = np.random.uniform(0.6, 1.4) |
|
|
|
|
|
theta = np.random.uniform(0, 2 * np.pi) |
|
cos_theta, sin_theta = np.cos(theta), np.sin(theta) |
|
|
|
|
|
y_scaled = (y_indices - center_y) * scale_y |
|
x_scaled = (x_indices - center_x) * scale_x |
|
|
|
|
|
rotated_y = y_scaled * cos_theta - x_scaled * sin_theta |
|
rotated_x = y_scaled * sin_theta + x_scaled * cos_theta |
|
|
|
|
|
distances = torch.sqrt(rotated_y**2 + rotated_x**2) |
|
|
|
|
|
perturbed_distances = distances + noise_field |
|
|
|
|
|
blob = (perturbed_distances < radius).float( |
|
).unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
|
|
blob = F.pad(blob, (7, 7, 7, 7), mode='reflect') |
|
blob = F.conv2d(blob, large_kernel, padding=0) |
|
|
|
|
|
rand_threshold = np.random.uniform(0.3, 0.6) |
|
blob = (blob > rand_threshold).float() |
|
|
|
|
|
blob = F.pad(blob, (3, 3, 3, 3), mode='reflect') |
|
blob = F.conv2d(blob, small_kernel, padding=0) |
|
blob = (blob > 0.5).float() |
|
|
|
|
|
mask = torch.maximum(mask, blob) |
|
|
|
|
|
current_coverage = mask.mean().item() |
|
|
|
|
|
if current_coverage > 0: |
|
if current_coverage < target_coverage * 0.7: |
|
|
|
mask = F.pad(mask, (2, 2, 2, 2), mode='reflect') |
|
mask = F.max_pool2d(mask, kernel_size=5, stride=1, padding=0) |
|
elif current_coverage > target_coverage * 1.3: |
|
|
|
mask = F.pad(mask, (1, 1, 1, 1), mode='reflect') |
|
mask = F.avg_pool2d(mask, kernel_size=3, stride=1, padding=0) |
|
mask = (mask > 0.7).float() |
|
|
|
|
|
mask = F.pad(mask, (3, 3, 3, 3), mode='reflect') |
|
mask = F.conv2d(mask, small_kernel, padding=0) |
|
mask = (mask > 0.5).float() |
|
|
|
|
|
masks[b] = mask |
|
|
|
return masks |
|
|
|
|
|
def get_gaussian_kernel(kernel_size=5, sigma=1.0): |
|
""" |
|
Returns a 2D Gaussian kernel. |
|
""" |
|
|
|
x = torch.linspace(-sigma * 2, sigma * 2, kernel_size) |
|
x = x.view(1, -1).repeat(kernel_size, 1) |
|
y = x.transpose(0, 1) |
|
|
|
|
|
gaussian = torch.exp(-(x**2 + y**2) / (2 * sigma**2)) |
|
gaussian /= gaussian.sum() |
|
|
|
return gaussian |
|
|
|
|
|
def save_masks_as_images(masks, suffix="", output_dir="output"): |
|
""" |
|
Save generated masks as RGB JPG images using PIL. |
|
""" |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
batch_size = masks.shape[0] |
|
for i in range(batch_size): |
|
|
|
mask = masks[i, 0].cpu().numpy() |
|
|
|
|
|
mask_255 = (mask * 255).astype(np.uint8) |
|
|
|
|
|
rgb_mask = np.stack([mask_255, mask_255, mask_255], axis=2) |
|
|
|
|
|
img = Image.fromarray(rgb_mask) |
|
img.save(os.path.join(output_dir, f"mask_{i:03d}{suffix}.jpg"), quality=95) |
|
|
|
|
|
def random_dialate_mask(mask, max_percent=0.05): |
|
""" |
|
Randomly dialates a binary mask with a kernel of random size. |
|
|
|
Args: |
|
mask (torch.Tensor): Input mask of shape [batch_size, channels, height, width] |
|
max_percent (float): Maximum kernel size as a percentage of the mask size |
|
|
|
Returns: |
|
torch.Tensor: Dialated mask with the same shape as input |
|
""" |
|
|
|
size = mask.shape[-1] |
|
max_size = int(size * max_percent) |
|
|
|
|
|
if max_size < 3: |
|
max_size = 3 |
|
|
|
batch_chunks = torch.chunk(mask, mask.shape[0], dim=0) |
|
out_chunks = [] |
|
|
|
for i in range(len(batch_chunks)): |
|
chunk = batch_chunks[i] |
|
|
|
|
|
kernel_size = np.random.randint(1, max_size) |
|
|
|
|
|
if kernel_size < 2: |
|
out_chunks.append(chunk) |
|
continue |
|
|
|
|
|
if kernel_size % 2 == 0: |
|
kernel_size += 1 |
|
|
|
|
|
kernel = torch.ones((1, 1, kernel_size, kernel_size), device=mask.device) / (kernel_size * kernel_size) |
|
|
|
|
|
padding = kernel_size // 2 |
|
padded_mask = F.pad(chunk, (padding, padding, padding, padding), mode='constant', value=0) |
|
|
|
|
|
dilated = F.conv2d(padded_mask, kernel) |
|
|
|
|
|
threshold = np.random.uniform(0.2, 0.8) |
|
|
|
|
|
dilated = (dilated > threshold).float() |
|
|
|
out_chunks.append(dilated) |
|
|
|
return torch.cat(out_chunks, dim=0) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
batch_size = 20 |
|
height = 256 |
|
width = 256 |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
print(f"Generating {batch_size} random blob masks on {device}...") |
|
|
|
for i in range(5): |
|
|
|
start = time.time() |
|
masks = generate_random_mask( |
|
batch_size=batch_size, |
|
height=height, |
|
width=width, |
|
device=device, |
|
min_coverage=0.2, |
|
max_coverage=0.8, |
|
num_blobs_range=(1, 3) |
|
) |
|
dialation = random_dialate_mask(masks) |
|
print(f"Generated {batch_size} masks with shape: {masks.shape}") |
|
end = time.time() |
|
|
|
print(f"Time taken: {(end - start)*1000:.2f} ms") |
|
|
|
print(f"Saving masks to 'output' directory...") |
|
save_masks_as_images(masks) |
|
save_masks_as_images(dialation, suffix="_dilated" ) |
|
|
|
print("Done!") |
|
|