RNRI / src /editor.py
Barak1's picture
return noise to pipe_inference.
c121b44
import torch
from src.config import RunConfig
import PIL
from src.euler_scheduler import MyEulerAncestralDiscreteScheduler
from diffusers.pipelines.auto_pipeline import AutoPipelineForImage2Image
from src.sdxl_inversion_pipeline import SDXLDDIMPipeline
from diffusers.utils.torch_utils import randn_tensor
def inversion_callback(pipe, step, timestep, callback_kwargs):
return callback_kwargs
def inference_callback(pipe, step, timestep, callback_kwargs):
return callback_kwargs
def center_crop(im):
width, height = im.size # Get dimensions
min_dim = min(width, height)
left = (width - min_dim) / 2
top = (height - min_dim) / 2
right = (width + min_dim) / 2
bottom = (height + min_dim) / 2
# Crop the center of the image
im = im.crop((left, top, right, bottom))
return im
def load_im_into_format_from_path(im_path):
if isinstance(im_path, str):
return center_crop(PIL.Image.open(im_path)).resize((512, 512))
else:
return center_crop(im_path).resize((512, 512))
class ImageEditorDemo:
def __init__(self, pipe_inversion, pipe_inference, input_image, description_prompt, cfg, device, inv_hp):
self.original_image = load_im_into_format_from_path(input_image).convert("RGB")
# self.pipe_inversion = self.pipe_inversion.to(device)
# self.last_latent = self.invert(pipe_inversion, self.original_image, description_prompt)
# if device == 'cuda':
# after the inversion, we can move the inversion model to the CPU
# self.pipe_inversion = self.pipe_inversion.to('cpu')
# self.pipe_inference = self.pipe_inference.to(device)
@staticmethod
def invert(pipe_inversion, init_image, base_prompt, cfg, inv_hp, device):
init_image = load_im_into_format_from_path(init_image).convert("RGB")
g_cpu = torch.Generator().manual_seed(7865)
img_size = (512, 512)
VQAE_SCALE = 8
latents_size = (1, 4, img_size[0] // VQAE_SCALE, img_size[1] // VQAE_SCALE)
noise = [randn_tensor(latents_size, dtype=torch.float16, device=torch.device(device), generator=g_cpu) for i
in range(cfg.num_inversion_steps)]
pipe_inversion.cfg = cfg
pipe_inversion.scheduler.set_noise_list(noise)
pipe_inversion.scheduler_inference.set_noise_list(noise)
pipe_inversion.set_progress_bar_config(disable=True)
res = pipe_inversion(prompt=base_prompt,
num_inversion_steps=cfg.num_inversion_steps,
num_inference_steps=cfg.num_inference_steps,
image=init_image,
guidance_scale=cfg.inversion_guidance_scale,
strength=cfg.inversion_max_step,
denoising_start=1.0 - cfg.inversion_max_step,
inv_hp=inv_hp)[0][0]
return {"latent": res, "noise": noise, "cfg": cfg}
@staticmethod
def edit(pipe_inference, target_prompt, last_latent, noise, cfg, edit_cfg):
pipe_inference.cfg = cfg
pipe_inference.scheduler.set_noise_list(noise)
pipe_inference.set_progress_bar_config(disable=True)
image = pipe_inference(prompt=target_prompt,
num_inference_steps=cfg.num_inference_steps,
negative_prompt="",
image=last_latent,
strength=cfg.inversion_max_step,
denoising_start=1.0 - cfg.inversion_max_step,
guidance_scale=edit_cfg).images[0]
return image
# def to(self, device):
# self.pipe_inference = self.pipe_inference.to(device)
# self.pipe_inversion = self.pipe_inversion.to(device)
# self.last_latent = self.last_latent.to(device)
#
# self.pipe_inversion.scheduler.set_noise_list_device(device)
# self.pipe_inference.scheduler.set_noise_list_device(device)
# self.pipe_inversion.scheduler_inference.set_noise_list_device(device)
# return self