|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import os |
|
from collections import Counter |
|
|
|
import torch |
|
import torchvision |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
from torch import nn |
|
from torch.utils.data import Dataset |
|
|
|
|
|
POOLING_BREAKDOWN = {1: (1, 1), 2: (2, 1), 3: (3, 1), 4: (2, 2), 5: (5, 1), 6: (3, 2), 7: (7, 1), 8: (4, 2), 9: (3, 3)} |
|
|
|
|
|
class ImageEncoder(nn.Module): |
|
def __init__(self, args): |
|
super().__init__() |
|
model = torchvision.models.resnet152(pretrained=True) |
|
modules = list(model.children())[:-2] |
|
self.model = nn.Sequential(*modules) |
|
self.pool = nn.AdaptiveAvgPool2d(POOLING_BREAKDOWN[args.num_image_embeds]) |
|
|
|
def forward(self, x): |
|
|
|
out = self.pool(self.model(x)) |
|
out = torch.flatten(out, start_dim=2) |
|
out = out.transpose(1, 2).contiguous() |
|
return out |
|
|
|
|
|
class JsonlDataset(Dataset): |
|
def __init__(self, data_path, tokenizer, transforms, labels, max_seq_length): |
|
self.data = [json.loads(l) for l in open(data_path)] |
|
self.data_dir = os.path.dirname(data_path) |
|
self.tokenizer = tokenizer |
|
self.labels = labels |
|
self.n_classes = len(labels) |
|
self.max_seq_length = max_seq_length |
|
|
|
self.transforms = transforms |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, index): |
|
sentence = torch.LongTensor(self.tokenizer.encode(self.data[index]["text"], add_special_tokens=True)) |
|
start_token, sentence, end_token = sentence[0], sentence[1:-1], sentence[-1] |
|
sentence = sentence[: self.max_seq_length] |
|
|
|
label = torch.zeros(self.n_classes) |
|
label[[self.labels.index(tgt) for tgt in self.data[index]["label"]]] = 1 |
|
|
|
image = Image.open(os.path.join(self.data_dir, self.data[index]["img"])).convert("RGB") |
|
image = self.transforms(image) |
|
|
|
return { |
|
"image_start_token": start_token, |
|
"image_end_token": end_token, |
|
"sentence": sentence, |
|
"image": image, |
|
"label": label, |
|
} |
|
|
|
def get_label_frequencies(self): |
|
label_freqs = Counter() |
|
for row in self.data: |
|
label_freqs.update(row["label"]) |
|
return label_freqs |
|
|
|
|
|
def collate_fn(batch): |
|
lens = [len(row["sentence"]) for row in batch] |
|
bsz, max_seq_len = len(batch), max(lens) |
|
|
|
mask_tensor = torch.zeros(bsz, max_seq_len, dtype=torch.long) |
|
text_tensor = torch.zeros(bsz, max_seq_len, dtype=torch.long) |
|
|
|
for i_batch, (input_row, length) in enumerate(zip(batch, lens)): |
|
text_tensor[i_batch, :length] = input_row["sentence"] |
|
mask_tensor[i_batch, :length] = 1 |
|
|
|
img_tensor = torch.stack([row["image"] for row in batch]) |
|
tgt_tensor = torch.stack([row["label"] for row in batch]) |
|
img_start_token = torch.stack([row["image_start_token"] for row in batch]) |
|
img_end_token = torch.stack([row["image_end_token"] for row in batch]) |
|
|
|
return text_tensor, mask_tensor, img_tensor, img_start_token, img_end_token, tgt_tensor |
|
|
|
|
|
def get_mmimdb_labels(): |
|
return [ |
|
"Crime", |
|
"Drama", |
|
"Thriller", |
|
"Action", |
|
"Comedy", |
|
"Romance", |
|
"Documentary", |
|
"Short", |
|
"Mystery", |
|
"History", |
|
"Family", |
|
"Adventure", |
|
"Fantasy", |
|
"Sci-Fi", |
|
"Western", |
|
"Horror", |
|
"Sport", |
|
"War", |
|
"Music", |
|
"Musical", |
|
"Animation", |
|
"Biography", |
|
"Film-Noir", |
|
] |
|
|
|
|
|
def get_image_transforms(): |
|
return transforms.Compose( |
|
[ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.46777044, 0.44531429, 0.40661017], |
|
std=[0.12221994, 0.12145835, 0.14380469], |
|
), |
|
] |
|
) |
|
|