Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Migrated from GitHub
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitattributes +1 -0
- LICENSE.txt +21 -0
- ORIGINAL_README.md +79 -0
- assets/result_clr_scale4_pexels-barbara-olsen-7869640.mp4 +0 -0
- assets/result_clr_scale4_pexels-zdmit-6780091.mp4 +0 -0
- blender/blender_render_human_ortho.py +837 -0
- blender/check_render.py +46 -0
- blender/count.py +44 -0
- blender/distribute.py +149 -0
- blender/rename_smpl_files.py +25 -0
- blender/render.sh +4 -0
- blender/render_human.py +88 -0
- blender/render_single.sh +7 -0
- blender/utils.py +128 -0
- configs/inference-768-6view.yaml +72 -0
- configs/remesh.yaml +18 -0
- configs/train-768-6view-onlyscan_face.yaml +145 -0
- configs/train-768-6view-onlyscan_face_smplx.yaml +154 -0
- core/opt.py +197 -0
- core/remesh.py +359 -0
- econdataset.py +370 -0
- examples/02986d0998ce01aa0aa67a99fbd1e09a.png +0 -0
- examples/16171.png +0 -0
- examples/26d2e846349647ff04c536816e0e8ca1.png +0 -0
- examples/30755.png +0 -0
- examples/3930.png +0 -0
- examples/4656716-3016170581.png +0 -0
- examples/663dcd6db19490de0b790da430bd5681.png +3 -0
- examples/7332.png +0 -0
- examples/85891251f52a2399e660a63c2a7fdf40.png +0 -0
- examples/a689a48d23d6b8d58d67ff5146c6e088.png +0 -0
- examples/b0d178743c7e3e09700aaee8d2b1ec47.png +0 -0
- examples/case5.png +0 -0
- examples/d40776a1e1582179d97907d36f84d776.png +0 -0
- examples/durant.png +0 -0
- examples/eedb9018-e0eb-45be-33bd-5a0108ca0d8b.png +0 -0
- examples/f14f7d40b72062928461b21c6cc877407e69ee0c_high.png +0 -0
- examples/f6317ac1b0498f4e6ef9d12bd991a9bd1ff4ae04f898-IQTEBw_fw1200.png +0 -0
- examples/pexels-barbara-olsen-7869640.png +0 -0
- examples/pexels-julia-m-cameron-4145040.png +0 -0
- examples/pexels-marta-wave-6437749.png +0 -0
- examples/pexels-photo-6311555-removebg.png +0 -0
- examples/pexels-zdmit-6780091.png +0 -0
- inference.py +221 -0
- lib/__init__.py +0 -0
- lib/common/__init__.py +0 -0
- lib/common/cloth_extraction.py +182 -0
- lib/common/config.py +218 -0
- lib/common/imutils.py +364 -0
- lib/common/render.py +398 -0
    	
        .gitattributes
    CHANGED
    
    | @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
|  | 
|  | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
| 36 | 
            +
            examples/663dcd6db19490de0b790da430bd5681.png filter=lfs diff=lfs merge=lfs -text
         | 
    	
        LICENSE.txt
    ADDED
    
    | @@ -0,0 +1,21 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            MIT License
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            Copyright (c) 2024 Fusion Lab: Generative Vision Lab of Fudan University
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            Permission is hereby granted, free of charge, to any person obtaining a copy
         | 
| 6 | 
            +
            of this software and associated documentation files (the "Software"), to deal
         | 
| 7 | 
            +
            in the Software without restriction, including without limitation the rights
         | 
| 8 | 
            +
            to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         | 
| 9 | 
            +
            copies of the Software, and to permit persons to whom the Software is
         | 
| 10 | 
            +
            furnished to do so, subject to the following conditions:
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            The above copyright notice and this permission notice shall be included in all
         | 
| 13 | 
            +
            copies or substantial portions of the Software.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         | 
| 16 | 
            +
            IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         | 
| 17 | 
            +
            FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         | 
| 18 | 
            +
            AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         | 
| 19 | 
            +
            LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         | 
| 20 | 
            +
            OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         | 
| 21 | 
            +
            SOFTWARE.
         | 
    	
        ORIGINAL_README.md
    ADDED
    
    | @@ -0,0 +1,79 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # PSHuman
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            This is the official implementation of *PSHuman: Photorealistic Single-image 3D Human Reconstruction using Cross-Scale Multiview Diffusion*.
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            ### [Project Page](https://penghtyx.github.io/PSHuman/) | [Arxiv](https://arxiv.org/pdf/2409.10141) | [Weights](https://huggingface.co/pengHTYX/PSHuman_Unclip_768_6views) 
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            https://github.com/user-attachments/assets/b62e3305-38a7-4b51-aed8-1fde967cca70
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            https://github.com/user-attachments/assets/76100d2e-4a1a-41ad-815c-816340ac6500
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            Given a single image of a clothed person, **PSHuman** facilitates detailed geometry and realistic 3D human appearance across various poses within one minute.
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            ### 📝 Update
         | 
| 15 | 
            +
            - __[2024.11.30]__: Release the SMPL-free [version](https://huggingface.co/pengHTYX/PSHuman_Unclip_768_6views), which does not requires SMPL condition for multview generation and perfome well in general posed human.
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            ### Installation
         | 
| 19 | 
            +
            ```
         | 
| 20 | 
            +
            conda create -n pshuman python=3.10
         | 
| 21 | 
            +
            conda activate pshuman
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            # torch
         | 
| 24 | 
            +
            pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            # other depedency
         | 
| 27 | 
            +
            pip install -r requirement.txt
         | 
| 28 | 
            +
            ```
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            This project is also based on SMPLX. We borrowed the related models from [ECON](https://github.com/YuliangXiu/ECON) and [SIFU](https://github.com/River-Zhang/SIFU), and re-orginized them, which can be downloaded from [Onedrive](https://hkustconnect-my.sharepoint.com/:u:/g/personal/plibp_connect_ust_hk/EZQphP-2y5BGhEIe8jb03i4BIcqiJ2mUW2JmGC5s0VKOdw?e=qVzBBD). 
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            ### Inference
         | 
| 35 | 
            +
            1. Given a human image, we use [Clipdrop](https://github.com/xxlong0/Wonder3D?tab=readme-ov-file) or ```rembg``` to remove the background. For the latter, we provide a simple scrip.
         | 
| 36 | 
            +
            ```
         | 
| 37 | 
            +
            python utils/remove_bg.py --path $DATA_PATH$
         | 
| 38 | 
            +
            ```
         | 
| 39 | 
            +
            Then, put the RGBA images in the ```$DATA_PATH$```.
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            2. By running [inference.py](inference.py), the textured mesh and rendered video will be saved in ```out```.
         | 
| 42 | 
            +
            ```
         | 
| 43 | 
            +
            CUDA_VISIBLE_DEVICES=$GPU python inference.py --config configs/inference-768-6view.yaml \
         | 
| 44 | 
            +
                pretrained_model_name_or_path='pengHTYX/PSHuman_Unclip_768_6views' \
         | 
| 45 | 
            +
                validation_dataset.crop_size=740 \
         | 
| 46 | 
            +
                with_smpl=false \
         | 
| 47 | 
            +
                validation_dataset.root_dir=$DATA_PATH$ \
         | 
| 48 | 
            +
                seed=600 \
         | 
| 49 | 
            +
                num_views=7 \
         | 
| 50 | 
            +
                save_mode='rgb' 
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            ``` 
         | 
| 53 | 
            +
            You can adjust the ```crop_size``` (720 or 740) and ```seed``` (42 or 600) to obtain best results for some cases.  
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            ### Training
         | 
| 56 | 
            +
            For the data preparing and preprocessing, please refer to our [paper](https://arxiv.org/pdf/2409.10141). Once the data is ready, we begin the training by running
         | 
| 57 | 
            +
            ```
         | 
| 58 | 
            +
            bash scripts/train_768.sh
         | 
| 59 | 
            +
            ```
         | 
| 60 | 
            +
            You should modified some parameters, such as ```data_common.root_dir``` and ```data_common.object_list```.
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            ### Related projects
         | 
| 63 | 
            +
            We collect code from following projects. We thanks for the contributions from the open-source community!     
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            [ECON](https://github.com/YuliangXiu/ECON) and [SIFU](https://github.com/River-Zhang/SIFU) recover human mesh from single human image.   
         | 
| 66 | 
            +
            [Era3D](https://github.com/pengHTYX/Era3D) and [Unique3D](https://github.com/AiuniAI/Unique3D) generate consistent multiview images with single color image.  
         | 
| 67 | 
            +
            [Continuous-Remeshing](https://github.com/Profactor/continuous-remeshing) for Inverse Rendering.
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            ### Citation
         | 
| 71 | 
            +
            If you find this codebase useful, please consider cite our work.
         | 
| 72 | 
            +
            ```
         | 
| 73 | 
            +
            @article{li2024pshuman,
         | 
| 74 | 
            +
              title={PSHuman: Photorealistic Single-view Human Reconstruction using Cross-Scale Diffusion},
         | 
| 75 | 
            +
              author={Li, Peng and Zheng, Wangguandong and Liu, Yuan and Yu, Tao and Li, Yangguang and Qi, Xingqun and Li, Mengfei and Chi, Xiaowei and Xia, Siyu and Xue, Wei and others},
         | 
| 76 | 
            +
              journal={arXiv preprint arXiv:2409.10141},
         | 
| 77 | 
            +
              year={2024}
         | 
| 78 | 
            +
            }
         | 
| 79 | 
            +
            ```
         | 
    	
        assets/result_clr_scale4_pexels-barbara-olsen-7869640.mp4
    ADDED
    
    | Binary file (320 kB). View file | 
|  | 
    	
        assets/result_clr_scale4_pexels-zdmit-6780091.mp4
    ADDED
    
    | Binary file (629 kB). View file | 
|  | 
    	
        blender/blender_render_human_ortho.py
    ADDED
    
    | @@ -0,0 +1,837 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """Blender script to render images of 3D models.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            This script is used to render images of 3D models. It takes in a list of paths
         | 
| 4 | 
            +
            to .glb files and renders images of each model. The images are from rotating the
         | 
| 5 | 
            +
            object around the origin. The images are saved to the output directory.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            Example usage:
         | 
| 8 | 
            +
                blender -b -P blender_script.py -- \
         | 
| 9 | 
            +
                    --object_path my_object.glb \
         | 
| 10 | 
            +
                    --output_dir ./views \
         | 
| 11 | 
            +
                    --engine CYCLES \
         | 
| 12 | 
            +
                    --scale 0.8 \
         | 
| 13 | 
            +
                    --num_images 12 \
         | 
| 14 | 
            +
                    --camera_dist 1.2
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            Here, input_model_paths.json is a json file containing a list of paths to .glb.
         | 
| 17 | 
            +
            """
         | 
| 18 | 
            +
            import argparse
         | 
| 19 | 
            +
            import json
         | 
| 20 | 
            +
            import math
         | 
| 21 | 
            +
            import os
         | 
| 22 | 
            +
            import random
         | 
| 23 | 
            +
            import sys
         | 
| 24 | 
            +
            import time
         | 
| 25 | 
            +
            import glob
         | 
| 26 | 
            +
            import urllib.request
         | 
| 27 | 
            +
            import uuid
         | 
| 28 | 
            +
            from typing import Tuple
         | 
| 29 | 
            +
            from mathutils import Vector, Matrix
         | 
| 30 | 
            +
            os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"
         | 
| 31 | 
            +
            # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
         | 
| 32 | 
            +
            import cv2
         | 
| 33 | 
            +
            import numpy as np
         | 
| 34 | 
            +
            from typing import Any, Callable, Dict, Generator, List, Literal, Optional, Set, Tuple
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            import bpy
         | 
| 37 | 
            +
            from mathutils import Vector
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            import OpenEXR
         | 
| 40 | 
            +
            import Imath
         | 
| 41 | 
            +
            from PIL import Image
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            # import blenderproc as bproc
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            bpy.app.debug_value=256
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            parser = argparse.ArgumentParser()
         | 
| 48 | 
            +
            parser.add_argument(
         | 
| 49 | 
            +
                "--object_path",
         | 
| 50 | 
            +
                type=str,
         | 
| 51 | 
            +
                required=True,
         | 
| 52 | 
            +
                help="Path to the object file",
         | 
| 53 | 
            +
            )
         | 
| 54 | 
            +
            parser.add_argument("--smpl_path", type=str, required=True, help="Path to the object file")
         | 
| 55 | 
            +
            parser.add_argument("--output_dir", type=str, default="/views_whole_sphere-test2")
         | 
| 56 | 
            +
            parser.add_argument(
         | 
| 57 | 
            +
                "--engine", type=str, default="BLENDER_EEVEE", choices=["CYCLES", "BLENDER_EEVEE"]
         | 
| 58 | 
            +
            )
         | 
| 59 | 
            +
            parser.add_argument("--scale", type=float, default=1.0)
         | 
| 60 | 
            +
            parser.add_argument("--num_images", type=int, default=8)
         | 
| 61 | 
            +
            parser.add_argument("--random_images", type=int, default=3)
         | 
| 62 | 
            +
            parser.add_argument("--random_ortho", type=int, default=1)
         | 
| 63 | 
            +
            parser.add_argument("--device", type=str, default="CUDA")
         | 
| 64 | 
            +
            parser.add_argument("--resolution", type=int, default=512)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            argv = sys.argv[sys.argv.index("--") + 1 :]
         | 
| 68 | 
            +
            args = parser.parse_args(argv)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            print('===================', args.engine, '===================')
         | 
| 73 | 
            +
             | 
| 74 | 
            +
            context = bpy.context
         | 
| 75 | 
            +
            scene = context.scene
         | 
| 76 | 
            +
            render = scene.render
         | 
| 77 | 
            +
             | 
| 78 | 
            +
            cam = scene.objects["Camera"]
         | 
| 79 | 
            +
            cam.data.type = 'ORTHO'
         | 
| 80 | 
            +
            cam.data.ortho_scale = 1.
         | 
| 81 | 
            +
            cam.data.lens = 35
         | 
| 82 | 
            +
            cam.data.sensor_height = 32
         | 
| 83 | 
            +
            cam.data.sensor_width = 32
         | 
| 84 | 
            +
             | 
| 85 | 
            +
            cam_constraint = cam.constraints.new(type="TRACK_TO")
         | 
| 86 | 
            +
            cam_constraint.track_axis = "TRACK_NEGATIVE_Z"
         | 
| 87 | 
            +
            cam_constraint.up_axis = "UP_Y"
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            # setup lighting
         | 
| 90 | 
            +
            # bpy.ops.object.light_add(type="AREA")
         | 
| 91 | 
            +
            # light2 = bpy.data.lights["Area"]
         | 
| 92 | 
            +
            # light2.energy = 3000
         | 
| 93 | 
            +
            # bpy.data.objects["Area"].location[2] = 0.5
         | 
| 94 | 
            +
            # bpy.data.objects["Area"].scale[0] = 100
         | 
| 95 | 
            +
            # bpy.data.objects["Area"].scale[1] = 100
         | 
| 96 | 
            +
            # bpy.data.objects["Area"].scale[2] = 100
         | 
| 97 | 
            +
             | 
| 98 | 
            +
            render.engine = args.engine
         | 
| 99 | 
            +
            render.image_settings.file_format = "PNG"
         | 
| 100 | 
            +
            render.image_settings.color_mode = "RGBA"
         | 
| 101 | 
            +
            render.resolution_x = args.resolution
         | 
| 102 | 
            +
            render.resolution_y = args.resolution
         | 
| 103 | 
            +
            render.resolution_percentage = 100
         | 
| 104 | 
            +
            render.threads_mode = 'FIXED'  # 使用固定线程数模式
         | 
| 105 | 
            +
            render.threads = 32  # 设置线程数
         | 
| 106 | 
            +
             | 
| 107 | 
            +
            scene.cycles.device = "GPU"
         | 
| 108 | 
            +
            scene.cycles.samples = 128   # 128
         | 
| 109 | 
            +
            scene.cycles.diffuse_bounces = 1
         | 
| 110 | 
            +
            scene.cycles.glossy_bounces = 1
         | 
| 111 | 
            +
            scene.cycles.transparent_max_bounces = 3  # 3
         | 
| 112 | 
            +
            scene.cycles.transmission_bounces = 3   # 3
         | 
| 113 | 
            +
            # scene.cycles.filter_width = 0.01
         | 
| 114 | 
            +
            bpy.context.scene.cycles.adaptive_threshold = 0
         | 
| 115 | 
            +
            scene.cycles.use_denoising = True
         | 
| 116 | 
            +
            scene.render.film_transparent = True
         | 
| 117 | 
            +
             | 
| 118 | 
            +
            bpy.context.preferences.addons["cycles"].preferences.get_devices()
         | 
| 119 | 
            +
            # Set the device_type
         | 
| 120 | 
            +
            bpy.context.preferences.addons["cycles"].preferences.compute_device_type = 'CUDA' # or "OPENCL"
         | 
| 121 | 
            +
            bpy.context.scene.cycles.tile_size = 8192
         | 
| 122 | 
            +
             | 
| 123 | 
            +
             | 
| 124 | 
            +
            # eevee = scene.eevee
         | 
| 125 | 
            +
            # eevee.use_soft_shadows = True
         | 
| 126 | 
            +
            # eevee.use_ssr = True
         | 
| 127 | 
            +
            # eevee.use_ssr_refraction = True
         | 
| 128 | 
            +
            # eevee.taa_render_samples = 64
         | 
| 129 | 
            +
            # eevee.use_gtao = True
         | 
| 130 | 
            +
            # eevee.gtao_distance = 1
         | 
| 131 | 
            +
            # eevee.use_volumetric_shadows = True
         | 
| 132 | 
            +
            # eevee.volumetric_tile_size = '2'
         | 
| 133 | 
            +
            # eevee.gi_diffuse_bounces = 1
         | 
| 134 | 
            +
            # eevee.gi_cubemap_resolution = '128'
         | 
| 135 | 
            +
            # eevee.gi_visibility_resolution = '16'
         | 
| 136 | 
            +
            # eevee.gi_irradiance_smoothing = 0
         | 
| 137 | 
            +
             | 
| 138 | 
            +
             | 
| 139 | 
            +
            # for depth & normal
         | 
| 140 | 
            +
            context.view_layer.use_pass_normal = True
         | 
| 141 | 
            +
            context.view_layer.use_pass_z = True
         | 
| 142 | 
            +
            context.scene.use_nodes = True
         | 
| 143 | 
            +
             | 
| 144 | 
            +
             | 
| 145 | 
            +
            tree = bpy.context.scene.node_tree
         | 
| 146 | 
            +
            nodes = bpy.context.scene.node_tree.nodes
         | 
| 147 | 
            +
            links = bpy.context.scene.node_tree.links
         | 
| 148 | 
            +
             | 
| 149 | 
            +
            # Clear default nodes
         | 
| 150 | 
            +
            for n in nodes:
         | 
| 151 | 
            +
                nodes.remove(n)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
            # # Create input render layer node.
         | 
| 154 | 
            +
            render_layers = nodes.new('CompositorNodeRLayers')
         | 
| 155 | 
            +
             | 
| 156 | 
            +
            scale_normal = nodes.new(type="CompositorNodeMixRGB")
         | 
| 157 | 
            +
            scale_normal.blend_type = 'MULTIPLY'
         | 
| 158 | 
            +
            scale_normal.inputs[2].default_value = (0.5, 0.5, 0.5, 1)
         | 
| 159 | 
            +
            links.new(render_layers.outputs['Normal'], scale_normal.inputs[1])
         | 
| 160 | 
            +
            bias_normal = nodes.new(type="CompositorNodeMixRGB")
         | 
| 161 | 
            +
            bias_normal.blend_type = 'ADD'
         | 
| 162 | 
            +
            bias_normal.inputs[2].default_value = (0.5, 0.5, 0.5, 0)
         | 
| 163 | 
            +
            links.new(scale_normal.outputs[0], bias_normal.inputs[1])
         | 
| 164 | 
            +
            normal_file_output = nodes.new(type="CompositorNodeOutputFile")
         | 
| 165 | 
            +
            normal_file_output.label = 'Normal Output'
         | 
| 166 | 
            +
            links.new(bias_normal.outputs[0], normal_file_output.inputs[0])
         | 
| 167 | 
            +
             | 
| 168 | 
            +
            normal_file_output.format.file_format = "OPEN_EXR" # default is "PNG"
         | 
| 169 | 
            +
            normal_file_output.format.color_mode = "RGB"  # default is "BW"
         | 
| 170 | 
            +
             | 
| 171 | 
            +
            depth_file_output = nodes.new(type="CompositorNodeOutputFile")
         | 
| 172 | 
            +
            depth_file_output.label = 'Depth Output'
         | 
| 173 | 
            +
            links.new(render_layers.outputs['Depth'], depth_file_output.inputs[0])
         | 
| 174 | 
            +
            depth_file_output.format.file_format = "OPEN_EXR" # default is "PNG"
         | 
| 175 | 
            +
            depth_file_output.format.color_mode = "RGB"  # default is "BW"
         | 
| 176 | 
            +
             | 
| 177 | 
            +
            def prepare_depth_outputs():
         | 
| 178 | 
            +
                tree = bpy.context.scene.node_tree
         | 
| 179 | 
            +
                links = tree.links
         | 
| 180 | 
            +
                render_node = tree.nodes['Render Layers']
         | 
| 181 | 
            +
                depth_out_node = tree.nodes.new(type="CompositorNodeOutputFile")
         | 
| 182 | 
            +
                depth_map_node = tree.nodes.new(type="CompositorNodeMapRange")
         | 
| 183 | 
            +
                depth_out_node.base_path = ''
         | 
| 184 | 
            +
                depth_out_node.format.file_format = 'OPEN_EXR'
         | 
| 185 | 
            +
                depth_out_node.format.color_depth = '32'
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                depth_map_node.inputs[1].default_value = 0.54
         | 
| 188 | 
            +
                depth_map_node.inputs[2].default_value = 1.96
         | 
| 189 | 
            +
                depth_map_node.inputs[3].default_value = 0
         | 
| 190 | 
            +
                depth_map_node.inputs[4].default_value = 1
         | 
| 191 | 
            +
                depth_map_node.use_clamp = True
         | 
| 192 | 
            +
                links.new(render_node.outputs[2],depth_map_node.inputs[0])
         | 
| 193 | 
            +
                links.new(depth_map_node.outputs[0], depth_out_node.inputs[0])
         | 
| 194 | 
            +
                return depth_out_node, depth_map_node
         | 
| 195 | 
            +
             | 
| 196 | 
            +
            depth_file_output, depth_map_node = prepare_depth_outputs()
         | 
| 197 | 
            +
             | 
| 198 | 
            +
             | 
| 199 | 
            +
            def exr_to_png(exr_path):
         | 
| 200 | 
            +
                depth_path = exr_path.replace('.exr', '.png')
         | 
| 201 | 
            +
                exr_image = OpenEXR.InputFile(exr_path)
         | 
| 202 | 
            +
                dw = exr_image.header()['dataWindow']
         | 
| 203 | 
            +
                (width, height) = (dw.max.x - dw.min.x + 1, dw.max.y - dw.min.y + 1)
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                def read_exr(s, width, height):
         | 
| 206 | 
            +
                    mat = np.fromstring(s, dtype=np.float32)
         | 
| 207 | 
            +
                    mat = mat.reshape(height, width)
         | 
| 208 | 
            +
                    return mat
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                dmap, _, _ = [read_exr(s, width, height) for s in exr_image.channels('BGR', Imath.PixelType(Imath.PixelType.FLOAT))]
         | 
| 211 | 
            +
                dmap = np.clip(np.asarray(dmap,np.float64),a_max=1.0, a_min=0.0) * 65535
         | 
| 212 | 
            +
                dmap = Image.fromarray(dmap.astype(np.uint16))
         | 
| 213 | 
            +
                dmap.save(depth_path)
         | 
| 214 | 
            +
                exr_image.close()
         | 
| 215 | 
            +
                # os.system('rm {}'.format(exr_path))
         | 
| 216 | 
            +
             | 
| 217 | 
            +
            def extract_depth(directory):
         | 
| 218 | 
            +
                fns = glob.glob(f'{directory}/*.exr')
         | 
| 219 | 
            +
                for fn in fns: exr_to_png(fn)
         | 
| 220 | 
            +
                os.system(f'rm {directory}/*.exr')
         | 
| 221 | 
            +
             | 
| 222 | 
            +
            def sample_point_on_sphere(radius: float) -> Tuple[float, float, float]:
         | 
| 223 | 
            +
                theta = random.random() * 2 * math.pi
         | 
| 224 | 
            +
                phi = math.acos(2 * random.random() - 1)
         | 
| 225 | 
            +
                return (
         | 
| 226 | 
            +
                    radius * math.sin(phi) * math.cos(theta),
         | 
| 227 | 
            +
                    radius * math.sin(phi) * math.sin(theta),
         | 
| 228 | 
            +
                    radius * math.cos(phi),
         | 
| 229 | 
            +
                )
         | 
| 230 | 
            +
             | 
| 231 | 
            +
            def sample_spherical(radius=3.0, maxz=3.0, minz=0.):
         | 
| 232 | 
            +
                correct = False
         | 
| 233 | 
            +
                while not correct:
         | 
| 234 | 
            +
                    vec = np.random.uniform(-1, 1, 3)
         | 
| 235 | 
            +
                    vec[2] = np.abs(vec[2])
         | 
| 236 | 
            +
                    vec = vec / np.linalg.norm(vec, axis=0) * radius
         | 
| 237 | 
            +
                    if maxz > vec[2] > minz:
         | 
| 238 | 
            +
                        correct = True
         | 
| 239 | 
            +
                return vec
         | 
| 240 | 
            +
             | 
| 241 | 
            +
            def sample_spherical(radius_min=1.5, radius_max=2.0, maxz=1.6, minz=-0.75):
         | 
| 242 | 
            +
                correct = False
         | 
| 243 | 
            +
                while not correct:
         | 
| 244 | 
            +
                    vec = np.random.uniform(-1, 1, 3)
         | 
| 245 | 
            +
            #         vec[2] = np.abs(vec[2])
         | 
| 246 | 
            +
                    radius = np.random.uniform(radius_min, radius_max, 1)
         | 
| 247 | 
            +
                    vec = vec / np.linalg.norm(vec, axis=0) * radius[0]
         | 
| 248 | 
            +
                    if maxz > vec[2] > minz:
         | 
| 249 | 
            +
                        correct = True
         | 
| 250 | 
            +
                return vec
         | 
| 251 | 
            +
             | 
| 252 | 
            +
            def randomize_camera():
         | 
| 253 | 
            +
                elevation = random.uniform(0., 90.)
         | 
| 254 | 
            +
                azimuth = random.uniform(0., 360)
         | 
| 255 | 
            +
                distance = random.uniform(0.8, 1.6)
         | 
| 256 | 
            +
                return set_camera_location(elevation, azimuth, distance)
         | 
| 257 | 
            +
             | 
| 258 | 
            +
            def set_camera_location(elevation, azimuth, distance):
         | 
| 259 | 
            +
                # from https://blender.stackexchange.com/questions/18530/
         | 
| 260 | 
            +
                x, y, z = sample_spherical(radius_min=1.5, radius_max=2.2, maxz=2.2, minz=-2.2)
         | 
| 261 | 
            +
                camera = bpy.data.objects["Camera"]
         | 
| 262 | 
            +
                camera.location = x, y, z
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                direction = - camera.location
         | 
| 265 | 
            +
                rot_quat = direction.to_track_quat('-Z', 'Y')
         | 
| 266 | 
            +
                camera.rotation_euler = rot_quat.to_euler()
         | 
| 267 | 
            +
                return camera
         | 
| 268 | 
            +
             | 
| 269 | 
            +
            def set_camera_mvdream(azimuth, elevation, distance):
         | 
| 270 | 
            +
                # theta, phi = np.deg2rad(azimuth), np.deg2rad(elevation)
         | 
| 271 | 
            +
                azimuth, elevation = np.deg2rad(azimuth), np.deg2rad(elevation)
         | 
| 272 | 
            +
                point = (
         | 
| 273 | 
            +
                    distance * math.cos(azimuth) * math.cos(elevation),
         | 
| 274 | 
            +
                    distance * math.sin(azimuth) * math.cos(elevation),
         | 
| 275 | 
            +
                    distance * math.sin(elevation),
         | 
| 276 | 
            +
                )
         | 
| 277 | 
            +
                camera = bpy.data.objects["Camera"]
         | 
| 278 | 
            +
                camera.location = point
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                direction = -camera.location
         | 
| 281 | 
            +
                rot_quat = direction.to_track_quat('-Z', 'Y')
         | 
| 282 | 
            +
                camera.rotation_euler = rot_quat.to_euler()
         | 
| 283 | 
            +
                return camera
         | 
| 284 | 
            +
             | 
| 285 | 
            +
            def reset_scene() -> None:
         | 
| 286 | 
            +
                """Resets the scene to a clean state.
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                Returns:
         | 
| 289 | 
            +
                    None
         | 
| 290 | 
            +
                """
         | 
| 291 | 
            +
                # delete everything that isn't part of a camera or a light
         | 
| 292 | 
            +
                for obj in bpy.data.objects:
         | 
| 293 | 
            +
                    if obj.type not in {"CAMERA", "LIGHT"}:
         | 
| 294 | 
            +
                        bpy.data.objects.remove(obj, do_unlink=True)
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                # delete all the materials
         | 
| 297 | 
            +
                for material in bpy.data.materials:
         | 
| 298 | 
            +
                    bpy.data.materials.remove(material, do_unlink=True)
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                # delete all the textures
         | 
| 301 | 
            +
                for texture in bpy.data.textures:
         | 
| 302 | 
            +
                    bpy.data.textures.remove(texture, do_unlink=True)
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                # delete all the images
         | 
| 305 | 
            +
                for image in bpy.data.images:
         | 
| 306 | 
            +
                    bpy.data.images.remove(image, do_unlink=True)
         | 
| 307 | 
            +
            def process_ply(obj):
         | 
| 308 | 
            +
                # obj = bpy.context.selected_objects[0]
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                # 创建一个新的材质
         | 
| 311 | 
            +
                material = bpy.data.materials.new(name="VertexColors")
         | 
| 312 | 
            +
                material.use_nodes = True
         | 
| 313 | 
            +
                obj.data.materials.append(material)
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                # 获取材质的节点树
         | 
| 316 | 
            +
                nodes = material.node_tree.nodes
         | 
| 317 | 
            +
                links = material.node_tree.links
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                # 删除原有的'Principled BSDF'节点
         | 
| 320 | 
            +
                principled_bsdf_node = nodes.get("Principled BSDF")
         | 
| 321 | 
            +
                if principled_bsdf_node:
         | 
| 322 | 
            +
                    nodes.remove(principled_bsdf_node)
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                # 创建一个新的'Emission'节点
         | 
| 325 | 
            +
                emission_node = nodes.new(type="ShaderNodeEmission")
         | 
| 326 | 
            +
                emission_node.location = 0, 0
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                # 创建一个'Attribute'节点
         | 
| 329 | 
            +
                attribute_node = nodes.new(type="ShaderNodeAttribute")
         | 
| 330 | 
            +
                attribute_node.location = -300, 0
         | 
| 331 | 
            +
                attribute_node.attribute_name = "Col"  # 顶点颜色属性名称
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                # 创建一个'Output'节点
         | 
| 334 | 
            +
                output_node = nodes.get("Material Output")
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                # 连接节点
         | 
| 337 | 
            +
                links.new(attribute_node.outputs["Color"], emission_node.inputs["Color"])
         | 
| 338 | 
            +
                links.new(emission_node.outputs["Emission"], output_node.inputs["Surface"])
         | 
| 339 | 
            +
             | 
| 340 | 
            +
            # # load the glb model
         | 
| 341 | 
            +
            # def load_object(object_path: str) -> None:
         | 
| 342 | 
            +
             | 
| 343 | 
            +
            #     if object_path.endswith(".glb"):
         | 
| 344 | 
            +
            #         bpy.ops.import_scene.gltf(filepath=object_path, merge_vertices=False)
         | 
| 345 | 
            +
            #     elif object_path.endswith(".fbx"):
         | 
| 346 | 
            +
            #         bpy.ops.import_scene.fbx(filepath=object_path)
         | 
| 347 | 
            +
            #     elif object_path.endswith(".obj"):
         | 
| 348 | 
            +
            #         bpy.ops.import_scene.obj(filepath=object_path)
         | 
| 349 | 
            +
            #     elif object_path.endswith(".ply"):
         | 
| 350 | 
            +
            #         bpy.ops.import_mesh.ply(filepath=object_path)
         | 
| 351 | 
            +
            #         obj = bpy.context.selected_objects[0]
         | 
| 352 | 
            +
            #         obj.rotation_euler[0] = 1.5708
         | 
| 353 | 
            +
            #         # bpy.ops.wm.ply_import(filepath=object_path, directory=os.path.dirname(object_path),forward_axis='X', up_axis='Y')
         | 
| 354 | 
            +
            #         process_ply(obj)
         | 
| 355 | 
            +
            #     else:
         | 
| 356 | 
            +
            #         raise ValueError(f"Unsupported file type: {object_path}")
         | 
| 357 | 
            +
                
         | 
| 358 | 
            +
             | 
| 359 | 
            +
             | 
| 360 | 
            +
            def scene_bbox(
         | 
| 361 | 
            +
                single_obj: Optional[bpy.types.Object] = None, ignore_matrix: bool = False
         | 
| 362 | 
            +
            ) -> Tuple[Vector, Vector]:
         | 
| 363 | 
            +
                """Returns the bounding box of the scene.
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                Taken from Shap-E rendering script
         | 
| 366 | 
            +
                (https://github.com/openai/shap-e/blob/main/shap_e/rendering/blender/blender_script.py#L68-L82)
         | 
| 367 | 
            +
             | 
| 368 | 
            +
                Args:
         | 
| 369 | 
            +
                    single_obj (Optional[bpy.types.Object], optional): If not None, only computes
         | 
| 370 | 
            +
                        the bounding box for the given object. Defaults to None.
         | 
| 371 | 
            +
                    ignore_matrix (bool, optional): Whether to ignore the object's matrix. Defaults
         | 
| 372 | 
            +
                        to False.
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                Raises:
         | 
| 375 | 
            +
                    RuntimeError: If there are no objects in the scene.
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                Returns:
         | 
| 378 | 
            +
                    Tuple[Vector, Vector]: The minimum and maximum coordinates of the bounding box.
         | 
| 379 | 
            +
                """
         | 
| 380 | 
            +
                bbox_min = (math.inf,) * 3
         | 
| 381 | 
            +
                bbox_max = (-math.inf,) * 3
         | 
| 382 | 
            +
                found = False
         | 
| 383 | 
            +
                for obj in get_scene_meshes() if single_obj is None else [single_obj]:
         | 
| 384 | 
            +
                    found = True
         | 
| 385 | 
            +
                    for coord in obj.bound_box:
         | 
| 386 | 
            +
                        coord = Vector(coord)
         | 
| 387 | 
            +
                        if not ignore_matrix:
         | 
| 388 | 
            +
                            coord = obj.matrix_world @ coord
         | 
| 389 | 
            +
                        bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
         | 
| 390 | 
            +
                        bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                if not found:
         | 
| 393 | 
            +
                    raise RuntimeError("no objects in scene to compute bounding box for")
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                return Vector(bbox_min), Vector(bbox_max)
         | 
| 396 | 
            +
             | 
| 397 | 
            +
             | 
| 398 | 
            +
            def get_scene_root_objects() -> Generator[bpy.types.Object, None, None]:
         | 
| 399 | 
            +
                """Returns all root objects in the scene.
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                Yields:
         | 
| 402 | 
            +
                    Generator[bpy.types.Object, None, None]: Generator of all root objects in the
         | 
| 403 | 
            +
                        scene.
         | 
| 404 | 
            +
                """
         | 
| 405 | 
            +
                for obj in bpy.context.scene.objects.values():
         | 
| 406 | 
            +
                    if not obj.parent:
         | 
| 407 | 
            +
                        yield obj
         | 
| 408 | 
            +
             | 
| 409 | 
            +
             | 
| 410 | 
            +
            def get_scene_meshes() -> Generator[bpy.types.Object, None, None]:
         | 
| 411 | 
            +
                """Returns all meshes in the scene.
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                Yields:
         | 
| 414 | 
            +
                    Generator[bpy.types.Object, None, None]: Generator of all meshes in the scene.
         | 
| 415 | 
            +
                """
         | 
| 416 | 
            +
                for obj in bpy.context.scene.objects.values():
         | 
| 417 | 
            +
                    if isinstance(obj.data, (bpy.types.Mesh)):
         | 
| 418 | 
            +
                        yield obj
         | 
| 419 | 
            +
             | 
| 420 | 
            +
             | 
| 421 | 
            +
            # Build intrinsic camera parameters from Blender camera data
         | 
| 422 | 
            +
            #
         | 
| 423 | 
            +
            # See notes on this in
         | 
| 424 | 
            +
            # blender.stackexchange.com/questions/15102/what-is-blenders-camera-projection-matrix-model
         | 
| 425 | 
            +
            def get_calibration_matrix_K_from_blender(camd):
         | 
| 426 | 
            +
                f_in_mm = camd.lens
         | 
| 427 | 
            +
                scene = bpy.context.scene
         | 
| 428 | 
            +
                resolution_x_in_px = scene.render.resolution_x
         | 
| 429 | 
            +
                resolution_y_in_px = scene.render.resolution_y
         | 
| 430 | 
            +
                scale = scene.render.resolution_percentage / 100
         | 
| 431 | 
            +
                sensor_width_in_mm = camd.sensor_width
         | 
| 432 | 
            +
                sensor_height_in_mm = camd.sensor_height
         | 
| 433 | 
            +
                pixel_aspect_ratio = scene.render.pixel_aspect_x / scene.render.pixel_aspect_y
         | 
| 434 | 
            +
                if (camd.sensor_fit == 'VERTICAL'):
         | 
| 435 | 
            +
                    # the sensor height is fixed (sensor fit is horizontal),
         | 
| 436 | 
            +
                    # the sensor width is effectively changed with the pixel aspect ratio
         | 
| 437 | 
            +
                    s_u = resolution_x_in_px * scale / sensor_width_in_mm / pixel_aspect_ratio
         | 
| 438 | 
            +
                    s_v = resolution_y_in_px * scale / sensor_height_in_mm
         | 
| 439 | 
            +
                else: # 'HORIZONTAL' and 'AUTO'
         | 
| 440 | 
            +
                    # the sensor width is fixed (sensor fit is horizontal),
         | 
| 441 | 
            +
                    # the sensor height is effectively changed with the pixel aspect ratio
         | 
| 442 | 
            +
                    pixel_aspect_ratio = scene.render.pixel_aspect_x / scene.render.pixel_aspect_y
         | 
| 443 | 
            +
                    s_u = resolution_x_in_px * scale / sensor_width_in_mm
         | 
| 444 | 
            +
                    s_v = resolution_y_in_px * scale * pixel_aspect_ratio / sensor_height_in_mm
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                # Parameters of intrinsic calibration matrix K
         | 
| 447 | 
            +
                alpha_u = f_in_mm * s_u
         | 
| 448 | 
            +
                alpha_v = f_in_mm * s_v
         | 
| 449 | 
            +
                u_0 = resolution_x_in_px * scale / 2
         | 
| 450 | 
            +
                v_0 = resolution_y_in_px * scale / 2
         | 
| 451 | 
            +
                skew = 0 # only use rectangular pixels
         | 
| 452 | 
            +
             | 
| 453 | 
            +
                K = Matrix(
         | 
| 454 | 
            +
                    ((alpha_u, skew,    u_0),
         | 
| 455 | 
            +
                    (    0  , alpha_v, v_0),
         | 
| 456 | 
            +
                    (    0  , 0,        1 )))
         | 
| 457 | 
            +
                return K
         | 
| 458 | 
            +
             | 
| 459 | 
            +
             | 
| 460 | 
            +
            def get_calibration_matrix_K_from_blender_for_ortho(camd, ortho_scale):
         | 
| 461 | 
            +
                scene = bpy.context.scene
         | 
| 462 | 
            +
                resolution_x_in_px = scene.render.resolution_x
         | 
| 463 | 
            +
                resolution_y_in_px = scene.render.resolution_y
         | 
| 464 | 
            +
                scale = scene.render.resolution_percentage / 100
         | 
| 465 | 
            +
                pixel_aspect_ratio = scene.render.pixel_aspect_x / scene.render.pixel_aspect_y
         | 
| 466 | 
            +
             | 
| 467 | 
            +
                fx = resolution_x_in_px / ortho_scale
         | 
| 468 | 
            +
                fy = resolution_y_in_px / ortho_scale / pixel_aspect_ratio
         | 
| 469 | 
            +
             | 
| 470 | 
            +
                cx = resolution_x_in_px / 2
         | 
| 471 | 
            +
                cy = resolution_y_in_px / 2
         | 
| 472 | 
            +
             | 
| 473 | 
            +
                K = Matrix(
         | 
| 474 | 
            +
                    ((fx, 0, cx),
         | 
| 475 | 
            +
                    (0, fy, cy),
         | 
| 476 | 
            +
                    (0 , 0, 1)))
         | 
| 477 | 
            +
                return K
         | 
| 478 | 
            +
             | 
| 479 | 
            +
             | 
| 480 | 
            +
            def get_3x4_RT_matrix_from_blender(cam):
         | 
| 481 | 
            +
                bpy.context.view_layer.update()
         | 
| 482 | 
            +
                location, rotation = cam.matrix_world.decompose()[0:2]
         | 
| 483 | 
            +
                R = np.asarray(rotation.to_matrix())
         | 
| 484 | 
            +
                t = np.asarray(location)
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                cam_rec = np.asarray([[1, 0, 0], [0, -1, 0], [0, 0, -1]], np.float32)
         | 
| 487 | 
            +
                R = R.T
         | 
| 488 | 
            +
                t = -R @ t
         | 
| 489 | 
            +
                R_world2cv = cam_rec @ R
         | 
| 490 | 
            +
                t_world2cv = cam_rec @ t
         | 
| 491 | 
            +
             | 
| 492 | 
            +
                RT = np.concatenate([R_world2cv,t_world2cv[:,None]],1)
         | 
| 493 | 
            +
                return RT
         | 
| 494 | 
            +
             | 
| 495 | 
            +
            def delete_invisible_objects() -> None:
         | 
| 496 | 
            +
                """Deletes all invisible objects in the scene.
         | 
| 497 | 
            +
             | 
| 498 | 
            +
                Returns:
         | 
| 499 | 
            +
                    None
         | 
| 500 | 
            +
                """
         | 
| 501 | 
            +
                bpy.ops.object.select_all(action="DESELECT")
         | 
| 502 | 
            +
                for obj in scene.objects:
         | 
| 503 | 
            +
                    if obj.hide_viewport or obj.hide_render:
         | 
| 504 | 
            +
                        obj.hide_viewport = False
         | 
| 505 | 
            +
                        obj.hide_render = False
         | 
| 506 | 
            +
                        obj.hide_select = False
         | 
| 507 | 
            +
                        obj.select_set(True)
         | 
| 508 | 
            +
                bpy.ops.object.delete()
         | 
| 509 | 
            +
             | 
| 510 | 
            +
                # Delete invisible collections
         | 
| 511 | 
            +
                invisible_collections = [col for col in bpy.data.collections if col.hide_viewport]
         | 
| 512 | 
            +
                for col in invisible_collections:
         | 
| 513 | 
            +
                    bpy.data.collections.remove(col)
         | 
| 514 | 
            +
             | 
| 515 | 
            +
             | 
| 516 | 
            +
            def normalize_scene():
         | 
| 517 | 
            +
                """Normalizes the scene by scaling and translating it to fit in a unit cube centered
         | 
| 518 | 
            +
                at the origin.
         | 
| 519 | 
            +
             | 
| 520 | 
            +
                Mostly taken from the Point-E / Shap-E rendering script
         | 
| 521 | 
            +
                (https://github.com/openai/point-e/blob/main/point_e/evals/scripts/blender_script.py#L97-L112),
         | 
| 522 | 
            +
                but fix for multiple root objects: (see bug report here:
         | 
| 523 | 
            +
                https://github.com/openai/shap-e/pull/60).
         | 
| 524 | 
            +
             | 
| 525 | 
            +
                Returns:
         | 
| 526 | 
            +
                    None
         | 
| 527 | 
            +
                """
         | 
| 528 | 
            +
                if len(list(get_scene_root_objects())) > 1:
         | 
| 529 | 
            +
                    print('we have more than one root objects!!')
         | 
| 530 | 
            +
                    # create an empty object to be used as a parent for all root objects
         | 
| 531 | 
            +
                    parent_empty = bpy.data.objects.new("ParentEmpty", None)
         | 
| 532 | 
            +
                    bpy.context.scene.collection.objects.link(parent_empty)
         | 
| 533 | 
            +
             | 
| 534 | 
            +
                    # parent all root objects to the empty object
         | 
| 535 | 
            +
                    for obj in get_scene_root_objects():
         | 
| 536 | 
            +
                        if obj != parent_empty:
         | 
| 537 | 
            +
                            obj.parent = parent_empty
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                bbox_min, bbox_max = scene_bbox()
         | 
| 540 | 
            +
                dxyz = bbox_max - bbox_min
         | 
| 541 | 
            +
                dist = np.sqrt(dxyz[0]**2+ dxyz[1]**2+dxyz[2]**2)
         | 
| 542 | 
            +
                scale = 1 / dist
         | 
| 543 | 
            +
                for obj in get_scene_root_objects():
         | 
| 544 | 
            +
                    obj.scale = obj.scale * scale
         | 
| 545 | 
            +
             | 
| 546 | 
            +
                # Apply scale to matrix_world.
         | 
| 547 | 
            +
                bpy.context.view_layer.update()
         | 
| 548 | 
            +
                bbox_min, bbox_max = scene_bbox()
         | 
| 549 | 
            +
                offset = -(bbox_min + bbox_max) / 2
         | 
| 550 | 
            +
                for obj in get_scene_root_objects():
         | 
| 551 | 
            +
                    obj.matrix_world.translation += offset
         | 
| 552 | 
            +
                bpy.ops.object.select_all(action="DESELECT")
         | 
| 553 | 
            +
             | 
| 554 | 
            +
                # unparent the camera
         | 
| 555 | 
            +
                bpy.data.objects["Camera"].parent = None
         | 
| 556 | 
            +
                return scale, offset
         | 
| 557 | 
            +
             | 
| 558 | 
            +
            def download_object(object_url: str) -> str:
         | 
| 559 | 
            +
                """Download the object and return the path."""
         | 
| 560 | 
            +
                # uid = uuid.uuid4()
         | 
| 561 | 
            +
                uid = object_url.split("/")[-1].split(".")[0]
         | 
| 562 | 
            +
                tmp_local_path = os.path.join("tmp-objects", f"{uid}.glb" + ".tmp")
         | 
| 563 | 
            +
                local_path = os.path.join("tmp-objects", f"{uid}.glb")
         | 
| 564 | 
            +
                # wget the file and put it in local_path
         | 
| 565 | 
            +
                os.makedirs(os.path.dirname(tmp_local_path), exist_ok=True)
         | 
| 566 | 
            +
                urllib.request.urlretrieve(object_url, tmp_local_path)
         | 
| 567 | 
            +
                os.rename(tmp_local_path, local_path)
         | 
| 568 | 
            +
                # get the absolute path
         | 
| 569 | 
            +
                local_path = os.path.abspath(local_path)
         | 
| 570 | 
            +
                return local_path
         | 
| 571 | 
            +
             | 
| 572 | 
            +
             | 
| 573 | 
            +
            def render_and_save(view_id, object_uid, len_val, azimuth, elevation, distance, ortho=False):
         | 
| 574 | 
            +
                # print(view_id)
         | 
| 575 | 
            +
                # render the image
         | 
| 576 | 
            +
                render_path = os.path.join(args.output_dir, 'image', f"{view_id:03d}.png")
         | 
| 577 | 
            +
                scene.render.filepath = render_path
         | 
| 578 | 
            +
             | 
| 579 | 
            +
                if not ortho:
         | 
| 580 | 
            +
                    cam.data.lens = len_val
         | 
| 581 | 
            +
             | 
| 582 | 
            +
                depth_map_node.inputs[1].default_value = distance - 1
         | 
| 583 | 
            +
                depth_map_node.inputs[2].default_value = distance + 1
         | 
| 584 | 
            +
                depth_file_output.base_path = os.path.join(args.output_dir, object_uid, 'depth')
         | 
| 585 | 
            +
             | 
| 586 | 
            +
                depth_file_output.file_slots[0].path = f"{view_id:03d}"
         | 
| 587 | 
            +
                normal_file_output.file_slots[0].path = f"{view_id:03d}"
         | 
| 588 | 
            +
             | 
| 589 | 
            +
                if not os.path.exists(os.path.join(args.output_dir,  'normal', f"{view_id+1:03d}.png")):
         | 
| 590 | 
            +
                    bpy.ops.render.render(write_still=True)
         | 
| 591 | 
            +
             | 
| 592 | 
            +
               
         | 
| 593 | 
            +
                if os.path.exists(os.path.join(args.output_dir, object_uid, 'depth', f"{view_id:03d}0001.exr")):
         | 
| 594 | 
            +
                    os.rename(os.path.join(args.output_dir, object_uid, 'depth', f"{view_id:03d}0001.exr"),
         | 
| 595 | 
            +
                              os.path.join(args.output_dir, object_uid, 'depth', f"{view_id:03d}.exr"))
         | 
| 596 | 
            +
             | 
| 597 | 
            +
                if os.path.exists(os.path.join(args.output_dir,  'normal', f"{view_id:03d}0001.exr")):
         | 
| 598 | 
            +
                    normal = cv2.imread(os.path.join(args.output_dir,  'normal', f"{view_id:03d}0001.exr"), cv2.IMREAD_UNCHANGED)
         | 
| 599 | 
            +
                    normal_unit16 = (normal * 65535).astype(np.uint16)
         | 
| 600 | 
            +
                    cv2.imwrite(os.path.join(args.output_dir,  'normal', f"{view_id:03d}.png"), normal_unit16)
         | 
| 601 | 
            +
                    os.remove(os.path.join(args.output_dir,  'normal', f"{view_id:03d}0001.exr"))
         | 
| 602 | 
            +
             | 
| 603 | 
            +
                # save camera KRT matrix
         | 
| 604 | 
            +
                if ortho:
         | 
| 605 | 
            +
                    K = get_calibration_matrix_K_from_blender_for_ortho(cam.data, ortho_scale=cam.data.ortho_scale)
         | 
| 606 | 
            +
                else:
         | 
| 607 | 
            +
                    K = get_calibration_matrix_K_from_blender(cam.data)
         | 
| 608 | 
            +
             | 
| 609 | 
            +
                RT = get_3x4_RT_matrix_from_blender(cam)
         | 
| 610 | 
            +
                para_path = os.path.join(args.output_dir, 'camera', f"{view_id:03d}.npy")
         | 
| 611 | 
            +
                # np.save(RT_path, RT)
         | 
| 612 | 
            +
                paras = {}
         | 
| 613 | 
            +
                paras['intrinsic'] = np.array(K, np.float32)
         | 
| 614 | 
            +
                paras['extrinsic'] = np.array(RT, np.float32)
         | 
| 615 | 
            +
                paras['fov'] = cam.data.angle
         | 
| 616 | 
            +
                paras['azimuth'] = azimuth
         | 
| 617 | 
            +
                paras['elevation'] = elevation
         | 
| 618 | 
            +
                paras['distance'] = distance
         | 
| 619 | 
            +
                paras['focal'] = cam.data.lens
         | 
| 620 | 
            +
                paras['sensor_width'] = cam.data.sensor_width
         | 
| 621 | 
            +
                paras['near'] = distance - 1
         | 
| 622 | 
            +
                paras['far'] = distance + 1
         | 
| 623 | 
            +
                paras['camera'] = 'persp' if not ortho else 'ortho'
         | 
| 624 | 
            +
                np.save(para_path, paras)
         | 
| 625 | 
            +
             | 
| 626 | 
            +
            def render_and_save_smpl(view_id, object_uid, len_val, azimuth, elevation, distance, ortho=False):
         | 
| 627 | 
            +
             | 
| 628 | 
            +
             | 
| 629 | 
            +
                if not ortho:
         | 
| 630 | 
            +
                    cam.data.lens = len_val
         | 
| 631 | 
            +
             | 
| 632 | 
            +
                render_path = os.path.join(args.output_dir, 'smpl_image', f"{view_id:03d}.png")
         | 
| 633 | 
            +
                scene.render.filepath = render_path
         | 
| 634 | 
            +
                
         | 
| 635 | 
            +
                normal_file_output.file_slots[0].path = f"{view_id:03d}"
         | 
| 636 | 
            +
                if not os.path.exists(os.path.join(args.output_dir, 'smpl_normal', f"{view_id:03d}.png")):
         | 
| 637 | 
            +
                    bpy.ops.render.render(write_still=True)
         | 
| 638 | 
            +
             | 
| 639 | 
            +
                if os.path.exists(os.path.join(args.output_dir,  'smpl_normal', f"{view_id:03d}0001.exr")):
         | 
| 640 | 
            +
                    normal = cv2.imread(os.path.join(args.output_dir,  'smpl_normal', f"{view_id:03d}0001.exr"), cv2.IMREAD_UNCHANGED)
         | 
| 641 | 
            +
                    normal_unit16 = (normal * 65535).astype(np.uint16)
         | 
| 642 | 
            +
                    cv2.imwrite(os.path.join(args.output_dir,  'smpl_normal', f"{view_id:03d}.png"), normal_unit16)
         | 
| 643 | 
            +
                    os.remove(os.path.join(args.output_dir,  'smpl_normal', f"{view_id:03d}0001.exr"))
         | 
| 644 | 
            +
             | 
| 645 | 
            +
             | 
| 646 | 
            +
             | 
| 647 | 
            +
            def scene_meshes():
         | 
| 648 | 
            +
                for obj in bpy.context.scene.objects.values():
         | 
| 649 | 
            +
                    if isinstance(obj.data, (bpy.types.Mesh)):
         | 
| 650 | 
            +
                        yield obj
         | 
| 651 | 
            +
             | 
| 652 | 
            +
            def load_object(object_path: str) -> None:
         | 
| 653 | 
            +
                """Loads a glb model into the scene."""
         | 
| 654 | 
            +
                if object_path.endswith(".glb"):
         | 
| 655 | 
            +
                    bpy.ops.import_scene.gltf(filepath=object_path, merge_vertices=False)
         | 
| 656 | 
            +
                elif object_path.endswith(".fbx"):
         | 
| 657 | 
            +
                    bpy.ops.import_scene.fbx(filepath=object_path)
         | 
| 658 | 
            +
                elif object_path.endswith(".obj"):
         | 
| 659 | 
            +
                    bpy.ops.import_scene.obj(filepath=object_path)
         | 
| 660 | 
            +
                    obj = bpy.context.selected_objects[0]
         | 
| 661 | 
            +
                    obj.rotation_euler[0] = 6.28319
         | 
| 662 | 
            +
                    # obj.rotation_euler[2] = 1.5708
         | 
| 663 | 
            +
                elif object_path.endswith(".ply"):
         | 
| 664 | 
            +
                    bpy.ops.import_mesh.ply(filepath=object_path)
         | 
| 665 | 
            +
                    obj = bpy.context.selected_objects[0]
         | 
| 666 | 
            +
                    obj.rotation_euler[0] = 1.5708
         | 
| 667 | 
            +
                    obj.rotation_euler[2] = 1.5708
         | 
| 668 | 
            +
                    # bpy.ops.wm.ply_import(filepath=object_path, directory=os.path.dirname(object_path),forward_axis='X', up_axis='Y')
         | 
| 669 | 
            +
                    process_ply(obj)
         | 
| 670 | 
            +
                else:
         | 
| 671 | 
            +
                    raise ValueError(f"Unsupported file type: {object_path}")
         | 
| 672 | 
            +
             | 
| 673 | 
            +
            def save_images(object_file: str, smpl_file: str) -> None:
         | 
| 674 | 
            +
                """Saves rendered images of the object in the scene."""
         | 
| 675 | 
            +
                object_uid = '' # os.path.basename(object_file).split(".")[0]
         | 
| 676 | 
            +
            #     # if we already render this object, we skip it
         | 
| 677 | 
            +
                if os.path.exists(os.path.join(args.output_dir,  'meta.npy')): return
         | 
| 678 | 
            +
                os.makedirs(args.output_dir, exist_ok=True)
         | 
| 679 | 
            +
                os.makedirs(os.path.join(args.output_dir,  'camera'), exist_ok=True)
         | 
| 680 | 
            +
             | 
| 681 | 
            +
                reset_scene()
         | 
| 682 | 
            +
                load_object(object_file)
         | 
| 683 | 
            +
                
         | 
| 684 | 
            +
                lights = [obj for obj in bpy.context.scene.objects if obj.type == 'LIGHT']
         | 
| 685 | 
            +
                for light in lights:
         | 
| 686 | 
            +
                    bpy.data.objects.remove(light, do_unlink=True)
         | 
| 687 | 
            +
                
         | 
| 688 | 
            +
            #     bproc.init()
         | 
| 689 | 
            +
                
         | 
| 690 | 
            +
                world_tree = bpy.context.scene.world.node_tree
         | 
| 691 | 
            +
                back_node = world_tree.nodes['Background']
         | 
| 692 | 
            +
                env_light = 0.5
         | 
| 693 | 
            +
                back_node.inputs['Color'].default_value = Vector([env_light, env_light, env_light, 1.0])
         | 
| 694 | 
            +
                back_node.inputs['Strength'].default_value = 1.0
         | 
| 695 | 
            +
                
         | 
| 696 | 
            +
                #Make light just directional, disable shadows.
         | 
| 697 | 
            +
                light_data = bpy.data.lights.new(name=f'Light', type='SUN')
         | 
| 698 | 
            +
                light = bpy.data.objects.new(name=f'Light', object_data=light_data)
         | 
| 699 | 
            +
                bpy.context.collection.objects.link(light)
         | 
| 700 | 
            +
                light = bpy.data.lights['Light']
         | 
| 701 | 
            +
                light.use_shadow = False
         | 
| 702 | 
            +
                # Possibly disable specular shading:
         | 
| 703 | 
            +
                light.specular_factor = 1.0
         | 
| 704 | 
            +
                light.energy = 5.0
         | 
| 705 | 
            +
             | 
| 706 | 
            +
                #Add another light source so stuff facing away from light is not completely dark
         | 
| 707 | 
            +
                light_data = bpy.data.lights.new(name=f'Light2', type='SUN')
         | 
| 708 | 
            +
                light = bpy.data.objects.new(name=f'Light2', object_data=light_data)
         | 
| 709 | 
            +
                bpy.context.collection.objects.link(light)
         | 
| 710 | 
            +
                light2 = bpy.data.lights['Light2']
         | 
| 711 | 
            +
                light2.use_shadow = False
         | 
| 712 | 
            +
                light2.specular_factor = 1.0
         | 
| 713 | 
            +
                light2.energy = 3 #0.015
         | 
| 714 | 
            +
                bpy.data.objects['Light2'].rotation_euler = bpy.data.objects['Light2'].rotation_euler
         | 
| 715 | 
            +
                bpy.data.objects['Light2'].rotation_euler[0] += 180
         | 
| 716 | 
            +
             | 
| 717 | 
            +
                #Add another light source so stuff facing away from light is not completely dark
         | 
| 718 | 
            +
                light_data = bpy.data.lights.new(name=f'Light3', type='SUN')
         | 
| 719 | 
            +
                light = bpy.data.objects.new(name=f'Light3', object_data=light_data)
         | 
| 720 | 
            +
                bpy.context.collection.objects.link(light)
         | 
| 721 | 
            +
                light3 = bpy.data.lights['Light3']
         | 
| 722 | 
            +
                light3.use_shadow = False
         | 
| 723 | 
            +
                light3.specular_factor = 1.0
         | 
| 724 | 
            +
                light3.energy = 3 #0.015
         | 
| 725 | 
            +
                bpy.data.objects['Light3'].rotation_euler = bpy.data.objects['Light3'].rotation_euler
         | 
| 726 | 
            +
                bpy.data.objects['Light3'].rotation_euler[0] += 90
         | 
| 727 | 
            +
             | 
| 728 | 
            +
                #Add another light source so stuff facing away from light is not completely dark
         | 
| 729 | 
            +
                light_data = bpy.data.lights.new(name=f'Light4', type='SUN')
         | 
| 730 | 
            +
                light = bpy.data.objects.new(name=f'Light4', object_data=light_data)
         | 
| 731 | 
            +
                bpy.context.collection.objects.link(light)
         | 
| 732 | 
            +
                light4 = bpy.data.lights['Light4']
         | 
| 733 | 
            +
                light4.use_shadow = False
         | 
| 734 | 
            +
                light4.specular_factor = 1.0
         | 
| 735 | 
            +
                light4.energy = 3 #0.015
         | 
| 736 | 
            +
                bpy.data.objects['Light4'].rotation_euler = bpy.data.objects['Light4'].rotation_euler
         | 
| 737 | 
            +
                bpy.data.objects['Light4'].rotation_euler[0] += -90
         | 
| 738 | 
            +
                
         | 
| 739 | 
            +
                scale, offset = normalize_scene()
         | 
| 740 | 
            +
                
         | 
| 741 | 
            +
             | 
| 742 | 
            +
                try:
         | 
| 743 | 
            +
                    # some objects' normals are affected by textures
         | 
| 744 | 
            +
                    mesh_objects = [obj for obj in scene_meshes()]
         | 
| 745 | 
            +
                    main_bsdf_name = 'BsdfPrincipled'
         | 
| 746 | 
            +
                    normal_name = 'Normal'
         | 
| 747 | 
            +
                    for obj in mesh_objects:
         | 
| 748 | 
            +
                        for mat in obj.data.materials:
         | 
| 749 | 
            +
                            for node in mat.node_tree.nodes:
         | 
| 750 | 
            +
                                if main_bsdf_name in node.bl_idname:
         | 
| 751 | 
            +
                                    principled_bsdf = node
         | 
| 752 | 
            +
                                    # remove links, we don't want add normal textures
         | 
| 753 | 
            +
                                    if principled_bsdf.inputs[normal_name].links:
         | 
| 754 | 
            +
                                        mat.node_tree.links.remove(principled_bsdf.inputs[normal_name].links[0])
         | 
| 755 | 
            +
                except:
         | 
| 756 | 
            +
                    print("don't know why")
         | 
| 757 | 
            +
                # create an empty object to track
         | 
| 758 | 
            +
                empty = bpy.data.objects.new("Empty", None)
         | 
| 759 | 
            +
                scene.collection.objects.link(empty)
         | 
| 760 | 
            +
                cam_constraint.target = empty
         | 
| 761 | 
            +
             | 
| 762 | 
            +
                subject_width = 1.0
         | 
| 763 | 
            +
                
         | 
| 764 | 
            +
                normal_file_output.base_path = os.path.join(args.output_dir, object_uid, 'normal')
         | 
| 765 | 
            +
                for i in range(args.num_images):
         | 
| 766 | 
            +
                    # change the camera to orthogonal
         | 
| 767 | 
            +
                    cam.data.type = 'ORTHO'
         | 
| 768 | 
            +
                    cam.data.ortho_scale = subject_width
         | 
| 769 | 
            +
                    distance = 1.5
         | 
| 770 | 
            +
                    azimuth = i * 360 / args.num_images
         | 
| 771 | 
            +
                    bpy.context.view_layer.update()
         | 
| 772 | 
            +
                    set_camera_mvdream(azimuth, 0, distance)
         | 
| 773 | 
            +
                    render_and_save(i * (args.random_images+1), object_uid, -1, azimuth, 0, distance, ortho=True)
         | 
| 774 | 
            +
                extract_depth(os.path.join(args.output_dir, object_uid, 'depth'))
         | 
| 775 | 
            +
            #     ####  smpl
         | 
| 776 | 
            +
                reset_scene()
         | 
| 777 | 
            +
                load_object(smpl_file)
         | 
| 778 | 
            +
             | 
| 779 | 
            +
                lights = [obj for obj in bpy.context.scene.objects if obj.type == 'LIGHT']
         | 
| 780 | 
            +
                for light in lights:
         | 
| 781 | 
            +
                    bpy.data.objects.remove(light, do_unlink=True)
         | 
| 782 | 
            +
                
         | 
| 783 | 
            +
                scale, offset = normalize_scene()
         | 
| 784 | 
            +
             | 
| 785 | 
            +
                try:
         | 
| 786 | 
            +
                    # some objects' normals are affected by textures
         | 
| 787 | 
            +
                    mesh_objects = [obj for obj in scene_meshes()]
         | 
| 788 | 
            +
                    main_bsdf_name = 'BsdfPrincipled'
         | 
| 789 | 
            +
                    normal_name = 'Normal'
         | 
| 790 | 
            +
                    for obj in mesh_objects:
         | 
| 791 | 
            +
                        for mat in obj.data.materials:
         | 
| 792 | 
            +
                            for node in mat.node_tree.nodes:
         | 
| 793 | 
            +
                                if main_bsdf_name in node.bl_idname:
         | 
| 794 | 
            +
                                    principled_bsdf = node
         | 
| 795 | 
            +
                                    # remove links, we don't want add normal textures
         | 
| 796 | 
            +
                                    if principled_bsdf.inputs[normal_name].links:
         | 
| 797 | 
            +
                                        mat.node_tree.links.remove(principled_bsdf.inputs[normal_name].links[0])
         | 
| 798 | 
            +
                except:
         | 
| 799 | 
            +
                    print("don't know why")
         | 
| 800 | 
            +
                # create an empty object to track
         | 
| 801 | 
            +
                empty = bpy.data.objects.new("Empty", None)
         | 
| 802 | 
            +
                scene.collection.objects.link(empty)
         | 
| 803 | 
            +
                cam_constraint.target = empty
         | 
| 804 | 
            +
             | 
| 805 | 
            +
                subject_width = 1.0
         | 
| 806 | 
            +
                
         | 
| 807 | 
            +
                normal_file_output.base_path = os.path.join(args.output_dir, object_uid, 'smpl_normal')
         | 
| 808 | 
            +
                for i in range(args.num_images):
         | 
| 809 | 
            +
                    # change the camera to orthogonal
         | 
| 810 | 
            +
                    cam.data.type = 'ORTHO'
         | 
| 811 | 
            +
                    cam.data.ortho_scale = subject_width
         | 
| 812 | 
            +
                    distance = 1.5
         | 
| 813 | 
            +
                    azimuth = i * 360 / args.num_images
         | 
| 814 | 
            +
                    bpy.context.view_layer.update()
         | 
| 815 | 
            +
                    set_camera_mvdream(azimuth, 0, distance)
         | 
| 816 | 
            +
                    render_and_save_smpl(i * (args.random_images+1), object_uid, -1, azimuth, 0, distance, ortho=True)
         | 
| 817 | 
            +
                    
         | 
| 818 | 
            +
             | 
| 819 | 
            +
                np.save(os.path.join(args.output_dir, object_uid, 'meta.npy'), np.asarray([scale, offset[0], offset[1], offset[1]],np.float32))
         | 
| 820 | 
            +
             | 
| 821 | 
            +
             | 
| 822 | 
            +
            if __name__ == "__main__":
         | 
| 823 | 
            +
                try:
         | 
| 824 | 
            +
                    start_i = time.time()
         | 
| 825 | 
            +
                    if args.object_path.startswith("http"):
         | 
| 826 | 
            +
                        local_path = download_object(args.object_path)
         | 
| 827 | 
            +
                    else:
         | 
| 828 | 
            +
                        local_path = args.object_path
         | 
| 829 | 
            +
                    save_images(local_path, args.smpl_path)
         | 
| 830 | 
            +
                    end_i = time.time()
         | 
| 831 | 
            +
                    print("Finished", local_path, "in", end_i - start_i, "seconds")
         | 
| 832 | 
            +
                    # delete the object if it was downloaded
         | 
| 833 | 
            +
                    if args.object_path.startswith("http"):
         | 
| 834 | 
            +
                        os.remove(local_path)
         | 
| 835 | 
            +
                except Exception as e:
         | 
| 836 | 
            +
                    print("Failed to render", args.object_path)
         | 
| 837 | 
            +
                    print(e)
         | 
    	
        blender/check_render.py
    ADDED
    
    | @@ -0,0 +1,46 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            from tqdm import tqdm
         | 
| 3 | 
            +
            import json
         | 
| 4 | 
            +
            from icecream import ic
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            def check_render(dataset, st=None, end=None):
         | 
| 8 | 
            +
                total_lists = []
         | 
| 9 | 
            +
                with open(dataset+'.json', 'r') as f:
         | 
| 10 | 
            +
                    glb_list = json.load(f)
         | 
| 11 | 
            +
                    for x in glb_list:
         | 
| 12 | 
            +
                        total_lists.append(x.split('/')[-2] )
         | 
| 13 | 
            +
                
         | 
| 14 | 
            +
                if st is not None:
         | 
| 15 | 
            +
                    end = min(end, len(total_lists))
         | 
| 16 | 
            +
                    total_lists = total_lists[st:end]
         | 
| 17 | 
            +
                    glb_list = glb_list[st:end]
         | 
| 18 | 
            +
                
         | 
| 19 | 
            +
                save_dir = '/data/lipeng/human_8view_with_smplx/'+dataset
         | 
| 20 | 
            +
                unrendered = set(total_lists) - set(os.listdir(save_dir))
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                num_finish = 0
         | 
| 23 | 
            +
                num_failed = len(unrendered)
         | 
| 24 | 
            +
                failed_case = []
         | 
| 25 | 
            +
                for case in os.listdir(save_dir):
         | 
| 26 | 
            +
                    if not os.path.exists(os.path.join(save_dir, case, 'smpl_normal', '007.png')):                
         | 
| 27 | 
            +
                        failed_case.append(case)
         | 
| 28 | 
            +
                        num_failed += 1
         | 
| 29 | 
            +
                    else:
         | 
| 30 | 
            +
                        num_finish += 1
         | 
| 31 | 
            +
                ic(num_failed)
         | 
| 32 | 
            +
                ic(num_finish)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
                need_render = []
         | 
| 36 | 
            +
                for full_path in glb_list:
         | 
| 37 | 
            +
                    for case in failed_case:
         | 
| 38 | 
            +
                        if case in full_path:
         | 
| 39 | 
            +
                            need_render.append(full_path)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                with open('need_render.json', 'w') as f:
         | 
| 42 | 
            +
                    json.dump(need_render, f, indent=4)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            if __name__ == '__main__':
         | 
| 45 | 
            +
                dataset = 'THuman2.1'
         | 
| 46 | 
            +
                check_render(dataset)
         | 
    	
        blender/count.py
    ADDED
    
    | @@ -0,0 +1,44 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import json
         | 
| 3 | 
            +
            def find_files(directory, extensions):
         | 
| 4 | 
            +
                results = []
         | 
| 5 | 
            +
                for foldername, subfolders, filenames in os.walk(directory):
         | 
| 6 | 
            +
                    for filename in filenames:
         | 
| 7 | 
            +
                        if filename.endswith(extensions):
         | 
| 8 | 
            +
                            file_path = os.path.abspath(os.path.join(foldername, filename))
         | 
| 9 | 
            +
                            results.append(file_path)
         | 
| 10 | 
            +
                return results
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            def count_customhumans(root):
         | 
| 13 | 
            +
                directory_path = ['CustomHumans/mesh']
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                extensions = ('.ply', '.obj')
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                lists = []
         | 
| 18 | 
            +
                for dataset_path in directory_path:
         | 
| 19 | 
            +
                    dir = os.path.join(root, dataset_path)
         | 
| 20 | 
            +
                    file_paths = find_files(dir, extensions)
         | 
| 21 | 
            +
                    # import pdb;pdb.set_trace()
         | 
| 22 | 
            +
                    dataset_name = dataset_path.split('/')[0]
         | 
| 23 | 
            +
                    for file_path in file_paths:
         | 
| 24 | 
            +
                        lists.append(file_path.replace(root, ""))
         | 
| 25 | 
            +
                with open(f'{dataset_name}.json', 'w') as f:
         | 
| 26 | 
            +
                    json.dump(lists, f, indent=4)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            def count_thuman21(root):
         | 
| 29 | 
            +
                directory_path = ['THuman2.1/mesh']
         | 
| 30 | 
            +
                extensions = ('.ply', '.obj')
         | 
| 31 | 
            +
                lists = []
         | 
| 32 | 
            +
                for dataset_path in directory_path:
         | 
| 33 | 
            +
                    dir = os.path.join(root, dataset_path)
         | 
| 34 | 
            +
                    file_paths = find_files(dir, extensions)
         | 
| 35 | 
            +
                    dataset_name = dataset_path.split('/')[0]
         | 
| 36 | 
            +
                    for file_path in file_paths:
         | 
| 37 | 
            +
                        lists.append(file_path.replace(root, ""))
         | 
| 38 | 
            +
                with open(f'{dataset_name}.json', 'w') as f:
         | 
| 39 | 
            +
                    json.dump(lists, f, indent=4)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            if __name__ == '__main__':
         | 
| 42 | 
            +
                root = '/data/lipeng/human_scan/'  
         | 
| 43 | 
            +
                # count_customhumans(root)
         | 
| 44 | 
            +
                count_thuman21(root)
         | 
    	
        blender/distribute.py
    ADDED
    
    | @@ -0,0 +1,149 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import glob
         | 
| 2 | 
            +
            import json
         | 
| 3 | 
            +
            import multiprocessing
         | 
| 4 | 
            +
            import shutil
         | 
| 5 | 
            +
            import subprocess
         | 
| 6 | 
            +
            import time
         | 
| 7 | 
            +
            from dataclasses import dataclass
         | 
| 8 | 
            +
            from typing import Optional
         | 
| 9 | 
            +
            import os
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import boto3
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            from glob import glob
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            import argparse
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            parser = argparse.ArgumentParser(description='distributed rendering')
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            parser.add_argument('--workers_per_gpu', type=int, default=10,
         | 
| 21 | 
            +
                                help='number of workers per gpu.')
         | 
| 22 | 
            +
            parser.add_argument('--input_models_path', type=str, default='/data/lipeng/human_scan/',
         | 
| 23 | 
            +
                                help='Path to a json file containing a list of 3D object files.')
         | 
| 24 | 
            +
            parser.add_argument('--num_gpus', type=int, default=-1,
         | 
| 25 | 
            +
                                help='number of gpus to use. -1 means all available gpus.')
         | 
| 26 | 
            +
            parser.add_argument('--gpu_list',nargs='+', type=int, 
         | 
| 27 | 
            +
                                help='the avalaible gpus')
         | 
| 28 | 
            +
            parser.add_argument('--resolution', type=int, default=512,
         | 
| 29 | 
            +
                                help='')
         | 
| 30 | 
            +
            parser.add_argument('--random_images', type=int, default=0)
         | 
| 31 | 
            +
            parser.add_argument('--start_i', type=int, default=0,
         | 
| 32 | 
            +
                                help='the index of first object to be rendered.')
         | 
| 33 | 
            +
            parser.add_argument('--end_i', type=int, default=-1,
         | 
| 34 | 
            +
                                help='the index of the last object to be rendered.')
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            parser.add_argument('--data_dir', type=str, default='/data/lipeng/human_scan/',
         | 
| 37 | 
            +
                                help='Path to a json file containing a list of 3D object files.')
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            parser.add_argument('--json_path', type=str, default='2K2K.json')
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            parser.add_argument('--save_dir', type=str, default='/data/lipeng/human_8view',
         | 
| 42 | 
            +
                                help='Path to a json file containing a list of 3D object files.')
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            parser.add_argument('--ortho_scale', type=float, default=1.,
         | 
| 45 | 
            +
                                help='ortho rendering usage; how large the object is')
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            args = parser.parse_args()
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            def parse_obj_list(xs):
         | 
| 51 | 
            +
                cases = []
         | 
| 52 | 
            +
                # print(xs[:2])
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                for x in xs:
         | 
| 55 | 
            +
                    if 'THuman3.0' in x:
         | 
| 56 | 
            +
                        # print(apath)
         | 
| 57 | 
            +
                        splits = x.split('/')
         | 
| 58 | 
            +
                        x = os.path.join('THuman3.0', splits[-2])
         | 
| 59 | 
            +
                    elif 'THuman2.1' in x:
         | 
| 60 | 
            +
                        splits = x.split('/')
         | 
| 61 | 
            +
                        x = os.path.join('THuman2.1', splits[-2])
         | 
| 62 | 
            +
                    elif 'CustomHumans' in x:
         | 
| 63 | 
            +
                        splits = x.split('/')
         | 
| 64 | 
            +
                        x = os.path.join('CustomHumans', splits[-2])
         | 
| 65 | 
            +
                    elif '1M' in x:
         | 
| 66 | 
            +
                        splits = x.split('/')
         | 
| 67 | 
            +
                        x = os.path.join('2K2K', splits[-2])
         | 
| 68 | 
            +
                    elif 'realistic_8k_model' in x:
         | 
| 69 | 
            +
                        splits = x.split('/')
         | 
| 70 | 
            +
                        x = os.path.join('realistic_8k_model', splits[-1].split('.')[0])
         | 
| 71 | 
            +
                    cases.append(f'{args.save_dir}/{x}') 
         | 
| 72 | 
            +
                return  cases
         | 
| 73 | 
            +
                 
         | 
| 74 | 
            +
             | 
| 75 | 
            +
            with open(args.json_path, 'r') as f:
         | 
| 76 | 
            +
                glb_list = json.load(f)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
            # glb_list = ['THuman2.1/mesh/1618/1618.obj']
         | 
| 79 | 
            +
            # glb_list = ['THuman3.0/00024_1/00024_0006/mesh.obj']
         | 
| 80 | 
            +
            # glb_list = ['CustomHumans/mesh/0383_00070_02_00061/mesh-f00061.obj'] 
         | 
| 81 | 
            +
            # glb_list = ['1M/01968/01968.ply', '1M/00103/00103.ply']
         | 
| 82 | 
            +
            # glb_list = ['realistic_8k_model/01aab099a2fe4af7be120110a385105d.glb']
         | 
| 83 | 
            +
             | 
| 84 | 
            +
            total_num_glbs = len(glb_list)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
             | 
| 88 | 
            +
            def worker(
         | 
| 89 | 
            +
                queue: multiprocessing.JoinableQueue,
         | 
| 90 | 
            +
                count: multiprocessing.Value,
         | 
| 91 | 
            +
                gpu: int,
         | 
| 92 | 
            +
                s3: Optional[boto3.client],
         | 
| 93 | 
            +
            ) -> None:
         | 
| 94 | 
            +
                print("Worker started")
         | 
| 95 | 
            +
                while True:
         | 
| 96 | 
            +
                    case, save_p = queue.get()
         | 
| 97 | 
            +
                    src_path = os.path.join(args.data_dir, case)
         | 
| 98 | 
            +
                    smpl_path = src_path.replace('mesh', 'smplx', 1)
         | 
| 99 | 
            +
                    
         | 
| 100 | 
            +
                    command = ('blender -b -P blender_render_human_ortho.py'
         | 
| 101 | 
            +
                    f' -- --object_path {src_path}'
         | 
| 102 | 
            +
                    f' --smpl_path {smpl_path}'
         | 
| 103 | 
            +
                    f' --output_dir {save_p} --engine CYCLES'
         | 
| 104 | 
            +
                    f' --resolution {args.resolution}'
         | 
| 105 | 
            +
                    f' --random_images {args.random_images}'
         | 
| 106 | 
            +
                    )
         | 
| 107 | 
            +
                   
         | 
| 108 | 
            +
                    print(command)
         | 
| 109 | 
            +
                    subprocess.run(command, shell=True)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    with count.get_lock():
         | 
| 112 | 
            +
                        count.value += 1
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    queue.task_done()
         | 
| 115 | 
            +
                
         | 
| 116 | 
            +
             | 
| 117 | 
            +
            if __name__ == "__main__":
         | 
| 118 | 
            +
                # args = tyro.cli(Args)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                s3 = None
         | 
| 121 | 
            +
                queue = multiprocessing.JoinableQueue()
         | 
| 122 | 
            +
                count = multiprocessing.Value("i", 0)
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                # Start worker processes on each of the GPUs
         | 
| 125 | 
            +
                for gpu_i in range(args.num_gpus):
         | 
| 126 | 
            +
                    for worker_i in range(args.workers_per_gpu):
         | 
| 127 | 
            +
                        worker_i = gpu_i * args.workers_per_gpu + worker_i
         | 
| 128 | 
            +
                        process = multiprocessing.Process(
         | 
| 129 | 
            +
                            target=worker, args=(queue, count, args.gpu_list[gpu_i], s3)
         | 
| 130 | 
            +
                        )
         | 
| 131 | 
            +
                        process.daemon = True
         | 
| 132 | 
            +
                        process.start()
         | 
| 133 | 
            +
                    
         | 
| 134 | 
            +
                # Add items to the queue
         | 
| 135 | 
            +
                
         | 
| 136 | 
            +
                save_dirs = parse_obj_list(glb_list)
         | 
| 137 | 
            +
                args.end_i = len(save_dirs) if args.end_i > len(save_dirs) or args.end_i==-1 else args.end_i
         | 
| 138 | 
            +
                
         | 
| 139 | 
            +
                for case_sub, save_dir in zip(glb_list[args.start_i:args.end_i], save_dirs[args.start_i:args.end_i]):
         | 
| 140 | 
            +
                    queue.put([case_sub, save_dir])
         | 
| 141 | 
            +
             | 
| 142 | 
            +
               
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                # Wait for all tasks to be completed
         | 
| 145 | 
            +
                queue.join()
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                # Add sentinels to the queue to stop the worker processes
         | 
| 148 | 
            +
                for i in range(args.num_gpus * args.workers_per_gpu):
         | 
| 149 | 
            +
                    queue.put(None)
         | 
    	
        blender/rename_smpl_files.py
    ADDED
    
    | @@ -0,0 +1,25 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            from tqdm import tqdm
         | 
| 3 | 
            +
            from glob import glob
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            def rename_customhumans():
         | 
| 6 | 
            +
                root = '/data/lipeng/human_scan/CustomHumans/smplx/'
         | 
| 7 | 
            +
                file_paths = glob(os.path.join(root, '*/*_smpl.obj'))
         | 
| 8 | 
            +
                for file_path in tqdm(file_paths):
         | 
| 9 | 
            +
                    new_path = file_path.replace('_smpl', '')
         | 
| 10 | 
            +
                    os.rename(file_path, new_path)  
         | 
| 11 | 
            +
                    
         | 
| 12 | 
            +
            def rename_thuman21():
         | 
| 13 | 
            +
                root = '/data/lipeng/human_scan/THuman2.1/smplx/'
         | 
| 14 | 
            +
                file_paths = glob(os.path.join(root, '*/*.obj'))
         | 
| 15 | 
            +
                for file_path in tqdm(file_paths):
         | 
| 16 | 
            +
                    obj_name = file_path.split('/')[-2]
         | 
| 17 | 
            +
                    folder_name = os.path.dirname(file_path)
         | 
| 18 | 
            +
                    new_path = os.path.join(folder_name, obj_name+'.obj')
         | 
| 19 | 
            +
                    # print(new_path)
         | 
| 20 | 
            +
                    # print(file_path)
         | 
| 21 | 
            +
                    os.rename(file_path, new_path)  
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            if __name__ == '__main__':
         | 
| 24 | 
            +
                rename_thuman21()
         | 
| 25 | 
            +
                rename_customhumans()
         | 
    	
        blender/render.sh
    ADDED
    
    | @@ -0,0 +1,4 @@ | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #### install environment
         | 
| 2 | 
            +
            # ~/pkgs/blender-3.6.4/3.6/python/bin/python3.10 -m pip install openexr opencv-python
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            python render_human.py
         | 
    	
        blender/render_human.py
    ADDED
    
    | @@ -0,0 +1,88 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import json
         | 
| 3 | 
            +
            import math
         | 
| 4 | 
            +
            from concurrent.futures import ProcessPoolExecutor
         | 
| 5 | 
            +
            import threading
         | 
| 6 | 
            +
            from tqdm import tqdm
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            # from glcontext import egl
         | 
| 9 | 
            +
            # egl.create_context()
         | 
| 10 | 
            +
            # exit(0)
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            LOCAL_RANK = 0
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            num_processes = 4
         | 
| 15 | 
            +
            NODE_RANK = int(os.getenv("SLURM_PROCID"))
         | 
| 16 | 
            +
            WORLD_SIZE = 1
         | 
| 17 | 
            +
            NODE_NUM=1
         | 
| 18 | 
            +
            # NODE_RANK = int(os.getenv("SLURM_NODEID"))
         | 
| 19 | 
            +
            IS_MAIN = False
         | 
| 20 | 
            +
            if NODE_RANK == 0 and LOCAL_RANK == 0:
         | 
| 21 | 
            +
                IS_MAIN = True
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            GLOBAL_RANK = NODE_RANK * (WORLD_SIZE//NODE_NUM) + LOCAL_RANK
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            # json_path = "object_lists/Thuman2.0.json"
         | 
| 27 | 
            +
            # json_path = "object_lists/THuman3.0.json"
         | 
| 28 | 
            +
            json_path = "object_lists/CustomHumans.json"
         | 
| 29 | 
            +
            data_dir = '/aifs4su/mmcode/lipeng'
         | 
| 30 | 
            +
            save_dir = '/aifs4su/mmcode/lipeng/human_8view_new'
         | 
| 31 | 
            +
            def parse_obj_list(x):
         | 
| 32 | 
            +
                if 'THuman3.0' in x:
         | 
| 33 | 
            +
                    # print(apath)
         | 
| 34 | 
            +
                    splits = x.split('/')
         | 
| 35 | 
            +
                    x = os.path.join('THuman3.0', splits[-2])
         | 
| 36 | 
            +
                elif 'Thuman2.0' in x:
         | 
| 37 | 
            +
                    splits = x.split('/')
         | 
| 38 | 
            +
                    x = os.path.join('Thuman2.0', splits[-2])
         | 
| 39 | 
            +
                elif 'CustomHumans' in x:
         | 
| 40 | 
            +
                    splits = x.split('/')
         | 
| 41 | 
            +
                    x = os.path.join('CustomHumans', splits[-2])
         | 
| 42 | 
            +
                    # print(splits[-2])
         | 
| 43 | 
            +
                elif '1M' in x:
         | 
| 44 | 
            +
                    splits = x.split('/')
         | 
| 45 | 
            +
                    x = os.path.join('2K2K', splits[-2])
         | 
| 46 | 
            +
                elif 'realistic_8k_model' in x:
         | 
| 47 | 
            +
                    splits = x.split('/')
         | 
| 48 | 
            +
                    x = os.path.join('realistic_8k_model', splits[-1].split('.')[0])
         | 
| 49 | 
            +
                return f'{save_dir}/{x}'  
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            with open(json_path, 'r') as f:
         | 
| 52 | 
            +
                glb_list = json.load(f)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            # glb_list = ['Thuman2.0/0011/0011.obj']
         | 
| 55 | 
            +
            # glb_list = ['THuman3.0/00024_1/00024_0006/mesh.obj']
         | 
| 56 | 
            +
            # glb_list = ['CustomHumans/mesh/0383_00070_02_00061/mesh-f00061.obj']
         | 
| 57 | 
            +
            # glb_list = ['realistic_8k_model/1d41f2a72f994306b80e632f1cc8233f.glb']
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            total_num_glbs = len(glb_list)
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            num_glbs_local = int(math.ceil(total_num_glbs / WORLD_SIZE))
         | 
| 62 | 
            +
            start_idx = GLOBAL_RANK * num_glbs_local
         | 
| 63 | 
            +
            end_idx = start_idx + num_glbs_local
         | 
| 64 | 
            +
            # print(start_idx, end_idx)
         | 
| 65 | 
            +
            local_glbs = glb_list[start_idx:end_idx]
         | 
| 66 | 
            +
            if IS_MAIN:
         | 
| 67 | 
            +
                pbar = tqdm(total=len(local_glbs))
         | 
| 68 | 
            +
                lock = threading.Lock()
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            def process_human(glb_path):
         | 
| 71 | 
            +
                src_path = os.path.join(data_dir, glb_path)
         | 
| 72 | 
            +
                save_path = parse_obj_list(glb_path)
         | 
| 73 | 
            +
                # print(save_path)
         | 
| 74 | 
            +
                command = ('blender -b -P blender_render_human_script.py'
         | 
| 75 | 
            +
                    f' -- --object_path {src_path}'
         | 
| 76 | 
            +
                    f' --output_dir {save_path} ')
         | 
| 77 | 
            +
                    # 1>/dev/null
         | 
| 78 | 
            +
                # print(command)
         | 
| 79 | 
            +
                os.system(command)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                if IS_MAIN:
         | 
| 82 | 
            +
                    with lock:
         | 
| 83 | 
            +
                        pbar.update(1)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
            with ProcessPoolExecutor(max_workers=num_processes) as executor:
         | 
| 86 | 
            +
                executor.map(process_human, local_glbs)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
             | 
    	
        blender/render_single.sh
    ADDED
    
    | @@ -0,0 +1,7 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # debug single sample
         | 
| 2 | 
            +
            blender -b -P blender_render_human_ortho.py \
         | 
| 3 | 
            +
                     -- --object_path /data/lipeng/human_scan/THuman2.1/mesh/0011/0011.obj \
         | 
| 4 | 
            +
                     --smpl_path /data/lipeng/human_scan/THuman2.1/smplx/0011/0011.obj \
         | 
| 5 | 
            +
                     --output_dir debug --engine CYCLES \
         | 
| 6 | 
            +
                     --resolution 768 \
         | 
| 7 | 
            +
                     --random_images 0
         | 
    	
        blender/utils.py
    ADDED
    
    | @@ -0,0 +1,128 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import datetime
         | 
| 2 | 
            +
            import pytz
         | 
| 3 | 
            +
            import traceback
         | 
| 4 | 
            +
            from torchvision.utils import make_grid
         | 
| 5 | 
            +
            from PIL import Image, ImageDraw, ImageFont
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            import json 
         | 
| 9 | 
            +
            import os
         | 
| 10 | 
            +
            from tqdm import tqdm   
         | 
| 11 | 
            +
            import cv2
         | 
| 12 | 
            +
            import imageio
         | 
| 13 | 
            +
            def get_time_for_log():
         | 
| 14 | 
            +
                return datetime.datetime.now(pytz.timezone('Asia/Shanghai')).strftime(
         | 
| 15 | 
            +
                    "%Y%m%d %H:%M:%S")
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def get_trace_for_log():
         | 
| 19 | 
            +
                return str(traceback.format_exc())
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            def make_grid_(imgs, save_file, nrow=10, pad_value=1):
         | 
| 22 | 
            +
                if isinstance(imgs, list):
         | 
| 23 | 
            +
                    if isinstance(imgs[0], Image.Image):
         | 
| 24 | 
            +
                        imgs = [torch.from_numpy(np.array(img)/255.) for img in imgs]
         | 
| 25 | 
            +
                    elif isinstance(imgs[0], np.ndarray):
         | 
| 26 | 
            +
                        imgs = [torch.from_numpy(img/255.) for img in imgs]
         | 
| 27 | 
            +
                    imgs = torch.stack(imgs, 0).permute(0, 3, 1, 2)
         | 
| 28 | 
            +
                if isinstance(imgs, np.ndarray):
         | 
| 29 | 
            +
                    imgs = torch.from_numpy(imgs)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                img_grid = make_grid(imgs, nrow=nrow, padding=2, pad_value=pad_value)
         | 
| 32 | 
            +
                img_grid = img_grid.permute(1, 2, 0).numpy()
         | 
| 33 | 
            +
                img_grid = (img_grid * 255).astype(np.uint8)
         | 
| 34 | 
            +
                img_grid = Image.fromarray(img_grid)
         | 
| 35 | 
            +
                img_grid.save(save_file) 
         | 
| 36 | 
            +
                
         | 
| 37 | 
            +
            def draw_caption(img, text, pos, size=100, color=(128, 128, 128)):
         | 
| 38 | 
            +
                draw = ImageDraw.Draw(img)
         | 
| 39 | 
            +
                # font = ImageFont.truetype(size= size)
         | 
| 40 | 
            +
                font = ImageFont.load_default()
         | 
| 41 | 
            +
                font = font.font_variant(size=size)
         | 
| 42 | 
            +
                draw.text(pos, text, color, font=font)
         | 
| 43 | 
            +
                return img
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            def txt2json(txt_file, json_file):  
         | 
| 47 | 
            +
                with open(txt_file, 'r') as f:
         | 
| 48 | 
            +
                    items = f.readlines()
         | 
| 49 | 
            +
                    items = [x.strip() for x in items]
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                with open(json_file, 'w') as f:
         | 
| 52 | 
            +
                    json.dump(items.tolist(), f)
         | 
| 53 | 
            +
                    
         | 
| 54 | 
            +
            def process_thuman_texture():
         | 
| 55 | 
            +
                path = '/aifs4su/mmcode/lipeng/Thuman2.0'
         | 
| 56 | 
            +
                cases = os.listdir(path)
         | 
| 57 | 
            +
                for case in tqdm(cases):
         | 
| 58 | 
            +
                    mtl = os.path.join(path, case, 'material0.mtl')
         | 
| 59 | 
            +
                    with open(mtl, 'r') as f:
         | 
| 60 | 
            +
                        lines = f.read()
         | 
| 61 | 
            +
                        lines = lines.replace('png', 'jpeg')
         | 
| 62 | 
            +
                    with open(mtl, 'w') as f:
         | 
| 63 | 
            +
                        f.write(lines)
         | 
| 64 | 
            +
                    
         | 
| 65 | 
            +
             | 
| 66 | 
            +
            #### for debug
         | 
| 67 | 
            +
            os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            def get_intrinsic_from_fov(fov, H, W, bs=-1):
         | 
| 71 | 
            +
                focal_length = 0.5 * H / np.tan(0.5 * fov)
         | 
| 72 | 
            +
                intrinsic = np.identity(3, dtype=np.float32)
         | 
| 73 | 
            +
                intrinsic[0, 0] = focal_length
         | 
| 74 | 
            +
                intrinsic[1, 1] = focal_length
         | 
| 75 | 
            +
                intrinsic[0, 2] = W / 2.0
         | 
| 76 | 
            +
                intrinsic[1, 2] = H / 2.0
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                if bs > 0:
         | 
| 79 | 
            +
                    intrinsic = intrinsic[None].repeat(bs, axis=0)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                return torch.from_numpy(intrinsic)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            def read_data(data_dir, i):
         | 
| 84 | 
            +
                """
         | 
| 85 | 
            +
                Return:
         | 
| 86 | 
            +
                rgb: (H, W, 3) torch.float32
         | 
| 87 | 
            +
                depth: (H, W, 1) torch.float32
         | 
| 88 | 
            +
                mask: (H, W, 1) torch.float32
         | 
| 89 | 
            +
                c2w: (4, 4) torch.float32
         | 
| 90 | 
            +
                intrinsic: (3, 3) torch.float32
         | 
| 91 | 
            +
                """
         | 
| 92 | 
            +
                background_color = torch.tensor([0.0, 0.0, 0.0])
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                rgb_name = os.path.join(data_dir, f'render_%04d.webp' % i)
         | 
| 95 | 
            +
                depth_name = os.path.join(data_dir, f'depth_%04d.exr' % i)
         | 
| 96 | 
            +
                
         | 
| 97 | 
            +
                img = torch.from_numpy(
         | 
| 98 | 
            +
                            np.asarray(
         | 
| 99 | 
            +
                                Image.fromarray(imageio.v2.imread(rgb_name))
         | 
| 100 | 
            +
                                .convert("RGBA")
         | 
| 101 | 
            +
                            )
         | 
| 102 | 
            +
                            / 255.0
         | 
| 103 | 
            +
                        ).float()
         | 
| 104 | 
            +
                mask = img[:, :, -1:]
         | 
| 105 | 
            +
                rgb = img[:, :, :3] * mask + background_color[
         | 
| 106 | 
            +
                    None, None, :
         | 
| 107 | 
            +
                ] * (1 - mask) 
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                depth = torch.from_numpy(
         | 
| 110 | 
            +
                    cv2.imread(depth_name, cv2.IMREAD_UNCHANGED)[..., 0, None]
         | 
| 111 | 
            +
                )
         | 
| 112 | 
            +
                mask[depth > 100.0] = 0.0
         | 
| 113 | 
            +
                depth[~(mask > 0.5)] = 0.0  # set invalid depth to 0
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                meta_path = os.path.join(data_dir, 'meta.json')
         | 
| 116 | 
            +
                with open(meta_path, 'r') as f:
         | 
| 117 | 
            +
                    meta = json.load(f)
         | 
| 118 | 
            +
                
         | 
| 119 | 
            +
                c2w = torch.as_tensor(
         | 
| 120 | 
            +
                            meta['locations'][i]["transform_matrix"],
         | 
| 121 | 
            +
                            dtype=torch.float32,
         | 
| 122 | 
            +
                        )
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                H, W = rgb.shape[:2]
         | 
| 125 | 
            +
                fovy = meta["camera_angle_x"]
         | 
| 126 | 
            +
                intrinsic = get_intrinsic_from_fov(fovy, H=H, W=W)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                return rgb, depth, mask, c2w, intrinsic
         | 
    	
        configs/inference-768-6view.yaml
    ADDED
    
    | @@ -0,0 +1,72 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            pretrained_model_name_or_path: 'stabilityai/stable-diffusion-2-1-unclip'
         | 
| 2 | 
            +
            revision: null
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            num_views: 7
         | 
| 5 | 
            +
            with_smpl: false
         | 
| 6 | 
            +
            validation_dataset:
         | 
| 7 | 
            +
              prompt_embeds_path: mvdiffusion/data/fixed_prompt_embeds_7view
         | 
| 8 | 
            +
              root_dir: 'examples/shhq'
         | 
| 9 | 
            +
              num_views: ${num_views}
         | 
| 10 | 
            +
              bg_color: 'white'
         | 
| 11 | 
            +
              img_wh:  [768, 768]
         | 
| 12 | 
            +
              num_validation_samples: 1000
         | 
| 13 | 
            +
              crop_size: 740
         | 
| 14 | 
            +
              margin_size: 50
         | 
| 15 | 
            +
              smpl_folder: 'smpl_image_pymaf'
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            save_dir: 'mv_results'
         | 
| 19 | 
            +
            save_mode: 'rgba' # 'concat', 'rgba', 'rgb'
         | 
| 20 | 
            +
            seed: 42
         | 
| 21 | 
            +
            validation_batch_size: 1
         | 
| 22 | 
            +
            dataloader_num_workers: 1 
         | 
| 23 | 
            +
            local_rank: -1
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            pipe_kwargs:
         | 
| 26 | 
            +
              num_views: ${num_views}
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            validation_guidance_scales: 3.0
         | 
| 29 | 
            +
            pipe_validation_kwargs:
         | 
| 30 | 
            +
              num_inference_steps: 40
         | 
| 31 | 
            +
              eta: 1.0
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            validation_grid_nrow: ${num_views}
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            unet_from_pretrained_kwargs:
         | 
| 36 | 
            +
              unclip: true
         | 
| 37 | 
            +
              sdxl: false
         | 
| 38 | 
            +
              num_views: ${num_views}
         | 
| 39 | 
            +
              sample_size: 96
         | 
| 40 | 
            +
              zero_init_conv_in: false # modify
         | 
| 41 | 
            +
              
         | 
| 42 | 
            +
              projection_camera_embeddings_input_dim: 2 # 2 for elevation and 6 for focal_length  
         | 
| 43 | 
            +
              zero_init_camera_projection: false
         | 
| 44 | 
            +
              num_regress_blocks: 3
         | 
| 45 | 
            +
              
         | 
| 46 | 
            +
              cd_attention_last: false
         | 
| 47 | 
            +
              cd_attention_mid: false
         | 
| 48 | 
            +
              multiview_attention: true
         | 
| 49 | 
            +
              sparse_mv_attention: true
         | 
| 50 | 
            +
              selfattn_block: self_rowwise
         | 
| 51 | 
            +
              mvcd_attention: true
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            recon_opt:
         | 
| 54 | 
            +
              res_path: out
         | 
| 55 | 
            +
              save_glb: False
         | 
| 56 | 
            +
              # camera setting
         | 
| 57 | 
            +
              num_view: 6
         | 
| 58 | 
            +
              scale: 4
         | 
| 59 | 
            +
              mode: ortho
         | 
| 60 | 
            +
              resolution: 1024
         | 
| 61 | 
            +
              cam_path: 'mvdiffusion/data/six_human_pose'
         | 
| 62 | 
            +
              # optimization
         | 
| 63 | 
            +
              iters: 700
         | 
| 64 | 
            +
              clr_iters: 200
         | 
| 65 | 
            +
              debug: false
         | 
| 66 | 
            +
              snapshot_step: 50
         | 
| 67 | 
            +
              lr_clr: 2e-3
         | 
| 68 | 
            +
              gpu_id: 0
         | 
| 69 | 
            +
             | 
| 70 | 
            +
              replace_hand: false
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            enable_xformers_memory_efficient_attention: true
         | 
    	
        configs/remesh.yaml
    ADDED
    
    | @@ -0,0 +1,18 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            res_path: out
         | 
| 2 | 
            +
            save_glb: False
         | 
| 3 | 
            +
            imgs_path: examples/debug
         | 
| 4 | 
            +
            mv_path: ./
         | 
| 5 | 
            +
            # camera setting
         | 
| 6 | 
            +
            num_view: 6
         | 
| 7 | 
            +
            scale: 4
         | 
| 8 | 
            +
            mode: ortho
         | 
| 9 | 
            +
            resolution: 1024
         | 
| 10 | 
            +
            cam_path: 'mvdiffusion/data/six_human_pose'
         | 
| 11 | 
            +
            # optimization
         | 
| 12 | 
            +
            iters: 700
         | 
| 13 | 
            +
            clr_iters: 200
         | 
| 14 | 
            +
            debug: false
         | 
| 15 | 
            +
            snapshot_step: 50
         | 
| 16 | 
            +
            lr_clr: 2e-3
         | 
| 17 | 
            +
            gpu_id: 0
         | 
| 18 | 
            +
            replace_hand: false
         | 
    	
        configs/train-768-6view-onlyscan_face.yaml
    ADDED
    
    | @@ -0,0 +1,145 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            pretrained_model_name_or_path: stabilityai/stable-diffusion-2-1-unclip
         | 
| 2 | 
            +
            pretrained_unet_path: null
         | 
| 3 | 
            +
            revision: null
         | 
| 4 | 
            +
            with_smpl: false
         | 
| 5 | 
            +
            data_common:
         | 
| 6 | 
            +
              root_dir: /aifs4su/mmcode/lipeng/human_8view_new/
         | 
| 7 | 
            +
              predict_relative_views: [0, 1, 2, 4, 6, 7]
         | 
| 8 | 
            +
              num_validation_samples: 8
         | 
| 9 | 
            +
              img_wh: [768, 768]
         | 
| 10 | 
            +
              read_normal: true
         | 
| 11 | 
            +
              read_color: true
         | 
| 12 | 
            +
              read_depth: false
         | 
| 13 | 
            +
              exten: .png
         | 
| 14 | 
            +
              prompt_embeds_path: mvdiffusion/data/fixed_prompt_embeds_7view
         | 
| 15 | 
            +
              object_list:
         | 
| 16 | 
            +
              - data_lists/human_only_scan.json
         | 
| 17 | 
            +
              invalid_list:
         | 
| 18 | 
            +
              - 
         | 
| 19 | 
            +
            train_dataset:
         | 
| 20 | 
            +
              root_dir: ${data_common.root_dir}
         | 
| 21 | 
            +
              azi_interval: 45.0
         | 
| 22 | 
            +
              random_views: 3
         | 
| 23 | 
            +
              predict_relative_views: ${data_common.predict_relative_views}
         | 
| 24 | 
            +
              bg_color: three_choices
         | 
| 25 | 
            +
              object_list: ${data_common.object_list}
         | 
| 26 | 
            +
              invalid_list: ${data_common.invalid_list}
         | 
| 27 | 
            +
              img_wh: ${data_common.img_wh}
         | 
| 28 | 
            +
              validation: false
         | 
| 29 | 
            +
              num_validation_samples: ${data_common.num_validation_samples}
         | 
| 30 | 
            +
              read_normal: ${data_common.read_normal}
         | 
| 31 | 
            +
              read_color: ${data_common.read_color}
         | 
| 32 | 
            +
              read_depth: ${data_common.read_depth}
         | 
| 33 | 
            +
              load_cache: false
         | 
| 34 | 
            +
              exten: ${data_common.exten}
         | 
| 35 | 
            +
              prompt_embeds_path: ${data_common.prompt_embeds_path}
         | 
| 36 | 
            +
              side_views_rate: 0.3
         | 
| 37 | 
            +
              elevation_list: null
         | 
| 38 | 
            +
            validation_dataset:
         | 
| 39 | 
            +
              prompt_embeds_path: ${data_common.prompt_embeds_path}
         | 
| 40 | 
            +
              root_dir: examples/debug
         | 
| 41 | 
            +
              num_views: ${num_views}
         | 
| 42 | 
            +
              bg_color: white
         | 
| 43 | 
            +
              img_wh: ${data_common.img_wh}
         | 
| 44 | 
            +
              num_validation_samples: 1000
         | 
| 45 | 
            +
              crop_size: 740
         | 
| 46 | 
            +
            validation_train_dataset:
         | 
| 47 | 
            +
              root_dir: ${data_common.root_dir}
         | 
| 48 | 
            +
              azi_interval: 45.0
         | 
| 49 | 
            +
              random_views: 3
         | 
| 50 | 
            +
              predict_relative_views: ${data_common.predict_relative_views}
         | 
| 51 | 
            +
              bg_color: white
         | 
| 52 | 
            +
              object_list: ${data_common.object_list}
         | 
| 53 | 
            +
              invalid_list: ${data_common.invalid_list}
         | 
| 54 | 
            +
              img_wh: ${data_common.img_wh}
         | 
| 55 | 
            +
              validation: false
         | 
| 56 | 
            +
              num_validation_samples: ${data_common.num_validation_samples}
         | 
| 57 | 
            +
              read_normal: ${data_common.read_normal}
         | 
| 58 | 
            +
              read_color: ${data_common.read_color}
         | 
| 59 | 
            +
              read_depth: ${data_common.read_depth}
         | 
| 60 | 
            +
              num_samples: ${data_common.num_validation_samples}
         | 
| 61 | 
            +
              load_cache: false
         | 
| 62 | 
            +
              exten: ${data_common.exten}
         | 
| 63 | 
            +
              prompt_embeds_path: ${data_common.prompt_embeds_path}
         | 
| 64 | 
            +
              elevation_list: null
         | 
| 65 | 
            +
            output_dir:  output/unit-unclip-768-6view-onlyscan-onlyortho-faceinself-scale0.5
         | 
| 66 | 
            +
            checkpoint_prefix: ../human_checkpoint_backup/
         | 
| 67 | 
            +
            seed: 42
         | 
| 68 | 
            +
            train_batch_size: 2
         | 
| 69 | 
            +
            validation_batch_size: 1
         | 
| 70 | 
            +
            validation_train_batch_size: 1
         | 
| 71 | 
            +
            max_train_steps: 30000
         | 
| 72 | 
            +
            gradient_accumulation_steps: 2
         | 
| 73 | 
            +
            gradient_checkpointing: true
         | 
| 74 | 
            +
            learning_rate: 0.0001
         | 
| 75 | 
            +
            scale_lr: false
         | 
| 76 | 
            +
            lr_scheduler: piecewise_constant
         | 
| 77 | 
            +
            step_rules:  1:2000,0.5
         | 
| 78 | 
            +
            lr_warmup_steps: 10
         | 
| 79 | 
            +
            snr_gamma: 5.0
         | 
| 80 | 
            +
            use_8bit_adam: false
         | 
| 81 | 
            +
            allow_tf32: true
         | 
| 82 | 
            +
            use_ema: true
         | 
| 83 | 
            +
            dataloader_num_workers: 32
         | 
| 84 | 
            +
            adam_beta1: 0.9
         | 
| 85 | 
            +
            adam_beta2: 0.999
         | 
| 86 | 
            +
            adam_weight_decay: 0.01
         | 
| 87 | 
            +
            adam_epsilon: 1.0e-08
         | 
| 88 | 
            +
            max_grad_norm: 1.0
         | 
| 89 | 
            +
            prediction_type: null
         | 
| 90 | 
            +
            logging_dir: logs
         | 
| 91 | 
            +
            vis_dir: vis
         | 
| 92 | 
            +
            mixed_precision: fp16
         | 
| 93 | 
            +
            report_to: wandb
         | 
| 94 | 
            +
            local_rank: 0
         | 
| 95 | 
            +
            checkpointing_steps: 2500
         | 
| 96 | 
            +
            checkpoints_total_limit: 2
         | 
| 97 | 
            +
            resume_from_checkpoint: latest
         | 
| 98 | 
            +
            enable_xformers_memory_efficient_attention: true
         | 
| 99 | 
            +
            validation_steps: 2500 # 
         | 
| 100 | 
            +
            validation_sanity_check: true
         | 
| 101 | 
            +
            tracker_project_name: PSHuman
         | 
| 102 | 
            +
            trainable_modules: null
         | 
| 103 | 
            +
             | 
| 104 | 
            +
             | 
| 105 | 
            +
            use_classifier_free_guidance: true
         | 
| 106 | 
            +
            condition_drop_rate: 0.05
         | 
| 107 | 
            +
            scale_input_latents: true
         | 
| 108 | 
            +
            regress_elevation: false
         | 
| 109 | 
            +
            regress_focal_length: false
         | 
| 110 | 
            +
            elevation_loss_weight: 1.0
         | 
| 111 | 
            +
            focal_loss_weight: 0.0
         | 
| 112 | 
            +
            pipe_kwargs:
         | 
| 113 | 
            +
              num_views: ${num_views}
         | 
| 114 | 
            +
            pipe_validation_kwargs:
         | 
| 115 | 
            +
              eta: 1.0
         | 
| 116 | 
            +
             | 
| 117 | 
            +
            unet_from_pretrained_kwargs:
         | 
| 118 | 
            +
              unclip: true
         | 
| 119 | 
            +
              num_views: ${num_views}
         | 
| 120 | 
            +
              sample_size: 96
         | 
| 121 | 
            +
              zero_init_conv_in: true
         | 
| 122 | 
            +
              regress_elevation: ${regress_elevation}
         | 
| 123 | 
            +
              regress_focal_length: ${regress_focal_length}
         | 
| 124 | 
            +
              num_regress_blocks: 2
         | 
| 125 | 
            +
              camera_embedding_type: e_de_da_sincos
         | 
| 126 | 
            +
              projection_camera_embeddings_input_dim: 2
         | 
| 127 | 
            +
              zero_init_camera_projection: true # modified
         | 
| 128 | 
            +
              init_mvattn_with_selfattn: false
         | 
| 129 | 
            +
              cd_attention_last: false
         | 
| 130 | 
            +
              cd_attention_mid: false
         | 
| 131 | 
            +
              multiview_attention: true
         | 
| 132 | 
            +
              sparse_mv_attention: true
         | 
| 133 | 
            +
              selfattn_block: self_rowwise
         | 
| 134 | 
            +
              mvcd_attention: true
         | 
| 135 | 
            +
              addition_downsample: false
         | 
| 136 | 
            +
              use_face_adapter: false
         | 
| 137 | 
            +
             | 
| 138 | 
            +
            validation_guidance_scales:
         | 
| 139 | 
            +
            - 3.0
         | 
| 140 | 
            +
            validation_grid_nrow: ${num_views}
         | 
| 141 | 
            +
            camera_embedding_lr_mult: 1.0
         | 
| 142 | 
            +
            plot_pose_acc: false
         | 
| 143 | 
            +
            num_views: 7
         | 
| 144 | 
            +
            pred_type: joint
         | 
| 145 | 
            +
            drop_type: drop_as_a_whole
         | 
    	
        configs/train-768-6view-onlyscan_face_smplx.yaml
    ADDED
    
    | @@ -0,0 +1,154 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            pretrained_model_name_or_path: stabilityai/stable-diffusion-2-1-unclip
         | 
| 2 | 
            +
            pretrained_unet_path: null
         | 
| 3 | 
            +
            revision: null
         | 
| 4 | 
            +
            with_smpl: true
         | 
| 5 | 
            +
            data_common:
         | 
| 6 | 
            +
              root_dir: /aifs4su/mmcode/lipeng/human_8view_with_smplx/
         | 
| 7 | 
            +
              predict_relative_views: [0, 1, 2, 4, 6, 7]
         | 
| 8 | 
            +
              num_validation_samples: 8
         | 
| 9 | 
            +
              img_wh: [768, 768]
         | 
| 10 | 
            +
              read_normal: true
         | 
| 11 | 
            +
              read_color: true
         | 
| 12 | 
            +
              read_depth: false
         | 
| 13 | 
            +
              exten: .png
         | 
| 14 | 
            +
              prompt_embeds_path: mvdiffusion/data/fixed_prompt_embeds_7view
         | 
| 15 | 
            +
              object_list:
         | 
| 16 | 
            +
              - data_lists/human_only_scan_with_smplx.json  # modified
         | 
| 17 | 
            +
              invalid_list:
         | 
| 18 | 
            +
              - 
         | 
| 19 | 
            +
              with_smpl: ${with_smpl}
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            train_dataset:
         | 
| 22 | 
            +
              root_dir: ${data_common.root_dir}
         | 
| 23 | 
            +
              azi_interval: 45.0
         | 
| 24 | 
            +
              random_views: 0
         | 
| 25 | 
            +
              predict_relative_views: ${data_common.predict_relative_views}
         | 
| 26 | 
            +
              bg_color: three_choices
         | 
| 27 | 
            +
              object_list: ${data_common.object_list}
         | 
| 28 | 
            +
              invalid_list: ${data_common.invalid_list}
         | 
| 29 | 
            +
              img_wh: ${data_common.img_wh}
         | 
| 30 | 
            +
              validation: false
         | 
| 31 | 
            +
              num_validation_samples: ${data_common.num_validation_samples}
         | 
| 32 | 
            +
              read_normal: ${data_common.read_normal}
         | 
| 33 | 
            +
              read_color: ${data_common.read_color}
         | 
| 34 | 
            +
              read_depth: ${data_common.read_depth}
         | 
| 35 | 
            +
              load_cache: false
         | 
| 36 | 
            +
              exten: ${data_common.exten}
         | 
| 37 | 
            +
              prompt_embeds_path: ${data_common.prompt_embeds_path}
         | 
| 38 | 
            +
              side_views_rate: 0.3
         | 
| 39 | 
            +
              elevation_list: null
         | 
| 40 | 
            +
              with_smpl: ${with_smpl}
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            validation_dataset:
         | 
| 43 | 
            +
              prompt_embeds_path: ${data_common.prompt_embeds_path}
         | 
| 44 | 
            +
              root_dir: examples/debug
         | 
| 45 | 
            +
              num_views: ${num_views}
         | 
| 46 | 
            +
              bg_color: white
         | 
| 47 | 
            +
              img_wh: ${data_common.img_wh}
         | 
| 48 | 
            +
              num_validation_samples: 1000
         | 
| 49 | 
            +
              margin_size: 10
         | 
| 50 | 
            +
              # crop_size: 720
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            validation_train_dataset:
         | 
| 53 | 
            +
              root_dir: ${data_common.root_dir}
         | 
| 54 | 
            +
              azi_interval: 45.0
         | 
| 55 | 
            +
              random_views: 0
         | 
| 56 | 
            +
              predict_relative_views: ${data_common.predict_relative_views}
         | 
| 57 | 
            +
              bg_color: white
         | 
| 58 | 
            +
              object_list: ${data_common.object_list}
         | 
| 59 | 
            +
              invalid_list: ${data_common.invalid_list}
         | 
| 60 | 
            +
              img_wh: ${data_common.img_wh}
         | 
| 61 | 
            +
              validation: false
         | 
| 62 | 
            +
              num_validation_samples: ${data_common.num_validation_samples}
         | 
| 63 | 
            +
              read_normal: ${data_common.read_normal}
         | 
| 64 | 
            +
              read_color: ${data_common.read_color}
         | 
| 65 | 
            +
              read_depth: ${data_common.read_depth}
         | 
| 66 | 
            +
              num_samples: ${data_common.num_validation_samples}
         | 
| 67 | 
            +
              load_cache: false
         | 
| 68 | 
            +
              exten: ${data_common.exten}
         | 
| 69 | 
            +
              prompt_embeds_path: ${data_common.prompt_embeds_path}
         | 
| 70 | 
            +
              elevation_list: null
         | 
| 71 | 
            +
              with_smpl: ${with_smpl}
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            output_dir: output/unit-unclip-768-6view-onlyscan-onlyortho-faceinself-scale0.5-smplx
         | 
| 74 | 
            +
            checkpoint_prefix: ../human_checkpoint_backup/
         | 
| 75 | 
            +
            seed: 42
         | 
| 76 | 
            +
            train_batch_size: 2
         | 
| 77 | 
            +
            validation_batch_size: 1
         | 
| 78 | 
            +
            validation_train_batch_size: 1
         | 
| 79 | 
            +
            max_train_steps: 30000
         | 
| 80 | 
            +
            gradient_accumulation_steps: 2
         | 
| 81 | 
            +
            gradient_checkpointing: true
         | 
| 82 | 
            +
            learning_rate: 0.0001
         | 
| 83 | 
            +
            scale_lr: false
         | 
| 84 | 
            +
            lr_scheduler: piecewise_constant
         | 
| 85 | 
            +
            step_rules:  1:2000,0.5
         | 
| 86 | 
            +
            lr_warmup_steps: 10
         | 
| 87 | 
            +
            snr_gamma: 5.0
         | 
| 88 | 
            +
            use_8bit_adam: false
         | 
| 89 | 
            +
            allow_tf32: true
         | 
| 90 | 
            +
            use_ema: true
         | 
| 91 | 
            +
            dataloader_num_workers: 32
         | 
| 92 | 
            +
            adam_beta1: 0.9
         | 
| 93 | 
            +
            adam_beta2: 0.999
         | 
| 94 | 
            +
            adam_weight_decay: 0.01
         | 
| 95 | 
            +
            adam_epsilon: 1.0e-08
         | 
| 96 | 
            +
            max_grad_norm: 1.0
         | 
| 97 | 
            +
            prediction_type: null
         | 
| 98 | 
            +
            logging_dir: logs
         | 
| 99 | 
            +
            vis_dir: vis
         | 
| 100 | 
            +
            mixed_precision: fp16
         | 
| 101 | 
            +
            report_to: wandb
         | 
| 102 | 
            +
            local_rank: 0
         | 
| 103 | 
            +
            checkpointing_steps: 5000
         | 
| 104 | 
            +
            checkpoints_total_limit: 2
         | 
| 105 | 
            +
            resume_from_checkpoint: latest
         | 
| 106 | 
            +
            enable_xformers_memory_efficient_attention: true
         | 
| 107 | 
            +
            validation_steps: 2500 # 
         | 
| 108 | 
            +
            validation_sanity_check: true
         | 
| 109 | 
            +
            tracker_project_name: PSHuman
         | 
| 110 | 
            +
            trainable_modules: null
         | 
| 111 | 
            +
             | 
| 112 | 
            +
            use_classifier_free_guidance: true
         | 
| 113 | 
            +
            condition_drop_rate: 0.05
         | 
| 114 | 
            +
            scale_input_latents: true
         | 
| 115 | 
            +
            regress_elevation: false
         | 
| 116 | 
            +
            regress_focal_length: false
         | 
| 117 | 
            +
            elevation_loss_weight: 1.0
         | 
| 118 | 
            +
            focal_loss_weight: 0.0
         | 
| 119 | 
            +
            pipe_kwargs:
         | 
| 120 | 
            +
              num_views: ${num_views}
         | 
| 121 | 
            +
            pipe_validation_kwargs:
         | 
| 122 | 
            +
              eta: 1.0
         | 
| 123 | 
            +
             | 
| 124 | 
            +
            unet_from_pretrained_kwargs:
         | 
| 125 | 
            +
              unclip: true
         | 
| 126 | 
            +
              num_views: ${num_views}
         | 
| 127 | 
            +
              sample_size: 96
         | 
| 128 | 
            +
              zero_init_conv_in: true
         | 
| 129 | 
            +
              regress_elevation: ${regress_elevation}
         | 
| 130 | 
            +
              regress_focal_length: ${regress_focal_length}
         | 
| 131 | 
            +
              num_regress_blocks: 2
         | 
| 132 | 
            +
              camera_embedding_type: e_de_da_sincos
         | 
| 133 | 
            +
              projection_camera_embeddings_input_dim: 2
         | 
| 134 | 
            +
              zero_init_camera_projection: true # modified
         | 
| 135 | 
            +
              init_mvattn_with_selfattn: false
         | 
| 136 | 
            +
              cd_attention_last: false
         | 
| 137 | 
            +
              cd_attention_mid: false
         | 
| 138 | 
            +
              multiview_attention: true
         | 
| 139 | 
            +
              sparse_mv_attention: true
         | 
| 140 | 
            +
              selfattn_block: self_rowwise
         | 
| 141 | 
            +
              mvcd_attention: true
         | 
| 142 | 
            +
              addition_downsample: false
         | 
| 143 | 
            +
              use_face_adapter: false
         | 
| 144 | 
            +
              in_channels: 12
         | 
| 145 | 
            +
              
         | 
| 146 | 
            +
             | 
| 147 | 
            +
            validation_guidance_scales:
         | 
| 148 | 
            +
            - 3.0
         | 
| 149 | 
            +
            validation_grid_nrow: ${num_views}
         | 
| 150 | 
            +
            camera_embedding_lr_mult: 1.0
         | 
| 151 | 
            +
            plot_pose_acc: false
         | 
| 152 | 
            +
            num_views: 7
         | 
| 153 | 
            +
            pred_type: joint
         | 
| 154 | 
            +
            drop_type: drop_as_a_whole
         | 
    	
        core/opt.py
    ADDED
    
    | @@ -0,0 +1,197 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from copy import deepcopy
         | 
| 2 | 
            +
            import time
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch_scatter
         | 
| 5 | 
            +
            from core.remesh import calc_edge_length, calc_edges, calc_face_collapses, calc_face_normals, calc_vertex_normals, collapse_edges, flip_edges, pack, prepend_dummies, remove_dummies, split_edges
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            @torch.no_grad()
         | 
| 8 | 
            +
            def remesh(
         | 
| 9 | 
            +
                    vertices_etc:torch.Tensor, #V,D
         | 
| 10 | 
            +
                    faces:torch.Tensor, #F,3 long
         | 
| 11 | 
            +
                    min_edgelen:torch.Tensor, #V
         | 
| 12 | 
            +
                    max_edgelen:torch.Tensor, #V
         | 
| 13 | 
            +
                    flip:bool,
         | 
| 14 | 
            +
                    max_vertices=1e6
         | 
| 15 | 
            +
                    ):
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                # dummies
         | 
| 18 | 
            +
                vertices_etc,faces = prepend_dummies(vertices_etc,faces)
         | 
| 19 | 
            +
                vertices = vertices_etc[:,:3] #V,3
         | 
| 20 | 
            +
                nan_tensor = torch.tensor([torch.nan],device=min_edgelen.device)
         | 
| 21 | 
            +
                min_edgelen = torch.concat((nan_tensor,min_edgelen))
         | 
| 22 | 
            +
                max_edgelen = torch.concat((nan_tensor,max_edgelen))
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                # collapse
         | 
| 25 | 
            +
                edges,face_to_edge = calc_edges(faces) #E,2 F,3
         | 
| 26 | 
            +
                edge_length = calc_edge_length(vertices,edges) #E
         | 
| 27 | 
            +
                face_normals = calc_face_normals(vertices,faces,normalize=False) #F,3
         | 
| 28 | 
            +
                vertex_normals = calc_vertex_normals(vertices,faces,face_normals) #V,3
         | 
| 29 | 
            +
                face_collapse = calc_face_collapses(vertices,faces,edges,face_to_edge,edge_length,face_normals,vertex_normals,min_edgelen,area_ratio=0.5)
         | 
| 30 | 
            +
                shortness = (1 - edge_length / min_edgelen[edges].mean(dim=-1)).clamp_min_(0) #e[0,1] 0...ok, 1...edgelen=0
         | 
| 31 | 
            +
                priority = face_collapse.float() + shortness
         | 
| 32 | 
            +
                vertices_etc,faces = collapse_edges(vertices_etc,faces,edges,priority)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                # split
         | 
| 35 | 
            +
                if vertices.shape[0]<max_vertices:
         | 
| 36 | 
            +
                    edges,face_to_edge = calc_edges(faces) #E,2 F,3
         | 
| 37 | 
            +
                    vertices = vertices_etc[:,:3] #V,3
         | 
| 38 | 
            +
                    edge_length = calc_edge_length(vertices,edges) #E
         | 
| 39 | 
            +
                    splits = edge_length > max_edgelen[edges].mean(dim=-1)
         | 
| 40 | 
            +
                    vertices_etc,faces = split_edges(vertices_etc,faces,edges,face_to_edge,splits,pack_faces=False)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                vertices_etc,faces = pack(vertices_etc,faces)
         | 
| 43 | 
            +
                vertices = vertices_etc[:,:3]
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                if flip:
         | 
| 46 | 
            +
                    edges,_,edge_to_face = calc_edges(faces,with_edge_to_face=True) #E,2 F,3
         | 
| 47 | 
            +
                    flip_edges(vertices,faces,edges,edge_to_face,with_border=False)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                return remove_dummies(vertices_etc,faces)
         | 
| 50 | 
            +
                
         | 
| 51 | 
            +
            def lerp_unbiased(a:torch.Tensor,b:torch.Tensor,weight:float,step:int):
         | 
| 52 | 
            +
                """lerp with adam's bias correction"""
         | 
| 53 | 
            +
                c_prev = 1-weight**(step-1)
         | 
| 54 | 
            +
                c = 1-weight**step
         | 
| 55 | 
            +
                a_weight = weight*c_prev/c
         | 
| 56 | 
            +
                b_weight = (1-weight)/c
         | 
| 57 | 
            +
                a.mul_(a_weight).add_(b, alpha=b_weight)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            class MeshOptimizer:
         | 
| 61 | 
            +
                """Use this like a pytorch Optimizer, but after calling opt.step(), do vertices,faces = opt.remesh()."""
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                def __init__(self, 
         | 
| 64 | 
            +
                        vertices:torch.Tensor, #V,3
         | 
| 65 | 
            +
                        faces:torch.Tensor, #F,3
         | 
| 66 | 
            +
                        lr=0.3, #learning rate
         | 
| 67 | 
            +
                        betas=(0.8,0.8,0), #betas[0:2] are the same as in Adam, betas[2] may be used to time-smooth the relative velocity nu
         | 
| 68 | 
            +
                        gammas=(0,0,0), #optional spatial smoothing for m1,m2,nu, values between 0 (no smoothing) and 1 (max. smoothing)
         | 
| 69 | 
            +
                        nu_ref=0.3, #reference velocity for edge length controller
         | 
| 70 | 
            +
                        edge_len_lims=(.01,.15), #smallest and largest allowed reference edge length
         | 
| 71 | 
            +
                        edge_len_tol=.5, #edge length tolerance for split and collapse
         | 
| 72 | 
            +
                        gain=.2,  #gain value for edge length controller
         | 
| 73 | 
            +
                        laplacian_weight=.02, #for laplacian smoothing/regularization
         | 
| 74 | 
            +
                        ramp=1, #learning rate ramp, actual ramp width is ramp/(1-betas[0])
         | 
| 75 | 
            +
                        grad_lim=10., #gradients are clipped to m1.abs()*grad_lim
         | 
| 76 | 
            +
                        remesh_interval=1, #larger intervals are faster but with worse mesh quality
         | 
| 77 | 
            +
                        local_edgelen=True, #set to False to use a global scalar reference edge length instead
         | 
| 78 | 
            +
                        remesh_milestones= [500], #list of steps at which to remesh
         | 
| 79 | 
            +
                        # total_steps=1000, #total number of steps
         | 
| 80 | 
            +
                        ):
         | 
| 81 | 
            +
                    self._vertices = vertices
         | 
| 82 | 
            +
                    self._faces = faces
         | 
| 83 | 
            +
                    self._lr = lr
         | 
| 84 | 
            +
                    self._betas = betas
         | 
| 85 | 
            +
                    self._gammas = gammas
         | 
| 86 | 
            +
                    self._nu_ref = nu_ref
         | 
| 87 | 
            +
                    self._edge_len_lims = edge_len_lims
         | 
| 88 | 
            +
                    self._edge_len_tol = edge_len_tol
         | 
| 89 | 
            +
                    self._gain = gain
         | 
| 90 | 
            +
                    self._laplacian_weight = laplacian_weight
         | 
| 91 | 
            +
                    self._ramp = ramp
         | 
| 92 | 
            +
                    self._grad_lim = grad_lim
         | 
| 93 | 
            +
                    # self._remesh_interval = remesh_interval
         | 
| 94 | 
            +
                    # self._remseh_milestones = [ for remesh_milestones]
         | 
| 95 | 
            +
                    self._local_edgelen = local_edgelen
         | 
| 96 | 
            +
                    self._step = 0
         | 
| 97 | 
            +
                    self._start = time.time()
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    V = self._vertices.shape[0]
         | 
| 100 | 
            +
                    # prepare continuous tensor for all vertex-based data 
         | 
| 101 | 
            +
                    self._vertices_etc = torch.zeros([V,9],device=vertices.device)
         | 
| 102 | 
            +
                    self._split_vertices_etc()
         | 
| 103 | 
            +
                    self.vertices.copy_(vertices) #initialize vertices
         | 
| 104 | 
            +
                    self._vertices.requires_grad_()
         | 
| 105 | 
            +
                    self._ref_len.fill_(edge_len_lims[1])
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                @property
         | 
| 108 | 
            +
                def vertices(self):
         | 
| 109 | 
            +
                    return self._vertices
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                @property
         | 
| 112 | 
            +
                def faces(self):
         | 
| 113 | 
            +
                    return self._faces
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                def _split_vertices_etc(self):
         | 
| 116 | 
            +
                    self._vertices = self._vertices_etc[:,:3]
         | 
| 117 | 
            +
                    self._m2 = self._vertices_etc[:,3]
         | 
| 118 | 
            +
                    self._nu = self._vertices_etc[:,4]
         | 
| 119 | 
            +
                    self._m1 = self._vertices_etc[:,5:8]
         | 
| 120 | 
            +
                    self._ref_len = self._vertices_etc[:,8]
         | 
| 121 | 
            +
                    
         | 
| 122 | 
            +
                    with_gammas = any(g!=0 for g in self._gammas)
         | 
| 123 | 
            +
                    self._smooth = self._vertices_etc[:,:8] if with_gammas else self._vertices_etc[:,:3]
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                def zero_grad(self):
         | 
| 126 | 
            +
                    self._vertices.grad = None
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                @torch.no_grad()
         | 
| 129 | 
            +
                def step(self):
         | 
| 130 | 
            +
                    
         | 
| 131 | 
            +
                    eps = 1e-8
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    self._step += 1
         | 
| 134 | 
            +
                    # spatial smoothing
         | 
| 135 | 
            +
                    edges,_ = calc_edges(self._faces) #E,2
         | 
| 136 | 
            +
                    E = edges.shape[0]
         | 
| 137 | 
            +
                    edge_smooth = self._smooth[edges] #E,2,S
         | 
| 138 | 
            +
                    neighbor_smooth = torch.zeros_like(self._smooth) #V,S
         | 
| 139 | 
            +
                    torch_scatter.scatter_mean(src=edge_smooth.flip(dims=[1]).reshape(E*2,-1),index=edges.reshape(E*2,1),dim=0,out=neighbor_smooth)
         | 
| 140 | 
            +
                    #apply optional smoothing of m1,m2,nu
         | 
| 141 | 
            +
                    if self._gammas[0]:
         | 
| 142 | 
            +
                        self._m1.lerp_(neighbor_smooth[:,5:8],self._gammas[0])
         | 
| 143 | 
            +
                    if self._gammas[1]:
         | 
| 144 | 
            +
                        self._m2.lerp_(neighbor_smooth[:,3],self._gammas[1])
         | 
| 145 | 
            +
                    if self._gammas[2]:
         | 
| 146 | 
            +
                        self._nu.lerp_(neighbor_smooth[:,4],self._gammas[2])
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    #add laplace smoothing to gradients
         | 
| 149 | 
            +
                    laplace = self._vertices - neighbor_smooth[:,:3]
         | 
| 150 | 
            +
                    grad = torch.addcmul(self._vertices.grad, laplace, self._nu[:,None], value=self._laplacian_weight)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    #gradient clipping
         | 
| 153 | 
            +
                    if self._step>1:
         | 
| 154 | 
            +
                        grad_lim = self._m1.abs().mul_(self._grad_lim)
         | 
| 155 | 
            +
                        grad.clamp_(min=-grad_lim,max=grad_lim)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    # moment updates
         | 
| 158 | 
            +
                    lerp_unbiased(self._m1, grad, self._betas[0], self._step)
         | 
| 159 | 
            +
                    lerp_unbiased(self._m2, (grad**2).sum(dim=-1), self._betas[1], self._step)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    velocity = self._m1 / self._m2[:,None].sqrt().add_(eps) #V,3
         | 
| 162 | 
            +
                    speed = velocity.norm(dim=-1) #V
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    if self._betas[2]:
         | 
| 165 | 
            +
                        lerp_unbiased(self._nu,speed,self._betas[2],self._step) #V
         | 
| 166 | 
            +
                    else:
         | 
| 167 | 
            +
                        self._nu.copy_(speed) #V
         | 
| 168 | 
            +
                    # update vertices
         | 
| 169 | 
            +
                    ramped_lr = self._lr * min(1,self._step * (1-self._betas[0]) / self._ramp)
         | 
| 170 | 
            +
                    self._vertices.add_(velocity * self._ref_len[:,None], alpha=-ramped_lr)
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    # update target edge length
         | 
| 173 | 
            +
                    if self._step < 500:
         | 
| 174 | 
            +
                        self._remesh_interval = 4
         | 
| 175 | 
            +
                    elif self._step < 800:
         | 
| 176 | 
            +
                        self._remesh_interval = 2
         | 
| 177 | 
            +
                    else:
         | 
| 178 | 
            +
                        self._remesh_interval = 1 
         | 
| 179 | 
            +
                         
         | 
| 180 | 
            +
                    if self._step % self._remesh_interval == 0:
         | 
| 181 | 
            +
                        if self._local_edgelen:
         | 
| 182 | 
            +
                            len_change = (1 + (self._nu - self._nu_ref) * self._gain)
         | 
| 183 | 
            +
                        else:
         | 
| 184 | 
            +
                            len_change = (1 + (self._nu.mean() - self._nu_ref) * self._gain)
         | 
| 185 | 
            +
                        self._ref_len *= len_change
         | 
| 186 | 
            +
                        self._ref_len.clamp_(*self._edge_len_lims)
         | 
| 187 | 
            +
                        
         | 
| 188 | 
            +
                def remesh(self, flip:bool=True)->tuple[torch.Tensor,torch.Tensor]:
         | 
| 189 | 
            +
                    min_edge_len = self._ref_len * (1 - self._edge_len_tol)
         | 
| 190 | 
            +
                    max_edge_len = self._ref_len * (1 + self._edge_len_tol)
         | 
| 191 | 
            +
                        
         | 
| 192 | 
            +
                    self._vertices_etc,self._faces = remesh(self._vertices_etc,self._faces,min_edge_len,max_edge_len,flip)
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                    self._split_vertices_etc()
         | 
| 195 | 
            +
                    self._vertices.requires_grad_()
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    return self._vertices, self._faces
         | 
    	
        core/remesh.py
    ADDED
    
    | @@ -0,0 +1,359 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn.functional as tfunc
         | 
| 3 | 
            +
            import torch_scatter
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            def prepend_dummies(
         | 
| 6 | 
            +
                    vertices:torch.Tensor, #V,D
         | 
| 7 | 
            +
                    faces:torch.Tensor, #F,3 long
         | 
| 8 | 
            +
                )->tuple[torch.Tensor,torch.Tensor]:
         | 
| 9 | 
            +
                """prepend dummy elements to vertices and faces to enable "masked" scatter operations"""
         | 
| 10 | 
            +
                V,D = vertices.shape
         | 
| 11 | 
            +
                vertices = torch.concat((torch.full((1,D),fill_value=torch.nan,device=vertices.device),vertices),dim=0)
         | 
| 12 | 
            +
                faces = torch.concat((torch.zeros((1,3),dtype=torch.long,device=faces.device),faces+1),dim=0)
         | 
| 13 | 
            +
                return vertices,faces
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            def remove_dummies(
         | 
| 16 | 
            +
                    vertices:torch.Tensor, #V,D - first vertex all nan and unreferenced
         | 
| 17 | 
            +
                    faces:torch.Tensor, #F,3 long - first face all zeros
         | 
| 18 | 
            +
                )->tuple[torch.Tensor,torch.Tensor]:
         | 
| 19 | 
            +
                """remove dummy elements added with prepend_dummies()"""
         | 
| 20 | 
            +
                return vertices[1:],faces[1:]-1
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            def calc_edges(
         | 
| 24 | 
            +
                    faces: torch.Tensor,  # F,3 long - first face may be dummy with all zeros
         | 
| 25 | 
            +
                    with_edge_to_face: bool = False
         | 
| 26 | 
            +
                ) -> tuple[torch.Tensor, ...]:
         | 
| 27 | 
            +
                """
         | 
| 28 | 
            +
                returns tuple of
         | 
| 29 | 
            +
                - edges E,2 long, 0 for unused, lower vertex index first
         | 
| 30 | 
            +
                - face_to_edge F,3 long
         | 
| 31 | 
            +
                - (optional) edge_to_face shape=E,[left,right],[face,side]
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                o-<-----e1     e0,e1...edge, e0<e1
         | 
| 34 | 
            +
                |      /A      L,R....left and right face
         | 
| 35 | 
            +
                |  L /  |      both triangles ordered counter clockwise
         | 
| 36 | 
            +
                |  / R  |      normals pointing out of screen
         | 
| 37 | 
            +
                V/      |      
         | 
| 38 | 
            +
                e0---->-o     
         | 
| 39 | 
            +
                """
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                F = faces.shape[0]
         | 
| 42 | 
            +
                
         | 
| 43 | 
            +
                # make full edges, lower vertex index first
         | 
| 44 | 
            +
                face_edges = torch.stack((faces,faces.roll(-1,1)),dim=-1) #F*3,3,2
         | 
| 45 | 
            +
                full_edges = face_edges.reshape(F*3,2)
         | 
| 46 | 
            +
                sorted_edges,_ = full_edges.sort(dim=-1) #F*3,2 TODO min/max faster?
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                # make unique edges
         | 
| 49 | 
            +
                edges,full_to_unique = torch.unique(input=sorted_edges,sorted=True,return_inverse=True,dim=0) #(E,2),(F*3)
         | 
| 50 | 
            +
                E = edges.shape[0]
         | 
| 51 | 
            +
                face_to_edge = full_to_unique.reshape(F,3) #F,3
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                if not with_edge_to_face:
         | 
| 54 | 
            +
                    return edges, face_to_edge
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                is_right = full_edges[:,0]!=sorted_edges[:,0] #F*3
         | 
| 57 | 
            +
                edge_to_face = torch.zeros((E,2,2),dtype=torch.long,device=faces.device) #E,LR=2,S=2
         | 
| 58 | 
            +
                scatter_src = torch.cartesian_prod(torch.arange(0,F,device=faces.device),torch.arange(0,3,device=faces.device)) #F*3,2
         | 
| 59 | 
            +
                edge_to_face.reshape(2*E,2).scatter_(dim=0,index=(2*full_to_unique+is_right)[:,None].expand(F*3,2),src=scatter_src) #E,LR=2,S=2
         | 
| 60 | 
            +
                edge_to_face[0] = 0
         | 
| 61 | 
            +
                return edges, face_to_edge, edge_to_face
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            def calc_edge_length(
         | 
| 64 | 
            +
                    vertices:torch.Tensor, #V,3 first may be dummy
         | 
| 65 | 
            +
                    edges:torch.Tensor, #E,2 long, lower vertex index first, (0,0) for unused
         | 
| 66 | 
            +
                    )->torch.Tensor: #E
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                full_vertices = vertices[edges] #E,2,3
         | 
| 69 | 
            +
                a,b = full_vertices.unbind(dim=1) #E,3
         | 
| 70 | 
            +
                return torch.norm(a-b,p=2,dim=-1)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            def calc_face_normals(
         | 
| 73 | 
            +
                    vertices:torch.Tensor, #V,3 first vertex may be unreferenced
         | 
| 74 | 
            +
                    faces:torch.Tensor, #F,3 long, first face may be all zero
         | 
| 75 | 
            +
                    normalize:bool=False,
         | 
| 76 | 
            +
                    )->torch.Tensor: #F,3
         | 
| 77 | 
            +
                """
         | 
| 78 | 
            +
                     n
         | 
| 79 | 
            +
                     |
         | 
| 80 | 
            +
                     c0     corners ordered counterclockwise when
         | 
| 81 | 
            +
                    / \     looking onto surface (in neg normal direction)
         | 
| 82 | 
            +
                  c1---c2
         | 
| 83 | 
            +
                """
         | 
| 84 | 
            +
                full_vertices = vertices[faces] #F,C=3,3
         | 
| 85 | 
            +
                v0,v1,v2 = full_vertices.unbind(dim=1) #F,3
         | 
| 86 | 
            +
                face_normals = torch.cross(v1-v0,v2-v0, dim=1) #F,3
         | 
| 87 | 
            +
                if normalize:
         | 
| 88 | 
            +
                    face_normals = tfunc.normalize(face_normals, eps=1e-6, dim=1) #TODO inplace?
         | 
| 89 | 
            +
                return face_normals #F,3
         | 
| 90 | 
            +
             | 
| 91 | 
            +
            def calc_vertex_normals(
         | 
| 92 | 
            +
                    vertices:torch.Tensor, #V,3 first vertex may be unreferenced
         | 
| 93 | 
            +
                    faces:torch.Tensor, #F,3 long, first face may be all zero
         | 
| 94 | 
            +
                    face_normals:torch.Tensor=None, #F,3, not normalized
         | 
| 95 | 
            +
                    )->torch.Tensor: #F,3
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                F = faces.shape[0]
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                if face_normals is None:
         | 
| 100 | 
            +
                    face_normals = calc_face_normals(vertices,faces)
         | 
| 101 | 
            +
                
         | 
| 102 | 
            +
                vertex_normals = torch.zeros((vertices.shape[0],3,3),dtype=vertices.dtype,device=vertices.device) #V,C=3,3
         | 
| 103 | 
            +
                vertex_normals.scatter_add_(dim=0,index=faces[:,:,None].expand(F,3,3),src=face_normals[:,None,:].expand(F,3,3))
         | 
| 104 | 
            +
                vertex_normals = vertex_normals.sum(dim=1) #V,3
         | 
| 105 | 
            +
                return tfunc.normalize(vertex_normals, eps=1e-6, dim=1)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
            def calc_face_ref_normals(
         | 
| 108 | 
            +
                    faces:torch.Tensor, #F,3 long, 0 for unused
         | 
| 109 | 
            +
                    vertex_normals:torch.Tensor, #V,3 first unused
         | 
| 110 | 
            +
                    normalize:bool=False,
         | 
| 111 | 
            +
                    )->torch.Tensor: #F,3
         | 
| 112 | 
            +
                """calculate reference normals for face flip detection"""
         | 
| 113 | 
            +
                full_normals = vertex_normals[faces] #F,C=3,3
         | 
| 114 | 
            +
                ref_normals = full_normals.sum(dim=1) #F,3
         | 
| 115 | 
            +
                if normalize:
         | 
| 116 | 
            +
                    ref_normals = tfunc.normalize(ref_normals, eps=1e-6, dim=1)
         | 
| 117 | 
            +
                return ref_normals
         | 
| 118 | 
            +
             | 
| 119 | 
            +
            def pack(
         | 
| 120 | 
            +
                    vertices:torch.Tensor, #V,3 first unused and nan
         | 
| 121 | 
            +
                    faces:torch.Tensor, #F,3 long, 0 for unused
         | 
| 122 | 
            +
                    )->tuple[torch.Tensor,torch.Tensor]: #(vertices,faces), keeps first vertex unused
         | 
| 123 | 
            +
                """removes unused elements in vertices and faces"""
         | 
| 124 | 
            +
                V = vertices.shape[0]
         | 
| 125 | 
            +
                
         | 
| 126 | 
            +
                # remove unused faces
         | 
| 127 | 
            +
                used_faces = faces[:,0]!=0
         | 
| 128 | 
            +
                used_faces[0] = True
         | 
| 129 | 
            +
                faces = faces[used_faces] #sync
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                # remove unused vertices
         | 
| 132 | 
            +
                used_vertices = torch.zeros(V,3,dtype=torch.bool,device=vertices.device)
         | 
| 133 | 
            +
                used_vertices.scatter_(dim=0,index=faces,value=True,reduce='add') #TODO int faster?
         | 
| 134 | 
            +
                used_vertices = used_vertices.any(dim=1)
         | 
| 135 | 
            +
                used_vertices[0] = True
         | 
| 136 | 
            +
                vertices = vertices[used_vertices] #sync
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                # update used faces
         | 
| 139 | 
            +
                ind = torch.zeros(V,dtype=torch.long,device=vertices.device)
         | 
| 140 | 
            +
                V1 = used_vertices.sum()
         | 
| 141 | 
            +
                ind[used_vertices] =  torch.arange(0,V1,device=vertices.device) #sync
         | 
| 142 | 
            +
                faces = ind[faces]
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                return vertices,faces
         | 
| 145 | 
            +
             | 
| 146 | 
            +
            def split_edges(
         | 
| 147 | 
            +
                    vertices:torch.Tensor, #V,3 first unused
         | 
| 148 | 
            +
                    faces:torch.Tensor, #F,3 long, 0 for unused
         | 
| 149 | 
            +
                    edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
         | 
| 150 | 
            +
                    face_to_edge:torch.Tensor, #F,3 long 0 for unused
         | 
| 151 | 
            +
                    splits, #E bool
         | 
| 152 | 
            +
                    pack_faces:bool=True,
         | 
| 153 | 
            +
                    )->tuple[torch.Tensor,torch.Tensor]: #(vertices,faces)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                #   c2                    c2               c...corners = faces
         | 
| 156 | 
            +
                #    . .                   . .             s...side_vert, 0 means no split
         | 
| 157 | 
            +
                #    .   .                 .N2 .           S...shrunk_face
         | 
| 158 | 
            +
                #    .     .               .     .         Ni...new_faces
         | 
| 159 | 
            +
                #   s2      s1           s2|c2...s1|c1
         | 
| 160 | 
            +
                #    .        .            .     .  .
         | 
| 161 | 
            +
                #    .          .          . S .      .
         | 
| 162 | 
            +
                #    .            .        . .     N1    .
         | 
| 163 | 
            +
                #   c0...(s0=0)....c1    s0|c0...........c1
         | 
| 164 | 
            +
                #
         | 
| 165 | 
            +
                # pseudo-code:
         | 
| 166 | 
            +
                #   S = [s0|c0,s1|c1,s2|c2] example:[c0,s1,s2]
         | 
| 167 | 
            +
                #   split = side_vert!=0 example:[False,True,True]
         | 
| 168 | 
            +
                #   N0 = split[0]*[c0,s0,s2|c2] example:[0,0,0]
         | 
| 169 | 
            +
                #   N1 = split[1]*[c1,s1,s0|c0] example:[c1,s1,c0]
         | 
| 170 | 
            +
                #   N2 = split[2]*[c2,s2,s1|c1] example:[c2,s2,s1]
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                V = vertices.shape[0]
         | 
| 173 | 
            +
                F = faces.shape[0]
         | 
| 174 | 
            +
                S = splits.sum().item() #sync
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                if S==0:
         | 
| 177 | 
            +
                    return vertices,faces
         | 
| 178 | 
            +
                
         | 
| 179 | 
            +
                edge_vert = torch.zeros_like(splits, dtype=torch.long) #E
         | 
| 180 | 
            +
                edge_vert[splits] = torch.arange(V,V+S,dtype=torch.long,device=vertices.device) #E 0 for no split, sync
         | 
| 181 | 
            +
                side_vert = edge_vert[face_to_edge] #F,3 long, 0 for no split
         | 
| 182 | 
            +
                split_edges = edges[splits] #S sync
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                #vertices
         | 
| 185 | 
            +
                split_vertices = vertices[split_edges].mean(dim=1) #S,3
         | 
| 186 | 
            +
                vertices = torch.concat((vertices,split_vertices),dim=0)
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                #faces
         | 
| 189 | 
            +
                side_split = side_vert!=0 #F,3
         | 
| 190 | 
            +
                shrunk_faces = torch.where(side_split,side_vert,faces) #F,3 long, 0 for no split
         | 
| 191 | 
            +
                new_faces = side_split[:,:,None] * torch.stack((faces,side_vert,shrunk_faces.roll(1,dims=-1)),dim=-1) #F,N=3,C=3
         | 
| 192 | 
            +
                faces = torch.concat((shrunk_faces,new_faces.reshape(F*3,3))) #4F,3
         | 
| 193 | 
            +
                if pack_faces:
         | 
| 194 | 
            +
                    mask = faces[:,0]!=0
         | 
| 195 | 
            +
                    mask[0] = True
         | 
| 196 | 
            +
                    faces = faces[mask] #F',3 sync
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                return vertices,faces
         | 
| 199 | 
            +
             | 
| 200 | 
            +
            def collapse_edges(
         | 
| 201 | 
            +
                    vertices:torch.Tensor, #V,3 first unused
         | 
| 202 | 
            +
                    faces:torch.Tensor, #F,3 long 0 for unused
         | 
| 203 | 
            +
                    edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
         | 
| 204 | 
            +
                    priorities:torch.Tensor, #E float
         | 
| 205 | 
            +
                    stable:bool=False, #only for unit testing
         | 
| 206 | 
            +
                    )->tuple[torch.Tensor,torch.Tensor]: #(vertices,faces)
         | 
| 207 | 
            +
                    
         | 
| 208 | 
            +
                V = vertices.shape[0]
         | 
| 209 | 
            +
                
         | 
| 210 | 
            +
                # check spacing
         | 
| 211 | 
            +
                _,order = priorities.sort(stable=stable) #E
         | 
| 212 | 
            +
                rank = torch.zeros_like(order)
         | 
| 213 | 
            +
                rank[order] = torch.arange(0,len(rank),device=rank.device)
         | 
| 214 | 
            +
                vert_rank = torch.zeros(V,dtype=torch.long,device=vertices.device) #V
         | 
| 215 | 
            +
                edge_rank = rank #E
         | 
| 216 | 
            +
                for i in range(3):
         | 
| 217 | 
            +
                    torch_scatter.scatter_max(src=edge_rank[:,None].expand(-1,2).reshape(-1),index=edges.reshape(-1),dim=0,out=vert_rank)
         | 
| 218 | 
            +
                    edge_rank,_ = vert_rank[edges].max(dim=-1) #E
         | 
| 219 | 
            +
                candidates = edges[(edge_rank==rank).logical_and_(priorities>0)] #E',2
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                # check connectivity
         | 
| 222 | 
            +
                vert_connections = torch.zeros(V,dtype=torch.long,device=vertices.device) #V
         | 
| 223 | 
            +
                vert_connections[candidates[:,0]] = 1 #start
         | 
| 224 | 
            +
                edge_connections = vert_connections[edges].sum(dim=-1) #E, edge connected to start
         | 
| 225 | 
            +
                vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1))# one edge from start
         | 
| 226 | 
            +
                vert_connections[candidates] = 0 #clear start and end
         | 
| 227 | 
            +
                edge_connections = vert_connections[edges].sum(dim=-1) #E, one or two edges from start
         | 
| 228 | 
            +
                vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1)) #one or two edges from start
         | 
| 229 | 
            +
                collapses = candidates[vert_connections[candidates[:,1]] <= 2] # E" not more than two connections between start and end
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                # mean vertices
         | 
| 232 | 
            +
                vertices[collapses[:,0]] = vertices[collapses].mean(dim=1) #TODO dim?
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                # update faces
         | 
| 235 | 
            +
                dest = torch.arange(0,V,dtype=torch.long,device=vertices.device) #V
         | 
| 236 | 
            +
                dest[collapses[:,1]] = dest[collapses[:,0]]
         | 
| 237 | 
            +
                faces = dest[faces] #F,3 TODO optimize?
         | 
| 238 | 
            +
                c0,c1,c2 = faces.unbind(dim=-1)
         | 
| 239 | 
            +
                collapsed = (c0==c1).logical_or_(c1==c2).logical_or_(c0==c2)
         | 
| 240 | 
            +
                faces[collapsed] = 0
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                return vertices,faces
         | 
| 243 | 
            +
             | 
| 244 | 
            +
            def calc_face_collapses(
         | 
| 245 | 
            +
                    vertices:torch.Tensor, #V,3 first unused
         | 
| 246 | 
            +
                    faces:torch.Tensor, #F,3 long, 0 for unused
         | 
| 247 | 
            +
                    edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
         | 
| 248 | 
            +
                    face_to_edge:torch.Tensor, #F,3 long 0 for unused
         | 
| 249 | 
            +
                    edge_length:torch.Tensor, #E
         | 
| 250 | 
            +
                    face_normals:torch.Tensor, #F,3
         | 
| 251 | 
            +
                    vertex_normals:torch.Tensor, #V,3 first unused
         | 
| 252 | 
            +
                    min_edge_length:torch.Tensor=None, #V
         | 
| 253 | 
            +
                    area_ratio = 0.5, #collapse if area < min_edge_length**2 * area_ratio
         | 
| 254 | 
            +
                    shortest_probability = 0.8
         | 
| 255 | 
            +
                    )->torch.Tensor: #E edges to collapse
         | 
| 256 | 
            +
                
         | 
| 257 | 
            +
                E = edges.shape[0]
         | 
| 258 | 
            +
                F = faces.shape[0]
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                # face flips
         | 
| 261 | 
            +
                ref_normals = calc_face_ref_normals(faces,vertex_normals,normalize=False) #F,3
         | 
| 262 | 
            +
                face_collapses = (face_normals*ref_normals).sum(dim=-1)<0 #F
         | 
| 263 | 
            +
                
         | 
| 264 | 
            +
                # small faces
         | 
| 265 | 
            +
                if min_edge_length is not None:
         | 
| 266 | 
            +
                    min_face_length = min_edge_length[faces].mean(dim=-1) #F
         | 
| 267 | 
            +
                    min_area = min_face_length**2 * area_ratio #F
         | 
| 268 | 
            +
                    face_collapses.logical_or_(face_normals.norm(dim=-1) < min_area*2) #F
         | 
| 269 | 
            +
                    face_collapses[0] = False
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                # faces to edges
         | 
| 272 | 
            +
                face_length = edge_length[face_to_edge] #F,3
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                if shortest_probability<1:
         | 
| 275 | 
            +
                    #select shortest edge with shortest_probability chance
         | 
| 276 | 
            +
                    randlim = round(2/(1-shortest_probability))
         | 
| 277 | 
            +
                    rand_ind = torch.randint(0,randlim,size=(F,),device=faces.device).clamp_max_(2) #selected edge local index in face
         | 
| 278 | 
            +
                    sort_ind = torch.argsort(face_length,dim=-1,descending=True) #F,3
         | 
| 279 | 
            +
                    local_ind = sort_ind.gather(dim=-1,index=rand_ind[:,None])
         | 
| 280 | 
            +
                else:
         | 
| 281 | 
            +
                    local_ind = torch.argmin(face_length,dim=-1)[:,None] #F,1 0...2 shortest edge local index in face
         | 
| 282 | 
            +
                
         | 
| 283 | 
            +
                edge_ind = face_to_edge.gather(dim=1,index=local_ind)[:,0] #F 0...E selected edge global index
         | 
| 284 | 
            +
                edge_collapses = torch.zeros(E,dtype=torch.long,device=vertices.device)
         | 
| 285 | 
            +
                edge_collapses.scatter_add_(dim=0,index=edge_ind,src=face_collapses.long()) #TODO legal for bool?
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                return edge_collapses.bool()
         | 
| 288 | 
            +
             | 
| 289 | 
            +
            def flip_edges(
         | 
| 290 | 
            +
                    vertices:torch.Tensor, #V,3 first unused
         | 
| 291 | 
            +
                    faces:torch.Tensor, #F,3 long, first must be 0, 0 for unused
         | 
| 292 | 
            +
                    edges:torch.Tensor, #E,2 long, first must be 0, 0 for unused, lower vertex index first
         | 
| 293 | 
            +
                    edge_to_face:torch.Tensor, #E,[left,right],[face,side]
         | 
| 294 | 
            +
                    with_border:bool=True, #handle border edges (D=4 instead of D=6)
         | 
| 295 | 
            +
                    with_normal_check:bool=True, #check face normal flips
         | 
| 296 | 
            +
                    stable:bool=False, #only for unit testing
         | 
| 297 | 
            +
                    ):
         | 
| 298 | 
            +
                V = vertices.shape[0]
         | 
| 299 | 
            +
                E = edges.shape[0]
         | 
| 300 | 
            +
                device=vertices.device
         | 
| 301 | 
            +
                vertex_degree = torch.zeros(V,dtype=torch.long,device=device) #V long
         | 
| 302 | 
            +
                vertex_degree.scatter_(dim=0,index=edges.reshape(E*2),value=1,reduce='add')
         | 
| 303 | 
            +
                neighbor_corner = (edge_to_face[:,:,1] + 2) % 3 #go from side to corner
         | 
| 304 | 
            +
                neighbors = faces[edge_to_face[:,:,0],neighbor_corner] #E,LR=2
         | 
| 305 | 
            +
                edge_is_inside = neighbors.all(dim=-1) #E
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                if with_border:
         | 
| 308 | 
            +
                    # inside vertices should have D=6, border edges D=4, so we subtract 2 for all inside vertices
         | 
| 309 | 
            +
                    # need to use float for masks in order to use scatter(reduce='multiply')
         | 
| 310 | 
            +
                    vertex_is_inside = torch.ones(V,2,dtype=torch.float32,device=vertices.device) #V,2 float
         | 
| 311 | 
            +
                    src = edge_is_inside.type(torch.float32)[:,None].expand(E,2) #E,2 float
         | 
| 312 | 
            +
                    vertex_is_inside.scatter_(dim=0,index=edges,src=src,reduce='multiply')
         | 
| 313 | 
            +
                    vertex_is_inside = vertex_is_inside.prod(dim=-1,dtype=torch.long) #V long
         | 
| 314 | 
            +
                    vertex_degree -= 2 * vertex_is_inside #V long
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                neighbor_degrees = vertex_degree[neighbors] #E,LR=2
         | 
| 317 | 
            +
                edge_degrees = vertex_degree[edges] #E,2
         | 
| 318 | 
            +
                #
         | 
| 319 | 
            +
                # loss = Sum_over_affected_vertices((new_degree-6)**2)
         | 
| 320 | 
            +
                # loss_change = Sum_over_neighbor_vertices((degree+1-6)**2-(degree-6)**2)
         | 
| 321 | 
            +
                #                   + Sum_over_edge_vertices((degree-1-6)**2-(degree-6)**2)
         | 
| 322 | 
            +
                #             = 2 * (2 + Sum_over_neighbor_vertices(degree) - Sum_over_edge_vertices(degree))
         | 
| 323 | 
            +
                #
         | 
| 324 | 
            +
                loss_change = 2 + neighbor_degrees.sum(dim=-1) - edge_degrees.sum(dim=-1) #E
         | 
| 325 | 
            +
                candidates = torch.logical_and(loss_change<0, edge_is_inside) #E
         | 
| 326 | 
            +
                loss_change = loss_change[candidates] #E'
         | 
| 327 | 
            +
                if loss_change.shape[0]==0:
         | 
| 328 | 
            +
                    return
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                edges_neighbors = torch.concat((edges[candidates],neighbors[candidates]),dim=-1) #E',4
         | 
| 331 | 
            +
                _,order = loss_change.sort(descending=True, stable=stable) #E'
         | 
| 332 | 
            +
                rank = torch.zeros_like(order)
         | 
| 333 | 
            +
                rank[order] = torch.arange(0,len(rank),device=rank.device)
         | 
| 334 | 
            +
                vertex_rank = torch.zeros((V,4),dtype=torch.long,device=device) #V,4
         | 
| 335 | 
            +
                torch_scatter.scatter_max(src=rank[:,None].expand(-1,4),index=edges_neighbors,dim=0,out=vertex_rank)
         | 
| 336 | 
            +
                vertex_rank,_ = vertex_rank.max(dim=-1) #V
         | 
| 337 | 
            +
                neighborhood_rank,_ = vertex_rank[edges_neighbors].max(dim=-1) #E'
         | 
| 338 | 
            +
                flip = rank==neighborhood_rank #E'
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                if with_normal_check:
         | 
| 341 | 
            +
                    #  cl-<-----e1     e0,e1...edge, e0<e1
         | 
| 342 | 
            +
                    #   |      /A      L,R....left and right face
         | 
| 343 | 
            +
                    #   |  L /  |      both triangles ordered counter clockwise
         | 
| 344 | 
            +
                    #   |  / R  |      normals pointing out of screen
         | 
| 345 | 
            +
                    #   V/      |      
         | 
| 346 | 
            +
                    #   e0---->-cr    
         | 
| 347 | 
            +
                    v = vertices[edges_neighbors] #E",4,3
         | 
| 348 | 
            +
                    v = v - v[:,0:1] #make relative to e0 
         | 
| 349 | 
            +
                    e1 = v[:,1]
         | 
| 350 | 
            +
                    cl = v[:,2]
         | 
| 351 | 
            +
                    cr = v[:,3]
         | 
| 352 | 
            +
                    n = torch.cross(e1,cl) + torch.cross(cr,e1) #sum of old normal vectors 
         | 
| 353 | 
            +
                    flip.logical_and_(torch.sum(n*torch.cross(cr,cl),dim=-1)>0) #first new face
         | 
| 354 | 
            +
                    flip.logical_and_(torch.sum(n*torch.cross(cl-e1,cr-e1),dim=-1)>0) #second new face
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                flip_edges_neighbors = edges_neighbors[flip] #E",4
         | 
| 357 | 
            +
                flip_edge_to_face = edge_to_face[candidates,:,0][flip] #E",2
         | 
| 358 | 
            +
                flip_faces = flip_edges_neighbors[:,[[0,3,2],[1,2,3]]] #E",2,3
         | 
| 359 | 
            +
                faces.scatter_(dim=0,index=flip_edge_to_face.reshape(-1,1).expand(-1,3),src=flip_faces.reshape(-1,3))
         | 
    	
        econdataset.py
    ADDED
    
    | @@ -0,0 +1,370 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
         | 
| 4 | 
            +
            # holder of all proprietary rights on this computer program.
         | 
| 5 | 
            +
            # You can only use this computer program if you have closed
         | 
| 6 | 
            +
            # a license agreement with MPG or you get the right to use the computer
         | 
| 7 | 
            +
            # program from someone who is authorized to grant you that right.
         | 
| 8 | 
            +
            # Any use of the computer program without a valid license is prohibited and
         | 
| 9 | 
            +
            # liable to prosecution.
         | 
| 10 | 
            +
            #
         | 
| 11 | 
            +
            # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
         | 
| 12 | 
            +
            # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
         | 
| 13 | 
            +
            # for Intelligent Systems. All rights reserved.
         | 
| 14 | 
            +
            #
         | 
| 15 | 
            +
            # Contact: [email protected]
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from lib.hybrik.models.simple3dpose import HybrIKBaseSMPLCam
         | 
| 18 | 
            +
            from lib.pixielib.utils.config import cfg as pixie_cfg
         | 
| 19 | 
            +
            from lib.pixielib.pixie import PIXIE
         | 
| 20 | 
            +
            import lib.smplx as smplx
         | 
| 21 | 
            +
            # from lib.pare.pare.core.tester import PARETester
         | 
| 22 | 
            +
            from lib.pymaf.utils.geometry import rot6d_to_rotmat, batch_rodrigues, rotation_matrix_to_angle_axis
         | 
| 23 | 
            +
            from lib.pymaf.utils.imutils import process_image
         | 
| 24 | 
            +
            from lib.common.imutils import econ_process_image
         | 
| 25 | 
            +
            from lib.pymaf.core import path_config
         | 
| 26 | 
            +
            from lib.pymaf.models import pymaf_net
         | 
| 27 | 
            +
            from lib.common.config import cfg
         | 
| 28 | 
            +
            from lib.common.render import Render
         | 
| 29 | 
            +
            from lib.dataset.body_model import TetraSMPLModel
         | 
| 30 | 
            +
            from lib.dataset.mesh_util import get_visibility
         | 
| 31 | 
            +
            from utils.smpl_util import SMPLX
         | 
| 32 | 
            +
            import os.path as osp
         | 
| 33 | 
            +
            import os
         | 
| 34 | 
            +
            import torch
         | 
| 35 | 
            +
            import numpy as np
         | 
| 36 | 
            +
            import random
         | 
| 37 | 
            +
            from termcolor import colored
         | 
| 38 | 
            +
            from PIL import ImageFile
         | 
| 39 | 
            +
            from torchvision.models import detection
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            ImageFile.LOAD_TRUNCATED_IMAGES = True
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            class SMPLDataset():
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                def __init__(self, cfg, device):
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    random.seed(1993)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    self.image_dir = cfg['image_dir']
         | 
| 52 | 
            +
                    self.seg_dir = cfg['seg_dir']
         | 
| 53 | 
            +
                    self.hps_type = cfg['hps_type']
         | 
| 54 | 
            +
                    self.smpl_type = 'smpl' if cfg['hps_type'] != 'pixie' else 'smplx'
         | 
| 55 | 
            +
                    self.smpl_gender = 'neutral'
         | 
| 56 | 
            +
                    self.colab = cfg['colab']
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    self.device = device
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    keep_lst = [f"{self.image_dir}/{i}" for i in  sorted(os.listdir(self.image_dir))]
         | 
| 61 | 
            +
                    img_fmts = ['jpg', 'png', 'jpeg', "JPG", 'bmp']
         | 
| 62 | 
            +
                    keep_lst = [item for item in keep_lst if item.split(".")[-1] in img_fmts]
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    self.subject_list = [item for item in keep_lst if item.split(".")[-1] in img_fmts]
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    if self.colab:
         | 
| 67 | 
            +
                        self.subject_list = [self.subject_list[0]]
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    # smpl related
         | 
| 70 | 
            +
                    self.smpl_data = SMPLX()
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    # smpl-smplx correspondence
         | 
| 73 | 
            +
                    self.smpl_joint_ids_24 = np.arange(22).tolist() + [68, 73]
         | 
| 74 | 
            +
                    self.smpl_joint_ids_24_pixie = np.arange(22).tolist() + [68 + 61, 72 + 68]
         | 
| 75 | 
            +
                    self.get_smpl_model = lambda smpl_type, smpl_gender: smplx.create(model_path=self.smpl_data.
         | 
| 76 | 
            +
                                                                                      model_dir,
         | 
| 77 | 
            +
                                                                                      gender=smpl_gender,
         | 
| 78 | 
            +
                                                                                      model_type=smpl_type,
         | 
| 79 | 
            +
                                                                                      ext='npz')
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    # Load SMPL model
         | 
| 82 | 
            +
                    self.smpl_model = self.get_smpl_model(self.smpl_type, self.smpl_gender).to(self.device)
         | 
| 83 | 
            +
                    self.faces = self.smpl_model.faces
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    if self.hps_type == 'pymaf':
         | 
| 86 | 
            +
                        self.hps = pymaf_net(path_config.SMPL_MEAN_PARAMS, pretrained=True).to(self.device)
         | 
| 87 | 
            +
                        self.hps.load_state_dict(torch.load(path_config.CHECKPOINT_FILE)['model'], strict=True)
         | 
| 88 | 
            +
                        self.hps.eval()
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    elif self.hps_type == 'pare':
         | 
| 91 | 
            +
                        self.hps = PARETester(path_config.CFG, path_config.CKPT).model
         | 
| 92 | 
            +
                    elif self.hps_type == 'pixie':
         | 
| 93 | 
            +
                        self.hps = PIXIE(config=pixie_cfg, device=self.device)
         | 
| 94 | 
            +
                        self.smpl_model = self.hps.smplx
         | 
| 95 | 
            +
                    elif self.hps_type == 'hybrik':
         | 
| 96 | 
            +
                        smpl_path = osp.join(self.smpl_data.model_dir, "smpl/SMPL_NEUTRAL.pkl")
         | 
| 97 | 
            +
                        self.hps = HybrIKBaseSMPLCam(cfg_file=path_config.HYBRIK_CFG,
         | 
| 98 | 
            +
                                                     smpl_path=smpl_path,
         | 
| 99 | 
            +
                                                     data_path=path_config.hybrik_data_dir)
         | 
| 100 | 
            +
                        self.hps.load_state_dict(torch.load(path_config.HYBRIK_CKPT, map_location='cpu'),
         | 
| 101 | 
            +
                                                 strict=False)
         | 
| 102 | 
            +
                        self.hps.to(self.device)
         | 
| 103 | 
            +
                    elif self.hps_type == 'bev':
         | 
| 104 | 
            +
                        try:
         | 
| 105 | 
            +
                            import bev
         | 
| 106 | 
            +
                        except:
         | 
| 107 | 
            +
                            print('Could not find bev, installing via pip install --upgrade simple-romp')
         | 
| 108 | 
            +
                            os.system('pip install simple-romp==1.0.3')
         | 
| 109 | 
            +
                            import bev
         | 
| 110 | 
            +
                        settings = bev.main.default_settings
         | 
| 111 | 
            +
                        # change the argparse settings of bev here if you prefer other settings.
         | 
| 112 | 
            +
                        settings.mode = 'image'
         | 
| 113 | 
            +
                        settings.GPU = int(str(self.device).split(':')[1])
         | 
| 114 | 
            +
                        settings.show_largest = True
         | 
| 115 | 
            +
                        # settings.show = True # uncommit this to show the original BEV predictions
         | 
| 116 | 
            +
                        self.hps = bev.BEV(settings)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    self.detector=detection.maskrcnn_resnet50_fpn(pretrained=True)
         | 
| 119 | 
            +
                    self.detector.eval()
         | 
| 120 | 
            +
                    print(colored(f"Using {self.hps_type} as HPS Estimator\n", "green"))
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    self.render = Render(size=512, device=device)
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                def __len__(self):
         | 
| 125 | 
            +
                    return len(self.subject_list)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                def compute_vis_cmap(self, smpl_verts, smpl_faces):
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    (xy, z) = torch.as_tensor(smpl_verts).split([2, 1], dim=1)
         | 
| 130 | 
            +
                    smpl_vis = get_visibility(xy, -z, torch.as_tensor(smpl_faces).long())
         | 
| 131 | 
            +
                    smpl_cmap = self.smpl_data.cmap_smpl_vids(self.smpl_type)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    return {
         | 
| 134 | 
            +
                        'smpl_vis': smpl_vis.unsqueeze(0).to(self.device),
         | 
| 135 | 
            +
                        'smpl_cmap': smpl_cmap.unsqueeze(0).to(self.device),
         | 
| 136 | 
            +
                        'smpl_verts': smpl_verts.unsqueeze(0)
         | 
| 137 | 
            +
                    }
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                def compute_voxel_verts(self, body_pose, global_orient, betas, trans, scale):
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    smpl_path = osp.join(self.smpl_data.model_dir, "smpl/SMPL_NEUTRAL.pkl")
         | 
| 142 | 
            +
                    tetra_path = osp.join(self.smpl_data.tedra_dir, 'tetra_neutral_adult_smpl.npz')
         | 
| 143 | 
            +
                    smpl_model = TetraSMPLModel(smpl_path, tetra_path, 'adult')
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    pose = torch.cat([global_orient[0], body_pose[0]], dim=0)
         | 
| 146 | 
            +
                    smpl_model.set_params(rotation_matrix_to_angle_axis(rot6d_to_rotmat(pose)), beta=betas[0])
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    verts = np.concatenate([smpl_model.verts, smpl_model.verts_added],
         | 
| 149 | 
            +
                                           axis=0) * scale.item() + trans.detach().cpu().numpy()
         | 
| 150 | 
            +
                    faces = np.loadtxt(osp.join(self.smpl_data.tedra_dir, 'tetrahedrons_neutral_adult.txt'),
         | 
| 151 | 
            +
                                       dtype=np.int32) - 1
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    pad_v_num = int(8000 - verts.shape[0])
         | 
| 154 | 
            +
                    pad_f_num = int(25100 - faces.shape[0])
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    verts = np.pad(verts,
         | 
| 157 | 
            +
                                   ((0, pad_v_num),
         | 
| 158 | 
            +
                                    (0, 0)), mode='constant', constant_values=0.0).astype(np.float32) * 0.5
         | 
| 159 | 
            +
                    faces = np.pad(faces, ((0, pad_f_num), (0, 0)), mode='constant',
         | 
| 160 | 
            +
                                   constant_values=0.0).astype(np.int32)
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    verts[:, 2] *= -1.0
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    voxel_dict = {
         | 
| 165 | 
            +
                        'voxel_verts': torch.from_numpy(verts).to(self.device).unsqueeze(0).float(),
         | 
| 166 | 
            +
                        'voxel_faces': torch.from_numpy(faces).to(self.device).unsqueeze(0).long(),
         | 
| 167 | 
            +
                        'pad_v_num': torch.tensor(pad_v_num).to(self.device).unsqueeze(0).long(),
         | 
| 168 | 
            +
                        'pad_f_num': torch.tensor(pad_f_num).to(self.device).unsqueeze(0).long()
         | 
| 169 | 
            +
                    }
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    return voxel_dict
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                def __getitem__(self, index):
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    img_path = self.subject_list[index]
         | 
| 176 | 
            +
                    img_name = img_path.split("/")[-1].rsplit(".", 1)[0]
         | 
| 177 | 
            +
                    print(img_name)
         | 
| 178 | 
            +
                    # smplx_param_path=f'./data/thuman2/smplx/{img_name[:-2]}.pkl'
         | 
| 179 | 
            +
                    # smplx_param = np.load(smplx_param_path, allow_pickle=True)
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    if self.seg_dir is None:
         | 
| 182 | 
            +
                        img_icon, img_hps, img_ori, img_mask, uncrop_param = process_image(
         | 
| 183 | 
            +
                            img_path, self.hps_type, 512, self.device)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                        data_dict = {
         | 
| 186 | 
            +
                            'name': img_name,
         | 
| 187 | 
            +
                            'image': img_icon.to(self.device).unsqueeze(0),
         | 
| 188 | 
            +
                            'ori_image': img_ori,
         | 
| 189 | 
            +
                            'mask': img_mask,
         | 
| 190 | 
            +
                            'uncrop_param': uncrop_param
         | 
| 191 | 
            +
                        }
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    else:
         | 
| 194 | 
            +
                        img_icon, img_hps, img_ori, img_mask, uncrop_param, segmentations = process_image(
         | 
| 195 | 
            +
                            img_path,
         | 
| 196 | 
            +
                            self.hps_type,
         | 
| 197 | 
            +
                            512,
         | 
| 198 | 
            +
                            self.device,
         | 
| 199 | 
            +
                            seg_path=os.path.join(self.seg_dir, f'{img_name}.json'))
         | 
| 200 | 
            +
                        data_dict = {
         | 
| 201 | 
            +
                            'name': img_name,
         | 
| 202 | 
            +
                            'image': img_icon.to(self.device).unsqueeze(0),
         | 
| 203 | 
            +
                            'ori_image': img_ori,
         | 
| 204 | 
            +
                            'mask': img_mask,
         | 
| 205 | 
            +
                            'uncrop_param': uncrop_param,
         | 
| 206 | 
            +
                            'segmentations': segmentations
         | 
| 207 | 
            +
                        }
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                    arr_dict=econ_process_image(img_path,self.hps_type,True,512,self.detector)
         | 
| 210 | 
            +
                    data_dict['hands_visibility']=arr_dict['hands_visibility']
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                    with torch.no_grad():
         | 
| 213 | 
            +
                        # import ipdb; ipdb.set_trace()
         | 
| 214 | 
            +
                        preds_dict = self.hps.forward(img_hps)
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    data_dict['smpl_faces'] = torch.Tensor(self.faces.astype(np.int64)).long().unsqueeze(0).to(
         | 
| 217 | 
            +
                        self.device)
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    if self.hps_type == 'pymaf':
         | 
| 220 | 
            +
                        output = preds_dict['smpl_out'][-1]
         | 
| 221 | 
            +
                        scale, tranX, tranY = output['theta'][0, :3]
         | 
| 222 | 
            +
                        data_dict['betas'] = output['pred_shape']
         | 
| 223 | 
            +
                        data_dict['body_pose'] = output['rotmat'][:, 1:]
         | 
| 224 | 
            +
                        data_dict['global_orient'] = output['rotmat'][:, 0:1]
         | 
| 225 | 
            +
                        data_dict['smpl_verts'] = output['verts']     # 不确定尺度是否一样
         | 
| 226 | 
            +
                        data_dict["type"] = "smpl"
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    elif self.hps_type == 'pare':
         | 
| 229 | 
            +
                        data_dict['body_pose'] = preds_dict['pred_pose'][:, 1:]
         | 
| 230 | 
            +
                        data_dict['global_orient'] = preds_dict['pred_pose'][:, 0:1]
         | 
| 231 | 
            +
                        data_dict['betas'] = preds_dict['pred_shape']
         | 
| 232 | 
            +
                        data_dict['smpl_verts'] = preds_dict['smpl_vertices']
         | 
| 233 | 
            +
                        scale, tranX, tranY = preds_dict['pred_cam'][0, :3]
         | 
| 234 | 
            +
                        data_dict["type"] = "smpl"
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    elif self.hps_type == 'pixie':
         | 
| 237 | 
            +
                        data_dict.update(preds_dict)
         | 
| 238 | 
            +
                        data_dict['body_pose'] = preds_dict['body_pose']
         | 
| 239 | 
            +
                        data_dict['global_orient'] = preds_dict['global_pose']
         | 
| 240 | 
            +
                        data_dict['betas'] = preds_dict['shape']
         | 
| 241 | 
            +
                        data_dict['smpl_verts'] = preds_dict['vertices']
         | 
| 242 | 
            +
                        scale, tranX, tranY = preds_dict['cam'][0, :3]
         | 
| 243 | 
            +
                        data_dict["type"] = "smplx"
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                    elif self.hps_type == 'hybrik':
         | 
| 246 | 
            +
                        data_dict['body_pose'] = preds_dict['pred_theta_mats'][:, 1:]
         | 
| 247 | 
            +
                        data_dict['global_orient'] = preds_dict['pred_theta_mats'][:, [0]]
         | 
| 248 | 
            +
                        data_dict['betas'] = preds_dict['pred_shape']
         | 
| 249 | 
            +
                        data_dict['smpl_verts'] = preds_dict['pred_vertices']
         | 
| 250 | 
            +
                        scale, tranX, tranY = preds_dict['pred_camera'][0, :3]
         | 
| 251 | 
            +
                        scale = scale * 2
         | 
| 252 | 
            +
                        data_dict["type"] = "smpl"
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                    elif self.hps_type == 'bev':
         | 
| 255 | 
            +
                        data_dict['betas'] = torch.from_numpy(preds_dict['smpl_betas'])[[0], :10].to(
         | 
| 256 | 
            +
                            self.device).float()
         | 
| 257 | 
            +
                        pred_thetas = batch_rodrigues(
         | 
| 258 | 
            +
                            torch.from_numpy(preds_dict['smpl_thetas'][0]).reshape(-1, 3)).float()
         | 
| 259 | 
            +
                        data_dict['body_pose'] = pred_thetas[1:][None].to(self.device)
         | 
| 260 | 
            +
                        data_dict['global_orient'] = pred_thetas[[0]][None].to(self.device)
         | 
| 261 | 
            +
                        data_dict['smpl_verts'] = torch.from_numpy(preds_dict['verts'][[0]]).to(
         | 
| 262 | 
            +
                            self.device).float()
         | 
| 263 | 
            +
                        tranX = preds_dict['cam_trans'][0, 0]
         | 
| 264 | 
            +
                        tranY = preds_dict['cam'][0, 1] + 0.28
         | 
| 265 | 
            +
                        scale = preds_dict['cam'][0, 0] * 1.1
         | 
| 266 | 
            +
                        data_dict["type"] = "smpl"
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                    data_dict['scale'] = scale
         | 
| 269 | 
            +
                    data_dict['trans'] = torch.tensor([tranX, tranY, 0.0]).unsqueeze(0).to(self.device).float()
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                    # data_dict info (key-shape):
         | 
| 272 | 
            +
                    # scale, tranX, tranY - tensor.float
         | 
| 273 | 
            +
                    # betas - [1,10] / [1, 200]
         | 
| 274 | 
            +
                    # body_pose - [1, 23, 3, 3] / [1, 21, 3, 3]
         | 
| 275 | 
            +
                    # global_orient - [1, 1, 3, 3]
         | 
| 276 | 
            +
                    # smpl_verts - [1, 6890, 3] / [1, 10475, 3]
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                    # from rot_mat to rot_6d for better optimization
         | 
| 279 | 
            +
                    N_body = data_dict["body_pose"].shape[1]
         | 
| 280 | 
            +
                    data_dict["body_pose"] = data_dict["body_pose"][:, :, :, :2].reshape(1, N_body, -1)
         | 
| 281 | 
            +
                    data_dict["global_orient"] = data_dict["global_orient"][:, :, :, :2].reshape(1, 1, -1)
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                    return data_dict
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                def render_normal(self, verts, faces):
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                    # render optimized mesh (normal, T_normal, image [-1,1])
         | 
| 288 | 
            +
                    self.render.load_meshes(verts, faces)
         | 
| 289 | 
            +
                    return self.render.get_rgb_image()
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                def render_depth(self, verts, faces):
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                    # render optimized mesh (normal, T_normal, image [-1,1])
         | 
| 294 | 
            +
                    self.render.load_meshes(verts, faces)
         | 
| 295 | 
            +
                    return self.render.get_depth_map(cam_ids=[0, 2])
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                def visualize_alignment(self, data):
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                    import vedo
         | 
| 300 | 
            +
                    import trimesh
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                    if self.hps_type != 'pixie':
         | 
| 303 | 
            +
                        smpl_out = self.smpl_model(betas=data['betas'],
         | 
| 304 | 
            +
                                                   body_pose=data['body_pose'],
         | 
| 305 | 
            +
                                                   global_orient=data['global_orient'],
         | 
| 306 | 
            +
                                                   pose2rot=False)
         | 
| 307 | 
            +
                        smpl_verts = ((smpl_out.vertices + data['trans']) *
         | 
| 308 | 
            +
                                      data['scale']).detach().cpu().numpy()[0]
         | 
| 309 | 
            +
                    else:
         | 
| 310 | 
            +
                        smpl_verts, _, _ = self.smpl_model(shape_params=data['betas'],
         | 
| 311 | 
            +
                                                           expression_params=data['exp'],
         | 
| 312 | 
            +
                                                           body_pose=data['body_pose'],
         | 
| 313 | 
            +
                                                           global_pose=data['global_orient'],
         | 
| 314 | 
            +
                                                           jaw_pose=data['jaw_pose'],
         | 
| 315 | 
            +
                                                           left_hand_pose=data['left_hand_pose'],
         | 
| 316 | 
            +
                                                           right_hand_pose=data['right_hand_pose'])
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                        smpl_verts = ((smpl_verts + data['trans']) * data['scale']).detach().cpu().numpy()[0]
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    smpl_verts *= np.array([1.0, -1.0, -1.0])
         | 
| 321 | 
            +
                    faces = data['smpl_faces'][0].detach().cpu().numpy()
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                    image_P = data['image']
         | 
| 324 | 
            +
                    image_F, image_B = self.render_normal(smpl_verts, faces)
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                    # create plot
         | 
| 327 | 
            +
                    vp = vedo.Plotter(title="", size=(1500, 1500))
         | 
| 328 | 
            +
                    vis_list = []
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                    image_F = (0.5 * (1.0 + image_F[0].permute(1, 2, 0).detach().cpu().numpy()) * 255.0)
         | 
| 331 | 
            +
                    image_B = (0.5 * (1.0 + image_B[0].permute(1, 2, 0).detach().cpu().numpy()) * 255.0)
         | 
| 332 | 
            +
                    image_P = (0.5 * (1.0 + image_P[0].permute(1, 2, 0).detach().cpu().numpy()) * 255.0)
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                    vis_list.append(
         | 
| 335 | 
            +
                        vedo.Picture(image_P * 0.5 + image_F * 0.5).scale(2.0 / image_P.shape[0]).pos(
         | 
| 336 | 
            +
                            -1.0, -1.0, 1.0))
         | 
| 337 | 
            +
                    vis_list.append(vedo.Picture(image_F).scale(2.0 / image_F.shape[0]).pos(-1.0, -1.0, -0.5))
         | 
| 338 | 
            +
                    vis_list.append(vedo.Picture(image_B).scale(2.0 / image_B.shape[0]).pos(-1.0, -1.0, -1.0))
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                    # create a mesh
         | 
| 341 | 
            +
                    mesh = trimesh.Trimesh(smpl_verts, faces, process=False)
         | 
| 342 | 
            +
                    mesh.visual.vertex_colors = [200, 200, 0]
         | 
| 343 | 
            +
                    vis_list.append(mesh)
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                    vp.show(*vis_list, bg="white", axes=1, interactive=True)
         | 
| 346 | 
            +
             | 
| 347 | 
            +
             | 
| 348 | 
            +
            if __name__ == '__main__':
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                cfg.merge_from_file("./configs/icon-filter.yaml")
         | 
| 351 | 
            +
                cfg.merge_from_file('./lib/pymaf/configs/pymaf_config.yaml')
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                cfg_show_list = ['test_gpus', ['0'], 'mcube_res', 512, 'clean_mesh', False]
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                cfg.merge_from_list(cfg_show_list)
         | 
| 356 | 
            +
                cfg.freeze()
         | 
| 357 | 
            +
             | 
| 358 | 
            +
             
         | 
| 359 | 
            +
                device = torch.device('cuda:0')
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                dataset = SMPLDataset(
         | 
| 362 | 
            +
                    {
         | 
| 363 | 
            +
                        'image_dir': "./examples",
         | 
| 364 | 
            +
                        'has_det': True,  # w/ or w/o detection
         | 
| 365 | 
            +
                        'hps_type': 'bev'  # pymaf/pare/pixie/hybrik/bev
         | 
| 366 | 
            +
                    },
         | 
| 367 | 
            +
                    device)
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                for i in range(len(dataset)):
         | 
| 370 | 
            +
                    dataset.visualize_alignment(dataset[i])
         | 
    	
        examples/02986d0998ce01aa0aa67a99fbd1e09a.png
    ADDED
    
    |   | 
    	
        examples/16171.png
    ADDED
    
    |   | 
    	
        examples/26d2e846349647ff04c536816e0e8ca1.png
    ADDED
    
    |   | 
    	
        examples/30755.png
    ADDED
    
    |   | 
    	
        examples/3930.png
    ADDED
    
    |   | 
    	
        examples/4656716-3016170581.png
    ADDED
    
    |   | 
    	
        examples/663dcd6db19490de0b790da430bd5681.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        examples/7332.png
    ADDED
    
    |   | 
    	
        examples/85891251f52a2399e660a63c2a7fdf40.png
    ADDED
    
    |   | 
    	
        examples/a689a48d23d6b8d58d67ff5146c6e088.png
    ADDED
    
    |   | 
    	
        examples/b0d178743c7e3e09700aaee8d2b1ec47.png
    ADDED
    
    |   | 
    	
        examples/case5.png
    ADDED
    
    |   | 
    	
        examples/d40776a1e1582179d97907d36f84d776.png
    ADDED
    
    |   | 
    	
        examples/durant.png
    ADDED
    
    |   | 
    	
        examples/eedb9018-e0eb-45be-33bd-5a0108ca0d8b.png
    ADDED
    
    |   | 
    	
        examples/f14f7d40b72062928461b21c6cc877407e69ee0c_high.png
    ADDED
    
    |   | 
    	
        examples/f6317ac1b0498f4e6ef9d12bd991a9bd1ff4ae04f898-IQTEBw_fw1200.png
    ADDED
    
    |   | 
    	
        examples/pexels-barbara-olsen-7869640.png
    ADDED
    
    |   | 
    	
        examples/pexels-julia-m-cameron-4145040.png
    ADDED
    
    |   | 
    	
        examples/pexels-marta-wave-6437749.png
    ADDED
    
    |   | 
    	
        examples/pexels-photo-6311555-removebg.png
    ADDED
    
    |   | 
    	
        examples/pexels-zdmit-6780091.png
    ADDED
    
    |   | 
    	
        inference.py
    ADDED
    
    | @@ -0,0 +1,221 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            from typing import Dict, Optional, Tuple, List
         | 
| 4 | 
            +
            from omegaconf import OmegaConf
         | 
| 5 | 
            +
            from PIL import Image
         | 
| 6 | 
            +
            from dataclasses import dataclass
         | 
| 7 | 
            +
            from collections import defaultdict
         | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
            import torch.utils.checkpoint
         | 
| 10 | 
            +
            from torchvision.utils import make_grid, save_image
         | 
| 11 | 
            +
            from accelerate.utils import  set_seed
         | 
| 12 | 
            +
            from tqdm.auto import tqdm
         | 
| 13 | 
            +
            import torch.nn.functional as F
         | 
| 14 | 
            +
            from einops import rearrange
         | 
| 15 | 
            +
            from rembg import remove, new_session
         | 
| 16 | 
            +
            import pdb
         | 
| 17 | 
            +
            from mvdiffusion.pipelines.pipeline_mvdiffusion_unclip import StableUnCLIPImg2ImgPipeline
         | 
| 18 | 
            +
            from econdataset import SMPLDataset
         | 
| 19 | 
            +
            from reconstruct import ReMesh
         | 
| 20 | 
            +
            providers = [
         | 
| 21 | 
            +
                ('CUDAExecutionProvider', {
         | 
| 22 | 
            +
                    'device_id': 0,
         | 
| 23 | 
            +
                    'arena_extend_strategy': 'kSameAsRequested',
         | 
| 24 | 
            +
                    'gpu_mem_limit': 8 * 1024 * 1024 * 1024,
         | 
| 25 | 
            +
                    'cudnn_conv_algo_search': 'HEURISTIC',
         | 
| 26 | 
            +
                })
         | 
| 27 | 
            +
            ]
         | 
| 28 | 
            +
            session = new_session(providers=providers)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            weight_dtype = torch.float16
         | 
| 31 | 
            +
            def tensor_to_numpy(tensor):
         | 
| 32 | 
            +
                return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            @dataclass
         | 
| 36 | 
            +
            class TestConfig:
         | 
| 37 | 
            +
                pretrained_model_name_or_path: str
         | 
| 38 | 
            +
                revision: Optional[str]
         | 
| 39 | 
            +
                validation_dataset: Dict
         | 
| 40 | 
            +
                save_dir: str
         | 
| 41 | 
            +
                seed: Optional[int]
         | 
| 42 | 
            +
                validation_batch_size: int
         | 
| 43 | 
            +
                dataloader_num_workers: int
         | 
| 44 | 
            +
                # save_single_views: bool
         | 
| 45 | 
            +
                save_mode: str
         | 
| 46 | 
            +
                local_rank: int
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                pipe_kwargs: Dict
         | 
| 49 | 
            +
                pipe_validation_kwargs: Dict
         | 
| 50 | 
            +
                unet_from_pretrained_kwargs: Dict
         | 
| 51 | 
            +
                validation_guidance_scales: float
         | 
| 52 | 
            +
                validation_grid_nrow: int
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                num_views: int
         | 
| 55 | 
            +
                enable_xformers_memory_efficient_attention: bool
         | 
| 56 | 
            +
                with_smpl: Optional[bool]
         | 
| 57 | 
            +
                
         | 
| 58 | 
            +
                recon_opt: Dict
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            def convert_to_numpy(tensor):
         | 
| 62 | 
            +
                return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            def convert_to_pil(tensor):
         | 
| 65 | 
            +
                return Image.fromarray(convert_to_numpy(tensor))
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            def save_image(tensor, fp):
         | 
| 68 | 
            +
                ndarr = convert_to_numpy(tensor)
         | 
| 69 | 
            +
                # pdb.set_trace()
         | 
| 70 | 
            +
                save_image_numpy(ndarr, fp)
         | 
| 71 | 
            +
                return ndarr
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            def save_image_numpy(ndarr, fp):
         | 
| 74 | 
            +
                im = Image.fromarray(ndarr)
         | 
| 75 | 
            +
                im.save(fp)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
            def run_inference(dataloader, econdata, pipeline, carving, cfg: TestConfig,  save_dir):
         | 
| 78 | 
            +
                pipeline.set_progress_bar_config(disable=True)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                if cfg.seed is None:
         | 
| 81 | 
            +
                    generator = None
         | 
| 82 | 
            +
                else:
         | 
| 83 | 
            +
                    generator = torch.Generator(device=pipeline.unet.device).manual_seed(cfg.seed)
         | 
| 84 | 
            +
                
         | 
| 85 | 
            +
                images_cond, pred_cat = [], defaultdict(list)
         | 
| 86 | 
            +
                for case_id, batch in tqdm(enumerate(dataloader)):
         | 
| 87 | 
            +
                    images_cond.append(batch['imgs_in'][:, 0]) 
         | 
| 88 | 
            +
                    
         | 
| 89 | 
            +
                    imgs_in = torch.cat([batch['imgs_in']]*2, dim=0)
         | 
| 90 | 
            +
                    num_views = imgs_in.shape[1]
         | 
| 91 | 
            +
                    imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W")# (B*Nv, 3, H, W)
         | 
| 92 | 
            +
                    if cfg.with_smpl:
         | 
| 93 | 
            +
                        smpl_in = torch.cat([batch['smpl_imgs_in']]*2, dim=0)
         | 
| 94 | 
            +
                        smpl_in = rearrange(smpl_in, "B Nv C H W -> (B Nv) C H W")
         | 
| 95 | 
            +
                    else:
         | 
| 96 | 
            +
                        smpl_in = None
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    normal_prompt_embeddings, clr_prompt_embeddings = batch['normal_prompt_embeddings'], batch['color_prompt_embeddings'] 
         | 
| 99 | 
            +
                    prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0)
         | 
| 100 | 
            +
                    prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C")
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    with torch.autocast("cuda"):
         | 
| 103 | 
            +
                        # B*Nv images
         | 
| 104 | 
            +
                        guidance_scale = cfg.validation_guidance_scales
         | 
| 105 | 
            +
                        unet_out = pipeline(
         | 
| 106 | 
            +
                            imgs_in, None, prompt_embeds=prompt_embeddings,
         | 
| 107 | 
            +
                            dino_feature=None, smpl_in=smpl_in,
         | 
| 108 | 
            +
                            generator=generator, guidance_scale=guidance_scale, output_type='pt', num_images_per_prompt=1, 
         | 
| 109 | 
            +
                            **cfg.pipe_validation_kwargs
         | 
| 110 | 
            +
                        )
         | 
| 111 | 
            +
                        
         | 
| 112 | 
            +
                        out = unet_out.images
         | 
| 113 | 
            +
                        bsz = out.shape[0] // 2
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                        normals_pred = out[:bsz]
         | 
| 116 | 
            +
                        images_pred = out[bsz:] 
         | 
| 117 | 
            +
                        if cfg.save_mode == 'concat': ## save concatenated color and normal---------------------
         | 
| 118 | 
            +
                            pred_cat[f"cfg{guidance_scale:.1f}"].append(torch.cat([normals_pred, images_pred], dim=-1)) # b, 3, h, w
         | 
| 119 | 
            +
                            cur_dir = os.path.join(save_dir, f"cropsize-{cfg.validation_dataset.crop_size}-cfg{guidance_scale:.1f}-seed{cfg.seed}-smpl-{cfg.with_smpl}")
         | 
| 120 | 
            +
                            os.makedirs(cur_dir, exist_ok=True)
         | 
| 121 | 
            +
                            for i in range(bsz//num_views):
         | 
| 122 | 
            +
                                scene =  batch['filename'][i].split('.')[0]
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                                img_in_ = images_cond[-1][i].to(out.device)
         | 
| 125 | 
            +
                                vis_ = [img_in_]
         | 
| 126 | 
            +
                                for j in range(num_views):
         | 
| 127 | 
            +
                                    idx = i*num_views + j
         | 
| 128 | 
            +
                                    normal = normals_pred[idx]
         | 
| 129 | 
            +
                                    color = images_pred[idx]
         | 
| 130 | 
            +
                                    
         | 
| 131 | 
            +
                                    vis_.append(color)
         | 
| 132 | 
            +
                                    vis_.append(normal)
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                                out_filename = f"{cur_dir}/{scene}.png"
         | 
| 135 | 
            +
                                vis_ = torch.stack(vis_, dim=0)
         | 
| 136 | 
            +
                                vis_ = make_grid(vis_, nrow=len(vis_), padding=0, value_range=(0, 1))
         | 
| 137 | 
            +
                                save_image(vis_, out_filename)
         | 
| 138 | 
            +
                        elif cfg.save_mode == 'rgb':
         | 
| 139 | 
            +
                            for i in range(bsz//num_views):
         | 
| 140 | 
            +
                                scene =  batch['filename'][i].split('.')[0]
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                                img_in_ = images_cond[-1][i].to(out.device)
         | 
| 143 | 
            +
                                normals, colors = [], []
         | 
| 144 | 
            +
                                for j in range(num_views):
         | 
| 145 | 
            +
                                    idx = i*num_views + j
         | 
| 146 | 
            +
                                    normal = normals_pred[idx]
         | 
| 147 | 
            +
                                    if j == 0:
         | 
| 148 | 
            +
                                        color = imgs_in[0].to(out.device)
         | 
| 149 | 
            +
                                    else:
         | 
| 150 | 
            +
                                        color = images_pred[idx]
         | 
| 151 | 
            +
                                    if j in [3, 4]:
         | 
| 152 | 
            +
                                        normal = torch.flip(normal, dims=[2])
         | 
| 153 | 
            +
                                        color = torch.flip(color, dims=[2])
         | 
| 154 | 
            +
                                        
         | 
| 155 | 
            +
                                    colors.append(color)
         | 
| 156 | 
            +
                                    if j == 6:
         | 
| 157 | 
            +
                                        normal = F.interpolate(normal.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False).squeeze(0)
         | 
| 158 | 
            +
                                    normals.append(normal)
         | 
| 159 | 
            +
                                    
         | 
| 160 | 
            +
                                    ## save color and normal---------------------
         | 
| 161 | 
            +
                                    # normal_filename = f"normals_{view}_masked.png"
         | 
| 162 | 
            +
                                    # rgb_filename = f"color_{view}_masked.png"
         | 
| 163 | 
            +
                                    # save_image(normal, os.path.join(scene_dir, normal_filename))
         | 
| 164 | 
            +
                                    # save_image(color, os.path.join(scene_dir, rgb_filename))
         | 
| 165 | 
            +
                                normals[0][:, :256, 256:512] =  normals[-1]
         | 
| 166 | 
            +
                                
         | 
| 167 | 
            +
                                colors = [remove(convert_to_pil(tensor), session=session) for tensor in colors[:6]]
         | 
| 168 | 
            +
                                normals = [remove(convert_to_pil(tensor), session=session) for tensor in normals[:6]]
         | 
| 169 | 
            +
                    pose = econdata.__getitem__(case_id)
         | 
| 170 | 
            +
                    carving.optimize_case(scene, pose, colors, normals)
         | 
| 171 | 
            +
                    torch.cuda.empty_cache()   
         | 
| 172 | 
            +
                           
         | 
| 173 | 
            +
                 
         | 
| 174 | 
            +
             | 
| 175 | 
            +
            def load_pshuman_pipeline(cfg):
         | 
| 176 | 
            +
                pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(cfg.pretrained_model_name_or_path, torch_dtype=weight_dtype)
         | 
| 177 | 
            +
                pipeline.unet.enable_xformers_memory_efficient_attention()
         | 
| 178 | 
            +
                if torch.cuda.is_available():
         | 
| 179 | 
            +
                    pipeline.to('cuda')
         | 
| 180 | 
            +
                return pipeline
         | 
| 181 | 
            +
             | 
| 182 | 
            +
            def main(
         | 
| 183 | 
            +
                cfg: TestConfig
         | 
| 184 | 
            +
            ):
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                # If passed along, set the training seed now.
         | 
| 187 | 
            +
                if cfg.seed is not None:
         | 
| 188 | 
            +
                    set_seed(cfg.seed)
         | 
| 189 | 
            +
                pipeline = load_pshuman_pipeline(cfg)
         | 
| 190 | 
            +
                
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                if cfg.with_smpl:
         | 
| 193 | 
            +
                    from mvdiffusion.data.testdata_with_smpl import SingleImageDataset
         | 
| 194 | 
            +
                else:
         | 
| 195 | 
            +
                    from mvdiffusion.data.single_image_dataset import SingleImageDataset
         | 
| 196 | 
            +
                    
         | 
| 197 | 
            +
                # Get the  dataset
         | 
| 198 | 
            +
                validation_dataset = SingleImageDataset(
         | 
| 199 | 
            +
                    **cfg.validation_dataset
         | 
| 200 | 
            +
                )
         | 
| 201 | 
            +
                validation_dataloader = torch.utils.data.DataLoader(
         | 
| 202 | 
            +
                    validation_dataset, batch_size=cfg.validation_batch_size, shuffle=False, num_workers=cfg.dataloader_num_workers
         | 
| 203 | 
            +
                )
         | 
| 204 | 
            +
                dataset_param = {'image_dir': validation_dataset.root_dir, 'seg_dir': None, 'colab': False, 'has_det': True, 'hps_type': 'pixie'}
         | 
| 205 | 
            +
                econdata = SMPLDataset(dataset_param, device='cuda')
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                carving = ReMesh(cfg.recon_opt, econ_dataset=econdata)
         | 
| 208 | 
            +
                run_inference(validation_dataloader, econdata, pipeline, carving, cfg, cfg.save_dir)
         | 
| 209 | 
            +
               
         | 
| 210 | 
            +
             | 
| 211 | 
            +
            if __name__ == '__main__':
         | 
| 212 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 213 | 
            +
                parser.add_argument('--config', type=str, required=True)
         | 
| 214 | 
            +
                args, extras = parser.parse_known_args()
         | 
| 215 | 
            +
                from utils.misc import load_config    
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                # parse YAML config to OmegaConf
         | 
| 218 | 
            +
                cfg = load_config(args.config, cli_args=extras)
         | 
| 219 | 
            +
                schema = OmegaConf.structured(TestConfig)
         | 
| 220 | 
            +
                cfg = OmegaConf.merge(schema, cfg)
         | 
| 221 | 
            +
                main(cfg)
         | 
    	
        lib/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        lib/common/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        lib/common/cloth_extraction.py
    ADDED
    
    | @@ -0,0 +1,182 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import json
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import itertools
         | 
| 5 | 
            +
            import trimesh
         | 
| 6 | 
            +
            from matplotlib.path import Path
         | 
| 7 | 
            +
            from collections import Counter
         | 
| 8 | 
            +
            from sklearn.neighbors import KNeighborsClassifier
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            def load_segmentation(path, shape):
         | 
| 12 | 
            +
                """
         | 
| 13 | 
            +
                Get a segmentation mask for a given image
         | 
| 14 | 
            +
                Arguments:
         | 
| 15 | 
            +
                    path: path to the segmentation json file
         | 
| 16 | 
            +
                    shape: shape of the output mask
         | 
| 17 | 
            +
                Returns:
         | 
| 18 | 
            +
                    Returns a segmentation mask
         | 
| 19 | 
            +
                """
         | 
| 20 | 
            +
                with open(path) as json_file:
         | 
| 21 | 
            +
                    dict = json.load(json_file)
         | 
| 22 | 
            +
                    segmentations = []
         | 
| 23 | 
            +
                    for key, val in dict.items():
         | 
| 24 | 
            +
                        if not key.startswith('item'):
         | 
| 25 | 
            +
                            continue
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                        # Each item can have multiple polygons. Combine them to one
         | 
| 28 | 
            +
                        # segmentation_coord = list(itertools.chain.from_iterable(val['segmentation']))
         | 
| 29 | 
            +
                        # segmentation_coord = np.round(np.array(segmentation_coord)).astype(int)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                        coordinates = []
         | 
| 32 | 
            +
                        for segmentation_coord in val['segmentation']:
         | 
| 33 | 
            +
                            # The format before is [x1,y1, x2, y2, ....]
         | 
| 34 | 
            +
                            x = segmentation_coord[::2]
         | 
| 35 | 
            +
                            y = segmentation_coord[1::2]
         | 
| 36 | 
            +
                            xy = np.vstack((x, y)).T
         | 
| 37 | 
            +
                            coordinates.append(xy)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                        segmentations.append({
         | 
| 40 | 
            +
                            'type': val['category_name'],
         | 
| 41 | 
            +
                            'type_id': val['category_id'],
         | 
| 42 | 
            +
                            'coordinates': coordinates
         | 
| 43 | 
            +
                        })
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    return segmentations
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            def smpl_to_recon_labels(recon, smpl, k=1):
         | 
| 49 | 
            +
                """
         | 
| 50 | 
            +
                Get the bodypart labels for the recon object by using the labels from the corresponding smpl object
         | 
| 51 | 
            +
                Arguments:
         | 
| 52 | 
            +
                    recon: trimesh object (fully clothed model)
         | 
| 53 | 
            +
                    shape: trimesh object (smpl model)
         | 
| 54 | 
            +
                    k: number of nearest neighbours to use
         | 
| 55 | 
            +
                Returns:
         | 
| 56 | 
            +
                    Returns a dictionary containing the bodypart and the corresponding indices
         | 
| 57 | 
            +
                """
         | 
| 58 | 
            +
                smpl_vert_segmentation = json.load(
         | 
| 59 | 
            +
                    open(
         | 
| 60 | 
            +
                        os.path.join(os.path.dirname(__file__),
         | 
| 61 | 
            +
                                     'smpl_vert_segmentation.json')))
         | 
| 62 | 
            +
                n = smpl.vertices.shape[0]
         | 
| 63 | 
            +
                y = np.array([None] * n)
         | 
| 64 | 
            +
                for key, val in smpl_vert_segmentation.items():
         | 
| 65 | 
            +
                    y[val] = key
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                classifier = KNeighborsClassifier(n_neighbors=1)
         | 
| 68 | 
            +
                classifier.fit(smpl.vertices, y)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                y_pred = classifier.predict(recon.vertices)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                recon_labels = {}
         | 
| 73 | 
            +
                for key in smpl_vert_segmentation.keys():
         | 
| 74 | 
            +
                    recon_labels[key] = list(
         | 
| 75 | 
            +
                        np.argwhere(y_pred == key).flatten().astype(int))
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                return recon_labels
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            def extract_cloth(recon, segmentation, K, R, t, smpl=None):
         | 
| 81 | 
            +
                """
         | 
| 82 | 
            +
                Extract a portion of a mesh using 2d segmentation coordinates
         | 
| 83 | 
            +
                Arguments:
         | 
| 84 | 
            +
                    recon: fully clothed mesh
         | 
| 85 | 
            +
                    seg_coord: segmentation coordinates in 2D (NDC)
         | 
| 86 | 
            +
                    K: intrinsic matrix of the projection
         | 
| 87 | 
            +
                    R: rotation matrix of the projection
         | 
| 88 | 
            +
                    t: translation vector of the projection
         | 
| 89 | 
            +
                Returns:
         | 
| 90 | 
            +
                    Returns a submesh using the segmentation coordinates
         | 
| 91 | 
            +
                """
         | 
| 92 | 
            +
                seg_coord = segmentation['coord_normalized']
         | 
| 93 | 
            +
                mesh = trimesh.Trimesh(recon.vertices, recon.faces)
         | 
| 94 | 
            +
                extrinsic = np.zeros((3, 4))
         | 
| 95 | 
            +
                extrinsic[:3, :3] = R
         | 
| 96 | 
            +
                extrinsic[:, 3] = t
         | 
| 97 | 
            +
                P = K[:3, :3] @ extrinsic
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                P_inv = np.linalg.pinv(P)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                # Each segmentation can contain multiple polygons
         | 
| 102 | 
            +
                # We need to check them separately
         | 
| 103 | 
            +
                points_so_far = []
         | 
| 104 | 
            +
                faces = recon.faces
         | 
| 105 | 
            +
                for polygon in seg_coord:
         | 
| 106 | 
            +
                    n = len(polygon)
         | 
| 107 | 
            +
                    coords_h = np.hstack((polygon, np.ones((n, 1))))
         | 
| 108 | 
            +
                    # Apply the inverse projection on homogeneus 2D coordinates to get the corresponding 3d Coordinates
         | 
| 109 | 
            +
                    XYZ = P_inv @ coords_h[:, :, None]
         | 
| 110 | 
            +
                    XYZ = XYZ.reshape((XYZ.shape[0], XYZ.shape[1]))
         | 
| 111 | 
            +
                    XYZ = XYZ[:, :3] / XYZ[:, 3, None]
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    p = Path(XYZ[:, :2])
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    grid = p.contains_points(recon.vertices[:, :2])
         | 
| 116 | 
            +
                    indeces = np.argwhere(grid == True)
         | 
| 117 | 
            +
                    points_so_far += list(indeces.flatten())
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                if smpl is not None:
         | 
| 120 | 
            +
                    num_verts = recon.vertices.shape[0]
         | 
| 121 | 
            +
                    recon_labels = smpl_to_recon_labels(recon, smpl)
         | 
| 122 | 
            +
                    body_parts_to_remove = [
         | 
| 123 | 
            +
                        'rightHand', 'leftToeBase', 'leftFoot', 'rightFoot', 'head',
         | 
| 124 | 
            +
                        'leftHandIndex1', 'rightHandIndex1', 'rightToeBase', 'leftHand',
         | 
| 125 | 
            +
                        'rightHand'
         | 
| 126 | 
            +
                    ]
         | 
| 127 | 
            +
                    type = segmentation['type_id']
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    # Remove additional bodyparts that are most likely not part of the segmentation but might intersect (e.g. hand in front of torso)
         | 
| 130 | 
            +
                    # https://github.com/switchablenorms/DeepFashion2
         | 
| 131 | 
            +
                    # Short sleeve clothes
         | 
| 132 | 
            +
                    if type == 1 or type == 3 or type == 10:
         | 
| 133 | 
            +
                        body_parts_to_remove += ['leftForeArm', 'rightForeArm']
         | 
| 134 | 
            +
                    # No sleeves at all or lower body clothes
         | 
| 135 | 
            +
                    elif type == 5 or type == 6 or type == 12 or type == 13 or type == 8 or type == 9:
         | 
| 136 | 
            +
                        body_parts_to_remove += [
         | 
| 137 | 
            +
                            'leftForeArm', 'rightForeArm', 'leftArm', 'rightArm'
         | 
| 138 | 
            +
                        ]
         | 
| 139 | 
            +
                    # Shorts
         | 
| 140 | 
            +
                    elif type == 7:
         | 
| 141 | 
            +
                        body_parts_to_remove += [
         | 
| 142 | 
            +
                            'leftLeg', 'rightLeg', 'leftForeArm', 'rightForeArm',
         | 
| 143 | 
            +
                            'leftArm', 'rightArm'
         | 
| 144 | 
            +
                        ]
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    verts_to_remove = list(
         | 
| 147 | 
            +
                        itertools.chain.from_iterable(
         | 
| 148 | 
            +
                            [recon_labels[part] for part in body_parts_to_remove]))
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    label_mask = np.zeros(num_verts, dtype=bool)
         | 
| 151 | 
            +
                    label_mask[verts_to_remove] = True
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    seg_mask = np.zeros(num_verts, dtype=bool)
         | 
| 154 | 
            +
                    seg_mask[points_so_far] = True
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    # Remove points that belong to other bodyparts
         | 
| 157 | 
            +
                    # If a vertice in pointsSoFar is included in the bodyparts to remove, then these points should be removed
         | 
| 158 | 
            +
                    extra_verts_to_remove = np.array(list(seg_mask) and list(label_mask))
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    combine_mask = np.zeros(num_verts, dtype=bool)
         | 
| 161 | 
            +
                    combine_mask[points_so_far] = True
         | 
| 162 | 
            +
                    combine_mask[extra_verts_to_remove] = False
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    all_indices = np.argwhere(combine_mask == True).flatten()
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                i_x = np.where(np.in1d(faces[:, 0], all_indices))[0]
         | 
| 167 | 
            +
                i_y = np.where(np.in1d(faces[:, 1], all_indices))[0]
         | 
| 168 | 
            +
                i_z = np.where(np.in1d(faces[:, 2], all_indices))[0]
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                faces_to_keep = np.array(list(set(i_x).union(i_y).union(i_z)))
         | 
| 171 | 
            +
                mask = np.zeros(len(recon.faces), dtype=bool)
         | 
| 172 | 
            +
                if len(faces_to_keep) > 0:
         | 
| 173 | 
            +
                    mask[faces_to_keep] = True
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    mesh.update_faces(mask)
         | 
| 176 | 
            +
                    mesh.remove_unreferenced_vertices()
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    # mesh.rezero()
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                    return mesh
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                return None
         | 
    	
        lib/common/config.py
    ADDED
    
    | @@ -0,0 +1,218 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
         | 
| 4 | 
            +
            # holder of all proprietary rights on this computer program.
         | 
| 5 | 
            +
            # You can only use this computer program if you have closed
         | 
| 6 | 
            +
            # a license agreement with MPG or you get the right to use the computer
         | 
| 7 | 
            +
            # program from someone who is authorized to grant you that right.
         | 
| 8 | 
            +
            # Any use of the computer program without a valid license is prohibited and
         | 
| 9 | 
            +
            # liable to prosecution.
         | 
| 10 | 
            +
            #
         | 
| 11 | 
            +
            # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
         | 
| 12 | 
            +
            # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
         | 
| 13 | 
            +
            # for Intelligent Systems. All rights reserved.
         | 
| 14 | 
            +
            #
         | 
| 15 | 
            +
            # Contact: [email protected]
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from yacs.config import CfgNode as CN
         | 
| 18 | 
            +
            import os
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            _C = CN(new_allowed=True)
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            # needed by trainer
         | 
| 23 | 
            +
            _C.name = 'default'
         | 
| 24 | 
            +
            _C.gpus = [0]
         | 
| 25 | 
            +
            _C.test_gpus = [1]
         | 
| 26 | 
            +
            _C.root = "./data/"
         | 
| 27 | 
            +
            _C.ckpt_dir = './data/ckpt/'
         | 
| 28 | 
            +
            _C.resume_path = ''
         | 
| 29 | 
            +
            _C.normal_path = ''
         | 
| 30 | 
            +
            _C.corr_path = ''
         | 
| 31 | 
            +
            _C.results_path = './data/results/'
         | 
| 32 | 
            +
            _C.projection_mode = 'orthogonal'
         | 
| 33 | 
            +
            _C.num_views = 1
         | 
| 34 | 
            +
            _C.sdf = False
         | 
| 35 | 
            +
            _C.sdf_clip = 5.0
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            _C.lr_G = 1e-3
         | 
| 38 | 
            +
            _C.lr_C = 1e-3
         | 
| 39 | 
            +
            _C.lr_N = 2e-4
         | 
| 40 | 
            +
            _C.weight_decay = 0.0
         | 
| 41 | 
            +
            _C.momentum = 0.0
         | 
| 42 | 
            +
            _C.optim = 'Adam'
         | 
| 43 | 
            +
            _C.schedule = [5, 10, 15]
         | 
| 44 | 
            +
            _C.gamma = 0.1
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            _C.overfit = False
         | 
| 47 | 
            +
            _C.resume = False
         | 
| 48 | 
            +
            _C.test_mode = False
         | 
| 49 | 
            +
            _C.test_uv = False
         | 
| 50 | 
            +
            _C.draw_geo_thres = 0.60
         | 
| 51 | 
            +
            _C.num_sanity_val_steps = 2
         | 
| 52 | 
            +
            _C.fast_dev = 0
         | 
| 53 | 
            +
            _C.get_fit = False
         | 
| 54 | 
            +
            _C.agora = False
         | 
| 55 | 
            +
            _C.optim_cloth = False
         | 
| 56 | 
            +
            _C.optim_body = False
         | 
| 57 | 
            +
            _C.mcube_res = 256
         | 
| 58 | 
            +
            _C.clean_mesh = True
         | 
| 59 | 
            +
            _C.remesh = False
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            _C.batch_size = 4
         | 
| 62 | 
            +
            _C.num_threads = 8
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            _C.num_epoch = 10
         | 
| 65 | 
            +
            _C.freq_plot = 0.01
         | 
| 66 | 
            +
            _C.freq_show_train = 0.1
         | 
| 67 | 
            +
            _C.freq_show_val = 0.2
         | 
| 68 | 
            +
            _C.freq_eval = 0.5
         | 
| 69 | 
            +
            _C.accu_grad_batch = 4
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            _C.test_items = ['sv', 'mv', 'mv-fusion', 'hybrid', 'dc-pred', 'gt']
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            _C.net = CN()
         | 
| 74 | 
            +
            _C.net.gtype = 'HGPIFuNet'
         | 
| 75 | 
            +
            _C.net.ctype = 'resnet18'
         | 
| 76 | 
            +
            _C.net.classifierIMF = 'MultiSegClassifier'
         | 
| 77 | 
            +
            _C.net.netIMF = 'resnet18'
         | 
| 78 | 
            +
            _C.net.norm = 'group'
         | 
| 79 | 
            +
            _C.net.norm_mlp = 'group'
         | 
| 80 | 
            +
            _C.net.norm_color = 'group'
         | 
| 81 | 
            +
            _C.net.hg_down = 'conv128' #'ave_pool'
         | 
| 82 | 
            +
            _C.net.num_views = 1
         | 
| 83 | 
            +
             | 
| 84 | 
            +
            # kernel_size, stride, dilation, padding
         | 
| 85 | 
            +
             | 
| 86 | 
            +
            _C.net.conv1 = [7, 2, 1, 3]
         | 
| 87 | 
            +
            _C.net.conv3x3 = [3, 1, 1, 1]
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            _C.net.num_stack = 4
         | 
| 90 | 
            +
            _C.net.num_hourglass = 2
         | 
| 91 | 
            +
            _C.net.hourglass_dim = 256
         | 
| 92 | 
            +
            _C.net.voxel_dim = 32
         | 
| 93 | 
            +
            _C.net.resnet_dim = 120
         | 
| 94 | 
            +
            _C.net.mlp_dim = [320, 1024, 512, 256, 128, 1]
         | 
| 95 | 
            +
            _C.net.mlp_dim_knn = [320, 1024, 512, 256, 128, 3]
         | 
| 96 | 
            +
            _C.net.mlp_dim_color = [513, 1024, 512, 256, 128, 3]
         | 
| 97 | 
            +
            _C.net.mlp_dim_multiseg = [1088, 2048, 1024, 500]
         | 
| 98 | 
            +
            _C.net.res_layers = [2, 3, 4]
         | 
| 99 | 
            +
            _C.net.filter_dim = 256
         | 
| 100 | 
            +
            _C.net.smpl_dim = 3
         | 
| 101 | 
            +
             | 
| 102 | 
            +
            _C.net.cly_dim = 3
         | 
| 103 | 
            +
            _C.net.soft_dim = 64
         | 
| 104 | 
            +
            _C.net.z_size = 200.0
         | 
| 105 | 
            +
            _C.net.N_freqs = 10
         | 
| 106 | 
            +
            _C.net.geo_w = 0.1
         | 
| 107 | 
            +
            _C.net.norm_w = 0.1
         | 
| 108 | 
            +
            _C.net.dc_w = 0.1
         | 
| 109 | 
            +
            _C.net.C_cat_to_G = False
         | 
| 110 | 
            +
             | 
| 111 | 
            +
            _C.net.skip_hourglass = True
         | 
| 112 | 
            +
            _C.net.use_tanh = False
         | 
| 113 | 
            +
            _C.net.soft_onehot = True
         | 
| 114 | 
            +
            _C.net.no_residual = False
         | 
| 115 | 
            +
            _C.net.use_attention = False
         | 
| 116 | 
            +
             | 
| 117 | 
            +
            _C.net.prior_type = "sdf"
         | 
| 118 | 
            +
            _C.net.smpl_feats = ['sdf', 'cmap', 'norm', 'vis']
         | 
| 119 | 
            +
            _C.net.use_filter = True
         | 
| 120 | 
            +
            _C.net.use_cc = False
         | 
| 121 | 
            +
            _C.net.use_PE = False
         | 
| 122 | 
            +
            _C.net.use_IGR = False
         | 
| 123 | 
            +
            _C.net.in_geo = ()
         | 
| 124 | 
            +
            _C.net.in_nml = ()
         | 
| 125 | 
            +
             | 
| 126 | 
            +
            _C.dataset = CN()
         | 
| 127 | 
            +
            _C.dataset.root = ''
         | 
| 128 | 
            +
            _C.dataset.set_splits = [0.95, 0.04]
         | 
| 129 | 
            +
            _C.dataset.types = [
         | 
| 130 | 
            +
                "3dpeople", "axyz", "renderpeople", "renderpeople_p27", "humanalloy"
         | 
| 131 | 
            +
            ]
         | 
| 132 | 
            +
            _C.dataset.scales = [1.0, 100.0, 1.0, 1.0, 100.0 / 39.37]
         | 
| 133 | 
            +
            _C.dataset.rp_type = "pifu900"
         | 
| 134 | 
            +
            _C.dataset.th_type = 'train'
         | 
| 135 | 
            +
            _C.dataset.input_size = 512
         | 
| 136 | 
            +
            _C.dataset.rotation_num = 3
         | 
| 137 | 
            +
            _C.dataset.num_sample_ray=128  # volume rendering
         | 
| 138 | 
            +
            _C.dataset.num_precomp = 10  # Number of segmentation classifiers
         | 
| 139 | 
            +
            _C.dataset.num_multiseg = 500  # Number of categories per classifier
         | 
| 140 | 
            +
            _C.dataset.num_knn = 10  # for loss/error
         | 
| 141 | 
            +
            _C.dataset.num_knn_dis = 20  # for accuracy
         | 
| 142 | 
            +
            _C.dataset.num_verts_max = 20000
         | 
| 143 | 
            +
            _C.dataset.zray_type = False
         | 
| 144 | 
            +
            _C.dataset.online_smpl = False
         | 
| 145 | 
            +
            _C.dataset.noise_type = ['z-trans', 'pose', 'beta']
         | 
| 146 | 
            +
            _C.dataset.noise_scale = [0.0, 0.0, 0.0]
         | 
| 147 | 
            +
            _C.dataset.num_sample_geo = 10000
         | 
| 148 | 
            +
            _C.dataset.num_sample_color = 0
         | 
| 149 | 
            +
            _C.dataset.num_sample_seg = 0
         | 
| 150 | 
            +
            _C.dataset.num_sample_knn = 10000
         | 
| 151 | 
            +
             | 
| 152 | 
            +
            _C.dataset.sigma_geo = 5.0
         | 
| 153 | 
            +
            _C.dataset.sigma_color = 0.10
         | 
| 154 | 
            +
            _C.dataset.sigma_seg = 0.10
         | 
| 155 | 
            +
            _C.dataset.thickness_threshold = 20.0
         | 
| 156 | 
            +
            _C.dataset.ray_sample_num = 2
         | 
| 157 | 
            +
            _C.dataset.semantic_p = False
         | 
| 158 | 
            +
            _C.dataset.remove_outlier = False
         | 
| 159 | 
            +
             | 
| 160 | 
            +
            _C.dataset.train_bsize = 1.0
         | 
| 161 | 
            +
            _C.dataset.val_bsize = 1.0
         | 
| 162 | 
            +
            _C.dataset.test_bsize = 1.0
         | 
| 163 | 
            +
             | 
| 164 | 
            +
             | 
| 165 | 
            +
            def get_cfg_defaults():
         | 
| 166 | 
            +
                """Get a yacs CfgNode object with default values for my_project."""
         | 
| 167 | 
            +
                # Return a clone so that the defaults will not be altered
         | 
| 168 | 
            +
                # This is for the "local variable" use pattern
         | 
| 169 | 
            +
                return _C.clone()
         | 
| 170 | 
            +
             | 
| 171 | 
            +
             | 
| 172 | 
            +
            # Alternatively, provide a way to import the defaults as
         | 
| 173 | 
            +
            # a global singleton:
         | 
| 174 | 
            +
            cfg = _C  # users can `from config import cfg`
         | 
| 175 | 
            +
             | 
| 176 | 
            +
            # cfg = get_cfg_defaults()
         | 
| 177 | 
            +
            # cfg.merge_from_file('./configs/example.yaml')
         | 
| 178 | 
            +
             | 
| 179 | 
            +
            # # Now override from a list (opts could come from the command line)
         | 
| 180 | 
            +
            # opts = ['dataset.root', './data/XXXX', 'learning_rate', '1e-2']
         | 
| 181 | 
            +
            # cfg.merge_from_list(opts)
         | 
| 182 | 
            +
             | 
| 183 | 
            +
             | 
| 184 | 
            +
            def update_cfg(cfg_file):
         | 
| 185 | 
            +
                # cfg = get_cfg_defaults()
         | 
| 186 | 
            +
                _C.merge_from_file(cfg_file)
         | 
| 187 | 
            +
                # return cfg.clone()
         | 
| 188 | 
            +
                return _C
         | 
| 189 | 
            +
             | 
| 190 | 
            +
             | 
| 191 | 
            +
            def parse_args(args):
         | 
| 192 | 
            +
                cfg_file = args.cfg_file
         | 
| 193 | 
            +
                if args.cfg_file is not None:
         | 
| 194 | 
            +
                    cfg = update_cfg(args.cfg_file)
         | 
| 195 | 
            +
                else:
         | 
| 196 | 
            +
                    cfg = get_cfg_defaults()
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                # if args.misc is not None:
         | 
| 199 | 
            +
                #     cfg.merge_from_list(args.misc)
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                return cfg
         | 
| 202 | 
            +
             | 
| 203 | 
            +
             | 
| 204 | 
            +
            def parse_args_extend(args):
         | 
| 205 | 
            +
                if args.resume:
         | 
| 206 | 
            +
                    if not os.path.exists(args.log_dir):
         | 
| 207 | 
            +
                        raise ValueError(
         | 
| 208 | 
            +
                            'Experiment are set to resume mode, but log directory does not exist.'
         | 
| 209 | 
            +
                        )
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    # load log's cfg
         | 
| 212 | 
            +
                    cfg_file = os.path.join(args.log_dir, 'cfg.yaml')
         | 
| 213 | 
            +
                    cfg = update_cfg(cfg_file)
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    if args.misc is not None:
         | 
| 216 | 
            +
                        cfg.merge_from_list(args.misc)
         | 
| 217 | 
            +
                else:
         | 
| 218 | 
            +
                    parse_args(args)
         | 
    	
        lib/common/imutils.py
    ADDED
    
    | @@ -0,0 +1,364 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"
         | 
| 3 | 
            +
            import cv2
         | 
| 4 | 
            +
            import mediapipe as mp
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            import torch.nn.functional as F
         | 
| 8 | 
            +
            from PIL import Image
         | 
| 9 | 
            +
            from lib.pymafx.core import constants
         | 
| 10 | 
            +
            from rembg import remove
         | 
| 11 | 
            +
            # from rembg.session_factory import new_session
         | 
| 12 | 
            +
            from torchvision import transforms
         | 
| 13 | 
            +
            from kornia.geometry.transform import get_affine_matrix2d, warp_affine
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def transform_to_tensor(res, mean=None, std=None, is_tensor=False):
         | 
| 17 | 
            +
                all_ops = []
         | 
| 18 | 
            +
                if res is not None:
         | 
| 19 | 
            +
                    all_ops.append(transforms.Resize(size=res))
         | 
| 20 | 
            +
                if not is_tensor:
         | 
| 21 | 
            +
                    all_ops.append(transforms.ToTensor())
         | 
| 22 | 
            +
                if mean is not None and std is not None:
         | 
| 23 | 
            +
                    all_ops.append(transforms.Normalize(mean=mean, std=std))
         | 
| 24 | 
            +
                return transforms.Compose(all_ops)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def get_affine_matrix_wh(w1, h1, w2, h2):
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                transl = torch.tensor([(w2 - w1) / 2.0, (h2 - h1) / 2.0]).unsqueeze(0)
         | 
| 30 | 
            +
                center = torch.tensor([w1 / 2.0, h1 / 2.0]).unsqueeze(0)
         | 
| 31 | 
            +
                scale = torch.min(torch.tensor([w2 / w1, h2 / h1])).repeat(2).unsqueeze(0)
         | 
| 32 | 
            +
                M = get_affine_matrix2d(transl, center, scale, angle=torch.tensor([0.]))
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                return M
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def get_affine_matrix_box(boxes, w2, h2):
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                # boxes [left, top, right, bottom]
         | 
| 40 | 
            +
                width = boxes[:, 2] - boxes[:, 0]    #(N,)
         | 
| 41 | 
            +
                height = boxes[:, 3] - boxes[:, 1]    #(N,)
         | 
| 42 | 
            +
                center = torch.tensor(
         | 
| 43 | 
            +
                    [(boxes[:, 0] + boxes[:, 2]) / 2.0, (boxes[:, 1] + boxes[:, 3]) / 2.0]
         | 
| 44 | 
            +
                ).T    #(N,2)
         | 
| 45 | 
            +
                scale = torch.min(torch.tensor([w2 / width, h2 / height]),
         | 
| 46 | 
            +
                                  dim=0)[0].unsqueeze(1).repeat(1, 2) * 0.9    #(N,2)
         | 
| 47 | 
            +
                transl = torch.cat([w2 / 2.0 - center[:, 0:1], h2 / 2.0 - center[:, 1:2]], dim=1)   #(N,2)
         | 
| 48 | 
            +
                M = get_affine_matrix2d(transl, center, scale, angle=torch.tensor([0.,]*transl.shape[0]))
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                return M
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            def load_img(img_file):
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                if img_file.endswith("exr"):
         | 
| 56 | 
            +
                    img = cv2.imread(img_file, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)  
         | 
| 57 | 
            +
                else :
         | 
| 58 | 
            +
                    img = cv2.imread(img_file, cv2.IMREAD_UNCHANGED)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                # considering non 8-bit image
         | 
| 61 | 
            +
                if img.dtype != np.uint8 :
         | 
| 62 | 
            +
                    img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                if len(img.shape) == 2:
         | 
| 65 | 
            +
                    img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                if not img_file.endswith("png"):
         | 
| 68 | 
            +
                    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
         | 
| 69 | 
            +
                else:
         | 
| 70 | 
            +
                    img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                return torch.tensor(img).permute(2, 0, 1).unsqueeze(0).float(), img.shape[:2]
         | 
| 73 | 
            +
             | 
| 74 | 
            +
             | 
| 75 | 
            +
            def get_keypoints(image):
         | 
| 76 | 
            +
                def collect_xyv(x, body=True):
         | 
| 77 | 
            +
                    lmk = x.landmark
         | 
| 78 | 
            +
                    all_lmks = []
         | 
| 79 | 
            +
                    for i in range(len(lmk)):
         | 
| 80 | 
            +
                        visibility = lmk[i].visibility if body else 1.0
         | 
| 81 | 
            +
                        all_lmks.append(torch.Tensor([lmk[i].x, lmk[i].y, lmk[i].z, visibility]))
         | 
| 82 | 
            +
                    return torch.stack(all_lmks).view(-1, 4)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                mp_holistic = mp.solutions.holistic
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                with mp_holistic.Holistic(
         | 
| 87 | 
            +
                    static_image_mode=True,
         | 
| 88 | 
            +
                    model_complexity=2,
         | 
| 89 | 
            +
                ) as holistic:
         | 
| 90 | 
            +
                    results = holistic.process(image)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                fake_kps = torch.zeros(33, 4)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                result = {}
         | 
| 95 | 
            +
                result["body"] = collect_xyv(results.pose_landmarks) if results.pose_landmarks else fake_kps
         | 
| 96 | 
            +
                result["lhand"] = collect_xyv(
         | 
| 97 | 
            +
                    results.left_hand_landmarks, False
         | 
| 98 | 
            +
                ) if results.left_hand_landmarks else fake_kps
         | 
| 99 | 
            +
                result["rhand"] = collect_xyv(
         | 
| 100 | 
            +
                    results.right_hand_landmarks, False
         | 
| 101 | 
            +
                ) if results.right_hand_landmarks else fake_kps
         | 
| 102 | 
            +
                result["face"] = collect_xyv(
         | 
| 103 | 
            +
                    results.face_landmarks, False
         | 
| 104 | 
            +
                ) if results.face_landmarks else fake_kps
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                return result
         | 
| 107 | 
            +
             | 
| 108 | 
            +
             | 
| 109 | 
            +
            def get_pymafx(image, landmarks):
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                # image [3,512,512]
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                item = {
         | 
| 114 | 
            +
                    'img_body':
         | 
| 115 | 
            +
                        F.interpolate(image.unsqueeze(0), size=224, mode='bicubic', align_corners=True)[0]
         | 
| 116 | 
            +
                }
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                for part in ['lhand', 'rhand', 'face']:
         | 
| 119 | 
            +
                    kp2d = landmarks[part]
         | 
| 120 | 
            +
                    kp2d_valid = kp2d[kp2d[:, 3] > 0.]
         | 
| 121 | 
            +
                    if len(kp2d_valid) > 0:
         | 
| 122 | 
            +
                        bbox = [
         | 
| 123 | 
            +
                            min(kp2d_valid[:, 0]),
         | 
| 124 | 
            +
                            min(kp2d_valid[:, 1]),
         | 
| 125 | 
            +
                            max(kp2d_valid[:, 0]),
         | 
| 126 | 
            +
                            max(kp2d_valid[:, 1])
         | 
| 127 | 
            +
                        ]
         | 
| 128 | 
            +
                        center_part = [(bbox[2] + bbox[0]) / 2., (bbox[3] + bbox[1]) / 2.]
         | 
| 129 | 
            +
                        scale_part = 2. * max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    # handle invalid part keypoints
         | 
| 132 | 
            +
                    if len(kp2d_valid) < 1 or scale_part < 0.01:
         | 
| 133 | 
            +
                        center_part = [0, 0]
         | 
| 134 | 
            +
                        scale_part = 0.5
         | 
| 135 | 
            +
                        kp2d[:, 3] = 0
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    center_part = torch.tensor(center_part).float()
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    theta_part = torch.zeros(1, 2, 3)
         | 
| 140 | 
            +
                    theta_part[:, 0, 0] = scale_part
         | 
| 141 | 
            +
                    theta_part[:, 1, 1] = scale_part
         | 
| 142 | 
            +
                    theta_part[:, :, -1] = center_part
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    grid = F.affine_grid(theta_part, torch.Size([1, 3, 224, 224]), align_corners=False)
         | 
| 145 | 
            +
                    img_part = F.grid_sample(image.unsqueeze(0), grid, align_corners=False).squeeze(0).float()
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    item[f'img_{part}'] = img_part
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    theta_i_inv = torch.zeros_like(theta_part)
         | 
| 150 | 
            +
                    theta_i_inv[:, 0, 0] = 1. / theta_part[:, 0, 0]
         | 
| 151 | 
            +
                    theta_i_inv[:, 1, 1] = 1. / theta_part[:, 1, 1]
         | 
| 152 | 
            +
                    theta_i_inv[:, :, -1] = -theta_part[:, :, -1] / theta_part[:, 0, 0].unsqueeze(-1)
         | 
| 153 | 
            +
                    item[f'{part}_theta_inv'] = theta_i_inv[0]
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                return item
         | 
| 156 | 
            +
             | 
| 157 | 
            +
             | 
| 158 | 
            +
            def remove_floats(mask):
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                # 1. find all the contours
         | 
| 161 | 
            +
                # 2. fillPoly "True" for the largest one
         | 
| 162 | 
            +
                # 3. fillPoly "False" for its childrens
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                new_mask = np.zeros(mask.shape)
         | 
| 165 | 
            +
                cnts, hier = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
         | 
| 166 | 
            +
                cnt_index = sorted(range(len(cnts)), key=lambda k: cv2.contourArea(cnts[k]), reverse=True)
         | 
| 167 | 
            +
                body_cnt = cnts[cnt_index[0]]
         | 
| 168 | 
            +
                childs_cnt_idx = np.where(np.array(hier)[0, :, -1] == cnt_index[0])[0]
         | 
| 169 | 
            +
                childs_cnt = [cnts[idx] for idx in childs_cnt_idx]
         | 
| 170 | 
            +
                cv2.fillPoly(new_mask, [body_cnt], 1)
         | 
| 171 | 
            +
                cv2.fillPoly(new_mask, childs_cnt, 0)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                return new_mask
         | 
| 174 | 
            +
             | 
| 175 | 
            +
             | 
| 176 | 
            +
            def econ_process_image(img_file, hps_type, single, input_res, detector):
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                img_raw, (in_height, in_width) = load_img(img_file)
         | 
| 179 | 
            +
                tgt_res = input_res * 2
         | 
| 180 | 
            +
                M_square = get_affine_matrix_wh(in_width, in_height, tgt_res, tgt_res)
         | 
| 181 | 
            +
                img_square = warp_affine(
         | 
| 182 | 
            +
                    img_raw,
         | 
| 183 | 
            +
                    M_square[:, :2], (tgt_res, ) * 2,
         | 
| 184 | 
            +
                    mode='bilinear',
         | 
| 185 | 
            +
                    padding_mode='zeros',
         | 
| 186 | 
            +
                    align_corners=True
         | 
| 187 | 
            +
                )
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                # detection for bbox
         | 
| 190 | 
            +
                predictions = detector(img_square / 255.)[0]
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                if single:
         | 
| 193 | 
            +
                    top_score = predictions["scores"][predictions["labels"] == 1].max()
         | 
| 194 | 
            +
                    human_ids = torch.where(predictions["scores"] == top_score)[0]
         | 
| 195 | 
            +
                else:
         | 
| 196 | 
            +
                    human_ids = torch.logical_and(predictions["labels"] == 1,
         | 
| 197 | 
            +
                                                  predictions["scores"] > 0.9).nonzero().squeeze(1)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                boxes = predictions["boxes"][human_ids, :].detach().cpu().numpy()
         | 
| 200 | 
            +
                masks = predictions["masks"][human_ids, :, :].permute(0, 2, 3, 1).detach().cpu().numpy()
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                M_crop = get_affine_matrix_box(boxes, input_res, input_res)
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                img_icon_lst = []
         | 
| 205 | 
            +
                img_crop_lst = []
         | 
| 206 | 
            +
                img_hps_lst = []
         | 
| 207 | 
            +
                img_mask_lst = []
         | 
| 208 | 
            +
                landmark_lst = []
         | 
| 209 | 
            +
                hands_visibility_lst = []
         | 
| 210 | 
            +
                img_pymafx_lst = []
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                uncrop_param = {
         | 
| 213 | 
            +
                    "ori_shape": [in_height, in_width],
         | 
| 214 | 
            +
                    "box_shape": [input_res, input_res],
         | 
| 215 | 
            +
                    "square_shape": [tgt_res, tgt_res],
         | 
| 216 | 
            +
                    "M_square": M_square,
         | 
| 217 | 
            +
                    "M_crop": M_crop
         | 
| 218 | 
            +
                }
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                for idx in range(len(boxes)):
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                    # mask out the pixels of others
         | 
| 223 | 
            +
                    if len(masks) > 1:
         | 
| 224 | 
            +
                        mask_detection = (masks[np.arange(len(masks)) != idx]).max(axis=0)
         | 
| 225 | 
            +
                    else:
         | 
| 226 | 
            +
                        mask_detection = masks[0] * 0.
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    img_square_rgba = torch.cat(
         | 
| 229 | 
            +
                        [img_square.squeeze(0).permute(1, 2, 0),
         | 
| 230 | 
            +
                         torch.tensor(mask_detection < 0.4) * 255],
         | 
| 231 | 
            +
                        dim=2
         | 
| 232 | 
            +
                    )
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                    img_crop = warp_affine(
         | 
| 235 | 
            +
                        img_square_rgba.unsqueeze(0).permute(0, 3, 1, 2),
         | 
| 236 | 
            +
                        M_crop[idx:idx + 1, :2], (input_res, ) * 2,
         | 
| 237 | 
            +
                        mode='bilinear',
         | 
| 238 | 
            +
                        padding_mode='zeros',
         | 
| 239 | 
            +
                        align_corners=True
         | 
| 240 | 
            +
                    ).squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8)
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    # get accurate person segmentation mask
         | 
| 243 | 
            +
                    img_rembg = remove(img_crop) #post_process_mask=True)
         | 
| 244 | 
            +
                    img_mask = remove_floats(img_rembg[:, :, [3]])
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    mean_icon = std_icon = (0.5, 0.5, 0.5)
         | 
| 247 | 
            +
                    img_np = (img_rembg[..., :3] * img_mask).astype(np.uint8)
         | 
| 248 | 
            +
                    img_icon = transform_to_tensor(512, mean_icon, std_icon)(
         | 
| 249 | 
            +
                        Image.fromarray(img_np)
         | 
| 250 | 
            +
                    ) * torch.tensor(img_mask).permute(2, 0, 1)
         | 
| 251 | 
            +
                    img_hps = transform_to_tensor(224, constants.IMG_NORM_MEAN,
         | 
| 252 | 
            +
                                                  constants.IMG_NORM_STD)(Image.fromarray(img_np))
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                    landmarks = get_keypoints(img_np)
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    # get hands visibility
         | 
| 257 | 
            +
                    hands_visibility = [True, True]
         | 
| 258 | 
            +
                    if landmarks['lhand'][:, -1].mean() == 0.:
         | 
| 259 | 
            +
                        hands_visibility[0] = False
         | 
| 260 | 
            +
                    if landmarks['rhand'][:, -1].mean() == 0.:
         | 
| 261 | 
            +
                        hands_visibility[1] = False
         | 
| 262 | 
            +
                    hands_visibility_lst.append(hands_visibility)
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                    if hps_type == 'pymafx':
         | 
| 265 | 
            +
                        img_pymafx_lst.append(
         | 
| 266 | 
            +
                            get_pymafx(
         | 
| 267 | 
            +
                                transform_to_tensor(512, constants.IMG_NORM_MEAN,
         | 
| 268 | 
            +
                                                    constants.IMG_NORM_STD)(Image.fromarray(img_np)), landmarks
         | 
| 269 | 
            +
                            )
         | 
| 270 | 
            +
                        )
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                    img_crop_lst.append(torch.tensor(img_crop).permute(2, 0, 1) / 255.0)
         | 
| 273 | 
            +
                    img_icon_lst.append(img_icon)
         | 
| 274 | 
            +
                    img_hps_lst.append(img_hps)
         | 
| 275 | 
            +
                    img_mask_lst.append(torch.tensor(img_mask[..., 0]))
         | 
| 276 | 
            +
                    landmark_lst.append(landmarks['body'])
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                # required image tensors / arrays
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                # img_icon  (tensor): (-1, 1),          [3,512,512]
         | 
| 281 | 
            +
                # img_hps   (tensor): (-2.11, 2.44),    [3,224,224]
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                # img_np    (array): (0, 255),          [512,512,3]
         | 
| 284 | 
            +
                # img_rembg (array): (0, 255),          [512,512,4]
         | 
| 285 | 
            +
                # img_mask  (array): (0, 1),            [512,512,1]
         | 
| 286 | 
            +
                # img_crop  (array): (0, 255),          [512,512,4]
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                return_dict = {
         | 
| 289 | 
            +
                    "img_icon": torch.stack(img_icon_lst).float(),    #[N, 3, res, res]
         | 
| 290 | 
            +
                    "img_crop": torch.stack(img_crop_lst).float(),    #[N, 4, res, res]               
         | 
| 291 | 
            +
                    "img_hps": torch.stack(img_hps_lst).float(),    #[N, 3, res, res]
         | 
| 292 | 
            +
                    "img_raw": img_raw,    #[1, 3, H, W]
         | 
| 293 | 
            +
                    "img_mask": torch.stack(img_mask_lst).float(),    #[N, res, res]
         | 
| 294 | 
            +
                    "uncrop_param": uncrop_param,
         | 
| 295 | 
            +
                    "landmark": torch.stack(landmark_lst),    #[N, 33, 4]
         | 
| 296 | 
            +
                    "hands_visibility": hands_visibility_lst,
         | 
| 297 | 
            +
                }
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                img_pymafx = {}
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                if len(img_pymafx_lst) > 0:
         | 
| 302 | 
            +
                    for idx in range(len(img_pymafx_lst)):
         | 
| 303 | 
            +
                        for key in img_pymafx_lst[idx].keys():
         | 
| 304 | 
            +
                            if key not in img_pymafx.keys():
         | 
| 305 | 
            +
                                img_pymafx[key] = [img_pymafx_lst[idx][key]]
         | 
| 306 | 
            +
                            else:
         | 
| 307 | 
            +
                                img_pymafx[key] += [img_pymafx_lst[idx][key]]
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                    for key in img_pymafx.keys():
         | 
| 310 | 
            +
                        img_pymafx[key] = torch.stack(img_pymafx[key]).float()
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                    return_dict.update({"img_pymafx": img_pymafx})
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                return return_dict
         | 
| 315 | 
            +
             | 
| 316 | 
            +
             | 
| 317 | 
            +
            def blend_rgb_norm(norms, data):
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                # norms [N, 3, res, res]
         | 
| 320 | 
            +
                masks = (norms.sum(dim=1) != norms[0, :, 0, 0].sum()).float().unsqueeze(1)
         | 
| 321 | 
            +
                norm_mask = F.interpolate(
         | 
| 322 | 
            +
                    torch.cat([norms, masks], dim=1).detach(),
         | 
| 323 | 
            +
                    size=data["uncrop_param"]["box_shape"],
         | 
| 324 | 
            +
                    mode="bilinear",
         | 
| 325 | 
            +
                    align_corners=False
         | 
| 326 | 
            +
                )
         | 
| 327 | 
            +
                final = data["img_raw"].type_as(norm_mask)
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                for idx in range(len(norms)):
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                    norm_pred = (norm_mask[idx:idx + 1, :3, :, :] + 1.0) * 255.0 / 2.0
         | 
| 332 | 
            +
                    mask_pred = norm_mask[idx:idx + 1, 3:4, :, :].repeat(1, 3, 1, 1)
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                    norm_ori = unwrap(norm_pred, data["uncrop_param"], idx)
         | 
| 335 | 
            +
                    mask_ori = unwrap(mask_pred, data["uncrop_param"], idx)
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                    final = final * (1.0 - mask_ori) + norm_ori * mask_ori
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                return final.detach().cpu()
         | 
| 340 | 
            +
             | 
| 341 | 
            +
             | 
| 342 | 
            +
            def unwrap(image, uncrop_param, idx):
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                device = image.device
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                img_square = warp_affine(
         | 
| 347 | 
            +
                    image,
         | 
| 348 | 
            +
                    torch.inverse(uncrop_param["M_crop"])[idx:idx + 1, :2].to(device),
         | 
| 349 | 
            +
                    uncrop_param["square_shape"],
         | 
| 350 | 
            +
                    mode='bilinear',
         | 
| 351 | 
            +
                    padding_mode='zeros',
         | 
| 352 | 
            +
                    align_corners=True
         | 
| 353 | 
            +
                )
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                img_ori = warp_affine(
         | 
| 356 | 
            +
                    img_square,
         | 
| 357 | 
            +
                    torch.inverse(uncrop_param["M_square"])[:, :2].to(device),
         | 
| 358 | 
            +
                    uncrop_param["ori_shape"],
         | 
| 359 | 
            +
                    mode='bilinear',
         | 
| 360 | 
            +
                    padding_mode='zeros',
         | 
| 361 | 
            +
                    align_corners=True
         | 
| 362 | 
            +
                )
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                return img_ori
         | 
    	
        lib/common/render.py
    ADDED
    
    | @@ -0,0 +1,398 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
         | 
| 4 | 
            +
            # holder of all proprietary rights on this computer program.
         | 
| 5 | 
            +
            # You can only use this computer program if you have closed
         | 
| 6 | 
            +
            # a license agreement with MPG or you get the right to use the computer
         | 
| 7 | 
            +
            # program from someone who is authorized to grant you that right.
         | 
| 8 | 
            +
            # Any use of the computer program without a valid license is prohibited and
         | 
| 9 | 
            +
            # liable to prosecution.
         | 
| 10 | 
            +
            #
         | 
| 11 | 
            +
            # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
         | 
| 12 | 
            +
            # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
         | 
| 13 | 
            +
            # for Intelligent Systems. All rights reserved.
         | 
| 14 | 
            +
            #
         | 
| 15 | 
            +
            # Contact: [email protected]
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from pytorch3d.renderer import (
         | 
| 18 | 
            +
                BlendParams,
         | 
| 19 | 
            +
                blending,
         | 
| 20 | 
            +
                look_at_view_transform,
         | 
| 21 | 
            +
                FoVOrthographicCameras,
         | 
| 22 | 
            +
                PointLights,
         | 
| 23 | 
            +
                RasterizationSettings,
         | 
| 24 | 
            +
                PointsRasterizationSettings,
         | 
| 25 | 
            +
                PointsRenderer,
         | 
| 26 | 
            +
                AlphaCompositor,
         | 
| 27 | 
            +
                PointsRasterizer,
         | 
| 28 | 
            +
                MeshRenderer,
         | 
| 29 | 
            +
                MeshRasterizer,
         | 
| 30 | 
            +
                SoftPhongShader,
         | 
| 31 | 
            +
                SoftSilhouetteShader,
         | 
| 32 | 
            +
                TexturesVertex,
         | 
| 33 | 
            +
            )
         | 
| 34 | 
            +
            from pytorch3d.renderer.mesh import TexturesVertex
         | 
| 35 | 
            +
            from pytorch3d.structures import Meshes
         | 
| 36 | 
            +
            from lib.dataset.mesh_util import get_visibility, get_visibility_color
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            import lib.common.render_utils as util
         | 
| 39 | 
            +
            import torch
         | 
| 40 | 
            +
            import numpy as np
         | 
| 41 | 
            +
            from PIL import Image
         | 
| 42 | 
            +
            from tqdm import tqdm
         | 
| 43 | 
            +
            import os
         | 
| 44 | 
            +
            import cv2
         | 
| 45 | 
            +
            import math
         | 
| 46 | 
            +
            from termcolor import colored
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            def image2vid(images, vid_path):
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                w, h = images[0].size
         | 
| 52 | 
            +
                videodims = (w, h)
         | 
| 53 | 
            +
                fourcc = cv2.VideoWriter_fourcc(*'XVID')
         | 
| 54 | 
            +
                video = cv2.VideoWriter(vid_path, fourcc, len(images) / 5.0, videodims)
         | 
| 55 | 
            +
                for image in images:
         | 
| 56 | 
            +
                    video.write(cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR))
         | 
| 57 | 
            +
                video.release()
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            def query_color(verts, faces, image, device, predicted_color):
         | 
| 61 | 
            +
                """query colors from points and image
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                Args:
         | 
| 64 | 
            +
                    verts ([B, 3]): [query verts]
         | 
| 65 | 
            +
                    faces ([M, 3]): [query faces]
         | 
| 66 | 
            +
                    image ([B, 3, H, W]): [full image]
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                Returns:
         | 
| 69 | 
            +
                    [np.float]: [return colors]
         | 
| 70 | 
            +
                """
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                verts = verts.float().to(device)
         | 
| 73 | 
            +
                faces = faces.long().to(device)
         | 
| 74 | 
            +
                predicted_color=predicted_color.to(device)
         | 
| 75 | 
            +
                (xy, z) = verts.split([2, 1], dim=1)
         | 
| 76 | 
            +
                visibility = get_visibility_color(xy, z, faces[:, [0, 2, 1]]).flatten()
         | 
| 77 | 
            +
                uv = xy.unsqueeze(0).unsqueeze(2)  # [B, N, 2]
         | 
| 78 | 
            +
                uv = uv * torch.tensor([1.0, -1.0]).type_as(uv)
         | 
| 79 | 
            +
                colors = (torch.nn.functional.grid_sample(
         | 
| 80 | 
            +
                    image, uv, align_corners=True)[0, :, :, 0].permute(1, 0) +
         | 
| 81 | 
            +
                          1.0) * 0.5 * 255.0
         | 
| 82 | 
            +
                colors[visibility == 0.0]=(predicted_color* 255.0)[visibility == 0.0]
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                return colors.detach().cpu()
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            class cleanShader(torch.nn.Module):
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                def __init__(self, device="cpu", cameras=None, blend_params=None):
         | 
| 90 | 
            +
                    super().__init__()
         | 
| 91 | 
            +
                    self.cameras = cameras
         | 
| 92 | 
            +
                    self.blend_params = blend_params if blend_params is not None else BlendParams(
         | 
| 93 | 
            +
                    )
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                def forward(self, fragments, meshes, **kwargs):
         | 
| 96 | 
            +
                    cameras = kwargs.get("cameras", self.cameras)
         | 
| 97 | 
            +
                    if cameras is None:
         | 
| 98 | 
            +
                        msg = "Cameras must be specified either at initialization \
         | 
| 99 | 
            +
                            or in the forward pass of TexturedSoftPhongShader"
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                        raise ValueError(msg)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    # get renderer output
         | 
| 104 | 
            +
                    blend_params = kwargs.get("blend_params", self.blend_params)
         | 
| 105 | 
            +
                    texels = meshes.sample_textures(fragments)
         | 
| 106 | 
            +
                    images = blending.softmax_rgb_blend(texels,
         | 
| 107 | 
            +
                                                        fragments,
         | 
| 108 | 
            +
                                                        blend_params,
         | 
| 109 | 
            +
                                                        znear=-256,
         | 
| 110 | 
            +
                                                        zfar=256)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    return images
         | 
| 113 | 
            +
             | 
| 114 | 
            +
             | 
| 115 | 
            +
            class Render:
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                def __init__(self, size=512, device=torch.device("cuda:0")):
         | 
| 118 | 
            +
                    self.device = device
         | 
| 119 | 
            +
                    self.size = size
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    # camera setting
         | 
| 122 | 
            +
                    self.dis = 100.0
         | 
| 123 | 
            +
                    self.scale = 100.0
         | 
| 124 | 
            +
                    self.mesh_y_center = 0.0
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    self.reload_cam()
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    self.type = "color"
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    self.mesh = None
         | 
| 131 | 
            +
                    self.deform_mesh = None
         | 
| 132 | 
            +
                    self.pcd = None
         | 
| 133 | 
            +
                    self.renderer = None
         | 
| 134 | 
            +
                    self.meshRas = None
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    self.uv_rasterizer = util.Pytorch3dRasterizer(self.size)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                def reload_cam(self):
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    self.cam_pos = [
         | 
| 141 | 
            +
                        (0, self.mesh_y_center, self.dis),
         | 
| 142 | 
            +
                        (self.dis, self.mesh_y_center, 0),
         | 
| 143 | 
            +
                        (0, self.mesh_y_center, -self.dis),
         | 
| 144 | 
            +
                        (-self.dis, self.mesh_y_center, 0),
         | 
| 145 | 
            +
                        (0,self.mesh_y_center+self.dis,0),
         | 
| 146 | 
            +
                        (0,self.mesh_y_center-self.dis,0),
         | 
| 147 | 
            +
                    ]
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                def get_camera(self, cam_id):
         | 
| 150 | 
            +
                    
         | 
| 151 | 
            +
                    if cam_id == 4:
         | 
| 152 | 
            +
                        R, T = look_at_view_transform(
         | 
| 153 | 
            +
                            eye=[self.cam_pos[cam_id]],
         | 
| 154 | 
            +
                            at=((0, self.mesh_y_center, 0), ),
         | 
| 155 | 
            +
                            up=((0, 0, 1), ),
         | 
| 156 | 
            +
                        )
         | 
| 157 | 
            +
                    elif cam_id == 5:
         | 
| 158 | 
            +
                        R, T = look_at_view_transform(
         | 
| 159 | 
            +
                            eye=[self.cam_pos[cam_id]],
         | 
| 160 | 
            +
                            at=((0, self.mesh_y_center, 0), ),
         | 
| 161 | 
            +
                            up=((0, 0, 1), ),
         | 
| 162 | 
            +
                        )
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    else:
         | 
| 165 | 
            +
                        R, T = look_at_view_transform(
         | 
| 166 | 
            +
                            eye=[self.cam_pos[cam_id]],
         | 
| 167 | 
            +
                            at=((0, self.mesh_y_center, 0), ),
         | 
| 168 | 
            +
                            up=((0, 1, 0), ),
         | 
| 169 | 
            +
                        )
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    camera = FoVOrthographicCameras(
         | 
| 172 | 
            +
                        device=self.device,
         | 
| 173 | 
            +
                        R=R,
         | 
| 174 | 
            +
                        T=T,
         | 
| 175 | 
            +
                        znear=100.0,
         | 
| 176 | 
            +
                        zfar=-100.0,
         | 
| 177 | 
            +
                        max_y=100.0,
         | 
| 178 | 
            +
                        min_y=-100.0,
         | 
| 179 | 
            +
                        max_x=100.0,
         | 
| 180 | 
            +
                        min_x=-100.0,
         | 
| 181 | 
            +
                        scale_xyz=(self.scale * np.ones(3), ),
         | 
| 182 | 
            +
                    )
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    return camera
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                def init_renderer(self, camera, type="clean_mesh", bg="gray"):
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    if "mesh" in type:
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                        # rasterizer
         | 
| 191 | 
            +
                        self.raster_settings_mesh = RasterizationSettings(
         | 
| 192 | 
            +
                            image_size=self.size,
         | 
| 193 | 
            +
                            blur_radius=np.log(1.0 / 1e-4) * 1e-7,
         | 
| 194 | 
            +
                            faces_per_pixel=30,
         | 
| 195 | 
            +
                        )
         | 
| 196 | 
            +
                        self.meshRas = MeshRasterizer(
         | 
| 197 | 
            +
                            cameras=camera, raster_settings=self.raster_settings_mesh)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    if bg == "black":
         | 
| 200 | 
            +
                        blendparam = BlendParams(1e-4, 1e-4, (0.0, 0.0, 0.0))
         | 
| 201 | 
            +
                    elif bg == "white":
         | 
| 202 | 
            +
                        blendparam = BlendParams(1e-4, 1e-8, (1.0, 1.0, 1.0))
         | 
| 203 | 
            +
                    elif bg == "gray":
         | 
| 204 | 
            +
                        blendparam = BlendParams(1e-4, 1e-8, (0.5, 0.5, 0.5))
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                    if type == "ori_mesh":
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                        lights = PointLights(
         | 
| 209 | 
            +
                            device=self.device,
         | 
| 210 | 
            +
                            ambient_color=((0.8, 0.8, 0.8), ),
         | 
| 211 | 
            +
                            diffuse_color=((0.2, 0.2, 0.2), ),
         | 
| 212 | 
            +
                            specular_color=((0.0, 0.0, 0.0), ),
         | 
| 213 | 
            +
                            location=[[0.0, 200.0, 0.0]],
         | 
| 214 | 
            +
                        )
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                        self.renderer = MeshRenderer(
         | 
| 217 | 
            +
                            rasterizer=self.meshRas,
         | 
| 218 | 
            +
                            shader=SoftPhongShader(
         | 
| 219 | 
            +
                                device=self.device,
         | 
| 220 | 
            +
                                cameras=camera,
         | 
| 221 | 
            +
                                lights=None,
         | 
| 222 | 
            +
                                blend_params=blendparam,
         | 
| 223 | 
            +
                            ),
         | 
| 224 | 
            +
                        )
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                    if type == "silhouette":
         | 
| 227 | 
            +
                        self.raster_settings_silhouette = RasterizationSettings(
         | 
| 228 | 
            +
                            image_size=self.size,
         | 
| 229 | 
            +
                            blur_radius=np.log(1.0 / 1e-4 - 1.0) * 5e-5,
         | 
| 230 | 
            +
                            faces_per_pixel=50,
         | 
| 231 | 
            +
                            cull_backfaces=True,
         | 
| 232 | 
            +
                        )
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                        self.silhouetteRas = MeshRasterizer(
         | 
| 235 | 
            +
                            cameras=camera,
         | 
| 236 | 
            +
                            raster_settings=self.raster_settings_silhouette)
         | 
| 237 | 
            +
                        self.renderer = MeshRenderer(rasterizer=self.silhouetteRas,
         | 
| 238 | 
            +
                                                     shader=SoftSilhouetteShader())
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    if type == "pointcloud":
         | 
| 241 | 
            +
                        self.raster_settings_pcd = PointsRasterizationSettings(
         | 
| 242 | 
            +
                            image_size=self.size, radius=0.006, points_per_pixel=10)
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                        self.pcdRas = PointsRasterizer(
         | 
| 245 | 
            +
                            cameras=camera, raster_settings=self.raster_settings_pcd)
         | 
| 246 | 
            +
                        self.renderer = PointsRenderer(
         | 
| 247 | 
            +
                            rasterizer=self.pcdRas,
         | 
| 248 | 
            +
                            compositor=AlphaCompositor(background_color=(0, 0, 0)),
         | 
| 249 | 
            +
                        )
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                    if type == "clean_mesh":
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                        self.renderer = MeshRenderer(
         | 
| 254 | 
            +
                            rasterizer=self.meshRas,
         | 
| 255 | 
            +
                            shader=cleanShader(device=self.device,
         | 
| 256 | 
            +
                                               cameras=camera,
         | 
| 257 | 
            +
                                               blend_params=blendparam),
         | 
| 258 | 
            +
                        )
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                def VF2Mesh(self, verts, faces, vertex_texture = None):
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                    if not torch.is_tensor(verts):
         | 
| 263 | 
            +
                        verts = torch.tensor(verts)
         | 
| 264 | 
            +
                    if not torch.is_tensor(faces):
         | 
| 265 | 
            +
                        faces = torch.tensor(faces)
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                    if verts.ndimension() == 2:
         | 
| 268 | 
            +
                        verts = verts.unsqueeze(0).float()
         | 
| 269 | 
            +
                    if faces.ndimension() == 2:
         | 
| 270 | 
            +
                        faces = faces.unsqueeze(0).long()
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                    verts = verts.to(self.device)
         | 
| 273 | 
            +
                    faces = faces.to(self.device)
         | 
| 274 | 
            +
                    if vertex_texture is not None:
         | 
| 275 | 
            +
                        vertex_texture = vertex_texture.to(self.device)
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                    mesh = Meshes(verts, faces).to(self.device)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                    if vertex_texture is None:
         | 
| 280 | 
            +
                        mesh.textures = TexturesVertex(
         | 
| 281 | 
            +
                            verts_features=(mesh.verts_normals_padded() + 1.0) * 0.5)#modify
         | 
| 282 | 
            +
                    else:    
         | 
| 283 | 
            +
                        mesh.textures = TexturesVertex(
         | 
| 284 | 
            +
                            verts_features = vertex_texture.unsqueeze(0))#modify
         | 
| 285 | 
            +
                    return mesh
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                def load_meshes(self, verts, faces,offset=None, vertex_texture = None):
         | 
| 288 | 
            +
                    """load mesh into the pytorch3d renderer
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                    Args:
         | 
| 291 | 
            +
                        verts ([N,3]): verts
         | 
| 292 | 
            +
                        faces ([N,3]): faces
         | 
| 293 | 
            +
                        offset ([N,3]): offset
         | 
| 294 | 
            +
                    """
         | 
| 295 | 
            +
                    if offset is not None:
         | 
| 296 | 
            +
                        verts = verts + offset
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    if isinstance(verts, list):
         | 
| 299 | 
            +
                        self.meshes = []
         | 
| 300 | 
            +
                        for V, F in zip(verts, faces):
         | 
| 301 | 
            +
                            if vertex_texture is None:
         | 
| 302 | 
            +
                                self.meshes.append(self.VF2Mesh(V, F))
         | 
| 303 | 
            +
                            else:
         | 
| 304 | 
            +
                                self.meshes.append(self.VF2Mesh(V, F, vertex_texture))
         | 
| 305 | 
            +
                    else:
         | 
| 306 | 
            +
                        if vertex_texture is None:
         | 
| 307 | 
            +
                            self.meshes = [self.VF2Mesh(verts, faces)]
         | 
| 308 | 
            +
                        else:
         | 
| 309 | 
            +
                            self.meshes = [self.VF2Mesh(verts, faces, vertex_texture)]
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                def get_depth_map(self, cam_ids=[0, 2]):
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    depth_maps = []
         | 
| 314 | 
            +
                    for cam_id in cam_ids:
         | 
| 315 | 
            +
                        self.init_renderer(self.get_camera(cam_id), "clean_mesh", "gray")
         | 
| 316 | 
            +
                        fragments = self.meshRas(self.meshes[0])
         | 
| 317 | 
            +
                        depth_map = fragments.zbuf[..., 0].squeeze(0)
         | 
| 318 | 
            +
                        if cam_id == 2:
         | 
| 319 | 
            +
                            depth_map = torch.fliplr(depth_map)
         | 
| 320 | 
            +
                        depth_maps.append(depth_map)
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                    return depth_maps
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                def get_rgb_image(self, cam_ids=[0, 2], bg='gray'):
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                    images = []
         | 
| 327 | 
            +
                    for cam_id in range(len(self.cam_pos)):
         | 
| 328 | 
            +
                        if cam_id in cam_ids:
         | 
| 329 | 
            +
                            self.init_renderer(self.get_camera(cam_id), "clean_mesh", bg)
         | 
| 330 | 
            +
                            if len(cam_ids) == 4:
         | 
| 331 | 
            +
                                rendered_img = (self.renderer(
         | 
| 332 | 
            +
                                    self.meshes[0])[0:1, :, :, :3].permute(0, 3, 1, 2) -
         | 
| 333 | 
            +
                                                0.5) * 2.0
         | 
| 334 | 
            +
                            else:
         | 
| 335 | 
            +
                                rendered_img = (self.renderer(
         | 
| 336 | 
            +
                                    self.meshes[0])[0:1, :, :, :3].permute(0, 3, 1, 2) -
         | 
| 337 | 
            +
                                                0.5) * 2.0
         | 
| 338 | 
            +
                            if cam_id == 2 and len(cam_ids) == 2:
         | 
| 339 | 
            +
                                rendered_img = torch.flip(rendered_img, dims=[3])
         | 
| 340 | 
            +
                            images.append(rendered_img)
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                    return images
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                def get_rendered_video(self, images, save_path):
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                    self.cam_pos = []
         | 
| 347 | 
            +
                    for angle in range(360):
         | 
| 348 | 
            +
                        self.cam_pos.append((
         | 
| 349 | 
            +
                            100.0 * math.cos(np.pi / 180 * angle),
         | 
| 350 | 
            +
                            self.mesh_y_center,
         | 
| 351 | 
            +
                            100.0 * math.sin(np.pi / 180 * angle),
         | 
| 352 | 
            +
                        ))
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                    old_shape = np.array(images[0].shape[:2])
         | 
| 355 | 
            +
                    new_shape = np.around(
         | 
| 356 | 
            +
                        (self.size / old_shape[0]) * old_shape).astype(np.int)
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
         | 
| 359 | 
            +
                    video = cv2.VideoWriter(save_path, fourcc, 10,
         | 
| 360 | 
            +
                                            (self.size * len(self.meshes) +
         | 
| 361 | 
            +
                                             new_shape[1] * len(images), self.size))
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                    pbar = tqdm(range(len(self.cam_pos)))
         | 
| 364 | 
            +
                    pbar.set_description(
         | 
| 365 | 
            +
                        colored(f"exporting video {os.path.basename(save_path)}...",
         | 
| 366 | 
            +
                                "blue"))
         | 
| 367 | 
            +
                    for cam_id in pbar:
         | 
| 368 | 
            +
                        self.init_renderer(self.get_camera(cam_id), "clean_mesh", "gray")
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                        img_lst = [
         | 
| 371 | 
            +
                            np.array(Image.fromarray(img).resize(new_shape[::-1])).astype(
         | 
| 372 | 
            +
                                np.uint8)[:, :, [2, 1, 0]] for img in images
         | 
| 373 | 
            +
                        ]
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                        for mesh in self.meshes:
         | 
| 376 | 
            +
                            rendered_img = ((self.renderer(mesh)[0, :, :, :3] *
         | 
| 377 | 
            +
                                             255.0).detach().cpu().numpy().astype(
         | 
| 378 | 
            +
                                                 np.uint8))
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                            img_lst.append(rendered_img)
         | 
| 381 | 
            +
                        final_img = np.concatenate(img_lst, axis=1)
         | 
| 382 | 
            +
                        video.write(final_img)
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                    video.release()
         | 
| 385 | 
            +
                    self.reload_cam()
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                def get_silhouette_image(self, cam_ids=[0, 2]):
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                    images = []
         | 
| 390 | 
            +
                    for cam_id in range(len(self.cam_pos)):
         | 
| 391 | 
            +
                        if cam_id in cam_ids:
         | 
| 392 | 
            +
                            self.init_renderer(self.get_camera(cam_id), "silhouette")
         | 
| 393 | 
            +
                            rendered_img = self.renderer(self.meshes[0])[0:1, :, :, 3]
         | 
| 394 | 
            +
                            if cam_id == 2 and len(cam_ids) == 2:
         | 
| 395 | 
            +
                                rendered_img = torch.flip(rendered_img, dims=[2])
         | 
| 396 | 
            +
                            images.append(rendered_img)
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                    return images
         | 
