|
""" |
|
Generate prompts for evaluation |
|
""" |
|
|
|
import argparse |
|
import json |
|
import os |
|
|
|
import numpy as np |
|
import yaml |
|
|
|
|
|
|
|
with open("object_names.txt") as cls_file: |
|
classnames = [line.strip() for line in cls_file] |
|
|
|
|
|
|
|
|
|
def with_article(name: str): |
|
if name[0] in "aeiou": |
|
return f"an {name}" |
|
return f"a {name}" |
|
|
|
|
|
|
|
|
|
|
|
def make_plural(name: str): |
|
if name[-1] in "s": |
|
return f"{name}es" |
|
return f"{name}s" |
|
|
|
|
|
|
|
|
|
|
|
def generate_single_object_sample(rng: np.random.Generator, size: int = None): |
|
TAG = "single_object" |
|
if size > len(classnames): |
|
size = len(classnames) |
|
print(f"Not enough distinct classes, generating only {size} samples") |
|
return_scalar = size is None |
|
size = size or 1 |
|
idxs = rng.choice(len(classnames), size=size, replace=False) |
|
samples = [ |
|
dict( |
|
tag=TAG, |
|
include=[{"class": classnames[idx], "count": 1}], |
|
prompt=f"a photo of {with_article(classnames[idx])}", |
|
) |
|
for idx in idxs |
|
] |
|
if return_scalar: |
|
return samples[0] |
|
return samples |
|
|
|
|
|
|
|
|
|
|
|
def generate_two_object_sample(rng: np.random.Generator): |
|
TAG = "two_object" |
|
idx_a, idx_b = rng.choice(len(classnames), size=2, replace=False) |
|
return dict( |
|
tag=TAG, |
|
include=[{"class": classnames[idx_a], "count": 1}, {"class": classnames[idx_b], "count": 1}], |
|
prompt=f"a photo of {with_article(classnames[idx_a])} and {with_article(classnames[idx_b])}", |
|
) |
|
|
|
|
|
|
|
|
|
numbers = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten"] |
|
|
|
|
|
def generate_counting_sample(rng: np.random.Generator, max_count=4): |
|
TAG = "counting" |
|
idx = rng.choice(len(classnames)) |
|
num = int(rng.integers(2, max_count, endpoint=True)) |
|
return dict( |
|
tag=TAG, |
|
include=[{"class": classnames[idx], "count": num}], |
|
exclude=[{"class": classnames[idx], "count": num + 1}], |
|
prompt=f"a photo of {numbers[num]} {make_plural(classnames[idx])}", |
|
) |
|
|
|
|
|
|
|
|
|
colors = ["red", "orange", "yellow", "green", "blue", "purple", "pink", "brown", "black", "white"] |
|
|
|
|
|
def generate_color_sample(rng: np.random.Generator): |
|
TAG = "colors" |
|
idx = rng.choice(len(classnames) - 1) + 1 |
|
idx = (idx + classnames.index("person")) % len(classnames) |
|
color = colors[rng.choice(len(colors))] |
|
return dict( |
|
tag=TAG, |
|
include=[{"class": classnames[idx], "count": 1, "color": color}], |
|
prompt=f"a photo of {with_article(color)} {classnames[idx]}", |
|
) |
|
|
|
|
|
|
|
|
|
positions = ["left of", "right of", "above", "below"] |
|
|
|
|
|
def generate_position_sample(rng: np.random.Generator): |
|
TAG = "position" |
|
idx_a, idx_b = rng.choice(len(classnames), size=2, replace=False) |
|
position = positions[rng.choice(len(positions))] |
|
return dict( |
|
tag=TAG, |
|
include=[ |
|
{"class": classnames[idx_b], "count": 1}, |
|
{"class": classnames[idx_a], "count": 1, "position": (position, 0)}, |
|
], |
|
prompt=f"a photo of {with_article(classnames[idx_a])} {position} {with_article(classnames[idx_b])}", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def generate_color_attribution_sample(rng: np.random.Generator): |
|
TAG = "color_attr" |
|
idxs = rng.choice(len(classnames) - 1, size=2, replace=False) + 1 |
|
idx_a, idx_b = (idxs + classnames.index("person")) % len(classnames) |
|
cidx_a, cidx_b = rng.choice(len(colors), size=2, replace=False) |
|
return dict( |
|
tag=TAG, |
|
include=[ |
|
{"class": classnames[idx_a], "count": 1, "color": colors[cidx_a]}, |
|
{"class": classnames[idx_b], "count": 1, "color": colors[cidx_b]}, |
|
], |
|
prompt=f"a photo of {with_article(colors[cidx_a])} {classnames[idx_a]} and {with_article(colors[cidx_b])} {classnames[idx_b]}", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def generate_suite(rng: np.random.Generator, n: int = 100, output_path: str = ""): |
|
samples = [] |
|
|
|
samples.extend(generate_single_object_sample(rng, size=len(classnames))) |
|
|
|
for _ in range(n): |
|
samples.append(generate_two_object_sample(rng)) |
|
|
|
for _ in range(n): |
|
samples.append(generate_counting_sample(rng, max_count=4)) |
|
|
|
for _ in range(n): |
|
samples.append(generate_color_sample(rng)) |
|
|
|
for _ in range(n): |
|
samples.append(generate_position_sample(rng)) |
|
|
|
for _ in range(n): |
|
samples.append(generate_color_attribution_sample(rng)) |
|
|
|
unique_samples, used_samples = [], set() |
|
for sample in samples: |
|
sample_text = yaml.safe_dump(sample) |
|
if sample_text not in used_samples: |
|
unique_samples.append(sample) |
|
used_samples.add(sample_text) |
|
|
|
|
|
os.makedirs(output_path, exist_ok=True) |
|
with open(os.path.join(output_path, "generation_prompts.txt"), "w") as fp: |
|
for sample in unique_samples: |
|
print(sample["prompt"], file=fp) |
|
with open(os.path.join(output_path, "evaluation_metadata.jsonl"), "w") as fp: |
|
for sample in unique_samples: |
|
print(json.dumps(sample), file=fp) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--seed", type=int, default=43, help="generation seed (default: 43)") |
|
parser.add_argument("--num-prompts", "-n", type=int, default=100, help="number of prompts per task (default: 100)") |
|
parser.add_argument( |
|
"--output-path", |
|
"-o", |
|
type=str, |
|
default="prompts", |
|
help="output folder for prompts and metadata (default: 'prompts/')", |
|
) |
|
args = parser.parse_args() |
|
rng = np.random.default_rng(args.seed) |
|
generate_suite(rng, args.num_prompts, args.output_path) |
|
|