Spaces:
Sleeping
Sleeping
File size: 5,398 Bytes
c1b628d d73a015 c1b628d d73a015 c1b628d d73a015 c1b628d 18be823 d73a015 c1b628d d73a015 c1b628d d73a015 c1b628d d73a015 c1b628d d73a015 c1b628d d73a015 c1b628d d73a015 c1b628d d73a015 c1b628d d73a015 c1b628d d73a015 c1b628d d73a015 c1b628d 18be823 29b76c6 c1b628d d73a015 c1b628d 18be823 d73a015 c1b628d d73a015 c1b628d d73a015 c1b628d d73a015 c1b628d d73a015 c1b628d d73a015 c1b628d d73a015 c1b628d d73a015 c1b628d d73a015 c1b628d d73a015 c1b628d d73a015 c1b628d d73a015 c1b628d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import streamlit as st
from streamlit_drawable_canvas import st_canvas
from PIL import Image
import numpy as np
import random
import vipainting
import time
import threading
from queue import Queue
import os
# Initialize queues to handle in-progress images for fitting and sampling
image_queue = Queue()
sampling_queue = Queue()
# Set up the app title
st.title("Mask Your Own Inpaint")
@st.cache_data
def load_images():
"""Loads images from an NPZ file only once to cache."""
data = np.load("data/sflckr_all_images.npz")
images = data["images"]
return images
# Initialize session state to hold the selected random image index
if "random_idx" not in st.session_state:
st.session_state.random_idx = None
# Load images and add button to randomly select one
images = load_images()
if st.button("Random Pick"):
st.session_state.random_idx = random.randint(0, images.shape[0] - 1)
def run_inpainting(random_idx, mask_array, image_queue, sampling_queue):
"""Starts inpainting and sends images to the queue."""
vipainting.vipaint(random_idx, mask_array, image_queue, sampling_queue)
# Only proceed if a random image has been selected
if st.session_state.random_idx is not None:
img_array = images[st.session_state.random_idx]
img_pil = Image.fromarray(img_array)
# Set up drawing canvas and display columns
col1, col2 = st.columns(2)
with col1:
st.write("Draw your mask on the image below:")
canvas_result = st_canvas(
fill_color="rgba(255, 0, 0, 0.3)",
stroke_width=50,
stroke_color="black",
background_image=img_pil,
update_streamlit=True,
width=300,
height=300,
drawing_mode="freedraw",
key="canvas"
)
# Process the mask if drawn
if canvas_result.image_data is not None:
mask = canvas_result.image_data[:, :, 3]
binary_mask = (mask > 128).astype(np.uint8) * 255
# Show the binary mask in the right column
with col2:
st.write("Masked Image")
st.image(binary_mask, caption="Binary Mask", width=300)
mask_image = Image.fromarray(binary_mask).resize((512, 512))
mask_array = 255 - np.array(mask_image)
mask_array = np.expand_dims(mask_array, axis=(0, 1))
# Inpaint button action
if st.button("inpaint"):
st.write("Please wait...")
# Start inpainting in a separate thread
inpaint_thread = threading.Thread(target=run_inpainting, args=(st.session_state.random_idx, mask_array, image_queue, sampling_queue))
inpaint_thread.start()
# Display initial image and segmentation map
img_left, img_right = st.columns(2)
img_left_placeholder = img_left.empty()
img_right_placeholder = img_right.empty()
img_left_placeholder.image(img_pil, caption="True Image", width=300)
seg_image_path = f"results/{st.session_state.random_idx}/input.png"
while not os.path.exists(seg_image_path):
time.sleep(0.5)
img_right_image = Image.open(seg_image_path)
img_right_placeholder.image(img_right_image, caption="Segmentation Map", width=300)
# Progress tracking for fitting
expected_updates = 100
progress_bar = st.progress(0)
st.write("Fitting in progress")
displayed_images = 0
# Alternating display for fitting updates
col_left, col_right = st.columns(2)
left_placeholder = col_left.empty()
right_placeholder = col_right.empty()
while displayed_images < expected_updates:
if not image_queue.empty():
img = image_queue.get()
if displayed_images % 2 == 0:
left_placeholder.image(img, caption=f"Progress Update {displayed_images + 1}", width=300)
else:
right_placeholder.image(img, caption=f"Progress Update {displayed_images + 1}", width=300)
displayed_images += 1
progress_bar.progress(displayed_images / expected_updates)
time.sleep(0.3)
# Progress tracking for sampling
expected_updates = 10
s_progress_bar = st.progress(0)
displayed_images = 0
st.write("Sampling in progress")
sample_left, sample_right = st.columns(2)
sleft_placeholder = sample_left.empty()
sright_placeholder = sample_right.empty()
while displayed_images < expected_updates:
if not sampling_queue.empty():
img = sampling_queue.get()
if displayed_images % 2 == 0:
sleft_placeholder.image(img, caption=f"Sampling Update {displayed_images + 1}", width=300)
else:
sright_placeholder.image(img, caption=f"Sampling Update {displayed_images + 1}", width=300)
displayed_images += 1
s_progress_bar.progress(displayed_images / expected_updates)
time.sleep(0.3)
# Wait for inpainting to finish
inpaint_thread.join()
st.success("Inpainting completed!")
|