{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "synthesize_demo", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "qJDJLE3v0HNr" }, "source": [ "# Fetch Codebase and Install Environment" ] }, { "cell_type": "code", "metadata": { "id": "qy1nwGJV5JuG" }, "source": [ "import os\n", "os.chdir('/content')\n", "CODE_DIR = 'GenForce'\n", "!git clone https://github.com/genforce/genforce.git $CODE_DIR\n", "os.chdir(f'./{CODE_DIR}')\n", "!pip install -r requirements.txt > installation_output.txt" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "qh5DFyyg0Ntm" }, "source": [ "# Define Utility Functions" ] }, { "cell_type": "code", "metadata": { "id": "qcSdJW5V0M-8" }, "source": [ "import os\n", "import subprocess\n", "import io\n", "import IPython.display\n", "import numpy as np\n", "import PIL.Image\n", "\n", "import torch\n", "\n", "from models import MODEL_ZOO\n", "from models import build_generator\n", "from utils.visualizer import fuse_images\n", "\n", "\n", "def postprocess(images):\n", " \"\"\"Post-processes images from `torch.Tensor` to `numpy.ndarray`.\"\"\"\n", " images = images.detach().cpu().numpy()\n", " images = (images + 1) * 255 / 2\n", " images = np.clip(images + 0.5, 0, 255).astype(np.uint8)\n", " images = images.transpose(0, 2, 3, 1)\n", " return images\n", "\n", "\n", "def build(model_name):\n", " \"\"\"Builds generator and load pre-trained weights.\"\"\"\n", " model_config = MODEL_ZOO[model_name].copy()\n", " url = model_config.pop('url') # URL to download model if needed.\n", "\n", " # Build generator.\n", " print(f'Building generator for model `{model_name}` ...')\n", " generator = build_generator(**model_config)\n", " print(f'Finish building generator.')\n", "\n", " # Load pre-trained weights.\n", " os.makedirs('checkpoints', exist_ok=True)\n", " checkpoint_path = os.path.join('checkpoints', model_name + '.pth')\n", " print(f'Loading checkpoint from `{checkpoint_path}` ...')\n", " if not os.path.exists(checkpoint_path):\n", " print(f' Downloading checkpoint from `{url}` ...')\n", " subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url])\n", " print(f' Finish downloading checkpoint.')\n", " checkpoint = torch.load(checkpoint_path, map_location='cpu')\n", " if 'generator_smooth' in checkpoint:\n", " generator.load_state_dict(checkpoint['generator_smooth'])\n", " else:\n", " generator.load_state_dict(checkpoint['generator'])\n", " generator = generator.cuda()\n", " generator.eval()\n", " print(f'Finish loading checkpoint.')\n", " return generator\n", "\n", "\n", "def synthesize(generator, num, synthesis_kwargs=None, batch_size=1, seed=0):\n", " \"\"\"Synthesize images.\"\"\"\n", " assert num > 0 and batch_size > 0\n", " synthesis_kwargs = synthesis_kwargs or dict()\n", "\n", " # Set random seed.\n", " np.random.seed(seed)\n", " torch.manual_seed(seed)\n", "\n", " # Sample and synthesize.\n", " outputs = []\n", " for idx in range(0, num, batch_size):\n", " batch = min(batch_size, num - idx)\n", " code = torch.randn(batch, generator.z_space_dim).cuda()\n", " with torch.no_grad():\n", " images = generator(code, **synthesis_kwargs)['image']\n", " images = postprocess(images)\n", " outputs.append(images)\n", " return np.concatenate(outputs, axis=0)\n", "\n", "\n", "def imshow(images, viz_size=256, col=0, spacing=0):\n", " \"\"\"Shows images in one figure.\"\"\"\n", " fused_image = fuse_images(\n", " images,\n", " col=col,\n", " image_size=viz_size,\n", " row_spacing=spacing,\n", " col_spacing=spacing\n", " )\n", " fused_image = np.asarray(fused_image, dtype=np.uint8)\n", " data = io.BytesIO()\n", " PIL.Image.fromarray(fused_image).save(data, 'jpeg')\n", " im_data = data.getvalue()\n", " disp = IPython.display.display(IPython.display.Image(im_data))\n", " return disp" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "rIrseINa879H" }, "source": [ "# Select a Model" ] }, { "cell_type": "code", "metadata": { "id": "RyoJmv-PtZo_" }, "source": [ "#@title { display-mode: \"form\", run: \"auto\" }\n", "\n", "model_name = \"stylegan_diningroom256\" #@param ['pggan_celebahq1024', 'pggan_bedroom256', 'pggan_livingroom256', 'pggan_diningroom256', 'pggan_kitchen256', 'pggan_churchoutdoor256', 'pggan_tower256', 'pggan_bridge256', 'pggan_restaurant256', 'pggan_classroom256', 'pggan_conferenceroom256', 'pggan_person256', 'pggan_cat256', 'pggan_dog256', 'pggan_bird256', 'pggan_horse256', 'pggan_sheep256', 'pggan_cow256', 'pggan_car256', 'pggan_bicycle256', 'pggan_motorbike256', 'pggan_bus256', 'pggan_train256', 'pggan_boat256', 'pggan_airplane256', 'pggan_bottle256', 'pggan_chair256', 'pggan_pottedplant256', 'pggan_tvmonitor256', 'pggan_diningtable256', 'pggan_sofa256', 'stylegan_ffhq1024', 'stylegan_celebahq1024', 'stylegan_bedroom256', 'stylegan_cat256', 'stylegan_car512', 'stylegan_celeba_partial256', 'stylegan_ffhq256', 'stylegan_ffhq512', 'stylegan_livingroom256', 'stylegan_diningroom256', 'stylegan_kitchen256', 'stylegan_apartment256', 'stylegan_church256', 'stylegan_tower256', 'stylegan_bridge256', 'stylegan_restaurant256', 'stylegan_classroom256', 'stylegan_conferenceroom256', 'stylegan_animeface512', 'stylegan_animeportrait512', 'stylegan_artface512', 'stylegan2_ffhq1024', 'stylegan2_church256', 'stylegan2_cat256', 'stylegan2_horse256', 'stylegan2_car512']\n", "\n", "generator = build(model_name)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "RsGPMc5E8_jn" }, "source": [ "# Synthesize" ] }, { "cell_type": "code", "metadata": { "id": "jPkIXKxp4-7L" }, "source": [ "#@title { display-mode: \"form\", run: \"auto\" }\n", "\n", "num_samples = 8 #@param {type:\"slider\", min:1, max:20, step:1}\n", "noise_seed = 488 #@param {type:\"slider\", min:0, max:1000, step:1}\n", "truncation = 1 #@param {type:\"slider\", min:0.0, max:1, step:0.02}\n", "truncation_layers = 3 #@param {type:\"slider\", min:0, max:18, step:1}\n", "randomize_noise = 'false' #@param ['true', 'false']\n", "\n", "synthesis_kwargs = dict(trunc_psi=1 - truncation,\n", " trunc_layers=truncation_layers,\n", " randomize_noise=randomize_noise)\n", "images = synthesize(generator, num_samples, synthesis_kwargs, seed=noise_seed)\n", "imshow(images)" ], "execution_count": null, "outputs": [] } ] }