Spaces:
Runtime error
Runtime error
""" | |
Evaluate generated images using Mask2Former (or other object detector model) | |
""" | |
import argparse | |
import json | |
import os | |
import re | |
import sys | |
import time | |
import warnings | |
from pathlib import Path | |
current_file_path = Path(__file__).resolve() | |
sys.path.insert(0, str(current_file_path.parent.parent.parent.parent.parent)) | |
warnings.filterwarnings("ignore") | |
import mmdet | |
import numpy as np | |
import open_clip | |
import pandas as pd | |
import torch | |
from clip_benchmark.metrics import zeroshot_classification as zsc | |
from mmdet.apis import inference_detector, init_detector | |
from PIL import Image, ImageOps | |
from tqdm import tqdm | |
zsc.tqdm = lambda it, *args, **kwargs: it | |
from tools.metrics.utils import tracker | |
# Get directory path | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
assert DEVICE == "cuda" | |
def timed(fn): | |
def wrapper(*args, **kwargs): | |
startt = time.time() | |
result = fn(*args, **kwargs) | |
endt = time.time() | |
print(f"Function {fn.__name__!r} executed in {endt - startt:.3f}s", file=sys.stderr) | |
return result | |
return wrapper | |
# Load models | |
def load_models(args): | |
CONFIG_PATH = args.model_config | |
OBJECT_DETECTOR = args.options.get("model", "mask2former_swin-s-p4-w7-224_lsj_8x2_50e_coco") | |
CKPT_PATH = os.path.join(args.model_path, f"{OBJECT_DETECTOR}.pth") | |
object_detector = init_detector(CONFIG_PATH, CKPT_PATH, device=DEVICE) | |
clip_arch = args.options.get("clip_model", "ViT-L-14") | |
clip_model, _, transform = open_clip.create_model_and_transforms(clip_arch, pretrained="openai", device=DEVICE) | |
tokenizer = open_clip.get_tokenizer(clip_arch) | |
with open(os.path.join(os.path.dirname(__file__), "object_names.txt")) as cls_file: | |
classnames = [line.strip() for line in cls_file] | |
return object_detector, (clip_model, transform, tokenizer), classnames | |
COLORS = ["red", "orange", "yellow", "green", "blue", "purple", "pink", "brown", "black", "white"] | |
COLOR_CLASSIFIERS = {} | |
# Evaluation parts | |
class ImageCrops(torch.utils.data.Dataset): | |
def __init__(self, image: Image.Image, objects): | |
self._image = image.convert("RGB") | |
bgcolor = args.options.get("bgcolor", "#999") | |
if bgcolor == "original": | |
self._blank = self._image.copy() | |
else: | |
self._blank = Image.new("RGB", image.size, color=bgcolor) | |
self._objects = objects | |
def __len__(self): | |
return len(self._objects) | |
def __getitem__(self, index): | |
box, mask = self._objects[index] | |
if mask is not None: | |
assert tuple(self._image.size[::-1]) == tuple(mask.shape), (index, self._image.size[::-1], mask.shape) | |
image = Image.composite(self._image, self._blank, Image.fromarray(mask)) | |
else: | |
image = self._image | |
if args.options.get("crop", "1") == "1": | |
image = image.crop(box[:4]) | |
# if args.save: | |
# base_count = len(os.listdir(args.save)) | |
# image.save(os.path.join(args.save, f"cropped_{base_count:05}.png")) | |
return (transform(image), 0) | |
def color_classification(image, bboxes, classname): | |
if classname not in COLOR_CLASSIFIERS: | |
COLOR_CLASSIFIERS[classname] = zsc.zero_shot_classifier( | |
clip_model, | |
tokenizer, | |
COLORS, | |
[ | |
f"a photo of a {{c}} {classname}", | |
f"a photo of a {{c}}-colored {classname}", | |
f"a photo of a {{c}} object", | |
], | |
DEVICE, | |
) | |
clf = COLOR_CLASSIFIERS[classname] | |
dataloader = torch.utils.data.DataLoader(ImageCrops(image, bboxes), batch_size=16, num_workers=4) | |
with torch.no_grad(): | |
pred, _ = zsc.run_classification(clip_model, clf, dataloader, DEVICE) | |
return [COLORS[index.item()] for index in pred.argmax(1)] | |
def compute_iou(box_a, box_b): | |
area_fn = lambda box: max(box[2] - box[0] + 1, 0) * max(box[3] - box[1] + 1, 0) | |
i_area = area_fn( | |
[max(box_a[0], box_b[0]), max(box_a[1], box_b[1]), min(box_a[2], box_b[2]), min(box_a[3], box_b[3])] | |
) | |
u_area = area_fn(box_a) + area_fn(box_b) - i_area | |
return i_area / u_area if u_area else 0 | |
def relative_position(obj_a, obj_b): | |
"""Give position of A relative to B, factoring in object dimensions""" | |
boxes = np.array([obj_a[0], obj_b[0]])[:, :4].reshape(2, 2, 2) | |
center_a, center_b = boxes.mean(axis=-2) | |
dim_a, dim_b = np.abs(np.diff(boxes, axis=-2))[..., 0, :] | |
offset = center_a - center_b | |
# | |
revised_offset = np.maximum(np.abs(offset) - POSITION_THRESHOLD * (dim_a + dim_b), 0) * np.sign(offset) | |
if np.all(np.abs(revised_offset) < 1e-3): | |
return set() | |
# | |
dx, dy = revised_offset / np.linalg.norm(offset) | |
relations = set() | |
if dx < -0.5: | |
relations.add("left of") | |
if dx > 0.5: | |
relations.add("right of") | |
if dy < -0.5: | |
relations.add("above") | |
if dy > 0.5: | |
relations.add("below") | |
return relations | |
def evaluate(image, objects, metadata): | |
""" | |
Evaluate given image using detected objects on the global metadata specifications. | |
Assumptions: | |
* Metadata combines 'include' clauses with AND, and 'exclude' clauses with OR | |
* All clauses are independent, i.e., duplicating a clause has no effect on the correctness | |
* CHANGED: Color and position will only be evaluated on the most confidently predicted objects; | |
therefore, objects are expected to appear in sorted order | |
""" | |
correct = True | |
reason = [] | |
matched_groups = [] | |
# Check for expected objects | |
for req in metadata.get("include", []): | |
classname = req["class"] | |
matched = True | |
found_objects = objects.get(classname, [])[: req["count"]] | |
if len(found_objects) < req["count"]: | |
correct = matched = False | |
reason.append(f"expected {classname}>={req['count']}, found {len(found_objects)}") | |
else: | |
if "color" in req: | |
# Color check | |
colors = color_classification(image, found_objects, classname) | |
if colors.count(req["color"]) < req["count"]: | |
correct = matched = False | |
reason.append( | |
f"expected {req['color']} {classname}>={req['count']}, found " | |
+ f"{colors.count(req['color'])} {req['color']}; and " | |
+ ", ".join(f"{colors.count(c)} {c}" for c in COLORS if c in colors) | |
) | |
if "position" in req and matched: | |
# Relative position check | |
expected_rel, target_group = req["position"] | |
if matched_groups[target_group] is None: | |
correct = matched = False | |
reason.append(f"no target for {classname} to be {expected_rel}") | |
else: | |
for obj in found_objects: | |
for target_obj in matched_groups[target_group]: | |
true_rels = relative_position(obj, target_obj) | |
if expected_rel not in true_rels: | |
correct = matched = False | |
reason.append( | |
f"expected {classname} {expected_rel} target, found " | |
+ f"{' and '.join(true_rels)} target" | |
) | |
break | |
if not matched: | |
break | |
if matched: | |
matched_groups.append(found_objects) | |
else: | |
matched_groups.append(None) | |
# Check for non-expected objects | |
for req in metadata.get("exclude", []): | |
classname = req["class"] | |
if len(objects.get(classname, [])) >= req["count"]: | |
correct = False | |
reason.append(f"expected {classname}<{req['count']}, found {len(objects[classname])}") | |
return correct, "\n".join(reason) | |
def evaluate_image(filepath, metadata): | |
result = inference_detector(object_detector, filepath) | |
bbox = result[0] if isinstance(result, tuple) else result | |
segm = result[1] if isinstance(result, tuple) and len(result) > 1 else None | |
image = ImageOps.exif_transpose(Image.open(filepath)) | |
detected = {} | |
# Determine bounding boxes to keep | |
confidence_threshold = THRESHOLD if metadata["tag"] != "counting" else COUNTING_THRESHOLD | |
for index, classname in enumerate(classnames): | |
ordering = np.argsort(bbox[index][:, 4])[::-1] | |
ordering = ordering[bbox[index][ordering, 4] > confidence_threshold] # Threshold | |
ordering = ordering[:MAX_OBJECTS].tolist() # Limit number of detected objects per class | |
detected[classname] = [] | |
while ordering: | |
max_obj = ordering.pop(0) | |
detected[classname].append((bbox[index][max_obj], None if segm is None else segm[index][max_obj])) | |
ordering = [ | |
obj | |
for obj in ordering | |
if NMS_THRESHOLD == 1 or compute_iou(bbox[index][max_obj], bbox[index][obj]) < NMS_THRESHOLD | |
] | |
if not detected[classname]: | |
del detected[classname] | |
# Evaluate | |
is_correct, reason = evaluate(image, detected, metadata) | |
return { | |
"filename": filepath, | |
"tag": metadata["tag"], | |
"prompt": metadata["prompt"], | |
"correct": is_correct, | |
"reason": reason, | |
"metadata": json.dumps(metadata), | |
"details": json.dumps({key: [box.tolist() for box, _ in value] for key, value in detected.items()}), | |
} | |
def main(args): | |
full_results = [] | |
image_dir = str(os.path.join(args.img_path, args.exp_name)) | |
args.outfile = f"{image_dir}_geneval.jsonl" | |
if os.path.exists(args.outfile): | |
df = pd.read_json(args.outfile, orient="records", lines=True) | |
return {args.exp_name: df} | |
for subfolder in tqdm(os.listdir(image_dir), f"Detecting on {args.gpu_id}"): | |
folderpath = os.path.join(image_dir, subfolder) | |
if not os.path.isdir(folderpath) or not subfolder.isdigit(): | |
continue | |
with open(os.path.join(folderpath, "metadata.jsonl")) as fp: | |
metadata = json.load(fp) | |
# Evaluate each image | |
for imagename in os.listdir(os.path.join(folderpath, "samples")): | |
imagepath = os.path.join(folderpath, "samples", imagename) | |
if not os.path.isfile(imagepath) or not re.match(r"\d+\.png", imagename): | |
continue | |
result = evaluate_image(imagepath, metadata) | |
full_results.append(result) | |
# Save results | |
if os.path.dirname(args.outfile): | |
os.makedirs(os.path.dirname(args.outfile), exist_ok=True) | |
with open(args.outfile, "w") as fp: | |
pd.DataFrame(full_results).to_json(fp, orient="records", lines=True) | |
df = pd.read_json(args.outfile, orient="records", lines=True) | |
return {args.exp_name: df} | |
def tracker_ori(df_dict, label=""): | |
if args.report_to == "wandb": | |
import wandb | |
wandb_name = f"[{args.log_metric}]_[{args.name}]" | |
wandb.init(project=args.tracker_project_name, name=wandb_name, resume="allow", id=wandb_name, tags="metrics") | |
run = wandb.run | |
run.define_metric("custom_step") | |
run.define_metric(f"GenEval_Overall_Score({label})", step_metric="custom_step") | |
for exp_name, df in df_dict.items(): | |
steps = [] | |
# 在函数内初始化wandb表格 | |
wandb_table = wandb.Table(columns=["Metric", "Value"]) | |
# 计算总图像数、总提示数、正确图像百分比和正确提示百分比 | |
total_images = len(df) | |
total_prompts = len(df.groupby("metadata")) | |
percentage_correct_images = df["correct"].mean() | |
percentage_correct_prompts = df.groupby("metadata")["correct"].any().mean() | |
wandb_table.add_data("Total images", total_images) | |
wandb_table.add_data("Total prompts", total_prompts) | |
wandb_table.add_data("% correct images", f"{percentage_correct_images:.2%}") | |
wandb_table.add_data("% correct prompts", f"{percentage_correct_prompts:.2%}") | |
task_scores = [] | |
for tag, task_df in df.groupby("tag", sort=False): | |
task_score = task_df["correct"].mean() | |
task_scores.append(task_score) | |
task_result = f"{tag:<16} = {task_score:.2%} ({task_df['correct'].sum()} / {len(task_df)})" | |
print(task_result) | |
# 将任务得分添加到表格中 | |
wandb_table.add_data(tag, f"{task_score:.2%} ({task_df['correct'].sum()} / {len(task_df)})") | |
# 计算整体得分 | |
overall_score = np.mean(task_scores) | |
print(f"Overall score (avg. over tasks): {overall_score:.5f}") | |
# 处理exp_name中的步骤 | |
match = re.search(r".*epoch(\d+)_step(\d+).*", exp_name) | |
if match: | |
epoch_name, step_name = match.groups() | |
step = int(step_name) | |
steps.append(step) | |
# 记录每个步骤和对应的整体得分 | |
run.log({"custom_step": step, f"GenEval_Overall_Score({label})": overall_score}) | |
# 记录表格到wandb | |
run.log({"Metrics Table": wandb_table}) | |
else: | |
print(f"{args.report_to} is not supported") | |
def log_results(df_dict): | |
# Measure overall success | |
for exp_name, df in df_dict.items(): | |
print("Summary") | |
print("=======") | |
print(f"Total images: {len(df)}") | |
print(f"Total prompts: {len(df.groupby('metadata'))}") | |
print(f"% correct images: {df['correct'].mean():.2%}") | |
print(f"% correct prompts: {df.groupby('metadata')['correct'].any().mean():.2%}") | |
print() | |
# By group | |
task_scores = [] | |
print("Task breakdown") | |
print("==============") | |
for tag, task_df in df.groupby("tag", sort=False): | |
task_scores.append(task_df["correct"].mean()) | |
print(f"{tag:<16} = {task_df['correct'].mean():.2%} ({task_df['correct'].sum()} / {len(task_df)})") | |
print() | |
print(f"Overall score (avg. over tasks): {np.mean(task_scores):.5f}") | |
return {exp_name: np.mean(task_scores)} | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--img_path", type=str, default=None) | |
parser.add_argument("--exp_name", type=str, default="Sana") | |
parser.add_argument("--outfile", type=str, default="results.jsonl") | |
parser.add_argument("--model-config", type=str, default=None) | |
parser.add_argument("--model-path", type=str, default=None) | |
parser.add_argument("--gpu_id", type=int, default=0) | |
# Other arguments | |
parser.add_argument("--options", nargs="*", type=str, default=[]) | |
# wandb report | |
parser.add_argument("--log_geneval", action="store_true") | |
parser.add_argument("--log_metric", type=str, default="metric") | |
parser.add_argument("--suffix_label", type=str, default="", help="used for clip_score online log") | |
parser.add_argument("--tracker_pattern", type=str, default="epoch_step", help="used for GenEval online log") | |
parser.add_argument( | |
"--report_to", | |
type=str, | |
default=None, | |
help=( | |
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' | |
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' | |
), | |
) | |
parser.add_argument( | |
"--tracker_project_name", | |
type=str, | |
default="t2i-evit-baseline", | |
help=( | |
"The `project_name` argument passed to Accelerator.init_trackers for" | |
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" | |
), | |
) | |
parser.add_argument( | |
"--name", | |
type=str, | |
default="baseline", | |
help=("Wandb Project Name"), | |
) | |
args = parser.parse_args() | |
args.options = dict(opt.split("=", 1) for opt in args.options) | |
if args.model_config is None: | |
args.model_config = os.path.join( | |
os.path.dirname(mmdet.__file__), "../configs/mask2former/mask2former_swin-s-p4-w7-224_lsj_8x2_50e_coco.py" | |
) | |
return args | |
if __name__ == "__main__": | |
args = parse_args() | |
object_detector, (clip_model, transform, tokenizer), classnames = load_models(args) | |
THRESHOLD = float(args.options.get("threshold", 0.3)) | |
COUNTING_THRESHOLD = float(args.options.get("counting_threshold", 0.9)) | |
MAX_OBJECTS = int(args.options.get("max_objects", 16)) | |
NMS_THRESHOLD = float(args.options.get("max_overlap", 1.0)) | |
POSITION_THRESHOLD = float(args.options.get("position_threshold", 0.1)) | |
args.exp_name = os.path.basename(args.exp_name) or os.path.dirname(args.exp_name) | |
df_dict = main(args) | |
geneval_result = log_results(df_dict) | |
if args.log_geneval: | |
# tracker_ori(df_dict, args.suffix_label) | |
tracker(args, geneval_result, args.suffix_label, pattern=args.tracker_pattern, metric="GenEval") | |