Spaces:
Sleeping
Sleeping
| import argparse | |
| import collections | |
| import glob | |
| import os | |
| import random | |
| from typing import Iterator | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchdata as td | |
| import torchdata.datapipes as dp | |
| from imwatermark import WatermarkEncoder | |
| from PIL import ( | |
| Image, | |
| ImageFile, | |
| ) | |
| from torch.utils.data import ( | |
| DataLoader, | |
| RandomSampler, | |
| ) | |
| from torchdata.dataloader2 import ( | |
| DataLoader2, | |
| MultiProcessingReadingService, | |
| ) | |
| from torchdata.datapipes.iter import ( | |
| Concater, | |
| FileLister, | |
| FileOpener, | |
| SampleMultiplexer, | |
| ) | |
| from torchvision.transforms import v2 | |
| from tqdm import tqdm | |
| from utils_sampling import UnderSamplerIterDataPipe | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| Image.MAX_IMAGE_PIXELS = 1000000000 | |
| encoder = WatermarkEncoder() | |
| encoder.set_watermark("bytes", b"test") | |
| DOMAIN_LABELS = { | |
| 0: "laion", | |
| 1: "StableDiffusion", | |
| 2: "dalle2", | |
| 3: "dalle3", | |
| 4: "midjourney", | |
| } | |
| N_SAMPLES = { | |
| 0: (115346, 14418, 14419), | |
| 1: (22060, 2757, 2758), | |
| 4: (21096, 2637, 2637), | |
| 2: (13582, 1697, 1699), | |
| 3: (12027, 1503, 1504), | |
| } | |
| class WorkerResultCollector(dp.iter.IterDataPipe): | |
| def __init__(self, source: dp.iter.IterDataPipe): | |
| self.source = source | |
| def __iter__(self) -> Iterator: | |
| yield from self.source | |
| def is_replicable(self) -> bool: | |
| """Method to force data back to main process""" | |
| return False | |
| def crop_bottom(image, cutoff=16): | |
| return image[:, :-cutoff, :] | |
| def random_gaussian_blur(image, p=0.01): | |
| if random.random() < p: | |
| return v2.functional.gaussian_blur(image, kernel_size=5) | |
| return image | |
| def random_invisible_watermark(image, p=0.2): | |
| image_np = np.array(image) | |
| image_np = np.transpose(image_np, (1, 2, 0)) | |
| if image_np.ndim == 2: # Grayscale image | |
| image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2BGR) | |
| elif image_np.shape[2] == 4: # RGBA image | |
| image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2BGR) | |
| # print(image_np.shape) | |
| if image_np.shape[0] < 256 or image_np.shape[1] < 256: | |
| image_np = cv2.resize( | |
| image_np, (256, 256), interpolation=cv2.INTER_AREA | |
| ) | |
| if random.random() < p: | |
| return encoder.encode(image_np, method="dwtDct") | |
| return image_np | |
| def build_transform(split: str): | |
| train_transform = v2.Compose( | |
| [ | |
| v2.Lambda(crop_bottom), | |
| v2.RandomCrop((256, 256), pad_if_needed=True), | |
| v2.Lambda(random_gaussian_blur), | |
| v2.RandomGrayscale(p=0.05), | |
| v2.Lambda(random_invisible_watermark), | |
| v2.ToImage(), | |
| ] | |
| ) | |
| eval_transform = v2.Compose( | |
| [ | |
| v2.CenterCrop((256, 256)), | |
| ] | |
| ) | |
| transform = train_transform if split == "train" else eval_transform | |
| return transform | |
| def dp_to_tuple_train(input_dict): | |
| transform = build_transform("train") | |
| return ( | |
| transform(input_dict[".jpg"]), | |
| input_dict[".label.cls"], | |
| input_dict[".domain_label.cls"], | |
| ) | |
| def dp_to_tuple_eval(input_dict): | |
| transform = build_transform("eval") | |
| return ( | |
| transform(input_dict[".jpg"]), | |
| input_dict[".label.cls"], | |
| input_dict[".domain_label.cls"], | |
| ) | |
| def load_dataset(domains: list[int], split: str): | |
| laion_lister = FileLister("./data/laion400m_data", f"{split}*.tar") | |
| genai_lister = { | |
| d: FileLister( | |
| f"./data/genai-images/{DOMAIN_LABELS[d]}", f"{split}*.tar" | |
| ) | |
| for d in domains | |
| if DOMAIN_LABELS[d] != "laion" | |
| } | |
| weight_genai = 1 / len(genai_lister) | |
| def open_lister(lister): | |
| opener = FileOpener(lister, mode="b") | |
| return opener.load_from_tar().routed_decode().webdataset() | |
| buffer_size1 = 100 if split == "train" else 10 | |
| buffer_size2 = 100 if split == "train" else 10 | |
| if split != "train": | |
| all_lister = [laion_lister] + list(genai_lister.values()) | |
| dp = open_lister(Concater(*all_lister)).sharding_filter() | |
| else: | |
| laion_dp = ( | |
| open_lister(laion_lister.shuffle()) | |
| .cycle() | |
| .sharding_filter() | |
| .shuffle(buffer_size=buffer_size1) | |
| ) | |
| genai_dp = { | |
| open_lister(genai_lister[d].shuffle()) | |
| .cycle() | |
| .sharding_filter() | |
| .shuffle( | |
| buffer_size=buffer_size1, | |
| ): weight_genai | |
| for d in domains | |
| if DOMAIN_LABELS[d] != "laion" | |
| } | |
| dp = SampleMultiplexer({laion_dp: 1, **genai_dp}).shuffle( | |
| buffer_size=buffer_size2 | |
| ) | |
| if split == "train": | |
| dp = dp.map(dp_to_tuple_train) | |
| else: | |
| dp = dp.map(dp_to_tuple_eval) | |
| return dp | |
| def load_dataloader( | |
| domains: list[int], split: str, batch_size: int = 32, num_workers: int = 4 | |
| ): | |
| dp = load_dataset(domains, split) | |
| # if split == "train": | |
| # dp = UnderSamplerIterDataPipe(dp, {0: 0.5, 1: 0.5}, seed=42) | |
| dp = dp.batch(batch_size).collate() | |
| dl = DataLoader( | |
| dp, batch_size=None, num_workers=num_workers, pin_memory=True | |
| ) | |
| return dl | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| args = parser.parse_args() | |
| # testing code | |
| dl = load_dataloader([0, 1], "train", num_workers=8) | |
| y_dist = collections.Counter() | |
| d_dist = collections.Counter() | |
| for i, (img, y, d) in tqdm(enumerate(dl)): | |
| if i % 100 == 0: | |
| print(y, d) | |
| if i == 400: | |
| break | |
| y_dist.update(y.numpy()) | |
| d_dist.update(d.numpy()) | |
| print("class label") | |
| for label in sorted(y_dist): | |
| frequency = y_dist[label] / sum(y_dist.values()) | |
| print(f"• {label}: {frequency:.2%} ({y_dist[label]})") | |
| print("domain label") | |
| for label in sorted(d_dist): | |
| frequency = d_dist[label] / sum(d_dist.values()) | |
| print(f"• {label}: {frequency:.2%} ({d_dist[label]})") | |