rom utils import download_url |
import argparse |
import numpy as np |
import PIL.Image |
import dnnlib |
import dnnlib.tflib as tflib |
import re |
import sys |
from io import BytesIO |
import IPython.display |
from math import ceil |
from PIL import Image, ImageDraw |
import os |
import pickle |
from utils import log_progress, imshow, create_image_grid, show_animation |
import imageio |
import glob |
import gdown |
import gradio as gr |
class Rasm: |
def __init__(self, mode = 'calligraphy'): |
if mode == 'calligraphy': |
url = 'https://drive.google.com/uc?id=138fdURGxdkOwZq7IWvnrGLcfo5VI8O1R' |
else: |
url = 'https://drive.google.com/uc?id=13h-alXGI0hbNOJy1qbmeoroXZSPBHEG2' |
output = 'model.pkl' |
print('Downloading networks from "%s"...' %url) |
gdown.download(url, output, quiet=False) |
dnnlib.tflib.init_tf() |
with dnnlib.util.open_url(output) as fp: |
self._G, self._D, self.Gs = pickle.load(fp) |
self.noise_vars = [var for name, var in self.Gs.components.synthesis.vars.items() if name.startswith('noise')] |
def generate_images_in_w_space(self, dlatents, truncation_psi): |
Gs_kwargs = dnnlib.EasyDict() |
Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) |
Gs_kwargs.randomize_noise = False |
Gs_kwargs.truncation_psi = truncation_psi |
imgs = [] |
for _, dlatent in log_progress(enumerate(dlatents), name = "Generating images"): |
row_images = self.Gs.components.synthesis.run(dlatent, **Gs_kwargs) |
imgs.append(PIL.Image.fromarray(row_images[0], 'RGB')) |
return imgs |
def generate_images(self, zs, truncation_psi, class_idx = None): |
Gs_kwargs = dnnlib.EasyDict() |
Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) |
Gs_kwargs.randomize_noise = False |
if not isinstance(truncation_psi, list): |
truncation_psi = [truncation_psi] * len(zs) |
imgs = [] |
label = np.zeros([1] + self.Gs.input_shapes[1][1:]) |
if class_idx is not None: |
label[:, class_idx] = 1 |
else: |
label = None |
for z_idx, z in log_progress(enumerate(zs), size = len(zs), name = "Generating images"): |
Gs_kwargs.truncation_psi = truncation_psi[z_idx] |
noise_rnd = np.random.RandomState(1) |
tflib.set_vars({var: noise_rnd.randn(*var.shape.as_list()) for var in self.noise_vars}) |
images = self.Gs.run(z, label, **Gs_kwargs) |
imgs.append(PIL.Image.fromarray(images[0], 'RGB')) |
return imgs |
def generate_from_zs(self, zs, truncation_psi = 0.5): |
Gs_kwargs = dnnlib.EasyDict() |
Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) |
Gs_kwargs.randomize_noise = False |
if not isinstance(truncation_psi, list): |
truncation_psi = [truncation_psi] * len(zs) |
for z_idx, z in log_progress(enumerate(zs), size = len(zs), name = "Generating images"): |
Gs_kwargs.truncation_psi = truncation_psi[z_idx] |
noise_rnd = np.random.RandomState(1) |
tflib.set_vars({var: noise_rnd.randn(*var.shape.as_list()) for var in self.noise_vars}) |
images = self.Gs.run(z, None, **Gs_kwargs) |
img = PIL.Image.fromarray(images[0], 'RGB') |
imshow(img) |
def generate_random_zs(self, size): |
seeds = np.random.randint(2**32, size=size) |
zs = [] |
for _, seed in enumerate(seeds): |
rnd = np.random.RandomState(seed) |
z = rnd.randn(1, *self.Gs.input_shape[1:]) |
zs.append(z) |
return zs |
def generate_zs_from_seeds(self, seeds): |
zs = [] |
for _, seed in enumerate(seeds): |
rnd = np.random.RandomState(seed) |
z = rnd.randn(1, *self.Gs.input_shape[1:]) |
zs.append(z) |
return zs |
def generate_images_from_seeds(self, seeds, truncation_psi): |
ima = self.generate_images(self.generate_zs_from_seeds(seeds), truncation_psi)[0] |
return ima, imshow(ima) |
def generate_randomly(self, truncation_psi = 0.5): |
ima, dis = self.generate_images_from_seeds(np.random.randint(4294967295, size=1), truncation_psi=truncation_psi) |
return ima, dis |
def generate_grid(self, truncation_psi = 0.7): |
seeds = np.random.randint((2**32 - 1), size=9) |
return create_image_grid(self.generate_images(self.generate_zs_from_seeds(seeds), truncation_psi), 0.7 , 3) |
def generate_animation(self, size = 9, steps = 10, trunc_psi = 0.5): |
seeds = list(np.random.randint((2**32) - 1, size=size)) |
seeds = seeds + [seeds[0]] |
zs = self.generate_zs_from_seeds(seeds) |
imgs = self.generate_images(self.interpolate(zs, steps = steps), trunc_psi) |
movie_name = 'animation.mp4' |
with imageio.get_writer(movie_name, mode='I') as writer: |
for image in log_progress(list(imgs), name = "Creating animation"): |
writer.append_data(np.array(image)) |
return show_animation(movie_name) |
def convertZtoW(self, latent, truncation_psi=0.7, truncation_cutoff=9): |
dlatent = self.Gs.components.mapping.run(latent, None) |
dlatent_avg = self.Gs.get_var('dlatent_avg') |
for i in range(truncation_cutoff): |
dlatent[0][i] = (dlatent[0][i]-dlatent_avg)*truncation_psi + dlatent_avg |
return dlatent |
def interpolate(self, zs, steps = 10): |
out = [] |
for i in range(len(zs)-1): |
for index in range(steps): |
fraction = index/float(steps) |
out.append(zs[i+1]*fraction + zs[i]*(1-fraction)) |
return out |
def model(mode, output): |
model=rasm.Rasm(mode=mode) |
if output=='Generate Art Randomly': |
ima,res= model.generate_randomly() |
elif output=='Generate Art Grid': |
ima = model.generate_grid() |
elif output=='Generate Art Animation': |
ima = model.generate_animation(size = 2, steps = 20) |
return ima |
imageout=gr.outputs.Image(model, |
[ |
gr.Radio(["calligraphy", "mosaics"],label="Type of Arbic Art"), |
gr.Radio(["Generate Art Randomly", "Generate Art Grid", "Generate Art Animation"],label="How do you prefer the output visualization" ), |
], |
outputs=imageout |
) |
demo.launch() |