File size: 10,161 Bytes
a9d25c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233

import os
import torch
from openai import OpenAI
from termcolor import colored

import transformers
# from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login

# environment variables and paths
from .env_utils import get_device, low_vram_mode

device = get_device()

class GPT:
    def __init__(self, model="gpt-4o-mini", api_key=None):
        self.prices = {
            # check at https://openai.com/api/pricing/
            "gpt-3.5-turbo-0125": [0.0000005, 0.0000015],
            "gpt-4o-mini"       : [0.00000015, 0.00000060],
            "gpt-4-1106-preview": [0.00001, 0.00003],
            "gpt-4-0125-preview": [0.00001, 0.00003],
            "gpt-4-turbo"       : [0.00001, 0.00003],
            "gpt-4o"            : [0.000005, 0.000015],
        }
        self.cheaper_model = "gpt-4o-mini"
        assert model in self.prices.keys(), "Invalid model, please choose from: {}, or add new models in the code.".format(self.prices.keys())
        self.model = model
        print(f"Using {model}")
        self.client = OpenAI(api_key=api_key)
        self.total_cost = 0.0

    def _update(self, response, price):
        current_cost = response.usage.completion_tokens * price[0] + response.usage.prompt_tokens * price[1]
        self.total_cost += current_cost
        # print in 4 decimal places
        print(
            colored(
                f"Current Tokens: {response.usage.completion_tokens + response.usage.prompt_tokens:d} \
                Current cost: {current_cost:.4f} $, \
                Total cost: {self.total_cost:.4f} $",
                "yellow",
            )
        )

    def chat(self, messages, temperature=0.0, max_tokens=200, post=False):
        # set temperature to 0.0 for more deterministic results
        if post:
            # use cheaper model for post-refinement to save costs, since the task is simpler.
            generated_text = self.client.chat.completions.create(
                model=self.cheaper_model, messages=messages, temperature=temperature, max_tokens=max_tokens
            )
            self._update(generated_text, self.prices[self.cheaper_model])
        else:
            generated_text = self.client.chat.completions.create(
                model=self.model, messages=messages, temperature=temperature, max_tokens=max_tokens
            )
            self._update(generated_text, self.prices[self.model])
        generated_text = generated_text.choices[0].message.content
        return generated_text


class Llama3:
    def __init__(self, model="Meta-Llama-3-8B-Instruct"):
        login(token=os.getenv('HF_TOKEN'))
        model = "meta-llama/{}".format(model)  # or replace with your local model path
        print(f"Using {model}")
        # ZeroGPU does not support quantization.
        # tokenizer = AutoTokenizer.from_pretrained(model)
        # if low_vram_mode:
        #     model = AutoModelForCausalLM.from_pretrained(
        #         model, quantization_config=BitsAndBytesConfig(load_in_8bit=True), device_map="auto"
        #     ).eval()
        self.pipeline = transformers.pipeline(
            "text-generation",
            model        = model,
            # tokenizer    = tokenizer,
            model_kwargs = {"torch_dtype": torch.bfloat16},
            device_map   = "auto",
        )
        self.terminators = [self.pipeline.tokenizer.eos_token_id, self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")]

    def _update(self):
        print(colored("Using Llama-3, Free", "green"))

    def chat(self, messages, temperature=0.0, max_tokens=200, post=False):
        prompt = self.pipeline.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        generated_text = self.pipeline(
            prompt,
            max_new_tokens = max_tokens,
            eos_token_id   = self.terminators,
            pad_token_id   = 128001,
            do_sample      = True,
            temperature    = max(temperature, 0.01), # 0.0 is not supported
            top_p          = 0.9,
        )
        self._update()
        generated_text = generated_text[0]["generated_text"][len(prompt) :]
        return generated_text


# Define the timeout handler
def timeout_handler(signum, frame):
    raise TimeoutError()


def init_model(model, api_key=None):
    if "gpt" in model:
        return GPT(model=model, api_key=api_key)
    elif "Llama" in model:
        return Llama3(model=model)
    else:
        raise ValueError("Invalid model")


def _generate_example_prompt(examples, llm=None):
    # system prompt
    system_prompt = """
    Task Description:
    - you will provide detailed explanations for example inputs and outputs within the context of the task.

    Please adhere to the following rules:
    - Exclude terms that appear in both lists.
    - Detail the relevance of unmatched terms from input to output, focusing on indirect relationships.
    - Identify and explain terms common to all output lists but rarely present in input lists; include these at the end of the output labeled 'Recommend Include Labels'.
    - Each explanation should be concise, around 50 words.

    Output Format:
    - '1. Input... Output... Explanation... n. Input... Output... Explanation... \n Recommend Include Labels: label1, labeln, ...'
    """
    messages = [
        {"role": "system", "content": system_prompt},
        {
            "role": "user",
            "content": f"Here are the input and output lists for which you need to provide detailed explanations:{examples.strip()}",
        },
    ]
    generated_example = llm.chat(messages, temperature=0.0, max_tokens=1000)
    return generated_example


def _make_prompt(label_list, example=None):
    Cityscape = "sidewalk" in label_list
    if Cityscape:
        add_text = f'contain at least {len(label_list.split(", "))} labels, '
    else:
        add_text = ""
    # Task description and instructions for processing the input to generate output
    system_prompt = f"""
    Task Description:
    - You will receive a list of caption tags accompanied by a caption text and must assign appropriate labels from a predefined label list: "{label_list}".

    Instructions:
    Step 1. Visualize the scene suggested by the input caption tags and text.
    Step 2. Analyze each term within the overall scene to predict relevant labels from the predefined list, ensuring no term is overlooked.
    Step 3. Now forget the input list and focus on the scene as a whole, expanding upon the labels to include any contextually relevant labels that complete the scene or setting.
    Step 4. Compile all identified labels into a comma-separated list, adhering strictly to the specified format.

    Contextually Relevant Tips:
    - Equivalencies include converting "girl, man" to "person" and "flower, vase" to "potted plant", while "bicycle, motorcycle" suggest "rider".
    - An outdoor scene may include labels like "sky", "tree", "clouds", "terrain".
    - An urban scene may imply "bus", "bicycle", "road", "sidewalk", "building", "pole", "traffic-light", "traffic-sign".

    Output:
    - Do not output any explanations other than the final label list.
    - The final output should {add_text}strictly adhere to the specified format: label1, label2, ... labeln
    """.strip()
    if example:
        system_prompt += f"""
        Additional Examples with Detailed Explanations:
        {example}
        """
    print("system_prompt: ", system_prompt)
    return system_prompt

    # - You will receive a list of terms accompanied by a caption text and must assign appropriate labels from a predefined label list: "{label_list}".

    # Instructions:
    # Step 1. Visualize the scene suggested by the input list and caption text.


def make_prompt(label_list):
    # Create a new system prompt using the label list and the improved example prompt
    system_prompt = _make_prompt(label_list)
    system_prompt = {"role": "system", "content": system_prompt.strip()}
    print("system_prompt: ", system_prompt)
    return system_prompt


def _call_llm(system_prompt, llm, user_input):
    messages = [system_prompt, {"role": "user", "content": "Here are input caption tags and text: " + user_input}]
    converted_label = llm.chat(messages=messages, temperature=0.0, max_tokens=200)
    return converted_label


def pre_refinement(user_input_list, system_prompt, llm=None):
    llm_outputs = [_call_llm(system_prompt, llm, user_input) for user_input in user_input_list]
    converted_labels = [f"{user_input_}, {converted_label}" for user_input_, converted_label in zip(user_input_list, llm_outputs)]
    return converted_labels, llm_outputs


def post_refinement(label_list, detected_label, llm=None):
    system_input = f"""
    Task Description:
    - You will receive a specific phrase and must assign an appropriate label from the predefined label list: "{label_list}". \n \

    Please adhere to the following rules: \n \
    - Select and return only one relevant label from the predefined label list that corresponds to the given phrase. \n \
    - Do not include any additional information or context beyond the label itself. \n \
    - Format is purely the label itself, without any additional punctuation or formatting. \n \
    """
    system_input = {"role": "system", "content": system_input}
    messages = [system_input, {"role": "user", "content": detected_label}]
    if detected_label == "":
        return ""
    generated_label = None
    for count in range(3):
        generated_label = llm.chat(messages=messages, temperature=0.0 if count == 0 else 0.1 * (count), post=True)
        if generated_label != "":
            break
    return generated_label


if __name__ == "__main__":
    # test the functions
    llm = Llama3(model="Meta-Llama-3-8B-Instruct")

    system_prompt = make_prompt("person, car, tree, sky, road, building, sidewalk, traffic-light, traffic-sign", llm=llm)

    converted_labels, llm_outputs = pre_refinement(["person, car, road, traffic-light"], system_prompt, llm=llm)
    print("converted_labels: ", converted_labels)
    print("llm_outputs: ", llm_outputs)