File size: 4,957 Bytes
a324479
 
 
 
 
 
169ec0c
a324479
 
e6915e1
a324479
 
 
 
 
 
 
 
 
 
 
 
 
 
169ec0c
 
892096a
 
 
 
 
 
 
 
e6915e1
892096a
 
169ec0c
a324479
 
892096a
e6915e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a324479
 
e6915e1
 
 
 
 
 
 
 
 
 
 
169ec0c
 
e6915e1
 
 
 
 
 
 
 
 
 
 
 
 
169ec0c
 
 
892096a
e6915e1
892096a
e6915e1
169ec0c
 
 
 
e6915e1
 
169ec0c
 
 
 
892096a
169ec0c
 
892096a
 
169ec0c
 
 
 
 
a324479
169ec0c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
import jax
from PIL import Image
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline
from diffusers.utils import load_image
import jax.numpy as jnp
import numpy as np
import gc


controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
    "mfidabel/controlnet-segment-anything", dtype=jnp.float32
)

pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32
)

# Add ControlNet params and Replicate
params["controlnet"] = controlnet_params
p_params = replicate(params)

# Description
title = "# 🧨 ControlNet on Segment Anything 🤗"
description = """This is a demo on 🧨 ControlNet based on Meta's [Segment Anything Model](https://segment-anything.com/).

                Upload a Segment Anything Segmentation Map, write a prompt, and generate images 🤗 This demo is still Work in Progress, so don't expect it to work well for now !! 

                
                Test some of the examples below to give it a try ⬇️
              """

examples = [["contemporary living room of a house", "low quality", "examples/condition_image_1.png"],
            ["new york buildings,  Vincent Van Gogh starry night ", "low quality, monochrome", "examples/condition_image_2.png"],
            ["contemporary living room,  high quality, 4k, realistic", "low quality, monochrome, low res", "examples/condition_image_3.png"]]


# Inference Function
def infer(prompts, negative_prompts, image, num_inference_steps = 50, seed = 4, num_samples = 4):
    try:
        rng = jax.random.PRNGKey(int(seed))
        num_inference_steps = int(num_inference_steps)
        image = Image.fromarray(image, mode="RGB")
        num_samples = max(jax.device_count(), int(num_samples))
        p_rng = jax.random.split(rng, jax.device_count())
        
        prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
        negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
        processed_image = pipe.prepare_image_inputs([image] * num_samples)
        
        prompt_ids = shard(prompt_ids)
        negative_prompt_ids = shard(negative_prompt_ids)
        processed_image = shard(processed_image)
        
        output = pipe(
            prompt_ids=prompt_ids,
            image=processed_image,
            params=p_params,
            prng_seed=p_rng,
            num_inference_steps=num_inference_steps,
            neg_prompt_ids=negative_prompt_ids,
            jit=True,
        ).images

        del negative_prompt_ids
        del processed_image
        del prompt_ids

        output = output.reshape((num_samples,) + output.shape[-3:])
        final_image = [np.array(x*255, dtype=np.uint8) for x in output]
        print(output.shape)
        del output
        
    except Exception as e:
        print("Error: " + str(e))
        final_image = [np.zeros((512, 512, 3), dtype=np.uint8)] * num_samples
    finally:
        gc.collect()
        return final_image
    

default_example = examples[2]

cond_img = gr.Image(label="Input", shape=(512, 512), value=default_example[2])\
                    .style(height=200)

output = gr.Gallery(label="Generated images")\
                    .style(height=200, rows=[2], columns=[1, 2], object_fit="contain")

prompt = gr.Textbox(lines=1, label="Prompt", value=default_example[0])
negative_prompt = gr.Textbox(lines=1, label="Negative Prompt", value=default_example[1])


with gr.Blocks(css="h1 { text-align: center }") as demo:
    with gr.Row():
        with gr.Column():
            # Title
            gr.Markdown(title)
            # Description
            gr.Markdown(description)

        with gr.Column():
            # Examples
            gr.Examples(examples=examples,
                    inputs=[prompt, negative_prompt, cond_img],
                    outputs=output,
                    fn=infer)

    # Images
    with gr.Row(variant="panel"):
        with gr.Column(scale=2):
            cond_img.render()
        with gr.Column(scale=1):
            output.render()
        
    # Submit & Clear
    with gr.Row():
        with gr.Column():
            prompt.render()
            negative_prompt.render()

        with gr.Column():
            with gr.Accordion("Advanced options", open=False):
                num_steps = gr.Slider(10, 60, 50, step=1, label="Steps")
                seed = gr.Slider(0, 1024, 4, step=1, label="Seed")
                num_samples = gr.Slider(1, 4, 4, step=1, label="Nº Samples")
                
            submit = gr.Button("Generate")
            # TODO: Download Button

    
    submit.click(infer, 
                 inputs=[prompt, negative_prompt, cond_img, num_steps, seed, num_samples],
                 outputs = output)
    
demo.queue()
demo.launch()