File size: 6,806 Bytes
d4733f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54f5ad4
d4733f5
 
 
 
 
 
 
 
54f5ad4
d4733f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca78dbf
d4733f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch.nn.functional as F
import torch
import numpy as np
from PIL import Image
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from .brushnet_nodes import BrushNetLoader, BrushNet, BlendInpaint, get_files_with_extension
from .comfyui_utils import CheckpointLoaderSimple, ControlNetLoader, ControlNetApplyAdvanced, CLIPTextEncode, KSampler, VAEDecode, GrowMask, PIDINET_Preprocessor, LineArt_Preprocessor, Color_Preprocessor

class ScribbleColorEditModel():
    def __init__(self):
        self.checkpoint_loader = CheckpointLoaderSimple()
        self.clip_text_encoder = CLIPTextEncode()
        self.mask_processor = GrowMask()
        self.controlnet_loader = ControlNetLoader()
        self.scribble_processor = PIDINET_Preprocessor()
        self.lineart_processor = LineArt_Preprocessor()
        self.color_processor = Color_Preprocessor()
        self.brushnet_loader = BrushNetLoader()
        self.brushnet_node = BrushNet()
        self.controlnet_apply = ControlNetApplyAdvanced()
        self.ksampler = KSampler()
        self.vae_decoder = VAEDecode()
        self.blender = BlendInpaint()
        self.ckpt_name = os.path.normpath("SD1.5/realisticVisionV60B1_v51VAE.safetensors")
        with torch.no_grad():
            self.model, self.clip, self.vae = self.checkpoint_loader.load_checkpoint(self.ckpt_name)
        self.load_models('SD1.5', 'float16')

    def load_models(self, base_model_version="SD1.5", dtype='float16'):
        if base_model_version == "SD1.5":
            edge_controlnet_name = "control_v11p_sd15_scribble.safetensors"
            color_controlnet_name = "color_finetune.safetensors"
            brushnet_name = os.path.normpath("brushnet/random_mask_brushnet_ckpt/diffusion_pytorch_model.safetensors")
        else:
            raise ValueError("Invalid base_model_version, not supported yet!!!: {}".format(base_model_version))
        self.edge_controlnet = self.controlnet_loader.load_controlnet(edge_controlnet_name)[0]
        self.color_controlnet = self.controlnet_loader.load_controlnet(color_controlnet_name)[0]
        self.brushnet_loader.inpaint_files = get_files_with_extension('inpaint')
        print("self.brushnet_loader.inpaint_files: ", get_files_with_extension('inpaint'))
        self.brushnet = self.brushnet_loader.brushnet_loading(brushnet_name, dtype)[0]
    
    def process(self, ckpt_name, image, colored_image, positive_prompt, negative_prompt, mask, add_mask, remove_mask, grow_size, stroke_as_edge, fine_edge, edge_strength, color_strength, inpaint_strength, seed, steps, cfg, sampler_name, scheduler, base_model_version='SD1.5', dtype='float16', palette_resolution=2048):
        if ckpt_name != self.ckpt_name:
            self.ckpt_name = ckpt_name
            with torch.no_grad():
                self.model, self.clip, self.vae = self.checkpoint_loader.load_checkpoint(ckpt_name)
        if not hasattr(self, 'edge_controlnet') or not hasattr(self, 'color_controlnet') or not hasattr(self, 'brushnet'):
            self.load_models(base_model_version, dtype)
        positive = self.clip_text_encoder.encode(self.clip, positive_prompt)[0]
        negative = self.clip_text_encoder.encode(self.clip, negative_prompt)[0]        
        # Grow Mask for Color Editing
        mask = self.mask_processor.expand_mask(mask, expand=grow_size, tapered_corners=True)[0]
        # Realistic Lineart
        image_copy = image.clone()
        if stroke_as_edge == "disable":
            bool_add_mask = add_mask > 0.5
            mean_brightness = image_copy[bool_add_mask].mean()
            if mean_brightness > 0.8:
                image_copy[bool_add_mask] = 0.0
            else:
                image_copy[bool_add_mask] = 1.0
                

        if not torch.equal(image, colored_image):
            print("Apply color controlnet")
            color_output = self.color_processor.execute(colored_image, resolution=palette_resolution)[0]
            lineart_output = self.lineart_processor.execute(image, resolution=512, coarse=False)[0]
            positive, negative = self.controlnet_apply.apply_controlnet(positive, negative, self.color_controlnet, color_output, color_strength, 0.0, 1.0)
            positive, negative = self.controlnet_apply.apply_controlnet(positive, negative, self.edge_controlnet, lineart_output, 0.8, 0.0, 1.0)
        else:
            print("Apply edge controlnet")
            # Resize masks to match the dimensions of lineart_output
            color_output = self.color_processor.execute(image, resolution=palette_resolution)[0]
            if fine_edge == "enable":
                lineart_output = self.lineart_processor.execute(image, resolution=512, coarse=False)[0]
            else:
                lineart_output = self.scribble_processor.execute(image, resolution=512)[0]
            add_mask_resized = F.interpolate(add_mask.unsqueeze(0).unsqueeze(0).float(), size=(1, lineart_output.shape[1], lineart_output.shape[2]), mode='nearest').squeeze(0).squeeze(0)
            remove_mask_resized = F.interpolate(remove_mask.unsqueeze(0).unsqueeze(0).float(), size=(1, lineart_output.shape[1], lineart_output.shape[2]), mode='nearest').squeeze(0).squeeze(0)

            bool_add_mask_resized = (add_mask_resized > 0.5)
            bool_remove_mask_resized = (remove_mask_resized > 0.5)

            if stroke_as_edge == "enable":
                lineart_output[bool_remove_mask_resized] = 0.0
                lineart_output[bool_add_mask_resized] = 1.0
            else:
                lineart_output[bool_remove_mask_resized & ~bool_add_mask_resized] = 0.0
            positive, negative = self.controlnet_apply.apply_controlnet(positive, negative, self.edge_controlnet, lineart_output, edge_strength, 0.0, 1.0)

        # BrushNet
        model, positive, negative, latent = self.brushnet_node.model_update(
            model=self.model,
            vae=self.vae,
            image=image,
            mask=mask,
            brushnet=self.brushnet,
            positive=positive,
            negative=negative,
            scale=inpaint_strength,
            start_at=0,
            end_at=10000
        )

        # KSampler Node
        latent_samples = self.ksampler.sample(
            model=model, 
            seed=seed, 
            steps=steps, 
            cfg=cfg, 
            sampler_name=sampler_name, 
            scheduler=scheduler, 
            positive=positive, 
            negative=negative, 
            latent_image=latent,
        )[0]

        final_image = self.vae_decoder.decode(self.vae, latent_samples)[0]
        final_image = self.blender.blend_inpaint(final_image, image, mask, kernel=10, sigma=10.0)[0]

        # Return the final image
        return (latent_samples, final_image, lineart_output, color_output)