File size: 5,196 Bytes
c1b628d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import streamlit as st
from streamlit_drawable_canvas import st_canvas
from PIL import Image
import numpy as np
import random
import vipainting
import time
import threading
from queue import Queue
import os

image_queue = Queue()
sampling_queue = Queue()


st.title("Mask Your Own Inpaint")

@st.cache_data
def load_images():
    data = np.load("data/sflckr_all_images.npz")
    images = data["images"]  
    return images

if "random_idx" not in st.session_state:
    st.session_state.random_idx = None

images = load_images()
if st.button("Random Pick"):
    st.session_state.random_idx = random.randint(0, images.shape[0] - 1)

def make_square(img, target_size=300):
    size = max(img.size)  
    background = Image.new("RGB", (size, size), (255, 255, 255))
    offset = ((size - img.size[0]) // 2, (size - img.size[1]) // 2)
    background.paste(img, offset)
    return background.resize((target_size, target_size))

def run_inpainting(random_idx, mask_array, image_queue, sampling_queue):
    vipainting.vipaint(random_idx, mask_array, image_queue, sampling_queue)


if st.session_state.random_idx is not None:
    img_array = images[st.session_state.random_idx]

    img_pil = Image.fromarray(img_array)
    img_pil = make_square(img_pil, target_size=300)  


    col1, col2 = st.columns(2)
    with col1:
        st.write("Draw your mask on the image below:")
        canvas_result = st_canvas(
            fill_color="rgba(255, 0, 0, 0.3)",  
            stroke_width=50,
            stroke_color="black",
            background_image=img_pil,
            update_streamlit=True,
            width=300,
            height=300,
            drawing_mode="freedraw",
            key="canvas"
        )


    if canvas_result.image_data is not None:
        mask = canvas_result.image_data[:, :, 3]  
        binary_mask = (mask > 128).astype(np.uint8) * 255

        with col2:
            st.write("Masked Image")
            st.image(binary_mask, caption="Binary Mask", width=300)  

        mask_image = Image.fromarray(binary_mask).resize((512, 512), Image.ANTIALIAS)
        mask_array = 255 - np.array(mask_image)
        mask_array = np.expand_dims(mask_array, axis=(0, 1)) 

        if st.button("inpaint"):
            st.write("Please wait...")
            inpaint_thread = threading.Thread(target=run_inpainting, args=(st.session_state.random_idx, mask_array, image_queue, sampling_queue))
            inpaint_thread.start()

            img_left, img_right = st.columns(2) 
            img_left_placeholder = img_left.empty()
            img_right_placeholder = img_right.empty()
            with img_left:
                img_left_placeholder.image(img_pil, caption=f"True Image", width=300)
            seg_image_path = f"results/{st.session_state.random_idx}/input.png"

            while True:
                if os.path.exists(seg_image_path):
                    with img_right:
                        img_right_image = Image.open(seg_image_path)
                        img_right_placeholder.image(img_right_image, caption="Segmentation Map", width=300)
                    break  
                time.sleep(0.5)  


            # Set up progress tracking
            expected_updates = 100
            progress_bar = st.progress(0)
            st.write("Fitting in progress")
            displayed_images = 0

            col_left, col_right = st.columns(2)
            left_placeholder = col_left.empty()
            right_placeholder = col_right.empty()

            
            while displayed_images < expected_updates:
                if not image_queue.empty():
                    img = image_queue.get()  # Get the next image from the queue

                    if displayed_images % 2 == 0:
                        left_placeholder.image(img, caption=f"Progress Update {displayed_images + 1}", width=300)
                    else:
                        right_placeholder.image(img, caption=f"Progress Update {displayed_images + 1}", width=300)

                    # Update progress bar
                    displayed_images += 1
                    progress_bar.progress(displayed_images / expected_updates)

                time.sleep(0.3)

            expected_updates = 10
            s_progress_bar = st.progress(0)
            displayed_images = 0
            st.write("Sampling in progress")
            sample_left, sample_right = st.columns(2)
            sleft_placeholder = sample_left.empty()
            sright_placeholder = sample_right.empty()
            while displayed_images < expected_updates:
                if not sampling_queue.empty():
                    img = sampling_queue.get()  

                    if displayed_images % 2 == 0:
                        sleft_placeholder.image(img, caption=f"Sampling Update {displayed_images + 1}", width=300)
                    else:
                        sright_placeholder.image(img, caption=f"Sampling Update {displayed_images + 1}", width=300)

                    displayed_images += 1
                    s_progress_bar.progress(displayed_images / expected_updates)

                time.sleep(0.3)

            inpaint_thread.join()
            st.success("Inpainting completed!")