File size: 5,398 Bytes
c1b628d
 
 
 
 
 
 
 
 
 
 
d73a015
c1b628d
 
 
d73a015
c1b628d
 
 
 
d73a015
c1b628d
18be823
d73a015
c1b628d
 
d73a015
c1b628d
 
 
d73a015
c1b628d
 
 
 
 
d73a015
c1b628d
 
d73a015
c1b628d
 
 
 
d73a015
c1b628d
 
 
 
d73a015
c1b628d
 
 
 
 
 
 
 
 
 
d73a015
c1b628d
d73a015
c1b628d
 
d73a015
c1b628d
 
d73a015
c1b628d
18be823
29b76c6
 
c1b628d
d73a015
c1b628d
18be823
 
d73a015
c1b628d
 
d73a015
 
c1b628d
 
 
d73a015
 
c1b628d
 
d73a015
c1b628d
d73a015
 
 
 
 
c1b628d
d73a015
c1b628d
 
 
 
 
d73a015
c1b628d
 
 
 
 
 
d73a015
c1b628d
 
 
 
 
 
 
 
d73a015
c1b628d
 
 
 
 
 
 
d73a015
c1b628d
 
d73a015
c1b628d
 
 
 
 
 
 
 
d73a015
c1b628d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import streamlit as st
from streamlit_drawable_canvas import st_canvas
from PIL import Image
import numpy as np
import random
import vipainting
import time
import threading
from queue import Queue
import os

# Initialize queues to handle in-progress images for fitting and sampling
image_queue = Queue()
sampling_queue = Queue()

# Set up the app title
st.title("Mask Your Own Inpaint")

@st.cache_data
def load_images():
    """Loads images from an NPZ file only once to cache."""
    data = np.load("data/sflckr_all_images.npz")

    images = data["images"]
    return images

# Initialize session state to hold the selected random image index
if "random_idx" not in st.session_state:
    st.session_state.random_idx = None

# Load images and add button to randomly select one
images = load_images()
if st.button("Random Pick"):
    st.session_state.random_idx = random.randint(0, images.shape[0] - 1)

def run_inpainting(random_idx, mask_array, image_queue, sampling_queue):
    """Starts inpainting and sends images to the queue."""
    vipainting.vipaint(random_idx, mask_array, image_queue, sampling_queue)

# Only proceed if a random image has been selected
if st.session_state.random_idx is not None:
    img_array = images[st.session_state.random_idx]
    img_pil = Image.fromarray(img_array)

    # Set up drawing canvas and display columns
    col1, col2 = st.columns(2)
    with col1:
        st.write("Draw your mask on the image below:")
        canvas_result = st_canvas(
            fill_color="rgba(255, 0, 0, 0.3)",
            stroke_width=50,
            stroke_color="black",
            background_image=img_pil,
            update_streamlit=True,
            width=300,
            height=300,
            drawing_mode="freedraw",
            key="canvas"
        )

    # Process the mask if drawn
    if canvas_result.image_data is not None:
        mask = canvas_result.image_data[:, :, 3]
        binary_mask = (mask > 128).astype(np.uint8) * 255

        # Show the binary mask in the right column
        with col2:
            st.write("Masked Image")
            st.image(binary_mask, caption="Binary Mask", width=300)


        mask_image = Image.fromarray(binary_mask).resize((512, 512))  

        mask_array = 255 - np.array(mask_image)
        mask_array = np.expand_dims(mask_array, axis=(0, 1))

        

        # Inpaint button action
        if st.button("inpaint"):
            st.write("Please wait...")

            # Start inpainting in a separate thread
            inpaint_thread = threading.Thread(target=run_inpainting, args=(st.session_state.random_idx, mask_array, image_queue, sampling_queue))
            inpaint_thread.start()

            # Display initial image and segmentation map
            img_left, img_right = st.columns(2)
            img_left_placeholder = img_left.empty()
            img_right_placeholder = img_right.empty()
            img_left_placeholder.image(img_pil, caption="True Image", width=300)

            seg_image_path = f"results/{st.session_state.random_idx}/input.png"
            while not os.path.exists(seg_image_path):
                time.sleep(0.5)
            img_right_image = Image.open(seg_image_path)
            img_right_placeholder.image(img_right_image, caption="Segmentation Map", width=300)

            # Progress tracking for fitting
            expected_updates = 100
            progress_bar = st.progress(0)
            st.write("Fitting in progress")
            displayed_images = 0

            # Alternating display for fitting updates
            col_left, col_right = st.columns(2)
            left_placeholder = col_left.empty()
            right_placeholder = col_right.empty()

            while displayed_images < expected_updates:
                if not image_queue.empty():
                    img = image_queue.get()
                    if displayed_images % 2 == 0:
                        left_placeholder.image(img, caption=f"Progress Update {displayed_images + 1}", width=300)
                    else:
                        right_placeholder.image(img, caption=f"Progress Update {displayed_images + 1}", width=300)
                    displayed_images += 1
                    progress_bar.progress(displayed_images / expected_updates)
                time.sleep(0.3)

            # Progress tracking for sampling
            expected_updates = 10
            s_progress_bar = st.progress(0)
            displayed_images = 0
            st.write("Sampling in progress")
            sample_left, sample_right = st.columns(2)
            sleft_placeholder = sample_left.empty()
            sright_placeholder = sample_right.empty()

            while displayed_images < expected_updates:
                if not sampling_queue.empty():
                    img = sampling_queue.get()
                    if displayed_images % 2 == 0:
                        sleft_placeholder.image(img, caption=f"Sampling Update {displayed_images + 1}", width=300)
                    else:
                        sright_placeholder.image(img, caption=f"Sampling Update {displayed_images + 1}", width=300)
                    displayed_images += 1
                    s_progress_bar.progress(displayed_images / expected_updates)
                time.sleep(0.3)

            # Wait for inpainting to finish
            inpaint_thread.join()
            st.success("Inpainting completed!")