File size: 4,136 Bytes
45efc4b
 
 
 
 
 
 
 
 
 
 
 
 
e3bc468
45efc4b
 
 
e3bc468
45efc4b
 
 
 
 
 
 
 
 
 
 
 
 
 
e11e948
 
 
 
45efc4b
 
 
6711456
45efc4b
e3bc468
 
 
 
 
 
 
 
 
 
 
 
 
 
45efc4b
e3bc468
45efc4b
 
32cf151
45efc4b
e3bc468
45efc4b
 
 
f04fc45
e3bc468
 
 
 
 
 
 
 
 
 
 
 
 
c121b44
f04fc45
e3bc468
 
 
 
 
 
 
45efc4b
 
e3bc468
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
from src.config import RunConfig
import PIL
from src.euler_scheduler import MyEulerAncestralDiscreteScheduler
from diffusers.pipelines.auto_pipeline import AutoPipelineForImage2Image
from src.sdxl_inversion_pipeline import SDXLDDIMPipeline

from diffusers.utils.torch_utils import randn_tensor


def inversion_callback(pipe, step, timestep, callback_kwargs):
    return callback_kwargs


def inference_callback(pipe, step, timestep, callback_kwargs):
    return callback_kwargs


def center_crop(im):
    width, height = im.size  # Get dimensions
    min_dim = min(width, height)
    left = (width - min_dim) / 2
    top = (height - min_dim) / 2
    right = (width + min_dim) / 2
    bottom = (height + min_dim) / 2

    # Crop the center of the image
    im = im.crop((left, top, right, bottom))
    return im


def load_im_into_format_from_path(im_path):
    if isinstance(im_path, str):
        return center_crop(PIL.Image.open(im_path)).resize((512, 512))
    else:
        return center_crop(im_path).resize((512, 512))


class ImageEditorDemo:
    def __init__(self, pipe_inversion, pipe_inference, input_image, description_prompt, cfg, device, inv_hp):
        self.original_image = load_im_into_format_from_path(input_image).convert("RGB")


        # self.pipe_inversion = self.pipe_inversion.to(device)
        # self.last_latent = self.invert(pipe_inversion, self.original_image, description_prompt)

        # if device  == 'cuda':
        # after the inversion, we can move the inversion model to the CPU
        # self.pipe_inversion = self.pipe_inversion.to('cpu')

        # self.pipe_inference = self.pipe_inference.to(device)

    @staticmethod
    def invert(pipe_inversion, init_image, base_prompt, cfg, inv_hp, device):
        init_image = load_im_into_format_from_path(init_image).convert("RGB")
        g_cpu = torch.Generator().manual_seed(7865)
        img_size = (512, 512)
        VQAE_SCALE = 8
        latents_size = (1, 4, img_size[0] // VQAE_SCALE, img_size[1] // VQAE_SCALE)
        noise = [randn_tensor(latents_size, dtype=torch.float16, device=torch.device(device), generator=g_cpu) for i
                 in range(cfg.num_inversion_steps)]
        pipe_inversion.cfg = cfg
        pipe_inversion.scheduler.set_noise_list(noise)
        pipe_inversion.scheduler_inference.set_noise_list(noise)
        pipe_inversion.set_progress_bar_config(disable=True)

        res = pipe_inversion(prompt=base_prompt,
                             num_inversion_steps=cfg.num_inversion_steps,
                             num_inference_steps=cfg.num_inference_steps,
                             image=init_image,
                             guidance_scale=cfg.inversion_guidance_scale,
                             strength=cfg.inversion_max_step,
                             denoising_start=1.0 - cfg.inversion_max_step,
                             inv_hp=inv_hp)[0][0]
        return {"latent": res, "noise": noise, "cfg": cfg}

    @staticmethod
    def edit(pipe_inference, target_prompt, last_latent, noise, cfg, edit_cfg):
        pipe_inference.cfg = cfg
        pipe_inference.scheduler.set_noise_list(noise)
        pipe_inference.set_progress_bar_config(disable=True)
        image = pipe_inference(prompt=target_prompt,
                               num_inference_steps=cfg.num_inference_steps,
                               negative_prompt="",
                               image=last_latent,
                               strength=cfg.inversion_max_step,
                               denoising_start=1.0 - cfg.inversion_max_step,
                               guidance_scale=edit_cfg).images[0]
        return image

    # def to(self, device):
    #     self.pipe_inference = self.pipe_inference.to(device)
    #     self.pipe_inversion = self.pipe_inversion.to(device)
    #     self.last_latent = self.last_latent.to(device)
    #
    #     self.pipe_inversion.scheduler.set_noise_list_device(device)
    #     self.pipe_inference.scheduler.set_noise_list_device(device)
    #     self.pipe_inversion.scheduler_inference.set_noise_list_device(device)
    #     return self