File size: 5,385 Bytes
bbfa6f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
from operator import truediv
import os
import re
import json
import sys
import argparse
# from nltk.stem import *
# import nltk
import openai
from abc import ABC, abstractmethod
# from pattern3.en import singularize
# from nltk.stem import WordNetLemmatizer
# from call_dino_service import
from tqdm import tqdm
from llava.eval.masp_eval.utils import GPTAPIWrapper
import time
class BaseAPIWrapper(ABC):
@abstractmethod
def get_completion(self, user_prompt, system_prompt=None):
pass
class CHAIR():
def __init__(self) -> None:
super().__init__()
self.system_prompt = "I am ChatGPT, a virtual assistant based on OpenAI's GPT-4 model. I'm designed to understand and generate human-like text based on the input I receive. My main purpose is to assist with information, answer questions, help with tasks that involve natural language processing, and engage in conversations with users.Please note that while I aim to provide accurate and reliable information, I can't guarantee perfection, and it's always a good idea to consult additional resources or professionals when making critical decisions based on the information I provide."
# self.openai_obj = OpenAIAPIWrapper(key_pool=["VrJQmRwcwnRW3KVEDaE8D9gYZm2a0zPm", "GjrgjjyJHUbLa15DLnr7t0Bhu6IPqFPj"])
self.openai_obj = GPTAPIWrapper(ak="GjrgjjyJHUbLa15DLnr7t0Bhu6IPqFPj")
with open('llava/eval/masp_eval/video_chair/prompts/cap2info.txt', 'r') as file:
content = file.read()
self.cap_user_prompt = content
with open('llava/eval/masp_eval/video_chair/prompts/refine_json.txt', 'r') as file:
content = file.read()
self.cap_user_prompt_deduplicate = content
def cap2info_gpt4(self, cap):
user_prompt = self.cap_user_prompt.replace('/video caption/', cap)
gpt_ret1, msgs = self.openai_obj.get_completion(user_prompt=user_prompt, system_prompt=self.system_prompt)
user_prompt = self.cap_user_prompt_deduplicate.replace('/json file/', gpt_ret1)
gpt_ret2, msgs = self.openai_obj.get_completion(user_prompt=user_prompt, system_prompt=self.system_prompt, previous_msgs=msgs, last_answer=gpt_ret1)
match = re.search(r"(?<=```json\n)([\s\S]*?)(?=```)", gpt_ret2)
if match:
try:
info = json.loads(match.group(1))
except Exception as e:
print(match.group(1))
info = None
# Split the string into a list of items
return info
else:
try:
start = gpt_ret2.find('{')
end = gpt_ret2.rfind('}')
info = json.loads(gpt_ret2[start:end+1])
return info
except Exception as e:
print(gpt_ret1)
print(gpt_ret2)
return None
def post_process_masp_cap_label(evaluator, annotations_file, gt=True):
results = []
with open(annotations_file, 'r', encoding='utf-8') as f:
annotations = json.load(f)
for data in tqdm(annotations):
if gt:
caption = data['refine_caption']
else:
caption = data['masp_inference']
cap_info = evaluator.cap2info_gpt4(caption)
data['cap_info'] = cap_info
results.append(data)
return results
from multiprocessing import Pool
evaluator = CHAIR()
# Function to process a single data item
def process_data(data, gt):
if gt:
caption = data['refine_caption']
else:
caption = data['masp_inference']
cap_info = evaluator.cap2info_gpt4(caption)
data['cap_info'] = cap_info
return data
# Function to initialize the multiprocessing pool and process the data
def process_annotations(annotations_file, gt=False):
# Load annotations
with open(annotations_file, 'r', encoding='utf-8') as f:
annotations = json.load(f)
# Create a pool of workers equal to the number of available CPU cores
pool = Pool(processes=32) # None means use all available cores
# Use a partial function to fix the gt and evaluator arguments
from functools import partial
process_data_partial = partial(process_data, gt=gt)
# Map the data processing function over the annotations using the pool
# pool.map(process_data_partial, annotations)
res = []
for data in tqdm(pool.imap_unordered(process_data_partial, annotations), total=len(annotations)):
res.append(data)
# Close the pool and wait for the work to finish
pool.close()
pool.join()
return res
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--cap_file", type=str, default='/mnt/bn/algo-masp-nas-2/xiangchen/model/masp_models/checkpoints/llava-mistral_gpt4v_adso65k_unfreeze_qformer/video_chair/vid_top1k_res.json')
parser.add_argument("--output_file", type=str, default='/mnt/bn/algo-masp-nas-2/xiangchen/model/masp_models/checkpoints/llava-mistral_gpt4v_adso65k_unfreeze_qformer/video_chair/vid_top1k_res_info.json')
parser.add_argument("--gt", type=bool, default=False)
args = parser.parse_args()
# post_anno = post_process_masp_cap_label(evaluator, args.cap_file, args.gt)
post_anno = process_annotations(args.cap_file, args.gt)
with open(f"{args.output_file}", "w") as file:
json.dump(post_anno, file, indent=4)
|