Alexander McKinney
fixes bug loading new image with different masks and cleans up code
8cd1abb
raw
history blame
7.28 kB
import io
import requests
import numpy as np
import torch
from PIL import Image
from skimage.measure import block_reduce
from typing import List, Optional
from functools import reduce
import gradio as gr
from transformers import DetrFeatureExtractor, DetrForSegmentation, DetrConfig
from transformers.models.detr.feature_extraction_detr import rgb_to_id
from diffusers import StableDiffusionInpaintPipeline
torch.inference_mode()
torch.no_grad()
def load_segmentation_models(model_name: str = 'facebook/detr-resnet-50-panoptic'):
feature_extractor = DetrFeatureExtractor.from_pretrained(model_name)
model = DetrForSegmentation.from_pretrained(model_name)
cfg = DetrConfig.from_pretrained(model_name)
return feature_extractor, model, cfg
def load_diffusion_pipeline(model_name: str = 'runwayml/stable-diffusion-inpainting'):
return StableDiffusionInpaintPipeline.from_pretrained(
model_name,
revision='fp16',
torch_dtype=torch.float16
)
def get_device(try_cuda=True):
return torch.device('cuda' if try_cuda and torch.cuda.is_available() else 'cpu')
def min_pool(x: torch.Tensor, kernel_size: int):
pad_size = (kernel_size - 1) // 2
return -torch.nn.functional.max_pool2d(-x, kernel_size, (1, 1), padding=pad_size)
def max_pool(x: torch.Tensor, kernel_size: int):
pad_size = (kernel_size - 1) // 2
return torch.nn.functional.max_pool2d(x, kernel_size, (1, 1), padding=pad_size)
def clean_mask(mask, max_kernel: int = 23, min_kernel: int = 5):
mask = torch.Tensor(mask[None, None]).float()
mask = min_pool(mask, min_kernel)
mask = max_pool(mask, max_kernel)
mask = mask.bool().squeeze().numpy()
return mask
device = get_device()
feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models()
pipe = load_diffusion_pipeline()
pipe = pipe.to(device)
def fn_segmentation(image, max_kernel, min_kernel):
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = segmentation_model(**inputs)
processed_sizes = torch.as_tensor(inputs["pixel_values"].shape[-2:]).unsqueeze(0)
result = feature_extractor.post_process_panoptic(outputs, processed_sizes)[0]
panoptic_seg = Image.open(io.BytesIO(result["png_string"])).resize((image.width, image.height))
panoptic_seg = np.array(panoptic_seg, dtype=np.uint8)
panoptic_seg_id = rgb_to_id(panoptic_seg)
raw_masks = []
for s in result['segments_info']:
m = panoptic_seg_id == s['id']
raw_masks.append(m.astype(np.uint8) * 255)
checkbox_choices = [f"{s['id']}:{segmentation_cfg.id2label[s['category_id']]}" for s in result['segments_info']]
checkbox_group = gr.CheckboxGroup.update(
choices=checkbox_choices
)
return raw_masks, checkbox_group, gr.Image.update(value=np.zeros((image.height, image.width))), gr.Image.update(value=image)
def fn_clean(masks, max_kernel, min_kernel):
out = []
for m in masks:
m = torch.FloatTensor(m)[None, None]
m = min_pool(m, min_kernel)
m = max_pool(m, max_kernel)
m = m.squeeze().numpy().astype(np.uint8)
out.append(m)
return out
def fn_update_mask(
image: Image,
masks: List[np.array],
masks_enabled: List[int],
max_kernel: int,
min_kernel: int,
):
masks_enabled = [int(m.split(':')[0]) for m in masks_enabled]
combined_mask = reduce(lambda x, y: x | y, [masks[i] for i in masks_enabled], np.zeros_like(masks[0], dtype=bool))
combined_mask = clean_mask(combined_mask, max_kernel, min_kernel)
masked_image = np.array(image).copy()
masked_image[combined_mask] = 0.0
return combined_mask.astype(np.uint8) * 255, Image.fromarray(masked_image)
def fn_diffusion(
prompt: str,
masked_image: Image,
mask: Image,
num_diffusion_steps: int,
guidance_scale: float,
negative_prompt: Optional[str] = None,
):
if len(negative_prompt) == 0:
negative_prompt = None
STABLE_DIFFUSION_SMALL_EDGE = 512
w, h = masked_image.size
is_width_larger = w > h
resize_ratio = STABLE_DIFFUSION_SMALL_EDGE / (h if is_width_larger else w)
new_width = int(w * resize_ratio) if is_width_larger else STABLE_DIFFUSION_SMALL_EDGE
new_height = STABLE_DIFFUSION_SMALL_EDGE if is_width_larger else int(h * resize_ratio)
new_width += 8 - (new_width % 8) if is_width_larger else 0
new_height += 0 if is_width_larger else 8 - (new_height % 8)
mask = Image.fromarray(mask).convert("RGB").resize((new_width, new_height))
masked_image = masked_image.convert("RGB").resize((new_width, new_height))
inpainted_image = pipe(
height=new_height,
width=new_width,
prompt=prompt,
image=masked_image,
mask_image=mask,
num_inference_steps=num_diffusion_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt
).images[0]
inpainted_image = inpainted_image.resize((w, h))
return inpainted_image
demo = gr.Blocks()
with demo:
input_image = gr.Image(value="http://images.cocodataset.org/val2017/000000039769.jpg", type='pil', label="Input Image")
bt_masks = gr.Button("Compute Masks")
with gr.Row():
mask_image = gr.Image(type='numpy', label="Diffusion Mask")
masked_image = gr.Image(type='pil', label="Masked Image")
mask_storage = gr.State()
with gr.Row():
max_slider = gr.Slider(minimum=1, maximum=99, value=23, step=2, label="Mask Overflow")
min_slider = gr.Slider(minimum=1, maximum=99, value=5, step=2, label="Mask Denoising")
mask_checkboxes = gr.CheckboxGroup(interactive=True, label="Mask Selection")
with gr.Row():
with gr.Column():
prompt = gr.Textbox("Two ginger cats lying together on a pink sofa. There are two TV remotes. High definition.", label="Prompt")
negative_prompt = gr.Textbox(label="Negative Prompt")
with gr.Column():
steps_slider = gr.Slider(minimum=1, maximum=100, value=50, label="Inference Steps")
guidance_slider = gr.Slider(minimum=0.0, maximum=50.0, value=7.5, step=0.1, label="Guidance Scale")
bt_diffusion = gr.Button("Run Diffusion")
inpainted_image = gr.Image(type='pil', label="Inpainted Image")
update_mask_inputs = [input_image, mask_storage, mask_checkboxes, max_slider, min_slider]
update_mask_outputs = [mask_image, masked_image]
input_image.change(lambda: gr.CheckboxGroup.update(choices=[], value=[]), outputs=mask_checkboxes)
bt_masks.click(fn_segmentation, inputs=[input_image, max_slider, min_slider], outputs=[mask_storage, mask_checkboxes, mask_image, masked_image])
max_slider.change(fn_update_mask, inputs=update_mask_inputs, outputs=update_mask_outputs)
min_slider.change(fn_update_mask, inputs=update_mask_inputs, outputs=update_mask_outputs)
mask_checkboxes.change(fn_update_mask, inputs=update_mask_inputs, outputs=update_mask_outputs)
bt_diffusion.click(fn_diffusion, inputs=[
prompt,
masked_image,
mask_image,
steps_slider,
guidance_slider,
negative_prompt
], outputs=inpainted_image)
demo.launch()