import streamlit as st from streamlit_drawable_canvas import st_canvas from PIL import Image import numpy as np import random import vipainting import time import threading from queue import Queue import os # Initialize queues to handle in-progress images for fitting and sampling image_queue = Queue() sampling_queue = Queue() # Set up the app title st.title("Mask Your Own Inpaint") @st.cache_data def load_images(): """Loads images from an NPZ file only once to cache.""" data = np.load("data/sflckr_all_images.npz") images = data["images"] return images # Initialize session state to hold the selected random image index if "random_idx" not in st.session_state: st.session_state.random_idx = None # Load images and add button to randomly select one images = load_images() if st.button("Random Pick"): st.session_state.random_idx = random.randint(0, images.shape[0] - 1) def run_inpainting(random_idx, mask_array, image_queue, sampling_queue): """Starts inpainting and sends images to the queue.""" vipainting.vipaint(random_idx, mask_array, image_queue, sampling_queue) # Only proceed if a random image has been selected if st.session_state.random_idx is not None: img_array = images[st.session_state.random_idx] img_pil = Image.fromarray(img_array) # Set up drawing canvas and display columns col1, col2 = st.columns(2) with col1: st.write("Draw your mask on the image below:") canvas_result = st_canvas( fill_color="rgba(255, 0, 0, 0.3)", stroke_width=50, stroke_color="black", background_image=img_pil, update_streamlit=True, width=300, height=300, drawing_mode="freedraw", key="canvas" ) # Process the mask if drawn if canvas_result.image_data is not None: mask = canvas_result.image_data[:, :, 3] binary_mask = (mask > 128).astype(np.uint8) * 255 # Show the binary mask in the right column with col2: st.write("Masked Image") st.image(binary_mask, caption="Binary Mask", width=300) mask_image = Image.fromarray(binary_mask).resize((512, 512)) mask_array = 255 - np.array(mask_image) mask_array = np.expand_dims(mask_array, axis=(0, 1)) # Inpaint button action if st.button("inpaint"): st.write("Please wait...") # Start inpainting in a separate thread inpaint_thread = threading.Thread(target=run_inpainting, args=(st.session_state.random_idx, mask_array, image_queue, sampling_queue)) inpaint_thread.start() # Display initial image and segmentation map img_left, img_right = st.columns(2) img_left_placeholder = img_left.empty() img_right_placeholder = img_right.empty() img_left_placeholder.image(img_pil, caption="True Image", width=300) seg_image_path = f"results/{st.session_state.random_idx}/input.png" while not os.path.exists(seg_image_path): time.sleep(0.5) img_right_image = Image.open(seg_image_path) img_right_placeholder.image(img_right_image, caption="Segmentation Map", width=300) # Progress tracking for fitting expected_updates = 100 progress_bar = st.progress(0) st.write("Fitting in progress") displayed_images = 0 # Alternating display for fitting updates col_left, col_right = st.columns(2) left_placeholder = col_left.empty() right_placeholder = col_right.empty() while displayed_images < expected_updates: if not image_queue.empty(): img = image_queue.get() if displayed_images % 2 == 0: left_placeholder.image(img, caption=f"Progress Update {displayed_images + 1}", width=300) else: right_placeholder.image(img, caption=f"Progress Update {displayed_images + 1}", width=300) displayed_images += 1 progress_bar.progress(displayed_images / expected_updates) time.sleep(0.3) # Progress tracking for sampling expected_updates = 10 s_progress_bar = st.progress(0) displayed_images = 0 st.write("Sampling in progress") sample_left, sample_right = st.columns(2) sleft_placeholder = sample_left.empty() sright_placeholder = sample_right.empty() while displayed_images < expected_updates: if not sampling_queue.empty(): img = sampling_queue.get() if displayed_images % 2 == 0: sleft_placeholder.image(img, caption=f"Sampling Update {displayed_images + 1}", width=300) else: sright_placeholder.image(img, caption=f"Sampling Update {displayed_images + 1}", width=300) displayed_images += 1 s_progress_bar.progress(displayed_images / expected_updates) time.sleep(0.3) # Wait for inpainting to finish inpaint_thread.join() st.success("Inpainting completed!")