Spaces:
Runtime error
Runtime error
Alexander McKinney
commited on
Commit
·
d16d053
1
Parent(s):
8cd1abb
adds comments to code
Browse files
app.py
CHANGED
@@ -17,6 +17,7 @@ from diffusers import StableDiffusionInpaintPipeline
|
|
17 |
torch.inference_mode()
|
18 |
torch.no_grad()
|
19 |
|
|
|
20 |
def load_segmentation_models(model_name: str = 'facebook/detr-resnet-50-panoptic'):
|
21 |
feature_extractor = DetrFeatureExtractor.from_pretrained(model_name)
|
22 |
model = DetrForSegmentation.from_pretrained(model_name)
|
@@ -24,6 +25,7 @@ def load_segmentation_models(model_name: str = 'facebook/detr-resnet-50-panoptic
|
|
24 |
|
25 |
return feature_extractor, model, cfg
|
26 |
|
|
|
27 |
def load_diffusion_pipeline(model_name: str = 'runwayml/stable-diffusion-inpainting'):
|
28 |
return StableDiffusionInpaintPipeline.from_pretrained(
|
29 |
model_name,
|
@@ -31,6 +33,7 @@ def load_diffusion_pipeline(model_name: str = 'runwayml/stable-diffusion-inpaint
|
|
31 |
torch_dtype=torch.float16
|
32 |
)
|
33 |
|
|
|
34 |
def get_device(try_cuda=True):
|
35 |
return torch.device('cuda' if try_cuda and torch.cuda.is_available() else 'cpu')
|
36 |
|
@@ -42,6 +45,7 @@ def max_pool(x: torch.Tensor, kernel_size: int):
|
|
42 |
pad_size = (kernel_size - 1) // 2
|
43 |
return torch.nn.functional.max_pool2d(x, kernel_size, (1, 1), padding=pad_size)
|
44 |
|
|
|
45 |
def clean_mask(mask, max_kernel: int = 23, min_kernel: int = 5):
|
46 |
mask = torch.Tensor(mask[None, None]).float()
|
47 |
mask = min_pool(mask, min_kernel)
|
@@ -49,13 +53,14 @@ def clean_mask(mask, max_kernel: int = 23, min_kernel: int = 5):
|
|
49 |
mask = mask.bool().squeeze().numpy()
|
50 |
return mask
|
51 |
|
52 |
-
device = get_device()
|
53 |
|
54 |
feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models()
|
55 |
-
|
56 |
pipe = load_diffusion_pipeline()
|
|
|
|
|
57 |
pipe = pipe.to(device)
|
58 |
|
|
|
59 |
def fn_segmentation(image, max_kernel, min_kernel):
|
60 |
inputs = feature_extractor(images=image, return_tensors="pt")
|
61 |
outputs = segmentation_model(**inputs)
|
@@ -81,17 +86,7 @@ def fn_segmentation(image, max_kernel, min_kernel):
|
|
81 |
|
82 |
return raw_masks, checkbox_group, gr.Image.update(value=np.zeros((image.height, image.width))), gr.Image.update(value=image)
|
83 |
|
84 |
-
|
85 |
-
out = []
|
86 |
-
for m in masks:
|
87 |
-
m = torch.FloatTensor(m)[None, None]
|
88 |
-
m = min_pool(m, min_kernel)
|
89 |
-
m = max_pool(m, max_kernel)
|
90 |
-
m = m.squeeze().numpy().astype(np.uint8)
|
91 |
-
out.append(m)
|
92 |
-
|
93 |
-
return out
|
94 |
-
|
95 |
def fn_update_mask(
|
96 |
image: Image,
|
97 |
masks: List[np.array],
|
@@ -108,6 +103,7 @@ def fn_update_mask(
|
|
108 |
|
109 |
return combined_mask.astype(np.uint8) * 255, Image.fromarray(masked_image)
|
110 |
|
|
|
111 |
def fn_diffusion(
|
112 |
prompt: str,
|
113 |
masked_image: Image,
|
@@ -118,6 +114,9 @@ def fn_diffusion(
|
|
118 |
):
|
119 |
if len(negative_prompt) == 0:
|
120 |
negative_prompt = None
|
|
|
|
|
|
|
121 |
STABLE_DIFFUSION_SMALL_EDGE = 512
|
122 |
|
123 |
w, h = masked_image.size
|
@@ -133,6 +132,7 @@ def fn_diffusion(
|
|
133 |
mask = Image.fromarray(mask).convert("RGB").resize((new_width, new_height))
|
134 |
masked_image = masked_image.convert("RGB").resize((new_width, new_height))
|
135 |
|
|
|
136 |
inpainted_image = pipe(
|
137 |
height=new_height,
|
138 |
width=new_width,
|
@@ -144,6 +144,7 @@ def fn_diffusion(
|
|
144 |
negative_prompt=negative_prompt
|
145 |
).images[0]
|
146 |
|
|
|
147 |
inpainted_image = inpainted_image.resize((w, h))
|
148 |
|
149 |
return inpainted_image
|
@@ -151,21 +152,24 @@ def fn_diffusion(
|
|
151 |
demo = gr.Blocks()
|
152 |
|
153 |
with demo:
|
|
|
154 |
input_image = gr.Image(value="http://images.cocodataset.org/val2017/000000039769.jpg", type='pil', label="Input Image")
|
155 |
|
|
|
156 |
bt_masks = gr.Button("Compute Masks")
|
157 |
-
|
158 |
with gr.Row():
|
159 |
mask_image = gr.Image(type='numpy', label="Diffusion Mask")
|
160 |
masked_image = gr.Image(type='pil', label="Masked Image")
|
161 |
mask_storage = gr.State()
|
162 |
|
|
|
163 |
with gr.Row():
|
164 |
max_slider = gr.Slider(minimum=1, maximum=99, value=23, step=2, label="Mask Overflow")
|
165 |
min_slider = gr.Slider(minimum=1, maximum=99, value=5, step=2, label="Mask Denoising")
|
166 |
|
167 |
mask_checkboxes = gr.CheckboxGroup(interactive=True, label="Mask Selection")
|
168 |
|
|
|
169 |
with gr.Row():
|
170 |
with gr.Column():
|
171 |
prompt = gr.Textbox("Two ginger cats lying together on a pink sofa. There are two TV remotes. High definition.", label="Prompt")
|
@@ -180,14 +184,19 @@ with demo:
|
|
180 |
update_mask_inputs = [input_image, mask_storage, mask_checkboxes, max_slider, min_slider]
|
181 |
update_mask_outputs = [mask_image, masked_image]
|
182 |
|
|
|
183 |
input_image.change(lambda: gr.CheckboxGroup.update(choices=[], value=[]), outputs=mask_checkboxes)
|
184 |
|
|
|
185 |
bt_masks.click(fn_segmentation, inputs=[input_image, max_slider, min_slider], outputs=[mask_storage, mask_checkboxes, mask_image, masked_image])
|
186 |
|
|
|
|
|
187 |
max_slider.change(fn_update_mask, inputs=update_mask_inputs, outputs=update_mask_outputs)
|
188 |
min_slider.change(fn_update_mask, inputs=update_mask_inputs, outputs=update_mask_outputs)
|
189 |
mask_checkboxes.change(fn_update_mask, inputs=update_mask_inputs, outputs=update_mask_outputs)
|
190 |
|
|
|
191 |
bt_diffusion.click(fn_diffusion, inputs=[
|
192 |
prompt,
|
193 |
masked_image,
|
|
|
17 |
torch.inference_mode()
|
18 |
torch.no_grad()
|
19 |
|
20 |
+
# Load segmentation models
|
21 |
def load_segmentation_models(model_name: str = 'facebook/detr-resnet-50-panoptic'):
|
22 |
feature_extractor = DetrFeatureExtractor.from_pretrained(model_name)
|
23 |
model = DetrForSegmentation.from_pretrained(model_name)
|
|
|
25 |
|
26 |
return feature_extractor, model, cfg
|
27 |
|
28 |
+
# Load diffusion pipeline
|
29 |
def load_diffusion_pipeline(model_name: str = 'runwayml/stable-diffusion-inpainting'):
|
30 |
return StableDiffusionInpaintPipeline.from_pretrained(
|
31 |
model_name,
|
|
|
33 |
torch_dtype=torch.float16
|
34 |
)
|
35 |
|
36 |
+
# Device helper
|
37 |
def get_device(try_cuda=True):
|
38 |
return torch.device('cuda' if try_cuda and torch.cuda.is_available() else 'cpu')
|
39 |
|
|
|
45 |
pad_size = (kernel_size - 1) // 2
|
46 |
return torch.nn.functional.max_pool2d(x, kernel_size, (1, 1), padding=pad_size)
|
47 |
|
48 |
+
# Apply min-max pooling to clean up mask
|
49 |
def clean_mask(mask, max_kernel: int = 23, min_kernel: int = 5):
|
50 |
mask = torch.Tensor(mask[None, None]).float()
|
51 |
mask = min_pool(mask, min_kernel)
|
|
|
53 |
mask = mask.bool().squeeze().numpy()
|
54 |
return mask
|
55 |
|
|
|
56 |
|
57 |
feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models()
|
|
|
58 |
pipe = load_diffusion_pipeline()
|
59 |
+
|
60 |
+
device = get_device()
|
61 |
pipe = pipe.to(device)
|
62 |
|
63 |
+
# Callback function that runs segmentation and updates CheckboxGroup
|
64 |
def fn_segmentation(image, max_kernel, min_kernel):
|
65 |
inputs = feature_extractor(images=image, return_tensors="pt")
|
66 |
outputs = segmentation_model(**inputs)
|
|
|
86 |
|
87 |
return raw_masks, checkbox_group, gr.Image.update(value=np.zeros((image.height, image.width))), gr.Image.update(value=image)
|
88 |
|
89 |
+
# Callback function that updates the displayed mask based on selected checkboxes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
def fn_update_mask(
|
91 |
image: Image,
|
92 |
masks: List[np.array],
|
|
|
103 |
|
104 |
return combined_mask.astype(np.uint8) * 255, Image.fromarray(masked_image)
|
105 |
|
106 |
+
# Callback function that runs diffusion given the current image, mask and prompt.
|
107 |
def fn_diffusion(
|
108 |
prompt: str,
|
109 |
masked_image: Image,
|
|
|
114 |
):
|
115 |
if len(negative_prompt) == 0:
|
116 |
negative_prompt = None
|
117 |
+
|
118 |
+
# Resize image to a more stable diffusion friendly format.
|
119 |
+
# TODO: remove magic number
|
120 |
STABLE_DIFFUSION_SMALL_EDGE = 512
|
121 |
|
122 |
w, h = masked_image.size
|
|
|
132 |
mask = Image.fromarray(mask).convert("RGB").resize((new_width, new_height))
|
133 |
masked_image = masked_image.convert("RGB").resize((new_width, new_height))
|
134 |
|
135 |
+
# Run diffusion
|
136 |
inpainted_image = pipe(
|
137 |
height=new_height,
|
138 |
width=new_width,
|
|
|
144 |
negative_prompt=negative_prompt
|
145 |
).images[0]
|
146 |
|
147 |
+
# Resize back to the original size
|
148 |
inpainted_image = inpainted_image.resize((w, h))
|
149 |
|
150 |
return inpainted_image
|
|
|
152 |
demo = gr.Blocks()
|
153 |
|
154 |
with demo:
|
155 |
+
# Input image control
|
156 |
input_image = gr.Image(value="http://images.cocodataset.org/val2017/000000039769.jpg", type='pil', label="Input Image")
|
157 |
|
158 |
+
# Combined mask controls
|
159 |
bt_masks = gr.Button("Compute Masks")
|
|
|
160 |
with gr.Row():
|
161 |
mask_image = gr.Image(type='numpy', label="Diffusion Mask")
|
162 |
masked_image = gr.Image(type='pil', label="Masked Image")
|
163 |
mask_storage = gr.State()
|
164 |
|
165 |
+
# Mask editing controls
|
166 |
with gr.Row():
|
167 |
max_slider = gr.Slider(minimum=1, maximum=99, value=23, step=2, label="Mask Overflow")
|
168 |
min_slider = gr.Slider(minimum=1, maximum=99, value=5, step=2, label="Mask Denoising")
|
169 |
|
170 |
mask_checkboxes = gr.CheckboxGroup(interactive=True, label="Mask Selection")
|
171 |
|
172 |
+
# Diffusion controls and output
|
173 |
with gr.Row():
|
174 |
with gr.Column():
|
175 |
prompt = gr.Textbox("Two ginger cats lying together on a pink sofa. There are two TV remotes. High definition.", label="Prompt")
|
|
|
184 |
update_mask_inputs = [input_image, mask_storage, mask_checkboxes, max_slider, min_slider]
|
185 |
update_mask_outputs = [mask_image, masked_image]
|
186 |
|
187 |
+
# Clear checkbox group on input image change
|
188 |
input_image.change(lambda: gr.CheckboxGroup.update(choices=[], value=[]), outputs=mask_checkboxes)
|
189 |
|
190 |
+
# Segmentation button callback
|
191 |
bt_masks.click(fn_segmentation, inputs=[input_image, max_slider, min_slider], outputs=[mask_storage, mask_checkboxes, mask_image, masked_image])
|
192 |
|
193 |
+
# Update mask callbacks
|
194 |
+
# TODO: can we replace this with `mask_image.change`? Not sure if it will actively update.
|
195 |
max_slider.change(fn_update_mask, inputs=update_mask_inputs, outputs=update_mask_outputs)
|
196 |
min_slider.change(fn_update_mask, inputs=update_mask_inputs, outputs=update_mask_outputs)
|
197 |
mask_checkboxes.change(fn_update_mask, inputs=update_mask_inputs, outputs=update_mask_outputs)
|
198 |
|
199 |
+
# Diffusion button callback
|
200 |
bt_diffusion.click(fn_diffusion, inputs=[
|
201 |
prompt,
|
202 |
masked_image,
|