# vqgan-jax-encoding-yfcc100m

Same as `vqgan-jax-encoding-with-captions`, but for YFCC100M.

This dataset was prepared by @borisdayma in Json lines format.

In [92]:
import io

import requests
from PIL import Image
import numpy as np
from tqdm import tqdm

import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torchvision.transforms import InterpolationMode
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets.folder import default_loader
import os

import jax
from jax import pmap

## VQGAN-JAX model

In [93]:
from vqgan_jax.modeling_flax_vqgan import VQModel

We'll use a VQGAN trained by using Taming Transformers and converted to a JAX model.

In [167]:
model = VQModel.from_pretrained("flax-community/vqgan_f16_16384")

Working with z of shape (1, 256, 16, 16) = 65536 dimensions.


## Dataset

In [94]:
import pandas as pd
from pathlib import Path

In [134]:
yfcc100m = Path('/home/khali/TPU-Test/YFCC100M_OpenAI_subset')
# Images are 'sharded' from the following directory
yfcc100m_images = yfcc100m/'data'/'data'/'images'
yfcc100m_metadata = yfcc100m/'metadata_YFCC100M.jsonl'
yfcc100m_output = yfcc100m/'metadata_encoded.tsv'

### Cleanup

We need to select entries with images that exist. Otherwise we can't build batches because `Dataloader` does not support `None` in batches. We use Huggingface Datasets, I understand they support threaded reading of jsonl files, and I was running out of memory when using pandas.

In [96]:
import datasets
from datasets import Dataset, load_dataset

In [10]:
# The metadata is too bog to load into memory at once, so chopping it into chunks
chunk_size=1000000
batch_no=1
for chunk in pd.read_json(yfcc100m_metadata, orient="records", lines=True,chunksize=chunk_size):
    chunk.to_csv('./chunks/chunk'+str(batch_no)+'.tsv', sep="\t", index=False)
    batch_no+=1

tcmalloc: large alloc 1254047744 bytes == 0xb2b08000 @  0x7f9e78632680 0x7f9e78653824 0x585b92 0x504d56 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56 0x56acb6 0x5f5956 0x56aadf 0x5f5956 0x56acb6 0x568d9a 0x5f5b33 0x50b7f8 0x5f2702 0x56c332
tcmalloc: large alloc 1254047744 bytes == 0xfd74e000 @  0x7f9e78632680 0x7f9e78653824 0x590214 0x586f90 0x56e1f3 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56 0x56acb6 0x5f5956 0x56aadf 0x5f5956 0x56acb6 0x568d9a 0x5f5b33 0x50b7f8 0x5f2702 0x56c332
tcmalloc: large alloc 5016190976 bytes == 0x148b42000 @  0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 

In [25]:
# looking up at a chunk
pd.read_csv("./chunks/chunk1.tsv", sep="\t")

Unnamed: 0,photoid,uid,unickname,datetaken,dateuploaded,capturedevice,title,description,usertags,machinetags,...,licenseurl,serverid,farmid,secret,secretoriginal,ext,marker,key,title_clean,description_clean
0,137943,48600072071@N01,doctor+paradox,2004-08-01 18:13:06.0,1091409186,,A+Picture+Share%21,Antenna,"cameraphone,cayugaheights,green,hydrant,ithaca...",,...,http://creativecommons.org/licenses/by-nc-sa/2.0/,1,1,1650c7cdc6,1650c7cdc6,jpg,0,d29e7c6a3028418c64eb15e3cf577c2,A Picture Share!,Antenna
1,1246361,44124324682@N01,mharrsch,2004-11-03 23:04:02.0,1099523042,,An+ornate+Roman+urn,Photographed+at+the+%3Ca+href%3D%22http%3A%2F%...,"ancient,baltimore,burial,death,empire,funeral,...",,...,http://creativecommons.org/licenses/by-nc-sa/2.0/,1,1,cf37054610,cf37054610,jpg,0,d29f01b149167d683f9ddde464bb3db,An ornate Roman urn,"Photographed at the Walters Art Museum, Baltim..."
2,1251599,51035803024@N01,bmitd67,2004-10-30 17:09:32.0,1099538888,Canon+PowerShot+S30,Jai+%26+Tara+on+the+Cumberland,Another+trip+for+the+happy+couple.,"blue+heron,cumberland+river,jai,tara,tennessee",,...,http://creativecommons.org/licenses/by-nc-sa/2.0/,1,1,4a4234e32c,4a4234e32c,jpg,0,d296e9e34bdae41edb6c679ff824ab2a,Jai & Tara on the Cumberland,Another trip for the happy couple.
3,2348587,73621375@N00,Thom+Watson,2004-12-18 21:08:09.0,1103497228,SONY+DSC-W1,Castle+gate+-+%22lite-brited%22,Taken+at+the+Miracle+of+Lights+display+in+Cent...,"bullrunpark,castle,centreville,christmas,decor...",,...,http://creativecommons.org/licenses/by-nc-sa/2.0/,2,1,7162c974c3,7162c974c3,jpg,0,d29ce96395848478b1e8396e44899,"Castle gate - ""lite-brited""",Taken at the Miracle of Lights display in Cent...
4,3516047,48600072071@N01,doctor+paradox,2005-01-18 16:44:18.0,1106084658,,A+Picture+Share%21,Tabular,"cameraphone,moblog,unfound",,...,http://creativecommons.org/licenses/by-nc-sa/2.0/,3,1,663e0d8b3d,663e0d8b3d,jpg,0,d29abf32c4e12ff881f975b70e0cec0,A Picture Share!,Tabular
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
999995,4648651054,24511045@N04,mtfrazier,2010-05-02 15:47:45.0,1275083371,Canon+EOS+50D,U.S.+Navy+Blue+Angels%3A+2010,2+May+2010%0ASunday%0ASt.+Joseph%2C+Missouri,,,...,http://creativecommons.org/licenses/by-nc-nd/2.0/,4072,5,2d12d73fb0,dd5856ea42,jpg,0,60fa2911cb81eb25b356e9fee978aef,U.S. Navy Blue Angels: 2010,"2 May 2010 Sunday St. Joseph, Missouri"
999996,4652130996,21963865@N04,GRAB1.0,2010-05-29 19:23:10.0,1275200833,SONY+DSLR-A230,Attempts+on+Her+Life,BAPA+1+production+of+Martin+Crimp%27s+Attempts...,,,...,http://creativecommons.org/licenses/by-nc-nd/2.0/,4003,5,8889121579,2f46599456,jpg,0,60f5ef5ce4c2d24566226abebd67d4,Attempts on Her Life,BAPA 1 production of Martin Crimp's Attempts o...
999997,4652568339,64025277@N00,1Sock,2010-05-13 15:38:37.0,1275234267,Canon+EOS+DIGITAL+REBEL+XT,Carlsbad+Caverns+3,%E2%99%A5%E2%99%A5%E2%99%A5%E2%99%A5%E2%99%A5%...,"carlsbad,carlsbad+caverns,cave,faa,new+mexico,...",,...,http://creativecommons.org/licenses/by-nc-nd/2.0/,4010,5,0a1808a69e,cf6d348e3d,jpg,0,60f029482d1d1028fda5281daf498f,Carlsbad Caverns 3,♥♥♥♥♥♥♥ Interested in purchasing this photogra...
999998,4653110895,20483509@N00,subberculture,2010-05-30 15:37:05.0,1275245596,Canon+DIGITAL+IXUS+40,Want,Isn%27t+that+gorgeous%3F,"2010,edinburgh+museum,may,phonebox,wood",,...,http://creativecommons.org/licenses/by-sa/2.0/,4066,5,77c3b3a254,c4697e1511,jpg,0,60f72775f433cf8de3efaeb431866153,Want,Isn't that gorgeous?


In [98]:
# Looking at a chunk with only the relevant columns that we need
df = pd.read_csv("./chunks/chunk1.tsv", sep="\t")[["key", "title_clean", "description_clean", "ext"]]
df.head()

Unnamed: 0,key,title_clean,description_clean,ext
0,d29e7c6a3028418c64eb15e3cf577c2,A Picture Share!,Antenna,jpg
1,d29f01b149167d683f9ddde464bb3db,An ornate Roman urn,"Photographed at the Walters Art Museum, Baltim...",jpg
2,d296e9e34bdae41edb6c679ff824ab2a,Jai & Tara on the Cumberland,Another trip for the happy couple.,jpg
3,d29ce96395848478b1e8396e44899,"Castle gate - ""lite-brited""",Taken at the Miracle of Lights display in Cent...,jpg
4,d29abf32c4e12ff881f975b70e0cec0,A Picture Share!,Tabular,jpg


### Grabbing each chunks from the folder, cleaning it up, only taking the entries which image exist and appending it to the global df

In [None]:
# the function that helps us to decide whether an image with certain id exists in storage, we only take the ones that we have the images for
def image_exists(item):
    name, _, _, ext, _ = item
    root=str(yfcc100m_images)
    image_path = (Path(root)/name[0:3]/name[3:6]/name).with_suffix("."+ext)
    if image_path.exists():
        return True
    else:
        return None

In [86]:
# This cell does it all, grabs each chunk, cleans it up based on image existing condition, etc.
global_df = pd.DataFrame()
chunks_dir = "./chunks"
for filename in os.listdir(chunks_dir):
        df = pd.read_csv(f"./chunks/{str(filename)}", sep="\t")[["key", "title_clean", "description_clean", "ext"]]
        df['caption'] = df["title_clean"]+". "+df['description_clean']
        df['is_exist'] = df.apply(image_exists, axis=1)
        df = df.dropna()[["key", "caption"]]
        df.columns = ['image_file', 'caption']
        global_df = global_df.append(df, ignore_index=True)

In [89]:
# saving the tsv to disk
global_df.to_csv('./chunks/YFCC_subset_clean.tsv', sep="\t", index=False)

In [101]:
# loading the tsv from disk (for explicitness, also my electricity was gone, glad it happened after I saved to the disk :( )

dataset = pd.read_csv(f"./chunks/YFCC_subset_clean.tsv", sep="\t")

In [153]:
"""
Luke Melas-Kyriazi's dataset.py's modified version for YFCC
"""
import warnings
from typing import Optional, Callable
from pathlib import Path
import numpy as np
import torch
import pandas as pd
from torch.utils.data import Dataset
from torchvision.datasets.folder import default_loader
from PIL import ImageFile
from PIL.Image import DecompressionBombWarning
ImageFile.LOAD_TRUNCATED_IMAGES = True
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=DecompressionBombWarning)


class CaptionDataset(Dataset):
    """
    A PyTorch Dataset class for (image, texts) tasks. Note that this dataset 
    returns the raw text rather than tokens. This is done on purpose, because
    it's easy to tokenize a batch of text after loading it from this dataset.
    """

    def __init__(self, *, images_root: str, captions_path: str, text_transform: Optional[Callable] = None, 
                 image_transform: Optional[Callable] = None, image_transform_type: str = 'torchvision',
                 include_captions: bool = True):
        """
        :param images_root: folder where images are stored
        :param captions_path: path to csv that maps image filenames to captions
        :param image_transform: image transform pipeline
        :param text_transform: image transform pipeline
        :param image_transform_type: image transform type, either `torchvision` or `albumentations`
        :param include_captions: Returns a dictionary with `image`, `text` if `true`; otherwise returns just the images.
        """

        # Base path for images
        self.images_root = Path(images_root)

        # Load captions as DataFrame
        self.captions = pd.read_csv(f"./chunks/YFCC_subset_clean.tsv", sep="\t")
        self.captions['image_file'] = self.captions['image_file'].astype(str)

        # PyTorch transformation pipeline for the image (normalizing, etc.)
        self.text_transform = text_transform
        self.image_transform = image_transform
        self.image_transform_type = image_transform_type.lower()
        assert self.image_transform_type in ['torchvision', 'albumentations']

        # Total number of datapoints
        self.size = len(self.captions)

        # Return image+captions or just images
        self.include_captions = include_captions
    
    def image_exists(item):
        name, caption = item
        root=str(self.images_root)
        image_path = (Path(root)/name[0:3]/name[3:6]/name).with_suffix(".jpg")

        return image_path.exists()

    def verify_that_all_images_exist(self):
        for image_file in self.captions['image_file']:
            if not image_exists:
                print(f'file does not exist: {p}')

    def _get_raw_image(self, i):
        name = self.captions.iloc[i]['image_file']
        image_path = (Path(self.images_root)/name[0:3]/name[3:6]/name).with_suffix(".jpg")
        image = default_loader(image_path)
        return image

    def _get_raw_text(self, i):
        return self.captions.iloc[i]['caption']

    def __getitem__(self, i):
        image = self._get_raw_image(i)
        caption = self._get_raw_text(i)
        if self.image_transform is not None:
            if self.image_transform_type == 'torchvision':
                image = self.image_transform(image)
            elif self.image_transform_type == 'albumentations':
                image = self.image_transform(image=np.array(image))['image']
            else:
                raise NotImplementedError(f"{self.image_transform_type=}")
        return {'image': image, 'text': caption} if self.include_captions else image

    def __len__(self):
        return self.size


if __name__ == "__main__":
    import albumentations as A
    from albumentations.pytorch import ToTensorV2
    from transformers import AutoTokenizer
    

    images_root = "/home/khali/TPU-Test/YFCC100M_OpenAI_subset/data/data/images"
    captions_path = './YFCC_subset_clean.tsv'
    image_size = 256
    
    # Create transforms
    def image_transform(image):
        s = min(image.size)
        r = image_size / s
        s = (round(r * image.size[1]), round(r * image.size[0]))
        image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)
        image = TF.center_crop(image, output_size = 2 * [image_size])
        image = torch.unsqueeze(T.ToTensor()(image), 0)
        image = image.permute(0, 2, 3, 1).numpy()
        return image
    
    # Create dataset
    dataset = CaptionDataset(
        images_root=images_root,
        captions_path=captions_path,
        image_transform=image_transform,
        image_transform_type='torchvision',
        include_captions=False
    )

In [155]:
len(dataset)

2483316

In [156]:
dataloader = DataLoader(dataset, batch_size=32, num_workers=4)

In [1]:
# looking at a batch
next(iter(dataloader))

In [None]:
# import matplotlib.pyplot as plt
# for tensor_image, _ in dataloader:
#     print(tensor_image)
#     plt.imshow(tensor_image.permute(1, 2, 0))
#     break

## Encoding

In [158]:
def encode(model, batch):
#     print("jitting encode function")
    _, indices = model.encode(batch)
    return indices

In [160]:
def superbatch_generator(dataloader, num_tpus):
    iter_loader = iter(dataloader)
    for batch in iter_loader:
        superbatch = [batch.squeeze(1)]
        try:
            for b in range(num_tpus-1):
                batch = next(iter_loader)
                if batch is None:
                    break
                # Skip incomplete last batch
                if batch.shape[0] == dataloader.batch_size:
                    superbatch.append(batch.squeeze(1))
        except StopIteration:
            pass
        superbatch = torch.stack(superbatch, axis=0)
        yield superbatch

In [170]:
import os

def encode_captioned_dataset(dataset, output_tsv, batch_size=32, num_workers=16):
    if os.path.isfile(output_tsv):
        print(f"Destination file {output_tsv} already exists, please move away.")
        return
    
    num_tpus = 8    
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
    superbatches = superbatch_generator(dataloader, num_tpus=num_tpus)
    
    p_encoder = pmap(lambda batch: encode(model, batch))

    # We save each superbatch to avoid reallocation of buffers as we process them.
    # We keep the file open to prevent excessive file seeks.
    with open(output_tsv, "w") as file:
        iterations = len(dataset) // (batch_size * num_tpus)
        for n in tqdm(range(iterations)):
            superbatch = next(superbatches)
            encoded = p_encoder(superbatch.numpy())
            encoded = encoded.reshape(-1, encoded.shape[-1])

            # Extract fields from the dataset internal `captions` property, and save to disk
            start_index = n * batch_size * num_tpus
            end_index = (n+1) * batch_size * num_tpus
            paths = dataset.captions["image_file"][start_index:end_index].values
            captions = dataset.captions["caption"][start_index:end_index].values
            encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))
            batch_df = pd.DataFrame.from_dict({"image_file": paths, "caption": captions, "encoding": encoded_as_string})
            batch_df.to_csv(file, sep='\t', header=(n==0), index=None)

In [171]:
encode_captioned_dataset(dataset, yfcc100m_output, batch_size=64, num_workers=16)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4850/4850 [2:27:51<00:00,  1.83s/it]


----