RNRI / app.py
barakmam
minor fix
18f1840
raw
history blame
11.6 kB
import gradio as gr
import torch
from src.euler_scheduler import MyEulerAncestralDiscreteScheduler
from diffusers.pipelines.auto_pipeline import AutoPipelineForImage2Image
from src.sdxl_inversion_pipeline import SDXLDDIMPipeline
from src.config import RunConfig
from src.editor import ImageEditorDemo
import spaces
device = "cuda" if torch.cuda.is_available() else "cpu"
# if torch.cuda.is_available():
# torch.cuda.max_memory_allocated(device=device)
# pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
# pipe.enable_xformers_memory_efficient_attention()
# pipe = pipe.to(device)
# else:
# pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
# pipe = pipe.to(device)
# css = """
# #col-container-1 {
# margin: 0 auto;
# max-width: 520px;
# }
# #col-container-2 {
# margin: 0 auto;
# max-width: 520px;
# }
# """
if device == "cuda":
torch.cuda.max_memory_allocated(device=device)
scheduler_class = MyEulerAncestralDiscreteScheduler
pipe_inversion = SDXLDDIMPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)#.to(device)
pipe_inference = AutoPipelineForImage2Image.from_pretrained("stabilityai/sdxl-turbo",
use_safetensors=True).to(device)
pipe_inference.scheduler = scheduler_class.from_config(pipe_inference.scheduler.config)
pipe_inversion.scheduler = scheduler_class.from_config(pipe_inversion.scheduler.config)
pipe_inversion.scheduler_inference = scheduler_class.from_config(pipe_inference.scheduler.config)
if device == "cuda":
pipe_inference.enable_xformers_memory_efficient_attention()
pipe_inversion.enable_xformers_memory_efficient_attention()
# with gr.Blocks(css=css) as demo:
# with gr.Blocks(css="style.css") as demo:
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(f""" # Real Time Editing with GNRI Inversion 🍎⚡️
This is a demo for our [paper](https://arxiv.org/abs/2312.12540) **GNRI: Lightning-fast Image Inversion and Editing for Text-to-Image Diffusion Models**.
Image editing using GNRI for inversion demonstrates significant speed-up and improved quality compared to previous state-of-the-art methods.
Take a look at the [project page](https://barakmam.github.io/rnri.github.io/).
""")
inv_state = gr.State()
@spaces.GPU
def set_pipe(input_image, description_prompt, edit_guidance_scale, num_inference_steps=4,
num_inversion_steps=4, inversion_max_step=0.6, rnri_iterations=2, rnri_alpha=0.1, rnri_lr=0.2):
if input_image is None or not description_prompt:
return None, "Please set all inputs."
if isinstance(num_inference_steps, str): num_inference_steps = int(num_inference_steps)
if isinstance(num_inversion_steps, str): num_inversion_steps = int(num_inversion_steps)
if isinstance(edit_guidance_scale, str): edit_guidance_scale = float(edit_guidance_scale)
if isinstance(inversion_max_step, str): inversion_max_step = float(inversion_max_step)
if isinstance(rnri_iterations, str): rnri_iterations = int(rnri_iterations)
if isinstance(rnri_alpha, str): rnri_alpha = float(rnri_alpha)
if isinstance(rnri_lr, str): rnri_lr = float(rnri_lr)
config = RunConfig(num_inference_steps=num_inference_steps,
num_inversion_steps=num_inversion_steps,
edit_guidance_scale=edit_guidance_scale,
inversion_max_step=inversion_max_step)
if device == 'cuda':
pipe_inference.to('cpu')
torch.cuda.empty_cache()
inversion_state = ImageEditorDemo.invert(pipe_inversion.to(device), input_image, description_prompt, config,
[rnri_iterations, rnri_alpha, rnri_lr], device)
if device == 'cuda':
pipe_inversion.to('cpu')
torch.cuda.empty_cache()
pipe_inference.to(device)
gr.Info('Input has set!')
return inversion_state, "Input has set!"
@spaces.GPU
def edit(inversion_state, target_prompt):
if inversion_state is None:
raise gr.Error("Set inputs before editing. Progress indication below")
image = ImageEditorDemo.edit(pipe_inference, target_prompt, inversion_state['latent'], inversion_state['noise'],
inversion_state['cfg'], inversion_state['cfg'].edit_guidance_scale)
return image
with gr.Row():
with gr.Column(elem_id="col-container-1"):
with gr.Row():
input_image = gr.Image(label="Input image", sources=['upload', 'webcam'], type="pil")
with gr.Row():
description_prompt = gr.Text(
label="Image description",
info="Enter your image description ",
show_label=False,
max_lines=1,
placeholder="Example: a cake on a table",
container=False,
)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
edit_guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=1.2,
)
num_inference_steps = gr.Slider(
label="Inference steps",
minimum=1,
maximum=12,
step=1,
value=4,
)
inversion_max_step = gr.Slider(
label="Inversion strength",
minimum=0.0,
maximum=1.0,
step=0.01,
value=0.6,
)
rnri_iterations = gr.Slider(
label="RNRI iterations",
minimum=0,
maximum=5,
step=1,
value=2,
)
rnri_alpha = gr.Slider(
label="RNRI alpha",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.1,
)
rnri_lr = gr.Slider(
label="RNRI learning rate",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.2,
)
with gr.Row():
is_set_text = gr.Text("", show_label=False)
with gr.Column(elem_id="col-container-2"):
result = gr.Image(label="Result")
with gr.Row():
target_prompt = gr.Text(
label="Edit prompt",
info="Enter your edit prompt",
show_label=False,
max_lines=1,
placeholder="Example: an oreo cake on a table",
container=False,
)
with gr.Row():
run_button = gr.Button("Edit", scale=1)
with gr.Row():
gr.Examples(
examples='examples',
inputs=[input_image, description_prompt, target_prompt, edit_guidance_scale, num_inference_steps,
inversion_max_step, rnri_iterations, rnri_alpha, rnri_lr],
cache_examples=False
)
gr.Markdown(f"""Disclaimer: Performance may be inferior to the reported in the paper due to hardware limitation.""")
input_image.change(set_pipe, inputs=[input_image, description_prompt, edit_guidance_scale, num_inference_steps,
num_inference_steps, inversion_max_step, rnri_iterations, rnri_alpha, rnri_lr],
outputs=[inv_state, is_set_text], trigger_mode='once')
description_prompt.change(set_pipe, inputs=[input_image, description_prompt, edit_guidance_scale,
num_inference_steps,
num_inference_steps, inversion_max_step, rnri_iterations, rnri_alpha,
rnri_lr],
outputs=[inv_state, is_set_text], trigger_mode='once')
edit_guidance_scale.change(set_pipe, inputs=[input_image, description_prompt, edit_guidance_scale,
num_inference_steps,
num_inference_steps, inversion_max_step, rnri_iterations, rnri_alpha,
rnri_lr],
outputs=[inv_state, is_set_text], trigger_mode='once')
num_inference_steps.change(set_pipe, inputs=[input_image, description_prompt, edit_guidance_scale,
num_inference_steps,
num_inference_steps, inversion_max_step, rnri_iterations, rnri_alpha,
rnri_lr],
outputs=[inv_state, is_set_text], trigger_mode='once')
inversion_max_step.change(set_pipe, inputs=[input_image, description_prompt, edit_guidance_scale,
num_inference_steps,
num_inference_steps, inversion_max_step, rnri_iterations, rnri_alpha,
rnri_lr],
outputs=[inv_state, is_set_text], trigger_mode='once')
rnri_iterations.change(set_pipe, inputs=[input_image, description_prompt, edit_guidance_scale,
num_inference_steps,
num_inference_steps, inversion_max_step, rnri_iterations, rnri_alpha,
rnri_lr],
outputs=[inv_state, is_set_text], trigger_mode='once')
rnri_alpha.change(set_pipe, inputs=[input_image, description_prompt, edit_guidance_scale,
num_inference_steps,
num_inference_steps, inversion_max_step, rnri_iterations, rnri_alpha,
rnri_lr],
outputs=[inv_state, is_set_text], trigger_mode='once')
rnri_lr.change(set_pipe, inputs=[input_image, description_prompt, edit_guidance_scale,
num_inference_steps,
num_inference_steps, inversion_max_step, rnri_iterations, rnri_alpha,
rnri_lr],
outputs=[inv_state, is_set_text], trigger_mode='once')
# set_button.click(
# fn=set_pipe,
# inputs=[inv_state, input_image, description_prompt, edit_guidance_scale, num_inference_steps,
# num_inference_steps, inversion_max_step, rnri_iterations, rnri_alpha, rnri_lr],
# outputs=[inv_state, is_set_text],
# )
run_button.click(
fn=edit,
inputs=[inv_state, target_prompt],
outputs=[result]
)
demo.queue().launch()