File size: 2,582 Bytes
2ada650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import random
import json
from tqdm import tqdm
import argparse
import pathlib

import openai
openai.organization = os.getenv('OPENAI_ORG')
openai.api_key = os.getenv('OPENAI_API_KEY')

from load_aokvqa import load_aokvqa


random.seed(0)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
    parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
    parser.add_argument('--n', type=int, default=10, dest='num_examples')
    parser.add_argument('--train-context', type=argparse.FileType('r'), dest='train_context_file')
    parser.add_argument('--prefix', type=str, default='', dest='prompt_prefix')
    parser.add_argument('--include-choices', action='store_true', dest='include_choices')
    parser.add_argument('--context', type=argparse.FileType('r'), dest='context_file')
    parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
    args = parser.parse_args()


    train_set = load_aokvqa(args.aokvqa_dir, 'train')
    eval_set = load_aokvqa(args.aokvqa_dir, args.split)

    train_context = {}
    context = {}
    if args.context_file is not None:
        train_context = json.load(args.train_context_file)
        context = json.load(args.context_file)

    predictions = {}

    for d in tqdm(eval_set):
        q = d['question_id']

        prompt = args.prompt_prefix
        for e in random.sample(train_set, args.num_examples):
            prompt += prompt_element(e,
                context=train_context.get(q, None),
                include_choices=args.include_choices,
                answer=True
            )
            prompt += '\n\n'

        prompt += prompt_element(d,
            context=context.get(q, None),
            include_choices=args.include_choices,
            answer=False
        )

        response = openai.Completion.create(
            engine="text-curie-001",
            prompt=prompt,
            temperature=0.0,
            max_tokens=10,
        )

        predictions[q] = response.choices[0].text.strip()

    json.dump(predictions, args.output_file)


def prompt_element(d, context=None, include_choices=False, answer=False):
    return (f"Context: {context}\n" if context is not None else '') + \
            f"Q: {d['question']}\n" + \
           (f"Choices: {', '.join(d['choices'])}.\n" if include_choices else '') + \
            f"A:" + (f" {d['choices'][d['correct_choice_idx']]}" if answer else '')

if __name__ == '__main__':
    main()