Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Add fast-foreground-estimation in masking image.
Browse files
    	
        app.py
    CHANGED
    
    | @@ -23,6 +23,40 @@ torch.jit.script = lambda f: f | |
| 23 |  | 
| 24 | 
             
            device = "cuda" if torch.cuda.is_available() else "CPU"
         | 
| 25 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 26 |  | 
| 27 | 
             
            def array_to_pil_image(image: np.ndarray, size: Tuple[int, int] = (1024, 1024)) -> Image.Image:
         | 
| 28 | 
             
                image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR)
         | 
| @@ -114,19 +148,16 @@ def predict(images, resolution, weights_file): | |
| 114 | 
             
                    if device == 'cuda':
         | 
| 115 | 
             
                        scaled_pred_tensor = scaled_pred_tensor.cpu()
         | 
| 116 |  | 
| 117 | 
            -
                    #  | 
| 118 | 
            -
                     | 
| 119 | 
            -
             | 
| 120 | 
            -
                     | 
| 121 | 
            -
                    image_pil = image_pil.resize(pred.shape[::-1])
         | 
| 122 | 
            -
                    pred = np.repeat(np.expand_dims(pred, axis=-1), 3, axis=-1)
         | 
| 123 | 
            -
                    image_masked = (pred * np.array(image_pil)).astype(np.uint8)
         | 
| 124 |  | 
| 125 | 
             
                    torch.cuda.empty_cache()
         | 
| 126 |  | 
| 127 | 
             
                    if tab_is_batch:
         | 
| 128 | 
             
                        save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
         | 
| 129 | 
            -
                         | 
| 130 | 
             
                        save_paths.append(save_file_path)
         | 
| 131 |  | 
| 132 | 
             
                if tab_is_batch:
         | 
|  | |
| 23 |  | 
| 24 | 
             
            device = "cuda" if torch.cuda.is_available() else "CPU"
         | 
| 25 |  | 
| 26 | 
            +
            ### image_proc.py
         | 
| 27 | 
            +
            def refine_foreground(image, mask, r=90):
         | 
| 28 | 
            +
                if mask.size != image.size:
         | 
| 29 | 
            +
                    mask = mask.resize(image.size)
         | 
| 30 | 
            +
                image = np.array(image) / 255.0
         | 
| 31 | 
            +
                mask = np.array(mask) / 255.0
         | 
| 32 | 
            +
                estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
         | 
| 33 | 
            +
                image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
         | 
| 34 | 
            +
                return image_masked
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
         | 
| 38 | 
            +
                # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
         | 
| 39 | 
            +
                alpha = alpha[:, :, None]
         | 
| 40 | 
            +
                F, blur_B = FB_blur_fusion_foreground_estimator(
         | 
| 41 | 
            +
                    image, image, image, alpha, r)
         | 
| 42 | 
            +
                return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
         | 
| 46 | 
            +
                if isinstance(image, Image.Image):
         | 
| 47 | 
            +
                    image = np.array(image) / 255.0
         | 
| 48 | 
            +
                blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                blurred_FA = cv2.blur(F * alpha, (r, r))
         | 
| 51 | 
            +
                blurred_F = blurred_FA / (blurred_alpha + 1e-5)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
         | 
| 54 | 
            +
                blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
         | 
| 55 | 
            +
                F = blurred_F + alpha * \
         | 
| 56 | 
            +
                    (image - alpha * blurred_F - (1 - alpha) * blurred_B)
         | 
| 57 | 
            +
                F = np.clip(F, 0, 1)
         | 
| 58 | 
            +
                return F, blurred_B
         | 
| 59 | 
            +
             | 
| 60 |  | 
| 61 | 
             
            def array_to_pil_image(image: np.ndarray, size: Tuple[int, int] = (1024, 1024)) -> Image.Image:
         | 
| 62 | 
             
                image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR)
         | 
|  | |
| 148 | 
             
                    if device == 'cuda':
         | 
| 149 | 
             
                        scaled_pred_tensor = scaled_pred_tensor.cpu()
         | 
| 150 |  | 
| 151 | 
            +
                    # Show Results
         | 
| 152 | 
            +
                    pred_pil = transforms.ToPILImage()(pred)
         | 
| 153 | 
            +
                    image_masked = refine_foreground(image, pred_pil)
         | 
| 154 | 
            +
                    image_masked.putalpha(pred_pil.resize(image.size))
         | 
|  | |
|  | |
|  | |
| 155 |  | 
| 156 | 
             
                    torch.cuda.empty_cache()
         | 
| 157 |  | 
| 158 | 
             
                    if tab_is_batch:
         | 
| 159 | 
             
                        save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
         | 
| 160 | 
            +
                        image_masked.save(save_file_path)
         | 
| 161 | 
             
                        save_paths.append(save_file_path)
         | 
| 162 |  | 
| 163 | 
             
                if tab_is_batch:
         | 
