meg HF Staff commited on
Commit
66357ea
·
verified ·
1 Parent(s): 03e9604

Reading in "label_frequencies_full.csv"

Browse files
Files changed (1) hide show
  1. app.py +21 -11
app.py CHANGED
@@ -25,21 +25,30 @@ NUM_EXAMPLES = 1281167
25
  # Arbitrary small number of dataset examples to look at, only using in devv'ing.
26
  DEV = True
27
  DEV_AMOUNT = 10
 
 
28
  # Whether to read in the distribution over labels from an external text file.
29
  READ_DISTRO = False
30
  GATED_IMAGENET = os.environ.get("GATED_IMAGENET")
 
31
 
32
- def get_label_fractions(dataset):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  print("Getting label proportions.")
34
- if READ_DISTRO:
35
- with open("label_distro.json", "r+") as f:
36
- label_counts = json.loads(f.read())
37
- else:
38
- label_counts = Counter([example['label'] for example in dataset])
39
- # Don't overrwrite the distribution when devving.
40
- if not DEV:
41
- with open("label_distro.json", "w+") as f:
42
- f.write(json.dumps(label_counts))
43
  label_list = list(label_counts.keys())
44
  denom = sum(label_counts.values())
45
  label_fractions = [label_counts[key]/denom for key in label_counts]
@@ -58,6 +67,7 @@ def randomize_labels(examples, indices, new_random_labels):
58
  examples["label"][n] = new_random_labels.pop() if index in batch_subset else examples["label"][n]
59
  return examples
60
 
 
61
  def main(percentage=10):
62
  global randomize_subset
63
  # Just for timing how long this takes.
@@ -77,7 +87,7 @@ def main(percentage=10):
77
  dataset = load_dataset("ILSVRC/imagenet-1k", split="train", streaming=True,
78
  trust_remote_code=True, token=GATED_IMAGENET)
79
 
80
- label_list, label_fractions = get_label_fractions(dataset)
81
 
82
  # How many new random labels are we creating?
83
  num_new_labels = int(round(NUM_EXAMPLES/float(percentage)))
 
25
  # Arbitrary small number of dataset examples to look at, only using in devv'ing.
26
  DEV = True
27
  DEV_AMOUNT = 10
28
+ if DEV:
29
+ NUM_EXAMPLES = DEV_AMOUNT
30
  # Whether to read in the distribution over labels from an external text file.
31
  READ_DISTRO = False
32
  GATED_IMAGENET = os.environ.get("GATED_IMAGENET")
33
+ LABELS_FILE = "label_frequencies_full.csv"
34
 
35
+
36
+ def read_label_frequencies():
37
+ label_counts_dict = {}
38
+ header_row = ['Label', 'Frequency']
39
+ with open(LABELS_FILE) as csvfile:
40
+ label_reader = csv.DictReader(csvfile)
41
+ assert label_reader.fieldnames == header_row
42
+ for row in label_reader:
43
+ assert row['Label'] not in label_counts_dict
44
+ label_counts_dict[row['Label']] = int(row['Frequency'])
45
+ # TODO: Can we just do this instead of the fractions? Do they really need to be normalized?
46
+ # label_list, label_counts = zip(*label_counts_dict.items())
47
+ return label_counts_dict
48
+
49
+
50
+ def get_label_fractions(label_counts_dict):
51
  print("Getting label proportions.")
 
 
 
 
 
 
 
 
 
52
  label_list = list(label_counts.keys())
53
  denom = sum(label_counts.values())
54
  label_fractions = [label_counts[key]/denom for key in label_counts]
 
67
  examples["label"][n] = new_random_labels.pop() if index in batch_subset else examples["label"][n]
68
  return examples
69
 
70
+
71
  def main(percentage=10):
72
  global randomize_subset
73
  # Just for timing how long this takes.
 
87
  dataset = load_dataset("ILSVRC/imagenet-1k", split="train", streaming=True,
88
  trust_remote_code=True, token=GATED_IMAGENET)
89
 
90
+ label_list, label_fractions = get_label_fractions(read_label_frequencies())
91
 
92
  # How many new random labels are we creating?
93
  num_new_labels = int(round(NUM_EXAMPLES/float(percentage)))