JiminHeo commited on
Commit
d73a015
·
1 Parent(s): 80db475
Files changed (1) hide show
  1. app.py +37 -36
app.py CHANGED
@@ -9,48 +9,53 @@ import threading
9
  from queue import Queue
10
  import os
11
 
 
12
  image_queue = Queue()
13
  sampling_queue = Queue()
14
 
15
-
16
  st.title("Mask Your Own Inpaint")
17
 
18
  @st.cache_data
19
  def load_images():
 
20
  data = np.load("data/sflckr_all_images.npz")
21
- images = data["images"]
22
  return images
23
 
 
24
  if "random_idx" not in st.session_state:
25
  st.session_state.random_idx = None
26
 
 
27
  images = load_images()
28
  if st.button("Random Pick"):
29
  st.session_state.random_idx = random.randint(0, images.shape[0] - 1)
30
 
31
  def make_square(img, target_size=300):
32
- size = max(img.size)
 
33
  background = Image.new("RGB", (size, size), (255, 255, 255))
34
  offset = ((size - img.size[0]) // 2, (size - img.size[1]) // 2)
35
  background.paste(img, offset)
36
  return background.resize((target_size, target_size))
37
 
38
  def run_inpainting(random_idx, mask_array, image_queue, sampling_queue):
 
39
  vipainting.vipaint(random_idx, mask_array, image_queue, sampling_queue)
40
 
41
-
42
  if st.session_state.random_idx is not None:
43
  img_array = images[st.session_state.random_idx]
44
-
45
  img_pil = Image.fromarray(img_array)
46
- img_pil = make_square(img_pil, target_size=300)
47
-
48
 
 
49
  col1, col2 = st.columns(2)
50
  with col1:
51
  st.write("Draw your mask on the image below:")
52
  canvas_result = st_canvas(
53
- fill_color="rgba(255, 0, 0, 0.3)",
54
  stroke_width=50,
55
  stroke_color="black",
56
  background_image=img_pil,
@@ -61,66 +66,63 @@ if st.session_state.random_idx is not None:
61
  key="canvas"
62
  )
63
 
64
-
65
  if canvas_result.image_data is not None:
66
- mask = canvas_result.image_data[:, :, 3]
67
  binary_mask = (mask > 128).astype(np.uint8) * 255
68
 
 
69
  with col2:
70
  st.write("Masked Image")
71
- st.image(binary_mask, caption="Binary Mask", width=300)
72
 
73
- mask_image = Image.fromarray(binary_mask).resize((512, 512), Image.ANTIALIAS)
74
  mask_array = 255 - np.array(mask_image)
75
- mask_array = np.expand_dims(mask_array, axis=(0, 1))
76
 
 
77
  if st.button("inpaint"):
78
  st.write("Please wait...")
 
 
79
  inpaint_thread = threading.Thread(target=run_inpainting, args=(st.session_state.random_idx, mask_array, image_queue, sampling_queue))
80
  inpaint_thread.start()
81
 
82
- img_left, img_right = st.columns(2)
 
83
  img_left_placeholder = img_left.empty()
84
  img_right_placeholder = img_right.empty()
85
- with img_left:
86
- img_left_placeholder.image(img_pil, caption=f"True Image", width=300)
87
- seg_image_path = f"results/{st.session_state.random_idx}/input.png"
88
-
89
- while True:
90
- if os.path.exists(seg_image_path):
91
- with img_right:
92
- img_right_image = Image.open(seg_image_path)
93
- img_right_placeholder.image(img_right_image, caption="Segmentation Map", width=300)
94
- break
95
- time.sleep(0.5)
96
 
 
 
 
 
 
97
 
98
- # Set up progress tracking
99
  expected_updates = 100
100
  progress_bar = st.progress(0)
101
  st.write("Fitting in progress")
102
  displayed_images = 0
103
 
 
104
  col_left, col_right = st.columns(2)
105
  left_placeholder = col_left.empty()
106
  right_placeholder = col_right.empty()
107
 
108
-
109
  while displayed_images < expected_updates:
110
  if not image_queue.empty():
111
- img = image_queue.get() # Get the next image from the queue
112
-
113
  if displayed_images % 2 == 0:
114
  left_placeholder.image(img, caption=f"Progress Update {displayed_images + 1}", width=300)
115
  else:
116
  right_placeholder.image(img, caption=f"Progress Update {displayed_images + 1}", width=300)
117
-
118
- # Update progress bar
119
  displayed_images += 1
120
  progress_bar.progress(displayed_images / expected_updates)
121
-
122
  time.sleep(0.3)
123
 
 
124
  expected_updates = 10
125
  s_progress_bar = st.progress(0)
126
  displayed_images = 0
@@ -128,19 +130,18 @@ if st.session_state.random_idx is not None:
128
  sample_left, sample_right = st.columns(2)
129
  sleft_placeholder = sample_left.empty()
130
  sright_placeholder = sample_right.empty()
 
131
  while displayed_images < expected_updates:
132
  if not sampling_queue.empty():
133
- img = sampling_queue.get()
134
-
135
  if displayed_images % 2 == 0:
136
  sleft_placeholder.image(img, caption=f"Sampling Update {displayed_images + 1}", width=300)
137
  else:
138
  sright_placeholder.image(img, caption=f"Sampling Update {displayed_images + 1}", width=300)
139
-
140
  displayed_images += 1
141
  s_progress_bar.progress(displayed_images / expected_updates)
142
-
143
  time.sleep(0.3)
144
 
 
145
  inpaint_thread.join()
146
  st.success("Inpainting completed!")
 
9
  from queue import Queue
10
  import os
11
 
12
+ # Initialize queues to handle in-progress images for fitting and sampling
13
  image_queue = Queue()
14
  sampling_queue = Queue()
15
 
16
+ # Set up the app title
17
  st.title("Mask Your Own Inpaint")
18
 
19
  @st.cache_data
20
  def load_images():
21
+ """Loads images from an NPZ file only once to cache."""
22
  data = np.load("data/sflckr_all_images.npz")
23
+ images = data["images"]
24
  return images
25
 
26
+ # Initialize session state to hold the selected random image index
27
  if "random_idx" not in st.session_state:
28
  st.session_state.random_idx = None
29
 
30
+ # Load images and add button to randomly select one
31
  images = load_images()
32
  if st.button("Random Pick"):
33
  st.session_state.random_idx = random.randint(0, images.shape[0] - 1)
34
 
35
  def make_square(img, target_size=300):
36
+ """Pads and resizes the image to a square."""
37
+ size = max(img.size)
38
  background = Image.new("RGB", (size, size), (255, 255, 255))
39
  offset = ((size - img.size[0]) // 2, (size - img.size[1]) // 2)
40
  background.paste(img, offset)
41
  return background.resize((target_size, target_size))
42
 
43
  def run_inpainting(random_idx, mask_array, image_queue, sampling_queue):
44
+ """Starts inpainting and sends images to the queue."""
45
  vipainting.vipaint(random_idx, mask_array, image_queue, sampling_queue)
46
 
47
+ # Only proceed if a random image has been selected
48
  if st.session_state.random_idx is not None:
49
  img_array = images[st.session_state.random_idx]
 
50
  img_pil = Image.fromarray(img_array)
51
+ img_pil = make_square(img_pil, target_size=512) # Resized to 512x512
 
52
 
53
+ # Set up drawing canvas and display columns
54
  col1, col2 = st.columns(2)
55
  with col1:
56
  st.write("Draw your mask on the image below:")
57
  canvas_result = st_canvas(
58
+ fill_color="rgba(255, 0, 0, 0.3)",
59
  stroke_width=50,
60
  stroke_color="black",
61
  background_image=img_pil,
 
66
  key="canvas"
67
  )
68
 
69
+ # Process the mask if drawn
70
  if canvas_result.image_data is not None:
71
+ mask = canvas_result.image_data[:, :, 3]
72
  binary_mask = (mask > 128).astype(np.uint8) * 255
73
 
74
+ # Show the binary mask in the right column
75
  with col2:
76
  st.write("Masked Image")
77
+ st.image(binary_mask, caption="Binary Mask", width=300)
78
 
79
+ mask_image = Image.fromarray(binary_mask)
80
  mask_array = 255 - np.array(mask_image)
81
+ mask_array = np.expand_dims(mask_array, axis=(0, 1))
82
 
83
+ # Inpaint button action
84
  if st.button("inpaint"):
85
  st.write("Please wait...")
86
+
87
+ # Start inpainting in a separate thread
88
  inpaint_thread = threading.Thread(target=run_inpainting, args=(st.session_state.random_idx, mask_array, image_queue, sampling_queue))
89
  inpaint_thread.start()
90
 
91
+ # Display initial image and segmentation map
92
+ img_left, img_right = st.columns(2)
93
  img_left_placeholder = img_left.empty()
94
  img_right_placeholder = img_right.empty()
95
+ img_left_placeholder.image(img_pil, caption="True Image", width=300)
 
 
 
 
 
 
 
 
 
 
96
 
97
+ seg_image_path = f"results/{st.session_state.random_idx}/input.png"
98
+ while not os.path.exists(seg_image_path):
99
+ time.sleep(0.5)
100
+ img_right_image = Image.open(seg_image_path)
101
+ img_right_placeholder.image(img_right_image, caption="Segmentation Map", width=300)
102
 
103
+ # Progress tracking for fitting
104
  expected_updates = 100
105
  progress_bar = st.progress(0)
106
  st.write("Fitting in progress")
107
  displayed_images = 0
108
 
109
+ # Alternating display for fitting updates
110
  col_left, col_right = st.columns(2)
111
  left_placeholder = col_left.empty()
112
  right_placeholder = col_right.empty()
113
 
 
114
  while displayed_images < expected_updates:
115
  if not image_queue.empty():
116
+ img = image_queue.get()
 
117
  if displayed_images % 2 == 0:
118
  left_placeholder.image(img, caption=f"Progress Update {displayed_images + 1}", width=300)
119
  else:
120
  right_placeholder.image(img, caption=f"Progress Update {displayed_images + 1}", width=300)
 
 
121
  displayed_images += 1
122
  progress_bar.progress(displayed_images / expected_updates)
 
123
  time.sleep(0.3)
124
 
125
+ # Progress tracking for sampling
126
  expected_updates = 10
127
  s_progress_bar = st.progress(0)
128
  displayed_images = 0
 
130
  sample_left, sample_right = st.columns(2)
131
  sleft_placeholder = sample_left.empty()
132
  sright_placeholder = sample_right.empty()
133
+
134
  while displayed_images < expected_updates:
135
  if not sampling_queue.empty():
136
+ img = sampling_queue.get()
 
137
  if displayed_images % 2 == 0:
138
  sleft_placeholder.image(img, caption=f"Sampling Update {displayed_images + 1}", width=300)
139
  else:
140
  sright_placeholder.image(img, caption=f"Sampling Update {displayed_images + 1}", width=300)
 
141
  displayed_images += 1
142
  s_progress_bar.progress(displayed_images / expected_updates)
 
143
  time.sleep(0.3)
144
 
145
+ # Wait for inpainting to finish
146
  inpaint_thread.join()
147
  st.success("Inpainting completed!")