JiminHeo's picture
first commit
c1b628d
raw
history blame
5.2 kB
import streamlit as st
from streamlit_drawable_canvas import st_canvas
from PIL import Image
import numpy as np
import random
import vipainting
import time
import threading
from queue import Queue
import os
image_queue = Queue()
sampling_queue = Queue()
st.title("Mask Your Own Inpaint")
@st.cache_data
def load_images():
data = np.load("data/sflckr_all_images.npz")
images = data["images"]
return images
if "random_idx" not in st.session_state:
st.session_state.random_idx = None
images = load_images()
if st.button("Random Pick"):
st.session_state.random_idx = random.randint(0, images.shape[0] - 1)
def make_square(img, target_size=300):
size = max(img.size)
background = Image.new("RGB", (size, size), (255, 255, 255))
offset = ((size - img.size[0]) // 2, (size - img.size[1]) // 2)
background.paste(img, offset)
return background.resize((target_size, target_size))
def run_inpainting(random_idx, mask_array, image_queue, sampling_queue):
vipainting.vipaint(random_idx, mask_array, image_queue, sampling_queue)
if st.session_state.random_idx is not None:
img_array = images[st.session_state.random_idx]
img_pil = Image.fromarray(img_array)
img_pil = make_square(img_pil, target_size=300)
col1, col2 = st.columns(2)
with col1:
st.write("Draw your mask on the image below:")
canvas_result = st_canvas(
fill_color="rgba(255, 0, 0, 0.3)",
stroke_width=50,
stroke_color="black",
background_image=img_pil,
update_streamlit=True,
width=300,
height=300,
drawing_mode="freedraw",
key="canvas"
)
if canvas_result.image_data is not None:
mask = canvas_result.image_data[:, :, 3]
binary_mask = (mask > 128).astype(np.uint8) * 255
with col2:
st.write("Masked Image")
st.image(binary_mask, caption="Binary Mask", width=300)
mask_image = Image.fromarray(binary_mask).resize((512, 512), Image.ANTIALIAS)
mask_array = 255 - np.array(mask_image)
mask_array = np.expand_dims(mask_array, axis=(0, 1))
if st.button("inpaint"):
st.write("Please wait...")
inpaint_thread = threading.Thread(target=run_inpainting, args=(st.session_state.random_idx, mask_array, image_queue, sampling_queue))
inpaint_thread.start()
img_left, img_right = st.columns(2)
img_left_placeholder = img_left.empty()
img_right_placeholder = img_right.empty()
with img_left:
img_left_placeholder.image(img_pil, caption=f"True Image", width=300)
seg_image_path = f"results/{st.session_state.random_idx}/input.png"
while True:
if os.path.exists(seg_image_path):
with img_right:
img_right_image = Image.open(seg_image_path)
img_right_placeholder.image(img_right_image, caption="Segmentation Map", width=300)
break
time.sleep(0.5)
# Set up progress tracking
expected_updates = 100
progress_bar = st.progress(0)
st.write("Fitting in progress")
displayed_images = 0
col_left, col_right = st.columns(2)
left_placeholder = col_left.empty()
right_placeholder = col_right.empty()
while displayed_images < expected_updates:
if not image_queue.empty():
img = image_queue.get() # Get the next image from the queue
if displayed_images % 2 == 0:
left_placeholder.image(img, caption=f"Progress Update {displayed_images + 1}", width=300)
else:
right_placeholder.image(img, caption=f"Progress Update {displayed_images + 1}", width=300)
# Update progress bar
displayed_images += 1
progress_bar.progress(displayed_images / expected_updates)
time.sleep(0.3)
expected_updates = 10
s_progress_bar = st.progress(0)
displayed_images = 0
st.write("Sampling in progress")
sample_left, sample_right = st.columns(2)
sleft_placeholder = sample_left.empty()
sright_placeholder = sample_right.empty()
while displayed_images < expected_updates:
if not sampling_queue.empty():
img = sampling_queue.get()
if displayed_images % 2 == 0:
sleft_placeholder.image(img, caption=f"Sampling Update {displayed_images + 1}", width=300)
else:
sright_placeholder.image(img, caption=f"Sampling Update {displayed_images + 1}", width=300)
displayed_images += 1
s_progress_bar.progress(displayed_images / expected_updates)
time.sleep(0.3)
inpaint_thread.join()
st.success("Inpainting completed!")