|
from operator import truediv |
|
import os |
|
import re |
|
import json |
|
import sys |
|
import argparse |
|
|
|
|
|
import openai |
|
from abc import ABC, abstractmethod |
|
|
|
|
|
|
|
from tqdm import tqdm |
|
from llava.eval.masp_eval.utils import GPTAPIWrapper |
|
|
|
import time |
|
|
|
class BaseAPIWrapper(ABC): |
|
@abstractmethod |
|
def get_completion(self, user_prompt, system_prompt=None): |
|
pass |
|
|
|
class CHAIR(): |
|
|
|
def __init__(self) -> None: |
|
super().__init__() |
|
self.system_prompt = "I am ChatGPT, a virtual assistant based on OpenAI's GPT-4 model. I'm designed to understand and generate human-like text based on the input I receive. My main purpose is to assist with information, answer questions, help with tasks that involve natural language processing, and engage in conversations with users.Please note that while I aim to provide accurate and reliable information, I can't guarantee perfection, and it's always a good idea to consult additional resources or professionals when making critical decisions based on the information I provide." |
|
|
|
self.openai_obj = GPTAPIWrapper(ak="GjrgjjyJHUbLa15DLnr7t0Bhu6IPqFPj") |
|
with open('llava/eval/masp_eval/video_chair/prompts/cap2info.txt', 'r') as file: |
|
content = file.read() |
|
self.cap_user_prompt = content |
|
with open('llava/eval/masp_eval/video_chair/prompts/refine_json.txt', 'r') as file: |
|
content = file.read() |
|
self.cap_user_prompt_deduplicate = content |
|
|
|
def cap2info_gpt4(self, cap): |
|
user_prompt = self.cap_user_prompt.replace('/video caption/', cap) |
|
gpt_ret1, msgs = self.openai_obj.get_completion(user_prompt=user_prompt, system_prompt=self.system_prompt) |
|
user_prompt = self.cap_user_prompt_deduplicate.replace('/json file/', gpt_ret1) |
|
gpt_ret2, msgs = self.openai_obj.get_completion(user_prompt=user_prompt, system_prompt=self.system_prompt, previous_msgs=msgs, last_answer=gpt_ret1) |
|
match = re.search(r"(?<=```json\n)([\s\S]*?)(?=```)", gpt_ret2) |
|
if match: |
|
try: |
|
info = json.loads(match.group(1)) |
|
except Exception as e: |
|
print(match.group(1)) |
|
info = None |
|
|
|
return info |
|
else: |
|
try: |
|
start = gpt_ret2.find('{') |
|
end = gpt_ret2.rfind('}') |
|
info = json.loads(gpt_ret2[start:end+1]) |
|
return info |
|
except Exception as e: |
|
print(gpt_ret1) |
|
print(gpt_ret2) |
|
return None |
|
|
|
|
|
def post_process_masp_cap_label(evaluator, annotations_file, gt=True): |
|
results = [] |
|
with open(annotations_file, 'r', encoding='utf-8') as f: |
|
annotations = json.load(f) |
|
for data in tqdm(annotations): |
|
if gt: |
|
caption = data['refine_caption'] |
|
else: |
|
caption = data['masp_inference'] |
|
cap_info = evaluator.cap2info_gpt4(caption) |
|
data['cap_info'] = cap_info |
|
results.append(data) |
|
return results |
|
|
|
|
|
from multiprocessing import Pool |
|
|
|
evaluator = CHAIR() |
|
|
|
|
|
def process_data(data, gt): |
|
if gt: |
|
caption = data['refine_caption'] |
|
else: |
|
caption = data['masp_inference'] |
|
cap_info = evaluator.cap2info_gpt4(caption) |
|
data['cap_info'] = cap_info |
|
return data |
|
|
|
|
|
def process_annotations(annotations_file, gt=False): |
|
|
|
with open(annotations_file, 'r', encoding='utf-8') as f: |
|
annotations = json.load(f) |
|
|
|
|
|
pool = Pool(processes=32) |
|
|
|
|
|
from functools import partial |
|
process_data_partial = partial(process_data, gt=gt) |
|
|
|
|
|
|
|
res = [] |
|
for data in tqdm(pool.imap_unordered(process_data_partial, annotations), total=len(annotations)): |
|
res.append(data) |
|
|
|
pool.close() |
|
pool.join() |
|
return res |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--cap_file", type=str, default='/mnt/bn/algo-masp-nas-2/xiangchen/model/masp_models/checkpoints/llava-mistral_gpt4v_adso65k_unfreeze_qformer/video_chair/vid_top1k_res.json') |
|
parser.add_argument("--output_file", type=str, default='/mnt/bn/algo-masp-nas-2/xiangchen/model/masp_models/checkpoints/llava-mistral_gpt4v_adso65k_unfreeze_qformer/video_chair/vid_top1k_res_info.json') |
|
parser.add_argument("--gt", type=bool, default=False) |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
post_anno = process_annotations(args.cap_file, args.gt) |
|
with open(f"{args.output_file}", "w") as file: |
|
json.dump(post_anno, file, indent=4) |
|
|
|
|