|
from PIL import Image |
|
import numpy as np |
|
import gradio as gr |
|
import spaces |
|
import torch |
|
from tqdm import tqdm |
|
|
|
from controlnet import QRControlNet |
|
from game_of_life import GameOfLife |
|
from utils import resize_image, generate_image_from_grid |
|
|
|
|
|
@spaces.gpu |
|
def init_controlnet(device: str) -> QRControlNet: |
|
return QRControlNet(device=device) |
|
|
|
|
|
def generate_all_images( |
|
gol_grids: list[np.array], |
|
source_image: Image, |
|
num_inference_steps: int, |
|
controlnet_conditioning_scale: float, |
|
strength: float, |
|
prompt: str, |
|
negative_prompt: str, |
|
seed: int, |
|
guidance_scale: float, |
|
img_size: int, |
|
): |
|
|
|
controlnet_conditioning_scale = float(controlnet_conditioning_scale) |
|
source_image = resize_image(source_image, resolution=img_size) |
|
images = [] |
|
for grid in tqdm(gol_grids): |
|
|
|
grid_inverse = 1 - grid |
|
grid_inverse_image = generate_image_from_grid(grid_inverse, img_size=img_size) |
|
|
|
image = controlnet.generate_image( |
|
source_image=source_image, |
|
control_image=grid_inverse_image, |
|
num_inference_steps=num_inference_steps, |
|
controlnet_conditioning_scale=controlnet_conditioning_scale, |
|
strength=strength, |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
seed=seed, |
|
guidance_scale=guidance_scale, |
|
img_size=img_size, |
|
) |
|
images.append(image) |
|
|
|
return images |
|
|
|
|
|
def make_gif(images: list[Image.Image], gif_path): |
|
images[0].save( |
|
gif_path, |
|
save_all=True, |
|
append_images=images[1:], |
|
duration=200, |
|
loop=0, |
|
) |
|
return gif_path |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
def generate( |
|
source_image, |
|
prompt, |
|
negative_prompt, |
|
seed, |
|
num_inference_steps, |
|
num_gol_steps, |
|
gol_grid_dim, |
|
img_size, |
|
controlnet_conditioning_scale, |
|
strength, |
|
guidance_scale, |
|
): |
|
|
|
|
|
gol = GameOfLife() |
|
gol.set_random_state(dim=(gol_grid_dim, gol_grid_dim), p=0.5, seed=seed) |
|
gol.generate_n_steps(n=num_gol_steps) |
|
|
|
gol_grids = gol.game_history |
|
|
|
|
|
gol_images = [ |
|
generate_image_from_grid(grid, img_size=img_size) for grid in gol_grids |
|
] |
|
path_gol_gif = make_gif(gol_images, "gol_original.gif") |
|
|
|
|
|
controlnet_images = generate_all_images( |
|
gol_grids=gol_grids, |
|
source_image=source_image, |
|
num_inference_steps=num_inference_steps, |
|
controlnet_conditioning_scale=controlnet_conditioning_scale, |
|
strength=strength, |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
seed=seed, |
|
guidance_scale=guidance_scale, |
|
img_size=img_size, |
|
) |
|
|
|
path_gol_controlnet = make_gif(controlnet_images, "gol_controlnet.gif") |
|
|
|
return path_gol_controlnet, path_gol_gif |
|
|
|
|
|
device = "cuda" |
|
|
|
|
|
print(f"Using {device=}") |
|
controlnet = init_controlnet() |
|
|
|
|
|
source_image = gr.Image(label="Source Image", type="pil", value="sky-gol-image.jpeg") |
|
|
|
output_controlnet = gr.Image(label="ControlNet Game of Life") |
|
output_gol = gr.Image(label="Original Game of Life") |
|
prompt = gr.Textbox( |
|
label="Prompt", value="clear sky with clouds, high quality, background 4k" |
|
) |
|
negative_prompt = gr.Textbox( |
|
label="Negative Prompt", |
|
value="ugly, disfigured, low quality, blurry, nsfw, qr code", |
|
) |
|
seed = gr.Number(label="Seed", value=42) |
|
num_inference_steps = gr.Number(label="Controlnet Inference Steps", value=50) |
|
num_gol_steps = gr.Slider( |
|
label="Number of Game of Life Steps", |
|
minimum=2, |
|
maximum=100, |
|
step=1, |
|
value=40, |
|
) |
|
gol_grid_dim = gr.Number( |
|
label="Game of Life Grid Dimension", |
|
value=10, |
|
) |
|
|
|
img_size = gr.Number(label="Image Size (pixels)", value=512) |
|
controlnet_conditioning_scale = gr.Slider( |
|
label="Controlnet Conditioning Scale", minimum=0.1, maximum=10.0, value=2.0 |
|
) |
|
strength = gr.Slider(label="Strength", minimum=0.1, maximum=1.0, value=0.9) |
|
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=100, value=20) |
|
|
|
|
|
demo = gr.Interface( |
|
fn=generate, |
|
inputs=[ |
|
source_image, |
|
prompt, |
|
negative_prompt, |
|
seed, |
|
num_inference_steps, |
|
num_gol_steps, |
|
gol_grid_dim, |
|
img_size, |
|
controlnet_conditioning_scale, |
|
strength, |
|
guidance_scale, |
|
], |
|
outputs=[output_controlnet, output_gol], |
|
) |
|
demo.launch() |
|
|