Alexander McKinney
experimenting with segmentation mask and inpainting pipeline
04bf3ab
raw
history blame
3.36 kB
import io
import requests
import numpy as np
import torch
from PIL import Image
from skimage.measure import block_reduce
import gradio as gr
from transformers import DetrFeatureExtractor, DetrForSegmentation, DetrConfig
from transformers.models.detr.feature_extraction_detr import rgb_to_id
from diffusers import StableDiffusionInpaintPipeline
def load_segmentation_models(model_name: str = 'facebook/detr-resnet-50-panoptic'):
feature_extractor = DetrFeatureExtractor.from_pretrained(model_name)
model = DetrForSegmentation.from_pretrained(model_name)
cfg = DetrConfig.from_pretrained(model_name)
return feature_extractor, model, cfg
def load_diffusion_pipeline(model_name: str = 'runwayml/stable-diffusion-inpainting'):
return StableDiffusionInpaintPipeline.from_pretrained(
model_name,
revision='fp16',
torch_dtype=torch.float16
)
def get_device(try_cuda=True):
return torch.device('cuda' if try_cuda and torch.cuda.is_available() else 'cpu')
def greet(name):
return "Hello " + name + "!"
def min_pool(x: torch.Tensor, kernel_size: int):
pad_size = (kernel_size - 1) // 2
return -torch.nn.functional.max_pool2d(-x, kernel_size, (1, 1), padding=pad_size)
def max_pool(x: torch.Tensor, kernel_size: int):
pad_size = (kernel_size - 1) // 2
return torch.nn.functional.max_pool2d(x, kernel_size, (1, 1), padding=pad_size)
def clean_mask(mask, min_kernel: int = 5, max_kernel: int = 23):
mask = torch.Tensor(mask[None, None]).float()
mask = min_pool(mask, min_kernel)
mask = max_pool(mask, max_kernel)
mask = mask.bool().squeeze().numpy()
return mask
# iface = gr.Interface(fn=greet, inputs="text", outputs="text")
# iface.launch()
device = get_device()
feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models()
model = segmentation_model.to(device)
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
# prepare image for the model
inputs = feature_extractor(images=image, return_tensors="pt").to(device)
# forward pass
outputs = segmentation_model(**inputs)
processed_sizes = torch.as_tensor(inputs["pixel_values"].shape[-2:]).unsqueeze(0)
result = feature_extractor.post_process_panoptic(outputs, processed_sizes)[0]
panoptic_seg = Image.open(io.BytesIO(result["png_string"])).resize((image.width, image.height))
panoptic_seg = np.array(panoptic_seg, dtype=np.uint8)
panoptic_seg_id = rgb_to_id(panoptic_seg)
print(result['segments_info'])
# cat_mask = (panoptic_seg_id == 1) | (panoptic_seg_id == 5)
cat_mask = (panoptic_seg_id == 5)
cat_mask = clean_mask(cat_mask)
masked_image = np.array(image).copy()
masked_image[cat_mask] = 0
masked_image = Image.fromarray(masked_image)
masked_image.save('masked_cat.png')
pipe = load_diffusion_pipeline()
pipe = pipe.to(device)
print(cat_mask)
resize_ratio = 512 / 480
new_width = int(640 * resize_ratio)
new_width += 8 - (new_width % 8)
print(new_width)
cat_mask = Image.fromarray(cat_mask.astype(np.uint8) * 255).convert("RGB").resize((new_width, 512))
masked_image = masked_image.resize((new_width, 512))
prompt = "Two cats on the sofa together."
inpainted_image = pipe(height=512, width=new_width, prompt=prompt, image=masked_image, mask_image=cat_mask).images[0]
inpainted_image.save('inpaint_cat.png')