""" Evaluate generated images using Mask2Former (or other object detector model) """ import argparse import json import os import re import sys import time import warnings from pathlib import Path current_file_path = Path(__file__).resolve() sys.path.insert(0, str(current_file_path.parent.parent.parent.parent.parent)) warnings.filterwarnings("ignore") import mmdet import numpy as np import open_clip import pandas as pd import torch from clip_benchmark.metrics import zeroshot_classification as zsc from mmdet.apis import inference_detector, init_detector from PIL import Image, ImageOps from tqdm import tqdm zsc.tqdm = lambda it, *args, **kwargs: it from tools.metrics.utils import tracker # Get directory path DEVICE = "cuda" if torch.cuda.is_available() else "cpu" assert DEVICE == "cuda" def timed(fn): def wrapper(*args, **kwargs): startt = time.time() result = fn(*args, **kwargs) endt = time.time() print(f"Function {fn.__name__!r} executed in {endt - startt:.3f}s", file=sys.stderr) return result return wrapper # Load models @timed def load_models(args): CONFIG_PATH = args.model_config OBJECT_DETECTOR = args.options.get("model", "mask2former_swin-s-p4-w7-224_lsj_8x2_50e_coco") CKPT_PATH = os.path.join(args.model_path, f"{OBJECT_DETECTOR}.pth") object_detector = init_detector(CONFIG_PATH, CKPT_PATH, device=DEVICE) clip_arch = args.options.get("clip_model", "ViT-L-14") clip_model, _, transform = open_clip.create_model_and_transforms(clip_arch, pretrained="openai", device=DEVICE) tokenizer = open_clip.get_tokenizer(clip_arch) with open(os.path.join(os.path.dirname(__file__), "object_names.txt")) as cls_file: classnames = [line.strip() for line in cls_file] return object_detector, (clip_model, transform, tokenizer), classnames COLORS = ["red", "orange", "yellow", "green", "blue", "purple", "pink", "brown", "black", "white"] COLOR_CLASSIFIERS = {} # Evaluation parts class ImageCrops(torch.utils.data.Dataset): def __init__(self, image: Image.Image, objects): self._image = image.convert("RGB") bgcolor = args.options.get("bgcolor", "#999") if bgcolor == "original": self._blank = self._image.copy() else: self._blank = Image.new("RGB", image.size, color=bgcolor) self._objects = objects def __len__(self): return len(self._objects) def __getitem__(self, index): box, mask = self._objects[index] if mask is not None: assert tuple(self._image.size[::-1]) == tuple(mask.shape), (index, self._image.size[::-1], mask.shape) image = Image.composite(self._image, self._blank, Image.fromarray(mask)) else: image = self._image if args.options.get("crop", "1") == "1": image = image.crop(box[:4]) # if args.save: # base_count = len(os.listdir(args.save)) # image.save(os.path.join(args.save, f"cropped_{base_count:05}.png")) return (transform(image), 0) def color_classification(image, bboxes, classname): if classname not in COLOR_CLASSIFIERS: COLOR_CLASSIFIERS[classname] = zsc.zero_shot_classifier( clip_model, tokenizer, COLORS, [ f"a photo of a {{c}} {classname}", f"a photo of a {{c}}-colored {classname}", f"a photo of a {{c}} object", ], DEVICE, ) clf = COLOR_CLASSIFIERS[classname] dataloader = torch.utils.data.DataLoader(ImageCrops(image, bboxes), batch_size=16, num_workers=4) with torch.no_grad(): pred, _ = zsc.run_classification(clip_model, clf, dataloader, DEVICE) return [COLORS[index.item()] for index in pred.argmax(1)] def compute_iou(box_a, box_b): area_fn = lambda box: max(box[2] - box[0] + 1, 0) * max(box[3] - box[1] + 1, 0) i_area = area_fn( [max(box_a[0], box_b[0]), max(box_a[1], box_b[1]), min(box_a[2], box_b[2]), min(box_a[3], box_b[3])] ) u_area = area_fn(box_a) + area_fn(box_b) - i_area return i_area / u_area if u_area else 0 def relative_position(obj_a, obj_b): """Give position of A relative to B, factoring in object dimensions""" boxes = np.array([obj_a[0], obj_b[0]])[:, :4].reshape(2, 2, 2) center_a, center_b = boxes.mean(axis=-2) dim_a, dim_b = np.abs(np.diff(boxes, axis=-2))[..., 0, :] offset = center_a - center_b # revised_offset = np.maximum(np.abs(offset) - POSITION_THRESHOLD * (dim_a + dim_b), 0) * np.sign(offset) if np.all(np.abs(revised_offset) < 1e-3): return set() # dx, dy = revised_offset / np.linalg.norm(offset) relations = set() if dx < -0.5: relations.add("left of") if dx > 0.5: relations.add("right of") if dy < -0.5: relations.add("above") if dy > 0.5: relations.add("below") return relations def evaluate(image, objects, metadata): """ Evaluate given image using detected objects on the global metadata specifications. Assumptions: * Metadata combines 'include' clauses with AND, and 'exclude' clauses with OR * All clauses are independent, i.e., duplicating a clause has no effect on the correctness * CHANGED: Color and position will only be evaluated on the most confidently predicted objects; therefore, objects are expected to appear in sorted order """ correct = True reason = [] matched_groups = [] # Check for expected objects for req in metadata.get("include", []): classname = req["class"] matched = True found_objects = objects.get(classname, [])[: req["count"]] if len(found_objects) < req["count"]: correct = matched = False reason.append(f"expected {classname}>={req['count']}, found {len(found_objects)}") else: if "color" in req: # Color check colors = color_classification(image, found_objects, classname) if colors.count(req["color"]) < req["count"]: correct = matched = False reason.append( f"expected {req['color']} {classname}>={req['count']}, found " + f"{colors.count(req['color'])} {req['color']}; and " + ", ".join(f"{colors.count(c)} {c}" for c in COLORS if c in colors) ) if "position" in req and matched: # Relative position check expected_rel, target_group = req["position"] if matched_groups[target_group] is None: correct = matched = False reason.append(f"no target for {classname} to be {expected_rel}") else: for obj in found_objects: for target_obj in matched_groups[target_group]: true_rels = relative_position(obj, target_obj) if expected_rel not in true_rels: correct = matched = False reason.append( f"expected {classname} {expected_rel} target, found " + f"{' and '.join(true_rels)} target" ) break if not matched: break if matched: matched_groups.append(found_objects) else: matched_groups.append(None) # Check for non-expected objects for req in metadata.get("exclude", []): classname = req["class"] if len(objects.get(classname, [])) >= req["count"]: correct = False reason.append(f"expected {classname}<{req['count']}, found {len(objects[classname])}") return correct, "\n".join(reason) def evaluate_image(filepath, metadata): result = inference_detector(object_detector, filepath) bbox = result[0] if isinstance(result, tuple) else result segm = result[1] if isinstance(result, tuple) and len(result) > 1 else None image = ImageOps.exif_transpose(Image.open(filepath)) detected = {} # Determine bounding boxes to keep confidence_threshold = THRESHOLD if metadata["tag"] != "counting" else COUNTING_THRESHOLD for index, classname in enumerate(classnames): ordering = np.argsort(bbox[index][:, 4])[::-1] ordering = ordering[bbox[index][ordering, 4] > confidence_threshold] # Threshold ordering = ordering[:MAX_OBJECTS].tolist() # Limit number of detected objects per class detected[classname] = [] while ordering: max_obj = ordering.pop(0) detected[classname].append((bbox[index][max_obj], None if segm is None else segm[index][max_obj])) ordering = [ obj for obj in ordering if NMS_THRESHOLD == 1 or compute_iou(bbox[index][max_obj], bbox[index][obj]) < NMS_THRESHOLD ] if not detected[classname]: del detected[classname] # Evaluate is_correct, reason = evaluate(image, detected, metadata) return { "filename": filepath, "tag": metadata["tag"], "prompt": metadata["prompt"], "correct": is_correct, "reason": reason, "metadata": json.dumps(metadata), "details": json.dumps({key: [box.tolist() for box, _ in value] for key, value in detected.items()}), } def main(args): full_results = [] image_dir = str(os.path.join(args.img_path, args.exp_name)) args.outfile = f"{image_dir}_geneval.jsonl" if os.path.exists(args.outfile): df = pd.read_json(args.outfile, orient="records", lines=True) return {args.exp_name: df} for subfolder in tqdm(os.listdir(image_dir), f"Detecting on {args.gpu_id}"): folderpath = os.path.join(image_dir, subfolder) if not os.path.isdir(folderpath) or not subfolder.isdigit(): continue with open(os.path.join(folderpath, "metadata.jsonl")) as fp: metadata = json.load(fp) # Evaluate each image for imagename in os.listdir(os.path.join(folderpath, "samples")): imagepath = os.path.join(folderpath, "samples", imagename) if not os.path.isfile(imagepath) or not re.match(r"\d+\.png", imagename): continue result = evaluate_image(imagepath, metadata) full_results.append(result) # Save results if os.path.dirname(args.outfile): os.makedirs(os.path.dirname(args.outfile), exist_ok=True) with open(args.outfile, "w") as fp: pd.DataFrame(full_results).to_json(fp, orient="records", lines=True) df = pd.read_json(args.outfile, orient="records", lines=True) return {args.exp_name: df} def tracker_ori(df_dict, label=""): if args.report_to == "wandb": import wandb wandb_name = f"[{args.log_metric}]_[{args.name}]" wandb.init(project=args.tracker_project_name, name=wandb_name, resume="allow", id=wandb_name, tags="metrics") run = wandb.run run.define_metric("custom_step") run.define_metric(f"GenEval_Overall_Score({label})", step_metric="custom_step") for exp_name, df in df_dict.items(): steps = [] # 在函数内初始化wandb表格 wandb_table = wandb.Table(columns=["Metric", "Value"]) # 计算总图像数、总提示数、正确图像百分比和正确提示百分比 total_images = len(df) total_prompts = len(df.groupby("metadata")) percentage_correct_images = df["correct"].mean() percentage_correct_prompts = df.groupby("metadata")["correct"].any().mean() wandb_table.add_data("Total images", total_images) wandb_table.add_data("Total prompts", total_prompts) wandb_table.add_data("% correct images", f"{percentage_correct_images:.2%}") wandb_table.add_data("% correct prompts", f"{percentage_correct_prompts:.2%}") task_scores = [] for tag, task_df in df.groupby("tag", sort=False): task_score = task_df["correct"].mean() task_scores.append(task_score) task_result = f"{tag:<16} = {task_score:.2%} ({task_df['correct'].sum()} / {len(task_df)})" print(task_result) # 将任务得分添加到表格中 wandb_table.add_data(tag, f"{task_score:.2%} ({task_df['correct'].sum()} / {len(task_df)})") # 计算整体得分 overall_score = np.mean(task_scores) print(f"Overall score (avg. over tasks): {overall_score:.5f}") # 处理exp_name中的步骤 match = re.search(r".*epoch(\d+)_step(\d+).*", exp_name) if match: epoch_name, step_name = match.groups() step = int(step_name) steps.append(step) # 记录每个步骤和对应的整体得分 run.log({"custom_step": step, f"GenEval_Overall_Score({label})": overall_score}) # 记录表格到wandb run.log({"Metrics Table": wandb_table}) else: print(f"{args.report_to} is not supported") def log_results(df_dict): # Measure overall success for exp_name, df in df_dict.items(): print("Summary") print("=======") print(f"Total images: {len(df)}") print(f"Total prompts: {len(df.groupby('metadata'))}") print(f"% correct images: {df['correct'].mean():.2%}") print(f"% correct prompts: {df.groupby('metadata')['correct'].any().mean():.2%}") print() # By group task_scores = [] print("Task breakdown") print("==============") for tag, task_df in df.groupby("tag", sort=False): task_scores.append(task_df["correct"].mean()) print(f"{tag:<16} = {task_df['correct'].mean():.2%} ({task_df['correct'].sum()} / {len(task_df)})") print() print(f"Overall score (avg. over tasks): {np.mean(task_scores):.5f}") return {exp_name: np.mean(task_scores)} def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--img_path", type=str, default=None) parser.add_argument("--exp_name", type=str, default="Sana") parser.add_argument("--outfile", type=str, default="results.jsonl") parser.add_argument("--model-config", type=str, default=None) parser.add_argument("--model-path", type=str, default=None) parser.add_argument("--gpu_id", type=int, default=0) # Other arguments parser.add_argument("--options", nargs="*", type=str, default=[]) # wandb report parser.add_argument("--log_geneval", action="store_true") parser.add_argument("--log_metric", type=str, default="metric") parser.add_argument("--suffix_label", type=str, default="", help="used for clip_score online log") parser.add_argument("--tracker_pattern", type=str, default="epoch_step", help="used for GenEval online log") parser.add_argument( "--report_to", type=str, default=None, help=( 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) parser.add_argument( "--tracker_project_name", type=str, default="t2i-evit-baseline", help=( "The `project_name` argument passed to Accelerator.init_trackers for" " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" ), ) parser.add_argument( "--name", type=str, default="baseline", help=("Wandb Project Name"), ) args = parser.parse_args() args.options = dict(opt.split("=", 1) for opt in args.options) if args.model_config is None: args.model_config = os.path.join( os.path.dirname(mmdet.__file__), "../configs/mask2former/mask2former_swin-s-p4-w7-224_lsj_8x2_50e_coco.py" ) return args if __name__ == "__main__": args = parse_args() object_detector, (clip_model, transform, tokenizer), classnames = load_models(args) THRESHOLD = float(args.options.get("threshold", 0.3)) COUNTING_THRESHOLD = float(args.options.get("counting_threshold", 0.9)) MAX_OBJECTS = int(args.options.get("max_objects", 16)) NMS_THRESHOLD = float(args.options.get("max_overlap", 1.0)) POSITION_THRESHOLD = float(args.options.get("position_threshold", 0.1)) args.exp_name = os.path.basename(args.exp_name) or os.path.dirname(args.exp_name) df_dict = main(args) geneval_result = log_results(df_dict) if args.log_geneval: # tracker_ori(df_dict, args.suffix_label) tracker(args, geneval_result, args.suffix_label, pattern=args.tracker_pattern, metric="GenEval")