Spaces:
Runtime error
Runtime error
Alexander McKinney
commited on
Commit
·
b4542eb
1
Parent(s):
04bf3ab
interface example
Browse filesneed to change to blocks, so we can compute segmentation once, diffusion
once. Only repeated components are on CPU.
unsure how to resolve onclick canvas, need to check what canvas can do.
app.py
CHANGED
@@ -12,6 +12,18 @@ from transformers.models.detr.feature_extraction_detr import rgb_to_id
|
|
12 |
|
13 |
from diffusers import StableDiffusionInpaintPipeline
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
def load_segmentation_models(model_name: str = 'facebook/detr-resnet-50-panoptic'):
|
16 |
feature_extractor = DetrFeatureExtractor.from_pretrained(model_name)
|
17 |
model = DetrForSegmentation.from_pretrained(model_name)
|
@@ -29,9 +41,6 @@ def load_diffusion_pipeline(model_name: str = 'runwayml/stable-diffusion-inpaint
|
|
29 |
def get_device(try_cuda=True):
|
30 |
return torch.device('cuda' if try_cuda and torch.cuda.is_available() else 'cpu')
|
31 |
|
32 |
-
def greet(name):
|
33 |
-
return "Hello " + name + "!"
|
34 |
-
|
35 |
def min_pool(x: torch.Tensor, kernel_size: int):
|
36 |
pad_size = (kernel_size - 1) // 2
|
37 |
return -torch.nn.functional.max_pool2d(-x, kernel_size, (1, 1), padding=pad_size)
|
@@ -47,55 +56,105 @@ def clean_mask(mask, min_kernel: int = 5, max_kernel: int = 23):
|
|
47 |
mask = mask.bool().squeeze().numpy()
|
48 |
return mask
|
49 |
|
50 |
-
# iface = gr.Interface(fn=greet, inputs="text", outputs="text")
|
51 |
-
# iface.launch()
|
52 |
device = get_device()
|
53 |
|
54 |
feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models()
|
55 |
-
|
56 |
-
|
57 |
-
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
58 |
-
image = Image.open(requests.get(url, stream=True).raw)
|
59 |
-
|
60 |
-
# prepare image for the model
|
61 |
-
inputs = feature_extractor(images=image, return_tensors="pt").to(device)
|
62 |
-
|
63 |
-
# forward pass
|
64 |
-
outputs = segmentation_model(**inputs)
|
65 |
-
|
66 |
-
processed_sizes = torch.as_tensor(inputs["pixel_values"].shape[-2:]).unsqueeze(0)
|
67 |
-
result = feature_extractor.post_process_panoptic(outputs, processed_sizes)[0]
|
68 |
-
|
69 |
-
panoptic_seg = Image.open(io.BytesIO(result["png_string"])).resize((image.width, image.height))
|
70 |
-
panoptic_seg = np.array(panoptic_seg, dtype=np.uint8)
|
71 |
-
|
72 |
-
panoptic_seg_id = rgb_to_id(panoptic_seg)
|
73 |
-
|
74 |
-
print(result['segments_info'])
|
75 |
-
|
76 |
-
# cat_mask = (panoptic_seg_id == 1) | (panoptic_seg_id == 5)
|
77 |
-
cat_mask = (panoptic_seg_id == 5)
|
78 |
-
cat_mask = clean_mask(cat_mask)
|
79 |
-
|
80 |
-
masked_image = np.array(image).copy()
|
81 |
-
masked_image[cat_mask] = 0
|
82 |
-
|
83 |
-
masked_image = Image.fromarray(masked_image)
|
84 |
-
masked_image.save('masked_cat.png')
|
85 |
|
86 |
pipe = load_diffusion_pipeline()
|
87 |
pipe = pipe.to(device)
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
from diffusers import StableDiffusionInpaintPipeline
|
14 |
|
15 |
+
# TODO: maybe need to port to `Blocks` system
|
16 |
+
# allegedly provides:
|
17 |
+
# Have multi-step interfaces, in which the output of one model becomes the
|
18 |
+
# input to the next model, or have more flexible data flows in general.
|
19 |
+
|
20 |
+
# and:
|
21 |
+
# Change a component’s properties (for example, the choices in a dropdown) or its visibility based on user input
|
22 |
+
# https://huggingface.co/course/chapter9/7?fw=pt
|
23 |
+
|
24 |
+
torch.inference_mode()
|
25 |
+
torch.no_grad()
|
26 |
+
|
27 |
def load_segmentation_models(model_name: str = 'facebook/detr-resnet-50-panoptic'):
|
28 |
feature_extractor = DetrFeatureExtractor.from_pretrained(model_name)
|
29 |
model = DetrForSegmentation.from_pretrained(model_name)
|
|
|
41 |
def get_device(try_cuda=True):
|
42 |
return torch.device('cuda' if try_cuda and torch.cuda.is_available() else 'cpu')
|
43 |
|
|
|
|
|
|
|
44 |
def min_pool(x: torch.Tensor, kernel_size: int):
|
45 |
pad_size = (kernel_size - 1) // 2
|
46 |
return -torch.nn.functional.max_pool2d(-x, kernel_size, (1, 1), padding=pad_size)
|
|
|
56 |
mask = mask.bool().squeeze().numpy()
|
57 |
return mask
|
58 |
|
|
|
|
|
59 |
device = get_device()
|
60 |
|
61 |
feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models()
|
62 |
+
# segmentation_model = segmentation_model.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
pipe = load_diffusion_pipeline()
|
65 |
pipe = pipe.to(device)
|
66 |
|
67 |
+
# TODO: potentially use `gr.Gallery` to display different masks
|
68 |
+
def fn_segmentation_diffusion(prompt, mask_indices, image, max_kernel, min_kernel, num_diffusion_steps):
|
69 |
+
mask_indices = [int(i) for i in mask_indices.split(',')]
|
70 |
+
inputs = feature_extractor(images=image, return_tensors="pt")
|
71 |
+
outputs = segmentation_model(**inputs)
|
72 |
+
|
73 |
+
processed_sizes = torch.as_tensor(inputs["pixel_values"].shape[-2:]).unsqueeze(0)
|
74 |
+
result = feature_extractor.post_process_panoptic(outputs, processed_sizes)[0]
|
75 |
+
|
76 |
+
panoptic_seg = Image.open(io.BytesIO(result["png_string"])).resize((image.width, image.height))
|
77 |
+
panoptic_seg = np.array(panoptic_seg, dtype=np.uint8)
|
78 |
+
|
79 |
+
class_str = '\n'.join(segmentation_cfg.id2label[s['category_id']] for s in result['segments_info'])
|
80 |
+
|
81 |
+
panoptic_seg_id = rgb_to_id(panoptic_seg)
|
82 |
+
|
83 |
+
if len(mask_indices) > 0:
|
84 |
+
mask = (panoptic_seg_id == mask_indices[0])
|
85 |
+
for idx in mask_indices[1:]:
|
86 |
+
mask = mask | (panoptic_seg_id == idx)
|
87 |
+
mask = clean_mask(mask, min_kernel=min_kernel, max_kernel=max_kernel)
|
88 |
+
|
89 |
+
masked_image = np.array(image).copy()
|
90 |
+
masked_image[mask] = 0
|
91 |
+
|
92 |
+
masked_image = Image.fromarray(masked_image).resize(image.size)
|
93 |
+
mask = Image.fromarray(mask.astype(np.uint8) * 255).resize(image.size)
|
94 |
+
|
95 |
+
if num_diffusion_steps == 0:
|
96 |
+
return masked_image, masked_image, class_str
|
97 |
+
|
98 |
+
STABLE_DIFFUSION_SMALL_EDGE = 512
|
99 |
+
|
100 |
+
assert masked_image.size == mask.size
|
101 |
+
w, h = masked_image.size
|
102 |
+
is_width_larger = w > h
|
103 |
+
resize_ratio = STABLE_DIFFUSION_SMALL_EDGE / (h if is_width_larger else w)
|
104 |
+
|
105 |
+
new_width = int(w * resize_ratio) if is_width_larger else STABLE_DIFFUSION_SMALL_EDGE
|
106 |
+
new_height = STABLE_DIFFUSION_SMALL_EDGE if is_width_larger else int(h * resize_ratio)
|
107 |
+
|
108 |
+
new_width += 8 - (new_width % 8) if is_width_larger else 0
|
109 |
+
new_height += 0 if is_width_larger else 8 - (new_height % 8)
|
110 |
+
|
111 |
+
mask = mask.convert("RGB").resize((new_width, new_height))
|
112 |
+
masked_image = masked_image.convert("RGB").resize((new_width, new_height))
|
113 |
+
|
114 |
+
inpainted_image = pipe(
|
115 |
+
height=new_height,
|
116 |
+
width=new_width,
|
117 |
+
prompt=prompt,
|
118 |
+
image=masked_image,
|
119 |
+
mask_image=mask,
|
120 |
+
num_inference_steps=num_diffusion_steps
|
121 |
+
).images[0]
|
122 |
+
|
123 |
+
return masked_image, inpainted_image, class_str
|
124 |
+
|
125 |
+
|
126 |
+
# iface_segmentation = gr.Interface(
|
127 |
+
# fn=fn_segmentation,
|
128 |
+
# inputs=[
|
129 |
+
# "text",
|
130 |
+
# "text",
|
131 |
+
# gr.Image(value="http://images.cocodataset.org/val2017/000000039769.jpg"),
|
132 |
+
# gr.Slider(minimum=1, maximum=99, value=23, step=2),
|
133 |
+
# gr.Slider(minimum=1, maximum=99, value=5, step=2),
|
134 |
+
# gr.Slider(minimum=0, maximum=100, value=50, step=1),
|
135 |
+
# ],
|
136 |
+
# outputs=["text", gr.Image(type="pil"), gr.Image(type="pil"), "number", "text"]
|
137 |
+
# )
|
138 |
+
|
139 |
+
# iface_diffusion = gr.Interface(
|
140 |
+
# fn=fn_diffusion,
|
141 |
+
# inputs=["text", gr.Image(type='pil'), gr.Image(type='pil'), "number", "text"],
|
142 |
+
# outputs=[gr.Image(), gr.Image(), gr.Textbox()]
|
143 |
+
# )
|
144 |
+
|
145 |
+
# iface = gr.Series(
|
146 |
+
# iface_segmentation, iface_diffusion,
|
147 |
+
iface = gr.Interface(
|
148 |
+
fn=fn_segmentation_diffusion,
|
149 |
+
inputs=[
|
150 |
+
"text",
|
151 |
+
"text",
|
152 |
+
gr.Image(value="http://images.cocodataset.org/val2017/000000039769.jpg", type='pil'),
|
153 |
+
gr.Slider(minimum=1, maximum=99, value=23, step=2),
|
154 |
+
gr.Slider(minimum=1, maximum=99, value=5, step=2),
|
155 |
+
gr.Slider(minimum=0, maximum=100, value=50, step=1),
|
156 |
+
],
|
157 |
+
outputs=[gr.Image(), gr.Image(), gr.Textbox(interactive=False)]
|
158 |
+
)
|
159 |
+
|
160 |
+
iface.launch()
|