import torch.nn.functional as F import torch import numpy as np from PIL import Image import os import sys sys.path.append(os.path.dirname(os.path.abspath(__file__))) from .brushnet_nodes import BrushNetLoader, BrushNet, BlendInpaint, get_files_with_extension from .comfyui_utils import CheckpointLoaderSimple, ControlNetLoader, ControlNetApplyAdvanced, CLIPTextEncode, KSampler, VAEDecode, GrowMask, PIDINET_Preprocessor, LineArt_Preprocessor, Color_Preprocessor class ScribbleColorEditModel(): def __init__(self): self.checkpoint_loader = CheckpointLoaderSimple() self.clip_text_encoder = CLIPTextEncode() self.mask_processor = GrowMask() self.controlnet_loader = ControlNetLoader() self.scribble_processor = PIDINET_Preprocessor() self.lineart_processor = LineArt_Preprocessor() self.color_processor = Color_Preprocessor() self.brushnet_loader = BrushNetLoader() self.brushnet_node = BrushNet() self.controlnet_apply = ControlNetApplyAdvanced() self.ksampler = KSampler() self.vae_decoder = VAEDecode() self.blender = BlendInpaint() self.ckpt_name = "SD1.5/realisticVisionV60B1_v51VAE.safetensors" with torch.no_grad(): self.model, self.clip, self.vae = self.checkpoint_loader.load_checkpoint(self.ckpt_name) self.load_models('SD1.5', 'float16') def load_models(self, base_model_version="SD1.5", dtype='float16'): if base_model_version == "SD1.5": edge_controlnet_name = "control_v11p_sd15_scribble.safetensors" color_controlnet_name = "color_finetune.safetensors" brushnet_name = "brushnet/random_mask_brushnet_ckpt/diffusion_pytorch_model.safetensors" # elif base_model_version == "SDXL": # edge_controlnet_name = "controlnet-scribble-sdxl-1.0.safetensors" # color_controlnet_name = "colorGridControlnet_v10.safetensors" # brushnet_name = "brushnet_xl/random_mask_brushnet_ckpt_sdxl_v0/diffusion_pytorch_model.safetensors" else: raise ValueError("Invalid base_model_version, not supported yet!!!: {}".format(base_model_version)) self.edge_controlnet = self.controlnet_loader.load_controlnet(edge_controlnet_name)[0] self.color_controlnet = self.controlnet_loader.load_controlnet(color_controlnet_name)[0] self.brushnet_loader.inpaint_files = get_files_with_extension('inpaint') print("self.brushnet_loader.inpaint_files: ", get_files_with_extension('inpaint')) self.brushnet = self.brushnet_loader.brushnet_loading(brushnet_name, dtype)[0] def process(self, ckpt_name, image, colored_image, positive_prompt, negative_prompt, mask, add_mask, remove_mask, grow_size, stroke_as_edge, fine_edge, edge_strength, color_strength, inpaint_strength, seed, steps, cfg, sampler_name, scheduler, base_model_version='SD1.5', dtype='float16', palette_resolution=2048): if ckpt_name != self.ckpt_name: self.ckpt_name = ckpt_name with torch.no_grad(): self.model, self.clip, self.vae = self.checkpoint_loader.load_checkpoint(ckpt_name) if not hasattr(self, 'edge_controlnet') or not hasattr(self, 'color_controlnet') or not hasattr(self, 'brushnet'): self.load_models(base_model_version, dtype) # 根据基础模型版本加载相应的 ControlNet&BrushNet 模型 positive = self.clip_text_encoder.encode(self.clip, positive_prompt)[0] negative = self.clip_text_encoder.encode(self.clip, negative_prompt)[0] # Grow Mask for Color Editing mask = self.mask_processor.expand_mask(mask, expand=grow_size, tapered_corners=True)[0] # Realistic Lineart image_copy = image.clone() if stroke_as_edge == "disable": bool_add_mask = add_mask > 0.5 mean_brightness = image_copy[bool_add_mask].mean() if mean_brightness > 0.8: image_copy[bool_add_mask] = 0.0 else: image_copy[bool_add_mask] = 1.0 if not torch.equal(image, colored_image): print("Apply color controlnet") color_output = self.color_processor.execute(colored_image, resolution=palette_resolution)[0] lineart_output = self.lineart_processor.execute(image, resolution=512, coarse=False)[0] positive, negative = self.controlnet_apply.apply_controlnet(positive, negative, self.color_controlnet, color_output, color_strength, 0.0, 1.0) positive, negative = self.controlnet_apply.apply_controlnet(positive, negative, self.edge_controlnet, lineart_output, 0.8, 0.0, 1.0) else: print("Apply edge controlnet") # Resize masks to match the dimensions of lineart_output color_output = self.color_processor.execute(image, resolution=palette_resolution)[0] if fine_edge == "enable": lineart_output = self.lineart_processor.execute(image, resolution=512, coarse=False)[0] else: lineart_output = self.scribble_processor.execute(image, resolution=512)[0] add_mask_resized = F.interpolate(add_mask.unsqueeze(0).unsqueeze(0).float(), size=(1, lineart_output.shape[1], lineart_output.shape[2]), mode='nearest').squeeze(0).squeeze(0) remove_mask_resized = F.interpolate(remove_mask.unsqueeze(0).unsqueeze(0).float(), size=(1, lineart_output.shape[1], lineart_output.shape[2]), mode='nearest').squeeze(0).squeeze(0) bool_add_mask_resized = (add_mask_resized > 0.5) bool_remove_mask_resized = (remove_mask_resized > 0.5) if stroke_as_edge == "enable": # 将remove_mask区域的像素变成黑色 lineart_output[bool_remove_mask_resized] = 0.0 # 将add_mask区域的像素变成白色 lineart_output[bool_add_mask_resized] = 1.0 else: lineart_output[bool_remove_mask_resized & ~bool_add_mask_resized] = 0.0 positive, negative = self.controlnet_apply.apply_controlnet(positive, negative, self.edge_controlnet, lineart_output, edge_strength, 0.0, 1.0) # BrushNet model, positive, negative, latent = self.brushnet_node.model_update( model=self.model, vae=self.vae, # 需要根据实际情况提供 VAE 模型 image=image, mask=mask, brushnet=self.brushnet, positive=positive, negative=negative, scale=inpaint_strength, start_at=0, end_at=10000 ) # KSampler Node latent_samples = self.ksampler.sample( model=model, seed=seed, steps=steps, cfg=cfg, sampler_name=sampler_name, scheduler=scheduler, positive=positive, negative=negative, latent_image=latent, )[0] final_image = self.vae_decoder.decode(self.vae, latent_samples)[0] final_image = self.blender.blend_inpaint(final_image, image, mask, kernel=10, sigma=10.0)[0] # Return the final image return (latent_samples, final_image, lineart_output, color_output)