meg HF Staff commited on
Commit
3d85131
·
verified ·
1 Parent(s): 3096e07

Adding new code to randomly sample from the observed distribution over labels.

Browse files
Files changed (1) hide show
  1. app.py +39 -9
app.py CHANGED
@@ -1,8 +1,14 @@
1
  import os
 
 
2
  import gradio as gr
3
  import random
4
  import time
 
 
 
5
  from datasets import load_dataset, Dataset
 
6
  from PIL import PngImagePlugin, ImageFile
7
  ImageFile.LOAD_TRUNCATED_IMAGES = True
8
  PngImagePlugin.MAX_TEXT_CHUNK = 1048576 * 10 # this is 10x the amount.
@@ -16,9 +22,11 @@ https://huggingface.co/datasets/datacomp
16
  # The number of examples/instances in this dataset is copied from the model card:
17
  # https://huggingface.co/datasets/ILSVRC/imagenet-1k
18
  NUM_EXAMPLES = 1281167
19
- # Arbitrary small number, only using in devv'ing (uncomment #.take(DEV_AMOUNT) below to use it).
20
- DEV = False
21
  DEV_AMOUNT = 10
 
 
22
  GATED_IMAGENET = os.environ.get("GATED_IMAGENET")
23
 
24
 
@@ -41,10 +49,17 @@ def main(percentage=10):
41
  dataset = load_dataset("ILSVRC/imagenet-1k", split="train", streaming=True,
42
  trust_remote_code=True, token=GATED_IMAGENET)
43
 
 
 
 
 
 
44
  # Create a set of indices that are randomly chosen, to change their labels.
45
- # Specifically, randomly choose NUM_EXAMPLES/percentage indices.
46
- randomize_subset = set(random.sample(range(0, NUM_EXAMPLES), round(
47
- NUM_EXAMPLES / float(percentage))))
 
 
48
 
49
  # Update the dataset so that the labels are randomized
50
  updated_dataset = dataset.map(randomize_labels, with_indices=True,
@@ -53,7 +68,7 @@ def main(percentage=10):
53
  # Upload the new version of the dataset (this will take awhile)
54
  if DEV:
55
  Dataset.from_generator(updated_dataset.__iter__).push_to_hub(
56
- "datacomp/imagenet-1k-random-debug" + str(percentage), token=GATED_IMAGENET)
57
  else:
58
  Dataset.from_generator(updated_dataset.__iter__).push_to_hub(
59
  "datacomp/imagenet-1k-random" + str(percentage), token=GATED_IMAGENET)
@@ -63,6 +78,23 @@ def main(percentage=10):
63
  print("That took %d seconds" % (end - start))
64
 
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  def randomize_labels(examples, indices):
67
  # What set of examples should be randomized in this batch?
68
  # This is the intersection of the batch indices and the indices we randomly selected to change the labels of.
@@ -72,9 +104,7 @@ def randomize_labels(examples, indices):
72
  # Change the label to a random integer between 0 and 9
73
  for n in range(len(indices)):
74
  index = indices[n]
75
- examples["label"][n] = random.randint(0,
76
- 9) if index in batch_subset else \
77
- examples["label"][n]
78
  return examples
79
 
80
  demo = gr.Interface(fn=main, inputs="text", outputs="text")
 
1
  import os
2
+ import json
3
+
4
  import gradio as gr
5
  import random
6
  import time
7
+
8
+ from collections import Counter
9
+ from numpy.random import choice
10
  from datasets import load_dataset, Dataset
11
+
12
  from PIL import PngImagePlugin, ImageFile
13
  ImageFile.LOAD_TRUNCATED_IMAGES = True
14
  PngImagePlugin.MAX_TEXT_CHUNK = 1048576 * 10 # this is 10x the amount.
 
22
  # The number of examples/instances in this dataset is copied from the model card:
23
  # https://huggingface.co/datasets/ILSVRC/imagenet-1k
24
  NUM_EXAMPLES = 1281167
25
+ # Arbitrary small number of dataset examples to look at, only using in devv'ing.
26
+ DEV = True
27
  DEV_AMOUNT = 10
28
+ # Whether to read in the distribution over labels from an external text file.
29
+ READ_DISTRO = False
30
  GATED_IMAGENET = os.environ.get("GATED_IMAGENET")
31
 
32
 
 
49
  dataset = load_dataset("ILSVRC/imagenet-1k", split="train", streaming=True,
50
  trust_remote_code=True, token=GATED_IMAGENET)
51
 
52
+ label_list, label_fractions = get_label_fractions()
53
+
54
+ # How many new random labels are we creating?
55
+ num_new_labels = int(round(NUM_EXAMPLES/float(percentage)))
56
+
57
  # Create a set of indices that are randomly chosen, to change their labels.
58
+ # Specifically, randomly choose num_new_labels indices.
59
+ randomize_subset = set(random.sample(range(0, NUM_EXAMPLES), num_new_labels))
60
+
61
+ # Randomly choose what the new label values are, following the observed label frequencies.
62
+ new_random_labels = list(choice(a=label_keys, size=num_new_labels, p=label_fractions))
63
 
64
  # Update the dataset so that the labels are randomized
65
  updated_dataset = dataset.map(randomize_labels, with_indices=True,
 
68
  # Upload the new version of the dataset (this will take awhile)
69
  if DEV:
70
  Dataset.from_generator(updated_dataset.__iter__).push_to_hub(
71
+ "datacomp/imagenet-1k-random-debug" + str(DEV_AMOUNT) + "-" + str(percentage), token=GATED_IMAGENET)
72
  else:
73
  Dataset.from_generator(updated_dataset.__iter__).push_to_hub(
74
  "datacomp/imagenet-1k-random" + str(percentage), token=GATED_IMAGENET)
 
78
  print("That took %d seconds" % (end - start))
79
 
80
 
81
+ def get_label_fractions(dataset):
82
+ print("Getting label proportions.")
83
+ if READ_DISTRO:
84
+ with open("label_distro.json", "r+") as f:
85
+ label_counts = json.loads(f.read())
86
+ else:
87
+ label_counts = Counter([example['label'] for example in dataset])
88
+ # Don't overrwrite the distribution when devving.
89
+ if not DEV:
90
+ with open("label_distro.json", "w+") as f:
91
+ f.write(json.dumps(label_counts))
92
+ label_list = list(label_counts.keys())
93
+ denom = sum(label_counts.values())
94
+ label_fractions = [label_counts[key]/denom for key in label_keys]
95
+ return label_list, label_fractions
96
+
97
+
98
  def randomize_labels(examples, indices):
99
  # What set of examples should be randomized in this batch?
100
  # This is the intersection of the batch indices and the indices we randomly selected to change the labels of.
 
104
  # Change the label to a random integer between 0 and 9
105
  for n in range(len(indices)):
106
  index = indices[n]
107
+ examples["label"][n] = new_random_labels.pop() if index in batch_subset else examples["label"][n]
 
 
108
  return examples
109
 
110
  demo = gr.Interface(fn=main, inputs="text", outputs="text")