nftnik's picture
Upload 765 files
067283f verified
raw
history blame
11.1 kB
import numpy as np
import torch
import comfy
import folder_paths
import nodes
import os
import math
import re
import safetensors
import glob
from collections import namedtuple
@torch.no_grad()
def automerge(tensor, threshold):
(batchsize, slices, dim) = tensor.shape
newTensor=[]
for batch in range(batchsize):
tokens = []
lastEmbed = tensor[batch,0,:]
merge=[lastEmbed]
tokens.append(lastEmbed)
for i in range(1,slices):
tok = tensor[batch,i,:]
cosine = torch.dot(tok,lastEmbed)/torch.sqrt(torch.dot(tok,tok)*torch.dot(lastEmbed,lastEmbed))
if cosine >= threshold:
merge.append(tok)
lastEmbed = torch.stack(merge).mean(dim=0)
else:
tokens.append(lastEmbed)
merge=[]
lastEmbed=tok
newTensor.append(torch.stack(tokens))
return torch.stack(newTensor)
STRENGTHS = ["highest", "high", "medium", "low", "lowest"]
STRENGTHS_VALUES = [1,2, 3,4,5]
class StyleModelApplySimple:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"style_model": ("STYLE_MODEL", ),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"image_strength": (STRENGTHS, {"default": "medium"})
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_stylemodel"
CATEGORY = "conditioning/style_model"
def apply_stylemodel(self, clip_vision_output, style_model, conditioning, image_strength):
stren = STRENGTHS.index(image_strength)
downsampling_factor = STRENGTHS_VALUES[stren]
mode="area" if downsampling_factor==3 else "bicubic"
cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0)
if downsampling_factor>1:
(b,t,h)=cond.shape
m = int(np.sqrt(t))
cond=torch.nn.functional.interpolate(cond.view(b, m, m, h).transpose(1,-1), size=(m//downsampling_factor, m//downsampling_factor), mode=mode)#
cond=cond.transpose(1,-1).reshape(b,-1,h)
c = []
for t in conditioning:
n = [torch.cat((t[0], cond), dim=1), t[1].copy()]
c.append(n)
return (c, )
def standardizeMask(mask):
if mask is None:
return None
if len(mask.shape) == 2:
(h,w)=mask.shape
mask=mask.view(1,1,h,w)
elif len(mask.shape)==3:
(b,h,w)=mask.shape
mask=mask.view(b,1,h,w)
return mask
def crop(img, mask, box, desiredSize):
(ox,oy,w,h) = box
if mask is not None:
mask=torch.nn.functional.interpolate(mask, size=(h,w), mode="bicubic").view(-1,h,w,1)
img = torch.nn.functional.interpolate(img.transpose(-1,1), size=(w,h), mode="bicubic", antialias=True)
return (img[:, :, ox:(desiredSize+ox), oy:(desiredSize+oy)].transpose(1,-1), None if mask == None else mask[:, oy:(desiredSize+oy), ox:(desiredSize+ox),:])
def letterbox(img, mask, w, h, desiredSize):
(b,oh,ow,c) = img.shape
img = torch.nn.functional.interpolate(img.transpose(-1,1), size=(w,h), mode="bicubic", antialias=True).transpose(1,-1)
letterbox = torch.zeros(size=(b,desiredSize,desiredSize, c))
offsetx = (desiredSize-w)//2
offsety = (desiredSize-h)//2
letterbox[:, offsety:(offsety+h), offsetx:(offsetx+w), :] += img
img = letterbox
if mask is not None:
mask=torch.nn.functional.interpolate(mask, size=(h,w), mode="bicubic")
letterbox = torch.zeros(size=(b,1,desiredSize,desiredSize))
letterbox[:, :, offsety:(offsety+h), offsetx:(offsetx+w)] += mask
mask = letterbox.view(b,1,desiredSize,desiredSize)
return (img, mask)
def getBoundingBox(mask, w, h, relativeMargin, desiredSize):
mask=mask.view(h,w)
marginW = math.ceil(relativeMargin * w)
marginH = math.ceil(relativeMargin * h)
indices = torch.nonzero(mask, as_tuple=False)
y_min, x_min = indices.min(dim=0).values
y_max, x_max = indices.max(dim=0).values
x_min = max(0, x_min.item() - marginW)
y_min = max(0, y_min.item() - marginH)
x_max = min(w, x_max.item() + marginW)
y_max = min(h, y_max.item() + marginH)
box_width = x_max - x_min
box_height = y_max - y_min
larger_edge = max(box_width, box_height, desiredSize)
if box_width < larger_edge:
delta = larger_edge - box_width
left_space = x_min
right_space = w - x_max
expand_left = min(delta // 2, left_space)
expand_right = min(delta - expand_left, right_space)
expand_left += min(delta - (expand_left+expand_right), left_space-expand_left)
x_min -= expand_left
x_max += expand_right
if box_height < larger_edge:
delta = larger_edge - box_height
top_space = y_min
bottom_space = h - y_max
expand_top = min(delta // 2, top_space)
expand_bottom = min(delta - expand_top, bottom_space)
expand_top += min(delta - (expand_top+expand_bottom), top_space-expand_top)
y_min -= expand_top
y_max += expand_bottom
x_min = max(0, x_min)
y_min = max(0, y_min)
x_max = min(w, x_max)
y_max = min(h, y_max)
return x_min, y_min, x_max, y_max
def patchifyMask(mask, patchSize=14):
if mask is None:
return mask
(b, imgSize, imgSize,_) = mask.shape
toks = imgSize//patchSize
return torch.nn.MaxPool2d(kernel_size=(patchSize,patchSize),stride=patchSize)(mask.view(b,imgSize,imgSize)).view(b,toks,toks,1)
def prepareImageAndMask(visionEncoder, image, mask, mode, autocrop_margin, desiredSize=384):
mode = IMAGE_MODES.index(mode)
(B,H,W,C) = image.shape
if mode==0: # center crop square
imgsize = min(H,W)
ratio = desiredSize/imgsize
(w,h) = (round(W*ratio), round(H*ratio))
image, mask = crop(image, standardizeMask(mask), ((w - desiredSize)//2, (h - desiredSize)//2, w, h), desiredSize)
elif mode==1:
if mask is None:
mask = torch.ones(size=(B,H,W))
imgsize = max(H,W)
ratio = desiredSize/imgsize
(w,h) = (round(W*ratio), round(H*ratio))
image, mask = letterbox(image, standardizeMask(mask), w, h, desiredSize)
elif mode==2:
(bx,by,bx2,by2) = getBoundingBox(mask,W,H,autocrop_margin, desiredSize)
image = image[:,by:by2,bx:bx2,:]
mask = mask[:,by:by2,bx:bx2]
imgsize = max(bx2-bx,by2-by)
ratio = desiredSize/imgsize
(w,h) = (round((bx2-bx)*ratio), round((by2-by)*ratio))
image, mask = letterbox(image, standardizeMask(mask), w, h, desiredSize)
return (image,mask)
def processMask(mask,imgSize=384, patchSize=14):
if len(mask.shape) == 2:
(h,w)=mask.shape
mask=mask.view(1,1,h,w)
elif len(mask.shape)==3:
(b,h,w)=mask.shape
mask=mask.view(b,1,h,w)
scalingFactor = imgSize/min(h,w)
# scale
mask=torch.nn.functional.interpolate(mask, size=(round(h*scalingFactor),round(w*scalingFactor)), mode="bicubic")
# crop
horizontalBorder = (imgSize-mask.shape[3])//2
verticalBorder = (imgSize-mask.shape[2])//2
mask=mask[:, :, verticalBorder:(verticalBorder+imgSize),horizontalBorder:(horizontalBorder+imgSize)].view(b,imgSize,imgSize)
toks = imgSize//patchSize
return torch.nn.MaxPool2d(kernel_size=(patchSize,patchSize),stride=patchSize)(mask).view(b,toks,toks,1)
IMAGE_MODES = [
"center crop (square)",
"keep aspect ratio",
"autocrop with mask"
]
class ReduxAdvanced:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"style_model": ("STYLE_MODEL", ),
"clip_vision": ("CLIP_VISION", ),
"image": ("IMAGE",),
"downsampling_factor": ("INT", {"default": 3, "min": 1, "max":9}),
"downsampling_function": (["nearest", "bilinear", "bicubic","area","nearest-exact"], {"default": "area"}),
"mode": (IMAGE_MODES, {"default": "center crop (square)"}),
"weight": ("FLOAT", {"default": 1.0, "min":0.0, "max":1.0, "step":0.01})
},
"optional": {
"mask": ("MASK", ),
"autocrop_margin": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01})
}}
RETURN_TYPES = ("CONDITIONING","IMAGE", "MASK")
FUNCTION = "apply_stylemodel"
CATEGORY = "conditioning/style_model"
def apply_stylemodel(self, clip_vision, image, style_model, conditioning, downsampling_factor, downsampling_function,mode,weight, mask=None, autocrop_margin=0.0):
image, masko = prepareImageAndMask(clip_vision, image, mask, mode, autocrop_margin)
clip_vision_output,mask=(clip_vision.encode_image(image), patchifyMask(masko))
mode="area"
cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0)
(b,t,h)=cond.shape
m = int(np.sqrt(t))
if downsampling_factor>1:
cond = cond.view(b, m, m, h)
if mask is not None:
cond = cond*mask
cond=torch.nn.functional.interpolate(cond.transpose(1,-1), size=(m//downsampling_factor, m//downsampling_factor), mode=downsampling_function)
cond=cond.transpose(1,-1).reshape(b,-1,h)
mask = None if mask is None else torch.nn.functional.interpolate(mask.view(b, m, m, 1).transpose(1,-1), size=(m//downsampling_factor, m//downsampling_factor), mode=mode).transpose(-1,1)
cond = cond*(weight*weight)
c = []
if mask is not None:
mask = (mask>0).reshape(b,-1)
max_len = mask.sum(dim=1).max().item()
padded_embeddings = torch.zeros((b, max_len, h), dtype=cond.dtype, device=cond.device)
for i in range(b):
filtered = cond[i][mask[i]]
padded_embeddings[i, :filtered.size(0)] = filtered
cond = padded_embeddings
for t in conditioning:
n = [torch.cat((t[0], cond), dim=1), t[1].copy()]
c.append(n)
return (c, image, masko)
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"StyleModelApplySimple": StyleModelApplySimple,
"ReduxAdvanced": ReduxAdvanced
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"StyleModelApplySimple": "Apply style model (simple)",
"ReduxAdvanced": "Apply Redux model (advanced)"
}