Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import torch | |
import comfy | |
import folder_paths | |
import nodes | |
import os | |
import math | |
import re | |
import safetensors | |
import glob | |
from collections import namedtuple | |
def automerge(tensor, threshold): | |
(batchsize, slices, dim) = tensor.shape | |
newTensor=[] | |
for batch in range(batchsize): | |
tokens = [] | |
lastEmbed = tensor[batch,0,:] | |
merge=[lastEmbed] | |
tokens.append(lastEmbed) | |
for i in range(1,slices): | |
tok = tensor[batch,i,:] | |
cosine = torch.dot(tok,lastEmbed)/torch.sqrt(torch.dot(tok,tok)*torch.dot(lastEmbed,lastEmbed)) | |
if cosine >= threshold: | |
merge.append(tok) | |
lastEmbed = torch.stack(merge).mean(dim=0) | |
else: | |
tokens.append(lastEmbed) | |
merge=[] | |
lastEmbed=tok | |
newTensor.append(torch.stack(tokens)) | |
return torch.stack(newTensor) | |
STRENGTHS = ["highest", "high", "medium", "low", "lowest"] | |
STRENGTHS_VALUES = [1,2, 3,4,5] | |
class StyleModelApplySimple: | |
def INPUT_TYPES(s): | |
return {"required": {"conditioning": ("CONDITIONING", ), | |
"style_model": ("STYLE_MODEL", ), | |
"clip_vision_output": ("CLIP_VISION_OUTPUT", ), | |
"image_strength": (STRENGTHS, {"default": "medium"}) | |
}} | |
RETURN_TYPES = ("CONDITIONING",) | |
FUNCTION = "apply_stylemodel" | |
CATEGORY = "conditioning/style_model" | |
def apply_stylemodel(self, clip_vision_output, style_model, conditioning, image_strength): | |
stren = STRENGTHS.index(image_strength) | |
downsampling_factor = STRENGTHS_VALUES[stren] | |
mode="area" if downsampling_factor==3 else "bicubic" | |
cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0) | |
if downsampling_factor>1: | |
(b,t,h)=cond.shape | |
m = int(np.sqrt(t)) | |
cond=torch.nn.functional.interpolate(cond.view(b, m, m, h).transpose(1,-1), size=(m//downsampling_factor, m//downsampling_factor), mode=mode)# | |
cond=cond.transpose(1,-1).reshape(b,-1,h) | |
c = [] | |
for t in conditioning: | |
n = [torch.cat((t[0], cond), dim=1), t[1].copy()] | |
c.append(n) | |
return (c, ) | |
def standardizeMask(mask): | |
if mask is None: | |
return None | |
if len(mask.shape) == 2: | |
(h,w)=mask.shape | |
mask=mask.view(1,1,h,w) | |
elif len(mask.shape)==3: | |
(b,h,w)=mask.shape | |
mask=mask.view(b,1,h,w) | |
return mask | |
def crop(img, mask, box, desiredSize): | |
(ox,oy,w,h) = box | |
if mask is not None: | |
mask=torch.nn.functional.interpolate(mask, size=(h,w), mode="bicubic").view(-1,h,w,1) | |
img = torch.nn.functional.interpolate(img.transpose(-1,1), size=(w,h), mode="bicubic", antialias=True) | |
return (img[:, :, ox:(desiredSize+ox), oy:(desiredSize+oy)].transpose(1,-1), None if mask == None else mask[:, oy:(desiredSize+oy), ox:(desiredSize+ox),:]) | |
def letterbox(img, mask, w, h, desiredSize): | |
(b,oh,ow,c) = img.shape | |
img = torch.nn.functional.interpolate(img.transpose(-1,1), size=(w,h), mode="bicubic", antialias=True).transpose(1,-1) | |
letterbox = torch.zeros(size=(b,desiredSize,desiredSize, c)) | |
offsetx = (desiredSize-w)//2 | |
offsety = (desiredSize-h)//2 | |
letterbox[:, offsety:(offsety+h), offsetx:(offsetx+w), :] += img | |
img = letterbox | |
if mask is not None: | |
mask=torch.nn.functional.interpolate(mask, size=(h,w), mode="bicubic") | |
letterbox = torch.zeros(size=(b,1,desiredSize,desiredSize)) | |
letterbox[:, :, offsety:(offsety+h), offsetx:(offsetx+w)] += mask | |
mask = letterbox.view(b,1,desiredSize,desiredSize) | |
return (img, mask) | |
def getBoundingBox(mask, w, h, relativeMargin, desiredSize): | |
mask=mask.view(h,w) | |
marginW = math.ceil(relativeMargin * w) | |
marginH = math.ceil(relativeMargin * h) | |
indices = torch.nonzero(mask, as_tuple=False) | |
y_min, x_min = indices.min(dim=0).values | |
y_max, x_max = indices.max(dim=0).values | |
x_min = max(0, x_min.item() - marginW) | |
y_min = max(0, y_min.item() - marginH) | |
x_max = min(w, x_max.item() + marginW) | |
y_max = min(h, y_max.item() + marginH) | |
box_width = x_max - x_min | |
box_height = y_max - y_min | |
larger_edge = max(box_width, box_height, desiredSize) | |
if box_width < larger_edge: | |
delta = larger_edge - box_width | |
left_space = x_min | |
right_space = w - x_max | |
expand_left = min(delta // 2, left_space) | |
expand_right = min(delta - expand_left, right_space) | |
expand_left += min(delta - (expand_left+expand_right), left_space-expand_left) | |
x_min -= expand_left | |
x_max += expand_right | |
if box_height < larger_edge: | |
delta = larger_edge - box_height | |
top_space = y_min | |
bottom_space = h - y_max | |
expand_top = min(delta // 2, top_space) | |
expand_bottom = min(delta - expand_top, bottom_space) | |
expand_top += min(delta - (expand_top+expand_bottom), top_space-expand_top) | |
y_min -= expand_top | |
y_max += expand_bottom | |
x_min = max(0, x_min) | |
y_min = max(0, y_min) | |
x_max = min(w, x_max) | |
y_max = min(h, y_max) | |
return x_min, y_min, x_max, y_max | |
def patchifyMask(mask, patchSize=14): | |
if mask is None: | |
return mask | |
(b, imgSize, imgSize,_) = mask.shape | |
toks = imgSize//patchSize | |
return torch.nn.MaxPool2d(kernel_size=(patchSize,patchSize),stride=patchSize)(mask.view(b,imgSize,imgSize)).view(b,toks,toks,1) | |
def prepareImageAndMask(visionEncoder, image, mask, mode, autocrop_margin, desiredSize=384): | |
mode = IMAGE_MODES.index(mode) | |
(B,H,W,C) = image.shape | |
if mode==0: # center crop square | |
imgsize = min(H,W) | |
ratio = desiredSize/imgsize | |
(w,h) = (round(W*ratio), round(H*ratio)) | |
image, mask = crop(image, standardizeMask(mask), ((w - desiredSize)//2, (h - desiredSize)//2, w, h), desiredSize) | |
elif mode==1: | |
if mask is None: | |
mask = torch.ones(size=(B,H,W)) | |
imgsize = max(H,W) | |
ratio = desiredSize/imgsize | |
(w,h) = (round(W*ratio), round(H*ratio)) | |
image, mask = letterbox(image, standardizeMask(mask), w, h, desiredSize) | |
elif mode==2: | |
(bx,by,bx2,by2) = getBoundingBox(mask,W,H,autocrop_margin, desiredSize) | |
image = image[:,by:by2,bx:bx2,:] | |
mask = mask[:,by:by2,bx:bx2] | |
imgsize = max(bx2-bx,by2-by) | |
ratio = desiredSize/imgsize | |
(w,h) = (round((bx2-bx)*ratio), round((by2-by)*ratio)) | |
image, mask = letterbox(image, standardizeMask(mask), w, h, desiredSize) | |
return (image,mask) | |
def processMask(mask,imgSize=384, patchSize=14): | |
if len(mask.shape) == 2: | |
(h,w)=mask.shape | |
mask=mask.view(1,1,h,w) | |
elif len(mask.shape)==3: | |
(b,h,w)=mask.shape | |
mask=mask.view(b,1,h,w) | |
scalingFactor = imgSize/min(h,w) | |
# scale | |
mask=torch.nn.functional.interpolate(mask, size=(round(h*scalingFactor),round(w*scalingFactor)), mode="bicubic") | |
# crop | |
horizontalBorder = (imgSize-mask.shape[3])//2 | |
verticalBorder = (imgSize-mask.shape[2])//2 | |
mask=mask[:, :, verticalBorder:(verticalBorder+imgSize),horizontalBorder:(horizontalBorder+imgSize)].view(b,imgSize,imgSize) | |
toks = imgSize//patchSize | |
return torch.nn.MaxPool2d(kernel_size=(patchSize,patchSize),stride=patchSize)(mask).view(b,toks,toks,1) | |
IMAGE_MODES = [ | |
"center crop (square)", | |
"keep aspect ratio", | |
"autocrop with mask" | |
] | |
class ReduxAdvanced: | |
def INPUT_TYPES(s): | |
return {"required": {"conditioning": ("CONDITIONING", ), | |
"style_model": ("STYLE_MODEL", ), | |
"clip_vision": ("CLIP_VISION", ), | |
"image": ("IMAGE",), | |
"downsampling_factor": ("INT", {"default": 3, "min": 1, "max":9}), | |
"downsampling_function": (["nearest", "bilinear", "bicubic","area","nearest-exact"], {"default": "area"}), | |
"mode": (IMAGE_MODES, {"default": "center crop (square)"}), | |
"weight": ("FLOAT", {"default": 1.0, "min":0.0, "max":1.0, "step":0.01}) | |
}, | |
"optional": { | |
"mask": ("MASK", ), | |
"autocrop_margin": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01}) | |
}} | |
RETURN_TYPES = ("CONDITIONING","IMAGE", "MASK") | |
FUNCTION = "apply_stylemodel" | |
CATEGORY = "conditioning/style_model" | |
def apply_stylemodel(self, clip_vision, image, style_model, conditioning, downsampling_factor, downsampling_function,mode,weight, mask=None, autocrop_margin=0.0): | |
image, masko = prepareImageAndMask(clip_vision, image, mask, mode, autocrop_margin) | |
clip_vision_output,mask=(clip_vision.encode_image(image), patchifyMask(masko)) | |
mode="area" | |
cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0) | |
(b,t,h)=cond.shape | |
m = int(np.sqrt(t)) | |
if downsampling_factor>1: | |
cond = cond.view(b, m, m, h) | |
if mask is not None: | |
cond = cond*mask | |
cond=torch.nn.functional.interpolate(cond.transpose(1,-1), size=(m//downsampling_factor, m//downsampling_factor), mode=downsampling_function) | |
cond=cond.transpose(1,-1).reshape(b,-1,h) | |
mask = None if mask is None else torch.nn.functional.interpolate(mask.view(b, m, m, 1).transpose(1,-1), size=(m//downsampling_factor, m//downsampling_factor), mode=mode).transpose(-1,1) | |
cond = cond*(weight*weight) | |
c = [] | |
if mask is not None: | |
mask = (mask>0).reshape(b,-1) | |
max_len = mask.sum(dim=1).max().item() | |
padded_embeddings = torch.zeros((b, max_len, h), dtype=cond.dtype, device=cond.device) | |
for i in range(b): | |
filtered = cond[i][mask[i]] | |
padded_embeddings[i, :filtered.size(0)] = filtered | |
cond = padded_embeddings | |
for t in conditioning: | |
n = [torch.cat((t[0], cond), dim=1), t[1].copy()] | |
c.append(n) | |
return (c, image, masko) | |
# A dictionary that contains all nodes you want to export with their names | |
# NOTE: names should be globally unique | |
NODE_CLASS_MAPPINGS = { | |
"StyleModelApplySimple": StyleModelApplySimple, | |
"ReduxAdvanced": ReduxAdvanced | |
} | |
# A dictionary that contains the friendly/humanly readable titles for the nodes | |
NODE_DISPLAY_NAME_MAPPINGS = { | |
"StyleModelApplySimple": "Apply style model (simple)", | |
"ReduxAdvanced": "Apply Redux model (advanced)" | |
} |