Spaces:
Running
Running
Reading in "label_frequencies_full.csv"
Browse files
app.py
CHANGED
@@ -25,21 +25,30 @@ NUM_EXAMPLES = 1281167
|
|
25 |
# Arbitrary small number of dataset examples to look at, only using in devv'ing.
|
26 |
DEV = True
|
27 |
DEV_AMOUNT = 10
|
|
|
|
|
28 |
# Whether to read in the distribution over labels from an external text file.
|
29 |
READ_DISTRO = False
|
30 |
GATED_IMAGENET = os.environ.get("GATED_IMAGENET")
|
|
|
31 |
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
print("Getting label proportions.")
|
34 |
-
if READ_DISTRO:
|
35 |
-
with open("label_distro.json", "r+") as f:
|
36 |
-
label_counts = json.loads(f.read())
|
37 |
-
else:
|
38 |
-
label_counts = Counter([example['label'] for example in dataset])
|
39 |
-
# Don't overrwrite the distribution when devving.
|
40 |
-
if not DEV:
|
41 |
-
with open("label_distro.json", "w+") as f:
|
42 |
-
f.write(json.dumps(label_counts))
|
43 |
label_list = list(label_counts.keys())
|
44 |
denom = sum(label_counts.values())
|
45 |
label_fractions = [label_counts[key]/denom for key in label_counts]
|
@@ -58,6 +67,7 @@ def randomize_labels(examples, indices, new_random_labels):
|
|
58 |
examples["label"][n] = new_random_labels.pop() if index in batch_subset else examples["label"][n]
|
59 |
return examples
|
60 |
|
|
|
61 |
def main(percentage=10):
|
62 |
global randomize_subset
|
63 |
# Just for timing how long this takes.
|
@@ -77,7 +87,7 @@ def main(percentage=10):
|
|
77 |
dataset = load_dataset("ILSVRC/imagenet-1k", split="train", streaming=True,
|
78 |
trust_remote_code=True, token=GATED_IMAGENET)
|
79 |
|
80 |
-
label_list, label_fractions = get_label_fractions(
|
81 |
|
82 |
# How many new random labels are we creating?
|
83 |
num_new_labels = int(round(NUM_EXAMPLES/float(percentage)))
|
|
|
25 |
# Arbitrary small number of dataset examples to look at, only using in devv'ing.
|
26 |
DEV = True
|
27 |
DEV_AMOUNT = 10
|
28 |
+
if DEV:
|
29 |
+
NUM_EXAMPLES = DEV_AMOUNT
|
30 |
# Whether to read in the distribution over labels from an external text file.
|
31 |
READ_DISTRO = False
|
32 |
GATED_IMAGENET = os.environ.get("GATED_IMAGENET")
|
33 |
+
LABELS_FILE = "label_frequencies_full.csv"
|
34 |
|
35 |
+
|
36 |
+
def read_label_frequencies():
|
37 |
+
label_counts_dict = {}
|
38 |
+
header_row = ['Label', 'Frequency']
|
39 |
+
with open(LABELS_FILE) as csvfile:
|
40 |
+
label_reader = csv.DictReader(csvfile)
|
41 |
+
assert label_reader.fieldnames == header_row
|
42 |
+
for row in label_reader:
|
43 |
+
assert row['Label'] not in label_counts_dict
|
44 |
+
label_counts_dict[row['Label']] = int(row['Frequency'])
|
45 |
+
# TODO: Can we just do this instead of the fractions? Do they really need to be normalized?
|
46 |
+
# label_list, label_counts = zip(*label_counts_dict.items())
|
47 |
+
return label_counts_dict
|
48 |
+
|
49 |
+
|
50 |
+
def get_label_fractions(label_counts_dict):
|
51 |
print("Getting label proportions.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
label_list = list(label_counts.keys())
|
53 |
denom = sum(label_counts.values())
|
54 |
label_fractions = [label_counts[key]/denom for key in label_counts]
|
|
|
67 |
examples["label"][n] = new_random_labels.pop() if index in batch_subset else examples["label"][n]
|
68 |
return examples
|
69 |
|
70 |
+
|
71 |
def main(percentage=10):
|
72 |
global randomize_subset
|
73 |
# Just for timing how long this takes.
|
|
|
87 |
dataset = load_dataset("ILSVRC/imagenet-1k", split="train", streaming=True,
|
88 |
trust_remote_code=True, token=GATED_IMAGENET)
|
89 |
|
90 |
+
label_list, label_fractions = get_label_fractions(read_label_frequencies())
|
91 |
|
92 |
# How many new random labels are we creating?
|
93 |
num_new_labels = int(round(NUM_EXAMPLES/float(percentage)))
|