|
import os |
|
import sys |
|
import json |
|
import argparse |
|
import pathlib |
|
import random |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
|
|
import sentencepiece; import pytorch_lightning as pl |
|
|
|
import torchmetrics.functional as MF |
|
|
|
from load_aokvqa import load_aokvqa |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') |
|
parser.add_argument('--vocab', type=argparse.FileType('r'), required=True) |
|
parser.add_argument('--log-dir', type=pathlib.Path, dest='log_dir', required=True) |
|
|
|
parser.add_argument('--backbone', type=str, choices=['clip', 'resnet', 'bert'], required=True) |
|
parser.add_argument('--clip-model-type', type=str, |
|
choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'], |
|
dest='clip_model_type', required=('clip' in sys.argv)) |
|
parser.add_argument('--train-features', type=pathlib.Path, required=True, dest='train_features') |
|
parser.add_argument('--val-features', type=pathlib.Path, required=True, dest='val_features') |
|
parser.add_argument('--vocab-features', type=pathlib.Path, required=('contrastive' in sys.argv), dest='vocab_features') |
|
|
|
parser.add_argument('--objective', type=str, choices=['classifier', 'contrastive'], required=True) |
|
parser.add_argument('--inputs', nargs='+', type=str, choices=['question', 'image'], required=True) |
|
|
|
parser.add_argument('--bs', type=int, default=128, dest='batch_size') |
|
parser.add_argument('--lr', type=float, default=0.01) |
|
parser.add_argument('--epochs', type=int, default=500) |
|
parser.add_argument('--gpus', type=int, default=1) |
|
args = parser.parse_args() |
|
|
|
pl.seed_everything(1) |
|
vocab = args.vocab.read().splitlines() |
|
|
|
|
|
|
|
dm = AokvqaEmbeddingsDataModule( |
|
args.aokvqa_dir, |
|
args.train_features, |
|
args.val_features, |
|
args.objective, |
|
args.backbone, |
|
args.inputs, |
|
vocab, |
|
args.vocab_features, |
|
batch_size=args.batch_size, |
|
num_workers=16 |
|
) |
|
|
|
|
|
|
|
model = LinearClassifier( |
|
args.objective, |
|
args.backbone, |
|
args.clip_model_type, |
|
args.inputs, |
|
len(vocab), |
|
args.lr |
|
) |
|
|
|
|
|
|
|
logger = pl.loggers.TensorBoardLogger( |
|
args.log_dir, |
|
name=f'{args.backbone}-{args.objective}', |
|
version=f"inputs:{'+'.join(args.inputs)}" |
|
) |
|
|
|
trainer = pl.Trainer( |
|
logger=logger, |
|
gpus=args.gpus, |
|
max_epochs=args.epochs, |
|
callbacks=[ |
|
pl.callbacks.ModelCheckpoint( |
|
monitor="val_acc", |
|
filename="{epoch:02d}-{val_acc:.2f}", |
|
mode="max" |
|
) |
|
], |
|
) |
|
|
|
trainer.fit(model, dm) |
|
|
|
|
|
class AokvqaEmbeddingsDataset(Dataset): |
|
def __init__(self, aokvqa_dir, split, input_features, objective, backbone, inputs, vocab, vocab_features): |
|
|
|
aokvqa_set = load_aokvqa(aokvqa_dir, split) |
|
|
|
assert ( backbone == 'resnet' and inputs == ['image'] and objective == 'classifier' ) \ |
|
or ( backbone == 'bert' and inputs == ['question'] and objective == 'classifier' ) \ |
|
or ( backbone == 'clip' ) |
|
|
|
embeddings = torch.load(input_features) |
|
if backbone == 'clip': |
|
for q in embeddings.keys(): |
|
embeddings[q]['question'] /= embeddings[q]['question'].norm(dim=-1, keepdim=True) |
|
embeddings[q]['image'] /= embeddings[q]['image'].norm(dim=-1, keepdim=True) |
|
if objective == 'contrastive': |
|
vocab_embeddings = torch.load(vocab_features) |
|
vocab_embeddings /= vocab_embeddings.norm(dim=-1, keepdim=True) |
|
|
|
self.objective = objective |
|
self.vocab_len = len(vocab) |
|
|
|
self.embeddings = [] |
|
self.answers = [] |
|
|
|
for o in aokvqa_set: |
|
correct_answers = set([o['choices'][o['correct_choice_idx']]] + o['direct_answers']) |
|
correct_answers = [vocab.index(a) for a in correct_answers if a in vocab] |
|
if self.objective == 'contrastive': |
|
correct_answers = [vocab_embeddings[a] for a in correct_answers] |
|
if len(correct_answers) == 0: continue |
|
self.answers.append(correct_answers) |
|
|
|
q = o['question_id'] |
|
if 'question' in inputs and 'image' in inputs: |
|
e = torch.cat((embeddings[q]['question'], embeddings[q]['image'])) |
|
elif 'question' in inputs and 'image' not in inputs: |
|
e = embeddings[q]['question'] |
|
elif 'question' not in inputs and 'image' in inputs: |
|
e = embeddings[q]['image'] |
|
self.embeddings.append(e) |
|
|
|
def __getitem__(self, index): |
|
e = self.embeddings[index] |
|
a = self.answers[index] |
|
if self.objective == 'classifier': |
|
a = torch.sum(F.one_hot(torch.tensor(a), num_classes=self.vocab_len), dim=0) |
|
elif self.objective == 'contrastive': |
|
a = random.sample(a, 1)[0] |
|
return e, a |
|
|
|
def __len__(self): |
|
return len(self.embeddings) |
|
|
|
|
|
class AokvqaEmbeddingsDataModule(pl.LightningDataModule): |
|
|
|
def __init__(self, aokvqa_dir, train_features, val_features, objective, backbone, inputs, vocab, vocab_features, batch_size=1, num_workers=0): |
|
super().__init__() |
|
self.aokvqa_dir = aokvqa_dir |
|
self.train_features = train_features |
|
self.val_features = val_features |
|
self.objective = objective |
|
self.backbone = backbone |
|
self.inputs = inputs |
|
self.vocab = vocab |
|
self.vocab_features = vocab_features |
|
self.batch_size = batch_size |
|
self.num_workers = num_workers |
|
|
|
def setup(self, stage=None): |
|
self.train_dataset = AokvqaEmbeddingsDataset( |
|
self.aokvqa_dir, 'train', self.train_features, self.objective, |
|
self.backbone, self.inputs, self.vocab, self.vocab_features |
|
) |
|
self.val_dataset = AokvqaEmbeddingsDataset( |
|
self.aokvqa_dir, 'val', self.val_features, self.objective, |
|
self.backbone, self.inputs, self.vocab, self.vocab_features |
|
) |
|
|
|
def train_dataloader(self): |
|
return DataLoader( |
|
self.train_dataset, batch_size=self.batch_size, shuffle=True, |
|
num_workers=int(0.8 * self.num_workers) |
|
) |
|
|
|
def val_dataloader(self): |
|
return DataLoader( |
|
self.val_dataset, batch_size=self.batch_size, shuffle=False, |
|
num_workers=int(0.2 * self.num_workers) |
|
) |
|
|
|
|
|
class LinearClassifier(pl.LightningModule): |
|
def __init__(self, objective, backbone, clip_model_type, inputs, vocab_len, lr=0.001): |
|
super().__init__() |
|
self.save_hyperparameters(ignore=['lr']) |
|
self.lr = lr |
|
|
|
if self.hparams.backbone == 'clip': |
|
clip_dim = { |
|
'RN50' : 1024, |
|
'RN50x4' : 640, |
|
'RN50x16' : 768, |
|
'RN50x64' : 1024, |
|
'RN101' : 512, |
|
'ViT-B/32' : 512, |
|
'ViT-B/16' : 512, |
|
'ViT-L/14' : 768, |
|
'ViT-L/14@336px' : 768, |
|
}[clip_model_type] |
|
emb_dim = clip_dim * len(inputs) |
|
elif self.hparams.backbone == 'resnet': |
|
emb_dim = 2048 |
|
elif self.hparams.backbone == 'bert': |
|
emb_dim = 768 |
|
|
|
if self.hparams.objective == 'classifier': |
|
out_dim = vocab_len |
|
elif self.hparams.objective == 'contrastive': |
|
out_dim = clip_dim |
|
|
|
self.linear = nn.Linear(emb_dim, out_dim) |
|
|
|
def forward(self, x): |
|
x = self.linear(x) |
|
if self.hparams.objective == 'classifier': |
|
x = torch.sigmoid(x) |
|
return x |
|
|
|
def compute_loss(self, batch): |
|
x, y = batch |
|
|
|
y_pred = self.forward(x) |
|
|
|
if self.hparams.objective == 'classifier': |
|
loss = F.binary_cross_entropy(y_pred, y.float()) |
|
elif self.hparams.objective == 'contrastive': |
|
indices = torch.arange(0, x.shape[0], dtype=torch.int64, device=self.device) |
|
sim = (y_pred @ y.T).softmax(dim=-1) |
|
loss = F.cross_entropy(sim, indices) |
|
|
|
if self.hparams.objective == 'classifier': |
|
acc = MF.f1_score(y_pred, y) |
|
elif self.hparams.objective == 'contrastive': |
|
acc = torch.mean(sim[indices, indices]) |
|
|
|
return loss, acc |
|
|
|
def training_step(self, batch, batch_idx): |
|
loss, acc = self.compute_loss(batch) |
|
self.log("train_loss", loss) |
|
self.log("train_acc", acc) |
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
loss, acc = self.compute_loss(batch) |
|
self.log("val_loss", loss) |
|
self.log("val_acc", acc) |
|
return loss |
|
|
|
def configure_optimizers(self): |
|
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) |
|
return optimizer |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|