from operator import truediv import os import re import json import sys import argparse # from nltk.stem import * # import nltk import openai from abc import ABC, abstractmethod # from pattern3.en import singularize # from nltk.stem import WordNetLemmatizer # from call_dino_service import from tqdm import tqdm from llava.eval.masp_eval.utils import GPTAPIWrapper import time class BaseAPIWrapper(ABC): @abstractmethod def get_completion(self, user_prompt, system_prompt=None): pass class CHAIR(): def __init__(self) -> None: super().__init__() self.system_prompt = "I am ChatGPT, a virtual assistant based on OpenAI's GPT-4 model. I'm designed to understand and generate human-like text based on the input I receive. My main purpose is to assist with information, answer questions, help with tasks that involve natural language processing, and engage in conversations with users.Please note that while I aim to provide accurate and reliable information, I can't guarantee perfection, and it's always a good idea to consult additional resources or professionals when making critical decisions based on the information I provide." # self.openai_obj = OpenAIAPIWrapper(key_pool=["VrJQmRwcwnRW3KVEDaE8D9gYZm2a0zPm", "GjrgjjyJHUbLa15DLnr7t0Bhu6IPqFPj"]) self.openai_obj = GPTAPIWrapper(ak="GjrgjjyJHUbLa15DLnr7t0Bhu6IPqFPj") with open('llava/eval/masp_eval/video_chair/prompts/cap2info.txt', 'r') as file: content = file.read() self.cap_user_prompt = content with open('llava/eval/masp_eval/video_chair/prompts/refine_json.txt', 'r') as file: content = file.read() self.cap_user_prompt_deduplicate = content def cap2info_gpt4(self, cap): user_prompt = self.cap_user_prompt.replace('/video caption/', cap) gpt_ret1, msgs = self.openai_obj.get_completion(user_prompt=user_prompt, system_prompt=self.system_prompt) user_prompt = self.cap_user_prompt_deduplicate.replace('/json file/', gpt_ret1) gpt_ret2, msgs = self.openai_obj.get_completion(user_prompt=user_prompt, system_prompt=self.system_prompt, previous_msgs=msgs, last_answer=gpt_ret1) match = re.search(r"(?<=```json\n)([\s\S]*?)(?=```)", gpt_ret2) if match: try: info = json.loads(match.group(1)) except Exception as e: print(match.group(1)) info = None # Split the string into a list of items return info else: try: start = gpt_ret2.find('{') end = gpt_ret2.rfind('}') info = json.loads(gpt_ret2[start:end+1]) return info except Exception as e: print(gpt_ret1) print(gpt_ret2) return None def post_process_masp_cap_label(evaluator, annotations_file, gt=True): results = [] with open(annotations_file, 'r', encoding='utf-8') as f: annotations = json.load(f) for data in tqdm(annotations): if gt: caption = data['refine_caption'] else: caption = data['masp_inference'] cap_info = evaluator.cap2info_gpt4(caption) data['cap_info'] = cap_info results.append(data) return results from multiprocessing import Pool evaluator = CHAIR() # Function to process a single data item def process_data(data, gt): if gt: caption = data['refine_caption'] else: caption = data['masp_inference'] cap_info = evaluator.cap2info_gpt4(caption) data['cap_info'] = cap_info return data # Function to initialize the multiprocessing pool and process the data def process_annotations(annotations_file, gt=False): # Load annotations with open(annotations_file, 'r', encoding='utf-8') as f: annotations = json.load(f) # Create a pool of workers equal to the number of available CPU cores pool = Pool(processes=32) # None means use all available cores # Use a partial function to fix the gt and evaluator arguments from functools import partial process_data_partial = partial(process_data, gt=gt) # Map the data processing function over the annotations using the pool # pool.map(process_data_partial, annotations) res = [] for data in tqdm(pool.imap_unordered(process_data_partial, annotations), total=len(annotations)): res.append(data) # Close the pool and wait for the work to finish pool.close() pool.join() return res if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--cap_file", type=str, default='/mnt/bn/algo-masp-nas-2/xiangchen/model/masp_models/checkpoints/llava-mistral_gpt4v_adso65k_unfreeze_qformer/video_chair/vid_top1k_res.json') parser.add_argument("--output_file", type=str, default='/mnt/bn/algo-masp-nas-2/xiangchen/model/masp_models/checkpoints/llava-mistral_gpt4v_adso65k_unfreeze_qformer/video_chair/vid_top1k_res_info.json') parser.add_argument("--gt", type=bool, default=False) args = parser.parse_args() # post_anno = post_process_masp_cap_label(evaluator, args.cap_file, args.gt) post_anno = process_annotations(args.cap_file, args.gt) with open(f"{args.output_file}", "w") as file: json.dump(post_anno, file, indent=4)