Spaces:
Running
on
A10G
Running
on
A10G
File size: 11,642 Bytes
5978fbd b6d3fe5 5978fbd bff370f dc77641 3104476 8fd5fc2 5978fbd bff370f 5978fbd b6d3fe5 5978fbd 5be29af 3f9fe08 5978fbd e70c515 16858f2 e3bc468 4f8debf e70c515 5be29af 3f9fe08 b6d3fe5 506cf95 6163c61 c9fcb1b bff370f be6b787 8fd5fc2 698ea30 be6b787 e3bc468 16858f2 4f8debf 16858f2 11be735 1528d83 e3bc468 16858f2 d8f4eaf 16858f2 4f8debf 8086df1 156e670 e3bc468 698ea30 97ef95f e3bc468 156e670 11be735 4f8debf e3bc468 16858f2 e3bc468 b6d3fe5 4dc9242 5978fbd b6d3fe5 8f7b417 b6d3fe5 cb0c1d5 b6d3fe5 5978fbd b6d3fe5 6711456 b6d3fe5 bff370f 004b80f 6711456 3ebd82f 6711456 3ebd82f b6d3fe5 6711456 8f7b417 6711456 cb0c1d5 6711456 3d2ee5b e4eecf5 6163c61 3d2ee5b 5978fbd 18f1840 8fd5fc2 be6b787 8fd5fc2 4567708 8fd5fc2 be6b787 8fd5fc2 4567708 be6b787 8fd5fc2 4567708 be6b787 8fd5fc2 4567708 be6b787 8fd5fc2 4567708 be6b787 8fd5fc2 4567708 be6b787 8fd5fc2 4567708 8fd5fc2 be6b787 8fd5fc2 be6b787 8fd5fc2 505b4e4 5978fbd 505b4e4 be6b787 b6d3fe5 5978fbd b6d3fe5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 |
import gradio as gr
import torch
from src.euler_scheduler import MyEulerAncestralDiscreteScheduler
from diffusers.pipelines.auto_pipeline import AutoPipelineForImage2Image
from src.sdxl_inversion_pipeline import SDXLDDIMPipeline
from src.config import RunConfig
from src.editor import ImageEditorDemo
import spaces
device = "cuda" if torch.cuda.is_available() else "cpu"
# if torch.cuda.is_available():
# torch.cuda.max_memory_allocated(device=device)
# pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
# pipe.enable_xformers_memory_efficient_attention()
# pipe = pipe.to(device)
# else:
# pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
# pipe = pipe.to(device)
# css = """
# #col-container-1 {
# margin: 0 auto;
# max-width: 520px;
# }
# #col-container-2 {
# margin: 0 auto;
# max-width: 520px;
# }
# """
if device == "cuda":
torch.cuda.max_memory_allocated(device=device)
scheduler_class = MyEulerAncestralDiscreteScheduler
pipe_inversion = SDXLDDIMPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)#.to(device)
pipe_inference = AutoPipelineForImage2Image.from_pretrained("stabilityai/sdxl-turbo",
use_safetensors=True).to(device)
pipe_inference.scheduler = scheduler_class.from_config(pipe_inference.scheduler.config)
pipe_inversion.scheduler = scheduler_class.from_config(pipe_inversion.scheduler.config)
pipe_inversion.scheduler_inference = scheduler_class.from_config(pipe_inference.scheduler.config)
if device == "cuda":
pipe_inference.enable_xformers_memory_efficient_attention()
pipe_inversion.enable_xformers_memory_efficient_attention()
# with gr.Blocks(css=css) as demo:
# with gr.Blocks(css="style.css") as demo:
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(f""" # Real Time Editing with GNRI Inversion 🍎⚡️
This is a demo for our [paper](https://arxiv.org/abs/2312.12540) **GNRI: Lightning-fast Image Inversion and Editing for Text-to-Image Diffusion Models**.
Image editing using GNRI for inversion demonstrates significant speed-up and improved quality compared to previous state-of-the-art methods.
Take a look at the [project page](https://barakmam.github.io/rnri.github.io/).
""")
inv_state = gr.State()
@spaces.GPU
def set_pipe(input_image, description_prompt, edit_guidance_scale, num_inference_steps=4,
num_inversion_steps=4, inversion_max_step=0.6, rnri_iterations=2, rnri_alpha=0.1, rnri_lr=0.2):
if input_image is None or not description_prompt:
return None, "Please set all inputs."
if isinstance(num_inference_steps, str): num_inference_steps = int(num_inference_steps)
if isinstance(num_inversion_steps, str): num_inversion_steps = int(num_inversion_steps)
if isinstance(edit_guidance_scale, str): edit_guidance_scale = float(edit_guidance_scale)
if isinstance(inversion_max_step, str): inversion_max_step = float(inversion_max_step)
if isinstance(rnri_iterations, str): rnri_iterations = int(rnri_iterations)
if isinstance(rnri_alpha, str): rnri_alpha = float(rnri_alpha)
if isinstance(rnri_lr, str): rnri_lr = float(rnri_lr)
config = RunConfig(num_inference_steps=num_inference_steps,
num_inversion_steps=num_inversion_steps,
edit_guidance_scale=edit_guidance_scale,
inversion_max_step=inversion_max_step)
if device == 'cuda':
pipe_inference.to('cpu')
torch.cuda.empty_cache()
inversion_state = ImageEditorDemo.invert(pipe_inversion.to(device), input_image, description_prompt, config,
[rnri_iterations, rnri_alpha, rnri_lr], device)
if device == 'cuda':
pipe_inversion.to('cpu')
torch.cuda.empty_cache()
pipe_inference.to(device)
gr.Info('Input has set!')
return inversion_state, "Input has set!"
@spaces.GPU
def edit(inversion_state, target_prompt):
if inversion_state is None:
raise gr.Error("Set inputs before editing. Progress indication below")
image = ImageEditorDemo.edit(pipe_inference, target_prompt, inversion_state['latent'], inversion_state['noise'],
inversion_state['cfg'], inversion_state['cfg'].edit_guidance_scale)
return image
with gr.Row():
with gr.Column(elem_id="col-container-1"):
with gr.Row():
input_image = gr.Image(label="Input image", sources=['upload', 'webcam'], type="pil")
with gr.Row():
description_prompt = gr.Text(
label="Image description",
info="Enter your image description ",
show_label=False,
max_lines=1,
placeholder="Example: a cake on a table",
container=False,
)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
edit_guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=1.2,
)
num_inference_steps = gr.Slider(
label="Inference steps",
minimum=1,
maximum=12,
step=1,
value=4,
)
inversion_max_step = gr.Slider(
label="Inversion strength",
minimum=0.0,
maximum=1.0,
step=0.01,
value=0.6,
)
rnri_iterations = gr.Slider(
label="RNRI iterations",
minimum=0,
maximum=5,
step=1,
value=2,
)
rnri_alpha = gr.Slider(
label="RNRI alpha",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.1,
)
rnri_lr = gr.Slider(
label="RNRI learning rate",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.2,
)
with gr.Row():
is_set_text = gr.Text("", show_label=False)
with gr.Column(elem_id="col-container-2"):
result = gr.Image(label="Result")
with gr.Row():
target_prompt = gr.Text(
label="Edit prompt",
info="Enter your edit prompt",
show_label=False,
max_lines=1,
placeholder="Example: an oreo cake on a table",
container=False,
)
with gr.Row():
run_button = gr.Button("Edit", scale=1)
with gr.Row():
gr.Examples(
examples='examples',
inputs=[input_image, description_prompt, target_prompt, edit_guidance_scale, num_inference_steps,
inversion_max_step, rnri_iterations, rnri_alpha, rnri_lr],
cache_examples=False
)
gr.Markdown(f"""Disclaimer: Performance may be inferior to the reported in the paper due to hardware limitation.""")
input_image.change(set_pipe, inputs=[input_image, description_prompt, edit_guidance_scale, num_inference_steps,
num_inference_steps, inversion_max_step, rnri_iterations, rnri_alpha, rnri_lr],
outputs=[inv_state, is_set_text], trigger_mode='once')
description_prompt.change(set_pipe, inputs=[input_image, description_prompt, edit_guidance_scale,
num_inference_steps,
num_inference_steps, inversion_max_step, rnri_iterations, rnri_alpha,
rnri_lr],
outputs=[inv_state, is_set_text], trigger_mode='once')
edit_guidance_scale.change(set_pipe, inputs=[input_image, description_prompt, edit_guidance_scale,
num_inference_steps,
num_inference_steps, inversion_max_step, rnri_iterations, rnri_alpha,
rnri_lr],
outputs=[inv_state, is_set_text], trigger_mode='once')
num_inference_steps.change(set_pipe, inputs=[input_image, description_prompt, edit_guidance_scale,
num_inference_steps,
num_inference_steps, inversion_max_step, rnri_iterations, rnri_alpha,
rnri_lr],
outputs=[inv_state, is_set_text], trigger_mode='once')
inversion_max_step.change(set_pipe, inputs=[input_image, description_prompt, edit_guidance_scale,
num_inference_steps,
num_inference_steps, inversion_max_step, rnri_iterations, rnri_alpha,
rnri_lr],
outputs=[inv_state, is_set_text], trigger_mode='once')
rnri_iterations.change(set_pipe, inputs=[input_image, description_prompt, edit_guidance_scale,
num_inference_steps,
num_inference_steps, inversion_max_step, rnri_iterations, rnri_alpha,
rnri_lr],
outputs=[inv_state, is_set_text], trigger_mode='once')
rnri_alpha.change(set_pipe, inputs=[input_image, description_prompt, edit_guidance_scale,
num_inference_steps,
num_inference_steps, inversion_max_step, rnri_iterations, rnri_alpha,
rnri_lr],
outputs=[inv_state, is_set_text], trigger_mode='once')
rnri_lr.change(set_pipe, inputs=[input_image, description_prompt, edit_guidance_scale,
num_inference_steps,
num_inference_steps, inversion_max_step, rnri_iterations, rnri_alpha,
rnri_lr],
outputs=[inv_state, is_set_text], trigger_mode='once')
# set_button.click(
# fn=set_pipe,
# inputs=[inv_state, input_image, description_prompt, edit_guidance_scale, num_inference_steps,
# num_inference_steps, inversion_max_step, rnri_iterations, rnri_alpha, rnri_lr],
# outputs=[inv_state, is_set_text],
# )
run_button.click(
fn=edit,
inputs=[inv_state, target_prompt],
outputs=[result]
)
demo.queue().launch()
|