Spaces:
Sleeping
Sleeping
File size: 5,196 Bytes
c1b628d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import streamlit as st
from streamlit_drawable_canvas import st_canvas
from PIL import Image
import numpy as np
import random
import vipainting
import time
import threading
from queue import Queue
import os
image_queue = Queue()
sampling_queue = Queue()
st.title("Mask Your Own Inpaint")
@st.cache_data
def load_images():
data = np.load("data/sflckr_all_images.npz")
images = data["images"]
return images
if "random_idx" not in st.session_state:
st.session_state.random_idx = None
images = load_images()
if st.button("Random Pick"):
st.session_state.random_idx = random.randint(0, images.shape[0] - 1)
def make_square(img, target_size=300):
size = max(img.size)
background = Image.new("RGB", (size, size), (255, 255, 255))
offset = ((size - img.size[0]) // 2, (size - img.size[1]) // 2)
background.paste(img, offset)
return background.resize((target_size, target_size))
def run_inpainting(random_idx, mask_array, image_queue, sampling_queue):
vipainting.vipaint(random_idx, mask_array, image_queue, sampling_queue)
if st.session_state.random_idx is not None:
img_array = images[st.session_state.random_idx]
img_pil = Image.fromarray(img_array)
img_pil = make_square(img_pil, target_size=300)
col1, col2 = st.columns(2)
with col1:
st.write("Draw your mask on the image below:")
canvas_result = st_canvas(
fill_color="rgba(255, 0, 0, 0.3)",
stroke_width=50,
stroke_color="black",
background_image=img_pil,
update_streamlit=True,
width=300,
height=300,
drawing_mode="freedraw",
key="canvas"
)
if canvas_result.image_data is not None:
mask = canvas_result.image_data[:, :, 3]
binary_mask = (mask > 128).astype(np.uint8) * 255
with col2:
st.write("Masked Image")
st.image(binary_mask, caption="Binary Mask", width=300)
mask_image = Image.fromarray(binary_mask).resize((512, 512), Image.ANTIALIAS)
mask_array = 255 - np.array(mask_image)
mask_array = np.expand_dims(mask_array, axis=(0, 1))
if st.button("inpaint"):
st.write("Please wait...")
inpaint_thread = threading.Thread(target=run_inpainting, args=(st.session_state.random_idx, mask_array, image_queue, sampling_queue))
inpaint_thread.start()
img_left, img_right = st.columns(2)
img_left_placeholder = img_left.empty()
img_right_placeholder = img_right.empty()
with img_left:
img_left_placeholder.image(img_pil, caption=f"True Image", width=300)
seg_image_path = f"results/{st.session_state.random_idx}/input.png"
while True:
if os.path.exists(seg_image_path):
with img_right:
img_right_image = Image.open(seg_image_path)
img_right_placeholder.image(img_right_image, caption="Segmentation Map", width=300)
break
time.sleep(0.5)
# Set up progress tracking
expected_updates = 100
progress_bar = st.progress(0)
st.write("Fitting in progress")
displayed_images = 0
col_left, col_right = st.columns(2)
left_placeholder = col_left.empty()
right_placeholder = col_right.empty()
while displayed_images < expected_updates:
if not image_queue.empty():
img = image_queue.get() # Get the next image from the queue
if displayed_images % 2 == 0:
left_placeholder.image(img, caption=f"Progress Update {displayed_images + 1}", width=300)
else:
right_placeholder.image(img, caption=f"Progress Update {displayed_images + 1}", width=300)
# Update progress bar
displayed_images += 1
progress_bar.progress(displayed_images / expected_updates)
time.sleep(0.3)
expected_updates = 10
s_progress_bar = st.progress(0)
displayed_images = 0
st.write("Sampling in progress")
sample_left, sample_right = st.columns(2)
sleft_placeholder = sample_left.empty()
sright_placeholder = sample_right.empty()
while displayed_images < expected_updates:
if not sampling_queue.empty():
img = sampling_queue.get()
if displayed_images % 2 == 0:
sleft_placeholder.image(img, caption=f"Sampling Update {displayed_images + 1}", width=300)
else:
sright_placeholder.image(img, caption=f"Sampling Update {displayed_images + 1}", width=300)
displayed_images += 1
s_progress_bar.progress(displayed_images / expected_updates)
time.sleep(0.3)
inpaint_thread.join()
st.success("Inpainting completed!")
|