Spaces:
Runtime error
Runtime error
import io | |
import requests | |
import numpy as np | |
import torch | |
from PIL import Image | |
from skimage.measure import block_reduce | |
import gradio as gr | |
from transformers import DetrFeatureExtractor, DetrForSegmentation, DetrConfig | |
from transformers.models.detr.feature_extraction_detr import rgb_to_id | |
from diffusers import StableDiffusionInpaintPipeline | |
# TODO: maybe need to port to `Blocks` system | |
# allegedly provides: | |
# Have multi-step interfaces, in which the output of one model becomes the | |
# input to the next model, or have more flexible data flows in general. | |
# and: | |
# Change a component’s properties (for example, the choices in a dropdown) or its visibility based on user input | |
# https://huggingface.co/course/chapter9/7?fw=pt | |
torch.inference_mode() | |
torch.no_grad() | |
def load_segmentation_models(model_name: str = 'facebook/detr-resnet-50-panoptic'): | |
feature_extractor = DetrFeatureExtractor.from_pretrained(model_name) | |
model = DetrForSegmentation.from_pretrained(model_name) | |
cfg = DetrConfig.from_pretrained(model_name) | |
return feature_extractor, model, cfg | |
def load_diffusion_pipeline(model_name: str = 'runwayml/stable-diffusion-inpainting'): | |
return StableDiffusionInpaintPipeline.from_pretrained( | |
model_name, | |
revision='fp16', | |
torch_dtype=torch.float16 | |
) | |
def get_device(try_cuda=True): | |
return torch.device('cuda' if try_cuda and torch.cuda.is_available() else 'cpu') | |
def min_pool(x: torch.Tensor, kernel_size: int): | |
pad_size = (kernel_size - 1) // 2 | |
return -torch.nn.functional.max_pool2d(-x, kernel_size, (1, 1), padding=pad_size) | |
def max_pool(x: torch.Tensor, kernel_size: int): | |
pad_size = (kernel_size - 1) // 2 | |
return torch.nn.functional.max_pool2d(x, kernel_size, (1, 1), padding=pad_size) | |
def clean_mask(mask, min_kernel: int = 5, max_kernel: int = 23): | |
mask = torch.Tensor(mask[None, None]).float() | |
mask = min_pool(mask, min_kernel) | |
mask = max_pool(mask, max_kernel) | |
mask = mask.bool().squeeze().numpy() | |
return mask | |
device = get_device() | |
feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models() | |
# segmentation_model = segmentation_model.to(device) | |
pipe = load_diffusion_pipeline() | |
pipe = pipe.to(device) | |
# TODO: potentially use `gr.Gallery` to display different masks | |
def fn_segmentation_diffusion(prompt, mask_indices, image, max_kernel, min_kernel, num_diffusion_steps): | |
mask_indices = [int(i) for i in mask_indices.split(',')] | |
inputs = feature_extractor(images=image, return_tensors="pt") | |
outputs = segmentation_model(**inputs) | |
processed_sizes = torch.as_tensor(inputs["pixel_values"].shape[-2:]).unsqueeze(0) | |
result = feature_extractor.post_process_panoptic(outputs, processed_sizes)[0] | |
panoptic_seg = Image.open(io.BytesIO(result["png_string"])).resize((image.width, image.height)) | |
panoptic_seg = np.array(panoptic_seg, dtype=np.uint8) | |
class_str = '\n'.join(segmentation_cfg.id2label[s['category_id']] for s in result['segments_info']) | |
panoptic_seg_id = rgb_to_id(panoptic_seg) | |
if len(mask_indices) > 0: | |
mask = (panoptic_seg_id == mask_indices[0]) | |
for idx in mask_indices[1:]: | |
mask = mask | (panoptic_seg_id == idx) | |
mask = clean_mask(mask, min_kernel=min_kernel, max_kernel=max_kernel) | |
masked_image = np.array(image).copy() | |
masked_image[mask] = 0 | |
masked_image = Image.fromarray(masked_image).resize(image.size) | |
mask = Image.fromarray(mask.astype(np.uint8) * 255).resize(image.size) | |
if num_diffusion_steps == 0: | |
return masked_image, masked_image, class_str | |
STABLE_DIFFUSION_SMALL_EDGE = 512 | |
assert masked_image.size == mask.size | |
w, h = masked_image.size | |
is_width_larger = w > h | |
resize_ratio = STABLE_DIFFUSION_SMALL_EDGE / (h if is_width_larger else w) | |
new_width = int(w * resize_ratio) if is_width_larger else STABLE_DIFFUSION_SMALL_EDGE | |
new_height = STABLE_DIFFUSION_SMALL_EDGE if is_width_larger else int(h * resize_ratio) | |
new_width += 8 - (new_width % 8) if is_width_larger else 0 | |
new_height += 0 if is_width_larger else 8 - (new_height % 8) | |
mask = mask.convert("RGB").resize((new_width, new_height)) | |
masked_image = masked_image.convert("RGB").resize((new_width, new_height)) | |
inpainted_image = pipe( | |
height=new_height, | |
width=new_width, | |
prompt=prompt, | |
image=masked_image, | |
mask_image=mask, | |
num_inference_steps=num_diffusion_steps | |
).images[0] | |
return masked_image, inpainted_image, class_str | |
# iface_segmentation = gr.Interface( | |
# fn=fn_segmentation, | |
# inputs=[ | |
# "text", | |
# "text", | |
# gr.Image(value="http://images.cocodataset.org/val2017/000000039769.jpg"), | |
# gr.Slider(minimum=1, maximum=99, value=23, step=2), | |
# gr.Slider(minimum=1, maximum=99, value=5, step=2), | |
# gr.Slider(minimum=0, maximum=100, value=50, step=1), | |
# ], | |
# outputs=["text", gr.Image(type="pil"), gr.Image(type="pil"), "number", "text"] | |
# ) | |
# iface_diffusion = gr.Interface( | |
# fn=fn_diffusion, | |
# inputs=["text", gr.Image(type='pil'), gr.Image(type='pil'), "number", "text"], | |
# outputs=[gr.Image(), gr.Image(), gr.Textbox()] | |
# ) | |
# iface = gr.Series( | |
# iface_segmentation, iface_diffusion, | |
iface = gr.Interface( | |
fn=fn_segmentation_diffusion, | |
inputs=[ | |
"text", | |
"text", | |
gr.Image(value="http://images.cocodataset.org/val2017/000000039769.jpg", type='pil'), | |
gr.Slider(minimum=1, maximum=99, value=23, step=2), | |
gr.Slider(minimum=1, maximum=99, value=5, step=2), | |
gr.Slider(minimum=0, maximum=100, value=50, step=1), | |
], | |
outputs=[gr.Image(), gr.Image(), gr.Textbox(interactive=False)] | |
) | |
iface.launch() | |