# Installation

In [None]:
#!pip install git+https://github.com/huggingface/transformers/
#!pip install git+https://github.com/google/flax

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%cd ../../vqgan-jax

# Custom BART Model

In [None]:
# TODO: set those args in a config file
OUTPUT_VOCAB_SIZE = 16384 + 1  # encoded image token space + 1 for bos
OUTPUT_LENGTH = 256 + 1  # number of encoded tokens + 1 for bos
BOS_TOKEN_ID = 16384
BASE_MODEL = 'facebook/bart-large'

In [None]:
import jax
import flax.linen as nn

from transformers.models.bart.modeling_flax_bart import *
from transformers import BartTokenizer, FlaxBartForConditionalGeneration

class CustomFlaxBartModule(FlaxBartModule):
    def setup(self):
        # we keep shared to easily load pre-trained weights
        self.shared = nn.Embed(
            self.config.vocab_size,
            self.config.d_model,
            embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
            dtype=self.dtype,
        )
        # a separate embedding is used for the decoder
        self.decoder_embed = nn.Embed(
            OUTPUT_VOCAB_SIZE,
            self.config.d_model,
            embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
            dtype=self.dtype,
        )
        self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)

        # the decoder has a different config
        decoder_config = BartConfig(self.config.to_dict())
        decoder_config.max_position_embeddings = OUTPUT_LENGTH
        decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
        self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)

class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
    def setup(self):
        self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
        self.lm_head = nn.Dense(
            OUTPUT_VOCAB_SIZE,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
        )
        self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, OUTPUT_VOCAB_SIZE))

class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
    module_class = CustomFlaxBartForConditionalGenerationModule

In [None]:
import wandb
run = wandb.init()
artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-1ef8yxby:latest', type='bart_model')
artifact_dir = artifact.download()

In [None]:
# create our model and initialize it randomly
model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)

In [None]:
model.config.forced_bos_token_id = None

In [None]:
# we verify that the shape has not been modified
model.params['final_logits_bias'].shape

## Inference

In [None]:
tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)

In [None]:
input_text = ['I enjoy walking with my cute dog']*8

In [None]:
input_ids_test = tokenizer(input_text, return_tensors='jax')

In [None]:
input_ids_test

In [None]:
greedy_output = model.generate(input_ids_test['input_ids'], max_length=257)

In [None]:
greedy_output[0].shape

In [None]:
greedy_output[0]

In [None]:
greedy_output[0][0]

# VGAN Jax

In [None]:
import io

import requests
from PIL import Image
import numpy as np

import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torchvision.transforms import InterpolationMode

In [None]:
from modeling_flax_vqgan import VQModel

In [None]:
def custom_to_pil(x):
    x = np.clip(x, 0., 1.)
    x = (255*x).astype(np.uint8)
    x = Image.fromarray(x)
    if not x.mode == "RGB":
        x = x.convert("RGB")
    return x

In [None]:
model = VQModel.from_pretrained("flax-community/vqgan_f16_16384")

In [None]:
def get_images(indices, model):
    indices =  indices[:, 1:]
    print(indices.shape)
    img = model.decode_code(indices)
    return img

In [None]:
custom_to_pil(np.asarray(get_images(jnp.expand_dims(greedy_output[0][0],0), model)[0]))