Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from streamlit_drawable_canvas import st_canvas | |
| from PIL import Image | |
| import numpy as np | |
| import random | |
| import vipainting | |
| import time | |
| import threading | |
| from queue import Queue | |
| import os | |
| # Initialize queues to handle in-progress images for fitting and sampling | |
| image_queue = Queue() | |
| sampling_queue = Queue() | |
| # Set up the app title | |
| st.title("Mask Your Own Inpaint") | |
| def load_images(): | |
| """Loads images from an NPZ file only once to cache.""" | |
| data = np.load("data/sflckr_all_images.npz") | |
| images = data["images"] | |
| return images | |
| # Initialize session state to hold the selected random image index | |
| if "random_idx" not in st.session_state: | |
| st.session_state.random_idx = None | |
| # Load images and add button to randomly select one | |
| images = load_images() | |
| if st.button("Random Pick"): | |
| st.session_state.random_idx = random.randint(0, images.shape[0] - 1) | |
| def make_square(img, target_size=300): | |
| """Pads and resizes the image to a square.""" | |
| size = max(img.size) | |
| background = Image.new("RGB", (size, size), (255, 255, 255)) | |
| offset = ((size - img.size[0]) // 2, (size - img.size[1]) // 2) | |
| background.paste(img, offset) | |
| return background.resize((target_size, target_size)) | |
| def run_inpainting(random_idx, mask_array, image_queue, sampling_queue): | |
| """Starts inpainting and sends images to the queue.""" | |
| vipainting.vipaint(random_idx, mask_array, image_queue, sampling_queue) | |
| # Only proceed if a random image has been selected | |
| if st.session_state.random_idx is not None: | |
| img_array = images[st.session_state.random_idx] | |
| img_pil = Image.fromarray(img_array) | |
| img_pil = make_square(img_pil, target_size=512) # Resized to 512x512 | |
| # Set up drawing canvas and display columns | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.write("Draw your mask on the image below:") | |
| canvas_result = st_canvas( | |
| fill_color="rgba(255, 0, 0, 0.3)", | |
| stroke_width=50, | |
| stroke_color="black", | |
| background_image=img_pil, | |
| update_streamlit=True, | |
| width=300, | |
| height=300, | |
| drawing_mode="freedraw", | |
| key="canvas" | |
| ) | |
| # Process the mask if drawn | |
| if canvas_result.image_data is not None: | |
| mask = canvas_result.image_data[:, :, 3] | |
| binary_mask = (mask > 128).astype(np.uint8) * 255 | |
| # Show the binary mask in the right column | |
| with col2: | |
| st.write("Masked Image") | |
| st.image(binary_mask, caption="Binary Mask", width=300) | |
| mask_image = Image.fromarray(binary_mask) | |
| mask_array = 255 - np.array(mask_image) | |
| mask_array = np.expand_dims(mask_array, axis=(0, 1)) | |
| # Inpaint button action | |
| if st.button("inpaint"): | |
| st.write("Please wait...") | |
| # Start inpainting in a separate thread | |
| inpaint_thread = threading.Thread(target=run_inpainting, args=(st.session_state.random_idx, mask_array, image_queue, sampling_queue)) | |
| inpaint_thread.start() | |
| # Display initial image and segmentation map | |
| img_left, img_right = st.columns(2) | |
| img_left_placeholder = img_left.empty() | |
| img_right_placeholder = img_right.empty() | |
| img_left_placeholder.image(img_pil, caption="True Image", width=300) | |
| seg_image_path = f"results/{st.session_state.random_idx}/input.png" | |
| while not os.path.exists(seg_image_path): | |
| time.sleep(0.5) | |
| img_right_image = Image.open(seg_image_path) | |
| img_right_placeholder.image(img_right_image, caption="Segmentation Map", width=300) | |
| # Progress tracking for fitting | |
| expected_updates = 100 | |
| progress_bar = st.progress(0) | |
| st.write("Fitting in progress") | |
| displayed_images = 0 | |
| # Alternating display for fitting updates | |
| col_left, col_right = st.columns(2) | |
| left_placeholder = col_left.empty() | |
| right_placeholder = col_right.empty() | |
| while displayed_images < expected_updates: | |
| if not image_queue.empty(): | |
| img = image_queue.get() | |
| if displayed_images % 2 == 0: | |
| left_placeholder.image(img, caption=f"Progress Update {displayed_images + 1}", width=300) | |
| else: | |
| right_placeholder.image(img, caption=f"Progress Update {displayed_images + 1}", width=300) | |
| displayed_images += 1 | |
| progress_bar.progress(displayed_images / expected_updates) | |
| time.sleep(0.3) | |
| # Progress tracking for sampling | |
| expected_updates = 10 | |
| s_progress_bar = st.progress(0) | |
| displayed_images = 0 | |
| st.write("Sampling in progress") | |
| sample_left, sample_right = st.columns(2) | |
| sleft_placeholder = sample_left.empty() | |
| sright_placeholder = sample_right.empty() | |
| while displayed_images < expected_updates: | |
| if not sampling_queue.empty(): | |
| img = sampling_queue.get() | |
| if displayed_images % 2 == 0: | |
| sleft_placeholder.image(img, caption=f"Sampling Update {displayed_images + 1}", width=300) | |
| else: | |
| sright_placeholder.image(img, caption=f"Sampling Update {displayed_images + 1}", width=300) | |
| displayed_images += 1 | |
| s_progress_bar.progress(displayed_images / expected_updates) | |
| time.sleep(0.3) | |
| # Wait for inpainting to finish | |
| inpaint_thread.join() | |
| st.success("Inpainting completed!") | |