Update models/depth_normal_pipeline_clip.py
Browse files
    	
        models/depth_normal_pipeline_clip.py
    CHANGED
    
    | 
         @@ -79,7 +79,7 @@ class DepthNormalEstimationPipeline(DiffusionPipeline): 
     | 
|
| 79 | 
         
             
                             match_input_res:bool =True,
         
     | 
| 80 | 
         
             
                             batch_size:int = 0,
         
     | 
| 81 | 
         
             
                             domain: str = "indoor",
         
     | 
| 82 | 
         
            -
                             seed: int = 0,
         
     | 
| 83 | 
         
             
                             color_map: str="Spectral",
         
     | 
| 84 | 
         
             
                             show_progress_bar:bool = True,
         
     | 
| 85 | 
         
             
                             ensemble_kwargs: Dict = None,
         
     | 
| 
         @@ -148,7 +148,7 @@ class DepthNormalEstimationPipeline(DiffusionPipeline): 
     | 
|
| 148 | 
         
             
                            input_rgb=batched_image,
         
     | 
| 149 | 
         
             
                            num_inference_steps=denoising_steps,
         
     | 
| 150 | 
         
             
                            domain=domain,
         
     | 
| 151 | 
         
            -
                            seed=seed,
         
     | 
| 152 | 
         
             
                            show_pbar=show_progress_bar,
         
     | 
| 153 | 
         
             
                        )
         
     | 
| 154 | 
         
             
                        depth_pred_ls.append(depth_pred_raw.detach().clone())
         
     | 
| 
         @@ -232,7 +232,7 @@ class DepthNormalEstimationPipeline(DiffusionPipeline): 
     | 
|
| 232 | 
         
             
                def single_infer(self,input_rgb:torch.Tensor,
         
     | 
| 233 | 
         
             
                                 num_inference_steps:int,
         
     | 
| 234 | 
         
             
                                 domain:str,
         
     | 
| 235 | 
         
            -
                                 seed: int,
         
     | 
| 236 | 
         
             
                                 show_pbar:bool,):
         
     | 
| 237 | 
         | 
| 238 | 
         
             
                    device = input_rgb.device
         
     | 
| 
         @@ -245,8 +245,8 @@ class DepthNormalEstimationPipeline(DiffusionPipeline): 
     | 
|
| 245 | 
         
             
                    rgb_latent = self.encode_RGB(input_rgb)
         
     | 
| 246 | 
         | 
| 247 | 
         
             
                    # Initial depth map (Guassian noise)
         
     | 
| 248 | 
         
            -
                    if seed >= 0:
         
     | 
| 249 | 
         
            -
                        torch.manual_seed(0)
         
     | 
| 250 | 
         
             
                    geo_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype).repeat(2,1,1,1)
         
     | 
| 251 | 
         
             
                    rgb_latent = rgb_latent.repeat(2,1,1,1)
         
     | 
| 252 | 
         | 
| 
         | 
|
| 79 | 
         
             
                             match_input_res:bool =True,
         
     | 
| 80 | 
         
             
                             batch_size:int = 0,
         
     | 
| 81 | 
         
             
                             domain: str = "indoor",
         
     | 
| 82 | 
         
            +
                             #seed: int = 0,
         
     | 
| 83 | 
         
             
                             color_map: str="Spectral",
         
     | 
| 84 | 
         
             
                             show_progress_bar:bool = True,
         
     | 
| 85 | 
         
             
                             ensemble_kwargs: Dict = None,
         
     | 
| 
         | 
|
| 148 | 
         
             
                            input_rgb=batched_image,
         
     | 
| 149 | 
         
             
                            num_inference_steps=denoising_steps,
         
     | 
| 150 | 
         
             
                            domain=domain,
         
     | 
| 151 | 
         
            +
                            #seed=seed,
         
     | 
| 152 | 
         
             
                            show_pbar=show_progress_bar,
         
     | 
| 153 | 
         
             
                        )
         
     | 
| 154 | 
         
             
                        depth_pred_ls.append(depth_pred_raw.detach().clone())
         
     | 
| 
         | 
|
| 232 | 
         
             
                def single_infer(self,input_rgb:torch.Tensor,
         
     | 
| 233 | 
         
             
                                 num_inference_steps:int,
         
     | 
| 234 | 
         
             
                                 domain:str,
         
     | 
| 235 | 
         
            +
                                 #seed: int,
         
     | 
| 236 | 
         
             
                                 show_pbar:bool,):
         
     | 
| 237 | 
         | 
| 238 | 
         
             
                    device = input_rgb.device
         
     | 
| 
         | 
|
| 245 | 
         
             
                    rgb_latent = self.encode_RGB(input_rgb)
         
     | 
| 246 | 
         | 
| 247 | 
         
             
                    # Initial depth map (Guassian noise)
         
     | 
| 248 | 
         
            +
                    #if seed >= 0:
         
     | 
| 249 | 
         
            +
                        #torch.manual_seed(0)
         
     | 
| 250 | 
         
             
                    geo_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype).repeat(2,1,1,1)
         
     | 
| 251 | 
         
             
                    rgb_latent = rgb_latent.repeat(2,1,1,1)
         
     | 
| 252 | 
         |