import io import requests import numpy as np import torch from PIL import Image from skimage.measure import block_reduce import gradio as gr from transformers import DetrFeatureExtractor, DetrForSegmentation, DetrConfig from transformers.models.detr.feature_extraction_detr import rgb_to_id from diffusers import StableDiffusionInpaintPipeline def load_segmentation_models(model_name: str = 'facebook/detr-resnet-50-panoptic'): feature_extractor = DetrFeatureExtractor.from_pretrained(model_name) model = DetrForSegmentation.from_pretrained(model_name) cfg = DetrConfig.from_pretrained(model_name) return feature_extractor, model, cfg def load_diffusion_pipeline(model_name: str = 'runwayml/stable-diffusion-inpainting'): return StableDiffusionInpaintPipeline.from_pretrained( model_name, revision='fp16', torch_dtype=torch.float16 ) def get_device(try_cuda=True): return torch.device('cuda' if try_cuda and torch.cuda.is_available() else 'cpu') def greet(name): return "Hello " + name + "!" def min_pool(x: torch.Tensor, kernel_size: int): pad_size = (kernel_size - 1) // 2 return -torch.nn.functional.max_pool2d(-x, kernel_size, (1, 1), padding=pad_size) def max_pool(x: torch.Tensor, kernel_size: int): pad_size = (kernel_size - 1) // 2 return torch.nn.functional.max_pool2d(x, kernel_size, (1, 1), padding=pad_size) def clean_mask(mask, min_kernel: int = 5, max_kernel: int = 23): mask = torch.Tensor(mask[None, None]).float() mask = min_pool(mask, min_kernel) mask = max_pool(mask, max_kernel) mask = mask.bool().squeeze().numpy() return mask # iface = gr.Interface(fn=greet, inputs="text", outputs="text") # iface.launch() device = get_device() feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models() model = segmentation_model.to(device) url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) # prepare image for the model inputs = feature_extractor(images=image, return_tensors="pt").to(device) # forward pass outputs = segmentation_model(**inputs) processed_sizes = torch.as_tensor(inputs["pixel_values"].shape[-2:]).unsqueeze(0) result = feature_extractor.post_process_panoptic(outputs, processed_sizes)[0] panoptic_seg = Image.open(io.BytesIO(result["png_string"])).resize((image.width, image.height)) panoptic_seg = np.array(panoptic_seg, dtype=np.uint8) panoptic_seg_id = rgb_to_id(panoptic_seg) print(result['segments_info']) # cat_mask = (panoptic_seg_id == 1) | (panoptic_seg_id == 5) cat_mask = (panoptic_seg_id == 5) cat_mask = clean_mask(cat_mask) masked_image = np.array(image).copy() masked_image[cat_mask] = 0 masked_image = Image.fromarray(masked_image) masked_image.save('masked_cat.png') pipe = load_diffusion_pipeline() pipe = pipe.to(device) print(cat_mask) resize_ratio = 512 / 480 new_width = int(640 * resize_ratio) new_width += 8 - (new_width % 8) print(new_width) cat_mask = Image.fromarray(cat_mask.astype(np.uint8) * 255).convert("RGB").resize((new_width, 512)) masked_image = masked_image.resize((new_width, 512)) prompt = "Two cats on the sofa together." inpainted_image = pipe(height=512, width=new_width, prompt=prompt, image=masked_image, mask_image=cat_mask).images[0] inpainted_image.save('inpaint_cat.png')