|
import os
|
|
import json
|
|
import torch
|
|
from PIL import Image
|
|
from torch.utils.data import Dataset
|
|
from transformers import AutoProcessor
|
|
from torch.utils.data import DataLoader
|
|
import pickle
|
|
import requests
|
|
from datasets import Dataset, load_dataset
|
|
import pandas as pd
|
|
import numpy as np
|
|
|
|
|
|
class ClipDataset(Dataset):
|
|
'''ClipDataset class for loading the CLIP dataset'''
|
|
def __init__(self, coco_data, model_name, tokenizer):
|
|
|
|
self.tokenizer = tokenizer
|
|
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
|
self.caption_dataset = coco_data
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.caption_dataset)
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
img_url = self.caption_dataset[idx]["image_url"]
|
|
caption = self.caption_dataset[idx]["caption"]
|
|
|
|
|
|
image = Image.open(requests.get(img_url,stream=True).raw)
|
|
width, height = image.size
|
|
new_width = 224
|
|
new_height = new_width * height // width
|
|
new_height = 224
|
|
new_width = new_height * width // height
|
|
image = image.resize((new_width, new_height), Image.LANCZOS)
|
|
image_processed = self.processor(images=image, return_tensors="pt") ['pixel_values']
|
|
image_sqeezed = image_processed.squeeze(0)
|
|
tokenized_caption = self.tokenizer(caption, return_tensors="pt", return_attention_mask=False)
|
|
tokenized_caption_ids = tokenized_caption['input_ids'].squeeze(0)
|
|
return(image_sqeezed , tokenized_caption_ids)
|
|
|
|
|
|
def collate_fn_phase1(batch):
|
|
|
|
image_embeddings, captions = zip(*batch)
|
|
|
|
image_embeddings_stacked = torch.stack(image_embeddings, dim=0)
|
|
|
|
captions_padded = torch.nn.utils.rnn.pad_sequence(captions, batch_first=True, padding_value=50256)
|
|
|
|
return (image_embeddings_stacked, captions_padded)
|
|
|
|
|
|
def get_data_loaders_phase1(data_dir, clip_model_name, tokenizer, train_batch_size, val_batch_size, num_workers):
|
|
|
|
with open(os.path.join(data_dir, 'coco_train.pkl'), 'rb') as fp:
|
|
train_pkl = pickle.load(fp)
|
|
with open(os.path.join(data_dir, "coco_val.pkl"), "rb") as fp:
|
|
val_pkl = pickle.load(fp)
|
|
|
|
train_dataloader = DataLoader(ClipDataset(train_pkl, clip_model_name, tokenizer), collate_fn=collate_fn_phase1, batch_size=train_batch_size, num_workers = num_workers, shuffle=True, pin_memory=True)
|
|
|
|
|
|
val_dataloader = DataLoader(ClipDataset(val_pkl, clip_model_name, tokenizer), collate_fn=collate_fn_phase1, batch_size=val_batch_size, num_workers = num_workers, shuffle=False, pin_memory=True)
|
|
return train_dataloader, val_dataloader
|
|
|
|
|
|
|
|
|
|
class ClipDatasetPhase2(Dataset):
|
|
'''ClipDataset class for loading the CLIP dataset'''
|
|
def __init__(self, data_frame, model_name, tokenizer):
|
|
|
|
self.tokenizer = tokenizer
|
|
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
|
self.df = data_frame
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.df)
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
img_url = self.df.ImageUrl[idx[0]]
|
|
que = self.df.Question[idx[0]]
|
|
ans = self.df.Answer[idx[0]]
|
|
|
|
print("img_url", img_url)
|
|
print("que", que)
|
|
print("ans", ans)
|
|
|
|
|
|
if img_url is None:
|
|
print("img_url is None")
|
|
image_sqeezed = None
|
|
else:
|
|
image = Image.open(requests.get(img_url,stream=True).raw)
|
|
width, height = image.size
|
|
new_width = 224
|
|
new_height = new_width * height // width
|
|
new_height = 224
|
|
new_width = new_height * width // height
|
|
image = image.resize((new_width, new_height), Image.LANCZOS)
|
|
image_processed = self.processor(images=image, return_tensors="pt") ['pixel_values']
|
|
image_sqeezed = image_processed.squeeze(0)
|
|
que_ids = self.tokenizer(que, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
|
|
ans_ids = self.tokenizer(ans, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
|
|
return(image_sqeezed , que_ids, ans_ids)
|
|
|
|
|
|
def collate_fn_phase2(batch):
|
|
|
|
image_embeddings, ques, ans = zip(*batch)
|
|
|
|
if image_embeddings[0] is None:
|
|
image_embeddings_stacked = None
|
|
else:
|
|
image_embeddings_stacked = torch.stack(image_embeddings, dim=0)
|
|
|
|
ques_padded = torch.nn.utils.rnn.pad_sequence(ques, batch_first=True, padding_value=50256)
|
|
ans_padded = torch.nn.utils.rnn.pad_sequence(ans, batch_first=True, padding_value=50256)
|
|
|
|
return (image_embeddings_stacked, ques_padded, ans_padded)
|
|
|
|
|
|
def prep_data(df):
|
|
df_assistant = df[(df.role == "assistant") & (df["rank"] == 0.0)].copy()
|
|
df_prompter = df[(df.role == "prompter")].copy()
|
|
df_prompter = df_prompter.set_index("message_id")
|
|
df_assistant["Answer"] = df_assistant["text"].values
|
|
|
|
inputs = []
|
|
for _, row in df_assistant.iterrows():
|
|
input = df_prompter.loc[row.parent_id]
|
|
inputs.append(input.text)
|
|
|
|
df_assistant["Question"] = inputs
|
|
df_assistant["ImageUrl"] = None
|
|
|
|
df_assistant = df_assistant[df_assistant.lang == "en"]
|
|
|
|
df_assistant = df_assistant[
|
|
["ImageUrl","Question", "Answer", "message_id"]
|
|
].rename(columns={"message_id": "Ids"})
|
|
|
|
return df_assistant
|
|
|
|
|
|
def get_i150_df(config):
|
|
with open(config.get("i150k_json"), "r") as fp:
|
|
i150k_json_read = json.load(fp)
|
|
max_tokens = 100
|
|
image_urls = []
|
|
ques_list = []
|
|
ans_list = []
|
|
id_list = []
|
|
for idx, data in enumerate(i150k_json_read):
|
|
image = data['image']
|
|
image_url = 'http://images.cocodataset.org/train2017/' + image
|
|
id_ = data["id"]
|
|
iterator = iter(data['conversations'])
|
|
for i in iterator:
|
|
ques = i
|
|
ans = next(iterator)
|
|
if (len(ques["value"])>100 or len(ans["value"])>max_tokens):
|
|
continue
|
|
if ques["from"] == "human" and ans["from"] == "gpt":
|
|
image_urls.append(image_url)
|
|
ques_list.append(ques["value"].replace("<image>\n","").replace("<image>",""))
|
|
ans_list.append(ans["value"])
|
|
id_list.append(id_)
|
|
df_i150k = pd.DataFrame(list(zip(image_urls, ques_list, ans_list, id_list)),
|
|
columns =["ImageUrl", "Question", "Answer", "Ids"])
|
|
msk = np.random.rand(len(df_i150k)) < 0.96
|
|
|
|
train_df = df_i150k[msk]
|
|
test_df = df_i150k[~msk]
|
|
return train_df, test_df
|
|
|
|
|
|
def get_oas_df(config):
|
|
train_ds, val_ds = load_dataset(config.get("QA_datasetName"), split=["train", "validation"])
|
|
train_df = prep_data(train_ds.to_pandas())
|
|
test_df = prep_data(val_ds.to_pandas())
|
|
return train_df, test_df
|
|
|
|
|
|
def get_data_loaders_phase2(tokenizer, config):
|
|
|
|
train_i150k, test_i150k = get_i150_df(config)
|
|
train_oas, test_oas = get_oas_df(config)
|
|
|
|
train_df = pd.concat([train_i150k, train_oas]).reset_index(drop=True)
|
|
val_df = pd.concat([test_i150k, test_oas]).reset_index(drop=True)
|
|
|
|
train_dataloader = DataLoader(ClipDatasetPhase2(train_df, config.get("clip_model_name"), tokenizer), collate_fn=collate_fn_phase2, batch_size=config.get("train_batch_size"), num_workers = config.get("num_workers"), shuffle=True, pin_memory=True)
|
|
|
|
|
|
val_dataloader = DataLoader(ClipDatasetPhase2(val_df, config.get("clip_model_name"), tokenizer), collate_fn=collate_fn_phase2, batch_size=config.get("val_batch_size"), num_workers = config.get("num_workers"), shuffle=False, pin_memory=True)
|
|
return train_dataloader, val_dataloader |