JiminHeo commited on
Commit
c1b628d
·
1 Parent(s): 07293d0

first commit

Browse files
Files changed (2) hide show
  1. app.py +146 -0
  2. requirements.txt +115 -0
app.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_drawable_canvas import st_canvas
3
+ from PIL import Image
4
+ import numpy as np
5
+ import random
6
+ import vipainting
7
+ import time
8
+ import threading
9
+ from queue import Queue
10
+ import os
11
+
12
+ image_queue = Queue()
13
+ sampling_queue = Queue()
14
+
15
+
16
+ st.title("Mask Your Own Inpaint")
17
+
18
+ @st.cache_data
19
+ def load_images():
20
+ data = np.load("data/sflckr_all_images.npz")
21
+ images = data["images"]
22
+ return images
23
+
24
+ if "random_idx" not in st.session_state:
25
+ st.session_state.random_idx = None
26
+
27
+ images = load_images()
28
+ if st.button("Random Pick"):
29
+ st.session_state.random_idx = random.randint(0, images.shape[0] - 1)
30
+
31
+ def make_square(img, target_size=300):
32
+ size = max(img.size)
33
+ background = Image.new("RGB", (size, size), (255, 255, 255))
34
+ offset = ((size - img.size[0]) // 2, (size - img.size[1]) // 2)
35
+ background.paste(img, offset)
36
+ return background.resize((target_size, target_size))
37
+
38
+ def run_inpainting(random_idx, mask_array, image_queue, sampling_queue):
39
+ vipainting.vipaint(random_idx, mask_array, image_queue, sampling_queue)
40
+
41
+
42
+ if st.session_state.random_idx is not None:
43
+ img_array = images[st.session_state.random_idx]
44
+
45
+ img_pil = Image.fromarray(img_array)
46
+ img_pil = make_square(img_pil, target_size=300)
47
+
48
+
49
+ col1, col2 = st.columns(2)
50
+ with col1:
51
+ st.write("Draw your mask on the image below:")
52
+ canvas_result = st_canvas(
53
+ fill_color="rgba(255, 0, 0, 0.3)",
54
+ stroke_width=50,
55
+ stroke_color="black",
56
+ background_image=img_pil,
57
+ update_streamlit=True,
58
+ width=300,
59
+ height=300,
60
+ drawing_mode="freedraw",
61
+ key="canvas"
62
+ )
63
+
64
+
65
+ if canvas_result.image_data is not None:
66
+ mask = canvas_result.image_data[:, :, 3]
67
+ binary_mask = (mask > 128).astype(np.uint8) * 255
68
+
69
+ with col2:
70
+ st.write("Masked Image")
71
+ st.image(binary_mask, caption="Binary Mask", width=300)
72
+
73
+ mask_image = Image.fromarray(binary_mask).resize((512, 512), Image.ANTIALIAS)
74
+ mask_array = 255 - np.array(mask_image)
75
+ mask_array = np.expand_dims(mask_array, axis=(0, 1))
76
+
77
+ if st.button("inpaint"):
78
+ st.write("Please wait...")
79
+ inpaint_thread = threading.Thread(target=run_inpainting, args=(st.session_state.random_idx, mask_array, image_queue, sampling_queue))
80
+ inpaint_thread.start()
81
+
82
+ img_left, img_right = st.columns(2)
83
+ img_left_placeholder = img_left.empty()
84
+ img_right_placeholder = img_right.empty()
85
+ with img_left:
86
+ img_left_placeholder.image(img_pil, caption=f"True Image", width=300)
87
+ seg_image_path = f"results/{st.session_state.random_idx}/input.png"
88
+
89
+ while True:
90
+ if os.path.exists(seg_image_path):
91
+ with img_right:
92
+ img_right_image = Image.open(seg_image_path)
93
+ img_right_placeholder.image(img_right_image, caption="Segmentation Map", width=300)
94
+ break
95
+ time.sleep(0.5)
96
+
97
+
98
+ # Set up progress tracking
99
+ expected_updates = 100
100
+ progress_bar = st.progress(0)
101
+ st.write("Fitting in progress")
102
+ displayed_images = 0
103
+
104
+ col_left, col_right = st.columns(2)
105
+ left_placeholder = col_left.empty()
106
+ right_placeholder = col_right.empty()
107
+
108
+
109
+ while displayed_images < expected_updates:
110
+ if not image_queue.empty():
111
+ img = image_queue.get() # Get the next image from the queue
112
+
113
+ if displayed_images % 2 == 0:
114
+ left_placeholder.image(img, caption=f"Progress Update {displayed_images + 1}", width=300)
115
+ else:
116
+ right_placeholder.image(img, caption=f"Progress Update {displayed_images + 1}", width=300)
117
+
118
+ # Update progress bar
119
+ displayed_images += 1
120
+ progress_bar.progress(displayed_images / expected_updates)
121
+
122
+ time.sleep(0.3)
123
+
124
+ expected_updates = 10
125
+ s_progress_bar = st.progress(0)
126
+ displayed_images = 0
127
+ st.write("Sampling in progress")
128
+ sample_left, sample_right = st.columns(2)
129
+ sleft_placeholder = sample_left.empty()
130
+ sright_placeholder = sample_right.empty()
131
+ while displayed_images < expected_updates:
132
+ if not sampling_queue.empty():
133
+ img = sampling_queue.get()
134
+
135
+ if displayed_images % 2 == 0:
136
+ sleft_placeholder.image(img, caption=f"Sampling Update {displayed_images + 1}", width=300)
137
+ else:
138
+ sright_placeholder.image(img, caption=f"Sampling Update {displayed_images + 1}", width=300)
139
+
140
+ displayed_images += 1
141
+ s_progress_bar.progress(displayed_images / expected_updates)
142
+
143
+ time.sleep(0.3)
144
+
145
+ inpaint_thread.join()
146
+ st.success("Inpainting completed!")
requirements.txt ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ streamlit-drawable-canvas
3
+ pillow
4
+ numpy
5
+ # Core Packages
6
+ numpy==1.24.4
7
+ pillow==9.5.0
8
+ torch==2.4.1
9
+ torchvision==0.8.1
10
+ opencv-python==4.10.0.84
11
+ opencv-python-headless==4.10.0.84
12
+ tqdm==4.66.5
13
+ scipy==1.10.1
14
+ pandas==2.0.3
15
+ matplotlib==3.7.5
16
+ streamlit==1.39.0
17
+ streamlit-drawable-canvas==0.9.3
18
+
19
+ # PyPI Packages
20
+ absl-py==2.1.0
21
+ aiohttp==3.10.10
22
+ aiohappyeyeballs==2.4.3
23
+ aiosignal==1.3.1
24
+ albumentations==1.4.18
25
+ altair==5.4.1
26
+ async-timeout==4.0.3
27
+ attrs==24.2.0
28
+ blinker==1.8.2
29
+ cachetools==5.5.0
30
+ charset-normalizer==3.4.0
31
+ click==8.1.7
32
+ contourpy==1.1.1
33
+ diffusers==0.31.0
34
+ docker-pycreds==0.4.0
35
+ einops==0.8.0
36
+ filelock==3.16.1
37
+ fonttools==4.54.1
38
+ fsspec==2024.10.0
39
+ gitdb==4.0.11
40
+ gitpython==3.1.43
41
+ google-auth==2.35.0
42
+ google-auth-oauthlib==1.0.0
43
+ grpcio==1.67.0
44
+ huggingface-hub==0.26.1
45
+ idna==3.10
46
+ imageio==2.35.1
47
+ importlib-metadata==8.5.0
48
+ importlib-resources==6.4.5
49
+ invisible-watermark==0.2.0
50
+ jinja2==3.1.4
51
+ jsonschema==4.23.0
52
+ jsonschema-specifications==2023.12.1
53
+ kiwisolver==1.4.7
54
+ kornia==0.6.4
55
+ markdown==3.7
56
+ markdown-it-py==3.0.0
57
+ matplotlib==3.7.5
58
+ mdurl==0.1.2
59
+ mpmath==1.3.0
60
+ multidict==6.1.0
61
+ networkx==3.1
62
+ nvidia-cublas-cu12==12.1.3.1
63
+ nvidia-cuda-cupti-cu12==12.1.105
64
+ nvidia-cuda-nvrtc-cu12==12.1.105
65
+ nvidia-cuda-runtime-cu12==12.1.105
66
+ nvidia-cudnn-cu12==9.1.0.70
67
+ nvidia-cufft-cu12==11.0.2.54
68
+ nvidia-curand-cu12==10.3.2.106
69
+ nvidia-cusolver-cu12==11.4.5.107
70
+ nvidia-cusparse-cu12==12.1.0.106
71
+ nvidia-nccl-cu12==2.20.5
72
+ nvidia-nvjitlink-cu12==12.6.77
73
+ nvidia-nvtx-cu12==12.1.105
74
+ oauthlib==3.2.2
75
+ omegaconf==2.3.0
76
+ packaging==24.1
77
+ pkgutil-resolve-name==1.3.10
78
+ protobuf==3.20.1
79
+ psutil==6.1.0
80
+ pyarrow==17.0.0
81
+ pydeck==0.9.1
82
+ pydeprecate==0.3.2
83
+ pygments==2.18.0
84
+ pyparsing==3.1.4
85
+ python-dateutil==2.9.0.post0
86
+ pytorch-lightning==1.6.5
87
+ pyyaml==6.0.2
88
+ referencing==0.35.1
89
+ regex==2024.9.11
90
+ requests==2.32.3
91
+ requests-oauthlib==2.0.0
92
+ rich==13.9.3
93
+ rsa==4.9
94
+ safetensors==0.4.5
95
+ scikit-image==0.21.0
96
+ sentry-sdk==2.17.0
97
+ setproctitle==1.3.3
98
+ smmap==5.0.1
99
+ sympy==1.13.3
100
+ taming-transformers-rom1504==0.0.6
101
+ tenacity==9.0.0
102
+ tensorboard==2.14.0
103
+ tensorboard-data-server==0.7.2
104
+ tifffile==2023.7.10
105
+ tokenizers==0.12.1
106
+ toml==0.10.2
107
+ torchmetrics==0.6.0
108
+ transformers==4.19.2
109
+ triton==3.0.0
110
+ urllib3==2.2.3
111
+ wandb==0.18.5
112
+ watchdog==4.0.2
113
+ werkzeug==3.0.6
114
+ yarl==1.15.2
115
+ zipp==3.20.2