{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c423d2a1-475e-482e-b759-f16456fd6707",
   "metadata": {},
   "source": [
    "# Install"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0440d6a7-78b9-49e9-98a2-9a5ed75e1a2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "!git clone https://github.com/kopyl/PixArt-alpha.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0abadf51-a7e3-4091-bb02-0bdd8d28fb73",
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd PixArt-alpha"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4df1af24-f439-485d-a946-966dbf16c49b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!pip install torch==2.0.0+cu117 torchvision==0.15.1+cu117 torchaudio==2.0.1 --index-url https://download.pytorch.org/whl/cu117\n",
    "!pip install -r requirements.txt\n",
    "!pip install wandb"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d44474fd-0b92-48fc-b4cf-142b59d3917c",
   "metadata": {},
   "source": [
    "## Download model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06b1c1c9-f8b1-4719-8564-2383eac9ff28",
   "metadata": {},
   "outputs": [],
   "source": [
    "!python tools/download.py --model_names \"PixArt-XL-2-512x512.pth\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f298a89c-d2a5-4da7-8304-c1390da0ba58",
   "metadata": {},
   "source": [
    "## Make dataset out of Hugginggface dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e17b8883-0a5c-4fa3-a7d0-e8ee95e42027",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from tqdm.notebook import tqdm\n",
    "from datasets import load_dataset\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92957b2c-6765-48ee-9296-d6739066d74d",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = load_dataset(\"lambdalabs/pokemon-blip-captions\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0095cdda-c31a-48ee-a115-076a5fc393c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "root_dir = \"/workspace/pixart-pokemon\"\n",
    "images_dir = \"images\"\n",
    "captions_dir = \"captions\"\n",
    "\n",
    "images_dir_absolute = os.path.join(root_dir, images_dir)\n",
    "captions_dir_absolute = os.path.join(root_dir, captions_dir)\n",
    "\n",
    "if not os.path.exists(root_dir):\n",
    "    os.makedirs(os.path.join(root_dir, images_dir))\n",
    "\n",
    "if not os.path.exists(os.path.join(root_dir, images_dir)):\n",
    "    os.makedirs(os.path.join(root_dir, images_dir))\n",
    "if not os.path.exists(os.path.join(root_dir, captions_dir)):\n",
    "    os.makedirs(os.path.join(root_dir, captions_dir))\n",
    "\n",
    "image_format = \"png\"\n",
    "json_name = \"partition/data_info.json\"\n",
    "if not os.path.exists(os.path.join(root_dir, \"partition\")):\n",
    "    os.makedirs(os.path.join(root_dir, \"partition\"))\n",
    "\n",
    "absolute_json_name = os.path.join(root_dir, json_name)\n",
    "data_info = []\n",
    "\n",
    "order = 0\n",
    "for item in tqdm(dataset[\"train\"]): \n",
    "    image = item[\"image\"]\n",
    "    image.save(f\"{images_dir_absolute}/{order}.{image_format}\")\n",
    "    with open(f\"{captions_dir_absolute}/{order}.txt\", \"w\") as text_file:\n",
    "        text_file.write(item[\"text\"])\n",
    "    \n",
    "    width, height = 512, 512\n",
    "    ratio = 1\n",
    "    data_info.append({\n",
    "        \"height\": height,\n",
    "        \"width\": width,\n",
    "        \"ratio\": ratio,\n",
    "        \"path\": f\"images/{order}.{image_format}\",\n",
    "        \"prompt\": item[\"text\"],\n",
    "    })\n",
    "        \n",
    "    order += 1\n",
    "\n",
    "with open(absolute_json_name, \"w\") as json_file:\n",
    "    json.dump(data_info, json_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25be1c03",
   "metadata": {},
   "source": [
    "## Extract features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f07a4f5-1873-48bf-86d0-9304942de5d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "!python /workspace/PixArt-alpha/tools/extract_features.py \\\n",
    "    --img_size 512 \\\n",
    "    --json_path \"/workspace/pixart-pokemon/partition/data_info.json\" \\\n",
    "    --t5_save_root \"/workspace/pixart-pokemon/caption_feature_wmask\" \\\n",
    "    --vae_save_root \"/workspace/pixart-pokemon/img_vae_features\" \\\n",
    "    --pretrained_models_dir \"/workspace/PixArt-alpha/output/pretrained_models\" \\\n",
    "    --dataset_root \"/workspace/pixart-pokemon\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fc653d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "!wandb login REPLACE_THIS_WITH_YOUR_AUTH_TOKEN_OF_WANDB"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2cf1fd1a",
   "metadata": {},
   "source": [
    "## Train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea0e9dab-17bc-45ed-9c81-b670bbb8de47",
   "metadata": {},
   "outputs": [],
   "source": [
    "!python -m torch.distributed.launch \\\n",
    "    train_scripts/train.py \\\n",
    "    /workspace/PixArt-alpha/notebooks/PixArt_xl2_img512_internal_for_pokemon_sample_training.py \\\n",
    "    --work-dir output/trained_model \\\n",
    "    --report_to=\"wandb\" \\\n",
    "    --loss_report_name=\"train_loss\""
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}