0?(this.length+=e.getTotalLength(),this.functions.push(e),s=[h[y][5]+s[0],h[y][6]+s[1]]):this.functions.push(new l(s[0],s[0],s[1],s[1]));else if("S"===h[y][0]){if(y>0&&["C","c","S","s"].indexOf(h[y-1][0])>-1){if(e){var p=e.getC();e=new j(s[0],s[1],2*s[0]-p.x,2*s[1]-p.y,h[y][1],h[y][2],h[y][3],h[y][4])}}else e=new j(s[0],s[1],s[0],s[1],h[y][1],h[y][2],h[y][3],h[y][4]);e&&(this.length+=e.getTotalLength(),s=[h[y][3],h[y][4]],this.functions.push(e))}else if("s"===h[y][0]){if(y>0&&["C","c","S","s"].indexOf(h[y-1][0])>-1){if(e){var x=e.getC(),v=e.getD();e=new j(s[0],s[1],s[0]+v.x-x.x,s[1]+v.y-x.y,s[0]+h[y][1],s[1]+h[y][2],s[0]+h[y][3],s[1]+h[y][4])}}else e=new j(s[0],s[1],s[0],s[1],s[0]+h[y][1],s[1]+h[y][2],s[0]+h[y][3],s[1]+h[y][4]);e&&(this.length+=e.getTotalLength(),s=[h[y][3]+s[0],h[y][4]+s[1]],this.functions.push(e))}else if("Q"===h[y][0]){if(s[0]==h[y][1]&&s[1]==h[y][2]){var M=new l(h[y][1],h[y][3],h[y][2],h[y][4]);this.length+=M.getTotalLength(),this.functions.push(M)}else e=new j(s[0],s[1],h[y][1],h[y][2],h[y][3],h[y][4],void 0,void 0),this.length+=e.getTotalLength(),this.functions.push(e);s=[h[y][3],h[y][4]],g=[h[y][1],h[y][2]]}else if("q"===h[y][0]){if(0!=h[y][1]||0!=h[y][2])e=new j(s[0],s[1],s[0]+h[y][1],s[1]+h[y][2],s[0]+h[y][3],s[1]+h[y][4],void 0,void 0),this.length+=e.getTotalLength(),this.functions.push(e);else{var w=new l(s[0]+h[y][1],s[0]+h[y][3],s[1]+h[y][2],s[1]+h[y][4]);this.length+=w.getTotalLength(),this.functions.push(w)}g=[s[0]+h[y][1],s[1]+h[y][2]],s=[h[y][3]+s[0],h[y][4]+s[1]]}else if("T"===h[y][0]){if(y>0&&["Q","q","T","t"].indexOf(h[y-1][0])>-1)e=new j(s[0],s[1],2*s[0]-g[0],2*s[1]-g[1],h[y][1],h[y][2],void 0,void 0),this.functions.push(e),this.length+=e.getTotalLength();else{var L=new l(s[0],h[y][1],s[1],h[y][2]);this.functions.push(L),this.length+=L.getTotalLength()}g=[2*s[0]-g[0],2*s[1]-g[1]],s=[h[y][1],h[y][2]]}else if("t"===h[y][0]){if(y>0&&["Q","q","T","t"].indexOf(h[y-1][0])>-1)e=new j(s[0],s[1],2*s[0]-g[0],2*s[1]-g[1],s[0]+h[y][1],s[1]+h[y][2],void 0,void 0),this.length+=e.getTotalLength(),this.functions.push(e);else{var d=new l(s[0],s[0]+h[y][1],s[1],s[1]+h[y][2]);this.length+=d.getTotalLength(),this.functions.push(d)}g=[2*s[0]-g[0],2*s[1]-g[1]],s=[h[y][1]+s[0],h[y][2]+s[1]]}else if("A"===h[y][0]){var A=new c(s[0],s[1],h[y][1],h[y][2],h[y][3],1===h[y][4],1===h[y][5],h[y][6],h[y][7]);this.length+=A.getTotalLength(),s=[h[y][6],h[y][7]],this.functions.push(A)}else if("a"===h[y][0]){var b=new c(s[0],s[1],h[y][1],h[y][2],h[y][3],1===h[y][4],1===h[y][5],s[0]+h[y][6],s[1]+h[y][7]);this.length+=b.getTotalLength(),s=[s[0]+h[y][6],s[1]+h[y][7]],this.functions.push(b)}this.partial_lengths.push(this.length)}})),E=e((function(t){var n=this;if(i(this,"inst",void 0),i(this,"getTotalLength",(function(){return n.inst.getTotalLength()})),i(this,"getPointAtLength",(function(t){return n.inst.getPointAtLength(t)})),i(this,"getTangentAtLength",(function(t){return n.inst.getTangentAtLength(t)})),i(this,"getPropertiesAtLength",(function(t){return n.inst.getPropertiesAtLength(t)})),i(this,"getParts",(function(){return n.inst.getParts()})),this.inst=new O(t),!(this instanceof E))return new E(t)}));t.svgPathProperties=E}));
diff --git a/custom_nodes/ComfyUI-KJNodes-main/nodes/audioscheduler_nodes.py b/custom_nodes/ComfyUI-KJNodes-main/nodes/audioscheduler_nodes.py
new file mode 100644
index 0000000000000000000000000000000000000000..69d0422e7da875298f87fe60a7f6d1494530dca2
--- /dev/null
+++ b/custom_nodes/ComfyUI-KJNodes-main/nodes/audioscheduler_nodes.py
@@ -0,0 +1,251 @@
+# to be used with https://github.com/a1lazydog/ComfyUI-AudioScheduler
+import torch
+from torchvision.transforms import functional as TF
+from PIL import Image, ImageDraw
+import numpy as np
+from ..utility.utility import pil2tensor
+from nodes import MAX_RESOLUTION
+
+class NormalizedAmplitudeToMask:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "normalized_amp": ("NORMALIZED_AMPLITUDE",),
+ "width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ "height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ "frame_offset": ("INT", {"default": 0,"min": -255, "max": 255, "step": 1}),
+ "location_x": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}),
+ "location_y": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}),
+ "size": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}),
+ "shape": (
+ [
+ 'none',
+ 'circle',
+ 'square',
+ 'triangle',
+ ],
+ {
+ "default": 'none'
+ }),
+ "color": (
+ [
+ 'white',
+ 'amplitude',
+ ],
+ {
+ "default": 'amplitude'
+ }),
+ },}
+
+ CATEGORY = "KJNodes/audio"
+ RETURN_TYPES = ("MASK",)
+ FUNCTION = "convert"
+ DESCRIPTION = """
+Works as a bridge to the AudioScheduler -nodes:
+https://github.com/a1lazydog/ComfyUI-AudioScheduler
+Creates masks based on the normalized amplitude.
+"""
+
+ def convert(self, normalized_amp, width, height, frame_offset, shape, location_x, location_y, size, color):
+ # Ensure normalized_amp is an array and within the range [0, 1]
+ normalized_amp = np.clip(normalized_amp, 0.0, 1.0)
+
+ # Offset the amplitude values by rolling the array
+ normalized_amp = np.roll(normalized_amp, frame_offset)
+
+ # Initialize an empty list to hold the image tensors
+ out = []
+ # Iterate over each amplitude value to create an image
+ for amp in normalized_amp:
+ # Scale the amplitude value to cover the full range of grayscale values
+ if color == 'amplitude':
+ grayscale_value = int(amp * 255)
+ elif color == 'white':
+ grayscale_value = 255
+ # Convert the grayscale value to an RGB format
+ gray_color = (grayscale_value, grayscale_value, grayscale_value)
+ finalsize = size * amp
+
+ if shape == 'none':
+ shapeimage = Image.new("RGB", (width, height), gray_color)
+ else:
+ shapeimage = Image.new("RGB", (width, height), "black")
+
+ draw = ImageDraw.Draw(shapeimage)
+ if shape == 'circle' or shape == 'square':
+ # Define the bounding box for the shape
+ left_up_point = (location_x - finalsize, location_y - finalsize)
+ right_down_point = (location_x + finalsize,location_y + finalsize)
+ two_points = [left_up_point, right_down_point]
+
+ if shape == 'circle':
+ draw.ellipse(two_points, fill=gray_color)
+ elif shape == 'square':
+ draw.rectangle(two_points, fill=gray_color)
+
+ elif shape == 'triangle':
+ # Define the points for the triangle
+ left_up_point = (location_x - finalsize, location_y + finalsize) # bottom left
+ right_down_point = (location_x + finalsize, location_y + finalsize) # bottom right
+ top_point = (location_x, location_y) # top point
+ draw.polygon([top_point, left_up_point, right_down_point], fill=gray_color)
+
+ shapeimage = pil2tensor(shapeimage)
+ mask = shapeimage[:, :, :, 0]
+ out.append(mask)
+
+ return (torch.cat(out, dim=0),)
+
+class NormalizedAmplitudeToFloatList:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "normalized_amp": ("NORMALIZED_AMPLITUDE",),
+ },}
+
+ CATEGORY = "KJNodes/audio"
+ RETURN_TYPES = ("FLOAT",)
+ FUNCTION = "convert"
+ DESCRIPTION = """
+Works as a bridge to the AudioScheduler -nodes:
+https://github.com/a1lazydog/ComfyUI-AudioScheduler
+Creates a list of floats from the normalized amplitude.
+"""
+
+ def convert(self, normalized_amp):
+ # Ensure normalized_amp is an array and within the range [0, 1]
+ normalized_amp = np.clip(normalized_amp, 0.0, 1.0)
+ return (normalized_amp.tolist(),)
+
+class OffsetMaskByNormalizedAmplitude:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "normalized_amp": ("NORMALIZED_AMPLITUDE",),
+ "mask": ("MASK",),
+ "x": ("INT", { "default": 0, "min": -4096, "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
+ "y": ("INT", { "default": 0, "min": -4096, "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
+ "rotate": ("BOOLEAN", { "default": False }),
+ "angle_multiplier": ("FLOAT", { "default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001, "display": "number" }),
+ }
+ }
+
+ RETURN_TYPES = ("MASK",)
+ RETURN_NAMES = ("mask",)
+ FUNCTION = "offset"
+ CATEGORY = "KJNodes/audio"
+ DESCRIPTION = """
+Works as a bridge to the AudioScheduler -nodes:
+https://github.com/a1lazydog/ComfyUI-AudioScheduler
+Offsets masks based on the normalized amplitude.
+"""
+
+ def offset(self, mask, x, y, angle_multiplier, rotate, normalized_amp):
+
+ # Ensure normalized_amp is an array and within the range [0, 1]
+ offsetmask = mask.clone()
+ normalized_amp = np.clip(normalized_amp, 0.0, 1.0)
+
+ batch_size, height, width = mask.shape
+
+ if rotate:
+ for i in range(batch_size):
+ rotation_amp = int(normalized_amp[i] * (360 * angle_multiplier))
+ rotation_angle = rotation_amp
+ offsetmask[i] = TF.rotate(offsetmask[i].unsqueeze(0), rotation_angle).squeeze(0)
+ if x != 0 or y != 0:
+ for i in range(batch_size):
+ offset_amp = normalized_amp[i] * 10
+ shift_x = min(x*offset_amp, width-1)
+ shift_y = min(y*offset_amp, height-1)
+ if shift_x != 0:
+ offsetmask[i] = torch.roll(offsetmask[i], shifts=int(shift_x), dims=1)
+ if shift_y != 0:
+ offsetmask[i] = torch.roll(offsetmask[i], shifts=int(shift_y), dims=0)
+
+ return offsetmask,
+
+class ImageTransformByNormalizedAmplitude:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "normalized_amp": ("NORMALIZED_AMPLITUDE",),
+ "zoom_scale": ("FLOAT", { "default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001, "display": "number" }),
+ "x_offset": ("INT", { "default": 0, "min": (1 -MAX_RESOLUTION), "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
+ "y_offset": ("INT", { "default": 0, "min": (1 -MAX_RESOLUTION), "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
+ "cumulative": ("BOOLEAN", { "default": False }),
+ "image": ("IMAGE",),
+ }}
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "amptransform"
+ CATEGORY = "KJNodes/audio"
+ DESCRIPTION = """
+Works as a bridge to the AudioScheduler -nodes:
+https://github.com/a1lazydog/ComfyUI-AudioScheduler
+Transforms image based on the normalized amplitude.
+"""
+
+ def amptransform(self, image, normalized_amp, zoom_scale, cumulative, x_offset, y_offset):
+ # Ensure normalized_amp is an array and within the range [0, 1]
+ normalized_amp = np.clip(normalized_amp, 0.0, 1.0)
+ transformed_images = []
+
+ # Initialize the cumulative zoom factor
+ prev_amp = 0.0
+
+ for i in range(image.shape[0]):
+ img = image[i] # Get the i-th image in the batch
+ amp = normalized_amp[i] # Get the corresponding amplitude value
+
+ # Incrementally increase the cumulative zoom factor
+ if cumulative:
+ prev_amp += amp
+ amp += prev_amp
+
+ # Convert the image tensor from BxHxWxC to CxHxW format expected by torchvision
+ img = img.permute(2, 0, 1)
+
+ # Convert PyTorch tensor to PIL Image for processing
+ pil_img = TF.to_pil_image(img)
+
+ # Calculate the crop size based on the amplitude
+ width, height = pil_img.size
+ crop_size = int(min(width, height) * (1 - amp * zoom_scale))
+ crop_size = max(crop_size, 1)
+
+ # Calculate the crop box coordinates (centered crop)
+ left = (width - crop_size) // 2
+ top = (height - crop_size) // 2
+ right = (width + crop_size) // 2
+ bottom = (height + crop_size) // 2
+
+ # Crop and resize back to original size
+ cropped_img = TF.crop(pil_img, top, left, crop_size, crop_size)
+ resized_img = TF.resize(cropped_img, (height, width))
+
+ # Convert back to tensor in CxHxW format
+ tensor_img = TF.to_tensor(resized_img)
+
+ # Convert the tensor back to BxHxWxC format
+ tensor_img = tensor_img.permute(1, 2, 0)
+
+ # Offset the image based on the amplitude
+ offset_amp = amp * 10 # Calculate the offset magnitude based on the amplitude
+ shift_x = min(x_offset * offset_amp, img.shape[1] - 1) # Calculate the shift in x direction
+ shift_y = min(y_offset * offset_amp, img.shape[0] - 1) # Calculate the shift in y direction
+
+ # Apply the offset to the image tensor
+ if shift_x != 0:
+ tensor_img = torch.roll(tensor_img, shifts=int(shift_x), dims=1)
+ if shift_y != 0:
+ tensor_img = torch.roll(tensor_img, shifts=int(shift_y), dims=0)
+
+ # Add to the list
+ transformed_images.append(tensor_img)
+
+ # Stack all transformed images into a batch
+ transformed_batch = torch.stack(transformed_images)
+
+ return (transformed_batch,)
\ No newline at end of file
diff --git a/custom_nodes/ComfyUI-KJNodes-main/nodes/batchcrop_nodes.py b/custom_nodes/ComfyUI-KJNodes-main/nodes/batchcrop_nodes.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e7446f7567f74421d4b05742cc3a340fec73c9
--- /dev/null
+++ b/custom_nodes/ComfyUI-KJNodes-main/nodes/batchcrop_nodes.py
@@ -0,0 +1,757 @@
+from ..utility.utility import tensor2pil, pil2tensor
+from PIL import Image, ImageDraw, ImageFilter
+import numpy as np
+import torch
+from torchvision.transforms import Resize, CenterCrop, InterpolationMode
+import math
+
+#based on nodes from mtb https://github.com/melMass/comfy_mtb
+
+def bbox_to_region(bbox, target_size=None):
+ bbox = bbox_check(bbox, target_size)
+ return (bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3])
+
+def bbox_check(bbox, target_size=None):
+ if not target_size:
+ return bbox
+
+ new_bbox = (
+ bbox[0],
+ bbox[1],
+ min(target_size[0] - bbox[0], bbox[2]),
+ min(target_size[1] - bbox[1], bbox[3]),
+ )
+ return new_bbox
+
+class BatchCropFromMask:
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "original_images": ("IMAGE",),
+ "masks": ("MASK",),
+ "crop_size_mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}),
+ "bbox_smooth_alpha": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
+ },
+ }
+
+ RETURN_TYPES = (
+ "IMAGE",
+ "IMAGE",
+ "BBOX",
+ "INT",
+ "INT",
+ )
+ RETURN_NAMES = (
+ "original_images",
+ "cropped_images",
+ "bboxes",
+ "width",
+ "height",
+ )
+ FUNCTION = "crop"
+ CATEGORY = "KJNodes/masking"
+
+ def smooth_bbox_size(self, prev_bbox_size, curr_bbox_size, alpha):
+ if alpha == 0:
+ return prev_bbox_size
+ return round(alpha * curr_bbox_size + (1 - alpha) * prev_bbox_size)
+
+ def smooth_center(self, prev_center, curr_center, alpha=0.5):
+ if alpha == 0:
+ return prev_center
+ return (
+ round(alpha * curr_center[0] + (1 - alpha) * prev_center[0]),
+ round(alpha * curr_center[1] + (1 - alpha) * prev_center[1])
+ )
+
+ def crop(self, masks, original_images, crop_size_mult, bbox_smooth_alpha):
+
+ bounding_boxes = []
+ cropped_images = []
+
+ self.max_bbox_width = 0
+ self.max_bbox_height = 0
+
+ # First, calculate the maximum bounding box size across all masks
+ curr_max_bbox_width = 0
+ curr_max_bbox_height = 0
+ for mask in masks:
+ _mask = tensor2pil(mask)[0]
+ non_zero_indices = np.nonzero(np.array(_mask))
+ min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1])
+ min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0])
+ width = max_x - min_x
+ height = max_y - min_y
+ curr_max_bbox_width = max(curr_max_bbox_width, width)
+ curr_max_bbox_height = max(curr_max_bbox_height, height)
+
+ # Smooth the changes in the bounding box size
+ self.max_bbox_width = self.smooth_bbox_size(self.max_bbox_width, curr_max_bbox_width, bbox_smooth_alpha)
+ self.max_bbox_height = self.smooth_bbox_size(self.max_bbox_height, curr_max_bbox_height, bbox_smooth_alpha)
+
+ # Apply the crop size multiplier
+ self.max_bbox_width = round(self.max_bbox_width * crop_size_mult)
+ self.max_bbox_height = round(self.max_bbox_height * crop_size_mult)
+ bbox_aspect_ratio = self.max_bbox_width / self.max_bbox_height
+
+ # Then, for each mask and corresponding image...
+ for i, (mask, img) in enumerate(zip(masks, original_images)):
+ _mask = tensor2pil(mask)[0]
+ non_zero_indices = np.nonzero(np.array(_mask))
+ min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1])
+ min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0])
+
+ # Calculate center of bounding box
+ center_x = np.mean(non_zero_indices[1])
+ center_y = np.mean(non_zero_indices[0])
+ curr_center = (round(center_x), round(center_y))
+
+ # If this is the first frame, initialize prev_center with curr_center
+ if not hasattr(self, 'prev_center'):
+ self.prev_center = curr_center
+
+ # Smooth the changes in the center coordinates from the second frame onwards
+ if i > 0:
+ center = self.smooth_center(self.prev_center, curr_center, bbox_smooth_alpha)
+ else:
+ center = curr_center
+
+ # Update prev_center for the next frame
+ self.prev_center = center
+
+ # Create bounding box using max_bbox_width and max_bbox_height
+ half_box_width = round(self.max_bbox_width / 2)
+ half_box_height = round(self.max_bbox_height / 2)
+ min_x = max(0, center[0] - half_box_width)
+ max_x = min(img.shape[1], center[0] + half_box_width)
+ min_y = max(0, center[1] - half_box_height)
+ max_y = min(img.shape[0], center[1] + half_box_height)
+
+ # Append bounding box coordinates
+ bounding_boxes.append((min_x, min_y, max_x - min_x, max_y - min_y))
+
+ # Crop the image from the bounding box
+ cropped_img = img[min_y:max_y, min_x:max_x, :]
+
+ # Calculate the new dimensions while maintaining the aspect ratio
+ new_height = min(cropped_img.shape[0], self.max_bbox_height)
+ new_width = round(new_height * bbox_aspect_ratio)
+
+ # Resize the image
+ resize_transform = Resize((new_height, new_width))
+ resized_img = resize_transform(cropped_img.permute(2, 0, 1))
+
+ # Perform the center crop to the desired size
+ crop_transform = CenterCrop((self.max_bbox_height, self.max_bbox_width)) # swap the order here if necessary
+ cropped_resized_img = crop_transform(resized_img)
+
+ cropped_images.append(cropped_resized_img.permute(1, 2, 0))
+
+ cropped_out = torch.stack(cropped_images, dim=0)
+
+ return (original_images, cropped_out, bounding_boxes, self.max_bbox_width, self.max_bbox_height, )
+
+class BatchUncrop:
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "original_images": ("IMAGE",),
+ "cropped_images": ("IMAGE",),
+ "bboxes": ("BBOX",),
+ "border_blending": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}, ),
+ "crop_rescale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "border_top": ("BOOLEAN", {"default": True}),
+ "border_bottom": ("BOOLEAN", {"default": True}),
+ "border_left": ("BOOLEAN", {"default": True}),
+ "border_right": ("BOOLEAN", {"default": True}),
+ }
+ }
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "uncrop"
+
+ CATEGORY = "KJNodes/masking"
+
+ def uncrop(self, original_images, cropped_images, bboxes, border_blending, crop_rescale, border_top, border_bottom, border_left, border_right):
+ def inset_border(image, border_width, border_color, border_top, border_bottom, border_left, border_right):
+ draw = ImageDraw.Draw(image)
+ width, height = image.size
+ if border_top:
+ draw.rectangle((0, 0, width, border_width), fill=border_color)
+ if border_bottom:
+ draw.rectangle((0, height - border_width, width, height), fill=border_color)
+ if border_left:
+ draw.rectangle((0, 0, border_width, height), fill=border_color)
+ if border_right:
+ draw.rectangle((width - border_width, 0, width, height), fill=border_color)
+ return image
+
+ if len(original_images) != len(cropped_images):
+ raise ValueError(f"The number of original_images ({len(original_images)}) and cropped_images ({len(cropped_images)}) should be the same")
+
+ # Ensure there are enough bboxes, but drop the excess if there are more bboxes than images
+ if len(bboxes) > len(original_images):
+ print(f"Warning: Dropping excess bounding boxes. Expected {len(original_images)}, but got {len(bboxes)}")
+ bboxes = bboxes[:len(original_images)]
+ elif len(bboxes) < len(original_images):
+ raise ValueError("There should be at least as many bboxes as there are original and cropped images")
+
+ input_images = tensor2pil(original_images)
+ crop_imgs = tensor2pil(cropped_images)
+
+ out_images = []
+ for i in range(len(input_images)):
+ img = input_images[i]
+ crop = crop_imgs[i]
+ bbox = bboxes[i]
+
+ # uncrop the image based on the bounding box
+ bb_x, bb_y, bb_width, bb_height = bbox
+
+ paste_region = bbox_to_region((bb_x, bb_y, bb_width, bb_height), img.size)
+
+ # scale factors
+ scale_x = crop_rescale
+ scale_y = crop_rescale
+
+ # scaled paste_region
+ paste_region = (round(paste_region[0]*scale_x), round(paste_region[1]*scale_y), round(paste_region[2]*scale_x), round(paste_region[3]*scale_y))
+
+ # rescale the crop image to fit the paste_region
+ crop = crop.resize((round(paste_region[2]-paste_region[0]), round(paste_region[3]-paste_region[1])))
+ crop_img = crop.convert("RGB")
+
+ if border_blending > 1.0:
+ border_blending = 1.0
+ elif border_blending < 0.0:
+ border_blending = 0.0
+
+ blend_ratio = (max(crop_img.size) / 2) * float(border_blending)
+
+ blend = img.convert("RGBA")
+ mask = Image.new("L", img.size, 0)
+
+ mask_block = Image.new("L", (paste_region[2]-paste_region[0], paste_region[3]-paste_region[1]), 255)
+ mask_block = inset_border(mask_block, round(blend_ratio / 2), (0), border_top, border_bottom, border_left, border_right)
+
+ mask.paste(mask_block, paste_region)
+ blend.paste(crop_img, paste_region)
+
+ mask = mask.filter(ImageFilter.BoxBlur(radius=blend_ratio / 4))
+ mask = mask.filter(ImageFilter.GaussianBlur(radius=blend_ratio / 4))
+
+ blend.putalpha(mask)
+ img = Image.alpha_composite(img.convert("RGBA"), blend)
+ out_images.append(img.convert("RGB"))
+
+ return (pil2tensor(out_images),)
+
+class BatchCropFromMaskAdvanced:
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "original_images": ("IMAGE",),
+ "masks": ("MASK",),
+ "crop_size_mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "bbox_smooth_alpha": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
+ },
+ }
+
+ RETURN_TYPES = (
+ "IMAGE",
+ "IMAGE",
+ "MASK",
+ "IMAGE",
+ "MASK",
+ "BBOX",
+ "BBOX",
+ "INT",
+ "INT",
+ )
+ RETURN_NAMES = (
+ "original_images",
+ "cropped_images",
+ "cropped_masks",
+ "combined_crop_image",
+ "combined_crop_masks",
+ "bboxes",
+ "combined_bounding_box",
+ "bbox_width",
+ "bbox_height",
+ )
+ FUNCTION = "crop"
+ CATEGORY = "KJNodes/masking"
+
+ def smooth_bbox_size(self, prev_bbox_size, curr_bbox_size, alpha):
+ return round(alpha * curr_bbox_size + (1 - alpha) * prev_bbox_size)
+
+ def smooth_center(self, prev_center, curr_center, alpha=0.5):
+ return (round(alpha * curr_center[0] + (1 - alpha) * prev_center[0]),
+ round(alpha * curr_center[1] + (1 - alpha) * prev_center[1]))
+
+ def crop(self, masks, original_images, crop_size_mult, bbox_smooth_alpha):
+ bounding_boxes = []
+ combined_bounding_box = []
+ cropped_images = []
+ cropped_masks = []
+ cropped_masks_out = []
+ combined_crop_out = []
+ combined_cropped_images = []
+ combined_cropped_masks = []
+
+ def calculate_bbox(mask):
+ non_zero_indices = np.nonzero(np.array(mask))
+
+ # handle empty masks
+ min_x, max_x, min_y, max_y = 0, 0, 0, 0
+ if len(non_zero_indices[1]) > 0 and len(non_zero_indices[0]) > 0:
+ min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1])
+ min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0])
+
+ width = max_x - min_x
+ height = max_y - min_y
+ bbox_size = max(width, height)
+ return min_x, max_x, min_y, max_y, bbox_size
+
+ combined_mask = torch.max(masks, dim=0)[0]
+ _mask = tensor2pil(combined_mask)[0]
+ new_min_x, new_max_x, new_min_y, new_max_y, combined_bbox_size = calculate_bbox(_mask)
+ center_x = (new_min_x + new_max_x) / 2
+ center_y = (new_min_y + new_max_y) / 2
+ half_box_size = round(combined_bbox_size // 2)
+ new_min_x = max(0, round(center_x - half_box_size))
+ new_max_x = min(original_images[0].shape[1], round(center_x + half_box_size))
+ new_min_y = max(0, round(center_y - half_box_size))
+ new_max_y = min(original_images[0].shape[0], round(center_y + half_box_size))
+
+ combined_bounding_box.append((new_min_x, new_min_y, new_max_x - new_min_x, new_max_y - new_min_y))
+
+ self.max_bbox_size = 0
+
+ # First, calculate the maximum bounding box size across all masks
+ curr_max_bbox_size = max(calculate_bbox(tensor2pil(mask)[0])[-1] for mask in masks)
+ # Smooth the changes in the bounding box size
+ self.max_bbox_size = self.smooth_bbox_size(self.max_bbox_size, curr_max_bbox_size, bbox_smooth_alpha)
+ # Apply the crop size multiplier
+ self.max_bbox_size = round(self.max_bbox_size * crop_size_mult)
+ # Make sure max_bbox_size is divisible by 16, if not, round it upwards so it is
+ self.max_bbox_size = math.ceil(self.max_bbox_size / 16) * 16
+
+ if self.max_bbox_size > original_images[0].shape[0] or self.max_bbox_size > original_images[0].shape[1]:
+ # max_bbox_size can only be as big as our input's width or height, and it has to be even
+ self.max_bbox_size = math.floor(min(original_images[0].shape[0], original_images[0].shape[1]) / 2) * 2
+
+ # Then, for each mask and corresponding image...
+ for i, (mask, img) in enumerate(zip(masks, original_images)):
+ _mask = tensor2pil(mask)[0]
+ non_zero_indices = np.nonzero(np.array(_mask))
+
+ # check for empty masks
+ if len(non_zero_indices[0]) > 0 and len(non_zero_indices[1]) > 0:
+ min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1])
+ min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0])
+
+ # Calculate center of bounding box
+ center_x = np.mean(non_zero_indices[1])
+ center_y = np.mean(non_zero_indices[0])
+ curr_center = (round(center_x), round(center_y))
+
+ # If this is the first frame, initialize prev_center with curr_center
+ if not hasattr(self, 'prev_center'):
+ self.prev_center = curr_center
+
+ # Smooth the changes in the center coordinates from the second frame onwards
+ if i > 0:
+ center = self.smooth_center(self.prev_center, curr_center, bbox_smooth_alpha)
+ else:
+ center = curr_center
+
+ # Update prev_center for the next frame
+ self.prev_center = center
+
+ # Create bounding box using max_bbox_size
+ half_box_size = self.max_bbox_size // 2
+ min_x = max(0, center[0] - half_box_size)
+ max_x = min(img.shape[1], center[0] + half_box_size)
+ min_y = max(0, center[1] - half_box_size)
+ max_y = min(img.shape[0], center[1] + half_box_size)
+
+ # Append bounding box coordinates
+ bounding_boxes.append((min_x, min_y, max_x - min_x, max_y - min_y))
+
+ # Crop the image from the bounding box
+ cropped_img = img[min_y:max_y, min_x:max_x, :]
+ cropped_mask = mask[min_y:max_y, min_x:max_x]
+
+ # Resize the cropped image to a fixed size
+ new_size = max(cropped_img.shape[0], cropped_img.shape[1])
+ resize_transform = Resize(new_size, interpolation=InterpolationMode.NEAREST, max_size=max(img.shape[0], img.shape[1]))
+ resized_mask = resize_transform(cropped_mask.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)
+ resized_img = resize_transform(cropped_img.permute(2, 0, 1))
+ # Perform the center crop to the desired size
+ # Constrain the crop to the smaller of our bbox or our image so we don't expand past the image dimensions.
+ crop_transform = CenterCrop((min(self.max_bbox_size, resized_img.shape[1]), min(self.max_bbox_size, resized_img.shape[2])))
+
+ cropped_resized_img = crop_transform(resized_img)
+ cropped_images.append(cropped_resized_img.permute(1, 2, 0))
+
+ cropped_resized_mask = crop_transform(resized_mask)
+ cropped_masks.append(cropped_resized_mask)
+
+ combined_cropped_img = original_images[i][new_min_y:new_max_y, new_min_x:new_max_x, :]
+ combined_cropped_images.append(combined_cropped_img)
+
+ combined_cropped_mask = masks[i][new_min_y:new_max_y, new_min_x:new_max_x]
+ combined_cropped_masks.append(combined_cropped_mask)
+ else:
+ bounding_boxes.append((0, 0, img.shape[1], img.shape[0]))
+ cropped_images.append(img)
+ cropped_masks.append(mask)
+ combined_cropped_images.append(img)
+ combined_cropped_masks.append(mask)
+
+ cropped_out = torch.stack(cropped_images, dim=0)
+ combined_crop_out = torch.stack(combined_cropped_images, dim=0)
+ cropped_masks_out = torch.stack(cropped_masks, dim=0)
+ combined_crop_mask_out = torch.stack(combined_cropped_masks, dim=0)
+
+ return (original_images, cropped_out, cropped_masks_out, combined_crop_out, combined_crop_mask_out, bounding_boxes, combined_bounding_box, self.max_bbox_size, self.max_bbox_size)
+
+class FilterZeroMasksAndCorrespondingImages:
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "masks": ("MASK",),
+ },
+ "optional": {
+ "original_images": ("IMAGE",),
+ },
+ }
+
+ RETURN_TYPES = ("MASK", "IMAGE", "IMAGE", "INDEXES",)
+ RETURN_NAMES = ("non_zero_masks_out", "non_zero_mask_images_out", "zero_mask_images_out", "zero_mask_images_out_indexes",)
+ FUNCTION = "filter"
+ CATEGORY = "KJNodes/masking"
+ DESCRIPTION = """
+Filter out all the empty (i.e. all zero) mask in masks
+Also filter out all the corresponding images in original_images by indexes if provide
+
+original_images (optional): If provided, need have same length as masks.
+"""
+
+ def filter(self, masks, original_images=None):
+ non_zero_masks = []
+ non_zero_mask_images = []
+ zero_mask_images = []
+ zero_mask_images_indexes = []
+
+ masks_num = len(masks)
+ also_process_images = False
+ if original_images is not None:
+ imgs_num = len(original_images)
+ if len(original_images) == masks_num:
+ also_process_images = True
+ else:
+ print(f"[WARNING] ignore input: original_images, due to number of original_images ({imgs_num}) is not equal to number of masks ({masks_num})")
+
+ for i in range(masks_num):
+ non_zero_num = np.count_nonzero(np.array(masks[i]))
+ if non_zero_num > 0:
+ non_zero_masks.append(masks[i])
+ if also_process_images:
+ non_zero_mask_images.append(original_images[i])
+ else:
+ zero_mask_images.append(original_images[i])
+ zero_mask_images_indexes.append(i)
+
+ non_zero_masks_out = torch.stack(non_zero_masks, dim=0)
+ non_zero_mask_images_out = zero_mask_images_out = zero_mask_images_out_indexes = None
+
+ if also_process_images:
+ non_zero_mask_images_out = torch.stack(non_zero_mask_images, dim=0)
+ if len(zero_mask_images) > 0:
+ zero_mask_images_out = torch.stack(zero_mask_images, dim=0)
+ zero_mask_images_out_indexes = zero_mask_images_indexes
+
+ return (non_zero_masks_out, non_zero_mask_images_out, zero_mask_images_out, zero_mask_images_out_indexes)
+
+class InsertImageBatchByIndexes:
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "images": ("IMAGE",),
+ "images_to_insert": ("IMAGE",),
+ "insert_indexes": ("INDEXES",),
+ },
+ }
+
+ RETURN_TYPES = ("IMAGE", )
+ RETURN_NAMES = ("images_after_insert", )
+ FUNCTION = "insert"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+This node is designed to be use with node FilterZeroMasksAndCorrespondingImages
+It inserts the images_to_insert into images according to insert_indexes
+
+Returns:
+ images_after_insert: updated original images with origonal sequence order
+"""
+
+ def insert(self, images, images_to_insert, insert_indexes):
+ images_after_insert = images
+
+ if images_to_insert is not None and insert_indexes is not None:
+ images_to_insert_num = len(images_to_insert)
+ insert_indexes_num = len(insert_indexes)
+ if images_to_insert_num == insert_indexes_num:
+ images_after_insert = []
+
+ i_images = 0
+ for i in range(len(images) + images_to_insert_num):
+ if i in insert_indexes:
+ images_after_insert.append(images_to_insert[insert_indexes.index(i)])
+ else:
+ images_after_insert.append(images[i_images])
+ i_images += 1
+
+ images_after_insert = torch.stack(images_after_insert, dim=0)
+
+ else:
+ print(f"[WARNING] skip this node, due to number of images_to_insert ({images_to_insert_num}) is not equal to number of insert_indexes ({insert_indexes_num})")
+
+
+ return (images_after_insert, )
+
+class BatchUncropAdvanced:
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "original_images": ("IMAGE",),
+ "cropped_images": ("IMAGE",),
+ "cropped_masks": ("MASK",),
+ "combined_crop_mask": ("MASK",),
+ "bboxes": ("BBOX",),
+ "border_blending": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}, ),
+ "crop_rescale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "use_combined_mask": ("BOOLEAN", {"default": False}),
+ "use_square_mask": ("BOOLEAN", {"default": True}),
+ },
+ "optional": {
+ "combined_bounding_box": ("BBOX", {"default": None}),
+ },
+ }
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "uncrop"
+ CATEGORY = "KJNodes/masking"
+
+
+ def uncrop(self, original_images, cropped_images, cropped_masks, combined_crop_mask, bboxes, border_blending, crop_rescale, use_combined_mask, use_square_mask, combined_bounding_box = None):
+
+ def inset_border(image, border_width=20, border_color=(0)):
+ width, height = image.size
+ bordered_image = Image.new(image.mode, (width, height), border_color)
+ bordered_image.paste(image, (0, 0))
+ draw = ImageDraw.Draw(bordered_image)
+ draw.rectangle((0, 0, width - 1, height - 1), outline=border_color, width=border_width)
+ return bordered_image
+
+ if len(original_images) != len(cropped_images):
+ raise ValueError(f"The number of original_images ({len(original_images)}) and cropped_images ({len(cropped_images)}) should be the same")
+
+ # Ensure there are enough bboxes, but drop the excess if there are more bboxes than images
+ if len(bboxes) > len(original_images):
+ print(f"Warning: Dropping excess bounding boxes. Expected {len(original_images)}, but got {len(bboxes)}")
+ bboxes = bboxes[:len(original_images)]
+ elif len(bboxes) < len(original_images):
+ raise ValueError("There should be at least as many bboxes as there are original and cropped images")
+
+ crop_imgs = tensor2pil(cropped_images)
+ input_images = tensor2pil(original_images)
+ out_images = []
+
+ for i in range(len(input_images)):
+ img = input_images[i]
+ crop = crop_imgs[i]
+ bbox = bboxes[i]
+
+ if use_combined_mask:
+ bb_x, bb_y, bb_width, bb_height = combined_bounding_box[0]
+ paste_region = bbox_to_region((bb_x, bb_y, bb_width, bb_height), img.size)
+ mask = combined_crop_mask[i]
+ else:
+ bb_x, bb_y, bb_width, bb_height = bbox
+ paste_region = bbox_to_region((bb_x, bb_y, bb_width, bb_height), img.size)
+ mask = cropped_masks[i]
+
+ # scale paste_region
+ scale_x = scale_y = crop_rescale
+ paste_region = (round(paste_region[0]*scale_x), round(paste_region[1]*scale_y), round(paste_region[2]*scale_x), round(paste_region[3]*scale_y))
+
+ # rescale the crop image to fit the paste_region
+ crop = crop.resize((round(paste_region[2]-paste_region[0]), round(paste_region[3]-paste_region[1])))
+ crop_img = crop.convert("RGB")
+
+ #border blending
+ if border_blending > 1.0:
+ border_blending = 1.0
+ elif border_blending < 0.0:
+ border_blending = 0.0
+
+ blend_ratio = (max(crop_img.size) / 2) * float(border_blending)
+ blend = img.convert("RGBA")
+
+ if use_square_mask:
+ mask = Image.new("L", img.size, 0)
+ mask_block = Image.new("L", (paste_region[2]-paste_region[0], paste_region[3]-paste_region[1]), 255)
+ mask_block = inset_border(mask_block, round(blend_ratio / 2), (0))
+ mask.paste(mask_block, paste_region)
+ else:
+ original_mask = tensor2pil(mask)[0]
+ original_mask = original_mask.resize((paste_region[2]-paste_region[0], paste_region[3]-paste_region[1]))
+ mask = Image.new("L", img.size, 0)
+ mask.paste(original_mask, paste_region)
+
+ mask = mask.filter(ImageFilter.BoxBlur(radius=blend_ratio / 4))
+ mask = mask.filter(ImageFilter.GaussianBlur(radius=blend_ratio / 4))
+
+ blend.paste(crop_img, paste_region)
+ blend.putalpha(mask)
+
+ img = Image.alpha_composite(img.convert("RGBA"), blend)
+ out_images.append(img.convert("RGB"))
+
+ return (pil2tensor(out_images),)
+
+class SplitBboxes:
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "bboxes": ("BBOX",),
+ "index": ("INT", {"default": 0,"min": 0, "max": 99999999, "step": 1}),
+ },
+ }
+
+ RETURN_TYPES = ("BBOX","BBOX",)
+ RETURN_NAMES = ("bboxes_a","bboxes_b",)
+ FUNCTION = "splitbbox"
+ CATEGORY = "KJNodes/masking"
+ DESCRIPTION = """
+Splits the specified bbox list at the given index into two lists.
+"""
+
+ def splitbbox(self, bboxes, index):
+ bboxes_a = bboxes[:index] # Sub-list from the start of bboxes up to (but not including) the index
+ bboxes_b = bboxes[index:] # Sub-list from the index to the end of bboxes
+
+ return (bboxes_a, bboxes_b,)
+
+class BboxToInt:
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "bboxes": ("BBOX",),
+ "index": ("INT", {"default": 0,"min": 0, "max": 99999999, "step": 1}),
+ },
+ }
+
+ RETURN_TYPES = ("INT","INT","INT","INT","INT","INT",)
+ RETURN_NAMES = ("x_min","y_min","width","height", "center_x","center_y",)
+ FUNCTION = "bboxtoint"
+ CATEGORY = "KJNodes/masking"
+ DESCRIPTION = """
+Returns selected index from bounding box list as integers.
+"""
+ def bboxtoint(self, bboxes, index):
+ x_min, y_min, width, height = bboxes[index]
+ center_x = int(x_min + width / 2)
+ center_y = int(y_min + height / 2)
+
+ return (x_min, y_min, width, height, center_x, center_y,)
+
+class BboxVisualize:
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "images": ("IMAGE",),
+ "bboxes": ("BBOX",),
+ "line_width": ("INT", {"default": 1,"min": 1, "max": 10, "step": 1}),
+ },
+ }
+
+ RETURN_TYPES = ("IMAGE",)
+ RETURN_NAMES = ("images",)
+ FUNCTION = "visualizebbox"
+ DESCRIPTION = """
+Visualizes the specified bbox on the image.
+"""
+
+ CATEGORY = "KJNodes/masking"
+
+ def visualizebbox(self, bboxes, images, line_width):
+ image_list = []
+ for image, bbox in zip(images, bboxes):
+ x_min, y_min, width, height = bbox
+
+ # Ensure bbox coordinates are integers
+ x_min = int(x_min)
+ y_min = int(y_min)
+ width = int(width)
+ height = int(height)
+
+ # Permute the image dimensions
+ image = image.permute(2, 0, 1)
+
+ # Clone the image to draw bounding boxes
+ img_with_bbox = image.clone()
+
+ # Define the color for the bbox, e.g., red
+ color = torch.tensor([1, 0, 0], dtype=torch.float32)
+
+ # Ensure color tensor matches the image channels
+ if color.shape[0] != img_with_bbox.shape[0]:
+ color = color.unsqueeze(1).expand(-1, line_width)
+
+ # Draw lines for each side of the bbox with the specified line width
+ for lw in range(line_width):
+ # Top horizontal line
+ if y_min + lw < img_with_bbox.shape[1]:
+ img_with_bbox[:, y_min + lw, x_min:x_min + width] = color[:, None]
+
+ # Bottom horizontal line
+ if y_min + height - lw < img_with_bbox.shape[1]:
+ img_with_bbox[:, y_min + height - lw, x_min:x_min + width] = color[:, None]
+
+ # Left vertical line
+ if x_min + lw < img_with_bbox.shape[2]:
+ img_with_bbox[:, y_min:y_min + height, x_min + lw] = color[:, None]
+
+ # Right vertical line
+ if x_min + width - lw < img_with_bbox.shape[2]:
+ img_with_bbox[:, y_min:y_min + height, x_min + width - lw] = color[:, None]
+
+ # Permute the image dimensions back
+ img_with_bbox = img_with_bbox.permute(1, 2, 0).unsqueeze(0)
+ image_list.append(img_with_bbox)
+
+ return (torch.cat(image_list, dim=0),)
+
+ return (torch.cat(image_list, dim=0),)
\ No newline at end of file
diff --git a/custom_nodes/ComfyUI-KJNodes-main/nodes/curve_nodes.py b/custom_nodes/ComfyUI-KJNodes-main/nodes/curve_nodes.py
new file mode 100644
index 0000000000000000000000000000000000000000..8552d0053a653bffe8cf8a9230b4a6529485daf8
--- /dev/null
+++ b/custom_nodes/ComfyUI-KJNodes-main/nodes/curve_nodes.py
@@ -0,0 +1,1561 @@
+import torch
+from torchvision import transforms
+import json
+from PIL import Image, ImageDraw, ImageFont, ImageColor, ImageFilter, ImageChops
+import numpy as np
+from ..utility.utility import pil2tensor, tensor2pil
+import folder_paths
+import io
+import base64
+
+from comfy.utils import common_upscale
+
+def plot_coordinates_to_tensor(coordinates, height, width, bbox_height, bbox_width, size_multiplier, prompt):
+ import matplotlib
+ matplotlib.use('Agg')
+ from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
+ text_color = '#999999'
+ bg_color = '#353535'
+ matplotlib.pyplot.rcParams['text.color'] = text_color
+ fig, ax = matplotlib.pyplot.subplots(figsize=(width/100, height/100), dpi=100)
+ fig.patch.set_facecolor(bg_color)
+ ax.set_facecolor(bg_color)
+ ax.grid(color=text_color, linestyle='-', linewidth=0.5)
+ ax.set_xlabel('x', color=text_color)
+ ax.set_ylabel('y', color=text_color)
+ for text in ax.get_xticklabels() + ax.get_yticklabels():
+ text.set_color(text_color)
+ ax.set_title('position for: ' + prompt)
+ ax.set_xlabel('X Coordinate')
+ ax.set_ylabel('Y Coordinate')
+ #ax.legend().remove()
+ ax.set_xlim(0, width) # Set the x-axis to match the input latent width
+ ax.set_ylim(height, 0) # Set the y-axis to match the input latent height, with (0,0) at top-left
+ # Adjust the margins of the subplot
+ matplotlib.pyplot.subplots_adjust(left=0.08, right=0.95, bottom=0.05, top=0.95, wspace=0.2, hspace=0.2)
+
+ cmap = matplotlib.pyplot.get_cmap('rainbow')
+ image_batch = []
+ canvas = FigureCanvas(fig)
+ width, height = fig.get_size_inches() * fig.get_dpi()
+ # Draw a box at each coordinate
+ for i, ((x, y), size) in enumerate(zip(coordinates, size_multiplier)):
+ color_index = i / (len(coordinates) - 1)
+ color = cmap(color_index)
+ draw_height = bbox_height * size
+ draw_width = bbox_width * size
+ rect = matplotlib.patches.Rectangle((x - draw_width/2, y - draw_height/2), draw_width, draw_height,
+ linewidth=1, edgecolor=color, facecolor='none', alpha=0.5)
+ ax.add_patch(rect)
+
+ # Check if there is a next coordinate to draw an arrow to
+ if i < len(coordinates) - 1:
+ x1, y1 = coordinates[i]
+ x2, y2 = coordinates[i + 1]
+ ax.annotate("", xy=(x2, y2), xytext=(x1, y1),
+ arrowprops=dict(arrowstyle="->",
+ linestyle="-",
+ lw=1,
+ color=color,
+ mutation_scale=20))
+ canvas.draw()
+ image_np = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(int(height), int(width), 3).copy()
+ image_tensor = torch.from_numpy(image_np).float() / 255.0
+ image_tensor = image_tensor.unsqueeze(0)
+ image_batch.append(image_tensor)
+
+ matplotlib.pyplot.close(fig)
+ image_batch_tensor = torch.cat(image_batch, dim=0)
+
+ return image_batch_tensor
+
+class PlotCoordinates:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "coordinates": ("STRING", {"forceInput": True}),
+ "text": ("STRING", {"default": 'title', "multiline": False}),
+ "width": ("INT", {"default": 512, "min": 8, "max": 4096, "step": 8}),
+ "height": ("INT", {"default": 512, "min": 8, "max": 4096, "step": 8}),
+ "bbox_width": ("INT", {"default": 128, "min": 8, "max": 4096, "step": 8}),
+ "bbox_height": ("INT", {"default": 128, "min": 8, "max": 4096, "step": 8}),
+ },
+ "optional": {"size_multiplier": ("FLOAT", {"default": [1.0], "forceInput": True})},
+ }
+ RETURN_TYPES = ("IMAGE", "INT", "INT", "INT", "INT",)
+ RETURN_NAMES = ("images", "width", "height", "bbox_width", "bbox_height",)
+ FUNCTION = "append"
+ CATEGORY = "KJNodes/experimental"
+ DESCRIPTION = """
+Plots coordinates to sequence of images using Matplotlib.
+
+"""
+
+ def append(self, coordinates, text, width, height, bbox_width, bbox_height, size_multiplier=[1.0]):
+ coordinates = json.loads(coordinates.replace("'", '"'))
+ coordinates = [(coord['x'], coord['y']) for coord in coordinates]
+ batch_size = len(coordinates)
+ if not size_multiplier or len(size_multiplier) != batch_size:
+ size_multiplier = [0] * batch_size
+ else:
+ size_multiplier = size_multiplier * (batch_size // len(size_multiplier)) + size_multiplier[:batch_size % len(size_multiplier)]
+
+ plot_image_tensor = plot_coordinates_to_tensor(coordinates, height, width, bbox_height, bbox_width, size_multiplier, text)
+
+ return (plot_image_tensor, width, height, bbox_width, bbox_height)
+
+class SplineEditor:
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "points_store": ("STRING", {"multiline": False}),
+ "coordinates": ("STRING", {"multiline": False}),
+ "mask_width": ("INT", {"default": 512, "min": 8, "max": 4096, "step": 8}),
+ "mask_height": ("INT", {"default": 512, "min": 8, "max": 4096, "step": 8}),
+ "points_to_sample": ("INT", {"default": 16, "min": 2, "max": 1000, "step": 1}),
+ "sampling_method": (
+ [
+ 'path',
+ 'time',
+ 'controlpoints'
+ ],
+ {
+ "default": 'time'
+ }),
+ "interpolation": (
+ [
+ 'cardinal',
+ 'monotone',
+ 'basis',
+ 'linear',
+ 'step-before',
+ 'step-after',
+ 'polar',
+ 'polar-reverse',
+ ],
+ {
+ "default": 'cardinal'
+ }),
+ "tension": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
+ "repeat_output": ("INT", {"default": 1, "min": 1, "max": 4096, "step": 1}),
+ "float_output_type": (
+ [
+ 'list',
+ 'pandas series',
+ 'tensor',
+ ],
+ {
+ "default": 'list'
+ }),
+ },
+ "optional": {
+ "min_value": ("FLOAT", {"default": 0.0, "min": -10000.0, "max": 10000.0, "step": 0.01}),
+ "max_value": ("FLOAT", {"default": 1.0, "min": -10000.0, "max": 10000.0, "step": 0.01}),
+ "bg_image": ("IMAGE", ),
+ }
+ }
+
+ RETURN_TYPES = ("MASK", "STRING", "FLOAT", "INT", "STRING",)
+ RETURN_NAMES = ("mask", "coord_str", "float", "count", "normalized_str",)
+ FUNCTION = "splinedata"
+ CATEGORY = "KJNodes/weights"
+ DESCRIPTION = """
+# WORK IN PROGRESS
+Do not count on this as part of your workflow yet,
+probably contains lots of bugs and stability is not
+guaranteed!!
+
+## Graphical editor to create values for various
+## schedules and/or mask batches.
+
+**Shift + click** to add control point at end.
+**Ctrl + click** to add control point (subdivide) between two points.
+**Right click on a point** to delete it.
+Note that you can't delete from start/end.
+
+Right click on canvas for context menu:
+These are purely visual options, doesn't affect the output:
+ - Toggle handles visibility
+ - Display sample points: display the points to be returned.
+
+**points_to_sample** value sets the number of samples
+returned from the **drawn spline itself**, this is independent from the
+actual control points, so the interpolation type matters.
+sampling_method:
+ - time: samples along the time axis, used for schedules
+ - path: samples along the path itself, useful for coordinates
+
+output types:
+ - mask batch
+ example compatible nodes: anything that takes masks
+ - list of floats
+ example compatible nodes: IPAdapter weights
+ - pandas series
+ example compatible nodes: anything that takes Fizz'
+ nodes Batch Value Schedule
+ - torch tensor
+ example compatible nodes: unknown
+"""
+
+ def splinedata(self, mask_width, mask_height, coordinates, float_output_type, interpolation,
+ points_to_sample, sampling_method, points_store, tension, repeat_output,
+ min_value=0.0, max_value=1.0, bg_image=None):
+
+ coordinates = json.loads(coordinates)
+ normalized = []
+ normalized_y_values = []
+ for coord in coordinates:
+ coord['x'] = int(round(coord['x']))
+ coord['y'] = int(round(coord['y']))
+ norm_x = (1.0 - (coord['x'] / mask_height) - 0.0) * (max_value - min_value) + min_value
+ norm_y = (1.0 - (coord['y'] / mask_height) - 0.0) * (max_value - min_value) + min_value
+ normalized_y_values.append(norm_y)
+ normalized.append({'x':norm_x, 'y':norm_y})
+ if float_output_type == 'list':
+ out_floats = normalized_y_values * repeat_output
+ elif float_output_type == 'pandas series':
+ try:
+ import pandas as pd
+ except:
+ raise Exception("MaskOrImageToWeight: pandas is not installed. Please install pandas to use this output_type")
+ out_floats = pd.Series(normalized_y_values * repeat_output),
+ elif float_output_type == 'tensor':
+ out_floats = torch.tensor(normalized_y_values * repeat_output, dtype=torch.float32)
+ # Create a color map for grayscale intensities
+ color_map = lambda y: torch.full((mask_height, mask_width, 3), y, dtype=torch.float32)
+
+ # Create image tensors for each normalized y value
+ mask_tensors = [color_map(y) for y in normalized_y_values]
+ masks_out = torch.stack(mask_tensors)
+ masks_out = masks_out.repeat(repeat_output, 1, 1, 1)
+ masks_out = masks_out.mean(dim=-1)
+ if bg_image is None:
+ return (masks_out, json.dumps(coordinates), out_floats, len(out_floats) , json.dumps(normalized))
+ else:
+ transform = transforms.ToPILImage()
+ image = transform(bg_image[0].permute(2, 0, 1))
+ buffered = io.BytesIO()
+ image.save(buffered, format="JPEG", quality=75)
+
+ # Step 3: Encode the image bytes to a Base64 string
+ img_bytes = buffered.getvalue()
+ img_base64 = base64.b64encode(img_bytes).decode('utf-8')
+ return {
+ "ui": {"bg_image": [img_base64]},
+ "result":(masks_out, json.dumps(coordinates), out_floats, len(out_floats) , json.dumps(normalized))
+ }
+
+
+class CreateShapeMaskOnPath:
+
+ RETURN_TYPES = ("MASK", "MASK",)
+ RETURN_NAMES = ("mask", "mask_inverted",)
+ FUNCTION = "createshapemask"
+ CATEGORY = "KJNodes/masking/generate"
+ DESCRIPTION = """
+Creates a mask or batch of masks with the specified shape.
+Locations are center locations.
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "shape": (
+ [ 'circle',
+ 'square',
+ 'triangle',
+ ],
+ {
+ "default": 'circle'
+ }),
+ "coordinates": ("STRING", {"forceInput": True}),
+ "frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ "frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ "shape_width": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}),
+ "shape_height": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}),
+ },
+ "optional": {
+ "size_multiplier": ("FLOAT", {"default": [1.0], "forceInput": True}),
+ }
+ }
+
+ def createshapemask(self, coordinates, frame_width, frame_height, shape_width, shape_height, shape, size_multiplier=[1.0]):
+ # Define the number of images in the batch
+ coordinates = coordinates.replace("'", '"')
+ coordinates = json.loads(coordinates)
+
+ batch_size = len(coordinates)
+ out = []
+ color = "white"
+ if not size_multiplier or len(size_multiplier) != batch_size:
+ size_multiplier = [0] * batch_size
+ else:
+ size_multiplier = size_multiplier * (batch_size // len(size_multiplier)) + size_multiplier[:batch_size % len(size_multiplier)]
+ for i, coord in enumerate(coordinates):
+ image = Image.new("RGB", (frame_width, frame_height), "black")
+ draw = ImageDraw.Draw(image)
+
+ # Calculate the size for this frame and ensure it's not less than 0
+ current_width = max(0, shape_width + i * size_multiplier[i])
+ current_height = max(0, shape_height + i * size_multiplier[i])
+
+ location_x = coord['x']
+ location_y = coord['y']
+
+ if shape == 'circle' or shape == 'square':
+ # Define the bounding box for the shape
+ left_up_point = (location_x - current_width // 2, location_y - current_height // 2)
+ right_down_point = (location_x + current_width // 2, location_y + current_height // 2)
+ two_points = [left_up_point, right_down_point]
+
+ if shape == 'circle':
+ draw.ellipse(two_points, fill=color)
+ elif shape == 'square':
+ draw.rectangle(two_points, fill=color)
+
+ elif shape == 'triangle':
+ # Define the points for the triangle
+ left_up_point = (location_x - current_width // 2, location_y + current_height // 2) # bottom left
+ right_down_point = (location_x + current_width // 2, location_y + current_height // 2) # bottom right
+ top_point = (location_x, location_y - current_height // 2) # top point
+ draw.polygon([top_point, left_up_point, right_down_point], fill=color)
+
+ image = pil2tensor(image)
+ mask = image[:, :, :, 0]
+ out.append(mask)
+ outstack = torch.cat(out, dim=0)
+ return (outstack, 1.0 - outstack,)
+
+class CreateShapeImageOnPath:
+
+ RETURN_TYPES = ("IMAGE", "MASK",)
+ RETURN_NAMES = ("image","mask", )
+ FUNCTION = "createshapemask"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Creates an image or batch of images with the specified shape.
+Locations are center locations.
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "shape": (
+ [ 'circle',
+ 'square',
+ 'triangle',
+ ],
+ {
+ "default": 'circle'
+ }),
+ "coordinates": ("STRING", {"forceInput": True}),
+ "frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ "frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ "shape_width": ("INT", {"default": 128,"min": 2, "max": 4096, "step": 1}),
+ "shape_height": ("INT", {"default": 128,"min": 2, "max": 4096, "step": 1}),
+ "shape_color": ("STRING", {"default": 'white'}),
+ "bg_color": ("STRING", {"default": 'black'}),
+ "blur_radius": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100, "step": 0.1}),
+ "intensity": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 100.0, "step": 0.01}),
+ },
+ "optional": {
+ "size_multiplier": ("FLOAT", {"default": [1.0], "forceInput": True}),
+ "trailing": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ }
+ }
+
+ def createshapemask(self, coordinates, frame_width, frame_height, shape_width, shape_height, shape_color,
+ bg_color, blur_radius, shape, intensity, size_multiplier=[1.0], accumulate=False, trailing=1.0):
+ # Define the number of images in the batch
+ if len(coordinates) < 10:
+ coords_list = []
+ for coords in coordinates:
+ coords = json.loads(coords.replace("'", '"'))
+ coords_list.append(coords)
+ else:
+ coords = json.loads(coordinates.replace("'", '"'))
+ coords_list = [coords]
+
+ batch_size = len(coords_list[0])
+ images_list = []
+ masks_list = []
+
+ if not size_multiplier or len(size_multiplier) != batch_size:
+ size_multiplier = [0] * batch_size
+ else:
+ size_multiplier = size_multiplier * (batch_size // len(size_multiplier)) + size_multiplier[:batch_size % len(size_multiplier)]
+
+ previous_output = None
+
+ for i in range(batch_size):
+ image = Image.new("RGB", (frame_width, frame_height), bg_color)
+ draw = ImageDraw.Draw(image)
+
+ # Calculate the size for this frame and ensure it's not less than 0
+ current_width = max(0, shape_width + i * size_multiplier[i])
+ current_height = max(0, shape_height + i * size_multiplier[i])
+
+ for coords in coords_list:
+ location_x = coords[i]['x']
+ location_y = coords[i]['y']
+
+ if shape == 'circle' or shape == 'square':
+ # Define the bounding box for the shape
+ left_up_point = (location_x - current_width // 2, location_y - current_height // 2)
+ right_down_point = (location_x + current_width // 2, location_y + current_height // 2)
+ two_points = [left_up_point, right_down_point]
+
+ if shape == 'circle':
+ draw.ellipse(two_points, fill=shape_color)
+ elif shape == 'square':
+ draw.rectangle(two_points, fill=shape_color)
+
+ elif shape == 'triangle':
+ # Define the points for the triangle
+ left_up_point = (location_x - current_width // 2, location_y + current_height // 2) # bottom left
+ right_down_point = (location_x + current_width // 2, location_y + current_height // 2) # bottom right
+ top_point = (location_x, location_y - current_height // 2) # top point
+ draw.polygon([top_point, left_up_point, right_down_point], fill=shape_color)
+
+ if blur_radius != 0:
+ image = image.filter(ImageFilter.GaussianBlur(blur_radius))
+ # Blend the current image with the accumulated image
+
+ image = pil2tensor(image)
+ if trailing != 1.0 and previous_output is not None:
+ # Add the decayed previous output to the current frame
+ image += trailing * previous_output
+ image = image / image.max()
+ previous_output = image
+ image = image * intensity
+ mask = image[:, :, :, 0]
+ masks_list.append(mask)
+ images_list.append(image)
+ out_images = torch.cat(images_list, dim=0).cpu().float()
+ out_masks = torch.cat(masks_list, dim=0)
+ return (out_images, out_masks)
+
+class CreateTextOnPath:
+
+ RETURN_TYPES = ("IMAGE", "MASK", "MASK",)
+ RETURN_NAMES = ("image", "mask", "mask_inverted",)
+ FUNCTION = "createtextmask"
+ CATEGORY = "KJNodes/masking/generate"
+ DESCRIPTION = """
+Creates a mask or batch of masks with the specified text.
+Locations are center locations.
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "coordinates": ("STRING", {"forceInput": True}),
+ "text": ("STRING", {"default": 'text', "multiline": True}),
+ "frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ "frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ "font": (folder_paths.get_filename_list("kjnodes_fonts"), ),
+ "font_size": ("INT", {"default": 42}),
+ "alignment": (
+ [ 'left',
+ 'center',
+ 'right'
+ ],
+ {"default": 'center'}
+ ),
+ "text_color": ("STRING", {"default": 'white'}),
+ },
+ "optional": {
+ "size_multiplier": ("FLOAT", {"default": [1.0], "forceInput": True}),
+ }
+ }
+
+ def createtextmask(self, coordinates, frame_width, frame_height, font, font_size, text, text_color, alignment, size_multiplier=[1.0]):
+ coordinates = coordinates.replace("'", '"')
+ coordinates = json.loads(coordinates)
+
+ batch_size = len(coordinates)
+ mask_list = []
+ image_list = []
+ color = text_color
+ font_path = folder_paths.get_full_path("kjnodes_fonts", font)
+
+ if len(size_multiplier) != batch_size:
+ size_multiplier = size_multiplier * (batch_size // len(size_multiplier)) + size_multiplier[:batch_size % len(size_multiplier)]
+
+ for i, coord in enumerate(coordinates):
+ image = Image.new("RGB", (frame_width, frame_height), "black")
+ draw = ImageDraw.Draw(image)
+ lines = text.split('\n') # Split the text into lines
+ # Apply the size multiplier to the font size for this iteration
+ current_font_size = int(font_size * size_multiplier[i])
+ current_font = ImageFont.truetype(font_path, current_font_size)
+ line_heights = [current_font.getbbox(line)[3] for line in lines] # List of line heights
+ total_text_height = sum(line_heights) # Total height of text block
+
+ # Calculate the starting Y position to center the block of text
+ start_y = coord['y'] - total_text_height // 2
+ for j, line in enumerate(lines):
+ text_width, text_height = current_font.getbbox(line)[2], line_heights[j]
+ if alignment == 'left':
+ location_x = coord['x']
+ elif alignment == 'center':
+ location_x = int(coord['x'] - text_width // 2)
+ elif alignment == 'right':
+ location_x = int(coord['x'] - text_width)
+
+ location_y = int(start_y + sum(line_heights[:j]))
+ text_position = (location_x, location_y)
+ # Draw the text
+ try:
+ draw.text(text_position, line, fill=color, font=current_font, features=['-liga'])
+ except:
+ draw.text(text_position, line, fill=color, font=current_font)
+
+ image = pil2tensor(image)
+ non_black_pixels = (image > 0).any(dim=-1)
+ mask = non_black_pixels.to(image.dtype)
+ mask_list.append(mask)
+ image_list.append(image)
+
+ out_images = torch.cat(image_list, dim=0).cpu().float()
+ out_masks = torch.cat(mask_list, dim=0)
+ return (out_images, out_masks, 1.0 - out_masks,)
+
+class CreateGradientFromCoords:
+
+ RETURN_TYPES = ("IMAGE", )
+ RETURN_NAMES = ("image", )
+ FUNCTION = "generate"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Creates a gradient image from coordinates.
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "coordinates": ("STRING", {"forceInput": True}),
+ "frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ "frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ "start_color": ("STRING", {"default": 'white'}),
+ "end_color": ("STRING", {"default": 'black'}),
+ "multiplier": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 100.0, "step": 0.01}),
+ },
+ }
+
+ def generate(self, coordinates, frame_width, frame_height, start_color, end_color, multiplier):
+ # Parse the coordinates
+ coordinates = json.loads(coordinates.replace("'", '"'))
+
+ # Create an image
+ image = Image.new("RGB", (frame_width, frame_height))
+ draw = ImageDraw.Draw(image)
+
+ # Extract start and end points for the gradient
+ start_coord = coordinates[0]
+ end_coord = coordinates[1]
+
+ start_color = ImageColor.getrgb(start_color)
+ end_color = ImageColor.getrgb(end_color)
+
+ # Calculate the gradient direction (vector)
+ gradient_direction = (end_coord['x'] - start_coord['x'], end_coord['y'] - start_coord['y'])
+ gradient_length = (gradient_direction[0] ** 2 + gradient_direction[1] ** 2) ** 0.5
+
+ # Iterate over each pixel in the image
+ for y in range(frame_height):
+ for x in range(frame_width):
+ # Calculate the projection of the point on the gradient line
+ point_vector = (x - start_coord['x'], y - start_coord['y'])
+ projection = (point_vector[0] * gradient_direction[0] + point_vector[1] * gradient_direction[1]) / gradient_length
+ projection = max(min(projection, gradient_length), 0) # Clamp the projection value
+
+ # Calculate the blend factor for the current pixel
+ blend = projection * multiplier / gradient_length
+
+ # Determine the color of the current pixel
+ color = (
+ int(start_color[0] + (end_color[0] - start_color[0]) * blend),
+ int(start_color[1] + (end_color[1] - start_color[1]) * blend),
+ int(start_color[2] + (end_color[2] - start_color[2]) * blend)
+ )
+
+ # Set the pixel color
+ draw.point((x, y), fill=color)
+
+ # Convert the PIL image to a tensor (assuming such a function exists in your context)
+ image_tensor = pil2tensor(image)
+
+ return (image_tensor,)
+
+class GradientToFloat:
+
+ RETURN_TYPES = ("FLOAT", "FLOAT",)
+ RETURN_NAMES = ("float_x", "float_y", )
+ FUNCTION = "sample"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Calculates list of floats from image.
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": ("IMAGE", ),
+ "steps": ("INT", {"default": 10, "min": 2, "max": 10000, "step": 1}),
+ },
+ }
+
+ def sample(self, image, steps):
+ # Assuming image is a tensor with shape [B, H, W, C]
+ B, H, W, C = image.shape
+
+ # Sample along the width axis (W)
+ w_intervals = torch.linspace(0, W - 1, steps=steps, dtype=torch.int64)
+ # Assuming we're sampling from the first batch and the first channel
+ w_sampled = image[0, :, w_intervals, 0]
+
+ # Sample along the height axis (H)
+ h_intervals = torch.linspace(0, H - 1, steps=steps, dtype=torch.int64)
+ # Assuming we're sampling from the first batch and the first channel
+ h_sampled = image[0, h_intervals, :, 0]
+
+ # Taking the mean across the height for width sampling, and across the width for height sampling
+ w_values = w_sampled.mean(dim=0).tolist()
+ h_values = h_sampled.mean(dim=1).tolist()
+
+ return (w_values, h_values)
+
+class MaskOrImageToWeight:
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "output_type": (
+ [
+ 'list',
+ 'pandas series',
+ 'tensor',
+ 'string'
+ ],
+ {
+ "default": 'list'
+ }),
+ },
+ "optional": {
+ "images": ("IMAGE",),
+ "masks": ("MASK",),
+ },
+
+ }
+ RETURN_TYPES = ("FLOAT", "STRING",)
+ FUNCTION = "execute"
+ CATEGORY = "KJNodes/weights"
+ DESCRIPTION = """
+Gets the mean values from mask or image batch
+and returns that as the selected output type.
+"""
+
+ def execute(self, output_type, images=None, masks=None):
+ mean_values = []
+ if masks is not None and images is None:
+ for mask in masks:
+ mean_values.append(mask.mean().item())
+ elif masks is None and images is not None:
+ for image in images:
+ mean_values.append(image.mean().item())
+ elif masks is not None and images is not None:
+ raise Exception("MaskOrImageToWeight: Use either mask or image input only.")
+
+ # Convert mean_values to the specified output_type
+ if output_type == 'list':
+ out = mean_values
+ elif output_type == 'pandas series':
+ try:
+ import pandas as pd
+ except:
+ raise Exception("MaskOrImageToWeight: pandas is not installed. Please install pandas to use this output_type")
+ out = pd.Series(mean_values),
+ elif output_type == 'tensor':
+ out = torch.tensor(mean_values, dtype=torch.float32),
+ return (out, [str(value) for value in mean_values],)
+
+class WeightScheduleConvert:
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "input_values": ("FLOAT", {"default": 0.0, "forceInput": True}),
+ "output_type": (
+ [
+ 'match_input',
+ 'list',
+ 'pandas series',
+ 'tensor',
+ ],
+ {
+ "default": 'list'
+ }),
+ "invert": ("BOOLEAN", {"default": False}),
+ "repeat": ("INT", {"default": 1,"min": 1, "max": 255, "step": 1}),
+ },
+ "optional": {
+ "remap_to_frames": ("INT", {"default": 0}),
+ "interpolation_curve": ("FLOAT", {"forceInput": True}),
+ "remap_values": ("BOOLEAN", {"default": False}),
+ "remap_min": ("FLOAT", {"default": 0.0, "min": -100000, "max": 100000.0, "step": 0.01}),
+ "remap_max": ("FLOAT", {"default": 1.0, "min": -100000, "max": 100000.0, "step": 0.01}),
+ },
+
+ }
+ RETURN_TYPES = ("FLOAT", "STRING", "INT",)
+ FUNCTION = "execute"
+ CATEGORY = "KJNodes/weights"
+ DESCRIPTION = """
+Converts different value lists/series to another type.
+"""
+
+ def detect_input_type(self, input_values):
+ import pandas as pd
+ if isinstance(input_values, list):
+ return 'list'
+ elif isinstance(input_values, pd.Series):
+ return 'pandas series'
+ elif isinstance(input_values, torch.Tensor):
+ return 'tensor'
+ else:
+ raise ValueError("Unsupported input type")
+
+ def execute(self, input_values, output_type, invert, repeat, remap_to_frames=0, interpolation_curve=None, remap_min=0.0, remap_max=1.0, remap_values=False):
+ import pandas as pd
+ input_type = self.detect_input_type(input_values)
+
+ if input_type == 'pandas series':
+ float_values = input_values.tolist()
+ elif input_type == 'tensor':
+ float_values = input_values
+ else:
+ float_values = input_values
+
+ if invert:
+ float_values = [1 - value for value in float_values]
+
+ if interpolation_curve is not None:
+ interpolated_pattern = []
+ orig_float_values = float_values
+ for value in interpolation_curve:
+ min_val = min(orig_float_values)
+ max_val = max(orig_float_values)
+ # Normalize the values to [0, 1]
+ normalized_values = [(value - min_val) / (max_val - min_val) for value in orig_float_values]
+ # Interpolate the normalized values to the new frame count
+ remapped_float_values = np.interp(np.linspace(0, 1, int(remap_to_frames * value)), np.linspace(0, 1, len(normalized_values)), normalized_values).tolist()
+ interpolated_pattern.extend(remapped_float_values)
+ float_values = interpolated_pattern
+ else:
+ # Remap float_values to match target_frame_amount
+ if remap_to_frames > 0 and remap_to_frames != len(float_values):
+ min_val = min(float_values)
+ max_val = max(float_values)
+ # Normalize the values to [0, 1]
+ normalized_values = [(value - min_val) / (max_val - min_val) for value in float_values]
+ # Interpolate the normalized values to the new frame count
+ float_values = np.interp(np.linspace(0, 1, remap_to_frames), np.linspace(0, 1, len(normalized_values)), normalized_values).tolist()
+
+ float_values = float_values * repeat
+ if remap_values:
+ float_values = self.remap_values(float_values, remap_min, remap_max)
+
+ if output_type == 'list':
+ out = float_values,
+ elif output_type == 'pandas series':
+ out = pd.Series(float_values),
+ elif output_type == 'tensor':
+ if input_type == 'pandas series':
+ out = torch.tensor(float_values.values, dtype=torch.float32),
+ else:
+ out = torch.tensor(float_values, dtype=torch.float32),
+ elif output_type == 'match_input':
+ out = float_values,
+ return (out, [str(value) for value in float_values], [int(value) for value in float_values])
+
+ def remap_values(self, values, target_min, target_max):
+ # Determine the current range
+ current_min = min(values)
+ current_max = max(values)
+ current_range = current_max - current_min
+
+ # Determine the target range
+ target_range = target_max - target_min
+
+ # Perform the linear interpolation for each value
+ remapped_values = [(value - current_min) / current_range * target_range + target_min for value in values]
+
+ return remapped_values
+
+
+class FloatToMask:
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "input_values": ("FLOAT", {"forceInput": True, "default": 0}),
+ "width": ("INT", {"default": 100, "min": 1}),
+ "height": ("INT", {"default": 100, "min": 1}),
+ },
+ }
+ RETURN_TYPES = ("MASK",)
+ FUNCTION = "execute"
+ CATEGORY = "KJNodes/masking/generate"
+ DESCRIPTION = """
+Generates a batch of masks based on the input float values.
+The batch size is determined by the length of the input float values.
+Each mask is generated with the specified width and height.
+"""
+
+ def execute(self, input_values, width, height):
+ import pandas as pd
+ # Ensure input_values is a list
+ if isinstance(input_values, (float, int)):
+ input_values = [input_values]
+ elif isinstance(input_values, pd.Series):
+ input_values = input_values.tolist()
+ elif isinstance(input_values, list) and all(isinstance(item, list) for item in input_values):
+ input_values = [item for sublist in input_values for item in sublist]
+
+ # Generate a batch of masks based on the input_values
+ masks = []
+ for value in input_values:
+ # Assuming value is a float between 0 and 1 representing the mask's intensity
+ mask = torch.ones((height, width), dtype=torch.float32) * value
+ masks.append(mask)
+ masks_out = torch.stack(masks, dim=0)
+
+ return(masks_out,)
+class WeightScheduleExtend:
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "input_values_1": ("FLOAT", {"default": 0.0, "forceInput": True}),
+ "input_values_2": ("FLOAT", {"default": 0.0, "forceInput": True}),
+ "output_type": (
+ [
+ 'match_input',
+ 'list',
+ 'pandas series',
+ 'tensor',
+ ],
+ {
+ "default": 'match_input'
+ }),
+ },
+
+ }
+ RETURN_TYPES = ("FLOAT",)
+ FUNCTION = "execute"
+ CATEGORY = "KJNodes/weights"
+ DESCRIPTION = """
+Extends, and converts if needed, different value lists/series
+"""
+
+ def detect_input_type(self, input_values):
+ import pandas as pd
+ if isinstance(input_values, list):
+ return 'list'
+ elif isinstance(input_values, pd.Series):
+ return 'pandas series'
+ elif isinstance(input_values, torch.Tensor):
+ return 'tensor'
+ else:
+ raise ValueError("Unsupported input type")
+
+ def execute(self, input_values_1, input_values_2, output_type):
+ import pandas as pd
+ input_type_1 = self.detect_input_type(input_values_1)
+ input_type_2 = self.detect_input_type(input_values_2)
+ # Convert input_values_2 to the same format as input_values_1 if they do not match
+ if not input_type_1 == input_type_2:
+ print("Converting input_values_2 to the same format as input_values_1")
+ if input_type_1 == 'pandas series':
+ # Convert input_values_2 to a pandas Series
+ float_values_2 = pd.Series(input_values_2)
+ elif input_type_1 == 'tensor':
+ # Convert input_values_2 to a tensor
+ float_values_2 = torch.tensor(input_values_2, dtype=torch.float32)
+ else:
+ print("Input types match, no conversion needed")
+ # If the types match, no conversion is needed
+ float_values_2 = input_values_2
+
+ float_values = input_values_1 + float_values_2
+
+ if output_type == 'list':
+ return float_values,
+ elif output_type == 'pandas series':
+ return pd.Series(float_values),
+ elif output_type == 'tensor':
+ if input_type_1 == 'pandas series':
+ return torch.tensor(float_values.values, dtype=torch.float32),
+ else:
+ return torch.tensor(float_values, dtype=torch.float32),
+ elif output_type == 'match_input':
+ return float_values,
+ else:
+ raise ValueError(f"Unsupported output_type: {output_type}")
+
+class FloatToSigmas:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required":
+ {
+ "float_list": ("FLOAT", {"default": 0.0, "forceInput": True}),
+ }
+ }
+ RETURN_TYPES = ("SIGMAS",)
+ RETURN_NAMES = ("SIGMAS",)
+ CATEGORY = "KJNodes/noise"
+ FUNCTION = "customsigmas"
+ DESCRIPTION = """
+Creates a sigmas tensor from list of float values.
+
+"""
+ def customsigmas(self, float_list):
+ return torch.tensor(float_list, dtype=torch.float32),
+
+class SigmasToFloat:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required":
+ {
+ "sigmas": ("SIGMAS",),
+ }
+ }
+ RETURN_TYPES = ("FLOAT",)
+ RETURN_NAMES = ("float",)
+ CATEGORY = "KJNodes/noise"
+ FUNCTION = "customsigmas"
+ DESCRIPTION = """
+Creates a float list from sigmas tensors.
+
+"""
+ def customsigmas(self, sigmas):
+ return sigmas.tolist(),
+
+class GLIGENTextBoxApplyBatchCoords:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {"conditioning_to": ("CONDITIONING", ),
+ "latents": ("LATENT", ),
+ "clip": ("CLIP", ),
+ "gligen_textbox_model": ("GLIGEN", ),
+ "coordinates": ("STRING", {"forceInput": True}),
+ "text": ("STRING", {"multiline": True}),
+ "width": ("INT", {"default": 128, "min": 8, "max": 4096, "step": 8}),
+ "height": ("INT", {"default": 128, "min": 8, "max": 4096, "step": 8}),
+ },
+ "optional": {"size_multiplier": ("FLOAT", {"default": [1.0], "forceInput": True})},
+ }
+ RETURN_TYPES = ("CONDITIONING", "IMAGE", )
+ RETURN_NAMES = ("conditioning", "coord_preview", )
+ FUNCTION = "append"
+ CATEGORY = "KJNodes/experimental"
+ DESCRIPTION = """
+This node allows scheduling GLIGEN text box positions in a batch,
+to be used with AnimateDiff-Evolved. Intended to pair with the
+Spline Editor -node.
+
+GLIGEN model can be downloaded through the Manage's "Install Models" menu.
+Or directly from here:
+https://huggingface.co/comfyanonymous/GLIGEN_pruned_safetensors/tree/main
+
+Inputs:
+- **latents** input is used to calculate batch size
+- **clip** is your standard text encoder, use same as for the main prompt
+- **gligen_textbox_model** connects to GLIGEN Loader
+- **coordinates** takes a json string of points, directly compatible
+with the spline editor node.
+- **text** is the part of the prompt to set position for
+- **width** and **height** are the size of the GLIGEN bounding box
+
+Outputs:
+- **conditioning** goes between to clip text encode and the sampler
+- **coord_preview** is an optional preview of the coordinates and
+bounding boxes.
+
+"""
+
+ def append(self, latents, coordinates, conditioning_to, clip, gligen_textbox_model, text, width, height, size_multiplier=[1.0]):
+ coordinates = json.loads(coordinates.replace("'", '"'))
+ coordinates = [(coord['x'], coord['y']) for coord in coordinates]
+
+ batch_size = sum(tensor.size(0) for tensor in latents.values())
+ if len(coordinates) != batch_size:
+ print("GLIGENTextBoxApplyBatchCoords WARNING: The number of coordinates does not match the number of latents")
+
+ c = []
+ _, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True)
+
+ for t in conditioning_to:
+ n = [t[0], t[1].copy()]
+
+ position_params_batch = [[] for _ in range(batch_size)] # Initialize a list of empty lists for each batch item
+ if len(size_multiplier) != batch_size:
+ size_multiplier = size_multiplier * (batch_size // len(size_multiplier)) + size_multiplier[:batch_size % len(size_multiplier)]
+
+ for i in range(batch_size):
+ x_position, y_position = coordinates[i]
+ position_param = (cond_pooled, int((height // 8) * size_multiplier[i]), int((width // 8) * size_multiplier[i]), (y_position - height // 2) // 8, (x_position - width // 2) // 8)
+ position_params_batch[i].append(position_param) # Append position_param to the correct sublist
+
+ prev = []
+ if "gligen" in n[1]:
+ prev = n[1]['gligen'][2]
+ else:
+ prev = [[] for _ in range(batch_size)]
+ # Concatenate prev and position_params_batch, ensuring both are lists of lists
+ # and each sublist corresponds to a batch item
+ combined_position_params = [prev_item + batch_item for prev_item, batch_item in zip(prev, position_params_batch)]
+ n[1]['gligen'] = ("position_batched", gligen_textbox_model, combined_position_params)
+ c.append(n)
+
+ image_height = latents['samples'].shape[-2] * 8
+ image_width = latents['samples'].shape[-1] * 8
+ plot_image_tensor = plot_coordinates_to_tensor(coordinates, image_height, image_width, height, width, size_multiplier, text)
+
+ return (c, plot_image_tensor,)
+
+class CreateInstanceDiffusionTracking:
+
+ RETURN_TYPES = ("TRACKING", "STRING", "INT", "INT", "INT", "INT",)
+ RETURN_NAMES = ("tracking", "prompt", "width", "height", "bbox_width", "bbox_height",)
+ FUNCTION = "tracking"
+ CATEGORY = "KJNodes/InstanceDiffusion"
+ DESCRIPTION = """
+Creates tracking data to be used with InstanceDiffusion:
+https://github.com/logtd/ComfyUI-InstanceDiffusion
+
+InstanceDiffusion prompt format:
+"class_id.class_name": "prompt",
+for example:
+"1.head": "((head))",
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "coordinates": ("STRING", {"forceInput": True}),
+ "width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ "height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ "bbox_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ "bbox_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ "class_name": ("STRING", {"default": "class_name"}),
+ "class_id": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}),
+ "prompt": ("STRING", {"default": "prompt", "multiline": True}),
+ },
+ "optional": {
+ "size_multiplier": ("FLOAT", {"default": [1.0], "forceInput": True}),
+ "fit_in_frame": ("BOOLEAN", {"default": True}),
+ }
+ }
+
+ def tracking(self, coordinates, class_name, class_id, width, height, bbox_width, bbox_height, prompt, size_multiplier=[1.0], fit_in_frame=True):
+ # Define the number of images in the batch
+ coordinates = coordinates.replace("'", '"')
+ coordinates = json.loads(coordinates)
+
+ tracked = {}
+ tracked[class_name] = {}
+ batch_size = len(coordinates)
+ # Initialize a list to hold the coordinates for the current ID
+ id_coordinates = []
+ if not size_multiplier or len(size_multiplier) != batch_size:
+ size_multiplier = [0] * batch_size
+ else:
+ size_multiplier = size_multiplier * (batch_size // len(size_multiplier)) + size_multiplier[:batch_size % len(size_multiplier)]
+ for i, coord in enumerate(coordinates):
+ x = coord['x']
+ y = coord['y']
+ adjusted_bbox_width = bbox_width * size_multiplier[i]
+ adjusted_bbox_height = bbox_height * size_multiplier[i]
+ # Calculate the top left and bottom right coordinates
+ top_left_x = x - adjusted_bbox_width // 2
+ top_left_y = y - adjusted_bbox_height // 2
+ bottom_right_x = x + adjusted_bbox_width // 2
+ bottom_right_y = y + adjusted_bbox_height // 2
+
+ if fit_in_frame:
+ # Clip the coordinates to the frame boundaries
+ top_left_x = max(0, top_left_x)
+ top_left_y = max(0, top_left_y)
+ bottom_right_x = min(width, bottom_right_x)
+ bottom_right_y = min(height, bottom_right_y)
+ # Ensure width and height are positive
+ adjusted_bbox_width = max(1, bottom_right_x - top_left_x)
+ adjusted_bbox_height = max(1, bottom_right_y - top_left_y)
+
+ # Update the coordinates with the new width and height
+ bottom_right_x = top_left_x + adjusted_bbox_width
+ bottom_right_y = top_left_y + adjusted_bbox_height
+
+ # Append the top left and bottom right coordinates to the list for the current ID
+ id_coordinates.append([top_left_x, top_left_y, bottom_right_x, bottom_right_y, width, height])
+
+ class_id = int(class_id)
+ # Assign the list of coordinates to the specified ID within the class_id dictionary
+ tracked[class_name][class_id] = id_coordinates
+
+ prompt_string = ""
+ for class_name, class_data in tracked.items():
+ for class_id in class_data.keys():
+ class_id_str = str(class_id)
+ # Use the incoming prompt for each class name and ID
+ prompt_string += f'"{class_id_str}.{class_name}": "({prompt})",\n'
+
+ # Remove the last comma and newline
+ prompt_string = prompt_string.rstrip(",\n")
+
+ return (tracked, prompt_string, width, height, bbox_width, bbox_height)
+
+class AppendInstanceDiffusionTracking:
+
+ RETURN_TYPES = ("TRACKING", "STRING",)
+ RETURN_NAMES = ("tracking", "prompt",)
+ FUNCTION = "append"
+ CATEGORY = "KJNodes/InstanceDiffusion"
+ DESCRIPTION = """
+Appends tracking data to be used with InstanceDiffusion:
+https://github.com/logtd/ComfyUI-InstanceDiffusion
+
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "tracking_1": ("TRACKING", {"forceInput": True}),
+ "tracking_2": ("TRACKING", {"forceInput": True}),
+ },
+ "optional": {
+ "prompt_1": ("STRING", {"default": "", "forceInput": True}),
+ "prompt_2": ("STRING", {"default": "", "forceInput": True}),
+ }
+ }
+
+ def append(self, tracking_1, tracking_2, prompt_1="", prompt_2=""):
+ tracking_copy = tracking_1.copy()
+ # Check for existing class names and class IDs, and raise an error if they exist
+ for class_name, class_data in tracking_2.items():
+ if class_name not in tracking_copy:
+ tracking_copy[class_name] = class_data
+ else:
+ # If the class name exists, merge the class data from tracking_2 into tracking_copy
+ # This will add new class IDs under the same class name without raising an error
+ tracking_copy[class_name].update(class_data)
+ prompt_string = prompt_1 + "," + prompt_2
+ return (tracking_copy, prompt_string)
+
+class InterpolateCoords:
+
+ RETURN_TYPES = ("STRING",)
+ RETURN_NAMES = ("coordinates",)
+ FUNCTION = "interpolate"
+ CATEGORY = "KJNodes/experimental"
+ DESCRIPTION = """
+Interpolates coordinates based on a curve.
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "coordinates": ("STRING", {"forceInput": True}),
+ "interpolation_curve": ("FLOAT", {"forceInput": True}),
+
+ },
+ }
+
+ def interpolate(self, coordinates, interpolation_curve):
+ # Parse the JSON string to get the list of coordinates
+ coordinates = json.loads(coordinates.replace("'", '"'))
+
+ # Convert the list of dictionaries to a list of (x, y) tuples for easier processing
+ coordinates = [(coord['x'], coord['y']) for coord in coordinates]
+
+ # Calculate the total length of the original path
+ path_length = sum(np.linalg.norm(np.array(coordinates[i]) - np.array(coordinates[i-1]))
+ for i in range(1, len(coordinates)))
+
+ # Initialize variables for interpolation
+ interpolated_coords = []
+ current_length = 0
+ current_index = 0
+
+ # Iterate over the normalized curve
+ for normalized_length in interpolation_curve:
+ target_length = normalized_length * path_length # Convert to the original scale
+ while current_index < len(coordinates) - 1:
+ segment_start, segment_end = np.array(coordinates[current_index]), np.array(coordinates[current_index + 1])
+ segment_length = np.linalg.norm(segment_end - segment_start)
+ if current_length + segment_length >= target_length:
+ break
+ current_length += segment_length
+ current_index += 1
+
+ # Interpolate between the last two points
+ if current_index < len(coordinates) - 1:
+ p1, p2 = np.array(coordinates[current_index]), np.array(coordinates[current_index + 1])
+ segment_length = np.linalg.norm(p2 - p1)
+ if segment_length > 0:
+ t = (target_length - current_length) / segment_length
+ interpolated_point = p1 + t * (p2 - p1)
+ interpolated_coords.append(interpolated_point.tolist())
+ else:
+ interpolated_coords.append(p1.tolist())
+ else:
+ # If the target_length is at or beyond the end of the path, add the last coordinate
+ interpolated_coords.append(coordinates[-1])
+
+ # Convert back to string format if necessary
+ interpolated_coords_str = "[" + ", ".join([f"{{'x': {round(coord[0])}, 'y': {round(coord[1])}}}" for coord in interpolated_coords]) + "]"
+ print(interpolated_coords_str)
+
+ return (interpolated_coords_str,)
+
+class DrawInstanceDiffusionTracking:
+
+ RETURN_TYPES = ("IMAGE",)
+ RETURN_NAMES = ("image", )
+ FUNCTION = "draw"
+ CATEGORY = "KJNodes/InstanceDiffusion"
+ DESCRIPTION = """
+Draws the tracking data from
+CreateInstanceDiffusionTracking -node.
+
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": ("IMAGE", ),
+ "tracking": ("TRACKING", {"forceInput": True}),
+ "box_line_width": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}),
+ "draw_text": ("BOOLEAN", {"default": True}),
+ "font": (folder_paths.get_filename_list("kjnodes_fonts"), ),
+ "font_size": ("INT", {"default": 20}),
+ },
+ }
+
+ def draw(self, image, tracking, box_line_width, draw_text, font, font_size):
+ import matplotlib.cm as cm
+
+ modified_images = []
+
+ colormap = cm.get_cmap('rainbow', len(tracking))
+ if draw_text:
+ font_path = folder_paths.get_full_path("kjnodes_fonts", font)
+ font = ImageFont.truetype(font_path, font_size)
+
+ # Iterate over each image in the batch
+ for i in range(image.shape[0]):
+ # Extract the current image and convert it to a PIL image
+ current_image = image[i, :, :, :].permute(2, 0, 1)
+ pil_image = transforms.ToPILImage()(current_image)
+
+ draw = ImageDraw.Draw(pil_image)
+
+ # Iterate over the bounding boxes for the current image
+ for j, (class_name, class_data) in enumerate(tracking.items()):
+ for class_id, bbox_list in class_data.items():
+ # Check if the current index is within the bounds of the bbox_list
+ if i < len(bbox_list):
+ bbox = bbox_list[i]
+ # Ensure bbox is a list or tuple before unpacking
+ if isinstance(bbox, (list, tuple)):
+ x1, y1, x2, y2, _, _ = bbox
+ # Convert coordinates to integers
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
+ # Generate a color from the rainbow colormap
+ color = tuple(int(255 * x) for x in colormap(j / len(tracking)))[:3]
+ # Draw the bounding box on the image with the generated color
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=box_line_width)
+ if draw_text:
+ # Draw the class name and ID as text above the box with the generated color
+ text = f"{class_id}.{class_name}"
+ # Calculate the width and height of the text
+ _, _, text_width, text_height = draw.textbbox((0, 0), text=text, font=font)
+ # Position the text above the top-left corner of the box
+ text_position = (x1, y1 - text_height)
+ draw.text(text_position, text, fill=color, font=font)
+ else:
+ print(f"Unexpected data type for bbox: {type(bbox)}")
+
+ # Convert the drawn image back to a torch tensor and adjust back to (H, W, C)
+ modified_image_tensor = transforms.ToTensor()(pil_image).permute(1, 2, 0)
+ modified_images.append(modified_image_tensor)
+
+ # Stack the modified images back into a batch
+ image_tensor_batch = torch.stack(modified_images).cpu().float()
+
+ return image_tensor_batch,
+
+class PointsEditor:
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "points_store": ("STRING", {"multiline": False}),
+ "coordinates": ("STRING", {"multiline": False}),
+ "neg_coordinates": ("STRING", {"multiline": False}),
+ "bbox_store": ("STRING", {"multiline": False}),
+ "bboxes": ("STRING", {"multiline": False}),
+ "bbox_format": (
+ [
+ 'xyxy',
+ 'xywh',
+ ],
+ ),
+ "width": ("INT", {"default": 512, "min": 8, "max": 4096, "step": 8}),
+ "height": ("INT", {"default": 512, "min": 8, "max": 4096, "step": 8}),
+ "normalize": ("BOOLEAN", {"default": False}),
+ },
+ "optional": {
+ "bg_image": ("IMAGE", ),
+ },
+ }
+
+ RETURN_TYPES = ("STRING", "STRING", "BBOX", "MASK", "IMAGE")
+ RETURN_NAMES = ("positive_coords", "negative_coords", "bbox", "bbox_mask", "cropped_image")
+ FUNCTION = "pointdata"
+ CATEGORY = "KJNodes/experimental"
+ DESCRIPTION = """
+# WORK IN PROGRESS
+Do not count on this as part of your workflow yet,
+probably contains lots of bugs and stability is not
+guaranteed!!
+
+## Graphical editor to create coordinates
+
+**Shift + click** to add a positive (green) point.
+**Shift + right click** to add a negative (red) point.
+**Ctrl + click** to draw a box.
+**Right click on a point** to delete it.
+Note that you can't delete from start/end of the points array.
+
+To add an image select the node and copy/paste or drag in the image.
+Or from the bg_image input on queue (first frame of the batch).
+
+**THE IMAGE IS SAVED TO THE NODE AND WORKFLOW METADATA**
+you can clear the image from the context menu by right clicking on the canvas
+
+"""
+
+ def pointdata(self, points_store, bbox_store, width, height, coordinates, neg_coordinates, normalize, bboxes, bbox_format="xyxy", bg_image=None):
+ coordinates = json.loads(coordinates)
+ pos_coordinates = []
+ for coord in coordinates:
+ coord['x'] = int(round(coord['x']))
+ coord['y'] = int(round(coord['y']))
+ if normalize:
+ norm_x = coord['x'] / width
+ norm_y = coord['y'] / height
+ pos_coordinates.append({'x': norm_x, 'y': norm_y})
+ else:
+ pos_coordinates.append({'x': coord['x'], 'y': coord['y']})
+
+ if neg_coordinates:
+ coordinates = json.loads(neg_coordinates)
+ neg_coordinates = []
+ for coord in coordinates:
+ coord['x'] = int(round(coord['x']))
+ coord['y'] = int(round(coord['y']))
+ if normalize:
+ norm_x = coord['x'] / width
+ norm_y = coord['y'] / height
+ neg_coordinates.append({'x': norm_x, 'y': norm_y})
+ else:
+ neg_coordinates.append({'x': coord['x'], 'y': coord['y']})
+
+ # Create a blank mask
+ mask = np.zeros((height, width), dtype=np.uint8)
+ bboxes = json.loads(bboxes)
+ print(bboxes)
+ valid_bboxes = []
+ for bbox in bboxes:
+ if (bbox.get("startX") is None or
+ bbox.get("startY") is None or
+ bbox.get("endX") is None or
+ bbox.get("endY") is None):
+ continue # Skip this bounding box if any value is None
+ else:
+ # Ensure that endX and endY are greater than startX and startY
+ x_min = min(int(bbox["startX"]), int(bbox["endX"]))
+ y_min = min(int(bbox["startY"]), int(bbox["endY"]))
+ x_max = max(int(bbox["startX"]), int(bbox["endX"]))
+ y_max = max(int(bbox["startY"]), int(bbox["endY"]))
+
+ valid_bboxes.append((x_min, y_min, x_max, y_max))
+
+ bboxes_xyxy = []
+ for bbox in valid_bboxes:
+ x_min, y_min, x_max, y_max = bbox
+ bboxes_xyxy.append((x_min, y_min, x_max, y_max))
+ mask[y_min:y_max, x_min:x_max] = 1 # Fill the bounding box area with 1s
+
+ if bbox_format == "xywh":
+ bboxes_xywh = []
+ for bbox in valid_bboxes:
+ x_min, y_min, x_max, y_max = bbox
+ width = x_max - x_min
+ height = y_max - y_min
+ bboxes_xywh.append((x_min, y_min, width, height))
+ bboxes = bboxes_xywh
+ else:
+ bboxes = bboxes_xyxy
+
+ mask_tensor = torch.from_numpy(mask)
+ mask_tensor = mask_tensor.unsqueeze(0).float().cpu()
+
+ if bg_image is not None and len(valid_bboxes) > 0:
+ x_min, y_min, x_max, y_max = bboxes[0]
+ cropped_image = bg_image[:, y_min:y_max, x_min:x_max, :]
+
+ elif bg_image is not None:
+ cropped_image = bg_image
+
+ if bg_image is None:
+ return (json.dumps(pos_coordinates), json.dumps(neg_coordinates), bboxes, mask_tensor)
+ else:
+ transform = transforms.ToPILImage()
+ image = transform(bg_image[0].permute(2, 0, 1))
+ buffered = io.BytesIO()
+ image.save(buffered, format="JPEG", quality=75)
+
+ # Step 3: Encode the image bytes to a Base64 string
+ img_bytes = buffered.getvalue()
+ img_base64 = base64.b64encode(img_bytes).decode('utf-8')
+
+ return {
+ "ui": {"bg_image": [img_base64]},
+ "result": (json.dumps(pos_coordinates), json.dumps(neg_coordinates), bboxes, mask_tensor, cropped_image)
+ }
+
+class CutAndDragOnPath:
+ RETURN_TYPES = ("IMAGE", "MASK",)
+ RETURN_NAMES = ("image","mask", )
+ FUNCTION = "cutanddrag"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Cuts the masked area from the image, and drags it along the path. If inpaint is enabled, and no bg_image is provided, the cut area is filled using cv2 TELEA algorithm.
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": ("IMAGE",),
+ "coordinates": ("STRING", {"forceInput": True}),
+ "mask": ("MASK",),
+ "frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ "frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ "inpaint": ("BOOLEAN", {"default": True}),
+ },
+ "optional": {
+ "bg_image": ("IMAGE",),
+ }
+ }
+
+ def cutanddrag(self, image, coordinates, mask, frame_width, frame_height, inpaint, bg_image=None):
+ # Parse coordinates
+ if len(coordinates) < 10:
+ coords_list = []
+ for coords in coordinates:
+ coords = json.loads(coords.replace("'", '"'))
+ coords_list.append(coords)
+ else:
+ coords = json.loads(coordinates.replace("'", '"'))
+ coords_list = [coords]
+
+ batch_size = len(coords_list[0])
+ images_list = []
+ masks_list = []
+
+ # Convert input image and mask to PIL
+ input_image = tensor2pil(image)[0]
+ input_mask = tensor2pil(mask)[0]
+
+ # Find masked region bounds
+ mask_array = np.array(input_mask)
+ y_indices, x_indices = np.where(mask_array > 0)
+ if len(x_indices) == 0 or len(y_indices) == 0:
+ return (image, mask)
+
+ x_min, x_max = x_indices.min(), x_indices.max()
+ y_min, y_max = y_indices.min(), y_indices.max()
+
+ # Cut out the masked region
+ cut_width = x_max - x_min
+ cut_height = y_max - y_min
+ cut_image = input_image.crop((x_min, y_min, x_max, y_max))
+ cut_mask = input_mask.crop((x_min, y_min, x_max, y_max))
+
+ # Create inpainted background
+ if bg_image is None:
+ background = input_image.copy()
+ # Inpaint the cut area
+ if inpaint:
+ import cv2
+ border = 5 # Create small border around cut area for better inpainting
+ fill_mask = Image.new("L", background.size, 0)
+ draw = ImageDraw.Draw(fill_mask)
+ draw.rectangle([x_min-border, y_min-border, x_max+border, y_max+border], fill=255)
+ background = cv2.inpaint(
+ np.array(background),
+ np.array(fill_mask),
+ inpaintRadius=3,
+ flags=cv2.INPAINT_TELEA
+ )
+ background = Image.fromarray(background)
+ else:
+ background = tensor2pil(bg_image)[0]
+
+ # Create batch of images with cut region at different positions
+ for i in range(batch_size):
+ # Create new image
+ new_image = background.copy()
+ new_mask = Image.new("L", (frame_width, frame_height), 0)
+
+ # Get target position from coordinates
+ for coords in coords_list:
+ target_x = int(coords[i]['x'] - cut_width/2)
+ target_y = int(coords[i]['y'] - cut_height/2)
+
+ # Paste cut region at new position
+ new_image.paste(cut_image, (target_x, target_y), cut_mask)
+ new_mask.paste(cut_mask, (target_x, target_y))
+
+ # Convert to tensor and append
+ image_tensor = pil2tensor(new_image)
+ mask_tensor = pil2tensor(new_mask)
+
+ images_list.append(image_tensor)
+ masks_list.append(mask_tensor)
+
+ # Stack tensors into batches
+ out_images = torch.cat(images_list, dim=0).cpu().float()
+ out_masks = torch.cat(masks_list, dim=0)
+
+ return (out_images, out_masks)
\ No newline at end of file
diff --git a/custom_nodes/ComfyUI-KJNodes-main/nodes/image_nodes.py b/custom_nodes/ComfyUI-KJNodes-main/nodes/image_nodes.py
new file mode 100644
index 0000000000000000000000000000000000000000..74570e1ad071815a3d6cf4a9878e752d7a4196fe
--- /dev/null
+++ b/custom_nodes/ComfyUI-KJNodes-main/nodes/image_nodes.py
@@ -0,0 +1,3157 @@
+import numpy as np
+import time
+import torch
+import torch.nn.functional as F
+import torchvision.transforms as T
+import io
+import base64
+import random
+import math
+import os
+import re
+import json
+from PIL.PngImagePlugin import PngInfo
+try:
+ import cv2
+except:
+ print("OpenCV not installed")
+ pass
+from PIL import ImageGrab, ImageDraw, ImageFont, Image, ImageSequence, ImageOps
+
+from nodes import MAX_RESOLUTION, SaveImage
+from comfy_extras.nodes_mask import ImageCompositeMasked
+from comfy.cli_args import args
+from comfy.utils import ProgressBar, common_upscale
+import folder_paths
+import model_management
+
+script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+
+class ImagePass:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ },
+ "optional": {
+ "image": ("IMAGE",),
+ },
+ }
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "passthrough"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Passes the image through without modifying it.
+"""
+
+ def passthrough(self, image=None):
+ return image,
+
+class ColorMatch:
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "image_ref": ("IMAGE",),
+ "image_target": ("IMAGE",),
+ "method": (
+ [
+ 'mkl',
+ 'hm',
+ 'reinhard',
+ 'mvgd',
+ 'hm-mvgd-hm',
+ 'hm-mkl-hm',
+ ], {
+ "default": 'mkl'
+ }),
+ },
+ "optional": {
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ }
+ }
+
+ CATEGORY = "KJNodes/image"
+
+ RETURN_TYPES = ("IMAGE",)
+ RETURN_NAMES = ("image",)
+ FUNCTION = "colormatch"
+ DESCRIPTION = """
+color-matcher enables color transfer across images which comes in handy for automatic
+color-grading of photographs, paintings and film sequences as well as light-field
+and stopmotion corrections.
+
+The methods behind the mappings are based on the approach from Reinhard et al.,
+the Monge-Kantorovich Linearization (MKL) as proposed by Pitie et al. and our analytical solution
+to a Multi-Variate Gaussian Distribution (MVGD) transfer in conjunction with classical histogram
+matching. As shown below our HM-MVGD-HM compound outperforms existing methods.
+https://github.com/hahnec/color-matcher/
+
+"""
+
+ def colormatch(self, image_ref, image_target, method, strength=1.0):
+ try:
+ from color_matcher import ColorMatcher
+ except:
+ raise Exception("Can't import color-matcher, did you install requirements.txt? Manual install: pip install color-matcher")
+ cm = ColorMatcher()
+ image_ref = image_ref.cpu()
+ image_target = image_target.cpu()
+ batch_size = image_target.size(0)
+ out = []
+ images_target = image_target.squeeze()
+ images_ref = image_ref.squeeze()
+
+ image_ref_np = images_ref.numpy()
+ images_target_np = images_target.numpy()
+
+ if image_ref.size(0) > 1 and image_ref.size(0) != batch_size:
+ raise ValueError("ColorMatch: Use either single reference image or a matching batch of reference images.")
+
+ for i in range(batch_size):
+ image_target_np = images_target_np if batch_size == 1 else images_target[i].numpy()
+ image_ref_np_i = image_ref_np if image_ref.size(0) == 1 else images_ref[i].numpy()
+ try:
+ image_result = cm.transfer(src=image_target_np, ref=image_ref_np_i, method=method)
+ except BaseException as e:
+ print(f"Error occurred during transfer: {e}")
+ break
+ # Apply the strength multiplier
+ image_result = image_target_np + strength * (image_result - image_target_np)
+ out.append(torch.from_numpy(image_result))
+
+ out = torch.stack(out, dim=0).to(torch.float32)
+ out.clamp_(0, 1)
+ return (out,)
+
+class SaveImageWithAlpha:
+ def __init__(self):
+ self.output_dir = folder_paths.get_output_directory()
+ self.type = "output"
+ self.prefix_append = ""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required":
+ {"images": ("IMAGE", ),
+ "mask": ("MASK", ),
+ "filename_prefix": ("STRING", {"default": "ComfyUI"})},
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
+ }
+
+ RETURN_TYPES = ()
+ FUNCTION = "save_images_alpha"
+ OUTPUT_NODE = True
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Saves an image and mask as .PNG with the mask as the alpha channel.
+"""
+
+ def save_images_alpha(self, images, mask, filename_prefix="ComfyUI_image_with_alpha", prompt=None, extra_pnginfo=None):
+ from PIL.PngImagePlugin import PngInfo
+ filename_prefix += self.prefix_append
+ full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
+ results = list()
+ if mask.dtype == torch.float16:
+ mask = mask.to(torch.float32)
+ def file_counter():
+ max_counter = 0
+ # Loop through the existing files
+ for existing_file in os.listdir(full_output_folder):
+ # Check if the file matches the expected format
+ match = re.fullmatch(fr"{filename}_(\d+)_?\.[a-zA-Z0-9]+", existing_file)
+ if match:
+ # Extract the numeric portion of the filename
+ file_counter = int(match.group(1))
+ # Update the maximum counter value if necessary
+ if file_counter > max_counter:
+ max_counter = file_counter
+ return max_counter
+
+ for image, alpha in zip(images, mask):
+ i = 255. * image.cpu().numpy()
+ a = 255. * alpha.cpu().numpy()
+ img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
+
+ # Resize the mask to match the image size
+ a_resized = Image.fromarray(a).resize(img.size, Image.LANCZOS)
+ a_resized = np.clip(a_resized, 0, 255).astype(np.uint8)
+ img.putalpha(Image.fromarray(a_resized, mode='L'))
+ metadata = None
+ if not args.disable_metadata:
+ metadata = PngInfo()
+ if prompt is not None:
+ metadata.add_text("prompt", json.dumps(prompt))
+ if extra_pnginfo is not None:
+ for x in extra_pnginfo:
+ metadata.add_text(x, json.dumps(extra_pnginfo[x]))
+
+ # Increment the counter by 1 to get the next available value
+ counter = file_counter() + 1
+ file = f"{filename}_{counter:05}.png"
+ img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4)
+ results.append({
+ "filename": file,
+ "subfolder": subfolder,
+ "type": self.type
+ })
+
+ return { "ui": { "images": results } }
+
+class ImageConcanate:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "image1": ("IMAGE",),
+ "image2": ("IMAGE",),
+ "direction": (
+ [ 'right',
+ 'down',
+ 'left',
+ 'up',
+ ],
+ {
+ "default": 'right'
+ }),
+ "match_image_size": ("BOOLEAN", {"default": True}),
+ }}
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "concatenate"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Concatenates the image2 to image1 in the specified direction.
+"""
+
+ def concatenate(self, image1, image2, direction, match_image_size, first_image_shape=None):
+ # Check if the batch sizes are different
+ batch_size1 = image1.shape[0]
+ batch_size2 = image2.shape[0]
+
+ if batch_size1 != batch_size2:
+ # Calculate the number of repetitions needed
+ max_batch_size = max(batch_size1, batch_size2)
+ repeats1 = max_batch_size - batch_size1
+ repeats2 = max_batch_size - batch_size2
+
+ # Repeat the last image to match the largest batch size
+ if repeats1 > 0:
+ last_image1 = image1[-1].unsqueeze(0).repeat(repeats1, 1, 1, 1)
+ image1 = torch.cat([image1.clone(), last_image1], dim=0)
+ if repeats2 > 0:
+ last_image2 = image2[-1].unsqueeze(0).repeat(repeats2, 1, 1, 1)
+ image2 = torch.cat([image2.clone(), last_image2], dim=0)
+
+ if match_image_size:
+ # Use first_image_shape if provided; otherwise, default to image1's shape
+ target_shape = first_image_shape if first_image_shape is not None else image1.shape
+
+ original_height = image2.shape[1]
+ original_width = image2.shape[2]
+ original_aspect_ratio = original_width / original_height
+
+ if direction in ['left', 'right']:
+ # Match the height and adjust the width to preserve aspect ratio
+ target_height = target_shape[1] # B, H, W, C format
+ target_width = int(target_height * original_aspect_ratio)
+ elif direction in ['up', 'down']:
+ # Match the width and adjust the height to preserve aspect ratio
+ target_width = target_shape[2] # B, H, W, C format
+ target_height = int(target_width / original_aspect_ratio)
+
+ # Adjust image2 to the expected format for common_upscale
+ image2_for_upscale = image2.movedim(-1, 1) # Move C to the second position (B, C, H, W)
+
+ # Resize image2 to match the target size while preserving aspect ratio
+ image2_resized = common_upscale(image2_for_upscale, target_width, target_height, "lanczos", "disabled")
+
+ # Adjust image2 back to the original format (B, H, W, C) after resizing
+ image2_resized = image2_resized.movedim(1, -1)
+ else:
+ image2_resized = image2
+
+ # Ensure both images have the same number of channels
+ channels_image1 = image1.shape[-1]
+ channels_image2 = image2_resized.shape[-1]
+
+ if channels_image1 != channels_image2:
+ if channels_image1 < channels_image2:
+ # Add alpha channel to image1 if image2 has it
+ alpha_channel = torch.ones((*image1.shape[:-1], channels_image2 - channels_image1), device=image1.device)
+ image1 = torch.cat((image1, alpha_channel), dim=-1)
+ else:
+ # Add alpha channel to image2 if image1 has it
+ alpha_channel = torch.ones((*image2_resized.shape[:-1], channels_image1 - channels_image2), device=image2_resized.device)
+ image2_resized = torch.cat((image2_resized, alpha_channel), dim=-1)
+
+
+ # Concatenate based on the specified direction
+ if direction == 'right':
+ concatenated_image = torch.cat((image1, image2_resized), dim=2) # Concatenate along width
+ elif direction == 'down':
+ concatenated_image = torch.cat((image1, image2_resized), dim=1) # Concatenate along height
+ elif direction == 'left':
+ concatenated_image = torch.cat((image2_resized, image1), dim=2) # Concatenate along width
+ elif direction == 'up':
+ concatenated_image = torch.cat((image2_resized, image1), dim=1) # Concatenate along height
+ return concatenated_image,
+
+import torch # Make sure you have PyTorch installed
+
+class ImageConcatFromBatch:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "images": ("IMAGE",),
+ "num_columns": ("INT", {"default": 3, "min": 1, "max": 255, "step": 1}),
+ "match_image_size": ("BOOLEAN", {"default": False}),
+ "max_resolution": ("INT", {"default": 4096}),
+ },
+ }
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "concat"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+ Concatenates images from a batch into a grid with a specified number of columns.
+ """
+
+ def concat(self, images, num_columns, match_image_size, max_resolution):
+ # Assuming images is a batch of images (B, H, W, C)
+ batch_size, height, width, channels = images.shape
+ num_rows = (batch_size + num_columns - 1) // num_columns # Calculate number of rows
+
+ print(f"Initial dimensions: batch_size={batch_size}, height={height}, width={width}, channels={channels}")
+ print(f"num_rows={num_rows}, num_columns={num_columns}")
+
+ if match_image_size:
+ target_shape = images[0].shape
+
+ resized_images = []
+ for image in images:
+ original_height = image.shape[0]
+ original_width = image.shape[1]
+ original_aspect_ratio = original_width / original_height
+
+ if original_aspect_ratio > 1:
+ target_height = target_shape[0]
+ target_width = int(target_height * original_aspect_ratio)
+ else:
+ target_width = target_shape[1]
+ target_height = int(target_width / original_aspect_ratio)
+
+ print(f"Resizing image from ({original_height}, {original_width}) to ({target_height}, {target_width})")
+
+ # Resize the image to match the target size while preserving aspect ratio
+ resized_image = common_upscale(image.movedim(-1, 0), target_width, target_height, "lanczos", "disabled")
+ resized_image = resized_image.movedim(0, -1) # Move channels back to the last dimension
+ resized_images.append(resized_image)
+
+ # Convert the list of resized images back to a tensor
+ images = torch.stack(resized_images)
+
+ height, width = target_shape[:2] # Update height and width
+
+ # Initialize an empty grid
+ grid_height = num_rows * height
+ grid_width = num_columns * width
+
+ print(f"Grid dimensions before scaling: grid_height={grid_height}, grid_width={grid_width}")
+
+ # Original scale factor calculation remains unchanged
+ scale_factor = min(max_resolution / grid_height, max_resolution / grid_width, 1.0)
+
+ # Apply scale factor to height and width
+ scaled_height = height * scale_factor
+ scaled_width = width * scale_factor
+
+ # Round scaled dimensions to the nearest number divisible by 8
+ height = max(1, int(round(scaled_height / 8) * 8))
+ width = max(1, int(round(scaled_width / 8) * 8))
+
+ if abs(scaled_height - height) > 4:
+ height = max(1, int(round((scaled_height + 4) / 8) * 8))
+ if abs(scaled_width - width) > 4:
+ width = max(1, int(round((scaled_width + 4) / 8) * 8))
+
+ # Recalculate grid dimensions with adjusted height and width
+ grid_height = num_rows * height
+ grid_width = num_columns * width
+ print(f"Grid dimensions after scaling: grid_height={grid_height}, grid_width={grid_width}")
+ print(f"Final image dimensions: height={height}, width={width}")
+
+ grid = torch.zeros((grid_height, grid_width, channels), dtype=images.dtype)
+
+ for idx, image in enumerate(images):
+ resized_image = torch.nn.functional.interpolate(image.unsqueeze(0).permute(0, 3, 1, 2), size=(height, width), mode="bilinear").squeeze().permute(1, 2, 0)
+ row = idx // num_columns
+ col = idx % num_columns
+ grid[row*height:(row+1)*height, col*width:(col+1)*width, :] = resized_image
+
+ return grid.unsqueeze(0),
+
+
+class ImageGridComposite2x2:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "image1": ("IMAGE",),
+ "image2": ("IMAGE",),
+ "image3": ("IMAGE",),
+ "image4": ("IMAGE",),
+ }}
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "compositegrid"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Concatenates the 4 input images into a 2x2 grid.
+"""
+
+ def compositegrid(self, image1, image2, image3, image4):
+ top_row = torch.cat((image1, image2), dim=2)
+ bottom_row = torch.cat((image3, image4), dim=2)
+ grid = torch.cat((top_row, bottom_row), dim=1)
+ return (grid,)
+
+class ImageGridComposite3x3:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "image1": ("IMAGE",),
+ "image2": ("IMAGE",),
+ "image3": ("IMAGE",),
+ "image4": ("IMAGE",),
+ "image5": ("IMAGE",),
+ "image6": ("IMAGE",),
+ "image7": ("IMAGE",),
+ "image8": ("IMAGE",),
+ "image9": ("IMAGE",),
+ }}
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "compositegrid"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Concatenates the 9 input images into a 3x3 grid.
+"""
+
+ def compositegrid(self, image1, image2, image3, image4, image5, image6, image7, image8, image9):
+ top_row = torch.cat((image1, image2, image3), dim=2)
+ mid_row = torch.cat((image4, image5, image6), dim=2)
+ bottom_row = torch.cat((image7, image8, image9), dim=2)
+ grid = torch.cat((top_row, mid_row, bottom_row), dim=1)
+ return (grid,)
+
+class ImageBatchTestPattern:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "batch_size": ("INT", {"default": 1,"min": 1, "max": 255, "step": 1}),
+ "start_from": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}),
+ "text_x": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}),
+ "text_y": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}),
+ "width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ "height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ "font": (folder_paths.get_filename_list("kjnodes_fonts"), ),
+ "font_size": ("INT", {"default": 255,"min": 8, "max": 4096, "step": 1}),
+ }}
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "generatetestpattern"
+ CATEGORY = "KJNodes/text"
+
+ def generatetestpattern(self, batch_size, font, font_size, start_from, width, height, text_x, text_y):
+ out = []
+ # Generate the sequential numbers for each image
+ numbers = np.arange(start_from, start_from + batch_size)
+ font_path = folder_paths.get_full_path("kjnodes_fonts", font)
+
+ for number in numbers:
+ # Create a black image with the number as a random color text
+ image = Image.new("RGB", (width, height), color='black')
+ draw = ImageDraw.Draw(image)
+
+ # Generate a random color for the text
+ font_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
+
+ font = ImageFont.truetype(font_path, font_size)
+
+ # Get the size of the text and position it in the center
+ text = str(number)
+
+ try:
+ draw.text((text_x, text_y), text, font=font, fill=font_color, features=['-liga'])
+ except:
+ draw.text((text_x, text_y), text, font=font, fill=font_color,)
+
+ # Convert the image to a numpy array and normalize the pixel values
+ image_np = np.array(image).astype(np.float32) / 255.0
+ image_tensor = torch.from_numpy(image_np).unsqueeze(0)
+ out.append(image_tensor)
+ out_tensor = torch.cat(out, dim=0)
+
+ return (out_tensor,)
+
+class ImageGrabPIL:
+
+ @classmethod
+ def IS_CHANGED(cls):
+
+ return
+
+ RETURN_TYPES = ("IMAGE",)
+ RETURN_NAMES = ("image",)
+ FUNCTION = "screencap"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Captures an area specified by screen coordinates.
+Can be used for realtime diffusion with autoqueue.
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "x": ("INT", {"default": 0,"min": 0, "max": 4096, "step": 1}),
+ "y": ("INT", {"default": 0,"min": 0, "max": 4096, "step": 1}),
+ "width": ("INT", {"default": 512,"min": 0, "max": 4096, "step": 1}),
+ "height": ("INT", {"default": 512,"min": 0, "max": 4096, "step": 1}),
+ "num_frames": ("INT", {"default": 1,"min": 1, "max": 255, "step": 1}),
+ "delay": ("FLOAT", {"default": 0.1,"min": 0.0, "max": 10.0, "step": 0.01}),
+ },
+ }
+
+ def screencap(self, x, y, width, height, num_frames, delay):
+ start_time = time.time()
+ captures = []
+ bbox = (x, y, x + width, y + height)
+
+ for _ in range(num_frames):
+ # Capture screen
+ screen_capture = ImageGrab.grab(bbox=bbox)
+ screen_capture_torch = torch.from_numpy(np.array(screen_capture, dtype=np.float32) / 255.0).unsqueeze(0)
+ captures.append(screen_capture_torch)
+
+ # Wait for a short delay if more than one frame is to be captured
+ if num_frames > 1:
+ time.sleep(delay)
+
+ elapsed_time = time.time() - start_time
+ print(f"screengrab took {elapsed_time} seconds.")
+
+ return (torch.cat(captures, dim=0),)
+
+class Screencap_mss:
+
+ @classmethod
+ def IS_CHANGED(s, **kwargs):
+ return float("NaN")
+
+ RETURN_TYPES = ("IMAGE",)
+ RETURN_NAMES = ("image",)
+ FUNCTION = "screencap"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Captures an area specified by screen coordinates.
+Can be used for realtime diffusion with autoqueue.
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "x": ("INT", {"default": 0,"min": 0, "max": 10000, "step": 1}),
+ "y": ("INT", {"default": 0,"min": 0, "max": 10000, "step": 1}),
+ "width": ("INT", {"default": 512,"min": 0, "max": 10000, "step": 1}),
+ "height": ("INT", {"default": 512,"min": 0, "max": 10000, "step": 1}),
+ "num_frames": ("INT", {"default": 1,"min": 1, "max": 255, "step": 1}),
+ "delay": ("FLOAT", {"default": 0.1,"min": 0.0, "max": 10.0, "step": 0.01}),
+ },
+ }
+
+ def screencap(self, x, y, width, height, num_frames, delay):
+ from mss import mss
+ captures = []
+ with mss() as sct:
+ bbox = {'top': y, 'left': x, 'width': width, 'height': height}
+
+ for _ in range(num_frames):
+ sct_img = sct.grab(bbox)
+ img_np = np.array(sct_img)
+ img_torch = torch.from_numpy(img_np[..., [2, 1, 0]]).float() / 255.0
+ captures.append(img_torch)
+
+ if num_frames > 1:
+ time.sleep(delay)
+
+ return (torch.stack(captures, 0),)
+
+class WebcamCaptureCV2:
+
+ @classmethod
+ def IS_CHANGED(cls):
+ return
+
+ RETURN_TYPES = ("IMAGE",)
+ RETURN_NAMES = ("image",)
+ FUNCTION = "capture"
+ CATEGORY = "KJNodes/experimental"
+ DESCRIPTION = """
+Captures a frame from a webcam using CV2.
+Can be used for realtime diffusion with autoqueue.
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "x": ("INT", {"default": 0,"min": 0, "max": 4096, "step": 1}),
+ "y": ("INT", {"default": 0,"min": 0, "max": 4096, "step": 1}),
+ "width": ("INT", {"default": 512,"min": 0, "max": 4096, "step": 1}),
+ "height": ("INT", {"default": 512,"min": 0, "max": 4096, "step": 1}),
+ "cam_index": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}),
+ "release": ("BOOLEAN", {"default": False}),
+ },
+ }
+
+ def capture(self, x, y, cam_index, width, height, release):
+ # Check if the camera index has changed or the capture object doesn't exist
+ if not hasattr(self, "cap") or self.cap is None or self.current_cam_index != cam_index:
+ if hasattr(self, "cap") and self.cap is not None:
+ self.cap.release()
+ self.current_cam_index = cam_index
+ self.cap = cv2.VideoCapture(cam_index)
+ try:
+ self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
+ self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
+ except:
+ pass
+ if not self.cap.isOpened():
+ raise Exception("Could not open webcam")
+
+ ret, frame = self.cap.read()
+ if not ret:
+ raise Exception("Failed to capture image from webcam")
+
+ # Crop the frame to the specified bbox
+ frame = frame[y:y+height, x:x+width]
+ img_torch = torch.from_numpy(frame[..., [2, 1, 0]]).float() / 255.0
+
+ if release:
+ self.cap.release()
+ self.cap = None
+
+ return (img_torch.unsqueeze(0),)
+
+class AddLabel:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "image":("IMAGE",),
+ "text_x": ("INT", {"default": 10, "min": 0, "max": 4096, "step": 1}),
+ "text_y": ("INT", {"default": 2, "min": 0, "max": 4096, "step": 1}),
+ "height": ("INT", {"default": 48, "min": -1, "max": 4096, "step": 1}),
+ "font_size": ("INT", {"default": 32, "min": 0, "max": 4096, "step": 1}),
+ "font_color": ("STRING", {"default": "white"}),
+ "label_color": ("STRING", {"default": "black"}),
+ "font": (folder_paths.get_filename_list("kjnodes_fonts"), ),
+ "text": ("STRING", {"default": "Text"}),
+ "direction": (
+ [ 'up',
+ 'down',
+ 'left',
+ 'right',
+ 'overlay'
+ ],
+ {
+ "default": 'up'
+ }),
+ },
+ "optional":{
+ "caption": ("STRING", {"default": "", "forceInput": True}),
+ }
+ }
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "addlabel"
+ CATEGORY = "KJNodes/text"
+ DESCRIPTION = """
+Creates a new with the given text, and concatenates it to
+either above or below the input image.
+Note that this changes the input image's height!
+Fonts are loaded from this folder:
+ComfyUI/custom_nodes/ComfyUI-KJNodes/fonts
+"""
+
+ def addlabel(self, image, text_x, text_y, text, height, font_size, font_color, label_color, font, direction, caption=""):
+ batch_size = image.shape[0]
+ width = image.shape[2]
+
+ font_path = os.path.join(script_directory, "fonts", "TTNorms-Black.otf") if font == "TTNorms-Black.otf" else folder_paths.get_full_path("kjnodes_fonts", font)
+
+ def process_image(input_image, caption_text):
+ font = ImageFont.truetype(font_path, font_size)
+ words = caption_text.split()
+ lines = []
+ current_line = []
+ current_line_width = 0
+
+ for word in words:
+ word_width = font.getbbox(word)[2]
+ if current_line_width + word_width <= width - 2 * text_x:
+ current_line.append(word)
+ current_line_width += word_width + font.getbbox(" ")[2] # Add space width
+ else:
+ lines.append(" ".join(current_line))
+ current_line = [word]
+ current_line_width = word_width
+
+ if current_line:
+ lines.append(" ".join(current_line))
+
+ if direction == 'overlay':
+ pil_image = Image.fromarray((input_image.cpu().numpy() * 255).astype(np.uint8))
+ else:
+ if height == -1:
+ # Adjust the image height automatically
+ margin = 8
+ required_height = (text_y + len(lines) * font_size) + margin # Calculate required height
+ pil_image = Image.new("RGB", (width, required_height), label_color)
+ else:
+ # Initialize with a minimal height
+ label_image = Image.new("RGB", (width, height), label_color)
+ pil_image = label_image
+
+ draw = ImageDraw.Draw(pil_image)
+
+
+ y_offset = text_y
+ for line in lines:
+ try:
+ draw.text((text_x, y_offset), line, font=font, fill=font_color, features=['-liga'])
+ except:
+ draw.text((text_x, y_offset), line, font=font, fill=font_color)
+ y_offset += font_size
+
+ processed_image = torch.from_numpy(np.array(pil_image).astype(np.float32) / 255.0).unsqueeze(0)
+ return processed_image
+
+ if caption == "":
+ processed_images = [process_image(img, text) for img in image]
+ else:
+ assert len(caption) == batch_size, f"Number of captions {(len(caption))} does not match number of images"
+ processed_images = [process_image(img, cap) for img, cap in zip(image, caption)]
+ processed_batch = torch.cat(processed_images, dim=0)
+
+ # Combine images based on direction
+ if direction == 'down':
+ combined_images = torch.cat((image, processed_batch), dim=1)
+ elif direction == 'up':
+ combined_images = torch.cat((processed_batch, image), dim=1)
+ elif direction == 'left':
+ processed_batch = torch.rot90(processed_batch, 3, (2, 3)).permute(0, 3, 1, 2)
+ combined_images = torch.cat((processed_batch, image), dim=2)
+ elif direction == 'right':
+ processed_batch = torch.rot90(processed_batch, 3, (2, 3)).permute(0, 3, 1, 2)
+ combined_images = torch.cat((image, processed_batch), dim=2)
+ else:
+ combined_images = processed_batch
+
+ return (combined_images,)
+
+class GetImageSizeAndCount:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "image": ("IMAGE",),
+ }}
+
+ RETURN_TYPES = ("IMAGE","INT", "INT", "INT",)
+ RETURN_NAMES = ("image", "width", "height", "count",)
+ FUNCTION = "getsize"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Returns width, height and batch size of the image,
+and passes it through unchanged.
+
+"""
+
+ def getsize(self, image):
+ width = image.shape[2]
+ height = image.shape[1]
+ count = image.shape[0]
+ return {"ui": {
+ "text": [f"{count}x{width}x{height}"]},
+ "result": (image, width, height, count)
+ }
+
+class ImageBatchRepeatInterleaving:
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "repeat"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Repeats each image in a batch by the specified number of times.
+Example batch of 5 images: 0, 1 ,2, 3, 4
+with repeats 2 becomes batch of 10 images: 0, 0, 1, 1, 2, 2, 3, 3, 4, 4
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "images": ("IMAGE",),
+ "repeats": ("INT", {"default": 1, "min": 1, "max": 4096}),
+ },
+ }
+
+ def repeat(self, images, repeats):
+
+ repeated_images = torch.repeat_interleave(images, repeats=repeats, dim=0)
+ return (repeated_images, )
+
+class ImageUpscaleWithModelBatched:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": { "upscale_model": ("UPSCALE_MODEL",),
+ "images": ("IMAGE",),
+ "per_batch": ("INT", {"default": 16, "min": 1, "max": 4096, "step": 1}),
+ }}
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "upscale"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Same as ComfyUI native model upscaling node,
+but allows setting sub-batches for reduced VRAM usage.
+"""
+ def upscale(self, upscale_model, images, per_batch):
+
+ device = model_management.get_torch_device()
+ upscale_model.to(device)
+ in_img = images.movedim(-1,-3)
+
+ steps = in_img.shape[0]
+ pbar = ProgressBar(steps)
+ t = []
+
+ for start_idx in range(0, in_img.shape[0], per_batch):
+ sub_images = upscale_model(in_img[start_idx:start_idx+per_batch].to(device))
+ t.append(sub_images.cpu())
+ # Calculate the number of images processed in this batch
+ batch_count = sub_images.shape[0]
+ # Update the progress bar by the number of images processed in this batch
+ pbar.update(batch_count)
+ upscale_model.cpu()
+
+ t = torch.cat(t, dim=0).permute(0, 2, 3, 1).cpu()
+
+ return (t,)
+
+class ImageNormalize_Neg1_To_1:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "images": ("IMAGE",),
+
+ }}
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "normalize"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Normalize the images to be in the range [-1, 1]
+"""
+
+ def normalize(self,images):
+ images = images * 2.0 - 1.0
+ return (images,)
+
+class RemapImageRange:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "image": ("IMAGE",),
+ "min": ("FLOAT", {"default": 0.0,"min": -10.0, "max": 1.0, "step": 0.01}),
+ "max": ("FLOAT", {"default": 1.0,"min": 0.0, "max": 10.0, "step": 0.01}),
+ "clamp": ("BOOLEAN", {"default": True}),
+ },
+ }
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "remap"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Remaps the image values to the specified range.
+"""
+
+ def remap(self, image, min, max, clamp):
+ if image.dtype == torch.float16:
+ image = image.to(torch.float32)
+ image = min + image * (max - min)
+ if clamp:
+ image = torch.clamp(image, min=0.0, max=1.0)
+ return (image, )
+
+class SplitImageChannels:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "image": ("IMAGE",),
+ },
+ }
+
+ RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE", "MASK")
+ RETURN_NAMES = ("red", "green", "blue", "mask")
+ FUNCTION = "split"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Splits image channels into images where the selected channel
+is repeated for all channels, and the alpha as a mask.
+"""
+
+ def split(self, image):
+ red = image[:, :, :, 0:1] # Red channel
+ green = image[:, :, :, 1:2] # Green channel
+ blue = image[:, :, :, 2:3] # Blue channel
+ alpha = image[:, :, :, 3:4] # Alpha channel
+ alpha = alpha.squeeze(-1)
+
+ # Repeat the selected channel for all channels
+ red = torch.cat([red, red, red], dim=3)
+ green = torch.cat([green, green, green], dim=3)
+ blue = torch.cat([blue, blue, blue], dim=3)
+ return (red, green, blue, alpha)
+
+class MergeImageChannels:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "red": ("IMAGE",),
+ "green": ("IMAGE",),
+ "blue": ("IMAGE",),
+
+ },
+ "optional": {
+ "alpha": ("MASK", {"default": None}),
+ },
+ }
+
+ RETURN_TYPES = ("IMAGE",)
+ RETURN_NAMES = ("image",)
+ FUNCTION = "merge"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Merges channel data into an image.
+"""
+
+ def merge(self, red, green, blue, alpha=None):
+ image = torch.stack([
+ red[..., 0, None], # Red channel
+ green[..., 1, None], # Green channel
+ blue[..., 2, None] # Blue channel
+ ], dim=-1)
+ image = image.squeeze(-2)
+ if alpha is not None:
+ image = torch.cat([image, alpha.unsqueeze(-1)], dim=-1)
+ return (image,)
+
+class ImagePadForOutpaintMasked:
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": ("IMAGE",),
+ "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
+ "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
+ "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
+ "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
+ "feathering": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
+ },
+ "optional": {
+ "mask": ("MASK",),
+ }
+ }
+
+ RETURN_TYPES = ("IMAGE", "MASK")
+ FUNCTION = "expand_image"
+
+ CATEGORY = "image"
+
+ def expand_image(self, image, left, top, right, bottom, feathering, mask=None):
+ if mask is not None:
+ if torch.allclose(mask, torch.zeros_like(mask)):
+ print("Warning: The incoming mask is fully black. Handling it as None.")
+ mask = None
+ B, H, W, C = image.size()
+
+ new_image = torch.ones(
+ (B, H + top + bottom, W + left + right, C),
+ dtype=torch.float32,
+ ) * 0.5
+
+ new_image[:, top:top + H, left:left + W, :] = image
+
+ if mask is None:
+ new_mask = torch.ones(
+ (B, H + top + bottom, W + left + right),
+ dtype=torch.float32,
+ )
+
+ t = torch.zeros(
+ (B, H, W),
+ dtype=torch.float32
+ )
+ else:
+ # If a mask is provided, pad it to fit the new image size
+ mask = F.pad(mask, (left, right, top, bottom), mode='constant', value=0)
+ mask = 1 - mask
+ t = torch.zeros_like(mask)
+
+ if feathering > 0 and feathering * 2 < H and feathering * 2 < W:
+
+ for i in range(H):
+ for j in range(W):
+ dt = i if top != 0 else H
+ db = H - i if bottom != 0 else H
+
+ dl = j if left != 0 else W
+ dr = W - j if right != 0 else W
+
+ d = min(dt, db, dl, dr)
+
+ if d >= feathering:
+ continue
+
+ v = (feathering - d) / feathering
+
+ if mask is None:
+ t[:, i, j] = v * v
+ else:
+ t[:, top + i, left + j] = v * v
+
+ if mask is None:
+ new_mask[:, top:top + H, left:left + W] = t
+ return (new_image, new_mask,)
+ else:
+ return (new_image, mask,)
+
+class ImagePadForOutpaintTargetSize:
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": ("IMAGE",),
+ "target_width": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
+ "target_height": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
+ "feathering": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
+ "upscale_method": (s.upscale_methods,),
+ },
+ "optional": {
+ "mask": ("MASK",),
+ }
+ }
+
+ RETURN_TYPES = ("IMAGE", "MASK")
+ FUNCTION = "expand_image"
+
+ CATEGORY = "image"
+
+ def expand_image(self, image, target_width, target_height, feathering, upscale_method, mask=None):
+ B, H, W, C = image.size()
+ new_height = H
+ new_width = W
+ # Calculate the scaling factor while maintaining aspect ratio
+ scaling_factor = min(target_width / W, target_height / H)
+
+ # Check if the image needs to be downscaled
+ if scaling_factor < 1:
+ image = image.movedim(-1,1)
+ # Calculate the new width and height after downscaling
+ new_width = int(W * scaling_factor)
+ new_height = int(H * scaling_factor)
+
+ # Downscale the image
+ image_scaled = common_upscale(image, new_width, new_height, upscale_method, "disabled").movedim(1,-1)
+ if mask is not None:
+ mask_scaled = mask.unsqueeze(0) # Add an extra dimension for batch size
+ mask_scaled = F.interpolate(mask_scaled, size=(new_height, new_width), mode="nearest")
+ mask_scaled = mask_scaled.squeeze(0) # Remove the extra dimension after interpolation
+ else:
+ mask_scaled = mask
+ else:
+ # If downscaling is not needed, use the original image dimensions
+ image_scaled = image
+ mask_scaled = mask
+
+ # Calculate how much padding is needed to reach the target dimensions
+ pad_top = max(0, (target_height - new_height) // 2)
+ pad_bottom = max(0, target_height - new_height - pad_top)
+ pad_left = max(0, (target_width - new_width) // 2)
+ pad_right = max(0, target_width - new_width - pad_left)
+
+ # Now call the original expand_image with the calculated padding
+ return ImagePadForOutpaintMasked.expand_image(self, image_scaled, pad_left, pad_top, pad_right, pad_bottom, feathering, mask_scaled)
+
+class ImagePrepForICLora:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "reference_image": ("IMAGE",),
+ "output_width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
+ "output_height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
+ "border_width": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}),
+ },
+ "optional": {
+ "latent_image": ("IMAGE",),
+ "latent_mask": ("MASK",),
+ "reference_mask": ("MASK",),
+ }
+ }
+
+ RETURN_TYPES = ("IMAGE", "MASK")
+ FUNCTION = "expand_image"
+
+ CATEGORY = "image"
+
+ def expand_image(self, reference_image, output_width, output_height, border_width, latent_image=None, reference_mask=None, latent_mask=None):
+
+ if reference_mask is not None:
+ if torch.allclose(reference_mask, torch.zeros_like(reference_mask)):
+ print("Warning: The incoming mask is fully black. Handling it as None.")
+ reference_mask = None
+ image = reference_image
+ B, H, W, C = image.size()
+
+ # Handle mask
+ if reference_mask is not None:
+ resized_mask = torch.nn.functional.interpolate(
+ reference_mask.unsqueeze(1),
+ size=(H, W),
+ mode='nearest'
+ ).squeeze(1)
+ print(resized_mask.shape)
+ image = image * resized_mask.unsqueeze(-1)
+
+ # Calculate new width maintaining aspect ratio
+ new_width = int((W / H) * output_height)
+
+ # Resize image to new height while maintaining aspect ratio
+ resized_image = common_upscale(image.movedim(-1,1), new_width, output_height, "lanczos", "disabled").movedim(1,-1)
+
+ # Create padded image
+ if latent_image is None:
+ pad_image = torch.zeros((B, output_height, output_width, C), device=image.device)
+ else:
+ resized_latent_image = common_upscale(latent_image.movedim(-1,1), output_width, output_height, "lanczos", "disabled").movedim(1,-1)
+ pad_image = resized_latent_image
+ if latent_mask is not None:
+ resized_latent_mask = torch.nn.functional.interpolate(
+ latent_mask.unsqueeze(1),
+ size=(pad_image.shape[1], pad_image.shape[2]),
+ mode='nearest'
+ ).squeeze(1)
+
+ if border_width > 0:
+ border = torch.zeros((B, output_height, border_width, C), device=image.device)
+ padded_image = torch.cat((resized_image, border, pad_image), dim=2)
+ if latent_mask is not None:
+ padded_mask = torch.zeros((B, padded_image.shape[1], padded_image.shape[2]), device=image.device)
+ padded_mask[:, :, (new_width + border_width):] = resized_latent_mask
+ else:
+ padded_mask = torch.ones((B, padded_image.shape[1], padded_image.shape[2]), device=image.device)
+ padded_mask[:, :, :new_width + border_width] = 0
+ else:
+ padded_image = torch.cat((resized_image, pad_image), dim=2)
+ if latent_mask is not None:
+ padded_mask = torch.zeros((B, padded_image.shape[1], padded_image.shape[2]), device=image.device)
+ padded_mask[:, :, new_width:] = resized_latent_mask
+ else:
+ padded_mask = torch.ones((B, padded_image.shape[1], padded_image.shape[2]), device=image.device)
+ padded_mask[:, :, :new_width] = 0
+
+ return (padded_image, padded_mask)
+
+
+class ImageAndMaskPreview(SaveImage):
+ def __init__(self):
+ self.output_dir = folder_paths.get_temp_directory()
+ self.type = "temp"
+ self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
+ self.compress_level = 4
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "mask_opacity": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
+ "mask_color": ("STRING", {"default": "255, 255, 255"}),
+ "pass_through": ("BOOLEAN", {"default": False}),
+ },
+ "optional": {
+ "image": ("IMAGE",),
+ "mask": ("MASK",),
+ },
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
+ }
+ RETURN_TYPES = ("IMAGE",)
+ RETURN_NAMES = ("composite",)
+ FUNCTION = "execute"
+ CATEGORY = "KJNodes/masking"
+ DESCRIPTION = """
+Preview an image or a mask, when both inputs are used
+composites the mask on top of the image.
+with pass_through on the preview is disabled and the
+composite is returned from the composite slot instead,
+this allows for the preview to be passed for video combine
+nodes for example.
+"""
+
+ def execute(self, mask_opacity, mask_color, pass_through, filename_prefix="ComfyUI", image=None, mask=None, prompt=None, extra_pnginfo=None):
+ if mask is not None and image is None:
+ preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
+ elif mask is None and image is not None:
+ preview = image
+ elif mask is not None and image is not None:
+ mask_adjusted = mask * mask_opacity
+ mask_image = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3).clone()
+
+ if ',' in mask_color:
+ color_list = np.clip([int(channel) for channel in mask_color.split(',')], 0, 255) # RGB format
+ else:
+ mask_color = mask_color.lstrip('#')
+ color_list = [int(mask_color[i:i+2], 16) for i in (0, 2, 4)] # Hex format
+ mask_image[:, :, :, 0] = color_list[0] / 255 # Red channel
+ mask_image[:, :, :, 1] = color_list[1] / 255 # Green channel
+ mask_image[:, :, :, 2] = color_list[2] / 255 # Blue channel
+
+ preview, = ImageCompositeMasked.composite(self, image, mask_image, 0, 0, True, mask_adjusted)
+ if pass_through:
+ return (preview, )
+ return(self.save_images(preview, filename_prefix, prompt, extra_pnginfo))
+
+class CrossFadeImages:
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "crossfadeimages"
+ CATEGORY = "KJNodes/image"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "images_1": ("IMAGE",),
+ "images_2": ("IMAGE",),
+ "interpolation": (["linear", "ease_in", "ease_out", "ease_in_out", "bounce", "elastic", "glitchy", "exponential_ease_out"],),
+ "transition_start_index": ("INT", {"default": 1,"min": 0, "max": 4096, "step": 1}),
+ "transitioning_frames": ("INT", {"default": 1,"min": 0, "max": 4096, "step": 1}),
+ "start_level": ("FLOAT", {"default": 0.0,"min": 0.0, "max": 1.0, "step": 0.01}),
+ "end_level": ("FLOAT", {"default": 1.0,"min": 0.0, "max": 1.0, "step": 0.01}),
+ },
+ }
+
+ def crossfadeimages(self, images_1, images_2, transition_start_index, transitioning_frames, interpolation, start_level, end_level):
+
+ def crossfade(images_1, images_2, alpha):
+ crossfade = (1 - alpha) * images_1 + alpha * images_2
+ return crossfade
+ def ease_in(t):
+ return t * t
+ def ease_out(t):
+ return 1 - (1 - t) * (1 - t)
+ def ease_in_out(t):
+ return 3 * t * t - 2 * t * t * t
+ def bounce(t):
+ if t < 0.5:
+ return self.ease_out(t * 2) * 0.5
+ else:
+ return self.ease_in((t - 0.5) * 2) * 0.5 + 0.5
+ def elastic(t):
+ return math.sin(13 * math.pi / 2 * t) * math.pow(2, 10 * (t - 1))
+ def glitchy(t):
+ return t + 0.1 * math.sin(40 * t)
+ def exponential_ease_out(t):
+ return 1 - (1 - t) ** 4
+
+ easing_functions = {
+ "linear": lambda t: t,
+ "ease_in": ease_in,
+ "ease_out": ease_out,
+ "ease_in_out": ease_in_out,
+ "bounce": bounce,
+ "elastic": elastic,
+ "glitchy": glitchy,
+ "exponential_ease_out": exponential_ease_out,
+ }
+
+ crossfade_images = []
+
+ alphas = torch.linspace(start_level, end_level, transitioning_frames)
+ for i in range(transitioning_frames):
+ alpha = alphas[i]
+ image1 = images_1[i + transition_start_index]
+ image2 = images_2[i + transition_start_index]
+ easing_function = easing_functions.get(interpolation)
+ alpha = easing_function(alpha) # Apply the easing function to the alpha value
+
+ crossfade_image = crossfade(image1, image2, alpha)
+ crossfade_images.append(crossfade_image)
+
+ # Convert crossfade_images to tensor
+ crossfade_images = torch.stack(crossfade_images, dim=0)
+ # Get the last frame result of the interpolation
+ last_frame = crossfade_images[-1]
+ # Calculate the number of remaining frames from images_2
+ remaining_frames = len(images_2) - (transition_start_index + transitioning_frames)
+ # Crossfade the remaining frames with the last used alpha value
+ for i in range(remaining_frames):
+ alpha = alphas[-1]
+ image1 = images_1[i + transition_start_index + transitioning_frames]
+ image2 = images_2[i + transition_start_index + transitioning_frames]
+ easing_function = easing_functions.get(interpolation)
+ alpha = easing_function(alpha) # Apply the easing function to the alpha value
+
+ crossfade_image = crossfade(image1, image2, alpha)
+ crossfade_images = torch.cat([crossfade_images, crossfade_image.unsqueeze(0)], dim=0)
+ # Append the beginning of images_1
+ beginning_images_1 = images_1[:transition_start_index]
+ crossfade_images = torch.cat([beginning_images_1, crossfade_images], dim=0)
+ return (crossfade_images, )
+
+class CrossFadeImagesMulti:
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "crossfadeimages"
+ CATEGORY = "KJNodes/image"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "inputcount": ("INT", {"default": 2, "min": 2, "max": 1000, "step": 1}),
+ "image_1": ("IMAGE",),
+ "image_2": ("IMAGE",),
+ "interpolation": (["linear", "ease_in", "ease_out", "ease_in_out", "bounce", "elastic", "glitchy", "exponential_ease_out"],),
+ "transitioning_frames": ("INT", {"default": 1,"min": 0, "max": 4096, "step": 1}),
+ },
+ }
+
+ def crossfadeimages(self, inputcount, transitioning_frames, interpolation, **kwargs):
+
+ def crossfade(images_1, images_2, alpha):
+ crossfade = (1 - alpha) * images_1 + alpha * images_2
+ return crossfade
+ def ease_in(t):
+ return t * t
+ def ease_out(t):
+ return 1 - (1 - t) * (1 - t)
+ def ease_in_out(t):
+ return 3 * t * t - 2 * t * t * t
+ def bounce(t):
+ if t < 0.5:
+ return self.ease_out(t * 2) * 0.5
+ else:
+ return self.ease_in((t - 0.5) * 2) * 0.5 + 0.5
+ def elastic(t):
+ return math.sin(13 * math.pi / 2 * t) * math.pow(2, 10 * (t - 1))
+ def glitchy(t):
+ return t + 0.1 * math.sin(40 * t)
+ def exponential_ease_out(t):
+ return 1 - (1 - t) ** 4
+
+ easing_functions = {
+ "linear": lambda t: t,
+ "ease_in": ease_in,
+ "ease_out": ease_out,
+ "ease_in_out": ease_in_out,
+ "bounce": bounce,
+ "elastic": elastic,
+ "glitchy": glitchy,
+ "exponential_ease_out": exponential_ease_out,
+ }
+
+ image_1 = kwargs["image_1"]
+ height = image_1.shape[1]
+ width = image_1.shape[2]
+
+ easing_function = easing_functions[interpolation]
+
+ for c in range(1, inputcount):
+ frames = []
+ new_image = kwargs[f"image_{c + 1}"]
+ new_image_height = new_image.shape[1]
+ new_image_width = new_image.shape[2]
+
+ if new_image_height != height or new_image_width != width:
+ new_image = common_upscale(new_image.movedim(-1, 1), width, height, "lanczos", "disabled")
+ new_image = new_image.movedim(1, -1) # Move channels back to the last dimension
+
+ last_frame_image_1 = image_1[-1]
+ first_frame_image_2 = new_image[0]
+
+ for frame in range(transitioning_frames):
+ t = frame / (transitioning_frames - 1)
+ alpha = easing_function(t)
+ alpha_tensor = torch.tensor(alpha, dtype=last_frame_image_1.dtype, device=last_frame_image_1.device)
+ frame_image = crossfade(last_frame_image_1, first_frame_image_2, alpha_tensor)
+ frames.append(frame_image)
+
+ frames = torch.stack(frames)
+ image_1 = torch.cat((image_1, frames, new_image), dim=0)
+
+ return image_1,
+
+def transition_images(images_1, images_2, alpha, transition_type, blur_radius, reverse):
+ width = images_1.shape[1]
+ height = images_1.shape[0]
+
+ mask = torch.zeros_like(images_1, device=images_1.device)
+
+ alpha = alpha.item()
+ if reverse:
+ alpha = 1 - alpha
+
+ #transitions from matteo's essential nodes
+ if "horizontal slide" in transition_type:
+ pos = round(width * alpha)
+ mask[:, :pos, :] = 1.0
+ elif "vertical slide" in transition_type:
+ pos = round(height * alpha)
+ mask[:pos, :, :] = 1.0
+ elif "box" in transition_type:
+ box_w = round(width * alpha)
+ box_h = round(height * alpha)
+ x1 = (width - box_w) // 2
+ y1 = (height - box_h) // 2
+ x2 = x1 + box_w
+ y2 = y1 + box_h
+ mask[y1:y2, x1:x2, :] = 1.0
+ elif "circle" in transition_type:
+ radius = math.ceil(math.sqrt(pow(width, 2) + pow(height, 2)) * alpha / 2)
+ c_x = width // 2
+ c_y = height // 2
+ x = torch.arange(0, width, dtype=torch.float32, device="cpu")
+ y = torch.arange(0, height, dtype=torch.float32, device="cpu")
+ y, x = torch.meshgrid((y, x), indexing="ij")
+ circle = ((x - c_x) ** 2 + (y - c_y) ** 2) <= (radius ** 2)
+ mask[circle] = 1.0
+ elif "horizontal door" in transition_type:
+ bar = math.ceil(height * alpha / 2)
+ if bar > 0:
+ mask[:bar, :, :] = 1.0
+ mask[-bar:,:, :] = 1.0
+ elif "vertical door" in transition_type:
+ bar = math.ceil(width * alpha / 2)
+ if bar > 0:
+ mask[:, :bar,:] = 1.0
+ mask[:, -bar:,:] = 1.0
+ elif "fade" in transition_type:
+ mask[:, :, :] = alpha
+
+ mask = gaussian_blur(mask, blur_radius)
+
+ return images_1 * (1 - mask) + images_2 * mask
+
+def ease_in(t):
+ return t * t
+def ease_out(t):
+ return 1 - (1 - t) * (1 - t)
+def ease_in_out(t):
+ return 3 * t * t - 2 * t * t * t
+def bounce(t):
+ if t < 0.5:
+ return ease_out(t * 2) * 0.5
+ else:
+ return ease_in((t - 0.5) * 2) * 0.5 + 0.5
+def elastic(t):
+ return math.sin(13 * math.pi / 2 * t) * math.pow(2, 10 * (t - 1))
+def glitchy(t):
+ return t + 0.1 * math.sin(40 * t)
+def exponential_ease_out(t):
+ return 1 - (1 - t) ** 4
+
+def gaussian_blur(mask, blur_radius):
+ if blur_radius > 0:
+ kernel_size = int(blur_radius * 2) + 1
+ if kernel_size % 2 == 0:
+ kernel_size += 1 # Ensure kernel size is odd
+ sigma = blur_radius / 3
+ x = torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=torch.float32)
+ x = torch.exp(-0.5 * (x / sigma) ** 2)
+ kernel1d = x / x.sum()
+ kernel2d = kernel1d[:, None] * kernel1d[None, :]
+ kernel2d = kernel2d.to(mask.device)
+ kernel2d = kernel2d.expand(mask.shape[2], 1, kernel2d.shape[0], kernel2d.shape[1])
+ mask = mask.permute(2, 0, 1).unsqueeze(0) # Change to [C, H, W] and add batch dimension
+ mask = F.conv2d(mask, kernel2d, padding=kernel_size // 2, groups=mask.shape[1])
+ mask = mask.squeeze(0).permute(1, 2, 0) # Change back to [H, W, C]
+ return mask
+
+easing_functions = {
+ "linear": lambda t: t,
+ "ease_in": ease_in,
+ "ease_out": ease_out,
+ "ease_in_out": ease_in_out,
+ "bounce": bounce,
+ "elastic": elastic,
+ "glitchy": glitchy,
+ "exponential_ease_out": exponential_ease_out,
+}
+
+class TransitionImagesMulti:
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "transition"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Creates transitions between images.
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "inputcount": ("INT", {"default": 2, "min": 2, "max": 1000, "step": 1}),
+ "image_1": ("IMAGE",),
+ "image_2": ("IMAGE",),
+ "interpolation": (["linear", "ease_in", "ease_out", "ease_in_out", "bounce", "elastic", "glitchy", "exponential_ease_out"],),
+ "transition_type": (["horizontal slide", "vertical slide", "box", "circle", "horizontal door", "vertical door", "fade"],),
+ "transitioning_frames": ("INT", {"default": 1,"min": 0, "max": 4096, "step": 1}),
+ "blur_radius": ("FLOAT", {"default": 0.0,"min": 0.0, "max": 100.0, "step": 0.1}),
+ "reverse": ("BOOLEAN", {"default": False}),
+ "device": (["CPU", "GPU"], {"default": "CPU"}),
+ },
+ }
+
+ def transition(self, inputcount, transitioning_frames, transition_type, interpolation, device, blur_radius, reverse, **kwargs):
+
+ gpu = model_management.get_torch_device()
+
+ image_1 = kwargs["image_1"]
+ height = image_1.shape[1]
+ width = image_1.shape[2]
+
+ easing_function = easing_functions[interpolation]
+
+ for c in range(1, inputcount):
+ frames = []
+ new_image = kwargs[f"image_{c + 1}"]
+ new_image_height = new_image.shape[1]
+ new_image_width = new_image.shape[2]
+
+ if new_image_height != height or new_image_width != width:
+ new_image = common_upscale(new_image.movedim(-1, 1), width, height, "lanczos", "disabled")
+ new_image = new_image.movedim(1, -1) # Move channels back to the last dimension
+
+ last_frame_image_1 = image_1[-1]
+ first_frame_image_2 = new_image[0]
+ if device == "GPU":
+ last_frame_image_1 = last_frame_image_1.to(gpu)
+ first_frame_image_2 = first_frame_image_2.to(gpu)
+
+ if reverse:
+ last_frame_image_1, first_frame_image_2 = first_frame_image_2, last_frame_image_1
+
+ for frame in range(transitioning_frames):
+ t = frame / (transitioning_frames - 1)
+ alpha = easing_function(t)
+ alpha_tensor = torch.tensor(alpha, dtype=last_frame_image_1.dtype, device=last_frame_image_1.device)
+ frame_image = transition_images(last_frame_image_1, first_frame_image_2, alpha_tensor, transition_type, blur_radius, reverse)
+ frames.append(frame_image)
+
+ frames = torch.stack(frames).cpu()
+ image_1 = torch.cat((image_1, frames, new_image), dim=0)
+
+ return image_1.cpu(),
+
+class TransitionImagesInBatch:
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "transition"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Creates transitions between images in a batch.
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "images": ("IMAGE",),
+ "interpolation": (["linear", "ease_in", "ease_out", "ease_in_out", "bounce", "elastic", "glitchy", "exponential_ease_out"],),
+ "transition_type": (["horizontal slide", "vertical slide", "box", "circle", "horizontal door", "vertical door", "fade"],),
+ "transitioning_frames": ("INT", {"default": 1,"min": 0, "max": 4096, "step": 1}),
+ "blur_radius": ("FLOAT", {"default": 0.0,"min": 0.0, "max": 100.0, "step": 0.1}),
+ "reverse": ("BOOLEAN", {"default": False}),
+ "device": (["CPU", "GPU"], {"default": "CPU"}),
+ },
+ }
+
+ #transitions from matteo's essential nodes
+ def transition(self, images, transitioning_frames, transition_type, interpolation, device, blur_radius, reverse):
+ if images.shape[0] == 1:
+ return images,
+
+ gpu = model_management.get_torch_device()
+
+ easing_function = easing_functions[interpolation]
+
+ images_list = []
+ pbar = ProgressBar(images.shape[0] - 1)
+ for i in range(images.shape[0] - 1):
+ frames = []
+ image_1 = images[i]
+ image_2 = images[i + 1]
+
+ if device == "GPU":
+ image_1 = image_1.to(gpu)
+ image_2 = image_2.to(gpu)
+
+ if reverse:
+ image_1, image_2 = image_2, image_1
+
+ for frame in range(transitioning_frames):
+ t = frame / (transitioning_frames - 1)
+ alpha = easing_function(t)
+ alpha_tensor = torch.tensor(alpha, dtype=image_1.dtype, device=image_1.device)
+ frame_image = transition_images(image_1, image_2, alpha_tensor, transition_type, blur_radius, reverse)
+ frames.append(frame_image)
+ pbar.update(1)
+
+ frames = torch.stack(frames).cpu()
+ images_list.append(frames)
+ images = torch.cat(images_list, dim=0)
+
+ return images.cpu(),
+
+class ShuffleImageBatch:
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "shuffle"
+ CATEGORY = "KJNodes/image"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "images": ("IMAGE",),
+ "seed": ("INT", {"default": 123,"min": 0, "max": 0xffffffffffffffff, "step": 1}),
+ },
+ }
+
+ def shuffle(self, images, seed):
+ torch.manual_seed(seed)
+ B, H, W, C = images.shape
+ indices = torch.randperm(B)
+ shuffled_images = images[indices]
+
+ return shuffled_images,
+
+class GetImageRangeFromBatch:
+
+ RETURN_TYPES = ("IMAGE", "MASK", )
+ FUNCTION = "imagesfrombatch"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Returns a range of images from a batch.
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "start_index": ("INT", {"default": 0,"min": -1, "max": 4096, "step": 1}),
+ "num_frames": ("INT", {"default": 1,"min": 1, "max": 4096, "step": 1}),
+ },
+ "optional": {
+ "images": ("IMAGE",),
+ "masks": ("MASK",),
+ }
+ }
+
+ def imagesfrombatch(self, start_index, num_frames, images=None, masks=None):
+ chosen_images = None
+ chosen_masks = None
+
+ # Process images if provided
+ if images is not None:
+ if start_index == -1:
+ start_index = max(0, len(images) - num_frames)
+ if start_index < 0 or start_index >= len(images):
+ raise ValueError("Start index is out of range")
+ end_index = min(start_index + num_frames, len(images))
+ chosen_images = images[start_index:end_index]
+
+ # Process masks if provided
+ if masks is not None:
+ if start_index == -1:
+ start_index = max(0, len(masks) - num_frames)
+ if start_index < 0 or start_index >= len(masks):
+ raise ValueError("Start index is out of range for masks")
+ end_index = min(start_index + num_frames, len(masks))
+ chosen_masks = masks[start_index:end_index]
+
+ return (chosen_images, chosen_masks,)
+
+class GetLatentRangeFromBatch:
+
+ RETURN_TYPES = ("LATENT", )
+ FUNCTION = "latentsfrombatch"
+ CATEGORY = "KJNodes/latents"
+ DESCRIPTION = """
+Returns a range of latents from a batch.
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "latents": ("LATENT",),
+ "start_index": ("INT", {"default": 0,"min": -1, "max": 4096, "step": 1}),
+ "num_frames": ("INT", {"default": 1,"min": -1, "max": 4096, "step": 1}),
+ },
+ }
+
+ def latentsfrombatch(self, latents, start_index, num_frames):
+ chosen_latents = None
+ samples = latents["samples"]
+ if len(samples.shape) == 4:
+ B, C, H, W = samples.shape
+ num_latents = B
+ elif len(samples.shape) == 5:
+ B, C, T, H, W = samples.shape
+ num_latents = T
+
+ if start_index == -1:
+ start_index = max(0, num_latents - num_frames)
+ if start_index < 0 or start_index >= num_latents:
+ raise ValueError("Start index is out of range")
+
+ end_index = num_latents if num_frames == -1 else min(start_index + num_frames, num_latents)
+
+ if len(samples.shape) == 4:
+ chosen_latents = samples[start_index:end_index]
+ elif len(samples.shape) == 5:
+ chosen_latents = samples[:, :, start_index:end_index]
+
+ return ({"samples": chosen_latents,},)
+
+class InsertLatentToIndex:
+
+ RETURN_TYPES = ("LATENT", )
+ FUNCTION = "insert"
+ CATEGORY = "KJNodes/latents"
+ DESCRIPTION = """
+Inserts a latent at the specified index into the original latent batch.
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "source": ("LATENT",),
+ "destination": ("LATENT",),
+ "index": ("INT", {"default": 0,"min": -1, "max": 4096, "step": 1}),
+ },
+ }
+
+ def insert(self, source, destination, index):
+ samples_destination = destination["samples"]
+ samples_source = source["samples"].to(samples_destination)
+
+ if len(samples_source.shape) == 4:
+ B, C, H, W = samples_source.shape
+ num_latents = B
+ elif len(samples_source.shape) == 5:
+ B, C, T, H, W = samples_source.shape
+ num_latents = T
+
+ if index >= num_latents or index < 0:
+ raise ValueError(f"Index {index} out of bounds for tensor with {num_latents} latents")
+
+ if len(samples_source.shape) == 4:
+ joined_latents = torch.cat([
+ samples_destination[:index],
+ samples_source,
+ samples_destination[index+1:]
+ ], dim=0)
+ else:
+ joined_latents = torch.cat([
+ samples_destination[:, :, :index],
+ samples_source,
+ samples_destination[:, :, index+1:]
+ ], dim=2)
+
+ return ({"samples": joined_latents,},)
+
+class GetImagesFromBatchIndexed:
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "indexedimagesfrombatch"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Selects and returns the images at the specified indices as an image batch.
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "images": ("IMAGE",),
+ "indexes": ("STRING", {"default": "0, 1, 2", "multiline": True}),
+ },
+ }
+
+ def indexedimagesfrombatch(self, images, indexes):
+
+ # Parse the indexes string into a list of integers
+ index_list = [int(index.strip()) for index in indexes.split(',')]
+
+ # Convert list of indices to a PyTorch tensor
+ indices_tensor = torch.tensor(index_list, dtype=torch.long)
+
+ # Select the images at the specified indices
+ chosen_images = images[indices_tensor]
+
+ return (chosen_images,)
+
+class InsertImagesToBatchIndexed:
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "insertimagesfrombatch"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Inserts images at the specified indices into the original image batch.
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "original_images": ("IMAGE",),
+ "images_to_insert": ("IMAGE",),
+ "indexes": ("STRING", {"default": "0, 1, 2", "multiline": True}),
+ },
+ }
+
+ def insertimagesfrombatch(self, original_images, images_to_insert, indexes):
+
+ # Parse the indexes string into a list of integers
+ index_list = [int(index.strip()) for index in indexes.split(',')]
+
+ # Convert list of indices to a PyTorch tensor
+ indices_tensor = torch.tensor(index_list, dtype=torch.long)
+
+ # Ensure the images_to_insert is a tensor
+ if not isinstance(images_to_insert, torch.Tensor):
+ images_to_insert = torch.tensor(images_to_insert)
+
+ # Insert the images at the specified indices
+ for index, image in zip(indices_tensor, images_to_insert):
+ original_images[index] = image
+
+ return (original_images,)
+
+class ReplaceImagesInBatch:
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "replace"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Replaces the images in a batch, starting from the specified start index,
+with the replacement images.
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "original_images": ("IMAGE",),
+ "replacement_images": ("IMAGE",),
+ "start_index": ("INT", {"default": 1,"min": 0, "max": 4096, "step": 1}),
+ },
+ }
+
+ def replace(self, original_images, replacement_images, start_index):
+ images = None
+ if start_index >= len(original_images):
+ raise ValueError("GetImageRangeFromBatch: Start index is out of range")
+ end_index = start_index + len(replacement_images)
+ if end_index > len(original_images):
+ raise ValueError("GetImageRangeFromBatch: End index is out of range")
+ # Create a copy of the original_images tensor
+ original_images_copy = original_images.clone()
+ original_images_copy[start_index:end_index] = replacement_images
+ images = original_images_copy
+ return (images, )
+
+
+class ReverseImageBatch:
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "reverseimagebatch"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Reverses the order of the images in a batch.
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "images": ("IMAGE",),
+ },
+ }
+
+ def reverseimagebatch(self, images):
+ reversed_images = torch.flip(images, [0])
+ return (reversed_images, )
+
+class ImageBatchMulti:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "inputcount": ("INT", {"default": 2, "min": 2, "max": 1000, "step": 1}),
+ "image_1": ("IMAGE", ),
+ "image_2": ("IMAGE", ),
+ },
+ }
+
+ RETURN_TYPES = ("IMAGE",)
+ RETURN_NAMES = ("images",)
+ FUNCTION = "combine"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Creates an image batch from multiple images.
+You can set how many inputs the node has,
+with the **inputcount** and clicking update.
+"""
+
+ def combine(self, inputcount, **kwargs):
+ from nodes import ImageBatch
+ image_batch_node = ImageBatch()
+ image = kwargs["image_1"]
+ for c in range(1, inputcount):
+ new_image = kwargs[f"image_{c + 1}"]
+ image, = image_batch_node.batch(image, new_image)
+ return (image,)
+
+
+class ImageTensorList:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "image1": ("IMAGE",),
+ "image2": ("IMAGE",),
+ }}
+
+ RETURN_TYPES = ("IMAGE",)
+ #OUTPUT_IS_LIST = (True,)
+ FUNCTION = "append"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Creates an image list from the input images.
+"""
+
+ def append(self, image1, image2):
+ image_list = []
+ if isinstance(image1, torch.Tensor) and isinstance(image2, torch.Tensor):
+ image_list = [image1, image2]
+ elif isinstance(image1, list) and isinstance(image2, torch.Tensor):
+ image_list = image1 + [image2]
+ elif isinstance(image1, torch.Tensor) and isinstance(image2, list):
+ image_list = [image1] + image2
+ elif isinstance(image1, list) and isinstance(image2, list):
+ image_list = image1 + image2
+ return image_list,
+
+class ImageAddMulti:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "inputcount": ("INT", {"default": 2, "min": 2, "max": 1000, "step": 1}),
+ "image_1": ("IMAGE", ),
+ "image_2": ("IMAGE", ),
+ "blending": (
+ [ 'add',
+ 'subtract',
+ 'multiply',
+ 'difference',
+ ],
+ {
+ "default": 'add'
+ }),
+ "blend_amount": ("FLOAT", {"default": 0.5, "min": 0, "max": 1, "step": 0.01}),
+ },
+ }
+
+ RETURN_TYPES = ("IMAGE",)
+ RETURN_NAMES = ("images",)
+ FUNCTION = "add"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Add blends multiple images together.
+You can set how many inputs the node has,
+with the **inputcount** and clicking update.
+"""
+
+ def add(self, inputcount, blending, blend_amount, **kwargs):
+ image = kwargs["image_1"]
+ for c in range(1, inputcount):
+ new_image = kwargs[f"image_{c + 1}"]
+ if blending == "add":
+ image = torch.add(image * blend_amount, new_image * blend_amount)
+ elif blending == "subtract":
+ image = torch.sub(image * blend_amount, new_image * blend_amount)
+ elif blending == "multiply":
+ image = torch.mul(image * blend_amount, new_image * blend_amount)
+ elif blending == "difference":
+ image = torch.sub(image, new_image)
+ return (image,)
+
+class ImageConcatMulti:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "inputcount": ("INT", {"default": 2, "min": 2, "max": 1000, "step": 1}),
+ "image_1": ("IMAGE", ),
+ "image_2": ("IMAGE", ),
+ "direction": (
+ [ 'right',
+ 'down',
+ 'left',
+ 'up',
+ ],
+ {
+ "default": 'right'
+ }),
+ "match_image_size": ("BOOLEAN", {"default": False}),
+ },
+ }
+
+ RETURN_TYPES = ("IMAGE",)
+ RETURN_NAMES = ("images",)
+ FUNCTION = "combine"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Creates an image from multiple images.
+You can set how many inputs the node has,
+with the **inputcount** and clicking update.
+"""
+
+ def combine(self, inputcount, direction, match_image_size, **kwargs):
+ image = kwargs["image_1"]
+ first_image_shape = None
+ if first_image_shape is None:
+ first_image_shape = image.shape
+ for c in range(1, inputcount):
+ new_image = kwargs[f"image_{c + 1}"]
+ image, = ImageConcanate.concatenate(self, image, new_image, direction, match_image_size, first_image_shape=first_image_shape)
+ first_image_shape = None
+ return (image,)
+
+class PreviewAnimation:
+ def __init__(self):
+ self.output_dir = folder_paths.get_temp_directory()
+ self.type = "temp"
+ self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
+ self.compress_level = 1
+
+ methods = {"default": 4, "fastest": 0, "slowest": 6}
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required":
+ {
+ "fps": ("FLOAT", {"default": 8.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
+ },
+ "optional": {
+ "images": ("IMAGE", ),
+ "masks": ("MASK", ),
+ },
+ }
+
+ RETURN_TYPES = ()
+ FUNCTION = "preview"
+ OUTPUT_NODE = True
+ CATEGORY = "KJNodes/image"
+
+ def preview(self, fps, images=None, masks=None):
+ filename_prefix = "AnimPreview"
+ full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
+ results = list()
+
+ pil_images = []
+
+ if images is not None and masks is not None:
+ for image in images:
+ i = 255. * image.cpu().numpy()
+ img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
+ pil_images.append(img)
+ for mask in masks:
+ if pil_images:
+ mask_np = mask.cpu().numpy()
+ mask_np = np.clip(mask_np * 255, 0, 255).astype(np.uint8) # Convert to values between 0 and 255
+ mask_img = Image.fromarray(mask_np, mode='L')
+ img = pil_images.pop(0) # Remove and get the first image
+ img = img.convert("RGBA") # Convert base image to RGBA
+
+ # Create a new RGBA image based on the grayscale mask
+ rgba_mask_img = Image.new("RGBA", img.size, (255, 255, 255, 255))
+ rgba_mask_img.putalpha(mask_img) # Use the mask image as the alpha channel
+
+ # Composite the RGBA mask onto the base image
+ composited_img = Image.alpha_composite(img, rgba_mask_img)
+ pil_images.append(composited_img) # Add the composited image back
+
+ elif images is not None and masks is None:
+ for image in images:
+ i = 255. * image.cpu().numpy()
+ img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
+ pil_images.append(img)
+
+ elif masks is not None and images is None:
+ for mask in masks:
+ mask_np = 255. * mask.cpu().numpy()
+ mask_img = Image.fromarray(np.clip(mask_np, 0, 255).astype(np.uint8))
+ pil_images.append(mask_img)
+ else:
+ print("PreviewAnimation: No images or masks provided")
+ return { "ui": { "images": results, "animated": (None,), "text": "empty" }}
+
+ num_frames = len(pil_images)
+
+ c = len(pil_images)
+ for i in range(0, c, num_frames):
+ file = f"{filename}_{counter:05}_.webp"
+ pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], lossless=False, quality=80, method=4)
+ results.append({
+ "filename": file,
+ "subfolder": subfolder,
+ "type": self.type
+ })
+ counter += 1
+
+ animated = num_frames != 1
+ return { "ui": { "images": results, "animated": (animated,), "text": [f"{num_frames}x{pil_images[0].size[0]}x{pil_images[0].size[1]}"] } }
+
+class ImageResizeKJ:
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": ("IMAGE",),
+ "width": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
+ "height": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
+ "upscale_method": (s.upscale_methods,),
+ "keep_proportion": ("BOOLEAN", { "default": False }),
+ "divisible_by": ("INT", { "default": 2, "min": 0, "max": 512, "step": 1, }),
+ },
+ "optional" : {
+ "width_input": ("INT", { "forceInput": True}),
+ "height_input": ("INT", { "forceInput": True}),
+ "get_image_size": ("IMAGE",),
+ "crop": (["disabled","center"],),
+ }
+ }
+
+ RETURN_TYPES = ("IMAGE", "INT", "INT",)
+ RETURN_NAMES = ("IMAGE", "width", "height",)
+ FUNCTION = "resize"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """
+Resizes the image to the specified width and height.
+Size can be retrieved from the inputs, and the final scale
+is determined in this order of importance:
+- get_image_size
+- width_input and height_input
+- width and height widgets
+
+Keep proportions keeps the aspect ratio of the image, by
+highest dimension.
+"""
+
+ def resize(self, image, width, height, keep_proportion, upscale_method, divisible_by,
+ width_input=None, height_input=None, get_image_size=None, crop="disabled"):
+ B, H, W, C = image.shape
+
+ if width_input:
+ width = width_input
+ if height_input:
+ height = height_input
+ if get_image_size is not None:
+ _, height, width, _ = get_image_size.shape
+
+ if keep_proportion and get_image_size is None:
+ # If one of the dimensions is zero, calculate it to maintain the aspect ratio
+ if width == 0 and height != 0:
+ ratio = height / H
+ width = round(W * ratio)
+ elif height == 0 and width != 0:
+ ratio = width / W
+ height = round(H * ratio)
+ elif width != 0 and height != 0:
+ # Scale based on which dimension is smaller in proportion to the desired dimensions
+ ratio = min(width / W, height / H)
+ width = round(W * ratio)
+ height = round(H * ratio)
+ else:
+ if width == 0:
+ width = W
+ if height == 0:
+ height = H
+
+ if divisible_by > 1 and get_image_size is None:
+ width = width - (width % divisible_by)
+ height = height - (height % divisible_by)
+
+ image = image.movedim(-1,1)
+ image = common_upscale(image, width, height, upscale_method, crop)
+ image = image.movedim(1,-1)
+
+ return(image, image.shape[2], image.shape[1],)
+import pathlib
+class LoadAndResizeImage:
+ _color_channels = ["alpha", "red", "green", "blue"]
+ @classmethod
+ def INPUT_TYPES(s):
+ input_dir = folder_paths.get_input_directory()
+ files = [f.name for f in pathlib.Path(input_dir).iterdir() if f.is_file()]
+ return {"required":
+ {
+ "image": (sorted(files), {"image_upload": True}),
+ "resize": ("BOOLEAN", { "default": False }),
+ "width": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, }),
+ "height": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, }),
+ "repeat": ("INT", { "default": 1, "min": 1, "max": 4096, "step": 1, }),
+ "keep_proportion": ("BOOLEAN", { "default": False }),
+ "divisible_by": ("INT", { "default": 2, "min": 0, "max": 512, "step": 1, }),
+ "mask_channel": (s._color_channels, {"tooltip": "Channel to use for the mask output"}),
+ "background_color": ("STRING", { "default": "", "tooltip": "Fills the alpha channel with the specified color."}),
+ },
+ }
+
+ CATEGORY = "KJNodes/image"
+ RETURN_TYPES = ("IMAGE", "MASK", "INT", "INT", "STRING",)
+ RETURN_NAMES = ("image", "mask", "width", "height","image_path",)
+ FUNCTION = "load_image"
+
+ def load_image(self, image, resize, width, height, repeat, keep_proportion, divisible_by, mask_channel, background_color):
+ from PIL import ImageColor, Image, ImageOps, ImageSequence
+ import numpy as np
+ import torch
+ image_path = folder_paths.get_annotated_filepath(image)
+
+ import node_helpers
+ img = node_helpers.pillow(Image.open, image_path)
+
+ # Process the background_color
+ if background_color:
+ try:
+ # Try to parse as RGB tuple
+ bg_color_rgba = tuple(int(x.strip()) for x in background_color.split(','))
+ except ValueError:
+ # If parsing fails, it might be a hex color or named color
+ if background_color.startswith('#') or background_color.lower() in ImageColor.colormap:
+ bg_color_rgba = ImageColor.getrgb(background_color)
+ else:
+ raise ValueError(f"Invalid background color: {background_color}")
+
+ bg_color_rgba += (255,) # Add alpha channel
+ else:
+ bg_color_rgba = None # No background color specified
+
+ output_images = []
+ output_masks = []
+ w, h = None, None
+
+ excluded_formats = ['MPO']
+
+ W, H = img.size
+ if resize:
+ if keep_proportion:
+ ratio = min(width / W, height / H)
+ width = round(W * ratio)
+ height = round(H * ratio)
+ else:
+ if width == 0:
+ width = W
+ if height == 0:
+ height = H
+
+ if divisible_by > 1:
+ width = width - (width % divisible_by)
+ height = height - (height % divisible_by)
+ else:
+ width, height = W, H
+
+ for frame in ImageSequence.Iterator(img):
+ frame = node_helpers.pillow(ImageOps.exif_transpose, frame)
+
+ if frame.mode == 'I':
+ frame = frame.point(lambda i: i * (1 / 255))
+
+ if frame.mode == 'P':
+ frame = frame.convert("RGBA")
+ elif 'A' in frame.getbands():
+ frame = frame.convert("RGBA")
+
+ # Extract alpha channel if it exists
+ if 'A' in frame.getbands() and bg_color_rgba:
+ alpha_mask = np.array(frame.getchannel('A')).astype(np.float32) / 255.0
+ alpha_mask = 1. - torch.from_numpy(alpha_mask)
+ bg_image = Image.new("RGBA", frame.size, bg_color_rgba)
+ # Composite the frame onto the background
+ frame = Image.alpha_composite(bg_image, frame)
+ else:
+ alpha_mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
+
+ image = frame.convert("RGB")
+
+ if len(output_images) == 0:
+ w = image.size[0]
+ h = image.size[1]
+
+ if image.size[0] != w or image.size[1] != h:
+ continue
+ if resize:
+ image = image.resize((width, height), Image.Resampling.BILINEAR)
+
+ image = np.array(image).astype(np.float32) / 255.0
+ image = torch.from_numpy(image)[None,]
+
+ c = mask_channel[0].upper()
+ if c in frame.getbands():
+ if resize:
+ frame = frame.resize((width, height), Image.Resampling.BILINEAR)
+ mask = np.array(frame.getchannel(c)).astype(np.float32) / 255.0
+ mask = torch.from_numpy(mask)
+ if c == 'A' and bg_color_rgba:
+ mask = alpha_mask
+ elif c == 'A':
+ mask = 1. - mask
+ else:
+ mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
+
+ output_images.append(image)
+ output_masks.append(mask.unsqueeze(0))
+
+ if len(output_images) > 1 and img.format not in excluded_formats:
+ output_image = torch.cat(output_images, dim=0)
+ output_mask = torch.cat(output_masks, dim=0)
+ else:
+ output_image = output_images[0]
+ output_mask = output_masks[0]
+ if repeat > 1:
+ output_image = output_image.repeat(repeat, 1, 1, 1)
+ output_mask = output_mask.repeat(repeat, 1, 1)
+
+ return (output_image, output_mask, width, height, image_path)
+
+
+ # @classmethod
+ # def IS_CHANGED(s, image, **kwargs):
+ # image_path = folder_paths.get_annotated_filepath(image)
+ # m = hashlib.sha256()
+ # with open(image_path, 'rb') as f:
+ # m.update(f.read())
+ # return m.digest().hex()
+
+ @classmethod
+ def VALIDATE_INPUTS(s, image):
+ if not folder_paths.exists_annotated_filepath(image):
+ return "Invalid image file: {}".format(image)
+
+ return True
+
+class LoadImagesFromFolderKJ:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "folder": ("STRING", {"default": ""}),
+ "width": ("INT", {"default": 1024, "min": 64, "step": 1}),
+ "height": ("INT", {"default": 1024, "min": 64, "step": 1}),
+ "keep_aspect_ratio": (["crop", "pad", "stretch",],),
+ },
+ "optional": {
+ "image_load_cap": ("INT", {"default": 0, "min": 0, "step": 1}),
+ "start_index": ("INT", {"default": 0, "min": 0, "step": 1}),
+ "include_subfolders": ("BOOLEAN", {"default": False}),
+ }
+ }
+
+ RETURN_TYPES = ("IMAGE", "MASK", "INT", "STRING",)
+ RETURN_NAMES = ("image", "mask", "count", "image_path",)
+ FUNCTION = "load_images"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = """Loads images from a folder into a batch, images are resized and loaded into a batch."""
+
+ def load_images(self, folder, width, height, image_load_cap, start_index, keep_aspect_ratio, include_subfolders=False):
+ if not os.path.isdir(folder):
+ raise FileNotFoundError(f"Folder '{folder} cannot be found.'")
+
+ valid_extensions = ['.jpg', '.jpeg', '.png', '.webp']
+ image_paths = []
+ if include_subfolders:
+ for root, _, files in os.walk(folder):
+ for file in files:
+ if any(file.lower().endswith(ext) for ext in valid_extensions):
+ image_paths.append(os.path.join(root, file))
+ else:
+ for file in os.listdir(folder):
+ if any(file.lower().endswith(ext) for ext in valid_extensions):
+ image_paths.append(os.path.join(folder, file))
+
+ dir_files = sorted(image_paths)
+
+ if len(dir_files) == 0:
+ raise FileNotFoundError(f"No files in directory '{folder}'.")
+
+ # start at start_index
+ dir_files = dir_files[start_index:]
+
+ images = []
+ masks = []
+ image_path_list = []
+
+ limit_images = False
+ if image_load_cap > 0:
+ limit_images = True
+ image_count = 0
+
+ for image_path in dir_files:
+ if os.path.isdir(image_path):
+ continue
+ if limit_images and image_count >= image_load_cap:
+ break
+ i = Image.open(image_path)
+ i = ImageOps.exif_transpose(i)
+
+ # Resize image to maximum dimensions
+ if i.size != (width, height):
+ i = self.resize_with_aspect_ratio(i, width, height, keep_aspect_ratio)
+
+
+ image = i.convert("RGB")
+ image = np.array(image).astype(np.float32) / 255.0
+ image = torch.from_numpy(image)[None,]
+
+ if 'A' in i.getbands():
+ mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
+ mask = 1. - torch.from_numpy(mask)
+ if mask.shape != (height, width):
+ mask = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(0),
+ size=(height, width),
+ mode='bilinear',
+ align_corners=False).squeeze()
+ else:
+ mask = torch.zeros((height, width), dtype=torch.float32, device="cpu")
+
+ images.append(image)
+ masks.append(mask)
+ image_path_list.append(image_path)
+ image_count += 1
+
+ if len(images) == 1:
+ return (images[0], masks[0], 1, image_path_list)
+
+ elif len(images) > 1:
+ image1 = images[0]
+ mask1 = masks[0].unsqueeze(0)
+
+ for image2 in images[1:]:
+ image1 = torch.cat((image1, image2), dim=0)
+
+ for mask2 in masks[1:]:
+ mask1 = torch.cat((mask1, mask2.unsqueeze(0)), dim=0)
+
+ return (image1, mask1, len(images), image_path_list)
+ def resize_with_aspect_ratio(self, img, width, height, mode):
+ if mode == "stretch":
+ return img.resize((width, height), Image.Resampling.LANCZOS)
+
+ img_width, img_height = img.size
+ aspect_ratio = img_width / img_height
+ target_ratio = width / height
+
+ if mode == "crop":
+ # Calculate dimensions for center crop
+ if aspect_ratio > target_ratio:
+ # Image is wider - crop width
+ new_width = int(height * aspect_ratio)
+ img = img.resize((new_width, height), Image.Resampling.LANCZOS)
+ left = (new_width - width) // 2
+ return img.crop((left, 0, left + width, height))
+ else:
+ # Image is taller - crop height
+ new_height = int(width / aspect_ratio)
+ img = img.resize((width, new_height), Image.Resampling.LANCZOS)
+ top = (new_height - height) // 2
+ return img.crop((0, top, width, top + height))
+
+ elif mode == "pad":
+ pad_color = self.get_edge_color(img)
+ # Calculate dimensions for padding
+ if aspect_ratio > target_ratio:
+ # Image is wider - pad height
+ new_height = int(width / aspect_ratio)
+ img = img.resize((width, new_height), Image.Resampling.LANCZOS)
+ padding = (height - new_height) // 2
+ padded = Image.new('RGBA', (width, height), pad_color)
+ padded.paste(img, (0, padding))
+ return padded
+ else:
+ # Image is taller - pad width
+ new_width = int(height * aspect_ratio)
+ img = img.resize((new_width, height), Image.Resampling.LANCZOS)
+ padding = (width - new_width) // 2
+ padded = Image.new('RGBA', (width, height), pad_color)
+ padded.paste(img, (padding, 0))
+ return padded
+ def get_edge_color(self, img):
+ from PIL import ImageStat
+ """Sample edges and return dominant color"""
+ width, height = img.size
+ img = img.convert('RGBA')
+
+ # Create 1-pixel high/wide images from edges
+ top = img.crop((0, 0, width, 1))
+ bottom = img.crop((0, height-1, width, height))
+ left = img.crop((0, 0, 1, height))
+ right = img.crop((width-1, 0, width, height))
+
+ # Combine edges into single image
+ edges = Image.new('RGBA', (width*2 + height*2, 1))
+ edges.paste(top, (0, 0))
+ edges.paste(bottom, (width, 0))
+ edges.paste(left.resize((height, 1)), (width*2, 0))
+ edges.paste(right.resize((height, 1)), (width*2 + height, 0))
+
+ # Get median color
+ stat = ImageStat.Stat(edges)
+ median = tuple(map(int, stat.median))
+ return median
+
+
+class ImageGridtoBatch:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "image": ("IMAGE", ),
+ "columns": ("INT", {"default": 3, "min": 1, "max": 8, "tooltip": "The number of columns in the grid."}),
+ "rows": ("INT", {"default": 0, "min": 1, "max": 8, "tooltip": "The number of rows in the grid. Set to 0 for automatic calculation."}),
+ }
+ }
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "decompose"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = "Converts a grid of images to a batch of images."
+
+ def decompose(self, image, columns, rows):
+ B, H, W, C = image.shape
+ print("input size: ", image.shape)
+
+ # Calculate cell width, rounding down
+ cell_width = W // columns
+
+ if rows == 0:
+ # If rows is 0, calculate number of full rows
+ rows = H // cell_height
+ else:
+ # If rows is specified, adjust cell_height
+ cell_height = H // rows
+
+ # Crop the image to fit full cells
+ image = image[:, :rows*cell_height, :columns*cell_width, :]
+
+ # Reshape and permute the image to get the grid
+ image = image.view(B, rows, cell_height, columns, cell_width, C)
+ image = image.permute(0, 1, 3, 2, 4, 5).contiguous()
+ image = image.view(B, rows * columns, cell_height, cell_width, C)
+
+ # Reshape to the final batch tensor
+ img_tensor = image.view(-1, cell_height, cell_width, C)
+
+ return (img_tensor,)
+
+class SaveImageKJ:
+ def __init__(self):
+ self.type = "output"
+ self.prefix_append = ""
+ self.compress_level = 4
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "images": ("IMAGE", {"tooltip": "The images to save."}),
+ "filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}),
+ "output_folder": ("STRING", {"default": "output", "tooltip": "The folder to save the images to."}),
+ },
+ "optional": {
+ "caption_file_extension": ("STRING", {"default": ".txt", "tooltip": "The extension for the caption file."}),
+ "caption": ("STRING", {"forceInput": True, "tooltip": "string to save as .txt file"}),
+ },
+ "hidden": {
+ "prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"
+ },
+ }
+
+ RETURN_TYPES = ("STRING",)
+ RETURN_NAMES = ("filename",)
+ FUNCTION = "save_images"
+
+ OUTPUT_NODE = True
+
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = "Saves the input images to your ComfyUI output directory."
+
+ def save_images(self, images, output_folder, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None, caption=None, caption_file_extension=".txt"):
+ filename_prefix += self.prefix_append
+
+ if os.path.isabs(output_folder):
+ if not os.path.exists(output_folder):
+ os.makedirs(output_folder, exist_ok=True)
+ full_output_folder = output_folder
+ _, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_folder, images[0].shape[1], images[0].shape[0])
+ else:
+ self.output_dir = folder_paths.get_output_directory()
+ full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
+
+ results = list()
+ for (batch_number, image) in enumerate(images):
+ i = 255. * image.cpu().numpy()
+ img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
+ metadata = None
+ if not args.disable_metadata:
+ metadata = PngInfo()
+ if prompt is not None:
+ metadata.add_text("prompt", json.dumps(prompt))
+ if extra_pnginfo is not None:
+ for x in extra_pnginfo:
+ metadata.add_text(x, json.dumps(extra_pnginfo[x]))
+
+ filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
+ base_file_name = f"{filename_with_batch_num}_{counter:05}_"
+ file = f"{base_file_name}.png"
+ img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level)
+ results.append({
+ "filename": file,
+ "subfolder": subfolder,
+ "type": self.type
+ })
+ if caption is not None:
+ txt_file = base_file_name + caption_file_extension
+ file_path = os.path.join(full_output_folder, txt_file)
+ with open(file_path, 'w') as f:
+ f.write(caption)
+
+ counter += 1
+
+ return file,
+
+class SaveStringKJ:
+ def __init__(self):
+ self.output_dir = folder_paths.get_output_directory()
+ self.type = "output"
+ self.prefix_append = ""
+ self.compress_level = 4
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "string": ("STRING", {"forceInput": True, "tooltip": "string to save as .txt file"}),
+ "filename_prefix": ("STRING", {"default": "text", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}),
+ "output_folder": ("STRING", {"default": "output", "tooltip": "The folder to save the images to."}),
+ },
+ "optional": {
+ "file_extension": ("STRING", {"default": ".txt", "tooltip": "The extension for the caption file."}),
+ },
+ }
+
+ RETURN_TYPES = ("STRING",)
+ RETURN_NAMES = ("filename",)
+ FUNCTION = "save_string"
+
+ OUTPUT_NODE = True
+
+ CATEGORY = "KJNodes/misc"
+ DESCRIPTION = "Saves the input string to your ComfyUI output directory."
+
+ def save_string(self, string, output_folder, filename_prefix="text", file_extension=".txt"):
+ filename_prefix += self.prefix_append
+
+ full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
+ if output_folder != "output":
+ if not os.path.exists(output_folder):
+ os.makedirs(output_folder, exist_ok=True)
+ full_output_folder = output_folder
+
+ base_file_name = f"{filename_prefix}_{counter:05}_"
+ results = list()
+
+ txt_file = base_file_name + file_extension
+ file_path = os.path.join(full_output_folder, txt_file)
+ with open(file_path, 'w') as f:
+ f.write(string)
+
+ return results,
+
+to_pil_image = T.ToPILImage()
+
+class FastPreview:
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "image": ("IMAGE", ),
+ "format": (["JPEG", "PNG", "WEBP"], {"default": "JPEG"}),
+ "quality" : ("INT", {"default": 75, "min": 1, "max": 100, "step": 1}),
+ },
+ }
+
+ RETURN_TYPES = ()
+ FUNCTION = "preview"
+ CATEGORY = "KJNodes/experimental"
+ OUTPUT_NODE = True
+ DESCRIPTION = "Experimental node for faster image previews by displaying through base64 it without saving to disk."
+
+ def preview(self, image, format, quality):
+ pil_image = to_pil_image(image[0].permute(2, 0, 1))
+
+ with io.BytesIO() as buffered:
+ pil_image.save(buffered, format=format, quality=quality)
+ img_bytes = buffered.getvalue()
+
+ img_base64 = base64.b64encode(img_bytes).decode('utf-8')
+
+ return {
+ "ui": {"bg_image": [img_base64]},
+ "result": ()
+ }
+
+class ImageCropByMaskAndResize:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": ("IMAGE", ),
+ "mask": ("MASK", ),
+ "base_resolution": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, }),
+ "padding": ("INT", { "default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
+ "min_crop_resolution": ("INT", { "default": 128, "min": 0, "max": MAX_RESOLUTION, "step": 8, }),
+ "max_crop_resolution": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, }),
+
+ },
+ }
+
+ RETURN_TYPES = ("IMAGE", "MASK", "BBOX", )
+ RETURN_NAMES = ("images", "masks", "bbox",)
+ FUNCTION = "crop"
+ CATEGORY = "KJNodes/image"
+
+ def crop_by_mask(self, mask, padding=0, min_crop_resolution=None, max_crop_resolution=None):
+ iy, ix = (mask == 1).nonzero(as_tuple=True)
+ h0, w0 = mask.shape
+
+ if iy.numel() == 0:
+ x_c = w0 / 2.0
+ y_c = h0 / 2.0
+ width = 0
+ height = 0
+ else:
+ x_min = ix.min().item()
+ x_max = ix.max().item()
+ y_min = iy.min().item()
+ y_max = iy.max().item()
+
+ width = x_max - x_min
+ height = y_max - y_min
+
+ if width > w0 or height > h0:
+ raise Exception("Masked area out of bounds")
+
+ x_c = (x_min + x_max) / 2.0
+ y_c = (y_min + y_max) / 2.0
+
+ if min_crop_resolution:
+ width = max(width, min_crop_resolution)
+ height = max(height, min_crop_resolution)
+
+ if max_crop_resolution:
+ width = min(width, max_crop_resolution)
+ height = min(height, max_crop_resolution)
+
+ if w0 <= width:
+ x0 = 0
+ w = w0
+ else:
+ x0 = max(0, x_c - width / 2 - padding)
+ w = width + 2 * padding
+ if x0 + w > w0:
+ x0 = w0 - w
+
+ if h0 <= height:
+ y0 = 0
+ h = h0
+ else:
+ y0 = max(0, y_c - height / 2 - padding)
+ h = height + 2 * padding
+ if y0 + h > h0:
+ y0 = h0 - h
+
+ return (int(x0), int(y0), int(w), int(h))
+
+ def crop(self, image, mask, base_resolution, padding=0, min_crop_resolution=128, max_crop_resolution=512):
+ mask = mask.round()
+ image_list = []
+ mask_list = []
+ bbox_list = []
+
+ # First, collect all bounding boxes
+ bbox_params = []
+ aspect_ratios = []
+ for i in range(image.shape[0]):
+ x0, y0, w, h = self.crop_by_mask(mask[i], padding, min_crop_resolution, max_crop_resolution)
+ bbox_params.append((x0, y0, w, h))
+ aspect_ratios.append(w / h)
+
+ # Find maximum width and height
+ max_w = max([w for x0, y0, w, h in bbox_params])
+ max_h = max([h for x0, y0, w, h in bbox_params])
+ max_aspect_ratio = max(aspect_ratios)
+
+ # Ensure dimensions are divisible by 16
+ max_w = (max_w + 15) // 16 * 16
+ max_h = (max_h + 15) // 16 * 16
+ # Calculate common target dimensions
+ if max_aspect_ratio > 1:
+ target_width = base_resolution
+ target_height = int(base_resolution / max_aspect_ratio)
+ else:
+ target_height = base_resolution
+ target_width = int(base_resolution * max_aspect_ratio)
+
+ for i in range(image.shape[0]):
+ x0, y0, w, h = bbox_params[i]
+
+ # Adjust cropping to use maximum width and height
+ x_center = x0 + w / 2
+ y_center = y0 + h / 2
+
+ x0_new = int(max(0, x_center - max_w / 2))
+ y0_new = int(max(0, y_center - max_h / 2))
+ x1_new = int(min(x0_new + max_w, image.shape[2]))
+ y1_new = int(min(y0_new + max_h, image.shape[1]))
+ x0_new = x1_new - max_w
+ y0_new = y1_new - max_h
+
+ cropped_image = image[i][y0_new:y1_new, x0_new:x1_new, :]
+ cropped_mask = mask[i][y0_new:y1_new, x0_new:x1_new]
+
+ # Ensure dimensions are divisible by 16
+ target_width = (target_width + 15) // 16 * 16
+ target_height = (target_height + 15) // 16 * 16
+
+ cropped_image = cropped_image.unsqueeze(0).movedim(-1, 1) # Move C to the second position (B, C, H, W)
+ cropped_image = common_upscale(cropped_image, target_width, target_height, "lanczos", "disabled")
+ cropped_image = cropped_image.movedim(1, -1).squeeze(0)
+
+ cropped_mask = cropped_mask.unsqueeze(0).unsqueeze(0)
+ cropped_mask = common_upscale(cropped_mask, target_width, target_height, 'bilinear', "disabled")
+ cropped_mask = cropped_mask.squeeze(0).squeeze(0)
+
+ image_list.append(cropped_image)
+ mask_list.append(cropped_mask)
+ bbox_list.append((x0_new, y0_new, x1_new, y1_new))
+
+
+ return (torch.stack(image_list), torch.stack(mask_list), bbox_list)
+
+class ImageCropByMask:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": ("IMAGE", ),
+ "mask": ("MASK", ),
+ },
+ }
+
+ RETURN_TYPES = ("IMAGE", )
+ RETURN_NAMES = ("image", )
+ FUNCTION = "crop"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = "Crops the input images based on the provided mask."
+
+ def crop(self, image, mask):
+ B, H, W, C = image.shape
+ mask = mask.round()
+
+ # Find bounding box for each batch
+ crops = []
+
+ for b in range(B):
+ # Get coordinates of non-zero elements
+ rows = torch.any(mask[min(b, mask.shape[0]-1)] > 0, dim=1)
+ cols = torch.any(mask[min(b, mask.shape[0]-1)] > 0, dim=0)
+
+ # Find boundaries
+ y_min, y_max = torch.where(rows)[0][[0, -1]]
+ x_min, x_max = torch.where(cols)[0][[0, -1]]
+
+ # Crop image and mask
+ crop = image[b:b+1, y_min:y_max+1, x_min:x_max+1, :]
+ crops.append(crop)
+
+ # Stack results back together
+ cropped_images = torch.cat(crops, dim=0)
+
+ return (cropped_images, )
+
+
+
+class ImageUncropByMask:
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required":
+ {
+ "destination": ("IMAGE",),
+ "source": ("IMAGE",),
+ "mask": ("MASK",),
+ "bbox": ("BBOX",),
+ },
+ }
+
+ CATEGORY = "KJNodes/image"
+ RETURN_TYPES = ("IMAGE",)
+ RETURN_NAMES = ("image",)
+ FUNCTION = "uncrop"
+
+ def uncrop(self, destination, source, mask, bbox=None):
+
+ output_list = []
+
+ B, H, W, C = destination.shape
+
+ for i in range(source.shape[0]):
+ x0, y0, x1, y1 = bbox[i]
+ bbox_height = y1 - y0
+ bbox_width = x1 - x0
+
+ # Resize source image to match the bounding box dimensions
+ #resized_source = F.interpolate(source[i].unsqueeze(0).movedim(-1, 1), size=(bbox_height, bbox_width), mode='bilinear', align_corners=False)
+ resized_source = common_upscale(source[i].unsqueeze(0).movedim(-1, 1), bbox_width, bbox_height, "lanczos", "disabled")
+ resized_source = resized_source.movedim(1, -1).squeeze(0)
+
+ # Resize mask to match the bounding box dimensions
+ resized_mask = common_upscale(mask[i].unsqueeze(0).unsqueeze(0), bbox_width, bbox_height, "bilinear", "disabled")
+ resized_mask = resized_mask.squeeze(0).squeeze(0)
+
+ # Calculate padding values
+ pad_left = x0
+ pad_right = W - x1
+ pad_top = y0
+ pad_bottom = H - y1
+
+ # Pad the resized source image and mask to fit the destination dimensions
+ padded_source = F.pad(resized_source, pad=(0, 0, pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)
+ padded_mask = F.pad(resized_mask, pad=(pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)
+
+ # Ensure the padded mask has the correct shape
+ padded_mask = padded_mask.unsqueeze(2).expand(-1, -1, destination[i].shape[2])
+ # Ensure the padded source has the correct shape
+ padded_source = padded_source.unsqueeze(2).expand(-1, -1, -1, destination[i].shape[2]).squeeze(2)
+
+ # Combine the destination and padded source images using the mask
+ result = destination[i] * (1.0 - padded_mask) + padded_source * padded_mask
+
+ output_list.append(result)
+
+
+ return (torch.stack(output_list),)
+
+class ImageCropByMaskBatch:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "image": ("IMAGE", ),
+ "masks": ("MASK", ),
+ "width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, }),
+ "height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, }),
+ "padding": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1, }),
+ "preserve_size": ("BOOLEAN", {"default": False}),
+ "bg_color": ("STRING", {"default": "0, 0, 0", "tooltip": "Color as RGB values in range 0-255, separated by commas."}),
+ }
+ }
+
+ RETURN_TYPES = ("IMAGE", "MASK", )
+ RETURN_NAMES = ("images", "masks",)
+ FUNCTION = "crop"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = "Crops the input images based on the provided masks."
+
+ def crop(self, image, masks, width, height, bg_color, padding, preserve_size):
+ B, H, W, C = image.shape
+ BM, HM, WM = masks.shape
+ mask_count = BM
+ if HM != H or WM != W:
+ masks = F.interpolate(masks.unsqueeze(1), size=(H, W), mode='nearest-exact').squeeze(1)
+ print(masks.shape)
+ output_images = []
+ output_masks = []
+
+ bg_color = [int(x.strip())/255.0 for x in bg_color.split(",")]
+
+ # For each mask
+ for i in range(mask_count):
+ curr_mask = masks[i]
+
+ # Find bounds
+ y_indices, x_indices = torch.nonzero(curr_mask, as_tuple=True)
+ if len(y_indices) == 0 or len(x_indices) == 0:
+ continue
+
+ # Get exact bounds with padding
+ min_y = max(0, y_indices.min().item() - padding)
+ max_y = min(H, y_indices.max().item() + 1 + padding)
+ min_x = max(0, x_indices.min().item() - padding)
+ max_x = min(W, x_indices.max().item() + 1 + padding)
+
+ # Ensure mask has correct shape for multiplication
+ curr_mask = curr_mask.unsqueeze(-1).expand(-1, -1, C)
+
+ # Crop image and mask together
+ cropped_img = image[0, min_y:max_y, min_x:max_x, :]
+ cropped_mask = curr_mask[min_y:max_y, min_x:max_x, :]
+
+ crop_h, crop_w = cropped_img.shape[0:2]
+ new_w = crop_w
+ new_h = crop_h
+
+ if not preserve_size or crop_w > width or crop_h > height:
+ scale = min(width/crop_w, height/crop_h)
+ new_w = int(crop_w * scale)
+ new_h = int(crop_h * scale)
+
+ # Resize RGB
+ resized_img = common_upscale(cropped_img.permute(2,0,1).unsqueeze(0), new_w, new_h, "lanczos", "disabled").squeeze(0).permute(1,2,0)
+ resized_mask = torch.nn.functional.interpolate(
+ cropped_mask.permute(2,0,1).unsqueeze(0),
+ size=(new_h, new_w),
+ mode='nearest'
+ ).squeeze(0).permute(1,2,0)
+ else:
+ resized_img = cropped_img
+ resized_mask = cropped_mask
+
+ # Create empty tensors
+ new_img = torch.zeros((height, width, 3), dtype=image.dtype)
+ new_mask = torch.zeros((height, width), dtype=image.dtype)
+
+ # Pad both
+ pad_x = (width - new_w) // 2
+ pad_y = (height - new_h) // 2
+ new_img[pad_y:pad_y+new_h, pad_x:pad_x+new_w, :] = resized_img
+ if len(resized_mask.shape) == 3:
+ resized_mask = resized_mask[:,:,0] # Take first channel if 3D
+ new_mask[pad_y:pad_y+new_h, pad_x:pad_x+new_w] = resized_mask
+
+ output_images.append(new_img)
+ output_masks.append(new_mask)
+
+ if not output_images:
+ return (torch.zeros((0, height, width, 3), dtype=image.dtype),)
+
+ out_rgb = torch.stack(output_images, dim=0)
+ out_masks = torch.stack(output_masks, dim=0)
+
+ # Apply mask to RGB
+ mask_expanded = out_masks.unsqueeze(-1).expand(-1, -1, -1, 3)
+ background_color = torch.tensor(bg_color, dtype=torch.float32, device=image.device)
+ out_rgb = out_rgb * mask_expanded + background_color * (1 - mask_expanded)
+
+ return (out_rgb, out_masks)
+
+class ImagePadKJ:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "image": ("IMAGE", ),
+ "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
+ "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
+ "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
+ "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
+ "extra_padding": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
+ "pad_mode": (["edge", "color"],),
+ "color": ("STRING", {"default": "0, 0, 0", "tooltip": "Color as RGB values in range 0-255, separated by commas."}),
+ }
+ , "optional": {
+ "masks": ("MASK", ),
+ }
+ }
+
+ RETURN_TYPES = ("IMAGE", "MASK", )
+ RETURN_NAMES = ("images", "masks",)
+ FUNCTION = "pad"
+ CATEGORY = "KJNodes/image"
+ DESCRIPTION = "Pad the input image and optionally mask with the specified padding."
+
+ def pad(self, image, left, right, top, bottom, extra_padding, color, pad_mode, mask=None):
+ B, H, W, C = image.shape
+
+ # Resize masks to image dimensions if necessary
+ if mask is not None:
+ BM, HM, WM = mask.shape
+ if HM != H or WM != W:
+ mask = F.interpolate(mask.unsqueeze(1), size=(H, W), mode='nearest-exact').squeeze(1)
+
+ # Parse background color
+ bg_color = [int(x.strip())/255.0 for x in color.split(",")]
+ if len(bg_color) == 1:
+ bg_color = bg_color * 3 # Grayscale to RGB
+ bg_color = torch.tensor(bg_color, dtype=image.dtype, device=image.device)
+
+ # Calculate padding sizes with extra padding
+ pad_left = left + extra_padding
+ pad_right = right + extra_padding
+ pad_top = top + extra_padding
+ pad_bottom = bottom + extra_padding
+
+ padded_width = W + pad_left + pad_right
+ padded_height = H + pad_top + pad_bottom
+ out_image = torch.zeros((B, padded_height, padded_width, C), dtype=image.dtype, device=image.device)
+
+ # Fill padded areas
+ for b in range(B):
+ if pad_mode == "edge":
+ # Pad with edge color
+ # Define edge pixels
+ top_edge = image[b, 0, :, :]
+ bottom_edge = image[b, H-1, :, :]
+ left_edge = image[b, :, 0, :]
+ right_edge = image[b, :, W-1, :]
+
+ # Fill borders with edge colors
+ out_image[b, :pad_top, :, :] = top_edge.mean(dim=0)
+ out_image[b, pad_top+H:, :, :] = bottom_edge.mean(dim=0)
+ out_image[b, :, :pad_left, :] = left_edge.mean(dim=0)
+ out_image[b, :, pad_left+W:, :] = right_edge.mean(dim=0)
+ out_image[b, pad_top:pad_top+H, pad_left:pad_left+W, :] = image[b]
+ else:
+ # Pad with specified background color
+ out_image[b, :, :, :] = bg_color.unsqueeze(0).unsqueeze(0) # Expand for H and W dimensions
+ out_image[b, pad_top:pad_top+H, pad_left:pad_left+W, :] = image[b]
+
+ if mask is not None:
+ out_masks = torch.zeros((BM, padded_height, padded_width), dtype=mask.dtype, device=mask.device)
+ for m in range(BM):
+ out_masks[m, pad_top:pad_top+H, pad_left:pad_left+W] = mask[m]
+ else:
+ out_masks = torch.zeros((1, padded_height, padded_width), dtype=image.dtype, device=image.device)
+
+ return (out_image, out_masks)
diff --git a/custom_nodes/ComfyUI-KJNodes-main/nodes/intrinsic_lora_nodes.py b/custom_nodes/ComfyUI-KJNodes-main/nodes/intrinsic_lora_nodes.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8f125363836cc7721b4b61d100702594522d389
--- /dev/null
+++ b/custom_nodes/ComfyUI-KJNodes-main/nodes/intrinsic_lora_nodes.py
@@ -0,0 +1,115 @@
+import folder_paths
+import os
+import torch
+import torch.nn.functional as F
+from comfy.utils import ProgressBar, load_torch_file
+import comfy.sample
+from nodes import CLIPTextEncode
+
+script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+folder_paths.add_model_folder_path("intrinsic_loras", os.path.join(script_directory, "intrinsic_loras"))
+
+class Intrinsic_lora_sampling:
+ def __init__(self):
+ self.loaded_lora = None
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": { "model": ("MODEL",),
+ "lora_name": (folder_paths.get_filename_list("intrinsic_loras"), ),
+ "task": (
+ [
+ 'depth map',
+ 'surface normals',
+ 'albedo',
+ 'shading',
+ ],
+ {
+ "default": 'depth map'
+ }),
+ "text": ("STRING", {"multiline": True, "default": ""}),
+ "clip": ("CLIP", ),
+ "vae": ("VAE", ),
+ "per_batch": ("INT", {"default": 16, "min": 1, "max": 4096, "step": 1}),
+ },
+ "optional": {
+ "image": ("IMAGE",),
+ "optional_latent": ("LATENT",),
+ },
+ }
+
+ RETURN_TYPES = ("IMAGE", "LATENT",)
+ FUNCTION = "onestepsample"
+ CATEGORY = "KJNodes"
+ DESCRIPTION = """
+Sampler to use the intrinsic loras:
+https://github.com/duxiaodan/intrinsic-lora
+These LoRAs are tiny and thus included
+with this node pack.
+"""
+
+ def onestepsample(self, model, lora_name, clip, vae, text, task, per_batch, image=None, optional_latent=None):
+ pbar = ProgressBar(3)
+
+ if optional_latent is None:
+ image_list = []
+ for start_idx in range(0, image.shape[0], per_batch):
+ sub_pixels = vae.vae_encode_crop_pixels(image[start_idx:start_idx+per_batch])
+ image_list.append(vae.encode(sub_pixels[:,:,:,:3]))
+ sample = torch.cat(image_list, dim=0)
+ else:
+ sample = optional_latent["samples"]
+ noise = torch.zeros(sample.size(), dtype=sample.dtype, layout=sample.layout, device="cpu")
+ prompt = task + "," + text
+ positive, = CLIPTextEncode.encode(self, clip, prompt)
+ negative = positive #negative shouldn't do anything in this scenario
+
+ pbar.update(1)
+
+ #custom model sampling to pass latent through as it is
+ class X0_PassThrough(comfy.model_sampling.EPS):
+ def calculate_denoised(self, sigma, model_output, model_input):
+ return model_output
+ def calculate_input(self, sigma, noise):
+ return noise
+ sampling_base = comfy.model_sampling.ModelSamplingDiscrete
+ sampling_type = X0_PassThrough
+
+ class ModelSamplingAdvanced(sampling_base, sampling_type):
+ pass
+ model_sampling = ModelSamplingAdvanced(model.model.model_config)
+
+ #load lora
+ model_clone = model.clone()
+ lora_path = folder_paths.get_full_path("intrinsic_loras", lora_name)
+ lora = load_torch_file(lora_path, safe_load=True)
+ self.loaded_lora = (lora_path, lora)
+
+ model_clone_with_lora = comfy.sd.load_lora_for_models(model_clone, None, lora, 1.0, 0)[0]
+
+ model_clone_with_lora.add_object_patch("model_sampling", model_sampling)
+
+ samples = {"samples": comfy.sample.sample(model_clone_with_lora, noise, 1, 1.0, "euler", "simple", positive, negative, sample,
+ denoise=1.0, disable_noise=True, start_step=0, last_step=1,
+ force_full_denoise=True, noise_mask=None, callback=None, disable_pbar=True, seed=None)}
+ pbar.update(1)
+
+ decoded = []
+ for start_idx in range(0, samples["samples"].shape[0], per_batch):
+ decoded.append(vae.decode(samples["samples"][start_idx:start_idx+per_batch]))
+ image_out = torch.cat(decoded, dim=0)
+
+ pbar.update(1)
+
+ if task == 'depth map':
+ imax = image_out.max()
+ imin = image_out.min()
+ image_out = (image_out-imin)/(imax-imin)
+ image_out = torch.max(image_out, dim=3, keepdim=True)[0].repeat(1, 1, 1, 3)
+ elif task == 'surface normals':
+ image_out = F.normalize(image_out * 2 - 1, dim=3) / 2 + 0.5
+ image_out = 1.0 - image_out
+ else:
+ image_out = image_out.clamp(-1.,1.)
+
+ return (image_out, samples,)
\ No newline at end of file
diff --git a/custom_nodes/ComfyUI-KJNodes-main/nodes/mask_nodes.py b/custom_nodes/ComfyUI-KJNodes-main/nodes/mask_nodes.py
new file mode 100644
index 0000000000000000000000000000000000000000..8852d0662d0cd5ca2c4be6add22fe77e65ee7442
--- /dev/null
+++ b/custom_nodes/ComfyUI-KJNodes-main/nodes/mask_nodes.py
@@ -0,0 +1,1397 @@
+import torch
+import torch.nn.functional as F
+from torchvision.transforms import functional as TF
+from PIL import Image, ImageDraw, ImageFilter, ImageFont
+import scipy.ndimage
+import numpy as np
+from contextlib import nullcontext
+import os
+
+import model_management
+from comfy.utils import ProgressBar
+from comfy.utils import common_upscale
+from nodes import MAX_RESOLUTION
+
+import folder_paths
+
+from ..utility.utility import tensor2pil, pil2tensor
+
+script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+
+class BatchCLIPSeg:
+
+ def __init__(self):
+ pass
+
+ @classmethod
+ def INPUT_TYPES(s):
+
+ return {"required":
+ {
+ "images": ("IMAGE",),
+ "text": ("STRING", {"multiline": False}),
+ "threshold": ("FLOAT", {"default": 0.5,"min": 0.0, "max": 10.0, "step": 0.001}),
+ "binary_mask": ("BOOLEAN", {"default": True}),
+ "combine_mask": ("BOOLEAN", {"default": False}),
+ "use_cuda": ("BOOLEAN", {"default": True}),
+ },
+ "optional":
+ {
+ "blur_sigma": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.1}),
+ "opt_model": ("CLIPSEGMODEL", ),
+ "prev_mask": ("MASK", {"default": None}),
+ "image_bg_level": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
+ "invert": ("BOOLEAN", {"default": False}),
+ }
+ }
+
+ CATEGORY = "KJNodes/masking"
+ RETURN_TYPES = ("MASK", "IMAGE", )
+ RETURN_NAMES = ("Mask", "Image", )
+ FUNCTION = "segment_image"
+ DESCRIPTION = """
+Segments an image or batch of images using CLIPSeg.
+"""
+
+ def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_cuda, blur_sigma=0.0, opt_model=None, prev_mask=None, invert= False, image_bg_level=0.5):
+ from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
+ import torchvision.transforms as transforms
+ offload_device = model_management.unet_offload_device()
+ device = model_management.get_torch_device()
+ if not use_cuda:
+ device = torch.device("cpu")
+ dtype = model_management.unet_dtype()
+
+ if opt_model is None:
+ checkpoint_path = os.path.join(folder_paths.models_dir,'clip_seg', 'clipseg-rd64-refined-fp16')
+ if not hasattr(self, "model"):
+ try:
+ if not os.path.exists(checkpoint_path):
+ from huggingface_hub import snapshot_download
+ snapshot_download(repo_id="Kijai/clipseg-rd64-refined-fp16", local_dir=checkpoint_path, local_dir_use_symlinks=False)
+ self.model = CLIPSegForImageSegmentation.from_pretrained(checkpoint_path)
+ except:
+ checkpoint_path = "CIDAS/clipseg-rd64-refined"
+ self.model = CLIPSegForImageSegmentation.from_pretrained(checkpoint_path)
+ processor = CLIPSegProcessor.from_pretrained(checkpoint_path)
+
+ else:
+ self.model = opt_model['model']
+ processor = opt_model['processor']
+
+ self.model.to(dtype).to(device)
+
+ B, H, W, C = images.shape
+ images = images.to(device)
+
+ autocast_condition = (dtype != torch.float32) and not model_management.is_device_mps(device)
+ with torch.autocast(model_management.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext():
+
+ PIL_images = [Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) for image in images ]
+ prompt = [text] * len(images)
+ input_prc = processor(text=prompt, images=PIL_images, return_tensors="pt")
+
+ for key in input_prc:
+ input_prc[key] = input_prc[key].to(device)
+ outputs = self.model(**input_prc)
+
+ mask_tensor = torch.sigmoid(outputs.logits)
+ mask_tensor = (mask_tensor - mask_tensor.min()) / (mask_tensor.max() - mask_tensor.min())
+ mask_tensor = torch.where(mask_tensor > (threshold), mask_tensor, torch.tensor(0, dtype=torch.float))
+ print(mask_tensor.shape)
+ if len(mask_tensor.shape) == 2:
+ mask_tensor = mask_tensor.unsqueeze(0)
+ mask_tensor = F.interpolate(mask_tensor.unsqueeze(1), size=(H, W), mode='nearest')
+ mask_tensor = mask_tensor.squeeze(1)
+
+ self.model.to(offload_device)
+
+ if binary_mask:
+ mask_tensor = (mask_tensor > 0).float()
+ if blur_sigma > 0:
+ kernel_size = int(6 * int(blur_sigma) + 1)
+ blur = transforms.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=(blur_sigma, blur_sigma))
+ mask_tensor = blur(mask_tensor)
+
+ if combine_mask:
+ mask_tensor = torch.max(mask_tensor, dim=0)[0]
+ mask_tensor = mask_tensor.unsqueeze(0).repeat(len(images),1,1)
+
+ del outputs
+ model_management.soft_empty_cache()
+
+ if prev_mask is not None:
+ if prev_mask.shape != mask_tensor.shape:
+ prev_mask = F.interpolate(prev_mask.unsqueeze(1), size=(H, W), mode='nearest')
+ mask_tensor = mask_tensor + prev_mask.to(device)
+ torch.clamp(mask_tensor, min=0.0, max=1.0)
+
+ if invert:
+ mask_tensor = 1 - mask_tensor
+
+ image_tensor = images * mask_tensor.unsqueeze(-1) + (1 - mask_tensor.unsqueeze(-1)) * image_bg_level
+ image_tensor = torch.clamp(image_tensor, min=0.0, max=1.0).cpu().float()
+
+ mask_tensor = mask_tensor.cpu().float()
+
+ return mask_tensor, image_tensor,
+
+class DownloadAndLoadCLIPSeg:
+
+ def __init__(self):
+ pass
+
+ @classmethod
+ def INPUT_TYPES(s):
+
+ return {"required":
+ {
+ "model": (
+ [ 'Kijai/clipseg-rd64-refined-fp16',
+ 'CIDAS/clipseg-rd64-refined',
+ ],
+ ),
+ },
+ }
+
+ CATEGORY = "KJNodes/masking"
+ RETURN_TYPES = ("CLIPSEGMODEL",)
+ RETURN_NAMES = ("clipseg_model",)
+ FUNCTION = "segment_image"
+ DESCRIPTION = """
+Downloads and loads CLIPSeg model with huggingface_hub,
+to ComfyUI/models/clip_seg
+"""
+
+ def segment_image(self, model):
+ from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
+ checkpoint_path = os.path.join(folder_paths.models_dir,'clip_seg', os.path.basename(model))
+ if not hasattr(self, "model"):
+ if not os.path.exists(checkpoint_path):
+ from huggingface_hub import snapshot_download
+ snapshot_download(repo_id=model, local_dir=checkpoint_path, local_dir_use_symlinks=False)
+ self.model = CLIPSegForImageSegmentation.from_pretrained(checkpoint_path)
+
+ processor = CLIPSegProcessor.from_pretrained(checkpoint_path)
+
+ clipseg_model = {}
+ clipseg_model['model'] = self.model
+ clipseg_model['processor'] = processor
+
+ return clipseg_model,
+
+class CreateTextMask:
+
+ RETURN_TYPES = ("IMAGE", "MASK",)
+ FUNCTION = "createtextmask"
+ CATEGORY = "KJNodes/text"
+ DESCRIPTION = """
+Creates a text image and mask.
+Looks for fonts from this folder:
+ComfyUI/custom_nodes/ComfyUI-KJNodes/fonts
+
+If start_rotation and/or end_rotation are different values,
+creates animation between them.
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "invert": ("BOOLEAN", {"default": False}),
+ "frames": ("INT", {"default": 1,"min": 1, "max": 4096, "step": 1}),
+ "text_x": ("INT", {"default": 0,"min": 0, "max": 4096, "step": 1}),
+ "text_y": ("INT", {"default": 0,"min": 0, "max": 4096, "step": 1}),
+ "font_size": ("INT", {"default": 32,"min": 8, "max": 4096, "step": 1}),
+ "font_color": ("STRING", {"default": "white"}),
+ "text": ("STRING", {"default": "HELLO!", "multiline": True}),
+ "font": (folder_paths.get_filename_list("kjnodes_fonts"), ),
+ "width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ "height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ "start_rotation": ("INT", {"default": 0,"min": 0, "max": 359, "step": 1}),
+ "end_rotation": ("INT", {"default": 0,"min": -359, "max": 359, "step": 1}),
+ },
+ }
+
+ def createtextmask(self, frames, width, height, invert, text_x, text_y, text, font_size, font_color, font, start_rotation, end_rotation):
+ # Define the number of images in the batch
+ batch_size = frames
+ out = []
+ masks = []
+ rotation = start_rotation
+ if start_rotation != end_rotation:
+ rotation_increment = (end_rotation - start_rotation) / (batch_size - 1)
+
+ font_path = folder_paths.get_full_path("kjnodes_fonts", font)
+ # Generate the text
+ for i in range(batch_size):
+ image = Image.new("RGB", (width, height), "black")
+ draw = ImageDraw.Draw(image)
+ font = ImageFont.truetype(font_path, font_size)
+
+ # Split the text into words
+ words = text.split()
+
+ # Initialize variables for line creation
+ lines = []
+ current_line = []
+ current_line_width = 0
+ try: #new pillow
+ # Iterate through words to create lines
+ for word in words:
+ word_width = font.getbbox(word)[2]
+ if current_line_width + word_width <= width - 2 * text_x:
+ current_line.append(word)
+ current_line_width += word_width + font.getbbox(" ")[2] # Add space width
+ else:
+ lines.append(" ".join(current_line))
+ current_line = [word]
+ current_line_width = word_width
+ except: #old pillow
+ for word in words:
+ word_width = font.getsize(word)[0]
+ if current_line_width + word_width <= width - 2 * text_x:
+ current_line.append(word)
+ current_line_width += word_width + font.getsize(" ")[0] # Add space width
+ else:
+ lines.append(" ".join(current_line))
+ current_line = [word]
+ current_line_width = word_width
+
+ # Add the last line if it's not empty
+ if current_line:
+ lines.append(" ".join(current_line))
+
+ # Draw each line of text separately
+ y_offset = text_y
+ for line in lines:
+ text_width = font.getlength(line)
+ text_height = font_size
+ text_center_x = text_x + text_width / 2
+ text_center_y = y_offset + text_height / 2
+ try:
+ draw.text((text_x, y_offset), line, font=font, fill=font_color, features=['-liga'])
+ except:
+ draw.text((text_x, y_offset), line, font=font, fill=font_color)
+ y_offset += text_height # Move to the next line
+
+ if start_rotation != end_rotation:
+ image = image.rotate(rotation, center=(text_center_x, text_center_y))
+ rotation += rotation_increment
+
+ image = np.array(image).astype(np.float32) / 255.0
+ image = torch.from_numpy(image)[None,]
+ mask = image[:, :, :, 0]
+ masks.append(mask)
+ out.append(image)
+
+ if invert:
+ return (1.0 - torch.cat(out, dim=0), 1.0 - torch.cat(masks, dim=0),)
+ return (torch.cat(out, dim=0),torch.cat(masks, dim=0),)
+
+class ColorToMask:
+
+ RETURN_TYPES = ("MASK",)
+ FUNCTION = "clip"
+ CATEGORY = "KJNodes/masking"
+ DESCRIPTION = """
+Converts chosen RGB value to a mask.
+With batch inputs, the **per_batch**
+controls the number of images processed at once.
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "images": ("IMAGE",),
+ "invert": ("BOOLEAN", {"default": False}),
+ "red": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}),
+ "green": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}),
+ "blue": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}),
+ "threshold": ("INT", {"default": 10,"min": 0, "max": 255, "step": 1}),
+ "per_batch": ("INT", {"default": 16, "min": 1, "max": 4096, "step": 1}),
+ },
+ }
+
+ def clip(self, images, red, green, blue, threshold, invert, per_batch):
+
+ color = torch.tensor([red, green, blue], dtype=torch.uint8)
+ black = torch.tensor([0, 0, 0], dtype=torch.uint8)
+ white = torch.tensor([255, 255, 255], dtype=torch.uint8)
+
+ if invert:
+ black, white = white, black
+
+ steps = images.shape[0]
+ pbar = ProgressBar(steps)
+ tensors_out = []
+
+ for start_idx in range(0, images.shape[0], per_batch):
+
+ # Calculate color distances
+ color_distances = torch.norm(images[start_idx:start_idx+per_batch] * 255 - color, dim=-1)
+
+ # Create a mask based on the threshold
+ mask = color_distances <= threshold
+
+ # Apply the mask to create new images
+ mask_out = torch.where(mask.unsqueeze(-1), white, black).float()
+ mask_out = mask_out.mean(dim=-1)
+
+ tensors_out.append(mask_out.cpu())
+ batch_count = mask_out.shape[0]
+ pbar.update(batch_count)
+
+ tensors_out = torch.cat(tensors_out, dim=0)
+ tensors_out = torch.clamp(tensors_out, min=0.0, max=1.0)
+ return tensors_out,
+
+class CreateFluidMask:
+
+ RETURN_TYPES = ("IMAGE", "MASK")
+ FUNCTION = "createfluidmask"
+ CATEGORY = "KJNodes/masking/generate"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "invert": ("BOOLEAN", {"default": False}),
+ "frames": ("INT", {"default": 1,"min": 1, "max": 4096, "step": 1}),
+ "width": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}),
+ "height": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}),
+ "inflow_count": ("INT", {"default": 3,"min": 0, "max": 255, "step": 1}),
+ "inflow_velocity": ("INT", {"default": 1,"min": 0, "max": 255, "step": 1}),
+ "inflow_radius": ("INT", {"default": 8,"min": 0, "max": 255, "step": 1}),
+ "inflow_padding": ("INT", {"default": 50,"min": 0, "max": 255, "step": 1}),
+ "inflow_duration": ("INT", {"default": 60,"min": 0, "max": 255, "step": 1}),
+ },
+ }
+ #using code from https://github.com/GregTJ/stable-fluids
+ def createfluidmask(self, frames, width, height, invert, inflow_count, inflow_velocity, inflow_radius, inflow_padding, inflow_duration):
+ from ..utility.fluid import Fluid
+ try:
+ from scipy.special import erf
+ except:
+ from scipy.spatial import erf
+ out = []
+ masks = []
+ RESOLUTION = width, height
+ DURATION = frames
+
+ INFLOW_PADDING = inflow_padding
+ INFLOW_DURATION = inflow_duration
+ INFLOW_RADIUS = inflow_radius
+ INFLOW_VELOCITY = inflow_velocity
+ INFLOW_COUNT = inflow_count
+
+ print('Generating fluid solver, this may take some time.')
+ fluid = Fluid(RESOLUTION, 'dye')
+
+ center = np.floor_divide(RESOLUTION, 2)
+ r = np.min(center) - INFLOW_PADDING
+
+ points = np.linspace(-np.pi, np.pi, INFLOW_COUNT, endpoint=False)
+ points = tuple(np.array((np.cos(p), np.sin(p))) for p in points)
+ normals = tuple(-p for p in points)
+ points = tuple(r * p + center for p in points)
+
+ inflow_velocity = np.zeros_like(fluid.velocity)
+ inflow_dye = np.zeros(fluid.shape)
+ for p, n in zip(points, normals):
+ mask = np.linalg.norm(fluid.indices - p[:, None, None], axis=0) <= INFLOW_RADIUS
+ inflow_velocity[:, mask] += n[:, None] * INFLOW_VELOCITY
+ inflow_dye[mask] = 1
+
+
+ for f in range(DURATION):
+ print(f'Computing frame {f + 1} of {DURATION}.')
+ if f <= INFLOW_DURATION:
+ fluid.velocity += inflow_velocity
+ fluid.dye += inflow_dye
+
+ curl = fluid.step()[1]
+ # Using the error function to make the contrast a bit higher.
+ # Any other sigmoid function e.g. smoothstep would work.
+ curl = (erf(curl * 2) + 1) / 4
+
+ color = np.dstack((curl, np.ones(fluid.shape), fluid.dye))
+ color = (np.clip(color, 0, 1) * 255).astype('uint8')
+ image = np.array(color).astype(np.float32) / 255.0
+ image = torch.from_numpy(image)[None,]
+ mask = image[:, :, :, 0]
+ masks.append(mask)
+ out.append(image)
+
+ if invert:
+ return (1.0 - torch.cat(out, dim=0),1.0 - torch.cat(masks, dim=0),)
+ return (torch.cat(out, dim=0),torch.cat(masks, dim=0),)
+
+class CreateAudioMask:
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "createaudiomask"
+ CATEGORY = "KJNodes/deprecated"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "invert": ("BOOLEAN", {"default": False}),
+ "frames": ("INT", {"default": 16,"min": 1, "max": 255, "step": 1}),
+ "scale": ("FLOAT", {"default": 0.5,"min": 0.0, "max": 2.0, "step": 0.01}),
+ "audio_path": ("STRING", {"default": "audio.wav"}),
+ "width": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}),
+ "height": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}),
+ },
+ }
+
+ def createaudiomask(self, frames, width, height, invert, audio_path, scale):
+ try:
+ import librosa
+ except ImportError:
+ raise Exception("Can not import librosa. Install it with 'pip install librosa'")
+ batch_size = frames
+ out = []
+ masks = []
+ if audio_path == "audio.wav": #I don't know why relative path won't work otherwise...
+ audio_path = os.path.join(script_directory, audio_path)
+ audio, sr = librosa.load(audio_path)
+ spectrogram = np.abs(librosa.stft(audio))
+
+ for i in range(batch_size):
+ image = Image.new("RGB", (width, height), "black")
+ draw = ImageDraw.Draw(image)
+ frame = spectrogram[:, i]
+ circle_radius = int(height * np.mean(frame))
+ circle_radius *= scale
+ circle_center = (width // 2, height // 2) # Calculate the center of the image
+
+ draw.ellipse([(circle_center[0] - circle_radius, circle_center[1] - circle_radius),
+ (circle_center[0] + circle_radius, circle_center[1] + circle_radius)],
+ fill='white')
+
+ image = np.array(image).astype(np.float32) / 255.0
+ image = torch.from_numpy(image)[None,]
+ mask = image[:, :, :, 0]
+ masks.append(mask)
+ out.append(image)
+
+ if invert:
+ return (1.0 - torch.cat(out, dim=0),)
+ return (torch.cat(out, dim=0),torch.cat(masks, dim=0),)
+
+class CreateGradientMask:
+
+ RETURN_TYPES = ("MASK",)
+ FUNCTION = "createmask"
+ CATEGORY = "KJNodes/masking/generate"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "invert": ("BOOLEAN", {"default": False}),
+ "frames": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}),
+ "width": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}),
+ "height": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}),
+ },
+ }
+ def createmask(self, frames, width, height, invert):
+ # Define the number of images in the batch
+ batch_size = frames
+ out = []
+ # Create an empty array to store the image batch
+ image_batch = np.zeros((batch_size, height, width), dtype=np.float32)
+ # Generate the black to white gradient for each image
+ for i in range(batch_size):
+ gradient = np.linspace(1.0, 0.0, width, dtype=np.float32)
+ time = i / frames # Calculate the time variable
+ offset_gradient = gradient - time # Offset the gradient values based on time
+ image_batch[i] = offset_gradient.reshape(1, -1)
+ output = torch.from_numpy(image_batch)
+ mask = output
+ out.append(mask)
+ if invert:
+ return (1.0 - torch.cat(out, dim=0),)
+ return (torch.cat(out, dim=0),)
+
+class CreateFadeMask:
+
+ RETURN_TYPES = ("MASK",)
+ FUNCTION = "createfademask"
+ CATEGORY = "KJNodes/deprecated"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "invert": ("BOOLEAN", {"default": False}),
+ "frames": ("INT", {"default": 2,"min": 2, "max": 10000, "step": 1}),
+ "width": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}),
+ "height": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}),
+ "interpolation": (["linear", "ease_in", "ease_out", "ease_in_out"],),
+ "start_level": ("FLOAT", {"default": 1.0,"min": 0.0, "max": 1.0, "step": 0.01}),
+ "midpoint_level": ("FLOAT", {"default": 0.5,"min": 0.0, "max": 1.0, "step": 0.01}),
+ "end_level": ("FLOAT", {"default": 0.0,"min": 0.0, "max": 1.0, "step": 0.01}),
+ "midpoint_frame": ("INT", {"default": 0,"min": 0, "max": 4096, "step": 1}),
+ },
+ }
+
+ def createfademask(self, frames, width, height, invert, interpolation, start_level, midpoint_level, end_level, midpoint_frame):
+ def ease_in(t):
+ return t * t
+
+ def ease_out(t):
+ return 1 - (1 - t) * (1 - t)
+
+ def ease_in_out(t):
+ return 3 * t * t - 2 * t * t * t
+
+ batch_size = frames
+ out = []
+ image_batch = np.zeros((batch_size, height, width), dtype=np.float32)
+
+ if midpoint_frame == 0:
+ midpoint_frame = batch_size // 2
+
+ for i in range(batch_size):
+ if i <= midpoint_frame:
+ t = i / midpoint_frame
+ if interpolation == "ease_in":
+ t = ease_in(t)
+ elif interpolation == "ease_out":
+ t = ease_out(t)
+ elif interpolation == "ease_in_out":
+ t = ease_in_out(t)
+ color = start_level - t * (start_level - midpoint_level)
+ else:
+ t = (i - midpoint_frame) / (batch_size - midpoint_frame)
+ if interpolation == "ease_in":
+ t = ease_in(t)
+ elif interpolation == "ease_out":
+ t = ease_out(t)
+ elif interpolation == "ease_in_out":
+ t = ease_in_out(t)
+ color = midpoint_level - t * (midpoint_level - end_level)
+
+ color = np.clip(color, 0, 255)
+ image = np.full((height, width), color, dtype=np.float32)
+ image_batch[i] = image
+
+ output = torch.from_numpy(image_batch)
+ mask = output
+ out.append(mask)
+
+ if invert:
+ return (1.0 - torch.cat(out, dim=0),)
+ return (torch.cat(out, dim=0),)
+
+class CreateFadeMaskAdvanced:
+
+ RETURN_TYPES = ("MASK",)
+ FUNCTION = "createfademask"
+ CATEGORY = "KJNodes/masking/generate"
+ DESCRIPTION = """
+Create a batch of masks interpolated between given frames and values.
+Uses same syntax as Fizz' BatchValueSchedule.
+First value is the frame index (not that this starts from 0, not 1)
+and the second value inside the brackets is the float value of the mask in range 0.0 - 1.0
+
+For example the default values:
+0:(0.0)
+7:(1.0)
+15:(0.0)
+
+Would create a mask batch fo 16 frames, starting from black,
+interpolating with the chosen curve to fully white at the 8th frame,
+and interpolating from that to fully black at the 16th frame.
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "points_string": ("STRING", {"default": "0:(0.0),\n7:(1.0),\n15:(0.0)\n", "multiline": True}),
+ "invert": ("BOOLEAN", {"default": False}),
+ "frames": ("INT", {"default": 16,"min": 2, "max": 10000, "step": 1}),
+ "width": ("INT", {"default": 512,"min": 1, "max": 4096, "step": 1}),
+ "height": ("INT", {"default": 512,"min": 1, "max": 4096, "step": 1}),
+ "interpolation": (["linear", "ease_in", "ease_out", "ease_in_out"],),
+ },
+ }
+
+ def createfademask(self, frames, width, height, invert, points_string, interpolation):
+ def ease_in(t):
+ return t * t
+
+ def ease_out(t):
+ return 1 - (1 - t) * (1 - t)
+
+ def ease_in_out(t):
+ return 3 * t * t - 2 * t * t * t
+
+ # Parse the input string into a list of tuples
+ points = []
+ points_string = points_string.rstrip(',\n')
+ for point_str in points_string.split(','):
+ frame_str, color_str = point_str.split(':')
+ frame = int(frame_str.strip())
+ color = float(color_str.strip()[1:-1]) # Remove parentheses around color
+ points.append((frame, color))
+
+ # Check if the last frame is already in the points
+ if len(points) == 0 or points[-1][0] != frames - 1:
+ # If not, add it with the color of the last specified frame
+ points.append((frames - 1, points[-1][1] if points else 0))
+
+ # Sort the points by frame number
+ points.sort(key=lambda x: x[0])
+
+ batch_size = frames
+ out = []
+ image_batch = np.zeros((batch_size, height, width), dtype=np.float32)
+
+ # Index of the next point to interpolate towards
+ next_point = 1
+
+ for i in range(batch_size):
+ while next_point < len(points) and i > points[next_point][0]:
+ next_point += 1
+
+ # Interpolate between the previous point and the next point
+ prev_point = next_point - 1
+ t = (i - points[prev_point][0]) / (points[next_point][0] - points[prev_point][0])
+ if interpolation == "ease_in":
+ t = ease_in(t)
+ elif interpolation == "ease_out":
+ t = ease_out(t)
+ elif interpolation == "ease_in_out":
+ t = ease_in_out(t)
+ elif interpolation == "linear":
+ pass # No need to modify `t` for linear interpolation
+
+ color = points[prev_point][1] - t * (points[prev_point][1] - points[next_point][1])
+ color = np.clip(color, 0, 255)
+ image = np.full((height, width), color, dtype=np.float32)
+ image_batch[i] = image
+
+ output = torch.from_numpy(image_batch)
+ mask = output
+ out.append(mask)
+
+ if invert:
+ return (1.0 - torch.cat(out, dim=0),)
+ return (torch.cat(out, dim=0),)
+
+class CreateMagicMask:
+
+ RETURN_TYPES = ("MASK", "MASK",)
+ RETURN_NAMES = ("mask", "mask_inverted",)
+ FUNCTION = "createmagicmask"
+ CATEGORY = "KJNodes/masking/generate"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "frames": ("INT", {"default": 16,"min": 2, "max": 4096, "step": 1}),
+ "depth": ("INT", {"default": 12,"min": 1, "max": 500, "step": 1}),
+ "distortion": ("FLOAT", {"default": 1.5,"min": 0.0, "max": 100.0, "step": 0.01}),
+ "seed": ("INT", {"default": 123,"min": 0, "max": 99999999, "step": 1}),
+ "transitions": ("INT", {"default": 1,"min": 1, "max": 20, "step": 1}),
+ "frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ "frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ },
+ }
+
+ def createmagicmask(self, frames, transitions, depth, distortion, seed, frame_width, frame_height):
+ from ..utility.magictex import coordinate_grid, random_transform, magic
+ import matplotlib.pyplot as plt
+ rng = np.random.default_rng(seed)
+ out = []
+ coords = coordinate_grid((frame_width, frame_height))
+
+ # Calculate the number of frames for each transition
+ frames_per_transition = frames // transitions
+
+ # Generate a base set of parameters
+ base_params = {
+ "coords": random_transform(coords, rng),
+ "depth": depth,
+ "distortion": distortion,
+ }
+ for t in range(transitions):
+ # Generate a second set of parameters that is at most max_diff away from the base parameters
+ params1 = base_params.copy()
+ params2 = base_params.copy()
+
+ params1['coords'] = random_transform(coords, rng)
+ params2['coords'] = random_transform(coords, rng)
+
+ for i in range(frames_per_transition):
+ # Compute the interpolation factor
+ alpha = i / frames_per_transition
+
+ # Interpolate between the two sets of parameters
+ params = params1.copy()
+ params['coords'] = (1 - alpha) * params1['coords'] + alpha * params2['coords']
+
+ tex = magic(**params)
+
+ dpi = frame_width / 10
+ fig = plt.figure(figsize=(10, 10), dpi=dpi)
+
+ ax = fig.add_subplot(111)
+ plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
+
+ ax.get_yaxis().set_ticks([])
+ ax.get_xaxis().set_ticks([])
+ ax.imshow(tex, aspect='auto')
+
+ fig.canvas.draw()
+ img = np.array(fig.canvas.renderer._renderer)
+
+ plt.close(fig)
+
+ pil_img = Image.fromarray(img).convert("L")
+ mask = torch.tensor(np.array(pil_img)) / 255.0
+
+ out.append(mask)
+
+ return (torch.stack(out, dim=0), 1.0 - torch.stack(out, dim=0),)
+
+class CreateShapeMask:
+
+ RETURN_TYPES = ("MASK", "MASK",)
+ RETURN_NAMES = ("mask", "mask_inverted",)
+ FUNCTION = "createshapemask"
+ CATEGORY = "KJNodes/masking/generate"
+ DESCRIPTION = """
+Creates a mask or batch of masks with the specified shape.
+Locations are center locations.
+Grow value is the amount to grow the shape on each frame, creating animated masks.
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "shape": (
+ [ 'circle',
+ 'square',
+ 'triangle',
+ ],
+ {
+ "default": 'circle'
+ }),
+ "frames": ("INT", {"default": 1,"min": 1, "max": 4096, "step": 1}),
+ "location_x": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}),
+ "location_y": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}),
+ "grow": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}),
+ "frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ "frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ "shape_width": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}),
+ "shape_height": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}),
+ },
+ }
+
+ def createshapemask(self, frames, frame_width, frame_height, location_x, location_y, shape_width, shape_height, grow, shape):
+ # Define the number of images in the batch
+ batch_size = frames
+ out = []
+ color = "white"
+ for i in range(batch_size):
+ image = Image.new("RGB", (frame_width, frame_height), "black")
+ draw = ImageDraw.Draw(image)
+
+ # Calculate the size for this frame and ensure it's not less than 0
+ current_width = max(0, shape_width + i*grow)
+ current_height = max(0, shape_height + i*grow)
+
+ if shape == 'circle' or shape == 'square':
+ # Define the bounding box for the shape
+ left_up_point = (location_x - current_width // 2, location_y - current_height // 2)
+ right_down_point = (location_x + current_width // 2, location_y + current_height // 2)
+ two_points = [left_up_point, right_down_point]
+
+ if shape == 'circle':
+ draw.ellipse(two_points, fill=color)
+ elif shape == 'square':
+ draw.rectangle(two_points, fill=color)
+
+ elif shape == 'triangle':
+ # Define the points for the triangle
+ left_up_point = (location_x - current_width // 2, location_y + current_height // 2) # bottom left
+ right_down_point = (location_x + current_width // 2, location_y + current_height // 2) # bottom right
+ top_point = (location_x, location_y - current_height // 2) # top point
+ draw.polygon([top_point, left_up_point, right_down_point], fill=color)
+
+ image = pil2tensor(image)
+ mask = image[:, :, :, 0]
+ out.append(mask)
+ outstack = torch.cat(out, dim=0)
+ return (outstack, 1.0 - outstack,)
+
+class CreateVoronoiMask:
+
+ RETURN_TYPES = ("MASK", "MASK",)
+ RETURN_NAMES = ("mask", "mask_inverted",)
+ FUNCTION = "createvoronoi"
+ CATEGORY = "KJNodes/masking/generate"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "frames": ("INT", {"default": 16,"min": 2, "max": 4096, "step": 1}),
+ "num_points": ("INT", {"default": 15,"min": 1, "max": 4096, "step": 1}),
+ "line_width": ("INT", {"default": 4,"min": 1, "max": 4096, "step": 1}),
+ "speed": ("FLOAT", {"default": 0.5,"min": 0.0, "max": 1.0, "step": 0.01}),
+ "frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ "frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ },
+ }
+
+ def createvoronoi(self, frames, num_points, line_width, speed, frame_width, frame_height):
+ from scipy.spatial import Voronoi
+ # Define the number of images in the batch
+ batch_size = frames
+ out = []
+
+ # Calculate aspect ratio
+ aspect_ratio = frame_width / frame_height
+
+ # Create start and end points for each point, considering the aspect ratio
+ start_points = np.random.rand(num_points, 2)
+ start_points[:, 0] *= aspect_ratio
+
+ end_points = np.random.rand(num_points, 2)
+ end_points[:, 0] *= aspect_ratio
+
+ for i in range(batch_size):
+ # Interpolate the points' positions based on the current frame
+ t = (i * speed) / (batch_size - 1) # normalize to [0, 1] over the frames
+ t = np.clip(t, 0, 1) # ensure t is in [0, 1]
+ points = (1 - t) * start_points + t * end_points # lerp
+
+ # Adjust points for aspect ratio
+ points[:, 0] *= aspect_ratio
+
+ vor = Voronoi(points)
+
+ # Create a blank image with a white background
+ fig, ax = plt.subplots()
+ plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
+ ax.set_xlim([0, aspect_ratio]); ax.set_ylim([0, 1]) # adjust x limits
+ ax.axis('off')
+ ax.margins(0, 0)
+ fig.set_size_inches(aspect_ratio * frame_height/100, frame_height/100) # adjust figure size
+ ax.fill_between([0, 1], [0, 1], color='white')
+
+ # Plot each Voronoi ridge
+ for simplex in vor.ridge_vertices:
+ simplex = np.asarray(simplex)
+ if np.all(simplex >= 0):
+ plt.plot(vor.vertices[simplex, 0], vor.vertices[simplex, 1], 'k-', linewidth=line_width)
+
+ fig.canvas.draw()
+ img = np.array(fig.canvas.renderer._renderer)
+
+ plt.close(fig)
+
+ pil_img = Image.fromarray(img).convert("L")
+ mask = torch.tensor(np.array(pil_img)) / 255.0
+
+ out.append(mask)
+
+ return (torch.stack(out, dim=0), 1.0 - torch.stack(out, dim=0),)
+
+class GetMaskSizeAndCount:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "mask": ("MASK",),
+ }}
+
+ RETURN_TYPES = ("MASK","INT", "INT", "INT",)
+ RETURN_NAMES = ("mask", "width", "height", "count",)
+ FUNCTION = "getsize"
+ CATEGORY = "KJNodes/masking"
+ DESCRIPTION = """
+Returns the width, height and batch size of the mask,
+and passes it through unchanged.
+
+"""
+
+ def getsize(self, mask):
+ width = mask.shape[2]
+ height = mask.shape[1]
+ count = mask.shape[0]
+ return {"ui": {
+ "text": [f"{count}x{width}x{height}"]},
+ "result": (mask, width, height, count)
+ }
+
+class GrowMaskWithBlur:
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "mask": ("MASK",),
+ "expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}),
+ "incremental_expandrate": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.1}),
+ "tapered_corners": ("BOOLEAN", {"default": True}),
+ "flip_input": ("BOOLEAN", {"default": False}),
+ "blur_radius": ("FLOAT", {
+ "default": 0.0,
+ "min": 0.0,
+ "max": 100,
+ "step": 0.1
+ }),
+ "lerp_alpha": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
+ "decay_factor": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
+ },
+ "optional": {
+ "fill_holes": ("BOOLEAN", {"default": False}),
+ },
+ }
+
+ CATEGORY = "KJNodes/masking"
+ RETURN_TYPES = ("MASK", "MASK",)
+ RETURN_NAMES = ("mask", "mask_inverted",)
+ FUNCTION = "expand_mask"
+ DESCRIPTION = """
+# GrowMaskWithBlur
+- mask: Input mask or mask batch
+- expand: Expand or contract mask or mask batch by a given amount
+- incremental_expandrate: increase expand rate by a given amount per frame
+- tapered_corners: use tapered corners
+- flip_input: flip input mask
+- blur_radius: value higher than 0 will blur the mask
+- lerp_alpha: alpha value for interpolation between frames
+- decay_factor: decay value for interpolation between frames
+- fill_holes: fill holes in the mask (slow)"""
+
+ def expand_mask(self, mask, expand, tapered_corners, flip_input, blur_radius, incremental_expandrate, lerp_alpha, decay_factor, fill_holes=False):
+ alpha = lerp_alpha
+ decay = decay_factor
+ if flip_input:
+ mask = 1.0 - mask
+ c = 0 if tapered_corners else 1
+ kernel = np.array([[c, 1, c],
+ [1, 1, 1],
+ [c, 1, c]])
+ growmask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).cpu()
+ out = []
+ previous_output = None
+ current_expand = expand
+ for m in growmask:
+ output = m.numpy().astype(np.float32)
+ for _ in range(abs(round(current_expand))):
+ if current_expand < 0:
+ output = scipy.ndimage.grey_erosion(output, footprint=kernel)
+ else:
+ output = scipy.ndimage.grey_dilation(output, footprint=kernel)
+ if current_expand < 0:
+ current_expand -= abs(incremental_expandrate)
+ else:
+ current_expand += abs(incremental_expandrate)
+ if fill_holes:
+ binary_mask = output > 0
+ output = scipy.ndimage.binary_fill_holes(binary_mask)
+ output = output.astype(np.float32) * 255
+ output = torch.from_numpy(output)
+ if alpha < 1.0 and previous_output is not None:
+ # Interpolate between the previous and current frame
+ output = alpha * output + (1 - alpha) * previous_output
+ if decay < 1.0 and previous_output is not None:
+ # Add the decayed previous output to the current frame
+ output += decay * previous_output
+ output = output / output.max()
+ previous_output = output
+ out.append(output)
+
+ if blur_radius != 0:
+ # Convert the tensor list to PIL images, apply blur, and convert back
+ for idx, tensor in enumerate(out):
+ # Convert tensor to PIL image
+ pil_image = tensor2pil(tensor.cpu().detach())[0]
+ # Apply Gaussian blur
+ pil_image = pil_image.filter(ImageFilter.GaussianBlur(blur_radius))
+ # Convert back to tensor
+ out[idx] = pil2tensor(pil_image)
+ blurred = torch.cat(out, dim=0)
+ return (blurred, 1.0 - blurred)
+ else:
+ return (torch.stack(out, dim=0), 1.0 - torch.stack(out, dim=0),)
+
+class MaskBatchMulti:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "inputcount": ("INT", {"default": 2, "min": 2, "max": 1000, "step": 1}),
+ "mask_1": ("MASK", ),
+ "mask_2": ("MASK", ),
+ },
+ }
+
+ RETURN_TYPES = ("MASK",)
+ RETURN_NAMES = ("masks",)
+ FUNCTION = "combine"
+ CATEGORY = "KJNodes/masking"
+ DESCRIPTION = """
+Creates an image batch from multiple masks.
+You can set how many inputs the node has,
+with the **inputcount** and clicking update.
+"""
+
+ def combine(self, inputcount, **kwargs):
+ mask = kwargs["mask_1"]
+ for c in range(1, inputcount):
+ new_mask = kwargs[f"mask_{c + 1}"]
+ if mask.shape[1:] != new_mask.shape[1:]:
+ new_mask = F.interpolate(new_mask.unsqueeze(1), size=(mask.shape[1], mask.shape[2]), mode="bicubic").squeeze(1)
+ mask = torch.cat((mask, new_mask), dim=0)
+ return (mask,)
+
+class OffsetMask:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "mask": ("MASK",),
+ "x": ("INT", { "default": 0, "min": -4096, "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
+ "y": ("INT", { "default": 0, "min": -4096, "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
+ "angle": ("INT", { "default": 0, "min": -360, "max": 360, "step": 1, "display": "number" }),
+ "duplication_factor": ("INT", { "default": 1, "min": 1, "max": 1000, "step": 1, "display": "number" }),
+ "roll": ("BOOLEAN", { "default": False }),
+ "incremental": ("BOOLEAN", { "default": False }),
+ "padding_mode": (
+ [
+ 'empty',
+ 'border',
+ 'reflection',
+
+ ], {
+ "default": 'empty'
+ }),
+ }
+ }
+
+ RETURN_TYPES = ("MASK",)
+ RETURN_NAMES = ("mask",)
+ FUNCTION = "offset"
+ CATEGORY = "KJNodes/masking"
+ DESCRIPTION = """
+Offsets the mask by the specified amount.
+ - mask: Input mask or mask batch
+ - x: Horizontal offset
+ - y: Vertical offset
+ - angle: Angle in degrees
+ - roll: roll edge wrapping
+ - duplication_factor: Number of times to duplicate the mask to form a batch
+ - border padding_mode: Padding mode for the mask
+"""
+
+ def offset(self, mask, x, y, angle, roll=False, incremental=False, duplication_factor=1, padding_mode="empty"):
+ # Create duplicates of the mask batch
+ mask = mask.repeat(duplication_factor, 1, 1).clone()
+
+ batch_size, height, width = mask.shape
+
+ if angle != 0 and incremental:
+ for i in range(batch_size):
+ rotation_angle = angle * (i+1)
+ mask[i] = TF.rotate(mask[i].unsqueeze(0), rotation_angle).squeeze(0)
+ elif angle > 0:
+ for i in range(batch_size):
+ mask[i] = TF.rotate(mask[i].unsqueeze(0), angle).squeeze(0)
+
+ if roll:
+ if incremental:
+ for i in range(batch_size):
+ shift_x = min(x*(i+1), width-1)
+ shift_y = min(y*(i+1), height-1)
+ if shift_x != 0:
+ mask[i] = torch.roll(mask[i], shifts=shift_x, dims=1)
+ if shift_y != 0:
+ mask[i] = torch.roll(mask[i], shifts=shift_y, dims=0)
+ else:
+ shift_x = min(x, width-1)
+ shift_y = min(y, height-1)
+ if shift_x != 0:
+ mask = torch.roll(mask, shifts=shift_x, dims=2)
+ if shift_y != 0:
+ mask = torch.roll(mask, shifts=shift_y, dims=1)
+ else:
+
+ for i in range(batch_size):
+ if incremental:
+ temp_x = min(x * (i+1), width-1)
+ temp_y = min(y * (i+1), height-1)
+ else:
+ temp_x = min(x, width-1)
+ temp_y = min(y, height-1)
+ if temp_x > 0:
+ if padding_mode == 'empty':
+ mask[i] = torch.cat([torch.zeros((height, temp_x)), mask[i, :, :-temp_x]], dim=1)
+ elif padding_mode in ['replicate', 'reflect']:
+ mask[i] = F.pad(mask[i, :, :-temp_x], (0, temp_x), mode=padding_mode)
+ elif temp_x < 0:
+ if padding_mode == 'empty':
+ mask[i] = torch.cat([mask[i, :, :temp_x], torch.zeros((height, -temp_x))], dim=1)
+ elif padding_mode in ['replicate', 'reflect']:
+ mask[i] = F.pad(mask[i, :, -temp_x:], (temp_x, 0), mode=padding_mode)
+
+ if temp_y > 0:
+ if padding_mode == 'empty':
+ mask[i] = torch.cat([torch.zeros((temp_y, width)), mask[i, :-temp_y, :]], dim=0)
+ elif padding_mode in ['replicate', 'reflect']:
+ mask[i] = F.pad(mask[i, :-temp_y, :], (0, temp_y), mode=padding_mode)
+ elif temp_y < 0:
+ if padding_mode == 'empty':
+ mask[i] = torch.cat([mask[i, :temp_y, :], torch.zeros((-temp_y, width))], dim=0)
+ elif padding_mode in ['replicate', 'reflect']:
+ mask[i] = F.pad(mask[i, -temp_y:, :], (temp_y, 0), mode=padding_mode)
+
+ return mask,
+
+class RoundMask:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "mask": ("MASK",),
+ }}
+
+ RETURN_TYPES = ("MASK",)
+ FUNCTION = "round"
+ CATEGORY = "KJNodes/masking"
+ DESCRIPTION = """
+Rounds the mask or batch of masks to a binary mask.
+
+
+"""
+
+ def round(self, mask):
+ mask = mask.round()
+ return (mask,)
+
+class ResizeMask:
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "mask": ("MASK",),
+ "width": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
+ "height": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
+ "keep_proportions": ("BOOLEAN", { "default": False }),
+ "upscale_method": (s.upscale_methods,),
+ "crop": (["disabled","center"],),
+ }
+ }
+
+ RETURN_TYPES = ("MASK", "INT", "INT",)
+ RETURN_NAMES = ("mask", "width", "height",)
+ FUNCTION = "resize"
+ CATEGORY = "KJNodes/masking"
+ DESCRIPTION = """
+Resizes the mask or batch of masks to the specified width and height.
+"""
+
+ def resize(self, mask, width, height, keep_proportions, upscale_method,crop):
+ if keep_proportions:
+ _, oh, ow = mask.shape
+ width = ow if width == 0 else width
+ height = oh if height == 0 else height
+ ratio = min(width / ow, height / oh)
+ width = round(ow*ratio)
+ height = round(oh*ratio)
+ outputs = mask.unsqueeze(1)
+ outputs = common_upscale(outputs, width, height, upscale_method, crop)
+ outputs = outputs.squeeze(1)
+
+ return(outputs, outputs.shape[2], outputs.shape[1],)
+
+class RemapMaskRange:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "mask": ("MASK",),
+ "min": ("FLOAT", {"default": 0.0,"min": -10.0, "max": 1.0, "step": 0.01}),
+ "max": ("FLOAT", {"default": 1.0,"min": 0.0, "max": 10.0, "step": 0.01}),
+ }
+ }
+
+ RETURN_TYPES = ("MASK",)
+ RETURN_NAMES = ("mask",)
+ FUNCTION = "remap"
+ CATEGORY = "KJNodes/masking"
+ DESCRIPTION = """
+Sets new min and max values for the mask.
+"""
+
+ def remap(self, mask, min, max):
+
+ # Find the maximum value in the mask
+ mask_max = torch.max(mask)
+
+ # If the maximum mask value is zero, avoid division by zero by setting it to 1
+ mask_max = mask_max if mask_max > 0 else 1
+
+ # Scale the mask values to the new range defined by min and max
+ # The highest pixel value in the mask will be scaled to max
+ scaled_mask = (mask / mask_max) * (max - min) + min
+
+ # Clamp the values to ensure they are within [0.0, 1.0]
+ scaled_mask = torch.clamp(scaled_mask, min=0.0, max=1.0)
+
+ return (scaled_mask, )
+
+
+def get_mask_polygon(self, mask_np):
+ import cv2
+ """Helper function to get polygon points from mask"""
+ # Find contours
+ contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+
+ if not contours:
+ return None
+
+ # Get the largest contour
+ largest_contour = max(contours, key=cv2.contourArea)
+
+ # Approximate polygon
+ epsilon = 0.02 * cv2.arcLength(largest_contour, True)
+ polygon = cv2.approxPolyDP(largest_contour, epsilon, True)
+
+ return polygon.squeeze()
+
+import cv2
+class SeparateMasks:
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "mask": ("MASK", ),
+ "size_threshold_width" : ("INT", {"default": 256, "min": 0.0, "max": 4096, "step": 1}),
+ "size_threshold_height" : ("INT", {"default": 256, "min": 0.0, "max": 4096, "step": 1}),
+ "mode": (["convex_polygons", "area"],),
+ "max_poly_points": ("INT", {"default": 8, "min": 3, "max": 32, "step": 1}),
+
+ },
+ }
+
+ RETURN_TYPES = ("MASK",)
+ RETURN_NAMES = ("mask",)
+ FUNCTION = "separate"
+ CATEGORY = "KJNodes/masking"
+ OUTPUT_NODE = True
+ DESCRIPTION = "Separates a mask into multiple masks based on the size of the connected components."
+
+ def polygon_to_mask(self, polygon, shape):
+ mask = np.zeros((shape[0], shape[1]), dtype=np.uint8) # Fixed shape handling
+
+ if len(polygon.shape) == 2: # Check if polygon points are valid
+ polygon = polygon.astype(np.int32)
+ cv2.fillPoly(mask, [polygon], 1)
+ return mask
+
+ def get_mask_polygon(self, mask_np, max_points):
+ contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ if not contours:
+ return None
+
+ largest_contour = max(contours, key=cv2.contourArea)
+ hull = cv2.convexHull(largest_contour)
+
+ # Initialize with smaller epsilon for more points
+ perimeter = cv2.arcLength(hull, True)
+ epsilon = perimeter * 0.01 # Start smaller
+
+ min_eps = perimeter * 0.001 # Much smaller minimum
+ max_eps = perimeter * 0.2 # Smaller maximum
+
+ best_approx = None
+ best_diff = float('inf')
+ max_iterations = 20
+
+ #print(f"Target points: {max_points}, Perimeter: {perimeter}")
+
+ for i in range(max_iterations):
+ curr_eps = (min_eps + max_eps) / 2
+ approx = cv2.approxPolyDP(hull, curr_eps, True)
+ points_diff = len(approx) - max_points
+
+ #print(f"Iteration {i}: points={len(approx)}, eps={curr_eps:.4f}")
+
+ if abs(points_diff) < best_diff:
+ best_approx = approx
+ best_diff = abs(points_diff)
+
+ if len(approx) > max_points:
+ min_eps = curr_eps * 1.1 # More gradual adjustment
+ elif len(approx) < max_points:
+ max_eps = curr_eps * 0.9 # More gradual adjustment
+ else:
+ return approx.squeeze()
+
+ if abs(max_eps - min_eps) < perimeter * 0.0001: # Relative tolerance
+ break
+
+ # If we didn't find exact match, return best approximation
+ return best_approx.squeeze() if best_approx is not None else hull.squeeze()
+
+ def separate(self, mask: torch.Tensor, size_threshold_width: int, size_threshold_height: int, max_poly_points: int, mode: str):
+ from scipy.ndimage import label, center_of_mass
+ import numpy as np
+
+ B, H, W = mask.shape
+ separated = []
+
+ mask = mask.round()
+
+ for b in range(B):
+ mask_np = mask[b].cpu().numpy().astype(np.uint8)
+ structure = np.ones((3, 3), dtype=np.int8)
+ labeled, ncomponents = label(mask_np, structure=structure)
+ pbar = ProgressBar(ncomponents)
+
+ for component in range(1, ncomponents + 1):
+ component_mask_np = (labeled == component).astype(np.uint8)
+
+ rows = np.any(component_mask_np, axis=1)
+ cols = np.any(component_mask_np, axis=0)
+ y_min, y_max = np.where(rows)[0][[0, -1]]
+ x_min, x_max = np.where(cols)[0][[0, -1]]
+
+ width = x_max - x_min + 1
+ height = y_max - y_min + 1
+ centroid_x = (x_min + x_max) / 2 # Calculate x centroid
+ print(f"Component {component}: width={width}, height={height}, x_pos={centroid_x}")
+
+ if width >= size_threshold_width and height >= size_threshold_height:
+ if mode != "area":
+ polygon = self.get_mask_polygon(component_mask_np, max_poly_points)
+ if polygon is not None:
+ poly_mask = self.polygon_to_mask(polygon, (H, W))
+ poly_mask = torch.tensor(poly_mask, device=mask.device)
+ separated.append((centroid_x, poly_mask))
+ else:
+ area_mask = torch.tensor(component_mask_np, device=mask.device)
+ separated.append((centroid_x, area_mask))
+ pbar.update(1)
+
+ if len(separated) > 0:
+ # Sort by x position and extract only the masks
+ separated.sort(key=lambda x: x[0])
+ separated = [x[1] for x in separated]
+ out_masks = torch.stack(separated, dim=0)
+ return out_masks,
+ else:
+ return torch.empty((1, 64, 64), device=mask.device),
+
\ No newline at end of file
diff --git a/custom_nodes/ComfyUI-KJNodes-main/nodes/model_optimization_nodes.py b/custom_nodes/ComfyUI-KJNodes-main/nodes/model_optimization_nodes.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac7c29632e2fd016e1f8f55932bdee53ba445780
--- /dev/null
+++ b/custom_nodes/ComfyUI-KJNodes-main/nodes/model_optimization_nodes.py
@@ -0,0 +1,1179 @@
+from comfy.ldm.modules import attention as comfy_attention
+import logging
+import comfy.model_patcher
+import comfy.utils
+import comfy.sd
+import torch
+import folder_paths
+import comfy.model_management as mm
+from comfy.cli_args import args
+
+orig_attention = comfy_attention.optimized_attention
+original_patch_model = comfy.model_patcher.ModelPatcher.patch_model
+original_load_lora_for_models = comfy.sd.load_lora_for_models
+
+class BaseLoaderKJ:
+ original_linear = None
+ cublas_patched = False
+
+ def _patch_modules(self, patch_cublaslinear, sage_attention):
+ from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight
+
+ if sage_attention != "disabled":
+ print("Patching comfy attention to use sageattn")
+ from sageattention import sageattn
+ def set_sage_func(sage_attention):
+ if sage_attention == "auto":
+ def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
+ return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout)
+ return func
+ elif sage_attention == "sageattn_qk_int8_pv_fp16_cuda":
+ from sageattention import sageattn_qk_int8_pv_fp16_cuda
+ def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
+ return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32", tensor_layout=tensor_layout)
+ return func
+ elif sage_attention == "sageattn_qk_int8_pv_fp16_triton":
+ from sageattention import sageattn_qk_int8_pv_fp16_triton
+ def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
+ return sageattn_qk_int8_pv_fp16_triton(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout)
+ return func
+ elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda":
+ from sageattention import sageattn_qk_int8_pv_fp8_cuda
+ def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
+ return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32", tensor_layout=tensor_layout)
+ return func
+
+ sage_func = set_sage_func(sage_attention)
+
+ @torch.compiler.disable()
+ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
+ if skip_reshape:
+ b, _, _, dim_head = q.shape
+ tensor_layout="HND"
+ else:
+ b, _, dim_head = q.shape
+ dim_head //= heads
+ q, k, v = map(
+ lambda t: t.view(b, -1, heads, dim_head),
+ (q, k, v),
+ )
+ tensor_layout="NHD"
+ if mask is not None:
+ # add a batch dimension if there isn't already one
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(0)
+ # add a heads dimension if there isn't already one
+ if mask.ndim == 3:
+ mask = mask.unsqueeze(1)
+ out = sage_func(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
+ if tensor_layout == "HND":
+ if not skip_output_reshape:
+ out = (
+ out.transpose(1, 2).reshape(b, -1, heads * dim_head)
+ )
+ else:
+ if skip_output_reshape:
+ out = out.transpose(1, 2)
+ else:
+ out = out.reshape(b, -1, heads * dim_head)
+ return out
+
+ comfy_attention.optimized_attention = attention_sage
+ comfy.ldm.hunyuan_video.model.optimized_attention = attention_sage
+ comfy.ldm.flux.math.optimized_attention = attention_sage
+ comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = attention_sage
+ comfy.ldm.cosmos.blocks.optimized_attention = attention_sage
+ comfy.ldm.wan.model.optimized_attention = attention_sage
+
+ else:
+ comfy_attention.optimized_attention = orig_attention
+ comfy.ldm.hunyuan_video.model.optimized_attention = orig_attention
+ comfy.ldm.flux.math.optimized_attention = orig_attention
+ comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = orig_attention
+ comfy.ldm.cosmos.blocks.optimized_attention = orig_attention
+ comfy.ldm.wan.model.optimized_attention = orig_attention
+
+ if patch_cublaslinear:
+ if not BaseLoaderKJ.cublas_patched:
+ BaseLoaderKJ.original_linear = disable_weight_init.Linear
+ try:
+ from cublas_ops import CublasLinear
+ except ImportError:
+ raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm")
+
+ class PatchedLinear(CublasLinear, CastWeightBiasOp):
+ def reset_parameters(self):
+ pass
+
+ def forward_comfy_cast_weights(self, input):
+ weight, bias = cast_bias_weight(self, input)
+ return torch.nn.functional.linear(input, weight, bias)
+
+ def forward(self, *args, **kwargs):
+ if self.comfy_cast_weights:
+ return self.forward_comfy_cast_weights(*args, **kwargs)
+ else:
+ return super().forward(*args, **kwargs)
+
+ disable_weight_init.Linear = PatchedLinear
+ BaseLoaderKJ.cublas_patched = True
+ else:
+ if BaseLoaderKJ.cublas_patched:
+ disable_weight_init.Linear = BaseLoaderKJ.original_linear
+ BaseLoaderKJ.cublas_patched = False
+
+class PathchSageAttentionKJ(BaseLoaderKJ):
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "model": ("MODEL",),
+ "sage_attention": (["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda"], {"default": False, "tooltip": "Global patch comfy attention to use sageattn, once patched to revert back to normal you would need to run this node again with disabled option."}),
+ }}
+
+ RETURN_TYPES = ("MODEL", )
+ FUNCTION = "patch"
+ DESCRIPTION = "Experimental node for patching attention mode. This doesn't use the model patching system and thus can't be disabled without running the node again with 'disabled' option."
+ EXPERIMENTAL = True
+ CATEGORY = "KJNodes/experimental"
+
+ def patch(self, model, sage_attention):
+ self._patch_modules(False, sage_attention)
+ return model,
+
+class CheckpointLoaderKJ(BaseLoaderKJ):
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
+ "patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
+ "sage_attention": (["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda"], {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
+ }}
+
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE")
+ FUNCTION = "patch"
+ OUTPUT_NODE = True
+ DESCRIPTION = "Experimental node for patching torch.nn.Linear with CublasLinear."
+ EXPERIMENTAL = True
+ CATEGORY = "KJNodes/experimental"
+
+ def patch(self, ckpt_name, patch_cublaslinear, sage_attention):
+ self._patch_modules(patch_cublaslinear, sage_attention)
+ from nodes import CheckpointLoaderSimple
+ model, clip, vae = CheckpointLoaderSimple.load_checkpoint(self, ckpt_name)
+ return model, clip, vae
+
+class DiffusionModelLoaderKJ(BaseLoaderKJ):
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "model_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The name of the checkpoint (model) to load."}),
+ "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2", "fp16", "bf16", "fp32"],),
+ "compute_dtype": (["default", "fp16", "bf16", "fp32"], {"default": "fp16", "tooltip": "The compute dtype to use for the model."}),
+ "patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
+ "sage_attention": (["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda"], {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
+ "enable_fp16_accumulation": ("BOOLEAN", {"default": False, "tooltip": "Enable torch.backends.cuda.matmul.allow_fp16_accumulation, requires pytorch 2.7.0 nightly."}),
+ }}
+
+ RETURN_TYPES = ("MODEL",)
+ FUNCTION = "patch_and_load"
+ OUTPUT_NODE = True
+ DESCRIPTION = "Node for patching torch.nn.Linear with CublasLinear."
+ EXPERIMENTAL = True
+ CATEGORY = "KJNodes/experimental"
+
+ def patch_and_load(self, model_name, weight_dtype, compute_dtype, patch_cublaslinear, sage_attention, enable_fp16_accumulation):
+ DTYPE_MAP = {
+ "fp8_e4m3fn": torch.float8_e4m3fn,
+ "fp8_e5m2": torch.float8_e5m2,
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+ "fp32": torch.float32
+ }
+ model_options = {}
+ if dtype := DTYPE_MAP.get(weight_dtype):
+ model_options["dtype"] = dtype
+ print(f"Setting {model_name} weight dtype to {dtype}")
+
+ if weight_dtype == "fp8_e4m3fn_fast":
+ model_options["dtype"] = torch.float8_e4m3fn
+ model_options["fp8_optimizations"] = True
+
+ if enable_fp16_accumulation:
+ if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
+ torch.backends.cuda.matmul.allow_fp16_accumulation = True
+ else:
+ raise RuntimeError("Failed to set fp16 accumulation, this requires pytorch 2.7.0 nightly currently")
+ else:
+ if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
+ torch.backends.cuda.matmul.allow_fp16_accumulation = False
+
+ unet_path = folder_paths.get_full_path_or_raise("diffusion_models", model_name)
+ model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
+ if dtype := DTYPE_MAP.get(compute_dtype):
+ model.set_model_compute_dtype(dtype)
+ model.force_cast_weights = False
+ print(f"Setting {model_name} compute dtype to {dtype}")
+ self._patch_modules(patch_cublaslinear, sage_attention)
+
+ return (model,)
+
+def patched_patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
+ with self.use_ejected():
+
+ device_to = mm.get_torch_device()
+
+ full_load_override = getattr(self.model, "full_load_override", "auto")
+ if full_load_override in ["enabled", "disabled"]:
+ full_load = full_load_override == "enabled"
+ else:
+ full_load = lowvram_model_memory == 0
+
+ self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
+
+ for k in self.object_patches:
+ old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
+ if k not in self.object_patches_backup:
+ self.object_patches_backup[k] = old
+
+ self.inject_model()
+ return self.model
+
+def patched_load_lora_for_models(model, clip, lora, strength_model, strength_clip):
+
+ patch_keys = list(model.object_patches_backup.keys())
+ for k in patch_keys:
+ #print("backing up object patch: ", k)
+ comfy.utils.set_attr(model.model, k, model.object_patches_backup[k])
+
+ key_map = {}
+ if model is not None:
+ key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
+ if clip is not None:
+ key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
+
+ lora = comfy.lora_convert.convert_lora(lora)
+ loaded = comfy.lora.load_lora(lora, key_map)
+ #print(temp_object_patches_backup)
+
+ if model is not None:
+ new_modelpatcher = model.clone()
+ k = new_modelpatcher.add_patches(loaded, strength_model)
+ else:
+ k = ()
+ new_modelpatcher = None
+
+ if clip is not None:
+ new_clip = clip.clone()
+ k1 = new_clip.add_patches(loaded, strength_clip)
+ else:
+ k1 = ()
+ new_clip = None
+ k = set(k)
+ k1 = set(k1)
+ for x in loaded:
+ if (x not in k) and (x not in k1):
+ print("NOT LOADED {}".format(x))
+
+ if patch_keys:
+ if hasattr(model.model, "compile_settings"):
+ compile_settings = getattr(model.model, "compile_settings")
+ print("compile_settings: ", compile_settings)
+ for k in patch_keys:
+ if "diffusion_model." in k:
+ # Remove the prefix to get the attribute path
+ key = k.replace('diffusion_model.', '')
+ attributes = key.split('.')
+ # Start with the diffusion_model object
+ block = model.get_model_object("diffusion_model")
+ # Navigate through the attributes to get to the block
+ for attr in attributes:
+ if attr.isdigit():
+ block = block[int(attr)]
+ else:
+ block = getattr(block, attr)
+ # Compile the block
+ compiled_block = torch.compile(block, mode=compile_settings["mode"], dynamic=compile_settings["dynamic"], fullgraph=compile_settings["fullgraph"], backend=compile_settings["backend"])
+ # Add the compiled block back as an object patch
+ model.add_object_patch(k, compiled_block)
+ return (new_modelpatcher, new_clip)
+
+class PatchModelPatcherOrder:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "model": ("MODEL",),
+ "patch_order": (["object_patch_first", "weight_patch_first"], {"default": "weight_patch_first", "tooltip": "Patch the comfy patch_model function to load weight patches (LoRAs) before compiling the model"}),
+ "full_load": (["enabled", "disabled", "auto"], {"default": "auto", "tooltip": "Disabling may help with memory issues when loading large models, when changing this you should probably force model reload to avoid issues!"}),
+ }}
+ RETURN_TYPES = ("MODEL",)
+ FUNCTION = "patch"
+ CATEGORY = "KJNodes/experimental"
+ DESCRIPTION = "Patch the comfy patch_model function patching order, useful for torch.compile (used as object_patch) as it should come last if you want to use LoRAs with compile"
+ EXPERIMENTAL = True
+
+ def patch(self, model, patch_order, full_load):
+ comfy.model_patcher.ModelPatcher.temp_object_patches_backup = {}
+ setattr(model.model, "full_load_override", full_load)
+ if patch_order == "weight_patch_first":
+ comfy.model_patcher.ModelPatcher.patch_model = patched_patch_model
+ comfy.sd.load_lora_for_models = patched_load_lora_for_models
+ else:
+ comfy.model_patcher.ModelPatcher.patch_model = original_patch_model
+ comfy.sd.load_lora_for_models = original_load_lora_for_models
+
+ return model,
+
+class TorchCompileModelFluxAdvanced:
+ def __init__(self):
+ self._compiled = False
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "model": ("MODEL",),
+ "backend": (["inductor", "cudagraphs"],),
+ "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
+ "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
+ "double_blocks": ("STRING", {"default": "0-18", "multiline": True}),
+ "single_blocks": ("STRING", {"default": "0-37", "multiline": True}),
+ "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
+ },
+ "optional": {
+ "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}),
+ }
+ }
+ RETURN_TYPES = ("MODEL",)
+ FUNCTION = "patch"
+
+ CATEGORY = "KJNodes/torchcompile"
+ EXPERIMENTAL = True
+
+ def parse_blocks(self, blocks_str):
+ blocks = []
+ for part in blocks_str.split(','):
+ part = part.strip()
+ if '-' in part:
+ start, end = map(int, part.split('-'))
+ blocks.extend(range(start, end + 1))
+ else:
+ blocks.append(int(part))
+ return blocks
+
+ def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks, dynamic, dynamo_cache_size_limit):
+ single_block_list = self.parse_blocks(single_blocks)
+ double_block_list = self.parse_blocks(double_blocks)
+ m = model.clone()
+ diffusion_model = m.get_model_object("diffusion_model")
+ torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit
+
+ if not self._compiled:
+ try:
+ for i, block in enumerate(diffusion_model.double_blocks):
+ if i in double_block_list:
+ #print("Compiling double_block", i)
+ m.add_object_patch(f"diffusion_model.double_blocks.{i}", torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend))
+ for i, block in enumerate(diffusion_model.single_blocks):
+ if i in single_block_list:
+ #print("Compiling single block", i)
+ m.add_object_patch(f"diffusion_model.single_blocks.{i}", torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend))
+ self._compiled = True
+ compile_settings = {
+ "backend": backend,
+ "mode": mode,
+ "fullgraph": fullgraph,
+ "dynamic": dynamic,
+ }
+ setattr(m.model, "compile_settings", compile_settings)
+ except:
+ raise RuntimeError("Failed to compile model")
+
+ return (m, )
+ # rest of the layers that are not patched
+ # diffusion_model.final_layer = torch.compile(diffusion_model.final_layer, mode=mode, fullgraph=fullgraph, backend=backend)
+ # diffusion_model.guidance_in = torch.compile(diffusion_model.guidance_in, mode=mode, fullgraph=fullgraph, backend=backend)
+ # diffusion_model.img_in = torch.compile(diffusion_model.img_in, mode=mode, fullgraph=fullgraph, backend=backend)
+ # diffusion_model.time_in = torch.compile(diffusion_model.time_in, mode=mode, fullgraph=fullgraph, backend=backend)
+ # diffusion_model.txt_in = torch.compile(diffusion_model.txt_in, mode=mode, fullgraph=fullgraph, backend=backend)
+ # diffusion_model.vector_in = torch.compile(diffusion_model.vector_in, mode=mode, fullgraph=fullgraph, backend=backend)
+
+class TorchCompileModelHyVideo:
+ def __init__(self):
+ self._compiled = False
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "model": ("MODEL",),
+ "backend": (["inductor","cudagraphs"], {"default": "inductor"}),
+ "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
+ "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
+ "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
+ "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}),
+ "compile_single_blocks": ("BOOLEAN", {"default": True, "tooltip": "Compile single blocks"}),
+ "compile_double_blocks": ("BOOLEAN", {"default": True, "tooltip": "Compile double blocks"}),
+ "compile_txt_in": ("BOOLEAN", {"default": False, "tooltip": "Compile txt_in layers"}),
+ "compile_vector_in": ("BOOLEAN", {"default": False, "tooltip": "Compile vector_in layers"}),
+ "compile_final_layer": ("BOOLEAN", {"default": False, "tooltip": "Compile final layer"}),
+
+ },
+ }
+ RETURN_TYPES = ("MODEL",)
+ FUNCTION = "patch"
+
+ CATEGORY = "KJNodes/torchcompile"
+ EXPERIMENTAL = True
+
+ def patch(self, model, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_single_blocks, compile_double_blocks, compile_txt_in, compile_vector_in, compile_final_layer):
+ m = model.clone()
+ diffusion_model = m.get_model_object("diffusion_model")
+ torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit
+ if not self._compiled:
+ try:
+ if compile_single_blocks:
+ for i, block in enumerate(diffusion_model.single_blocks):
+ compiled_block = torch.compile(block, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
+ m.add_object_patch(f"diffusion_model.single_blocks.{i}", compiled_block)
+ if compile_double_blocks:
+ for i, block in enumerate(diffusion_model.double_blocks):
+ compiled_block = torch.compile(block, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
+ m.add_object_patch(f"diffusion_model.double_blocks.{i}", compiled_block)
+ if compile_txt_in:
+ compiled_block = torch.compile(diffusion_model.txt_in, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
+ m.add_object_patch("diffusion_model.txt_in", compiled_block)
+ if compile_vector_in:
+ compiled_block = torch.compile(diffusion_model.vector_in, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
+ m.add_object_patch("diffusion_model.vector_in", compiled_block)
+ if compile_final_layer:
+ compiled_block = torch.compile(diffusion_model.final_layer, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
+ m.add_object_patch("diffusion_model.final_layer", compiled_block)
+ self._compiled = True
+ compile_settings = {
+ "backend": backend,
+ "mode": mode,
+ "fullgraph": fullgraph,
+ "dynamic": dynamic,
+ }
+ setattr(m.model, "compile_settings", compile_settings)
+ except:
+ raise RuntimeError("Failed to compile model")
+ return (m, )
+
+class TorchCompileModelWanVideo:
+ def __init__(self):
+ self._compiled = False
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "model": ("MODEL",),
+ "backend": (["inductor","cudagraphs"], {"default": "inductor"}),
+ "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
+ "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
+ "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
+ "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}),
+ "compile_transformer_blocks_only": ("BOOLEAN", {"default": False, "tooltip": "Compile only transformer blocks"}),
+ },
+ }
+ RETURN_TYPES = ("MODEL",)
+ FUNCTION = "patch"
+
+ CATEGORY = "KJNodes/torchcompile"
+ EXPERIMENTAL = True
+
+ def patch(self, model, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks_only):
+ m = model.clone()
+ diffusion_model = m.get_model_object("diffusion_model")
+ torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit
+ is_compiled = hasattr(model.model.diffusion_model.blocks[0], "_orig_mod")
+ if is_compiled:
+ logging.info(f"Already compiled, not reapplying")
+ else:
+ logging.info(f"Not compiled, applying")
+ try:
+ if compile_transformer_blocks_only:
+ for i, block in enumerate(diffusion_model.blocks):
+ if is_compiled:
+ compiled_block = torch.compile(block._orig_mod, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
+ else:
+ compiled_block = torch.compile(block, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
+ m.add_object_patch(f"diffusion_model.blocks.{i}", compiled_block)
+ else:
+ compiled_model = torch.compile(diffusion_model, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
+ m.add_object_patch("diffusion_model", compiled_model)
+
+ compile_settings = {
+ "backend": backend,
+ "mode": mode,
+ "fullgraph": fullgraph,
+ "dynamic": dynamic,
+ }
+ setattr(m.model, "compile_settings", compile_settings)
+ except:
+ raise RuntimeError("Failed to compile model")
+ return (m, )
+
+class TorchCompileVAE:
+ def __init__(self):
+ self._compiled_encoder = False
+ self._compiled_decoder = False
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "vae": ("VAE",),
+ "backend": (["inductor", "cudagraphs"],),
+ "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
+ "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
+ "compile_encoder": ("BOOLEAN", {"default": True, "tooltip": "Compile encoder"}),
+ "compile_decoder": ("BOOLEAN", {"default": True, "tooltip": "Compile decoder"}),
+ }}
+ RETURN_TYPES = ("VAE",)
+ FUNCTION = "compile"
+
+ CATEGORY = "KJNodes/torchcompile"
+ EXPERIMENTAL = True
+
+ def compile(self, vae, backend, mode, fullgraph, compile_encoder, compile_decoder):
+ if compile_encoder:
+ if not self._compiled_encoder:
+ encoder_name = "encoder"
+ if hasattr(vae.first_stage_model, "taesd_encoder"):
+ encoder_name = "taesd_encoder"
+
+ try:
+ setattr(
+ vae.first_stage_model,
+ encoder_name,
+ torch.compile(
+ getattr(vae.first_stage_model, encoder_name),
+ mode=mode,
+ fullgraph=fullgraph,
+ backend=backend,
+ ),
+ )
+ self._compiled_encoder = True
+ except:
+ raise RuntimeError("Failed to compile model")
+ if compile_decoder:
+ if not self._compiled_decoder:
+ decoder_name = "decoder"
+ if hasattr(vae.first_stage_model, "taesd_decoder"):
+ decoder_name = "taesd_decoder"
+
+ try:
+ setattr(
+ vae.first_stage_model,
+ decoder_name,
+ torch.compile(
+ getattr(vae.first_stage_model, decoder_name),
+ mode=mode,
+ fullgraph=fullgraph,
+ backend=backend,
+ ),
+ )
+ self._compiled_decoder = True
+ except:
+ raise RuntimeError("Failed to compile model")
+ return (vae, )
+
+class TorchCompileControlNet:
+ def __init__(self):
+ self._compiled= False
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "controlnet": ("CONTROL_NET",),
+ "backend": (["inductor", "cudagraphs"],),
+ "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
+ "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
+ }}
+ RETURN_TYPES = ("CONTROL_NET",)
+ FUNCTION = "compile"
+
+ CATEGORY = "KJNodes/torchcompile"
+ EXPERIMENTAL = True
+
+ def compile(self, controlnet, backend, mode, fullgraph):
+ if not self._compiled:
+ try:
+ # for i, block in enumerate(controlnet.control_model.double_blocks):
+ # print("Compiling controlnet double_block", i)
+ # controlnet.control_model.double_blocks[i] = torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend)
+ controlnet.control_model = torch.compile(controlnet.control_model, mode=mode, fullgraph=fullgraph, backend=backend)
+ self._compiled = True
+ except:
+ self._compiled = False
+ raise RuntimeError("Failed to compile model")
+
+ return (controlnet, )
+
+class TorchCompileLTXModel:
+ def __init__(self):
+ self._compiled = False
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "model": ("MODEL",),
+ "backend": (["inductor", "cudagraphs"],),
+ "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
+ "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
+ "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
+ }}
+ RETURN_TYPES = ("MODEL",)
+ FUNCTION = "patch"
+
+ CATEGORY = "KJNodes/torchcompile"
+ EXPERIMENTAL = True
+
+ def patch(self, model, backend, mode, fullgraph, dynamic):
+ m = model.clone()
+ diffusion_model = m.get_model_object("diffusion_model")
+
+ if not self._compiled:
+ try:
+ for i, block in enumerate(diffusion_model.transformer_blocks):
+ compiled_block = torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend)
+ m.add_object_patch(f"diffusion_model.transformer_blocks.{i}", compiled_block)
+ self._compiled = True
+ compile_settings = {
+ "backend": backend,
+ "mode": mode,
+ "fullgraph": fullgraph,
+ "dynamic": dynamic,
+ }
+ setattr(m.model, "compile_settings", compile_settings)
+
+ except:
+ raise RuntimeError("Failed to compile model")
+
+ return (m, )
+
+class TorchCompileCosmosModel:
+ def __init__(self):
+ self._compiled = False
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "model": ("MODEL",),
+ "backend": (["inductor", "cudagraphs"],),
+ "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
+ "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
+ "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
+ "dynamo_cache_size_limit": ("INT", {"default": 64, "tooltip": "Set the dynamo cache size limit"}),
+ }}
+ RETURN_TYPES = ("MODEL",)
+ FUNCTION = "patch"
+
+ CATEGORY = "KJNodes/torchcompile"
+ EXPERIMENTAL = True
+
+ def patch(self, model, backend, mode, fullgraph, dynamic, dynamo_cache_size_limit):
+
+ m = model.clone()
+ diffusion_model = m.get_model_object("diffusion_model")
+ torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit
+
+ if not self._compiled:
+ try:
+ for name, block in diffusion_model.blocks.items():
+ #print(f"Compiling block {name}")
+ compiled_block = torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend)
+ m.add_object_patch(f"diffusion_model.blocks.{name}", compiled_block)
+ #diffusion_model.blocks[name] = compiled_block
+
+ self._compiled = True
+ compile_settings = {
+ "backend": backend,
+ "mode": mode,
+ "fullgraph": fullgraph,
+ "dynamic": dynamic,
+ }
+ setattr(m.model, "compile_settings", compile_settings)
+
+ except:
+ raise RuntimeError("Failed to compile model")
+
+ return (m, )
+
+
+#teacache
+
+try:
+ from comfy.ldm.wan.model import sinusoidal_embedding_1d
+except:
+ pass
+from einops import repeat
+from unittest.mock import patch
+from contextlib import nullcontext
+import numpy as np
+
+def relative_l1_distance(last_tensor, current_tensor):
+ l1_distance = torch.abs(last_tensor - current_tensor).mean()
+ norm = torch.abs(last_tensor).mean()
+ relative_l1_distance = l1_distance / norm
+ return relative_l1_distance.to(torch.float32)
+
+def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, **kwargs):
+ # embeddings
+ x = self.patch_embedding(x.float()).to(x.dtype)
+ grid_sizes = x.shape[2:]
+ x = x.flatten(2).transpose(1, 2)
+
+ # time embeddings
+ e = self.time_embedding(
+ sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
+
+ # context
+ context = self.text_embedding(context)
+ if clip_fea is not None and self.img_emb is not None:
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
+ context = torch.concat([context_clip, context], dim=1)
+
+ @torch.compiler.disable()
+ def tea_cache(x, e0, e, kwargs):
+ #teacache for cond and uncond separately
+ rel_l1_thresh = transformer_options["rel_l1_thresh"]
+
+ is_cond = True if transformer_options["cond_or_uncond"] == [0] else False
+
+ should_calc = True
+ suffix = "cond" if is_cond else "uncond"
+
+ # Init cache dict if not exists
+ if not hasattr(self, 'teacache_state'):
+ self.teacache_state = {
+ 'cond': {'accumulated_rel_l1_distance': 0, 'prev_input': None,
+ 'teacache_skipped_steps': 0, 'previous_residual': None},
+ 'uncond': {'accumulated_rel_l1_distance': 0, 'prev_input': None,
+ 'teacache_skipped_steps': 0, 'previous_residual': None}
+ }
+ logging.info("\nTeaCache: Initialized")
+
+ cache = self.teacache_state[suffix]
+
+ if cache['prev_input'] is not None:
+ if transformer_options["coefficients"] == []:
+ temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0)
+ curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1
+ else:
+ rescale_func = np.poly1d(transformer_options["coefficients"])
+ curr_acc_dist = cache['accumulated_rel_l1_distance'] + rescale_func(((e-cache['prev_input']).abs().mean() / cache['prev_input'].abs().mean()).cpu().item())
+ try:
+ if curr_acc_dist < rel_l1_thresh:
+ should_calc = False
+ cache['accumulated_rel_l1_distance'] = curr_acc_dist
+ else:
+ should_calc = True
+ cache['accumulated_rel_l1_distance'] = 0
+ except:
+ should_calc = True
+ cache['accumulated_rel_l1_distance'] = 0
+
+ if transformer_options["coefficients"] == []:
+ cache['prev_input'] = e0.clone().detach()
+ else:
+ cache['prev_input'] = e.clone().detach()
+
+ if not should_calc:
+ x += cache['previous_residual'].to(x.device)
+ cache['teacache_skipped_steps'] += 1
+ #print(f"TeaCache: Skipping {suffix} step")
+ return should_calc, cache
+
+ if not transformer_options:
+ raise RuntimeError("Can't access transformer_options, this requires ComfyUI nightly version from Mar 14, 2025 or later")
+
+ teacache_enabled = transformer_options.get("teacache_enabled", False)
+ if not teacache_enabled:
+ should_calc = True
+ else:
+ should_calc, cache = tea_cache(x, e0, e, kwargs)
+
+ if should_calc:
+ original_x = x.clone().detach()
+ patches_replace = transformer_options.get("patches_replace", {})
+ blocks_replace = patches_replace.get("dit", {})
+ for i, block in enumerate(self.blocks):
+ if ("double_block", i) in blocks_replace:
+ def block_wrap(args):
+ out = {}
+ out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
+ return out
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap, "transformer_options": transformer_options})
+ x = out["img"]
+ else:
+ x = block(x, e=e0, freqs=freqs, context=context)
+
+ if teacache_enabled:
+ cache['previous_residual'] = (x - original_x).to(transformer_options["teacache_device"])
+
+ # head
+ x = self.head(x, e)
+
+ # unpatchify
+ x = self.unpatchify(x, grid_sizes)
+ return x
+
+class WanVideoTeaCacheKJ:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "model": ("MODEL",),
+ "rel_l1_thresh": ("FLOAT", {"default": 0.275, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Threshold for to determine when to apply the cache, compromise between speed and accuracy. When using coefficients a good value range is something between 0.2-0.4 for all but 1.3B model, which should be about 10 times smaller, same as when not using coefficients."}),
+ "start_percent": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The start percentage of the steps to use with TeaCache."}),
+ "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The end percentage of the steps to use with TeaCache."}),
+ "cache_device": (["main_device", "offload_device"], {"default": "offload_device", "tooltip": "Device to cache to"}),
+ "coefficients": (["disabled", "1.3B", "14B", "i2v_480", "i2v_720"], {"default": "i2v_480", "tooltip": "Coefficients for rescaling the relative l1 distance, if disabled the threshold value should be about 10 times smaller than the value used with coefficients."}),
+ }
+ }
+
+ RETURN_TYPES = ("MODEL",)
+ RETURN_NAMES = ("model",)
+ FUNCTION = "patch_teacache"
+ CATEGORY = "KJNodes/teacache"
+ DESCRIPTION = """
+Patch WanVideo model to use TeaCache. Speeds up inference by caching the output and
+applying it instead of doing the step. Best results are achieved by choosing the
+appropriate coefficients for the model. Early steps should never be skipped, with too
+aggressive values this can happen and the motion suffers. Starting later can help with that too.
+When NOT using coefficients, the threshold value should be
+about 10 times smaller than the value used with coefficients.
+
+Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaCache4Wan2.1:
+
+
+
++-------------------+--------+---------+--------+
+| Model | Low | Medium | High |
++-------------------+--------+---------+--------+
+| Wan2.1 t2v 1.3B | 0.05 | 0.07 | 0.08 |
+| Wan2.1 t2v 14B | 0.14 | 0.15 | 0.20 |
+| Wan2.1 i2v 480P | 0.13 | 0.19 | 0.26 |
+| Wan2.1 i2v 720P | 0.18 | 0.20 | 0.30 |
++-------------------+--------+---------+--------+
+
+"""
+ EXPERIMENTAL = True
+
+ def patch_teacache(self, model, rel_l1_thresh, start_percent, end_percent, cache_device, coefficients):
+ if rel_l1_thresh == 0:
+ return (model,)
+
+ if coefficients == "disabled" and rel_l1_thresh > 0.1:
+ logging.warning("Threshold value is too high for TeaCache without coefficients, consider using coefficients for better results.")
+ if coefficients != "disabled" and rel_l1_thresh < 0.1 and "1.3B" not in coefficients:
+ logging.warning("Threshold value is too low for TeaCache with coefficients, consider using higher threshold value for better results.")
+
+ # type_str = str(type(model.model.model_config).__name__)
+ #if model.model.diffusion_model.dim == 1536:
+ # model_type ="1.3B"
+ # else:
+ # if "WAN21_T2V" in type_str:
+ # model_type = "14B"
+ # elif "WAN21_I2V" in type_str:
+ # model_type = "i2v_480"
+ # else:
+ # model_type = "i2v_720" #how to detect this?
+
+
+ teacache_coefficients_map = {
+ "disabled": [],
+ "1.3B": [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01],
+ "14B": [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404],
+ "i2v_480": [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01],
+ "i2v_720": [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683],
+ }
+ coefficients = teacache_coefficients_map[coefficients]
+
+ teacache_device = mm.get_torch_device() if cache_device == "main_device" else mm.unet_offload_device()
+
+ model_clone = model.clone()
+ if 'transformer_options' not in model_clone.model_options:
+ model_clone.model_options['transformer_options'] = {}
+ model_clone.model_options["transformer_options"]["rel_l1_thresh"] = rel_l1_thresh
+ model_clone.model_options["transformer_options"]["teacache_device"] = teacache_device
+ model_clone.model_options["transformer_options"]["coefficients"] = coefficients
+ diffusion_model = model_clone.get_model_object("diffusion_model")
+
+ def outer_wrapper(start_percent, end_percent):
+ def unet_wrapper_function(model_function, kwargs):
+ input = kwargs["input"]
+ timestep = kwargs["timestep"]
+ c = kwargs["c"]
+ sigmas = c["transformer_options"]["sample_sigmas"]
+ cond_or_uncond = kwargs["cond_or_uncond"]
+ last_step = (len(sigmas) - 1)
+
+ matched_step_index = (sigmas == timestep[0] ).nonzero()
+ if len(matched_step_index) > 0:
+ current_step_index = matched_step_index.item()
+ else:
+ for i in range(len(sigmas) - 1):
+ # walk from beginning of steps until crossing the timestep
+ if (sigmas[i] - timestep[0]) * (sigmas[i + 1] - timestep[0]) <= 0:
+ current_step_index = i
+ break
+ else:
+ current_step_index = 0
+
+ if current_step_index == 0:
+ if hasattr(diffusion_model, "teacache_state"):
+ delattr(diffusion_model, "teacache_state")
+ logging.info("\nResetting TeaCache state")
+
+ current_percent = current_step_index / (len(sigmas) - 1)
+ c["transformer_options"]["current_percent"] = current_percent
+ if start_percent <= current_percent <= end_percent:
+ c["transformer_options"]["teacache_enabled"] = True
+
+ context = patch.multiple(
+ diffusion_model,
+ forward_orig=teacache_wanvideo_forward_orig.__get__(diffusion_model, diffusion_model.__class__)
+ )
+
+ with context:
+ out = model_function(input, timestep, **c)
+ if current_step_index+1 == last_step and hasattr(diffusion_model, "teacache_state"):
+ if len(cond_or_uncond) == 1 and cond_or_uncond[0] == 0:
+ skipped_steps_cond = diffusion_model.teacache_state["cond"]["teacache_skipped_steps"]
+ skipped_steps_uncond = diffusion_model.teacache_state["uncond"]["teacache_skipped_steps"]
+ logging.info("-----------------------------------")
+ logging.info(f"TeaCache skipped:")
+ logging.info(f"{skipped_steps_cond} cond steps")
+ logging.info(f"{skipped_steps_uncond} uncond step")
+ logging.info(f"out of {last_step} steps")
+ logging.info("-----------------------------------")
+ elif len(cond_or_uncond) == 2:
+ skipped_steps_cond = diffusion_model.teacache_state["uncond"]["teacache_skipped_steps"]
+ logging.info("-----------------------------------")
+ logging.info(f"TeaCache skipped:")
+ logging.info(f"{skipped_steps_cond} cond steps")
+ logging.info(f"out of {last_step} steps")
+ logging.info("-----------------------------------")
+
+ return out
+ return unet_wrapper_function
+
+ model_clone.set_model_unet_function_wrapper(outer_wrapper(start_percent=start_percent, end_percent=end_percent))
+
+ return (model_clone,)
+
+
+
+from comfy.ldm.modules.attention import optimized_attention
+from comfy.ldm.flux.math import apply_rope
+
+def modified_wan_self_attention_forward(self, x, freqs):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
+ """
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
+
+ # query, key, value function
+ def qkv_fn(x):
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
+ v = self.v(x).view(b, s, n * d)
+ return q, k, v
+
+ q, k, v = qkv_fn(x)
+
+ q, k = apply_rope(q, k, freqs)
+
+ feta_scores = get_feta_scores(q, k, self.num_frames, self.enhance_weight)
+
+ x = optimized_attention(
+ q.view(b, s, n * d),
+ k.view(b, s, n * d),
+ v,
+ heads=self.num_heads,
+ )
+
+ x = self.o(x)
+
+ x *= feta_scores
+
+ return x
+
+from einops import rearrange
+def get_feta_scores(query, key, num_frames, enhance_weight):
+ img_q, img_k = query, key #torch.Size([2, 9216, 12, 128])
+
+ _, ST, num_heads, head_dim = img_q.shape
+ spatial_dim = ST / num_frames
+ spatial_dim = int(spatial_dim)
+
+ query_image = rearrange(
+ img_q, "B (T S) N C -> (B S) N T C", T=num_frames, S=spatial_dim, N=num_heads, C=head_dim
+ )
+ key_image = rearrange(
+ img_k, "B (T S) N C -> (B S) N T C", T=num_frames, S=spatial_dim, N=num_heads, C=head_dim
+ )
+
+ return feta_score(query_image, key_image, head_dim, num_frames, enhance_weight)
+
+def feta_score(query_image, key_image, head_dim, num_frames, enhance_weight):
+ scale = head_dim**-0.5
+ query_image = query_image * scale
+ attn_temp = query_image @ key_image.transpose(-2, -1) # translate attn to float32
+ attn_temp = attn_temp.to(torch.float32)
+ attn_temp = attn_temp.softmax(dim=-1)
+
+ # Reshape to [batch_size * num_tokens, num_frames, num_frames]
+ attn_temp = attn_temp.reshape(-1, num_frames, num_frames)
+
+ # Create a mask for diagonal elements
+ diag_mask = torch.eye(num_frames, device=attn_temp.device).bool()
+ diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1)
+
+ # Zero out diagonal elements
+ attn_wo_diag = attn_temp.masked_fill(diag_mask, 0)
+
+ # Calculate mean for each token's attention matrix
+ # Number of off-diagonal elements per matrix is n*n - n
+ num_off_diag = num_frames * num_frames - num_frames
+ mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag
+
+ enhance_scores = mean_scores.mean() * (num_frames + enhance_weight)
+ enhance_scores = enhance_scores.clamp(min=1)
+ return enhance_scores
+
+import types
+class WanAttentionPatch:
+ def __init__(self, num_frames, weight):
+ self.num_frames = num_frames
+ self.enhance_weight = weight
+
+ def __get__(self, obj, objtype=None):
+ # Create bound method with stored parameters
+ def wrapped_attention(self_module, *args, **kwargs):
+ self_module.num_frames = self.num_frames
+ self_module.enhance_weight = self.enhance_weight
+ return modified_wan_self_attention_forward(self_module, *args, **kwargs)
+ return types.MethodType(wrapped_attention, obj)
+
+class WanVideoEnhanceAVideoKJ:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "model": ("MODEL",),
+ "latent": ("LATENT", {"tooltip": "Only used to get the latent count"}),
+ "weight": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Strength of the enhance effect"}),
+ }
+ }
+
+ RETURN_TYPES = ("MODEL",)
+ RETURN_NAMES = ("model",)
+ FUNCTION = "enhance"
+ CATEGORY = "KJNodes/experimental"
+ DESCRIPTION = "https://github.com/NUS-HPC-AI-Lab/Enhance-A-Video"
+ EXPERIMENTAL = True
+
+ def enhance(self, model, weight, latent):
+ if weight == 0:
+ return (model,)
+
+ num_frames = latent["samples"].shape[2]
+
+ model_clone = model.clone()
+ if 'transformer_options' not in model_clone.model_options:
+ model_clone.model_options['transformer_options'] = {}
+ model_clone.model_options["transformer_options"]["enhance_weight"] = weight
+ diffusion_model = model_clone.get_model_object("diffusion_model")
+
+ compile_settings = getattr(model.model, "compile_settings", None)
+ for idx, block in enumerate(diffusion_model.blocks):
+ patched_attn = WanAttentionPatch(num_frames, weight).__get__(block.self_attn, block.__class__)
+ if compile_settings is not None:
+ patched_attn = torch.compile(patched_attn, mode=compile_settings["mode"], dynamic=compile_settings["dynamic"], fullgraph=compile_settings["fullgraph"], backend=compile_settings["backend"])
+
+ model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.self_attn.forward", patched_attn)
+
+ return (model_clone,)
+
+class SkipLayerGuidanceWanVideo:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {"model": ("MODEL", ),
+ "blocks": ("STRING", {"default": "10", "multiline": False}),
+ "start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
+ "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
+ }}
+ RETURN_TYPES = ("MODEL",)
+ FUNCTION = "slg"
+ EXPERIMENTAL = True
+ DESCRIPTION = "Simplified skip layer guidance that only skips the uncond on selected blocks"
+
+ CATEGORY = "advanced/guidance"
+
+ def slg(self, model, start_percent, end_percent, blocks):
+ def skip(args, extra_args):
+ transformer_options = extra_args.get("transformer_options", {})
+ original_block = extra_args["original_block"]
+
+ if not transformer_options:
+ raise ValueError("transformer_options not found in extra_args, currently SkipLayerGuidanceWanVideo only works with TeaCacheKJ")
+ if start_percent <= transformer_options["current_percent"] <= end_percent:
+ if args["img"].shape[0] == 2:
+ prev_img_uncond = args["img"][0].unsqueeze(0)
+
+ new_args = {
+ "img": args["img"][1],
+ "txt": args["txt"][1],
+ "vec": args["vec"][1],
+ "pe": args["pe"][1]
+ }
+
+ block_out = original_block(new_args)
+
+ out = {
+ "img": torch.cat([prev_img_uncond, block_out["img"]], dim=0),
+ "txt": args["txt"],
+ "vec": args["vec"],
+ "pe": args["pe"]
+ }
+ else:
+ if transformer_options.get("cond_or_uncond") == [0]:
+ out = original_block(args)
+ else:
+ out = args
+ else:
+ out = original_block(args)
+ return out
+
+ block_list = [int(x.strip()) for x in blocks.split(",")]
+ blocks = [int(i) for i in block_list]
+ logging.info(f"Selected blocks to skip uncond on: {blocks}")
+
+ m = model.clone()
+
+ for b in blocks:
+ #m.set_model_patch_replace(skip, "dit", "double_block", b)
+ model_options = m.model_options["transformer_options"].copy()
+ if "patches_replace" not in model_options:
+ model_options["patches_replace"] = {}
+ else:
+ model_options["patches_replace"] = model_options["patches_replace"].copy()
+
+ if "dit" not in model_options["patches_replace"]:
+ model_options["patches_replace"]["dit"] = {}
+ else:
+ model_options["patches_replace"]["dit"] = model_options["patches_replace"]["dit"].copy()
+
+ block = ("double_block", b)
+
+ model_options["patches_replace"]["dit"][block] = skip
+ m.model_options["transformer_options"] = model_options
+
+
+ return (m, )
\ No newline at end of file
diff --git a/custom_nodes/ComfyUI-KJNodes-main/nodes/nodes.py b/custom_nodes/ComfyUI-KJNodes-main/nodes/nodes.py
new file mode 100644
index 0000000000000000000000000000000000000000..719ae40e03ded10d162e181f481b3dbe22f4a243
--- /dev/null
+++ b/custom_nodes/ComfyUI-KJNodes-main/nodes/nodes.py
@@ -0,0 +1,2728 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from PIL import Image
+from typing import Union
+import json, re, os, io, time, platform
+import re
+import importlib
+
+import model_management
+import folder_paths
+from nodes import MAX_RESOLUTION
+from comfy.utils import common_upscale, ProgressBar, load_torch_file
+
+script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+folder_paths.add_model_folder_path("kjnodes_fonts", os.path.join(script_directory, "fonts"))
+
+class AnyType(str):
+ """A special class that is always equal in not equal comparisons. Credit to pythongosssss"""
+
+ def __ne__(self, __value: object) -> bool:
+ return False
+any = AnyType("*")
+
+class BOOLConstant:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "value": ("BOOLEAN", {"default": True}),
+ },
+ }
+ RETURN_TYPES = ("BOOLEAN",)
+ RETURN_NAMES = ("value",)
+ FUNCTION = "get_value"
+ CATEGORY = "KJNodes/constants"
+
+ def get_value(self, value):
+ return (value,)
+
+class INTConstant:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "value": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
+ },
+ }
+ RETURN_TYPES = ("INT",)
+ RETURN_NAMES = ("value",)
+ FUNCTION = "get_value"
+ CATEGORY = "KJNodes/constants"
+
+ def get_value(self, value):
+ return (value,)
+
+class FloatConstant:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "value": ("FLOAT", {"default": 0.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.00001}),
+ },
+ }
+
+ RETURN_TYPES = ("FLOAT",)
+ RETURN_NAMES = ("value",)
+ FUNCTION = "get_value"
+ CATEGORY = "KJNodes/constants"
+
+ def get_value(self, value):
+ return (value,)
+
+class StringConstant:
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "string": ("STRING", {"default": '', "multiline": False}),
+ }
+ }
+ RETURN_TYPES = ("STRING",)
+ FUNCTION = "passtring"
+ CATEGORY = "KJNodes/constants"
+
+ def passtring(self, string):
+ return (string, )
+
+class StringConstantMultiline:
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "string": ("STRING", {"default": "", "multiline": True}),
+ "strip_newlines": ("BOOLEAN", {"default": True}),
+ }
+ }
+ RETURN_TYPES = ("STRING",)
+ FUNCTION = "stringify"
+ CATEGORY = "KJNodes/constants"
+
+ def stringify(self, string, strip_newlines):
+ new_string = []
+ for line in io.StringIO(string):
+ if not line.strip().startswith("\n") and strip_newlines:
+ line = line.replace("\n", '')
+ new_string.append(line)
+ new_string = "\n".join(new_string)
+
+ return (new_string, )
+
+
+
+class ScaleBatchPromptSchedule:
+
+ RETURN_TYPES = ("STRING",)
+ FUNCTION = "scaleschedule"
+ CATEGORY = "KJNodes/misc"
+ DESCRIPTION = """
+Scales a batch schedule from Fizz' nodes BatchPromptSchedule
+to a different frame count.
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "input_str": ("STRING", {"forceInput": True,"default": "0:(0.0),\n7:(1.0),\n15:(0.0)\n"}),
+ "old_frame_count": ("INT", {"forceInput": True,"default": 1,"min": 1, "max": 4096, "step": 1}),
+ "new_frame_count": ("INT", {"forceInput": True,"default": 1,"min": 1, "max": 4096, "step": 1}),
+
+ },
+ }
+
+ def scaleschedule(self, old_frame_count, input_str, new_frame_count):
+ pattern = r'"(\d+)"\s*:\s*"(.*?)"(?:,|\Z)'
+ frame_strings = dict(re.findall(pattern, input_str))
+
+ # Calculate the scaling factor
+ scaling_factor = (new_frame_count - 1) / (old_frame_count - 1)
+
+ # Initialize a dictionary to store the new frame numbers and strings
+ new_frame_strings = {}
+
+ # Iterate over the frame numbers and strings
+ for old_frame, string in frame_strings.items():
+ # Calculate the new frame number
+ new_frame = int(round(int(old_frame) * scaling_factor))
+
+ # Store the new frame number and corresponding string
+ new_frame_strings[new_frame] = string
+
+ # Format the output string
+ output_str = ', '.join([f'"{k}":"{v}"' for k, v in sorted(new_frame_strings.items())])
+ return (output_str,)
+
+
+class GetLatentsFromBatchIndexed:
+
+ RETURN_TYPES = ("LATENT",)
+ FUNCTION = "indexedlatentsfrombatch"
+ CATEGORY = "KJNodes/latents"
+ DESCRIPTION = """
+Selects and returns the latents at the specified indices as an latent batch.
+"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "latents": ("LATENT",),
+ "indexes": ("STRING", {"default": "0, 1, 2", "multiline": True}),
+ "latent_format": (["BCHW", "BTCHW", "BCTHW"], {"default": "BCHW"}),
+ },
+ }
+
+ def indexedlatentsfrombatch(self, latents, indexes, latent_format):
+
+ samples = latents.copy()
+ latent_samples = samples["samples"]
+
+ # Parse the indexes string into a list of integers
+ index_list = [int(index.strip()) for index in indexes.split(',')]
+
+ # Convert list of indices to a PyTorch tensor
+ indices_tensor = torch.tensor(index_list, dtype=torch.long)
+
+ # Select the latents at the specified indices
+ if latent_format == "BCHW":
+ chosen_latents = latent_samples[indices_tensor]
+ elif latent_format == "BTCHW":
+ chosen_latents = latent_samples[:, indices_tensor]
+ elif latent_format == "BCTHW":
+ chosen_latents = latent_samples[:, :, indices_tensor]
+
+ samples["samples"] = chosen_latents
+ return (samples,)
+
+
+class ConditioningMultiCombine:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "inputcount": ("INT", {"default": 2, "min": 2, "max": 20, "step": 1}),
+ "operation": (["combine", "concat"], {"default": "combine"}),
+ "conditioning_1": ("CONDITIONING", ),
+ "conditioning_2": ("CONDITIONING", ),
+ },
+ }
+
+ RETURN_TYPES = ("CONDITIONING", "INT")
+ RETURN_NAMES = ("combined", "inputcount")
+ FUNCTION = "combine"
+ CATEGORY = "KJNodes/masking/conditioning"
+ DESCRIPTION = """
+Combines multiple conditioning nodes into one
+"""
+
+ def combine(self, inputcount, operation, **kwargs):
+ from nodes import ConditioningCombine
+ from nodes import ConditioningConcat
+ cond_combine_node = ConditioningCombine()
+ cond_concat_node = ConditioningConcat()
+ cond = kwargs["conditioning_1"]
+ for c in range(1, inputcount):
+ new_cond = kwargs[f"conditioning_{c + 1}"]
+ if operation == "combine":
+ cond = cond_combine_node.combine(new_cond, cond)[0]
+ elif operation == "concat":
+ cond = cond_concat_node.concat(cond, new_cond)[0]
+ return (cond, inputcount,)
+
+class AppendStringsToList:
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "string1": ("STRING", {"default": '', "forceInput": True}),
+ "string2": ("STRING", {"default": '', "forceInput": True}),
+ }
+ }
+ RETURN_TYPES = ("STRING",)
+ FUNCTION = "joinstring"
+ CATEGORY = "KJNodes/text"
+
+ def joinstring(self, string1, string2):
+ if not isinstance(string1, list):
+ string1 = [string1]
+ if not isinstance(string2, list):
+ string2 = [string2]
+
+ joined_string = string1 + string2
+ return (joined_string, )
+
+class JoinStrings:
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "string1": ("STRING", {"default": '', "forceInput": True}),
+ "string2": ("STRING", {"default": '', "forceInput": True}),
+ "delimiter": ("STRING", {"default": ' ', "multiline": False}),
+ }
+ }
+ RETURN_TYPES = ("STRING",)
+ FUNCTION = "joinstring"
+ CATEGORY = "KJNodes/text"
+
+ def joinstring(self, string1, string2, delimiter):
+ joined_string = string1 + delimiter + string2
+ return (joined_string, )
+
+class JoinStringMulti:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "inputcount": ("INT", {"default": 2, "min": 2, "max": 1000, "step": 1}),
+ "string_1": ("STRING", {"default": '', "forceInput": True}),
+ "string_2": ("STRING", {"default": '', "forceInput": True}),
+ "delimiter": ("STRING", {"default": ' ', "multiline": False}),
+ "return_list": ("BOOLEAN", {"default": False}),
+ },
+ }
+
+ RETURN_TYPES = ("STRING",)
+ RETURN_NAMES = ("string",)
+ FUNCTION = "combine"
+ CATEGORY = "KJNodes/text"
+ DESCRIPTION = """
+Creates single string, or a list of strings, from
+multiple input strings.
+You can set how many inputs the node has,
+with the **inputcount** and clicking update.
+"""
+
+ def combine(self, inputcount, delimiter, **kwargs):
+ string = kwargs["string_1"]
+ return_list = kwargs["return_list"]
+ strings = [string] # Initialize a list with the first string
+ for c in range(1, inputcount):
+ new_string = kwargs[f"string_{c + 1}"]
+ if return_list:
+ strings.append(new_string) # Add new string to the list
+ else:
+ string = string + delimiter + new_string
+ if return_list:
+ return (strings,) # Return the list of strings
+ else:
+ return (string,) # Return the combined string
+
+class CondPassThrough:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ },
+ "optional": {
+ "positive": ("CONDITIONING", ),
+ "negative": ("CONDITIONING", ),
+ },
+ }
+
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING",)
+ RETURN_NAMES = ("positive", "negative")
+ FUNCTION = "passthrough"
+ CATEGORY = "KJNodes/misc"
+ DESCRIPTION = """
+ Simply passes through the positive and negative conditioning,
+ workaround for Set node not allowing bypassed inputs.
+"""
+
+ def passthrough(self, positive=None, negative=None):
+ return (positive, negative,)
+
+class ModelPassThrough:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ },
+ "optional": {
+ "model": ("MODEL", ),
+ },
+ }
+
+ RETURN_TYPES = ("MODEL", )
+ RETURN_NAMES = ("model",)
+ FUNCTION = "passthrough"
+ CATEGORY = "KJNodes/misc"
+ DESCRIPTION = """
+ Simply passes through the model,
+ workaround for Set node not allowing bypassed inputs.
+"""
+
+ def passthrough(self, model=None):
+ return (model,)
+
+def append_helper(t, mask, c, set_area_to_bounds, strength):
+ n = [t[0], t[1].copy()]
+ _, h, w = mask.shape
+ n[1]['mask'] = mask
+ n[1]['set_area_to_bounds'] = set_area_to_bounds
+ n[1]['mask_strength'] = strength
+ c.append(n)
+
+class ConditioningSetMaskAndCombine:
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "positive_1": ("CONDITIONING", ),
+ "negative_1": ("CONDITIONING", ),
+ "positive_2": ("CONDITIONING", ),
+ "negative_2": ("CONDITIONING", ),
+ "mask_1": ("MASK", ),
+ "mask_2": ("MASK", ),
+ "mask_1_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "mask_2_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "set_cond_area": (["default", "mask bounds"],),
+ }
+ }
+
+ RETURN_TYPES = ("CONDITIONING","CONDITIONING",)
+ RETURN_NAMES = ("combined_positive", "combined_negative",)
+ FUNCTION = "append"
+ CATEGORY = "KJNodes/masking/conditioning"
+ DESCRIPTION = """
+Bundles multiple conditioning mask and combine nodes into one,functionality is identical to ComfyUI native nodes
+"""
+
+ def append(self, positive_1, negative_1, positive_2, negative_2, mask_1, mask_2, set_cond_area, mask_1_strength, mask_2_strength):
+ c = []
+ c2 = []
+ set_area_to_bounds = False
+ if set_cond_area != "default":
+ set_area_to_bounds = True
+ if len(mask_1.shape) < 3:
+ mask_1 = mask_1.unsqueeze(0)
+ if len(mask_2.shape) < 3:
+ mask_2 = mask_2.unsqueeze(0)
+ for t in positive_1:
+ append_helper(t, mask_1, c, set_area_to_bounds, mask_1_strength)
+ for t in positive_2:
+ append_helper(t, mask_2, c, set_area_to_bounds, mask_2_strength)
+ for t in negative_1:
+ append_helper(t, mask_1, c2, set_area_to_bounds, mask_1_strength)
+ for t in negative_2:
+ append_helper(t, mask_2, c2, set_area_to_bounds, mask_2_strength)
+ return (c, c2)
+
+class ConditioningSetMaskAndCombine3:
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "positive_1": ("CONDITIONING", ),
+ "negative_1": ("CONDITIONING", ),
+ "positive_2": ("CONDITIONING", ),
+ "negative_2": ("CONDITIONING", ),
+ "positive_3": ("CONDITIONING", ),
+ "negative_3": ("CONDITIONING", ),
+ "mask_1": ("MASK", ),
+ "mask_2": ("MASK", ),
+ "mask_3": ("MASK", ),
+ "mask_1_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "mask_2_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "mask_3_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "set_cond_area": (["default", "mask bounds"],),
+ }
+ }
+
+ RETURN_TYPES = ("CONDITIONING","CONDITIONING",)
+ RETURN_NAMES = ("combined_positive", "combined_negative",)
+ FUNCTION = "append"
+ CATEGORY = "KJNodes/masking/conditioning"
+ DESCRIPTION = """
+Bundles multiple conditioning mask and combine nodes into one,functionality is identical to ComfyUI native nodes
+"""
+
+ def append(self, positive_1, negative_1, positive_2, positive_3, negative_2, negative_3, mask_1, mask_2, mask_3, set_cond_area, mask_1_strength, mask_2_strength, mask_3_strength):
+ c = []
+ c2 = []
+ set_area_to_bounds = False
+ if set_cond_area != "default":
+ set_area_to_bounds = True
+ if len(mask_1.shape) < 3:
+ mask_1 = mask_1.unsqueeze(0)
+ if len(mask_2.shape) < 3:
+ mask_2 = mask_2.unsqueeze(0)
+ if len(mask_3.shape) < 3:
+ mask_3 = mask_3.unsqueeze(0)
+ for t in positive_1:
+ append_helper(t, mask_1, c, set_area_to_bounds, mask_1_strength)
+ for t in positive_2:
+ append_helper(t, mask_2, c, set_area_to_bounds, mask_2_strength)
+ for t in positive_3:
+ append_helper(t, mask_3, c, set_area_to_bounds, mask_3_strength)
+ for t in negative_1:
+ append_helper(t, mask_1, c2, set_area_to_bounds, mask_1_strength)
+ for t in negative_2:
+ append_helper(t, mask_2, c2, set_area_to_bounds, mask_2_strength)
+ for t in negative_3:
+ append_helper(t, mask_3, c2, set_area_to_bounds, mask_3_strength)
+ return (c, c2)
+
+class ConditioningSetMaskAndCombine4:
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "positive_1": ("CONDITIONING", ),
+ "negative_1": ("CONDITIONING", ),
+ "positive_2": ("CONDITIONING", ),
+ "negative_2": ("CONDITIONING", ),
+ "positive_3": ("CONDITIONING", ),
+ "negative_3": ("CONDITIONING", ),
+ "positive_4": ("CONDITIONING", ),
+ "negative_4": ("CONDITIONING", ),
+ "mask_1": ("MASK", ),
+ "mask_2": ("MASK", ),
+ "mask_3": ("MASK", ),
+ "mask_4": ("MASK", ),
+ "mask_1_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "mask_2_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "mask_3_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "mask_4_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "set_cond_area": (["default", "mask bounds"],),
+ }
+ }
+
+ RETURN_TYPES = ("CONDITIONING","CONDITIONING",)
+ RETURN_NAMES = ("combined_positive", "combined_negative",)
+ FUNCTION = "append"
+ CATEGORY = "KJNodes/masking/conditioning"
+ DESCRIPTION = """
+Bundles multiple conditioning mask and combine nodes into one,functionality is identical to ComfyUI native nodes
+"""
+
+ def append(self, positive_1, negative_1, positive_2, positive_3, positive_4, negative_2, negative_3, negative_4, mask_1, mask_2, mask_3, mask_4, set_cond_area, mask_1_strength, mask_2_strength, mask_3_strength, mask_4_strength):
+ c = []
+ c2 = []
+ set_area_to_bounds = False
+ if set_cond_area != "default":
+ set_area_to_bounds = True
+ if len(mask_1.shape) < 3:
+ mask_1 = mask_1.unsqueeze(0)
+ if len(mask_2.shape) < 3:
+ mask_2 = mask_2.unsqueeze(0)
+ if len(mask_3.shape) < 3:
+ mask_3 = mask_3.unsqueeze(0)
+ if len(mask_4.shape) < 3:
+ mask_4 = mask_4.unsqueeze(0)
+ for t in positive_1:
+ append_helper(t, mask_1, c, set_area_to_bounds, mask_1_strength)
+ for t in positive_2:
+ append_helper(t, mask_2, c, set_area_to_bounds, mask_2_strength)
+ for t in positive_3:
+ append_helper(t, mask_3, c, set_area_to_bounds, mask_3_strength)
+ for t in positive_4:
+ append_helper(t, mask_4, c, set_area_to_bounds, mask_4_strength)
+ for t in negative_1:
+ append_helper(t, mask_1, c2, set_area_to_bounds, mask_1_strength)
+ for t in negative_2:
+ append_helper(t, mask_2, c2, set_area_to_bounds, mask_2_strength)
+ for t in negative_3:
+ append_helper(t, mask_3, c2, set_area_to_bounds, mask_3_strength)
+ for t in negative_4:
+ append_helper(t, mask_4, c2, set_area_to_bounds, mask_4_strength)
+ return (c, c2)
+
+class ConditioningSetMaskAndCombine5:
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "positive_1": ("CONDITIONING", ),
+ "negative_1": ("CONDITIONING", ),
+ "positive_2": ("CONDITIONING", ),
+ "negative_2": ("CONDITIONING", ),
+ "positive_3": ("CONDITIONING", ),
+ "negative_3": ("CONDITIONING", ),
+ "positive_4": ("CONDITIONING", ),
+ "negative_4": ("CONDITIONING", ),
+ "positive_5": ("CONDITIONING", ),
+ "negative_5": ("CONDITIONING", ),
+ "mask_1": ("MASK", ),
+ "mask_2": ("MASK", ),
+ "mask_3": ("MASK", ),
+ "mask_4": ("MASK", ),
+ "mask_5": ("MASK", ),
+ "mask_1_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "mask_2_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "mask_3_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "mask_4_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "mask_5_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "set_cond_area": (["default", "mask bounds"],),
+ }
+ }
+
+ RETURN_TYPES = ("CONDITIONING","CONDITIONING",)
+ RETURN_NAMES = ("combined_positive", "combined_negative",)
+ FUNCTION = "append"
+ CATEGORY = "KJNodes/masking/conditioning"
+ DESCRIPTION = """
+Bundles multiple conditioning mask and combine nodes into one,functionality is identical to ComfyUI native nodes
+"""
+
+ def append(self, positive_1, negative_1, positive_2, positive_3, positive_4, positive_5, negative_2, negative_3, negative_4, negative_5, mask_1, mask_2, mask_3, mask_4, mask_5, set_cond_area, mask_1_strength, mask_2_strength, mask_3_strength, mask_4_strength, mask_5_strength):
+ c = []
+ c2 = []
+ set_area_to_bounds = False
+ if set_cond_area != "default":
+ set_area_to_bounds = True
+ if len(mask_1.shape) < 3:
+ mask_1 = mask_1.unsqueeze(0)
+ if len(mask_2.shape) < 3:
+ mask_2 = mask_2.unsqueeze(0)
+ if len(mask_3.shape) < 3:
+ mask_3 = mask_3.unsqueeze(0)
+ if len(mask_4.shape) < 3:
+ mask_4 = mask_4.unsqueeze(0)
+ if len(mask_5.shape) < 3:
+ mask_5 = mask_5.unsqueeze(0)
+ for t in positive_1:
+ append_helper(t, mask_1, c, set_area_to_bounds, mask_1_strength)
+ for t in positive_2:
+ append_helper(t, mask_2, c, set_area_to_bounds, mask_2_strength)
+ for t in positive_3:
+ append_helper(t, mask_3, c, set_area_to_bounds, mask_3_strength)
+ for t in positive_4:
+ append_helper(t, mask_4, c, set_area_to_bounds, mask_4_strength)
+ for t in positive_5:
+ append_helper(t, mask_5, c, set_area_to_bounds, mask_5_strength)
+ for t in negative_1:
+ append_helper(t, mask_1, c2, set_area_to_bounds, mask_1_strength)
+ for t in negative_2:
+ append_helper(t, mask_2, c2, set_area_to_bounds, mask_2_strength)
+ for t in negative_3:
+ append_helper(t, mask_3, c2, set_area_to_bounds, mask_3_strength)
+ for t in negative_4:
+ append_helper(t, mask_4, c2, set_area_to_bounds, mask_4_strength)
+ for t in negative_5:
+ append_helper(t, mask_5, c2, set_area_to_bounds, mask_5_strength)
+ return (c, c2)
+
+class VRAM_Debug:
+
+ @classmethod
+
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+
+ "empty_cache": ("BOOLEAN", {"default": True}),
+ "gc_collect": ("BOOLEAN", {"default": True}),
+ "unload_all_models": ("BOOLEAN", {"default": False}),
+ },
+ "optional": {
+ "any_input": (any, {}),
+ "image_pass": ("IMAGE",),
+ "model_pass": ("MODEL",),
+ }
+ }
+
+ RETURN_TYPES = (any, "IMAGE","MODEL","INT", "INT",)
+ RETURN_NAMES = ("any_output", "image_pass", "model_pass", "freemem_before", "freemem_after")
+ FUNCTION = "VRAMdebug"
+ CATEGORY = "KJNodes/misc"
+ DESCRIPTION = """
+Returns the inputs unchanged, they are only used as triggers,
+and performs comfy model management functions and garbage collection,
+reports free VRAM before and after the operations.
+"""
+
+ def VRAMdebug(self, gc_collect, empty_cache, unload_all_models, image_pass=None, model_pass=None, any_input=None):
+ freemem_before = model_management.get_free_memory()
+ print("VRAMdebug: free memory before: ", f"{freemem_before:,.0f}")
+ if empty_cache:
+ model_management.soft_empty_cache()
+ if unload_all_models:
+ model_management.unload_all_models()
+ if gc_collect:
+ import gc
+ gc.collect()
+ freemem_after = model_management.get_free_memory()
+ print("VRAMdebug: free memory after: ", f"{freemem_after:,.0f}")
+ print("VRAMdebug: freed memory: ", f"{freemem_after - freemem_before:,.0f}")
+ return {"ui": {
+ "text": [f"{freemem_before:,.0f}x{freemem_after:,.0f}"]},
+ "result": (any_input, image_pass, model_pass, freemem_before, freemem_after)
+ }
+
+class SomethingToString:
+ @classmethod
+
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "input": (any, {}),
+ },
+ "optional": {
+ "prefix": ("STRING", {"default": ""}),
+ "suffix": ("STRING", {"default": ""}),
+ }
+ }
+ RETURN_TYPES = ("STRING",)
+ FUNCTION = "stringify"
+ CATEGORY = "KJNodes/text"
+ DESCRIPTION = """
+Converts any type to a string.
+"""
+
+ def stringify(self, input, prefix="", suffix=""):
+ if isinstance(input, (int, float, bool)):
+ stringified = str(input)
+ elif isinstance(input, list):
+ stringified = ', '.join(str(item) for item in input)
+ else:
+ return
+ if prefix: # Check if prefix is not empty
+ stringified = prefix + stringified # Add the prefix
+ if suffix: # Check if suffix is not empty
+ stringified = stringified + suffix # Add the suffix
+
+ return (stringified,)
+
+class Sleep:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "input": (any, {}),
+ "minutes": ("INT", {"default": 0, "min": 0, "max": 1439}),
+ "seconds": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 59.99, "step": 0.01}),
+ },
+ }
+ RETURN_TYPES = (any,)
+ FUNCTION = "sleepdelay"
+ CATEGORY = "KJNodes/misc"
+ DESCRIPTION = """
+Delays the execution for the input amount of time.
+"""
+
+ def sleepdelay(self, input, minutes, seconds):
+ total_seconds = minutes * 60 + seconds
+ time.sleep(total_seconds)
+ return input,
+
+class EmptyLatentImagePresets:
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "dimensions": (
+ [
+ '512 x 512 (1:1)',
+ '768 x 512 (1.5:1)',
+ '960 x 512 (1.875:1)',
+ '1024 x 512 (2:1)',
+ '1024 x 576 (1.778:1)',
+ '1536 x 640 (2.4:1)',
+ '1344 x 768 (1.75:1)',
+ '1216 x 832 (1.46:1)',
+ '1152 x 896 (1.286:1)',
+ '1024 x 1024 (1:1)',
+ ],
+ {
+ "default": '512 x 512 (1:1)'
+ }),
+
+ "invert": ("BOOLEAN", {"default": False}),
+ "batch_size": ("INT", {
+ "default": 1,
+ "min": 1,
+ "max": 4096
+ }),
+ },
+ }
+
+ RETURN_TYPES = ("LATENT", "INT", "INT")
+ RETURN_NAMES = ("Latent", "Width", "Height")
+ FUNCTION = "generate"
+ CATEGORY = "KJNodes/latents"
+
+ def generate(self, dimensions, invert, batch_size):
+ from nodes import EmptyLatentImage
+ result = [x.strip() for x in dimensions.split('x')]
+
+ # Remove the aspect ratio part
+ result[0] = result[0].split('(')[0].strip()
+ result[1] = result[1].split('(')[0].strip()
+
+ if invert:
+ width = int(result[1].split(' ')[0])
+ height = int(result[0])
+ else:
+ width = int(result[0])
+ height = int(result[1].split(' ')[0])
+ latent = EmptyLatentImage().generate(width, height, batch_size)[0]
+
+ return (latent, int(width), int(height),)
+
+class EmptyLatentImageCustomPresets:
+ @classmethod
+ def INPUT_TYPES(cls):
+ try:
+ with open(os.path.join(script_directory, 'custom_dimensions.json')) as f:
+ dimensions_dict = json.load(f)
+ except FileNotFoundError:
+ dimensions_dict = []
+ return {
+ "required": {
+ "dimensions": (
+ [f"{d['label']} - {d['value']}" for d in dimensions_dict],
+ ),
+
+ "invert": ("BOOLEAN", {"default": False}),
+ "batch_size": ("INT", {
+ "default": 1,
+ "min": 1,
+ "max": 4096
+ }),
+ },
+ }
+
+ RETURN_TYPES = ("LATENT", "INT", "INT")
+ RETURN_NAMES = ("Latent", "Width", "Height")
+ FUNCTION = "generate"
+ CATEGORY = "KJNodes/latents"
+ DESCRIPTION = """
+Generates an empty latent image with the specified dimensions.
+The choices are loaded from 'custom_dimensions.json' in the nodes folder.
+"""
+
+ def generate(self, dimensions, invert, batch_size):
+ from nodes import EmptyLatentImage
+ # Split the string into label and value
+ label, value = dimensions.split(' - ')
+ # Split the value into width and height
+ width, height = [x.strip() for x in value.split('x')]
+
+ if invert:
+ width, height = height, width
+
+ latent = EmptyLatentImage().generate(int(width), int(height), batch_size)[0]
+
+ return (latent, int(width), int(height),)
+
+class WidgetToString:
+ @classmethod
+ def IS_CHANGED(cls, **kwargs):
+ return float("NaN")
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "id": ("INT", {"default": 0}),
+ "widget_name": ("STRING", {"multiline": False}),
+ "return_all": ("BOOLEAN", {"default": False}),
+ },
+ "optional": {
+ "any_input": (any, {}),
+ "node_title": ("STRING", {"multiline": False}),
+ "allowed_float_decimals": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Number of decimal places to display for float values"}),
+
+ },
+ "hidden": {"extra_pnginfo": "EXTRA_PNGINFO",
+ "prompt": "PROMPT",
+ "unique_id": "UNIQUE_ID",},
+ }
+
+ RETURN_TYPES = ("STRING", )
+ FUNCTION = "get_widget_value"
+ CATEGORY = "KJNodes/text"
+ DESCRIPTION = """
+Selects a node and it's specified widget and outputs the value as a string.
+If no node id or title is provided it will use the 'any_input' link and use that node.
+To see node id's, enable node id display from Manager badge menu.
+Alternatively you can search with the node title. Node titles ONLY exist if they
+are manually edited!
+The 'any_input' is required for making sure the node you want the value from exists in the workflow.
+"""
+
+ def get_widget_value(self, id, widget_name, extra_pnginfo, prompt, unique_id, return_all=False, any_input=None, node_title="", allowed_float_decimals=2):
+ workflow = extra_pnginfo["workflow"]
+ #print(json.dumps(workflow, indent=4))
+ results = []
+ node_id = None # Initialize node_id to handle cases where no match is found
+ link_id = None
+ link_to_node_map = {}
+
+ for node in workflow["nodes"]:
+ if node_title:
+ if "title" in node:
+ if node["title"] == node_title:
+ node_id = node["id"]
+ break
+ else:
+ print("Node title not found.")
+ elif id != 0:
+ if node["id"] == id:
+ node_id = id
+ break
+ elif any_input is not None:
+ if node["type"] == "WidgetToString" and node["id"] == int(unique_id) and not link_id:
+ for node_input in node["inputs"]:
+ if node_input["name"] == "any_input":
+ link_id = node_input["link"]
+
+ # Construct a map of links to node IDs for future reference
+ node_outputs = node.get("outputs", None)
+ if not node_outputs:
+ continue
+ for output in node_outputs:
+ node_links = output.get("links", None)
+ if not node_links:
+ continue
+ for link in node_links:
+ link_to_node_map[link] = node["id"]
+ if link_id and link == link_id:
+ break
+
+ if link_id:
+ node_id = link_to_node_map.get(link_id, None)
+
+ if node_id is None:
+ raise ValueError("No matching node found for the given title or id")
+
+ values = prompt[str(node_id)]
+ if "inputs" in values:
+ if return_all:
+ # Format items based on type
+ formatted_items = []
+ for k, v in values["inputs"].items():
+ if isinstance(v, float):
+ item = f"{k}: {v:.{allowed_float_decimals}f}"
+ else:
+ item = f"{k}: {str(v)}"
+ formatted_items.append(item)
+ results.append(', '.join(formatted_items))
+ elif widget_name in values["inputs"]:
+ v = values["inputs"][widget_name]
+ if isinstance(v, float):
+ v = f"{v:.{allowed_float_decimals}f}"
+ else:
+ v = str(v)
+ return (v, )
+ else:
+ raise NameError(f"Widget not found: {node_id}.{widget_name}")
+ return (', '.join(results).strip(', '), )
+
+class DummyOut:
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "any_input": (any, {}),
+ }
+ }
+
+ RETURN_TYPES = (any,)
+ FUNCTION = "dummy"
+ CATEGORY = "KJNodes/misc"
+ OUTPUT_NODE = True
+ DESCRIPTION = """
+Does nothing, used to trigger generic workflow output.
+A way to get previews in the UI without saving anything to disk.
+"""
+
+ def dummy(self, any_input):
+ return (any_input,)
+
+class FlipSigmasAdjusted:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required":
+ {"sigmas": ("SIGMAS", ),
+ "divide_by_last_sigma": ("BOOLEAN", {"default": False}),
+ "divide_by": ("FLOAT", {"default": 1,"min": 1, "max": 255, "step": 0.01}),
+ "offset_by": ("INT", {"default": 1,"min": -100, "max": 100, "step": 1}),
+ }
+ }
+ RETURN_TYPES = ("SIGMAS", "STRING",)
+ RETURN_NAMES = ("SIGMAS", "sigmas_string",)
+ CATEGORY = "KJNodes/noise"
+ FUNCTION = "get_sigmas_adjusted"
+
+ def get_sigmas_adjusted(self, sigmas, divide_by_last_sigma, divide_by, offset_by):
+
+ sigmas = sigmas.flip(0)
+ if sigmas[0] == 0:
+ sigmas[0] = 0.0001
+ adjusted_sigmas = sigmas.clone()
+ #offset sigma
+ for i in range(1, len(sigmas)):
+ offset_index = i - offset_by
+ if 0 <= offset_index < len(sigmas):
+ adjusted_sigmas[i] = sigmas[offset_index]
+ else:
+ adjusted_sigmas[i] = 0.0001
+ if adjusted_sigmas[0] == 0:
+ adjusted_sigmas[0] = 0.0001
+ if divide_by_last_sigma:
+ adjusted_sigmas = adjusted_sigmas / adjusted_sigmas[-1]
+
+ sigma_np_array = adjusted_sigmas.numpy()
+ array_string = np.array2string(sigma_np_array, precision=2, separator=', ', threshold=np.inf)
+ adjusted_sigmas = adjusted_sigmas / divide_by
+ return (adjusted_sigmas, array_string,)
+
+class CustomSigmas:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required":
+ {
+ "sigmas_string" :("STRING", {"default": "14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.029","multiline": True}),
+ "interpolate_to_steps": ("INT", {"default": 10,"min": 0, "max": 255, "step": 1}),
+ }
+ }
+ RETURN_TYPES = ("SIGMAS",)
+ RETURN_NAMES = ("SIGMAS",)
+ CATEGORY = "KJNodes/noise"
+ FUNCTION = "customsigmas"
+ DESCRIPTION = """
+Creates a sigmas tensor from a string of comma separated values.
+Examples:
+
+Nvidia's optimized AYS 10 step schedule for SD 1.5:
+14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.029
+SDXL:
+14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.029
+SVD:
+700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002
+"""
+ def customsigmas(self, sigmas_string, interpolate_to_steps):
+ sigmas_list = sigmas_string.split(', ')
+ sigmas_float_list = [float(sigma) for sigma in sigmas_list]
+ sigmas_tensor = torch.FloatTensor(sigmas_float_list)
+ if len(sigmas_tensor) != interpolate_to_steps + 1:
+ sigmas_tensor = self.loglinear_interp(sigmas_tensor, interpolate_to_steps + 1)
+ sigmas_tensor[-1] = 0
+ return (sigmas_tensor.float(),)
+
+ def loglinear_interp(self, t_steps, num_steps):
+ """
+ Performs log-linear interpolation of a given array of decreasing numbers.
+ """
+ t_steps_np = t_steps.numpy()
+
+ xs = np.linspace(0, 1, len(t_steps_np))
+ ys = np.log(t_steps_np[::-1])
+
+ new_xs = np.linspace(0, 1, num_steps)
+ new_ys = np.interp(new_xs, xs, ys)
+
+ interped_ys = np.exp(new_ys)[::-1].copy()
+ interped_ys_tensor = torch.tensor(interped_ys)
+ return interped_ys_tensor
+
+class StringToFloatList:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required":
+ {
+ "string" :("STRING", {"default": "1, 2, 3", "multiline": True}),
+ }
+ }
+ RETURN_TYPES = ("FLOAT",)
+ RETURN_NAMES = ("FLOAT",)
+ CATEGORY = "KJNodes/misc"
+ FUNCTION = "createlist"
+
+ def createlist(self, string):
+ float_list = [float(x.strip()) for x in string.split(',')]
+ return (float_list,)
+
+
+class InjectNoiseToLatent:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "latents":("LATENT",),
+ "strength": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 200.0, "step": 0.0001}),
+ "noise": ("LATENT",),
+ "normalize": ("BOOLEAN", {"default": False}),
+ "average": ("BOOLEAN", {"default": False}),
+ },
+ "optional":{
+ "mask": ("MASK", ),
+ "mix_randn_amount": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.001}),
+ "seed": ("INT", {"default": 123,"min": 0, "max": 0xffffffffffffffff, "step": 1}),
+ }
+ }
+
+ RETURN_TYPES = ("LATENT",)
+ FUNCTION = "injectnoise"
+ CATEGORY = "KJNodes/noise"
+
+ def injectnoise(self, latents, strength, noise, normalize, average, mix_randn_amount=0, seed=None, mask=None):
+ samples = latents["samples"].clone().cpu()
+ noise = noise["samples"].clone().cpu()
+ if samples.shape != samples.shape:
+ raise ValueError("InjectNoiseToLatent: Latent and noise must have the same shape")
+ if average:
+ noised = (samples + noise) / 2
+ else:
+ noised = samples + noise * strength
+ if normalize:
+ noised = noised / noised.std()
+ if mask is not None:
+ mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(noised.shape[2], noised.shape[3]), mode="bilinear")
+ mask = mask.expand((-1,noised.shape[1],-1,-1))
+ if mask.shape[0] < noised.shape[0]:
+ mask = mask.repeat((noised.shape[0] -1) // mask.shape[0] + 1, 1, 1, 1)[:noised.shape[0]]
+ noised = mask * noised + (1-mask) * samples
+ if mix_randn_amount > 0:
+ if seed is not None:
+ generator = torch.manual_seed(seed)
+ rand_noise = torch.randn(noised.size(), dtype=noised.dtype, layout=noised.layout, generator=generator, device="cpu")
+ noised = noised + (mix_randn_amount * rand_noise)
+
+ return ({"samples":noised},)
+
+class SoundReactive:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "sound_level": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 99999, "step": 0.01}),
+ "start_range_hz": ("INT", {"default": 150, "min": 0, "max": 9999, "step": 1}),
+ "end_range_hz": ("INT", {"default": 2000, "min": 0, "max": 9999, "step": 1}),
+ "multiplier": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 99999, "step": 0.01}),
+ "smoothing_factor": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
+ "normalize": ("BOOLEAN", {"default": False}),
+ },
+ }
+
+ RETURN_TYPES = ("FLOAT","INT",)
+ RETURN_NAMES =("sound_level", "sound_level_int",)
+ FUNCTION = "react"
+ CATEGORY = "KJNodes/audio"
+ DESCRIPTION = """
+Reacts to the sound level of the input.
+Uses your browsers sound input options and requires.
+Meant to be used with realtime diffusion with autoqueue.
+"""
+
+ def react(self, sound_level, start_range_hz, end_range_hz, smoothing_factor, multiplier, normalize):
+
+ sound_level *= multiplier
+
+ if normalize:
+ sound_level /= 255
+
+ sound_level_int = int(sound_level)
+ return (sound_level, sound_level_int, )
+
+class GenerateNoise:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ "height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
+ "seed": ("INT", {"default": 123,"min": 0, "max": 0xffffffffffffffff, "step": 1}),
+ "multiplier": ("FLOAT", {"default": 1.0,"min": 0.0, "max": 4096, "step": 0.01}),
+ "constant_batch_noise": ("BOOLEAN", {"default": False}),
+ "normalize": ("BOOLEAN", {"default": False}),
+ },
+ "optional": {
+ "model": ("MODEL", ),
+ "sigmas": ("SIGMAS", ),
+ "latent_channels": (['4', '16', ],),
+ "shape": (["BCHW", "BCTHW","BTCHW",],),
+ }
+ }
+
+ RETURN_TYPES = ("LATENT",)
+ FUNCTION = "generatenoise"
+ CATEGORY = "KJNodes/noise"
+ DESCRIPTION = """
+Generates noise for injection or to be used as empty latents on samplers with add_noise off.
+"""
+
+ def generatenoise(self, batch_size, width, height, seed, multiplier, constant_batch_noise, normalize, sigmas=None, model=None, latent_channels=4, shape="BCHW"):
+
+ generator = torch.manual_seed(seed)
+ if shape == "BCHW":
+ noise = torch.randn([batch_size, int(latent_channels), height // 8, width // 8], dtype=torch.float32, layout=torch.strided, generator=generator, device="cpu")
+ elif shape == "BCTHW":
+ noise = torch.randn([1, int(latent_channels), batch_size,height // 8, width // 8], dtype=torch.float32, layout=torch.strided, generator=generator, device="cpu")
+ elif shape == "BTCHW":
+ noise = torch.randn([1, batch_size, int(latent_channels), height // 8, width // 8], dtype=torch.float32, layout=torch.strided, generator=generator, device="cpu")
+ if sigmas is not None:
+ sigma = sigmas[0] - sigmas[-1]
+ sigma /= model.model.latent_format.scale_factor
+ noise *= sigma
+
+ noise *=multiplier
+
+ if normalize:
+ noise = noise / noise.std()
+ if constant_batch_noise:
+ noise = noise[0].repeat(batch_size, 1, 1, 1)
+
+
+ return ({"samples":noise}, )
+
+def camera_embeddings(elevation, azimuth):
+ elevation = torch.as_tensor([elevation])
+ azimuth = torch.as_tensor([azimuth])
+ embeddings = torch.stack(
+ [
+ torch.deg2rad(
+ (90 - elevation) - (90)
+ ), # Zero123 polar is 90-elevation
+ torch.sin(torch.deg2rad(azimuth)),
+ torch.cos(torch.deg2rad(azimuth)),
+ torch.deg2rad(
+ 90 - torch.full_like(elevation, 0)
+ ),
+ ], dim=-1).unsqueeze(1)
+
+ return embeddings
+
+def interpolate_angle(start, end, fraction):
+ # Calculate the difference in angles and adjust for wraparound if necessary
+ diff = (end - start + 540) % 360 - 180
+ # Apply fraction to the difference
+ interpolated = start + fraction * diff
+ # Normalize the result to be within the range of -180 to 180
+ return (interpolated + 180) % 360 - 180
+
+
+class StableZero123_BatchSchedule:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": { "clip_vision": ("CLIP_VISION",),
+ "init_image": ("IMAGE",),
+ "vae": ("VAE",),
+ "width": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
+ "height": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
+ "interpolation": (["linear", "ease_in", "ease_out", "ease_in_out"],),
+ "azimuth_points_string": ("STRING", {"default": "0:(0.0),\n7:(1.0),\n15:(0.0)\n", "multiline": True}),
+ "elevation_points_string": ("STRING", {"default": "0:(0.0),\n7:(0.0),\n15:(0.0)\n", "multiline": True}),
+ }}
+
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
+ RETURN_NAMES = ("positive", "negative", "latent")
+ FUNCTION = "encode"
+ CATEGORY = "KJNodes/experimental"
+
+ def encode(self, clip_vision, init_image, vae, width, height, batch_size, azimuth_points_string, elevation_points_string, interpolation):
+ output = clip_vision.encode_image(init_image)
+ pooled = output.image_embeds.unsqueeze(0)
+ pixels = common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
+ encode_pixels = pixels[:,:,:,:3]
+ t = vae.encode(encode_pixels)
+
+ def ease_in(t):
+ return t * t
+ def ease_out(t):
+ return 1 - (1 - t) * (1 - t)
+ def ease_in_out(t):
+ return 3 * t * t - 2 * t * t * t
+
+ # Parse the azimuth input string into a list of tuples
+ azimuth_points = []
+ azimuth_points_string = azimuth_points_string.rstrip(',\n')
+ for point_str in azimuth_points_string.split(','):
+ frame_str, azimuth_str = point_str.split(':')
+ frame = int(frame_str.strip())
+ azimuth = float(azimuth_str.strip()[1:-1])
+ azimuth_points.append((frame, azimuth))
+ # Sort the points by frame number
+ azimuth_points.sort(key=lambda x: x[0])
+
+ # Parse the elevation input string into a list of tuples
+ elevation_points = []
+ elevation_points_string = elevation_points_string.rstrip(',\n')
+ for point_str in elevation_points_string.split(','):
+ frame_str, elevation_str = point_str.split(':')
+ frame = int(frame_str.strip())
+ elevation_val = float(elevation_str.strip()[1:-1])
+ elevation_points.append((frame, elevation_val))
+ # Sort the points by frame number
+ elevation_points.sort(key=lambda x: x[0])
+
+ # Index of the next point to interpolate towards
+ next_point = 1
+ next_elevation_point = 1
+
+ positive_cond_out = []
+ positive_pooled_out = []
+ negative_cond_out = []
+ negative_pooled_out = []
+
+ #azimuth interpolation
+ for i in range(batch_size):
+ # Find the interpolated azimuth for the current frame
+ while next_point < len(azimuth_points) and i >= azimuth_points[next_point][0]:
+ next_point += 1
+ # If next_point is equal to the length of points, we've gone past the last point
+ if next_point == len(azimuth_points):
+ next_point -= 1 # Set next_point to the last index of points
+ prev_point = max(next_point - 1, 0) # Ensure prev_point is not less than 0
+
+ # Calculate fraction
+ if azimuth_points[next_point][0] != azimuth_points[prev_point][0]: # Prevent division by zero
+ fraction = (i - azimuth_points[prev_point][0]) / (azimuth_points[next_point][0] - azimuth_points[prev_point][0])
+ if interpolation == "ease_in":
+ fraction = ease_in(fraction)
+ elif interpolation == "ease_out":
+ fraction = ease_out(fraction)
+ elif interpolation == "ease_in_out":
+ fraction = ease_in_out(fraction)
+
+ # Use the new interpolate_angle function
+ interpolated_azimuth = interpolate_angle(azimuth_points[prev_point][1], azimuth_points[next_point][1], fraction)
+ else:
+ interpolated_azimuth = azimuth_points[prev_point][1]
+ # Interpolate the elevation
+ next_elevation_point = 1
+ while next_elevation_point < len(elevation_points) and i >= elevation_points[next_elevation_point][0]:
+ next_elevation_point += 1
+ if next_elevation_point == len(elevation_points):
+ next_elevation_point -= 1
+ prev_elevation_point = max(next_elevation_point - 1, 0)
+
+ if elevation_points[next_elevation_point][0] != elevation_points[prev_elevation_point][0]:
+ fraction = (i - elevation_points[prev_elevation_point][0]) / (elevation_points[next_elevation_point][0] - elevation_points[prev_elevation_point][0])
+ if interpolation == "ease_in":
+ fraction = ease_in(fraction)
+ elif interpolation == "ease_out":
+ fraction = ease_out(fraction)
+ elif interpolation == "ease_in_out":
+ fraction = ease_in_out(fraction)
+
+ interpolated_elevation = interpolate_angle(elevation_points[prev_elevation_point][1], elevation_points[next_elevation_point][1], fraction)
+ else:
+ interpolated_elevation = elevation_points[prev_elevation_point][1]
+
+ cam_embeds = camera_embeddings(interpolated_elevation, interpolated_azimuth)
+ cond = torch.cat([pooled, cam_embeds.repeat((pooled.shape[0], 1, 1))], dim=-1)
+
+ positive_pooled_out.append(t)
+ positive_cond_out.append(cond)
+ negative_pooled_out.append(torch.zeros_like(t))
+ negative_cond_out.append(torch.zeros_like(pooled))
+
+ # Concatenate the conditions and pooled outputs
+ final_positive_cond = torch.cat(positive_cond_out, dim=0)
+ final_positive_pooled = torch.cat(positive_pooled_out, dim=0)
+ final_negative_cond = torch.cat(negative_cond_out, dim=0)
+ final_negative_pooled = torch.cat(negative_pooled_out, dim=0)
+
+ # Structure the final output
+ final_positive = [[final_positive_cond, {"concat_latent_image": final_positive_pooled}]]
+ final_negative = [[final_negative_cond, {"concat_latent_image": final_negative_pooled}]]
+
+ latent = torch.zeros([batch_size, 4, height // 8, width // 8])
+ return (final_positive, final_negative, {"samples": latent})
+
+def linear_interpolate(start, end, fraction):
+ return start + (end - start) * fraction
+
+class SV3D_BatchSchedule:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": { "clip_vision": ("CLIP_VISION",),
+ "init_image": ("IMAGE",),
+ "vae": ("VAE",),
+ "width": ("INT", {"default": 576, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
+ "height": ("INT", {"default": 576, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
+ "batch_size": ("INT", {"default": 21, "min": 1, "max": 4096}),
+ "interpolation": (["linear", "ease_in", "ease_out", "ease_in_out"],),
+ "azimuth_points_string": ("STRING", {"default": "0:(0.0),\n9:(180.0),\n20:(360.0)\n", "multiline": True}),
+ "elevation_points_string": ("STRING", {"default": "0:(0.0),\n9:(0.0),\n20:(0.0)\n", "multiline": True}),
+ }}
+
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
+ RETURN_NAMES = ("positive", "negative", "latent")
+ FUNCTION = "encode"
+ CATEGORY = "KJNodes/experimental"
+ DESCRIPTION = """
+Allow scheduling of the azimuth and elevation conditions for SV3D.
+Note that SV3D is still a video model and the schedule needs to always go forward
+https://huggingface.co/stabilityai/sv3d
+"""
+
+ def encode(self, clip_vision, init_image, vae, width, height, batch_size, azimuth_points_string, elevation_points_string, interpolation):
+ output = clip_vision.encode_image(init_image)
+ pooled = output.image_embeds.unsqueeze(0)
+ pixels = common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
+ encode_pixels = pixels[:,:,:,:3]
+ t = vae.encode(encode_pixels)
+
+ def ease_in(t):
+ return t * t
+ def ease_out(t):
+ return 1 - (1 - t) * (1 - t)
+ def ease_in_out(t):
+ return 3 * t * t - 2 * t * t * t
+
+ # Parse the azimuth input string into a list of tuples
+ azimuth_points = []
+ azimuth_points_string = azimuth_points_string.rstrip(',\n')
+ for point_str in azimuth_points_string.split(','):
+ frame_str, azimuth_str = point_str.split(':')
+ frame = int(frame_str.strip())
+ azimuth = float(azimuth_str.strip()[1:-1])
+ azimuth_points.append((frame, azimuth))
+ # Sort the points by frame number
+ azimuth_points.sort(key=lambda x: x[0])
+
+ # Parse the elevation input string into a list of tuples
+ elevation_points = []
+ elevation_points_string = elevation_points_string.rstrip(',\n')
+ for point_str in elevation_points_string.split(','):
+ frame_str, elevation_str = point_str.split(':')
+ frame = int(frame_str.strip())
+ elevation_val = float(elevation_str.strip()[1:-1])
+ elevation_points.append((frame, elevation_val))
+ # Sort the points by frame number
+ elevation_points.sort(key=lambda x: x[0])
+
+ # Index of the next point to interpolate towards
+ next_point = 1
+ next_elevation_point = 1
+ elevations = []
+ azimuths = []
+ # For azimuth interpolation
+ for i in range(batch_size):
+ # Find the interpolated azimuth for the current frame
+ while next_point < len(azimuth_points) and i >= azimuth_points[next_point][0]:
+ next_point += 1
+ if next_point == len(azimuth_points):
+ next_point -= 1
+ prev_point = max(next_point - 1, 0)
+
+ if azimuth_points[next_point][0] != azimuth_points[prev_point][0]:
+ fraction = (i - azimuth_points[prev_point][0]) / (azimuth_points[next_point][0] - azimuth_points[prev_point][0])
+ # Apply the ease function to the fraction
+ if interpolation == "ease_in":
+ fraction = ease_in(fraction)
+ elif interpolation == "ease_out":
+ fraction = ease_out(fraction)
+ elif interpolation == "ease_in_out":
+ fraction = ease_in_out(fraction)
+
+ interpolated_azimuth = linear_interpolate(azimuth_points[prev_point][1], azimuth_points[next_point][1], fraction)
+ else:
+ interpolated_azimuth = azimuth_points[prev_point][1]
+
+ # Interpolate the elevation
+ next_elevation_point = 1
+ while next_elevation_point < len(elevation_points) and i >= elevation_points[next_elevation_point][0]:
+ next_elevation_point += 1
+ if next_elevation_point == len(elevation_points):
+ next_elevation_point -= 1
+ prev_elevation_point = max(next_elevation_point - 1, 0)
+
+ if elevation_points[next_elevation_point][0] != elevation_points[prev_elevation_point][0]:
+ fraction = (i - elevation_points[prev_elevation_point][0]) / (elevation_points[next_elevation_point][0] - elevation_points[prev_elevation_point][0])
+ # Apply the ease function to the fraction
+ if interpolation == "ease_in":
+ fraction = ease_in(fraction)
+ elif interpolation == "ease_out":
+ fraction = ease_out(fraction)
+ elif interpolation == "ease_in_out":
+ fraction = ease_in_out(fraction)
+
+ interpolated_elevation = linear_interpolate(elevation_points[prev_elevation_point][1], elevation_points[next_elevation_point][1], fraction)
+ else:
+ interpolated_elevation = elevation_points[prev_elevation_point][1]
+
+ azimuths.append(interpolated_azimuth)
+ elevations.append(interpolated_elevation)
+
+ #print("azimuths", azimuths)
+ #print("elevations", elevations)
+
+ # Structure the final output
+ final_positive = [[pooled, {"concat_latent_image": t, "elevation": elevations, "azimuth": azimuths}]]
+ final_negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t),"elevation": elevations, "azimuth": azimuths}]]
+
+ latent = torch.zeros([batch_size, 4, height // 8, width // 8])
+ return (final_positive, final_negative, {"samples": latent})
+
+class LoadResAdapterNormalization:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "model": ("MODEL",),
+ "resadapter_path": (folder_paths.get_filename_list("checkpoints"), )
+ }
+ }
+
+ RETURN_TYPES = ("MODEL",)
+ FUNCTION = "load_res_adapter"
+ CATEGORY = "KJNodes/experimental"
+
+ def load_res_adapter(self, model, resadapter_path):
+ print("ResAdapter: Checking ResAdapter path")
+ resadapter_full_path = folder_paths.get_full_path("checkpoints", resadapter_path)
+ if not os.path.exists(resadapter_full_path):
+ raise Exception("Invalid model path")
+ else:
+ print("ResAdapter: Loading ResAdapter normalization weights")
+ from comfy.utils import load_torch_file
+ prefix_to_remove = 'diffusion_model.'
+ model_clone = model.clone()
+ norm_state_dict = load_torch_file(resadapter_full_path)
+ new_values = {key[len(prefix_to_remove):]: value for key, value in norm_state_dict.items() if key.startswith(prefix_to_remove)}
+ print("ResAdapter: Attempting to add patches with ResAdapter weights")
+ try:
+ for key in model.model.diffusion_model.state_dict().keys():
+ if key in new_values:
+ original_tensor = model.model.diffusion_model.state_dict()[key]
+ new_tensor = new_values[key].to(model.model.diffusion_model.dtype)
+ if original_tensor.shape == new_tensor.shape:
+ model_clone.add_object_patch(f"diffusion_model.{key}.data", new_tensor)
+ else:
+ print("ResAdapter: No match for key: ",key)
+ except:
+ raise Exception("Could not patch model, this way of patching was added to ComfyUI on March 3rd 2024, is your ComfyUI up to date?")
+ print("ResAdapter: Added resnet normalization patches")
+ return (model_clone, )
+
+class Superprompt:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "instruction_prompt": ("STRING", {"default": 'Expand the following prompt to add more detail', "multiline": True}),
+ "prompt": ("STRING", {"default": '', "multiline": True, "forceInput": True}),
+ "max_new_tokens": ("INT", {"default": 128, "min": 1, "max": 4096, "step": 1}),
+ }
+ }
+
+ RETURN_TYPES = ("STRING",)
+ FUNCTION = "process"
+ CATEGORY = "KJNodes/text"
+ DESCRIPTION = """
+# SuperPrompt
+A T5 model fine-tuned on the SuperPrompt dataset for
+upsampling text prompts to more detailed descriptions.
+Meant to be used as a pre-generation step for text-to-image
+models that benefit from more detailed prompts.
+https://huggingface.co/roborovski/superprompt-v1
+"""
+
+ def process(self, instruction_prompt, prompt, max_new_tokens):
+ device = model_management.get_torch_device()
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
+
+ checkpoint_path = os.path.join(script_directory, "models","superprompt-v1")
+ if not os.path.exists(checkpoint_path):
+ print(f"Downloading model to: {checkpoint_path}")
+ from huggingface_hub import snapshot_download
+ snapshot_download(repo_id="roborovski/superprompt-v1",
+ local_dir=checkpoint_path,
+ local_dir_use_symlinks=False)
+ tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small", legacy=False)
+
+ model = T5ForConditionalGeneration.from_pretrained(checkpoint_path, device_map=device)
+ model.to(device)
+ input_text = instruction_prompt + ": " + prompt
+
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
+ outputs = model.generate(input_ids, max_new_tokens=max_new_tokens)
+ out = (tokenizer.decode(outputs[0]))
+ out = out.replace('', '')
+ out = out.replace('