select-subset / app.py
meg's picture
meg HF Staff
DEV = False
ba26a99 verified
raw
history blame
5.12 kB
import os
import json
import csv
import gradio as gr
import random
import time
from collections import Counter
from numpy.random import choice
from datasets import load_dataset, Dataset
from PIL import PngImagePlugin, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
PngImagePlugin.MAX_TEXT_CHUNK = 1048576 * 10 # this is 10x the amount.
"""
This code is designed to read in the ImageNet 1K ILSVRC dataset from the Hugging Face Hub,
then create a new version of this dataset with {percentage} lines with random labels based on the observed frequencies,
then upload this new version of the Hugging Face Hub, in the Data Composition organization:
https://huggingface.co/datasets/datacomp
"""
# The number of examples/instances in this dataset is copied from the model card:
# https://huggingface.co/datasets/ILSVRC/imagenet-1k
NUM_EXAMPLES = 1281167
DEV = False
# Arbitrary small number of dataset examples to look at, only using in devv'ing.
DEV_AMOUNT = 10
if DEV:
NUM_EXAMPLES = DEV_AMOUNT
# Whether to read in the distribution over labels from an external text file.
READ_DISTRO = False
GATED_IMAGENET = os.environ.get("GATED_IMAGENET")
LABELS_FILE = "label_frequencies_full.csv"
def read_label_frequencies():
label_counts_dict = {}
header_row = ['Label', 'Frequency']
with open(LABELS_FILE) as csvfile:
label_reader = csv.DictReader(csvfile)
assert label_reader.fieldnames == header_row
for row in label_reader:
assert row['Label'] not in label_counts_dict
label_counts_dict[row['Label']] = int(row['Frequency'])
# TODO: Can we just do this instead of the fractions? Do they really need to be normalized?
# label_list, label_counts = zip(*label_counts_dict.items())
return label_counts_dict
def get_label_fractions(label_counts_dict):
print("Getting label proportions.")
label_list = list(label_counts_dict.keys())
denom = sum(label_counts_dict.values())
label_fractions = [label_counts_dict[key]/denom for key in label_counts_dict]
return label_list, label_fractions
def randomize_labels(examples, indices, new_random_labels):
# What set of examples should be randomized in this batch?
# This is the intersection of the batch indices and the indices we randomly selected to change the labels of.
batch_subset = list(set(indices) & randomize_subset)
# If this batch has indices that we're changing the label of....
if batch_subset != []:
# Change the label to a random integer between 0 and 9
for n in range(len(indices)):
index = indices[n]
examples["label"][n] = new_random_labels.pop() if index in batch_subset else examples["label"][n]
return examples
def main(percentage=10):
global randomize_subset
# Just for timing how long this takes.
start = time.time()
percentage = float(percentage)
print("Randomizing %d percent of the data." % percentage)
# Set the random seed, based on the percentage, so that our random changes are reproducible.
random.seed(percentage)
# Load the dataset from the HF hub. Use streaming so as not to load the entire dataset at once.
# Use the .take(DEV_AMOUNT) to only grab a small chunk of instances to develop with.
if DEV:
dataset = load_dataset("ILSVRC/imagenet-1k", split="train", streaming=True,
trust_remote_code=True, token=GATED_IMAGENET).take(DEV_AMOUNT)
else:
dataset = load_dataset("ILSVRC/imagenet-1k", split="train", streaming=True,
trust_remote_code=True, token=GATED_IMAGENET)
label_list, label_fractions = get_label_fractions(read_label_frequencies())
# How many new random labels are we creating?
num_new_labels = int(round(NUM_EXAMPLES/float(percentage)))
# Create a set of indices that are randomly chosen, to change their labels.
# Specifically, randomly choose num_new_labels indices.
randomize_subset = set(random.sample(range(0, NUM_EXAMPLES), num_new_labels))
# Randomly choose what the new label values are, following the observed label frequencies.
new_random_labels = list(choice(a=label_list, size=num_new_labels, p=label_fractions))
# Update the dataset so that the labels are randomized
updated_dataset = dataset.map(randomize_labels, with_indices=True,
fn_kwargs={'new_random_labels':new_random_labels},
features=dataset.features, batched=True)
# Upload the new version of the dataset (this will take awhile)
if DEV:
Dataset.from_generator(updated_dataset.__iter__).push_to_hub(
"datacomp/imagenet-1k-random-debug" + str(DEV_AMOUNT) + "-" + str(percentage), token=GATED_IMAGENET)
else:
Dataset.from_generator(updated_dataset.__iter__).push_to_hub(
"datacomp/imagenet-1k-random" + str(percentage), token=GATED_IMAGENET)
end = time.time()
print("That took %d seconds" % (end - start))
demo = gr.Interface(fn=main, inputs="text", outputs="text")
demo.launch()