File size: 4,780 Bytes
2ada650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import sys
import os
import argparse
import pathlib
from tqdm import tqdm
import json

import torch
import torch.nn as nn

# https://github.com/PyTorchLightning/pytorch-lightning/issues/11663
import sentencepiece; import pytorch_lightning as pl; import clip

from transfer_experiments.train import LinearClassifier
from load_aokvqa import load_aokvqa
from evaluation.remap_predictions import map_to_choices


parser = argparse.ArgumentParser()
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
parser.add_argument('--features', type=pathlib.Path, required=True)
parser.add_argument('--out', type=argparse.FileType('w'), dest='output_file')
#
parser_weights = parser.add_mutually_exclusive_group(required=True)

parser_weights.add_argument('--ckpt', type=pathlib.Path, dest='checkpoint_path')

parser_weights.add_argument('--zero-shot', action='store_true', dest='clip_zero_shot')
parser.add_argument('--inputs', nargs='+', type=str, choices=['question', 'image'], required=('--zero-shot' in sys.argv))
#
parser.add_argument('--vocab', type=argparse.FileType('r'))
parser.add_argument('--vocab-features', type=pathlib.Path, dest='vocab_features')
parser.add_argument('--mc', action='store_true', dest='multiple_choice')

parser.add_argument('--clip-model-type', type=str,
                    choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'],
                    dest='clip_model_type', required=('--zero-shot' in sys.argv and '--mc' in sys.argv))
#
args = parser.parse_args()


## Load dataset

aokvqa_set = load_aokvqa(args.aokvqa_dir, args.split)

## Load models

device = "cuda" if torch.cuda.is_available() else "cpu"

if args.checkpoint_path is not None:
    classifier = LinearClassifier.load_from_checkpoint(args.checkpoint_path)
    classifier.to(device)
    hp = classifier.hparams
elif args.clip_zero_shot:
    classifier = nn.Identity().to(device)
    hp = pl.utilities.AttributeDict(backbone='clip', clip_model_type=args.clip_model_type, objective='zero-shot', inputs=args.inputs)

# Load input features

embeddings = torch.load(args.features)
if hp.backbone == 'clip':
    for q in embeddings.keys():
        embeddings[q]['question'] = embeddings[q]['question'] / embeddings[q]['question'].norm(dim=-1, keepdim=True)
        embeddings[q]['image'] = embeddings[q]['image'] / embeddings[q]['image'].norm(dim=-1, keepdim=True)

# Load vocab, vocab features, clip

if (hp.objective == 'classifier') or \
   (hp.objective in ['contrastive', 'zero-shot'] and args.multiple_choice is False):
        vocab = args.vocab.read().splitlines()

if hp.objective in ['contrastive', 'zero-shot']:
    if args.multiple_choice is False:
        vocab_features = torch.load(args.vocab_features).cpu()
        vocab_features /= vocab_features.norm(dim=-1, keepdim=True)
    else:
        clip_model = clip.load(hp.clip_model_type, device=device)[0]
        logit_scale = clip_model.logit_scale.exp().cpu()

## Prediction loop

predictions = {}

with torch.no_grad():
    for o in tqdm(aokvqa_set):
        q = o['question_id']

        # Load input embedding (from question / image)
        if hp.objective == 'zero-shot' and ('question' in hp.inputs and 'image' in hp.inputs):
            e = embeddings[q]['question'] + embeddings[q]['image']
        elif 'question' in hp.inputs and 'image' in hp.inputs:
            e = torch.cat((embeddings[q]['question'], embeddings[q]['image']))
        elif 'question' in hp.inputs:
            e = embeddings[q]['question']
        elif 'image' in hp.inputs:
            e = embeddings[q]['image']

        # Pass inputs through model
        e = e.unsqueeze(0).to(device)
        x = classifier(e)[0].cpu()

        # Predict
        if hp.objective in ['contrastive', 'zero-shot']:
            if args.multiple_choice:
                vocab = o['choices']
                # Encode choices
                vocab_features = clip.tokenize(vocab).to(device)
                vocab_features = torch.stack([
                    clip_model.encode_text(v.unsqueeze(0)) for v in vocab_features
                ], dim=1)[0]
                vocab_features /= vocab_features.norm(dim=-1, keepdim=True)
                vocab_features = vocab_features.float().cpu()

            x = logit_scale * x @ vocab_features.t()
            x = x.softmax(dim=-1)

        predictions[q] = vocab[x.argmax().item()]

## Save and evaluate predictions

# Map prediction to nearest neighbor choice (by word embeddings)
if args.multiple_choice and hp.objective == 'classifier':
    predictions = map_to_choices(aokvqa_set, predictions)

json.dump(predictions, args.output_file)