fffiloni's picture
Upload 164 files
2ada650 verified
import os
import sys
import json
import argparse
import pathlib
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
# https://github.com/PyTorchLightning/pytorch-lightning/issues/11663
import sentencepiece; import pytorch_lightning as pl
import torchmetrics.functional as MF
from load_aokvqa import load_aokvqa
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
parser.add_argument('--vocab', type=argparse.FileType('r'), required=True)
parser.add_argument('--log-dir', type=pathlib.Path, dest='log_dir', required=True)
#
parser.add_argument('--backbone', type=str, choices=['clip', 'resnet', 'bert'], required=True)
parser.add_argument('--clip-model-type', type=str,
choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'],
dest='clip_model_type', required=('clip' in sys.argv))
parser.add_argument('--train-features', type=pathlib.Path, required=True, dest='train_features')
parser.add_argument('--val-features', type=pathlib.Path, required=True, dest='val_features')
parser.add_argument('--vocab-features', type=pathlib.Path, required=('contrastive' in sys.argv), dest='vocab_features')
#
parser.add_argument('--objective', type=str, choices=['classifier', 'contrastive'], required=True)
parser.add_argument('--inputs', nargs='+', type=str, choices=['question', 'image'], required=True)
# Defaults
parser.add_argument('--bs', type=int, default=128, dest='batch_size')
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--epochs', type=int, default=500)
parser.add_argument('--gpus', type=int, default=1)
args = parser.parse_args()
pl.seed_everything(1)
vocab = args.vocab.read().splitlines()
## Data loading
dm = AokvqaEmbeddingsDataModule(
args.aokvqa_dir,
args.train_features,
args.val_features,
args.objective,
args.backbone,
args.inputs,
vocab,
args.vocab_features,
batch_size=args.batch_size,
num_workers=16
)
## Model definition
model = LinearClassifier(
args.objective,
args.backbone,
args.clip_model_type,
args.inputs,
len(vocab),
args.lr
)
## Training and testing loops
logger = pl.loggers.TensorBoardLogger(
args.log_dir,
name=f'{args.backbone}-{args.objective}',
version=f"inputs:{'+'.join(args.inputs)}"
)
trainer = pl.Trainer(
logger=logger,
gpus=args.gpus,
max_epochs=args.epochs,
callbacks=[
pl.callbacks.ModelCheckpoint(
monitor="val_acc",
filename="{epoch:02d}-{val_acc:.2f}",
mode="max"
)
],
)
trainer.fit(model, dm)
class AokvqaEmbeddingsDataset(Dataset):
def __init__(self, aokvqa_dir, split, input_features, objective, backbone, inputs, vocab, vocab_features):
aokvqa_set = load_aokvqa(aokvqa_dir, split)
assert ( backbone == 'resnet' and inputs == ['image'] and objective == 'classifier' ) \
or ( backbone == 'bert' and inputs == ['question'] and objective == 'classifier' ) \
or ( backbone == 'clip' )
embeddings = torch.load(input_features)
if backbone == 'clip':
for q in embeddings.keys():
embeddings[q]['question'] /= embeddings[q]['question'].norm(dim=-1, keepdim=True)
embeddings[q]['image'] /= embeddings[q]['image'].norm(dim=-1, keepdim=True)
if objective == 'contrastive':
vocab_embeddings = torch.load(vocab_features)
vocab_embeddings /= vocab_embeddings.norm(dim=-1, keepdim=True)
self.objective = objective
self.vocab_len = len(vocab)
self.embeddings = []
self.answers = []
for o in aokvqa_set:
correct_answers = set([o['choices'][o['correct_choice_idx']]] + o['direct_answers'])
correct_answers = [vocab.index(a) for a in correct_answers if a in vocab]
if self.objective == 'contrastive':
correct_answers = [vocab_embeddings[a] for a in correct_answers]
if len(correct_answers) == 0: continue
self.answers.append(correct_answers)
q = o['question_id']
if 'question' in inputs and 'image' in inputs:
e = torch.cat((embeddings[q]['question'], embeddings[q]['image']))
elif 'question' in inputs and 'image' not in inputs:
e = embeddings[q]['question']
elif 'question' not in inputs and 'image' in inputs:
e = embeddings[q]['image']
self.embeddings.append(e)
def __getitem__(self, index):
e = self.embeddings[index]
a = self.answers[index]
if self.objective == 'classifier':
a = torch.sum(F.one_hot(torch.tensor(a), num_classes=self.vocab_len), dim=0)
elif self.objective == 'contrastive':
a = random.sample(a, 1)[0]
return e, a
def __len__(self):
return len(self.embeddings)
class AokvqaEmbeddingsDataModule(pl.LightningDataModule):
def __init__(self, aokvqa_dir, train_features, val_features, objective, backbone, inputs, vocab, vocab_features, batch_size=1, num_workers=0):
super().__init__()
self.aokvqa_dir = aokvqa_dir
self.train_features = train_features
self.val_features = val_features
self.objective = objective
self.backbone = backbone
self.inputs = inputs
self.vocab = vocab
self.vocab_features = vocab_features
self.batch_size = batch_size
self.num_workers = num_workers
def setup(self, stage=None):
self.train_dataset = AokvqaEmbeddingsDataset(
self.aokvqa_dir, 'train', self.train_features, self.objective,
self.backbone, self.inputs, self.vocab, self.vocab_features
)
self.val_dataset = AokvqaEmbeddingsDataset(
self.aokvqa_dir, 'val', self.val_features, self.objective,
self.backbone, self.inputs, self.vocab, self.vocab_features
)
def train_dataloader(self):
return DataLoader(
self.train_dataset, batch_size=self.batch_size, shuffle=True,
num_workers=int(0.8 * self.num_workers)
)
def val_dataloader(self):
return DataLoader(
self.val_dataset, batch_size=self.batch_size, shuffle=False,
num_workers=int(0.2 * self.num_workers)
)
class LinearClassifier(pl.LightningModule):
def __init__(self, objective, backbone, clip_model_type, inputs, vocab_len, lr=0.001):
super().__init__()
self.save_hyperparameters(ignore=['lr'])
self.lr = lr
if self.hparams.backbone == 'clip':
clip_dim = {
'RN50' : 1024,
'RN50x4' : 640,
'RN50x16' : 768,
'RN50x64' : 1024,
'RN101' : 512,
'ViT-B/32' : 512,
'ViT-B/16' : 512,
'ViT-L/14' : 768,
'ViT-L/14@336px' : 768,
}[clip_model_type]
emb_dim = clip_dim * len(inputs)
elif self.hparams.backbone == 'resnet':
emb_dim = 2048
elif self.hparams.backbone == 'bert':
emb_dim = 768
if self.hparams.objective == 'classifier':
out_dim = vocab_len
elif self.hparams.objective == 'contrastive':
out_dim = clip_dim
self.linear = nn.Linear(emb_dim, out_dim)
def forward(self, x):
x = self.linear(x)
if self.hparams.objective == 'classifier':
x = torch.sigmoid(x)
return x
def compute_loss(self, batch):
x, y = batch
y_pred = self.forward(x)
if self.hparams.objective == 'classifier':
loss = F.binary_cross_entropy(y_pred, y.float())
elif self.hparams.objective == 'contrastive':
indices = torch.arange(0, x.shape[0], dtype=torch.int64, device=self.device)
sim = (y_pred @ y.T).softmax(dim=-1)
loss = F.cross_entropy(sim, indices)
if self.hparams.objective == 'classifier':
acc = MF.f1_score(y_pred, y)
elif self.hparams.objective == 'contrastive':
acc = torch.mean(sim[indices, indices])
return loss, acc
def training_step(self, batch, batch_idx):
loss, acc = self.compute_loss(batch)
self.log("train_loss", loss)
self.log("train_acc", acc)
return loss
def validation_step(self, batch, batch_idx):
loss, acc = self.compute_loss(batch)
self.log("val_loss", loss)
self.log("val_acc", acc)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
if __name__ == '__main__':
main()