File size: 3,838 Bytes
bbfa6f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# compute chair for each video
import json
import collections 
import argparse 
from pathlib import Path

def eval_video_chair(file_name, metric):
    with file_name.open("r") as json_file:
        data = json.load(json_file)

    items = {}
    coverages = collections.defaultdict(list)
    hallucinations = collections.defaultdict(list)
    buckets = ['subjects', 'attributes', 'activities', 'locations', 'text_overlays']
    index = 0
    for object_id, tag_info in data.items():
        items[object_id] = index
        for tag in buckets:
            if tag in tag_info:
                cvg = round(tag_info[tag][0]*100 / tag_info[tag][1], 2)
                coverages[tag].append(cvg) if metric == "coverage" else hallucinations[tag].append(round(100 - cvg, 2))
            else: # "-100" means gt has no such tag for coverage and pred has no such tag for hallucination, leading to N/A value.
                coverages[tag].append(-100) if metric == "coverage"  else hallucinations[tag].append(-100)
        index += 1
    return (items, coverages) if metric == "coverage" else (items, hallucinations)


def get_dict_val(inputs, items, key):
    for dd in inputs:
        if str(dd["object_id"]) == str(items):
            return dd["cap_info"][key] if key in dd["cap_info"] else []                
    return []


def get_instance_result(pred_file, gt_file, coverage_file, hallucination_file, save_file):
    buckets = ['subjects', 'attributes', 'activities', 'locations', 'text_overlays']
    pred = json.load(open(pred_file, "r"))
    gt = json.load(open(gt_file, "r"))
    output_dir = Path(pred_file).parent

    items1, coverages  = eval_video_chair(output_dir / coverage_file, "coverage")
    items2, hallucinations = eval_video_chair(output_dir / hallucination_file, "hallucination")

    gt_map = {str(item['object_id']): item for item in gt}
    pred_map = {str(item['object_id']): item for item in pred}

    out = []
    for obj_id, idx_1 in items1.items():
        if obj_id not in items2:
            continue
        idx_2 = items2[obj_id]
        res = {}
        for key in buckets:
            res["object_id"] = obj_id
            res["coverage_"+key] = coverages[key][idx_1] if coverages[key][idx_1] != -100 else "N/A"
            res["hallucination_"+key] = hallucinations[key][idx_2] if hallucinations[key][idx_2] != -100 else "N/A"
            if key == "attributes": # "skip attributes which are combined in subjects"
                continue
            res["pred_"+key] = get_dict_val(pred, obj_id, key)
            res["gt_"+key] = get_dict_val(gt, obj_id, key)
            res['masp_inference'] = pred_map[obj_id]['masp_inference']
            res['refine_caption'] = gt_map[obj_id]['refine_caption']
        out.append(res) 

    
    with (output_dir / save_file).open("w") as json_data:
        json.dump(out, json_data, indent=4)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--pred_file", type=str, default='/mnt/bn/algo-masp-nas-2/xiangchen/model/masp_models/checkpoints/llava-mistral_gpt4v_public800k_unfreeze_qformer/video_chair/video_chair_1k_res_info.json')
    parser.add_argument("--gt_file", type=str, default='/mnt/bn/algo-masp-nas-2/kaili.zhao/data/masp_data/eval/eval_v1.0/eval_benchmark_pos_diverse_1k_11policies_gt.json')
    parser.add_argument("--coverage_file", type=str, default='each_video_coverage_detail.json')
    parser.add_argument("--hallucination_file", type=str, default='each_video_halluciantion_detail.json')
    parser.add_argument("--save_file", type=str, default='video_chair_final.json')
    args = parser.parse_args()
    get_instance_result(args.pred_file, args.gt_file, args.coverage_file, args.hallucination_file, args.save_file)
    print(f"===== Completed video chair for each individual computation! =====")