Spaces:
Build error
Build error
Commit
·
5b2ab1c
1
Parent(s):
8f5320e
Initial commit
Browse files- app.py +23 -0
- examples/notebooks/demo_controlnet.ipynb +0 -0
- examples/notebooks/demo_sam.ipynb +0 -0
- requirements.txt +4 -0
- src/designgenie/__init__.py +1 -0
- src/designgenie/data/__init__.py +1 -0
- src/designgenie/data/image_folder.py +69 -0
- src/designgenie/interfaces/__init__.py +1 -0
- src/designgenie/interfaces/gradio_interface.py +240 -0
- src/designgenie/models/__init__.py +6 -0
- src/designgenie/models/diffusion/__init__.py +15 -0
- src/designgenie/models/diffusion/controlnet.py +160 -0
- src/designgenie/models/diffusion/controlnet_inpaint.py +125 -0
- src/designgenie/models/segmentation/__init__.py +16 -0
- src/designgenie/models/segmentation/maskformer.py +122 -0
- src/designgenie/pipelines/__init__.py +1 -0
- src/designgenie/pipelines/inpaint_pipeline.py +81 -0
- src/designgenie/utils/__init__.py +6 -0
- src/designgenie/utils/helper.py +55 -0
- src/designgenie/utils/segmentation_utils.py +287 -0
app.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.designgenie.interfaces import GradioApp
|
2 |
+
|
3 |
+
|
4 |
+
# def run_pipeline():
|
5 |
+
# pipe = InpaintPipeline(
|
6 |
+
# segmentation_model_name="mask2former",
|
7 |
+
# diffusion_model_name="controlnet_inpaint",
|
8 |
+
# control_model_name="mlsd",
|
9 |
+
# images_root="/home/nader/Projects/DesignGenie/assets/images/",
|
10 |
+
# prompts_path="/home/nader/Projects/DesignGenie/assets/prompts.txt",
|
11 |
+
# image_size=(768, 512),
|
12 |
+
# image_extensions=(".jpg", ".jpeg", ".png", ".webp"),
|
13 |
+
# )
|
14 |
+
|
15 |
+
# results = pipe.run()
|
16 |
+
|
17 |
+
# for i, images in enumerate(results):
|
18 |
+
# for j, image in enumerate(images):
|
19 |
+
# image.save(f"./assets/results/result_{i}_{j}.png")
|
20 |
+
|
21 |
+
if __name__ == "__main__":
|
22 |
+
app = GradioApp()
|
23 |
+
app.interface.launch(share=True)
|
examples/notebooks/demo_controlnet.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
examples/notebooks/demo_sam.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
torch
|
3 |
+
transformers
|
4 |
+
diffusers
|
src/designgenie/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .pipelines import InpaintPipeline
|
src/designgenie/data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .image_folder import ImageFolderDataset
|
src/designgenie/data/image_folder.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List, Optional, Tuple, Union
|
2 |
+
import os
|
3 |
+
from PIL import Image
|
4 |
+
from random import randint, choices
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
from diffusers.utils import load_image
|
10 |
+
|
11 |
+
|
12 |
+
class ImageFolderDataset(Dataset):
|
13 |
+
"""Dataset class for loading images and prompts from a folder and file path.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
images_root (str):
|
17 |
+
Path to the folder containing images.
|
18 |
+
prompts_path (str):
|
19 |
+
Path to the file containing prompts.
|
20 |
+
image_size (Tuple[int, int]):
|
21 |
+
Size of the images to be loaded.
|
22 |
+
extensions (Tuple[str]):
|
23 |
+
Tuple of valid image extensions.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
images_root: str,
|
29 |
+
prompts_path: Optional[str] = None,
|
30 |
+
image_size: Tuple[int, int] = (512, 512),
|
31 |
+
extensions: Tuple[str] = (".jpg", ".jpeg", ".png", ".webp"),
|
32 |
+
) -> None:
|
33 |
+
super().__init__()
|
34 |
+
self.image_size = image_size
|
35 |
+
|
36 |
+
self.images_paths, self.prompts = self._make_dataset(
|
37 |
+
images_root=images_root, extensions=extensions, prompts_path=prompts_path
|
38 |
+
)
|
39 |
+
|
40 |
+
self.to_tensor = transforms.ToTensor()
|
41 |
+
|
42 |
+
def _make_dataset(
|
43 |
+
self,
|
44 |
+
images_root: str,
|
45 |
+
extensions: Tuple[str],
|
46 |
+
prompts_path: Optional[str] = None,
|
47 |
+
) -> Tuple[List[str], Union[None, List[str]]]:
|
48 |
+
images_paths = []
|
49 |
+
for root, _, fnames in sorted(os.walk(images_root)):
|
50 |
+
for fname in sorted(fnames):
|
51 |
+
if fname.lower().endswith(extensions):
|
52 |
+
images_paths.append(os.path.join(root, fname))
|
53 |
+
|
54 |
+
if prompts_path is not None:
|
55 |
+
with open(prompts_path, "r") as f:
|
56 |
+
prompts = f.readlines()
|
57 |
+
else:
|
58 |
+
prompts = None
|
59 |
+
|
60 |
+
return images_paths, prompts
|
61 |
+
|
62 |
+
def __len__(self) -> int:
|
63 |
+
return len(self.images_paths)
|
64 |
+
|
65 |
+
def __getitem__(self, idx: int) -> Tuple[Image.Image, Union[None, str]]:
|
66 |
+
image = load_image(self.images_paths[idx]).resize(self.image_size)
|
67 |
+
prompt = self.prompts[idx] if self.prompts is not None else None
|
68 |
+
|
69 |
+
return self.to_tensor(image), prompt
|
src/designgenie/interfaces/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .gradio_interface import GradioApp
|
src/designgenie/interfaces/gradio_interface.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
import cv2
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from ..models import create_diffusion_model, create_segmentation_model
|
10 |
+
from ..utils import (
|
11 |
+
get_masked_images,
|
12 |
+
visualize_segmentation_map,
|
13 |
+
get_masks_from_segmentation_map,
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
# points color and marker
|
18 |
+
COLOR = (255, 0, 0)
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class AppState:
|
23 |
+
"""A class to store the memory state of the Gradio App."""
|
24 |
+
|
25 |
+
original_image: Image.Image = None
|
26 |
+
predicted_semantic_map: torch.Tensor = None
|
27 |
+
input_coordinates: List[int] = field(default_factory=list)
|
28 |
+
n_outputs: int = 2
|
29 |
+
|
30 |
+
|
31 |
+
class GradioApp:
|
32 |
+
def __init__(self):
|
33 |
+
self._interface = self.build_interface()
|
34 |
+
self.state = AppState()
|
35 |
+
|
36 |
+
self.segmentation_model = None
|
37 |
+
self.diffusion_model = None
|
38 |
+
|
39 |
+
@property
|
40 |
+
def interface(self):
|
41 |
+
return self._interface
|
42 |
+
|
43 |
+
def _segment_input(self, image: Image.Image, model_name: str) -> Image.Image:
|
44 |
+
"""Segment the input image using the given model."""
|
45 |
+
if self.segmentation_model is None:
|
46 |
+
self.segmentation_model = create_segmentation_model(
|
47 |
+
segmentation_model_name=model_name
|
48 |
+
)
|
49 |
+
|
50 |
+
predicted_semantic_map = self.segmentation_model.process([image])[0]
|
51 |
+
self.state.predicted_semantic_map = predicted_semantic_map
|
52 |
+
|
53 |
+
segmentation_map = visualize_segmentation_map(predicted_semantic_map, image)
|
54 |
+
return segmentation_map
|
55 |
+
|
56 |
+
def _generate_outputs(
|
57 |
+
self,
|
58 |
+
prompt: str,
|
59 |
+
model_name: str,
|
60 |
+
n_outputs: int,
|
61 |
+
inference_steps: int,
|
62 |
+
strength: float,
|
63 |
+
guidance_scale: float,
|
64 |
+
eta: float,
|
65 |
+
) -> Image.Image:
|
66 |
+
if self.diffusion_model is None:
|
67 |
+
self.diffusion_model = create_diffusion_model(
|
68 |
+
diffusion_model_name="controlnet_inpaint", control_model_name=model_name
|
69 |
+
)
|
70 |
+
|
71 |
+
control_image = self.diffusion_model.generate_control_images(
|
72 |
+
images=[self.state.original_image]
|
73 |
+
)[0]
|
74 |
+
|
75 |
+
image_mask, masked_control_image = get_masked_images(
|
76 |
+
control_image,
|
77 |
+
self.state.predicted_semantic_map,
|
78 |
+
self.state.input_coordinates,
|
79 |
+
)
|
80 |
+
|
81 |
+
outputs = self.diffusion_model.process(
|
82 |
+
images=[self.state.original_image],
|
83 |
+
prompts=[prompt],
|
84 |
+
mask_images=[image_mask],
|
85 |
+
control_images=[masked_control_image],
|
86 |
+
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
|
87 |
+
n_outputs=n_outputs,
|
88 |
+
)
|
89 |
+
|
90 |
+
return (
|
91 |
+
*outputs["output_images"][0],
|
92 |
+
control_image,
|
93 |
+
image_mask,
|
94 |
+
)
|
95 |
+
|
96 |
+
def image_change(self, input_image):
|
97 |
+
input_image = input_image.resize((768, 512))
|
98 |
+
self.state.original_image = input_image
|
99 |
+
return input_image
|
100 |
+
|
101 |
+
def clear_coordinates(self):
|
102 |
+
self.state.input_coordinates = []
|
103 |
+
|
104 |
+
def get_coordinates(self, event: gr.SelectData, input_image: Image.Image):
|
105 |
+
w, h = tuple(event.index)
|
106 |
+
self.state.input_coordinates.append((h, w))
|
107 |
+
print(self.state.input_coordinates)
|
108 |
+
|
109 |
+
return Image.fromarray(
|
110 |
+
cv2.drawMarker(
|
111 |
+
np.asarray(input_image), event.index, COLOR, markerSize=20, thickness=5
|
112 |
+
)
|
113 |
+
)
|
114 |
+
|
115 |
+
def build_interface(self):
|
116 |
+
"""Builds the Gradio interface for the DesignGenie app."""
|
117 |
+
with gr.Blocks() as designgenie_interface:
|
118 |
+
# --> App Header <---
|
119 |
+
with gr.Row():
|
120 |
+
# --> Description <--
|
121 |
+
with gr.Column():
|
122 |
+
gr.Markdown(
|
123 |
+
"""
|
124 |
+
# DesignGenie
|
125 |
+
|
126 |
+
An AI copilot for home interior design. It identifies various sections of your home and generates personalized designs for the selected sections using ContolNet and StableDiffusion.
|
127 |
+
"""
|
128 |
+
)
|
129 |
+
# --> Model Selection <--
|
130 |
+
with gr.Column():
|
131 |
+
with gr.Row():
|
132 |
+
segmentation_model = gr.Dropdown(
|
133 |
+
choices=["mask2former", "maskformer"],
|
134 |
+
label="Segmentation Model",
|
135 |
+
value="mask2former",
|
136 |
+
interactive=True,
|
137 |
+
)
|
138 |
+
controlnet_model = gr.Dropdown(
|
139 |
+
choices=["mlsd", "soft_edge", "hed", "scribble"],
|
140 |
+
label="Controlnet Module",
|
141 |
+
value="mlsd",
|
142 |
+
interactive=True,
|
143 |
+
)
|
144 |
+
|
145 |
+
# --> Model Parameters <--
|
146 |
+
with gr.Accordion(label="Parameters", open=False):
|
147 |
+
with gr.Column():
|
148 |
+
gr.Markdown("### Stable Diffusion Parameters")
|
149 |
+
with gr.Row():
|
150 |
+
with gr.Column():
|
151 |
+
inference_steps = gr.Number(
|
152 |
+
value=30, label="Number of inference steps."
|
153 |
+
)
|
154 |
+
strength = gr.Number(value=1.0, label="Strength.")
|
155 |
+
with gr.Column():
|
156 |
+
guidance_scale = gr.Number(value=7.5, label="Guidance scale.")
|
157 |
+
eta = gr.Number(value=0.0, label="Eta.")
|
158 |
+
|
159 |
+
with gr.Row().style(equal_height=False):
|
160 |
+
with gr.Column():
|
161 |
+
# --> Input Image and Segmentation <--
|
162 |
+
input_image = gr.Image(label="Input Image", type="pil")
|
163 |
+
input_image.select(
|
164 |
+
self.get_coordinates,
|
165 |
+
inputs=[input_image],
|
166 |
+
outputs=[input_image],
|
167 |
+
)
|
168 |
+
input_image.upload(
|
169 |
+
self.image_change, inputs=[input_image], outputs=[input_image]
|
170 |
+
)
|
171 |
+
|
172 |
+
with gr.Row():
|
173 |
+
gr.Markdown(
|
174 |
+
"""
|
175 |
+
1. Select your input image.
|
176 |
+
2. Click on `Segment Image` button.
|
177 |
+
3. Choose the segments that you want to redisgn by simply clicking on the image.
|
178 |
+
"""
|
179 |
+
)
|
180 |
+
with gr.Column():
|
181 |
+
segment_btn = gr.Button(
|
182 |
+
value="Segment Image", variant="primary"
|
183 |
+
)
|
184 |
+
clear_btn = gr.Button(value="Clear")
|
185 |
+
|
186 |
+
segment_btn.click(
|
187 |
+
self._segment_input,
|
188 |
+
inputs=[input_image, segmentation_model],
|
189 |
+
outputs=input_image,
|
190 |
+
)
|
191 |
+
clear_btn.click(self.clear_coordinates)
|
192 |
+
|
193 |
+
# --> Prompt and Num Outputs <--
|
194 |
+
text = gr.Textbox(
|
195 |
+
label="Text prompt(optional)",
|
196 |
+
info="You can describe how the model should redesign the selected segments of your home.",
|
197 |
+
)
|
198 |
+
num_outputs = gr.Slider(
|
199 |
+
value=3,
|
200 |
+
minimum=1,
|
201 |
+
maximum=5,
|
202 |
+
step=1,
|
203 |
+
interactive=True,
|
204 |
+
label="Number of Generated Outputs",
|
205 |
+
info="Number of design outputs you want the model to generate.",
|
206 |
+
)
|
207 |
+
|
208 |
+
submit_btn = gr.Button(value="Submit", variant="primary")
|
209 |
+
|
210 |
+
with gr.Column():
|
211 |
+
with gr.Tab(label="Output Images"):
|
212 |
+
output_images = [
|
213 |
+
gr.Image(
|
214 |
+
interactive=False, label=f"Output Image {i}", type="pil"
|
215 |
+
)
|
216 |
+
for i in range(3)
|
217 |
+
]
|
218 |
+
|
219 |
+
with gr.Tab(label="Control Images"):
|
220 |
+
control_labels = ["Control Image", "Generated Mask"]
|
221 |
+
control_images = [
|
222 |
+
gr.Image(interactive=False, label=label, type="pil")
|
223 |
+
for label in control_labels
|
224 |
+
]
|
225 |
+
|
226 |
+
submit_btn.click(
|
227 |
+
self._generate_outputs,
|
228 |
+
inputs=[
|
229 |
+
text,
|
230 |
+
controlnet_model,
|
231 |
+
num_outputs,
|
232 |
+
inference_steps,
|
233 |
+
strength,
|
234 |
+
guidance_scale,
|
235 |
+
eta,
|
236 |
+
],
|
237 |
+
outputs=output_images + control_images,
|
238 |
+
)
|
239 |
+
|
240 |
+
return designgenie_interface
|
src/designgenie/models/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .diffusion import (
|
2 |
+
StableDiffusionControlNet,
|
3 |
+
StableDiffusionControlNetInpaint,
|
4 |
+
create_diffusion_model,
|
5 |
+
)
|
6 |
+
from .segmentation import MaskFormer, Mask2Former, create_segmentation_model
|
src/designgenie/models/diffusion/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .controlnet import StableDiffusionControlNet
|
2 |
+
from .controlnet_inpaint import StableDiffusionControlNetInpaint
|
3 |
+
|
4 |
+
DIFFUSION_MODELS = {
|
5 |
+
"controlnet": StableDiffusionControlNet,
|
6 |
+
"controlnet_inpaint": StableDiffusionControlNetInpaint,
|
7 |
+
}
|
8 |
+
|
9 |
+
|
10 |
+
def create_diffusion_model(diffusion_model_name: str, **kwargs):
|
11 |
+
assert (
|
12 |
+
diffusion_model_name in DIFFUSION_MODELS.keys()
|
13 |
+
), "Diffusion model name must be one of " + ", ".join(DIFFUSION_MODELS.keys())
|
14 |
+
|
15 |
+
return DIFFUSION_MODELS[diffusion_model_name](**kwargs)
|
src/designgenie/models/diffusion/controlnet.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List, Optional, Tuple, Union
|
2 |
+
import itertools
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from controlnet_aux import MLSDdetector, PidiNetDetector, HEDdetector
|
8 |
+
from diffusers import (
|
9 |
+
ControlNetModel,
|
10 |
+
StableDiffusionControlNetPipeline,
|
11 |
+
UniPCMultistepScheduler,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
MODEL_DICT = {
|
16 |
+
"mlsd": {
|
17 |
+
"name": "lllyasviel/Annotators",
|
18 |
+
"detector": MLSDdetector,
|
19 |
+
"model": "lllyasviel/control_v11p_sd15_mlsd",
|
20 |
+
},
|
21 |
+
"soft_edge": {
|
22 |
+
"name": "lllyasviel/Annotators",
|
23 |
+
"detector": PidiNetDetector,
|
24 |
+
"model": "lllyasviel/control_v11p_sd15_softedge",
|
25 |
+
},
|
26 |
+
"hed": {
|
27 |
+
"name": "lllyasviel/Annotators",
|
28 |
+
"detector": HEDdetector,
|
29 |
+
"model": "lllyasviel/sd-controlnet-hed",
|
30 |
+
},
|
31 |
+
"scribble": {
|
32 |
+
"name": "lllyasviel/Annotators",
|
33 |
+
"detector": HEDdetector,
|
34 |
+
"model": "lllyasviel/control_v11p_sd15_scribble",
|
35 |
+
},
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
class StableDiffusionControlNet:
|
40 |
+
"""ControlNet pipeline for generating images from prompts.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
control_model_name (str):
|
44 |
+
Name of the controlnet processor.
|
45 |
+
sd_model_name (str):
|
46 |
+
Name of the StableDiffusion model.
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
control_model_name: str,
|
52 |
+
sd_model_name: Optional[str] = "runwayml/stable-diffusion-v1-5",
|
53 |
+
) -> None:
|
54 |
+
self.processor = MODEL_DICT[control_model_name]["detector"].from_pretrained(
|
55 |
+
MODEL_DICT[control_model_name]["name"]
|
56 |
+
)
|
57 |
+
self.pipe = self.create_pipe(
|
58 |
+
sd_model_name=sd_model_name, control_model_name=control_model_name
|
59 |
+
)
|
60 |
+
|
61 |
+
def _repeat(self, items: List[Any], n: int) -> List[Any]:
|
62 |
+
"""Repeat items in a list n times.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
items (List[Any]): List of items to be repeated.
|
66 |
+
n (int): Number of repetitions.
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
List[Any]: List of repeated items.
|
70 |
+
"""
|
71 |
+
return list(
|
72 |
+
itertools.chain.from_iterable(itertools.repeat(item, n) for item in items)
|
73 |
+
)
|
74 |
+
|
75 |
+
def generate_control_images(self, images: List[Image.Image]) -> List[Image.Image]:
|
76 |
+
"""Generate control images from input images.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
images (List[Image.Image]): Input images.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
List[Image.Image]: Control images.
|
83 |
+
"""
|
84 |
+
return [self.processor(image) for image in images]
|
85 |
+
|
86 |
+
def create_pipe(
|
87 |
+
self, sd_model_name: str, control_model_name: str
|
88 |
+
) -> StableDiffusionControlNetPipeline:
|
89 |
+
"""Create a StableDiffusionControlNetPipeline.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
sd_model_name (str): StableDiffusion model name.
|
93 |
+
control_model_name (str): Name of the ControlNet module.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
StableDiffusionControlNetPipeline
|
97 |
+
"""
|
98 |
+
controlnet = ControlNetModel.from_pretrained(
|
99 |
+
MODEL_DICT[control_model_name]["model"], torch_dtype=torch.float16
|
100 |
+
)
|
101 |
+
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
102 |
+
sd_model_name, controlnet=controlnet, torch_dtype=torch.float16
|
103 |
+
)
|
104 |
+
|
105 |
+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
106 |
+
pipe.enable_model_cpu_offload()
|
107 |
+
pipe.enable_xformers_memory_efficient_attention()
|
108 |
+
|
109 |
+
return pipe
|
110 |
+
|
111 |
+
def process(
|
112 |
+
self,
|
113 |
+
images: List[Image.Image],
|
114 |
+
prompts: List[str],
|
115 |
+
negative_prompt: Optional[str] = None,
|
116 |
+
n_outputs: Optional[int] = 1,
|
117 |
+
num_inference_steps: Optional[int] = 30,
|
118 |
+
) -> List[List[Image.Image]]:
|
119 |
+
"""Generate images from `prompts` using `control_images` and `negative_prompt`.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
images (List[Image.Image]): Input images.
|
123 |
+
prompts (List[str]): List of prompts.
|
124 |
+
negative_prompt (Optional[str], optional): Negative prompt. Defaults to None.
|
125 |
+
n_outputs (Optional[int], optional): Number of generated outputs. Defaults to 1.
|
126 |
+
num_inference_steps (Optional[int], optional): Number of inference iterations. Defaults to 30.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
List[List[Image.Image]]
|
130 |
+
"""
|
131 |
+
|
132 |
+
control_images = self.generate_control_images(images)
|
133 |
+
|
134 |
+
assert len(prompts) == len(
|
135 |
+
control_images
|
136 |
+
), "Number of prompts and input images must be equal."
|
137 |
+
|
138 |
+
if n_outputs > 1:
|
139 |
+
prompts = self._repeat(prompts, n=n_outputs)
|
140 |
+
control_images = self._repeat(control_images, n=n_outputs)
|
141 |
+
|
142 |
+
generator = [
|
143 |
+
torch.Generator(device="cuda").manual_seed(int(i))
|
144 |
+
for i in np.random.randint(len(prompts), size=len(prompts))
|
145 |
+
]
|
146 |
+
|
147 |
+
output = self.pipe(
|
148 |
+
prompts,
|
149 |
+
image=control_images,
|
150 |
+
negative_prompt=[negative_prompt] * len(prompts),
|
151 |
+
num_inference_steps=num_inference_steps,
|
152 |
+
generator=generator,
|
153 |
+
)
|
154 |
+
|
155 |
+
output_images = [
|
156 |
+
output.images[idx * n_outputs : (idx + 1) * n_outputs]
|
157 |
+
for idx in range(len(images))
|
158 |
+
]
|
159 |
+
|
160 |
+
return output_images
|
src/designgenie/models/diffusion/controlnet_inpaint.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Tuple, Union
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from diffusers import (
|
7 |
+
ControlNetModel,
|
8 |
+
StableDiffusionControlNetInpaintPipeline,
|
9 |
+
UniPCMultistepScheduler,
|
10 |
+
)
|
11 |
+
|
12 |
+
from .controlnet import StableDiffusionControlNet, MODEL_DICT
|
13 |
+
|
14 |
+
|
15 |
+
class StableDiffusionControlNetInpaint(StableDiffusionControlNet):
|
16 |
+
"""StableDiffusion with ControlNet model for inpainting images based on prompts.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
control_model_name (str):
|
20 |
+
Name of the controlnet processor.
|
21 |
+
sd_model_name (str):
|
22 |
+
Name of the StableDiffusion model.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
control_model_name: str,
|
28 |
+
sd_model_name: Optional[str] = "runwayml/stable-diffusion-inpainting",
|
29 |
+
) -> None:
|
30 |
+
super().__init__(
|
31 |
+
control_model_name=control_model_name,
|
32 |
+
sd_model_name=sd_model_name,
|
33 |
+
)
|
34 |
+
|
35 |
+
def create_pipe(
|
36 |
+
self, sd_model_name: str, control_model_name: str
|
37 |
+
) -> StableDiffusionControlNetInpaintPipeline:
|
38 |
+
"""Create a StableDiffusionControlNetInpaintPipeline.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
sd_model_name (str): StableDiffusion model name.
|
42 |
+
control_model_name (str): Name of the ControlNet module.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
StableDiffusionControlNetInpaintPipeline
|
46 |
+
"""
|
47 |
+
controlnet = ControlNetModel.from_pretrained(
|
48 |
+
MODEL_DICT[control_model_name]["model"], torch_dtype=torch.float16
|
49 |
+
)
|
50 |
+
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
51 |
+
sd_model_name, controlnet=controlnet, torch_dtype=torch.float16
|
52 |
+
)
|
53 |
+
|
54 |
+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
55 |
+
pipe.enable_model_cpu_offload()
|
56 |
+
pipe.enable_xformers_memory_efficient_attention()
|
57 |
+
|
58 |
+
return pipe
|
59 |
+
|
60 |
+
def process(
|
61 |
+
self,
|
62 |
+
images: List[Image.Image],
|
63 |
+
prompts: List[str],
|
64 |
+
mask_images: List[Image.Image],
|
65 |
+
control_images: Optional[List[Image.Image]] = None,
|
66 |
+
negative_prompt: Optional[str] = None,
|
67 |
+
n_outputs: Optional[int] = 1,
|
68 |
+
num_inference_steps: Optional[int] = 30,
|
69 |
+
strength: Optional[float] = 1.0,
|
70 |
+
guidance_scale: Optional[float] = 7.5,
|
71 |
+
eta: Optional[float] = 0.0,
|
72 |
+
) -> List[List[Image.Image]]:
|
73 |
+
"""Inpaint images based on `prompts` using `control_images` and `mask_images`.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
images (List[Image.Image]): Input images.
|
77 |
+
prompts (List[str]): List of prompts.
|
78 |
+
mask_images (List[Image.Image]): List of mask images.
|
79 |
+
control_images (Optional[List[Image.Image]], optional): List of control images. Defaults to None.
|
80 |
+
negative_prompt (Optional[str], optional): Negative prompt. Defaults to None.
|
81 |
+
n_outputs (Optional[int], optional): Number of generated outputs. Defaults to 1.
|
82 |
+
num_inference_steps (Optional[int], optional): Number of inference iterations. Defaults to 30.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
List[List[Image.Image]]
|
86 |
+
"""
|
87 |
+
|
88 |
+
if control_images is None:
|
89 |
+
control_images = self.generate_control_images(images)
|
90 |
+
|
91 |
+
assert len(prompts) == len(
|
92 |
+
control_images
|
93 |
+
), "Number of prompts and input images must be equal."
|
94 |
+
|
95 |
+
if n_outputs > 1:
|
96 |
+
prompts = self._repeat(prompts, n=n_outputs)
|
97 |
+
images = self._repeat(images, n=n_outputs)
|
98 |
+
control_images = self._repeat(control_images, n=n_outputs)
|
99 |
+
mask_images = self._repeat(mask_images, n=n_outputs)
|
100 |
+
|
101 |
+
generator = [
|
102 |
+
torch.Generator(device="cuda").manual_seed(int(i))
|
103 |
+
for i in np.random.randint(max(len(prompts), 16), size=len(prompts))
|
104 |
+
]
|
105 |
+
|
106 |
+
output = self.pipe(
|
107 |
+
prompts,
|
108 |
+
image=images,
|
109 |
+
control_image=control_images,
|
110 |
+
mask_image=mask_images,
|
111 |
+
negative_prompt=[negative_prompt] * len(prompts),
|
112 |
+
num_inference_steps=num_inference_steps,
|
113 |
+
generator=generator,
|
114 |
+
)
|
115 |
+
|
116 |
+
output_images = [
|
117 |
+
output.images[idx * n_outputs : (idx + 1) * n_outputs]
|
118 |
+
for idx in range(len(images) // n_outputs)
|
119 |
+
]
|
120 |
+
|
121 |
+
return {
|
122 |
+
"output_images": output_images,
|
123 |
+
"control_images": control_images,
|
124 |
+
"mask_images": mask_images,
|
125 |
+
}
|
src/designgenie/models/segmentation/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .maskformer import MaskFormer, Mask2Former
|
2 |
+
|
3 |
+
SEGMENTATION_MODEL_DICT = {
|
4 |
+
"maskformer": MaskFormer,
|
5 |
+
"mask2former": Mask2Former,
|
6 |
+
}
|
7 |
+
|
8 |
+
|
9 |
+
def create_segmentation_model(segmentation_model_name: str, **kwargs):
|
10 |
+
assert (
|
11 |
+
segmentation_model_name in SEGMENTATION_MODEL_DICT.keys()
|
12 |
+
), "Segmentation model name must be one of " + ", ".join(
|
13 |
+
SEGMENTATION_MODEL_DICT.keys()
|
14 |
+
)
|
15 |
+
|
16 |
+
return SEGMENTATION_MODEL_DICT[segmentation_model_name](**kwargs)
|
src/designgenie/models/segmentation/maskformer.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List, Optional, Tuple, Union
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torchvision.transforms as transforms
|
6 |
+
from transformers import (
|
7 |
+
AutoImageProcessor,
|
8 |
+
Mask2FormerForUniversalSegmentation,
|
9 |
+
MaskFormerImageProcessor,
|
10 |
+
MaskFormerForInstanceSegmentation,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
class MaskFormer:
|
15 |
+
"""MaskFormer semantic segmentation model.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
model_size (str, optional):
|
19 |
+
Size of the MaskFormer model. Defaults to "large".
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, model_size: Optional[str] = "large") -> None:
|
23 |
+
assert model_size in [
|
24 |
+
"tiny",
|
25 |
+
"base",
|
26 |
+
"large",
|
27 |
+
], "Model size must be one of 'tiny', 'base', or 'large'"
|
28 |
+
|
29 |
+
self.processor = MaskFormerImageProcessor.from_pretrained(
|
30 |
+
f"facebook/maskformer-swin-{model_size}-ade"
|
31 |
+
)
|
32 |
+
self.model = MaskFormerForInstanceSegmentation.from_pretrained(
|
33 |
+
f"facebook/maskformer-swin-{model_size}-ade"
|
34 |
+
)
|
35 |
+
|
36 |
+
def process(self, images: List[Image.Image]):
|
37 |
+
inputs = self.processor(images=images, return_tensors="pt")
|
38 |
+
outputs = self.model(**inputs)
|
39 |
+
# model predicts class_queries_logits of shape `(batch_size, num_queries)`
|
40 |
+
# and masks_queries_logits of shape `(batch_size, num_queries, height, width)`
|
41 |
+
class_queries_logits = outputs.class_queries_logits
|
42 |
+
masks_queries_logits = outputs.masks_queries_logits
|
43 |
+
|
44 |
+
# you can pass them to processor for postprocessing
|
45 |
+
# we refer to the demo notebooks for visualization (see "Resources" section in the MaskFormer docs)
|
46 |
+
predicted_semantic_maps = self.processor.post_process_semantic_segmentation(
|
47 |
+
outputs, target_sizes=[images[0].size[::-1] * len(images)]
|
48 |
+
)
|
49 |
+
|
50 |
+
return predicted_semantic_maps
|
51 |
+
|
52 |
+
|
53 |
+
class Mask2Former(MaskFormer):
|
54 |
+
"""Mask2Former semantic segmentation model.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
model_size (str, optional):
|
58 |
+
Size of the Mask2Former model. Defaults to "large".
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __init__(self, model_size: Optional[str] = "large") -> None:
|
62 |
+
assert model_size in [
|
63 |
+
"tiny",
|
64 |
+
"base",
|
65 |
+
"large",
|
66 |
+
], "Model size must be one of 'tiny', 'base', or 'large'"
|
67 |
+
self.processor = AutoImageProcessor.from_pretrained(
|
68 |
+
f"facebook/mask2former-swin-{model_size}-ade-semantic"
|
69 |
+
)
|
70 |
+
self.model = Mask2FormerForUniversalSegmentation.from_pretrained(
|
71 |
+
f"facebook/mask2former-swin-{model_size}-ade-semantic"
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
+
# class ADESegmentation:
|
76 |
+
# def __init__(self, model_name: str):
|
77 |
+
# self.processor = MODEL_DICT[model_name]["processor"].from_pretrained(
|
78 |
+
# MODEL_DICT[model_name]["name"]
|
79 |
+
# )
|
80 |
+
# self.model = MODEL_DICT[model_name]["model"].from_pretrained(
|
81 |
+
# MODEL_DICT[model_name]["name"]
|
82 |
+
# )
|
83 |
+
|
84 |
+
# def predict(self, image: Image.Image):
|
85 |
+
# inputs = processor(images=image, return_tensors="pt")
|
86 |
+
# outputs = model(**inputs)
|
87 |
+
# # model predicts class_queries_logits of shape `(batch_size, num_queries)`
|
88 |
+
# # and masks_queries_logits of shape `(batch_size, num_queries, height, width)`
|
89 |
+
# class_queries_logits = outputs.class_queries_logits
|
90 |
+
# masks_queries_logits = outputs.masks_queries_logits
|
91 |
+
|
92 |
+
# # you can pass them to processor for postprocessing
|
93 |
+
# # we refer to the demo notebooks for visualization (see "Resources" section in the MaskFormer docs)
|
94 |
+
# predicted_semantic_maps = processor.post_process_semantic_segmentation(
|
95 |
+
# outputs, target_sizes=[image.size[::-1]]
|
96 |
+
# )
|
97 |
+
|
98 |
+
# return predicted_semantic_maps
|
99 |
+
|
100 |
+
# def get_mask(self, predicted_semantic_maps, class_id: int):
|
101 |
+
# masks, labels, obj_names = get_masks_from_segmentation_map(
|
102 |
+
# predicted_semantic_maps[0]
|
103 |
+
# )
|
104 |
+
|
105 |
+
# mask = masks[labels.index(ID)]
|
106 |
+
# object_mask = np.logical_not(mask).astype(int)
|
107 |
+
|
108 |
+
# mask = torch.Tensor(mask).repeat(3, 1, 1)
|
109 |
+
# object_mask = torch.Tensor(object_mask).repeat(3, 1, 1)
|
110 |
+
|
111 |
+
# return mask, object_mask
|
112 |
+
|
113 |
+
# def get_PIL_mask(self, predicted_semantic_maps, class_id: int):
|
114 |
+
# mask, object_mask = self.get_mask(predicted_semantic_maps[0], class_id=class_id)
|
115 |
+
|
116 |
+
# mask = transforms.ToPILImage()(mask)
|
117 |
+
# object_mask = transforms.ToPILImage()(object_mask)
|
118 |
+
|
119 |
+
# return mask, object_mask
|
120 |
+
|
121 |
+
# def get_PIL_segmentation_map(self, predicted_semantic_maps):
|
122 |
+
# return visualize_segmentation_map(predicted_semantic_maps[0])
|
src/designgenie/pipelines/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .inpaint_pipeline import InpaintPipeline
|
src/designgenie/pipelines/inpaint_pipeline.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
from torchvision.transforms.functional import to_pil_image
|
5 |
+
|
6 |
+
from ..data import ImageFolderDataset
|
7 |
+
from ..models import create_diffusion_model, create_segmentation_model
|
8 |
+
from ..utils import get_masked_images
|
9 |
+
|
10 |
+
|
11 |
+
class InpaintPipeline:
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
segmentation_model_name: str,
|
15 |
+
diffusion_model_name: str,
|
16 |
+
control_model_name: str,
|
17 |
+
images_root: str,
|
18 |
+
prompts_path: Optional[str] = None,
|
19 |
+
sd_model_name: Optional[str] = "runwayml/stable-diffusion-inpainting",
|
20 |
+
image_size: Optional[Tuple[int, int]] = (512, 512),
|
21 |
+
image_extensions: Optional[Tuple[str]] = (".jpg", ".jpeg", ".png", ".webp"),
|
22 |
+
segmentation_model_size: Optional[str] = "large",
|
23 |
+
):
|
24 |
+
self.segmentation_model = create_segmentation_model(
|
25 |
+
segmentation_model_name=segmentation_model_name,
|
26 |
+
model_size=segmentation_model_size,
|
27 |
+
)
|
28 |
+
|
29 |
+
self.diffusion_model = create_diffusion_model(
|
30 |
+
diffusion_model_name=diffusion_model_name,
|
31 |
+
control_model_name=control_model_name,
|
32 |
+
sd_model_name=sd_model_name,
|
33 |
+
)
|
34 |
+
|
35 |
+
self.data_loader = self.build_data_loader(
|
36 |
+
images_root=images_root,
|
37 |
+
prompts_path=prompts_path,
|
38 |
+
image_size=image_size,
|
39 |
+
image_extensions=image_extensions,
|
40 |
+
)
|
41 |
+
|
42 |
+
def build_data_loader(
|
43 |
+
self,
|
44 |
+
images_root: str,
|
45 |
+
prompts_path: Optional[str] = None,
|
46 |
+
image_size: Optional[Tuple[int, int]] = (512, 512),
|
47 |
+
image_extensions: Optional[Tuple[str]] = (".jpg", ".jpeg", ".png", ".webp"),
|
48 |
+
batch_size: Optional[int] = 1,
|
49 |
+
) -> DataLoader:
|
50 |
+
dataset = ImageFolderDataset(
|
51 |
+
images_root, prompts_path, image_size, image_extensions
|
52 |
+
)
|
53 |
+
data_loader = DataLoader(
|
54 |
+
dataset, batch_size=batch_size, shuffle=False, num_workers=8
|
55 |
+
)
|
56 |
+
|
57 |
+
return data_loader
|
58 |
+
|
59 |
+
def run(self, data_loader: Optional[DataLoader] = None) -> List[Dict[str, Any]]:
|
60 |
+
if data_loader is not None:
|
61 |
+
self.data_loader = data_loader
|
62 |
+
|
63 |
+
results = []
|
64 |
+
for idx, (images, prompts) in enumerate(self.data_loader):
|
65 |
+
images = [to_pil_image(img) for img in images]
|
66 |
+
|
67 |
+
semantic_maps = self.segmentation_model.process(images)
|
68 |
+
|
69 |
+
object_masks = [
|
70 |
+
get_object_mask(seg_map, class_id=0) for seg_map in semantic_maps
|
71 |
+
]
|
72 |
+
|
73 |
+
outputs = self.diffusion_model.process(
|
74 |
+
images=images,
|
75 |
+
prompts=[prompts[0]],
|
76 |
+
mask_images=object_masks,
|
77 |
+
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
|
78 |
+
)
|
79 |
+
results += outputs["output_images"]
|
80 |
+
|
81 |
+
return results
|
src/designgenie/utils/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .segmentation_utils import (
|
2 |
+
get_masked_images,
|
3 |
+
visualize_segmentation_map,
|
4 |
+
get_masks_from_segmentation_map,
|
5 |
+
)
|
6 |
+
from .helper import WandBLogger, parser
|
src/designgenie/utils/helper.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple, Dict, Union, Any, Optional
|
2 |
+
import argparse
|
3 |
+
from PIL import Image
|
4 |
+
import wandb
|
5 |
+
|
6 |
+
|
7 |
+
class WandBLogger:
|
8 |
+
def __init__(self, config: Dict[str, Any]):
|
9 |
+
assert "wandb_project" in config, "Missing `wandb_project` in config"
|
10 |
+
self.wandb = wandb.init(
|
11 |
+
project=config.wandb_project, name=config.exp_name, config=config
|
12 |
+
)
|
13 |
+
|
14 |
+
def log_scalars(self, logs: Dict[str, Union[int, float]]):
|
15 |
+
self.wandb.log(logs)
|
16 |
+
|
17 |
+
def log_images(self, logs: Dict[str, List[Image.Image]]):
|
18 |
+
wandb.log(
|
19 |
+
{
|
20 |
+
key: [wandb.Image(image, caption=key) for image in images]
|
21 |
+
for key, images in logs.items()
|
22 |
+
}
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
def parser():
|
27 |
+
parser = argparse.ArgumentParser()
|
28 |
+
parser.add_argument("--segmentation_model", type=str, default="mask2former")
|
29 |
+
parser.add_argument("--controlnet_name", type=str, default="hed")
|
30 |
+
parser.add_argument(
|
31 |
+
"--sd_model", type=str, default="runwayml/stable-diffusion-v1-5"
|
32 |
+
)
|
33 |
+
parser.add_argument(
|
34 |
+
"--images_path",
|
35 |
+
type=str,
|
36 |
+
default="/home/nader/DesignGenie/assets/images/",
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"--prompts_path",
|
40 |
+
type=str,
|
41 |
+
default="/home/nader/DesignGenie/assets/prompts.txt",
|
42 |
+
)
|
43 |
+
parser.add_argument(
|
44 |
+
"--negative_prompt",
|
45 |
+
type=str,
|
46 |
+
default="monochrome, lowres, bad anatomy, worst quality, low quality",
|
47 |
+
)
|
48 |
+
parser.add_argument("--num_inference_steps", type=int, default=20)
|
49 |
+
parser.add_argument("--n_outputs", type=int, default=4)
|
50 |
+
parser.add_argument("--wandb_project", type=str, default="DesignGenie")
|
51 |
+
parser.add_argument("--wandb", type=int, default=1)
|
52 |
+
parser.add_argument("--exp_name", type=str, default="demo")
|
53 |
+
args = parser.parse_args()
|
54 |
+
|
55 |
+
return args
|
src/designgenie/utils/segmentation_utils.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List, Optional, Tuple, Union
|
2 |
+
from functools import reduce
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import requests
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
from torchvision.transforms.functional import to_pil_image
|
10 |
+
|
11 |
+
|
12 |
+
def visualize_segmentation_map(
|
13 |
+
semantic_map: torch.Tensor, original_image: Image.Image
|
14 |
+
) -> Image.Image:
|
15 |
+
"""
|
16 |
+
Visualizes a segmentation map by overlaying it on the original image.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
semantic_map (torch.Tensor): Segmentation map tensor.
|
20 |
+
original_image (Image.Image): Original image.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
Image.Image: Overlay image with segmentation map.
|
24 |
+
"""
|
25 |
+
# Convert to RGB
|
26 |
+
color_seg = np.zeros(
|
27 |
+
(semantic_map.shape[0], semantic_map.shape[1], 3), dtype=np.uint8
|
28 |
+
) # height, width, 3
|
29 |
+
palette = np.array(ade_palette())
|
30 |
+
for label, color in enumerate(palette):
|
31 |
+
color_seg[semantic_map == label, :] = color
|
32 |
+
# Convert to BGR
|
33 |
+
color_seg = color_seg[..., ::-1]
|
34 |
+
|
35 |
+
# Show image + mask
|
36 |
+
img = np.array(original_image) * 0.5 + color_seg * 0.5
|
37 |
+
img = img.astype(np.uint8)
|
38 |
+
|
39 |
+
return Image.fromarray(img)
|
40 |
+
|
41 |
+
|
42 |
+
def get_masks_from_segmentation_map(
|
43 |
+
semantic_map: torch.Tensor,
|
44 |
+
) -> Tuple[List[np.array], List[int], List[str]]:
|
45 |
+
"""
|
46 |
+
Extracts masks, labels, and object names from a segmentation map.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
semantic_map (torch.Tensor): Segmentation map tensor.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
Tuple[List[np.array], List[int], List[str]]: Tuple containing masks, labels, and object names.
|
53 |
+
"""
|
54 |
+
masks = []
|
55 |
+
labels = []
|
56 |
+
obj_names = []
|
57 |
+
for label, color in enumerate(np.array(ade_palette())):
|
58 |
+
mask = np.ones(
|
59 |
+
(semantic_map.shape[0], semantic_map.shape[1]), dtype=np.uint8
|
60 |
+
) # height, width
|
61 |
+
indices = semantic_map == label
|
62 |
+
mask[indices] = 0
|
63 |
+
|
64 |
+
if indices.sum() > 0:
|
65 |
+
masks.append(mask)
|
66 |
+
labels.append(label)
|
67 |
+
obj_names.append(ADE_LABELS[str(label)])
|
68 |
+
|
69 |
+
return masks, labels, obj_names
|
70 |
+
|
71 |
+
|
72 |
+
def get_mask_from_coordinates(
|
73 |
+
segmentation_maps: List[np.array], coordinates: Tuple[int, int]
|
74 |
+
):
|
75 |
+
"""
|
76 |
+
Retrieves a mask from a list of segmentation maps based on given coordinates.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
segmentation_maps (List[np.array]): List of segmentation maps.
|
80 |
+
coordinates (Tuple[int, int]): Coordinates to filter the masks.
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
np.array: Combined mask from the segmentation maps.
|
84 |
+
"""
|
85 |
+
masks = []
|
86 |
+
for seg_map in segmentation_maps:
|
87 |
+
for coordinate in coordinates:
|
88 |
+
if seg_map[coordinate] == 0:
|
89 |
+
masks.append(seg_map)
|
90 |
+
|
91 |
+
return reduce(np.multiply, masks)
|
92 |
+
|
93 |
+
|
94 |
+
def get_masked_images(
|
95 |
+
control_image: Image.Image,
|
96 |
+
semantic_map: torch.Tensor,
|
97 |
+
coordinates: List[Tuple[int, int]],
|
98 |
+
return_tensors: bool = False,
|
99 |
+
) -> Union[torch.Tensor, Image.Image]:
|
100 |
+
"""
|
101 |
+
Retrieves masked images based on given control image, segmentation map, and coordinates.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
control_image (Image.Image): Control image.
|
105 |
+
semantic_map (torch.Tensor): Segmentation map tensor.
|
106 |
+
coordinates (List[Tuple[int, int]]): List of coordinates.
|
107 |
+
return_tensors (bool, optional): Whether to return masked images as tensors. Defaults to False.
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
Union[torch.Tensor, Image.Image]: Masked image tensor or PIL image.
|
111 |
+
"""
|
112 |
+
masks, labels, obj_names = get_masks_from_segmentation_map(semantic_map)
|
113 |
+
|
114 |
+
mask = get_mask_from_coordinates(masks, coordinates)
|
115 |
+
|
116 |
+
mask_image = np.logical_not(mask).astype(int)
|
117 |
+
mask_image = torch.Tensor(mask_image).repeat(3, 1, 1)
|
118 |
+
|
119 |
+
mask = torch.Tensor(mask).repeat(3, 1, 1)
|
120 |
+
control_image = transforms.ToTensor()(control_image)
|
121 |
+
masked_control_image = transforms.ToPILImage()(mask * control_image)
|
122 |
+
|
123 |
+
if not return_tensors:
|
124 |
+
mask_image = to_pil_image(mask_image)
|
125 |
+
|
126 |
+
return mask_image, masked_control_image
|
127 |
+
|
128 |
+
|
129 |
+
ADE_LABELS = requests.get(
|
130 |
+
"https://huggingface.co/datasets/huggingface/label-files/raw/main/ade20k-id2label.json"
|
131 |
+
).json()
|
132 |
+
|
133 |
+
|
134 |
+
def ade_palette():
|
135 |
+
"""ADE20K palette that maps each class to RGB values."""
|
136 |
+
return [
|
137 |
+
[120, 120, 120],
|
138 |
+
[180, 120, 120],
|
139 |
+
[6, 230, 230],
|
140 |
+
[80, 50, 50],
|
141 |
+
[4, 200, 3],
|
142 |
+
[120, 120, 80],
|
143 |
+
[140, 140, 140],
|
144 |
+
[204, 5, 255],
|
145 |
+
[230, 230, 230],
|
146 |
+
[4, 250, 7],
|
147 |
+
[224, 5, 255],
|
148 |
+
[235, 255, 7],
|
149 |
+
[150, 5, 61],
|
150 |
+
[120, 120, 70],
|
151 |
+
[8, 255, 51],
|
152 |
+
[255, 6, 82],
|
153 |
+
[143, 255, 140],
|
154 |
+
[204, 255, 4],
|
155 |
+
[255, 51, 7],
|
156 |
+
[204, 70, 3],
|
157 |
+
[0, 102, 200],
|
158 |
+
[61, 230, 250],
|
159 |
+
[255, 6, 51],
|
160 |
+
[11, 102, 255],
|
161 |
+
[255, 7, 71],
|
162 |
+
[255, 9, 224],
|
163 |
+
[9, 7, 230],
|
164 |
+
[220, 220, 220],
|
165 |
+
[255, 9, 92],
|
166 |
+
[112, 9, 255],
|
167 |
+
[8, 255, 214],
|
168 |
+
[7, 255, 224],
|
169 |
+
[255, 184, 6],
|
170 |
+
[10, 255, 71],
|
171 |
+
[255, 41, 10],
|
172 |
+
[7, 255, 255],
|
173 |
+
[224, 255, 8],
|
174 |
+
[102, 8, 255],
|
175 |
+
[255, 61, 6],
|
176 |
+
[255, 194, 7],
|
177 |
+
[255, 122, 8],
|
178 |
+
[0, 255, 20],
|
179 |
+
[255, 8, 41],
|
180 |
+
[255, 5, 153],
|
181 |
+
[6, 51, 255],
|
182 |
+
[235, 12, 255],
|
183 |
+
[160, 150, 20],
|
184 |
+
[0, 163, 255],
|
185 |
+
[140, 140, 140],
|
186 |
+
[250, 10, 15],
|
187 |
+
[20, 255, 0],
|
188 |
+
[31, 255, 0],
|
189 |
+
[255, 31, 0],
|
190 |
+
[255, 224, 0],
|
191 |
+
[153, 255, 0],
|
192 |
+
[0, 0, 255],
|
193 |
+
[255, 71, 0],
|
194 |
+
[0, 235, 255],
|
195 |
+
[0, 173, 255],
|
196 |
+
[31, 0, 255],
|
197 |
+
[11, 200, 200],
|
198 |
+
[255, 82, 0],
|
199 |
+
[0, 255, 245],
|
200 |
+
[0, 61, 255],
|
201 |
+
[0, 255, 112],
|
202 |
+
[0, 255, 133],
|
203 |
+
[255, 0, 0],
|
204 |
+
[255, 163, 0],
|
205 |
+
[255, 102, 0],
|
206 |
+
[194, 255, 0],
|
207 |
+
[0, 143, 255],
|
208 |
+
[51, 255, 0],
|
209 |
+
[0, 82, 255],
|
210 |
+
[0, 255, 41],
|
211 |
+
[0, 255, 173],
|
212 |
+
[10, 0, 255],
|
213 |
+
[173, 255, 0],
|
214 |
+
[0, 255, 153],
|
215 |
+
[255, 92, 0],
|
216 |
+
[255, 0, 255],
|
217 |
+
[255, 0, 245],
|
218 |
+
[255, 0, 102],
|
219 |
+
[255, 173, 0],
|
220 |
+
[255, 0, 20],
|
221 |
+
[255, 184, 184],
|
222 |
+
[0, 31, 255],
|
223 |
+
[0, 255, 61],
|
224 |
+
[0, 71, 255],
|
225 |
+
[255, 0, 204],
|
226 |
+
[0, 255, 194],
|
227 |
+
[0, 255, 82],
|
228 |
+
[0, 10, 255],
|
229 |
+
[0, 112, 255],
|
230 |
+
[51, 0, 255],
|
231 |
+
[0, 194, 255],
|
232 |
+
[0, 122, 255],
|
233 |
+
[0, 255, 163],
|
234 |
+
[255, 153, 0],
|
235 |
+
[0, 255, 10],
|
236 |
+
[255, 112, 0],
|
237 |
+
[143, 255, 0],
|
238 |
+
[82, 0, 255],
|
239 |
+
[163, 255, 0],
|
240 |
+
[255, 235, 0],
|
241 |
+
[8, 184, 170],
|
242 |
+
[133, 0, 255],
|
243 |
+
[0, 255, 92],
|
244 |
+
[184, 0, 255],
|
245 |
+
[255, 0, 31],
|
246 |
+
[0, 184, 255],
|
247 |
+
[0, 214, 255],
|
248 |
+
[255, 0, 112],
|
249 |
+
[92, 255, 0],
|
250 |
+
[0, 224, 255],
|
251 |
+
[112, 224, 255],
|
252 |
+
[70, 184, 160],
|
253 |
+
[163, 0, 255],
|
254 |
+
[153, 0, 255],
|
255 |
+
[71, 255, 0],
|
256 |
+
[255, 0, 163],
|
257 |
+
[255, 204, 0],
|
258 |
+
[255, 0, 143],
|
259 |
+
[0, 255, 235],
|
260 |
+
[133, 255, 0],
|
261 |
+
[255, 0, 235],
|
262 |
+
[245, 0, 255],
|
263 |
+
[255, 0, 122],
|
264 |
+
[255, 245, 0],
|
265 |
+
[10, 190, 212],
|
266 |
+
[214, 255, 0],
|
267 |
+
[0, 204, 255],
|
268 |
+
[20, 0, 255],
|
269 |
+
[255, 255, 0],
|
270 |
+
[0, 153, 255],
|
271 |
+
[0, 41, 255],
|
272 |
+
[0, 255, 204],
|
273 |
+
[41, 0, 255],
|
274 |
+
[41, 255, 0],
|
275 |
+
[173, 0, 255],
|
276 |
+
[0, 245, 255],
|
277 |
+
[71, 0, 255],
|
278 |
+
[122, 0, 255],
|
279 |
+
[0, 255, 184],
|
280 |
+
[0, 92, 255],
|
281 |
+
[184, 255, 0],
|
282 |
+
[0, 133, 255],
|
283 |
+
[255, 214, 0],
|
284 |
+
[25, 194, 194],
|
285 |
+
[102, 255, 0],
|
286 |
+
[92, 0, 255],
|
287 |
+
]
|