NeuralFalcon commited on
Commit
d62e696
·
verified ·
1 Parent(s): 863515e

Upload 7 files

Browse files
Files changed (7) hide show
  1. app.py +86 -0
  2. helper.py +376 -0
  3. packages.txt +1 -0
  4. pixelwise_estimator.py +114 -0
  5. requirement.txt +7 -0
  6. soft_foreground_segmenter.py +78 -0
  7. utils.py +163 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio_imageslider import ImageSlider
3
+ from helper import create_transparent_foreground,remove_background_batch_images,remove_background_from_video
4
+ from soft_foreground_segmenter import SoftForegroundSegmenter
5
+ foreground_model = "foreground-segmentation-model-vitl16_384.onnx"
6
+ foreground_segmenter = SoftForegroundSegmenter(onnx_model=foreground_model)
7
+
8
+ def process_image(image_path):
9
+ original, transparent, output_image_path = create_transparent_foreground(image_path,foreground_segmenter)
10
+ return (original, transparent), output_image_path
11
+
12
+ def ui1():
13
+ with gr.Blocks() as demo:
14
+ gr.Markdown("## 🪄 Background Remove From Image")
15
+
16
+ with gr.Row():
17
+ with gr.Column():
18
+ image_input = gr.Image(type="filepath", label="Upload Image")
19
+ btn = gr.Button("Remove Background")
20
+ with gr.Column():
21
+ image_slider = ImageSlider(label="Before vs After",position=0.5)
22
+ save_path_box = gr.File(label="Download Transparent Image")
23
+
24
+ btn.click(
25
+ fn=process_image,
26
+ inputs=image_input,
27
+ outputs=[image_slider, save_path_box]
28
+ )
29
+ gr.Examples(
30
+ examples=[["./assets/cat.png"],["./assets/girl.jpg"],["./assets/dog.jpg"]],
31
+ inputs=[image_input],
32
+ outputs=[image_slider, save_path_box],
33
+ fn=process_image,
34
+ cache_examples=True,
35
+ )
36
+
37
+ return demo
38
+
39
+
40
+ def process_uploaded_images(uploaded_images):
41
+ return remove_background_batch_images(uploaded_images,foreground_segmenter)
42
+ def ui2():
43
+ with gr.Blocks() as demo:
44
+ gr.Markdown("## 🪄 Background Remover From Bulk Images")
45
+ with gr.Row():
46
+ with gr.Column():
47
+ image_input = gr.File(file_types=["image"], file_count="multiple", label="Upload Multiple Images")
48
+ submit_btn = gr.Button("Remove Backgrounds")
49
+ with gr.Column():
50
+ zip_output = gr.File(label="Download ZIP")
51
+
52
+ submit_btn.click(fn=process_uploaded_images, inputs=image_input, outputs=zip_output)
53
+ return demo
54
+
55
+
56
+
57
+ def process_video(video_file):
58
+ output_path = remove_background_from_video(video_file, foreground_segmenter)
59
+ return output_path # should be absolute or relative path to processed video
60
+
61
+ def ui3():
62
+ # --- Gradio Interface ---
63
+ with gr.Blocks() as demo:
64
+ gr.Markdown("## 🎥 Remove Background From Video")
65
+
66
+ with gr.Row():
67
+ with gr.Column():
68
+ input_video = gr.Video(label="Upload Video (.mp4)")
69
+ run_btn = gr.Button("Remove Background")
70
+ with gr.Column():
71
+ output_video = gr.Video(label="Green Screen Video")
72
+
73
+ run_btn.click(fn=process_video, inputs=input_video, outputs=output_video)
74
+ # gr.Examples(
75
+ # examples=[["./assets/video.mp4"]],
76
+ # inputs=[input_video],
77
+ # outputs=[output_video],
78
+ # fn=process_video,
79
+ # cache_examples=True,
80
+ # )
81
+ return demo
82
+ demo1=ui1()
83
+ demo2=ui2()
84
+ demo3=ui3()
85
+ demo = gr.TabbedInterface([demo1, demo2,demo3],["Background Remove From Image","Background Remover From Bulk Images","Remove Background From Video"],title="Microsoft DAViD Background Remove")
86
+ demo.queue().launch(debug=True, share=True)
helper.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import os
4
+ import shutil
5
+ import subprocess
6
+ import glob
7
+ from tqdm.auto import tqdm
8
+ import uuid
9
+ import re
10
+ from zipfile import ZipFile
11
+
12
+ gpu = False
13
+ os.makedirs("./results",exist_ok=True)
14
+
15
+ def apply_green_screen(image_path, save_path,foreground_segmenter):
16
+ """
17
+ Replaces the background of the input image with green using a segmentation model.
18
+
19
+ Args:
20
+ image_path (str): Path to the input image.
21
+ segmenter (SoftForegroundSegmenter): Initialized segmentation model.
22
+ save_path (str, optional): If provided, saves the result to this path.
23
+
24
+ Returns:
25
+ np.ndarray: The green screen composited image.
26
+ """
27
+
28
+ # Load image with alpha if available
29
+ image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
30
+ if image is None:
31
+ raise FileNotFoundError(f"Image not found: {image_path}")
32
+
33
+ # Remove transparency if present
34
+ if image.shape[2] == 4:
35
+ image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
36
+
37
+ # Convert to RGB for the model
38
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
39
+
40
+ # Get segmentation mask
41
+ mask = foreground_segmenter.estimate_foreground_segmentation(image_rgb)
42
+
43
+ # Normalize and convert mask to 0-255 uint8
44
+ if mask.max() <= 1.0:
45
+ mask = (mask * 255).astype(np.uint8)
46
+ else:
47
+ mask = mask.astype(np.uint8)
48
+
49
+ if mask.ndim == 2:
50
+ mask_gray = mask
51
+ elif mask.shape[2] == 1:
52
+ mask_gray = mask[:, :, 0]
53
+ else:
54
+ mask_gray = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
55
+
56
+ _, binary_mask = cv2.threshold(mask_gray, 128, 255, cv2.THRESH_BINARY)
57
+
58
+ # Create green background
59
+ green_bg = np.full_like(image_rgb, (0, 255, 0), dtype=np.uint8)
60
+
61
+ # Create 3-channel mask
62
+ mask_3ch = cv2.cvtColor(binary_mask, cv2.COLOR_GRAY2BGR)
63
+
64
+ # Composite: foreground from image, background as green
65
+ output_rgb = np.where(mask_3ch == 255, image_rgb, green_bg)
66
+
67
+ # Convert back to BGR for OpenCV
68
+ output_bgr = cv2.cvtColor(output_rgb, cv2.COLOR_RGB2BGR)
69
+
70
+ # Save if path is given
71
+ if save_path:
72
+ cv2.imwrite(save_path, output_bgr)
73
+
74
+ return output_bgr
75
+
76
+
77
+ def create_transparent_foreground(image_path,foreground_segmenter):
78
+ uid = uuid.uuid4().hex[:8].upper()
79
+ base_name = os.path.splitext(os.path.basename(image_path))[0]
80
+ base_name = re.sub(r'[^a-zA-Z\s]', '', base_name)
81
+ base_name = base_name.strip().replace(" ", "_").replace("__","_")
82
+ save_path = f"./results/{base_name}_{uid}.png"
83
+ save_path = os.path.abspath(save_path)
84
+
85
+ image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
86
+ if image is None:
87
+ raise FileNotFoundError(f"Image not found: {image_path}")
88
+ if image.shape[2] == 4:
89
+ image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
90
+
91
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
92
+ mask = foreground_segmenter.estimate_foreground_segmentation(image_rgb)
93
+
94
+ if mask.max() <= 1.0:
95
+ mask = (mask * 255).astype(np.uint8)
96
+ else:
97
+ mask = mask.astype(np.uint8)
98
+
99
+ if mask.ndim == 3 and mask.shape[2] == 3:
100
+ mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
101
+
102
+ _, alpha = cv2.threshold(mask, 128, 255, cv2.THRESH_BINARY)
103
+ rgba_image = np.dstack((image_rgb, alpha))
104
+ cv2.imwrite(save_path, cv2.cvtColor(rgba_image, cv2.COLOR_RGBA2BGRA))
105
+
106
+ return image_rgb, rgba_image, save_path
107
+
108
+
109
+
110
+
111
+ def remove_background_batch_images(img_list, foreground_segmenter):
112
+ # Create unique temp directory
113
+ uid = uuid.uuid4().hex[:8].upper()
114
+ temp_dir = os.path.abspath(f"./results/bg_removed_{uid}")
115
+ os.makedirs(temp_dir, exist_ok=True)
116
+
117
+ # Process each image
118
+ for image_path in tqdm(img_list, desc="Removing Backgrounds"):
119
+ _, _, save_path = create_transparent_foreground(image_path, foreground_segmenter)
120
+ shutil.move(save_path, os.path.join(temp_dir, os.path.basename(save_path)))
121
+
122
+ # Create zip file
123
+ zip_path = f"{temp_dir}.zip"
124
+ with ZipFile(zip_path, 'w') as zipf:
125
+ for root, _, files in os.walk(temp_dir):
126
+ for file in files:
127
+ file_path = os.path.join(root, file)
128
+ arcname = os.path.relpath(file_path, start=temp_dir)
129
+ zipf.write(file_path, arcname=arcname)
130
+ # shutil.rmtree(temp_dir)
131
+ return os.path.abspath(zip_path)
132
+
133
+ def get_sorted_paths(directory, extension="png"):
134
+ """
135
+ Returns full paths of all images with the given extension, sorted by filename (without extension).
136
+ """
137
+ extension = extension.lstrip(".").lower()
138
+ pattern = os.path.join(directory, f"*.{extension}")
139
+ files = glob.glob(pattern)
140
+ files.sort(key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
141
+ return files
142
+
143
+
144
+ def extract_all_frames_ffmpeg_gpu(video_path, output_dir="frames", extension="png", use_gpu=True):
145
+ """
146
+ Extracts all frames from a video using ffmpeg, with optional GPU acceleration.
147
+ Returns a sorted list of full paths to the extracted frames.
148
+ """
149
+ if os.path.exists(output_dir):
150
+ shutil.rmtree(output_dir)
151
+ os.makedirs(output_dir, exist_ok=True)
152
+
153
+ extension = extension.lstrip(".")
154
+ output_pattern = os.path.join(output_dir, f"%05d.{extension}")
155
+
156
+ command = [
157
+ "ffmpeg", "-i", video_path, output_pattern
158
+ ]
159
+ if use_gpu:
160
+ command.insert(1, "cuda")
161
+ command.insert(1, "-hwaccel")
162
+
163
+ print("Running command:", " ".join(command))
164
+ subprocess.run(command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
165
+
166
+ return get_sorted_paths(output_dir, extension)
167
+
168
+
169
+
170
+ def green_screen_batch(frames, foreground_segmenter,output_dir="green_screen_frames"):
171
+ """
172
+ Applies green screen background to a batch of frames and saves the results.
173
+
174
+ Args:
175
+ frames (List[str]): List of image paths.
176
+ output_dir (str): Directory to save green-screened output.
177
+ """
178
+ if os.path.exists(output_dir):
179
+ shutil.rmtree(output_dir)
180
+ os.makedirs(output_dir, exist_ok=True)
181
+ green_screen_frames=[]
182
+ for frame in tqdm(frames, desc="Processing green screen frames"):
183
+ save_image_path=os.path.join(output_dir, os.path.basename(frame))
184
+ result = apply_green_screen(
185
+ frame,
186
+ save_image_path,
187
+ foreground_segmenter
188
+ )
189
+ green_screen_frames.append(save_image_path)
190
+ return green_screen_frames
191
+
192
+
193
+ def green_screen_video_maker(original_video, green_screen_frames, batch_size=100):
194
+ """
195
+ Creates video chunks from green screen frames based on original video's properties.
196
+
197
+ Args:
198
+ original_video (str): Path to the original video file (to read FPS, size).
199
+ green_screen_frames (List[str]): List of green screen frame paths.
200
+ batch_size (int): Number of frames per chunked video.
201
+ """
202
+ temp_folder = "temp_video"
203
+ if os.path.exists(temp_folder):
204
+ shutil.rmtree(temp_folder)
205
+ os.makedirs(temp_folder, exist_ok=True)
206
+
207
+ # Get video info from original video
208
+ cap = cv2.VideoCapture(original_video)
209
+ fps = cap.get(cv2.CAP_PROP_FPS)
210
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
211
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
212
+ cap.release()
213
+
214
+ total_frames = len(green_screen_frames)
215
+ num_chunks = (total_frames + batch_size - 1) // batch_size # Ceiling division
216
+
217
+ for chunk_idx in tqdm(range(num_chunks), desc="Processing video chunks"):
218
+ chunk_path = os.path.join(temp_folder, f"{chunk_idx+1}.mp4")
219
+ out = cv2.VideoWriter(chunk_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
220
+
221
+ start_idx = chunk_idx * batch_size
222
+ end_idx = min(start_idx + batch_size, total_frames)
223
+
224
+ for frame_path in green_screen_frames[start_idx:end_idx]:
225
+ frame = cv2.imread(frame_path)
226
+ frame = cv2.resize(frame, (width, height)) # Ensure matching resolution
227
+ out.write(frame)
228
+
229
+ out.release()
230
+
231
+
232
+
233
+ def merge_video_chunks(output_path="final_video.mp4", temp_folder="temp_video", use_gpu=True):
234
+ """
235
+ Merges all video chunks from temp_folder into a final single video.
236
+ """
237
+ os.makedirs("./results", exist_ok=True)
238
+ output_path = f"../results/{output_path}" # relative to temp_folder
239
+ file_list_path = os.path.join(temp_folder, "chunks.txt")
240
+ chunk_files=sorted(
241
+ [f for f in os.listdir(temp_folder) if f.lower().endswith("mp4")],
242
+ key=lambda x: int(os.path.splitext(x)[0])
243
+ )
244
+
245
+ with open(file_list_path, "w") as f:
246
+ for chunk in chunk_files:
247
+ f.write(f"file '{chunk}'\n") # ✅ No './' prefix
248
+
249
+ ffmpeg_cmd = ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", "chunks.txt"]
250
+
251
+ if use_gpu:
252
+ ffmpeg_cmd += ["-c:v", "h264_nvenc", "-preset", "fast"]
253
+ else:
254
+ ffmpeg_cmd += ["-c", "copy"]
255
+
256
+ ffmpeg_cmd.append(output_path)
257
+
258
+ # ✅ Run from inside temp_folder, so chunks.txt and mp4 files are local
259
+ subprocess.run(ffmpeg_cmd, cwd=temp_folder, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
260
+
261
+
262
+ def extract_audio_from_video(video_path, output_audio_path="output_audio.wav", format="wav", sample_rate=16000, channels=1):
263
+ """
264
+ Extracts audio from a video file using ffmpeg.
265
+
266
+ Args:
267
+ video_path (str): Path to the input video file.
268
+ output_audio_path (str): Path to save the extracted audio (e.g., .wav or .mp3).
269
+ format (str): 'wav' or 'mp3'
270
+ sample_rate (int): Sampling rate in Hz (e.g., 16000 for ASR models)
271
+ channels (int): Number of audio channels (1=mono, 2=stereo)
272
+ """
273
+ # Ensure the output directory exists
274
+ os.makedirs(os.path.dirname(output_audio_path) or ".", exist_ok=True)
275
+
276
+ # Build ffmpeg command
277
+ if format.lower() == "wav":
278
+ command = [
279
+ "ffmpeg", "-y", # Overwrite output
280
+ "-i", video_path, # Input video
281
+ "-vn", # Disable video
282
+ "-ac", str(channels), # Audio channels (1 = mono)
283
+ "-ar", str(sample_rate), # Audio sample rate
284
+ "-acodec", "pcm_s16le", # WAV codec
285
+ output_audio_path
286
+ ]
287
+ elif format.lower() == "mp3":
288
+ command = [
289
+ "ffmpeg", "-y",
290
+ "-i", video_path,
291
+ "-vn",
292
+ "-ac", str(channels),
293
+ "-ar", str(sample_rate),
294
+ "-acodec", "libmp3lame", # MP3 codec
295
+ output_audio_path
296
+ ]
297
+ else:
298
+ raise ValueError("Unsupported format. Use 'wav' or 'mp3'.")
299
+
300
+ # Run command silently
301
+ subprocess.run(command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
302
+
303
+ def add_audio(video_path, audio_path, output_path, use_gpu=False):
304
+ """
305
+ Replaces the audio of a video with a new audio track.
306
+
307
+ Args:
308
+ video_path (str): Path to the video file.
309
+ audio_path (str): Path to the audio file.
310
+ output_path (str): Path where the final video will be saved.
311
+ use_gpu (bool): If True, use GPU-accelerated video encoding.
312
+ """
313
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
314
+
315
+ command = [
316
+ "ffmpeg", "-y", # Overwrite without asking
317
+ "-i", video_path, # Input video
318
+ "-i", audio_path, # Input audio
319
+ "-map", "0:v:0", # Use video from first input
320
+ "-map", "1:a:0", # Use audio from second input
321
+ "-shortest" # Trim to the shortest stream (audio/video)
322
+ ]
323
+
324
+ if use_gpu:
325
+ command += ["-c:v", "h264_nvenc", "-preset", "fast"]
326
+ else:
327
+ command += ["-c:v", "copy"]
328
+
329
+ command += ["-c:a", "aac", "-b:a", "192k", output_path]
330
+
331
+ subprocess.run(command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
332
+
333
+
334
+
335
+ def remove_background_from_video(uploaded_video_path,foreground_segmenter):
336
+ # 🔁 Generate a single UUID to use for all related files
337
+ uid = uuid.uuid4().hex[:8].upper()
338
+
339
+ # Define all output paths using that UUID
340
+ base_name = os.path.splitext(os.path.basename(uploaded_video_path))[0]
341
+ base_name = re.sub(r'[^a-zA-Z\s]', '', base_name)
342
+ base_name = base_name.strip().replace(" ", "_")
343
+
344
+ temp_video_path = f"./results/{base_name}_chunks_{uid}.mp4"
345
+ audio_path = f"./results/{base_name}_audio_{uid}.wav"
346
+ final_output_path = f"./results/{base_name}_final_{uid}.mp4"
347
+
348
+ # Step 1: Extract frames
349
+ frames = extract_all_frames_ffmpeg_gpu(
350
+ video_path=uploaded_video_path,
351
+ output_dir="frames",
352
+ extension="png",
353
+ use_gpu=gpu
354
+ )
355
+
356
+ # Step 2: Remove background (green screen)
357
+ green_screen_frames = green_screen_batch(frames,foreground_segmenter)
358
+
359
+ # Step 3: Rebuild video from frames
360
+ green_screen_video_maker(uploaded_video_path, green_screen_frames, batch_size=100)
361
+
362
+ # Step 4: Merge video chunks
363
+ merge_video_chunks(output_path=os.path.basename(temp_video_path), use_gpu=gpu)
364
+
365
+ # Step 5: Extract original audio
366
+ extract_audio_from_video(uploaded_video_path, output_audio_path=audio_path)
367
+
368
+ # Step 6: Add audio back
369
+ add_audio(
370
+ video_path=temp_video_path,
371
+ audio_path=audio_path,
372
+ output_path=final_output_path,
373
+ use_gpu=True
374
+ )
375
+
376
+ return os.path.abspath(final_output_path)
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
pixelwise_estimator.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied From https://github.com/microsoft/DAViD/blob/main/runtime/pixelwise_estimator.py
2
+ """Runtime core for pixelwise estimators.
3
+
4
+ Copyright (c) Microsoft Corporation.
5
+
6
+ MIT License
7
+
8
+ Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ of this software and associated documentation files (the "Software"), to deal
10
+ in the Software without restriction, including without limitation the rights
11
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ copies of the Software, and to permit persons to whom the Software is
13
+ furnished to do so, subject to the following conditions:
14
+
15
+ The above copyright notice and this permission notice shall be included in all
16
+ copies or substantial portions of the Software.
17
+
18
+ THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ SOFTWARE.
25
+ """
26
+
27
+ from pathlib import Path
28
+ from typing import Optional, Union
29
+
30
+ import numpy as np
31
+ from onnxruntime import InferenceSession
32
+ from utils import ONNX_EP, ModelNotFoundError, prepare_image_for_model, preprocess_img
33
+
34
+
35
+ class RuntimeSession(InferenceSession):
36
+ """The runtime session."""
37
+
38
+ def __init__(self, onnx_model: Union[str, Path], providers: Optional[list[str]] = None) -> None:
39
+ """Create a runtime session.
40
+
41
+ Args:
42
+ onnx_model: The path to the onnx model.
43
+ providers: Optional list of ONNX execution providers to use, defaults to [GPU, CPU].
44
+ """
45
+ super().__init__(str(onnx_model), providers=providers or ONNX_EP)
46
+ self.onnx_model_path: Path = Path(onnx_model)
47
+
48
+ @property
49
+ def input_name(self) -> str:
50
+ """Get the name of the input tensor."""
51
+ return self.get_inputs()[0].name
52
+
53
+ def __call__(self, x: np.ndarray) -> list[np.ndarray]:
54
+ """Run the model on the input tensor."""
55
+ x = x.astype(np.float32)
56
+ return self.run(None, {self.input_name: x})
57
+
58
+
59
+ class PixelwiseEstimator:
60
+ """Given an input image, estimates the pixelwise (dense) output (e.g., normal map, depth map, etc.)."""
61
+
62
+ def __init__(self, onnx_model: Union[str, Path], providers: Optional[list[str]] = None):
63
+ """Creates a pixelwise estimator.
64
+
65
+ Arguments:
66
+ onnx_model: Path to an ONNX model.
67
+ providers: Optional list of ONNX execution providers to use, defaults to [GPU, CPU].
68
+
69
+ Raises:
70
+ TypeError: If onnx_model is not a string or Path.
71
+ ModelNotFoundError: If the model file does not exist.
72
+ ModelError: If the provided model has an undeclared or incorrect roi type.
73
+ """
74
+ if not isinstance(onnx_model, (str, Path)):
75
+ raise TypeError(f"onnx_model should be a string or Path, got {type(onnx_model)}")
76
+ onnx_model = Path(onnx_model)
77
+ if not onnx_model.exists():
78
+ raise ModelNotFoundError(f"model {onnx_model} does not exist")
79
+
80
+ self.onnx_model = onnx_model
81
+
82
+ self.roi_size = 512
83
+
84
+ self.onnx_sess = RuntimeSession(str(onnx_model), providers=providers)
85
+
86
+ @staticmethod
87
+ def inference(input_img: np.ndarray, onnx_sess: RuntimeSession) -> np.ndarray:
88
+ """Predict the pixelwise (dense) map given an input image.
89
+
90
+ Args:
91
+ input_img: Input image.
92
+ onnx_sess: ONNX inference session.
93
+
94
+ Returns:
95
+ Predicted output map.
96
+ """
97
+ input_tensor = onnx_sess.get_inputs()[0]
98
+ input_name = input_tensor.name
99
+ input_shape = input_tensor.shape
100
+ input_img = np.transpose(input_img, (2, 0, 1)).reshape(1, *input_shape[1:]) # HWC to BCHW
101
+ pred_onnx = onnx_sess.run(None, {input_name: input_img.astype(np.float32)})
102
+
103
+ return pred_onnx
104
+
105
+ def _estimate_dense_map(self, image: np.ndarray) -> tuple[np.ndarray]:
106
+ """Estimating dense maps from image input."""
107
+ if not isinstance(image, np.ndarray):
108
+ raise TypeError(f"Image should be a numpy array, got {type(image)}")
109
+
110
+ image_bgr = preprocess_img(image)
111
+ processed_image, metadata = prepare_image_for_model(image_bgr, self.roi_size)
112
+ output = self.inference(processed_image, self.onnx_sess)
113
+
114
+ return output, metadata
requirement.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ numpy==2.2.6
2
+ onnx==1.18.0
3
+ onnxruntime-gpu==1.22.0
4
+ opencv-python==4.12.0.88
5
+ opencv-python-headless==4.12.0.88
6
+ gradio>=5.38.2
7
+ gradio_imageslider==0.0.20
soft_foreground_segmenter.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied From https://github.com/microsoft/DAViD/blob/main/runtime/soft_foreground_segmenter.py
2
+ """This module provides a SoftForegroundSegmenter which segments the foreground human subjects from the background.
3
+
4
+ Copyright (c) Microsoft Corporation.
5
+
6
+ MIT License
7
+
8
+ Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ of this software and associated documentation files (the "Software"), to deal
10
+ in the Software without restriction, including without limitation the rights
11
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ copies of the Software, and to permit persons to whom the Software is
13
+ furnished to do so, subject to the following conditions:
14
+
15
+ The above copyright notice and this permission notice shall be included in all
16
+ copies or substantial portions of the Software.
17
+
18
+ THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ SOFTWARE.
25
+ """
26
+
27
+ from pathlib import Path
28
+ from typing import Optional, Union
29
+
30
+ import cv2
31
+ import numpy as np
32
+ from pixelwise_estimator import PixelwiseEstimator
33
+ from utils import composite_model_output_to_image
34
+
35
+
36
+ class SoftForegroundSegmenter(PixelwiseEstimator):
37
+ """Estimates the soft foreground segmentation mask of human in an image."""
38
+
39
+ def __init__(
40
+ self,
41
+ onnx_model: Union[str, Path],
42
+ providers: Optional[list[str]] = None,
43
+ binarization_threshold: Optional[float] = None,
44
+ ):
45
+ """Creates a soft foreground segmenter to segment the foreground human subjects in an image.
46
+
47
+ Arguments:
48
+ onnx_model: A path to an ONNX model.
49
+ providers: Optional list of ONNX execution providers to use, defaults to [GPU, CPU].
50
+ binarization_threshold: Threshold above which the mask is considered foreground. When None, the mask is returned as is.
51
+
52
+ Raises:
53
+ TypeError: if onnx_model is not a string or Path.
54
+ ModelNotFoundError: if the model file does not exist.
55
+ """
56
+ super().__init__(
57
+ onnx_model,
58
+ providers=providers,
59
+ )
60
+ self.binarization_threshold = binarization_threshold
61
+
62
+ def estimate_foreground_segmentation(self, image: np.ndarray) -> np.ndarray:
63
+ """Predict the soft foreground/background segmentation given input image."""
64
+ mask, metadata = self._estimate_dense_map(image)
65
+ mask = mask[0][0]
66
+ mask = np.transpose(mask, (1, 2, 0))
67
+
68
+ # post_process to get the final segmentation mask and composite it onto the original size
69
+ segmented_image = composite_model_output_to_image(mask, metadata, interp_mode=cv2.INTER_CUBIC)
70
+
71
+ # clip the mask to [0, 1]
72
+ segmented_image = np.clip(segmented_image, 0, 1)
73
+
74
+ # Apply threshold if binarization_threshold is set
75
+ if self.binarization_threshold:
76
+ return ((segmented_image > self.binarization_threshold) * 1).astype(np.uint8)
77
+
78
+ return segmented_image
utils.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Copied From https://github.com/microsoft/DAViD/blob/main/runtime/utils.py
2
+ """Utility classes and functions for image processing and ROI operations.
3
+
4
+ Copyright (c) Microsoft Corporation.
5
+
6
+ MIT License
7
+
8
+ Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ of this software and associated documentation files (the "Software"), to deal
10
+ in the Software without restriction, including without limitation the rights
11
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ copies of the Software, and to permit persons to whom the Software is
13
+ furnished to do so, subject to the following conditions:
14
+
15
+ The above copyright notice and this permission notice shall be included in all
16
+ copies or substantial portions of the Software.
17
+
18
+ THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ SOFTWARE.
25
+ """
26
+
27
+ import cv2
28
+ import numpy as np
29
+
30
+ ONNX_EP = ["CUDAExecutionProvider", "CPUExecutionProvider"]
31
+ UINT8_MAX = np.iinfo(np.uint8).max
32
+ UINT16_MAX = np.iinfo(np.uint16).max
33
+
34
+
35
+ class ImageFormatError(Exception):
36
+ """Exception raised for invalid image formats."""
37
+
38
+ pass
39
+
40
+
41
+ class ModelNotFoundError(Exception):
42
+ """Exception raised when model file is not found."""
43
+
44
+ pass
45
+
46
+
47
+ def preprocess_img(img: np.ndarray) -> np.ndarray:
48
+ """Preprocesses a BGR image for DNN. Turning to float if not already and normalizing to [0, 1].
49
+
50
+ Normalization of uint images is done by dividing by brightest possible value (e.g. 255 for uint8).
51
+
52
+ Arguments:
53
+ img: The image to preprocess, can be uint8, uint16, float16, float32 or float64.
54
+
55
+ Returns:
56
+ The preprocessed image in np.float32 format.
57
+
58
+ Raises:
59
+ ImageFormatError: If the image is not three channels or not uint8, uint16, float16, float32 or float64.
60
+ """
61
+ if img.ndim != 3 or img.shape[2] != 3:
62
+ raise ImageFormatError("image must be 3 channels, got shape: {img.shape}")
63
+ if img.dtype not in [np.uint8, np.uint16, np.float16, np.float32, np.float64]: # noqa: PLR6201
64
+ raise ImageFormatError("image must be uint8 or float16, float32, float64")
65
+
66
+ if img.dtype == np.uint8:
67
+ img = img.astype(np.float32) / UINT8_MAX
68
+ if img.dtype == np.uint16:
69
+ img = img.astype(np.float32) / UINT16_MAX
70
+ img = np.clip(img, 0, 1)
71
+ return img.astype(np.float32)
72
+
73
+
74
+ def prepare_image_for_model(image: np.ndarray, roi_size: int = 512) -> tuple[np.ndarray, dict]:
75
+ """Prepare any input image for model inference by resizing to roi_size x roi_size.
76
+
77
+ This function takes an image of any size and prepares it for a model that expects
78
+ a square input (e.g., 512x512). It handles aspect ratio preservation by padding
79
+ with replicated border values.
80
+
81
+ Args:
82
+ image: Input image of any size
83
+ roi_size: Target size for the model (default 512)
84
+
85
+ Returns:
86
+ tuple: (preprocessed_image, metadata_dict)
87
+ - preprocessed_image: Image resized to roi_size x roi_size
88
+ - metadata_dict: Contains information needed to composite back to original size
89
+ """
90
+ # Get original shape
91
+ original_shape = image.shape[:2] # (height, width)
92
+
93
+ # Calculate padding to make the image square
94
+ if original_shape[0] < original_shape[1]:
95
+ pad_h = (original_shape[1] - original_shape[0]) // 2
96
+ pad_w = 0
97
+ pad_h_extra = original_shape[1] - original_shape[0] - pad_h
98
+ pad_w_extra = 0
99
+ elif original_shape[0] > original_shape[1]:
100
+ pad_w = (original_shape[0] - original_shape[1]) // 2
101
+ pad_h = 0
102
+ pad_w_extra = original_shape[0] - original_shape[1] - pad_w
103
+ pad_h_extra = 0
104
+ else:
105
+ pad_h = pad_w = pad_h_extra = pad_w_extra = 0
106
+
107
+ # Pad the image to make it square
108
+ padded_image = cv2.copyMakeBorder(
109
+ image,
110
+ top=pad_h,
111
+ bottom=pad_h_extra,
112
+ left=pad_w,
113
+ right=pad_w_extra,
114
+ borderType=cv2.BORDER_REPLICATE,
115
+ )
116
+
117
+ square_shape = padded_image.shape[:2]
118
+
119
+ while padded_image.shape[1] > roi_size * 3 and padded_image.shape[0] > roi_size * 3:
120
+ padded_image = cv2.pyrDown(padded_image)
121
+
122
+ resized_image = cv2.resize(padded_image, (roi_size, roi_size), interpolation=cv2.INTER_LINEAR)
123
+
124
+ metadata = {
125
+ "original_shape": original_shape,
126
+ "square_shape": square_shape,
127
+ "original_padding": (pad_h, pad_w, pad_h_extra, pad_w_extra),
128
+ }
129
+
130
+ return resized_image, metadata
131
+
132
+
133
+ def composite_model_output_to_image(
134
+ model_output: np.ndarray, metadata: dict, interp_mode: int = cv2.INTER_NEAREST
135
+ ) -> np.ndarray:
136
+ """Composite model output back to the original image size.
137
+
138
+ Takes the model output (which should be roi_size x roi_size) and composites it
139
+ back to the original image dimensions using the metadata from prepare_image_for_model.
140
+
141
+ Args:
142
+ model_output: Output from the model (roi_size x roi_size)
143
+ metadata: Metadata dict returned from prepare_image_for_model
144
+ interp_mode: Interpolation mode for resizing (default INTER_NEAREST for discrete outputs)
145
+
146
+ Returns:
147
+ np.ndarray: Output composited to original image size
148
+ """
149
+ pad_h, pad_w, pad_h_extra, pad_w_extra = metadata["original_padding"]
150
+
151
+ # Resize the entire model output back to the square shape
152
+ square_shape = metadata["square_shape"]
153
+ resized_to_square = cv2.resize(model_output, (square_shape[1], square_shape[0]), interpolation=interp_mode)
154
+
155
+ # Remove the padding to get back to original dimensions
156
+ if pad_h > 0 or pad_h_extra > 0:
157
+ final_output = resized_to_square[pad_h : square_shape[0] - pad_h_extra, :]
158
+ elif pad_w > 0 or pad_w_extra > 0:
159
+ final_output = resized_to_square[:, pad_w : square_shape[1] - pad_w_extra]
160
+ else:
161
+ final_output = resized_to_square
162
+
163
+ return final_output