Spaces:
Sleeping
Sleeping
| import glob | |
| import json | |
| import webdataset as wds | |
| def split_dataset(path, n_train, n_val, n_test, label, domain_label): | |
| max_file_size = 1000 | |
| input_files = glob.glob(path + "/*.tar") | |
| src = wds.WebDataset(input_files) | |
| train_path_prefix = path + "/train" | |
| val_path_prefix = path + "/val" | |
| test_path_prefix = path + "/test" | |
| def write_split(dataset, prefix, start, end): | |
| n_split = end - start | |
| output_files = [ | |
| f"{prefix}_{i}.tar" for i in range(n_split // max_file_size + 1) | |
| ] | |
| for i, output_file in enumerate(output_files): | |
| print(f"Writing {output_file}") | |
| with wds.TarWriter(output_file) as dst: | |
| for sample in dataset.slice( | |
| start + i * max_file_size, | |
| min(start + (i + 1) * max_file_size, end), | |
| ): | |
| new_sample = { | |
| "__key__": sample["__key__"], | |
| "jpg": sample["jpg"], | |
| "label.cls": label, | |
| "domain_label.cls": domain_label, | |
| } | |
| dst.write(new_sample) | |
| write_split(src, train_path_prefix, 0, n_train) | |
| write_split(src, val_path_prefix, n_train, n_train + n_val) | |
| write_split( | |
| src, | |
| test_path_prefix, | |
| n_train + n_val, | |
| n_train + n_val + n_test, | |
| ) | |
| def calculate_sizes(path): | |
| stat_files = glob.glob(path + "/*_stats.json") | |
| total = 0 | |
| for f in stat_files: | |
| with open(f) as stats: | |
| total += json.load(stats)["successes"] | |
| n_train = int(total * 0.8) | |
| n_val = int(total * 0.1) | |
| n_test = total - n_train - n_val | |
| return n_train, n_val, n_test | |
| if __name__ == "__main__": | |
| paths = [ | |
| "./data/laion400m_data", | |
| "./data/genai-images/StableDiffusion", | |
| "./data/genai-images/midjourney", | |
| "./data/genai-images/dalle2", | |
| "./data/genai-images/dalle3", | |
| ] | |
| sizes = [] | |
| for p in paths: | |
| res = calculate_sizes(p) | |
| sizes.append(res) | |
| domain_labels = [0, 1, 4, 2, 3] | |
| for i, p in enumerate(paths): | |
| print(f"{p}: {sizes[i]}") | |
| label = 0 if i == 0 else 1 | |
| print(label, domain_labels[i]) | |
| split_dataset(p, *calculate_sizes(p), label, domain_labels[i]) | |