Spaces:
Running
on
Zero
Running
on
Zero
| import torch.nn.functional as F | |
| import comfy | |
| from .model_patch import add_model_patch_option, patch_model_function_wrapper | |
| class RAUNet: | |
| def INPUT_TYPES(s): | |
| return {"required": | |
| { | |
| "model": ("MODEL",), | |
| "du_start": ("INT", {"default": 0, "min": 0, "max": 10000}), | |
| "du_end": ("INT", {"default": 4, "min": 0, "max": 10000}), | |
| "xa_start": ("INT", {"default": 4, "min": 0, "max": 10000}), | |
| "xa_end": ("INT", {"default": 10, "min": 0, "max": 10000}), | |
| }, | |
| } | |
| CATEGORY = "inpaint" | |
| RETURN_TYPES = ("MODEL",) | |
| RETURN_NAMES = ("model",) | |
| FUNCTION = "model_update" | |
| def model_update(self, model, du_start, du_end, xa_start, xa_end): | |
| model = model.clone() | |
| add_raunet_patch(model, | |
| du_start, | |
| du_end, | |
| xa_start, | |
| xa_end) | |
| return (model,) | |
| # This is main patch function | |
| def add_raunet_patch(model, du_start, du_end, xa_start, xa_end): | |
| def raunet_forward(model, x, timesteps, transformer_options, control): | |
| if 'model_patch' not in transformer_options: | |
| print("RAUNet: 'model_patch' not in transformer_options, skip") | |
| return | |
| mp = transformer_options['model_patch'] | |
| is_SDXL = mp['SDXL'] | |
| if is_SDXL and type(model.input_blocks[6][0]) != comfy.ldm.modules.diffusionmodules.openaimodel.Downsample: | |
| print('RAUNet: model is SDXL, but input[6] != Downsample, skip') | |
| return | |
| if not is_SDXL and type(model.input_blocks[3][0]) != comfy.ldm.modules.diffusionmodules.openaimodel.Downsample: | |
| print('RAUNet: model is not SDXL, but input[3] != Downsample, skip') | |
| return | |
| if 'raunet' not in mp: | |
| print('RAUNet: "raunet" not in model_patch options, skip') | |
| return | |
| if is_SDXL: | |
| block = model.input_blocks[6][0] | |
| else: | |
| block = model.input_blocks[3][0] | |
| total_steps = mp['total_steps'] | |
| step = mp['step'] | |
| ro = mp['raunet'] | |
| du_start = ro['du_start'] | |
| du_end = ro['du_end'] | |
| if step >= du_start and step < du_end: | |
| block.op.stride = (4, 4) | |
| block.op.padding = (2, 2) | |
| block.op.dilation = (2, 2) | |
| else: | |
| block.op.stride = (2, 2) | |
| block.op.padding = (1, 1) | |
| block.op.dilation = (1, 1) | |
| patch_model_function_wrapper(model, raunet_forward) | |
| model.set_model_input_block_patch(in_xattn_patch) | |
| model.set_model_output_block_patch(out_xattn_patch) | |
| to = add_model_patch_option(model) | |
| mp = to['model_patch'] | |
| if 'raunet' not in mp: | |
| mp['raunet'] = {} | |
| ro = mp['raunet'] | |
| ro['du_start'] = du_start | |
| ro['du_end'] = du_end | |
| ro['xa_start'] = xa_start | |
| ro['xa_end'] = xa_end | |
| def in_xattn_patch(h, transformer_options): | |
| # both SDXL and SD15 = (input,4) | |
| if transformer_options["block"] != ("input", 4): | |
| # wrong block | |
| return h | |
| if 'model_patch' not in transformer_options: | |
| print("RAUNet (i-x-p): 'model_patch' not in transformer_options") | |
| return h | |
| mp = transformer_options['model_patch'] | |
| if 'raunet' not in mp: | |
| print("RAUNet (i-x-p): 'raunet' not in model_patch options") | |
| return h | |
| step = mp['step'] | |
| ro = mp['raunet'] | |
| xa_start = ro['xa_start'] | |
| xa_end = ro['xa_end'] | |
| if step < xa_start or step >= xa_end: | |
| return h | |
| h = F.avg_pool2d(h, kernel_size=(2,2)) | |
| return h | |
| def out_xattn_patch(h, hsp, transformer_options): | |
| if 'model_patch' not in transformer_options: | |
| print("RAUNet (o-x-p): 'model_patch' not in transformer_options") | |
| return h, hsp | |
| mp = transformer_options['model_patch'] | |
| if 'raunet' not in mp: | |
| print("RAUNet (o-x-p): 'raunet' not in model_patch options") | |
| return h | |
| step = mp['step'] | |
| is_SDXL = mp['SDXL'] | |
| ro = mp['raunet'] | |
| xa_start = ro['xa_start'] | |
| xa_end = ro['xa_end'] | |
| if is_SDXL: | |
| if transformer_options["block"] != ("output", 5): | |
| # wrong block | |
| return h, hsp | |
| else: | |
| if transformer_options["block"] != ("output", 8): | |
| # wrong block | |
| return h, hsp | |
| if step < xa_start or step >= xa_end: | |
| return h, hsp | |
| #error in hidiffusion codebase, size * 2 for particular sizes only | |
| #re_size = (int(h.shape[-2] * 2), int(h.shape[-1] * 2)) | |
| re_size = (hsp.shape[-2], hsp.shape[-1]) | |
| h = F.interpolate(h, size=re_size, mode='bicubic') | |
| return h, hsp | |