Spaces:
Sleeping
Sleeping
Commit
·
302f20e
1
Parent(s):
920ece4
up
Browse files- .idea/.gitignore +0 -8
- .idea/RemoveFurnitureV1.iml +0 -8
- .idea/inspectionProfiles/Project_Default.xml +0 -15
- .idea/inspectionProfiles/profiles_settings.xml +0 -6
- .idea/misc.xml +0 -7
- .idea/modules.xml +0 -8
- .idea/vcs.xml +0 -6
- app.py +39 -65
- explanation.py +0 -51
- models.py +3 -7
- pipelines.py +2 -3
- preprocessing.py +3 -83
- segmentation.py +1 -2
- utils.py +27 -0
.idea/.gitignore
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
# Default ignored files
|
2 |
-
/shelf/
|
3 |
-
/workspace.xml
|
4 |
-
# Editor-based HTTP Client requests
|
5 |
-
/httpRequests/
|
6 |
-
# Datasource local storage ignored files
|
7 |
-
/dataSources/
|
8 |
-
/dataSources.local.xml
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.idea/RemoveFurnitureV1.iml
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
-
<module type="PYTHON_MODULE" version="4">
|
3 |
-
<component name="NewModuleRootManager">
|
4 |
-
<content url="file://$MODULE_DIR$" />
|
5 |
-
<orderEntry type="inheritedJdk" />
|
6 |
-
<orderEntry type="sourceFolder" forTests="false" />
|
7 |
-
</component>
|
8 |
-
</module>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.idea/inspectionProfiles/Project_Default.xml
DELETED
@@ -1,15 +0,0 @@
|
|
1 |
-
<component name="InspectionProjectProfileManager">
|
2 |
-
<profile version="1.0">
|
3 |
-
<option name="myName" value="Project Default" />
|
4 |
-
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
5 |
-
<option name="ignoredPackages">
|
6 |
-
<value>
|
7 |
-
<list size="2">
|
8 |
-
<item index="0" class="java.lang.String" itemvalue="opencv_python" />
|
9 |
-
<item index="1" class="java.lang.String" itemvalue="skimage" />
|
10 |
-
</list>
|
11 |
-
</value>
|
12 |
-
</option>
|
13 |
-
</inspection_tool>
|
14 |
-
</profile>
|
15 |
-
</component>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.idea/inspectionProfiles/profiles_settings.xml
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
<component name="InspectionProjectProfileManager">
|
2 |
-
<settings>
|
3 |
-
<option name="USE_PROJECT_PROFILE" value="false" />
|
4 |
-
<version value="1.0" />
|
5 |
-
</settings>
|
6 |
-
</component>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.idea/misc.xml
DELETED
@@ -1,7 +0,0 @@
|
|
1 |
-
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
-
<project version="4">
|
3 |
-
<component name="Black">
|
4 |
-
<option name="sdkName" value="Python 3.12" />
|
5 |
-
</component>
|
6 |
-
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12" project-jdk-type="Python SDK" />
|
7 |
-
</project>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.idea/modules.xml
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
-
<project version="4">
|
3 |
-
<component name="ProjectModuleManager">
|
4 |
-
<modules>
|
5 |
-
<module fileurl="file://$PROJECT_DIR$/.idea/RemoveFurnitureV1.iml" filepath="$PROJECT_DIR$/.idea/RemoveFurnitureV1.iml" />
|
6 |
-
</modules>
|
7 |
-
</component>
|
8 |
-
</project>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.idea/vcs.xml
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
-
<project version="4">
|
3 |
-
<component name="VcsDirectoryMappings">
|
4 |
-
<mapping directory="" vcs="Git" />
|
5 |
-
</component>
|
6 |
-
</project>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -1,22 +1,9 @@
|
|
1 |
import gradio as gr
|
2 |
-
import io
|
3 |
-
from PIL import Image
|
4 |
import numpy as np
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
from models import make_image_controlnet, make_inpainting
|
8 |
-
from preprocessing import get_mask
|
9 |
-
|
10 |
-
def image_to_byte_array(image: Image) -> bytes:
|
11 |
-
# BytesIO is a fake file stored in memory
|
12 |
-
imgByteArr = io.BytesIO()
|
13 |
-
# image.save expects a file as a argument, passing a bytes io ins
|
14 |
-
image.save(imgByteArr, format='png') # image.format
|
15 |
-
# Turn the BytesIO object back into a bytes object
|
16 |
-
imgByteArr = imgByteArr.getvalue()
|
17 |
-
return imgByteArr
|
18 |
-
|
19 |
-
def predict(input_img1,
|
20 |
input_img2,
|
21 |
positive_prompt,
|
22 |
negative_prompt,
|
@@ -24,24 +11,15 @@ def predict(input_img1,
|
|
24 |
resolution
|
25 |
):
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
print("predict")
|
30 |
-
# bla bla
|
31 |
-
# input_img1 = Image.fromarray(input_img1)
|
32 |
-
# input_img2 = Image.fromarray(input_img2)
|
33 |
-
|
34 |
-
# setResoluton(resolution)
|
35 |
HEIGHT = resolution
|
36 |
WIDTH = resolution
|
37 |
-
# WIDTH = resolution
|
38 |
-
# HEIGHT = resolution
|
39 |
|
40 |
input_img1 = input_img1.resize((resolution, resolution))
|
41 |
input_img2 = input_img2.resize((resolution, resolution))
|
42 |
|
43 |
canvas_mask = np.array(input_img2)
|
44 |
-
mask = get_mask(canvas_mask)
|
45 |
|
46 |
print(input_img1, mask, positive_prompt, negative_prompt)
|
47 |
|
@@ -58,41 +36,37 @@ def predict(input_img1,
|
|
58 |
|
59 |
return retList
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
)
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
# ).launch(share=True)
|
96 |
-
|
97 |
-
|
98 |
-
|
|
|
1 |
import gradio as gr
|
|
|
|
|
2 |
import numpy as np
|
3 |
+
from models import make_inpainting
|
4 |
+
import utils
|
5 |
|
6 |
+
def removeFurniture(input_img1,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
input_img2,
|
8 |
positive_prompt,
|
9 |
negative_prompt,
|
|
|
11 |
resolution
|
12 |
):
|
13 |
|
14 |
+
print("removeFurniture")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
HEIGHT = resolution
|
16 |
WIDTH = resolution
|
|
|
|
|
17 |
|
18 |
input_img1 = input_img1.resize((resolution, resolution))
|
19 |
input_img2 = input_img2.resize((resolution, resolution))
|
20 |
|
21 |
canvas_mask = np.array(input_img2)
|
22 |
+
mask = utils.get_mask(canvas_mask)
|
23 |
|
24 |
print(input_img1, mask, positive_prompt, negative_prompt)
|
25 |
|
|
|
36 |
|
37 |
return retList
|
38 |
|
39 |
+
def segmentation(image):
|
40 |
+
return image
|
41 |
+
|
42 |
+
def upscale(image):
|
43 |
+
return image
|
44 |
+
|
45 |
+
with gr.Blocks() as app:
|
46 |
+
with gr.Row():
|
47 |
+
|
48 |
+
gr.Button("FurnituRemove").click(removeFurniture,
|
49 |
+
inputs=[gr.Image(label="img", type="pil"),
|
50 |
+
gr.Image(label="mask", type="pil"),
|
51 |
+
gr.Textbox(label="positive_prompt",value="empty room"),
|
52 |
+
gr.Textbox(label="negative_prompt",value=""),
|
53 |
+
gr.Number(label="num_of_images",value=2),
|
54 |
+
gr.Number(label="resolution",value=512)
|
55 |
+
],
|
56 |
+
outputs=[
|
57 |
+
gr.Image(),
|
58 |
+
gr.Image(),
|
59 |
+
gr.Image(),
|
60 |
+
gr.Image(),
|
61 |
+
gr.Image(),
|
62 |
+
gr.Image(),
|
63 |
+
gr.Image(),
|
64 |
+
gr.Image(),
|
65 |
+
gr.Image(),
|
66 |
+
gr.Image()])
|
67 |
+
|
68 |
+
gr.Button("Segmentation").click(segmentation, inputs=gr.Image(type="pil"), outputs=gr.Image())
|
69 |
+
|
70 |
+
gr.Button("Upscale").click(upscale, inputs=gr.Image(type="pil"), outputs=gr.Image())
|
71 |
+
|
72 |
+
app.launch(debug=True)
|
|
|
|
|
|
|
|
explanation.py
DELETED
@@ -1,51 +0,0 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
|
3 |
-
def make_inpainting_explanation():
|
4 |
-
with st.expander("Explanation inpainting", expanded=False):
|
5 |
-
st.write("In the inpainting mode, you can draw regions on the input image that you want to regenerate. "
|
6 |
-
"This can be useful to remove unwanted objects from the image or to improve the consistency of the image."
|
7 |
-
)
|
8 |
-
st.image("content/inpainting_sidebar.png", caption="Image before inpainting, note the ornaments on the wall", width=500)
|
9 |
-
st.write("You can find drawing options in the sidebar. There are two modes: freedraw and polygon. Freedraw allows the user to draw with a pencil of a certain width. "
|
10 |
-
"Polygon allows the user to draw a polygon by clicking on the image to add a point. The polygon is closed by right clicking.")
|
11 |
-
|
12 |
-
st.write("### Example inpainting")
|
13 |
-
st.write("In the example below, the ornaments on the wall are removed. The inpainting is done by drawing a mask on the image.")
|
14 |
-
st.image("content/inpainting_before.jpg", caption="Image before inpainting, note the ornaments on the wall")
|
15 |
-
st.image("content/inpainting_after.png", caption="Image before inpainting, note the ornaments on the wall")
|
16 |
-
|
17 |
-
def make_regeneration_explanation():
|
18 |
-
with st.expander("Explanation object regeneration"):
|
19 |
-
st.write("In this object regeneration mode, the model calculates which objects occur in the image. "
|
20 |
-
"The user can then select which objects can be regenerated by the controlnet model by adding them in the multiselect box. "
|
21 |
-
"All the object classes that are not selected will remain the same as in the original image."
|
22 |
-
)
|
23 |
-
st.write("### Example object regeneration")
|
24 |
-
st.write("In the example below, the room consists of various objects such as wall, ceiling, floor, lamp, bed, ... "
|
25 |
-
"In the multiselect box, all the objects except for 'lamp', 'bed and 'table' are selected to be regenerated. "
|
26 |
-
)
|
27 |
-
st.image("content/regen_example.png", caption="Room where all concepts except for 'bed', 'lamp', 'table' are regenerated")
|
28 |
-
|
29 |
-
def make_segmentation_explanation():
|
30 |
-
with st.expander("Segmentation mode", expanded=False):
|
31 |
-
st.write("In the segmentation mode, the user can use his imagination and the paint brush to place concepts in the image. "
|
32 |
-
"In the left sidebar, you can first find the high level category of the concept you want to add, such as 'lighting', 'floor', .. "
|
33 |
-
"After selecting the category, you can select the specific concept you want to add in the 'Choose a color' dropdown. "
|
34 |
-
"This will change the color of the paint brush, which you can then use to draw on the input image. "
|
35 |
-
"The model will then regenerate the image with the concepts you have drawn and leave the rest of the image unchanged. "
|
36 |
-
)
|
37 |
-
st.image("content/sidebar segmentation.png", caption="Sidebar with segmentation options", width=300)
|
38 |
-
st.write("You can choose the freedraw mode which gives you a pencil of a certain (chosen) width or the polygon mode. With the polygon mode you can click to add a point to the polygon and close the polygon by right clicking. ")
|
39 |
-
st.write("Important: "
|
40 |
-
"it's not easy to draw a good segmentation mask. This is because you need to keep in mind the perspective of the room and the exact "
|
41 |
-
"shape of the object you want to draw within this perspective. Controlnet will follow your segmentation mask pretty well, so "
|
42 |
-
"a non-natural object shape will sometimes result in weird outputs. However, give it a try and see what you can do! "
|
43 |
-
)
|
44 |
-
st.image("content/segmentation window.png", caption="Example of a segmentation mask drawn on the input image to add a window to the room")
|
45 |
-
st.write("Tip: ")
|
46 |
-
st.write("In the concepts dropdown, you can select 'keep background' (which is a white color). Everything drawn in this color will use "
|
47 |
-
"the original underlying segmentation mask. This can be useful to help with generating other objects, since you give the model a some "
|
48 |
-
"freedom to generate outside the object borders."
|
49 |
-
)
|
50 |
-
st.image("content/keep background 1.png", caption="Image with a poster drawn on the wall.")
|
51 |
-
st.image("content/keep background 2.png", caption="Image with a poster drawn on the wall surrounded by 'keep background'.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models.py
CHANGED
@@ -2,17 +2,13 @@
|
|
2 |
import logging
|
3 |
from typing import List, Tuple, Dict
|
4 |
|
5 |
-
|
6 |
import torch
|
7 |
-
import gc
|
8 |
-
import time
|
9 |
import numpy as np
|
10 |
from PIL import Image
|
11 |
-
from PIL import ImageFilter
|
12 |
|
13 |
from diffusers import ControlNetModel, UniPCMultistepScheduler
|
14 |
|
15 |
-
# from config import WIDTH, HEIGHT
|
16 |
from palette import ade_palette
|
17 |
from stable_diffusion_controlnet_inpaint_img2img import StableDiffusionControlNetInpaintImg2ImgPipeline
|
18 |
from helpers import flush, postprocess_image_masking, convolution
|
@@ -48,7 +44,7 @@ def make_image_controlnet(image: np.ndarray,
|
|
48 |
mask_image_postproc = convolution(mask_image)
|
49 |
|
50 |
|
51 |
-
|
52 |
generated_image = pipe(
|
53 |
prompt=positive_prompt,
|
54 |
negative_prompt=negative_prompt,
|
@@ -90,7 +86,7 @@ def make_inpainting(positive_prompt: str,
|
|
90 |
flush()
|
91 |
retList=[]
|
92 |
for x in range(num_of_images):
|
93 |
-
|
94 |
resp = pipe(image=image,
|
95 |
mask_image=mask_image,
|
96 |
prompt=positive_prompt,
|
|
|
2 |
import logging
|
3 |
from typing import List, Tuple, Dict
|
4 |
|
5 |
+
|
6 |
import torch
|
|
|
|
|
7 |
import numpy as np
|
8 |
from PIL import Image
|
|
|
9 |
|
10 |
from diffusers import ControlNetModel, UniPCMultistepScheduler
|
11 |
|
|
|
12 |
from palette import ade_palette
|
13 |
from stable_diffusion_controlnet_inpaint_img2img import StableDiffusionControlNetInpaintImg2ImgPipeline
|
14 |
from helpers import flush, postprocess_image_masking, convolution
|
|
|
44 |
mask_image_postproc = convolution(mask_image)
|
45 |
|
46 |
|
47 |
+
|
48 |
generated_image = pipe(
|
49 |
prompt=positive_prompt,
|
50 |
negative_prompt=negative_prompt,
|
|
|
86 |
flush()
|
87 |
retList=[]
|
88 |
for x in range(num_of_images):
|
89 |
+
|
90 |
resp = pipe(image=image,
|
91 |
mask_image=mask_image,
|
92 |
prompt=positive_prompt,
|
pipelines.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import logging
|
2 |
from typing import List, Tuple, Dict
|
3 |
|
4 |
-
|
5 |
import torch
|
6 |
import gc
|
7 |
import time
|
@@ -105,7 +105,7 @@ class SDPipeline:
|
|
105 |
|
106 |
|
107 |
|
108 |
-
|
109 |
def get_controlnet():
|
110 |
"""Method to load the controlnet model
|
111 |
Returns:
|
@@ -116,7 +116,6 @@ def get_controlnet():
|
|
116 |
|
117 |
|
118 |
|
119 |
-
@st.cache_resource(max_entries=5)
|
120 |
def get_inpainting_pipeline():
|
121 |
"""Method to load the inpainting pipeline
|
122 |
Returns:
|
|
|
1 |
import logging
|
2 |
from typing import List, Tuple, Dict
|
3 |
|
4 |
+
|
5 |
import torch
|
6 |
import gc
|
7 |
import time
|
|
|
105 |
|
106 |
|
107 |
|
108 |
+
|
109 |
def get_controlnet():
|
110 |
"""Method to load the controlnet model
|
111 |
Returns:
|
|
|
116 |
|
117 |
|
118 |
|
|
|
119 |
def get_inpainting_pipeline():
|
120 |
"""Method to load the inpainting pipeline
|
121 |
Returns:
|
preprocessing.py
CHANGED
@@ -4,7 +4,7 @@ from typing import List, Tuple
|
|
4 |
|
5 |
import numpy as np
|
6 |
from PIL import Image, ImageFilter
|
7 |
-
|
8 |
|
9 |
from config import COLOR_RGB
|
10 |
# from enhance_config import ENHANCE_SETTINGS
|
@@ -51,85 +51,5 @@ def preprocess_seg_mask(canvas_seg, real_seg: Image.Image = None) -> Tuple[np.nd
|
|
51 |
return mask, image_seg
|
52 |
|
53 |
|
54 |
-
|
55 |
-
|
56 |
-
Args:
|
57 |
-
image_mask (np.ndarray): segmentation mask
|
58 |
-
Returns:
|
59 |
-
np.ndarray: mask
|
60 |
-
"""
|
61 |
-
# average the colors of the segmentation masks
|
62 |
-
average_color = np.mean(image_mask, axis=(2))
|
63 |
-
mask = average_color[:, :] > 0
|
64 |
-
if mask.sum() > 0:
|
65 |
-
mask = mask * 1
|
66 |
-
return mask
|
67 |
-
|
68 |
-
|
69 |
-
# def get_image() -> np.ndarray:
|
70 |
-
#
|
71 |
-
# """Get the image from the session state.
|
72 |
-
# Returns:
|
73 |
-
# np.ndarray: image
|
74 |
-
# """
|
75 |
-
# if 'initial_image' in st.session_state and st.session_state['initial_image'] is not None:
|
76 |
-
# initial_image = st.session_state['initial_image']
|
77 |
-
# if isinstance(initial_image, Image.Image):
|
78 |
-
# return np.array(initial_image.resize((WIDTH, HEIGHT)))
|
79 |
-
# else:
|
80 |
-
# return np.array(Image.fromarray(initial_image).resize((WIDTH, HEIGHT)))
|
81 |
-
# else:
|
82 |
-
# return None
|
83 |
-
|
84 |
-
|
85 |
-
# def make_enhance_config(segmentation, objects=None):
|
86 |
-
"""Make the enhance config for the segmentation image.
|
87 |
-
"""
|
88 |
-
info = ENHANCE_SETTINGS[objects]
|
89 |
-
|
90 |
-
segmentation = np.array(segmentation)
|
91 |
-
|
92 |
-
if 'replace' in info:
|
93 |
-
replace_color = info['replace']
|
94 |
-
mask = np.zeros(segmentation.shape)
|
95 |
-
for color in info['colors']:
|
96 |
-
mask[np.all(segmentation == color, axis=-1)] = [1, 1, 1]
|
97 |
-
segmentation[np.all(segmentation == color, axis=-1)] = replace_color
|
98 |
-
|
99 |
-
if info['inverse'] is False:
|
100 |
-
mask = np.zeros(segmentation.shape)
|
101 |
-
for color in info['colors']:
|
102 |
-
mask[np.all(segmentation == color, axis=-1)] = [1, 1, 1]
|
103 |
-
else:
|
104 |
-
mask = np.ones(segmentation.shape)
|
105 |
-
for color in info['colors']:
|
106 |
-
mask[np.all(segmentation == color, axis=-1)] = [0, 0, 0]
|
107 |
-
|
108 |
-
st.session_state['positive_prompt'] = info['positive_prompt']
|
109 |
-
st.session_state['negative_prompt'] = info['negative_prompt']
|
110 |
-
|
111 |
-
if info['inpainting'] is True:
|
112 |
-
mask = mask.astype(np.uint8)
|
113 |
-
mask = Image.fromarray(mask)
|
114 |
-
mask = mask.filter(ImageFilter.GaussianBlur(radius=13))
|
115 |
-
mask = mask.filter(ImageFilter.MaxFilter(size=9))
|
116 |
-
mask = np.array(mask)
|
117 |
-
|
118 |
-
mask[mask < 0.1] = 0
|
119 |
-
mask[mask >= 0.1] = 1
|
120 |
-
mask = mask.astype(np.uint8)
|
121 |
-
|
122 |
-
conditioning = dict(
|
123 |
-
mask_image=mask,
|
124 |
-
positive_prompt=info['positive_prompt'],
|
125 |
-
negative_prompt=info['negative_prompt'],
|
126 |
-
)
|
127 |
-
else:
|
128 |
-
conditioning = dict(
|
129 |
-
mask_image=mask,
|
130 |
-
controlnet_conditioning_image=segmentation,
|
131 |
-
positive_prompt=info['positive_prompt'],
|
132 |
-
negative_prompt=info['negative_prompt'],
|
133 |
-
strength=info['strength']
|
134 |
-
)
|
135 |
-
return conditioning, info['inpainting']
|
|
|
4 |
|
5 |
import numpy as np
|
6 |
from PIL import Image, ImageFilter
|
7 |
+
|
8 |
|
9 |
from config import COLOR_RGB
|
10 |
# from enhance_config import ENHANCE_SETTINGS
|
|
|
51 |
return mask, image_seg
|
52 |
|
53 |
|
54 |
+
|
55 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
segmentation.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import logging
|
2 |
from typing import List, Tuple, Dict
|
3 |
|
4 |
-
import streamlit as st
|
5 |
import torch
|
6 |
import gc
|
7 |
import numpy as np
|
@@ -18,7 +17,7 @@ def flush():
|
|
18 |
gc.collect()
|
19 |
torch.cuda.empty_cache()
|
20 |
|
21 |
-
|
22 |
def get_segmentation_pipeline() -> Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]:
|
23 |
"""Method to load the segmentation pipeline
|
24 |
Returns:
|
|
|
1 |
import logging
|
2 |
from typing import List, Tuple, Dict
|
3 |
|
|
|
4 |
import torch
|
5 |
import gc
|
6 |
import numpy as np
|
|
|
17 |
gc.collect()
|
18 |
torch.cuda.empty_cache()
|
19 |
|
20 |
+
|
21 |
def get_segmentation_pipeline() -> Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]:
|
22 |
"""Method to load the segmentation pipeline
|
23 |
Returns:
|
utils.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
def image_to_byte_array(image: Image) -> bytes:
|
6 |
+
# BytesIO is a fake file stored in memory
|
7 |
+
imgByteArr = io.BytesIO()
|
8 |
+
# image.save expects a file as a argument, passing a bytes io ins
|
9 |
+
image.save(imgByteArr, format='png') # image.format
|
10 |
+
# Turn the BytesIO object back into a bytes object
|
11 |
+
imgByteArr = imgByteArr.getvalue()
|
12 |
+
return imgByteArr
|
13 |
+
|
14 |
+
|
15 |
+
def get_mask(image_mask: np.ndarray) -> np.ndarray:
|
16 |
+
"""Get the mask from the segmentation mask.
|
17 |
+
Args:
|
18 |
+
image_mask (np.ndarray): segmentation mask
|
19 |
+
Returns:
|
20 |
+
np.ndarray: mask
|
21 |
+
"""
|
22 |
+
# average the colors of the segmentation masks
|
23 |
+
average_color = np.mean(image_mask, axis=(2))
|
24 |
+
mask = average_color[:, :] > 0
|
25 |
+
if mask.sum() > 0:
|
26 |
+
mask = mask * 1
|
27 |
+
return mask
|