Spaces:
Sleeping
Sleeping
app
Browse files
app.py
CHANGED
@@ -9,48 +9,53 @@ import threading
|
|
9 |
from queue import Queue
|
10 |
import os
|
11 |
|
|
|
12 |
image_queue = Queue()
|
13 |
sampling_queue = Queue()
|
14 |
|
15 |
-
|
16 |
st.title("Mask Your Own Inpaint")
|
17 |
|
18 |
@st.cache_data
|
19 |
def load_images():
|
|
|
20 |
data = np.load("data/sflckr_all_images.npz")
|
21 |
-
images = data["images"]
|
22 |
return images
|
23 |
|
|
|
24 |
if "random_idx" not in st.session_state:
|
25 |
st.session_state.random_idx = None
|
26 |
|
|
|
27 |
images = load_images()
|
28 |
if st.button("Random Pick"):
|
29 |
st.session_state.random_idx = random.randint(0, images.shape[0] - 1)
|
30 |
|
31 |
def make_square(img, target_size=300):
|
32 |
-
|
|
|
33 |
background = Image.new("RGB", (size, size), (255, 255, 255))
|
34 |
offset = ((size - img.size[0]) // 2, (size - img.size[1]) // 2)
|
35 |
background.paste(img, offset)
|
36 |
return background.resize((target_size, target_size))
|
37 |
|
38 |
def run_inpainting(random_idx, mask_array, image_queue, sampling_queue):
|
|
|
39 |
vipainting.vipaint(random_idx, mask_array, image_queue, sampling_queue)
|
40 |
|
41 |
-
|
42 |
if st.session_state.random_idx is not None:
|
43 |
img_array = images[st.session_state.random_idx]
|
44 |
-
|
45 |
img_pil = Image.fromarray(img_array)
|
46 |
-
img_pil = make_square(img_pil, target_size=
|
47 |
-
|
48 |
|
|
|
49 |
col1, col2 = st.columns(2)
|
50 |
with col1:
|
51 |
st.write("Draw your mask on the image below:")
|
52 |
canvas_result = st_canvas(
|
53 |
-
fill_color="rgba(255, 0, 0, 0.3)",
|
54 |
stroke_width=50,
|
55 |
stroke_color="black",
|
56 |
background_image=img_pil,
|
@@ -61,66 +66,63 @@ if st.session_state.random_idx is not None:
|
|
61 |
key="canvas"
|
62 |
)
|
63 |
|
64 |
-
|
65 |
if canvas_result.image_data is not None:
|
66 |
-
mask = canvas_result.image_data[:, :, 3]
|
67 |
binary_mask = (mask > 128).astype(np.uint8) * 255
|
68 |
|
|
|
69 |
with col2:
|
70 |
st.write("Masked Image")
|
71 |
-
st.image(binary_mask, caption="Binary Mask", width=300)
|
72 |
|
73 |
-
mask_image = Image.fromarray(binary_mask)
|
74 |
mask_array = 255 - np.array(mask_image)
|
75 |
-
mask_array = np.expand_dims(mask_array, axis=(0, 1))
|
76 |
|
|
|
77 |
if st.button("inpaint"):
|
78 |
st.write("Please wait...")
|
|
|
|
|
79 |
inpaint_thread = threading.Thread(target=run_inpainting, args=(st.session_state.random_idx, mask_array, image_queue, sampling_queue))
|
80 |
inpaint_thread.start()
|
81 |
|
82 |
-
|
|
|
83 |
img_left_placeholder = img_left.empty()
|
84 |
img_right_placeholder = img_right.empty()
|
85 |
-
|
86 |
-
img_left_placeholder.image(img_pil, caption=f"True Image", width=300)
|
87 |
-
seg_image_path = f"results/{st.session_state.random_idx}/input.png"
|
88 |
-
|
89 |
-
while True:
|
90 |
-
if os.path.exists(seg_image_path):
|
91 |
-
with img_right:
|
92 |
-
img_right_image = Image.open(seg_image_path)
|
93 |
-
img_right_placeholder.image(img_right_image, caption="Segmentation Map", width=300)
|
94 |
-
break
|
95 |
-
time.sleep(0.5)
|
96 |
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
-
#
|
99 |
expected_updates = 100
|
100 |
progress_bar = st.progress(0)
|
101 |
st.write("Fitting in progress")
|
102 |
displayed_images = 0
|
103 |
|
|
|
104 |
col_left, col_right = st.columns(2)
|
105 |
left_placeholder = col_left.empty()
|
106 |
right_placeholder = col_right.empty()
|
107 |
|
108 |
-
|
109 |
while displayed_images < expected_updates:
|
110 |
if not image_queue.empty():
|
111 |
-
img = image_queue.get()
|
112 |
-
|
113 |
if displayed_images % 2 == 0:
|
114 |
left_placeholder.image(img, caption=f"Progress Update {displayed_images + 1}", width=300)
|
115 |
else:
|
116 |
right_placeholder.image(img, caption=f"Progress Update {displayed_images + 1}", width=300)
|
117 |
-
|
118 |
-
# Update progress bar
|
119 |
displayed_images += 1
|
120 |
progress_bar.progress(displayed_images / expected_updates)
|
121 |
-
|
122 |
time.sleep(0.3)
|
123 |
|
|
|
124 |
expected_updates = 10
|
125 |
s_progress_bar = st.progress(0)
|
126 |
displayed_images = 0
|
@@ -128,19 +130,18 @@ if st.session_state.random_idx is not None:
|
|
128 |
sample_left, sample_right = st.columns(2)
|
129 |
sleft_placeholder = sample_left.empty()
|
130 |
sright_placeholder = sample_right.empty()
|
|
|
131 |
while displayed_images < expected_updates:
|
132 |
if not sampling_queue.empty():
|
133 |
-
img = sampling_queue.get()
|
134 |
-
|
135 |
if displayed_images % 2 == 0:
|
136 |
sleft_placeholder.image(img, caption=f"Sampling Update {displayed_images + 1}", width=300)
|
137 |
else:
|
138 |
sright_placeholder.image(img, caption=f"Sampling Update {displayed_images + 1}", width=300)
|
139 |
-
|
140 |
displayed_images += 1
|
141 |
s_progress_bar.progress(displayed_images / expected_updates)
|
142 |
-
|
143 |
time.sleep(0.3)
|
144 |
|
|
|
145 |
inpaint_thread.join()
|
146 |
st.success("Inpainting completed!")
|
|
|
9 |
from queue import Queue
|
10 |
import os
|
11 |
|
12 |
+
# Initialize queues to handle in-progress images for fitting and sampling
|
13 |
image_queue = Queue()
|
14 |
sampling_queue = Queue()
|
15 |
|
16 |
+
# Set up the app title
|
17 |
st.title("Mask Your Own Inpaint")
|
18 |
|
19 |
@st.cache_data
|
20 |
def load_images():
|
21 |
+
"""Loads images from an NPZ file only once to cache."""
|
22 |
data = np.load("data/sflckr_all_images.npz")
|
23 |
+
images = data["images"]
|
24 |
return images
|
25 |
|
26 |
+
# Initialize session state to hold the selected random image index
|
27 |
if "random_idx" not in st.session_state:
|
28 |
st.session_state.random_idx = None
|
29 |
|
30 |
+
# Load images and add button to randomly select one
|
31 |
images = load_images()
|
32 |
if st.button("Random Pick"):
|
33 |
st.session_state.random_idx = random.randint(0, images.shape[0] - 1)
|
34 |
|
35 |
def make_square(img, target_size=300):
|
36 |
+
"""Pads and resizes the image to a square."""
|
37 |
+
size = max(img.size)
|
38 |
background = Image.new("RGB", (size, size), (255, 255, 255))
|
39 |
offset = ((size - img.size[0]) // 2, (size - img.size[1]) // 2)
|
40 |
background.paste(img, offset)
|
41 |
return background.resize((target_size, target_size))
|
42 |
|
43 |
def run_inpainting(random_idx, mask_array, image_queue, sampling_queue):
|
44 |
+
"""Starts inpainting and sends images to the queue."""
|
45 |
vipainting.vipaint(random_idx, mask_array, image_queue, sampling_queue)
|
46 |
|
47 |
+
# Only proceed if a random image has been selected
|
48 |
if st.session_state.random_idx is not None:
|
49 |
img_array = images[st.session_state.random_idx]
|
|
|
50 |
img_pil = Image.fromarray(img_array)
|
51 |
+
img_pil = make_square(img_pil, target_size=512) # Resized to 512x512
|
|
|
52 |
|
53 |
+
# Set up drawing canvas and display columns
|
54 |
col1, col2 = st.columns(2)
|
55 |
with col1:
|
56 |
st.write("Draw your mask on the image below:")
|
57 |
canvas_result = st_canvas(
|
58 |
+
fill_color="rgba(255, 0, 0, 0.3)",
|
59 |
stroke_width=50,
|
60 |
stroke_color="black",
|
61 |
background_image=img_pil,
|
|
|
66 |
key="canvas"
|
67 |
)
|
68 |
|
69 |
+
# Process the mask if drawn
|
70 |
if canvas_result.image_data is not None:
|
71 |
+
mask = canvas_result.image_data[:, :, 3]
|
72 |
binary_mask = (mask > 128).astype(np.uint8) * 255
|
73 |
|
74 |
+
# Show the binary mask in the right column
|
75 |
with col2:
|
76 |
st.write("Masked Image")
|
77 |
+
st.image(binary_mask, caption="Binary Mask", width=300)
|
78 |
|
79 |
+
mask_image = Image.fromarray(binary_mask)
|
80 |
mask_array = 255 - np.array(mask_image)
|
81 |
+
mask_array = np.expand_dims(mask_array, axis=(0, 1))
|
82 |
|
83 |
+
# Inpaint button action
|
84 |
if st.button("inpaint"):
|
85 |
st.write("Please wait...")
|
86 |
+
|
87 |
+
# Start inpainting in a separate thread
|
88 |
inpaint_thread = threading.Thread(target=run_inpainting, args=(st.session_state.random_idx, mask_array, image_queue, sampling_queue))
|
89 |
inpaint_thread.start()
|
90 |
|
91 |
+
# Display initial image and segmentation map
|
92 |
+
img_left, img_right = st.columns(2)
|
93 |
img_left_placeholder = img_left.empty()
|
94 |
img_right_placeholder = img_right.empty()
|
95 |
+
img_left_placeholder.image(img_pil, caption="True Image", width=300)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
+
seg_image_path = f"results/{st.session_state.random_idx}/input.png"
|
98 |
+
while not os.path.exists(seg_image_path):
|
99 |
+
time.sleep(0.5)
|
100 |
+
img_right_image = Image.open(seg_image_path)
|
101 |
+
img_right_placeholder.image(img_right_image, caption="Segmentation Map", width=300)
|
102 |
|
103 |
+
# Progress tracking for fitting
|
104 |
expected_updates = 100
|
105 |
progress_bar = st.progress(0)
|
106 |
st.write("Fitting in progress")
|
107 |
displayed_images = 0
|
108 |
|
109 |
+
# Alternating display for fitting updates
|
110 |
col_left, col_right = st.columns(2)
|
111 |
left_placeholder = col_left.empty()
|
112 |
right_placeholder = col_right.empty()
|
113 |
|
|
|
114 |
while displayed_images < expected_updates:
|
115 |
if not image_queue.empty():
|
116 |
+
img = image_queue.get()
|
|
|
117 |
if displayed_images % 2 == 0:
|
118 |
left_placeholder.image(img, caption=f"Progress Update {displayed_images + 1}", width=300)
|
119 |
else:
|
120 |
right_placeholder.image(img, caption=f"Progress Update {displayed_images + 1}", width=300)
|
|
|
|
|
121 |
displayed_images += 1
|
122 |
progress_bar.progress(displayed_images / expected_updates)
|
|
|
123 |
time.sleep(0.3)
|
124 |
|
125 |
+
# Progress tracking for sampling
|
126 |
expected_updates = 10
|
127 |
s_progress_bar = st.progress(0)
|
128 |
displayed_images = 0
|
|
|
130 |
sample_left, sample_right = st.columns(2)
|
131 |
sleft_placeholder = sample_left.empty()
|
132 |
sright_placeholder = sample_right.empty()
|
133 |
+
|
134 |
while displayed_images < expected_updates:
|
135 |
if not sampling_queue.empty():
|
136 |
+
img = sampling_queue.get()
|
|
|
137 |
if displayed_images % 2 == 0:
|
138 |
sleft_placeholder.image(img, caption=f"Sampling Update {displayed_images + 1}", width=300)
|
139 |
else:
|
140 |
sright_placeholder.image(img, caption=f"Sampling Update {displayed_images + 1}", width=300)
|
|
|
141 |
displayed_images += 1
|
142 |
s_progress_bar.progress(displayed_images / expected_updates)
|
|
|
143 |
time.sleep(0.3)
|
144 |
|
145 |
+
# Wait for inpainting to finish
|
146 |
inpaint_thread.join()
|
147 |
st.success("Inpainting completed!")
|