Spaces:
Running
Running
File size: 4,631 Bytes
ea2f505 4a6a9ac c9000d8 ea2f505 c293bd2 ea2f505 c9000d8 4a6a9ac c9000d8 4a6a9ac c9000d8 ea2f505 |
|
from PIL import Image
import numpy as np
import gradio as gr
import spaces
import torch
from tqdm import tqdm
from controlnet import QRControlNet
from game_of_life import GameOfLife
from utils import resize_image, generate_image_from_grid
@spaces.GPU
def init_controlnet(device: str) -> QRControlNet:
return QRControlNet(device=device)
def generate_all_images(
gol_grids: list[np.array],
source_image: Image,
num_inference_steps: int,
controlnet_conditioning_scale: float,
strength: float,
prompt: str,
negative_prompt: str,
seed: int,
guidance_scale: float,
img_size: int,
):
controlnet_conditioning_scale = float(controlnet_conditioning_scale)
source_image = resize_image(source_image, resolution=img_size)
images = []
for grid in tqdm(gol_grids):
grid_inverse = 1 - grid # invert the grid for controlnet
grid_inverse_image = generate_image_from_grid(grid_inverse, img_size=img_size)
image = controlnet.generate_image(
source_image=source_image,
control_image=grid_inverse_image,
num_inference_steps=num_inference_steps,
controlnet_conditioning_scale=controlnet_conditioning_scale,
strength=strength,
prompt=prompt,
negative_prompt=negative_prompt,
seed=seed,
guidance_scale=guidance_scale,
img_size=img_size,
)
images.append(image)
return images
def make_gif(images: list[Image.Image], gif_path):
images[0].save(
gif_path,
save_all=True,
append_images=images[1:],
duration=200, # Duration between frames in milliseconds
loop=0,
) # Loop forever
return gif_path
@spaces.GPU(duration=120)
def generate(
source_image,
prompt,
negative_prompt,
seed,
num_inference_steps,
num_gol_steps,
gol_grid_dim,
img_size,
controlnet_conditioning_scale,
strength,
guidance_scale,
):
# Compute the Game of Life first
gol = GameOfLife()
gol.set_random_state(dim=(gol_grid_dim, gol_grid_dim), p=0.5, seed=seed)
gol.generate_n_steps(n=num_gol_steps)
gol_grids = gol.game_history
# Generate the gif for the original Game of Life
gol_images = [
generate_image_from_grid(grid, img_size=img_size) for grid in gol_grids
]
path_gol_gif = make_gif(gol_images, "gol_original.gif")
# Generate the gif for the ControlNet Game of Life
controlnet_images = generate_all_images(
gol_grids=gol_grids,
source_image=source_image,
num_inference_steps=num_inference_steps,
controlnet_conditioning_scale=controlnet_conditioning_scale,
strength=strength,
prompt=prompt,
negative_prompt=negative_prompt,
seed=seed,
guidance_scale=guidance_scale,
img_size=img_size,
)
path_gol_controlnet = make_gif(controlnet_images, "gol_controlnet.gif")
return path_gol_controlnet, path_gol_gif
device = "cuda"
# device = "mps"
# device = "cpu"
print(f"Using {device=}")
controlnet = init_controlnet(device=device)
source_image = gr.Image(label="Source Image", type="pil", value="sky-gol-image.jpeg")
output_controlnet = gr.Image(label="ControlNet Game of Life")
output_gol = gr.Image(label="Original Game of Life")
prompt = gr.Textbox(
label="Prompt", value="clear sky with clouds, high quality, background 4k"
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="ugly, disfigured, low quality, blurry, nsfw, qr code",
)
seed = gr.Number(label="Seed", value=42)
num_inference_steps = gr.Number(label="Controlnet Inference Steps", value=50)
num_gol_steps = gr.Slider(
label="Number of Game of Life Steps",
minimum=2,
maximum=100,
step=1,
value=40,
)
gol_grid_dim = gr.Number(
label="Game of Life Grid Dimension",
value=10,
)
img_size = gr.Number(label="Image Size (pixels)", value=512)
controlnet_conditioning_scale = gr.Slider(
label="Controlnet Conditioning Scale", minimum=0.1, maximum=10.0, value=2.0
)
strength = gr.Slider(label="Strength", minimum=0.1, maximum=1.0, value=0.9)
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=100, value=20)
demo = gr.Interface(
fn=generate,
inputs=[
source_image,
prompt,
negative_prompt,
seed,
num_inference_steps,
num_gol_steps,
gol_grid_dim,
img_size,
controlnet_conditioning_scale,
strength,
guidance_scale,
],
outputs=[output_controlnet, output_gol],
)
demo.launch()
|