Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import json | |
| import argparse | |
| from collections import defaultdict | |
| import random | |
| import numpy as np | |
| from PIL import Image | |
| from tqdm import tqdm | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from minigpt4.common.config import Config | |
| from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser, computeIoU | |
| from minigpt4.conversation.conversation import CONV_VISION_minigptv2 | |
| from minigpt4.datasets.datasets.coco_caption import RefCOCOEvalData | |
| def list_of_str(arg): | |
| return list(map(str, arg.split(','))) | |
| parser = eval_parser() | |
| parser.add_argument("--dataset", type=list_of_str, default='refcoco', help="dataset to evaluate") | |
| parser.add_argument("--res", type=float, default=100.0, help="resolution used in refcoco") | |
| parser.add_argument("--resample", action='store_true', help="resolution used in refcoco") | |
| args = parser.parse_args() | |
| cfg = Config(args) | |
| eval_dict = {'refcoco': ['val','testA','testB'], | |
| 'refcoco+': ['val','testA','testB'], | |
| 'refcocog': ['val','test']} | |
| model, vis_processor = init_model(args) | |
| model.eval() | |
| CONV_VISION = CONV_VISION_minigptv2 | |
| conv_temp = CONV_VISION.copy() | |
| conv_temp.system = "" | |
| # | |
| model.eval() | |
| save_path = cfg.run_cfg.save_path | |
| for dataset in args.dataset: | |
| for split in eval_dict[dataset]: | |
| eval_file_path = cfg.evaluation_datasets_cfg[dataset]["eval_file_path"] | |
| img_path = cfg.evaluation_datasets_cfg[dataset]["img_path"] | |
| batch_size = cfg.evaluation_datasets_cfg[dataset]["batch_size"] | |
| max_new_tokens = cfg.evaluation_datasets_cfg[dataset]["max_new_tokens"] | |
| with open(os.path.join(eval_file_path,f"{dataset}/{dataset}_{split}.json"), 'r') as f: | |
| refcoco = json.load(f) | |
| data = RefCOCOEvalData(refcoco, vis_processor, img_path) | |
| eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False) | |
| minigpt4_predict = defaultdict(list) | |
| resamples = [] | |
| for images, questions, img_ids in tqdm(eval_dataloader): | |
| texts = prepare_texts(questions, conv_temp) # warp the texts with conversation template | |
| answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False) | |
| for answer, img_id, question in zip(answers, img_ids, questions): | |
| answer = answer.replace("<unk>","").replace(" ","").strip() | |
| pattern = r'\{<\d{1,3}><\d{1,3}><\d{1,3}><\d{1,3}>\}' | |
| if re.match(pattern, answer): | |
| minigpt4_predict[img_id].append(answer) | |
| else: | |
| resamples.append({'img_id': img_id, 'sents': [question.replace('[refer] give me the location of','').strip()]}) | |
| if args.resample: | |
| for i in range(20): | |
| data = RefCOCOEvalData(resamples, vis_processor, img_path) | |
| resamples = [] | |
| eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False) | |
| for images, questions, img_ids in tqdm(eval_dataloader): | |
| texts = prepare_texts(questions, conv_temp) # warp the texts with conversation template | |
| answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False) | |
| for answer, img_id, question in zip(answers, img_ids, questions): | |
| answer = answer.replace("<unk>","").replace(" ","").strip() | |
| pattern = r'\{<\d{1,3}><\d{1,3}><\d{1,3}><\d{1,3}>\}' | |
| if re.match(pattern, answer) or i == 4: | |
| minigpt4_predict[img_id].append(answer) | |
| else: | |
| resamples.append({'img_id': img_id, 'sents': [question.replace('[refer] give me the location of','').strip()]}) | |
| if len(resamples) == 0: | |
| break | |
| file_save_path = os.path.join(save_path,f"{args.dataset}_{split}.json") | |
| with open(file_save_path,'w') as f: | |
| json.dump(minigpt4_predict, f) | |
| count=0 | |
| total=len(refcoco) | |
| res=args.res | |
| refcoco_dict = defaultdict() | |
| for item in refcoco: | |
| refcoco_dict[item['img_id']] = item | |
| for img_id in refcoco_dict: | |
| item = refcoco_dict[img_id] | |
| bbox = item['bbox'] | |
| outputs = minigpt4_predict[img_id] | |
| for output in outputs: | |
| try: | |
| integers = re.findall(r'\d+', output) | |
| pred_bbox = [int(num) for num in integers] | |
| height = item['height'] | |
| width = item['width'] | |
| pred_bbox[0] = pred_bbox[0] / res * width | |
| pred_bbox[1] = pred_bbox[1] / res * height | |
| pred_bbox[2] = pred_bbox[2] / res * width | |
| pred_bbox[3] = pred_bbox[3] / res * height | |
| gt_bbox = [0,0,0,0] | |
| gt_bbox[0] = bbox[0] | |
| gt_bbox[1] = bbox[1] | |
| gt_bbox[2] = bbox[0] + bbox[2] | |
| gt_bbox[3] = bbox[1] + bbox[3] | |
| iou_score = computeIoU(pred_bbox, gt_bbox) | |
| if iou_score > 0.5: | |
| count+=1 | |
| except: | |
| continue | |
| print(f'{dataset} {split}:', count / total * 100, flush=True) | |