Spaces:
Runtime error
Runtime error
File size: 4,957 Bytes
a324479 169ec0c a324479 e6915e1 a324479 169ec0c 892096a e6915e1 892096a 169ec0c a324479 892096a e6915e1 a324479 e6915e1 169ec0c e6915e1 169ec0c 892096a e6915e1 892096a e6915e1 169ec0c e6915e1 169ec0c 892096a 169ec0c 892096a 169ec0c a324479 169ec0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import gradio as gr
import jax
from PIL import Image
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline
from diffusers.utils import load_image
import jax.numpy as jnp
import numpy as np
import gc
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
"mfidabel/controlnet-segment-anything", dtype=jnp.float32
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32
)
# Add ControlNet params and Replicate
params["controlnet"] = controlnet_params
p_params = replicate(params)
# Description
title = "# 🧨 ControlNet on Segment Anything 🤗"
description = """This is a demo on 🧨 ControlNet based on Meta's [Segment Anything Model](https://segment-anything.com/).
Upload a Segment Anything Segmentation Map, write a prompt, and generate images 🤗 This demo is still Work in Progress, so don't expect it to work well for now !!
Test some of the examples below to give it a try ⬇️
"""
examples = [["contemporary living room of a house", "low quality", "examples/condition_image_1.png"],
["new york buildings, Vincent Van Gogh starry night ", "low quality, monochrome", "examples/condition_image_2.png"],
["contemporary living room, high quality, 4k, realistic", "low quality, monochrome, low res", "examples/condition_image_3.png"]]
# Inference Function
def infer(prompts, negative_prompts, image, num_inference_steps = 50, seed = 4, num_samples = 4):
try:
rng = jax.random.PRNGKey(int(seed))
num_inference_steps = int(num_inference_steps)
image = Image.fromarray(image, mode="RGB")
num_samples = max(jax.device_count(), int(num_samples))
p_rng = jax.random.split(rng, jax.device_count())
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
processed_image = pipe.prepare_image_inputs([image] * num_samples)
prompt_ids = shard(prompt_ids)
negative_prompt_ids = shard(negative_prompt_ids)
processed_image = shard(processed_image)
output = pipe(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=p_rng,
num_inference_steps=num_inference_steps,
neg_prompt_ids=negative_prompt_ids,
jit=True,
).images
del negative_prompt_ids
del processed_image
del prompt_ids
output = output.reshape((num_samples,) + output.shape[-3:])
final_image = [np.array(x*255, dtype=np.uint8) for x in output]
print(output.shape)
del output
except Exception as e:
print("Error: " + str(e))
final_image = [np.zeros((512, 512, 3), dtype=np.uint8)] * num_samples
finally:
gc.collect()
return final_image
default_example = examples[2]
cond_img = gr.Image(label="Input", shape=(512, 512), value=default_example[2])\
.style(height=200)
output = gr.Gallery(label="Generated images")\
.style(height=200, rows=[2], columns=[1, 2], object_fit="contain")
prompt = gr.Textbox(lines=1, label="Prompt", value=default_example[0])
negative_prompt = gr.Textbox(lines=1, label="Negative Prompt", value=default_example[1])
with gr.Blocks(css="h1 { text-align: center }") as demo:
with gr.Row():
with gr.Column():
# Title
gr.Markdown(title)
# Description
gr.Markdown(description)
with gr.Column():
# Examples
gr.Examples(examples=examples,
inputs=[prompt, negative_prompt, cond_img],
outputs=output,
fn=infer)
# Images
with gr.Row(variant="panel"):
with gr.Column(scale=2):
cond_img.render()
with gr.Column(scale=1):
output.render()
# Submit & Clear
with gr.Row():
with gr.Column():
prompt.render()
negative_prompt.render()
with gr.Column():
with gr.Accordion("Advanced options", open=False):
num_steps = gr.Slider(10, 60, 50, step=1, label="Steps")
seed = gr.Slider(0, 1024, 4, step=1, label="Seed")
num_samples = gr.Slider(1, 4, 4, step=1, label="Nº Samples")
submit = gr.Button("Generate")
# TODO: Download Button
submit.click(infer,
inputs=[prompt, negative_prompt, cond_img, num_steps, seed, num_samples],
outputs = output)
demo.queue()
demo.launch() |