Spaces:
Running
Running
import os | |
import json | |
from torchvision.datasets import ImageFolder | |
from torchvision import transforms | |
# Load domain configuration | |
config_path = "domain_config.json" | |
with open(config_path, "r") as f: | |
domain_config = json.load(f) | |
# Extract category names from domain_config.json | |
config_categories = list(domain_config.keys()) | |
# Path to categorized images folder | |
dataset_path = "categorized_images" | |
# Apply data augmentation | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.RandomHorizontalFlip(p=0.5), # Flip images randomly | |
transforms.RandomRotation(degrees=15), # Rotate images by up to 15 degrees | |
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # Adjust colors | |
transforms.RandomResizedCrop(size=(224, 224), scale=(0.8, 1.0)), # Random crop | |
transforms.ToTensor() | |
]) | |
# Load dataset with augmentation | |
dataset = ImageFolder(root=dataset_path, transform=transform) | |
# Extract dataset categories | |
dataset_categories = dataset.classes | |
# Check for inconsistencies | |
print("\n✅ Dataset Classes from ImageFolder:", dataset_categories) | |
print("\n✅ Categories in domain_config.json:", config_categories) | |
if set(dataset_categories) != set(config_categories): | |
print("\n⚠️ WARNING: Mismatch between dataset classes and domain_config.json!") | |
# Count images per category | |
print("\n📊 Image Count Per Category:") | |
for category, idx in dataset.class_to_idx.items(): | |
category_path = os.path.join(dataset_path, category) | |
num_images = len(os.listdir(category_path)) if os.path.exists(category_path) else 0 | |
print(f" - {category}: {num_images} images") | |
print("\n✅ All categories have images!") | |