Alexander McKinney
full example in blocks
7d008e4
raw
history blame
10.7 kB
import io
import requests
import numpy as np
import torch
from PIL import Image
from skimage.measure import block_reduce
from typing import List
from functools import reduce
import gradio as gr
from transformers import DetrFeatureExtractor, DetrForSegmentation, DetrConfig
from transformers.models.detr.feature_extraction_detr import rgb_to_id
from diffusers import StableDiffusionInpaintPipeline
# TODO: maybe need to port to `Blocks` system
# allegedly provides:
# Have multi-step interfaces, in which the output of one model becomes the
# input to the next model, or have more flexible data flows in general.
# and:
# Change a component’s properties (for example, the choices in a dropdown) or its visibility based on user input
# https://huggingface.co/course/chapter9/7?fw=pt
torch.inference_mode()
torch.no_grad()
def load_segmentation_models(model_name: str = 'facebook/detr-resnet-50-panoptic'):
feature_extractor = DetrFeatureExtractor.from_pretrained(model_name)
model = DetrForSegmentation.from_pretrained(model_name)
cfg = DetrConfig.from_pretrained(model_name)
return feature_extractor, model, cfg
def load_diffusion_pipeline(model_name: str = 'runwayml/stable-diffusion-inpainting'):
return StableDiffusionInpaintPipeline.from_pretrained(
model_name,
revision='fp16',
torch_dtype=torch.float16
)
def get_device(try_cuda=True):
return torch.device('cuda' if try_cuda and torch.cuda.is_available() else 'cpu')
def min_pool(x: torch.Tensor, kernel_size: int):
pad_size = (kernel_size - 1) // 2
return -torch.nn.functional.max_pool2d(-x, kernel_size, (1, 1), padding=pad_size)
def max_pool(x: torch.Tensor, kernel_size: int):
pad_size = (kernel_size - 1) // 2
return torch.nn.functional.max_pool2d(x, kernel_size, (1, 1), padding=pad_size)
def clean_mask(mask, max_kernel: int = 23, min_kernel: int = 5):
mask = torch.Tensor(mask[None, None]).float()
mask = min_pool(mask, min_kernel)
mask = max_pool(mask, max_kernel)
mask = mask.bool().squeeze().numpy()
return mask
device = get_device()
feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models()
# segmentation_model = segmentation_model.to(device)
pipe = load_diffusion_pipeline()
pipe = pipe.to(device)
def fn_segmentation(image, max_kernel, min_kernel):
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = segmentation_model(**inputs)
processed_sizes = torch.as_tensor(inputs["pixel_values"].shape[-2:]).unsqueeze(0)
result = feature_extractor.post_process_panoptic(outputs, processed_sizes)[0]
panoptic_seg = Image.open(io.BytesIO(result["png_string"])).resize((image.width, image.height))
panoptic_seg = np.array(panoptic_seg, dtype=np.uint8)
panoptic_seg_id = rgb_to_id(panoptic_seg)
raw_masks = []
for s in result['segments_info']:
m = panoptic_seg_id == s['id']
raw_masks.append(m.astype(np.uint8) * 255)
# masks = fn_clean(raw_masks, max_kernel, min_kernel)
checkbox_choices = [f"{s['id']}:{segmentation_cfg.id2label[s['category_id']]}" for s in result['segments_info']]
checkbox_group = gr.CheckboxGroup.update(
choices=checkbox_choices
)
return raw_masks, checkbox_group, gr.Image.update(value=np.zeros((image.height, image.width))), gr.Image.update(value=image)
def fn_clean(masks, max_kernel, min_kernel):
out = []
for m in masks:
m = torch.FloatTensor(m)[None, None]
m = min_pool(m, min_kernel)
m = max_pool(m, max_kernel)
m = m.squeeze().numpy().astype(np.uint8)
out.append(m)
return out
def fn_update_mask(
image: Image,
masks: List[np.array],
masks_enabled: List[int],
max_kernel: int,
min_kernel: int,
):
masks_enabled = [int(m.split(':')[0]) for m in masks_enabled]
combined_mask = reduce(lambda x, y: x | y, [masks[i] for i in masks_enabled], np.zeros_like(masks[0], dtype=bool))
combined_mask = clean_mask(combined_mask, max_kernel, min_kernel)
masked_image = np.array(image).copy()
masked_image[combined_mask] = 0.0
return combined_mask.astype(np.uint8) * 255, Image.fromarray(masked_image)
def fn_diffusion(prompt: str, masked_image: Image, mask: Image, num_diffusion_steps: int):
STABLE_DIFFUSION_SMALL_EDGE = 512
w, h = masked_image.size
is_width_larger = w > h
resize_ratio = STABLE_DIFFUSION_SMALL_EDGE / (h if is_width_larger else w)
new_width = int(w * resize_ratio) if is_width_larger else STABLE_DIFFUSION_SMALL_EDGE
new_height = STABLE_DIFFUSION_SMALL_EDGE if is_width_larger else int(h * resize_ratio)
new_width += 8 - (new_width % 8) if is_width_larger else 0
new_height += 0 if is_width_larger else 8 - (new_height % 8)
mask = Image.fromarray(mask).convert("RGB").resize((new_width, new_height))
masked_image = masked_image.convert("RGB").resize((new_width, new_height))
inpainted_image = pipe(
height=new_height,
width=new_width,
prompt=prompt,
image=masked_image,
mask_image=mask,
num_inference_steps=num_diffusion_steps
).images[0]
inpainted_image = inpainted_image.resize((w, h))
return inpainted_image
def fn_segmentation_diffusion(prompt, mask_indices, image, max_kernel, min_kernel, num_diffusion_steps):
mask_indices = [int(i) for i in mask_indices.split(',')]
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = segmentation_model(**inputs)
processed_sizes = torch.as_tensor(inputs["pixel_values"].shape[-2:]).unsqueeze(0)
result = feature_extractor.post_process_panoptic(outputs, processed_sizes)[0]
panoptic_seg = Image.open(io.BytesIO(result["png_string"])).resize((image.width, image.height))
panoptic_seg = np.array(panoptic_seg, dtype=np.uint8)
class_str = '\n'.join(segmentation_cfg.id2label[s['category_id']] for s in result['segments_info'])
panoptic_seg_id = rgb_to_id(panoptic_seg)
if len(mask_indices) > 0:
mask = (panoptic_seg_id == mask_indices[0])
for idx in mask_indices[1:]:
mask = mask | (panoptic_seg_id == idx)
mask = clean_mask(mask, min_kernel=min_kernel, max_kernel=max_kernel)
masked_image = np.array(image).copy()
masked_image[mask] = 0
masked_image = Image.fromarray(masked_image).resize(image.size)
mask = Image.fromarray(mask.astype(np.uint8) * 255).resize(image.size)
if num_diffusion_steps == 0:
return masked_image, masked_image, class_str
STABLE_DIFFUSION_SMALL_EDGE = 512
assert masked_image.size == mask.size
w, h = masked_image.size
is_width_larger = w > h
resize_ratio = STABLE_DIFFUSION_SMALL_EDGE / (h if is_width_larger else w)
new_width = int(w * resize_ratio) if is_width_larger else STABLE_DIFFUSION_SMALL_EDGE
new_height = STABLE_DIFFUSION_SMALL_EDGE if is_width_larger else int(h * resize_ratio)
new_width += 8 - (new_width % 8) if is_width_larger else 0
new_height += 0 if is_width_larger else 8 - (new_height % 8)
mask = mask.convert("RGB").resize((new_width, new_height))
masked_image = masked_image.convert("RGB").resize((new_width, new_height))
inpainted_image = pipe(
height=new_height,
width=new_width,
prompt=prompt,
image=masked_image,
mask_image=mask,
num_inference_steps=num_diffusion_steps
).images[0]
return masked_image, inpainted_image, class_str
# iface_segmentation = gr.Interface(
# fn=fn_segmentation,
# inputs=[
# "text",
# "text",
# gr.Image(value="http://images.cocodataset.org/val2017/000000039769.jpg"),
# gr.Slider(minimum=1, maximum=99, value=23, step=2),
# gr.Slider(minimum=1, maximum=99, value=5, step=2),
# gr.Slider(minimum=0, maximum=100, value=50, step=1),
# ],
# outputs=["text", gr.Image(type="pil"), gr.Image(type="pil"), "number", "text"]
# )
# iface_diffusion = gr.Interface(
# fn=fn_diffusion,
# inputs=["text", gr.Image(type='pil'), gr.Image(type='pil'), "number", "text"],
# outputs=[gr.Image(), gr.Image(), gr.Textbox()]
# )
# iface = gr.Series(
# iface_segmentation, iface_diffusion,
# iface = gr.Interface(
# fn=fn_segmentation_diffusion,
# inputs=[
# "text",
# "text",
# gr.Image(value="http://images.cocodataset.org/val2017/000000039769.jpg", type='pil'),
# gr.Slider(minimum=1, maximum=99, value=23, step=2),
# gr.Slider(minimum=1, maximum=99, value=5, step=2),
# gr.Slider(minimum=0, maximum=100, value=50, step=1),
# ],
# outputs=[gr.Image(), gr.Image(), gr.Textbox(interactive=False)]
# )
# iface = gr.Interface(
# fn=fn_segmentation,
# inputs=[
# gr.Image(value="http://images.cocodataset.org/val2017/000000039769.jpg", type='pil'),
# gr.Slider(minimum=1, maximum=99, value=23, step=2),
# gr.Slider(minimum=1, maximum=99, value=5, step=2),
# ],
# outputs=gr.Gallery()
# )
# iface.launch()
demo = gr.Blocks()
with demo:
input_image = gr.Image(value="http://images.cocodataset.org/val2017/000000039769.jpg", type='pil')
bt_masks = gr.Button("Compute Masks")
with gr.Row():
mask_image = gr.Image(type='numpy')
masked_image = gr.Image(type='pil')
mask_storage = gr.State()
with gr.Row():
max_slider = gr.Slider(minimum=1, maximum=99, value=23, step=2)
min_slider = gr.Slider(minimum=1, maximum=99, value=5, step=2)
mask_checkboxes = gr.CheckboxGroup(interactive=True)
with gr.Row():
with gr.Column():
prompt = gr.Textbox("Two ginger cats lying together on a pink sofa. There are two TV remotes. High definition.")
steps_slider = gr.Slider(minimum=1, maximum=100, value=50)
bt_diffusion = gr.Button("Run Diffusion")
inpainted_image = gr.Image(type='pil')
bt_masks.click(fn_segmentation, inputs=[input_image, max_slider, min_slider], outputs=[mask_storage, mask_checkboxes, mask_image, masked_image])
max_slider.change(fn_update_mask, inputs=[input_image, mask_storage, mask_checkboxes, max_slider, min_slider], outputs=[mask_image, masked_image])
min_slider.change(fn_update_mask, inputs=[input_image, mask_storage, mask_checkboxes, max_slider, min_slider], outputs=[mask_image, masked_image])
mask_checkboxes.change(fn_update_mask, inputs=[input_image, mask_storage, mask_checkboxes, max_slider, min_slider], outputs=[mask_image, masked_image])
bt_diffusion.click(fn_diffusion, inputs=[prompt, masked_image, mask_image, steps_slider], outputs=inpainted_image)
demo.launch()