File size: 3,596 Bytes
d643072
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import json
import os
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser

import ImageReward as RM
from tqdm import tqdm

from tools.metrics.utils import tracker


def parse_args():
    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
    parser.add_argument("--json_path", type=str, default="./benchmark-prompts-dict.json")

    parser.add_argument("--img_path", type=str, default=None)
    parser.add_argument("--exp_name", type=str, default="Sana")
    parser.add_argument("--txt_path", type=str, default=None)
    parser.add_argument("--sample_nums", type=int, default=100)
    parser.add_argument("--sample_per_prompt", default=10, type=int)

    # online logging setting
    parser.add_argument("--log_metric", type=str, default="metric")
    parser.add_argument("--gpu_id", type=int, default=0)
    parser.add_argument("--log_image_reward", action="store_true")
    parser.add_argument("--suffix_label", type=str, default="", help="used for image-reward online log")
    parser.add_argument("--tracker_pattern", type=str, default="epoch_step", help="used for image-reward online log")
    parser.add_argument(
        "--report_to",
        type=str,
        default=None,
        help=(
            'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
            ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
        ),
    )
    parser.add_argument(
        "--tracker_project_name",
        type=str,
        default="t2i-evit-baseline",
        help=(
            "The `project_name` argument passed to Accelerator.init_trackers for"
            " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
        ),
    )
    parser.add_argument(
        "--name",
        type=str,
        default="baseline",
        help=("Wandb Project Name"),
    )
    args = parser.parse_args()
    return args


def main():
    txt_path = args.txt_path if args.txt_path is not None else args.img_path
    save_txt_path = os.path.join(txt_path, f"{args.exp_name}_sample{sample_nums}_image_reward.txt")
    if os.path.exists(save_txt_path):
        with open(save_txt_path) as f:
            image_reward_value = f.readlines()[0].strip()
        print(f"Image Reward {image_reward_value}: {args.exp_name}")
        return {args.exp_name: float(image_reward_value)}

    total_scores = 0
    count = 0
    for k, v in tqdm(
        prompt_json.items(), desc=f"ImageReward {args.sample_per_prompt} images / prompt: {args.exp_name}"
    ):
        for i in range(args.sample_per_prompt):
            img_path = os.path.join(args.img_path, args.exp_name, f"{k}_{i}.jpg")
            score = model.score(v["prompt"], img_path)
            total_scores += score
            count += 1

    image_reward_value = total_scores / count
    print(f"Image Reward {image_reward_value}: {args.exp_name}")
    with open(save_txt_path, "w") as file:
        file.write(str(image_reward_value))

    return {args.exp_name: image_reward_value}


if __name__ == "__main__":
    args = parse_args()
    sample_nums = args.sample_nums

    model = RM.load("ImageReward-v1.0")
    prompt_json = json.load(open(args.json_path))
    print(args.img_path, args.exp_name)
    args.exp_name = os.path.basename(args.exp_name) or os.path.dirname(args.exp_name)

    image_reward_result = main()

    if args.log_image_reward:
        tracker(args, image_reward_result, args.suffix_label, pattern=args.tracker_pattern, metric="ImageReward")