File size: 5,385 Bytes
bbfa6f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from operator import truediv
import os
import re
import json
import sys
import argparse
# from nltk.stem import *
# import nltk
import openai
from abc import ABC, abstractmethod
# from pattern3.en import singularize
# from nltk.stem import WordNetLemmatizer
# from call_dino_service import 
from tqdm import tqdm
from llava.eval.masp_eval.utils import GPTAPIWrapper 

import time

class BaseAPIWrapper(ABC):
    @abstractmethod
    def get_completion(self, user_prompt, system_prompt=None):
        pass

class CHAIR():

    def __init__(self) -> None:
        super().__init__()
        self.system_prompt = "I am ChatGPT, a virtual assistant based on OpenAI's GPT-4 model. I'm designed to understand and generate human-like text based on the input I receive. My main purpose is to assist with information, answer questions, help with tasks that involve natural language processing, and engage in conversations with users.Please note that while I aim to provide accurate and reliable information, I can't guarantee perfection, and it's always a good idea to consult additional resources or professionals when making critical decisions based on the information I provide."
        # self.openai_obj = OpenAIAPIWrapper(key_pool=["VrJQmRwcwnRW3KVEDaE8D9gYZm2a0zPm", "GjrgjjyJHUbLa15DLnr7t0Bhu6IPqFPj"])
        self.openai_obj = GPTAPIWrapper(ak="GjrgjjyJHUbLa15DLnr7t0Bhu6IPqFPj")
        with open('llava/eval/masp_eval/video_chair/prompts/cap2info.txt', 'r') as file:
            content = file.read()
        self.cap_user_prompt = content
        with open('llava/eval/masp_eval/video_chair/prompts/refine_json.txt', 'r') as file:
            content = file.read()
        self.cap_user_prompt_deduplicate = content     
    
    def cap2info_gpt4(self, cap):
        user_prompt = self.cap_user_prompt.replace('/video caption/', cap)
        gpt_ret1, msgs = self.openai_obj.get_completion(user_prompt=user_prompt, system_prompt=self.system_prompt)
        user_prompt = self.cap_user_prompt_deduplicate.replace('/json file/', gpt_ret1)
        gpt_ret2, msgs = self.openai_obj.get_completion(user_prompt=user_prompt, system_prompt=self.system_prompt, previous_msgs=msgs, last_answer=gpt_ret1)
        match = re.search(r"(?<=```json\n)([\s\S]*?)(?=```)", gpt_ret2)
        if match:
            try:
                info = json.loads(match.group(1))
            except Exception as e:
                print(match.group(1))
                info = None
            # Split the string into a list of items
            return info
        else:
            try:
                start = gpt_ret2.find('{')
                end = gpt_ret2.rfind('}')
                info = json.loads(gpt_ret2[start:end+1])
                return info
            except Exception as e:
                print(gpt_ret1)
                print(gpt_ret2)
                return None


def post_process_masp_cap_label(evaluator, annotations_file, gt=True):
    results = []
    with open(annotations_file, 'r', encoding='utf-8') as f:
        annotations = json.load(f)
    for data in tqdm(annotations):
        if gt:
            caption = data['refine_caption']
        else:
            caption = data['masp_inference']
        cap_info = evaluator.cap2info_gpt4(caption)
        data['cap_info'] = cap_info
        results.append(data)
    return results


from multiprocessing import Pool

evaluator = CHAIR()

# Function to process a single data item
def process_data(data, gt):
    if gt:
        caption = data['refine_caption']
    else:
        caption = data['masp_inference']
    cap_info = evaluator.cap2info_gpt4(caption)
    data['cap_info'] = cap_info
    return data

# Function to initialize the multiprocessing pool and process the data
def process_annotations(annotations_file, gt=False):
    # Load annotations
    with open(annotations_file, 'r', encoding='utf-8') as f:
        annotations = json.load(f)

    # Create a pool of workers equal to the number of available CPU cores
    pool = Pool(processes=32)  # None means use all available cores

    # Use a partial function to fix the gt and evaluator arguments
    from functools import partial
    process_data_partial = partial(process_data, gt=gt)

    # Map the data processing function over the annotations using the pool
    # pool.map(process_data_partial, annotations)
    res = []
    for data in tqdm(pool.imap_unordered(process_data_partial, annotations), total=len(annotations)):
        res.append(data)
    # Close the pool and wait for the work to finish
    pool.close()
    pool.join()
    return res



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--cap_file", type=str, default='/mnt/bn/algo-masp-nas-2/xiangchen/model/masp_models/checkpoints/llava-mistral_gpt4v_adso65k_unfreeze_qformer/video_chair/vid_top1k_res.json')
    parser.add_argument("--output_file", type=str, default='/mnt/bn/algo-masp-nas-2/xiangchen/model/masp_models/checkpoints/llava-mistral_gpt4v_adso65k_unfreeze_qformer/video_chair/vid_top1k_res_info.json')
    parser.add_argument("--gt", type=bool, default=False)

    args = parser.parse_args()
    
    # post_anno = post_process_masp_cap_label(evaluator, args.cap_file, args.gt)
    post_anno = process_annotations(args.cap_file, args.gt)
    with open(f"{args.output_file}", "w") as file:
            json.dump(post_anno, file, indent=4)