File size: 3,363 Bytes
04bf3ab
 
 
 
 
 
 
3e7b7cc
 
04bf3ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e7b7cc
 
 
04bf3ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import io
import requests
import numpy as np
import torch
from PIL import Image
from skimage.measure import block_reduce

import gradio as gr

from transformers import DetrFeatureExtractor, DetrForSegmentation, DetrConfig
from transformers.models.detr.feature_extraction_detr import rgb_to_id

from diffusers import StableDiffusionInpaintPipeline

def load_segmentation_models(model_name: str = 'facebook/detr-resnet-50-panoptic'):
    feature_extractor = DetrFeatureExtractor.from_pretrained(model_name)
    model = DetrForSegmentation.from_pretrained(model_name)
    cfg = DetrConfig.from_pretrained(model_name)

    return feature_extractor, model, cfg

def load_diffusion_pipeline(model_name: str = 'runwayml/stable-diffusion-inpainting'):
    return StableDiffusionInpaintPipeline.from_pretrained(
        model_name,
        revision='fp16',
        torch_dtype=torch.float16
    )

def get_device(try_cuda=True):
    return torch.device('cuda' if try_cuda and torch.cuda.is_available() else 'cpu')

def greet(name):
    return "Hello " + name + "!"

def min_pool(x: torch.Tensor, kernel_size: int):
    pad_size = (kernel_size - 1) // 2
    return -torch.nn.functional.max_pool2d(-x, kernel_size, (1, 1), padding=pad_size) 

def max_pool(x: torch.Tensor, kernel_size: int):
    pad_size = (kernel_size - 1) // 2
    return torch.nn.functional.max_pool2d(x, kernel_size, (1, 1), padding=pad_size) 

def clean_mask(mask, min_kernel: int = 5, max_kernel: int = 23):
    mask = torch.Tensor(mask[None, None]).float()
    mask = min_pool(mask, min_kernel)
    mask = max_pool(mask, max_kernel)
    mask = mask.bool().squeeze().numpy()
    return mask

# iface = gr.Interface(fn=greet, inputs="text", outputs="text")
# iface.launch()
device = get_device()

feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models()
model = segmentation_model.to(device)

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

# prepare image for the model
inputs = feature_extractor(images=image, return_tensors="pt").to(device)

# forward pass
outputs = segmentation_model(**inputs)

processed_sizes = torch.as_tensor(inputs["pixel_values"].shape[-2:]).unsqueeze(0)
result = feature_extractor.post_process_panoptic(outputs, processed_sizes)[0]

panoptic_seg = Image.open(io.BytesIO(result["png_string"])).resize((image.width, image.height))
panoptic_seg = np.array(panoptic_seg, dtype=np.uint8)

panoptic_seg_id = rgb_to_id(panoptic_seg)

print(result['segments_info'])

# cat_mask = (panoptic_seg_id == 1) | (panoptic_seg_id == 5)
cat_mask = (panoptic_seg_id == 5)
cat_mask = clean_mask(cat_mask)

masked_image = np.array(image).copy()
masked_image[cat_mask] = 0

masked_image = Image.fromarray(masked_image)
masked_image.save('masked_cat.png')

pipe = load_diffusion_pipeline()
pipe = pipe.to(device)

print(cat_mask)

resize_ratio = 512 / 480
new_width = int(640 * resize_ratio)
new_width += 8 - (new_width % 8)

print(new_width)
cat_mask = Image.fromarray(cat_mask.astype(np.uint8) * 255).convert("RGB").resize((new_width, 512))
masked_image = masked_image.resize((new_width, 512))

prompt = "Two cats on the sofa together."
inpainted_image = pipe(height=512, width=new_width, prompt=prompt, image=masked_image, mask_image=cat_mask).images[0]
inpainted_image.save('inpaint_cat.png')