File size: 7,324 Bytes
d4733f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import torch.nn.functional as F
import torch
import numpy as np
from PIL import Image
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from .brushnet_nodes import BrushNetLoader, BrushNet, BlendInpaint, get_files_with_extension
from .comfyui_utils import CheckpointLoaderSimple, ControlNetLoader, ControlNetApplyAdvanced, CLIPTextEncode, KSampler, VAEDecode, GrowMask, PIDINET_Preprocessor, LineArt_Preprocessor, Color_Preprocessor

class ScribbleColorEditModel():
    def __init__(self):
        self.checkpoint_loader = CheckpointLoaderSimple()
        self.clip_text_encoder = CLIPTextEncode()
        self.mask_processor = GrowMask()
        self.controlnet_loader = ControlNetLoader()
        self.scribble_processor = PIDINET_Preprocessor()
        self.lineart_processor = LineArt_Preprocessor()
        self.color_processor = Color_Preprocessor()
        self.brushnet_loader = BrushNetLoader()
        self.brushnet_node = BrushNet()
        self.controlnet_apply = ControlNetApplyAdvanced()
        self.ksampler = KSampler()
        self.vae_decoder = VAEDecode()
        self.blender = BlendInpaint()
        self.ckpt_name = "SD1.5/realisticVisionV60B1_v51VAE.safetensors"
        with torch.no_grad():
            self.model, self.clip, self.vae = self.checkpoint_loader.load_checkpoint(self.ckpt_name)
        self.load_models('SD1.5', 'float16')

    def load_models(self, base_model_version="SD1.5", dtype='float16'):
        if base_model_version == "SD1.5":
            edge_controlnet_name = "control_v11p_sd15_scribble.safetensors"
            color_controlnet_name = "color_finetune.safetensors"
            brushnet_name = "brushnet/random_mask_brushnet_ckpt/diffusion_pytorch_model.safetensors"
        # elif base_model_version == "SDXL":
        #     edge_controlnet_name = "controlnet-scribble-sdxl-1.0.safetensors"
        #     color_controlnet_name = "colorGridControlnet_v10.safetensors"
        #     brushnet_name = "brushnet_xl/random_mask_brushnet_ckpt_sdxl_v0/diffusion_pytorch_model.safetensors"
        else:
            raise ValueError("Invalid base_model_version, not supported yet!!!: {}".format(base_model_version))
        self.edge_controlnet = self.controlnet_loader.load_controlnet(edge_controlnet_name)[0]
        self.color_controlnet = self.controlnet_loader.load_controlnet(color_controlnet_name)[0]
        self.brushnet_loader.inpaint_files = get_files_with_extension('inpaint')
        print("self.brushnet_loader.inpaint_files: ", get_files_with_extension('inpaint'))
        self.brushnet = self.brushnet_loader.brushnet_loading(brushnet_name, dtype)[0]
    
    def process(self, ckpt_name, image, colored_image, positive_prompt, negative_prompt, mask, add_mask, remove_mask, grow_size, stroke_as_edge, fine_edge, edge_strength, color_strength, inpaint_strength, seed, steps, cfg, sampler_name, scheduler, base_model_version='SD1.5', dtype='float16', palette_resolution=2048):
        if ckpt_name != self.ckpt_name:
            self.ckpt_name = ckpt_name
            with torch.no_grad():
                self.model, self.clip, self.vae = self.checkpoint_loader.load_checkpoint(ckpt_name)
        if not hasattr(self, 'edge_controlnet') or not hasattr(self, 'color_controlnet') or not hasattr(self, 'brushnet'):
            self.load_models(base_model_version, dtype)
        # 根据基础模型版本加载相应的 ControlNet&BrushNet 模型
        positive = self.clip_text_encoder.encode(self.clip, positive_prompt)[0]
        negative = self.clip_text_encoder.encode(self.clip, negative_prompt)[0]        
        # Grow Mask for Color Editing
        mask = self.mask_processor.expand_mask(mask, expand=grow_size, tapered_corners=True)[0]
        # Realistic Lineart
        image_copy = image.clone()
        if stroke_as_edge == "disable":
            bool_add_mask = add_mask > 0.5
            mean_brightness = image_copy[bool_add_mask].mean()
            if mean_brightness > 0.8:
                image_copy[bool_add_mask] = 0.0
            else:
                image_copy[bool_add_mask] = 1.0
                

        if not torch.equal(image, colored_image):
            print("Apply color controlnet")
            color_output = self.color_processor.execute(colored_image, resolution=palette_resolution)[0]
            lineart_output = self.lineart_processor.execute(image, resolution=512, coarse=False)[0]
            positive, negative = self.controlnet_apply.apply_controlnet(positive, negative, self.color_controlnet, color_output, color_strength, 0.0, 1.0)
            positive, negative = self.controlnet_apply.apply_controlnet(positive, negative, self.edge_controlnet, lineart_output, 0.8, 0.0, 1.0)
        else:
            print("Apply edge controlnet")
            # Resize masks to match the dimensions of lineart_output
            color_output = self.color_processor.execute(image, resolution=palette_resolution)[0]
            if fine_edge == "enable":
                lineart_output = self.lineart_processor.execute(image, resolution=512, coarse=False)[0]
            else:
                lineart_output = self.scribble_processor.execute(image, resolution=512)[0]
            add_mask_resized = F.interpolate(add_mask.unsqueeze(0).unsqueeze(0).float(), size=(1, lineart_output.shape[1], lineart_output.shape[2]), mode='nearest').squeeze(0).squeeze(0)
            remove_mask_resized = F.interpolate(remove_mask.unsqueeze(0).unsqueeze(0).float(), size=(1, lineart_output.shape[1], lineart_output.shape[2]), mode='nearest').squeeze(0).squeeze(0)

            bool_add_mask_resized = (add_mask_resized > 0.5)
            bool_remove_mask_resized = (remove_mask_resized > 0.5)

            if stroke_as_edge == "enable":
                # 将remove_mask区域的像素变成黑色
                lineart_output[bool_remove_mask_resized] = 0.0
                # 将add_mask区域的像素变成白色
                lineart_output[bool_add_mask_resized] = 1.0
            else:
                lineart_output[bool_remove_mask_resized & ~bool_add_mask_resized] = 0.0
            positive, negative = self.controlnet_apply.apply_controlnet(positive, negative, self.edge_controlnet, lineart_output, edge_strength, 0.0, 1.0)

        # BrushNet
        model, positive, negative, latent = self.brushnet_node.model_update(
            model=self.model,
            vae=self.vae,  # 需要根据实际情况提供 VAE 模型
            image=image,
            mask=mask,
            brushnet=self.brushnet,
            positive=positive,
            negative=negative,
            scale=inpaint_strength,
            start_at=0,
            end_at=10000
        )

        # KSampler Node
        latent_samples = self.ksampler.sample(
            model=model, 
            seed=seed, 
            steps=steps, 
            cfg=cfg, 
            sampler_name=sampler_name, 
            scheduler=scheduler, 
            positive=positive, 
            negative=negative, 
            latent_image=latent,
        )[0]

        final_image = self.vae_decoder.decode(self.vae, latent_samples)[0]
        final_image = self.blender.blend_inpaint(final_image, image, mask, kernel=10, sigma=10.0)[0]

        # Return the final image
        return (latent_samples, final_image, lineart_output, color_output)