{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "ClothingGAN-Demo.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Bm8iDDKC1LZo"
      },
      "source": [
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mfrashad/ClothingGAN/blob/master/ClothingGAN_Demo.ipynb)\n",
        "# Clothing GAN demo\n",
        "Notebook by [@mfrashad](https://mfrashad.com)\n",
        "\n",
        "\n",
        "<br>\n",
        "Make sure runtime type is GPU"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 200
        },
        "cellView": "form",
        "id": "Kj8mGkmH0xgA",
        "outputId": "6a793110-884d-4f59-89ec-9c5eced9b98a"
      },
      "source": [
        "#@title Install dependencies (restart runtime after installing)\n",
        "from IPython.display import Javascript\n",
        "display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 200})'''))\n",
        "!pip install ninja gradio fbpca boto3 requests==2.23.0 urllib3==1.25.11"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "application/javascript": [
              "google.colab.output.setIframeHeight(0, true, {maxHeight: 200})"
            ],
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Collecting ninja\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/1d/de/393468f2a37fc2c1dc3a06afc37775e27fde2d16845424141d4da62c686d/ninja-1.10.0.post2-py3-none-manylinux1_x86_64.whl (107kB)\n",
            "\r\u001b[K     |███                             | 10kB 19.1MB/s eta 0:00:01\r\u001b[K     |██████                          | 20kB 17.8MB/s eta 0:00:01\r\u001b[K     |█████████▏                      | 30kB 10.3MB/s eta 0:00:01\r\u001b[K     |████████████▏                   | 40kB 8.4MB/s eta 0:00:01\r\u001b[K     |███████████████▎                | 51kB 5.3MB/s eta 0:00:01\r\u001b[K     |██████████████████▎             | 61kB 5.4MB/s eta 0:00:01\r\u001b[K     |█████████████████████▍          | 71kB 6.0MB/s eta 0:00:01\r\u001b[K     |████████████████████████▍       | 81kB 6.4MB/s eta 0:00:01\r\u001b[K     |███████████████████████████▍    | 92kB 6.5MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████▌ | 102kB 6.8MB/s eta 0:00:01\r\u001b[K     |████████████████████████████████| 112kB 6.8MB/s \n",
            "\u001b[?25hCollecting gradio\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/59/97/b7210489b201409175e63afa3307f8e067fe1289cc19a68003dfeef03f06/gradio-2.1.0-py3-none-any.whl (2.5MB)\n",
            "\u001b[K     |████████████████████████████████| 2.5MB 8.5MB/s \n",
            "\u001b[?25hCollecting fbpca\n",
            "  Downloading https://files.pythonhosted.org/packages/a7/a5/2085d0645a4bb4f0b606251b0b7466c61326e4a471d445c1c3761a2d07bc/fbpca-1.0.tar.gz\n",
            "Collecting boto3\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/2c/e1/2c6c374f043c3f22829563b7fb2bf28fe3dca7ce5994bc5ceeff0959d6c9/boto3-1.17.105-py2.py3-none-any.whl (131kB)\n",
            "\u001b[K     |████████████████████████████████| 133kB 26.6MB/s \n",
            "\u001b[?25hRequirement already satisfied: requests==2.23.0 in /usr/local/lib/python3.7/dist-packages (2.23.0)\n",
            "Collecting urllib3==1.25.11\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/56/aa/4ef5aa67a9a62505db124a5cb5262332d1d4153462eb8fd89c9fa41e5d92/urllib3-1.25.11-py2.py3-none-any.whl (127kB)\n",
            "\u001b[K     |████████████████████████████████| 133kB 22.3MB/s \n",
            "\u001b[?25hRequirement already satisfied: pillow in /usr/local/lib/python3.7/dist-packages (from gradio) (7.1.2)\n",
            "Collecting ffmpy\n",
            "  Downloading https://files.pythonhosted.org/packages/bf/e2/947df4b3d666bfdd2b0c6355d215c45d2d40f929451cb29a8a2995b29788/ffmpy-0.3.0.tar.gz\n",
            "Requirement already satisfied: Flask>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from gradio) (1.1.4)\n",
            "Collecting flask-cachebuster\n",
            "  Downloading https://files.pythonhosted.org/packages/74/47/f3e1fedfaad965c81c2f17234636d72f71450f1b4522ca26d2b7eb4a0a74/Flask-CacheBuster-1.0.0.tar.gz\n",
            "Collecting pycryptodome\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/ad/16/9627ab0493894a11c68e46000dbcc82f578c8ff06bc2980dcd016aea9bd3/pycryptodome-3.10.1-cp35-abi3-manylinux2010_x86_64.whl (1.9MB)\n",
            "\u001b[K     |████████████████████████████████| 1.9MB 26.5MB/s \n",
            "\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from gradio) (1.1.5)\n",
            "Collecting Flask-Cors>=3.0.8\n",
            "  Downloading https://files.pythonhosted.org/packages/db/84/901e700de86604b1c4ef4b57110d4e947c218b9997adf5d38fa7da493bce/Flask_Cors-3.0.10-py2.py3-none-any.whl\n",
            "Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from gradio) (1.4.1)\n",
            "Collecting paramiko\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/95/19/124e9287b43e6ff3ebb9cdea3e5e8e88475a873c05ccdf8b7e20d2c4201e/paramiko-2.7.2-py2.py3-none-any.whl (206kB)\n",
            "\u001b[K     |████████████████████████████████| 215kB 49.0MB/s \n",
            "\u001b[?25hCollecting analytics-python\n",
            "  Downloading https://files.pythonhosted.org/packages/30/81/2f447982f8d5dec5b56c10ca9ac53e5de2b2e9e2bdf7e091a05731f21379/analytics_python-1.3.1-py2.py3-none-any.whl\n",
            "Collecting markdown2\n",
            "  Downloading https://files.pythonhosted.org/packages/5d/be/3924cc1c0e12030b5225de2b4521f1dc729730773861475de26be64a0d2b/markdown2-2.4.0-py2.py3-none-any.whl\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from gradio) (1.19.5)\n",
            "Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from gradio) (3.2.2)\n",
            "Collecting Flask-Login\n",
            "  Downloading https://files.pythonhosted.org/packages/2b/83/ac5bf3279f969704fc1e63f050c50e10985e50fd340e6069ec7e09df5442/Flask_Login-0.5.0-py2.py3-none-any.whl\n",
            "Collecting jmespath<1.0.0,>=0.7.1\n",
            "  Downloading https://files.pythonhosted.org/packages/07/cb/5f001272b6faeb23c1c9e0acc04d48eaaf5c862c17709d20e3469c6e0139/jmespath-0.10.0-py2.py3-none-any.whl\n",
            "Collecting s3transfer<0.5.0,>=0.4.0\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/63/d0/693477c688348654ddc21dcdce0817653a294aa43f41771084c25e7ff9c7/s3transfer-0.4.2-py2.py3-none-any.whl (79kB)\n",
            "\u001b[K     |████████████████████████████████| 81kB 12.5MB/s \n",
            "\u001b[?25hCollecting botocore<1.21.0,>=1.20.105\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/95/da/3417300f85ba5173e8dba9248b9ae8bcb74a8aac1c92fa3d257f99073b9e/botocore-1.20.105-py2.py3-none-any.whl (7.7MB)\n",
            "\u001b[K     |████████████████████████████████| 7.7MB 45.1MB/s \n",
            "\u001b[?25hRequirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests==2.23.0) (2021.5.30)\n",
            "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests==2.23.0) (3.0.4)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests==2.23.0) (2.10)\n",
            "Requirement already satisfied: click<8.0,>=5.1 in /usr/local/lib/python3.7/dist-packages (from Flask>=1.1.1->gradio) (7.1.2)\n",
            "Requirement already satisfied: Jinja2<3.0,>=2.10.1 in /usr/local/lib/python3.7/dist-packages (from Flask>=1.1.1->gradio) (2.11.3)\n",
            "Requirement already satisfied: Werkzeug<2.0,>=0.15 in /usr/local/lib/python3.7/dist-packages (from Flask>=1.1.1->gradio) (1.0.1)\n",
            "Requirement already satisfied: itsdangerous<2.0,>=0.24 in /usr/local/lib/python3.7/dist-packages (from Flask>=1.1.1->gradio) (1.1.0)\n",
            "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->gradio) (2.8.1)\n",
            "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->gradio) (2018.9)\n",
            "Requirement already satisfied: Six in /usr/local/lib/python3.7/dist-packages (from Flask-Cors>=3.0.8->gradio) (1.15.0)\n",
            "Collecting pynacl>=1.0.1\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/9d/57/2f5e6226a674b2bcb6db531e8b383079b678df5b10cdaa610d6cf20d77ba/PyNaCl-1.4.0-cp35-abi3-manylinux1_x86_64.whl (961kB)\n",
            "\u001b[K     |████████████████████████████████| 962kB 53.6MB/s \n",
            "\u001b[?25hCollecting cryptography>=2.5\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/b2/26/7af637e6a7e87258b963f1731c5982fb31cd507f0d90d91836e446955d02/cryptography-3.4.7-cp36-abi3-manylinux2014_x86_64.whl (3.2MB)\n",
            "\u001b[K     |████████████████████████████████| 3.2MB 49.0MB/s \n",
            "\u001b[?25hCollecting bcrypt>=3.1.3\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/26/70/6d218afbe4c73538053c1016dd631e8f25fffc10cd01f5c272d7acf3c03d/bcrypt-3.2.0-cp36-abi3-manylinux2010_x86_64.whl (63kB)\n",
            "\u001b[K     |████████████████████████████████| 71kB 11.2MB/s \n",
            "\u001b[?25hCollecting backoff==1.10.0\n",
            "  Downloading https://files.pythonhosted.org/packages/f0/32/c5dd4f4b0746e9ec05ace2a5045c1fc375ae67ee94355344ad6c7005fd87/backoff-1.10.0-py2.py3-none-any.whl\n",
            "Collecting monotonic>=1.5\n",
            "  Downloading https://files.pythonhosted.org/packages/9a/67/7e8406a29b6c45be7af7740456f7f37025f0506ae2e05fb9009a53946860/monotonic-1.6-py2.py3-none-any.whl\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->gradio) (0.10.0)\n",
            "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->gradio) (1.3.1)\n",
            "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->gradio) (2.4.7)\n",
            "Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from Jinja2<3.0,>=2.10.1->Flask>=1.1.1->gradio) (2.0.1)\n",
            "Requirement already satisfied: cffi>=1.4.1 in /usr/local/lib/python3.7/dist-packages (from pynacl>=1.0.1->paramiko->gradio) (1.14.5)\n",
            "Requirement already satisfied: pycparser in /usr/local/lib/python3.7/dist-packages (from cffi>=1.4.1->pynacl>=1.0.1->paramiko->gradio) (2.20)\n",
            "Building wheels for collected packages: fbpca, ffmpy, flask-cachebuster\n",
            "  Building wheel for fbpca (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for fbpca: filename=fbpca-1.0-cp37-none-any.whl size=11376 sha256=4b14cfd952b104d56a985021caf774f906fed7ca5fae1d1f41c570c6c0ea121c\n",
            "  Stored in directory: /root/.cache/pip/wheels/53/a2/dd/9b66cf53dbc58cec1e613d216689e5fa946d3e7805c30f60dc\n",
            "  Building wheel for ffmpy (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for ffmpy: filename=ffmpy-0.3.0-cp37-none-any.whl size=4710 sha256=571d113a8f5d748045ade970eca5e1a4bab4ed32cfb262851649d238839f682d\n",
            "  Stored in directory: /root/.cache/pip/wheels/cc/ac/c4/bef572cb7e52bfca170046f567e64858632daf77e0f34e5a74\n",
            "  Building wheel for flask-cachebuster (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for flask-cachebuster: filename=Flask_CacheBuster-1.0.0-cp37-none-any.whl size=3372 sha256=37a3576c476072ff54679fb45e5e3c1150e6a1790d23da3873cc2394f3be7741\n",
            "  Stored in directory: /root/.cache/pip/wheels/9f/fc/a7/ab5712c3ace9a8f97276465cc2937316ab8063c1fea488ea77\n",
            "Successfully built fbpca ffmpy flask-cachebuster\n",
            "\u001b[31mERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible.\u001b[0m\n",
            "Installing collected packages: ninja, ffmpy, flask-cachebuster, pycryptodome, Flask-Cors, pynacl, cryptography, bcrypt, paramiko, backoff, monotonic, analytics-python, markdown2, Flask-Login, gradio, fbpca, jmespath, urllib3, botocore, s3transfer, boto3\n",
            "  Found existing installation: urllib3 1.24.3\n",
            "    Uninstalling urllib3-1.24.3:\n",
            "      Successfully uninstalled urllib3-1.24.3\n",
            "Successfully installed Flask-Cors-3.0.10 Flask-Login-0.5.0 analytics-python-1.3.1 backoff-1.10.0 bcrypt-3.2.0 boto3-1.17.105 botocore-1.20.105 cryptography-3.4.7 fbpca-1.0 ffmpy-0.3.0 flask-cachebuster-1.0.0 gradio-2.1.0 jmespath-0.10.0 markdown2-2.4.0 monotonic-1.6 ninja-1.10.0.post2 paramiko-2.7.2 pycryptodome-3.10.1 pynacl-1.4.0 s3transfer-0.4.2 urllib3-1.25.11\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "qwf-gggHtA_t",
        "outputId": "68f7afee-6889-4f9c-97d0-de9ba931c206"
      },
      "source": [
        "!git clone https://github.com/mfrashad/ClothingGAN.git\n",
        "%cd ClothingGAN/"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Cloning into 'ClothingGAN'...\n",
            "remote: Enumerating objects: 333, done.\u001b[K\n",
            "remote: Counting objects: 100% (60/60), done.\u001b[K\n",
            "remote: Compressing objects: 100% (37/37), done.\u001b[K\n",
            "remote: Total 333 (delta 38), reused 22 (delta 22), pack-reused 273\u001b[K\n",
            "Receiving objects: 100% (333/333), 47.08 MiB | 51.89 MiB/s, done.\n",
            "Resolving deltas: 100% (108/108), done.\n",
            "/content/ClothingGAN\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 153
        },
        "cellView": "form",
        "id": "2BYsIETVtGnF",
        "outputId": "e3b5c2ac-9c18-4917-de0f-1f2c126aa696"
      },
      "source": [
        "#@title Install other dependencies\n",
        "from IPython.display import Javascript\n",
        "display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 200})'''))\n",
        "!git submodule update --init --recursive\n",
        "!python -c \"import nltk; nltk.download('wordnet')\""
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "application/javascript": [
              "google.colab.output.setIframeHeight(0, true, {maxHeight: 200})"
            ],
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Submodule 'stylegan/stylegan_tf' (https://github.com/NVlabs/stylegan.git) registered for path 'models/stylegan/stylegan_tf'\n",
            "Submodule 'stylegan2/stylegan2-pytorch' (https://github.com/harskish/stylegan2-pytorch.git) registered for path 'models/stylegan2/stylegan2-pytorch'\n",
            "Cloning into '/content/ClothingGAN/models/stylegan/stylegan_tf'...\n",
            "Cloning into '/content/ClothingGAN/models/stylegan2/stylegan2-pytorch'...\n",
            "Submodule path 'models/stylegan/stylegan_tf': checked out '66813a32aac5045fcde72751522a0c0ba963f6f2'\n",
            "Submodule path 'models/stylegan2/stylegan2-pytorch': checked out '91ea2a7a4320701535466cce942c9e099d65670e'\n",
            "[nltk_data] Downloading package wordnet to /root/nltk_data...\n",
            "[nltk_data]   Unzipping corpora/wordnet.zip.\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "cellView": "form",
        "id": "dJg91yvSwKi3",
        "outputId": "15c2285c-e8da-45b9-bdc2-9bab3709b270"
      },
      "source": [
        "#@title Load Model\n",
        "selected_model = 'lookbook'\n",
        "\n",
        "# Load model\n",
        "from IPython.utils import io\n",
        "import torch\n",
        "import PIL\n",
        "import numpy as np\n",
        "import ipywidgets as widgets\n",
        "from PIL import Image\n",
        "import imageio\n",
        "from models import get_instrumented_model\n",
        "from decomposition import get_or_compute\n",
        "from config import Config\n",
        "from skimage import img_as_ubyte\n",
        "\n",
        "# Speed up computation\n",
        "torch.autograd.set_grad_enabled(False)\n",
        "torch.backends.cudnn.benchmark = True\n",
        "\n",
        "# Specify model to use\n",
        "config = Config(\n",
        "  model='StyleGAN2',\n",
        "  layer='style',\n",
        "  output_class=selected_model,\n",
        "  components=80,\n",
        "  use_w=True,\n",
        "  batch_size=5_000, # style layer quite small\n",
        ")\n",
        "\n",
        "inst = get_instrumented_model(config.model, config.output_class,\n",
        "                              config.layer, torch.device('cuda'), use_w=config.use_w)\n",
        "\n",
        "path_to_components = get_or_compute(config, inst)\n",
        "\n",
        "model = inst.model\n",
        "\n",
        "comps = np.load(path_to_components)\n",
        "lst = comps.files\n",
        "latent_dirs = []\n",
        "latent_stdevs = []\n",
        "\n",
        "load_activations = False\n",
        "\n",
        "for item in lst:\n",
        "    if load_activations:\n",
        "      if item == 'act_comp':\n",
        "        for i in range(comps[item].shape[0]):\n",
        "          latent_dirs.append(comps[item][i])\n",
        "      if item == 'act_stdev':\n",
        "        for i in range(comps[item].shape[0]):\n",
        "          latent_stdevs.append(comps[item][i])\n",
        "    else:\n",
        "      if item == 'lat_comp':\n",
        "        for i in range(comps[item].shape[0]):\n",
        "          latent_dirs.append(comps[item][i])\n",
        "      if item == 'lat_stdev':\n",
        "        for i in range(comps[item].shape[0]):\n",
        "          latent_stdevs.append(comps[item][i])"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "StyleGAN2: Optimized CUDA op FusedLeakyReLU not available, using native PyTorch fallback.\n",
            "StyleGAN2: Optimized CUDA op UpFirDn2d not available, using native PyTorch fallback.\n",
            "Downloading https://drive.google.com/uc?export=download&id=1-F-RMkbHUv_S_k-_olh43mu5rDUMGYKe\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "id": "uCR_3Ghos2kK"
      },
      "source": [
        "#@title Define functions\n",
        "from ipywidgets import fixed\n",
        "\n",
        "# Taken from https://github.com/alexanderkuk/log-progress\n",
        "def log_progress(sequence, every=1, size=None, name='Items'):\n",
        "    from ipywidgets import IntProgress, HTML, VBox\n",
        "    from IPython.display import display\n",
        "\n",
        "    is_iterator = False\n",
        "    if size is None:\n",
        "        try:\n",
        "            size = len(sequence)\n",
        "        except TypeError:\n",
        "            is_iterator = True\n",
        "    if size is not None:\n",
        "        if every is None:\n",
        "            if size <= 200:\n",
        "                every = 1\n",
        "            else:\n",
        "                every = int(size / 200)     # every 0.5%\n",
        "    else:\n",
        "        assert every is not None, 'sequence is iterator, set every'\n",
        "\n",
        "    if is_iterator:\n",
        "        progress = IntProgress(min=0, max=1, value=1)\n",
        "        progress.bar_style = 'info'\n",
        "    else:\n",
        "        progress = IntProgress(min=0, max=size, value=0)\n",
        "    label = HTML()\n",
        "    box = VBox(children=[label, progress])\n",
        "    display(box)\n",
        "\n",
        "    index = 0\n",
        "    try:\n",
        "        for index, record in enumerate(sequence, 1):\n",
        "            if index == 1 or index % every == 0:\n",
        "                if is_iterator:\n",
        "                    label.value = '{name}: {index} / ?'.format(\n",
        "                        name=name,\n",
        "                        index=index\n",
        "                    )\n",
        "                else:\n",
        "                    progress.value = index\n",
        "                    label.value = u'{name}: {index} / {size}'.format(\n",
        "                        name=name,\n",
        "                        index=index,\n",
        "                        size=size\n",
        "                    )\n",
        "            yield record\n",
        "    except:\n",
        "        progress.bar_style = 'danger'\n",
        "        raise\n",
        "    else:\n",
        "        progress.bar_style = 'success'\n",
        "        progress.value = index\n",
        "        label.value = \"{name}: {index}\".format(\n",
        "            name=name,\n",
        "            index=str(index or '?')\n",
        "        )\n",
        "\n",
        "def name_direction(sender):\n",
        "  if not text.value:\n",
        "    print('Please name the direction before saving')\n",
        "    return\n",
        "    \n",
        "  if num in named_directions.values():\n",
        "    target_key = list(named_directions.keys())[list(named_directions.values()).index(num)]\n",
        "    print(f'Direction already named: {target_key}')\n",
        "    print(f'Overwriting... ')\n",
        "    del(named_directions[target_key])\n",
        "  named_directions[text.value] = [num, start_layer.value, end_layer.value]\n",
        "  save_direction(random_dir, text.value)\n",
        "  for item in named_directions:\n",
        "    print(item, named_directions[item])\n",
        "\n",
        "def save_direction(direction, filename):\n",
        "  filename += \".npy\"\n",
        "  np.save(filename, direction, allow_pickle=True, fix_imports=True)\n",
        "  print(f'Latent direction saved as {filename}')\n",
        "\n",
        "def mix_w(w1, w2, content, style):\n",
        "    for i in range(0,5):\n",
        "        w2[i] = w1[i] * (1 - content) + w2[i] * content\n",
        "\n",
        "    for i in range(5, 16):\n",
        "        w2[i] = w1[i] * (1 - style) + w2[i] * style\n",
        "    \n",
        "    return w2\n",
        "\n",
        "def display_sample_pytorch(seed, truncation, directions, distances, scale, start, end, w=None, disp=True, save=None, noise_spec=None):\n",
        "    # blockPrint()\n",
        "    model.truncation = truncation\n",
        "    if w is None:\n",
        "        w = model.sample_latent(1, seed=seed).detach().cpu().numpy()\n",
        "        w = [w]*model.get_max_latents() # one per layer\n",
        "    else:\n",
        "        w = [np.expand_dims(x, 0) for x in w]\n",
        "    \n",
        "    for l in range(start, end):\n",
        "      for i in range(len(directions)):\n",
        "        w[l] = w[l] + directions[i] * distances[i] * scale\n",
        "    \n",
        "    torch.cuda.empty_cache()\n",
        "    #save image and display\n",
        "    out = model.sample_np(w)\n",
        "    final_im = Image.fromarray((out * 255).astype(np.uint8)).resize((500,500),Image.LANCZOS)\n",
        "    \n",
        "    \n",
        "    if save is not None:\n",
        "      if disp == False:\n",
        "        print(save)\n",
        "      final_im.save(f'out/{seed}_{save:05}.png')\n",
        "    if disp:\n",
        "      display(final_im)\n",
        "    \n",
        "    return final_im\n",
        "\n",
        "def generate_mov(seed, truncation, direction_vec, scale, layers, n_frames, out_name = 'out', noise_spec = None, loop=True):\n",
        "  \"\"\"Generates a mov moving back and forth along the chosen direction vector\"\"\"\n",
        "  # Example of reading a generated set of images, and storing as MP4.\n",
        "  %mkdir out\n",
        "  movieName = f'out/{out_name}.mp4'\n",
        "  offset = -10\n",
        "  step = 20 / n_frames\n",
        "  imgs = []\n",
        "  for i in log_progress(range(n_frames), name = \"Generating frames\"):\n",
        "    print(f'\\r{i} / {n_frames}', end='')\n",
        "    w = model.sample_latent(1, seed=seed).cpu().numpy()\n",
        "\n",
        "    model.truncation = truncation\n",
        "    w = [w]*model.get_max_latents() # one per layer\n",
        "    for l in layers:\n",
        "      if l <= model.get_max_latents():\n",
        "          w[l] = w[l] + direction_vec * offset * scale\n",
        "\n",
        "    #save image and display\n",
        "    out = model.sample_np(w)\n",
        "    final_im = Image.fromarray((out * 255).astype(np.uint8))\n",
        "    imgs.append(out)\n",
        "    #increase offset\n",
        "    offset += step\n",
        "  if loop:\n",
        "    imgs += imgs[::-1]\n",
        "  with imageio.get_writer(movieName, mode='I') as writer:\n",
        "    for image in log_progress(list(imgs), name = \"Creating animation\"):\n",
        "        writer.append_data(img_as_ubyte(image))"
      ],
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 640
        },
        "cellView": "form",
        "id": "jneXxZnNwHo5",
        "outputId": "c8e2b76e-3a00-47f5-ba2d-51606f09ee93"
      },
      "source": [
        "#@title Demo UI\n",
        "import gradio as gr\n",
        "import numpy as np\n",
        "\n",
        "def generate_image(seed1, seed2, content, style, truncation, c0, c1, c2, c3, c4, c5, c6, start_layer, end_layer):\n",
        "    seed1 = int(seed1)\n",
        "    seed2 = int(seed2)\n",
        "\n",
        "    scale = 1\n",
        "    params = {'c0': c0,\n",
        "          'c1': c1,\n",
        "          'c2': c2,\n",
        "          'c3': c3,\n",
        "          'c4': c4,\n",
        "          'c5': c5,\n",
        "          'c6': c6}\n",
        "\n",
        "    param_indexes = {'c0': 0,\n",
        "              'c1': 1,\n",
        "              'c2': 2,\n",
        "              'c3': 3,\n",
        "              'c4': 4,\n",
        "              'c5': 5,\n",
        "              'c6': 6}\n",
        "\n",
        "    directions = []\n",
        "    distances = []\n",
        "    for k, v in params.items():\n",
        "        directions.append(latent_dirs[param_indexes[k]])\n",
        "        distances.append(v)\n",
        "\n",
        "    w1 = model.sample_latent(1, seed=seed1).detach().cpu().numpy()\n",
        "    w1 = [w1]*model.get_max_latents() # one per layer\n",
        "    im1 = model.sample_np(w1)\n",
        "\n",
        "    w2 = model.sample_latent(1, seed=seed2).detach().cpu().numpy()\n",
        "    w2 = [w2]*model.get_max_latents() # one per layer\n",
        "    im2 = model.sample_np(w2)\n",
        "    combined_im = np.concatenate([im1, im2], axis=1)\n",
        "    input_im = Image.fromarray((combined_im * 255).astype(np.uint8))\n",
        "    \n",
        "\n",
        "    mixed_w = mix_w(w1, w2, content, style)\n",
        "    return input_im, display_sample_pytorch(seed1, truncation, directions, distances, scale, int(start_layer), int(end_layer), w=mixed_w, disp=False)\n",
        "\n",
        "truncation = gr.inputs.Slider(minimum=0, maximum=1, default=0.5, label=\"Truncation\")\n",
        "start_layer = gr.inputs.Number(default=0, label=\"Start Layer\")\n",
        "end_layer = gr.inputs.Number(default=14, label=\"End Layer\")\n",
        "seed1 = gr.inputs.Number(default=0, label=\"Seed 1\")\n",
        "seed2 = gr.inputs.Number(default=0, label=\"Seed 2\")\n",
        "content = gr.inputs.Slider(label=\"Structure\", minimum=0, maximum=1, default=0.5)\n",
        "style = gr.inputs.Slider(label=\"Style\", minimum=0, maximum=1, default=0.5)\n",
        "\n",
        "slider_max_val = 20\n",
        "slider_min_val = -20\n",
        "slider_step = 1\n",
        "\n",
        "c0 = gr.inputs.Slider(label=\"Sleeve & Size\", minimum=slider_min_val, maximum=slider_max_val, default=0)\n",
        "c1 = gr.inputs.Slider(label=\"Dress - Jacket\", minimum=slider_min_val, maximum=slider_max_val, default=0)\n",
        "c2 = gr.inputs.Slider(label=\"Female Coat\", minimum=slider_min_val, maximum=slider_max_val, default=0)\n",
        "c3 = gr.inputs.Slider(label=\"Coat\", minimum=slider_min_val, maximum=slider_max_val, default=0)\n",
        "c4 = gr.inputs.Slider(label=\"Graphics\", minimum=slider_min_val, maximum=slider_max_val, default=0)\n",
        "c5 = gr.inputs.Slider(label=\"Dark\", minimum=slider_min_val, maximum=slider_max_val, default=0)\n",
        "c6 = gr.inputs.Slider(label=\"Less Cleavage\", minimum=slider_min_val, maximum=slider_max_val, default=0)\n",
        "\n",
        "\n",
        "scale = 1\n",
        "\n",
        "inputs = [seed1, seed2, content, style, truncation, c0, c1, c2, c3, c4, c5, c6, start_layer, end_layer]\n",
        "\n",
        "gr.Interface(generate_image, inputs, [\"image\", \"image\"], live=True, title=\"ClothingGAN\").launch()"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()`\n",
            "This share link will expire in 24 hours. If you need a permanent link, visit: https://gradio.app/introducing-hosted (NEW!)\n",
            "Running on External URL: https://10342.gradio.app\n",
            "Interface loading below...\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "\n",
              "        <iframe\n",
              "            width=\"900\"\n",
              "            height=\"500\"\n",
              "            src=\"https://10342.gradio.app\"\n",
              "            frameborder=\"0\"\n",
              "            allowfullscreen\n",
              "        ></iframe>\n",
              "        "
            ],
            "text/plain": [
              "<IPython.lib.display.IFrame at 0x7f3d693d3bd0>"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(<Flask 'gradio.networking'>,\n",
              " 'http://127.0.0.1:7860/',\n",
              " 'https://10342.gradio.app')"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 5
        }
      ]
    }
  ]
}