model1 / llava /eval /masp_eval /video_chair /info_extract_from_caption.py
Wangpeng An
Upload folder using huggingface_hub
bbfa6f6 verified
from operator import truediv
import os
import re
import json
import sys
import argparse
# from nltk.stem import *
# import nltk
import openai
from abc import ABC, abstractmethod
# from pattern3.en import singularize
# from nltk.stem import WordNetLemmatizer
# from call_dino_service import
from tqdm import tqdm
from llava.eval.masp_eval.utils import GPTAPIWrapper
import time
class BaseAPIWrapper(ABC):
@abstractmethod
def get_completion(self, user_prompt, system_prompt=None):
pass
class CHAIR():
def __init__(self) -> None:
super().__init__()
self.system_prompt = "I am ChatGPT, a virtual assistant based on OpenAI's GPT-4 model. I'm designed to understand and generate human-like text based on the input I receive. My main purpose is to assist with information, answer questions, help with tasks that involve natural language processing, and engage in conversations with users.Please note that while I aim to provide accurate and reliable information, I can't guarantee perfection, and it's always a good idea to consult additional resources or professionals when making critical decisions based on the information I provide."
# self.openai_obj = OpenAIAPIWrapper(key_pool=["VrJQmRwcwnRW3KVEDaE8D9gYZm2a0zPm", "GjrgjjyJHUbLa15DLnr7t0Bhu6IPqFPj"])
self.openai_obj = GPTAPIWrapper(ak="GjrgjjyJHUbLa15DLnr7t0Bhu6IPqFPj")
with open('llava/eval/masp_eval/video_chair/prompts/cap2info.txt', 'r') as file:
content = file.read()
self.cap_user_prompt = content
with open('llava/eval/masp_eval/video_chair/prompts/refine_json.txt', 'r') as file:
content = file.read()
self.cap_user_prompt_deduplicate = content
def cap2info_gpt4(self, cap):
user_prompt = self.cap_user_prompt.replace('/video caption/', cap)
gpt_ret1, msgs = self.openai_obj.get_completion(user_prompt=user_prompt, system_prompt=self.system_prompt)
user_prompt = self.cap_user_prompt_deduplicate.replace('/json file/', gpt_ret1)
gpt_ret2, msgs = self.openai_obj.get_completion(user_prompt=user_prompt, system_prompt=self.system_prompt, previous_msgs=msgs, last_answer=gpt_ret1)
match = re.search(r"(?<=```json\n)([\s\S]*?)(?=```)", gpt_ret2)
if match:
try:
info = json.loads(match.group(1))
except Exception as e:
print(match.group(1))
info = None
# Split the string into a list of items
return info
else:
try:
start = gpt_ret2.find('{')
end = gpt_ret2.rfind('}')
info = json.loads(gpt_ret2[start:end+1])
return info
except Exception as e:
print(gpt_ret1)
print(gpt_ret2)
return None
def post_process_masp_cap_label(evaluator, annotations_file, gt=True):
results = []
with open(annotations_file, 'r', encoding='utf-8') as f:
annotations = json.load(f)
for data in tqdm(annotations):
if gt:
caption = data['refine_caption']
else:
caption = data['masp_inference']
cap_info = evaluator.cap2info_gpt4(caption)
data['cap_info'] = cap_info
results.append(data)
return results
from multiprocessing import Pool
evaluator = CHAIR()
# Function to process a single data item
def process_data(data, gt):
if gt:
caption = data['refine_caption']
else:
caption = data['masp_inference']
cap_info = evaluator.cap2info_gpt4(caption)
data['cap_info'] = cap_info
return data
# Function to initialize the multiprocessing pool and process the data
def process_annotations(annotations_file, gt=False):
# Load annotations
with open(annotations_file, 'r', encoding='utf-8') as f:
annotations = json.load(f)
# Create a pool of workers equal to the number of available CPU cores
pool = Pool(processes=32) # None means use all available cores
# Use a partial function to fix the gt and evaluator arguments
from functools import partial
process_data_partial = partial(process_data, gt=gt)
# Map the data processing function over the annotations using the pool
# pool.map(process_data_partial, annotations)
res = []
for data in tqdm(pool.imap_unordered(process_data_partial, annotations), total=len(annotations)):
res.append(data)
# Close the pool and wait for the work to finish
pool.close()
pool.join()
return res
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--cap_file", type=str, default='/mnt/bn/algo-masp-nas-2/xiangchen/model/masp_models/checkpoints/llava-mistral_gpt4v_adso65k_unfreeze_qformer/video_chair/vid_top1k_res.json')
parser.add_argument("--output_file", type=str, default='/mnt/bn/algo-masp-nas-2/xiangchen/model/masp_models/checkpoints/llava-mistral_gpt4v_adso65k_unfreeze_qformer/video_chair/vid_top1k_res_info.json')
parser.add_argument("--gt", type=bool, default=False)
args = parser.parse_args()
# post_anno = post_process_masp_cap_label(evaluator, args.cap_file, args.gt)
post_anno = process_annotations(args.cap_file, args.gt)
with open(f"{args.output_file}", "w") as file:
json.dump(post_anno, file, indent=4)