|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import subprocess |
|
from pathlib import Path |
|
|
|
import einops |
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
from PIL import Image |
|
from torch import nn |
|
from torchvision.utils import save_image |
|
|
|
|
|
class Generator(nn.Module): |
|
def __init__(self, nc=4, nz=100, ngf=64): |
|
super(Generator, self).__init__() |
|
self.network = nn.Sequential( |
|
nn.ConvTranspose2d(nz, ngf * 4, 3, 1, 0, bias=False), |
|
nn.BatchNorm2d(ngf * 4), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, bias=False), |
|
nn.BatchNorm2d(ngf * 2), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 0, bias=False), |
|
nn.BatchNorm2d(ngf), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), |
|
nn.Tanh(), |
|
) |
|
|
|
def forward(self, input): |
|
output = self.network(input) |
|
return output |
|
|
|
|
|
model = Generator() |
|
weights_path = hf_hub_download('nateraw/cryptopunks-gan', 'generator.pth') |
|
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) |
|
|
|
|
|
@torch.no_grad() |
|
def interpolate(save_dir='./lerp/', frames=100, rows=8, cols=8): |
|
save_dir = Path(save_dir) |
|
save_dir.mkdir(exist_ok=True, parents=True) |
|
|
|
z1 = torch.randn(rows * cols, 100, 1, 1) |
|
z2 = torch.randn(rows * cols, 100, 1, 1) |
|
|
|
zs = [] |
|
for i in range(frames): |
|
alpha = i / frames |
|
z = (1 - alpha) * z1 + alpha * z2 |
|
zs.append(z) |
|
|
|
zs += zs[::-1] |
|
|
|
for i, z in enumerate(zs): |
|
imgs = model(z) |
|
|
|
|
|
imgs = (imgs + 1) / 2 |
|
|
|
imgs = (imgs.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8) |
|
|
|
|
|
imgs = einops.rearrange(imgs, "(b1 b2) h w c -> (b1 h) (b2 w) c", b1=rows, b2=cols) |
|
|
|
Image.fromarray(imgs).save(save_dir / f"{i:03}.png") |
|
|
|
subprocess.call(f"convert -dispose previous -delay 10 -loop 0 {save_dir}/*.png out.gif".split()) |
|
|
|
|
|
def predict(choice, seed): |
|
torch.manual_seed(seed) |
|
|
|
if choice == 'interpolation': |
|
interpolate() |
|
return 'out.gif' |
|
else: |
|
z = torch.randn(64, 100, 1, 1) |
|
punks = model(z) |
|
save_image(punks, "punks.png", normalize=True) |
|
return 'punks.png' |
|
|
|
|
|
gr.Interface( |
|
predict, |
|
inputs=[ |
|
gr.inputs.Dropdown(['image', 'interpolation'], label='Output Type'), |
|
gr.inputs.Slider(label='Seed', minimum=0, maximum=1000, default=42), |
|
], |
|
outputs="image", |
|
title="Cryptopunks GAN", |
|
description="These CryptoPunks do not exist. You have the choice of either generating random punks, or a gif showing the interpolation between two random punk grids.", |
|
article="<p style='text-align: center'><a href='https://arxiv.org/pdf/1511.06434.pdf'>Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks</a> | <a href='https://github.com/teddykoker/cryptopunks-gan'>Github Repo</a></p>", |
|
examples=[["interpolation", 123], ["interpolation", 42], ["image", 456], ["image", 42]], |
|
).launch(cache_examples=True) |
|
|