Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| import comfy | |
| import folder_paths | |
| import nodes | |
| import os | |
| import math | |
| import re | |
| import safetensors | |
| import glob | |
| from collections import namedtuple | |
| def automerge(tensor, threshold): | |
| (batchsize, slices, dim) = tensor.shape | |
| newTensor=[] | |
| for batch in range(batchsize): | |
| tokens = [] | |
| lastEmbed = tensor[batch,0,:] | |
| merge=[lastEmbed] | |
| tokens.append(lastEmbed) | |
| for i in range(1,slices): | |
| tok = tensor[batch,i,:] | |
| cosine = torch.dot(tok,lastEmbed)/torch.sqrt(torch.dot(tok,tok)*torch.dot(lastEmbed,lastEmbed)) | |
| if cosine >= threshold: | |
| merge.append(tok) | |
| lastEmbed = torch.stack(merge).mean(dim=0) | |
| else: | |
| tokens.append(lastEmbed) | |
| merge=[] | |
| lastEmbed=tok | |
| newTensor.append(torch.stack(tokens)) | |
| return torch.stack(newTensor) | |
| STRENGTHS = ["highest", "high", "medium", "low", "lowest"] | |
| STRENGTHS_VALUES = [1,2, 3,4,5] | |
| class StyleModelApplySimple: | |
| def INPUT_TYPES(s): | |
| return {"required": {"conditioning": ("CONDITIONING", ), | |
| "style_model": ("STYLE_MODEL", ), | |
| "clip_vision_output": ("CLIP_VISION_OUTPUT", ), | |
| "image_strength": (STRENGTHS, {"default": "medium"}) | |
| }} | |
| RETURN_TYPES = ("CONDITIONING",) | |
| FUNCTION = "apply_stylemodel" | |
| CATEGORY = "conditioning/style_model" | |
| def apply_stylemodel(self, clip_vision_output, style_model, conditioning, image_strength): | |
| stren = STRENGTHS.index(image_strength) | |
| downsampling_factor = STRENGTHS_VALUES[stren] | |
| mode="area" if downsampling_factor==3 else "bicubic" | |
| cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0) | |
| if downsampling_factor>1: | |
| (b,t,h)=cond.shape | |
| m = int(np.sqrt(t)) | |
| cond=torch.nn.functional.interpolate(cond.view(b, m, m, h).transpose(1,-1), size=(m//downsampling_factor, m//downsampling_factor), mode=mode)# | |
| cond=cond.transpose(1,-1).reshape(b,-1,h) | |
| c = [] | |
| for t in conditioning: | |
| n = [torch.cat((t[0], cond), dim=1), t[1].copy()] | |
| c.append(n) | |
| return (c, ) | |
| def standardizeMask(mask): | |
| if mask is None: | |
| return None | |
| if len(mask.shape) == 2: | |
| (h,w)=mask.shape | |
| mask=mask.view(1,1,h,w) | |
| elif len(mask.shape)==3: | |
| (b,h,w)=mask.shape | |
| mask=mask.view(b,1,h,w) | |
| return mask | |
| def crop(img, mask, box, desiredSize): | |
| (ox,oy,w,h) = box | |
| if mask is not None: | |
| mask=torch.nn.functional.interpolate(mask, size=(h,w), mode="bicubic").view(-1,h,w,1) | |
| img = torch.nn.functional.interpolate(img.transpose(-1,1), size=(w,h), mode="bicubic", antialias=True) | |
| return (img[:, :, ox:(desiredSize+ox), oy:(desiredSize+oy)].transpose(1,-1), None if mask == None else mask[:, oy:(desiredSize+oy), ox:(desiredSize+ox),:]) | |
| def letterbox(img, mask, w, h, desiredSize): | |
| (b,oh,ow,c) = img.shape | |
| img = torch.nn.functional.interpolate(img.transpose(-1,1), size=(w,h), mode="bicubic", antialias=True).transpose(1,-1) | |
| letterbox = torch.zeros(size=(b,desiredSize,desiredSize, c)) | |
| offsetx = (desiredSize-w)//2 | |
| offsety = (desiredSize-h)//2 | |
| letterbox[:, offsety:(offsety+h), offsetx:(offsetx+w), :] += img | |
| img = letterbox | |
| if mask is not None: | |
| mask=torch.nn.functional.interpolate(mask, size=(h,w), mode="bicubic") | |
| letterbox = torch.zeros(size=(b,1,desiredSize,desiredSize)) | |
| letterbox[:, :, offsety:(offsety+h), offsetx:(offsetx+w)] += mask | |
| mask = letterbox.view(b,1,desiredSize,desiredSize) | |
| return (img, mask) | |
| def getBoundingBox(mask, w, h, relativeMargin, desiredSize): | |
| mask=mask.view(h,w) | |
| marginW = math.ceil(relativeMargin * w) | |
| marginH = math.ceil(relativeMargin * h) | |
| indices = torch.nonzero(mask, as_tuple=False) | |
| y_min, x_min = indices.min(dim=0).values | |
| y_max, x_max = indices.max(dim=0).values | |
| x_min = max(0, x_min.item() - marginW) | |
| y_min = max(0, y_min.item() - marginH) | |
| x_max = min(w, x_max.item() + marginW) | |
| y_max = min(h, y_max.item() + marginH) | |
| box_width = x_max - x_min | |
| box_height = y_max - y_min | |
| larger_edge = max(box_width, box_height, desiredSize) | |
| if box_width < larger_edge: | |
| delta = larger_edge - box_width | |
| left_space = x_min | |
| right_space = w - x_max | |
| expand_left = min(delta // 2, left_space) | |
| expand_right = min(delta - expand_left, right_space) | |
| expand_left += min(delta - (expand_left+expand_right), left_space-expand_left) | |
| x_min -= expand_left | |
| x_max += expand_right | |
| if box_height < larger_edge: | |
| delta = larger_edge - box_height | |
| top_space = y_min | |
| bottom_space = h - y_max | |
| expand_top = min(delta // 2, top_space) | |
| expand_bottom = min(delta - expand_top, bottom_space) | |
| expand_top += min(delta - (expand_top+expand_bottom), top_space-expand_top) | |
| y_min -= expand_top | |
| y_max += expand_bottom | |
| x_min = max(0, x_min) | |
| y_min = max(0, y_min) | |
| x_max = min(w, x_max) | |
| y_max = min(h, y_max) | |
| return x_min, y_min, x_max, y_max | |
| def patchifyMask(mask, patchSize=14): | |
| if mask is None: | |
| return mask | |
| (b, imgSize, imgSize,_) = mask.shape | |
| toks = imgSize//patchSize | |
| return torch.nn.MaxPool2d(kernel_size=(patchSize,patchSize),stride=patchSize)(mask.view(b,imgSize,imgSize)).view(b,toks,toks,1) | |
| def prepareImageAndMask(visionEncoder, image, mask, mode, autocrop_margin, desiredSize=384): | |
| mode = IMAGE_MODES.index(mode) | |
| (B,H,W,C) = image.shape | |
| if mode==0: # center crop square | |
| imgsize = min(H,W) | |
| ratio = desiredSize/imgsize | |
| (w,h) = (round(W*ratio), round(H*ratio)) | |
| image, mask = crop(image, standardizeMask(mask), ((w - desiredSize)//2, (h - desiredSize)//2, w, h), desiredSize) | |
| elif mode==1: | |
| if mask is None: | |
| mask = torch.ones(size=(B,H,W)) | |
| imgsize = max(H,W) | |
| ratio = desiredSize/imgsize | |
| (w,h) = (round(W*ratio), round(H*ratio)) | |
| image, mask = letterbox(image, standardizeMask(mask), w, h, desiredSize) | |
| elif mode==2: | |
| (bx,by,bx2,by2) = getBoundingBox(mask,W,H,autocrop_margin, desiredSize) | |
| image = image[:,by:by2,bx:bx2,:] | |
| mask = mask[:,by:by2,bx:bx2] | |
| imgsize = max(bx2-bx,by2-by) | |
| ratio = desiredSize/imgsize | |
| (w,h) = (round((bx2-bx)*ratio), round((by2-by)*ratio)) | |
| image, mask = letterbox(image, standardizeMask(mask), w, h, desiredSize) | |
| return (image,mask) | |
| def processMask(mask,imgSize=384, patchSize=14): | |
| if len(mask.shape) == 2: | |
| (h,w)=mask.shape | |
| mask=mask.view(1,1,h,w) | |
| elif len(mask.shape)==3: | |
| (b,h,w)=mask.shape | |
| mask=mask.view(b,1,h,w) | |
| scalingFactor = imgSize/min(h,w) | |
| # scale | |
| mask=torch.nn.functional.interpolate(mask, size=(round(h*scalingFactor),round(w*scalingFactor)), mode="bicubic") | |
| # crop | |
| horizontalBorder = (imgSize-mask.shape[3])//2 | |
| verticalBorder = (imgSize-mask.shape[2])//2 | |
| mask=mask[:, :, verticalBorder:(verticalBorder+imgSize),horizontalBorder:(horizontalBorder+imgSize)].view(b,imgSize,imgSize) | |
| toks = imgSize//patchSize | |
| return torch.nn.MaxPool2d(kernel_size=(patchSize,patchSize),stride=patchSize)(mask).view(b,toks,toks,1) | |
| IMAGE_MODES = [ | |
| "center crop (square)", | |
| "keep aspect ratio", | |
| "autocrop with mask" | |
| ] | |
| class ReduxAdvanced: | |
| def INPUT_TYPES(s): | |
| return {"required": {"conditioning": ("CONDITIONING", ), | |
| "style_model": ("STYLE_MODEL", ), | |
| "clip_vision": ("CLIP_VISION", ), | |
| "image": ("IMAGE",), | |
| "downsampling_factor": ("INT", {"default": 3, "min": 1, "max":9}), | |
| "downsampling_function": (["nearest", "bilinear", "bicubic","area","nearest-exact"], {"default": "area"}), | |
| "mode": (IMAGE_MODES, {"default": "center crop (square)"}), | |
| "weight": ("FLOAT", {"default": 1.0, "min":0.0, "max":1.0, "step":0.01}) | |
| }, | |
| "optional": { | |
| "mask": ("MASK", ), | |
| "autocrop_margin": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01}) | |
| }} | |
| RETURN_TYPES = ("CONDITIONING","IMAGE", "MASK") | |
| FUNCTION = "apply_stylemodel" | |
| CATEGORY = "conditioning/style_model" | |
| def apply_stylemodel(self, clip_vision, image, style_model, conditioning, downsampling_factor, downsampling_function,mode,weight, mask=None, autocrop_margin=0.0): | |
| image, masko = prepareImageAndMask(clip_vision, image, mask, mode, autocrop_margin) | |
| clip_vision_output,mask=(clip_vision.encode_image(image), patchifyMask(masko)) | |
| mode="area" | |
| cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0) | |
| (b,t,h)=cond.shape | |
| m = int(np.sqrt(t)) | |
| if downsampling_factor>1: | |
| cond = cond.view(b, m, m, h) | |
| if mask is not None: | |
| cond = cond*mask | |
| cond=torch.nn.functional.interpolate(cond.transpose(1,-1), size=(m//downsampling_factor, m//downsampling_factor), mode=downsampling_function) | |
| cond=cond.transpose(1,-1).reshape(b,-1,h) | |
| mask = None if mask is None else torch.nn.functional.interpolate(mask.view(b, m, m, 1).transpose(1,-1), size=(m//downsampling_factor, m//downsampling_factor), mode=mode).transpose(-1,1) | |
| cond = cond*(weight*weight) | |
| c = [] | |
| if mask is not None: | |
| mask = (mask>0).reshape(b,-1) | |
| max_len = mask.sum(dim=1).max().item() | |
| padded_embeddings = torch.zeros((b, max_len, h), dtype=cond.dtype, device=cond.device) | |
| for i in range(b): | |
| filtered = cond[i][mask[i]] | |
| padded_embeddings[i, :filtered.size(0)] = filtered | |
| cond = padded_embeddings | |
| for t in conditioning: | |
| n = [torch.cat((t[0], cond), dim=1), t[1].copy()] | |
| c.append(n) | |
| return (c, image, masko) | |
| # A dictionary that contains all nodes you want to export with their names | |
| # NOTE: names should be globally unique | |
| NODE_CLASS_MAPPINGS = { | |
| "StyleModelApplySimple": StyleModelApplySimple, | |
| "ReduxAdvanced": ReduxAdvanced | |
| } | |
| # A dictionary that contains the friendly/humanly readable titles for the nodes | |
| NODE_DISPLAY_NAME_MAPPINGS = { | |
| "StyleModelApplySimple": "Apply style model (simple)", | |
| "ReduxAdvanced": "Apply Redux model (advanced)" | |
| } |