RNRI / app.py
Barak1's picture
added example
8f7b417
raw
history blame
7.49 kB
import gradio as gr
import torch
from src.euler_scheduler import MyEulerAncestralDiscreteScheduler
from diffusers.pipelines.auto_pipeline import AutoPipelineForImage2Image
from src.sdxl_inversion_pipeline import SDXLDDIMPipeline
from src.config import RunConfig
from src.editor import ImageEditorDemo
import spaces
device = "cuda" if torch.cuda.is_available() else "cpu"
# if torch.cuda.is_available():
# torch.cuda.max_memory_allocated(device=device)
# pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
# pipe.enable_xformers_memory_efficient_attention()
# pipe = pipe.to(device)
# else:
# pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
# pipe = pipe.to(device)
# css = """
# #col-container-1 {
# margin: 0 auto;
# max-width: 520px;
# }
# #col-container-2 {
# margin: 0 auto;
# max-width: 520px;
# }
# """
if torch.cuda.is_available():
power_device = "GPU"
else:
power_device = "CPU"
# with gr.Blocks(css=css) as demo:
with gr.Blocks(css="style.css") as demo:
gr.Markdown(f""" # Real Time Editing with RNRI Inversion 🍎⚡️
This is a demo for our [paper](https://arxiv.org/abs/2312.12540) **RNRI: Regularized Newton Raphson Inversion for Text-to-Image Diffusion Models**.
Image editing using our RNRI for inversion demonstrates significant speed-up and improved quality compared to previous state-of-the-art methods.
Take a look at our [project page](https://barakmam.github.io/rnri.github.io/).
""")
editor_state = gr.State()
@spaces.GPU
def set_pipe(input_image, description_prompt, edit_guidance_scale, num_inference_steps=4,
num_inversion_steps=4, inversion_max_step=0.6, rnri_iterations=2, rnri_alpha=0.1, rnri_lr=0.2):
scheduler_class = MyEulerAncestralDiscreteScheduler
print('\n################## 1')
pipe_inversion = SDXLDDIMPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True) # .to('cpu')
print('\n################## 2')
pipe_inference = AutoPipelineForImage2Image.from_pretrained("stabilityai/sdxl-turbo",
use_safetensors=True) # .to('cpu')
print('\n################## 3')
pipe_inference.scheduler = scheduler_class.from_config(pipe_inference.scheduler.config)
pipe_inversion.scheduler = scheduler_class.from_config(pipe_inversion.scheduler.config)
pipe_inversion.scheduler_inference = scheduler_class.from_config(pipe_inference.scheduler.config)
print('\n################## 4')
config = RunConfig(num_inference_steps=num_inference_steps,
num_inversion_steps=num_inversion_steps,
edit_guidance_scale=edit_guidance_scale,
inversion_max_step=inversion_max_step)
image_editor = ImageEditorDemo(pipe_inversion, pipe_inference, input_image,
description_prompt, config, device,
[rnri_iterations, rnri_alpha, rnri_lr])
print('\n################## 5')
return image_editor, "Input has set!"
@spaces.GPU
def edit(editor, target_prompt):
if editor is None:
raise gr.Error("Set inputs before editing.")
# if device == "cuda":
# image = editor.to(device).edit(target_prompt)
# else:
image = editor.edit(target_prompt)
return image
gr.Markdown(f"""running on {power_device}""")
with gr.Row():
with gr.Column(elem_id="col-container-1"):
with gr.Row():
input_image = gr.Image(label="Input image", sources=['upload', 'webcam'], type="pil")
with gr.Row():
description_prompt = gr.Text(
label="Image description",
info="Enter your image description ",
show_label=False,
max_lines=1,
placeholder="a cake on a table",
container=False,
)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
edit_guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=1.2,
)
num_inference_steps = gr.Slider(
label="Inference steps",
minimum=1,
maximum=12,
step=1,
value=4,
)
inversion_max_step = gr.Slider(
label="Inversion strength",
minimum=0.0,
maximum=1.0,
step=0.01,
value=0.6,
)
rnri_iterations = gr.Slider(
label="RNRI iterations",
minimum=0,
maximum=5,
step=1,
value=2,
)
rnri_alpha = gr.Slider(
label="RNRI alpha",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.1,
)
rnri_lr = gr.Slider(
label="RNRI learning rate",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.2,
)
with gr.Row():
set_button = gr.Button("Set input image & description & settings", scale=1)
is_set_text = gr.Text("", show_label=False)
with gr.Column(elem_id="col-container-2"):
result = gr.Image(label="Result")
with gr.Row():
target_prompt = gr.Text(
label="Edit prompt",
info="Enter your edit prompt",
show_label=False,
max_lines=1,
placeholder="an oreo cake on a table",
container=False,
)
with gr.Row():
run_button = gr.Button("Edit", scale=1)
with gr.Row():
gr.Examples(
examples='examples',
inputs=[input_image, description_prompt, target_prompt, edit_guidance_scale, num_inference_steps,
inversion_max_step, rnri_iterations, rnri_alpha, rnri_lr],
)
set_button.click(
fn=set_pipe,
inputs=[input_image, description_prompt, edit_guidance_scale, num_inference_steps,
num_inference_steps, inversion_max_step, rnri_iterations, rnri_alpha, rnri_lr],
outputs=[editor_state, is_set_text],
)
run_button.click(
fn=edit,
inputs=[editor_state, target_prompt],
outputs=[result]
)
demo.queue().launch()
# im = infer(input_image, description_prompt, target_prompt, edit_guidance_scale, num_inference_steps=4, num_inversion_steps=4,
# inversion_max_step=0.6)