{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2020 Erik Härkönen. All rights reserved.\n",
    "# This file is licensed to you under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License. You may obtain a copy\n",
    "# of the License at http://www.apache.org/licenses/LICENSE-2.0\n",
    "\n",
    "# Unless required by applicable law or agreed to in writing, software distributed under\n",
    "# the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS\n",
    "# OF ANY KIND, either express or implied. See the License for the specific language\n",
    "# governing permissions and limitations under the License.\n",
    "\n",
    "# Figure: BigGAN edit transferability between classes\n",
    "%matplotlib inline\n",
    "from notebook_init import *\n",
    "\n",
    "rand = lambda : np.random.randint(np.iinfo(np.int32).max)\n",
    "outdir = Path('out/figures/edit_transferability')\n",
    "makedirs(outdir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "inst = get_instrumented_model('BigGAN-512', 'husky', 'generator.gen_z', device, inst=inst)\n",
    "model = inst.model\n",
    "model.truncation = 0.7\n",
    "\n",
    "pc_config = Config(components=80, n=1_000_000,\n",
    "    layer='generator.gen_z', model='BigGAN-512', output_class='husky')\n",
    "dump_name = get_or_compute(pc_config, inst)\n",
    "\n",
    "with np.load(dump_name) as data:\n",
    "    lat_comp = data['lat_comp']\n",
    "    lat_mean = data['lat_mean']\n",
    "    lat_std = data['lat_stdev']\n",
    "\n",
    "# name: component_idx, layer_start, layer_end, strength\n",
    "edits = {\n",
    "    'translate_x': ( 0, 0, 15, -3.0),\n",
    "    'zoom':        ( 6, 0, 15, 2.0),\n",
    "    'clouds':      (54, 7, 10, 15.0),\n",
    "    #'dark_fg':     (51, 7, 10, 20.0),\n",
    "    'sunlight':    (33, 7, 10, 25.0),\n",
    "    #'silouette':   (13, 7, 10, -20.0),\n",
    "    #'grass_bg':    (69, 3,  7, -20.0),\n",
    "}\n",
    "\n",
    "def apply_offset(z, idx, start, end, sigma):\n",
    "    lat = z if isinstance(z, list) else [z]*model.get_max_latents()\n",
    "    for i in range(start, end):\n",
    "        lat[i] = lat[i] + lat_comp[idx]*lat_std[idx]*sigma\n",
    "    return lat\n",
    "\n",
    "show = True\n",
    "\n",
    "# good geom seeds: 2145371585\n",
    "# good style seeds: 337336281, 2075156369, 311784160\n",
    "\n",
    "for _ in range(1):\n",
    "    \n",
    "    # Type 1: geometric edit - transfers well\n",
    "    \n",
    "    seed1_geom = 2145371585\n",
    "    seed2_geom = 2046317118\n",
    "    print('Seeds geom:', [seed1_geom, seed2_geom])\n",
    "    z1 = model.sample_latent(1, seed=seed1_geom).cpu().numpy()\n",
    "    z2 = model.sample_latent(1, seed=seed2_geom).cpu().numpy()\n",
    "\n",
    "    model.set_output_class('husky')\n",
    "    base_husky = model.sample_np(z1)\n",
    "    zoom_husky = model.sample_np(apply_offset(z1, *edits['zoom']))\n",
    "    transl_husky = model.sample_np(apply_offset(z1, *edits['translate_x']))\n",
    "    img_geom1 = np.hstack([base_husky, zoom_husky, transl_husky])\n",
    "\n",
    "    model.set_output_class('castle')\n",
    "    base_castle = model.sample_np(z2)\n",
    "    zoom_castle = model.sample_np(apply_offset(z2, *edits['zoom']))\n",
    "    transl_castle = model.sample_np(apply_offset(z2, *edits['translate_x']))\n",
    "    img_geom2 = np.hstack([base_castle, zoom_castle, transl_castle])\n",
    "\n",
    "        \n",
    "    # Type 2: style edit - often transfers\n",
    "    \n",
    "    seed1_style = 417482011 #rand()\n",
    "    seed2_style = 1026291813\n",
    "    print('Seeds style:', [seed1_style, seed2_style])\n",
    "    z1 = model.sample_latent(1, seed=seed1_style).cpu().numpy()\n",
    "    z2 = model.sample_latent(1, seed=seed2_style).cpu().numpy()\n",
    "\n",
    "    model.set_output_class('lighthouse')\n",
    "    base_lighthouse = model.sample_np(z2)\n",
    "    edit1_lighthouse = model.sample_np(apply_offset(z2, *edits['clouds']))\n",
    "    edit2_lighthouse = model.sample_np(apply_offset(z2, *edits['sunlight']))\n",
    "    img_style2 = np.hstack([base_lighthouse, edit1_lighthouse, edit2_lighthouse])\n",
    "    \n",
    "    model.set_output_class('barn')\n",
    "    base_barn = model.sample_np(z1)\n",
    "    edit1_barn = model.sample_np(apply_offset(z1, *edits['clouds']))\n",
    "    edit2_barn = model.sample_np(apply_offset(z1, *edits['sunlight']))\n",
    "    img_style1 = np.hstack([base_barn, edit1_barn, edit2_barn])\n",
    "    \n",
    "    \n",
    "    grid = np.vstack([img_geom1, img_geom2, img_style1, img_style2])\n",
    "    \n",
    "    if show:\n",
    "        plt.figure(figsize=(12,12))\n",
    "        plt.imshow(grid)\n",
    "        plt.axis('off')\n",
    "        plt.show()\n",
    "    else:\n",
    "        Image.fromarray((255*grid).astype(np.uint8)).save(outdir / f'{seed1_geom}_{seed2_geom}_{seed1_style}_{seed2_style}_transf.jpg')\n",
    "        \n",
    "        # Save individual frames\n",
    "        Image.fromarray((255*base_husky).astype(np.uint8)).save(outdir / 'geom_husky_1.png')\n",
    "        Image.fromarray((255*zoom_husky).astype(np.uint8)).save(outdir / 'geom_husky_2.png')\n",
    "        Image.fromarray((255*transl_husky).astype(np.uint8)).save(outdir / 'geom_husky_3.png')\n",
    "        Image.fromarray((255*base_castle).astype(np.uint8)).save(outdir / 'geom_castle_1.png')\n",
    "        Image.fromarray((255*zoom_castle).astype(np.uint8)).save(outdir / 'geom_castle_2.png')\n",
    "        Image.fromarray((255*transl_castle).astype(np.uint8)).save(outdir / 'geom_castle_3.png')\n",
    "        \n",
    "        Image.fromarray((255*base_lighthouse).astype(np.uint8)).save(outdir / 'style_lighthouse_1.png')\n",
    "        Image.fromarray((255*edit1_lighthouse).astype(np.uint8)).save(outdir / 'style_lighthouse_2.png')\n",
    "        Image.fromarray((255*edit2_lighthouse).astype(np.uint8)).save(outdir / 'style_lighthouse_3.png')\n",
    "        Image.fromarray((255*base_barn).astype(np.uint8)).save(outdir / 'style_barn_1.png')\n",
    "        Image.fromarray((255*edit1_barn).astype(np.uint8)).save(outdir / 'style_barn_2.png')\n",
    "        Image.fromarray((255*edit2_barn).astype(np.uint8)).save(outdir / 'style_barn_3.png')\n",
    "        \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "file_extension": ".py",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  },
  "mimetype": "text/x-python",
  "name": "python",
  "npconvert_exporter": "python",
  "pygments_lexer": "ipython3",
  "version": 3
 },
 "nbformat": 4,
 "nbformat_minor": 2
}